diff --git a/dnn/atlas-stub/include/acl/acl.h b/dnn/atlas-stub/include/acl/acl.h old mode 100755 new mode 100644 diff --git a/dnn/atlas-stub/include/acl/acl_base.h b/dnn/atlas-stub/include/acl/acl_base.h old mode 100755 new mode 100644 diff --git a/dnn/atlas-stub/include/acl/acl_mdl.h b/dnn/atlas-stub/include/acl/acl_mdl.h old mode 100755 new mode 100644 diff --git a/dnn/atlas-stub/include/acl/acl_op.h b/dnn/atlas-stub/include/acl/acl_op.h old mode 100755 new mode 100644 diff --git a/dnn/atlas-stub/include/acl/acl_rt.h b/dnn/atlas-stub/include/acl/acl_rt.h old mode 100755 new mode 100644 diff --git a/dnn/atlas-stub/include/acl/ops/acl_cblas.h b/dnn/atlas-stub/include/acl/ops/acl_cblas.h old mode 100755 new mode 100644 diff --git a/dnn/atlas-stub/include/acl/ops/acl_dvpp.h b/dnn/atlas-stub/include/acl/ops/acl_dvpp.h old mode 100755 new mode 100644 diff --git a/dnn/atlas-stub/include/acl/ops/acl_fv.h b/dnn/atlas-stub/include/acl/ops/acl_fv.h old mode 100755 new mode 100644 diff --git a/dnn/include/hip_header.h b/dnn/include/hip_header.h index d662f578..8c310608 100644 --- a/dnn/include/hip_header.h +++ b/dnn/include/hip_header.h @@ -23,9 +23,9 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Wdeprecated-declarations" #pragma GCC diagnostic ignored "-Wsign-compare" -#include -#include #include +#include +#include #pragma GCC diagnostic pop #if !defined(__HIP_PLATFORM_HCC__) diff --git a/dnn/include/megcore.h b/dnn/include/megcore.h index 3af2da53..d8ecf5b9 100644 --- a/dnn/include/megcore.h +++ b/dnn/include/megcore.h @@ -11,10 +11,10 @@ #pragma once -#include "megdnn/thin/function.h" -#include "megcore_cdefs.h" #include #include +#include "megcore_cdefs.h" +#include "megdnn/thin/function.h" #include "megdnn/internal/visibility_prologue.h" @@ -26,36 +26,35 @@ namespace megcore { * the caller thread immediately. */ class CPUDispatcher { - public: - using Task = megdnn::thin_function; - using MultiThreadingTask = megdnn::thin_function; - virtual ~CPUDispatcher() noexcept; - - /*! - * \brief dispatch a task on the computing thread - * \param task the task that would be moved away - */ - virtual void dispatch(Task&& task) = 0; - - /*! - * \brief dispatch a multithreading task on the computing thread - * \param task the task would be moved away - * \param parallelism the parallelism of the task. - */ - virtual void dispatch(MultiThreadingTask&& task, - size_t parallelism) = 0; - - /*! - * \brief synchronize the calling thread with the computing thread - */ - virtual void sync() = 0; - - /*! - * \brief the computing thread number. - */ - virtual size_t nr_threads() = 0; +public: + using Task = megdnn::thin_function; + using MultiThreadingTask = megdnn::thin_function; + virtual ~CPUDispatcher() noexcept; + + /*! + * \brief dispatch a task on the computing thread + * \param task the task that would be moved away + */ + virtual void dispatch(Task&& task) = 0; + + /*! + * \brief dispatch a multithreading task on the computing thread + * \param task the task would be moved away + * \param parallelism the parallelism of the task. + */ + virtual void dispatch(MultiThreadingTask&& task, size_t parallelism) = 0; + + /*! + * \brief synchronize the calling thread with the computing thread + */ + virtual void sync() = 0; + + /*! + * \brief the computing thread number. + */ + virtual size_t nr_threads() = 0; }; -} // namespace megcore +} // namespace megcore using MegcoreCPUDispatcher = megcore::CPUDispatcher; @@ -63,75 +62,62 @@ using MegcoreCPUDispatcher = megcore::CPUDispatcher; * \brief Layer 1: device handle */ struct megcoreDeviceContext; -typedef struct megcoreDeviceContext *megcoreDeviceHandle_t; +typedef struct megcoreDeviceContext* megcoreDeviceHandle_t; megcoreStatus_t megcoreCreateDeviceHandle( - megcoreDeviceHandle_t *handle, - megcorePlatform_t platform, - int deviceID = -1, + megcoreDeviceHandle_t* handle, megcorePlatform_t platform, int deviceID = -1, unsigned int flags = 0); -megcoreStatus_t megcoreDestroyDeviceHandle( - megcoreDeviceHandle_t handle); - -megcoreStatus_t megcoreGetPlatform(megcoreDeviceHandle_t handle, - megcorePlatform_t *platform); -megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, - int *deviceID); -megcoreStatus_t megcoreGetMemAlignment(megcoreDeviceHandle_t handle, - size_t *memAlignmentInBytes); +megcoreStatus_t megcoreDestroyDeviceHandle(megcoreDeviceHandle_t handle); + +megcoreStatus_t megcoreGetPlatform( + megcoreDeviceHandle_t handle, megcorePlatform_t* platform); +megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, int* deviceID); +megcoreStatus_t megcoreGetMemAlignment( + megcoreDeviceHandle_t handle, size_t* memAlignmentInBytes); megcoreStatus_t megcoreGetDeviceFlags( - megcoreDeviceHandle_t handle, - unsigned int *flags); + megcoreDeviceHandle_t handle, unsigned int* flags); megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle); megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle); -megcoreStatus_t megcoreMalloc(megcoreDeviceHandle_t handle, - void **devPtr, size_t sizeInBytes); -megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, - void *devPtr); +megcoreStatus_t megcoreMalloc( + megcoreDeviceHandle_t handle, void** devPtr, size_t sizeInBytes); +megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, void* devPtr); /** * \brief Layer 2: computing handle */ struct megcoreComputingContext; -typedef struct megcoreComputingContext *megcoreComputingHandle_t; +typedef struct megcoreComputingContext* megcoreComputingHandle_t; megcoreStatus_t megcoreCreateComputingHandle( - megcoreComputingHandle_t *compHandle, - megcoreDeviceHandle_t devHandle, + megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, unsigned int flags = 0); megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( - megcoreComputingHandle_t *compHandle, - megcoreDeviceHandle_t devHandle, + megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, const std::shared_ptr& dispatcher, unsigned int flags = 0); -megcoreStatus_t megcoreDestroyComputingHandle( - megcoreComputingHandle_t handle); +megcoreStatus_t megcoreDestroyComputingHandle(megcoreComputingHandle_t handle); megcoreStatus_t megcoreGetDeviceHandle( - megcoreComputingHandle_t compHandle, - megcoreDeviceHandle_t *devHandle); + megcoreComputingHandle_t compHandle, megcoreDeviceHandle_t* devHandle); megcoreStatus_t megcoreGetComputingFlags( - megcoreComputingHandle_t handle, - unsigned int *flags); + megcoreComputingHandle_t handle, unsigned int* flags); MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle); megcoreStatus_t megcoreMemcpy( - megcoreComputingHandle_t handle, - void *dst, const void *src, size_t sizeInBytes, + megcoreComputingHandle_t handle, void* dst, const void* src, size_t sizeInBytes, megcoreMemcpyKind_t kind); megcoreStatus_t megcoreMemset( - megcoreComputingHandle_t handle, - void *dst, int value, size_t sizeInBytes); + megcoreComputingHandle_t handle, void* dst, int value, size_t sizeInBytes); megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle); /** * \brief Miscellaneous */ -const char *megcoreGetErrorName(megcoreStatus_t status); +const char* megcoreGetErrorName(megcoreStatus_t status); #include "megdnn/internal/visibility_epilogue.h" diff --git a/dnn/include/megcore_atlas.h b/dnn/include/megcore_atlas.h index 308c6d74..fd47064c 100644 --- a/dnn/include/megcore_atlas.h +++ b/dnn/include/megcore_atlas.h @@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithAtlasContext( megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, unsigned int flags, const AtlasContext& ctx); -megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, - AtlasContext* ctx); +megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, AtlasContext* ctx); namespace atlas { //! convert acl error code to error string @@ -47,12 +46,12 @@ inline megcoreStatus_t megcoreCreateComputingHandleWithACLStream( megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, unsigned int flags, aclrtStream stream) { megcore::AtlasContext ctx{stream}; - return megcore::createComputingHandleWithAtlasContext(compHandle, devHandle, - flags, ctx); + return megcore::createComputingHandleWithAtlasContext( + compHandle, devHandle, flags, ctx); } -inline megcoreStatus_t megcoreGetACLStream(megcoreComputingHandle_t handle, - aclrtStream* stream) { +inline megcoreStatus_t megcoreGetACLStream( + megcoreComputingHandle_t handle, aclrtStream* stream) { megcore::AtlasContext ctx; auto ret = megcore::getAtlasContext(handle, &ctx); *stream = ctx.stream; diff --git a/dnn/include/megcore_cambricon.h b/dnn/include/megcore_cambricon.h index 01031fb2..7cdeb386 100644 --- a/dnn/include/megcore_cambricon.h +++ b/dnn/include/megcore_cambricon.h @@ -34,8 +34,8 @@ megcoreStatus_t createComputingHandleWithCambriconContext( megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, unsigned int flags, const CambriconContext& ctx); -megcoreStatus_t getCambriconContext(megcoreComputingHandle_t handle, - CambriconContext* ctx); +megcoreStatus_t getCambriconContext( + megcoreComputingHandle_t handle, CambriconContext* ctx); } // namespace megcore @@ -58,4 +58,3 @@ static inline megcoreStatus_t megcoreGetCNRTQueue( #include "megdnn/internal/visibility_epilogue.h" // vim: syntax=cpp.doxygen - diff --git a/dnn/include/megcore_cdefs.h b/dnn/include/megcore_cdefs.h index 7506d4b2..a1c72cf9 100644 --- a/dnn/include/megcore_cdefs.h +++ b/dnn/include/megcore_cdefs.h @@ -40,7 +40,6 @@ typedef enum { megcoreErrorInternalError = 5, } megcoreStatus_t; - /** * \brief Memcpy kind */ @@ -70,6 +69,6 @@ struct AsyncErrorInfo { char msg[228]; int msg_args[4]; }; -} // namespace megcore +} // namespace megcore // vim: syntax=cpp.doxygen diff --git a/dnn/include/megcore_cuda.h b/dnn/include/megcore_cuda.h index 50f57d17..c21dea2a 100644 --- a/dnn/include/megcore_cuda.h +++ b/dnn/include/megcore_cuda.h @@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithCUDAContext( megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, unsigned int flags, const CudaContext& ctx); -megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, - CudaContext* ctx); +megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, CudaContext* ctx); } // namespace megcore @@ -43,8 +42,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithCUDAStream( unsigned int flags, cudaStream_t stream) { megcore::CudaContext ctx; ctx.stream = stream; - return megcore::createComputingHandleWithCUDAContext(compHandle, devHandle, - flags, ctx); + return megcore::createComputingHandleWithCUDAContext( + compHandle, devHandle, flags, ctx); } static inline megcoreStatus_t megcoreGetCUDAStream( diff --git a/dnn/include/megcore_rocm.h b/dnn/include/megcore_rocm.h index 4cb5bea5..7467020e 100644 --- a/dnn/include/megcore_rocm.h +++ b/dnn/include/megcore_rocm.h @@ -23,7 +23,9 @@ struct ROCMContext { hipStream_t stream = nullptr; static std::atomic_bool sm_miopen_algo_search; - static inline bool enable_miopen_algo_search() { return sm_miopen_algo_search.load(); } + static inline bool enable_miopen_algo_search() { + return sm_miopen_algo_search.load(); + } static inline void enable_miopen_algo_search(bool enable_algo_search) { sm_miopen_algo_search.store(enable_algo_search); } @@ -40,8 +42,7 @@ megcoreStatus_t createComputingHandleWithROCMContext( megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, unsigned int flags, const ROCMContext& ctx); -megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, - ROCMContext* ctx); +megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, ROCMContext* ctx); // Set MIOpen algo search enabled or disabled megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true); @@ -55,8 +56,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithROCMStream( unsigned int flags, hipStream_t stream) { megcore::ROCMContext ctx; ctx.stream = stream; - return megcore::createComputingHandleWithROCMContext(compHandle, devHandle, - flags, ctx); + return megcore::createComputingHandleWithROCMContext( + compHandle, devHandle, flags, ctx); } static inline megcoreStatus_t megcoreGetROCMStream( diff --git a/dnn/include/megdnn.h b/dnn/include/megdnn.h index 2d60c166..0c5381bd 100644 --- a/dnn/include/megdnn.h +++ b/dnn/include/megdnn.h @@ -10,7 +10,7 @@ */ #pragma once -#include "megdnn/version.h" #include "megdnn/oprs.h" +#include "megdnn/version.h" // vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/arch.h b/dnn/include/megdnn/arch.h index 7e2a4341..ebaa59f8 100644 --- a/dnn/include/megdnn/arch.h +++ b/dnn/include/megdnn/arch.h @@ -14,20 +14,20 @@ #include "megdnn/config/config.h" #if defined(__GNUC__) || defined(__clang__) - #if !defined (__clang__) - // gcc specific - #define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) - #if GCC_VERSION < 40800 - #error "GCC version should be at least 4.8.0." - #endif // GCC_VERSION < 40800 - #endif // !defined(__clang__) +#if !defined(__clang__) +// gcc specific +#define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) +#if GCC_VERSION < 40800 +#error "GCC version should be at least 4.8.0." +#endif // GCC_VERSION < 40800 +#endif // !defined(__clang__) - #ifndef megdnn_trap - #define megdnn_trap() __builtin_trap() - #endif +#ifndef megdnn_trap +#define megdnn_trap() __builtin_trap() +#endif - #define megdnn_likely(v) __builtin_expect(bool(v), 1) - #define megdnn_unlikely(v) __builtin_expect(bool(v), 0) +#define megdnn_likely(v) __builtin_expect(bool(v), 1) +#define megdnn_unlikely(v) __builtin_expect(bool(v), 0) #if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG) //! Thumb2 limit code length @@ -36,123 +36,122 @@ #define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__)) #endif - #define MEGDNN_DEPRECATED __attribute__((deprecated)) - #define MEGDNN_PACKED __attribute__((packed)) - #define MEGDNN_CONSTEXPR constexpr - #define MEGDNN_NOEXCEPT noexcept - #define MEGDNN_STATIC_ASSERT static_assert - #define MEGDNN_FINAL final - #define MEGDNN_NORETURN __attribute__((noreturn)) - #define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) - #define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) - #if defined(__clang_major__) && (__clang_major__ >= 7) - #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) - #else - #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] - #endif - #define MEGDNN_NOINLINE __attribute__((noinline)) - - #define megdnn_isatty(x) isatty(x) +#define MEGDNN_DEPRECATED __attribute__((deprecated)) +#define MEGDNN_PACKED __attribute__((packed)) +#define MEGDNN_CONSTEXPR constexpr +#define MEGDNN_NOEXCEPT noexcept +#define MEGDNN_STATIC_ASSERT static_assert +#define MEGDNN_FINAL final +#define MEGDNN_NORETURN __attribute__((noreturn)) +#define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) +#define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) +#if defined(__clang_major__) && (__clang_major__ >= 7) +#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) +#else +#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] +#endif +#define MEGDNN_NOINLINE __attribute__((noinline)) + +#define megdnn_isatty(x) isatty(x) #elif defined(__INTEL_COMPILER) || defined(_MSC_VER) #ifndef megdnn_trap #define megdnn_trap() __debugbreak() #endif -#define megdnn_likely(v) (bool(v)) +#define megdnn_likely(v) (bool(v)) #define megdnn_unlikely(v) (bool(v)) #define MEGDNN_DEPRECATED #define MEGDNN_PACKED -#define MEGDNN_CONSTEXPR constexpr -#define MEGDNN_NOEXCEPT noexcept +#define MEGDNN_CONSTEXPR constexpr +#define MEGDNN_NOEXCEPT noexcept #define MEGDNN_STATIC_ASSERT static_assert -#define MEGDNN_FINAL final +#define MEGDNN_FINAL final #if defined(_MSC_VER) - #define MEGDNN_NORETURN __declspec(noreturn) - #define MEGDNN_NOINLINE __declspec(noinline) +#define MEGDNN_NORETURN __declspec(noreturn) +#define MEGDNN_NOINLINE __declspec(noinline) #else - #define MEGDNN_NORETURN - #define MEGDNN_FORCE_NOINLINE -#endif // _MSC_VER +#define MEGDNN_NORETURN +#define MEGDNN_FORCE_NOINLINE +#endif // _MSC_VER #define MEGDNN_WARN_UNUSED_RESULT #define megdnn_isatty(x) _isatty(x) #else - #error "unknown compiler" -#endif // __GNUC__ +#error "unknown compiler" +#endif // __GNUC__ // __cpp_exceptions and __cpp_rtti is referred from // https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations -// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, +// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, // similar for __GXX_RTTI // _CPPUNWIND and _CPPRTTI is used by MSVC, see // https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019 #ifndef MEGDNN_ENABLE_EXCEPTIONS - #if __cpp_exceptions || __EXCEPTIONS || \ - (defined(_MSC_VER) && defined(_CPPUNWIND)) - #define MEGDNN_ENABLE_EXCEPTIONS 1 - #else - #define MEGDNN_ENABLE_EXCEPTIONS 0 - #endif +#if __cpp_exceptions || __EXCEPTIONS || (defined(_MSC_VER) && defined(_CPPUNWIND)) +#define MEGDNN_ENABLE_EXCEPTIONS 1 +#else +#define MEGDNN_ENABLE_EXCEPTIONS 0 +#endif #endif #ifndef MEGDNN_ENABLE_RTTI - #if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) - #define MEGDNN_ENABLE_RTTI 1 - #else - #define MEGDNN_ENABLE_RTTI 0 - #endif +#if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) +#define MEGDNN_ENABLE_RTTI 1 +#else +#define MEGDNN_ENABLE_RTTI 0 +#endif #endif #ifdef __CUDACC__ - #define MEGDNN_CC_CUDA 1 - #undef MEGDNN_CONSTEXPR - #define MEGDNN_CONSTEXPR const +#define MEGDNN_CC_CUDA 1 +#undef MEGDNN_CONSTEXPR +#define MEGDNN_CONSTEXPR const #if defined(__CUDACC_VER_MAJOR__) #if __CUDACC_VER_MAJOR__ >= 9 - #undef MEGDNN_STATIC_ASSERT - #define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); +#undef MEGDNN_STATIC_ASSERT +#define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); #else - #undef MEGDNN_STATIC_ASSERT - #define MEGDNN_STATIC_ASSERT(cond, msg) +#undef MEGDNN_STATIC_ASSERT +#define MEGDNN_STATIC_ASSERT(cond, msg) #endif #endif - #define nullptr NULL - #undef MEGDNN_FINAL - #define MEGDNN_FINAL +#define nullptr NULL +#undef MEGDNN_FINAL +#define MEGDNN_FINAL #elif defined(__HIPCC__) - #define MEGDNN_CC_CUDA 1 +#define MEGDNN_CC_CUDA 1 #else - #define MEGDNN_CC_HOST 1 -#endif // __CUDACC__ +#define MEGDNN_CC_HOST 1 +#endif // __CUDACC__ // MEGDNN_HOST and MEGDNN_DEVICE #if MEGDNN_CC_CUDA - #define MEGDNN_HOST __host__ - #define MEGDNN_DEVICE __device__ +#define MEGDNN_HOST __host__ +#define MEGDNN_DEVICE __device__ #else - #define MEGDNN_HOST - #define MEGDNN_DEVICE +#define MEGDNN_HOST +#define MEGDNN_DEVICE #endif #if MEGDNN_CC_CUDA - #define MEGDNN_FORCE_INLINE __forceinline__ +#define MEGDNN_FORCE_INLINE __forceinline__ #else #if __GNUC__ || __has_attribute(always_inline) - #define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) +#define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) #else - #define MEGDNN_FORCE_INLINE inline +#define MEGDNN_FORCE_INLINE inline #endif #endif #if defined(_MSC_VER) || defined(WIN32) - #define ATTR_ALIGNED(v) __declspec(align(v)) +#define ATTR_ALIGNED(v) __declspec(align(v)) #else - #define ATTR_ALIGNED(v) __attribute__((aligned(v))) +#define ATTR_ALIGNED(v) __attribute__((aligned(v))) #endif // vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/basic_types.h b/dnn/include/megdnn/basic_types.h index c10cd3b9..53f22c9a 100644 --- a/dnn/include/megdnn/basic_types.h +++ b/dnn/include/megdnn/basic_types.h @@ -16,10 +16,10 @@ #include "megdnn/internal/defs.h" #if MEGDNN_CC_HOST +#include #include #include #include -#include #include "megdnn/thin/small_vector.h" #endif // MEGDNN_CC_HOST @@ -35,8 +35,7 @@ class ErrorHandler { protected: MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0; - MEGDNN_NORETURN virtual void do_on_tensor_reshape_error( - const std::string& msg) { + MEGDNN_NORETURN virtual void do_on_tensor_reshape_error(const std::string& msg) { on_megdnn_error(msg); } @@ -70,8 +69,9 @@ public: #if MEGDNN_CC_HOST enum class LogLevel { DEBUG, INFO, WARN, ERROR }; -typedef void (*LogHandler)(LogLevel level, const char* file, const char* func, - int line, const char* fmt, va_list ap); +typedef void (*LogHandler)( + LogLevel level, const char* file, const char* func, int line, const char* fmt, + va_list ap); /*! * \brief set the callback to receive all log messages @@ -144,8 +144,7 @@ struct TensorLayout : public TensorShape { ptrdiff_t low_elem, low_byte; size_t high_elem, high_byte; - Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, - size_t high_byte) + Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, size_t high_byte) : low_elem(low_elem), low_byte(low_byte), high_elem(high_elem), @@ -235,11 +234,13 @@ struct TensorLayout : public TensorShape { TensorLayout(const TensorShape& shape, DType dtype, Format format); //! creating layout with user-specified shape and stride. - TensorLayout(const TensorShape& shape, const std::vector& stride, - DType dtype); + TensorLayout( + const TensorShape& shape, const std::vector& stride, + DType dtype); - TensorLayout(const TensorShape& shape, const std::vector& stride, - DType dtype, Format format); + TensorLayout( + const TensorShape& shape, const std::vector& stride, DType dtype, + Format format); /* =================== inplace modifiers =================== */ @@ -310,8 +311,7 @@ struct TensorLayout : public TensorShape { * * \throw TensorReshapeError if no stride exists for target shape. */ - TensorLayout reshape(const TensorShape& shape) const - MEGDNN_WARN_UNUSED_RESULT; + TensorLayout reshape(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; /*! * \brief try to reshape to another view; return whether these two shapes @@ -319,15 +319,14 @@ struct TensorLayout : public TensorShape { * \return true iff there exists target stride so this layout can be * converted to target shape and the elements can match. */ - bool try_reshape(TensorLayout& output, - const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; + bool try_reshape(TensorLayout& output, const TensorShape& shape) const + MEGDNN_WARN_UNUSED_RESULT; /*! * \brief Broadcast on dims with shape == 1 to match target *shape*. * \throw TensorReshapeError if could not be satisfied */ - TensorLayout broadcast(const TensorShape& shape) const - MEGDNN_WARN_UNUSED_RESULT; + TensorLayout broadcast(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; /*! * \brief Collapse consecutive axes with contiguous layout together @@ -441,8 +440,7 @@ struct Workspace { Workspace() : raw_ptr(NULL), size(0) {} - Workspace(dt_byte* raw_ptr_, size_t size_) - : raw_ptr(raw_ptr_), size(size_) {} + Workspace(dt_byte* raw_ptr_, size_t size_) : raw_ptr(raw_ptr_), size(size_) {} template T* ptr(size_t offset_in_bytes = 0) const { @@ -467,9 +465,8 @@ public: * \param shape requested output shape * \param user_data extra user data passed in DynOutMallocPolicyCall */ - virtual TensorND alloc_output(size_t id, DType dtype, - const TensorShape& shape, - void* user_data) = 0; + virtual TensorND alloc_output( + size_t id, DType dtype, const TensorShape& shape, void* user_data) = 0; /*! * \brief allocate workspace memory @@ -508,19 +505,15 @@ struct DynOutMallocPolicyCall { */ template T* alloc_workspace(size_t nr_elem) { - using real_elem = - typename std::conditional::value, - uint8_t, elem>::type; - return static_cast(policy->alloc_workspace( - nr_elem * sizeof(real_elem), user_data)); + using real_elem = typename std::conditional< + std::is_same::value, uint8_t, elem>::type; + return static_cast( + policy->alloc_workspace(nr_elem * sizeof(real_elem), user_data)); } - void free_workspace(void* ptr) { - return policy->free_workspace(ptr, user_data); - } + void free_workspace(void* ptr) { return policy->free_workspace(ptr, user_data); } }; - template class EnumClassBit { std::underlying_type_t m_val; @@ -528,8 +521,7 @@ class EnumClassBit { constexpr EnumClassBit(std::underlying_type_t v) : m_val(v) {} public: - constexpr EnumClassBit(T v) - : m_val(static_cast>(v)) {} + constexpr EnumClassBit(T v) : m_val(static_cast>(v)) {} constexpr operator T() const { return static_cast(m_val); } @@ -542,7 +534,7 @@ public: DEF_OPR(&) DEF_OPR(|) - DEF_OPR (^) + DEF_OPR(^) constexpr EnumClassBit operator~() const { return ~m_val; } @@ -553,14 +545,13 @@ public: } // namespace megdnn -#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ - inline constexpr ::megdnn::EnumClassBit operator op(cls x, cls y) { \ - return ::megdnn::EnumClassBit(x) \ - op ::megdnn::EnumClassBit(y); \ - } \ - inline constexpr ::megdnn::EnumClassBit operator op( \ - ::megdnn::EnumClassBit x, cls y) { \ - return x op ::megdnn::EnumClassBit(y); \ +#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ + inline constexpr ::megdnn::EnumClassBit operator op(cls x, cls y) { \ + return ::megdnn::EnumClassBit(x) op ::megdnn::EnumClassBit(y); \ + } \ + inline constexpr ::megdnn::EnumClassBit operator op( \ + ::megdnn::EnumClassBit x, cls y) { \ + return x op ::megdnn::EnumClassBit(y); \ } #define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ diff --git a/dnn/include/megdnn/common.h b/dnn/include/megdnn/common.h index aa23912f..142b8f4c 100644 --- a/dnn/include/megdnn/common.h +++ b/dnn/include/megdnn/common.h @@ -14,14 +14,14 @@ #include "megbrain_build_config.h" #if MGB_ENABLE_GETENV -#define MGB_GETENV ::std::getenv +#define MGB_GETENV ::std::getenv #else -#define MGB_GETENV(_name) static_cast(nullptr) +#define MGB_GETENV(_name) static_cast(nullptr) #endif #ifdef WIN32 -#define unsetenv(_name) _putenv_s(_name, ""); -#define setenv(name,value,overwrite) _putenv_s(name,value) +#define unsetenv(_name) _putenv_s(_name, ""); +#define setenv(name, value, overwrite) _putenv_s(name, value) #endif namespace megdnn { @@ -32,8 +32,7 @@ namespace megdnn { */ template bool has_available_algo(Opr* opr, Args&&... args) { - const typename Opr::AlgoBase::SizeArgs size_args( - opr, std::forward(args)...); + const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward(args)...); for (auto i : Opr::algo_pack().all_algos) { if (i->is_available(size_args)) { return true; @@ -42,6 +41,6 @@ bool has_available_algo(Opr* opr, Args&&... args) { return false; } -} +} // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/include/megdnn/cuda.h b/dnn/include/megdnn/cuda.h index a8f6f198..b7b5a60a 100644 --- a/dnn/include/megdnn/cuda.h +++ b/dnn/include/megdnn/cuda.h @@ -17,11 +17,11 @@ #include "megdnn/internal/visibility_prologue.h" namespace megdnn { -std::unique_ptr make_cuda_handle_with_stream(cudaStream_t stream, - int device_id = -1); -cudaStream_t get_cuda_stream(Handle *handle); +std::unique_ptr make_cuda_handle_with_stream( + cudaStream_t stream, int device_id = -1); +cudaStream_t get_cuda_stream(Handle* handle); -} // namespace megdnn +} // namespace megdnn #include "megdnn/internal/visibility_epilogue.h" // vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index dcc57cfe..fbcce01a 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -26,12 +26,12 @@ #if MEGDNN_DISABLE_FLOAT16 #define DNN_INC_FLOAT16(_x) -#define DNN_FLOAT16_SELECT(_x, _y) _y +#define DNN_FLOAT16_SELECT(_x, _y) _y #else -#include "megdnn/dtype/half.hpp" #include "megdnn/dtype/bfloat16.hpp" -#define DNN_INC_FLOAT16(_x) _x -#define DNN_FLOAT16_SELECT(_x, _y) _x +#include "megdnn/dtype/half.hpp" +#define DNN_INC_FLOAT16(_x) _x +#define DNN_FLOAT16_SELECT(_x, _y) _x #endif namespace megdnn { @@ -39,65 +39,39 @@ namespace megdnn { /*! * \brief iterate through each dtype name */ -#define MEGDNN_FOREACH_DTYPE_NAME(cb) \ - cb(Float32) \ - cb(Uint8) \ - cb(Int8) \ - cb(Int16) \ - cb(Int32) \ - cb(IntB1) \ - cb(IntB2) \ - cb(IntB4) \ - cb(Byte) \ - DNN_INC_FLOAT16(cb(Float16)) \ - DNN_INC_FLOAT16(cb(BFloat16)) \ - cb(UintB4) \ - cb(Bool) \ - cb(Uint16) \ +#define MEGDNN_FOREACH_DTYPE_NAME(cb) \ + cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(IntB1) cb(IntB2) cb(IntB4) \ + cb(Byte) DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) \ + cb(UintB4) cb(Bool) cb(Uint16) /*! * \brief iterate through each full byte dtype */ -#define MEGDNN_FOREACH_FULL_BYTE_DTYPE(cb) \ - cb(Float32) \ - cb(Uint8) \ - cb(Int8) \ - cb(Int16) \ - cb(Int32) \ - cb(Byte) \ - DNN_INC_FLOAT16(cb(Float16)) \ - DNN_INC_FLOAT16(cb(BFloat16)) \ - cb(Bool) \ - cb(Uint16) \ +#define MEGDNN_FOREACH_FULL_BYTE_DTYPE(cb) \ + cb(Float32) cb(Uint8) cb(Int8) cb(Int16) cb(Int32) cb(Byte) \ + DNN_INC_FLOAT16(cb(Float16)) DNN_INC_FLOAT16(cb(BFloat16)) cb(Bool) \ + cb(Uint16) /*! * \brief iterate through each fractional byte dtype */ -#define MEGDNN_FOREACH_LOWBIT_DTYPE(cb) \ - cb(IntB, 1)\ - cb(IntB, 2)\ - cb(IntB, 4)\ - cb(UintB, 4)\ +#define MEGDNN_FOREACH_LOWBIT_DTYPE(cb) cb(IntB, 1) cb(IntB, 2) cb(IntB, 4) cb(UintB, 4) // This is used to make enum definition possible. -#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb) \ - cb(Quantized8Asymm) +#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb) cb(Quantized8Asymm) -#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) \ - cb(QuantizedS32) \ - cb(QuantizedS8) \ - cb(Quantized4Asymm) \ - cb(QuantizedS4) \ - cb(QuantizedS16) +#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) \ + cb(QuantizedS32) cb(QuantizedS8) cb(Quantized4Asymm) cb(QuantizedS4) \ + cb(QuantizedS16) #define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(cb_first, cb_others) \ - MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb_first) \ + MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb_first) \ MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb_others) /*! * \brief iterate through each parameterized dtype */ -#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) \ +#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) \ MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb) \ MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) @@ -106,50 +80,42 @@ namespace megdnn { * numeric computing */ -#define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ - cb(::megdnn::dtype::Float32) \ - DNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ - DNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) - +#define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ + cb(::megdnn::dtype::Float32) DNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ + DNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) /*! * \brief iterate through each dtype object that can be involved in integer * numeric computing */ -#define MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) \ - cb(::megdnn::dtype::Int32) \ - cb(::megdnn::dtype::Int16) \ - cb(::megdnn::dtype::Int8) \ - cb(::megdnn::dtype::Uint8) \ +#define MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) \ + cb(::megdnn::dtype::Int32) cb(::megdnn::dtype::Int16) cb(::megdnn::dtype::Int8) \ + cb(::megdnn::dtype::Uint8) /*! * \brief iterate through each dtype object that can be involved in numeric * computing (i.e. dtypes except Byte) */ -#define MEGDNN_FOREACH_COMPUTING_DTYPE(cb) \ +#define MEGDNN_FOREACH_COMPUTING_DTYPE(cb) \ MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ - MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) \ + MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) //! In order to avoid an unnecessary increase in binary size, we just //! use QuantizedS16 dtype in winograd_filter_preprocess now. So I didn't add //! this data type here. -#define MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) \ - cb(::megdnn::dtype::Quantized8Asymm) \ - cb(::megdnn::dtype::QuantizedS32) \ - cb(::megdnn::dtype::QuantizedS8) \ +#define MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) \ + cb(::megdnn::dtype::Quantized8Asymm) cb(::megdnn::dtype::QuantizedS32) \ + cb(::megdnn::dtype::QuantizedS8) #define MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) \ - cb(::megdnn::dtype::Quantized4Asymm) \ - cb(::megdnn::dtype::QuantizedS4) + cb(::megdnn::dtype::Quantized4Asymm) cb(::megdnn::dtype::QuantizedS4) -#define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \ - cb(::megdnn::dtype::QuantizedS32) \ - cb(::megdnn::dtype::QuantizedS8) \ - cb(::megdnn::dtype::QuantizedS4) +#define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \ + cb(::megdnn::dtype::QuantizedS32) cb(::megdnn::dtype::QuantizedS8) \ + cb(::megdnn::dtype::QuantizedS4) #define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \ - cb(::megdnn::dtype::Quantized8Asymm) \ - cb(::megdnn::dtype::Quantized4Asymm) + cb(::megdnn::dtype::Quantized8Asymm) cb(::megdnn::dtype::Quantized4Asymm) /*! * \brief a POD representation of a single byte @@ -164,24 +130,23 @@ namespace megdnn { class dt_byte { unsigned char _; - public: - - //! convert to given type - template - T* as() { - return reinterpret_cast(this); - } +public: + //! convert to given type + template + T* as() { + return reinterpret_cast(this); + } - //! convert to given type - template - const T* as() const { - return reinterpret_cast(this); - } + //! convert to given type + template + const T* as() const { + return reinterpret_cast(this); + } } MEGDNN_PACKED; #define DEFINE_LOWBIT(_name, b) \ - class dt_##_name##b {\ - unsigned char _;\ + class dt_##_name##b { \ + unsigned char _; \ } MEGDNN_PACKED; MEGDNN_FOREACH_LOWBIT_DTYPE(DEFINE_LOWBIT) #undef DEFINE_LOWBIT @@ -189,138 +154,117 @@ MEGDNN_FOREACH_LOWBIT_DTYPE(DEFINE_LOWBIT) class dt_quint8 { uint8_t _; - public: - //! Convert to normal uint8_t - MEGDNN_DEVICE uint8_t as_uint8() const { - return _; - } +public: + //! Convert to normal uint8_t + MEGDNN_DEVICE uint8_t as_uint8() const { return _; } - MEGDNN_HOST MEGDNN_DEVICE explicit dt_quint8(uint8_t val):_(val) {} + MEGDNN_HOST MEGDNN_DEVICE explicit dt_quint8(uint8_t val) : _(val) {} #ifdef MEGDNN_CC_HOST - explicit operator uint8_t() { return _; } + explicit operator uint8_t() { return _; } #endif - bool operator<(const dt_quint8& b) const { return _ < b._; } - bool operator>(const dt_quint8& b) const { return _ > b._; } - bool operator==(const dt_quint8& b) const { return _ == b._; } - bool operator!=(const dt_quint8& b) const { return _ != b._; } + bool operator<(const dt_quint8& b) const { return _ < b._; } + bool operator>(const dt_quint8& b) const { return _ > b._; } + bool operator==(const dt_quint8& b) const { return _ == b._; } + bool operator!=(const dt_quint8& b) const { return _ != b._; } } MEGDNN_PACKED; class dt_qint32 { int32_t _; - public: - //! Convert to normal uint32_t - MEGDNN_DEVICE int32_t as_int32() const { - return _; - } +public: + //! Convert to normal uint32_t + MEGDNN_DEVICE int32_t as_int32() const { return _; } - MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint32(int32_t val):_(val) {} + MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint32(int32_t val) : _(val) {} #ifdef MEGDNN_CC_HOST - explicit operator int32_t() { return _; } + explicit operator int32_t() { return _; } #endif - dt_qint32 operator*(const dt_qint32& b) const { - return dt_qint32(_ * b._); - } - dt_qint32 operator+(const dt_qint32& b) const { - return dt_qint32(_ + b._); - } - dt_qint32 operator-(const dt_qint32& b) const { - return dt_qint32(_ - b._); - } + dt_qint32 operator*(const dt_qint32& b) const { return dt_qint32(_ * b._); } + dt_qint32 operator+(const dt_qint32& b) const { return dt_qint32(_ + b._); } + dt_qint32 operator-(const dt_qint32& b) const { return dt_qint32(_ - b._); } #ifdef MEGDNN_CC_HOST - dt_qint32 operator/(int b) const { - return dt_qint32(std::round(_ / static_cast(b))); - } - dt_qint32 operator/(const dt_qint32& b) const { - return dt_qint32(std::round(_ / static_cast(b._))); - } + dt_qint32 operator/(int b) const { + return dt_qint32(std::round(_ / static_cast(b))); + } + dt_qint32 operator/(const dt_qint32& b) const { + return dt_qint32(std::round(_ / static_cast(b._))); + } #endif - dt_qint32 operator+=(const dt_qint32& b) { - _ += b._; - return *this; - } - bool operator<(const dt_qint32& b) const { return _ < b._; } - bool operator>(const dt_qint32& b) const { return _ > b._; } + dt_qint32 operator+=(const dt_qint32& b) { + _ += b._; + return *this; + } + bool operator<(const dt_qint32& b) const { return _ < b._; } + bool operator>(const dt_qint32& b) const { return _ > b._; } } MEGDNN_PACKED; class dt_qint8 { int8_t _; - public: - MEGDNN_DEVICE int8_t as_int8() const { - return _; - } +public: + MEGDNN_DEVICE int8_t as_int8() const { return _; } - MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint8(int8_t val):_(val) {} + MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint8(int8_t val) : _(val) {} #ifdef MEGDNN_CC_HOST - explicit operator int8_t() { return _; } + explicit operator int8_t() { return _; } #endif - bool operator<(const dt_qint8& b) const { return _ < b._; } - bool operator>(const dt_qint8& b) const { return _ > b._; } - bool operator==(const dt_qint8& b) const { return _ == b._; } - bool operator!=(const dt_qint8& b) const { return _ != b._; } + bool operator<(const dt_qint8& b) const { return _ < b._; } + bool operator>(const dt_qint8& b) const { return _ > b._; } + bool operator==(const dt_qint8& b) const { return _ == b._; } + bool operator!=(const dt_qint8& b) const { return _ != b._; } } MEGDNN_PACKED; class dt_qint16 { int16_t _; - public: - //! Convert to normal int16_t - MEGDNN_DEVICE int16_t as_int16() const { - return _; - } +public: + //! Convert to normal int16_t + MEGDNN_DEVICE int16_t as_int16() const { return _; } - MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint16(int16_t val):_(val) {} + MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint16(int16_t val) : _(val) {} #ifdef MEGDNN_CC_HOST - explicit operator int16_t() { return _; } + explicit operator int16_t() { return _; } #endif - dt_qint16 operator*(const dt_qint16& b) const { - return dt_qint16(_ * b._); - } - dt_qint16 operator+(const dt_qint16& b) const { - return dt_qint16(_ + b._); - } - dt_qint16 operator-(const dt_qint16& b) const { - return dt_qint16(_ - b._); - } + dt_qint16 operator*(const dt_qint16& b) const { return dt_qint16(_ * b._); } + dt_qint16 operator+(const dt_qint16& b) const { return dt_qint16(_ + b._); } + dt_qint16 operator-(const dt_qint16& b) const { return dt_qint16(_ - b._); } #ifdef MEGDNN_CC_HOST - dt_qint16 operator/(int b) const { - return dt_qint16(std::round(_ / static_cast(b))); - } - dt_qint16 operator/(const dt_qint16& b) const { - return dt_qint16(std::round(_ / static_cast(b._))); - } + dt_qint16 operator/(int b) const { + return dt_qint16(std::round(_ / static_cast(b))); + } + dt_qint16 operator/(const dt_qint16& b) const { + return dt_qint16(std::round(_ / static_cast(b._))); + } #endif - dt_qint16 operator+=(const dt_qint16& b) { - _ += b._; - return *this; - } - bool operator<(const dt_qint16& b) const { return _ < b._; } - bool operator>(const dt_qint16& b) const { return _ > b._; } + dt_qint16 operator+=(const dt_qint16& b) { + _ += b._; + return *this; + } + bool operator<(const dt_qint16& b) const { return _ < b._; } + bool operator>(const dt_qint16& b) const { return _ > b._; } } MEGDNN_PACKED; template class dt_qulowbit { uint8_t _; - public: - //! Convert to normal uint8_t - MEGDNN_DEVICE uint8_t as_uint8() const { - return _; - } - MEGDNN_DEVICE uint8_t as_storage() const { return _; } +public: + //! Convert to normal uint8_t + MEGDNN_DEVICE uint8_t as_uint8() const { return _; } + + MEGDNN_DEVICE uint8_t as_storage() const { return _; } - MEGDNN_HOST MEGDNN_DEVICE explicit dt_qulowbit(uint8_t val):_(val) {} + MEGDNN_HOST MEGDNN_DEVICE explicit dt_qulowbit(uint8_t val) : _(val) {} #ifdef MEGDNN_CC_HOST - explicit operator uint8_t() { return _; } + explicit operator uint8_t() { return _; } #endif - bool operator<(const dt_qulowbit& b) const { return _ < b._; } - bool operator>(const dt_qulowbit& b) const { return _ > b._; } + bool operator<(const dt_qulowbit& b) const { return _ < b._; } + bool operator>(const dt_qulowbit& b) const { return _ > b._; } - dt_qulowbit& operator=(const uint8_t val) { - _ = val; - return *this; - } + dt_qulowbit& operator=(const uint8_t val) { + _ = val; + return *this; + } }; using dt_quint4 = dt_qulowbit<4>; @@ -328,25 +272,23 @@ template class dt_qlowbit { int8_t _; - public: - //! Convert to normal int8_t - MEGDNN_DEVICE int8_t as_int8() const { - return _; - } +public: + //! Convert to normal int8_t + MEGDNN_DEVICE int8_t as_int8() const { return _; } - MEGDNN_DEVICE int8_t as_storage() const { return _; } + MEGDNN_DEVICE int8_t as_storage() const { return _; } - MEGDNN_HOST MEGDNN_DEVICE explicit dt_qlowbit(int8_t val):_(val) {} + MEGDNN_HOST MEGDNN_DEVICE explicit dt_qlowbit(int8_t val) : _(val) {} #ifdef MEGDNN_CC_HOST - explicit operator int8_t() { return _; } + explicit operator int8_t() { return _; } #endif - bool operator<(const dt_qlowbit& b) const { return _ < b._; } - bool operator>(const dt_qlowbit& b) const { return _ > b._; } + bool operator<(const dt_qlowbit& b) const { return _ < b._; } + bool operator>(const dt_qlowbit& b) const { return _ > b._; } - dt_qlowbit& operator=(const int8_t val) { - _ = val; - return *this; - } + dt_qlowbit& operator=(const int8_t val) { + _ = val; + return *this; + } }; using dt_qint4 = dt_qlowbit<4>; @@ -369,63 +311,55 @@ DNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) #define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000 #if MEGDNN_CC_HOST - //! enumeration of dtypes; useful for hash or being used in switch-case - enum class DTypeEnum: uint32_t { +//! enumeration of dtypes; useful for hash or being used in switch-case +enum class DTypeEnum : uint32_t { #else - struct DTypeEnum { - enum Ev { +struct DTypeEnum { + enum Ev { #endif - Float32, - Uint8, - Int8, - Int16, - Int32, - IntB1, - IntB2, - IntB4, - Byte, + Float32, + Uint8, + Int8, + Int16, + Int32, + IntB1, + IntB2, + IntB4, + Byte, #if !MEGDNN_DISABLE_FLOAT16 - Float16, + Float16, #endif - UintB4 = 10, + UintB4 = 10, #if !MEGDNN_DISABLE_FLOAT16 - BFloat16 = 11, + BFloat16 = 11, #endif - Bool = 12, - Uint16 = 13, - #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, - #define D(_name) _name, - MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D) - #undef D - #undef FST + Bool = 12, + Uint16 = 13, +#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, +#define D(_name) _name, + MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D) +#undef D +#undef FST #if !MEGDNN_CC_HOST - }; - uint32_t ev; - DTypeEnum(): ev(0) {} - DTypeEnum(uint32_t e): ev(e) {} +}; +uint32_t ev; +DTypeEnum() : ev(0) {} +DTypeEnum(uint32_t e) : ev(e) {} #endif - }; +}; #if MEGDNN_CC_HOST - //! dtype numeric category fo - enum class DTypeCategory: int { - OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL - }; - //! dtype signedness - enum class DTypeSignedness: int { - OTHER, UNSIGNED, SIGNED - }; +//! dtype numeric category fo +enum class DTypeCategory : int { OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL }; +//! dtype signedness +enum class DTypeSignedness : int { OTHER, UNSIGNED, SIGNED }; #else struct DTypeCategory { - enum Ev { - OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL - }; + enum Ev { OTHER, FLOAT, INT, LOWBIT, QUANTIZED, BOOL }; int ev; }; struct DTypeSignedness { - enum Ev { - OTHER, UNSIGNED, SIGNED - }; + enum Ev { OTHER, UNSIGNED, SIGNED }; int ev; }; #endif @@ -451,149 +385,128 @@ using DTypeParam = DTypeParamImpl::ctype>; * \brief Information about a data type that can be accessed at runtime */ class DType { - private: - MEGDNN_NORETURN void on_request_lowbit_size() const; +private: + MEGDNN_NORETURN void on_request_lowbit_size() const; // HACK: This is required in ParameterizedDType::downcast_from - public: - MEGDNN_NORETURN void on_assert_is_failed(const char *rname) const; - protected: - struct Trait { - const char *const name; - const uint16_t size_log; //!< log2 of sizeof(dt) for non-lowbit - const uint16_t low_bit; //!< 0 for non-lowbit; otherwise num bits - DTypeEnum enumv; - DTypeCategory category; - DTypeSignedness signedness; - const bool has_param; - }; - Trait *m_trait; - - explicit DType(Trait *t): - m_trait(t) - {} - - public: - DType(): - m_trait(nullptr) - {} - - bool valid() const { - return m_trait != nullptr; - } +public: + MEGDNN_NORETURN void on_assert_is_failed(const char* rname) const; + +protected: + struct Trait { + const char* const name; + const uint16_t size_log; //!< log2 of sizeof(dt) for non-lowbit + const uint16_t low_bit; //!< 0 for non-lowbit; otherwise num bits + DTypeEnum enumv; + DTypeCategory category; + DTypeSignedness signedness; + const bool has_param; + }; + Trait* m_trait; - /*! - * \brief name of this data type - */ - const char *name() const { - return m_trait ? m_trait->name : "invalid"; - } + explicit DType(Trait* t) : m_trait(t) {} - /*! - * \brief size of elem_num this data type, if fraction form return ceil - */ - size_t size(size_t elem_num) const { - if (m_trait->low_bit != 0) - return static_cast( (m_trait->low_bit*elem_num + 7)/8 ); - return elem_num << m_trait->size_log; - } +public: + DType() : m_trait(nullptr) {} + + bool valid() const { return m_trait != nullptr; } + + /*! + * \brief name of this data type + */ + const char* name() const { return m_trait ? m_trait->name : "invalid"; } + + /*! + * \brief size of elem_num this data type, if fraction form return ceil + */ + size_t size(size_t elem_num) const { + if (m_trait->low_bit != 0) + return static_cast((m_trait->low_bit * elem_num + 7) / 8); + return elem_num << m_trait->size_log; + } - /*! - * \brief max number of elements within representation - * - * The total size of the tensor (in bytes) should not exceed size_t range. - */ - size_t max_elements() const { - if (m_trait->low_bit != 0) - return std::numeric_limits::max(); + /*! + * \brief max number of elements within representation + * + * The total size of the tensor (in bytes) should not exceed size_t range. + */ + size_t max_elements() const { + if (m_trait->low_bit != 0) + return std::numeric_limits::max(); - return std::numeric_limits::max() >> m_trait->size_log; - } + return std::numeric_limits::max() >> m_trait->size_log; + } - size_t low_bit() const { return m_trait->low_bit; } + size_t low_bit() const { return m_trait->low_bit; } - bool is_low_bit() const { return low_bit() != 0; } + bool is_low_bit() const { return low_bit() != 0; } - bool is_quantized_lowbit() const { - return low_bit() != 0 && + bool is_quantized_lowbit() const { + return low_bit() != 0 && #if MEGDNN_CC_HOST - category() == DTypeCategory::QUANTIZED; + category() == DTypeCategory::QUANTIZED; #else category().ev == DTypeCategory::Ev::QUANTIZED; #endif - } + } - /*! - * \brief size of this data type, in bytes - */ - size_t size() const { - if (m_trait->low_bit == 0) - return 1 << m_trait->size_log; - on_request_lowbit_size(); - } + /*! + * \brief size of this data type, in bytes + */ + size_t size() const { + if (m_trait->low_bit == 0) + return 1 << m_trait->size_log; + on_request_lowbit_size(); + } - //! size() in log2 - size_t size_log() const { - if (m_trait->low_bit == 0) - return m_trait->size_log; - on_request_lowbit_size(); - } + //! size() in log2 + size_t size_log() const { + if (m_trait->low_bit == 0) + return m_trait->size_log; + on_request_lowbit_size(); + } - //! assert this dtype is given type; throw exception on failure - void assert_is(const DType &rhs) const { - if (m_trait != rhs.m_trait) - on_assert_is_failed(rhs.name()); - } + //! assert this dtype is given type; throw exception on failure + void assert_is(const DType& rhs) const { + if (m_trait != rhs.m_trait) + on_assert_is_failed(rhs.name()); + } - template - inline void assert_is_ctype() const; + template + inline void assert_is_ctype() const; - template - inline void assert_is_compatible_ctype() const; + template + inline void assert_is_compatible_ctype() const; - //! get corresponding enum value for this dtype - DTypeEnum enumv() const { - return m_trait->enumv; - } + //! get corresponding enum value for this dtype + DTypeEnum enumv() const { return m_trait->enumv; } - //! get category of this data type - DTypeCategory category() const { - return m_trait->category; - } + //! get category of this data type + DTypeCategory category() const { return m_trait->category; } - //! get signedness of this data type - DTypeSignedness signedness() const { - return m_trait->signedness; - } + //! get signedness of this data type + DTypeSignedness signedness() const { return m_trait->signedness; } - bool has_param() const { - return m_trait->has_param; - } + bool has_param() const { return m_trait->has_param; } - bool operator == (const DType &rhs) const { - return m_trait == rhs.m_trait; - } + bool operator==(const DType& rhs) const { return m_trait == rhs.m_trait; } - bool operator != (const DType &rhs) const { - return m_trait != rhs.m_trait; - } + bool operator!=(const DType& rhs) const { return m_trait != rhs.m_trait; } - //! get dtype object from enum - static DType from_enum(DTypeEnum ev); + //! get dtype object from enum + static DType from_enum(DTypeEnum ev); - //! get a handle of the dtype that could be used for equivalence check - const void* handle() const { - return m_trait; - } + //! get a handle of the dtype that could be used for equivalence check + const void* handle() const { return m_trait; } - template - T as() const { - return T::downcast_from(*this); - } + template + T as() const { + return T::downcast_from(*this); + } - template - const DTypeParam& param() const { - return as::dtype>().param(); - } + template + const DTypeParam& param() const { + return as::dtype>().param(); + } }; #ifdef MEGDNN_CC_HOST @@ -613,8 +526,7 @@ class ParameterizedDType MEGDNN_FINAL : public DType { struct Trait : DType::Trait { DTypeParam param; - Trait(const DType::Trait& static_trait, - const DTypeParam& param) + Trait(const DType::Trait& static_trait, const DTypeParam& param) : DType::Trait(static_trait), param(param) {} }; @@ -649,9 +561,7 @@ public: } #pragma GCC diagnostic pop - const DTypeParam& param() { - return static_cast(m_trait)->param; - } + const DTypeParam& param() { return static_cast(m_trait)->param; } }; #endif // MEGDNN_CC_HOST @@ -659,11 +569,12 @@ public: //! dtype implementation classes namespace dtype { -#define IMPL(_name) \ - class _name MEGDNN_FINAL: public DType { \ - static Trait sm_trait; \ - public: \ - _name(): DType(&sm_trait) {} \ +#define IMPL(_name) \ + class _name MEGDNN_FINAL : public DType { \ + static Trait sm_trait; \ + \ + public: \ + _name() : DType(&sm_trait) {} \ }; MEGDNN_FOREACH_DTYPE_NAME(IMPL) @@ -679,53 +590,49 @@ MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) #undef cb //! log function used in DTypeTrait -template struct log { - static MEGDNN_CONSTEXPR size_t value = log<(n>>1)>::value + 1; +template +struct log { + static MEGDNN_CONSTEXPR size_t value = log<(n >> 1)>::value + 1; #if MEGDNN_CC_HOST - MEGDNN_STATIC_ASSERT( (n&(n-1)) == 0, "only full power number can have log"); + MEGDNN_STATIC_ASSERT((n & (n - 1)) == 0, "only full power number can have log"); #endif }; -template<> struct log<1> {static MEGDNN_CONSTEXPR size_t value = 0;}; +template <> +struct log<1> { + static MEGDNN_CONSTEXPR size_t value = 0; +}; -} // namespace dtype +} // namespace dtype // begin define DTypeTrait impls { #if MEGDNN_CC_HOST -#define MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, \ - _has_param) \ - static MEGDNN_CONSTEXPR const char *name = #_name; \ - using ctype = _ctype; \ - using dtype = ::megdnn::dtype::_name; \ - static MEGDNN_CONSTEXPR DTypeCategory category = DTypeCategory::_cat; \ - static MEGDNN_CONSTEXPR DTypeSignedness \ - signedness = DTypeSignedness::_sign; \ - static MEGDNN_CONSTEXPR uint16_t size_log = \ - ::megdnn::dtype::log::value; \ - static MEGDNN_CONSTEXPR DTypeEnum enumv = DTypeEnum::_name;\ - static MEGDNN_CONSTEXPR uint16_t low_bit = _bits;\ +#define MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, _has_param) \ + static MEGDNN_CONSTEXPR const char* name = #_name; \ + using ctype = _ctype; \ + using dtype = ::megdnn::dtype::_name; \ + static MEGDNN_CONSTEXPR DTypeCategory category = DTypeCategory::_cat; \ + static MEGDNN_CONSTEXPR DTypeSignedness signedness = DTypeSignedness::_sign; \ + static MEGDNN_CONSTEXPR uint16_t size_log = \ + ::megdnn::dtype::log::value; \ + static MEGDNN_CONSTEXPR DTypeEnum enumv = DTypeEnum::_name; \ + static MEGDNN_CONSTEXPR uint16_t low_bit = _bits; \ static MEGDNN_CONSTEXPR bool has_param = _has_param #else -#define MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, \ - _has_param) \ - typedef _ctype ctype; \ - typedef ::megdnn::dtype::_name dtype; \ - static const uint16_t size_log = \ - ::megdnn::dtype::log::value; \ - static MEGDNN_CONSTEXPR int enumv = DTypeEnum::_name;\ +#define MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, _has_param) \ + typedef _ctype ctype; \ + typedef ::megdnn::dtype::_name dtype; \ + static const uint16_t size_log = ::megdnn::dtype::log::value; \ + static MEGDNN_CONSTEXPR int enumv = DTypeEnum::_name; \ static MEGDNN_CONSTEXPR uint16_t low_bit = _bits -#endif // MEGDNN_CC_HOST +#endif // MEGDNN_CC_HOST -#define MEGDNN_DEF_DT(_name, _ctype, _cat, _sign, _minval, _maxval) \ - template <> \ - struct DTypeTrait { \ +#define MEGDNN_DEF_DT(_name, _ctype, _cat, _sign, _minval, _maxval) \ + template <> \ + struct DTypeTrait { \ MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, 0, false); \ - MEGDNN_HOST MEGDNN_DEVICE static ctype min() { \ - return _minval; \ - } \ - MEGDNN_HOST MEGDNN_DEVICE static ctype max() { \ - return _maxval; \ - } \ + MEGDNN_HOST MEGDNN_DEVICE static ctype min() { return _minval; } \ + MEGDNN_HOST MEGDNN_DEVICE static ctype max() { return _maxval; } \ } MEGDNN_DEF_DT(Float32, dt_float32, FLOAT, SIGNED, -FLT_MAX, FLT_MAX); @@ -735,29 +642,29 @@ MEGDNN_DEF_DT(Int8, dt_int8, INT, SIGNED, INT8_MIN, INT8_MAX); MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX); MEGDNN_DEF_DT(Bool, dt_bool, BOOL, UNSIGNED, false, true); MEGDNN_DEF_DT(Uint16, dt_uint16, INT, UNSIGNED, 0, UINT16_MAX); -DNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, - std::numeric_limits::lowest(), - std::numeric_limits::max())); -DNN_INC_FLOAT16(MEGDNN_DEF_DT(BFloat16, dt_bfloat16, FLOAT, SIGNED, - std::numeric_limits::lowest(), - std::numeric_limits::max())); +DNN_INC_FLOAT16(MEGDNN_DEF_DT( + Float16, dt_float16, FLOAT, SIGNED, std::numeric_limits::lowest(), + std::numeric_limits::max())); +DNN_INC_FLOAT16(MEGDNN_DEF_DT( + BFloat16, dt_bfloat16, FLOAT, SIGNED, + std::numeric_limits::lowest(), + std::numeric_limits::max())); template <> struct DTypeTrait { MEGDNN_DEF_DT_BASIC_FIELDS(Byte, dt_byte, OTHER, OTHER, 0, false); }; -#define MEGDNN_DEF_FRACTION_DT(_name, b)\ - template <> \ - struct DTypeTrait {\ - MEGDNN_DEF_DT_BASIC_FIELDS(_name##b, dt_##_name##b, LOWBIT, OTHER, b, \ - false); \ +#define MEGDNN_DEF_FRACTION_DT(_name, b) \ + template <> \ + struct DTypeTrait { \ + MEGDNN_DEF_DT_BASIC_FIELDS(_name##b, dt_##_name##b, LOWBIT, OTHER, b, false); \ }; MEGDNN_FOREACH_LOWBIT_DTYPE(MEGDNN_DEF_FRACTION_DT) #undef MEGDNN_DEF_FRACTION_DT -#define MEGDNN_DEF_PARAMETERIZED_DT(_name, _ctype, _itype, _cat, _sign, \ - _minval, _maxval, _bits) \ +#define MEGDNN_DEF_PARAMETERIZED_DT( \ + _name, _ctype, _itype, _cat, _sign, _minval, _maxval, _bits) \ template <> \ struct DTypeTrait { \ MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, true); \ @@ -769,55 +676,50 @@ MEGDNN_FOREACH_LOWBIT_DTYPE(MEGDNN_DEF_FRACTION_DT) } \ }; -MEGDNN_DEF_PARAMETERIZED_DT(Quantized4Asymm, dt_quint4, uint8_t, QUANTIZED, - SIGNED, 0, 15, 4); -MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS4, dt_qint4, int8_t, QUANTIZED, - SIGNED, -8, 7, 4); -MEGDNN_DEF_PARAMETERIZED_DT(Quantized8Asymm, dt_quint8, dt_quint8, QUANTIZED, - SIGNED, 0, 255, 0); -MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS8, dt_qint8, dt_qint8, QUANTIZED, SIGNED, - INT8_MIN, INT8_MAX, 0); -MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS16, dt_qint16, dt_qint16, QUANTIZED, - SIGNED, INT16_MIN, INT16_MAX, 0); -MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS32, dt_qint32, dt_qint32, QUANTIZED, - SIGNED, INT32_MIN, INT32_MAX, 0); +MEGDNN_DEF_PARAMETERIZED_DT( + Quantized4Asymm, dt_quint4, uint8_t, QUANTIZED, SIGNED, 0, 15, 4); +MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS4, dt_qint4, int8_t, QUANTIZED, SIGNED, -8, 7, 4); +MEGDNN_DEF_PARAMETERIZED_DT( + Quantized8Asymm, dt_quint8, dt_quint8, QUANTIZED, SIGNED, 0, 255, 0); +MEGDNN_DEF_PARAMETERIZED_DT( + QuantizedS8, dt_qint8, dt_qint8, QUANTIZED, SIGNED, INT8_MIN, INT8_MAX, 0); +MEGDNN_DEF_PARAMETERIZED_DT( + QuantizedS16, dt_qint16, dt_qint16, QUANTIZED, SIGNED, INT16_MIN, INT16_MAX, 0); +MEGDNN_DEF_PARAMETERIZED_DT( + QuantizedS32, dt_qint32, dt_qint32, QUANTIZED, SIGNED, INT32_MIN, INT32_MAX, 0); #undef MEGDNN_DEF_PARAMETERIZED_DT #undef MEGDNN_DEF_DT #undef MEGDNN_DEF_DT_BASIC_FIELDS // end define DTypeTrait impls } - // alias DTypeTrait for ctypes -#define IMPL(_obj) \ -template <> \ -struct DTypeTrait::ctype>: \ -public DTypeTrait { }; +#define IMPL(_obj) \ + template <> \ + struct DTypeTrait::ctype> \ + : public DTypeTrait {}; MEGDNN_FOREACH_DTYPE_NAME(IMPL) MEGDNN_FOREACH_PARAMETERIZED_DTYPE(IMPL) #undef IMPL - -template +template inline void DType::assert_is_ctype() const { return assert_is(typename DTypeTrait::dtype()); } #ifdef MEGDNN_CC_HOST -#define INST(_dt) \ - template <> \ - inline void DType::assert_is_ctype::ctype>() \ - const { \ - if (enumv() != DTypeTrait::enumv) { \ - on_assert_is_failed(DTypeTrait::name); \ - } \ +#define INST(_dt) \ + template <> \ + inline void DType::assert_is_ctype::ctype>() const { \ + if (enumv() != DTypeTrait::enumv) { \ + on_assert_is_failed(DTypeTrait::name); \ + } \ } MEGDNN_FOREACH_PARAMETERIZED_DTYPE(INST) #undef INST - template inline void DType::assert_is_compatible_ctype() const { if (enumv() != DTypeTrait::enumv) { @@ -825,14 +727,14 @@ inline void DType::assert_is_compatible_ctype() const { } } -#define INST(_dt, _dtype) \ - template <> \ - inline void \ - DType::assert_is_compatible_ctype::ctype>() const { \ - if (enumv() != DTypeTrait::enumv && \ - enumv() != DTypeTrait::enumv) { \ - on_assert_is_failed(DTypeTrait::name); \ - } \ +#define INST(_dt, _dtype) \ + template <> \ + inline void DType::assert_is_compatible_ctype::ctype>() \ + const { \ + if (enumv() != DTypeTrait::enumv && \ + enumv() != DTypeTrait::enumv) { \ + on_assert_is_failed(DTypeTrait::name); \ + } \ } INST(Int8, QuantizedS8) @@ -843,20 +745,18 @@ INST(Int32, QuantizedS32) #else -#define INST(_dt) \ - template <> \ - inline void DType::assert_is_ctype::ctype>() \ - const { \ - if (enumv().ev != DTypeTrait::enumv) { \ - on_assert_is_failed(dtype::_dt().name()); \ - } \ +#define INST(_dt) \ + template <> \ + inline void DType::assert_is_ctype::ctype>() const { \ + if (enumv().ev != DTypeTrait::enumv) { \ + on_assert_is_failed(dtype::_dt().name()); \ + } \ } -MEGDNN_FOREACH_PARAMETERIZED_DTYPE(INST) + MEGDNN_FOREACH_PARAMETERIZED_DTYPE(INST) #undef INST #endif // MEGDNN_CC_HOST - // begin Specialization of DTypeParamImpl for each parameterzied DType { template <> struct DTypeParamImpl { @@ -902,9 +802,7 @@ struct DTypeParamImpl { v = fmin(fmax(-128.f, v), 127.f); return static_cast(v); } - MEGDNN_DEVICE float dequantize(dt_qint8 in) const { - return in.as_int8() * scale; - } + MEGDNN_DEVICE float dequantize(dt_qint8 in) const { return in.as_int8() * scale; } }; template <> @@ -925,9 +823,7 @@ struct DTypeParamImpl { v = fmin(fmax(-32768.f, v), 32767.f); return static_cast(v); } - MEGDNN_DEVICE float dequantize(dt_qint16 in) const { - return in.as_int16() * scale; - } + MEGDNN_DEVICE float dequantize(dt_qint16 in) const { return in.as_int16() * scale; } }; template <> @@ -949,9 +845,7 @@ struct DTypeParamImpl { v = fmin(fmax(-2147483648.f, v), 2147483520.f); return static_cast(v); } - MEGDNN_DEVICE float dequantize(dt_qint32 in) const { - return in.as_int32() * scale; - } + MEGDNN_DEVICE float dequantize(dt_qint32 in) const { return in.as_int32() * scale; } }; template <> @@ -996,17 +890,13 @@ struct DTypeParamImpl { v = fmin(fmax(-8.f, v), 7.f); return static_cast(v); } - MEGDNN_DEVICE float dequantize(int8_t in) const { - return in * scale; - } - MEGDNN_DEVICE float dequantize(dt_qint4 in) const { - return in.as_int8() * scale; - } + MEGDNN_DEVICE float dequantize(int8_t in) const { return in * scale; } + MEGDNN_DEVICE float dequantize(dt_qint4 in) const { return in.as_int8() * scale; } }; // end Specialization of DTypeParamImpl for each parameterzied DType } -} // namespace megdnn +} // namespace megdnn #include "megdnn/internal/visibility_epilogue.h" diff --git a/dnn/include/megdnn/dtype/half_common_epilogue.h b/dnn/include/megdnn/dtype/half_common_epilogue.h index 5f15b4dc..17f08aae 100644 --- a/dnn/include/megdnn/dtype/half_common_epilogue.h +++ b/dnn/include/megdnn/dtype/half_common_epilogue.h @@ -3,17 +3,22 @@ * * Copyright (c) 2012-2013 Christian Rau * - * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation - * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, - * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE - * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, - * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * Permission is hereby granted, free of charge, to any person obtaining a copy of this + * software and associated documentation files (the "Software"), to deal in the Software + * without restriction, including without limitation the rights to use, copy, modify, + * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be included in all copies + * or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A + * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF + * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE + * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * * Version 1.11.0 * \file @@ -41,8 +46,8 @@ #undef HALF_NOEXCEPT #undef HALF_NOTHROW #ifdef HALF_POP_WARNINGS - #pragma warning(pop) - #undef HALF_POP_WARNINGS +#pragma warning(pop) +#undef HALF_POP_WARNINGS #endif // vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/dtype/half_common_prologue.h b/dnn/include/megdnn/dtype/half_common_prologue.h index 14fc79b1..99313d7b 100644 --- a/dnn/include/megdnn/dtype/half_common_prologue.h +++ b/dnn/include/megdnn/dtype/half_common_prologue.h @@ -3,17 +3,22 @@ * * Copyright (c) 2012-2013 Christian Rau * - * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation - * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, - * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following conditions: + * Permission is hereby granted, free of charge, to any person obtaining a copy of this + * software and associated documentation files (the "Software"), to deal in the Software + * without restriction, including without limitation the rights to use, copy, modify, + * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following + * conditions: * - * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + * The above copyright notice and this permission notice shall be included in all copies + * or substantial portions of the Software. * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE - * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, - * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A + * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF + * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE + * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * * Version 1.11.0 * \file @@ -39,166 +44,164 @@ #include "megdnn/arch.h" /// Combined gcc version number. -#define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) +#define HALF_GNUC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) -//check C++11 language features -#if defined(__clang__) //clang - #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) - #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 - #endif - #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) - #define HALF_ENABLE_CPP11_CONSTEXPR 1 - #endif - #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) - #define HALF_ENABLE_CPP11_NOEXCEPT 1 - #endif - #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) - #define HALF_ENABLE_CPP11_USER_LITERALS 1 - #endif - #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) - #define HALF_ENABLE_CPP11_LONG_LONG 1 - #endif -/*#elif defined(__INTEL_COMPILER) //Intel C++ - #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? - #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 - #endif - #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? - #define HALF_ENABLE_CPP11_CONSTEXPR 1 - #endif - #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? - #define HALF_ENABLE_CPP11_NOEXCEPT 1 - #endif - #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? - #define HALF_ENABLE_CPP11_LONG_LONG 1 - #endif*/ -#elif defined(__GNUC__) //gcc - #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L - #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) - #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 - #endif - #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) - #define HALF_ENABLE_CPP11_CONSTEXPR 1 - #endif - #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) - #define HALF_ENABLE_CPP11_NOEXCEPT 1 - #endif - #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) - #define HALF_ENABLE_CPP11_USER_LITERALS 1 - #endif - #if !defined(HALF_ENABLE_CPP11_LONG_LONG) - #define HALF_ENABLE_CPP11_LONG_LONG 1 - #endif - #endif -#elif defined(_MSC_VER) //Visual C++ - #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) - #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 - #endif - #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) - #define HALF_ENABLE_CPP11_LONG_LONG 1 - #endif - #define HALF_POP_WARNINGS 1 - #pragma warning(push) - //! 4521 and 4522 is multiple copy/assigment operator specified - #pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned +// check C++11 language features +#if defined(__clang__) // clang +#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ + !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +/*#elif defined(__INTEL_COMPILER) + //Intel C++ #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) + ???????? #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 #endif #if __INTEL_COMPILER >= + 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? #define + HALF_ENABLE_CPP11_CONSTEXPR 1 #endif #if __INTEL_COMPILER >= 1300 && + !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? #define + HALF_ENABLE_CPP11_NOEXCEPT 1 #endif #if __INTEL_COMPILER >= 1100 && + !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? #define + HALF_ENABLE_CPP11_LONG_LONG 1 #endif*/ +#elif defined(__GNUC__) // gcc +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L +#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#endif +#elif defined(_MSC_VER) // Visual C++ +#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#define HALF_POP_WARNINGS 1 +#pragma warning(push) +//! 4521 and 4522 is multiple copy/assigment operator specified +#pragma warning(disable : 4099 4127 4146 4521 4522) // struct vs class, constant in if, + // negative unsigned #endif -//check C++11 library features +// check C++11 library features #include -#if defined(_LIBCPP_VERSION) //libc++ - #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 - #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS - #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 - #endif - #ifndef HALF_ENABLE_CPP11_CSTDINT - #define HALF_ENABLE_CPP11_CSTDINT 1 - #endif - #ifndef HALF_ENABLE_CPP11_CMATH - #define HALF_ENABLE_CPP11_CMATH 1 - #endif - #ifndef HALF_ENABLE_CPP11_HASH - #define HALF_ENABLE_CPP11_HASH 1 - #endif - #endif -#elif defined(__GLIBCXX__) //libstdc++ - #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 - #ifdef __clang__ - #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) - #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 - #endif - #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) - #define HALF_ENABLE_CPP11_CSTDINT 1 - #endif - #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) - #define HALF_ENABLE_CPP11_CMATH 1 - #endif - #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) - #define HALF_ENABLE_CPP11_HASH 1 - #endif - #else - #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) - #define HALF_ENABLE_CPP11_CSTDINT 1 - #endif - #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) - #define HALF_ENABLE_CPP11_CMATH 1 - #endif - #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) - #define HALF_ENABLE_CPP11_HASH 1 - #endif - #endif - #endif -#elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ - #if _CPPLIB_VER >= 520 - #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS - #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 - #endif - #ifndef HALF_ENABLE_CPP11_CSTDINT - #define HALF_ENABLE_CPP11_CSTDINT 1 - #endif - #ifndef HALF_ENABLE_CPP11_HASH - #define HALF_ENABLE_CPP11_HASH 1 - #endif - #endif - #if _CPPLIB_VER >= 610 - #ifndef HALF_ENABLE_CPP11_CMATH - #define HALF_ENABLE_CPP11_CMATH 1 - #endif - #endif +#if defined(_LIBCPP_VERSION) // libc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#ifndef HALF_ENABLE_CPP11_CSTDINT +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#ifndef HALF_ENABLE_CPP11_CMATH +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#ifndef HALF_ENABLE_CPP11_HASH +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#endif +#elif defined(__GLIBCXX__) // libstdc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifdef __clang__ +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#else +#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#endif +#endif +#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ +#if _CPPLIB_VER >= 520 +#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#ifndef HALF_ENABLE_CPP11_CSTDINT +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#ifndef HALF_ENABLE_CPP11_HASH +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#endif +#if _CPPLIB_VER >= 610 +#ifndef HALF_ENABLE_CPP11_CMATH +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#endif #endif #undef HALF_GNUC_VERSION -//support constexpr +// support constexpr #if HALF_ENABLE_CPP11_CONSTEXPR - #define HALF_CONSTEXPR constexpr - #define HALF_CONSTEXPR_CONST constexpr +#define HALF_CONSTEXPR constexpr +#define HALF_CONSTEXPR_CONST constexpr #else - #define HALF_CONSTEXPR - #define HALF_CONSTEXPR_CONST const +#define HALF_CONSTEXPR +#define HALF_CONSTEXPR_CONST const #endif -//support noexcept +// support noexcept #if HALF_ENABLE_CPP11_NOEXCEPT - #define HALF_NOEXCEPT noexcept - #define HALF_NOTHROW noexcept +#define HALF_NOEXCEPT noexcept +#define HALF_NOTHROW noexcept #else - #define HALF_NOEXCEPT - #define HALF_NOTHROW throw() +#define HALF_NOEXCEPT +#define HALF_NOTHROW throw() #endif #include -#include #include #include #include -#include #include +#include +#include #if HALF_ENABLE_CPP11_TYPE_TRAITS - #include +#include #endif #if HALF_ENABLE_CPP11_CSTDINT - #include +#include #endif #if HALF_ENABLE_CPP11_HASH - #include +#include #endif // vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/handle.h b/dnn/include/megdnn/handle.h index f938d129..a0e07d6c 100644 --- a/dnn/include/megdnn/handle.h +++ b/dnn/include/megdnn/handle.h @@ -12,8 +12,8 @@ #pragma once #include "megcore.h" -#include "megdnn/config/config.h" #include "megdnn/basic_types.h" +#include "megdnn/config/config.h" #include #include @@ -24,150 +24,147 @@ namespace megdnn { class OperatorBase; class Handle { - public: - enum class HandleType { - NAIVE = 0, - FALLBACK = 1, - X86 = 2, - ARM_COMMON = 3, - ARMV7 = 4, - AARCH64 = 5, - CUDA = 6, - ROCM = 11, - ATLAS = 13, - CAMBRICON = 12, - }; - - //! Device vendor - enum class HandleVendorType : uint32_t { - NOT_SPEC = 0, - MALI = 1, - ADRENO = 2, - CUDA = 3, - INTEL = 4, - POWERVR = 5, - AMD = 6, - }; - - protected: - Handle(megcoreComputingHandle_t computing_handle, HandleType type); - - public: - /** - * \brief Create a MegDNN handle from a MegCore Computing handle. - * - * \param[in] computing_handle MegCore computing handle. Please note - * that computing_handle would not be released when this Handle is - * destructed - * \param[in] debug_level - * Applicable for CPU computing handle. - * 0 means taking the fastest possible code path; it may contains - * platform-specific instructions such as SSE for x86_64 or NEON for - * armv7v7. - * 1 means taking the fastest possible code path without - * platform-specific instructions in C++ code. Note that the compiled - * binary file still contains platform-specific codes. - * 2 means taking the naive code path. Performance is severely - * hampered, but it is less error-prone since the internal - * implementation is rather straightforward. - * - * **Debug level 1 and 2 should not be used in productions.** - */ - static std::unique_ptr make( - megcoreComputingHandle_t computing_handle, - int debug_level = 0); +public: + enum class HandleType { + NAIVE = 0, + FALLBACK = 1, + X86 = 2, + ARM_COMMON = 3, + ARMV7 = 4, + AARCH64 = 5, + CUDA = 6, + ROCM = 11, + ATLAS = 13, + CAMBRICON = 12, + }; + + //! Device vendor + enum class HandleVendorType : uint32_t { + NOT_SPEC = 0, + MALI = 1, + ADRENO = 2, + CUDA = 3, + INTEL = 4, + POWERVR = 5, + AMD = 6, + }; + +protected: + Handle(megcoreComputingHandle_t computing_handle, HandleType type); + +public: + /** + * \brief Create a MegDNN handle from a MegCore Computing handle. + * + * \param[in] computing_handle MegCore computing handle. Please note + * that computing_handle would not be released when this Handle is + * destructed + * \param[in] debug_level + * Applicable for CPU computing handle. + * 0 means taking the fastest possible code path; it may contains + * platform-specific instructions such as SSE for x86_64 or NEON for + * armv7v7. + * 1 means taking the fastest possible code path without + * platform-specific instructions in C++ code. Note that the compiled + * binary file still contains platform-specific codes. + * 2 means taking the naive code path. Performance is severely + * hampered, but it is less error-prone since the internal + * implementation is rather straightforward. + * + * **Debug level 1 and 2 should not be used in productions.** + */ + static std::unique_ptr make( + megcoreComputingHandle_t computing_handle, int debug_level = 0); #if MEGDNN_WITH_CUDA - static std::unique_ptr make_cuda_handle( - megcoreComputingHandle_t computing_handle); - template - std::unique_ptr create_cuda_operator(); + static std::unique_ptr make_cuda_handle( + megcoreComputingHandle_t computing_handle); + template + std::unique_ptr create_cuda_operator(); #endif #if MEGDNN_WITH_ROCM - static std::unique_ptr make_rocm_handle( - megcoreComputingHandle_t computing_handle); - template - std::unique_ptr create_rocm_operator(); + static std::unique_ptr make_rocm_handle( + megcoreComputingHandle_t computing_handle); + template + std::unique_ptr create_rocm_operator(); #endif - virtual ~Handle(); - - /*! - * \brief Get the underlying megcore computing handle. - */ - megcoreComputingHandle_t megcore_computing_handle() const { - return m_computing_handle; - } - - /*! - * \brief set a callback function to be invoked when this handle is - * destructed, so associated resources can be released (e.g. - * computing handle) - * - * This function can be called at most once. - */ - void set_destructor(const thin_function &d); - - /*! - * \brief set a callback to be invoked when an operator is destructed - * \param[in,out] cb the callback function; it would be set to the - * previous callback function - */ - void set_opr_destruct_callback(thin_function &cb) { - cb.swap(m_on_opr_destructed); - } - - void on_opr_destructed(OperatorBase* opr); - - /** - * \brief Create operator of Opr type. - */ - template - std::unique_ptr create_operator(); - - /* - * ============================================================= - * Users should call functions below to query memory requirement. - * ============================================================= - */ - - /** - * \brief The internal data pointer of TensorND should be aligned to - * alignment_requirement() in bytes. - */ - virtual size_t alignment_requirement() const; - - //! get alignment in bytes for rows of image 2D tensor format - virtual size_t image2d_pitch_alignment() const; - - //! get vendor type - virtual HandleVendorType vendor_type() const; - - HandleType type() const { - return m_handle_type; - } - - /** - * \brief Check is the layout satisfy cross device copy constraint. - * 1. The handle of the src and the dst is the same kind - * 2. The dst is continguous. - */ - virtual bool check_cross_dev_copy_constraint(const TensorLayout &src); - - private: - static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; - volatile uint32_t m_alive_magic = ALIVE_MAGIC; - megcoreComputingHandle_t m_computing_handle; - const HandleType m_handle_type; - thin_function m_destructor; - thin_function m_on_opr_destructed; - - Handle() = delete; - Handle(const Handle &rhs) = delete; - Handle &operator=(const Handle &rhs) = delete; + virtual ~Handle(); + + /*! + * \brief Get the underlying megcore computing handle. + */ + megcoreComputingHandle_t megcore_computing_handle() const { + return m_computing_handle; + } + + /*! + * \brief set a callback function to be invoked when this handle is + * destructed, so associated resources can be released (e.g. + * computing handle) + * + * This function can be called at most once. + */ + void set_destructor(const thin_function& d); + + /*! + * \brief set a callback to be invoked when an operator is destructed + * \param[in,out] cb the callback function; it would be set to the + * previous callback function + */ + void set_opr_destruct_callback(thin_function& cb) { + cb.swap(m_on_opr_destructed); + } + + void on_opr_destructed(OperatorBase* opr); + + /** + * \brief Create operator of Opr type. + */ + template + std::unique_ptr create_operator(); + + /* + * ============================================================= + * Users should call functions below to query memory requirement. + * ============================================================= + */ + + /** + * \brief The internal data pointer of TensorND should be aligned to + * alignment_requirement() in bytes. + */ + virtual size_t alignment_requirement() const; + + //! get alignment in bytes for rows of image 2D tensor format + virtual size_t image2d_pitch_alignment() const; + + //! get vendor type + virtual HandleVendorType vendor_type() const; + + HandleType type() const { return m_handle_type; } + + /** + * \brief Check is the layout satisfy cross device copy constraint. + * 1. The handle of the src and the dst is the same kind + * 2. The dst is continguous. + */ + virtual bool check_cross_dev_copy_constraint(const TensorLayout& src); + +private: + static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; + volatile uint32_t m_alive_magic = ALIVE_MAGIC; + megcoreComputingHandle_t m_computing_handle; + const HandleType m_handle_type; + thin_function m_destructor; + thin_function m_on_opr_destructed; + + Handle() = delete; + Handle(const Handle& rhs) = delete; + Handle& operator=(const Handle& rhs) = delete; }; -} // namespace megdnn +} // namespace megdnn #include "megdnn/internal/visibility_epilogue.h" diff --git a/dnn/include/megdnn/heuristic_cache.h b/dnn/include/megdnn/heuristic_cache.h index 1298ba85..75f75c67 100644 --- a/dnn/include/megdnn/heuristic_cache.h +++ b/dnn/include/megdnn/heuristic_cache.h @@ -49,8 +49,9 @@ public: mutable std::string m_input; public: - Key(Handle* opr_handle, Algorithm::OprType opr_type, const TensorLayout* inp_layouts_ptr, - size_t inp_layouts_size, const void* param_ptr = nullptr, size_t param_size = 0) + Key(Handle* opr_handle, Algorithm::OprType opr_type, + const TensorLayout* inp_layouts_ptr, size_t inp_layouts_size, + const void* param_ptr = nullptr, size_t param_size = 0) : m_handle{opr_handle}, m_opr_type{static_cast(opr_type)}, m_inp_layouts_ptr{inp_layouts_ptr}, diff --git a/dnn/include/megdnn/internal/defs.h b/dnn/include/megdnn/internal/defs.h index 18960cb8..871bb58e 100644 --- a/dnn/include/megdnn/internal/defs.h +++ b/dnn/include/megdnn/internal/defs.h @@ -16,20 +16,19 @@ * \brief iterate through small (usually used) ndim values */ #define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \ - cb(1 ,##__VA_ARGS__) cb(2 ,##__VA_ARGS__) cb(3 ,##__VA_ARGS__) + cb(1, ##__VA_ARGS__) cb(2, ##__VA_ARGS__) cb(3, ##__VA_ARGS__) /*! * \brief iterate through large (rarely used) ndim values */ #define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \ - cb(4 ,##__VA_ARGS__) cb(5 ,##__VA_ARGS__) cb(6 ,##__VA_ARGS__) \ - cb(7, ##__VA_ARGS__) + cb(4, ##__VA_ARGS__) cb(5, ##__VA_ARGS__) cb(6, ##__VA_ARGS__) cb(7, ##__VA_ARGS__) /*! * \brief iterate through all ndim values */ -#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ - MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb ,##__VA_ARGS__) \ - MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb ,##__VA_ARGS__) +#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ + MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ##__VA_ARGS__) \ + MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ##__VA_ARGS__) // vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/internal/opr_header_prologue.h b/dnn/include/megdnn/internal/opr_header_prologue.h index 09f9d75b..e44b8026 100644 --- a/dnn/include/megdnn/internal/opr_header_prologue.h +++ b/dnn/include/megdnn/internal/opr_header_prologue.h @@ -11,14 +11,14 @@ // intentional no header guard here #include "megdnn/handle.h" -#include "megdnn/oprs/base.h" #include "megdnn/opr_param_defs.h" #include "megdnn/opr_result_defs.h" +#include "megdnn/oprs/base.h" #include "./visibility_prologue.h" -#include #include +#include #ifndef _megdnn_in #define _megdnn_in @@ -29,36 +29,37 @@ #endif #ifndef _megdnn_tensor_in -#define _megdnn_tensor_in const TensorND & +#define _megdnn_tensor_in const TensorND& #endif #ifndef _megdnn_tensor_out -#define _megdnn_tensor_out const TensorND & +#define _megdnn_tensor_out const TensorND& #endif #ifndef _megdnn_tensor_inout -#define _megdnn_tensor_inout const TensorND & +#define _megdnn_tensor_inout const TensorND& #endif #ifndef _megdnn_workspace -#define _megdnn_workspace const Workspace & +#define _megdnn_workspace const Workspace& #endif -#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ - public: \ - _opr_name(Handle *handle): _base_name(handle) {} \ +#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ +public: \ + _opr_name(Handle* handle) : _base_name(handle) {} #define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \ - DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ - static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ - static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; \ + DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ + static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ + static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; -#define DEF_OPR_PARAM(_pname) \ - public: \ - using Param = param::_pname; \ - Param& param() { return m_param; } \ - const Param& param() const { return m_param; } \ - protected: \ - Param m_param +#define DEF_OPR_PARAM(_pname) \ +public: \ + using Param = param::_pname; \ + Param& param() { return m_param; } \ + const Param& param() const { return m_param; } \ + \ +protected: \ + Param m_param // vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/internal/visibility_epilogue.h b/dnn/include/megdnn/internal/visibility_epilogue.h index 98c84bf5..aeaf71cf 100644 --- a/dnn/include/megdnn/internal/visibility_epilogue.h +++ b/dnn/include/megdnn/internal/visibility_epilogue.h @@ -20,4 +20,3 @@ #endif // vim: syntax=cpp.doxygen - diff --git a/dnn/include/megdnn/opr_result_defs.h b/dnn/include/megdnn/opr_result_defs.h index f9451d65..7689dccc 100644 --- a/dnn/include/megdnn/opr_result_defs.h +++ b/dnn/include/megdnn/opr_result_defs.h @@ -16,25 +16,21 @@ namespace megdnn { namespace opr_result { - struct Checksum { - uint32_t checksum; - union { - int32_t iv; - float fv; - } last_val; - - bool operator == (const Checksum &rhs) const { - return checksum == rhs.checksum && - last_val.iv == rhs.last_val.iv; - } - - bool operator != (const Checksum &rhs) const { - return !operator==(rhs); - } - }; - -} // namespace opr_result -} // namespace megdnn - +struct Checksum { + uint32_t checksum; + union { + int32_t iv; + float fv; + } last_val; + + bool operator==(const Checksum& rhs) const { + return checksum == rhs.checksum && last_val.iv == rhs.last_val.iv; + } + + bool operator!=(const Checksum& rhs) const { return !operator==(rhs); } +}; + +} // namespace opr_result +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/oprs.h b/dnn/include/megdnn/oprs.h index 36f6b11a..30b97b7c 100644 --- a/dnn/include/megdnn/oprs.h +++ b/dnn/include/megdnn/oprs.h @@ -12,11 +12,11 @@ #include "megdnn/oprs/cv.h" #include "megdnn/oprs/general.h" +#include "megdnn/oprs/imgproc.h" +#include "megdnn/oprs/linalg.h" #include "megdnn/oprs/nn.h" #include "megdnn/oprs/nn_int.h" -#include "megdnn/oprs/imgproc.h" #include "megdnn/oprs/utils.h" -#include "megdnn/oprs/linalg.h" template struct OprArityTrait; @@ -53,6 +53,4 @@ INST_ARITY(megdnn::PoolingBackward, 3, 1); #undef INST_ARITY - - // vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/oprs/base.h b/dnn/include/megdnn/oprs/base.h index 3dbdfe67..e3352c3c 100644 --- a/dnn/include/megdnn/oprs/base.h +++ b/dnn/include/megdnn/oprs/base.h @@ -90,7 +90,7 @@ enum class AlgoDataType : uint32_t { INT8X8X16 = 1 << 4, INT16X16X32 = 1 << 5, INT4X4X16 = 1 << 6, - QINT4x4x32 = 1 << 7, + QINT4x4x32 = 1 << 7, }; /*! @@ -195,16 +195,16 @@ public: Handle::HandleType handle_type() const { return m_handle_type; } Info::Desc desc() const { return {handle_type(), type(), param(), name()}; } - Info info() const { - return {desc(), attribute()}; - } + Info info() const { return {desc(), attribute()}; } template static void serialize_write_pod(const T& val, std::string& result) { - static_assert(std::is_trivially_copyable::value, - "type should be trivially copyable"); - static_assert(!std::is_pointer::value, - "serialize pointer is unsafe in eager execution mode"); + static_assert( + std::is_trivially_copyable::value, + "type should be trivially copyable"); + static_assert( + !std::is_pointer::value, + "serialize pointer is unsafe in eager execution mode"); result.append(reinterpret_cast(&val), sizeof(T)); } @@ -231,9 +231,8 @@ public: return ret; } - static std::string deserialize_read_pod(const std::string& data, - size_t offset = 0, - size_t size = 0) { + static std::string deserialize_read_pod( + const std::string& data, size_t offset = 0, size_t size = 0) { return std::string(data.data() + offset, size); } @@ -286,8 +285,8 @@ public: * \param layouts origin layouts of the parent opr * \param opr parent opr */ - virtual std::vector get_subopr_list(const TensorLayoutArray&, - const OperatorBase*) const { + virtual std::vector get_subopr_list( + const TensorLayoutArray&, const OperatorBase*) const { return {}; } @@ -333,9 +332,7 @@ public: ExecutionPolicy& execution_policy() { return m_execution_policy; } - const ExecutionPolicy& execution_policy() const { - return m_execution_policy; - } + const ExecutionPolicy& execution_policy() const { return m_execution_policy; } virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; @@ -355,8 +352,8 @@ public: using AlgoAttribute = detail::Algorithm::Attribute; //! get all possible algorithm decriptions for the specified layouts - std::vector get_all_algorithms_info(const TensorLayout& p0, - const TensorLayout& p1) { + std::vector get_all_algorithms_info( + const TensorLayout& p0, const TensorLayout& p1) { std::vector ret; for (auto&& algo : get_all_algorithms(p0, p1)) { ret.emplace_back(algo->info()); @@ -364,8 +361,8 @@ public: return ret; } - std::vector get_all_algorithms_info_safe(const TensorLayout& p0, - const TensorLayout& p1) { + std::vector get_all_algorithms_info_safe( + const TensorLayout& p0, const TensorLayout& p1) { std::vector ret; for (auto&& algo : get_all_algorithms_safe(p0, p1)) { ret.emplace_back(algo->info()); @@ -382,12 +379,11 @@ public: */ AlgorithmInfo get_algorithm_info_heuristic( const TensorLayout& p0, const TensorLayout& p1, - size_t workspace_limit_in_bytes = - std::numeric_limits::max(), + size_t workspace_limit_in_bytes = std::numeric_limits::max(), const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { - return get_algorithm_heuristic(p0, p1, workspace_limit_in_bytes, - positive_attr, negative_attr) + return get_algorithm_heuristic( + p0, p1, workspace_limit_in_bytes, positive_attr, negative_attr) ->info(); } @@ -408,8 +404,7 @@ protected: */ virtual Algorithm* get_algorithm_heuristic( const TensorLayout& p0, const TensorLayout& p1, - size_t workspace_limit_in_bytes = - std::numeric_limits::max(), + size_t workspace_limit_in_bytes = std::numeric_limits::max(), const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; }; @@ -423,9 +418,8 @@ public: using AlgoAttribute = detail::Algorithm::Attribute; //! get all possible algorithm decriptions for the specified layouts - std::vector get_all_algorithms_info(const TensorLayout& p0, - const TensorLayout& p1, - const TensorLayout& p2) { + std::vector get_all_algorithms_info( + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) { std::vector ret; for (auto&& algo : get_all_algorithms(p0, p1, p2)) { ret.emplace_back(algo->info()); @@ -433,9 +427,8 @@ public: return ret; } - std::vector get_all_algorithms_info_safe(const TensorLayout& p0, - const TensorLayout& p1, - const TensorLayout& p2) { + std::vector get_all_algorithms_info_safe( + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) { std::vector ret; for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { ret.emplace_back(algo->info()); @@ -451,14 +444,13 @@ public: * \p workspace_limit_in_bytes. */ AlgorithmInfo get_algorithm_info_heuristic( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, - size_t workspace_limit_in_bytes = - std::numeric_limits::max(), + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + size_t workspace_limit_in_bytes = std::numeric_limits::max(), const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { - return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, - positive_attr, negative_attr) + return get_algorithm_heuristic( + p0, p1, p2, workspace_limit_in_bytes, positive_attr, + negative_attr) ->info(); } @@ -467,11 +459,9 @@ protected: //! get all possible algorithms for the specified layouts virtual std::vector get_all_algorithms( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2) = 0; + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0; virtual std::vector get_all_algorithms_safe( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2) = 0; + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0; /** * \brief Returns the best algorithm by heuristic. @@ -480,10 +470,8 @@ protected: * \p workspace_limit_in_bytes. */ virtual Algorithm* get_algorithm_heuristic( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, - size_t workspace_limit_in_bytes = - std::numeric_limits::max(), + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + size_t workspace_limit_in_bytes = std::numeric_limits::max(), const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; }; @@ -497,10 +485,9 @@ public: using AlgoAttribute = detail::Algorithm::Attribute; //! get all possible algorithm decriptions for the specified layouts - std::vector get_all_algorithms_info(const TensorLayout& p0, - const TensorLayout& p1, - const TensorLayout& p2, - const TensorLayout& p3) { + std::vector get_all_algorithms_info( + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3) { std::vector ret; for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { ret.emplace_back(algo->info()); @@ -508,10 +495,9 @@ public: return ret; } - std::vector get_all_algorithms_info_safe(const TensorLayout& p0, - const TensorLayout& p1, - const TensorLayout& p2, - const TensorLayout& p3) { + std::vector get_all_algorithms_info_safe( + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3) { std::vector ret; for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { ret.emplace_back(algo->info()); @@ -527,14 +513,14 @@ public: * \p workspace_limit_in_bytes. */ AlgorithmInfo get_algorithm_info_heuristic( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - size_t workspace_limit_in_bytes = - std::numeric_limits::max(), + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, + size_t workspace_limit_in_bytes = std::numeric_limits::max(), const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { - return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, - positive_attr, negative_attr) + return get_algorithm_heuristic( + p0, p1, p2, p3, workspace_limit_in_bytes, positive_attr, + negative_attr) ->info(); } @@ -543,11 +529,11 @@ protected: //! get all possible algorithms for the specified layouts virtual std::vector get_all_algorithms( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3) = 0; + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3) = 0; virtual std::vector get_all_algorithms_safe( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3) = 0; + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3) = 0; /** * \brief Returns the best algorithm by heuristic. @@ -556,10 +542,9 @@ protected: * \p workspace_limit_in_bytes. */ virtual Algorithm* get_algorithm_heuristic( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - size_t workspace_limit_in_bytes = - std::numeric_limits::max(), + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, + size_t workspace_limit_in_bytes = std::numeric_limits::max(), const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; }; @@ -573,11 +558,9 @@ public: using AlgoAttribute = detail::Algorithm::Attribute; //! get all possible algorithm decriptions for the specified layouts - std::vector get_all_algorithms_info(const TensorLayout& p0, - const TensorLayout& p1, - const TensorLayout& p2, - const TensorLayout& p3, - const TensorLayout& p4) { + std::vector get_all_algorithms_info( + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4) { std::vector ret; for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { ret.emplace_back(algo->info()); @@ -585,11 +568,9 @@ public: return ret; } - std::vector get_all_algorithms_info_safe(const TensorLayout& p0, - const TensorLayout& p1, - const TensorLayout& p2, - const TensorLayout& p3, - const TensorLayout& p4) { + std::vector get_all_algorithms_info_safe( + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4) { std::vector ret; for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { ret.emplace_back(algo->info()); @@ -605,16 +586,14 @@ public: * \p workspace_limit_in_bytes. */ AlgorithmInfo get_algorithm_info_heuristic( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - const TensorLayout& p4, - size_t workspace_limit_in_bytes = - std::numeric_limits::max(), + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4, + size_t workspace_limit_in_bytes = std::numeric_limits::max(), const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { - return get_algorithm_heuristic(p0, p1, p2, p3, p4, - workspace_limit_in_bytes, positive_attr, - negative_attr) + return get_algorithm_heuristic( + p0, p1, p2, p3, p4, workspace_limit_in_bytes, positive_attr, + negative_attr) ->info(); } @@ -622,14 +601,12 @@ protected: ~MultiAlgoOpr() = default; //! get all possible algorithms for the specified layouts - virtual std::vector get_all_algorithms( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - const TensorLayout& p4) = 0; + virtual std::vector get_all_algorithms( + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4) = 0; virtual std::vector get_all_algorithms_safe( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - const TensorLayout& p4) = 0; + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4) = 0; /** * \brief Returns the best algorithm by heuristic. @@ -638,11 +615,9 @@ protected: * \p workspace_limit_in_bytes. */ virtual Algorithm* get_algorithm_heuristic( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - const TensorLayout& p4, - size_t workspace_limit_in_bytes = - std::numeric_limits::max(), + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4, + size_t workspace_limit_in_bytes = std::numeric_limits::max(), const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; }; @@ -657,9 +632,8 @@ public: //! get all possible algorithm decriptions for the specified layouts std::vector get_all_algorithms_info( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - const TensorLayout& p4, const TensorLayout& p5, + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, const TensorLayout& p6, const TensorLayout& p7) { std::vector ret; for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { @@ -669,9 +643,8 @@ public: } std::vector get_all_algorithms_info_safe( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - const TensorLayout& p4, const TensorLayout& p5, + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, const TensorLayout& p6, const TensorLayout& p7) { std::vector ret; for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { @@ -687,17 +660,15 @@ public: * The selected algorithm should not use workspace more than */ AlgorithmInfo get_algorithm_info_heuristic( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - const TensorLayout& p4, const TensorLayout& p5, + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, const TensorLayout& p6, const TensorLayout& p7, - size_t workspace_limit_in_bytes = - std::numeric_limits::max(), + size_t workspace_limit_in_bytes = std::numeric_limits::max(), const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { - return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, - workspace_limit_in_bytes, positive_attr, - negative_attr) + return get_algorithm_heuristic( + p0, p1, p2, p3, p4, p5, p6, p7, workspace_limit_in_bytes, + positive_attr, negative_attr) ->info(); } @@ -705,15 +676,13 @@ protected: ~MultiAlgoOpr() = default; //! get all possible algorithms for the specified layouts - virtual std::vector get_all_algorithms( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - const TensorLayout& p4, const TensorLayout& p5, + virtual std::vector get_all_algorithms( + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, const TensorLayout& p6, const TensorLayout& p7) = 0; virtual std::vector get_all_algorithms_safe( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - const TensorLayout& p4, const TensorLayout& p5, + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, const TensorLayout& p6, const TensorLayout& p7) = 0; /** @@ -723,12 +692,10 @@ protected: * \p workspace_limit_in_bytes. */ virtual Algorithm* get_algorithm_heuristic( - const TensorLayout& p0, const TensorLayout& p1, - const TensorLayout& p2, const TensorLayout& p3, - const TensorLayout& p4, const TensorLayout& p5, + const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, + const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, const TensorLayout& p6, const TensorLayout& p7, - size_t workspace_limit_in_bytes = - std::numeric_limits::max(), + size_t workspace_limit_in_bytes = std::numeric_limits::max(), const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; }; diff --git a/dnn/include/megdnn/oprs/cv.h b/dnn/include/megdnn/oprs/cv.h index 0c05da32..0e63198b 100644 --- a/dnn/include/megdnn/oprs/cv.h +++ b/dnn/include/megdnn/oprs/cv.h @@ -31,15 +31,17 @@ class FlipForward : public FlipBase { DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Flip = FlipForward; @@ -56,15 +58,17 @@ class RotateForward : public RotateBase { DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Rotate = RotateForward; @@ -81,15 +85,17 @@ class ROICopyForward : public ROICopyBase { DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using ROICopy = ROICopyForward; @@ -106,15 +112,17 @@ class CvtColorForward : public CvtColorBase { DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using CvtColor = CvtColorForward; @@ -130,8 +138,9 @@ public: using BorderMode = Param::BorderMode; protected: - void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans, - const TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& trans, + const TensorLayout& dst); std::string param_msg() const; int get_real_coord(int p, int len); }; @@ -148,15 +157,17 @@ public: * \warning src, trans, border_value, dst should be contiguous * The size of trans is N * 2 * 3 */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& trans, - const TensorLayout& dst) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& trans, + const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& trans, - const TensorLayout& dst, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst, + size_t workspace_in_bytes); }; using WarpAffine = WarpAffineForward; @@ -173,15 +184,17 @@ class GaussianBlurForward : public GaussianBlurBase { DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using GaussianBlur = GaussianBlurForward; @@ -212,15 +225,17 @@ class ResizeForward : public ResizeBase { DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Resize = ResizeForward; @@ -228,15 +243,17 @@ class ResizeBackward : public ResizeBase { DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1); public: - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& mat) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& mat) = 0; protected: - void check_exec(const TensorLayout& diff, const TensorLayout& mat, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& diff, const TensorLayout& mat, + size_t workspace_in_bytes); }; /** @@ -251,29 +268,32 @@ public: using BorderMode = Param::BorderMode; protected: - void check_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, - const TensorLayout& dst); - void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, - TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& map_xy, + const TensorLayout& dst); + void deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst); }; class RemapForward : public RemapBase { DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& src, const TensorLayout& map_xy, - TensorLayout& dst); + void deduce_layout( + const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& map_xy, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& map_xy, + const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& map_xy, - const TensorLayout& dst, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& map_xy, + const TensorLayout& dst, size_t workspace_in_bytes); }; using Remap = RemapForward; @@ -281,35 +301,37 @@ class RemapBackwardData : public RemapBase { DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1); public: - virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& map_xy, const TensorLayout& diff, + const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& map_xy, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& map_xy, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes); }; class RemapBackwardMat : public RemapBase { DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, - _megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& map_xy, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& map_xy, + const TensorLayout& diff, const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& map_xy, - const TensorLayout& diff, const TensorLayout& grad, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& map_xy, + const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes); }; class SeparableFilterBase : public OperatorBase { @@ -317,32 +339,34 @@ class SeparableFilterBase : public OperatorBase { DEF_OPR_PARAM(SeparableFilter); protected: - void deduce_layout_fwd(const TensorLayout& src, - const TensorLayout& filter_x, - const TensorLayout& filter_y, TensorLayout& dst); - void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x, - const TensorLayout& filter_y, - const TensorLayout& dst); + void deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, const TensorLayout& dst); }; class SeparableFilterForward : public SeparableFilterBase { DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x, - _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x, - const TensorLayout& filter_y, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter_x, - const TensorLayout& filter_y, - const TensorLayout& dst) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& filter_x, - const TensorLayout& filter_y, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, const TensorLayout& dst, + size_t workspace_in_bytes); }; using SeparableFilter = SeparableFilterForward; diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index d57bfcbb..c9af381f 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -26,69 +26,64 @@ namespace megdnn { * is true, e.g. float == float returns value of float). Output layout must be * contiguous. */ -class ElemwiseForward: public OperatorBase { +class ElemwiseForward : public OperatorBase { DEF_OPR_PARAM(Elemwise); DEF_OPR_IMPL(ElemwiseForward, OperatorBase, -1, 1); - public: - using Mode = Param::Mode; - - //! information about a mode - struct ModeTrait { - uint32_t arity; //!< number of inputs needed - bool commutable; //!< whether arity == 2 and inputs commutable - bool allow_int; //!< whether int inputs allowed - bool allow_float; //!< whether float inputs allowed - bool allow_bool; //!< whether bool inputs allowed - const char* name; //!< name of the mode - - - ModeTrait(): - arity(0), commutable(0), allow_int(0), allow_float(0), allow_bool(0), - name(NULL) - {} - - //! get trait from a mode; this function is thread safe - static const ModeTrait& from_mode(Mode mode); - }; - - //! get trait of current mode - const ModeTrait& mode_trait() const { - return ModeTrait::from_mode(m_param.mode); - } - - /** - * \param[in] src input tensor - * \param[out] dst output tensor - * - * src and dst should have the same shape; - * layouts should be contiguous; - * the underlying data pointer can point to the same memory region for - * src and dst. - */ - virtual void exec(_megdnn_in const TensorNDArray &src, - _megdnn_tensor_out dst) = 0; - - //! deduce output shape (do not check whether arity matches) - static void deduce_shape( - const TensorShapeArray &src, - TensorShape &dst); - - static void deduce_format(const TensorFormatArray& src, - TensorFormat& dst); - - //! deduce output layout - void deduce_layout(const TensorLayoutArray &src, - TensorLayout &dst); - - protected: - //! throw exception if incorrect layout; broadcast input shape to - //! output shape - void check_layout_and_broadcast( - const TensorLayoutPtrArray &src, const TensorLayout &dst); - - private: - void check_dtype(DType dtype); +public: + using Mode = Param::Mode; + + //! information about a mode + struct ModeTrait { + uint32_t arity; //!< number of inputs needed + bool commutable; //!< whether arity == 2 and inputs commutable + bool allow_int; //!< whether int inputs allowed + bool allow_float; //!< whether float inputs allowed + bool allow_bool; //!< whether bool inputs allowed + const char* name; //!< name of the mode + + ModeTrait() + : arity(0), + commutable(0), + allow_int(0), + allow_float(0), + allow_bool(0), + name(NULL) {} + + //! get trait from a mode; this function is thread safe + static const ModeTrait& from_mode(Mode mode); + }; + + //! get trait of current mode + const ModeTrait& mode_trait() const { return ModeTrait::from_mode(m_param.mode); } + + /** + * \param[in] src input tensor + * \param[out] dst output tensor + * + * src and dst should have the same shape; + * layouts should be contiguous; + * the underlying data pointer can point to the same memory region for + * src and dst. + */ + virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0; + + //! deduce output shape (do not check whether arity matches) + static void deduce_shape(const TensorShapeArray& src, TensorShape& dst); + + static void deduce_format(const TensorFormatArray& src, TensorFormat& dst); + + //! deduce output layout + void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst); + +protected: + //! throw exception if incorrect layout; broadcast input shape to + //! output shape + void check_layout_and_broadcast( + const TensorLayoutPtrArray& src, const TensorLayout& dst); + +private: + void check_dtype(DType dtype); }; using Elemwise = ElemwiseForward; @@ -111,8 +106,7 @@ public: void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst); //! compatible API for mgb; workspace is not used - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace) { + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace) { return exec(src, dst); } @@ -141,8 +135,9 @@ protected: * Note: \p exp_f and \p exp_i must be dereferenced before dispatching any * kernel. They are allocated on the caller's stack. */ - virtual void do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - const float* exp_f, const int* exp_i) = 0; + virtual void do_exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, const float* exp_f, + const int* exp_i) = 0; }; /*! @@ -150,45 +145,46 @@ protected: * * dst and delta can have arbitrary layout but must have the same shape. */ -class AddUpdateForward: public OperatorBase { +class AddUpdateForward : public OperatorBase { DEF_OPR_PARAM(AddUpdate); DEF_OPR_IMPL(AddUpdateForward, OperatorBase, -1, 1); - public: - virtual void exec( - _megdnn_tensor_inout dst, _megdnn_tensor_in delta) = 0; +public: + virtual void exec(_megdnn_tensor_inout dst, _megdnn_tensor_in delta) = 0; - protected: - void check_exec(const TensorLayout &dst, const TensorLayout &delta); +protected: + void check_exec(const TensorLayout& dst, const TensorLayout& delta); }; using AddUpdate = AddUpdateForward; -class ReduceForward: public OperatorBase { +class ReduceForward : public OperatorBase { DEF_OPR_PARAM(Reduce); DEF_OPR_IMPL(ReduceForward, OperatorBase, 1, 1); - public: - using Mode = Param::Mode; - using DataType = Param::DataType; - - /** - * \param[in] src input tensor - * \param[out] dst output tensor - * - * src and dst should be contiguous. - * src and dst should be of the same shape for all dimensions except - * param().axis. - * the param().axis-th dimension shape for dst should be one. - */ - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes); +public: + using Mode = Param::Mode; + using DataType = Param::DataType; + + /** + * \param[in] src input tensor + * \param[out] dst output tensor + * + * src and dst should be contiguous. + * src and dst should be of the same shape for all dimensions except + * param().axis. + * the param().axis-th dimension shape for dst should be one. + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Reduce = ReduceForward; @@ -197,10 +193,11 @@ class CorrelationBase : public OperatorBase { DEF_OPR_PARAM(Correlation); protected: - void deduce_layout_fwd(const TensorLayout& data1, const TensorLayout& data2, - TensorLayout& dst); - void check_layout_fwd(const TensorLayout& data1, const TensorLayout& data2, - const TensorLayout& dst); + void deduce_layout_fwd( + const TensorLayout& data1, const TensorLayout& data2, TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& dst); }; class CorrelationForward : public CorrelationBase { @@ -208,20 +205,23 @@ class CorrelationForward : public CorrelationBase { public: /** - * \param[in] data1 (n, c, ih, iw) + * \param[in] data1 (n, c, ih, iw) * \param[in] data2 (n, c, ih, iw) * \param[out] dst (n, q, oh, ow), q is the number of neighborhood - * */ - virtual void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& data1, const TensorLayout& data2, - TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& data1, - const TensorLayout& data2, - const TensorLayout& dst) = 0; + * */ + virtual void exec( + _megdnn_tensor_in data1, _megdnn_tensor_in data2, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& data1, const TensorLayout& data2, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& dst) = 0; + protected: - void check_exec(const TensorLayout& data1, const TensorLayout& data2, - const TensorLayout& dst, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& dst, size_t workspace_in_bytes); }; using Correlation = CorrelationForward; @@ -235,18 +235,21 @@ public: * \param[in] data2 the `data2' parameter in CorrelationForward::exec * \param[out] grad1 the backpropagated gradient wrt. data1 */ - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, - _megdnn_tensor_out grad1, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1, - const TensorLayout& data2, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& data1, - const TensorLayout& data2, - const TensorLayout& grad1) = 0; + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out grad1, _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& diff1, const TensorLayout& data1, + const TensorLayout& data2, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& data1, + const TensorLayout& data2, const TensorLayout& grad1) = 0; protected: - void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, - const TensorLayout& grad1, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& diff, const TensorLayout& data1, + const TensorLayout& data2, const TensorLayout& grad1, + size_t workspace_in_bytes); }; class CorrelationBackwardData2 : public CorrelationBase { @@ -259,119 +262,125 @@ public: * \param[in] data2 the `data2' parameter in CorrelationForward::exec * \param[out] grad2 the backpropagated gradient wrt. data2 */ - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, - _megdnn_tensor_out grad2, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1, - const TensorLayout& data2, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& data1, - const TensorLayout& data2, - const TensorLayout& grad2) = 0; + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out grad2, _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& diff1, const TensorLayout& data1, + const TensorLayout& data2, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& data1, + const TensorLayout& data2, const TensorLayout& grad2) = 0; protected: - void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, - const TensorLayout& grad2, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& diff, const TensorLayout& data1, + const TensorLayout& data2, const TensorLayout& grad2, + size_t workspace_in_bytes); }; -class CumsumForward: public OperatorBase { +class CumsumForward : public OperatorBase { DEF_OPR_PARAM(Cumsum); DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1); - public: - /** - * \param[in] src input tensor - * \param[out] dst output tensor - * - * src and dst should be contiguous. - * src and dst should have the same shape. - * - * The exclusive flag specifies whether the current element it taken - * into account when calculating results. - * - * The reverse flag specifies whether cumsum is forward ( - * from 0 to n) or backward (from n downto 0). - * - * Example: - * exclusive && reverse: - * dst_i = src_{i+1} + src_{i+2} + ... + src_{n-1} - * exclusive && !reverse - * dst_i = src_0 + src_1 + ... + src_{i-1} - * !exclusive && reverse: - * dst_i = src_i + src_{i+1} + ... + src_{n-1} - * !exclusive && !reverse: - * dst_i = src_0 + src_1 + ... + src_i - */ - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes); +public: + /** + * \param[in] src input tensor + * \param[out] dst output tensor + * + * src and dst should be contiguous. + * src and dst should have the same shape. + * + * The exclusive flag specifies whether the current element it taken + * into account when calculating results. + * + * The reverse flag specifies whether cumsum is forward ( + * from 0 to n) or backward (from n downto 0). + * + * Example: + * exclusive && reverse: + * dst_i = src_{i+1} + src_{i+2} + ... + src_{n-1} + * exclusive && !reverse + * dst_i = src_0 + src_1 + ... + src_{i-1} + * !exclusive && reverse: + * dst_i = src_i + src_{i+1} + ... + src_{n-1} + * !exclusive && !reverse: + * dst_i = src_0 + src_1 + ... + src_i + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Cumsum = CumsumForward; // mxx can be max or min -class ArgmxxBase: public OperatorBase { +class ArgmxxBase : public OperatorBase { DEF_OPR_IMPL_CTOR(ArgmxxBase, OperatorBase); DEF_OPR_PARAM(Axis); - protected: - void check_layout_fwd(const TensorLayout &src, - const TensorLayout &dst); +protected: + void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst); }; -class ArgmaxForward: public ArgmxxBase { +class ArgmaxForward : public ArgmxxBase { DEF_OPR_IMPL(ArgmaxForward, ArgmxxBase, 1, 1); - public: - /** - * \param[in] src input tensor - * \param[out] dst output tensor containing the argmax indices - * - * src and dst should be contiguous. - * src and dst should be of the same shape for all dimensions except - * param().axis. - * the param().axis-th dimension shape for dst should be one. - */ - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, - TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &src, - const TensorLayout &dst, - size_t workspace_in_bytes); + +public: + /** + * \param[in] src input tensor + * \param[out] dst output tensor containing the argmax indices + * + * src and dst should be contiguous. + * src and dst should be of the same shape for all dimensions except + * param().axis. + * the param().axis-th dimension shape for dst should be one. + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Argmax = ArgmaxForward; -class ArgminForward: public ArgmxxBase { +class ArgminForward : public ArgmxxBase { DEF_OPR_IMPL(ArgminForward, ArgmxxBase, 1, 1); - public: - /** - * \param[in] src input tensor - * \param[out] dst output tensor containing the argmax indices - * - * src and dst should be contiguous. - * src and dst should be of the same shape for all dimensions except - * param().axis. - * the param().axis-th dimension shape for dst should be one. - */ - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, - TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &src, - const TensorLayout &dst, - size_t workspace_in_bytes); + +public: + /** + * \param[in] src input tensor + * \param[out] dst output tensor containing the argmax indices + * + * src and dst should be contiguous. + * src and dst should be of the same shape for all dimensions except + * param().axis. + * the param().axis-th dimension shape for dst should be one. + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Argmin = ArgminForward; @@ -397,34 +406,37 @@ public: virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0; - virtual Output exec(_megdnn_tensor_in data, _megdnn_tensor_in mask, - _megdnn_workspace workspace, - DynOutMallocPolicyCall malloc_policy) = 0; + virtual Output exec( + _megdnn_tensor_in data, _megdnn_tensor_in mask, _megdnn_workspace workspace, + DynOutMallocPolicyCall malloc_policy) = 0; protected: //! check input layouts and get flattened size - size_t check_exec_get_size(const TensorLayout& data, - const TensorLayout& mask, - size_t workspace_in_bytes); + size_t check_exec_get_size( + const TensorLayout& data, const TensorLayout& mask, + size_t workspace_in_bytes); }; -class TransposeForward: public OperatorBase { +class TransposeForward : public OperatorBase { DEF_OPR_IMPL(TransposeForward, OperatorBase, 1, 1); DEF_OPR_PARAM(Empty); - public: - /** - * \param[in] src (m, n) stride[0] >= n && stride[1] == 1 - * \param[out] dst (n, m) stride[0] >= m && stride[1] == 1 - */ - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes); + +public: + /** + * \param[in] src (m, n) stride[0] >= n && stride[1] == 1 + * \param[out] dst (n, m) stride[0] >= m && stride[1] == 1 + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Transpose = TransposeForward; @@ -440,107 +452,108 @@ using Transpose = TransposeForward; * More contiguous the input/output layouts, higher performance. There is also * special optimization for broadcast case. */ -class RelayoutForward: public OperatorBase { +class RelayoutForward : public OperatorBase { DEF_OPR_IMPL(RelayoutForward, OperatorBase, 1, 1); DEF_OPR_PARAM(Empty); - public: - /*! - * \brief execute relayout opr - * - * This operator should be placed on the same computing device of *dst*. - * - * \param src_handle handle of input tensor; for CUDA d2d copy, the - * src handle can be on a different GPU for copy tensor with - * non-contig dims <= 2 - */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - Handle *src_handle = nullptr) = 0; - protected: - //! check layout and collapse contiguous - void check_layout_and_canonize( - TensorLayout &src, TensorLayout &dst); + +public: + /*! + * \brief execute relayout opr + * + * This operator should be placed on the same computing device of *dst*. + * + * \param src_handle handle of input tensor; for CUDA d2d copy, the + * src handle can be on a different GPU for copy tensor with + * non-contig dims <= 2 + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + Handle* src_handle = nullptr) = 0; + +protected: + //! check layout and collapse contiguous + void check_layout_and_canonize(TensorLayout& src, TensorLayout& dst); }; using Relayout = RelayoutForward; /** * \brief Base class for Concat and Split operators */ -class ConcatSplitBase: public OperatorBase { - public: - using Param = param::Axis; - - ConcatSplitBase(Handle *handle); - const Param ¶m() const { return m_param; } - Param ¶m() { return m_param; } - protected: - void check_layout_common(const TensorLayoutArray &srcs, - const TensorLayout &dst); - Param m_param; - /** - * \brief a helper function - * - * A = shape[0] * shape[1] * ... * shape[axis-1] - * B = {srcs[0].shape[axis], srcs[1].shape[axis], ...} - * C = shape[axis+1] * shape[axis+2] * ... * shape[ndim-1] - */ - void get_ABC(const TensorShapeArray &srcs, - size_t &A, - size_t *B, - size_t &C); - thin_function m_get_layout; - thin_function m_get_shape; +class ConcatSplitBase : public OperatorBase { +public: + using Param = param::Axis; + + ConcatSplitBase(Handle* handle); + const Param& param() const { return m_param; } + Param& param() { return m_param; } + +protected: + void check_layout_common(const TensorLayoutArray& srcs, const TensorLayout& dst); + Param m_param; + /** + * \brief a helper function + * + * A = shape[0] * shape[1] * ... * shape[axis-1] + * B = {srcs[0].shape[axis], srcs[1].shape[axis], ...} + * C = shape[axis+1] * shape[axis+2] * ... * shape[ndim-1] + */ + void get_ABC(const TensorShapeArray& srcs, size_t& A, size_t* B, size_t& C); + thin_function m_get_layout; + thin_function m_get_shape; }; -class ConcatForward: public ConcatSplitBase { +class ConcatForward : public ConcatSplitBase { DEF_OPR_IMPL(ConcatForward, ConcatSplitBase, 1, 1); - public: - /** - * \param[in] srcs a vector containing all inputs to be concatenated - * \param[out] dst the output tensor. - * - * All tensors in srcs and dst should be contiguous. - * All tensors should have the same shape for all axes except - * param().axis. - * For the param().axis-th axis, the axis shape for dst should be the - * sum of corresponding axis shapes for all srcs. - */ - virtual void exec(_megdnn_in const TensorNDArray &srcs, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayoutArray &srcs, - TensorLayout &dst); - virtual size_t get_workspace_in_bytes( - const TensorLayoutArray &srcs, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayoutArray &srcs, - const TensorLayout &dst, - size_t workspace_in_bytes); + +public: + /** + * \param[in] srcs a vector containing all inputs to be concatenated + * \param[out] dst the output tensor. + * + * All tensors in srcs and dst should be contiguous. + * All tensors should have the same shape for all axes except + * param().axis. + * For the param().axis-th axis, the axis shape for dst should be the + * sum of corresponding axis shapes for all srcs. + */ + virtual void exec( + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayoutArray& srcs, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayoutArray& srcs, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayoutArray& srcs, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Concat = ConcatForward; -class SplitForward: public ConcatSplitBase { +class SplitForward : public ConcatSplitBase { DEF_OPR_IMPL(SplitForward, ConcatSplitBase, 1, 1); - public: - /** - * \param[in] src input tensor - * \param[out] dsts a vector containing all splitted result - * - * All tensors in src and dsts should be contiguous. - * All tensors should have the same shape for all axes except - * param().axis. - * For the param().axis-th axis, the axis shape for src should be the - * sum of corresponding axis shapes for all dsts. - */ - virtual void exec(_megdnn_tensor_in src, - const TensorNDArray &dsts, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayoutArray &dsts) = 0; - protected: - void check_exec(const TensorLayout &src, - const TensorLayoutArray &dsts, - size_t workspace_in_bytes); + +public: + /** + * \param[in] src input tensor + * \param[out] dsts a vector containing all splitted result + * + * All tensors in src and dsts should be contiguous. + * All tensors should have the same shape for all axes except + * param().axis. + * For the param().axis-th axis, the axis shape for src should be the + * sum of corresponding axis shapes for all dsts. + */ + virtual void exec( + _megdnn_tensor_in src, const TensorNDArray& dsts, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayoutArray& dsts) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayoutArray& dsts, + size_t workspace_in_bytes); }; using Split = SplitForward; @@ -555,17 +568,17 @@ using Split = SplitForward; */ class ParamPackConcatSplitBase : public OperatorBase { protected: - void check_exec(const TensorLayout& concated, const TensorLayout& offsets, - const TensorLayout& parts); + void check_exec( + const TensorLayout& concated, const TensorLayout& offsets, + const TensorLayout& parts); public: using Param = megdnn::param::Empty; ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {} //! generate offsets to be used with ParamPackConcat and ParamPackSplit - static std::vector gen_offsets(const TensorShapeArray& shapes, - size_t alignment, - size_t dtype_size); + static std::vector gen_offsets( + const TensorShapeArray& shapes, size_t alignment, size_t dtype_size); }; /** @@ -573,7 +586,7 @@ public: * Combine multiple gradient tensors into a single large tensor, use copy * strategy due to AddUpdate or other dynamic situation. */ -class ParamPackConcat: public ParamPackConcatSplitBase { +class ParamPackConcat : public ParamPackConcatSplitBase { DEF_OPR_IMPL(ParamPackConcat, ParamPackConcatSplitBase, 2, 1); public: @@ -585,209 +598,209 @@ public: * the begin and the end of srcs[i]'s offsets in dst * \param[out] dst: output TensorND, live on cpu or gpu */ - virtual void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in offsets, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in srcs, _megdnn_tensor_in offsets, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorShapeArray& srcs, - const TensorShape& offsets, - const TensorShape& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorShapeArray& srcs, const TensorShape& offsets, + const TensorShape& dst) = 0; }; /** * \brief base class for Tile and Repeat */ -class TileRepeatBase: public OperatorBase { - public: - TileRepeatBase(Handle *handle): OperatorBase(handle) {} - struct Param { - TensorShape times; - }; - Param ¶m() { return m_param; } - const Param ¶m() const { return m_param; } - protected: - void check_layout_fwd(const TensorLayout &src, - const TensorLayout &dst); - void deduce_layout_fwd(const TensorLayout &src, - TensorLayout &dst); - /** - * Assuming src/dst/times are already simplified on entrance. - */ - size_t get_workspace_in_bytes_fwd(const TensorShape &src, - const TensorShape &dst, - const TensorShape ×, - DType dtype); - Param m_param; +class TileRepeatBase : public OperatorBase { +public: + TileRepeatBase(Handle* handle) : OperatorBase(handle) {} + struct Param { + TensorShape times; + }; + Param& param() { return m_param; } + const Param& param() const { return m_param; } + +protected: + void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst); + void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst); + /** + * Assuming src/dst/times are already simplified on entrance. + */ + size_t get_workspace_in_bytes_fwd( + const TensorShape& src, const TensorShape& dst, const TensorShape& times, + DType dtype); + Param m_param; }; -class TileBase: public TileRepeatBase { - public: - TileBase(Handle *handle): TileRepeatBase(handle) {} - protected: - void simplify_shape(const TensorShape &src, - const TensorShape &dst, - const TensorShape ×, - TensorShape &src2, - TensorShape &dst2, - TensorShape ×2); - /** - * This is a helper function that would facilitate other backends' - * implementation. - */ - size_t get_workspace_in_bytes_fwd(const TensorLayout &src, - const TensorLayout &dst); +class TileBase : public TileRepeatBase { +public: + TileBase(Handle* handle) : TileRepeatBase(handle) {} + +protected: + void simplify_shape( + const TensorShape& src, const TensorShape& dst, const TensorShape& times, + TensorShape& src2, TensorShape& dst2, TensorShape& times2); + /** + * This is a helper function that would facilitate other backends' + * implementation. + */ + size_t get_workspace_in_bytes_fwd(const TensorLayout& src, const TensorLayout& dst); }; -class TileForward: public TileBase { +class TileForward : public TileBase { DEF_OPR_IMPL(TileForward, TileBase, 1, 1); - public: - /** - * \brief Tile src times to get dst. - * \param[in] src input tensor - * \param[out] dst output tensor - * \param[out] workspace temporary workspace - * - * src and dst must be contiguous. - * dst.shape should be {src.shape[0]*param().times[0], - * src.shape[1]*param().times[1], ...} - * - * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html - * - * Difference between Tile and Repeat: - * Tiling `abc' twice yields `abcabc', whereas repeating `abc' twice - * yields `aabbcc'. - */ - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, - TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes); + +public: + /** + * \brief Tile src times to get dst. + * \param[in] src input tensor + * \param[out] dst output tensor + * \param[out] workspace temporary workspace + * + * src and dst must be contiguous. + * dst.shape should be {src.shape[0]*param().times[0], + * src.shape[1]*param().times[1], ...} + * + * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.tile.html + * + * Difference between Tile and Repeat: + * Tiling `abc' twice yields `abcabc', whereas repeating `abc' twice + * yields `aabbcc'. + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Tile = TileForward; -class TileBackward: public TileBase { +class TileBackward : public TileBase { DEF_OPR_IMPL(TileBackward, TileBase, 1, 1); - public: - /** - * \param[in] diff the backpropagated gradient wrt. dst - * \param[out] grad the backpropagated gradient wrt. src - * \param[out] workspace temporary workspace - */ - virtual void exec(_megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &diff, - const TensorLayout &grad) = 0; - protected: - void check_exec(const TensorLayout &diff, const TensorLayout &grad, - size_t workspace_in_bytes); + +public: + /** + * \param[in] diff the backpropagated gradient wrt. dst + * \param[out] grad the backpropagated gradient wrt. src + * \param[out] workspace temporary workspace + */ + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& grad) = 0; + +protected: + void check_exec( + const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes); }; -class RepeatBase: public TileRepeatBase { - public: - RepeatBase(Handle *handle): TileRepeatBase(handle) {} - protected: - void simplify_shape(const TensorShape &src, - const TensorShape &dst, - const TensorShape ×, - TensorShape &src2, - TensorShape &dst2, - TensorShape ×2); - /** - * This is a helper function that would facilitate other backends' - * implementation. - */ - size_t get_workspace_in_bytes_fwd(const TensorLayout &src, - const TensorLayout &dst); +class RepeatBase : public TileRepeatBase { +public: + RepeatBase(Handle* handle) : TileRepeatBase(handle) {} + +protected: + void simplify_shape( + const TensorShape& src, const TensorShape& dst, const TensorShape& times, + TensorShape& src2, TensorShape& dst2, TensorShape& times2); + /** + * This is a helper function that would facilitate other backends' + * implementation. + */ + size_t get_workspace_in_bytes_fwd(const TensorLayout& src, const TensorLayout& dst); }; -class RepeatForward: public RepeatBase { +class RepeatForward : public RepeatBase { DEF_OPR_IMPL(RepeatForward, RepeatBase, 1, 1); - public: - /** - * \brief Repeat src times to get dst. - * \param[in] src input tensor - * \param[out] dst output tensor - * \param[out] workspace temporary workspace - * - * src and dst must be contiguous. - * dst.shape should be {src.shape[0]*param().times[0], - * src.shape[1]*param().times[1], ...} - * - * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html - * \see TileForward - */ - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, - TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &src, - const TensorLayout &dst, - size_t workspace_in_bytes); + +public: + /** + * \brief Repeat src times to get dst. + * \param[in] src input tensor + * \param[out] dst output tensor + * \param[out] workspace temporary workspace + * + * src and dst must be contiguous. + * dst.shape should be {src.shape[0]*param().times[0], + * src.shape[1]*param().times[1], ...} + * + * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.repeat.html + * \see TileForward + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Repeat = RepeatForward; -class RepeatBackward: public RepeatBase { +class RepeatBackward : public RepeatBase { DEF_OPR_IMPL(RepeatBackward, RepeatBase, 1, 1); - public: - /** - * \param[in] diff the backpropagated gradient wrt. dst - * \param[out] grad the backpropagated gradient wrt. src - * \param[out] workspace temporary workspace - */ - virtual void exec(_megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &diff, - const TensorLayout &grad) = 0; - protected: - void check_exec(const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes); + +public: + /** + * \param[in] diff the backpropagated gradient wrt. dst + * \param[out] grad the backpropagated gradient wrt. src + * \param[out] workspace temporary workspace + */ + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& grad) = 0; + +protected: + void check_exec( + const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes); }; -class ArgsortForward: public OperatorBase { +class ArgsortForward : public OperatorBase { DEF_OPR_IMPL(ArgsortForward, OperatorBase, 1, 2); DEF_OPR_PARAM(Argsort); - public: - using Order = Param::Order; - /** - * \param[in] src (m, n) - * \param[out] dst (m, n) - * \param[out] indices (m, n) - * - * src, dst and indices should be contiguous. - * Performing m independent sorting on m arrays of length n. - * Sorting arrays and storing the resulting array in `dst', - * and the corresponding indices in `indices'. - * - * Indices range from 0 to n-1. - * - * Note that indices is a TensorND of type int. - */ - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_tensor_out indices, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, - TensorLayout &dst, - TensorLayout &indices); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst, - const TensorLayout &indices) = 0; - protected: - void check_exec(const TensorLayout &src, - const TensorLayout &dst, - const TensorLayout &indices, - size_t workspace_in_bytes); + +public: + using Order = Param::Order; + /** + * \param[in] src (m, n) + * \param[out] dst (m, n) + * \param[out] indices (m, n) + * + * src, dst and indices should be contiguous. + * Performing m independent sorting on m arrays of length n. + * Sorting arrays and storing the resulting array in `dst', + * and the corresponding indices in `indices'. + * + * Indices range from 0 to n-1. + * + * Note that indices is a TensorND of type int. + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& src, TensorLayout& dst, TensorLayout& indices); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& indices) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& indices, size_t workspace_in_bytes); }; using Argsort = ArgsortForward; @@ -810,15 +823,17 @@ public: * * Constraint: n >= k. Untouched values would be initialized as zero. */ - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& indices, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& indices, + const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& diff, const TensorLayout& indices, - const TensorLayout& grad, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& diff, const TensorLayout& indices, + const TensorLayout& grad, size_t workspace_in_bytes); }; class TopK : public OperatorBase { @@ -827,9 +842,9 @@ class TopK : public OperatorBase { protected: //! impl exec; inputs have been validated - virtual void do_exec(int k, _megdnn_tensor_in data, - _megdnn_tensor_out values, int32_t* indices, - _megdnn_workspace workspace) = 0; + virtual void do_exec( + int k, _megdnn_tensor_in data, _megdnn_tensor_out values, int32_t* indices, + _megdnn_workspace workspace) = 0; public: /*! @@ -843,149 +858,153 @@ public: * \param[out] indices () or (m, ) or (m, k) output values; its shape * depends on mode */ - void exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values, - _megdnn_tensor_out indices, _megdnn_workspace workspace); - virtual size_t get_workspace_in_bytes(int k, const TensorLayout& data, - const TensorLayout& values, - const TensorLayout& indices) = 0; - - void deduce_layout(int k, const TensorLayout& data, TensorLayout& values, - TensorLayout& indices); + void exec( + int k, _megdnn_tensor_in data, _megdnn_tensor_out values, + _megdnn_tensor_out indices, _megdnn_workspace workspace); + virtual size_t get_workspace_in_bytes( + int k, const TensorLayout& data, const TensorLayout& values, + const TensorLayout& indices) = 0; + + void deduce_layout( + int k, const TensorLayout& data, TensorLayout& values, + TensorLayout& indices); }; /*! * \brief convert dtype of *src* to match dtype of *dst*; *src* may have * arbitrary layout and *dst* must be contiguous. */ -class TypeCvtForward: public OperatorBase { +class TypeCvtForward : public OperatorBase { DEF_OPR_PARAM(Empty); DEF_OPR_IMPL(TypeCvtForward, OperatorBase, 1, 1); - public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &dst); + +public: + virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) = 0; + +protected: + void check_exec(const TensorLayout& src, const TensorLayout& dst); }; using TypeCvt = TypeCvtForward; -class IndexingRemapBase: public OperatorBase { - public: - using Param = param::IndexingRemap; - - IndexingRemapBase(Handle *handle): OperatorBase(handle) {} - Param ¶m() { return m_param; } - const Param ¶m() const { return m_param; } - protected: - Param m_param; - void check_layout_fwd(const TensorLayout &src, - const TensorLayout &map, - const TensorLayout &dst); +class IndexingRemapBase : public OperatorBase { +public: + using Param = param::IndexingRemap; + + IndexingRemapBase(Handle* handle) : OperatorBase(handle) {} + Param& param() { return m_param; } + const Param& param() const { return m_param; } + +protected: + Param m_param; + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& map, const TensorLayout& dst); }; -class IndexingRemapForward: public IndexingRemapBase { +class IndexingRemapForward : public IndexingRemapBase { DEF_OPR_IMPL(IndexingRemapForward, IndexingRemapBase, 2, 1); - public: - /** - * \param[in] src input tensor - * \param[in] map input map - * \param[out] dst output tensor - * - * Suppose: - * the shape of src is \f$(s_0, s_1, ..., s_{m-1}\f$; - * the shape of dst is \f$(d_0, d_1, ..., d_{n-1})\f$; - * then: - * the shape of map must be \f$(d_0, d_1, ..., d_{n-1}, m)\f$. - * - * The last dimension of map indicates the src indices for the - * corresponding dst entry. - * - * src and dst can be non-contiguous in a non-overlapping manner. - */ - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_in map, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout &src, - const TensorLayout &map, - TensorLayout &dst); - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &map, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &src, - const TensorLayout &map, - const TensorLayout &dst, - size_t workspace_in_bytes); + +public: + /** + * \param[in] src input tensor + * \param[in] map input map + * \param[out] dst output tensor + * + * Suppose: + * the shape of src is \f$(s_0, s_1, ..., s_{m-1}\f$; + * the shape of dst is \f$(d_0, d_1, ..., d_{n-1})\f$; + * then: + * the shape of map must be \f$(d_0, d_1, ..., d_{n-1}, m)\f$. + * + * The last dimension of map indicates the src indices for the + * corresponding dst entry. + * + * src and dst can be non-contiguous in a non-overlapping manner. + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in map, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& src, const TensorLayout& map, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& map, + const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& map, const TensorLayout& dst, + size_t workspace_in_bytes); }; using IndexingRemap = IndexingRemapForward; // The using directives preserve backward compatibility. using TensorRemapForward = IndexingRemap; using TensorRemap = TensorRemapForward; -class IndexingRemapBackward: public IndexingRemapBase { +class IndexingRemapBackward : public IndexingRemapBase { DEF_OPR_IMPL(IndexingRemapBackward, IndexingRemapBase, 2, 1); - public: - /** - * \param[in] diff the backpropagated gradient wrt. dst - * \param[in] map the `map' parameter in IndexingRemapForward::exec - * \param[out] grad the backpropagated gradient wrt. src - */ - virtual void exec(_megdnn_tensor_in diff, - _megdnn_tensor_in map, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &diff, - const TensorLayout &map, - const TensorLayout &grad) = 0; - protected: - void check_exec(const TensorLayout &diff, - const TensorLayout &map, - const TensorLayout &grad, - size_t workspace_in_bytes); + +public: + /** + * \param[in] diff the backpropagated gradient wrt. dst + * \param[in] map the `map' parameter in IndexingRemapForward::exec + * \param[out] grad the backpropagated gradient wrt. src + */ + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in map, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& map, + const TensorLayout& grad) = 0; + +protected: + void check_exec( + const TensorLayout& diff, const TensorLayout& map, const TensorLayout& grad, + size_t workspace_in_bytes); }; // The using directives preserve backward compatibility. using TensorRemapBackward = IndexingRemapBackward; -class Linspace: public OperatorBase { +class Linspace : public OperatorBase { DEF_OPR_IMPL(Linspace, OperatorBase, 0, 1); DEF_OPR_PARAM(LinspaceFull); - public: - /** - * \param[out] dst must be 1d. - * - * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html - */ - virtual void exec(_megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); + +public: + /** + * \param[out] dst must be 1d. + * + * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html + */ + virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0; + +protected: + void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); }; -class Eye: public OperatorBase { +class Eye : public OperatorBase { DEF_OPR_IMPL(Eye, OperatorBase, 0, 1); DEF_OPR_PARAM(Eye); - public: - /** - * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.eye.html - */ - virtual void exec(_megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); + +public: + /** + * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.eye.html + */ + virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0; + +protected: + void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); }; -class IndexingOneHotBase: public OperatorBase { +class IndexingOneHotBase : public OperatorBase { DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase); DEF_OPR_PARAM(Axis); - protected: - void deduce_layout_fwd(const TensorLayout &src, - const TensorLayout &index, - TensorLayout &dst); - void check_layout_fwd(const TensorLayout &src, - const TensorLayout &index, - const TensorLayout &dst); +protected: + void deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& index, TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& index, + const TensorLayout& dst); }; /*! @@ -1000,26 +1019,26 @@ class IndexingOneHotBase: public OperatorBase { * \param[in] index (n-1)-dimensional index, must be int * \param[out] dst n-dimensional output data */ -class IndexingOneHotForward: public IndexingOneHotBase { +class IndexingOneHotForward : public IndexingOneHotBase { DEF_OPR_IMPL(IndexingOneHotForward, IndexingOneHotBase, 2, 1); - public: - void deduce_layout(const TensorLayout &src, - const TensorLayout &index, TensorLayout &dst) { - deduce_layout_fwd(src, index, dst); - } - - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_in index, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &index, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &src, - const TensorLayout &index, const TensorLayout &dst, - size_t workspace_in_bytes); +public: + void deduce_layout( + const TensorLayout& src, const TensorLayout& index, TensorLayout& dst) { + deduce_layout_fwd(src, index, dst); + } + + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in index, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& index, + const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& index, const TensorLayout& dst, + size_t workspace_in_bytes); }; using IndexingOneHot = IndexingOneHotForward; @@ -1031,19 +1050,21 @@ using IndexingOneHot = IndexingOneHotForward; * \param[in] index (n-1)-dimensional index, must be int * \param[in] sub n-dimensional sub tensor to be filled in *data* */ -class IndexingSetOneHotForward: public IndexingOneHotBase { +class IndexingSetOneHotForward : public IndexingOneHotBase { DEF_OPR_IMPL(IndexingSetOneHotForward, IndexingOneHotBase, -1, 1); - public: - virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in index, - _megdnn_tensor_in sub, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &data, - const TensorLayout &index, - const TensorLayout &sub) = 0; - protected: - void check_exec(const TensorLayout &data, - const TensorLayout &index, const TensorLayout &sub, - size_t workspace_in_bytes); +public: + virtual void exec( + _megdnn_tensor_inout data, _megdnn_tensor_in index, _megdnn_tensor_in sub, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& data, const TensorLayout& index, + const TensorLayout& sub) = 0; + +protected: + void check_exec( + const TensorLayout& data, const TensorLayout& index, + const TensorLayout& sub, size_t workspace_in_bytes); }; using IndexingSetOneHot = IndexingSetOneHotForward; @@ -1052,71 +1073,70 @@ using IndexingSetOneHot = IndexingSetOneHotForward; * * Note that the indexing axes are required to be sorted in ascending order */ -class IndexingMultiAxisVecBase: public OperatorBase { +class IndexingMultiAxisVecBase : public OperatorBase { DEF_OPR_IMPL_CTOR(IndexingMultiAxisVecBase, OperatorBase); DEF_OPR_PARAM(Empty); - public: - struct AxisIndexer { - size_t axis; - TensorND vec; - }; - - struct AxisIndexerLayoutOnly { - size_t axis; - TensorLayout layout; - }; - - using IndexDesc = std::vector; - using IndexDescLayoutOnly = std::vector; - - /*! - * \brief convert IndexDesc to IndexDescLayoutOnly - */ - static IndexDescLayoutOnly extract_index_layout(const IndexDesc &index); - - /*! - * \brief get the axes on src that are not used in index - * \param[out] out output buffer; suggested size is - * TensorLayout::MAX_NDIM - * \return number of elements written to *out* - */ - static size_t get_nonindex_axes(size_t src_ndim, const IndexDesc &index, - size_t *out); - - /*! - * \brief get contiguous-collapsed layout for indexing on value - * \param idx_axis indexer axis on value (i.e. ExecInfo::idx_axis) - * \return a tensor layout and an axis to iterate over *value* and also - * access *data*; stride of layout on that axis would be zero, and - * strides on other axes correspond to the strides in *data* - */ - static std::pair get_value_iter_optimized_layout( - const TensorLayout &data, const TensorLayout &value, - const IndexDesc &index, size_t idx_axis); - - //! helper info for kernel implementation - struct ExecInfo { - //! axis in value used by indexer - size_t idx_axis; - ptrdiff_t value_stride; - - void* error_tracker; - megcore::AsyncErrorInfo* error_info; - }; - - protected: - /*! - * \return axis on dst used by indexer (i.e. ExecInfo::idx_axis) - */ - static size_t deduce_layout_fwd( - const TensorLayout &data, - const IndexDescLayoutOnly &index, - TensorLayout &dst); - - static ExecInfo check_exec_noworkspace( - const TensorLayout &data, const TensorLayout &value, - const IndexDesc &index, IndexDescLayoutOnly &index_layout); +public: + struct AxisIndexer { + size_t axis; + TensorND vec; + }; + + struct AxisIndexerLayoutOnly { + size_t axis; + TensorLayout layout; + }; + + using IndexDesc = std::vector; + using IndexDescLayoutOnly = std::vector; + + /*! + * \brief convert IndexDesc to IndexDescLayoutOnly + */ + static IndexDescLayoutOnly extract_index_layout(const IndexDesc& index); + + /*! + * \brief get the axes on src that are not used in index + * \param[out] out output buffer; suggested size is + * TensorLayout::MAX_NDIM + * \return number of elements written to *out* + */ + static size_t get_nonindex_axes( + size_t src_ndim, const IndexDesc& index, size_t* out); + + /*! + * \brief get contiguous-collapsed layout for indexing on value + * \param idx_axis indexer axis on value (i.e. ExecInfo::idx_axis) + * \return a tensor layout and an axis to iterate over *value* and also + * access *data*; stride of layout on that axis would be zero, and + * strides on other axes correspond to the strides in *data* + */ + static std::pair get_value_iter_optimized_layout( + const TensorLayout& data, const TensorLayout& value, const IndexDesc& index, + size_t idx_axis); + + //! helper info for kernel implementation + struct ExecInfo { + //! axis in value used by indexer + size_t idx_axis; + ptrdiff_t value_stride; + + void* error_tracker; + megcore::AsyncErrorInfo* error_info; + }; + +protected: + /*! + * \return axis on dst used by indexer (i.e. ExecInfo::idx_axis) + */ + static size_t deduce_layout_fwd( + const TensorLayout& data, const IndexDescLayoutOnly& index, + TensorLayout& dst); + + static ExecInfo check_exec_noworkspace( + const TensorLayout& data, const TensorLayout& value, const IndexDesc& index, + IndexDescLayoutOnly& index_layout); }; /*! @@ -1124,36 +1144,32 @@ class IndexingMultiAxisVecBase: public OperatorBase { * * src can have arbitrary layout, but dst must be dim1-contig */ -class IndexingMultiAxisVec: public IndexingMultiAxisVecBase { +class IndexingMultiAxisVec : public IndexingMultiAxisVecBase { DEF_OPR_IMPL(IndexingMultiAxisVec, IndexingMultiAxisVecBase, 0, 1); - public: - virtual void exec(_megdnn_tensor_in src, - const IndexDesc &index, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - - /*! - * \brief get workspace size based on output shape and indexing axes - */ - size_t get_workspace_in_bytes( - const TensorShape &dst, - const size_t *axes, size_t nr_axes); - - static void deduce_layout( - const TensorLayout &data, - const IndexDescLayoutOnly &index, - TensorLayout &dst) { - deduce_layout_fwd(data, index, dst); - } - protected: - - virtual size_t get_workspace_in_bytes(size_t dst_idx_size) = 0; - - ExecInfo check_exec( - const TensorLayout &src, - const IndexDesc &index, - const TensorLayout &dst, - size_t workspace_in_bytes); +public: + virtual void exec( + _megdnn_tensor_in src, const IndexDesc& index, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + + /*! + * \brief get workspace size based on output shape and indexing axes + */ + size_t get_workspace_in_bytes( + const TensorShape& dst, const size_t* axes, size_t nr_axes); + + static void deduce_layout( + const TensorLayout& data, const IndexDescLayoutOnly& index, + TensorLayout& dst) { + deduce_layout_fwd(data, index, dst); + } + +protected: + virtual size_t get_workspace_in_bytes(size_t dst_idx_size) = 0; + + ExecInfo check_exec( + const TensorLayout& src, const IndexDesc& index, const TensorLayout& dst, + size_t workspace_in_bytes); }; /*! @@ -1161,42 +1177,37 @@ class IndexingMultiAxisVec: public IndexingMultiAxisVecBase { * * data can have arbitrary layout, but value must be dim1-contig */ -class IndexingModifyMultiAxisVecBase: public IndexingMultiAxisVecBase { +class IndexingModifyMultiAxisVecBase : public IndexingMultiAxisVecBase { DEF_OPR_IMPL_CTOR(IndexingModifyMultiAxisVecBase, IndexingMultiAxisVecBase); - public: - virtual void exec( - _megdnn_tensor_inout data, _megdnn_tensor_in value, - const IndexDesc &index, - _megdnn_workspace workspace) = 0; - - /*! - * \brief get workspace size based on shape of value input and indexing - * axes - */ - size_t get_workspace_in_bytes( - const TensorShape &value, - const size_t *axes, size_t nr_axes); - - protected: - ExecInfo check_exec( - const TensorLayout &data, const TensorLayout &value, - const IndexDesc &index, - size_t workspace_in_bytes); - - virtual size_t get_workspace_in_bytes(size_t value_idx_size) = 0; +public: + virtual void exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& index, + _megdnn_workspace workspace) = 0; + + /*! + * \brief get workspace size based on shape of value input and indexing + * axes + */ + size_t get_workspace_in_bytes( + const TensorShape& value, const size_t* axes, size_t nr_axes); + +protected: + ExecInfo check_exec( + const TensorLayout& data, const TensorLayout& value, const IndexDesc& index, + size_t workspace_in_bytes); + + virtual size_t get_workspace_in_bytes(size_t value_idx_size) = 0; }; //! set value to indexed locations; index values must be non-overlapping -class IndexingSetMultiAxisVec: public IndexingModifyMultiAxisVecBase { - DEF_OPR_IMPL(IndexingSetMultiAxisVec, - IndexingModifyMultiAxisVecBase, 0, 0); +class IndexingSetMultiAxisVec : public IndexingModifyMultiAxisVecBase { + DEF_OPR_IMPL(IndexingSetMultiAxisVec, IndexingModifyMultiAxisVecBase, 0, 0); }; //! add value to indexed locations; index values must be non-overlapping -class IndexingIncrMultiAxisVec: public IndexingModifyMultiAxisVecBase { - DEF_OPR_IMPL(IndexingIncrMultiAxisVec, - IndexingModifyMultiAxisVecBase, 0, 0); +class IndexingIncrMultiAxisVec : public IndexingModifyMultiAxisVecBase { + DEF_OPR_IMPL(IndexingIncrMultiAxisVec, IndexingModifyMultiAxisVecBase, 0, 0); }; class MeshBase : public OperatorBase { @@ -1206,8 +1217,7 @@ class MeshBase : public OperatorBase { public: using AxisIndexer = IndexingMultiAxisVecBase::AxisIndexer; using IndexDesc = IndexingMultiAxisVecBase::IndexDesc; - using AxisIndexerLayoutOnly = - IndexingMultiAxisVecBase::AxisIndexerLayoutOnly; + using AxisIndexerLayoutOnly = IndexingMultiAxisVecBase::AxisIndexerLayoutOnly; using IndexDescLayoutOnly = IndexingMultiAxisVecBase::IndexDescLayoutOnly; size_t get_workspace_in_bytes(const TensorShape&, const size_t*, size_t) { @@ -1215,54 +1225,58 @@ public: } protected: - virtual void check_exec(const TensorLayout& origin, - const TensorLayout& indexed, const IndexDesc& desc); + virtual void check_exec( + const TensorLayout& origin, const TensorLayout& indexed, + const IndexDesc& desc); }; class NormalMeshBase : public MeshBase { DEF_OPR_IMPL(NormalMeshBase, MeshBase, 0, 0); protected: - virtual void check_exec(const TensorLayout& origin, - const TensorLayout& indexed, - const IndexDesc& desc) override final; + virtual void check_exec( + const TensorLayout& origin, const TensorLayout& indexed, + const IndexDesc& desc) override final; }; class NormalMeshModifyBase : public NormalMeshBase { DEF_OPR_IMPL_CTOR(NormalMeshModifyBase, NormalMeshBase); public: - virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value, - const IndexDesc& desc, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc, + _megdnn_workspace workspace) = 0; }; class BatchedMeshBase : public MeshBase { DEF_OPR_IMPL_CTOR(BatchedMeshBase, MeshBase); protected: - virtual void check_exec(const TensorLayout& origin, - const TensorLayout& indexed, - const IndexDesc& desc) override final; + virtual void check_exec( + const TensorLayout& origin, const TensorLayout& indexed, + const IndexDesc& desc) override final; }; class BatchedMeshModifyBase : public BatchedMeshBase { DEF_OPR_IMPL_CTOR(BatchedMeshModifyBase, BatchedMeshBase); public: - virtual void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value, - const IndexDesc& desc, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc, + _megdnn_workspace workspace) = 0; }; class MeshIndexing : public NormalMeshBase { DEF_OPR_IMPL(MeshIndexing, NormalMeshBase, 0, 0); public: - virtual void exec(_megdnn_tensor_in src, const IndexDesc& desc, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, const IndexDesc& desc, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; - static void deduce_layout(const TensorLayout& inp, - const IndexDescLayoutOnly& layouts, - TensorLayout& out_layout); + static void deduce_layout( + const TensorLayout& inp, const IndexDescLayoutOnly& layouts, + TensorLayout& out_layout); }; class IncrMeshIndexing : public NormalMeshModifyBase { @@ -1277,13 +1291,13 @@ class BatchedMeshIndexing : public BatchedMeshBase { DEF_OPR_IMPL(BatchedMeshIndexing, BatchedMeshBase, 0, 0); public: - virtual void exec(_megdnn_tensor_in src, const IndexDesc& desc, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, const IndexDesc& desc, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; - static void deduce_layout(const TensorLayout& inp, - const IndexDescLayoutOnly& layouts, - TensorLayout& out_layout); + static void deduce_layout( + const TensorLayout& inp, const IndexDescLayoutOnly& layouts, + TensorLayout& out_layout); }; class BatchedIncrMeshIndexing : public BatchedMeshModifyBase { @@ -1299,103 +1313,114 @@ class RelayoutFormat : public OperatorBase { DEF_OPR_IMPL(RelayoutFormat, OperatorBase, 1, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_format(TensorFormat src, TensorFormat& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; protected: void deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst); void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst); - void check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); - void deduce_exec_layout(const TensorLayout& src, const TensorLayout& dst, - TensorLayout& exec_workspace, - TensorLayout& exec_src, TensorLayout& exec_dst); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); + void deduce_exec_layout( + const TensorLayout& src, const TensorLayout& dst, + TensorLayout& exec_workspace, TensorLayout& exec_src, + TensorLayout& exec_dst); }; /*! * \brief check whether input contains inf or nan value. */ -class CheckNonFinite: public OperatorBase { +class CheckNonFinite : public OperatorBase { DEF_OPR_PARAM(Empty); DEF_OPR_IMPL(CheckNonFinite, OperatorBase, 1, 1); - public: - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) = 0; +public: + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; - void deduce_layout(const TensorLayout &src, TensorLayout &dst); + void deduce_layout(const TensorLayout& src, TensorLayout& dst); - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; - protected: - void check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes); +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; /*! * \brief fill the tensor with a scalar value */ -class Fill: public OperatorBase { +class Fill : public OperatorBase { DEF_OPR_PARAM(Fill); DEF_OPR_IMPL(Fill, OperatorBase, 0, 1); public: - virtual void exec(_megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0; + protected: void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); }; /*! * \brief standard padding operator - * Inputs must have the same dtype, and the output tensor shape must greater or equal than - * input tensor in every dimensions, the extra space will be fulled with m which default to - * be 0. + * Inputs must have the same dtype, and the output tensor shape must greater or equal + * than input tensor in every dimensions, the extra space will be fulled with m which + * default to be 0. */ -class PaddingBase: public OperatorBase { +class PaddingBase : public OperatorBase { DEF_OPR_PARAM(Padding); DEF_OPR_IMPL(PaddingBase, OperatorBase, 1, 1); + public: using Mode = Param::PaddingMode; + protected: SmallVector get_offsets(); void check_exec(const TensorLayout& src, const TensorLayout& dst); }; -class PaddingForward: public PaddingBase { +class PaddingForward : public PaddingBase { DEF_OPR_IMPL(PaddingForward, PaddingBase, 1, 1); + public: virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) = 0; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace) { + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace) { return exec(src, dst); } - virtual size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst) = 0; - void deduce_layout(const TensorLayout &src, TensorLayout &dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; + void deduce_layout(const TensorLayout& src, TensorLayout& dst); + protected: void forward_check_exec(const TensorLayout& src, const TensorLayout& dst); }; using Padding = PaddingForward; -class PaddingBackward: public PaddingBase { +class PaddingBackward : public PaddingBase { DEF_OPR_IMPL(PaddingBackward, PaddingBase, 1, 1); + public: virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) = 0; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace) { + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace) { return exec(src, dst); } - virtual size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; + protected: void backward_check_exec(const TensorLayout& src, const TensorLayout& dst); }; diff --git a/dnn/include/megdnn/oprs/imgproc.h b/dnn/include/megdnn/oprs/imgproc.h index 71b61566..be0a9474 100644 --- a/dnn/include/megdnn/oprs/imgproc.h +++ b/dnn/include/megdnn/oprs/imgproc.h @@ -13,173 +13,162 @@ namespace megdnn { -class WarpPerspectiveBase: public OperatorBase { +class WarpPerspectiveBase : public OperatorBase { DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase); DEF_OPR_PARAM(WarpPerspective); - public: - using InterpolationMode = Param::InterpolationMode; - using BorderMode = Param::BorderMode; - - protected: - void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, - const TensorLayout &dst) { - check_layout_fwd(src, mat, {}, dst); - } - - void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, - const TensorLayout &mat_idx, const TensorLayout &dst); - std::string param_msg() const; - int get_real_coord(int p, int len); + +public: + using InterpolationMode = Param::InterpolationMode; + using BorderMode = Param::BorderMode; + +protected: + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { + check_layout_fwd(src, mat, {}, dst); + } + + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& mat, + const TensorLayout& mat_idx, const TensorLayout& dst); + std::string param_msg() const; + int get_real_coord(int p, int len); }; -class WarpPerspectiveForward: public WarpPerspectiveBase { +class WarpPerspectiveForward : public WarpPerspectiveBase { DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1); - public: - /** - * \param[in] src (n, channel, in_height, in_width) - * \param[in] mat (n, 3, 3) - * \param[out] dst (n, channel, out_height, out_width) - * - * \see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine - * - * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] - * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, - * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) - * - * src and dst can have different shapes, as long as their n and c agree. - * src, mat and dst should be contiguous. - */ - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in mat, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) { - exec(src, mat, {}, dst, workspace); - } - - /** - * \p src should have batch size m, and \p mat and \p mat_idx should - * both have batch size n. Each item in \p mat_idx must be in the range - * of [0, m-1]. - * - * \param mat_idx the indices of input image that each matrix in \p mat - * should act on. It can also be empty and in such case \p mat - * should have the same batch size as \p src. - */ - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_in mat, - _megdnn_tensor_in mat_idx, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - - size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &mat, - const TensorLayout &dst) { - return get_workspace_in_bytes(src, mat, {}, dst); - } - - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &mat, - const TensorLayout &mat_idx, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &src, - const TensorLayout &mat, - const TensorLayout &mat_idx, - const TensorLayout &dst, - size_t workspace_in_bytes); - - void check_exec_allow_nhwc_mat_idx(const TensorLayout &src, - const TensorLayout &mat, - const TensorLayout &mat_idx, - const TensorLayout &dst, - size_t workspace_in_bytes); + +public: + /** + * \param[in] src (n, channel, in_height, in_width) + * \param[in] mat (n, 3, 3) + * \param[out] dst (n, channel, out_height, out_width) + * + * \see + * http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine + * + * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] + * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, + * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) + * + * src and dst can have different shapes, as long as their n and c agree. + * src, mat and dst should be contiguous. + */ + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + exec(src, mat, {}, dst, workspace); + } + + /** + * \p src should have batch size m, and \p mat and \p mat_idx should + * both have batch size n. Each item in \p mat_idx must be in the range + * of [0, m-1]. + * + * \param mat_idx the indices of input image that each matrix in \p mat + * should act on. It can also be empty and in such case \p mat + * should have the same batch size as \p src. + */ + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, + _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { + return get_workspace_in_bytes(src, mat, {}, dst); + } + + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& mat, + const TensorLayout& mat_idx, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& mat, + const TensorLayout& mat_idx, const TensorLayout& dst, + size_t workspace_in_bytes); + + void check_exec_allow_nhwc_mat_idx( + const TensorLayout& src, const TensorLayout& mat, + const TensorLayout& mat_idx, const TensorLayout& dst, + size_t workspace_in_bytes); }; using WarpPerspective = WarpPerspectiveForward; -class WarpPerspectiveBackwardData: public WarpPerspectiveBase { +class WarpPerspectiveBackwardData : public WarpPerspectiveBase { DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1); - public: - /** - * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec - * \param[in] diff the backpropagated gradient wrt. dst - * \param[out] grad the backpropagated gradient wrt. src - * \param[out] workspace temporary workspace to perform backward - */ - void exec(_megdnn_tensor_in mat, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) { - exec(mat, {}, diff, grad, workspace); - } - - virtual void exec(_megdnn_tensor_in mat, - _megdnn_tensor_in mat_idx, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; - - size_t get_workspace_in_bytes(const TensorLayout &mat, - const TensorLayout &diff, - const TensorLayout &grad) { - return get_workspace_in_bytes(mat, {}, diff, grad); - } - - virtual size_t get_workspace_in_bytes(const TensorLayout &mat, - const TensorLayout &mat_idx, - const TensorLayout &diff, - const TensorLayout &grad) = 0; - protected: - void check_exec(const TensorLayout &mat, - const TensorLayout &mat_idx, - const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes); + +public: + /** + * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec + * \param[in] diff the backpropagated gradient wrt. dst + * \param[out] grad the backpropagated gradient wrt. src + * \param[out] workspace temporary workspace to perform backward + */ + void exec( + _megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + exec(mat, {}, diff, grad, workspace); + } + + virtual void exec( + _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; + + size_t get_workspace_in_bytes( + const TensorLayout& mat, const TensorLayout& diff, + const TensorLayout& grad) { + return get_workspace_in_bytes(mat, {}, diff, grad); + } + + virtual size_t get_workspace_in_bytes( + const TensorLayout& mat, const TensorLayout& mat_idx, + const TensorLayout& diff, const TensorLayout& grad) = 0; + +protected: + void check_exec( + const TensorLayout& mat, const TensorLayout& mat_idx, + const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes); }; -class WarpPerspectiveBackwardMat: public WarpPerspectiveBase { +class WarpPerspectiveBackwardMat : public WarpPerspectiveBase { DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1); - public: - /** - * \param[in] src the `src' parameter in WarpPerspectiveForward::exec - * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec - * \param[in] diff the backpropagated gradient wrt. dst - * \param[out] grad the backpropagated gradient wrt. mat - * \param[out] workspace temporary workspace to perform backward - */ - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in mat, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) { - exec(src, mat, {}, diff, grad, workspace); - } - - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_in mat, - _megdnn_tensor_in mat_idx, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; - - size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &mat, - const TensorLayout &diff, - const TensorLayout &grad) { - return get_workspace_in_bytes(src, mat, {}, diff, grad); - } - - virtual size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &mat, - const TensorLayout &mat_idx, - const TensorLayout &diff, - const TensorLayout &grad) = 0; - protected: - void check_exec(const TensorLayout &src, - const TensorLayout &mat, - const TensorLayout &mat_idx, - const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes); + +public: + /** + * \param[in] src the `src' parameter in WarpPerspectiveForward::exec + * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec + * \param[in] diff the backpropagated gradient wrt. dst + * \param[out] grad the backpropagated gradient wrt. mat + * \param[out] workspace temporary workspace to perform backward + */ + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) { + exec(src, mat, {}, diff, grad, workspace); + } + + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff, + const TensorLayout& grad) { + return get_workspace_in_bytes(src, mat, {}, diff, grad); + } + + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& mat, + const TensorLayout& mat_idx, const TensorLayout& diff, + const TensorLayout& grad) = 0; + +protected: + void check_exec( + const TensorLayout& src, const TensorLayout& mat, + const TensorLayout& mat_idx, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes); }; class DctChannelSelectForward : public OperatorBase { @@ -194,37 +183,32 @@ public: * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor * \param[out] workspace temporary workspace to perform forward */ - virtual void exec(_megdnn_tensor_in src, - _megdnn_tensor_in mask_offset, - _megdnn_tensor_in mask_val, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - - void deduce_layout(const TensorLayout& src, - const TensorLayout& mask_offset, - const TensorLayout& mask_val, - TensorLayout& dst); - - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& mask_offset, - const TensorLayout& mask_val, - const TensorLayout& dst) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in mask_offset, + _megdnn_tensor_in mask_val, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + + void deduce_layout( + const TensorLayout& src, const TensorLayout& mask_offset, + const TensorLayout& mask_val, TensorLayout& dst); + + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& mask_offset, + const TensorLayout& mask_val, const TensorLayout& dst) = 0; protected: - void check_layout_fwd(const TensorLayout& src, - const TensorLayout& mask_offset, - const TensorLayout& mask_val, - const TensorLayout& dst); - - void deduce_layout_fwd(const TensorLayout& src, - const TensorLayout& mask_offset, - const TensorLayout& mask_val, - TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& mask_offset, + const TensorLayout& mask_val, const TensorLayout& dst); + + void deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& mask_offset, + const TensorLayout& mask_val, TensorLayout& dst); std::string param_msg() const; }; -} // namespace megdnn +} // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/include/megdnn/oprs/linalg.h b/dnn/include/megdnn/oprs/linalg.h index baf69d3b..54931275 100644 --- a/dnn/include/megdnn/oprs/linalg.h +++ b/dnn/include/megdnn/oprs/linalg.h @@ -33,22 +33,22 @@ public: * op(A) = A if transposeA is false, otherwise op(A) = A^t. * op(B) = B if transposeB is false, otherwise op(B) = B^t. */ - virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, - _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; - void deduce_dtype(DType A, DType B, DType &C); - void deduce_layout(const TensorLayout& A, const TensorLayout& B, - TensorLayout& C); - virtual size_t get_workspace_in_bytes(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) = 0; + virtual void exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) = 0; + void deduce_dtype(DType A, DType B, DType& C); + void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); + virtual size_t get_workspace_in_bytes( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; static Algorithm::OprType get_opr_type() { return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD; } protected: - void check_exec(const TensorLayout& A, const TensorLayout& B, - const TensorLayout& C, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, + size_t workspace_in_bytes); }; using BatchedMatrixMul = BatchedMatrixMulForward; @@ -70,24 +70,24 @@ public: * op(A) = A if transposeA is false, otherwise op(A) = A^t. * op(B) = B if transposeB is false, otherwise op(B) = B^t. */ - virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, - _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) = 0; void deduce_dtype(DType A, DType B, DType& C); - void deduce_layout(const TensorLayout& A, const TensorLayout& B, - TensorLayout& C); - virtual size_t get_workspace_in_bytes(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) = 0; + void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); + virtual size_t get_workspace_in_bytes( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; - static size_t pack_size (const Param::Format format); + static size_t pack_size(const Param::Format format); static Algorithm::OprType get_opr_type() { return Algorithm::OprType::MATRIX_MUL_FORWARD; } protected: - void check_exec(const TensorLayout& A, const TensorLayout& B, - const TensorLayout& C, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, + size_t workspace_in_bytes); }; using MatrixMul = MatrixMulForward; @@ -104,11 +104,11 @@ class MatrixInverse : public OperatorBase { DEF_OPR_PARAM(Empty); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst); + size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst); protected: /*! @@ -116,8 +116,7 @@ protected: * * Note that \p batch and \p n can be null */ - static void canonize_params(const TensorLayout& layout, size_t* batch, - size_t* n); + static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n); /*! * \brief canonize and validate input params for exec() impls @@ -125,11 +124,12 @@ protected: * Since get_workspace_in_bytes() would be called, \p batch and \p n can not * be null */ - void check_exec(const TensorLayout& src, const TensorLayout& dst, - _megdnn_workspace workspace, size_t* batch, size_t* n); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + _megdnn_workspace workspace, size_t* batch, size_t* n); - virtual size_t get_workspace_in_bytes(size_t batch, size_t n, - size_t dtype_size) = 0; + virtual size_t get_workspace_in_bytes( + size_t batch, size_t n, size_t dtype_size) = 0; }; //! inter-product of two vectors @@ -147,17 +147,17 @@ public: * A, B, C must be contiguous. A and B must have the same 1-dimensional * shape and non-negative strides. C must be scalar. */ - virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, - _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& A, const TensorLayout& B, - TensorLayout& C); - virtual size_t get_workspace_in_bytes(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) = 0; + virtual void exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) = 0; + void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); + virtual size_t get_workspace_in_bytes( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; protected: - void check_exec(const TensorLayout& A, const TensorLayout& B, - const TensorLayout& C, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, + size_t workspace_in_bytes); }; using Dot = DotForward; @@ -193,23 +193,24 @@ public: * if compute_uv is false (default to true). * */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out u, - _megdnn_tensor_out s, _megdnn_tensor_out vt, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& src, TensorLayout& u, - TensorLayout& s, TensorLayout& vt); - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& u, const TensorLayout& s, - const TensorLayout& vt); + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s, + _megdnn_tensor_out vt, _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& src, TensorLayout& u, TensorLayout& s, + TensorLayout& vt); + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, + const TensorLayout& vt); protected: - static void canonize_params(const TensorLayout& layout, size_t* batch, - size_t* m, size_t* n); - virtual size_t get_workspace_in_bytes(size_t block_cnt, size_t m, size_t n, - size_t dtype_size) = 0; - void check_exec(const TensorLayout& src, const TensorLayout& u, - const TensorLayout& s, const TensorLayout& vt, - size_t workspace_in_bytes); + static void canonize_params( + const TensorLayout& layout, size_t* batch, size_t* m, size_t* n); + virtual size_t get_workspace_in_bytes( + size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0; + void check_exec( + const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, + const TensorLayout& vt, size_t workspace_in_bytes); }; using SVD = SVDForward; diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index ca6920a7..e5ea399c 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -22,32 +22,34 @@ public: using Mode = Param::Mode; protected: - void deduce_layout_fwd(const TensorLayout& src, - const TensorLayout& filter_x, - const TensorLayout& filter_y, TensorLayout& dst); - void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x, - const TensorLayout& filter_y, - const TensorLayout& dst); + void deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, const TensorLayout& dst); }; class SeparableConvForward : public SeparableConvBase { DEF_OPR_IMPL(SeparableConvForward, SeparableConvBase, 3, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x, - _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x, - const TensorLayout& filter_y, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter_x, - const TensorLayout& filter_y, - const TensorLayout& dst) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& filter_x, - const TensorLayout& filter_y, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, const TensorLayout& dst, + size_t workspace_in_bytes); }; using SeparableConv = SeparableConvForward; @@ -147,12 +149,12 @@ public: protected: // Check or deduce output DType void check_or_deduce_dtype_fwd(DType src, DType filter, DType& dst) const; - CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src, - const TensorLayout& filter, - TensorLayout& dst) const; - CanonizedFilterMeta check_layout_fwd(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) const; + CanonizedFilterMeta deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + TensorLayout& dst) const; + CanonizedFilterMeta check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) const; CanonizedFilterMeta make_canonized_filter_meta( size_t src_ndim, const TensorLayout& filter) const; @@ -163,10 +165,11 @@ class MaskPropagate : public OperatorBase { DEF_OPR_PARAM(MaskPropagate); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); }; @@ -178,24 +181,23 @@ class MaskConvForward : public ConvolutionBase { DEF_OPR_IMPL(MaskConvForward, ConvolutionBase, 3, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_in mask, _megdnn_tensor_out dst, - _megdnn_workspace worksapce) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& mask, - const TensorLayout& dst) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in mask, + _megdnn_tensor_out dst, _megdnn_workspace worksapce) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& mask, const TensorLayout& dst) = 0; void deduce_dtype(DType src, DType filter, DType mask, DType& dst); - void deduce_layout(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& mask, TensorLayout& dst); + void deduce_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& mask, TensorLayout& dst); protected: - CanonizedFilterMeta check_exec(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& mask, - const TensorLayout& dst, - size_t workspace_in_bytes); + CanonizedFilterMeta check_exec( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& mask, const TensorLayout& dst, + size_t workspace_in_bytes); }; using MaskConvolution = MaskConvForward; @@ -219,25 +221,24 @@ public: * satisfies the situation that weights is preprocessed * \param[out] dst (n, oc, oh, ow) */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - const PreprocessedFilter* preprocessed_filter, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + const PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) = 0; /** * \brief execute weight preprocessing, read weights form filter and write * to preprocessed_filter after preprocessed. * * \praram[in] workspace the needed tmp workspace when exec_preprocess */ - virtual void exec_preprocess(const TensorLayout& src_layout, - _megdnn_tensor_in filter, - const TensorLayout& dst_layout, - PreprocessedFilter* preprocessed_filter, - _megdnn_workspace workspace) = 0; + virtual void exec_preprocess( + const TensorLayout& src_layout, _megdnn_tensor_in filter, + const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) = 0; void deduce_dtype(DType src, DType filter, DType& dst); - void deduce_layout(const TensorLayout& src, const TensorLayout& filter, - TensorLayout& dst); + void deduce_layout( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst); /** * \brief query the workspace needed when executing the opr, if the weights @@ -249,8 +250,7 @@ public: */ virtual size_t get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, - const PreprocessedFilter* preprocessed_filter) = 0; + const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) = 0; /** * \brief deduce the preprocessed filter layouts according to the src, @@ -303,25 +303,25 @@ public: * \param[in] diff (n, oc, oh, ow) * \param[out] grad (n, ic, ih, iw) */ - virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) = 0; void deduce_dtype(DType filter, DType diff, DType& grad); - void deduce_layout(const TensorLayout& filter, const TensorLayout& diff, - TensorLayout& grad); + void deduce_layout( + const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad); static Algorithm::OprType get_opr_type() { return Algorithm::OprType::CONVOLUTION_BACKWARD_DATA; } protected: - CanonizedFilterMeta check_exec(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes); + CanonizedFilterMeta check_exec( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes); }; /** @@ -340,21 +340,21 @@ public: * \param[in] diff (n, oc, oh, ow) * \param[out] grad (oc, ic, fh, fw) */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) = 0; static Algorithm::OprType get_opr_type() { return Algorithm::OprType::CONVOLUTION_BACKWARD_FILTER; } protected: - CanonizedFilterMeta check_exec(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes); + CanonizedFilterMeta check_exec( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes); }; /** @@ -383,11 +383,11 @@ public: * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah, * alphaw, oc, ic) */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_in bias, _megdnn_tensor_in z, - _megdnn_tensor_out dst, - const PreprocessedFilter* preprocessed_filter, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias, + _megdnn_tensor_in z, _megdnn_tensor_out dst, + const PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) = 0; /** * \brief execute weight preprocessing, read weights form filter and bias, @@ -396,17 +396,15 @@ public: * \praram[in] workspace the needed tmp workspace when exec_preprocess * running, the size is got by get_preprocess_workspace_in_bytes */ - virtual void exec_preprocess(const TensorLayout& src_layout, - _megdnn_tensor_in filter, - _megdnn_tensor_in bias, - const TensorLayout& z_layout, - const TensorLayout& dst_layout, - PreprocessedFilter* preprocessed_filter, - _megdnn_workspace workspace) = 0; + virtual void exec_preprocess( + const TensorLayout& src_layout, _megdnn_tensor_in filter, + _megdnn_tensor_in bias, const TensorLayout& z_layout, + const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) = 0; void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst); - void deduce_layout(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - TensorLayout& dst); + void deduce_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, TensorLayout& dst); /** * \brief query the workspace needed when executing the opr, if the weights @@ -418,8 +416,7 @@ public: */ virtual size_t get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, + const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) = 0; /** @@ -524,15 +521,13 @@ public: protected: CanonizedFilterMeta check_exec( const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, size_t workspace_in_bytes, - const PreprocessedFilter* preprocessed_filter); + const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst, + size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter); CanonizedFilterMeta check_exec_allow_noncontiguous( const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, size_t workspace_in_bytes, - const PreprocessedFilter* preprocessed_filter); + const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst, + size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter); }; using ConvBias = ConvBiasForward; @@ -555,13 +550,13 @@ class ConvPoolingBase : public OperatorBase { DEF_OPR_PARAM(ConvPooling); protected: - virtual void deduce_layout(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, TensorLayout& dst) = 0; - virtual void check_layout(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, TensorLayout& dst, - size_t workspace_limit_in_bytes) = 0; + virtual void deduce_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, TensorLayout& dst) = 0; + virtual void check_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, TensorLayout& dst, + size_t workspace_limit_in_bytes) = 0; }; class ConvPoolingForward : public ConvPoolingBase { @@ -572,23 +567,22 @@ public: * \param[in] src input tensor * \param[out] dst output tensor */ - virtual void exec(const _megdnn_in TensorND src, - const _megdnn_in TensorND filter, - const _megdnn_in TensorND bias, _megdnn_out TensorND dst, - _megdnn_out Workspace workspace) = 0; - virtual void deduce_layout(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, TensorLayout& dst) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, - const TensorLayout& dst) = 0; - -protected: - virtual void check_layout(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, TensorLayout& dst, - size_t workspace_limit_in_bytes) = 0; + virtual void exec( + const _megdnn_in TensorND src, const _megdnn_in TensorND filter, + const _megdnn_in TensorND bias, _megdnn_out TensorND dst, + _megdnn_out Workspace workspace) = 0; + virtual void deduce_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& dst) = 0; + +protected: + virtual void check_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, TensorLayout& dst, + size_t workspace_limit_in_bytes) = 0; }; using ConvPooling = ConvPoolingForward; @@ -600,10 +594,11 @@ public: using Mode = Param::Mode; protected: - void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter, - TensorLayout& dst); - void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst); + void deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst); }; class GroupLocalForward : public GroupLocalBase { @@ -615,19 +610,21 @@ public: * \param[in] filter (G, OH, OW, IC/G, FH, FW, OC/G) * \param[out] dst (N, OC, OH, OW) **/ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& src, const TensorLayout& filter, - TensorLayout& dst) { + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) { deduce_layout_fwd(src, filter, dst); } - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst, size_t workspace_in_bytes); }; using GroupLocal = GroupLocalForward; @@ -635,30 +632,34 @@ class GroupLocalBackwardData : public GroupLocalBase { DEF_OPR_IMPL(GroupLocalBackwardData, GroupLocalBase, 2, 1); public: - virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes); }; class GroupLocalBackwardFilter : public GroupLocalBase { DEF_OPR_IMPL(GroupLocalBackwardFilter, GroupLocalBase, 2, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes); }; class Images2NeibsBase : public OperatorBase { @@ -685,15 +686,17 @@ public: * where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1), * iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$. */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Images2Neibs = Images2NeibsForward; @@ -705,14 +708,16 @@ public: * \param[in] diff the backpropagated gradient wrt. dst * \param[out] grad the backpropagated gradient wrt. src */ - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& diff, const TensorLayout& grad, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes); }; class SlidingWindowTransposeBase : public OperatorBase { @@ -725,43 +730,45 @@ protected: }; class SlidingWindowTransposeForward : public SlidingWindowTransposeBase { - DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, - 1); + DEF_OPR_IMPL(SlidingWindowTransposeForward, SlidingWindowTransposeBase, 1, 1); public: /** * \param[in] src (N, C, IH, IW, window_h, window_w) * \param[out] dst (N, C, OH, OW) */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using SlidingWindowTranspose = SlidingWindowTransposeForward; class SlidingWindowTransposeBackward : public SlidingWindowTransposeBase { - DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, - 1); + DEF_OPR_IMPL(SlidingWindowTransposeBackward, SlidingWindowTransposeBase, 1, 1); public: /** * \param[in] diff the backpropagated gradient wrt. dst * \param[out] grad the backpropagated gradient wrt. src */ - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& diff, const TensorLayout& grad, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes); }; /** @@ -788,19 +795,21 @@ public: * \param[in] src input tensor * \param[out] dst output tensor */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; static Algorithm::OprType get_opr_type() { return Algorithm::OprType::POOLING_FORWARD; } protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using Pooling = PoolingForward; @@ -816,22 +825,21 @@ public: * \param[in] diff the backpropagated gradient wrt. dst * \param[out] grad the backpropagated gradient wrt. src */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad) = 0; static Algorithm::OprType get_opr_type() { return Algorithm::OprType::POOLING_BACKWARD; } protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - const TensorLayout& diff, const TensorLayout& grad, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes); }; /** @@ -842,8 +850,8 @@ class AdaptivePoolingBase : public OperatorBase { DEF_OPR_PARAM(AdaptivePooling); protected: - param::Pooling deduce_pooling_param(const TensorLayout& src, - const TensorLayout& dst); + param::Pooling deduce_pooling_param( + const TensorLayout& src, const TensorLayout& dst); }; class AdaptivePoolingForward : public AdaptivePoolingBase { @@ -854,10 +862,11 @@ public: * \param[in] src input tensor * \param[out] dst output tensor */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; }; using AdaptivePooling = AdaptivePoolingForward; @@ -872,13 +881,12 @@ public: * \param[in] diff the backpropagated gradient wrt. dst * \param[out] grad the backpropagated gradient wrt. src */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad) = 0; }; /** @@ -892,10 +900,11 @@ public: using Mode = Param::Mode; protected: - void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter, - TensorLayout& dst); - void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst); + void deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst); }; class LocalForward : public LocalBase { @@ -907,23 +916,25 @@ public: * \param[in] filter (oh, ow, ic, fh, fw, oc) * \param[out] dst (n, oc, oh, ow) */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; /** * \brief Deducing output tensor layouts from input tensor layouts. * * Be aware that the first and second dimension of `filter' are ignored * when deducing `dst' layout. */ - void deduce_layout(const TensorLayout& src, const TensorLayout& filter, - TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) = 0; + void deduce_layout( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst, size_t workspace_in_bytes); }; using Local = LocalForward; @@ -936,16 +947,18 @@ public: * \param[in] diff (n, oc, oh, ow) * \param[out] grad (n, ic, ih, iw) */ - virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes); }; class LocalBackwardFilter : public LocalBase { @@ -957,16 +970,18 @@ public: * \param[in] diff (n, oc, oh, ow) * \param[out] grad (oh, ow, ic, fh, fw, oc) */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes); }; class BNBase : public OperatorBase { @@ -997,18 +1012,17 @@ public: * src and dst must have the same shape. * src and dst must be contiguous. */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, - _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean, - _megdnn_tensor_inout variance, - _megdnn_tensor_out batch_mean, - _megdnn_tensor_out batch_inv_variance, - _megdnn_tensor_out reserve, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& src, const TensorLayout& bn_scale, - const TensorLayout& bn_bias, TensorLayout& mean, - TensorLayout& variance, TensorLayout& batch_mean, - TensorLayout& batch_inv_variance, TensorLayout& reserve, - TensorLayout& dst); + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in bn_scale, + _megdnn_tensor_in bn_bias, _megdnn_tensor_inout mean, + _megdnn_tensor_inout variance, _megdnn_tensor_out batch_mean, + _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, + _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& src, const TensorLayout& bn_scale, + const TensorLayout& bn_bias, TensorLayout& mean, TensorLayout& variance, + TensorLayout& batch_mean, TensorLayout& batch_inv_variance, + TensorLayout& reserve, TensorLayout& dst); virtual size_t get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& bn_scale, const TensorLayout& bn_bias, const TensorLayout& mean, @@ -1018,13 +1032,12 @@ public: virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& bn_scale, - const TensorLayout& bn_bias, const TensorLayout& mean, - const TensorLayout& variance, - const TensorLayout& batch_mean, - const TensorLayout& batch_inv_variance, - const TensorLayout& dst, size_t workspace_in_bytes, - size_t reserve_in_bytes = 0); + void check_exec( + const TensorLayout& src, const TensorLayout& bn_scale, + const TensorLayout& bn_bias, const TensorLayout& mean, + const TensorLayout& variance, const TensorLayout& batch_mean, + const TensorLayout& batch_inv_variance, const TensorLayout& dst, + size_t workspace_in_bytes, size_t reserve_in_bytes = 0); }; using BN = BNForward; @@ -1045,30 +1058,28 @@ public: Calculated in the forwardpropagation. * \param[in] reserve (see cudnnBatchNormalizationBackwardEx) */ - virtual void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, - _megdnn_tensor_in saved_batch_mean, - _megdnn_tensor_in saved_batch_variance, - _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, - _megdnn_tensor_out d_bn_scale, - _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in x, _megdnn_tensor_in dy, + _megdnn_tensor_in saved_batch_mean, _megdnn_tensor_in saved_batch_variance, + _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, + _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, + _megdnn_tensor_out dx, _megdnn_workspace workspace) = 0; virtual size_t get_workspace_in_bytes( const TensorLayout& x, const TensorLayout& dy, const TensorLayout& saved_batch_mean, - const TensorLayout& saved_batch_variance, - const TensorLayout& bn_scale, const TensorLayout& reserve, - const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias, - const TensorLayout& dx) = 0; + const TensorLayout& saved_batch_variance, const TensorLayout& bn_scale, + const TensorLayout& reserve, const TensorLayout& d_bn_scale, + const TensorLayout& d_bn_bias, const TensorLayout& dx) = 0; virtual size_t get_reserve_in_bytes(const TensorLayout& src) = 0; protected: - void check_exec(const TensorLayout& x, const TensorLayout& dy, - const TensorLayout& saved_batch_mean, - const TensorLayout& saved_batch_variance, - const TensorLayout& bn_scale, - const TensorLayout& d_bn_scale, - const TensorLayout& d_bn_bias, const TensorLayout& dx, - size_t workspace_in_bytes, size_t reserve_in_bytes = 0); + void check_exec( + const TensorLayout& x, const TensorLayout& dy, + const TensorLayout& saved_batch_mean, + const TensorLayout& saved_batch_variance, const TensorLayout& bn_scale, + const TensorLayout& d_bn_scale, const TensorLayout& d_bn_bias, + const TensorLayout& dx, size_t workspace_in_bytes, + size_t reserve_in_bytes = 0); }; class LRNBase : public OperatorBase { @@ -1091,15 +1102,17 @@ public: * src and dst must have the same shape. * src and dst must be contiguous. */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; void deduce_layout(const TensorLayout& src, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_in_bytes); }; using LRN = LRNForward; @@ -1115,18 +1128,17 @@ public: * * All tensors should be contiguous and of the same shape. */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - const TensorLayout& diff, const TensorLayout& grad, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes); }; class ROIPoolingBase : public OperatorBase { @@ -1134,8 +1146,9 @@ class ROIPoolingBase : public OperatorBase { DEF_OPR_PARAM(ROIPooling); protected: - void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois, - const TensorLayout& dst, const TensorLayout& index); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, + const TensorLayout& index); }; class ROIPoolingForward : public ROIPoolingBase { @@ -1157,18 +1170,17 @@ public: * It is used to store argmax indicex in MAX mode, and it is not used * in AVERAGE mode. */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois, - _megdnn_tensor_out dst, _megdnn_tensor_out index, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& rois, - const TensorLayout& dst, - const TensorLayout& index) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in rois, _megdnn_tensor_out dst, + _megdnn_tensor_out index, _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, + const TensorLayout& index) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& rois, - const TensorLayout& dst, const TensorLayout& index, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, + const TensorLayout& index, size_t workspace_in_bytes); }; using ROIPooling = ROIPoolingForward; @@ -1183,19 +1195,19 @@ public: * \param[in] index the `index' parameter in ROIPoolingForward::exec * \param[out] grad the backpropagated gradient wrt. src */ - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in src, - _megdnn_tensor_in rois, _megdnn_tensor_in index, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& src, - const TensorLayout& rois, - const TensorLayout& index, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in src, _megdnn_tensor_in rois, + _megdnn_tensor_in index, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois, + const TensorLayout& index, const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& diff, const TensorLayout& src, - const TensorLayout& rois, const TensorLayout& index, - const TensorLayout& grad, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois, + const TensorLayout& index, const TensorLayout& grad, + size_t workspace_in_bytes); }; class Convolution3DBase : public OperatorBase { @@ -1222,20 +1234,19 @@ public: } MEGDNN_PACKED; protected: - CanonizedFilterMeta deduce_layout_fwd(const TensorLayout& src, - const TensorLayout& filter, - TensorLayout& dst) const; - CanonizedFilterMeta check_layout_fwd(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) const; + CanonizedFilterMeta deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + TensorLayout& dst) const; + CanonizedFilterMeta check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) const; CanonizedFilterMeta make_canonized_filter_meta( size_t src_ndim, const TensorLayout& filter) const; }; -class Convolution3DForward - : public Convolution3DBase, - public detail::MultiAlgoOpr { +class Convolution3DForward : public Convolution3DBase, + public detail::MultiAlgoOpr { DEF_OPR_IMPL(Convolution3DForward, Convolution3DBase, 2, 1); public: @@ -1244,23 +1255,23 @@ public: * \param[in] filter (oc, ic, fd, fh, fw) * \param[out] dst (n, oc, od, oh, ow) */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& src, const TensorLayout& filter, - TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) = 0; static Algorithm::OprType get_opr_type() { return Algorithm::OprType::CONVOLUTION3D_FORWARD; } protected: - CanonizedFilterMeta check_exec(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst, - size_t workspace_in_bytes); + CanonizedFilterMeta check_exec( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst, size_t workspace_in_bytes); }; using Convolution3D = Convolution3DForward; @@ -1275,24 +1286,24 @@ public: * \param[in] diff (n, oc, od, oh, ow) * \param[out] grad (n, ic, id, ih, iw) */ - virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) = 0; - void deduce_layout(const TensorLayout& filter, const TensorLayout& diff, - TensorLayout& grad); + void deduce_layout( + const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad); static Algorithm::OprType get_opr_type() { return Algorithm::OprType::CONVOLUTION3D_BACKWARD_DATA; } protected: - CanonizedFilterMeta check_exec(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes); + CanonizedFilterMeta check_exec( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes); }; class Convolution3DBackwardFilter @@ -1306,21 +1317,21 @@ public: * \param[in] diff (n, oc, od, oh, ow) * \param[out] grad (oc, ic, fd, fh, fw) */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) = 0; static Algorithm::OprType get_opr_type() { return Algorithm::OprType::CONVOLUTION3D_BACKWARD_FILTER; } protected: - CanonizedFilterMeta check_exec(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes); + CanonizedFilterMeta check_exec( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes); }; class LocalShareBase : public OperatorBase { @@ -1328,10 +1339,11 @@ class LocalShareBase : public OperatorBase { DEF_OPR_PARAM(LocalShare); protected: - void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& filter, - TensorLayout& dst); - void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst); + void deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst); }; class LocalShareForward : public LocalShareBase, @@ -1345,30 +1357,31 @@ public: * FH, FW, OC / G) * \param[out] dst (N, OC, OH, OW) */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; /** * \brief deduce layout of the ouput tensor */ - void deduce_layout(const TensorLayout& src, const TensorLayout& filter, - TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) = 0; + void deduce_layout( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) = 0; static Algorithm::OprType get_opr_type() { return Algorithm::OprType::LOCAL_SHARE_FORWARD; } protected: - void check_exec(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst, size_t workspace_in_bytes); }; using LocalShare = LocalShareForward; -class LocalShareBackwardData - : public LocalShareBase, - public detail::MultiAlgoOpr { +class LocalShareBackwardData : public LocalShareBase, + public detail::MultiAlgoOpr { DEF_OPR_IMPL(LocalShareBackwardData, LocalShareBase, 2, 1); public: @@ -1378,21 +1391,23 @@ public: * \param[in] diff (N, OC, OH, OW) * \param[out] grad (N, IC, IH, IW) */ - virtual void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) = 0; - void deduce_layout(const TensorLayout& filter, const TensorLayout& diff, - TensorLayout& grad); + virtual void exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) = 0; + void deduce_layout( + const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad); static Algorithm::OprType get_opr_type() { return Algorithm::OprType::LOCAL_SHARE_BACKWARD_DATA; } protected: - void check_exec(const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes); }; class LocalShareBackwardFilter @@ -1407,20 +1422,22 @@ public: * \param[out] grad (G, spatial_groups_h, spatial_groups_w, IC / G, * FH, FW, OC / G) */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) = 0; static Algorithm::OprType get_opr_type() { return Algorithm::OprType::LOCAL_SHARE_BACKWARD_FILTER; } protected: - void check_exec(const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes); }; class ROIAlignBase : public OperatorBase { @@ -1428,10 +1445,12 @@ class ROIAlignBase : public OperatorBase { DEF_OPR_PARAM(ROIAlign); protected: - void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& rois, - TensorLayout& dst, TensorLayout& index); - void check_layout_fwd(const TensorLayout& src, const TensorLayout& rois, - const TensorLayout& dst, const TensorLayout& index); + void deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst, + TensorLayout& index); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, + const TensorLayout& index); }; class ROIAlignForward : public ROIAlignBase { @@ -1451,20 +1470,20 @@ public: * It is used to store argmax indicex in MAX mode, and it is not used * in AVERAGE mode. */ - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in rois, - _megdnn_tensor_out dst, _megdnn_tensor_out index, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& src, const TensorLayout& rois, - TensorLayout& dst, TensorLayout& index); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& rois, - const TensorLayout& dst, - const TensorLayout& index) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in rois, _megdnn_tensor_out dst, + _megdnn_tensor_out index, _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst, + TensorLayout& index); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, + const TensorLayout& index) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& rois, - const TensorLayout& dst, const TensorLayout& index, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, + const TensorLayout& index, size_t workspace_in_bytes); }; using ROIAlign = ROIAlignForward; @@ -1478,18 +1497,18 @@ public: * \param[in] index the `index' parameter in ROIAlignForward::exec * \param[out] grad the backpropagated gradient wrt. src */ - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in rois, - _megdnn_tensor_in index, _megdnn_tensor_out grad, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& rois, - const TensorLayout& index, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in rois, _megdnn_tensor_in index, + _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& rois, + const TensorLayout& index, const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& diff, const TensorLayout& rois, - const TensorLayout& index, const TensorLayout& grad, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& diff, const TensorLayout& rois, + const TensorLayout& index, const TensorLayout& grad, + size_t workspace_in_bytes); }; class DeformableConvBase : public OperatorBase { @@ -1506,17 +1525,17 @@ protected: CanonizedFilterMeta make_canonized_filter_meta( size_t src_ndim, const TensorLayout& filter, const TensorLayout& offset) const; - void deduce_layout_fwd(const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& mask, const TensorLayout& offset, - TensorLayout& dst); - void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& mask, const TensorLayout& offset, - const TensorLayout& dst); + void deduce_layout_fwd( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& mask, const TensorLayout& offset, TensorLayout& dst); + void check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& mask, const TensorLayout& offset, + const TensorLayout& dst); }; -class DeformableConvForward - : public DeformableConvBase, - public detail::MultiAlgoOpr { +class DeformableConvForward : public DeformableConvBase, + public detail::MultiAlgoOpr { DEF_OPR_IMPL(DeformableConvForward, DeformableConvBase, 4, 1); public: @@ -1527,29 +1546,27 @@ public: * \param[in] mask (dg, fh, fw, oh, ow) * \param[out] dst (n, oc, oh, ow) */ - virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, - _megdnn_tensor_in offset, _megdnn_tensor_in mask, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, - TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst) = 0; + virtual void exec( + _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset, + _megdnn_tensor_in mask, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, TensorLayout& dst); + virtual size_t get_workspace_in_bytes( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& dst) = 0; static Algorithm::OprType get_opr_type() { return Algorithm::OprType::DEFORMABLE_CONV_FORWARD; } protected: - CanonizedFilterMeta check_exec(const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst, - size_t workspace_in_bytes); + CanonizedFilterMeta check_exec( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& dst, size_t workspace_in_bytes); }; using DeformableConv = DeformableConvForward; @@ -1571,30 +1588,28 @@ public: * \param[in] out_grad (n, oc, oh, ow) * \param[out] filter_grad (oc, ic, ih, iw) */ - virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, - _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, - _megdnn_tensor_out filter_grad, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& im, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& out_grad, - const TensorLayout& filter_grad) = 0; - void deduce_layout(const TensorLayout& im, const TensorLayout& offset, - const TensorLayout& mask, const TensorLayout& out_grad, - TensorLayout& filter_grad); + virtual void exec( + _megdnn_tensor_in im, _megdnn_tensor_in offset, _megdnn_tensor_in mask, + _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& im, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& filter_grad) = 0; + void deduce_layout( + const TensorLayout& im, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + TensorLayout& filter_grad); static Algorithm::OprType get_opr_type() { return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_FILTER; } protected: - CanonizedFilterMeta check_exec(const TensorLayout& im, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& out_grad, - const TensorLayout& filter_grad, - size_t workspace_in_bytes); + CanonizedFilterMeta check_exec( + const TensorLayout& im, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& filter_grad, size_t workspace_in_bytes); }; /** @@ -1618,21 +1633,21 @@ public: * \param[out] offset_grad (dg, 2, fh, fw, oh, ow) * \param[out] mask_grad (dg, fh, fw, oh, ow) */ - virtual void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, - _megdnn_tensor_in offset, _megdnn_tensor_in mask, - _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad, - _megdnn_tensor_out offset_grad, - _megdnn_tensor_out mask_grad, - _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset, + _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, + _megdnn_tensor_out im_grad, _megdnn_tensor_out offset_grad, + _megdnn_tensor_out mask_grad, _megdnn_workspace workspace) = 0; virtual size_t get_workspace_in_bytes( const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& im_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad) = 0; - void deduce_layout(const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& out_grad, TensorLayout& im_grad, - TensorLayout& offset_grad, TensorLayout& mask_grad); + void deduce_layout( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& out_grad, TensorLayout& im_grad, + TensorLayout& offset_grad, TensorLayout& mask_grad); static Algorithm::OprType get_opr_type() { return Algorithm::OprType::DEFORMABLE_CONV_BACKWARD_DATA; @@ -1652,20 +1667,18 @@ class DeformablePSROIPoolingBase : public OperatorBase { DEF_OPR_PARAM(DeformablePSROIPooling); protected: - void deduce_layout_fwd(const TensorLayout& data, const TensorLayout& trans, - const TensorLayout& rois, TensorLayout& out_data, - TensorLayout& out_count); + void deduce_layout_fwd( + const TensorLayout& data, const TensorLayout& trans, + const TensorLayout& rois, TensorLayout& out_data, TensorLayout& out_count); - void check_layout_fwd(const TensorLayout& data, const TensorLayout& trans, - const TensorLayout& rois, - const TensorLayout& out_data, - const TensorLayout& out_count, - size_t workspace_in_bytes); + void check_layout_fwd( + const TensorLayout& data, const TensorLayout& trans, + const TensorLayout& rois, const TensorLayout& out_data, + const TensorLayout& out_count, size_t workspace_in_bytes); }; class DeformablePSROIPoolingForward : public DeformablePSROIPoolingBase { - DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3, - 2); + DEF_OPR_IMPL(DeformablePSROIPoolingForward, DeformablePSROIPoolingBase, 3, 2); public: /** @@ -1675,28 +1688,27 @@ public: * \param[out] out_data ( n, ic, ih, iw) * \param[out] out_count ( n, ic, ih, iw) */ - virtual size_t get_workspace_in_bytes(const TensorLayout& data, - const TensorLayout& rois, - const TensorLayout& trans, - const TensorLayout& out_data, - const TensorLayout& out_count) = 0; - virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois, - _megdnn_tensor_in trans, _megdnn_tensor_out out_data, - _megdnn_tensor_out out_count, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& data, const TensorLayout& rois, - const TensorLayout& trans, TensorLayout& out_data, - TensorLayout& out_count); - void check_exec(const TensorLayout& data, const TensorLayout& rois, - const TensorLayout& trans, const TensorLayout& out_data, - const TensorLayout& out_count, size_t workspace_in_bytes); + virtual size_t get_workspace_in_bytes( + const TensorLayout& data, const TensorLayout& rois, + const TensorLayout& trans, const TensorLayout& out_data, + const TensorLayout& out_count) = 0; + virtual void exec( + _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans, + _megdnn_tensor_out out_data, _megdnn_tensor_out out_count, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& data, const TensorLayout& rois, + const TensorLayout& trans, TensorLayout& out_data, TensorLayout& out_count); + void check_exec( + const TensorLayout& data, const TensorLayout& rois, + const TensorLayout& trans, const TensorLayout& out_data, + const TensorLayout& out_count, size_t workspace_in_bytes); }; using DeformablePSROIPooling = DeformablePSROIPoolingForward; class DeformablePSROIPoolingBackward : public DeformablePSROIPoolingBase { - DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5, - 2); + DEF_OPR_IMPL(DeformablePSROIPoolingBackward, DeformablePSROIPoolingBase, 5, 2); public: /** @@ -1708,58 +1720,53 @@ public: * \param[out] data_diff ( n, ic, ih, iw) * \param[out] trans_diff ( n, ic, ih, iw) */ - virtual void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois, - _megdnn_tensor_in trans, _megdnn_tensor_in out_diff, - _megdnn_tensor_in out_count, _megdnn_tensor_out data_diff, - _megdnn_tensor_out trans_diff, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& data, - const TensorLayout& rois, - const TensorLayout& trans, - const TensorLayout& out_diff, - const TensorLayout& out_count, - const TensorLayout& data_diff, - const TensorLayout& trans_diff) = 0; - - void check_exec(const TensorLayout& data, const TensorLayout& rois, - const TensorLayout& trans, const TensorLayout& out_diff, - const TensorLayout& out_count, - const TensorLayout& data_diff, - const TensorLayout& trans_diff, size_t workspace_in_bytes); -}; - -class BatchConvBiasForward - : public ConvolutionBase, - public detail::MultiAlgoOpr { + virtual void exec( + _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans, + _megdnn_tensor_in out_diff, _megdnn_tensor_in out_count, + _megdnn_tensor_out data_diff, _megdnn_tensor_out trans_diff, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& data, const TensorLayout& rois, + const TensorLayout& trans, const TensorLayout& out_diff, + const TensorLayout& out_count, const TensorLayout& data_diff, + const TensorLayout& trans_diff) = 0; + + void check_exec( + const TensorLayout& data, const TensorLayout& rois, + const TensorLayout& trans, const TensorLayout& out_diff, + const TensorLayout& out_count, const TensorLayout& data_diff, + const TensorLayout& trans_diff, size_t workspace_in_bytes); +}; + +class BatchConvBiasForward : public ConvolutionBase, + public detail::MultiAlgoOpr { DEF_OPR_IMPL(BatchConvBiasForward, ConvolutionBase, 4, 1); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_in bias, _megdnn_tensor_in z, - _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias, + _megdnn_tensor_in z, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; void deduce_dtype(DType src, DType filter, DType bias, DType z, DType& dst); - void deduce_layout(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - TensorLayout& dst); + void deduce_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, TensorLayout& dst); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, - const TensorLayout& z, - const TensorLayout& dst) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) = 0; static Algorithm::OprType get_opr_type() { return Algorithm::OprType::BATCH_CONV_FORWARD; } protected: - CanonizedFilterMeta check_exec(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, - const TensorLayout& z, - const TensorLayout& dst, - size_t workspace_in_bytes); + CanonizedFilterMeta check_exec( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst, + size_t workspace_in_bytes); }; using BatchConvBias = BatchConvBiasForward; @@ -1769,29 +1776,31 @@ class FakeQuantBase : public OperatorBase { protected: void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output); - void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& output); + void check_layout_fwd( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& output); }; class FakeQuantForward : public FakeQuantBase { DEF_OPR_IMPL(FakeQuantForward, FakeQuantBase, 3, 1); public: - virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, _megdnn_tensor_out output, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& input, const TensorLayout& scale, - const TensorLayout& zero_point, TensorLayout& output); - virtual size_t get_workspace_in_bytes(const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& output) = 0; + virtual void exec( + _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_out output, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, TensorLayout& output); + virtual size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& output) = 0; protected: - void check_exec(const TensorLayout& input, const TensorLayout& scale, - const TensorLayout& zero_point, const TensorLayout& output, - size_t workspace_in_bytes); + void check_exec( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& output, + size_t workspace_in_bytes); }; using FakeQuant = FakeQuantForward; @@ -1800,19 +1809,20 @@ class FakeQuantBackward : public FakeQuantBase { DEF_OPR_IMPL(FakeQuantBackward, FakeQuantBase, 4, 1); public: - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, - _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& input, + const TensorLayout& scale, const TensorLayout& zero_point, + const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& diff, const TensorLayout& input, - const TensorLayout& scale, const TensorLayout& zero_point, - const TensorLayout& grad, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& diff, const TensorLayout& input, + const TensorLayout& scale, const TensorLayout& zero_point, + const TensorLayout& grad, size_t workspace_in_bytes); }; class TQTBase : public OperatorBase { @@ -1821,26 +1831,28 @@ class TQTBase : public OperatorBase { protected: void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output); - void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale, - const TensorLayout& output); + void check_layout_fwd( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& output); }; class TQTForward : public TQTBase { DEF_OPR_IMPL(TQTForward, TQTBase, 2, 1); public: - virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, - _megdnn_tensor_out output, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& input, const TensorLayout& scale, - TensorLayout& output); - virtual size_t get_workspace_in_bytes(const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& output) = 0; + virtual void exec( + _megdnn_tensor_in input, _megdnn_tensor_in scale, _megdnn_tensor_out output, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& input, const TensorLayout& scale, TensorLayout& output); + virtual size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& output) = 0; protected: - void check_exec(const TensorLayout& input, const TensorLayout& scale, - const TensorLayout& output, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& output, size_t workspace_in_bytes); }; using TQT = TQTForward; @@ -1848,20 +1860,20 @@ class TQTBackward : public TQTBase { DEF_OPR_IMPL(TQTBackward, TQTBase, 3, 2); public: - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, - _megdnn_tensor_in scale, _megdnn_tensor_out grad_x, - _megdnn_tensor_out grad_s, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& grad_x, - const TensorLayout& grad_s) = 0; + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& input, + const TensorLayout& scale, const TensorLayout& grad_x, + const TensorLayout& grad_s) = 0; protected: - void check_exec(const TensorLayout& diff, const TensorLayout& input, - const TensorLayout& scale, const TensorLayout& grad_x, - const TensorLayout& grad_s, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& diff, const TensorLayout& input, + const TensorLayout& scale, const TensorLayout& grad_x, + const TensorLayout& grad_s, size_t workspace_in_bytes); }; class LSQBase : public OperatorBase { @@ -1870,34 +1882,34 @@ class LSQBase : public OperatorBase { protected: void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output); - void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& grad_scale, - const TensorLayout& output); + void check_layout_fwd( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& grad_scale, + const TensorLayout& output); }; class LSQForward : public LSQBase { DEF_OPR_IMPL(LSQForward, LSQBase, 4, 1); public: - virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, - _megdnn_tensor_in grad_scale, _megdnn_tensor_out output, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& input, const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& grad_scale, TensorLayout& output); - virtual size_t get_workspace_in_bytes(const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& grad_scale, - const TensorLayout& output) = 0; - -protected: - void check_exec(const TensorLayout& input, const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& grad_scale, const TensorLayout& output, - size_t workspace_in_bytes); + virtual void exec( + _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, + _megdnn_tensor_out output, _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& grad_scale, + TensorLayout& output); + virtual size_t get_workspace_in_bytes( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& grad_scale, + const TensorLayout& output) = 0; + +protected: + void check_exec( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& grad_scale, + const TensorLayout& output, size_t workspace_in_bytes); }; using LSQ = LSQForward; @@ -1905,24 +1917,23 @@ class LSQBackward : public LSQBase { DEF_OPR_IMPL(LSQBackward, LSQBase, 5, 2); public: - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, - _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, - _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, - _megdnn_tensor_out grad_s, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& grad_scale, - const TensorLayout& grad_x, - const TensorLayout& grad_s) = 0; - -protected: - void check_exec(const TensorLayout& diff, const TensorLayout& input, - const TensorLayout& scale, const TensorLayout& zero_point, - const TensorLayout& grad_scale, const TensorLayout& grad_x, - const TensorLayout& grad_s, size_t workspace_in_bytes); + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, + _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& input, + const TensorLayout& scale, const TensorLayout& zero_point, + const TensorLayout& grad_scale, const TensorLayout& grad_x, + const TensorLayout& grad_s) = 0; + +protected: + void check_exec( + const TensorLayout& diff, const TensorLayout& input, + const TensorLayout& scale, const TensorLayout& zero_point, + const TensorLayout& grad_scale, const TensorLayout& grad_x, + const TensorLayout& grad_s, size_t workspace_in_bytes); }; } // namespace megdnn diff --git a/dnn/include/megdnn/oprs/nn_int.h b/dnn/include/megdnn/oprs/nn_int.h index 44378394..9abd6b71 100644 --- a/dnn/include/megdnn/oprs/nn_int.h +++ b/dnn/include/megdnn/oprs/nn_int.h @@ -36,7 +36,7 @@ public: struct ModeTrait { uint32_t arity = 0; //!< number of inputs needed CheckDtypeFunc check_inp[MAX_ARITY]; - SetOrCheckDtypeFunc check_out; //!< dtype of output var + SetOrCheckDtypeFunc check_out; //!< dtype of output var bool need_specify_out_dtype = false; //!< the dtype should be setup externally, otherwise //!< would be inferred by check_out(dtype, false) @@ -46,13 +46,10 @@ public: static const ModeTrait& from_mode(Mode mode); }; - virtual void exec(_megdnn_in const TensorNDArray& src, - _megdnn_tensor_out dst) = 0; + virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0; //! get trait of current mode - const ModeTrait& mode_trait() const { - return ModeTrait::from_mode(m_param.mode); - } + const ModeTrait& mode_trait() const { return ModeTrait::from_mode(m_param.mode); } //! deduce output layout void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst); @@ -60,8 +57,8 @@ public: protected: //! throw exception if incorrect layout; broadcast input shape to //! output shape - void check_layout_and_broadcast(const TensorLayoutPtrArray& src, - const TensorLayout& dst); + void check_layout_and_broadcast( + const TensorLayoutPtrArray& src, const TensorLayout& dst); }; } // namespace megdnn diff --git a/dnn/include/megdnn/oprs/utils.h b/dnn/include/megdnn/oprs/utils.h index af22a5a8..85ef2dc8 100644 --- a/dnn/include/megdnn/oprs/utils.h +++ b/dnn/include/megdnn/oprs/utils.h @@ -15,84 +15,97 @@ namespace megdnn { //! base class for random number generators -class RNGBase: public OperatorBase { +class RNGBase : public OperatorBase { DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase); - public: - virtual void exec(_megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; - protected: - virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0; + +public: + virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0; + +protected: + virtual void check_exec(const TensorLayout& dst, size_t workspace_in_bytes) = 0; }; //! sample from poisson distribution -class PoissonRNG: public OperatorBase { +class PoissonRNG : public OperatorBase { DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); DEF_OPR_PARAM(PoissonRNG); - public: - virtual void exec(_megdnn_tensor_in lam, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &lam, - const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &lam, const TensorLayout &dst, - size_t workspace_in_bytes); + +public: + virtual void exec( + _megdnn_tensor_in lam, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& lam, const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& lam, const TensorLayout& dst, + size_t workspace_in_bytes); }; //! sample from beta distribution -class BetaRNG: public OperatorBase { +class BetaRNG : public OperatorBase { DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); DEF_OPR_PARAM(BetaRNG); - public: - virtual void exec(_megdnn_tensor_in alpha, - _megdnn_tensor_in beta, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &alpha, - const TensorLayout &beta, const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &alpha, const TensorLayout &beta, - const TensorLayout &dst, size_t workspace_in_bytes); + +public: + virtual void exec( + _megdnn_tensor_in alpha, _megdnn_tensor_in beta, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& alpha, const TensorLayout& beta, + const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& alpha, const TensorLayout& beta, + const TensorLayout& dst, size_t workspace_in_bytes); }; //! sample from gamma distribution -class GammaRNG: public OperatorBase { +class GammaRNG : public OperatorBase { DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); DEF_OPR_PARAM(GammaRNG); - public: - virtual void exec(_megdnn_tensor_in shape, - _megdnn_tensor_in scale, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout &shape, - const TensorLayout &scale, const TensorLayout &dst) = 0; - protected: - void check_exec(const TensorLayout &shape, const TensorLayout &scale, - const TensorLayout &dst, size_t workspace_in_bytes); + +public: + virtual void exec( + _megdnn_tensor_in shape, _megdnn_tensor_in scale, _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& shape, const TensorLayout& scale, + const TensorLayout& dst) = 0; + +protected: + void check_exec( + const TensorLayout& shape, const TensorLayout& scale, + const TensorLayout& dst, size_t workspace_in_bytes); }; //! sample from uniform distribution on the interval (0, 1] -class UniformRNG: public RNGBase { +class UniformRNG : public RNGBase { DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); DEF_OPR_PARAM(UniformRNG); - protected: - void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); + +protected: + void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); }; //! sample from gaussian distribution -class GaussianRNG: public RNGBase { +class GaussianRNG : public RNGBase { DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); DEF_OPR_PARAM(GaussianRNG); - protected: - void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); + +protected: + void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); }; -class PermutationRNG: public RNGBase { +class PermutationRNG : public RNGBase { DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); DEF_OPR_PARAM(PermutationRNG); - protected: - void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); + +protected: + void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); }; class ShuffleRNGForward : public OperatorBase { @@ -100,18 +113,19 @@ class ShuffleRNGForward : public OperatorBase { DEF_OPR_PARAM(ShuffleRNG); public: - virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_tensor_out indices, - _megdnn_workspace workspace) = 0; - void deduce_layout(const TensorLayout& src, TensorLayout& dst, - TensorLayout& indices); - virtual size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& indices) = 0; + virtual void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& src, TensorLayout& dst, TensorLayout& indices); + virtual size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& indices) = 0; protected: - void check_exec(const TensorLayout& src, const TensorLayout& dst, - const TensorLayout& indices, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& indices, size_t workspace_in_bytes); }; using ShuffleRNG = ShuffleRNGForward; @@ -120,27 +134,29 @@ class ShuffleRNGBackward : public OperatorBase { DEF_OPR_PARAM(ShuffleRNG); public: - virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, - _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; - virtual size_t get_workspace_in_bytes(const TensorLayout& diff, - const TensorLayout& indices, - const TensorLayout& grad) = 0; + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& indices, + const TensorLayout& grad) = 0; protected: - void check_exec(const TensorLayout& diff, const TensorLayout& indices, - const TensorLayout& grad, size_t workspace_in_bytes); + void check_exec( + const TensorLayout& diff, const TensorLayout& indices, + const TensorLayout& grad, size_t workspace_in_bytes); }; /*! * \brief sleep for specific time on the computing device; useful for testing * async problems */ -class SleepForward: public OperatorBase { +class SleepForward : public OperatorBase { DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0); DEF_OPR_PARAM(Sleep); - public: - virtual void exec() = 0; +public: + virtual void exec() = 0; }; using Sleep = SleepForward; @@ -149,20 +165,19 @@ using Sleep = SleepForward; * * data must be a one-dimensional contiguous tensor with dtype byte */ -class ChecksumForward: public OperatorBase { +class ChecksumForward : public OperatorBase { DEF_OPR_PARAM(Empty); DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1); - public: - using Result = opr_result::Checksum; +public: + using Result = opr_result::Checksum; - virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0; - virtual Result exec(_megdnn_tensor_in data, - _megdnn_workspace workspace) = 0; + virtual Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) = 0; - protected: - void check_exec(const TensorLayout &layout, size_t workspace_in_bytes); +protected: + void check_exec(const TensorLayout& layout, size_t workspace_in_bytes); }; using Checksum = ChecksumForward; @@ -175,21 +190,22 @@ class MaxTensorDiff : public OperatorBase { DEF_OPR_PARAM(Empty); DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2); - public: - virtual size_t get_workspace_in_bytes(const TensorLayout& layout1, - const TensorLayout& layout2) = 0; +public: + virtual size_t get_workspace_in_bytes( + const TensorLayout& layout1, const TensorLayout& layout2) = 0; - virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2, - _megdnn_workspace workspace) = 0; + virtual float exec( + _megdnn_tensor_in src1, _megdnn_tensor_in src2, + _megdnn_workspace workspace) = 0; - protected: - void check_exec(const TensorLayout& layout1, - const TensorLayout& layout2, size_t workspace_in_bytes); +protected: + void check_exec( + const TensorLayout& layout1, const TensorLayout& layout2, + size_t workspace_in_bytes); }; - -bool check_bias_share_in_channel(const TensorLayout& bias, - const param::ConvBias::Format format); +bool check_bias_share_in_channel( + const TensorLayout& bias, const param::ConvBias::Format format); } // namespace megdnn diff --git a/dnn/include/megdnn/tensor_format.h b/dnn/include/megdnn/tensor_format.h index bb9b68a7..3975ab52 100644 --- a/dnn/include/megdnn/tensor_format.h +++ b/dnn/include/megdnn/tensor_format.h @@ -18,9 +18,9 @@ namespace megdnn { enum class TensorFormat::Type { - DEFAULT = 0, //!< see DefaultTensorFormat - IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat - LOWBITS_ALIGNED_TO_BYTE = 2, //!< + DEFAULT = 0, //!< see DefaultTensorFormat + IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat + LOWBITS_ALIGNED_TO_BYTE = 2, //!< }; class TensorFormat::ImplBase { @@ -33,8 +33,7 @@ public: virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; - virtual TensorLayout collapse_contiguous_spec( - const TensorLayout& layout) const = 0; + virtual TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const = 0; virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0; @@ -79,8 +78,7 @@ public: */ bool is_contiguous_spec(const TensorLayout& layout) const override; - TensorLayout collapse_contiguous_spec( - const TensorLayout& layout) const override; + TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; TensorLayout::Span span_spec(const TensorLayout& layout) const override; @@ -88,8 +86,7 @@ public: void serialize_append(std::string& result) const override; static TensorFormat make(); - static TensorFormat deserialize(const Handle* handle, const void* buf, - size_t size); + static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); }; namespace detail { @@ -112,8 +109,8 @@ class Image2DTensorFormatBase : public TensorFormat::ImplBase { size_t m_align_axis, m_align_size_in_elements_log2; protected: - Image2DTensorFormatBase(Type type, size_t align_axis, - size_t align_size_in_elements); + Image2DTensorFormatBase( + Type type, size_t align_axis, size_t align_size_in_elements); virtual ~Image2DTensorFormatBase() = default; public: @@ -129,9 +126,7 @@ public: size_t align_axis() const { return m_align_axis; } - size_t align_size_in_elements_log2() const { - return m_align_size_in_elements_log2; - } + size_t align_size_in_elements_log2() const { return m_align_size_in_elements_log2; } std::string to_string() const override; @@ -145,6 +140,7 @@ public: size_t image_height(const TensorLayout& layout) const; void serialize_append(std::string& result) const override; + protected: struct SerializePack { uint8_t align_axis; @@ -160,15 +156,14 @@ class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase { * align COUNT, but mdl needs align size in byte, which equal to * (image_width algin count) * sizeof(data_type) * pixel_size */ - size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements, - const TensorLayout& layout) const; + size_t image_pitch_alignment_in_bytes( + size_t align_size_in_elements, const TensorLayout& layout) const; protected: - Image2DPackedTensorFormatBase(Type type, size_t align_axis, - size_t align_size_in_elements, - Handle::HandleVendorType vendor_type) - : detail::Image2DTensorFormatBase(type, align_axis, - align_size_in_elements), + Image2DPackedTensorFormatBase( + Type type, size_t align_axis, size_t align_size_in_elements, + Handle::HandleVendorType vendor_type) + : detail::Image2DTensorFormatBase(type, align_axis, align_size_in_elements), m_vendor_type(vendor_type) {} virtual ~Image2DPackedTensorFormatBase() = default; @@ -197,13 +192,12 @@ public: bool is_contiguous_spec(const TensorLayout& layout) const override; - TensorLayout collapse_contiguous_spec( - const TensorLayout& layout) const override; + TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; }; using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; /*! - * \brief used for tensors storing lowbit data + * \brief used for tensors storing lowbit data * * \param m_size_nbits size in bits of elements in the tensor * \param m_align_size_in_bits aligned size in bits @@ -213,14 +207,14 @@ class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase { size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements; protected: //? - LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits, - size_t align_size_in_bits); + LowbitsAlignedTensorFormatBase( + Type type, size_t size_nbits, size_t align_size_in_bits); virtual ~LowbitsAlignedTensorFormatBase() = default; public: size_t align_size_in_bits() const { return m_align_size_in_bits; } - + size_t size_nbits() const { return m_size_nbits; } std::string to_string() const override; @@ -238,8 +232,8 @@ public: bool is_contiguous_spec(const TensorLayout& layout) const override; - TensorLayout collapse_contiguous_spec( - const TensorLayout& layout) const override; + TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; + protected: struct SerializePack { uint8_t size_nbits; @@ -254,16 +248,14 @@ protected: * * This is used for OpenCL. */ -class Image2DPack4TensorFormat final - : public detail::Image2DPack4TensorFormatBase { +class Image2DPack4TensorFormat final : public detail::Image2DPack4TensorFormatBase { public: static constexpr Type TYPE = Type::IMAGE2D_PACK4; //! for internal usage or test purposes - static TensorFormat make_raw(size_t align_axis, - size_t align_size_in_elements, - Handle::HandleVendorType vendor_type = - Handle::HandleVendorType::NOT_SPEC); + static TensorFormat make_raw( + size_t align_axis, size_t align_size_in_elements, + Handle::HandleVendorType vendor_type = Handle::HandleVendorType::NOT_SPEC); static TensorFormat make(size_t align_axis, const Handle* handle); @@ -273,13 +265,11 @@ public: * Note that the alignment may be different if deserialized on another * handle */ - static TensorFormat deserialize(const Handle* handle, const void* buf, - size_t size); + static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); static bool is_valid_image(const TensorLayout& layout) { if (layout.format.type() == TYPE) { - layout.format.as_impl().assert_valid( - layout); + layout.format.as_impl().assert_valid(layout); return true; } return false; @@ -288,8 +278,9 @@ public: TensorFormat change_axis(size_t axis) const override; private: - Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements, - Handle::HandleVendorType vendor_type) + Image2DPack4TensorFormat( + size_t align_axis, size_t align_size_in_elements, + Handle::HandleVendorType vendor_type) : detail::Image2DPack4TensorFormatBase( TYPE, align_axis, align_size_in_elements, vendor_type) {} }; @@ -306,13 +297,12 @@ public: static TensorFormat make(size_t size_nbits); - static TensorFormat deserialize(const Handle* handle, const void* buf, - size_t size); + static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); static bool is_valid_layout(const TensorLayout& layout) { if (layout.format.type() == TYPE) { - layout.format.as_impl() - .assert_valid(layout); + layout.format.as_impl().assert_valid( + layout); return true; } return false; @@ -320,8 +310,7 @@ public: private: LowbitsAlignedToBytesTensorFormat(size_t size_nbits) - : detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, - BYTE_IN_BITS) {} + : detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, BYTE_IN_BITS) {} }; } // namespace megdnn diff --git a/dnn/include/megdnn/tensor_iter.h b/dnn/include/megdnn/tensor_iter.h index 8857fc60..c9c51200 100644 --- a/dnn/include/megdnn/tensor_iter.h +++ b/dnn/include/megdnn/tensor_iter.h @@ -167,13 +167,11 @@ public: TensorIter(const TensorND& tensor) : m_tensor(tensor) {} - Iter begin() const { - return Iter::make(const_cast(m_tensor), 0); - } + Iter begin() const { return Iter::make(const_cast(m_tensor), 0); } Iter end() const { - return Iter::make(const_cast(m_tensor), - m_tensor.layout.total_nr_elems()); + return Iter::make( + const_cast(m_tensor), m_tensor.layout.total_nr_elems()); } }; /*! diff --git a/dnn/include/megdnn/thin/function.h b/dnn/include/megdnn/thin/function.h index ab3849ea..21a5d4e5 100644 --- a/dnn/include/megdnn/thin/function.h +++ b/dnn/include/megdnn/thin/function.h @@ -11,19 +11,19 @@ #pragma once -#include +#include #include -#include #include -#include +#include +#include #include "megdnn/internal/visibility_prologue.h" namespace megdnn { -template +template using thin_function = ::std::function; -} // namespace megdnn +} // namespace megdnn #include "megdnn/internal/visibility_epilogue.h" diff --git a/dnn/include/megdnn/thin/small_vector.h b/dnn/include/megdnn/thin/small_vector.h index c482bd26..fa4a36ec 100644 --- a/dnn/include/megdnn/thin/small_vector.h +++ b/dnn/include/megdnn/thin/small_vector.h @@ -58,18 +58,16 @@ protected: m_end_ptr(first_elm), m_capacity_ptr(static_cast(first_elm) + size) {} - void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, - size_t type_size); + void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size); public: size_t size_in_bytes() const { - return size_t(static_cast(m_end_ptr) - - static_cast(m_begin_ptr)); + return size_t(static_cast(m_end_ptr) - static_cast(m_begin_ptr)); } size_t capacity_in_bytes() const { - return size_t(static_cast(m_capacity_ptr) - - static_cast(m_begin_ptr)); + return size_t( + static_cast(m_capacity_ptr) - static_cast(m_begin_ptr)); } bool empty() const { return m_begin_ptr == m_end_ptr; } @@ -85,20 +83,15 @@ private: U m_first_elm; protected: - SmallVectorTemplateCommon(size_t size) - : SmallVectorBase(&m_first_elm, size) {} + SmallVectorTemplateCommon(size_t size) : SmallVectorBase(&m_first_elm, size) {} void grow_pod(size_t min_sz_in_bytes, size_t type_size) { SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size); } - bool is_small() { - return m_begin_ptr == static_cast(&m_first_elm); - } + bool is_small() { return m_begin_ptr == static_cast(&m_first_elm); } - void reset_to_small() { - m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; - } + void reset_to_small() { m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; } void set_end(T* p) { m_end_ptr = p; } @@ -128,20 +121,12 @@ protected: public: // forwarding iterator creation iterator begin() { return static_cast(m_begin_ptr); } - const_iterator begin() const { - return static_cast(m_begin_ptr); - } - const_iterator cbegin() const { - return static_cast(m_begin_ptr); - } + const_iterator begin() const { return static_cast(m_begin_ptr); } + const_iterator cbegin() const { return static_cast(m_begin_ptr); } iterator end() { return static_cast(m_end_ptr); } - const_iterator end() const { - return static_cast(m_end_ptr); - } - const_iterator cend() const { - return static_cast(m_end_ptr); - } + const_iterator end() const { return static_cast(m_end_ptr); } + const_iterator cend() const { return static_cast(m_end_ptr); } reference at(size_type idx) { if (idx >= size()) { @@ -167,13 +152,9 @@ public: // reverse iterator creation method. reverse_iterator rbegin() { return reverse_iterator(end()); } - const_reverse_iterator rbegin() const { - return const_reverse_iterator(end()); - } + const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); } reverse_iterator rend() { return reverse_iterator(begin()); } - const_reverse_iterator rend() const { - return const_reverse_iterator(begin()); - } + const_reverse_iterator rend() const { return const_reverse_iterator(begin()); } pointer data() { return pointer(begin()); } const_pointer data() const { return const_pointer(begin()); } @@ -207,8 +188,8 @@ protected: template static void uninitialized_move(It1 first, It1 last, It2 dest) { - std::uninitialized_copy(std::make_move_iterator(first), - std::make_move_iterator(last), dest); + std::uninitialized_copy( + std::make_move_iterator(first), std::make_move_iterator(last), dest); } template @@ -293,9 +274,7 @@ protected: memcpy(dest, first, (last - first) * sizeof(T)); } - void grow(size_t min_sz = 0) { - this->grow_pod(min_sz * sizeof(T), sizeof(T)); - } + void grow(size_t min_sz = 0) { this->grow_pod(min_sz * sizeof(T), sizeof(T)); } public: void push_back(const T& _elm) { @@ -318,8 +297,7 @@ public: * SmallVector can be converted to SmallVectorImpl to erase N */ template -class SmallVectorImpl - : public SmallVectorTemplateBase::value> { +class SmallVectorImpl : public SmallVectorTemplateBase::value> { using SuperClass = SmallVectorTemplateBase::value>; public: @@ -329,8 +307,7 @@ public: protected: explicit SmallVectorImpl(unsigned n) - : SmallVectorTemplateBase::value>(n * sizeof(T)) { - } + : SmallVectorTemplateBase::value>(n * sizeof(T)) {} public: SmallVectorImpl(const SmallVectorImpl&) = delete; @@ -354,8 +331,7 @@ public: } else if (n > this->size()) { if (this->capacity() < n) this->grow(n); - for (auto it = this->end(), end = this->begin() + n; it != end; - ++it) + for (auto it = this->end(), end = this->begin() + n; it != end; ++it) new (&*it) T(); this->set_end(this->begin() + n); } @@ -389,10 +365,11 @@ public: void swap(SmallVectorImpl& rhs); /// Add the specified range to the end of the SmallVector. - template ::iterator_category, - std::input_iterator_tag>::value>::type> + template < + typename in_iter, + typename = typename std::enable_if::iterator_category, + std::input_iterator_tag>::value>::type> void append(in_iter in_start, in_iter in_end) { size_type num_inputs = std::distance(in_start, in_end); // Grow allocated space if needed. @@ -432,10 +409,11 @@ public: std::uninitialized_fill(this->begin(), this->end(), elm); } - template ::iterator_category, - std::input_iterator_tag>::value>::type> + template < + typename in_iter, + typename = typename std::enable_if::iterator_category, + std::input_iterator_tag>::value>::type> void assign(in_iter in_start, in_iter in_end) { clear(); append(in_start, in_end); @@ -571,8 +549,7 @@ public: std::fill_n(it, num_overwritten, elm); // Insert the non-overwritten middle part. - std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, - elm); + std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, elm); return it; } @@ -646,8 +623,7 @@ public: if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { this->grow(); } - new (static_cast(this->end())) - T(std::forward(args)...); + new (static_cast(this->end())) T(std::forward(args)...); this->set_end(this->end() + 1); } @@ -661,13 +637,11 @@ public: return std::equal(this->begin(), this->end(), rhs.begin()); } - bool operator!=(const SmallVectorImpl& rhs) const { - return !(*this == rhs); - } + bool operator!=(const SmallVectorImpl& rhs) const { return !(*this == rhs); } bool operator<(const SmallVectorImpl& rhs) const { - return std::lexicographical_compare(this->begin(), this->end(), - rhs.begin(), rhs.end()); + return std::lexicographical_compare( + this->begin(), this->end(), rhs.begin(), rhs.end()); } }; @@ -698,15 +672,13 @@ void SmallVectorImpl::swap(SmallVectorImpl& rhs) { // Copy over the extra elms. if (this->size() > rhs.size()) { size_t elm_diff = this->size() - rhs.size(); - this->uninitialized_move(this->begin() + num_shared, this->end(), - rhs.end()); + this->uninitialized_move(this->begin() + num_shared, this->end(), rhs.end()); rhs.set_end(rhs.end() + elm_diff); this->destroy_range(this->begin() + num_shared, this->end()); this->set_end(this->begin() + num_shared); } else if (rhs.size() > this->size()) { size_t elm_diff = rhs.size() - this->size(); - this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), - this->end()); + this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), this->end()); this->set_end(this->end() + elm_diff); this->destroy_range(rhs.begin() + num_shared, rhs.end()); rhs.set_end(rhs.begin() + num_shared); @@ -714,8 +686,7 @@ void SmallVectorImpl::swap(SmallVectorImpl& rhs) { } template -SmallVectorImpl& SmallVectorImpl::operator=( - const SmallVectorImpl& rhs) { +SmallVectorImpl& SmallVectorImpl::operator=(const SmallVectorImpl& rhs) { if (this == &rhs) return *this; size_t rhs_sz = rhs.size(); @@ -740,8 +711,7 @@ SmallVectorImpl& SmallVectorImpl::operator=( } else if (cur_sz) { std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin()); } - std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), - this->begin() + cur_sz); + std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz); this->set_end(this->begin() + rhs_sz); return *this; } @@ -785,8 +755,7 @@ SmallVectorImpl& SmallVectorImpl::operator=(SmallVectorImpl&& rhs) { std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin()); } - this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), - this->begin() + cur_sz); + this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz); this->set_end(this->begin() + rhs_sz); @@ -826,8 +795,7 @@ class SmallVector : public SmallVectorImpl { public: SmallVector() : SmallVectorImpl(N) {} - explicit SmallVector(size_t size, const T& value = T()) - : SmallVectorImpl(N) { + explicit SmallVector(size_t size, const T& value = T()) : SmallVectorImpl(N) { this->assign(size, value); } @@ -901,15 +869,13 @@ namespace std { /// Implement std::swap in terms of SmallVector swap. template -inline void swap(megdnn::SmallVectorImpl& lhs, - megdnn::SmallVectorImpl& rhs) { +inline void swap(megdnn::SmallVectorImpl& lhs, megdnn::SmallVectorImpl& rhs) { lhs.swap(rhs); } /// Implement std::swap in terms of SmallVector swap. template -inline void swap(megdnn::SmallVector& lhs, - megdnn::SmallVector& rhs) { +inline void swap(megdnn::SmallVector& lhs, megdnn::SmallVector& rhs) { lhs.swap(rhs); } } // end namespace std diff --git a/dnn/include/megdnn/version.h b/dnn/include/megdnn/version.h index a9e36c32..e6365327 100644 --- a/dnn/include/megdnn/version.h +++ b/dnn/include/megdnn/version.h @@ -17,13 +17,13 @@ #include "megdnn/internal/visibility_prologue.h" namespace megdnn { - struct Version { - int major, minor, patch; - }; +struct Version { + int major, minor, patch; +}; - //! get megdnn version of the binary - Version get_version(); -} +//! get megdnn version of the binary +Version get_version(); +} // namespace megdnn #include "megdnn/internal/visibility_epilogue.h" diff --git a/dnn/src/aarch64/conv_bias/fp16/algos.cpp b/dnn/src/aarch64/conv_bias/fp16/algos.cpp index 402a580d..399211f0 100644 --- a/dnn/src/aarch64/conv_bias/fp16/algos.cpp +++ b/dnn/src/aarch64/conv_bias/fp16/algos.cpp @@ -22,18 +22,17 @@ using namespace aarch64; /* ===================== stride-2 algo ===================== */ MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) -bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoF16DirectStride2::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; return param.filter_meta.format == param::Convolution::Format::NCHW && param.src_type.enumv() == DTypeEnum::Float16 && param.filter_type.enumv() == DTypeEnum::Float16 && - param.dst_type.enumv() == DTypeEnum::Float16 && - !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && + param.dst_type.enumv() == DTypeEnum::Float16 && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); } MIDOUT_END(); @@ -52,8 +51,7 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) { return get_kimpls(param); @@ -62,8 +60,7 @@ ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( return {}; } -SmallVector -ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( +SmallVector ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( const NCBKernSizeParam& param) const { auto fm = param.filter_meta; auto FH = fm.spatial[0]; @@ -72,8 +69,9 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( size_t OC = param.filter_meta.ocpg; size_t group = fm.group; bool large_group = group >= param.nr_threads; - using Func = std::function; + using Func = std::function; Func conv = nullptr; if (FH == 2) { conv = fp16::conv_stride2::do_conv_2x2_stride2; @@ -101,31 +99,35 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { arm_common::MultithreadDirectConvCommon:: - copy_padding_kern_stride(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern_stride( + bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { arm_common::MultithreadDirectConvCommon:: - do_conv_kern_stride(bundle, kern_param, ncb_index, conv, - {ncb_index.thread_id, 0, oc}); + do_conv_kern_stride( + bundle, kern_param, ncb_index, conv, + {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); arm_common::MultithreadDirectConvCommon:: - copy_padding_kern_stride(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern_stride( + bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); - auto do_conv = [bundle, conv](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto do_conv = [bundle, conv]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); arm_common::MultithreadDirectConvCommon:: - do_conv_kern_stride(bundle, kern_param, ncb_index, conv, - ncb_index.ndrange_id); + do_conv_kern_stride( + bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); }; ret_kerns.push_back({do_conv, {group, N, OC}}); } diff --git a/dnn/src/aarch64/conv_bias/fp16/algos.h b/dnn/src/aarch64/conv_bias/fp16/algos.h index 16aa65d9..cb32df93 100644 --- a/dnn/src/aarch64/conv_bias/fp16/algos.h +++ b/dnn/src/aarch64/conv_bias/fp16/algos.h @@ -18,13 +18,13 @@ namespace aarch64 { /* ===================== stride-2 algo ===================== */ class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; + public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV8F16STRD2"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; diff --git a/dnn/src/aarch64/conv_bias/fp16/stride2_kern.h b/dnn/src/aarch64/conv_bias/fp16/stride2_kern.h index 4f7ad3a1..e923e0dd 100644 --- a/dnn/src/aarch64/conv_bias/fp16/stride2_kern.h +++ b/dnn/src/aarch64/conv_bias/fp16/stride2_kern.h @@ -20,9 +20,9 @@ namespace aarch64 { namespace fp16 { namespace conv_stride2 { -static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, - __fp16* dst, size_t IH, size_t IW, size_t OH, - size_t OW, size_t IC) { +static void do_conv_2x2_stride2( + const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; size_t width = OW >> 3; size_t mod4_left = width & 3; @@ -162,10 +162,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, "5: \n" : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) : "r"(mod4_left), "w"(_k0123) - : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", - "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", - "v31"); + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v28", "v29", "v30", "v31"); r0 += tail_step; r1 += tail_step; @@ -175,9 +174,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, } } -static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, - __fp16* dst, size_t IH, size_t IW, size_t OH, - size_t OW, size_t IC) { +static void do_conv_3x3_stride2( + const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; size_t width = OW >> 3; size_t mod3_left = width % 3; @@ -352,10 +351,10 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, "3: \n" : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) - : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", - "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", - "v25", "v26", "v27", "v28", "v29"); + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29"); r0 += tail_step; r1 += tail_step; @@ -366,9 +365,9 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, } } -static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, - __fp16* dst, size_t IH, size_t IW, size_t OH, - size_t OW, size_t IC) { +static void do_conv_5x5_stride2( + const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; size_t width = OW >> 3; size_t mod2_left = width & 1; @@ -384,18 +383,12 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, const __fp16* r4 = src_ptr + IW * 4; register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); - register MEGDNN_SIMD_TYPE _k4567 asm("v1") = - MEGDNN_SIMD_LOADU(filter + 4); - register MEGDNN_SIMD_TYPE _k891011 asm("v2") = - MEGDNN_SIMD_LOADU(filter + 8); - register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = - MEGDNN_SIMD_LOADU(filter + 12); - register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = - MEGDNN_SIMD_LOADU(filter + 16); - register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = - MEGDNN_SIMD_LOADU(filter + 20); - register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = - MEGDNN_SIMD_SET1(filter[24]); + register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4); + register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8); + register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12); + register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16); + register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20); + register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = MEGDNN_SIMD_SET1(filter[24]); for (size_t i = 0; i < OH; i++) { asm volatile( @@ -592,15 +585,14 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, "bne 2b \n" "3: \n" - : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), - "+r"(r3), "+r"(r4) + : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), + "+r"(r4) : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), - "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), - "r"(mod2_left) - : "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", - "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", - "v28", "v29", "v30", "v31"); + "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left) + : "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); r0 += tail_step; r1 += tail_step; @@ -613,9 +605,9 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, } } -static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, - __fp16* dst, size_t IH, size_t IW, size_t OH, - size_t OW, size_t IC) { +static void do_conv_7x7_stride2( + const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; size_t width = OW >> 3; @@ -632,30 +624,20 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, const __fp16* r6 = src_ptr + IW * 6; register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); - register MEGDNN_SIMD_TYPE _k4567 asm("v1") = - MEGDNN_SIMD_LOADU(filter + 4); - register MEGDNN_SIMD_TYPE _k891011 asm("v2") = - MEGDNN_SIMD_LOADU(filter + 8); - register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = - MEGDNN_SIMD_LOADU(filter + 12); - register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = - MEGDNN_SIMD_LOADU(filter + 16); - register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = - MEGDNN_SIMD_LOADU(filter + 20); - register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = - MEGDNN_SIMD_LOADU(filter + 24); - register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = - MEGDNN_SIMD_LOADU(filter + 28); - register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = - MEGDNN_SIMD_LOADU(filter + 32); - register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = - MEGDNN_SIMD_LOADU(filter + 36); + register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4); + register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8); + register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12); + register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16); + register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20); + register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = MEGDNN_SIMD_LOADU(filter + 24); + register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = MEGDNN_SIMD_LOADU(filter + 28); + register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = MEGDNN_SIMD_LOADU(filter + 32); + register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = MEGDNN_SIMD_LOADU(filter + 36); register MEGDNN_SIMD_TYPE _k40414243 asm("v10") = MEGDNN_SIMD_LOADU(filter + 40); register MEGDNN_SIMD_TYPE _k44454647 asm("v11") = MEGDNN_SIMD_LOADU(filter + 44); - register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = - MEGDNN_SIMD_SET1(filter[48]); + register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = MEGDNN_SIMD_SET1(filter[48]); for (size_t i = 0; i < OH; i++) { asm volatile( @@ -1005,16 +987,15 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, "bne 2b \n" "3: \n" - : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), - "+r"(r4), "+r"(r5), "+r"(r6) + : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4), + "+r"(r5), "+r"(r6) : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), - "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), - "w"(_k48484848) - : "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", - "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); + "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848) + : "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31"); r0 += tail_step; r1 += tail_step; diff --git a/dnn/src/aarch64/conv_bias/fp32/algos.cpp b/dnn/src/aarch64/conv_bias/fp32/algos.cpp index 9517b9d7..a5d3358c 100644 --- a/dnn/src/aarch64/conv_bias/fp32/algos.cpp +++ b/dnn/src/aarch64/conv_bias/fp32/algos.cpp @@ -21,18 +21,17 @@ using namespace megdnn; using namespace aarch64; MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) -bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoF32DirectStride2::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; return param.filter_meta.format == param::ConvBias::Format::NCHW && param.src_type.enumv() == DTypeEnum::Float32 && param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32 && - !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && + param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); } MIDOUT_END(); @@ -50,8 +49,7 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { return get_kimpls(param); @@ -60,8 +58,7 @@ ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( return {}; } -SmallVector -ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( +SmallVector ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( const NCBKernSizeParam& param) const { auto fm = param.filter_meta; auto FH = fm.spatial[0]; @@ -70,8 +67,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( size_t OC = param.filter_meta.ocpg; size_t group = fm.group; bool large_group = group >= param.nr_threads; - using Func = std::function; + using Func = std::function; Func conv = nullptr; if (FH == 2) { conv = fp32::conv_stride2::do_conv_2x2_stride2; @@ -83,8 +81,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( conv = fp32::conv_stride2::do_conv_7x7_stride2; } - WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< - float, float>::get_bundle_stride(param, large_group); + WorkspaceBundle bundle = + arm_common::MultithreadDirectConvCommon::get_bundle_stride( + param, large_group); SmallVector ret_kerns; //! Dense conv and small group @@ -99,34 +98,34 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { arm_common::MultithreadDirectConvCommon:: - copy_padding_kern_stride(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern_stride( + bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - arm_common::MultithreadDirectConvCommon< - float, float>::do_conv_kern_stride(bundle, kern_param, - ncb_index, conv, - {ncb_index.thread_id, - 0, oc}); + arm_common::MultithreadDirectConvCommon:: + do_conv_kern_stride( + bundle, kern_param, ncb_index, conv, + {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); arm_common::MultithreadDirectConvCommon:: - copy_padding_kern_stride(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern_stride( + bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); - auto do_conv = [bundle, conv](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto do_conv = [bundle, conv]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - arm_common::MultithreadDirectConvCommon< - float, float>::do_conv_kern_stride(bundle, kern_param, - ncb_index, conv, - ncb_index.ndrange_id); + arm_common::MultithreadDirectConvCommon::do_conv_kern_stride( + bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); }; ret_kerns.push_back({do_conv, {group, N, OC}}); } diff --git a/dnn/src/aarch64/conv_bias/fp32/algos.h b/dnn/src/aarch64/conv_bias/fp32/algos.h index da810189..3c274410 100644 --- a/dnn/src/aarch64/conv_bias/fp32/algos.h +++ b/dnn/src/aarch64/conv_bias/fp32/algos.h @@ -22,14 +22,14 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl; class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; + public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV8F32STRD2"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; diff --git a/dnn/src/aarch64/conv_bias/fp32/stride2_kern.h b/dnn/src/aarch64/conv_bias/fp32/stride2_kern.h index 3db3c197..6faf4613 100644 --- a/dnn/src/aarch64/conv_bias/fp32/stride2_kern.h +++ b/dnn/src/aarch64/conv_bias/fp32/stride2_kern.h @@ -16,16 +16,15 @@ namespace megdnn { namespace aarch64 { -namespace fp32{ +namespace fp32 { namespace conv_stride2 { - //! For the detail tune process, refer to `expr/conv_aarch64_stride2/main.cpp` // refer to function do_conv_2x2_stride2_asm_unroll4 -static void do_conv_2x2_stride2(const float* src, const float* filter, - float* dst, size_t IH, size_t IW, size_t OH, - size_t OW, size_t IC) { +static void do_conv_2x2_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; size_t width = OW >> 2; size_t mod4_left = width & 3; @@ -165,10 +164,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter, "5: \n" : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) : "r"(mod4_left), "w"(_k0123) - : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", - "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", - "v31"); + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v28", "v29", "v30", "v31"); r0 += tail_step; r1 += tail_step; @@ -179,9 +177,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter, } // refer to function do_conv_3x3_stride2_asm_unroll3 -static void do_conv_3x3_stride2(const float* src, const float* filter, - float* dst, size_t IH, size_t IW, size_t OH, - size_t OW, size_t IC) { +static void do_conv_3x3_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; size_t width = OW >> 2; size_t mod3_left = width % 3; @@ -269,7 +267,7 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, "ld2 {v1.4s, v2.4s}, [%2], #32 \n" // 0, 2, 4, 6 "ld2 {v5.4s, v6.4s}, [%3], #32 \n" - "ld1 {v3.4s}, [%2] \n" // load src 8 12 ... + "ld1 {v3.4s}, [%2] \n" // load src 8 12 ... "fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i] "ext v7.16b, v1.16b, v3.16b, #4 \n" // 2, 4, 6, 8 "fmla v0.4s, v2.4s, v22.4s \n" @@ -356,10 +354,10 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, "3: \n" : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) - : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", - "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", - "v25", "v26", "v27", "v28", "v29"); + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29"); r0 += tail_step; r1 += tail_step; @@ -371,9 +369,9 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, } // refer to function do_conv_5x5_stride2_asm_unroll2 -static void do_conv_5x5_stride2(const float* src, const float* filter, - float* dst, size_t IH, size_t IW, size_t OH, - size_t OW, size_t IC) { +static void do_conv_5x5_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; size_t width = OW >> 2; size_t mod2_left = width & 1; @@ -591,15 +589,13 @@ static void do_conv_5x5_stride2(const float* src, const float* filter, "bne 2b \n" "3: \n" - : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), - "+r"(r3), "+r"(r4) + : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), + "+r"(r4) : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), - "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), - "r"(mod2_left) - : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", - "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", - "v23", "v24"); + "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left) + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24"); r0 += tail_step; r1 += tail_step; @@ -613,9 +609,9 @@ static void do_conv_5x5_stride2(const float* src, const float* filter, } // refer to function do_conv_7x7_stride2_asm_unroll2 -static void do_conv_7x7_stride2(const float* src, const float* filter, - float* dst, size_t IH, size_t IW, size_t OH, - size_t OW, size_t IC) { +static void do_conv_7x7_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; size_t width = OW >> 2; @@ -993,16 +989,15 @@ static void do_conv_7x7_stride2(const float* src, const float* filter, "bne 2b \n" "3: \n" - : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), - "+r"(r4), "+r"(r5), "+r"(r6) + : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4), + "+r"(r5), "+r"(r6) : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), - "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), - "w"(_k48484848) - : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", - "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18"); + "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848) + : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18"); r0 += tail_step; r1 += tail_step; diff --git a/dnn/src/aarch64/conv_bias/int8/algos.cpp b/dnn/src/aarch64/conv_bias/int8/algos.cpp index d6d0d53d..79b6b0d3 100644 --- a/dnn/src/aarch64/conv_bias/int8/algos.cpp +++ b/dnn/src/aarch64/conv_bias/int8/algos.cpp @@ -68,9 +68,9 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( size_t N = OH * OW; #if MGB_ENABLE_DOT -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ M, N, K, param.filter_type, param.src_type, param.dst_type); \ part2 = megdnn::matmul::GemmInterleaved< \ @@ -84,11 +84,12 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( DISPATCH_GEMM_BIAS(s8_4x4, 0) } #else -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ - MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ - _bias_midout_enum, _nonline_midout_enum) { \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN( \ + megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ M, N, K, param.filter_type, param.src_type, param.dst_type); \ part2 = megdnn::matmul::GemmInterleaved< \ @@ -104,8 +105,8 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( return {nullptr, {part0, part1, part2}}; } -void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, - const NCBKernIndex& ncb_index) { +void ConvBiasImpl::AlgoS8MatrixMul::kimpl( + const NCBKernParam& param, const NCBKernIndex& ncb_index) { auto is_xcorr = !param.filter_meta.should_flip; UNPACK_CONV_NCB_KERN_SIZES(param); auto bundle = get_bundle(param); @@ -157,29 +158,28 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); } else { if (is_xcorr) - img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, - FW, SH, SW); + img2col_stride( + src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); else - img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, - FW, SH, SW); + img2col_stride( + src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); } } { - Workspace workspace(static_cast(bundle.get(2)), - bundle.get_size(2)); + Workspace workspace( + static_cast(bundle.get(2)), bundle.get_size(2)); size_t M = OC; size_t K = IC * FH * FW; size_t N = OH * OW; #if MGB_ENABLE_DOT -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ - matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ - M, N, K, param.filter_type, param.src_type, param.dst_type); \ - megdnn::matmul::GemmInterleaved< \ - matmul::gemm_##_gemm##_##_bias##_##_nonline> \ - gemm_interleaved(M, N, K, false, false, strategy); \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved \ + gemm_interleaved(M, N, K, false, false, strategy); \ gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); if (cpuinfo_has_arm_neon_dot()) { @@ -188,19 +188,18 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, DISPATCH_GEMM_BIAS(s8_4x4, 0) } #else -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ - MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ - _bias_midout_enum, _nonline_midout_enum) { \ - matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ - M, N, K, param.filter_type, param.src_type, param.dst_type); \ - megdnn::matmul::GemmInterleaved< \ - matmul::gemm_##_gemm##_##_bias##_##_nonline> \ - gemm_interleaved(M, N, K, false, false, strategy); \ - gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ - bias); \ - } \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN( \ + megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved \ + gemm_interleaved(M, N, K, false, false, strategy); \ + gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ + } \ MIDOUT_END() DISPATCH_GEMM_BIAS(s8_4x4, 0) #endif diff --git a/dnn/src/aarch64/conv_bias/int8/algos.h b/dnn/src/aarch64/conv_bias/int8/algos.h index c74f856c..b5e2ba05 100644 --- a/dnn/src/aarch64/conv_bias/int8/algos.h +++ b/dnn/src/aarch64/conv_bias/int8/algos.h @@ -12,8 +12,8 @@ #pragma once #include "src/aarch64/conv_bias/opr_impl.h" -#include "src/fallback/conv_bias/opr_impl.h" #include "src/common/opr_delegate.h" +#include "src/fallback/conv_bias/opr_impl.h" namespace megdnn { namespace aarch64 { @@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "S8MATMUL"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override { return get_bundle(param).total_size_in_bytes(); } - SmallVector dispatch_kerns( - const NCBKernSizeParam& param) const override { + SmallVector dispatch_kerns(const NCBKernSizeParam& param) const override { size_t group = param.filter_meta.group; return {{kimpl, {group, 1_z, 1_z}}}; } diff --git a/dnn/src/aarch64/conv_bias/int8/strategy.cpp b/dnn/src/aarch64/conv_bias/int8/strategy.cpp index 6506912f..8d80df1e 100644 --- a/dnn/src/aarch64/conv_bias/int8/strategy.cpp +++ b/dnn/src/aarch64/conv_bias/int8/strategy.cpp @@ -29,9 +29,10 @@ struct KernCaller; #if MGB_ENABLE_DOT template struct KernCaller { - static void run(const dt_int8* packA, const dt_int8* packB, size_t M, - size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, - Op op, const dt_int32* bias, dt_int32* workspace) { + static void run( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, + dt_int32* workspace) { megdnn_assert(is_first_k); constexpr size_t A_INTERLEAVE = 8; @@ -49,19 +50,19 @@ struct KernCaller { size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12, - is_first_k); + matmul_8x12x4::kern_8x12( + packA, cur_packB, K, workspace, 12, is_first_k); - arm_common::ConvBiasMatmul::postprocess(bias, workspace, - output, LDC, op); + arm_common::ConvBiasMatmul:: + postprocess(bias, workspace, output, LDC, op); output += B_INTERLEAVE; cur_packB += K12; } for (; n < N; n += 4) { - matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4, - is_first_k, std::min(N - n, 4)); + matmul_8x12x4::kern_8x4( + packA, cur_packB, K, workspace, 4, is_first_k, + std::min(N - n, 4)); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ @@ -83,9 +84,9 @@ struct KernCaller { const dt_int8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12, - is_first_k, - std::min(M - m, 4)); + matmul_8x12x4::kern_4x12( + packA, cur_packB, K, workspace, 12, is_first_k, + std::min(M - m, 4)); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); @@ -97,14 +98,13 @@ struct KernCaller { } for (; n < N; n += 4) { - matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4, - is_first_k, std::min(M - m, 4), - std::min(N - n, 4)); + matmul_8x12x4::kern_4x4( + packA, cur_packB, K, workspace, 4, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); - DISPATCH_M(cb, std::min(M - m, 4), - std::min(N - n, 4)); + DISPATCH_M(cb, std::min(M - m, 4), std::min(N - n, 4)); #undef cb output += 4; @@ -122,9 +122,10 @@ struct KernCaller { template struct KernCaller { - static void run(const dt_int8* packA, const dt_int8* packB, size_t M, - size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, - Op op, const dt_int32* bias, dt_int32* workspace) { + static void run( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, + dt_int32* workspace) { megdnn_assert(is_first_k); constexpr size_t A_INTERLEAVE = 4; @@ -140,20 +141,18 @@ struct KernCaller { size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, - is_first_k); - arm_common::ConvBiasMatmul::postprocess(bias, workspace, - output, LDC, op); + matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, is_first_k); + arm_common::ConvBiasMatmul::postprocess( + bias, workspace, output, LDC, op); output += B_INTERLEAVE; cur_packB += K4; } for (; n < N; n += B_INTERLEAVE) { - matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace, - 4, is_first_k, 4, - std::min(N - n, 4)); + matmul_4x4x16::kern_4x4_remain( + packA, cur_packB, K, workspace, 4, is_first_k, 4, + std::min(N - n, 4)); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); @@ -182,8 +181,7 @@ struct KernCaller { #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); - DISPATCH_M(cb, std::min(M - m, 4), - std::min(N - n, 4)); + DISPATCH_M(cb, std::min(M - m, 4), std::min(N - n, 4)); #undef cb output += B_INTERLEAVE; cur_packB += K4; @@ -200,21 +198,19 @@ struct KernCaller { MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) -void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax, bool transpose) const { +void gemm_s8_4x4_nobias_identity::pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { if (transpose) { - matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, - kmax); + matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); } else { - matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, - kmax); + matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); } } -void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, - int ldin, int x0, int xmax, int k0, - int kmax, bool transpose) const { +void gemm_s8_4x4_nobias_identity::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); } else { @@ -229,23 +225,21 @@ size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { #if MGB_ENABLE_DOT MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) -void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax, bool transpose) const { +void gemm_s8_8x12_nobias_identity::pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t); MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t); if (transpose) { - matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, - kmax); + matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); } else { - matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, - kmax); + matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); } } -void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, - int ldin, int x0, int xmax, int k0, - int kmax, bool transpose) const { +void gemm_s8_8x12_nobias_identity::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); } else { @@ -259,18 +253,17 @@ size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const { #endif -#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ - void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ - const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \ - size_t K, dt_int8* C, size_t LDC, bool is_first_k, \ - const dt_int32* bias, dt_int32* workspace) const { \ - float scale_A = A_dtype.param().scale; \ - float scale_B = B_dtype.param().scale; \ - float scale_C = C_dtype.param().scale; \ - DEFINE_OP(_OP); \ - impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ - packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ - workspace); \ +#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ + void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, \ + dt_int8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ + dt_int32* workspace) const { \ + float scale_A = A_dtype.param().scale; \ + float scale_B = B_dtype.param().scale; \ + float scale_C = C_dtype.param().scale; \ + DEFINE_OP(_OP); \ + impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ + packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace); \ } #define DEFINE_OP(_Op) \ @@ -286,18 +279,16 @@ KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) #endif #undef DEFINE_OP -#define DEFINE_OP(_Op) \ - arm_common::_Op op(scale_A* scale_B, \ - scale_A* scale_B, scale_C); +#define DEFINE_OP(_Op) \ + arm_common::_Op op( \ + scale_A* scale_B, scale_A* scale_B, scale_C); KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) -KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, - FuseAddHSwishOp) +KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) #if MGB_ENABLE_DOT KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) -KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, - FuseAddHSwishOp) +KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) #endif #undef DEFINE_OP diff --git a/dnn/src/aarch64/conv_bias/int8/strategy.h b/dnn/src/aarch64/conv_bias/int8/strategy.h index d4d0224d..d2d13d86 100644 --- a/dnn/src/aarch64/conv_bias/int8/strategy.h +++ b/dnn/src/aarch64/conv_bias/int8/strategy.h @@ -20,43 +20,42 @@ namespace matmul { * * \name gemm___biasmode_nolinemode */ -MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16, - false, true, - gemm_s8_4x4_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( + dt_int8, dt_int8, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu, - gemm_s8_4x4_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_4x4_nobias_relu, gemm_s8_4x4_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish, - gemm_s8_4x4_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_4x4_nobias_hswish, gemm_s8_4x4_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity, - gemm_s8_4x4_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_4x4_bias_channel_identity, gemm_s8_4x4_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu, - gemm_s8_4x4_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_4x4_bias_channel_relu, gemm_s8_4x4_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, - gemm_s8_4x4_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_4x4_bias_channel_hswish, gemm_s8_4x4_nobias_identity); #if MGB_ENABLE_DOT -MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, - false, true, - gemm_s8_8x12_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( + dt_int8, dt_int8, dt_int32, 8, 12, 4, false, true, + gemm_s8_8x12_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu, - gemm_s8_8x12_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_8x12_nobias_relu, gemm_s8_8x12_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish, - gemm_s8_8x12_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_8x12_nobias_hswish, gemm_s8_8x12_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity, - gemm_s8_8x12_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_8x12_bias_channel_identity, gemm_s8_8x12_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu, - gemm_s8_8x12_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_8x12_bias_channel_relu, gemm_s8_8x12_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, - gemm_s8_8x12_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_8x12_bias_channel_hswish, gemm_s8_8x12_nobias_identity); #endif } // namespace matmul diff --git a/dnn/src/aarch64/conv_bias/opr_impl.cpp b/dnn/src/aarch64/conv_bias/opr_impl.cpp index 53da143f..fcc1b5f7 100644 --- a/dnn/src/aarch64/conv_bias/opr_impl.cpp +++ b/dnn/src/aarch64/conv_bias/opr_impl.cpp @@ -13,13 +13,13 @@ #include "src/aarch64/conv_bias/int8/algos.h" #include "src/aarch64/conv_bias/quint8/algos.h" -#include "src/naive/handle.h" -#include "src/common/utils.h" #include "src/common/metahelper.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" -#include "src/fallback/convolution/opr_impl.h" -#include "src/aarch64/conv_bias/fp32/algos.h" #include "src/aarch64/conv_bias/fp16/algos.h" +#include "src/aarch64/conv_bias/fp32/algos.h" +#include "src/fallback/convolution/opr_impl.h" using namespace megdnn; using namespace aarch64; @@ -56,12 +56,10 @@ public: const SmallVector& direct_algos() const { return m_direct_algos; } - const SmallVector& matmul_algos() - const { + const SmallVector& matmul_algos() const { return m_matmul_algos; } const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } - }; const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { @@ -71,15 +69,16 @@ const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) -SmallVector -ConvBiasImpl::get_all_packed_algo() { +SmallVector ConvBiasImpl::get_all_packed_algo() { auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); - algos.insert(algos.begin(), algo_pack().direct_algos().begin(), - algo_pack().direct_algos().end()); + algos.insert( + algos.begin(), algo_pack().direct_algos().begin(), + algo_pack().direct_algos().end()); //! We put matmul algos at the begin. Because matmul will get privilege when //! prefer return true. See - algos.insert(algos.begin(), algo_pack().matmul_algos().begin(), - algo_pack().matmul_algos().end()); + algos.insert( + algos.begin(), algo_pack().matmul_algos().begin(), + algo_pack().matmul_algos().end()); return std::move(algos); } diff --git a/dnn/src/aarch64/conv_bias/opr_impl.h b/dnn/src/aarch64/conv_bias/opr_impl.h index f47c9f62..cfabeb6e 100644 --- a/dnn/src/aarch64/conv_bias/opr_impl.h +++ b/dnn/src/aarch64/conv_bias/opr_impl.h @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once -#include "src/common/utils.h" #include "src/arm_common/conv_bias/opr_impl.h" +#include "src/common/utils.h" namespace megdnn { namespace aarch64 { diff --git a/dnn/src/aarch64/conv_bias/quint8/algos.cpp b/dnn/src/aarch64/conv_bias/quint8/algos.cpp index 713f516a..11596fdc 100644 --- a/dnn/src/aarch64/conv_bias/quint8/algos.cpp +++ b/dnn/src/aarch64/conv_bias/quint8/algos.cpp @@ -70,9 +70,9 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( size_t N = OH * OW; #if MGB_ENABLE_DOT -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ M, N, K, param.filter_type, param.src_type, param.dst_type); \ part2 = megdnn::matmul::GemmInterleaved< \ @@ -86,11 +86,12 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0); } #else -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ - MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ - _bias_midout_enum, _nonline_midout_enum) { \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN( \ + megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ M, N, K, param.filter_type, param.src_type, param.dst_type); \ part2 = megdnn::matmul::GemmInterleaved< \ @@ -106,8 +107,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( return {nullptr, {part0, part1, part2}}; } -void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, - const NCBKernIndex& ncb_index) { +void ConvBiasImpl::AlgoQU8MatrixMul::kimpl( + const NCBKernParam& param, const NCBKernIndex& ncb_index) { auto is_xcorr = !param.filter_meta.should_flip; UNPACK_CONV_NCB_KERN_SIZES(param); auto bundle = get_bundle(param); @@ -160,29 +161,28 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); } else { if (is_xcorr) - img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, - FW, SH, SW); + img2col_stride( + src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); else - img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, - FW, SH, SW); + img2col_stride( + src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); } } { - Workspace workspace(static_cast(bundle.get(2)), - bundle.get_size(2)); + Workspace workspace( + static_cast(bundle.get(2)), bundle.get_size(2)); size_t M = OC; size_t K = IC * FH * FW; size_t N = OH * OW; #if MGB_ENABLE_DOT -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ - matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ - M, N, K, param.filter_type, param.src_type, param.dst_type); \ - megdnn::matmul::GemmInterleaved< \ - matmul::gemm_##_gemm##_##_bias##_##_nonline> \ - gemm_interleaved(M, N, K, false, false, strategy); \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved \ + gemm_interleaved(M, N, K, false, false, strategy); \ gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); if (cpuinfo_has_arm_neon_dot()) { @@ -191,19 +191,18 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) } #else -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ - MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ - _bias_midout_enum, _nonline_midout_enum) { \ - matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ - M, N, K, param.filter_type, param.src_type, param.dst_type); \ - megdnn::matmul::GemmInterleaved< \ - matmul::gemm_##_gemm##_##_bias##_##_nonline> \ - gemm_interleaved(M, N, K, false, false, strategy); \ - gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ - bias); \ - } \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN( \ + megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ + _bias_midout_enum, _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved \ + gemm_interleaved(M, N, K, false, false, strategy); \ + gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ + } \ MIDOUT_END() DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) diff --git a/dnn/src/aarch64/conv_bias/quint8/algos.h b/dnn/src/aarch64/conv_bias/quint8/algos.h index 5c36f75a..bfccbeac 100644 --- a/dnn/src/aarch64/conv_bias/quint8/algos.h +++ b/dnn/src/aarch64/conv_bias/quint8/algos.h @@ -12,8 +12,8 @@ #pragma once #include "src/aarch64/conv_bias/opr_impl.h" -#include "src/fallback/conv_bias/opr_impl.h" #include "src/common/opr_delegate.h" +#include "src/fallback/conv_bias/opr_impl.h" namespace megdnn { namespace aarch64 { @@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { static void kimpl(const NCBKernParam& param, const NCBKernIndex&); public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "QU8MATMUL"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override { return get_bundle(param).total_size_in_bytes(); } - SmallVector dispatch_kerns( - const NCBKernSizeParam& param) const override { + SmallVector dispatch_kerns(const NCBKernSizeParam& param) const override { size_t group = param.filter_meta.group; return {{kimpl, {group, 1_z, 1_z}}}; } diff --git a/dnn/src/aarch64/conv_bias/quint8/strategy.cpp b/dnn/src/aarch64/conv_bias/quint8/strategy.cpp index dda6badf..da9e741f 100644 --- a/dnn/src/aarch64/conv_bias/quint8/strategy.cpp +++ b/dnn/src/aarch64/conv_bias/quint8/strategy.cpp @@ -14,8 +14,8 @@ #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" -#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" #include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" +#include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" #include "src/arm_common/conv_bias/matmul_postprocess.h" using namespace megdnn; @@ -29,10 +29,10 @@ struct KernCaller; #if MGB_ENABLE_DOT template struct KernCaller { - static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, - size_t N, size_t K, dt_uint8* C, size_t LDC, - bool is_first_k, Op op, const dt_int32* bias, - dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { + static void run( + const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, + dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, + dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { megdnn_assert(is_first_k); constexpr size_t A_INTERLEAVE = 8; constexpr size_t B_INTERLEAVE = 8; @@ -50,20 +50,19 @@ struct KernCaller { size_t n = 0; const dt_uint8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8, - is_first_k, zp_A, zp_B, zAB); + matmul_8x8x4::kern_8x8( + packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B, zAB); - arm_common::ConvBiasMatmul::postprocess(bias, workspace, - output, LDC, op); + arm_common::ConvBiasMatmul:: + postprocess(bias, workspace, output, LDC, op); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4, - is_first_k, std::min(N - n, 4), - zp_A, zp_B, zAB); + matmul_8x8x4::kern_8x4( + packA, cur_packB, K, workspace, 4, is_first_k, + std::min(N - n, 4), zp_A, zp_B, zAB); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); @@ -84,9 +83,9 @@ struct KernCaller { const dt_uint8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8, - is_first_k, std::min(M - m, 4), - zp_A, zp_B, zAB); + matmul_8x8x4::kern_4x8( + packA, cur_packB, K, workspace, 8, is_first_k, + std::min(M - m, 4), zp_A, zp_B, zAB); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); @@ -98,15 +97,14 @@ struct KernCaller { } for (; n < N; n += 4) { - matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4, - is_first_k, std::min(M - m, 4), - std::min(N - n, 4), zp_A, zp_B, - zAB); + matmul_8x8x4::kern_4x4( + packA, cur_packB, K, workspace, 4, is_first_k, + std::min(M - m, 4), std::min(N - n, 4), zp_A, + zp_B, zAB); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); - DISPATCH_M(cb, std::min(M - m, 4), - std::min(N - n, 4)); + DISPATCH_M(cb, std::min(M - m, 4), std::min(N - n, 4)); #undef cb output += 4; @@ -124,10 +122,10 @@ struct KernCaller { template struct KernCaller { - static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, - size_t N, size_t K, dt_uint8* C, size_t LDC, - bool is_first_k, Op op, const dt_int32* bias, - dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { + static void run( + const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, + dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, + dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { megdnn_assert(is_first_k); constexpr size_t A_INTERLEAVE = 8; @@ -144,27 +142,25 @@ struct KernCaller { size_t n = 0; const dt_uint8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8, - is_first_k, zp_A, zp_B); + matmul_8x8x8::kern_8x8( + packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B); - arm_common::ConvBiasMatmul::postprocess(bias, workspace, - output, LDC, op); + arm_common::ConvBiasMatmul:: + postprocess(bias, workspace, output, LDC, op); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4, - is_first_k, std::min(N - n, 4), - zp_A, zp_B); + matmul_8x8x8::kern_8x4( + packA, cur_packB, K, workspace, 4, is_first_k, + std::min(N - n, 4), zp_A, zp_B); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); DISPATCH_N(cb, 8, std::min(N - n, 4)); #undef cb - output += 4; cur_packB += K4; } @@ -179,9 +175,9 @@ struct KernCaller { const dt_uint8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8, - is_first_k, std::min(M - m, 4), - zp_A, zp_B); + matmul_8x8x8::kern_4x8( + packA, cur_packB, K, workspace, 8, is_first_k, + std::min(M - m, 4), zp_A, zp_B); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); @@ -193,17 +189,16 @@ struct KernCaller { } for (; n < N; n += 4) { - matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4, - is_first_k, std::min(M - m, 4), - std::min(N - n, 4), zp_A, zp_B); + matmul_8x8x8::kern_4x4( + packA, cur_packB, K, workspace, 4, is_first_k, + std::min(M - m, 4), std::min(N - n, 4), zp_A, + zp_B); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); - DISPATCH_M(cb, std::min(M - m, 4), - std::min(N - n, 4)); + DISPATCH_M(cb, std::min(M - m, 4), std::min(N - n, 4)); #undef cb - output += 4; cur_packB += K4; } @@ -219,27 +214,27 @@ struct KernCaller { #if MGB_ENABLE_DOT MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity) -void gemm_u8_8x8_dot_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, - int ldin, int y0, int ymax, int k0, - int kmax, bool transpose) const { +void gemm_u8_8x8_dot_nobias_identity::pack_A( + uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { if (transpose) { - matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, - ymax, k0, kmax); + matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( + outptr, inptr, ldin, y0, ymax, k0, kmax); } else { - matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, - y0, ymax, k0, kmax); + matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( + outptr, inptr, ldin, y0, ymax, k0, kmax); } } -void gemm_u8_8x8_dot_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, - int ldin, int x0, int xmax, int k0, - int kmax, bool transpose) const { +void gemm_u8_8x8_dot_nobias_identity::pack_B( + uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, - xmax, k0, kmax); + matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( + out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, - k0, kmax); + matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( + out, in, ldin, x0, xmax, k0, kmax); } } @@ -249,30 +244,27 @@ size_t gemm_u8_8x8_dot_nobias_identity::get_workspace_size() const { #endif MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity) -void gemm_u8_8x8_nodot_nobias_identity::pack_A(dt_uint8* outptr, - const dt_uint8* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +void gemm_u8_8x8_nodot_nobias_identity::pack_A( + dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { uint8_t zA = A_dtype.param().zero_point; if (transpose) { - matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, - ymax, k0, kmax, zA); + matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n( + outptr, inptr, ldin, y0, ymax, k0, kmax, zA); } else { - matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, - kmax, zA); + matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax, zA); } } -void gemm_u8_8x8_nodot_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, - int ldin, int x0, int xmax, int k0, - int kmax, bool transpose) const { +void gemm_u8_8x8_nodot_nobias_identity::pack_B( + dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { uint8_t zB = B_dtype.param().zero_point; if (transpose) { - matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, - k0, kmax, zB); + matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n( + out, in, ldin, x0, xmax, k0, kmax, zB); } else { - matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, - zB); + matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, zB); } } @@ -280,22 +272,21 @@ size_t gemm_u8_8x8_nodot_nobias_identity::get_workspace_size() const { return 8 * 8 * sizeof(dt_int32); } -#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, \ - _OP) \ - void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline:: \ - kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, \ - size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ - const dt_int32* bias, dt_int32* workspace) const { \ - float scale_A = A_dtype.param().scale; \ - uint8_t zp_A = A_dtype.param().zero_point; \ - float scale_B = B_dtype.param().scale; \ - uint8_t zp_B = B_dtype.param().zero_point; \ - float scale_C = C_dtype.param().scale; \ - uint8_t zp_C = C_dtype.param().zero_point; \ - DEFINE_OP(_OP); \ - impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ - packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ - workspace, zp_A, zp_B); \ +#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, _OP) \ + void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline::kern( \ + const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ + size_t K, dt_uint8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ + dt_int32* workspace) const { \ + float scale_A = A_dtype.param().scale; \ + uint8_t zp_A = A_dtype.param().zero_point; \ + float scale_B = B_dtype.param().scale; \ + uint8_t zp_B = B_dtype.param().zero_point; \ + float scale_C = C_dtype.param().scale; \ + uint8_t zp_C = C_dtype.param().zero_point; \ + DEFINE_OP(_OP); \ + impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ + packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace, zp_A, \ + zp_B); \ } #define DEFINE_OP(_Op) \ @@ -311,17 +302,22 @@ KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, relu, ReluOp) KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) #undef DEFINE_OP -#define DEFINE_OP(_Op) \ - arm_common::_Op op(scale_A* scale_B, \ - scale_A* scale_B, scale_C, zp_C); +#define DEFINE_OP(_Op) \ + arm_common::_Op op( \ + scale_A* scale_B, scale_A* scale_B, scale_C, zp_C); #if MGB_ENABLE_DOT KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) -KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) -KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) +KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, + FuseAddReluOp) +KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, + FuseAddHSwishOp) #endif -KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) -KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) -KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) +KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, + AddOp) +KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, + FuseAddReluOp) +KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, + FuseAddHSwishOp) #undef DEFINE_OP #undef KERN diff --git a/dnn/src/aarch64/conv_bias/quint8/strategy.h b/dnn/src/aarch64/conv_bias/quint8/strategy.h index 4562c84b..31050c25 100644 --- a/dnn/src/aarch64/conv_bias/quint8/strategy.h +++ b/dnn/src/aarch64/conv_bias/quint8/strategy.h @@ -16,46 +16,44 @@ namespace aarch64 { namespace matmul { #if MGB_ENABLE_DOT -MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, - false, true, - gemm_u8_8x8_dot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( + dt_uint8, dt_uint8, dt_int32, 8, 8, 4, false, true, + gemm_u8_8x8_dot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_relu, - gemm_u8_8x8_dot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_8x8_dot_nobias_relu, gemm_u8_8x8_dot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_hswish, - gemm_u8_8x8_dot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_8x8_dot_nobias_hswish, gemm_u8_8x8_dot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_identity, - gemm_u8_8x8_dot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_8x8_dot_bias_channel_identity, gemm_u8_8x8_dot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_relu, - gemm_u8_8x8_dot_nobias_identity); - -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_hswish, - gemm_u8_8x8_dot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_8x8_dot_bias_channel_relu, gemm_u8_8x8_dot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_8x8_dot_bias_channel_hswish, gemm_u8_8x8_dot_nobias_identity); #endif -MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, - false, true, - gemm_u8_8x8_nodot_nobias_identity); - -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_relu, - gemm_u8_8x8_nodot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( + dt_uint8, dt_uint8, dt_int32, 8, 8, 8, false, true, + gemm_u8_8x8_nodot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_hswish, - gemm_u8_8x8_nodot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_8x8_nodot_nobias_relu, gemm_u8_8x8_nodot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_identity, - gemm_u8_8x8_nodot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_8x8_nodot_nobias_hswish, gemm_u8_8x8_nodot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_relu, - gemm_u8_8x8_nodot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_8x8_nodot_bias_channel_identity, gemm_u8_8x8_nodot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_hswish, - gemm_u8_8x8_nodot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_8x8_nodot_bias_channel_relu, gemm_u8_8x8_nodot_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_8x8_nodot_bias_channel_hswish, gemm_u8_8x8_nodot_nobias_identity); } // namespace matmul } // namespace aarch64 diff --git a/dnn/src/aarch64/handle.cpp b/dnn/src/aarch64/handle.cpp index b84e3001..cec1fc4b 100644 --- a/dnn/src/aarch64/handle.cpp +++ b/dnn/src/aarch64/handle.cpp @@ -11,11 +11,11 @@ #include "src/common/handle_impl.h" +#include "src/aarch64/conv_bias/opr_impl.h" #include "src/aarch64/handle.h" #include "src/aarch64/matrix_mul/opr_impl.h" -#include "src/aarch64/rotate/opr_impl.h" #include "src/aarch64/relayout/opr_impl.h" -#include "src/aarch64/conv_bias/opr_impl.h" +#include "src/aarch64/rotate/opr_impl.h" #include "src/aarch64/warp_perspective/opr_impl.h" namespace megdnn { @@ -38,7 +38,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective) MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) #pragma GCC diagnostic pop -} // namespace aarch64 -} // namespace megdnn +} // namespace aarch64 +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/handle.h b/dnn/src/aarch64/handle.h index 9e97a108..7db99c63 100644 --- a/dnn/src/aarch64/handle.h +++ b/dnn/src/aarch64/handle.h @@ -14,20 +14,18 @@ namespace megdnn { namespace aarch64 { -class HandleImpl: public arm_common::HandleImpl { - public: - HandleImpl(megcoreComputingHandle_t computing_handle, - HandleType type = HandleType::AARCH64): - arm_common::HandleImpl::HandleImpl(computing_handle, type) - {} +class HandleImpl : public arm_common::HandleImpl { +public: + HandleImpl( + megcoreComputingHandle_t computing_handle, + HandleType type = HandleType::AARCH64) + : arm_common::HandleImpl::HandleImpl(computing_handle, type) {} - template - std::unique_ptr create_operator(); + template + std::unique_ptr create_operator(); }; -} // namespace aarch64 -} // namespace megdnn +} // namespace aarch64 +} // namespace megdnn // vim: syntax=cpp.doxygen - - diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index 13104c8c..4e4d3c53 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -14,10 +14,10 @@ #include "src/aarch64/matrix_mul/fp16/strategy.h" #include "src/aarch64/matrix_mul/fp32/strategy.h" #include "src/aarch64/matrix_mul/int16/strategy.h" +#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" #include "src/aarch64/matrix_mul/int8/strategy.h" #include "src/aarch64/matrix_mul/int8_dot/strategy.h" #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" -#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" #include "src/aarch64/matrix_mul/quint8/strategy.h" #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" #include "src/aarch64/matrix_mul/quint8_dot/strategy.h" @@ -32,8 +32,7 @@ using namespace megdnn; using namespace aarch64; /* ===================== F32K8X12X1 algo ===================== */ -bool MatrixMulImpl::AlgoF32K8x12x1::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF32K8x12x1::usable(const KernSizeParam& kern_size_param) const { return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.B_type == kern_size_param.A_type && kern_size_param.C_type == kern_size_param.A_type && @@ -43,10 +42,10 @@ bool MatrixMulImpl::AlgoF32K8x12x1::usable( size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -62,33 +61,29 @@ size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace( MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( const KernSizeParam&) const { auto f32_kern_8x12 = [](const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoF32K8x12x1::get_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32K8x12x1::get_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; - auto LDA = kern_param.LDA, LDB = kern_param.LDB, - LDC = kern_param.LDC; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::sgemm_8x12 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::sgemm_8x12 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); }; return f32_kern_8x12; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, - "AlgoF32K8x12x1Impl"_hash, - aarch64::matmul::sgemm_8x12, float, float, - AlgoDataType::FLOAT32, DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, "AlgoF32K8x12x1Impl"_hash, + aarch64::matmul::sgemm_8x12, float, float, AlgoDataType::FLOAT32, DEFAULT); /* ===================== F32_MK4_8X12X1 algo ===================== */ bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable( @@ -98,21 +93,20 @@ bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable( kern_size_param.C_type == kern_size_param.A_type && kern_size_param.A_type == dtype::Float32() && kern_size_param.format == param::MatrixMul::Format::MK4 && - !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; + !kern_size_param.trA && !kern_size_param.trB && kern_size_param.M % 4 == 0 && + kern_size_param.K % 4 == 0; } size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoF32MK4_8x12x1::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32MK4_8x12x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type, C_type); return megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) .get_workspace_size(); @@ -124,38 +118,32 @@ size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace( MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern( const KernSizeParam&) const { auto f32_kern_mk4_8x12 = [](const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoF32MK4_8x12x1::get_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32MK4_8x12x1::get_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; - auto LDA = kern_param.LDA, LDB = kern_param.LDB, - LDC = kern_param.LDC; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::sgemm_mk4_8x12 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); }; return f32_kern_mk4_8x12; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1, - megdnn_aarch64_matmul_kern, - "AlgoF32MK4_8x12x1Impl"_hash, - aarch64::matmul::sgemm_mk4_8x12, float, - float, AlgoDataType::FLOAT32, MK4); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoF32MK4_8x12x1, megdnn_aarch64_matmul_kern, "AlgoF32MK4_8x12x1Impl"_hash, + aarch64::matmul::sgemm_mk4_8x12, float, float, AlgoDataType::FLOAT32, MK4); /* ===================== F32K4X16X1 algo ===================== */ -bool MatrixMulImpl::AlgoF32K4x16x1::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF32K4x16x1::usable(const KernSizeParam& kern_size_param) const { return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.B_type == kern_size_param.A_type && kern_size_param.C_type == kern_size_param.A_type && @@ -165,10 +153,10 @@ bool MatrixMulImpl::AlgoF32K4x16x1::usable( size_t MatrixMulImpl::AlgoF32K4x16x1::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoF32K4x16x1::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32K4x16x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -184,33 +172,29 @@ size_t MatrixMulImpl::AlgoF32K4x16x1::get_workspace( MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern( const KernSizeParam&) const { auto f32_kern_4x16 = [](const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoF32K4x16x1::get_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32K4x16x1::get_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; - auto LDA = kern_param.LDA, LDB = kern_param.LDB, - LDC = kern_param.LDC; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::sgemm_4x16 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::sgemm_4x16 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); }; return f32_kern_4x16; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K4x16x1, megdnn_aarch64_matmul_kern, - "AlgoF32K4x16x1Impl"_hash, - aarch64::matmul::sgemm_4x16, float, float, - AlgoDataType::FLOAT32, MK4); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoF32K4x16x1, megdnn_aarch64_matmul_kern, "AlgoF32K4x16x1Impl"_hash, + aarch64::matmul::sgemm_4x16, float, float, AlgoDataType::FLOAT32, MK4); /* ===================== F32MK4_4x16 algo ===================== */ @@ -226,17 +210,17 @@ bool MatrixMulImpl::AlgoF32MK4_4x16::usable( size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoF32MK4_4x16::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32MK4_4x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; aarch64::matmul::sgemm_nopack_4x16 strategy(A_type, B_type, C_type); return megdnn::matmul::GemmInterleaved< - aarch64::matmul::sgemm_nopack_4x16, false>(M, N, K, trA, - trB, strategy) + aarch64::matmul::sgemm_nopack_4x16, false>( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -246,23 +230,21 @@ size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace( MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x16::get_kern( const KernSizeParam&) const { auto f32_kern_mk4_4x16 = [](const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoF32MK4_4x16::get_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoF32MK4_4x16::get_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; - auto LDA = kern_param.LDA, LDB = kern_param.LDB, - LDC = kern_param.LDC; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); aarch64::matmul::sgemm_nopack_4x16 strategy(A_type, B_type, C_type); - megdnn::matmul::GemmInterleaved(M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); }; @@ -279,22 +261,19 @@ void f16_kern(const MatrixMulImpl::KernParam& kern_param) { auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); aarch64::matmul::hgemm_8x24 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } } // anonymous namespace -bool MatrixMulImpl::AlgoF16K8x24x1::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF16K8x24x1::usable(const KernSizeParam& kern_size_param) const { return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.format == param::MatrixMul::Format::DEFAULT && kern_size_param.C_type == kern_size_param.A_type && @@ -304,10 +283,10 @@ bool MatrixMulImpl::AlgoF16K8x24x1::usable( size_t MatrixMulImpl::AlgoF16K8x24x1::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoF16K8x24x1::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoF16K8x24x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -325,15 +304,13 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern( return f16_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K8x24x1, megdnn_aarch64_matmul_kern, - "AlogF16K8x24x1Impl"_hash, - aarch64::matmul::hgemm_8x24, dt_float16, - dt_float16, AlgoDataType::FLOAT16, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoF16K8x24x1, megdnn_aarch64_matmul_kern, "AlogF16K8x24x1Impl"_hash, + aarch64::matmul::hgemm_8x24, dt_float16, dt_float16, AlgoDataType::FLOAT16, + DEFAULT); /* ===================== F16_MK8_8x8 algo ===================== */ -bool MatrixMulImpl::AlgoF16MK8_8x8::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF16MK8_8x8::usable(const KernSizeParam& kern_size_param) const { return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.C_type == kern_size_param.A_type && kern_size_param.B_type == kern_size_param.A_type && @@ -344,10 +321,10 @@ bool MatrixMulImpl::AlgoF16MK8_8x8::usable( size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoF16MK8_8x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoF16MK8_8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -364,25 +341,23 @@ size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace( MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern( const KernSizeParam&) const { auto kern_mk8_8x8 = [](const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoF16MK8_8x8::get_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoF16MK8_8x8::get_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; - auto LDA = kern_param.LDA, LDB = kern_param.LDB, - LDC = kern_param.LDC; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_nopack_f16_8x8 strategy(A_type, B_type, - C_type); + aarch64::matmul::gemm_nopack_f16_8x8 strategy(A_type, B_type, C_type); megdnn::matmul::GemmInterleaved< - aarch64::matmul::gemm_nopack_f16_8x8, false>(M, N, K, trA, - trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + aarch64::matmul::gemm_nopack_f16_8x8, false>( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); }; @@ -394,24 +369,22 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern( #if MGB_ENABLE_DOT /* ==================== Int8x8x32 K8x12x4 Dotprod algo ==================== */ namespace { -void int8x8x32_k8x12x4_dotprod_kern( - const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int8x8x32_k8x12x4_dotprod_kern"_hash)) { +void int8x8x32_k8x12x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("int8x8x32_k8x12x4_dotprod_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); aarch64::matmul::gemm_s8_8x12 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -419,7 +392,7 @@ void int8x8x32_k8x12x4_dotprod_kern( bool MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::usable( const KernSizeParam& kern_size_param) const { - if (!cpuinfo_has_arm_neon_dot()){ + if (!cpuinfo_has_arm_neon_dot()) { return false; } return can_be_treated_as_int8x8x32(kern_size_param); @@ -427,10 +400,10 @@ bool MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::usable( size_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt8x8x32K8x12x4DotProd::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x32K8x12x4DotProd::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -449,34 +422,29 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_kern( return int8x8x32_k8x12x4_dotprod_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd, - megdnn_aarch64_matmul_kern, - "AlgoInt8x8x32K8x12x4DotProdImpl"_hash, - aarch64::matmul::gemm_s8_8x12, int8_t, - int32_t, AlgoDataType::QINT8X8X32, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x32K8x12x4DotProd, megdnn_aarch64_matmul_kern, + "AlgoInt8x8x32K8x12x4DotProdImpl"_hash, aarch64::matmul::gemm_s8_8x12, int8_t, + int32_t, AlgoDataType::QINT8X8X32, DEFAULT); /* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ namespace { -void int8x8x32_mk4_8x12x4_dotprod_kern( - const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int8x8x32_mk4_8x12x4_dotprod_kern"_hash)) { +void int8x8x32_mk4_8x12x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("int8x8x32_mk4_8x12x4_dotprod_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -484,8 +452,7 @@ void int8x8x32_mk4_8x12x4_dotprod_kern( bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable( const KernSizeParam& kern_size_param) const { - - if (!cpuinfo_has_arm_neon_dot()){ + if (!cpuinfo_has_arm_neon_dot()) { return false; } @@ -504,17 +471,14 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace( MIDOUT_BEGIN( megdnn_aarch64_matmul_kern, midout_iv("AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type, - C_type); - return megdnn::matmul::GemmInterleaved< - aarch64::matmul::gemm_mk4_s8_8x12>(M, N, K, trA, trB, - strategy) + aarch64::matmul::gemm_mk4_s8_8x12 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -526,48 +490,42 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern( return int8x8x32_mk4_8x12x4_dotprod_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd, - megdnn_aarch64_matmul_kern, - "AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash, - aarch64::matmul::gemm_mk4_s8_8x12, int8_t, - int32_t, AlgoDataType::QINT8X8X32, - MK4_DOT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x32MK4_8x12x4DotProd, megdnn_aarch64_matmul_kern, + "AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash, aarch64::matmul::gemm_mk4_s8_8x12, + int8_t, int32_t, AlgoDataType::QINT8X8X32, MK4_DOT); #endif /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ namespace { void int8x8x32_mk4_4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int8x8x32_mk4_4x4x16_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, midout_iv("int8x8x32_mk4_4x4x16_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } } // anonymous namespace -bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::usable( - const KernSizeParam& param) const { +bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::usable(const KernSizeParam& param) const { return param.A_type.enumv() == param.B_type.enumv() && (param.A_type.enumv() == DTypeEnum::Int8 || param.A_type.enumv() == DTypeEnum::QuantizedS8) && (param.C_type.enumv() == DTypeEnum::Int32 || param.C_type.enumv() == DTypeEnum::QuantizedS32) && param.compute_mode == Param::ComputeMode::DEFAULT && - param.format == param::MatrixMul::Format::MK4 && !param.trA && - !param.trB && param.M % 4 == 0 && param.K % 4 == 0; + param.format == param::MatrixMul::Format::MK4 && !param.trA && !param.trB && + param.M % 4 == 0 && param.K % 4 == 0; } bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::preferred( @@ -577,18 +535,16 @@ bool MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::preferred( size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt8x8x32MK4_4x4x16::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x32MK4_4x4x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type, - C_type); - return megdnn::matmul::GemmInterleaved< - aarch64::matmul::gemm_mk4_s8_4x4>(M, N, K, trA, trB, - strategy) + aarch64::matmul::gemm_mk4_s8_4x4 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -600,32 +556,27 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_kern( return int8x8x32_mk4_4x4x16_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x4x16, - megdnn_aarch64_matmul_kern, - "AlgoInt8x8x32MK4_4x4x16Impl"_hash, - aarch64::matmul::gemm_mk4_s8_4x4, int8_t, - int32_t, AlgoDataType::QINT8X8X32, - MK4); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x32MK4_4x4x16, megdnn_aarch64_matmul_kern, + "AlgoInt8x8x32MK4_4x4x16Impl"_hash, aarch64::matmul::gemm_mk4_s8_4x4, int8_t, + int32_t, AlgoDataType::QINT8X8X32, MK4); /* ===================== Int8x8x32 K4x4x16 algo ===================== */ namespace { void int8x8x32_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int8x8x32_k4x4x16_kern"_hash)) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("int8x8x32_k4x4x16_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); aarch64::matmul::gemm_s8_4x4 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -643,10 +594,10 @@ bool MatrixMulImpl::AlgoInt8x8x32K4x4x16::preferred( size_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt8x8x32K4x4x16::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x32K4x4x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -664,31 +615,26 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_kern( return int8x8x32_k4x4x16_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x4x16, - megdnn_aarch64_matmul_kern, - "AlgoInt8x8x32K4x4x16Impl"_hash, - aarch64::matmul::gemm_s8_4x4, int8_t, - int32_t, AlgoDataType::QINT8X8X32, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x32K4x4x16, megdnn_aarch64_matmul_kern, + "AlgoInt8x8x32K4x4x16Impl"_hash, aarch64::matmul::gemm_s8_4x4, int8_t, int32_t, + AlgoDataType::QINT8X8X32, DEFAULT); /* ===================== Int8x8x32 K8x8x8 algo ===================== */ namespace { void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int8x8x32_k8x8x8_kern"_hash)) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("int8x8x32_k8x8x8_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); aarch64::matmul::gemm_s8_8x8 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -706,10 +652,10 @@ bool MatrixMulImpl::AlgoInt8x8x32K8x8x8::preferred( size_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt8x8x32K8x8x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x32K8x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -726,33 +672,27 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_kern( const KernSizeParam&) const { return int8x8x32_k8x8x8_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8, - megdnn_aarch64_matmul_kern, - "AlgoInt8x8x32K8x8x8Impl"_hash, - aarch64::matmul::gemm_s8_8x8, int8_t, - int32_t, AlgoDataType::QINT8X8X32, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x32K8x8x8, megdnn_aarch64_matmul_kern, "AlgoInt8x8x32K8x8x8Impl"_hash, + aarch64::matmul::gemm_s8_8x8, int8_t, int32_t, AlgoDataType::QINT8X8X32, + DEFAULT); /* ===================== Int8x8x16 K8x8x8 algo ===================== */ namespace { void int8x8x16_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int8x8x16_k8x8x8_kern"_hash)) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("int8x8x16_k8x8x8_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -772,15 +712,14 @@ bool MatrixMulImpl::AlgoInt8x8x16K8x8x8::preferred( size_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt8x8x16K8x8x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x16K8x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, A_type, B_type, C_type); return megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) .get_workspace_size(); @@ -794,31 +733,26 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_kern( return int8x8x16_k8x8x8_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x8, - megdnn_aarch64_matmul_kern, - "AlgoInt8x8x16K8x8x8Impl"_hash, - aarch64::matmul::gemm_s8x8x16_8x8, int8_t, - int16_t, AlgoDataType::INT8X8X16, DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x16K8x8x8, megdnn_aarch64_matmul_kern, "AlgoInt8x8x16K8x8x8Impl"_hash, + aarch64::matmul::gemm_s8x8x16_8x8, int8_t, int16_t, AlgoDataType::INT8X8X16, + DEFAULT); /* ===================== Int8x8x16 K4x4x16 algo ===================== */ namespace { void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int8x8x16_k4x4x16_kern"_hash)) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("int8x8x16_k4x4x16_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -839,15 +773,14 @@ bool MatrixMulImpl::AlgoInt8x8x16K4x4x16::preferred( size_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt8x8x16K4x4x16::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x16K4x4x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::gemm_s8x8x16_4x4 strategy(M, N, K, A_type, B_type, C_type); return megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) .get_workspace_size(); @@ -861,33 +794,29 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_kern( return int8x8x16_k4x4x16_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16, - megdnn_aarch64_matmul_kern, - "AlgoInt8x8x16K4x4x16Impl"_hash, - aarch64::matmul::gemm_s8x8x16_4x4, int8_t, - int16_t, AlgoDataType::INT8X8X16, DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x16K4x4x16, megdnn_aarch64_matmul_kern, + "AlgoInt8x8x16K4x4x16Impl"_hash, aarch64::matmul::gemm_s8x8x16_4x4, int8_t, + int16_t, AlgoDataType::INT8X8X16, DEFAULT); /* ===================== Int8x8x16 K16x12x4 algo ===================== */ namespace { void int8x8x16_mk4_16x12x4_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int8x8x16_mk4_16x12x4_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, midout_iv("int8x8x16_mk4_16x12x4_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type, - B_type, C_type); - megdnn::matmul::GemmInterleaved< - aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB, - strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy( + M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -898,12 +827,11 @@ bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::usable( return can_be_treated_as_int8x8x16(kern_size_param) && kern_size_param.format == param::MatrixMul::Format::MK4 && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && - !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; + !kern_size_param.trA && !kern_size_param.trB && kern_size_param.M % 4 == 0 && + kern_size_param.K % 4 == 0; } -bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::preferred( - const KernSizeParam&) const { +bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::preferred(const KernSizeParam&) const { #if !MGB_ENABLE_CPUINFO return false; #else @@ -911,26 +839,25 @@ bool MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::preferred( #ifdef __IN_TEE_ENV__ arch = cpuinfo_uarch_unknown; #endif - bool little_core = arch == cpuinfo_uarch_cortex_a53 || - arch == cpuinfo_uarch_cortex_a55; + bool little_core = + arch == cpuinfo_uarch_cortex_a53 || arch == cpuinfo_uarch_cortex_a55; return little_core; #endif } size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt8x8x16MK4_16x12x4::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x16MK4_16x12x4::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy(M, N, K, A_type, - B_type, C_type); - return megdnn::matmul::GemmInterleaved< - matmul::gemm_s8x8x16_mk4_16x12_a53>(M, N, K, trA, trB, - strategy) + aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53 strategy( + M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -951,24 +878,21 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( /* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */ namespace { void int8x8x16_mk4_4x4x8_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int8x8x16_mk4_4x4x8_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, midout_iv("int8x8x16_mk4_4x4x8_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type, - B_type, C_type); - megdnn::matmul::GemmInterleaved< - aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB, - strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy( + M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -979,12 +903,11 @@ bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::usable( return can_be_treated_as_int8x8x16(kern_size_param) && kern_size_param.format == param::MatrixMul::Format::MK4 && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && - !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; + !kern_size_param.trA && !kern_size_param.trB && kern_size_param.M % 4 == 0 && + kern_size_param.K % 4 == 0; } -bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::preferred( - const KernSizeParam&) const { +bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::preferred(const KernSizeParam&) const { #if !MGB_ENABLE_CPUINFO return false; #else @@ -992,26 +915,25 @@ bool MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::preferred( #ifdef __IN_TEE_ENV__ arch = cpuinfo_uarch_unknown; #endif - bool little_core = arch == cpuinfo_uarch_cortex_a53 || - arch == cpuinfo_uarch_cortex_a55; + bool little_core = + arch == cpuinfo_uarch_cortex_a53 || arch == cpuinfo_uarch_cortex_a55; return !little_core; #endif } size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt8x8x16MK4_4x4x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x16MK4_4x4x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy(M, N, K, A_type, - B_type, C_type); - return megdnn::matmul::GemmInterleaved< - matmul::gemm_s8x8x16_mk4_4x4_a72>(M, N, K, trA, trB, - strategy) + aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72 strategy( + M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -1023,33 +945,28 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern( return int8x8x16_mk4_4x4x8_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8, - megdnn_aarch64_matmul_kern, - "AlgoInt8x8x16MK4_4x4x8_Impl"_hash, - aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72, - int8_t, int16_t, AlgoDataType::INT8X8X16, - MK4); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x16MK4_4x4x8, megdnn_aarch64_matmul_kern, + "AlgoInt8x8x16MK4_4x4x8_Impl"_hash, aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72, + int8_t, int16_t, AlgoDataType::INT8X8X16, MK4); /* ===================== Int16x16x32 K12x8x1 algo ===================== */ namespace { void int16x16x32_k12x8x1_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int16x16x32_k12x8x1_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, midout_iv("int16x16x32_k12x8x1_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -1059,8 +976,7 @@ bool MatrixMulImpl::AlgoInt16x16x32K12x8x1::usable( const KernSizeParam& kern_size_param) const { return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && kern_size_param.format == param::MatrixMul::Format::DEFAULT && - kern_size_param.compute_mode == - param::MatrixMul::ComputeMode::DEFAULT && + kern_size_param.compute_mode == param::MatrixMul::ComputeMode::DEFAULT && kern_size_param.A_type.enumv() == DTypeEnum::Int16 && kern_size_param.C_type.enumv() == DTypeEnum::Int32; } @@ -1073,15 +989,14 @@ bool MatrixMulImpl::AlgoInt16x16x32K12x8x1::preferred( size_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt16x16x32K12x8x1::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt16x16x32K12x8x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::gemm_s16_12x8x1 strategy(M, N, K, A_type, B_type, C_type); return megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) .get_workspace_size(); @@ -1095,12 +1010,10 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_kern( return int16x16x32_k12x8x1_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x8x1, - megdnn_aarch64_matmul_kern, - "AlgoInt16x16x32K12x8x1Impl"_hash, - aarch64::matmul::gemm_s16_12x8x1, int16_t, - int32_t, AlgoDataType::INT16X16X32, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt16x16x32K12x8x1, megdnn_aarch64_matmul_kern, + "AlgoInt16x16x32K12x8x1Impl"_hash, aarch64::matmul::gemm_s16_12x8x1, int16_t, + int32_t, AlgoDataType::INT16X16X32, DEFAULT); /* ===================== Int16x16x32MK8_8x8 algo ===================== */ @@ -1116,10 +1029,10 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_8x8::usable( size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt16x16x32MK8_8x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt16x16x32MK8_8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -1136,25 +1049,22 @@ size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace( MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern( const KernSizeParam&) const { auto kern_mk8_8x8 = [](const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt16x16x32MK8_8x8::get_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt16x16x32MK8_8x8::get_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; - auto LDA = kern_param.LDA, LDB = kern_param.LDB, - LDC = kern_param.LDC; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_nopack_s16_8x8 strategy(A_type, B_type, - C_type); + aarch64::matmul::gemm_nopack_s16_8x8 strategy(A_type, B_type, C_type); megdnn::matmul::GemmInterleaved< - aarch64::matmul::gemm_nopack_s16_8x8, false>(M, N, K, trA, - trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + aarch64::matmul::gemm_nopack_s16_8x8, false>( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); }; @@ -1165,22 +1075,20 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern( /* ==================== Quint8 K8x8x4 Dotprod algo ==================== */ namespace { void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("quint8_k8x8x4_dotprod_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, midout_iv("quint8_k8x8x4_dotprod_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); aarch64::matmul::gemm_u8_8x8_dot strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -1188,7 +1096,7 @@ void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoQuint8K8x8x4DotProd::usable( const KernSizeParam& kern_size_param) const { - if (!cpuinfo_has_arm_neon_dot()){ + if (!cpuinfo_has_arm_neon_dot()) { return false; } return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && @@ -1200,10 +1108,10 @@ bool MatrixMulImpl::AlgoQuint8K8x8x4DotProd::usable( size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoQuint8K8x8x4DotProd::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoQuint8K8x8x4DotProd::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -1222,21 +1130,18 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern( return quint8_k8x8x4_dotprod_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd, - megdnn_aarch64_matmul_kern, - "AlgoQuint8K8x8x4DotProdImpl"_hash, - aarch64::matmul::gemm_u8_8x8_dot, uint8_t, - int32_t, AlgoDataType::QUINT8X8X32, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoQuint8K8x8x4DotProd, megdnn_aarch64_matmul_kern, + "AlgoQuint8K8x8x4DotProdImpl"_hash, aarch64::matmul::gemm_u8_8x8_dot, uint8_t, + int32_t, AlgoDataType::QUINT8X8X32, DEFAULT); /* ===================== Quint8 Gemv DotProd algo ===================== */ namespace { void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("quint8_gemv_dotprod_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, midout_iv("quint8_gemv_dotprod_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); auto A_type = kern_param.A_type, B_type = kern_param.B_type; @@ -1251,7 +1156,7 @@ void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoQuint8GemvDotProd::usable( const KernSizeParam& kern_size_param) const { - if (!cpuinfo_has_arm_neon_dot()){ + if (!cpuinfo_has_arm_neon_dot()) { return false; } return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && @@ -1259,8 +1164,8 @@ bool MatrixMulImpl::AlgoQuint8GemvDotProd::usable( kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.format == param::MatrixMul::Format::DEFAULT && - !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.N == 1 && kern_size_param.LDB == 1; + !kern_size_param.trA && !kern_size_param.trB && kern_size_param.N == 1 && + kern_size_param.LDB == 1; } bool MatrixMulImpl::AlgoQuint8GemvDotProd::preferred( @@ -1278,22 +1183,20 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8GemvDotProd::get_kern( /* ===================== Quint8 K8x8x8 algo ===================== */ namespace { void quint8_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("quint8_gemv_dotprod_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, midout_iv("quint8_gemv_dotprod_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -1310,10 +1213,10 @@ bool MatrixMulImpl::AlgoQuint8K8x8x8::usable( size_t MatrixMulImpl::AlgoQuint8K8x8x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoQuint8K8x8x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoQuint8K8x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -1332,34 +1235,29 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x8::get_kern( return quint8_k8x8x8_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8, - megdnn_aarch64_matmul_kern, - "AlgoQuint8K8x8x8Impl"_hash, - aarch64::matmul::gemm_u8_8x8, uint8_t, - int32_t, AlgoDataType::QUINT8X8X32, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoQuint8K8x8x8, megdnn_aarch64_matmul_kern, "AlgoQuint8K8x8x8Impl"_hash, + aarch64::matmul::gemm_u8_8x8, uint8_t, int32_t, AlgoDataType::QUINT8X8X32, + DEFAULT); /* ===================== Int8x8x16 K8x8x8 algo ===================== */ namespace { void int8x8x16_mk4_8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int8x8x16_mk4_8x8x8_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, midout_iv("int8x8x16_mk4_8x8x8_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type, - B_type, C_type); - megdnn::matmul::GemmInterleaved< - aarch64::matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB, - strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy( + M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -1370,29 +1268,27 @@ bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::usable( return can_be_treated_as_int8x8x16(kern_size_param) && kern_size_param.format == param::MatrixMul::Format::MK4 && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && - !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; + !kern_size_param.trA && !kern_size_param.trB && kern_size_param.M % 4 == 0 && + kern_size_param.K % 4 == 0; } -bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::preferred( - const KernSizeParam&) const { +bool MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::preferred(const KernSizeParam&) const { return true; } size_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt8x8x16_MK4_8x8x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt8x8x16_MK4_8x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy(M, N, K, A_type, - B_type, C_type); - return megdnn::matmul::GemmInterleaved< - matmul::gemm_s8x8x16_mk4_8x8x8>(M, N, K, trA, trB, - strategy) + aarch64::matmul::gemm_s8x8x16_mk4_8x8x8 strategy( + M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -1404,33 +1300,28 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern( return int8x8x16_mk4_8x8x8_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, - megdnn_aarch64_matmul_kern, - "AlgoInt8x8x16MK4_K8x8x8Impl"_hash, - aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, - int8_t, int16_t, AlgoDataType::INT8X8X16, - MK4); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x16MK4_K8x8x8, megdnn_aarch64_matmul_kern, + "AlgoInt8x8x16MK4_K8x8x8Impl"_hash, aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, + int8_t, int16_t, AlgoDataType::INT8X8X16, MK4); /* ===================== Int4x4x16 K8x8x8 algo ===================== */ namespace { void int4x4x16_k8x8x16_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("int4x4x16_k8x8x8_kern"_hash)) { + MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, midout_iv("int4x4x16_k8x8x8_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy( + M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -1454,15 +1345,15 @@ bool MatrixMulImpl::AlgoInt4x4x16K8x8x8::preferred( size_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_aarch64_matmul_kern, - midout_iv("AlgoInt4x4x16K8x8x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_aarch64_matmul_kern, + midout_iv("AlgoInt4x4x16K8x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy(M, N, K, A_type, B_type, - C_type); + aarch64::matmul::gemm_s4x4x16_s4_8x8x8 strategy( + M, N, K, A_type, B_type, C_type); return megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) .get_workspace_size(); @@ -1475,10 +1366,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt4x4x16K8x8x8::get_kern( return int4x4x16_k8x8x16_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt4x4x16K8x8x8, - megdnn_aarch64_matmul_kern, - "AlgoInt4x4x16K8x8x8Impl"_hash, - aarch64::matmul::gemm_s4x4x16_s4_8x8x8, - int8_t, int16_t, AlgoDataType::INT4X4X16, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt4x4x16K8x8x8, megdnn_aarch64_matmul_kern, "AlgoInt4x4x16K8x8x8Impl"_hash, + aarch64::matmul::gemm_s4x4x16_s4_8x8x8, int8_t, int16_t, + AlgoDataType::INT4X4X16, DEFAULT); // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index 65652acd..23954315 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -21,9 +21,7 @@ namespace aarch64 { class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_F32K8X12X1"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -35,8 +33,7 @@ public: class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } bool usable(const KernSizeParam&) const override; @@ -48,9 +45,7 @@ public: class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_F32K4X16X1"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -61,9 +56,7 @@ public: class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_F32_MK4_4x16"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -73,8 +66,7 @@ public: MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) }; -class MatrixMulImpl::AlgoF32Gemv final - : public arm_common::MatrixMulImpl::AlgoF32Gemv { +class MatrixMulImpl::AlgoF32Gemv final : public arm_common::MatrixMulImpl::AlgoF32Gemv { public: AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { m_handle_type = Handle::HandleType::AARCH64; @@ -85,9 +77,7 @@ public: #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_F16_K8X24X1"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -98,9 +88,7 @@ public: class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_F16_MK8_8X8"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -115,12 +103,8 @@ public: #if MGB_ENABLE_DOT class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } - const char* name() const override { - return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; @@ -130,12 +114,8 @@ public: class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } - const char* name() const override { - return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; @@ -147,8 +127,7 @@ public: class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } bool usable(const KernSizeParam&) const override; @@ -163,9 +142,7 @@ public: class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -178,9 +155,7 @@ public: class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -192,9 +167,7 @@ public: class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -207,9 +180,7 @@ public: class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -222,8 +193,7 @@ public: class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } bool usable(const KernSizeParam&) const override; @@ -238,12 +208,9 @@ public: class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; - } - const char* name() const override { - return "AARCH64_INT8X8X16_MK4_16X12X4"; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } + const char* name() const override { return "AARCH64_INT8X8X16_MK4_16X12X4"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -257,12 +224,9 @@ public: class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; - } - const char* name() const override { - return "AARCH64_INT8X8X16_MK4_K8X8X8"; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } + const char* name() const override { return "AARCH64_INT8X8X16_MK4_K8X8X8"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -276,8 +240,7 @@ public: class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } bool usable(const KernSizeParam&) const override; @@ -292,9 +255,7 @@ public: class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -306,9 +267,7 @@ public: class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -321,12 +280,8 @@ public: #if MGB_ENABLE_DOT class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } - const char* name() const override { - return "AARCH64_QUINT8_K8X8X4_DOTPROD"; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "AARCH64_QUINT8_K8X8X4_DOTPROD"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; @@ -336,8 +291,7 @@ public: class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } bool usable(const KernSizeParam&) const override; @@ -352,9 +306,7 @@ public: #endif class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; diff --git a/dnn/src/aarch64/matrix_mul/asm/common.h b/dnn/src/aarch64/matrix_mul/asm/common.h index 98e829a8..df332025 100644 --- a/dnn/src/aarch64/matrix_mul/asm/common.h +++ b/dnn/src/aarch64/matrix_mul/asm/common.h @@ -21,9 +21,9 @@ namespace megdnn { namespace aarch64 { /* ======================== Prefetch ======================== */ -#define ASM_PREFETCH(address) "PRFM PLDL1KEEP, " address "\n" -#define ASM_PREFETCHL2(address) "PRFM PLDL2KEEP, " address "\n" -#define ASM_PREFETCHW(address) "PRFM PSTL1KEEP, " address "\n" +#define ASM_PREFETCH(address) "PRFM PLDL1KEEP, " address "\n" +#define ASM_PREFETCHL2(address) "PRFM PLDL2KEEP, " address "\n" +#define ASM_PREFETCHW(address) "PRFM PSTL1KEEP, " address "\n" #define ASM_PREFETCHWL2(address) "PRFM PSTL2KEEP, " address "\n" static inline void prefetch_6x(const void* pfp) { @@ -267,11 +267,10 @@ static inline void interleave_16x1_8_h_helper( } template -static inline void interleave_8x1_8_h(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T*& outptr, int skippf = 0) { +static inline void interleave_8x1_8_h( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr, int skippf = 0) { static_assert(sizeof(T) == 2, "only support size == 2"); asm volatile( // Load up 8 elements (1 vector) from each of 8 sources. @@ -347,9 +346,9 @@ static inline void interleave_8x1_8_h(const T*& inptr0, const T*& inptr1, } template -static inline void interleave_4x1_4_h(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x1_4_h( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert(sizeof(T) == 2, "only support size == 2"); asm volatile( "ldr d0, [%[inptr0]], #8\n" // d0 = A0A1A2A3 @@ -368,19 +367,16 @@ static inline void interleave_4x1_4_h(const T*& inptr0, const T*& inptr1, "zip1 v10.4h, v8.4h, v9.4h\n" // d10 = A2B2C2D2 "zip2 v11.4h, v8.4h, v9.4h\n" // d11 = A3B3C3D3 "stp d10, d11, [%[outptr]], #16\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "memory"); } -static inline void interleave_4x1_2_d(const int64_t*& inptr0, - const int64_t*& inptr1, - const int64_t*& inptr2, - const int64_t*& inptr3, - int64_t*& outptr) { +static inline void interleave_4x1_2_d( + const int64_t*& inptr0, const int64_t*& inptr1, const int64_t*& inptr2, + const int64_t*& inptr3, int64_t*& outptr) { asm volatile( "ld1 {v0.2d}, [%[inptr0]], #16\n" // d0 = A0A1 "ld1 {v1.2d}, [%[inptr1]], #16\n" // d1 = B0B1 @@ -396,18 +392,15 @@ static inline void interleave_4x1_2_d(const int64_t*& inptr0, "st1 {v6.2d}, [%[outptr]], #16\n" "st1 {v5.2d}, [%[outptr]], #16\n" "st1 {v7.2d}, [%[outptr]], #16\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); } -static inline void interleave_4x2_2_d(const int64_t*& inptr0, - const int64_t*& inptr1, - const int64_t*& inptr2, - const int64_t*& inptr3, - int64_t*& outptr) { +static inline void interleave_4x2_2_d( + const int64_t*& inptr0, const int64_t*& inptr1, const int64_t*& inptr2, + const int64_t*& inptr3, int64_t*& outptr) { asm volatile( "ld1 {v0.2d}, [%[inptr0]], #16\n" // d0 = A0 "ld1 {v1.2d}, [%[inptr0]], #16\n" // d1 = A1 @@ -426,9 +419,8 @@ static inline void interleave_4x2_2_d(const int64_t*& inptr0, "st1 {v3.2d}, [%[outptr]], #16\n" "st1 {v5.2d}, [%[outptr]], #16\n" "st1 {v7.2d}, [%[outptr]], #16\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); } @@ -437,8 +429,8 @@ static inline void interleave_12x1_4_s( const int32_t*& inptr0, const int32_t*& inptr1, const int32_t*& inptr2, const int32_t*& inptr3, const int32_t*& inptr4, const int32_t*& inptr5, const int32_t*& inptr6, const int32_t*& inptr7, const int32_t*& inptr8, - const int32_t*& inptr9, const int32_t*& inptr10, - const int32_t*& inptr11, int32_t*& outptr) { + const int32_t*& inptr9, const int32_t*& inptr10, const int32_t*& inptr11, + int32_t*& outptr) { asm volatile( "ld1 {v0.4s}, [%[inptr0]], #16\n" // d0 = A0A1A2A3 "ld1 {v1.4s}, [%[inptr1]], #16\n" // d1 = B0B1B2B3 @@ -492,23 +484,22 @@ static inline void interleave_12x1_4_s( "st1 {v7.4s}, [%[outptr]], #16\n" "st1 {v11.4s}, [%[outptr]], #16\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), [inptr11] "+r"(inptr11), + [outptr] "+r"(outptr) : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), - [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), - [inptr11] "+r"(inptr11), [outptr] "+r"(outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "cc", "memory"); } template static inline void interleave_12x1_4_h( - const T*& in0, const T*& in1, const T*& in2, const T*& in3, - const T*& in4, const T*& in5, const T*& in6, const T*& in7, - const T*& in8, const T*& in9, const T*& in10, const T*& in11, T*& out) { + const T*& in0, const T*& in1, const T*& in2, const T*& in3, const T*& in4, + const T*& in5, const T*& in6, const T*& in7, const T*& in8, const T*& in9, + const T*& in10, const T*& in11, T*& out) { static_assert( std::is_same::value || std::is_same::value, "interleave_12x1_4_h only support uint16_t and int16_t"); @@ -578,54 +569,50 @@ static inline void interleave_12x1_4_h( "st1 {v7.4h}, [%[outptr]], #8\n" // d7 = E3F3G3H3 "st1 {v11.4h}, [%[outptr]], #8\n" // d11 = I3J3K3L3 + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), [inptr11] "+r"(inptr11), + [outptr] "+r"(outptr) : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), - [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), - [inptr11] "+r"(inptr11), [outptr] "+r"(outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "cc", "memory"); } template -static inline void interleave_12x4_4_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - const T*& inptr8, const T*& inptr9, - const T*& inptr10, const T*& inptr11, - T*& outptr) { +static inline void interleave_12x4_4_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + const T*& inptr8, const T*& inptr9, const T*& inptr10, const T*& inptr11, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_12x4_4_b only support uint8_t and int8_t"); - interleave_12x1_4_s(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(inptr2), - reinterpret_cast(inptr3), - reinterpret_cast(inptr4), - reinterpret_cast(inptr5), - reinterpret_cast(inptr6), - reinterpret_cast(inptr7), - reinterpret_cast(inptr8), - reinterpret_cast(inptr9), - reinterpret_cast(inptr10), - reinterpret_cast(inptr11), - reinterpret_cast(outptr)); -} - -static inline void interleave_2x1_4_s(const int32_t*& inptr0, - const int32_t*& inptr1, - int32_t*& outptr) { + interleave_12x1_4_s( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(inptr4), + reinterpret_cast(inptr5), + reinterpret_cast(inptr6), + reinterpret_cast(inptr7), + reinterpret_cast(inptr8), + reinterpret_cast(inptr9), + reinterpret_cast(inptr10), + reinterpret_cast(inptr11), + reinterpret_cast(outptr)); +} + +static inline void interleave_2x1_4_s( + const int32_t*& inptr0, const int32_t*& inptr1, int32_t*& outptr) { asm volatile( "ld1 {v0.4s}, [%[inptr0]], #16\n" // d0 = A0A1A2A3 "ld1 {v1.4s}, [%[inptr1]], #16\n" // d1 = B0B1B2B3 "st1 {v0.4s}, [%[outptr]], #16\n" "st1 {v1.4s}, [%[outptr]], #16\n" - : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) : : "v0", "v1", "cc", "memory"); } @@ -670,14 +657,13 @@ static inline void interleave_8x1_4_s( "st1 {v15.4s}, [%[outptr]], #16\n" "st1 {v23.4s}, [%[outptr]], #16\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "cc", "memory"); } static inline void interleave_8x1_2_d( @@ -711,13 +697,12 @@ static inline void interleave_8x1_2_d( "st1 {v11.2d}, [%[outptr]], #16\n" "st1 {v13.2d}, [%[outptr]], #16\n" "st1 {v15.2d}, [%[outptr]], #16\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "cc", "memory"); } static inline void interleave_8x2_2_d( @@ -758,49 +743,48 @@ static inline void interleave_8x2_2_d( "st1 {v11.2d}, [%[outptr]], #16\n" "st1 {v13.2d}, [%[outptr]], #16\n" "st1 {v15.2d}, [%[outptr]], #16\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "cc", "memory"); } template -static inline void interleave_2x4_4_b(const T*& inptr0, const T*& inptr1, - T*& outptr) { +static inline void interleave_2x4_4_b(const T*& inptr0, const T*& inptr1, T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_2x4_4_b only support uint8_t and int8_t"); - interleave_2x1_4_s(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(outptr)); + interleave_2x1_4_s( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(outptr)); } template -static inline void interleave_8x4_4_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T*& outptr) { +static inline void interleave_8x4_4_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_8x4_4_b only support uint8_t and int8_t"); - interleave_8x1_4_s(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(inptr2), - reinterpret_cast(inptr3), - reinterpret_cast(inptr4), - reinterpret_cast(inptr5), - reinterpret_cast(inptr6), - reinterpret_cast(inptr7), - reinterpret_cast(outptr)); + interleave_8x1_4_s( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(inptr4), + reinterpret_cast(inptr5), + reinterpret_cast(inptr6), + reinterpret_cast(inptr7), + reinterpret_cast(outptr)); } template -static inline void interleave_8x4_1_h(const T*& in0, const T*& in1, - const T*& in2, const T*& in3, T* out) { +static inline void interleave_8x4_1_h( + const T*& in0, const T*& in1, const T*& in2, const T*& in3, T* out) { static_assert(sizeof(T) == 2, "only support size == 2"); asm volatile( "ldr q0, [%[in0]], #16\n" // A1A2A3A4A5A6A7A8 @@ -827,79 +811,78 @@ static inline void interleave_8x4_1_h(const T*& in0, const T*& in1, "st1 {v13.2d}, [%[out]], #16\n" "st1 {v14.2d}, [%[out]], #16\n" "st1 {v15.2d}, [%[out]], #16\n" - : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), - [in3] "+r"(in3), [out] "+r"(out) + : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), [in3] "+r"(in3), + [out] "+r"(out) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "memory"); } template -static inline void interleave_8x8_2_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T*& outptr) { +static inline void interleave_8x8_2_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_8x8_2_b only support uint8_t and int8_t"); - interleave_8x1_2_d(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(inptr2), - reinterpret_cast(inptr3), - reinterpret_cast(inptr4), - reinterpret_cast(inptr5), - reinterpret_cast(inptr6), - reinterpret_cast(inptr7), - reinterpret_cast(outptr)); + interleave_8x1_2_d( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(inptr4), + reinterpret_cast(inptr5), + reinterpret_cast(inptr6), + reinterpret_cast(inptr7), + reinterpret_cast(outptr)); } template -static inline void interleave_8x8_2_h(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T*& outptr) { +static inline void interleave_8x8_2_h( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_8x8_2_h only support uint16_t and int16_t"); - interleave_8x2_2_d(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(inptr2), - reinterpret_cast(inptr3), - reinterpret_cast(inptr4), - reinterpret_cast(inptr5), - reinterpret_cast(inptr6), - reinterpret_cast(inptr7), - reinterpret_cast(outptr)); + interleave_8x2_2_d( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(inptr4), + reinterpret_cast(inptr5), + reinterpret_cast(inptr6), + reinterpret_cast(inptr7), + reinterpret_cast(outptr)); } template -static inline void interleave_8x2_8_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T*& outptr) { +static inline void interleave_8x2_8_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_8x2_8_b only support uint8_t and int8_t"); - interleave_8x1_8_h(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(inptr2), - reinterpret_cast(inptr3), - reinterpret_cast(inptr4), - reinterpret_cast(inptr5), - reinterpret_cast(inptr6), - reinterpret_cast(inptr7), - reinterpret_cast(outptr)); + interleave_8x1_8_h( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(inptr4), + reinterpret_cast(inptr5), + reinterpret_cast(inptr6), + reinterpret_cast(inptr7), + reinterpret_cast(outptr)); } template -static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T*& outptr) { +static inline void interleave_8x8_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_8x8_1_b only support uint8_t and int8_t"); @@ -917,15 +900,13 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, "st1 {v1.2d}, [%[outptr]], 16\n" // C1C2C3C4C5C6C7C8D1D2D3D4D5D6D7D8 "st1 {v2.2d}, [%[outptr]], 16\n" // E1E2E3E4E5E6E7E8F1F2F3F4F5F6F7F8 "st1 {v3.2d}, [%[outptr]], 16\n" // G1G2G3G4G5G6G7G8H1H2H3H4H5H6H7H8 - : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "memory"); } - template static inline void interleave_8x4_1_b_with_shift( const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, @@ -952,21 +933,19 @@ static inline void interleave_8x4_1_b_with_shift( "zip1 v10.16b, v7.16b, v6.16b\n" "zip2 v11.16b, v7.16b, v6.16b\n" "st1 {v8.16b-v11.16b},[%[outptr]],#64" - : [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1), - [ inptr2 ] "+r"(inptr2), [ inptr3 ] "+r"(inptr3), - [ inptr4 ] "+r"(inptr4), [ inptr5 ] "+r"(inptr5), - [ inptr6 ] "+r"(inptr6), [ inptr7 ] "+r"(inptr7), - [ outptr ] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : - : "v0", "v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11","memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "memory"); } template -static inline void interleave_8x8_1_h(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T*& outptr) { +static inline void interleave_8x8_1_h( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_8x8_1_h only support uint16_t and int16_t"); @@ -988,19 +967,16 @@ static inline void interleave_8x8_1_h(const T*& inptr0, const T*& inptr1, "st1 {v5.8h}, [%[outptr]], #16\n" // F1F2F3F4F5F6F7F8 "st1 {v6.8h}, [%[outptr]], #16\n" // G1G2G3G4G5G6G7G8 "st1 {v7.8h}, [%[outptr]], #16\n" // H1H2H3H4H5H6H7H8 - : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); } -static inline void interleave_4x1_4_s(const int32_t*& inptr0, - const int32_t*& inptr1, - const int32_t*& inptr2, - const int32_t*& inptr3, - int32_t*& outptr) { +static inline void interleave_4x1_4_s( + const int32_t*& inptr0, const int32_t*& inptr1, const int32_t*& inptr2, + const int32_t*& inptr3, int32_t*& outptr) { asm volatile( "ld1 {v0.4s}, [%[inptr0]], #16\n" // d0 = A0A1A2A3 "ld1 {v1.4s}, [%[inptr1]], #16\n" // d1 = B0B1B2B3 @@ -1020,18 +996,17 @@ static inline void interleave_4x1_4_s(const int32_t*& inptr0, "st1 {v14.4s}, [%[outptr]], #16\n" "st1 {v15.4s}, [%[outptr]], #16\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "cc", "memory"); } template -static inline void interleave_4x8_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x8_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert(sizeof(T) == 4, "only support size == 4"); asm volatile( "ld1 {v0.4s, v1.4s}, [%[inptr0]], #32\n" @@ -1043,17 +1018,16 @@ static inline void interleave_4x8_1_s(const T*& inptr0, const T*& inptr1, "st1 {v4.4s, v5.4s}, [%[outptr]], #32\n" "st1 {v6.4s, v7.4s}, [%[outptr]], #32\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "cc", "memory"); } template -static inline void interleave_4x12_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x12_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert(sizeof(T) == 4, "only support size == 4"); asm volatile( "ld1 {v0.4s, v1.4s, v2.4s}, [%[inptr0]], #48\n" @@ -1065,18 +1039,17 @@ static inline void interleave_4x12_1_s(const T*& inptr0, const T*& inptr1, "st1 {v8.4s, v9.4s, v10.4s}, [%[outptr]], #48\n" "st1 {v12.4s, v13.4s, v14.4s}, [%[outptr]], #48\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v4", "v5", "v6", "v8", "v9", "v10", "v12", - "v13", "v14", "cc", "memory"); + : "v0", "v1", "v2", "v4", "v5", "v6", "v8", "v9", "v10", "v12", "v13", + "v14", "cc", "memory"); } template -static inline void interleave_4x16_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x16_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert(sizeof(T) == 1, "only support size == 1"); asm volatile( "ld1 {v0.4s}, [%[inptr0]], #16\n" // d0 = A0A1A2A3 @@ -1088,18 +1061,16 @@ static inline void interleave_4x16_1_b(const T*& inptr0, const T*& inptr1, "st1 {v2.4s}, [%[outptr]], #16\n" "st1 {v3.4s}, [%[outptr]], #16\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "v4", "cc", "memory"); } - template -static inline void interleave_4x16_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x16_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert(sizeof(T) == 4, "only support size == 4"); asm volatile( "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" @@ -1111,46 +1082,47 @@ static inline void interleave_4x16_1_s(const T*& inptr0, const T*& inptr1, "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[outptr]], #64\n" "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[outptr]], #64\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "cc", "memory"); } template -static inline void interleave_4x2_4_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x2_4_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_4x2_4_b only support uint8_t and int8_t"); - interleave_4x1_4_h(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(inptr2), - reinterpret_cast(inptr3), - reinterpret_cast(outptr)); + interleave_4x1_4_h( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(outptr)); } template -static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x4_4_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_4x4_4_b only support uint8_t and int8_t"); - interleave_4x1_4_s(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(inptr2), - reinterpret_cast(inptr3), - reinterpret_cast(outptr)); + interleave_4x1_4_s( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(outptr)); } template -static inline void interleave_4x4_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x4_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert(sizeof(T) == 4, "interleave_4x4_1_s only support size == 4"); asm volatile( "ld1 {v0.4s}, [%[inptr0]], #16\n" @@ -1159,16 +1131,14 @@ static inline void interleave_4x4_1_s(const T*& inptr0, const T*& inptr1, "ld1 {v3.4s}, [%[inptr3]], #16\n" "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[outptr]], #64\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "cc", "memory"); } template -static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1, - T* outptr) { +static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1, T* outptr) { static_assert(sizeof(T) == 4, "interleave_2x4_4_s only support size == 4"); asm volatile( "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" @@ -1178,8 +1148,7 @@ static inline void interleave_2x4_4_s(const T*& inptr0, const T*& inptr1, "stp q2, q6, [%[outptr], #64]\n" "stp q3, q7, [%[outptr], #96]\n" - : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); } @@ -1197,31 +1166,33 @@ static inline void interleave_1x4_4_s(const T*& inptr0, T* outptr) { } template -static inline void interleave_4x8_2_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x8_2_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_4x8_2_b only support uint8_t and int8_t"); - interleave_4x1_2_d(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(inptr2), - reinterpret_cast(inptr3), - reinterpret_cast(outptr)); + interleave_4x1_2_d( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(outptr)); } template -static inline void interleave_4x8_2_h(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x8_2_h( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_4x8_2_h only support uint16_t and int16_t"); - interleave_4x2_2_d(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(inptr2), - reinterpret_cast(inptr3), - reinterpret_cast(outptr)); + interleave_4x2_2_d( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(outptr)); } template @@ -1272,8 +1243,8 @@ static inline void interleave_1x4_1_s(const T*& inptr0, T*& outptr) { } template -static inline void interleave_helper(const T*& inptr, T*& outptr, int unroll_k, - int ksize, T val = 0) { +static inline void interleave_helper( + const T*& inptr, T*& outptr, int unroll_k, int ksize, T val = 0) { int k = 0; for (; k < ksize; k++) { *outptr++ = *inptr++; @@ -1284,8 +1255,8 @@ static inline void interleave_helper(const T*& inptr, T*& outptr, int unroll_k, } template -static inline void interleave_1(const T*& inptr0, T*& outptr, int unroll_k, - int ksize, T val = 0) { +static inline void interleave_1( + const T*& inptr0, T*& outptr, int unroll_k, int ksize, T val = 0) { for (int k = 0; k < ksize; k += unroll_k) { int size = std::min(unroll_k, ksize - k); interleave_helper(inptr0, outptr, unroll_k, size, val); @@ -1293,9 +1264,9 @@ static inline void interleave_1(const T*& inptr0, T*& outptr, int unroll_k, } template -static inline void interleave_4(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, T*& outptr, - int unroll_k, int ksize, T val = 0) { +static inline void interleave_4( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr, int unroll_k, int ksize, T val = 0) { for (int k = 0; k < ksize; k += unroll_k) { int size = std::min(unroll_k, ksize - k); interleave_helper(inptr0, outptr, unroll_k, size, val); @@ -1306,11 +1277,10 @@ static inline void interleave_4(const T*& inptr0, const T*& inptr1, } template -static inline void interleave_8(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, T*& outptr, - int unroll_k, int ksize, T val = 0) { +static inline void interleave_8( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr, int unroll_k, int ksize, T val = 0) { for (int k = 0; k < ksize; k += unroll_k) { int size = std::min(unroll_k, ksize - k); interleave_helper(inptr0, outptr, unroll_k, size, val); @@ -1325,13 +1295,11 @@ static inline void interleave_8(const T*& inptr0, const T*& inptr1, } template -static inline void interleave_12(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - const T*& inptr8, const T*& inptr9, - const T*& inptr10, const T*& inptr11, - T*& outptr, int unroll_k, int ksize) { +static inline void interleave_12( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + const T*& inptr8, const T*& inptr9, const T*& inptr10, const T*& inptr11, + T*& outptr, int unroll_k, int ksize) { for (int k = 0; k < ksize; k += unroll_k) { int size = std::min(unroll_k, ksize - k); interleave_helper(inptr0, outptr, unroll_k, size); @@ -1359,8 +1327,8 @@ static inline void interleave_12(const T*& inptr0, const T*& inptr1, * rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[i, j] */ template -static inline void transpose_24x4_1_h(const T*& in0, const T*& in1, - const T*& in2, const T*& in3, T* out) { +static inline void transpose_24x4_1_h( + const T*& in0, const T*& in1, const T*& in2, const T*& in3, T* out) { static_assert(sizeof(T) == 2, "only support size == 2"); asm volatile( "ldp q0, q1, [%[in0]], #32\n" @@ -1390,8 +1358,8 @@ static inline void transpose_24x4_1_h(const T*& in0, const T*& in1, } template -static inline void transpose_16x4_1_h(const T*& in0, const T*& in1, - const T*& in2, const T*& in3, T* out) { +static inline void transpose_16x4_1_h( + const T*& in0, const T*& in1, const T*& in2, const T*& in3, T* out) { static_assert(sizeof(T) == 2, "only support size == 2"); asm volatile( "ldp q0, q1, [%[in0]], #32\n" @@ -1402,15 +1370,15 @@ static inline void transpose_16x4_1_h(const T*& in0, const T*& in1, "stp q4, q5, [%[out], #64]\n" "ldp q6, q7, [%[in3]], #32\n" "stp q6, q7, [%[out], #96]\n" - : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), - [in3] "+r"(in3), [out] "+r"(out) + : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), [in3] "+r"(in3), + [out] "+r"(out) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); } template -static inline void transpose_8x4_1_h(const T*& in0, const T*& in1, - const T*& in2, const T*& in3, T* out) { +static inline void transpose_8x4_1_h( + const T*& in0, const T*& in1, const T*& in2, const T*& in3, T* out) { static_assert(sizeof(T) == 2, "only support size == 2"); asm volatile( "ldr q0, [%[in0]], #16\n" @@ -1421,8 +1389,8 @@ static inline void transpose_8x4_1_h(const T*& in0, const T*& in1, "str q2, [%[out], #32]\n" "ldr q3, [%[in3]], #16\n" "str q3, [%[out], #48]\n" - : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), - [in3] "+r"(in3), [out] "+r"(out) + : [in0] "+r"(in0), [in1] "+r"(in1), [in2] "+r"(in2), [in3] "+r"(in3), + [out] "+r"(out) : : "v0", "v1", "v2", "v3", "memory"); } @@ -1538,11 +1506,10 @@ static inline void transpose_4x1_1_h(const T*& in0, T* out) { } template -static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T* outptr, int stride = 16) { - static_assert(sizeof(T) == 4, - "transpose_4x4_1_s only support sizeof(T) == 4"); +static inline void transpose_4x4_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T* outptr, int stride = 16) { + static_assert(sizeof(T) == 4, "transpose_4x4_1_s only support sizeof(T) == 4"); asm volatile( "ld1 {v0.4s}, [%[inptr0]], 16\n" // A0A1A2A3 @@ -1564,18 +1531,16 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1, "st1 {v10.4s}, [%[outptr]], %x[stride]\n" "st1 {v9.4s}, [%[outptr]], %x[stride]\n" "st1 {v11.4s}, [%[outptr]], %x[stride]\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr), [stride] "+r"(stride) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr), [stride] "+r"(stride) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "memory"); } template static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { - static_assert(sizeof(T) == 4, - "transpose_1x12_4_s only support sizeof(T) == 4"); + static_assert(sizeof(T) == 4, "transpose_1x12_4_s only support sizeof(T) == 4"); asm volatile( "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" @@ -1590,14 +1555,13 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { "stp q7, q11, [%[outptr], #160] \n" : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "memory"); } template static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { - static_assert(sizeof(T) == 4, - "transpose_1x4_4_s only support sizeof(T) == 4"); + static_assert(sizeof(T) == 4, "transpose_1x4_4_s only support sizeof(T) == 4"); asm volatile( "ld4 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[inptr0]], #64\n" @@ -1608,13 +1572,11 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { } template -static inline void transpose_8x4_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T* outptr) { - static_assert(sizeof(T) == 4, - "transpose_8x4_1_s only support sizeof(T) == 4"); +static inline void transpose_8x4_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T* outptr) { + static_assert(sizeof(T) == 4, "transpose_8x4_1_s only support sizeof(T) == 4"); asm volatile( "ld1 {v0.4s}, [%[inptr0]], 16\n" // A0A1A2A3 @@ -1649,25 +1611,21 @@ static inline void transpose_8x4_1_s(const T*& inptr0, const T*& inptr1, "st1 {v0.4s,v1.4s,v2.4s,v3.4s}, [%[outptr]], #64\n" "st1 {v4.4s,v5.4s,v6.4s,v7.4s}, [%[outptr]], #64\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "memory"); } template -static inline void transpose_12x4_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - const T*& inptr8, const T*& inptr9, - const T*& inptr10, const T*& inptr11, - T* outptr) { - static_assert(sizeof(T) == 4, - "transpose_12x4_1_s only support sizeof(T) == 4"); +static inline void transpose_12x4_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + const T*& inptr8, const T*& inptr9, const T*& inptr10, const T*& inptr11, + T* outptr) { + static_assert(sizeof(T) == 4, "transpose_12x4_1_s only support sizeof(T) == 4"); asm volatile( "ld1 {v0.4s}, [%[inptr0]], 16\n" // A0A1A2A3 "ld1 {v1.4s}, [%[inptr1]], 16\n" // B0B1B2B3 @@ -1719,22 +1677,21 @@ static inline void transpose_12x4_1_s(const T*& inptr0, const T*& inptr1, "st1 {v3.4s,v4.4s,v5.4s}, [%[outptr]], #48\n" "st1 {v6.4s,v7.4s,v8.4s}, [%[outptr]], #48\n" "st1 {v24.4s,v25.4s,v26.4s}, [%[outptr]], #48\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), [inptr11] "+r"(inptr11), + [outptr] "+r"(outptr) : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), - [inptr9] "+r"(inptr9), [inptr10] "+r"(inptr10), - [inptr11] "+r"(inptr11), [outptr] "+r"(outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "memory"); } template -static inline void transpose_12x4_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T* outptr) { +static inline void transpose_12x4_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T* outptr) { static_assert( std::is_same::value || std::is_same::value, "transpose_12x4_1_b only support uint8_t and int8_t"); @@ -1772,19 +1729,17 @@ static inline void transpose_12x4_1_b(const T*& inptr0, const T*& inptr1, "stp q17, q18, [%[outptr]], #32\n" "str q19, [%[outptr]], #16\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : - : "w1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "memory"); + : "w1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "memory"); } template -static inline void transpose_8x4_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T* outptr) { +static inline void transpose_8x4_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T* outptr) { static_assert( std::is_same::value || std::is_same::value, "transpose_8x4_1_b only support uint8_t and int8_t"); @@ -1802,20 +1757,17 @@ static inline void transpose_8x4_1_b(const T*& inptr0, const T*& inptr1, "st1 {v4.2d}, [%[outptr]], #16\n" "st1 {v5.2d}, [%[outptr]], #16\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "v0", "v1", "v2", "v3", "v4", "v5", "memory"); } template -static inline void transpose_4x8_1_b_with_shift(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T*& outptr) { - +static inline void transpose_4x8_1_b_with_shift( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr) { static int8x16_t shuffle_idx = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; static_assert( @@ -1831,8 +1783,8 @@ static inline void transpose_4x8_1_b_with_shift(const T*& inptr0, const T*& inpt "ld1 {v1.s}[2], [%[inptr6]], #4\n" // G1G2G3G4 "ld1 {v1.s}[3], [%[inptr7]], #4\n" // H1H2H3H4 - "tbl v2.16b, {v0.16b}, %[shuffle_idx].16b \n" // A1B1C1D1A2B2C2D2A3B3C3D3A4B4C4D4 - "tbl v3.16b, {v1.16b}, %[shuffle_idx].16b \n" // E1F1G1H1E2F2G2H2E3F3G3H3E4F4G4H4 + "tbl v2.16b, {v0.16b}, %[shuffle_idx].16b \n" // A1B1C1D1A2B2C2D2A3B3C3D3A4B4C4D4 + "tbl v3.16b, {v1.16b}, %[shuffle_idx].16b \n" // E1F1G1H1E2F2G2H2E3F3G3H3E4F4G4H4 "zip1 v4.4s, v2.4s, v3.4s\n" // A1B1C1D1E1F1G1H1 A2B2C2D2E2F2G2H2 "zip2 v5.4s, v2.4s, v3.4s\n" // A3B3C3D3E3F3G3H3 A4B4C4D4E4F4G4H4 @@ -1849,20 +1801,19 @@ static inline void transpose_4x8_1_b_with_shift(const T*& inptr0, const T*& inpt "zip2 v3.2d,v11.2d,v10.2d\n" "st1 {v0.2d-v3.2d},[%[outptr]],#64\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [shuffle_idx]"+w"(shuffle_idx), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), + [shuffle_idx] "+w"(shuffle_idx), [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","v8","v9","v10","v11","memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "memory"); } template -static inline void transpose_8x8_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T* outptr) { +static inline void transpose_8x8_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T* outptr) { static_assert( std::is_same::value || std::is_same::value, "transpose_8x8_1_b only support uint8_t and int8_t"); @@ -1911,22 +1862,19 @@ static inline void transpose_8x8_1_b(const T*& inptr0, const T*& inptr1, // A6B6C6D6E6F6G6H6 "st1 {v19.16b}, [%[outptr]], #16\n" // A7B7C7D7E7F7G7H7 // A8B8C8D8E8F8G8H8 + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "memory"); } template -static inline void transpose_4x16_1_b_helper(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T* outptr) { +static inline void transpose_4x16_1_b_helper( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T* outptr) { static_assert(sizeof(T) == 1, "only support size == 1"); static int8x16_t shuffle_idx = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; @@ -1954,19 +1902,18 @@ static inline void transpose_4x16_1_b_helper(const T*& inptr0, const T*& inptr1, "str d5, [%[outptr]], #16\n" "str d7, [%[outptr]], #16\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), - [outptr] "+r"(outptr), [shuffle_idx] "+w"(shuffle_idx) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr), + [shuffle_idx] "+w"(shuffle_idx) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); } template -static inline void transpose_4(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, T* outptr, - int interleave, int size, T val = 0) { +static inline void transpose_4( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T* outptr, int interleave, int size, T val = 0) { megdnn_assert(size <= interleave); int i = 0; for (; i < size; i++) { @@ -1984,11 +1931,10 @@ static inline void transpose_4(const T*& inptr0, const T*& inptr1, } template -static inline void transpose_8(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, T* outptr, - int interleave, int size, T val = 0) { +static inline void transpose_8( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T* outptr, int interleave, int size, T val = 0) { megdnn_assert(size <= interleave); int i = 0; for (; i < size; i++) { @@ -2016,13 +1962,11 @@ static inline void transpose_8(const T*& inptr0, const T*& inptr1, //! pack form {1, 4(icb), 4(ic), 4(oc)} to {1, 1, 4(oc), 16(ic)} template -static inline void transpose_interleave_4x4_4_b(const T*& inptr0, - const T*& inptr1, - const T*& inptr2, - const T*& inptr3, T* outptr, - int stride = 64) { - static_assert(sizeof(T) == 1, - "transpose_interleave_4x4_4_b only support sizeof(T) == 1"); +static inline void transpose_interleave_4x4_4_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T* outptr, int stride = 64) { + static_assert( + sizeof(T) == 1, "transpose_interleave_4x4_4_b only support sizeof(T) == 1"); asm volatile( "ld4 {v0.16b, v1.16b, v2.16b, v3.16b},[%[inptr0]], 64\n" @@ -2034,34 +1978,30 @@ static inline void transpose_interleave_4x4_4_b(const T*& inptr0, "st1 {v4.16b, v5.16b, v6.16b, v7.16b},[%[outptr]], %x[stride]\n" "st1 {v8.16b, v9.16b, v10.16b, v11.16b},[%[outptr]], %x[stride]\n" "st1 {v12.16b, v13.16b, v14.16b, v15.16b},[%[outptr]], %x[stride]\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr), [stride] "+r"(stride) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr), [stride] "+r"(stride) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v14", "v15", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v14", "v15", "memory"); } template -static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr, - int stride = 64) { - static_assert(sizeof(T) == 1, - "transpose_interleave_1x4_4_b only support sizeof(T) == 1"); +static inline void transpose_interleave_1x4_4_b( + const T*& inptr0, T* outptr, int stride = 64) { + static_assert( + sizeof(T) == 1, "transpose_interleave_1x4_4_b only support sizeof(T) == 1"); asm volatile( "ld4 {v0.16b, v1.16b, v2.16b, v3.16b},[%[inptr0]], 64\n" "st1 {v0.16b, v1.16b, v2.16b, v3.16b},[%[outptr]], %x[stride]\n" - : - [inptr0] "+r"(inptr0), [outptr] "+r"(outptr), [stride] "+r"(stride) + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr), [stride] "+r"(stride) : : "v0", "v1", "v2", "v3", "v4", "memory"); } -static inline void interleave_4x4_16x4_s8_s16(const int8_t* inptr0, - const int8_t* inptr1, - const int8_t* inptr2, - const int8_t* inptr3, - int16_t* outptr) { +static inline void interleave_4x4_16x4_s8_s16( + const int8_t* inptr0, const int8_t* inptr1, const int8_t* inptr2, + const int8_t* inptr3, int16_t* outptr) { int8x16_t row0 = vld1q_s8(inptr0); int16x8_t row0_01 = vmovl_low_s8(row0); int16x8_t row0_23 = vmovl_high_s8(row0); @@ -2111,9 +2051,8 @@ static inline void interleave_4x4_16x4_s8_s16(const int8_t* inptr0, vst1_s16(outptr + 14 * 4, row2_3); vst1_s16(outptr + 15 * 4, row3_3); }; -static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0, - const int8_t* inptr1, - int16_t* outptr) { +static inline void interleave_4x4_8x4_s8_s16( + const int8_t* inptr0, const int8_t* inptr1, int16_t* outptr) { int8x16_t row0 = vld1q_s8(inptr0); int16x8_t row0_01 = vmovl_low_s8(row0); int16x8_t row0_23 = vmovl_high_s8(row0); @@ -2140,8 +2079,7 @@ static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0, vst1_s16(outptr + 7 * 4, row1_3); }; -static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr, - int count) { +static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr, int count) { for (; count >= 32; count -= 32) { int8x8_t in0 = vld1_s8(inptr); int8x8_t in1 = vld1_s8(inptr + 1 * 8); @@ -2173,24 +2111,25 @@ static inline void transpos_12x4_s8(const int8_t* inptr0, int8_t* outptr) { int8x16_t input2 = vqtbl1q_s8(vld1q_s8(inptr0 + 4 * 8), vtbl); vst1_s8(outptr, input.val[0]); - vst1q_lane_s32(reinterpret_cast(outptr + 8), - vreinterpretq_s32_s8(input2), 0); + vst1q_lane_s32( + reinterpret_cast(outptr + 8), vreinterpretq_s32_s8(input2), 0); vst1_s8(outptr + 1 * 12, input.val[1]); - vst1q_lane_s32(reinterpret_cast(outptr + 1 * 12 + 8), - vreinterpretq_s32_s8(input2), 1); + vst1q_lane_s32( + reinterpret_cast(outptr + 1 * 12 + 8), + vreinterpretq_s32_s8(input2), 1); vst1_s8(outptr + 2 * 12, input.val[2]); - vst1q_lane_s32(reinterpret_cast(outptr + 2 * 12 + 8), - vreinterpretq_s32_s8(input2), 2); + vst1q_lane_s32( + reinterpret_cast(outptr + 2 * 12 + 8), + vreinterpretq_s32_s8(input2), 2); vst1_s8(outptr + 3 * 12, input.val[3]); - vst1q_lane_s32(reinterpret_cast(outptr + 3 * 12 + 8), - vreinterpretq_s32_s8(input2), 3); + vst1q_lane_s32( + reinterpret_cast(outptr + 3 * 12 + 8), + vreinterpretq_s32_s8(input2), 3); } - template -static inline void interleave_8x8_mk4_b(const T*& inptr0, const T*& inptr1, - T*& outptr) { - +static inline void interleave_8x8_mk4_b( + const T*& inptr0, const T*& inptr1, T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "transpose_8x4_1_b only support uint8_t and int8_t"); @@ -2211,16 +2150,13 @@ static inline void interleave_8x8_mk4_b(const T*& inptr0, const T*& inptr1, "st1 {v6.4s},[%[outptr]],#16\n" "st1 {v7.4s},[%[outptr]],#16\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); } template -static inline void transpose_8x8_mk4_b(const T*& inptr0, const T*& inptr1, - T* outptr) { - +static inline void transpose_8x8_mk4_b(const T*& inptr0, const T*& inptr1, T* outptr) { static_assert( std::is_same::value || std::is_same::value, "transpose_8x4_1_b only support uint8_t and int8_t"); @@ -2236,10 +2172,9 @@ static inline void transpose_8x8_mk4_b(const T*& inptr0, const T*& inptr1, "st1 {v6.2s},[%[outptr]],#8\n" "st1 {v7.2s},[%[outptr]],#8\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5","v6","v7","memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "memory"); } } // namespace aarch64 diff --git a/dnn/src/aarch64/matrix_mul/fp16/strategy.cpp b/dnn/src/aarch64/matrix_mul/fp16/strategy.cpp index bdb02eb3..6cac980f 100644 --- a/dnn/src/aarch64/matrix_mul/fp16/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/fp16/strategy.cpp @@ -21,8 +21,8 @@ using namespace aarch64::matmul; namespace { -void interleave_8x1(__fp16* out, const __fp16* in, int ldin, int y0, int ymax, - int k0, int kmax) { +void interleave_8x1( + __fp16* out, const __fp16* in, int ldin, int y0, int ymax, int k0, int kmax) { __fp16* outptr = out; const __fp16* inptr = in; __fp16 zerobuff[24]; @@ -51,8 +51,9 @@ void interleave_8x1(__fp16* out, const __fp16* in, int ldin, int y0, int ymax, int x = (kmax - k0); for (; x > 7; x -= 8) { int skippf = (x & 31); - interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr, skippf); + interleave_8x1_8_h( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, skippf); } for (; x > 0; x--) { @@ -109,8 +110,9 @@ void interleave_8x1(__fp16* out, const __fp16* in, int ldin, int y0, int ymax, } } -void interleave_24x1(__fp16* out, const __fp16* in, const int ldin, const int y0, - const int ymax, const int k0, const int kmax) { +void interleave_24x1( + __fp16* out, const __fp16* in, const int ldin, const int y0, const int ymax, + const int k0, const int kmax) { __fp16* outptr = out; const __fp16* inptr = in; __fp16 zerobuff[24]; @@ -144,9 +146,9 @@ void interleave_24x1(__fp16* out, const __fp16* in, const int ldin, const int y0 int x = (kmax - k0); for (; x > 7; x -= 8) { int skippf = (x & 31); - interleave_24x1_8_h_helper(inptr0, inptr1, inptr2, inptr3, - inptr4, inptr5, inptr6, inptr7, - outptr_inner, skippf); + interleave_24x1_8_h_helper( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_inner, skippf); } for (; x > 0; x--) { *outptr_inner++ = *inptr0++; @@ -188,9 +190,9 @@ void interleave_24x1(__fp16* out, const __fp16* in, const int ldin, const int y0 int x = (kmax - k0); for (; x > 7; x -= 8) { int skippf = (x & 31); - interleave_16x1_8_h_helper(inptr0, inptr1, inptr2, inptr3, - inptr4, inptr5, inptr6, inptr7, - outptr_inner, skippf); + interleave_16x1_8_h_helper( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_inner, skippf); } for (; x > 0; x--) { *outptr_inner++ = *inptr0++; @@ -229,8 +231,9 @@ void interleave_24x1(__fp16* out, const __fp16* in, const int ldin, const int y0 int x = (kmax - k0); for (; x > 7; x -= 8) { int skippf = (x & 31); - interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr, skippf); + interleave_8x1_8_h( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, skippf); } for (; x > 0; x--) { @@ -287,8 +290,8 @@ void interleave_24x1(__fp16* out, const __fp16* in, const int ldin, const int y0 } } -void transpose_1x8(__fp16* out, const __fp16* in, int ldin, int x0, int xmax, - int k0, int kmax) { +void transpose_1x8( + __fp16* out, const __fp16* in, int ldin, int x0, int xmax, int k0, int kmax) { int ksize = kmax - k0; int ksize8 = (ksize << 3); int ksize4 = (ksize << 2); @@ -421,8 +424,9 @@ void transpose_1x8(__fp16* out, const __fp16* in, int ldin, int x0, int xmax, } } -void transpose_1x24(__fp16* out, const __fp16* in, const int ldin, const int x0, - const int xmax, const int k0, const int kmax) { +void transpose_1x24( + __fp16* out, const __fp16* in, const int ldin, const int x0, const int xmax, + const int k0, const int kmax) { int ksize = kmax - k0; int ksize24 = ksize * 24; int ksize16 = (ksize << 4); @@ -644,9 +648,9 @@ void transpose_1x24(__fp16* out, const __fp16* in, const int ldin, const int x0, // // Accumulator -void aarch64_hgemm_assembly_kernel_24x8(const __fp16* a_ptr, - const __fp16*& b_ptr, int K, - __fp16* outptr0, int ldout, int type) { +void aarch64_hgemm_assembly_kernel_24x8( + const __fp16* a_ptr, const __fp16*& b_ptr, int K, __fp16* outptr0, int ldout, + int type) { int oddk = (K & 1); int k = ((K + 1) / 2) - 1; @@ -999,9 +1003,9 @@ void aarch64_hgemm_assembly_kernel_24x8(const __fp16* a_ptr, // +--+--+ - - - - +--------+--------+ // // Accumulator -void aarch64_hgemm_assembly_kernel_16x8(const __fp16* a_ptr, - const __fp16*& b_ptr, int K, - __fp16* outptr0, int ldout, int type) { +void aarch64_hgemm_assembly_kernel_16x8( + const __fp16* a_ptr, const __fp16*& b_ptr, int K, __fp16* outptr0, int ldout, + int type) { int oddk = (K & 1); int k = ((K + 1) / 2) - 1; @@ -1224,16 +1228,14 @@ void aarch64_hgemm_assembly_kernel_16x8(const __fp16* a_ptr, "3:\n" "str q23, [%[outptr7], #16]\n" - : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [b1] "+w"(b1), - [k] "+r"(k), [b0a] "+w"(b0a), [b1a] "+w"(b1a), - [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), - [outptr3] "+r"(outptr3), [outptr4] "+r"(outptr4), - [outptr5] "+r"(outptr5), [outptr6] "+r"(outptr6), - [outptr7] "+r"(outptr7) + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [b1] "+w"(b1), [k] "+r"(k), + [b0a] "+w"(b0a), [b1a] "+w"(b1a), [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3), [outptr4] "+r"(outptr4), + [outptr5] "+r"(outptr5), [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7) : [oddk] "r"(oddk), [type] "r"(type) - : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", - "v17", "v18", "v19", "v20", "v21", "v22", "v23", "cc", "memory"); + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "cc", "memory"); } // Overview of register layout: @@ -1264,9 +1266,9 @@ void aarch64_hgemm_assembly_kernel_16x8(const __fp16* a_ptr, // +--+--+ - - - - +--------+ // // Accumulator -void aarch64_hgemm_assembly_kernel_8x8(const __fp16* a_ptr, - const __fp16*& b_ptr, int K, - __fp16* outptr0, int ldout, int type) { +void aarch64_hgemm_assembly_kernel_8x8( + const __fp16* a_ptr, const __fp16*& b_ptr, int K, __fp16* outptr0, int ldout, + int type) { int oddk = (K & 1); int k = ((K + 1) / 2) - 1; @@ -1420,13 +1422,11 @@ void aarch64_hgemm_assembly_kernel_8x8(const __fp16* a_ptr, "3:\n" : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [k] "+r"(k), [b0a] "+w"(b0a), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), - [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3), - [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3), [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7) : [oddk] "r"(oddk), [type] "r"(type) - : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc", - "memory"); + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc", "memory"); } // Overview of register layout: @@ -1457,10 +1457,9 @@ void aarch64_hgemm_assembly_kernel_8x8(const __fp16* a_ptr, // +--+--+ - - - - +--------+ // // Accumulator -void aarch64_hgemm_assembly_kernel_4x8(const __fp16* a_ptr, - const __fp16*& b_ptr, int K, - __fp16* outptr0, int ldout, int x_remain, - int type) { +void aarch64_hgemm_assembly_kernel_4x8( + const __fp16* a_ptr, const __fp16*& b_ptr, int K, __fp16* outptr0, int ldout, + int x_remain, int type) { int oddk = (K & 1); int k = ((K + 1) / 2) - 1; @@ -1654,9 +1653,8 @@ void aarch64_hgemm_assembly_kernel_4x8(const __fp16* a_ptr, "3:\n" STORE_C : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [k] "+r"(k), [b0a] "+w"(b0a), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), - [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3), - [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3), [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), [outptr6] "+r"(outptr6), [outptr7] "+r"(outptr7) : [oddk] "r"(oddk), [x_remain] "r"(x_remain), [type] "r"(type) : "x0", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc", @@ -1692,10 +1690,9 @@ void aarch64_hgemm_assembly_kernel_4x8(const __fp16* a_ptr, // // Accumulator //! cannot load %[a0] and %[a0a] at same time! -void aarch64_hgemm_assembly_kernel_24x4(const __fp16* a_ptr, - const __fp16*& b_ptr, int K, - __fp16* outptr0, int ldout, - int y_remain, int type) { +void aarch64_hgemm_assembly_kernel_24x4( + const __fp16* a_ptr, const __fp16*& b_ptr, int K, __fp16* outptr0, int ldout, + int y_remain, int type) { int oddk = (K & 1); int k = ((K + 1) / 2) - 1; @@ -1743,157 +1740,154 @@ void aarch64_hgemm_assembly_kernel_24x4(const __fp16* a_ptr, STORE_LINE("10", "18", "26", "2") \ STORE_LINE("11", "19", "27", "3") \ "STORE_24x4_C_END:\n" -// clang-format on - - asm volatile( - ".arch armv8.2-a+fp16\n" - - // load accumulator C - "cmp %w[type], #0\n" - "beq 5f\n" - LOAD_C - "b 6f\n" - "5:\n" - "eor v8.16b, v8.16b, v8.16b\n" - "eor v9.16b, v9.16b, v9.16b\n" - "eor v10.16b, v10.16b, v10.16b\n" - "eor v11.16b, v11.16b, v11.16b\n" - - "eor v16.16b, v16.16b, v16.16b\n" - "eor v17.16b, v17.16b, v17.16b\n" - "eor v18.16b, v18.16b, v18.16b\n" - "eor v19.16b, v19.16b, v19.16b\n" - - "eor v24.16b, v24.16b, v24.16b\n" - "eor v25.16b, v25.16b, v25.16b\n" - "eor v26.16b, v26.16b, v26.16b\n" - "eor v27.16b, v27.16b, v27.16b\n" - - "6:\n" - "ldr %d[a0], [%[a_ptr]]\n" - "ldr %q[b0], [%[b_ptr]]\n" - "ldr %q[b1], [%[b_ptr], #16]\n" - "ldr %q[b2], [%[b_ptr], #32]\n" - "ldr %q[b0a], [%[b_ptr], #48]\n" - "ldr %q[b1a], [%[b_ptr], #64]\n" - - "cbz %w[k], 4f\n" - - "1:\n" - "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" - "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" - "ldr %d[a0a], [%[a_ptr], #8]\n" - "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" - "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" - "ldr %q[b2a], [%[b_ptr], #80]\n" - "ldr %q[b0], [%[b_ptr], #96]\n" - - "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" - "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" - "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" - "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" - "add %[b_ptr], %[b_ptr], #96\n" - "ldr %q[b1], [%[b_ptr], #16]\n" - - "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" - "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" - "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" - "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" - "ldr %d[a0], [%[a_ptr], #16]\n" - - "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" - "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" - "ldr %q[b2], [%[b_ptr], #32]\n" - - "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" - "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" - "ldr %q[b0a], [%[b_ptr], #48]\n" - - "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" - "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" - "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" - "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" - "ldr %q[b1a], [%[b_ptr], #64]\n" - - "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n" - "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n" - "add %[a_ptr], %[a_ptr], #16\n" - "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n" - "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n" - "subs %w[k], %w[k], #1\n" - - "bne 1b\n" - "4:\n" - // Jump to odd tail if necessary. - "cbnz %w[oddk], 2f\n" - - // Even tail - "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" - "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" - "ldr %d[a0a], [%[a_ptr], #8]\n" - "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" - "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" - "ldr %q[b2a], [%[b_ptr], #80]\n" - - "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" - "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" - "add %[b_ptr], %[b_ptr], #96\n" - "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" - "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" - "add %[a_ptr], %[a_ptr], #16\n" - - "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" - "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" - "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" - "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" - - "fmla v8.8h, %[b0a].8h, %[a0a].h[0]\n" - "fmla v9.8h, %[b0a].8h, %[a0a].h[1]\n" - "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" - "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" - - "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" - "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" - "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" - "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" - - "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n" - "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n" - "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n" - "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n" - "b 3f\n" - - // Odd tail - "2:\n" - "add %[a_ptr], %[a_ptr], #8\n" - "add %[b_ptr], %[b_ptr], #48\n" - - "fmla v8.8h, %[b0].8h, %[a0].h[0]\n" - "fmla v9.8h, %[b0].8h, %[a0].h[1]\n" - "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" - "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" - - "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" - "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" - "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" - "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" - - "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" - "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" - "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" - "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" - - "3:\n" STORE_C - : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), - [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k), - [b0a] "+w"(b0a), [b1a] "+w"(b1a), [b2a] "+w"(b2a), - [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), - [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : - [oddk] "r"(oddk), [y_remain] "r"(y_remain), [type] "r"(type) - : "w0", "v8", "v9", "v10", "v11", "v16", "v17", "v18", - "v19", "v24", "v25", "v26", "v27", "cc", "memory"); + // clang-format on + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" LOAD_C + "b 6f\n" + "5:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "eor v24.16b, v24.16b, v24.16b\n" + "eor v25.16b, v25.16b, v25.16b\n" + "eor v26.16b, v26.16b, v26.16b\n" + "eor v27.16b, v27.16b, v27.16b\n" + + "6:\n" + "ldr %d[a0], [%[a_ptr]]\n" + "ldr %q[b0], [%[b_ptr]]\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + "ldr %q[b0a], [%[b_ptr], #48]\n" + "ldr %q[b1a], [%[b_ptr], #64]\n" + + "cbz %w[k], 4f\n" + + "1:\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b2a], [%[b_ptr], #80]\n" + "ldr %q[b0], [%[b_ptr], #96]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + "ldr %d[a0], [%[a_ptr], #16]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "ldr %q[b2], [%[b_ptr], #32]\n" + + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "ldr %q[b0a], [%[b_ptr], #48]\n" + + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + "ldr %q[b1a], [%[b_ptr], #64]\n" + + "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n" + "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n" + "add %[a_ptr], %[a_ptr], #16\n" + "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n" + "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n" + "subs %w[k], %w[k], #1\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b2a], [%[b_ptr], #80]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "add %[b_ptr], %[b_ptr], #96\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + + "fmla v8.8h, %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h, %[b0a].8h, %[a0a].h[1]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + + "fmla v24.8h, %[b2a].8h, %[a0a].h[0]\n" + "fmla v25.8h, %[b2a].8h, %[a0a].h[1]\n" + "fmla v26.8h, %[b2a].8h, %[a0a].h[2]\n" + "fmla v27.8h, %[b2a].8h, %[a0a].h[3]\n" + "b 3f\n" + + // Odd tail + "2:\n" + "add %[a_ptr], %[a_ptr], #8\n" + "add %[b_ptr], %[b_ptr], #48\n" + + "fmla v8.8h, %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h, %[b0].8h, %[a0].h[1]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + + "fmla v24.8h, %[b2].8h, %[a0].h[0]\n" + "fmla v25.8h, %[b2].8h, %[a0].h[1]\n" + "fmla v26.8h, %[b2].8h, %[a0].h[2]\n" + "fmla v27.8h, %[b2].8h, %[a0].h[3]\n" + + "3:\n" STORE_C + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [b1] "+w"(b1), + [b2] "+w"(b2), [k] "+r"(k), [b0a] "+w"(b0a), [b1a] "+w"(b1a), + [b2a] "+w"(b2a), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3) + : [oddk] "r"(oddk), [y_remain] "r"(y_remain), [type] "r"(type) + : "w0", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19", "v24", "v25", + "v26", "v27", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE @@ -1924,10 +1918,9 @@ void aarch64_hgemm_assembly_kernel_24x4(const __fp16* a_ptr, // +--+--+ - - - - +--------+--------+ // // Accumulator -void aarch64_hgemm_assembly_kernel_16x4(const __fp16* a_ptr, - const __fp16*& b_ptr, int K, - __fp16* outptr0, int ldout, - int y_remain, int type) { +void aarch64_hgemm_assembly_kernel_16x4( + const __fp16* a_ptr, const __fp16*& b_ptr, int K, __fp16* outptr0, int ldout, + int y_remain, int type) { int oddk = (K & 1); int k = ((K + 1) / 2) - 1; @@ -1942,7 +1935,7 @@ void aarch64_hgemm_assembly_kernel_16x4(const __fp16* a_ptr, __fp16* outptr2 = outptr1 + ldout; __fp16* outptr3 = outptr2 + ldout; -// clang-format off + // clang-format off #define LOAD_LINE(v1, v2, n) \ "cbz w0, LOAD_16x4_C_END\n" \ @@ -1970,123 +1963,121 @@ void aarch64_hgemm_assembly_kernel_16x4(const __fp16* a_ptr, STORE_LINE("11", "19", "3") \ "STORE_16x4_C_END:\n" -// clang-format on - - asm volatile( - ".arch armv8.2-a+fp16\n" - - // load accumulator C - "cmp %w[type], #0\n" - "beq 5f\n" LOAD_C - "b 6f\n" - - "5:\n" - "eor v8.16b, v8.16b, v8.16b\n" - "eor v9.16b, v9.16b, v9.16b\n" - "eor v10.16b, v10.16b, v10.16b\n" - "eor v11.16b, v11.16b, v11.16b\n" - - "eor v16.16b, v16.16b, v16.16b\n" - "eor v17.16b, v17.16b, v17.16b\n" - "eor v18.16b, v18.16b, v18.16b\n" - "eor v19.16b, v19.16b, v19.16b\n" - - "6:\n" - "ldr %d[a0], [%[a_ptr]]\n" - "ldr %q[b0], [%[b_ptr]]\n" - "ldr %q[b1], [%[b_ptr], #16]\n" - "ldr %q[b0a], [%[b_ptr], #32]\n" - "ldr %q[b1a], [%[b_ptr], #48]\n" - - "cbz %w[k], 4f\n" - - "1:\n" - "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" - "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" - "ldr %d[a0a], [%[a_ptr], #8]\n" - "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" - "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" - "ldr %q[b0], [%[b_ptr], #64]\n" - - "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" - "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" - "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" - "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" - "add %[b_ptr], %[b_ptr], #64\n" - "ldr %q[b1], [%[b_ptr], #16]\n" - - "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" - "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" - "ldr %d[a0], [%[a_ptr], #16]\n" - "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" - "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" - "ldr %q[b0a], [%[b_ptr], #32]\n" - - "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" - "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" - "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" - "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" - "ldr %q[b1a], [%[b_ptr], #48]\n" - - "add %[a_ptr], %[a_ptr], #16\n" - "subs %w[k], %w[k], #1\n" - - "bne 1b\n" - "4:\n" - // Jump to odd tail if necessary. - "cbnz %w[oddk], 2f\n" - - // Even tail - "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" - "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" - "ldr %d[a0a], [%[a_ptr], #8]\n" - "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" - "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" - - "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" - "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" - "add %[b_ptr], %[b_ptr], #64\n" - "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" - "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" - "add %[a_ptr], %[a_ptr], #16\n" - - "fmla v8.8h, %[b0a].8h, %[a0a].h[0]\n" - "fmla v9.8h, %[b0a].8h, %[a0a].h[1]\n" - "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" - "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" - - "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" - "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" - "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" - "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" - - "b 3f\n" - - // Odd tail - "2:\n" - "add %[a_ptr], %[a_ptr], #8\n" - "add %[b_ptr], %[b_ptr], #32\n" - - "fmla v8.8h, %[b0].8h, %[a0].h[0]\n" - "fmla v9.8h, %[b0].8h, %[a0].h[1]\n" - "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" - "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" - - "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" - "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" - "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" - "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" - - "3:\n" STORE_C - : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), - [b1] "+w"(b1), [k] "+r"(k), [b0a] "+w"(b0a), - [b1a] "+w"(b1a), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), - [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) - : - [oddk] "r"(oddk), [y_remain] "r"(y_remain), [type] "r"(type) - : "w0", "v8", "v9", "v10", "v11", "v16", "v17", "v18", - "v19", "cc", "memory"); + // clang-format on + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" LOAD_C + "b 6f\n" + + "5:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "eor v16.16b, v16.16b, v16.16b\n" + "eor v17.16b, v17.16b, v17.16b\n" + "eor v18.16b, v18.16b, v18.16b\n" + "eor v19.16b, v19.16b, v19.16b\n" + + "6:\n" + "ldr %d[a0], [%[a_ptr]]\n" + "ldr %q[b0], [%[b_ptr]]\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + "ldr %q[b0a], [%[b_ptr], #32]\n" + "ldr %q[b1a], [%[b_ptr], #48]\n" + + "cbz %w[k], 4f\n" + + "1:\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b0], [%[b_ptr], #64]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "add %[b_ptr], %[b_ptr], #64\n" + "ldr %q[b1], [%[b_ptr], #16]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "ldr %d[a0], [%[a_ptr], #16]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "ldr %q[b0a], [%[b_ptr], #32]\n" + + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + "ldr %q[b1a], [%[b_ptr], #48]\n" + + "add %[a_ptr], %[a_ptr], #16\n" + "subs %w[k], %w[k], #1\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "add %[b_ptr], %[b_ptr], #64\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "fmla v8.8h, %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h, %[b0a].8h, %[a0a].h[1]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + + "fmla v16.8h, %[b1a].8h, %[a0a].h[0]\n" + "fmla v17.8h, %[b1a].8h, %[a0a].h[1]\n" + "fmla v18.8h, %[b1a].8h, %[a0a].h[2]\n" + "fmla v19.8h, %[b1a].8h, %[a0a].h[3]\n" + + "b 3f\n" + + // Odd tail + "2:\n" + "add %[a_ptr], %[a_ptr], #8\n" + "add %[b_ptr], %[b_ptr], #32\n" + + "fmla v8.8h, %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h, %[b0].8h, %[a0].h[1]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + + "fmla v16.8h, %[b1].8h, %[a0].h[0]\n" + "fmla v17.8h, %[b1].8h, %[a0].h[1]\n" + "fmla v18.8h, %[b1].8h, %[a0].h[2]\n" + "fmla v19.8h, %[b1].8h, %[a0].h[3]\n" + + "3:\n" STORE_C + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [b1] "+w"(b1), [k] "+r"(k), + [b0a] "+w"(b0a), [b1a] "+w"(b1a), [a_ptr] "+r"(a_ptr), + [b_ptr] "+r"(b_ptr), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), + [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [oddk] "r"(oddk), [y_remain] "r"(y_remain), [type] "r"(type) + : "w0", "v8", "v9", "v10", "v11", "v16", "v17", "v18", "v19", "cc", + "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE @@ -2117,10 +2108,9 @@ void aarch64_hgemm_assembly_kernel_16x4(const __fp16* a_ptr, // +--+--+ - - - - +--------+ // // Accumulator -void aarch64_hgemm_assembly_kernel_8x4(const __fp16* a_ptr, - const __fp16*& b_ptr, int K, - __fp16* outptr0, int ldout, int y_remain, - int type) { +void aarch64_hgemm_assembly_kernel_8x4( + const __fp16* a_ptr, const __fp16*& b_ptr, int K, __fp16* outptr0, int ldout, + int y_remain, int type) { int oddk = (K & 1); int k = ((K + 1) / 2) - 1; @@ -2161,89 +2151,87 @@ void aarch64_hgemm_assembly_kernel_8x4(const __fp16* a_ptr, STORE_LINE("10", "2") \ STORE_LINE("11", "3") \ "STORE_8x4_C_END:\n" -// clang-format on - - asm volatile( - ".arch armv8.2-a+fp16\n" - - // load accumulator C - "cmp %w[type], #0\n" - "beq 5f\n" LOAD_C - "b 6f\n" - "5:\n" - "eor v8.16b, v8.16b, v8.16b\n" - "eor v9.16b, v9.16b, v9.16b\n" - "eor v10.16b, v10.16b, v10.16b\n" - "eor v11.16b, v11.16b, v11.16b\n" - - "6:\n" - "ldr %d[a0], [%[a_ptr]]\n" - "ldr %q[b0], [%[b_ptr]]\n" - "ldr %q[b0a], [%[b_ptr], #16]\n" - - "cbz %w[k], 4f\n" - - "1:\n" - "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" - "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" - "ldr %d[a0a], [%[a_ptr], #8]\n" - "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" - "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" - "ldr %q[b0], [%[b_ptr], #32]\n" - - "add %[b_ptr], %[b_ptr], #32\n" - "ldr %d[a0], [%[a_ptr], #16]\n" - - "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" - "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" - "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" - "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" - "ldr %q[b0a], [%[b_ptr], #16]\n" - - "add %[a_ptr], %[a_ptr], #16\n" - "subs %w[k], %w[k], #1\n" - - "bne 1b\n" - "4:\n" - // Jump to odd tail if necessary. - "cbnz %w[oddk], 2f\n" - - // Even tail - "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" - "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" - "ldr %d[a0a], [%[a_ptr], #8]\n" - "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" - "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" - - "add %[b_ptr], %[b_ptr], #32\n" - "add %[a_ptr], %[a_ptr], #16\n" - - "fmla v8.8h, %[b0a].8h, %[a0a].h[0]\n" - "fmla v9.8h, %[b0a].8h, %[a0a].h[1]\n" - "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" - "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" - - "b 3f\n" - - // Odd tail - "2:\n" - "add %[a_ptr], %[a_ptr], #8\n" - "add %[b_ptr], %[b_ptr], #16\n" - - "fmla v8.8h, %[b0].8h, %[a0].h[0]\n" - "fmla v9.8h, %[b0].8h, %[a0].h[1]\n" - "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" - "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" - - "3:\n" STORE_C - : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), - [k] "+r"(k), [b0a] "+w"(b0a), [a_ptr] "+r"(a_ptr), - [b_ptr] "+r"(b_ptr), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), - [outptr3] "+r"(outptr3) - : - [oddk] "r"(oddk), [y_remain] "r"(y_remain), [type] "r"(type) - : "w0", "v8", "v9", "v10", "v11", "cc", "memory"); + // clang-format on + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" LOAD_C + "b 6f\n" + "5:\n" + "eor v8.16b, v8.16b, v8.16b\n" + "eor v9.16b, v9.16b, v9.16b\n" + "eor v10.16b, v10.16b, v10.16b\n" + "eor v11.16b, v11.16b, v11.16b\n" + + "6:\n" + "ldr %d[a0], [%[a_ptr]]\n" + "ldr %q[b0], [%[b_ptr]]\n" + "ldr %q[b0a], [%[b_ptr], #16]\n" + + "cbz %w[k], 4f\n" + + "1:\n" + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + "ldr %q[b0], [%[b_ptr], #32]\n" + + "add %[b_ptr], %[b_ptr], #32\n" + "ldr %d[a0], [%[a_ptr], #16]\n" + + "fmla v8.8h , %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h , %[b0a].8h, %[a0a].h[1]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + "ldr %q[b0a], [%[b_ptr], #16]\n" + + "add %[a_ptr], %[a_ptr], #16\n" + "subs %w[k], %w[k], #1\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "fmla v8.8h , %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h , %[b0].8h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + + "add %[b_ptr], %[b_ptr], #32\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "fmla v8.8h, %[b0a].8h, %[a0a].h[0]\n" + "fmla v9.8h, %[b0a].8h, %[a0a].h[1]\n" + "fmla v10.8h, %[b0a].8h, %[a0a].h[2]\n" + "fmla v11.8h, %[b0a].8h, %[a0a].h[3]\n" + + "b 3f\n" + + // Odd tail + "2:\n" + "add %[a_ptr], %[a_ptr], #8\n" + "add %[b_ptr], %[b_ptr], #16\n" + + "fmla v8.8h, %[b0].8h, %[a0].h[0]\n" + "fmla v9.8h, %[b0].8h, %[a0].h[1]\n" + "fmla v10.8h, %[b0].8h, %[a0].h[2]\n" + "fmla v11.8h, %[b0].8h, %[a0].h[3]\n" + + "3:\n" STORE_C + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [k] "+r"(k), + [b0a] "+w"(b0a), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3) + : [oddk] "r"(oddk), [y_remain] "r"(y_remain), [type] "r"(type) + : "w0", "v8", "v9", "v10", "v11", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE @@ -2274,10 +2262,9 @@ void aarch64_hgemm_assembly_kernel_8x4(const __fp16* a_ptr, // +--+--+ - - - - +--------+ // // Accumulator -void aarch64_hgemm_assembly_kernel_4x4(const __fp16* a_ptr, - const __fp16*& b_ptr, int K, - __fp16* outptr0, int ldout, int x_remain, - int y_remain, int type) { +void aarch64_hgemm_assembly_kernel_4x4( + const __fp16* a_ptr, const __fp16*& b_ptr, int K, __fp16* outptr0, int ldout, + int x_remain, int y_remain, int type) { int oddk = (K & 1); int k = ((K + 1) / 2) - 1; @@ -2323,13 +2310,9 @@ void aarch64_hgemm_assembly_kernel_4x4(const __fp16* a_ptr, ":\n" \ "subs w1, w1, #1\n" -#define LOAD_C \ - "mov w1, %w[y_remain]\n" \ - LOAD_LINE("8", "0") \ - LOAD_LINE("9", "1") \ - LOAD_LINE("10", "2") \ - LOAD_LINE("11", "3") \ - "LOAD_4x4_C_END:\n" +#define LOAD_C \ + "mov w1, %w[y_remain]\n" LOAD_LINE("8", "0") LOAD_LINE("9", "1") \ + LOAD_LINE("10", "2") LOAD_LINE("11", "3") "LOAD_4x4_C_END:\n" #define STORE_LINE(reg_index, n) \ "cbz w1, STORE_4x4_C_END\n" \ @@ -2364,93 +2347,89 @@ void aarch64_hgemm_assembly_kernel_4x4(const __fp16* a_ptr, ":\n" \ "subs w1, w1, #1\n" -#define STORE_C "mov w1, %w[y_remain]\n" \ - STORE_LINE("8", "0") \ - STORE_LINE("9", "1") \ - STORE_LINE("10", "2") \ - STORE_LINE("11", "3") \ - "STORE_4x4_C_END:\n" - - asm volatile( - ".arch armv8.2-a+fp16\n" - - // load accumulator C - "cmp %w[type], #0\n" - "beq 5f\n" LOAD_C - "b 6f\n" - - "5:\n" - "eor v8.8b, v8.8b, v8.8b\n" - "eor v9.8b, v9.8b, v9.8b\n" - "eor v10.8b, v10.8b, v10.8b\n" - "eor v11.8b, v11.8b, v11.8b\n" - - "6:\n" - "ldr %d[a0], [%[a_ptr]]\n" - - "cbz %w[k], 4f\n" - - "1:\n" - "ldp %d[b0], %d[b0a], [%[b_ptr]]\n" - "fmla v8.4h , %[b0].4h, %[a0].h[0]\n" - "fmla v9.4h , %[b0].4h, %[a0].h[1]\n" - "ldr %d[a0a], [%[a_ptr], #8]\n" - "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" - "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" - - "add %[b_ptr], %[b_ptr], #16\n" - "ldr %d[a0], [%[a_ptr], #16]\n" - - "fmla v8.4h , %[b0a].4h, %[a0a].h[0]\n" - "fmla v9.4h , %[b0a].4h, %[a0a].h[1]\n" - "fmla v10.4h, %[b0a].4h, %[a0a].h[2]\n" - "fmla v11.4h, %[b0a].4h, %[a0a].h[3]\n" - - "add %[a_ptr], %[a_ptr], #16\n" - "subs %w[k], %w[k], #1\n" - - "bne 1b\n" - "4:\n" - // Jump to odd tail if necessary. - "cbnz %w[oddk], 2f\n" - - // Even tail - "ldp %d[b0], %d[b0a], [%[b_ptr]]\n" - "fmla v8.4h , %[b0].4h, %[a0].h[0]\n" - "fmla v9.4h , %[b0].4h, %[a0].h[1]\n" - "ldr %d[a0a], [%[a_ptr], #8]\n" - "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" - "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" - - "add %[b_ptr], %[b_ptr], #16\n" - "add %[a_ptr], %[a_ptr], #16\n" - - "fmla v8.4h, %[b0a].4h, %[a0a].h[0]\n" - "fmla v9.4h, %[b0a].4h, %[a0a].h[1]\n" - "fmla v10.4h, %[b0a].4h, %[a0a].h[2]\n" - "fmla v11.4h, %[b0a].4h, %[a0a].h[3]\n" - "b 3f\n" - - // Odd tail - "2:\n" - "ldr %d[b0], [%[b_ptr]]\n" - "add %[a_ptr], %[a_ptr], #8\n" - "add %[b_ptr], %[b_ptr], #8\n" - - "fmla v8.4h, %[b0].4h, %[a0].h[0]\n" - "fmla v9.4h, %[b0].4h, %[a0].h[1]\n" - "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" - "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" - - "3:\n" STORE_C - : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), - [k] "+r"(k), [b0a] "+w"(b0a), [a_ptr] "+r"(a_ptr), - [b_ptr] "+r"(b_ptr), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), - [outptr3] "+r"(outptr3) - : [oddk] "r"(oddk), [x_remain] "r"(x_remain), - [y_remain] "r"(y_remain), [type] "r"(type) - : "x0", "w1", "v8", "v9", "v10", "v11", "cc", "memory"); +#define STORE_C \ + "mov w1, %w[y_remain]\n" STORE_LINE("8", "0") STORE_LINE("9", "1") \ + STORE_LINE("10", "2") STORE_LINE("11", "3") "STORE_4x4_C_END:\n" + + asm volatile( + ".arch armv8.2-a+fp16\n" + + // load accumulator C + "cmp %w[type], #0\n" + "beq 5f\n" LOAD_C + "b 6f\n" + + "5:\n" + "eor v8.8b, v8.8b, v8.8b\n" + "eor v9.8b, v9.8b, v9.8b\n" + "eor v10.8b, v10.8b, v10.8b\n" + "eor v11.8b, v11.8b, v11.8b\n" + + "6:\n" + "ldr %d[a0], [%[a_ptr]]\n" + + "cbz %w[k], 4f\n" + + "1:\n" + "ldp %d[b0], %d[b0a], [%[b_ptr]]\n" + "fmla v8.4h , %[b0].4h, %[a0].h[0]\n" + "fmla v9.4h , %[b0].4h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" + "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" + + "add %[b_ptr], %[b_ptr], #16\n" + "ldr %d[a0], [%[a_ptr], #16]\n" + + "fmla v8.4h , %[b0a].4h, %[a0a].h[0]\n" + "fmla v9.4h , %[b0a].4h, %[a0a].h[1]\n" + "fmla v10.4h, %[b0a].4h, %[a0a].h[2]\n" + "fmla v11.4h, %[b0a].4h, %[a0a].h[3]\n" + + "add %[a_ptr], %[a_ptr], #16\n" + "subs %w[k], %w[k], #1\n" + + "bne 1b\n" + "4:\n" + // Jump to odd tail if necessary. + "cbnz %w[oddk], 2f\n" + + // Even tail + "ldp %d[b0], %d[b0a], [%[b_ptr]]\n" + "fmla v8.4h , %[b0].4h, %[a0].h[0]\n" + "fmla v9.4h , %[b0].4h, %[a0].h[1]\n" + "ldr %d[a0a], [%[a_ptr], #8]\n" + "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" + "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" + + "add %[b_ptr], %[b_ptr], #16\n" + "add %[a_ptr], %[a_ptr], #16\n" + + "fmla v8.4h, %[b0a].4h, %[a0a].h[0]\n" + "fmla v9.4h, %[b0a].4h, %[a0a].h[1]\n" + "fmla v10.4h, %[b0a].4h, %[a0a].h[2]\n" + "fmla v11.4h, %[b0a].4h, %[a0a].h[3]\n" + "b 3f\n" + + // Odd tail + "2:\n" + "ldr %d[b0], [%[b_ptr]]\n" + "add %[a_ptr], %[a_ptr], #8\n" + "add %[b_ptr], %[b_ptr], #8\n" + + "fmla v8.4h, %[b0].4h, %[a0].h[0]\n" + "fmla v9.4h, %[b0].4h, %[a0].h[1]\n" + "fmla v10.4h, %[b0].4h, %[a0].h[2]\n" + "fmla v11.4h, %[b0].4h, %[a0].h[3]\n" + + "3:\n" STORE_C + : [a0] "+w"(a0), [a0a] "+w"(a0a), [b0] "+w"(b0), [k] "+r"(k), + [b0a] "+w"(b0a), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3) + : [oddk] "r"(oddk), [x_remain] "r"(x_remain), [y_remain] "r"(y_remain), + [type] "r"(type) + : "x0", "w1", "v8", "v9", "v10", "v11", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -2458,9 +2437,9 @@ void aarch64_hgemm_assembly_kernel_4x4(const __fp16* a_ptr, #undef STORE_C } -void aarch64_hgemm_asimd_8x24(const __fp16* Apanel, const __fp16* Bpanel, - __fp16* out, int ldout, int x0, int xmax, int y0, - int ymax, int K, bool is_first_k) { +void aarch64_hgemm_asimd_8x24( + const __fp16* Apanel, const __fp16* Bpanel, __fp16* out, int ldout, int x0, + int xmax, int y0, int ymax, int K, bool is_first_k) { const __fp16* a_ptr = Apanel; const int A_interleave = 8; const int B_transpose1xW = 24; @@ -2479,28 +2458,25 @@ void aarch64_hgemm_asimd_8x24(const __fp16* Apanel, const __fp16* Bpanel, for (; x + B_transpose1xW <= xmax; x += B_transpose1xW) { a_ptr = a_ptr0; - aarch64_hgemm_assembly_kernel_24x8(a_ptr, b_ptr, K, outptr0, ldout, - type); + aarch64_hgemm_assembly_kernel_24x8(a_ptr, b_ptr, K, outptr0, ldout, type); outptr0 += B_transpose1xW; } for (; x + 16 <= xmax; x += 16) { a_ptr = a_ptr0; - aarch64_hgemm_assembly_kernel_16x8(a_ptr, b_ptr, K, outptr0, ldout, - type); + aarch64_hgemm_assembly_kernel_16x8(a_ptr, b_ptr, K, outptr0, ldout, type); outptr0 += 16; } for (; x + 8 <= xmax; x += 8) { a_ptr = a_ptr0; - aarch64_hgemm_assembly_kernel_8x8(a_ptr, b_ptr, K, outptr0, ldout, - type); + aarch64_hgemm_assembly_kernel_8x8(a_ptr, b_ptr, K, outptr0, ldout, type); outptr0 += 8; } for (; x < xmax; x += 4) { int x_remain = xmax - x; a_ptr = a_ptr0; - aarch64_hgemm_assembly_kernel_4x8(a_ptr, b_ptr, K, outptr0, ldout, - x_remain, type); + aarch64_hgemm_assembly_kernel_4x8( + a_ptr, b_ptr, K, outptr0, ldout, x_remain, type); outptr0 += 4; } a_ptr = a_ptr0 + K8; @@ -2515,27 +2491,27 @@ void aarch64_hgemm_asimd_8x24(const __fp16* Apanel, const __fp16* Bpanel, int x = x0; for (; x + B_transpose1xW <= xmax; x += B_transpose1xW) { a_ptr = a_ptr0; - aarch64_hgemm_assembly_kernel_24x4(a_ptr, b_ptr, K, outptr0, ldout, - ymax - y, type); + aarch64_hgemm_assembly_kernel_24x4( + a_ptr, b_ptr, K, outptr0, ldout, ymax - y, type); outptr0 += B_transpose1xW; } for (; x + 16 <= xmax; x += 16) { a_ptr = a_ptr0; - aarch64_hgemm_assembly_kernel_16x4(a_ptr, b_ptr, K, outptr0, ldout, - ymax - y, type); + aarch64_hgemm_assembly_kernel_16x4( + a_ptr, b_ptr, K, outptr0, ldout, ymax - y, type); outptr0 += 16; } for (; x + 8 <= xmax; x += 8) { a_ptr = a_ptr0; - aarch64_hgemm_assembly_kernel_8x4(a_ptr, b_ptr, K, outptr0, ldout, - ymax - y, type); + aarch64_hgemm_assembly_kernel_8x4( + a_ptr, b_ptr, K, outptr0, ldout, ymax - y, type); outptr0 += 8; } for (; x < xmax; x += 4) { a_ptr = a_ptr0; - aarch64_hgemm_assembly_kernel_4x4(a_ptr, b_ptr, K, outptr0, ldout, - xmax - x, ymax - y, type); + aarch64_hgemm_assembly_kernel_4x4( + a_ptr, b_ptr, K, outptr0, ldout, xmax - x, ymax - y, type); outptr0 += 4; } a_ptr = a_ptr0 + K4; @@ -2545,45 +2521,48 @@ void aarch64_hgemm_asimd_8x24(const __fp16* Apanel, const __fp16* Bpanel, MEGDNN_REG_GEMM_STRATEGY_IMPL(hgemm_8x24); -void hgemm_8x24::pack_A(dt_float16* out, const dt_float16* in, int ldin, int y0, - int ymax, int k0, int kmax, bool transpose_A) const { +void hgemm_8x24::pack_A( + dt_float16* out, const dt_float16* in, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose_A) const { if (transpose_A) { - transpose_1x8(reinterpret_cast<__fp16*>(out), - reinterpret_cast(in), ldin, y0, ymax, k0, - kmax); + transpose_1x8( + reinterpret_cast<__fp16*>(out), reinterpret_cast(in), + ldin, y0, ymax, k0, kmax); } else { - interleave_8x1(reinterpret_cast<__fp16*>(out), - reinterpret_cast(in), ldin, y0, ymax, k0, - kmax); + interleave_8x1( + reinterpret_cast<__fp16*>(out), reinterpret_cast(in), + ldin, y0, ymax, k0, kmax); } } -void hgemm_8x24::pack_B(dt_float16* out, const dt_float16* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose_B) const { +void hgemm_8x24::pack_B( + dt_float16* out, const dt_float16* in, int ldin, int x0, int xmax, int k0, + int kmax, bool transpose_B) const { if (transpose_B) { - interleave_24x1(reinterpret_cast<__fp16*>(out), - reinterpret_cast(in), ldin, x0, xmax, k0, - kmax); + interleave_24x1( + reinterpret_cast<__fp16*>(out), reinterpret_cast(in), + ldin, x0, xmax, k0, kmax); } else { - transpose_1x24(reinterpret_cast<__fp16*>(out), - reinterpret_cast(in), ldin, x0, xmax, k0, - kmax); + transpose_1x24( + reinterpret_cast<__fp16*>(out), reinterpret_cast(in), + ldin, x0, xmax, k0, kmax); } } -void hgemm_8x24::kern(const dt_float16* packA, const dt_float16* packB, - size_t M, size_t N, size_t K, dt_float16* C, size_t LDC, - bool is_first_k, const dt_float16*, dt_float16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == C_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Float16); +void hgemm_8x24::kern( + const dt_float16* packA, const dt_float16* packB, size_t M, size_t N, size_t K, + dt_float16* C, size_t LDC, bool is_first_k, const dt_float16*, + dt_float16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float16); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); - aarch64_hgemm_asimd_8x24(reinterpret_cast(packA), - reinterpret_cast(packB), - reinterpret_cast<__fp16*>(C), LDC, 0, N, 0, M, K, - is_first_k); + aarch64_hgemm_asimd_8x24( + reinterpret_cast(packA), + reinterpret_cast(packB), reinterpret_cast<__fp16*>(C), LDC, + 0, N, 0, M, K, is_first_k); } #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp16/strategy.h b/dnn/src/aarch64/matrix_mul/fp16/strategy.h index f2ff6688..3d9db2c0 100644 --- a/dnn/src/aarch64/matrix_mul/fp16/strategy.h +++ b/dnn/src/aarch64/matrix_mul/fp16/strategy.h @@ -16,11 +16,11 @@ namespace megdnn { namespace aarch64 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false, - true, hgemm_8x24); +MEGDNN_REG_GEMM_STRATEGY( + dt_float16, dt_float16, dt_float16, 8, 24, 1, false, true, hgemm_8x24); -MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1, - false, true, gemm_nopack_f16_8x8); +MEGDNN_REG_GEMM_STRATEGY_NOPACK( + dt_float16, dt_float16, dt_float16, 8, 8, 1, false, true, gemm_nopack_f16_8x8); } // namespace matmul } // namespace aarch64 diff --git a/dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp b/dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp index 38bc1695..0d973493 100644 --- a/dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp +++ b/dnn/src/aarch64/matrix_mul/fp16/strategy_mk8_8x8.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/aarch64/matrix_mul/fp16/strategy.h" #include "src/aarch64/matrix_mul/asm/common.h" +#include "src/aarch64/matrix_mul/fp16/strategy.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" @@ -21,8 +21,9 @@ using namespace aarch64::matmul; namespace { -void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, - dt_float16* output) { +void kern_8x1( + const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, + dt_float16* output) { LDB *= sizeof(dt_float16); asm volatile( ".arch armv8.2-a+fp16\n" @@ -86,9 +87,8 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", - "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", - "memory"); + : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); } // Overview of register layout: @@ -115,8 +115,9 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, // |v23[0-7]| |v27[0-7]| // +--------+ +--------+ // Accumulator -void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, - dt_float16* output) { +void kern_8x4( + const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, + dt_float16* output) { //! LDB means number of elements in one block in B. we will read 24 numbers //! first. so minus 24 * 2 bytes here. LDB = (LDB - 24) * sizeof(dt_float16); @@ -263,8 +264,8 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "cc", "memory"); } // Overview of register layout: @@ -295,8 +296,9 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, // | v7[0-7]| |v31[0-7]| // +--------+ +--------+ // Accumulator -void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, - dt_float16* output) { +void kern_8x8( + const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, + dt_float16* output) { //! As each load 128 number from B, but the pos add 112 * 2, so we minus 112 //! here. LDB = (LDB - 32) * sizeof(dt_float16); @@ -467,20 +469,19 @@ void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", - "v28", "v29", "v30", "v31", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "cc", "memory"); } } // anonymous namespace MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8); -void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, - const dt_float16* B, size_t LDB, dt_float16* C, - size_t LDC, size_t M, size_t K, size_t N, - const dt_float16*, void*, bool trA, - bool trB) const { +void gemm_nopack_f16_8x8::kern( + const dt_float16* A, size_t LDA, const dt_float16* B, size_t LDB, dt_float16* C, + size_t LDC, size_t M, size_t K, size_t N, const dt_float16*, void*, bool trA, + bool trB) const { constexpr static size_t MB = 8; constexpr static size_t KB = 8; constexpr static size_t NB = 8; diff --git a/dnn/src/aarch64/matrix_mul/fp32/common.h b/dnn/src/aarch64/matrix_mul/fp32/common.h index 9892a490..d6ace968 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/common.h +++ b/dnn/src/aarch64/matrix_mul/fp32/common.h @@ -17,21 +17,23 @@ namespace megdnn { namespace aarch64 { -MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M, - size_t K, size_t LDA, const float* alpha); +MEGDNN_NOINLINE void sgemm_packA_n( + const float* A, float* Apacked, size_t M, size_t K, size_t LDA, + const float* alpha); -MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M, - size_t K, size_t LDA, const float* alpha); +MEGDNN_NOINLINE void sgemm_packA_t( + const float* A, float* Apacked, size_t M, size_t K, size_t LDA, + const float* alpha); -MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K, - size_t N, size_t LDB); +MEGDNN_NOINLINE void sgemm_packB_n( + const float* B, float* Bpacked, size_t K, size_t N, size_t LDB); -MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K, - size_t N, size_t LDB); +MEGDNN_NOINLINE void sgemm_packB_t( + const float* B, float* Bpacked, size_t K, size_t N, size_t LDB); -MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C, - size_t LDC, size_t M, size_t N, size_t K, - int type, const float* beta); +MEGDNN_NOINLINE void sgemm_kernel12x8( + const float* A, const float* B, float* C, size_t LDC, size_t M, size_t N, + size_t K, int type, const float* beta); } // namespace aarch64 } // namespace megdnn diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h index 0b7cc77d..9b8b0449 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_4x16.h @@ -12,7 +12,6 @@ #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" - namespace megdnn { namespace aarch64 { namespace matmul_general_4x16 { @@ -39,8 +38,9 @@ namespace matmul_general_4x16 { // +--+ - - - - +--------+--------+--------+--------+ // // Accumulator -void kern_4x16(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, int m_remain) { +void kern_4x16( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -224,14 +224,14 @@ void kern_4x16(const float* packA, const float* packB, int K, "6:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", - "x10", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", "x10", "cc", + "memory"); #undef LOAD_LINE #undef LOAD_C @@ -263,8 +263,9 @@ void kern_4x16(const float* packA, const float* packB, int K, // +--+--+ - - - - +--------+ // // Accumulator -void kern_4x4(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k, int m_remain, int n_remain) { +void kern_4x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -330,99 +331,100 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, STORE_LINE("6", "2") \ STORE_LINE("7", "3") \ "105:\n" -// clang-format on - - asm volatile( - // load accumulator C - "add x1, x0, %x[LDC]\n" - "add x2, x1, %x[LDC]\n" - "add x3, x2, %x[LDC]\n" - - "cmp %w[is_first_k], #1\n" - "beq 1f\n" LOAD_C - - "b 2f\n" - - "1:\n" - "eor v4.16b, v4.16b, v4.16b\n" - "eor v5.16b, v5.16b, v5.16b\n" - "eor v6.16b, v6.16b, v6.16b\n" - "eor v7.16b, v7.16b, v7.16b\n" - - "2: \n" - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "ld1 {v2.4s}, [%[b_ptr]], 16\n" - "cmp %w[K], #0\n" - "beq 4f\n" - - "3:\n" - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "ld1 {v3.4s}, [%[b_ptr]], 16\n" - "fmla v4.4s, v2.4s, v0.s[0]\n" - "fmla v5.4s, v2.4s, v0.s[1]\n" - "fmla v6.4s, v2.4s, v0.s[2]\n" - "fmla v7.4s, v2.4s, v0.s[3]\n" - - "ld1 {v0.4s}, [%[a_ptr]], 16\n" - "ld1 {v2.4s}, [%[b_ptr]], 16\n" - "fmla v4.4s, v3.4s, v1.s[0]\n" - "fmla v5.4s, v3.4s, v1.s[1]\n" - "fmla v6.4s, v3.4s, v1.s[2]\n" - "fmla v7.4s, v3.4s, v1.s[3]\n" - - "subs %w[K], %w[K], #1\n" - "bne 3b\n" - - "4:\n" - "cmp %w[oddk], #1\n" - "beq 5f\n" - - // Even tail - "ld1 {v1.4s}, [%[a_ptr]], 16\n" - "ld1 {v3.4s}, [%[b_ptr]], 16\n" - "fmla v4.4s, v2.4s, v0.s[0]\n" - "fmla v5.4s, v2.4s, v0.s[1]\n" - "fmla v6.4s, v2.4s, v0.s[2]\n" - "fmla v7.4s, v2.4s, v0.s[3]\n" - - "fmla v4.4s, v3.4s, v1.s[0]\n" - "fmla v5.4s, v3.4s, v1.s[1]\n" - "fmla v6.4s, v3.4s, v1.s[2]\n" - "fmla v7.4s, v3.4s, v1.s[3]\n" - - "b 6f\n" - - // odd tail - "5:\n" - "fmla v4.4s, v2.4s, v0.s[0]\n" - "fmla v5.4s, v2.4s, v0.s[1]\n" - "fmla v6.4s, v2.4s, v0.s[2]\n" - "fmla v7.4s, v2.4s, v0.s[3]\n" - - "6:\n" STORE_C - - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), - [n_remain] "+r"(n_remain), [outptr] "+r"(outptr) - : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", - "x2", "x3", "x10", "cc", "memory"); + // clang-format on + + asm volatile( + // load accumulator C + "add x1, x0, %x[LDC]\n" + "add x2, x1, %x[LDC]\n" + "add x3, x2, %x[LDC]\n" + + "cmp %w[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "eor v4.16b, v4.16b, v4.16b\n" + "eor v5.16b, v5.16b, v5.16b\n" + "eor v6.16b, v6.16b, v6.16b\n" + "eor v7.16b, v7.16b, v7.16b\n" + + "2: \n" + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "cmp %w[K], #0\n" + "beq 4f\n" + + "3:\n" + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v4.4s, v2.4s, v0.s[0]\n" + "fmla v5.4s, v2.4s, v0.s[1]\n" + "fmla v6.4s, v2.4s, v0.s[2]\n" + "fmla v7.4s, v2.4s, v0.s[3]\n" + + "ld1 {v0.4s}, [%[a_ptr]], 16\n" + "ld1 {v2.4s}, [%[b_ptr]], 16\n" + "fmla v4.4s, v3.4s, v1.s[0]\n" + "fmla v5.4s, v3.4s, v1.s[1]\n" + "fmla v6.4s, v3.4s, v1.s[2]\n" + "fmla v7.4s, v3.4s, v1.s[3]\n" + + "subs %w[K], %w[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %w[oddk], #1\n" + "beq 5f\n" + + // Even tail + "ld1 {v1.4s}, [%[a_ptr]], 16\n" + "ld1 {v3.4s}, [%[b_ptr]], 16\n" + "fmla v4.4s, v2.4s, v0.s[0]\n" + "fmla v5.4s, v2.4s, v0.s[1]\n" + "fmla v6.4s, v2.4s, v0.s[2]\n" + "fmla v7.4s, v2.4s, v0.s[3]\n" + + "fmla v4.4s, v3.4s, v1.s[0]\n" + "fmla v5.4s, v3.4s, v1.s[1]\n" + "fmla v6.4s, v3.4s, v1.s[2]\n" + "fmla v7.4s, v3.4s, v1.s[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "fmla v4.4s, v2.4s, v0.s[0]\n" + "fmla v5.4s, v2.4s, v0.s[1]\n" + "fmla v6.4s, v2.4s, v0.s[2]\n" + "fmla v7.4s, v2.4s, v0.s[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain), + [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", "x2", "x3", "x10", + "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE #undef STORE_C } -void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, - int ymax, int k0, int kmax) { +void sgemm_4x16_pack_A_n( + float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { float zerobuff[4]; std::memset(zerobuff, 0, sizeof(float) * 4); - constexpr int PACK_SIZE = 4*4; + constexpr int PACK_SIZE = 4 * 4; int y = y0; for (; y + 3 < ymax; y += 4) { - // printf("main loop pack_a_n %p \n",outptr); + // printf("main loop pack_a_n %p \n",outptr); const float* inptr0 = inptr + y * ldin + k0; const float* inptr1 = inptr0 + ldin; const float* inptr2 = inptr1 + ldin; @@ -459,9 +461,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, switch ((y + 3) - ymax) { /* Everything falls through in here */ case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -478,9 +482,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -493,8 +499,8 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, } } -void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, - int xmax, int k0, int kmax) { +void sgemm_4x16_pack_A_t( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { int ksize = kmax - k0; int ksize4 = (ksize << 2); float* outptr_base = out; @@ -515,8 +521,7 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, auto outptr = outptr_base; for (; x + 4 <= xmax; x += 4) { auto outptr_interleave = outptr; - interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize4; } @@ -546,8 +551,8 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, } } -void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, - int x0, int xmax, int k0, int kmax) { +void sgemm_4x16_pack_B_n( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { int ksize = kmax - k0; int ksize16 = ksize * 16; int ksize4 = (ksize << 2); @@ -570,15 +575,13 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, auto outptr = outptr_base; for (; x + 16 <= xmax; x += 16) { auto outptr_interleave = outptr; - interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize16; } outptr = outptr_base4; for (; x + 4 <= xmax; x += 4) { auto outptr_interleave = outptr; - interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize4; } @@ -616,8 +619,8 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, } } -void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, - int y0, int ymax, int k0, int kmax) { +void sgemm_4x16_pack_B_t( + float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { float* outptr = out; const float* inptr = in; float zerobuff[4]; @@ -642,8 +645,7 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, int x = (kmax - k0); for (; x > 3; x -= 4) { - transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, - 64); + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, 64); outptr_inner += 64; } for (; x > 0; x--) { @@ -676,9 +678,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, switch ((y + 3) - ymax) { /* Everything falls through in here */ case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -696,9 +700,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, switch ((y + 3) - ymax) { /* Everything falls through in here */ case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -711,8 +717,8 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, } } -} // matmul_general_4x16 -} // aarch64 -} // megdnn +} // namespace matmul_general_4x16 +} // namespace aarch64 +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h index 9f76c80a..c69ce9cb 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12.h @@ -43,8 +43,9 @@ struct matmul_general_8x12 { // +--+ --- - +--------+--------+--------+ // // Accumulator - static void kern_8x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k) { + static void kern_8x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -306,14 +307,13 @@ struct matmul_general_8x12 { "6:\n" : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", - "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", - "x6", "x7", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -348,9 +348,9 @@ struct matmul_general_8x12 { // +--+ --- - +--------+ // // Accumulator - static void kern_8x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int n_remain) { + static void kern_8x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int n_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -520,13 +520,12 @@ struct matmul_general_8x12 { "6:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr), - [n_remain] "+r"(n_remain) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", - "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", - "cc", "memory"); + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", + "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", + "memory"); #undef LOAD_LINE #undef LOAD_C @@ -557,9 +556,9 @@ struct matmul_general_8x12 { // +--+ --- - +--------+--------+--------+ // // Accumulator - static void kern_4x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int m_remain) { + static void kern_4x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -717,13 +716,12 @@ struct matmul_general_8x12 { "6:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr), - [m_remain] "+r"(m_remain) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [m_remain] "+r"(m_remain) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "x1", "x2", "x3", "x10", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", + "x2", "x3", "x10", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -754,9 +752,9 @@ struct matmul_general_8x12 { // +--+ --- - +--------+ // // Accumulator - static void kern_4x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { + static void kern_4x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -895,20 +893,21 @@ struct matmul_general_8x12 { "6:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr), - [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [n_remain] "+r"(n_remain), + [m_remain] "+r"(m_remain) : - : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", - "x3", "x10", "cc", "memory"); + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", + "x10", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE #undef STORE_C } - static void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, - int y0, int ymax, int k0, int kmax) { + static void sgemm_8x12_pack_A_n( + float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { float zerobuff[8]; std::memset(zerobuff, 0, sizeof(float) * 8); constexpr int PACK_SIZE_32 = 4 * 8; @@ -933,8 +932,9 @@ struct matmul_general_8x12 { prefetch_2x(inptr7); int x = (kmax - k0); for (; x > 3; x -= 4) { - transpose_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, - inptr5, inptr6, inptr7, outptr); + transpose_8x4_1_s( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += PACK_SIZE_32; } for (; x > 0; x--) { @@ -1004,8 +1004,8 @@ struct matmul_general_8x12 { } } - static void sgemm_8x12_pack_A_t(float* out, const float* in, int ldin, - int x0, int xmax, int k0, int kmax) { + static void sgemm_8x12_pack_A_t( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { int ksize = kmax - k0; int ksize8 = (ksize << 3); int ksize4 = (ksize << 2); @@ -1028,20 +1028,17 @@ struct matmul_general_8x12 { auto outptr = outptr_base; for (; x + 8 <= xmax; x += 8) { auto outptr_interleave = outptr; - interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize8; } outptr = outptr_base4; for (; x + 4 <= xmax; x += 4) { auto outptr_interleave = outptr; - interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize4; } if (x < xmax) { - interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, - xmax - x); + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); } outptr_base += 4 * 8; outptr_base4 += 4 * 4; @@ -1071,8 +1068,8 @@ struct matmul_general_8x12 { } } - static void sgemm_8x12_pack_B_n(float* out, const float* in, int ldin, - int x0, int xmax, int k0, int kmax) { + static void sgemm_8x12_pack_B_n( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { int ksize = kmax - k0; int ksize12 = ksize * 12; int ksize4 = (ksize << 2); @@ -1095,20 +1092,17 @@ struct matmul_general_8x12 { auto outptr = outptr_base; for (; x + 12 <= xmax; x += 12) { auto outptr_interleave = outptr; - interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize12; } outptr = outptr_base4; for (; x + 4 <= xmax; x += 4) { auto outptr_interleave = outptr; - interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize4; } if (x < xmax) { - interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, - xmax - x); + interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); } outptr_base += 12 * 4; outptr_base4 += 4 * 4; @@ -1138,8 +1132,8 @@ struct matmul_general_8x12 { } } - static void sgemm_8x12_pack_B_t(float* out, const float* in, int ldin, - int y0, int ymax, int k0, int kmax) { + static void sgemm_8x12_pack_B_t( + float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { float* outptr = out; const float* inptr = in; float zerobuff[12]; @@ -1172,9 +1166,9 @@ struct matmul_general_8x12 { prefetch_2x(inptr11); int x = (kmax - k0); for (; x > 3; x -= 4) { - transpose_12x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, - inptr5, inptr6, inptr7, inptr8, inptr9, - inptr10, inptr11, outptr); + transpose_12x4_1_s( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + inptr8, inptr9, inptr10, inptr11, outptr); outptr += 48; } for (; x > 0; x--) { diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h index 880c997f..6a237c5a 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h @@ -43,8 +43,9 @@ struct matmul_general_8x12_a53 { // +--+ --- - +--------+--------+--------+ // // Accumulator - static void kern_8x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k) { + static void kern_8x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -575,15 +576,14 @@ struct matmul_general_8x12_a53 { "6:\n" : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", - "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", - "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", + "x11", "x12", "x13", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C } @@ -615,9 +615,9 @@ struct matmul_general_8x12_a53 { // +--+ --- - +--------+ // // Accumulator - static void kern_8x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int n_remain) { + static void kern_8x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int n_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -856,13 +856,12 @@ struct matmul_general_8x12_a53 { "6:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr), - [n_remain] "+r"(n_remain) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", - "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", - "x8", "x9", "x10", "cc", "memory"); + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", + "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", + "x10", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -893,9 +892,9 @@ struct matmul_general_8x12_a53 { // +--+ --- - +--------+--------+--------+ // // Accumulator - static void kern_4x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int m_remain) { + static void kern_4x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -1133,14 +1132,12 @@ struct matmul_general_8x12_a53 { "6:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr), - [m_remain] "+r"(m_remain) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [m_remain] "+r"(m_remain) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "x1", "x2", "x3", "x8", "x9", "x10", "x20", "x21", - "x22", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", + "x2", "x3", "x8", "x9", "x10", "x20", "x21", "x22", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -1171,9 +1168,9 @@ struct matmul_general_8x12_a53 { // +--+ --- - +--------+ // // Accumulator - static void kern_4x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { + static void kern_4x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -1312,12 +1309,12 @@ struct matmul_general_8x12_a53 { "6:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr), - [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [n_remain] "+r"(n_remain), + [m_remain] "+r"(m_remain) : - : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", - "x3", "x10", "cc", "memory"); + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", + "x10", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h index 21837d73..1b12e424 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_general_8x12_a55.h @@ -43,8 +43,9 @@ struct matmul_general_8x12_a55 { // +--+ --- - +--------+--------+--------+ // // Accumulator - static void kern_8x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k) { + static void kern_8x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -525,15 +526,14 @@ struct matmul_general_8x12_a55 { "6:\n" : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", - "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", - "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", + "x11", "x12", "x13", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C } @@ -565,9 +565,9 @@ struct matmul_general_8x12_a55 { // +--+ --- - +--------+ // // Accumulator - static void kern_8x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int n_remain) { + static void kern_8x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int n_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -742,13 +742,12 @@ struct matmul_general_8x12_a55 { "6:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr), - [n_remain] "+r"(n_remain) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", - "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", - "x10", "cc", "memory"); + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", + "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x10", "cc", + "memory"); #undef LOAD_LINE #undef LOAD_C @@ -779,9 +778,9 @@ struct matmul_general_8x12_a55 { // +--+ --- - +--------+--------+--------+ // // Accumulator - static void kern_4x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int m_remain) { + static void kern_4x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -972,14 +971,12 @@ struct matmul_general_8x12_a55 { "6:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr), - [m_remain] "+r"(m_remain) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [m_remain] "+r"(m_remain) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "x1", "x2", "x3", "x10", "x20", "x21", "x22", "cc", - "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", + "x2", "x3", "x10", "x20", "x21", "x22", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -1010,9 +1007,9 @@ struct matmul_general_8x12_a55 { // +--+ --- - +--------+ // // Accumulator - static void kern_4x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { + static void kern_4x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -1151,12 +1148,12 @@ struct matmul_general_8x12_a55 { "6:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [outptr] "+r"(outptr), - [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [outptr] "+r"(outptr), [n_remain] "+r"(n_remain), + [m_remain] "+r"(m_remain) : - : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", - "x3", "x10", "cc", "memory"); + : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", + "x10", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h index 4360d502..46fbefec 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h @@ -44,8 +44,9 @@ struct matmul_mk4_8x12 { // +--+ --- - +--------+--------+--------+ // // Accumulator - static void kern_8x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k) { + static void kern_8x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k) { const float* a_ptr = packA; const float* b_ptr = packB; float* output0 = output; @@ -307,10 +308,10 @@ struct matmul_mk4_8x12 { [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output0] "+r"(output0), [output1] "+r"(output1) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", - "v28", "v29", "v30", "v31", "x1", "x2", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "x1", "x2", "cc", "memory"); } // Overview of register layout: @@ -340,9 +341,9 @@ struct matmul_mk4_8x12 { // +--+ --- - +--------+ // // Accumulator - static void kern_8x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int n_remain) { + static void kern_8x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int n_remain) { const float* a_ptr = packA; const float* b_ptr = packB; float* output0 = output; @@ -500,8 +501,8 @@ struct matmul_mk4_8x12 { [output0] "+r"(output0), [output1] "+r"(output1), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "cc", "memory"); #undef LOAD_C #undef STORE_C @@ -531,8 +532,9 @@ struct matmul_mk4_8x12 { // // Accumulator - static void kern_4x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k) { + static void kern_4x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k) { MEGDNN_MARK_USED_VAR(LDC); const float* a_ptr = packA; const float* b_ptr = packB; @@ -669,9 +671,9 @@ struct matmul_mk4_8x12 { [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output0] "+r"(output0) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "x1", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", + "cc", "memory"); } // Overview of register layout: @@ -697,9 +699,9 @@ struct matmul_mk4_8x12 { // +--+ --- - +--------+ // // Accumulator - static void kern_4x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int n_remain) { + static void kern_4x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int n_remain) { MEGDNN_MARK_USED_VAR(LDC); const float* a_ptr = packA; const float* b_ptr = packB; @@ -818,15 +820,15 @@ struct matmul_mk4_8x12 { [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output0] "+r"(output0), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", - "memory"); + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); #undef LOAD_C #undef STORE_C } - static void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin, - int y0, int ymax, int k0, int kmax) { + static void sgemm_8x12_pack_A( + float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); constexpr int PACK_SIZE_32 = 4 * 8; @@ -855,8 +857,8 @@ struct matmul_mk4_8x12 { } } - static void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0, - int xmax, int k0, int kmax) { + static void sgemm_8x12_pack_B( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); float tmpbuff[16] = {0.0f}; @@ -886,8 +888,7 @@ struct matmul_mk4_8x12 { outptr += ksize4; } if (x < xmax) { - std::memcpy(tmpbuff, inptr, - sizeof(float) * (xmax - x) * PACK_C_SIZE); + std::memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE); auto outptr_interleave = outptr; const float* tmp_ptr = &tmpbuff[0]; transpose_1x4_4_s(tmp_ptr, outptr_interleave); diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h index 5f438c57..faae5718 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h @@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a53 { // +--+ --- - +--------+--------+--------+ // // Accumulator - static void kern_8x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k) { + static void kern_8x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k) { const float* a_ptr = packA; const float* b_ptr = packB; float* output0 = output; @@ -553,11 +554,11 @@ struct matmul_mk4_8x12_a53 { [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output0] "+r"(output0), [output1] "+r"(output1) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", - "v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", - "x11", "x12", "x13", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory"); } // Overview of register layout: @@ -587,9 +588,9 @@ struct matmul_mk4_8x12_a53 { // +--+ --- - +--------+ // // Accumulator - static void kern_8x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int n_remain) { + static void kern_8x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int n_remain) { const float* a_ptr = packA; const float* b_ptr = packB; float* output0 = output; @@ -831,8 +832,8 @@ struct matmul_mk4_8x12_a53 { [output0] "+r"(output0), [output1] "+r"(output1), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "x8", "x9", "x10", "cc", "memory"); #undef LOAD_C #undef STORE_C @@ -862,8 +863,9 @@ struct matmul_mk4_8x12_a53 { // // Accumulator - static void kern_4x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k) { + static void kern_4x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k) { MEGDNN_MARK_USED_VAR(LDC); const float* a_ptr = packA; const float* b_ptr = packB; @@ -1098,9 +1100,9 @@ struct matmul_mk4_8x12_a53 { [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output0] "+r"(output0) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", + "x8", "x9", "x10", "x11", "x12", "cc", "memory"); } // Overview of register layout: @@ -1126,9 +1128,9 @@ struct matmul_mk4_8x12_a53 { // +--+ --- - +--------+ // // Accumulator - static void kern_4x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int n_remain) { + static void kern_4x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int n_remain) { MEGDNN_MARK_USED_VAR(LDC); const float* a_ptr = packA; const float* b_ptr = packB; @@ -1246,8 +1248,7 @@ struct matmul_mk4_8x12_a53 { [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output0] "+r"(output0), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", - "memory"); + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); #undef LOAD_C #undef STORE_C diff --git a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h index a6de35f0..7fe5c8b2 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h +++ b/dnn/src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h @@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a55 { // +--+ --- - +--------+--------+--------+ // // Accumulator - static void kern_8x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k) { + static void kern_8x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k) { const float* a_ptr = packA; const float* b_ptr = packB; float* output0 = output; @@ -519,11 +520,11 @@ struct matmul_mk4_8x12_a55 { [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output0] "+r"(output0), [output1] "+r"(output1) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", - "v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", - "x11", "x12", "x13", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory"); } // Overview of register layout: @@ -553,9 +554,9 @@ struct matmul_mk4_8x12_a55 { // +--+ --- - +--------+ // // Accumulator - static void kern_8x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int n_remain) { + static void kern_8x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int n_remain) { const float* a_ptr = packA; const float* b_ptr = packB; float* output0 = output; @@ -749,8 +750,8 @@ struct matmul_mk4_8x12_a55 { [output0] "+r"(output0), [output1] "+r"(output1), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "x8", "x9", "x10", "cc", "memory"); #undef LOAD_C #undef STORE_C @@ -780,8 +781,9 @@ struct matmul_mk4_8x12_a55 { // // Accumulator - static void kern_4x12(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k) { + static void kern_4x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k) { MEGDNN_MARK_USED_VAR(LDC); const float* a_ptr = packA; const float* b_ptr = packB; @@ -997,9 +999,9 @@ struct matmul_mk4_8x12_a55 { [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output0] "+r"(output0) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", + "x8", "x9", "x10", "x11", "x12", "cc", "memory"); } // Overview of register layout: @@ -1025,9 +1027,9 @@ struct matmul_mk4_8x12_a55 { // +--+ --- - +--------+ // // Accumulator - static void kern_4x4(const float* packA, const float* packB, int K, - float* output, int LDC, bool is_first_k, - int n_remain) { + static void kern_4x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int n_remain) { MEGDNN_MARK_USED_VAR(LDC); const float* a_ptr = packA; const float* b_ptr = packB; @@ -1146,8 +1148,7 @@ struct matmul_mk4_8x12_a55 { [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output0] "+r"(output0), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", - "memory"); + : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); #undef LOAD_C #undef STORE_C diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp b/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp index 9b5d5138..a369abd6 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp @@ -10,6 +10,7 @@ * implied. */ +#include "src/aarch64/matrix_mul/fp32/strategy.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h" @@ -17,44 +18,40 @@ #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h" #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h" -#include "src/aarch64/matrix_mul/fp32/strategy.h" #include "src/common/utils.h" - using namespace megdnn; using namespace aarch64; using namespace aarch64::matmul; MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); -void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax, - int k0, int kmax, bool transpose_A) const { +void sgemm_4x16::pack_A( + float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose_A) const { if (transpose_A) { - matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, - kmax); + matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); } else { - matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, - kmax); + matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } } -void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, - int k0, int kmax, bool transpose_B) const { +void sgemm_4x16::pack_B( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose_B) const { if (transpose_B) { - matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, - kmax); + matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, - size_t N, size_t K, float* C, size_t LDC, bool is_first_k, - const float*, float*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == C_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Float32); +void sgemm_4x16::kern( + const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, + size_t LDC, bool is_first_k, const float*, float*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); @@ -71,9 +68,9 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, size_t n = 0; const float* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, - is_first_k, - std::min(M - m, 4)); + matmul_general_4x16::kern_4x16( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K16; } @@ -92,32 +89,30 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); -void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax, - int k0, int kmax, bool transpose_A) const { +void sgemm_8x12::pack_A( + float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose_A) const { if (transpose_A) { - matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, - kmax); + matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); } else { - matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, - kmax); + matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } } -void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, - int k0, int kmax, bool transpose_B) const { +void sgemm_8x12::pack_B( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose_B) const { if (transpose_B) { - matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, - kmax); + matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } template -static inline void sgemm_8x12_helper(const float* packA, const float* packB, - size_t M, size_t N, size_t K, float* C, - size_t LDC, bool is_first_k) { +static inline void sgemm_8x12_helper( + const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, + size_t LDC, bool is_first_k) { constexpr size_t A_INTERLEAVE = 8; constexpr size_t A_INTERLEAVE4 = 4; constexpr size_t B_INTERLEAVE = 12; @@ -138,8 +133,9 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, } for (; n < N; n += 4) { - gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(N - n, 4)); + gemm_class::kern_8x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -150,16 +146,17 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, size_t n = 0; const float* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4)); + gemm_class::kern_4x12( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K12; } for (; n < N; n += 4) { - gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), - std::min(N - n, 4)); + gemm_class::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -167,56 +164,55 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, } } -void sgemm_8x12::kern(const float* packA, const float* packB, size_t M, - size_t N, size_t K, float* C, size_t LDC, bool is_first_k, - const float*, float*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == C_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Float32); +void sgemm_8x12::kern( + const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, + size_t LDC, bool is_first_k, const float*, float*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); #if !MGB_ENABLE_CPUINFO - sgemm_8x12_helper(packA, packB, M, N, K, C, LDC, - is_first_k); + sgemm_8x12_helper(packA, packB, M, N, K, C, LDC, is_first_k); #else auto arch = cpuinfo_get_current_core()->uarch; #ifdef __IN_TEE_ENV__ arch = cpuinfo_uarch_unknown; #endif if (arch == cpuinfo_uarch_cortex_a53) { - sgemm_8x12_helper(packA, packB, M, N, K, C, - LDC, is_first_k); + sgemm_8x12_helper( + packA, packB, M, N, K, C, LDC, is_first_k); } else if (arch == cpuinfo_uarch_cortex_a55) { - sgemm_8x12_helper(packA, packB, M, N, K, C, - LDC, is_first_k); + sgemm_8x12_helper( + packA, packB, M, N, K, C, LDC, is_first_k); } else { - sgemm_8x12_helper(packA, packB, M, N, K, C, LDC, - is_first_k); + sgemm_8x12_helper( + packA, packB, M, N, K, C, LDC, is_first_k); } #endif } MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); -void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0, - int ymax, int k0, int kmax, - bool transpose_A) const { +void sgemm_mk4_8x12::pack_A( + float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose_A) const { megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A"); matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax); } -void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0, - int xmax, int k0, int kmax, - bool transpose_B) const { +void sgemm_mk4_8x12::pack_B( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose_B) const { megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B"); matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); } template -static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, - size_t M, size_t N, size_t K, float* C, - size_t LDC, bool is_first_k) { +static inline void sgemm_mk4_8x12_helper( + const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, + size_t LDC, bool is_first_k) { const int K12 = K * 12; const int K8 = K * 8; const int K4 = K * 4; @@ -237,8 +233,9 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, } for (; n < N; n += 4) { - gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(N - n, 4)); + gemm_name::kern_8x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 4 * PACK_C_SIZE; cur_packB += K4; } @@ -254,41 +251,41 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, cur_packB += K12; } for (; n < N; n += 4) { - gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(N - n, 4)); + gemm_name::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 4 * PACK_C_SIZE; cur_packB += K4; } packA += K4; } } -void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M, - size_t N, size_t K, float* C, size_t LDC, - bool is_first_k, const float*, float*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == C_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Float32); +void sgemm_mk4_8x12::kern( + const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, + size_t LDC, bool is_first_k, const float*, float*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); #if !MGB_ENABLE_CPUINFO - sgemm_mk4_8x12_helper(packA, packB, M, N, K, C, LDC, - is_first_k); + sgemm_mk4_8x12_helper(packA, packB, M, N, K, C, LDC, is_first_k); #else auto arch = cpuinfo_get_current_core()->uarch; #ifdef __IN_TEE_ENV__ arch = cpuinfo_uarch_unknown; #endif if (arch == cpuinfo_uarch_cortex_a53) { - sgemm_mk4_8x12_helper(packA, packB, M, N, K, C, - LDC, is_first_k); + sgemm_mk4_8x12_helper( + packA, packB, M, N, K, C, LDC, is_first_k); } else if (arch == cpuinfo_uarch_cortex_a55) { - sgemm_mk4_8x12_helper(packA, packB, M, N, K, C, - LDC, is_first_k); + sgemm_mk4_8x12_helper( + packA, packB, M, N, K, C, LDC, is_first_k); } else { - sgemm_mk4_8x12_helper(packA, packB, M, N, K, C, LDC, - is_first_k); + sgemm_mk4_8x12_helper( + packA, packB, M, N, K, C, LDC, is_first_k); } #endif } diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy.h b/dnn/src/aarch64/matrix_mul/fp32/strategy.h index 3b7980b8..097277c9 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/strategy.h +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy.h @@ -15,17 +15,14 @@ namespace megdnn { namespace aarch64 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, - sgemm_8x12); +MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); -MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, - sgemm_4x16); +MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, sgemm_4x16); -MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, - sgemm_mk4_8x12); +MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, sgemm_mk4_8x12); -MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true, - sgemm_nopack_4x16); +MEGDNN_REG_GEMM_STRATEGY_NOPACK( + float, float, float, 4, 16, 1, false, true, sgemm_nopack_4x16); } // namespace matmul } // namespace aarch64 diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp b/dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp index 253178b5..433ae4de 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy_mk4_4x16.cpp @@ -20,8 +20,8 @@ using namespace aarch64::matmul; namespace { -void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, - float* output) { +void kern_4x1( + const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { LDB *= sizeof(float); asm volatile( "subs %w[K], %w[K], #4\n" @@ -64,8 +64,7 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", - "memory"); + : "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", "memory"); } // Overview of register layout: @@ -89,8 +88,8 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, // +--------+ - - - - -+--------+ // Accumulator -void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, - float* output) { +void kern_4x4( + const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { //! As each load 16 number from B, but the pos add 12 * 4, so we minus 12 //! here. LDB = (LDB - 12) * sizeof(float); @@ -165,8 +164,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", - "v18", "v19", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", + "v19", "cc", "memory"); } // Overview of register layout: @@ -195,8 +194,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, // +--------+ - - - - -+--------+ // Accumulator -void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, - float* output) { +void kern_4x8( + const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { //! As each load 32 number from B, but the pos add 24 * 4, so we minus 24 //! here. LDB = (LDB - 24) * sizeof(float); @@ -304,9 +303,9 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", - "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", - "v27", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "cc", + "memory"); } // Overview of register layout: @@ -342,8 +341,7 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, // +--------+ // Accumulator -void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, - float* output) { +void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, float* output) { //! As each load 64 number from B, but the pos add 56 * 4, so we minus 56 //! here. LDB = (LDB - 56) * sizeof(float); @@ -565,20 +563,18 @@ void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", - "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", - "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); } } // namespace MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16); -void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, - size_t LDB, float* C, size_t LDC, size_t M, - size_t K, size_t N, const float*, void*, bool trA, - bool trB) const { +void sgemm_nopack_4x16::kern( + const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC, + size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const { constexpr static size_t MB = 4; constexpr static size_t KB = 4; constexpr static size_t NB = 16; diff --git a/dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h b/dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h index 3c4d9605..256d55fd 100644 --- a/dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h +++ b/dnn/src/aarch64/matrix_mul/int16/kernel_12x8x1.h @@ -46,8 +46,9 @@ namespace matmul_12x8x1 { * Accumulator */ -static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { +static void kern_12x8( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k) { const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; @@ -155,15 +156,13 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, "stp q25, q26, [x9]\n" "stp q27, q28, [x10]\n" "stp q29, q30, [x11]\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [output] "+r"(output) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) : - : "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", - "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", - "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "x1", - "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", - "cc", "memory"); + : "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "x1", "x2", "x3", "x4", "x5", + "x6", "x7", "x8", "x9", "x10", "x11", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE @@ -196,8 +195,9 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, * Accumulator */ -static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { +static void kern_8x8( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k) { const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; @@ -276,13 +276,12 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, "stp q17, q18, [x5]\n" "stp q19, q20, [x6]\n" "stp q21, q22, [x7]\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [output] "+r"(output) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) : - : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", - "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); + : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", "x2", "x3", "x4", + "x5", "x6", "x7", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE @@ -311,9 +310,9 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, * Accumulator */ -static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - size_t m_remain) { +static void kern_4x8( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t m_remain) { const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; @@ -388,14 +387,13 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), - [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), - [m_remain] "+r"(m_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [x0] "+r"(x0), [m_remain] "+r"(m_remain) : - : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "cc", "memory"); + : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "cc", + "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE @@ -432,9 +430,9 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, * Accumulator */ -static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - size_t n_remain) { +static void kern_12x4( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t n_remain) { const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; @@ -573,18 +571,16 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), - [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), - [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), - [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), - [outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9), - [outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), - [x0] "+r"(x0), [n_remain] "+r"(n_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), + [outptr7] "=r"(outptr7), [outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9), + [outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), [x0] "+r"(x0), + [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "cc", "memory"); + : "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -618,9 +614,9 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, * Accumulator */ -static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - size_t n_remain) { +static void kern_8x4( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t n_remain) { const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; @@ -734,16 +730,14 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), - [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), - [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), - [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), - [n_remain] "+r"(n_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), + [outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) : - : "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", - "cc", "memory"); + : "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc", + "memory"); #undef LOAD_LINE #undef LOAD_C @@ -773,9 +767,9 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, * Accumulator */ -static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, size_t m_remain, - size_t n_remain) { +static void kern_4x4( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t m_remain, size_t n_remain) { const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; @@ -874,11 +868,10 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), - [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), - [m_remain] "+r"(m_remain), [x1] "+r"(x1), + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [x0] "+r"(x0), [m_remain] "+r"(m_remain), [x1] "+r"(x1), [n_remain] "+r"(n_remain) : : "v0", "v2", "v8", "v9", "v10", "v11", "cc", "memory"); @@ -889,9 +882,9 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, #undef STORE_C } -static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_s16_12x8x1_pack_A_n( + int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int16_t zerobuff[4]; std::memset(zerobuff, 0, sizeof(int16_t) * 4); @@ -925,15 +918,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, int K = kmax - k0; for (; K > 3; K -= 4) { - interleave_12x1_4_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, inptr8, inptr9, inptr10, - inptr11, outptr); + interleave_12x1_4_h( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + inptr8, inptr9, inptr10, inptr11, outptr); } if (K > 0) { - interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, - outptr, 1, K); + interleave_12( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + inptr8, inptr9, inptr10, inptr11, outptr, 1, K); } } @@ -949,13 +942,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, int K = kmax - k0; for (; K > 7; K -= 8) { - interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + interleave_8x1_8_h( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 1, K); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 1, K); } } @@ -975,9 +970,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -992,9 +989,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -1007,9 +1006,8 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, } } -static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in, - int ldin, int x0, int xmax, - int k0, int kmax) { +static void gemm_s16_12x8x1_transpose_pack_A_n( + int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) { const int ksize = kmax - k0; const int ksize4 = ksize * 4; const int ksize8 = ksize4 * 2; @@ -1054,8 +1052,8 @@ static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in, } } -static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s16_12x8x1_pack_B_n( + int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) { const int ksize = kmax - k0; const int ksize4 = ksize * 4; const int ksize8 = ksize4 * 2; @@ -1090,10 +1088,9 @@ static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin, } } -static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, - const int16_t* inptr, int ldin, - int y0, int ymax, int k0, - int kmax) { +static void gemm_s16_12x8x1_transpose_pack_B_n( + int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int16_t zerobuff[4]; std::memset(zerobuff, 0, sizeof(int16_t) * 4); @@ -1110,13 +1107,15 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, int K = kmax - k0; for (; K > 7; K -= 8) { - interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + interleave_8x1_8_h( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 1, K); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 1, K); } } @@ -1136,9 +1135,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -1153,9 +1154,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; diff --git a/dnn/src/aarch64/matrix_mul/int16/strategy.cpp b/dnn/src/aarch64/matrix_mul/int16/strategy.cpp index 29c150cf..c31d7d1e 100644 --- a/dnn/src/aarch64/matrix_mul/int16/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/int16/strategy.cpp @@ -22,39 +22,37 @@ using namespace aarch64::matmul; ///////////////////////// gemm_s16_12x8x1 //////////////////////////////////// MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1); -void gemm_s16_12x8x1::pack_A(dt_int16* outptr, const dt_int16* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +void gemm_s16_12x8x1::pack_A( + dt_int16* outptr, const dt_int16* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { if (transpose) { - matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(outptr, inptr, ldin, - y0, ymax, k0, kmax); + matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n( + outptr, inptr, ldin, y0, ymax, k0, kmax); } else { - matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(outptr, inptr, ldin, y0, ymax, - k0, kmax); + matmul_12x8x1::gemm_s16_12x8x1_pack_A_n( + outptr, inptr, ldin, y0, ymax, k0, kmax); } } -void gemm_s16_12x8x1::pack_B(dt_int16* out, const dt_int16* in, int ldin, - int x0, int xmax, int k0, int kmax, - bool transpose) const { +void gemm_s16_12x8x1::pack_B( + dt_int16* out, const dt_int16* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(out, in, ldin, x0, - xmax, k0, kmax); + matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n( + out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, - size_t M, size_t N, size_t K, dt_int32* C, - size_t LDC, bool is_first_k, const dt_int32*, - dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - (A_dtype.enumv() == DTypeEnum::Int16 && - C_dtype.enumv() == DTypeEnum::Int32), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s16_12x8x1::kern( + const dt_int16* packA, const dt_int16* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + (A_dtype.enumv() == DTypeEnum::Int16 && + C_dtype.enumv() == DTypeEnum::Int32), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); @@ -72,15 +70,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, size_t n = 0; const dt_int16* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, - is_first_k); + matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_12x8x1::kern_12x4(packA, cur_packB, K, output, LDC, - is_first_k, std::min(N - n, 4)); + matmul_12x8x1::kern_12x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -92,15 +90,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, const dt_int16* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, - is_first_k); + matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_12x8x1::kern_8x4(packA, cur_packB, K, output, LDC, - is_first_k, std::min(N - n, 4)); + matmul_12x8x1::kern_8x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -112,16 +110,17 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, const dt_int16* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_12x8x1::kern_4x8(packA, cur_packB, K, output, LDC, - is_first_k, std::min(M - m, 4)); + matmul_12x8x1::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_12x8x1::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, std::min(M - m, 4), - std::min(N - n, 4)); + matmul_12x8x1::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); output += 4; cur_packB += K4; } diff --git a/dnn/src/aarch64/matrix_mul/int16/strategy.h b/dnn/src/aarch64/matrix_mul/int16/strategy.h index 22f5476a..c0881b0a 100644 --- a/dnn/src/aarch64/matrix_mul/int16/strategy.h +++ b/dnn/src/aarch64/matrix_mul/int16/strategy.h @@ -16,11 +16,11 @@ namespace megdnn { namespace aarch64 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, - gemm_s16_12x8x1); +MEGDNN_REG_GEMM_STRATEGY( + dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, gemm_s16_12x8x1); -MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false, - true, gemm_nopack_s16_8x8); +MEGDNN_REG_GEMM_STRATEGY_NOPACK( + dt_int16, dt_int32, dt_int32, 8, 8, 1, false, true, gemm_nopack_s16_8x8); } // namespace matmul } // namespace aarch64 diff --git a/dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp b/dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp index 644e1231..879f4c86 100644 --- a/dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp +++ b/dnn/src/aarch64/matrix_mul/int16/strategy_mk8_8x8.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/aarch64/matrix_mul/int16/strategy.h" #include "src/aarch64/matrix_mul/asm/common.h" +#include "src/aarch64/matrix_mul/int16/strategy.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" @@ -20,8 +20,9 @@ using namespace aarch64::matmul; namespace { -void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, - dt_int32* output) { +void kern_8x1( + const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, + dt_int32* output) { //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 //! here. LDB *= sizeof(dt_int16); @@ -91,9 +92,8 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "v0", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "cc", "memory"); + : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); } // Overview of register layout: @@ -120,8 +120,9 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, // | v31[0-7]| |v23[0-3]| // +---------+ +--------+ // Accumulator -void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, - dt_int32* output) { +void kern_8x4( + const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, + dt_int32* output) { //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 //! here. LDB = (LDB - 24) * sizeof(dt_int16); @@ -349,9 +350,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", + "memory"); } // Overview of register layout: @@ -382,8 +383,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, // | v7[0-7]| |v30[0-3]|v31[0-3]| // +--------+ +--------+--------+ // Accumulator -void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, - dt_int32* output) { +void kern_8x8( + const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, + dt_int32* output) { //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 //! here. LDB = (LDB - 48) * sizeof(dt_int16); @@ -693,20 +695,20 @@ void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "cc", "memory"); } } // anonymous namespace MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8); -void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, - size_t LDB, dt_int32* C, size_t LDC, size_t M, - size_t K, size_t N, const dt_int32*, void*, - bool trA, bool trB) const { +void gemm_nopack_s16_8x8::kern( + const dt_int16* A, size_t LDA, const dt_int16* B, size_t LDB, dt_int32* C, + size_t LDC, size_t M, size_t K, size_t N, const dt_int32*, void*, bool trA, + bool trB) const { constexpr static size_t MB = 8; constexpr static size_t KB = 8; constexpr static size_t NB = 8; diff --git a/dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h b/dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h index 1a5f5119..0b5d4b89 100644 --- a/dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h +++ b/dnn/src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h @@ -36,9 +36,9 @@ namespace matmul_s4_4x4x16 { * Accumulator */ -static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { +static void s4_kern_8x8_remain( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { K /= 8; LDC = LDC * sizeof(int16_t); const int8_t* a_ptr = packA; @@ -170,7 +170,7 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, "dup v5.8b,v20.b[5]\n" "dup v6.8b,v20.b[6]\n" "dup v7.8b,v20.b[7]\n" - + "ld1 {v17.8b}, [%[b_ptr]], 8\n" "dup v8.8b,v20.b[8]\n" @@ -318,16 +318,16 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, STORE_C : - [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), - [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), - [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), - [ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), + [m_remain] "+r"(m_remain), + [n_remain] "+r"( + n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) : - : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31"); + : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", + "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); #undef LOAD_LINE #undef LOAD_C @@ -335,14 +335,14 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { +static void s4_kern_8x8( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { K /= 8; LDC = LDC * sizeof(int16_t); const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; -// clang-format off + // clang-format off #define LOAD_C_8 \ "ld1 {v24.8h}, [x0], #16\n" \ @@ -363,9 +363,9 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, "st1 {v28.8h}, [x4], #16\n" \ "st1 {v29.8h}, [x5], #16\n" \ "st1 {v30.8h}, [x6], #16\n" \ - "st1 {v31.8h}, [x7], #16\n" \ + "st1 {v31.8h}, [x7], #16\n" -// clang-format on + // clang-format on register int16_t* outptr asm("x0") = output; asm volatile( "add x1, x0, %x[LDC]\n" @@ -395,8 +395,8 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, "PRFM PLDL1KEEP, [%[a_ptr], #512]\n" "PRFM PLDL1KEEP, [%[b_ptr], #512]\n" "1:\n" - // "ld1 {v20.16b}, [%[a_ptr]],#16\n" - // "ld1 {v21.16b}, [%[a_ptr]],#16\n" + // "ld1 {v20.16b}, [%[a_ptr]],#16\n" + // "ld1 {v21.16b}, [%[a_ptr]],#16\n" "dup v0.8b,v20.b[0]\n" "ld1 {v22.16b}, [%[a_ptr]],#16\n" "dup v1.8b,v20.b[1]\n" @@ -409,7 +409,6 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, "dup v5.8b,v20.b[5]\n" "dup v6.8b,v20.b[6]\n" "dup v7.8b,v20.b[7]\n" - "dup v8.8b,v20.b[8]\n" "smlal v24.8h, v0.8b, v16.8b\n" @@ -560,26 +559,26 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, STORE_C_8 : - [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), - [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), - [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), - [ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) + [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), + [m_remain] "+r"(m_remain), + [n_remain] "+r"( + n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) : - : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31"); + : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", + "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE #undef STORE_C } -//packa -static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +// packa +static void gemm_s4x4x16_8x8x8_transpose_pack( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[8]; int8_t tmpbuff0[8]; int8_t tmpbuff1[8]; @@ -617,22 +616,23 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in prefetch_2x(inptr5); prefetch_2x(inptr6); prefetch_2x(inptr7); - int K = (kmax - k0)/2; + int K = (kmax - k0) / 2; //! read 4 * 16 in each row for (; K > 3; K -= 4) { - transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, - inptr5, inptr6, inptr7, outptr); + transpose_4x8_1_b_with_shift( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { - std::memcpy(tmpbuff0,inptr0,K); - std::memcpy(tmpbuff1,inptr1,K); - std::memcpy(tmpbuff2,inptr2,K); - std::memcpy(tmpbuff3,inptr3,K); - std::memcpy(tmpbuff4,inptr4,K); - std::memcpy(tmpbuff5,inptr5,K); - std::memcpy(tmpbuff6,inptr6,K); - std::memcpy(tmpbuff7,inptr7,K); + std::memcpy(tmpbuff0, inptr0, K); + std::memcpy(tmpbuff1, inptr1, K); + std::memcpy(tmpbuff2, inptr2, K); + std::memcpy(tmpbuff3, inptr3, K); + std::memcpy(tmpbuff4, inptr4, K); + std::memcpy(tmpbuff5, inptr5, K); + std::memcpy(tmpbuff6, inptr6, K); + std::memcpy(tmpbuff7, inptr7, K); inptr0 = tmpbuff0; inptr1 = tmpbuff1; inptr2 = tmpbuff2; @@ -641,8 +641,9 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in inptr5 = tmpbuff5; inptr6 = tmpbuff6; inptr7 = tmpbuff7; - transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, - inptr5, inptr6, inptr7, outptr); + transpose_4x8_1_b_with_shift( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } } for (; y < ymax; y += 8) { @@ -655,23 +656,29 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in const int8_t* inptr6 = inptr5 + ldin; const int8_t* inptr7 = inptr6 + ldin; - int K = (kmax - k0)/2; + int K = (kmax - k0) / 2; //! read 4 * 16 in each row for (; K > 3; K -= 4) { if (y + 7 >= ymax) { switch (y + 7 - ymax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -679,24 +686,31 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in megdnn_assert(0); } } - transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, - inptr5, inptr6, inptr7, outptr); + transpose_4x8_1_b_with_shift( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { if (y + 7 >= ymax) { switch (y + 7 - ymax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -705,14 +719,14 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in } } - std::memcpy(tmpbuff0,inptr0,K); - std::memcpy(tmpbuff1,inptr1,K); - std::memcpy(tmpbuff2,inptr2,K); - std::memcpy(tmpbuff3,inptr3,K); - std::memcpy(tmpbuff4,inptr4,K); - std::memcpy(tmpbuff5,inptr5,K); - std::memcpy(tmpbuff6,inptr6,K); - std::memcpy(tmpbuff7,inptr7,K); + std::memcpy(tmpbuff0, inptr0, K); + std::memcpy(tmpbuff1, inptr1, K); + std::memcpy(tmpbuff2, inptr2, K); + std::memcpy(tmpbuff3, inptr3, K); + std::memcpy(tmpbuff4, inptr4, K); + std::memcpy(tmpbuff5, inptr5, K); + std::memcpy(tmpbuff6, inptr6, K); + std::memcpy(tmpbuff7, inptr7, K); inptr0 = tmpbuff0; inptr1 = tmpbuff1; inptr2 = tmpbuff2; @@ -721,14 +735,15 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in inptr5 = tmpbuff5; inptr6 = tmpbuff6; inptr7 = tmpbuff7; - transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, - inptr5, inptr6, inptr7, outptr); + transpose_4x8_1_b_with_shift( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } } } -//packb -static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +// packb +static void gemm_s4x4x16_8x8x8_interleave_pack( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[8]; int8_t tmpbuff0[8]; int8_t tmpbuff1[8]; @@ -748,7 +763,7 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); const int ksize = kmax - k0; - const int ksize8 = round_up(ksize, 8) * 8; //pack to int8 *8 packto s4 *4 + const int ksize8 = round_up(ksize, 8) * 8; // pack to int8 *8 packto s4 *4 int8_t* outptr = out; int8_t* outptr_interleave = nullptr; @@ -776,21 +791,22 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int8_t* outptr_inner = outptr; for (; x + 3 < xmax; x += 4) { outptr_interleave = outptr_inner; - interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr_interleave); + interleave_8x4_1_b_with_shift( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave); outptr_inner += ksize8; } if (x < xmax) { int remainx = xmax - x; - std::memcpy(tmpbuff0,inptr0,remainx); - std::memcpy(tmpbuff1,inptr1,remainx); - std::memcpy(tmpbuff2,inptr2,remainx); - std::memcpy(tmpbuff3,inptr3,remainx); - std::memcpy(tmpbuff4,inptr4,remainx); - std::memcpy(tmpbuff5,inptr5,remainx); - std::memcpy(tmpbuff6,inptr6,remainx); - std::memcpy(tmpbuff7,inptr7,remainx); + std::memcpy(tmpbuff0, inptr0, remainx); + std::memcpy(tmpbuff1, inptr1, remainx); + std::memcpy(tmpbuff2, inptr2, remainx); + std::memcpy(tmpbuff3, inptr3, remainx); + std::memcpy(tmpbuff4, inptr4, remainx); + std::memcpy(tmpbuff5, inptr5, remainx); + std::memcpy(tmpbuff6, inptr6, remainx); + std::memcpy(tmpbuff7, inptr7, remainx); inptr0 = tmpbuff0; inptr1 = tmpbuff1; inptr2 = tmpbuff2; @@ -801,8 +817,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, inptr7 = tmpbuff7; outptr_interleave = outptr_inner; - interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr_interleave); + interleave_8x4_1_b_with_shift( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave); outptr_inner += ksize8; } outptr += 64; @@ -847,8 +864,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, break; } outptr_interleave = outptr_inner; - interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr_interleave); + interleave_8x4_1_b_with_shift( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave); outptr_inner += ksize8; } if (x < xmax) { @@ -880,14 +898,14 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, } int remainx = xmax - x; outptr_interleave = outptr_inner; - std::memcpy(tmpbuff0,inptr0,remainx); - std::memcpy(tmpbuff1,inptr1,remainx); - std::memcpy(tmpbuff2,inptr2,remainx); - std::memcpy(tmpbuff3,inptr3,remainx); - std::memcpy(tmpbuff4,inptr4,remainx); - std::memcpy(tmpbuff5,inptr5,remainx); - std::memcpy(tmpbuff6,inptr6,remainx); - std::memcpy(tmpbuff7,inptr7,remainx); + std::memcpy(tmpbuff0, inptr0, remainx); + std::memcpy(tmpbuff1, inptr1, remainx); + std::memcpy(tmpbuff2, inptr2, remainx); + std::memcpy(tmpbuff3, inptr3, remainx); + std::memcpy(tmpbuff4, inptr4, remainx); + std::memcpy(tmpbuff5, inptr5, remainx); + std::memcpy(tmpbuff6, inptr6, remainx); + std::memcpy(tmpbuff7, inptr7, remainx); inptr0 = tmpbuff0; inptr1 = tmpbuff1; inptr2 = tmpbuff2; @@ -898,16 +916,16 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, inptr7 = tmpbuff7; outptr_interleave = outptr_inner; - interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr_interleave); + interleave_8x4_1_b_with_shift( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave); outptr_inner += ksize8; } } } -} // namespace matmul_4x4x16 +} // namespace matmul_s4_4x4x16 } // namespace aarch64 } // namespace megdnn - // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp b/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp index 4c06d5ed..6c1f5c78 100644 --- a/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.cpp @@ -10,9 +10,9 @@ * implied. */ +#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" #include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h" -#include "src/aarch64/matrix_mul/int4x4x16/strategy.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_common.h" @@ -23,39 +23,38 @@ using namespace aarch64::matmul; // ===========================gemm_s4x4x16_s4_8x8x8================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8); -void gemm_s4x4x16_s4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, - int ymax, int k0, int kmax, - bool transpose) const { +void gemm_s4x4x16_s4_8x8x8::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, y0, ymax, k0, - kmax); + matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack( + out, in, ldin, y0, ymax, k0, kmax); } else { - matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, y0, ymax, k0, - kmax); + matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack( + out, in, ldin, y0, ymax, k0, kmax); } } -void gemm_s4x4x16_s4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, - bool transpose) const { +void gemm_s4x4x16_s4_8x8x8::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, x0, xmax, k0, - kmax); + matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack( + out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, x0, xmax, k0, - kmax); + matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack( + out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, - size_t M, size_t N, size_t K, dt_int16* C, - size_t LDC, bool is_first_k, const dt_int16*, - dt_int16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - (A_dtype.enumv() == DTypeEnum::QuantizedS4 && - C_dtype.enumv() == DTypeEnum::QuantizedS16), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s4x4x16_s4_8x8x8::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + (A_dtype.enumv() == DTypeEnum::QuantizedS4 && + C_dtype.enumv() == DTypeEnum::QuantizedS16), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); @@ -72,16 +71,17 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_s4_4x4x16::s4_kern_8x8(packA, cur_packB, K, output, LDC, - is_first_k, A_INTERLEAVE, B_INTERLEAVE); + matmul_s4_4x4x16::s4_kern_8x8( + packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, + B_INTERLEAVE); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += B_INTERLEAVE) { - matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, - is_first_k, A_INTERLEAVE, - std::min(N - n, B_INTERLEAVE)); + matmul_s4_4x4x16::s4_kern_8x8_remain( + packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, + std::min(N - n, B_INTERLEAVE)); output += B_INTERLEAVE; cur_packB += K8; } @@ -94,10 +94,10 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t n = 0; const dt_int8* cur_packB = packB; for (; n < N; n += B_INTERLEAVE) { - matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, - is_first_k, - std::min(M - m, A_INTERLEAVE), - std::min(N - n, B_INTERLEAVE)); + matmul_s4_4x4x16::s4_kern_8x8_remain( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, A_INTERLEAVE), + std::min(N - n, B_INTERLEAVE)); output += B_INTERLEAVE; cur_packB += K8; } @@ -105,5 +105,4 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, } } - // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h b/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h index a23acbec..098e9215 100644 --- a/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h +++ b/dnn/src/aarch64/matrix_mul/int4x4x16/strategy.h @@ -17,8 +17,8 @@ namespace megdnn { namespace aarch64 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, - gemm_s4x4x16_s4_8x8x8); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, gemm_s4x4x16_s4_8x8x8); } // namespace matmul } // namespace aarch64 diff --git a/dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h b/dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h index de5c0771..79aca414 100644 --- a/dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h +++ b/dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h @@ -51,8 +51,9 @@ namespace matmul_4x4x16 { * Accumulator */ -static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { +static void kern_4x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k) { K /= 16; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -472,9 +473,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, ); } -static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - int m_remain, int n_remain) { +static void kern_4x4_remain( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { megdnn_assert(K > 0); K /= 16; const int8_t* a_ptr = packA; @@ -655,16 +656,14 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [output] "+r"(output), [m_remain] "+r"(m_remain), - [n_remain] "+r"(n_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output), + [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", - "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "x0", "x1", "x2", "x3", "x4", "x5", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -672,8 +671,9 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, int kmax) { +static void gemm_s8_4x4_pack_A_n( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -716,9 +716,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -734,9 +736,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -749,8 +753,8 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, } } -static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8_4x4_pack_B_n( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; @@ -777,19 +781,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, if (remain >= 0) { switch (remain) { case 7: - inptr0 = zerobuff; MEGDNN_FALLTHRU + inptr0 = zerobuff; + MEGDNN_FALLTHRU case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -798,9 +809,9 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } } - transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, - inptr4, inptr5, inptr6, inptr7, - outptr_inner); + transpose_4x16_1_b_helper( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_inner); outptr_inner += ksize4; } @@ -808,19 +819,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, if (remain >= 0) { switch (remain) { case 7: - inptr0 = zerobuff; MEGDNN_FALLTHRU + inptr0 = zerobuff; + MEGDNN_FALLTHRU case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; diff --git a/dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h b/dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h index 0b0f5575..9b6983ae 100644 --- a/dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h +++ b/dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h @@ -42,8 +42,9 @@ namespace matmul_8x8x8 { * Accumulator */ -static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { +static void kern_8x8( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -272,14 +273,13 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, "stp q18, q19, [x5]\n" "stp q20, q21, [x6]\n" "stp q22, q23, [x7]\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [output] "+r"(output) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v26", "v27", "x1", - "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v26", "v27", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "cc", "memory"); } /** @@ -309,9 +309,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, * Accumulator */ -static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - size_t n_remain) { +static void kern_8x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t n_remain) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -520,16 +520,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), - [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), - [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), - [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), - [n_remain] "+r"(n_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), + [outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -559,9 +557,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, * Accumulator */ -static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - size_t m_remain) { +static void kern_4x8( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t m_remain) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -724,14 +722,13 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), - [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), - [outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [x0] "+r"(x0), [m_remain] "+r"(m_remain) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -762,9 +759,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, * Accumulator */ -static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, size_t m_remain, - size_t n_remain) { +static void kern_4x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t m_remain, size_t n_remain) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -922,11 +919,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), - [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), - [x1] "+r"(x1), [m_remain] "+r"(m_remain), + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [x0] "+r"(x0), [x1] "+r"(x1), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc", @@ -938,8 +934,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, - int y0, int ymax, int k0, int kmax) { +static void gemm_s8_8x8_pack_A_n( + int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -965,13 +962,15 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, int K = kmax - k0; for (; K > 15; K -= 16) { - interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + interleave_8x8_2_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 8, K); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 8, K); } } @@ -991,9 +990,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -1009,9 +1010,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -1024,9 +1027,8 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, } } -static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, - int ldin, int x0, int xmax, int k0, - int kmax) { +static void gemm_s8_8x8_transpose_pack_A_n( + int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; @@ -1063,17 +1065,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1081,8 +1089,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, megdnn_assert(0); } } - transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += ksize8; } @@ -1091,17 +1100,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1110,8 +1125,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, } } - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, 4); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, 4); outptr += ksize4; } @@ -1119,17 +1135,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1138,8 +1160,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, } } - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, xmax - x); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, xmax - x); } outptr_base += 8 * 8; @@ -1147,8 +1170,8 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, } } -static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8_8x8_pack_B_n( + int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; @@ -1186,17 +1209,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1205,8 +1234,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, } } outptr_interleave = outptr; - interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr_interleave); + interleave_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave); outptr += ksize8; } @@ -1215,17 +1245,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1235,8 +1271,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, 4); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, 4); outptr += ksize4; } @@ -1244,17 +1281,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1264,8 +1307,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, xmax - x); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, xmax - x); } outptr_base += 8 * 8; @@ -1273,9 +1317,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, } } -static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_s8_8x8_transpose_pack_B_n( + int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); constexpr int interleave4 = 32; @@ -1303,14 +1347,16 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, int K = kmax - k0; for (; K > 7; K -= 8) { - transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += interleave8; } if (K > 0) { - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 8, K); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 8, K); outptr += interleave8; } } @@ -1331,9 +1377,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -1350,9 +1398,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; diff --git a/dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h b/dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h index d4f91b94..652a4b83 100644 --- a/dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h +++ b/dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h @@ -50,8 +50,9 @@ namespace matmul_mk4_4x4x16 { * Accumulator */ -static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, bool is_first_k) { +static void kern_4x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, + bool is_first_k) { K = div_ceil(K, 16); const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -366,17 +367,18 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, "6:\n" "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [k] "+r"(K), [output] "+r"(output) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "cc", "memory"); } -static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, bool is_first_k, size_t remain_n) { +static void kern_4x4_remain( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, + bool is_first_k, size_t remain_n) { K = div_ceil(K, 16); const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -718,26 +720,27 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, "7:\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k), - [k] "+r"(K), [output] "+r"(output) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [remain_n] "+r"(remain_n), + [is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "cc", "memory"); } -static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_mk4_s8_4x4_pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { //! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} int8_t zerobuff[4][64]; std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); - megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, - "mk4 matmul with m is not times of 4"); - megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, - "mk4 matmul with k is not times of 4"); + megdnn_assert( + ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, + "mk4 matmul with m is not times of 4"); + megdnn_assert( + kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, + "mk4 matmul with k is not times of 4"); size_t roundk = round_up(kmax - k0, 16); size_t out_offset = roundk * 4; int y = y0; @@ -754,8 +757,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, prefetch_2x(inptr3); int K = kmax - k0; for (; K > 15; K -= 16) { - transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, - out_offset); + transpose_interleave_4x4_4_b( + inptr0, inptr1, inptr2, inptr3, output, out_offset); output += 64; } if (K > 0) { @@ -767,8 +770,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, inptr1 = zerobuff[1]; inptr2 = zerobuff[2]; inptr3 = zerobuff[3]; - transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, - out_offset); + transpose_interleave_4x4_4_b( + inptr0, inptr1, inptr2, inptr3, output, out_offset); output += 64; } } @@ -790,21 +793,21 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, } } -static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_mk4_s8_4x4_pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int32_t zerobuff[4]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; const int ICB = (ksize) / 4; const int ksize4 = round_up(ICB, 4) * 4; int32_t* outptr = reinterpret_cast(out); - megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, - "mk4 matmul with k is not times of 4"); + megdnn_assert( + kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, + "mk4 matmul with k is not times of 4"); int k = k0 / 4; for (; k + 3 < ICB; k += 4) { - const int32_t* inptr0 = - reinterpret_cast(in + k * ldin + x0); + const int32_t* inptr0 = reinterpret_cast(in + k * ldin + x0); const int32_t* inptr1 = reinterpret_cast(in + (k + 1) * ldin + x0); const int32_t* inptr2 = @@ -829,8 +832,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, outptr += 4 * 4; } if (k < ICB) { - const int32_t* inptr0 = - reinterpret_cast(in + k * ldin + x0); + const int32_t* inptr0 = reinterpret_cast(in + k * ldin + x0); const int32_t* inptr1 = reinterpret_cast(in + (k + 1) * ldin + x0); const int32_t* inptr2 = @@ -844,9 +846,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, if (k + 3 >= ICB) { switch (k + 3 - ICB) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -861,9 +865,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, if (k + 3 >= ICB) { switch (k + 3 - ICB) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -882,7 +888,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, } } -} // namespace matmul_4x4x16 +} // namespace matmul_mk4_4x4x16 } // namespace aarch64 } // namespace megdnn diff --git a/dnn/src/aarch64/matrix_mul/int8/strategy.cpp b/dnn/src/aarch64/matrix_mul/int8/strategy.cpp index a5a37a49..a96edf5c 100644 --- a/dnn/src/aarch64/matrix_mul/int8/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/int8/strategy.cpp @@ -24,20 +24,19 @@ using namespace aarch64::matmul; ///////////////////////// gemm_s8_4x4 //////////////////////////////////// MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4); -void gemm_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +void gemm_s8_4x4::pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { if (transpose) { - matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, - kmax); + matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); } else { - matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, - kmax); + matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); } } -void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { +void gemm_s8_4x4::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); } else { @@ -45,16 +44,16 @@ void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, } } -void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, - size_t N, size_t K, dt_int32* C, size_t LDC, - bool is_first_k, const dt_int32*, dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int32) || - (A_dtype.enumv() == DTypeEnum::QuantizedS8 && - C_dtype.enumv() == DTypeEnum::QuantizedS32)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s8_4x4::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); @@ -72,16 +71,15 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k); + matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE; cur_packB += K4; } for (; n < N; n += B_INTERLEAVE) { - matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, LDC, - is_first_k, 4, - std::min(N - n, 4)); + matmul_4x4x16::kern_4x4_remain( + packA, cur_packB, K, output, LDC, is_first_k, 4, + std::min(N - n, 4)); output += B_INTERLEAVE; cur_packB += K4; } @@ -107,33 +105,32 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, ///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); -void gemm_mk4_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { - megdnn_assert(!transpose, - "the gemm_mk4_s8_4x4 strategy is not support transpose A"); - matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, - kmax); +void gemm_mk4_s8_4x4::pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { + megdnn_assert( + !transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose A"); + matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax); } -void gemm_mk4_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { - megdnn_assert(!transpose, - "the gemm_mk4_s8_4x4 strategy is not support transpose B"); - matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, - kmax); +void gemm_mk4_s8_4x4::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { + megdnn_assert( + !transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose B"); + matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, kmax); } -void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, - size_t N, size_t K, dt_int32* C, size_t LDC, - bool is_first_k, const dt_int32*, dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int32) || - (A_dtype.enumv() == DTypeEnum::QuantizedS8 && - C_dtype.enumv() == DTypeEnum::QuantizedS32)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_mk4_s8_4x4::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); @@ -151,57 +148,54 @@ void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, - is_first_k); + matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, is_first_k); output += B_INTERLEAVE * 4; cur_packB += K4; } if (n < N) { - matmul_mk4_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, - is_first_k, N - n); + matmul_mk4_4x4x16::kern_4x4_remain( + packA, cur_packB, K, output, is_first_k, N - n); } packA += K4; } } - ///////////////////////// gemm_s8_8x8 //////////////////////////////////// MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8); -void gemm_s8_8x8::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +void gemm_s8_8x8::pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { if (transpose) { - matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, - ymax, k0, kmax); + matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n( + outptr, inptr, ldin, y0, ymax, k0, kmax); } else { - matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, - kmax); + matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); } } -void gemm_s8_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { +void gemm_s8_8x8::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, - k0, kmax); + matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } else { matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, - size_t N, size_t K, dt_int32* C, size_t LDC, - bool is_first_k, const dt_int32*, dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int32) || - (A_dtype.enumv() == DTypeEnum::QuantizedS8 && - C_dtype.enumv() == DTypeEnum::QuantizedS32)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s8_8x8::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); @@ -220,15 +214,15 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, - is_first_k); + matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(N - n, 4)); + matmul_8x8x8::kern_8x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -240,16 +234,17 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, const dt_int8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4)); + matmul_8x8x8::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), - std::min(N - n, 4)); + matmul_8x8x8::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); output += 4; cur_packB += K4; } diff --git a/dnn/src/aarch64/matrix_mul/int8/strategy.h b/dnn/src/aarch64/matrix_mul/int8/strategy.h index 26b755e3..9fbecabd 100644 --- a/dnn/src/aarch64/matrix_mul/int8/strategy.h +++ b/dnn/src/aarch64/matrix_mul/int8/strategy.h @@ -16,14 +16,14 @@ namespace megdnn { namespace aarch64 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, - gemm_s8_4x4); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4); -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, - gemm_mk4_s8_4x4); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, gemm_mk4_s8_4x4); -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, - gemm_s8_8x8); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, gemm_s8_8x8); } // namespace matmul } // namespace aarch64 diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h index 61649fb2..14d37464 100644 --- a/dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h +++ b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h @@ -52,8 +52,9 @@ namespace matmul_8x12x4 { #if 1 MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { +static void kern_8x12( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -410,8 +411,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, } #else MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { +static void kern_8x12( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -612,18 +614,17 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, "stp q15, q23, [%[outptr7]]\n" "str q31, [%[outptr7], #32]\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), - [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), - [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), - [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), - [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), - [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), - [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), - [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), [a1] "+w"(a1), + [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), [b1] "+w"(b1), + [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk), + [is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), + [outptr7] "=r"(outptr7) : - : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", - "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", "cc", "memory"); } #endif @@ -653,8 +654,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, // // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, int m_remain) { +static void kern_4x12( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int m_remain) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -796,15 +798,15 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, "4:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), - [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), - [is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), - [LDC] "+r"(LDC), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), - [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), - [b2a] "=w"(b2a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), - [outptr3] "=r"(outptr3), [x0] "=r"(x0) + [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), + [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [a0] "=w"(a0), + [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), + [b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [x0] "=r"(x0) : - : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", - "v17", "v18", "v19", "memory", "cc"); + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "memory", "cc"); #undef LOAD_LINE #undef LOAD_C @@ -840,8 +842,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, // // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, int n_remain) { +static void kern_8x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int n_remain) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -1004,12 +1007,11 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), - [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), - [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), - [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "=r"(x0) + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), + [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), + [x0] "=r"(x0) : - : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", - "cc"); + : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc"); #undef LOAD_LINE #undef LOAD_C @@ -1041,9 +1043,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, // // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { +static void kern_4x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { K /= 4; const int32_t* a_ptr = reinterpret_cast(packA); const int32_t* b_ptr = reinterpret_cast(packB); @@ -1172,10 +1174,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, "4:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), - [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), - [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), - [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), + [k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), + [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1) : : "v4", "v5", "v6", "v7", "memory", "cc"); @@ -1186,9 +1187,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_s8_8x12_pack_A_n( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -1215,13 +1216,15 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, int K = kmax - k0; //! read 8 * 4 in each row for (; K > 15; K -= 16) { - interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + interleave_8x4_4_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, K); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, K); } } for (; y < ymax; y += 4) { @@ -1274,8 +1277,8 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, } } -static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8_8x12_pack_A_t( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; @@ -1361,8 +1364,8 @@ static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, } } -static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8_8x12_pack_B_n( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; @@ -1448,9 +1451,9 @@ static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } } -static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_s8_8x12_pack_B_t( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -1485,15 +1488,15 @@ static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr, int K = kmax - k0; //! read 12 * 4 in each row for (; K > 15; K -= 16) { - interleave_12x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, inptr8, inptr9, inptr10, - inptr11, outptr); + interleave_12x4_4_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + inptr8, inptr9, inptr10, inptr11, outptr); } if (K > 0) { - interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, - outptr, 4, K); + interleave_12( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + inptr8, inptr9, inptr10, inptr11, outptr, 4, K); } } for (; y < ymax; y += 4) { diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h index 79898861..54d61e53 100644 --- a/dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h +++ b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h @@ -40,8 +40,9 @@ namespace matmul_mk4_8x12x4 { // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { +static void kern_8x12( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -397,8 +398,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { +static void kern_4x12( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -514,13 +516,12 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, "stp q16, q17, [%[outptr0], #128]\n" "stp q18, q19, [%[outptr0], #160]\n" : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), - [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), - [is_first_k] "+r"(is_first_k), [a0] "=w"(a0), [a0a] "=w"(a0a), - [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), - [b1a] "=w"(b1a), [b2a] "=w"(b2a) + [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), + [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), + [b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a) : - : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", - "v17", "v18", "v19", "memory", "cc"); + : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "memory", "cc"); } // Overview of register layout: @@ -544,8 +545,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, int n_remain) { +static void kern_8x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int n_remain) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -689,11 +691,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), - [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), - [x0] "=r"(x0) + [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [x0] "=r"(x0) : - : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", - "cc"); + : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc"); #undef LOAD_LINE #undef LOAD_C @@ -720,8 +720,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, int n_remain) { +static void kern_4x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int n_remain) { K /= 4; const int32_t* a_ptr = reinterpret_cast(packA); const int32_t* b_ptr = reinterpret_cast(packB); @@ -834,10 +835,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, "4:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), - [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), - [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k), - [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), - [x0] "=r"(x0) + [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), + [b0] "=w"(b0), [b0a] "=w"(b0a), [x0] "=r"(x0) : : "v4", "v5", "v6", "v7", "memory", "cc"); @@ -847,13 +847,11 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { - megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, - "mk4 matmul with m is not times of 4"); - megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, - "mk4 matmul with k is not times of 4"); +static void gemm_mk4_s8_8x12_pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { + megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, "mk4 matmul with m is not times of 4"); + megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, "mk4 matmul with k is not times of 4"); int y = y0; int start_y = y0 / 4; for (; y + 7 < ymax; y += 8, start_y += 2) { @@ -869,15 +867,15 @@ static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr, interleave_2x4_4_b(inptr0, inptr1, outptr); } } - for (; y + 3 < ymax; y += 4, start_y ++) { + for (; y + 3 < ymax; y += 4, start_y++) { int K = kmax - k0; const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2); std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4); } } -static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_mk4_s8_8x12_pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { const int ksize = kmax - k0; const int ksize12 = ksize * 12; const int ksize4 = ksize * 4; diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp index fadf4215..1470180d 100644 --- a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp @@ -12,10 +12,10 @@ #include "src/aarch64/matrix_mul/int8_dot/strategy.h" #if MGB_ENABLE_DOT #include "src/aarch64/matrix_mul/asm/common.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/common/utils.h" #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" #include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" using namespace megdnn; using namespace aarch64; @@ -24,20 +24,19 @@ using namespace aarch64::matmul; /* ====================== gemm_s8_8x12 ===========================*/ MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); -void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +void gemm_s8_8x12::pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { if (transpose) { - matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, - kmax); + matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, kmax); } else { - matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, - kmax); + matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); } } -void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { +void gemm_s8_8x12::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); } else { @@ -45,16 +44,16 @@ void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, } } -void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, - size_t N, size_t K, dt_int32* C, size_t LDC, - bool is_first_k, const dt_int32*, dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int32) || - (A_dtype.enumv() == DTypeEnum::QuantizedS8 && - C_dtype.enumv() == DTypeEnum::QuantizedS32)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s8_8x12::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -75,15 +74,15 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, - is_first_k); + matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE; cur_packB += K12; } for (; n < N; n += 4) { - matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, - is_first_k, std::min(N - n, 4)); + matmul_8x12x4::kern_8x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -95,16 +94,17 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, const dt_int8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, - is_first_k, std::min(M - m, 4)); + matmul_8x12x4::kern_4x12( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K12; } for (; n < N; n += 4) { - matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, std::min(M - m, 4), - std::min(N - n, 4)); + matmul_8x12x4::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -115,32 +115,32 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, /* ====================== gemm_mk4_s8_8x12 ===========================*/ MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12); -void gemm_mk4_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { - megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix A is not supported"); - matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, - kmax); +void gemm_mk4_s8_8x12::pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { + megdnn_assert( + !transpose, "matrix mul mk4 with transposed matrix A is not supported"); + matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax); } -void gemm_mk4_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, - bool transpose) const { - megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported"); +void gemm_mk4_s8_8x12::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { + megdnn_assert( + !transpose, "matrix mul mk4 with transposed matrix B is not supported"); matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); } -void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, - size_t M, size_t N, size_t K, dt_int32* C, - size_t LDC, bool is_first_k, const dt_int32*, - dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int32) || - (A_dtype.enumv() == DTypeEnum::QuantizedS8 && - C_dtype.enumv() == DTypeEnum::QuantizedS32)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_mk4_s8_8x12::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -161,15 +161,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, - is_first_k); + matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); output += (B_INTERLEAVE << 2); cur_packB += K12; } for (; n < N; n += 4) { - matmul_mk4_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, - is_first_k, std::min(N - n, 4)); + matmul_mk4_8x12x4::kern_8x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 16; cur_packB += K4; } @@ -181,15 +181,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, const dt_int8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, - is_first_k); + matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); output += (B_INTERLEAVE << 2); cur_packB += K12; } for (; n < N; n += 4) { - matmul_mk4_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, std::min(N - n, 4)); + matmul_mk4_8x12x4::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 16; cur_packB += K4; } diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h index f413fed8..e5a07164 100644 --- a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h +++ b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h @@ -16,14 +16,14 @@ namespace megdnn { namespace aarch64 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, - gemm_s8_8x12); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_s8_8x12); -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, - gemm_mk4_s8_8x12); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_mk4_s8_8x12); -} // namespace aarch64 } // namespace matmul +} // namespace aarch64 } // namespace megdnn #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h index 90f70c75..6bfdd817 100644 --- a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h @@ -34,9 +34,9 @@ namespace matmul_4x4x16 { * * Accumulator */ -static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { +static void kern_4x4( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { K /= 16; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -230,16 +230,14 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, // Store back into memory STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr] "+r"(outptr), [m_remain] "+r"(m_remain), - [n_remain] "+r"(n_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), + [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) : - : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", - "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", - "v31"); + : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", "v3", + "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31"); #undef LOAD_LINE #undef LOAD_C @@ -247,9 +245,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_s8x8x16_4x4_pack_A_n( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -292,9 +290,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -309,9 +309,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -324,8 +326,8 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, } } -static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8x8x16_4x4_pack_B_n( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; @@ -362,19 +364,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, if (remain >= 0) { switch (remain) { case 7: - inptr0 = zerobuff; MEGDNN_FALLTHRU + inptr0 = zerobuff; + MEGDNN_FALLTHRU case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -383,9 +392,9 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } } - transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, - inptr4, inptr5, inptr6, inptr7, - outptr_inner); + transpose_4x16_1_b_helper( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_inner); outptr_inner += ksize4; } @@ -393,19 +402,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, if (remain >= 0) { switch (remain) { case 7: - inptr0 = zerobuff; MEGDNN_FALLTHRU + inptr0 = zerobuff; + MEGDNN_FALLTHRU case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h index b5bd58ab..c0e7a8db 100644 --- a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h @@ -42,8 +42,9 @@ namespace matmul_8x8x8 { * * Accumulator */ -static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k) { +static void kern_8x8( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -217,13 +218,12 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, "bne 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [outptr] "+r"(outptr) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", - "x4", "x5", "x6", "x7", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", "x4", "x5", + "x6", "x7", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE @@ -258,9 +258,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, * Accumulator */ -static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, - size_t n_remain) { +static void kern_8x4( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, size_t n_remain) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -471,16 +471,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), - [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), - [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), - [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), - [n_remain] "+r"(n_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), + [outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -514,9 +512,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, * Accumulator */ -static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, - size_t m_remain) { +static void kern_4x8( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, size_t m_remain) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -646,11 +644,10 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), - [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), - [m_remain] "+r"(m_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), + [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), + [x0] "+r"(x0), [m_remain] "+r"(m_remain) : : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", "memory"); @@ -686,9 +683,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, * * Accumulator */ -static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, size_t m_remain, - size_t n_remain) { +static void kern_4x4( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, size_t m_remain, size_t n_remain) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -853,11 +850,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, "3:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), - [x0] "+r"(x0), [m_remain] "+r"(m_remain), - [n_remain] "+r"(n_remain) + [x0] "+r"(x0), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", - "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", "cc", + "memory"); #undef LOAD_LINE #undef LOAD_C @@ -865,9 +861,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_s8x8x16_8x8_pack_A_n( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -893,13 +889,15 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, int K = kmax - k0; for (; K > 15; K -= 16) { - interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + interleave_8x8_2_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 8, K); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 8, K); } } for (; y < ymax; y += 4) { @@ -918,9 +916,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -936,9 +936,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -951,9 +953,8 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, } } -static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, - int ldin, int x0, int xmax, - int k0, int kmax) { +static void gemm_s8x8x16_8x8_transpose_pack_A_n( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -991,17 +992,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1009,8 +1016,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, megdnn_assert(0); } } - transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += ksize8; } @@ -1019,17 +1027,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1038,8 +1052,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, } } - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, 4); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, 4); outptr += ksize4; } @@ -1047,17 +1062,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1066,8 +1087,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, } } - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, xmax - x); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, xmax - x); } outptr_base += 8 * 8; @@ -1075,8 +1097,8 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, } } -static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8x8x16_8x8_pack_B_n( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; @@ -1113,17 +1135,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1132,8 +1160,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } } outptr_interleave = outptr; - interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr_interleave); + interleave_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave); outptr += ksize8; } @@ -1142,17 +1171,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1162,8 +1197,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, 4); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, 4); outptr += ksize4; } @@ -1171,17 +1207,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1191,8 +1233,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, xmax - x); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, xmax - x); } outptr_base += 8 * 8; @@ -1200,10 +1243,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } } -static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, - const dt_int8* inptr, int ldin, - int y0, int ymax, int k0, - int kmax) { +static void gemm_s8x8x16_8x8_transpose_pack_B_n( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); constexpr int interleave4 = 32; @@ -1231,14 +1273,16 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, int K = kmax - k0; for (; K > 7; K -= 8) { - transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += interleave8; } if (K > 0) { - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 8, K); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 8, K); outptr += interleave8; } } @@ -1259,9 +1303,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -1278,9 +1324,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h index 99c79a79..a951637a 100644 --- a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h @@ -40,11 +40,9 @@ namespace matmul_mk4_16x12x4_a53 { * Accumulator */ // clang-format on -static __attribute__((noinline)) void kern_16x12(const int16_t* packA, - const int8_t* packB, int K, - int16_t* output, int LDC, - bool is_first_k, - int remain_n) { +static __attribute__((noinline)) void kern_16x12( + const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int remain_n) { K /= 4; const int16_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -521,15 +519,15 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA, "6:\n" STORE_C "101:\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), + [remain_n] "+r"(remain_n) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", - "x8", "x9", "x10", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", + "memory"); #undef STORE_C #undef STORE_LINE @@ -554,10 +552,9 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA, * Accumulator */ // clang-format on -static __attribute__((noinline)) void kern_8x12(const int16_t* packA, - const int8_t* packB, int K, - int16_t* output, int LDC, - bool is_first_k, int remain_n) { +static __attribute__((noinline)) void kern_8x12( + const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int remain_n) { K /= 4; const int16_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -858,14 +855,13 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA, "6:\n" STORE_C "101:\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), + [remain_n] "+r"(remain_n) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", - "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3", + "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory"); #undef STORE_C #undef STORE_LINE @@ -890,10 +886,9 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA, * Accumulator */ // clang-format on -static __attribute__((noinline)) void kern_4x12(const int16_t* packA, - const int8_t* packB, int K, - int16_t* output, int LDC, - bool is_first_k, int remain_n) { +static __attribute__((noinline)) void kern_4x12( + const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int remain_n) { K /= 4; const int16_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -1162,22 +1157,21 @@ static __attribute__((noinline)) void kern_4x12(const int16_t* packA, "6:\n" STORE_C "101:\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), + [remain_n] "+r"(remain_n) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", - "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3", + "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory"); #undef STORE_C #undef STORE_LINE } -static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr, - const dt_int8* inptr, int ldin, - int m0, int mmax, int k0, int kmax) { +static void gemm_s8x8x16_mk4_16x12_pack_A( + dt_int16* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0, + int kmax) { megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); constexpr int pack_m = 16; @@ -1224,9 +1218,8 @@ static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr, } } -static void gemm_s8x8x16_mk4_16x12_pack_B(dt_int8* out, const dt_int8* in, - int ldin, int n0, int nmax, int k0, - int kmax) { +static void gemm_s8x8x16_mk4_16x12_pack_B( + dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); constexpr int pack_n = 12; diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h index 88d2c153..65a2238c 100644 --- a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h @@ -43,8 +43,9 @@ namespace matmul_mk4_4x4x8_a72 { */ // clang-format on -static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool, int remain_n) { +static inline void kern_4x4( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, bool, + int remain_n) { K = div_ceil(K, 8); int oddk = (K & 1); K = ((K + 1) / 2) - 1; @@ -261,15 +262,14 @@ static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, "7:\n" STORE_C "101:\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr), - [remain_n] "+r"(remain_n) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [oddk] "+r"(oddk), + [LDC] "+r"(LDC), [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", - "x8", "x9", "x10", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", + "memory"); #undef STORE_C #undef STORE_LINE @@ -282,26 +282,23 @@ static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) { vst1_s8(outptr + 3 * 8, in0.val[3]); } -static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2, - dt_int8* outptr) { +static inline void interleve_8x4_b( + const dt_int8* inptr, const dt_int8* inptr2, dt_int8* outptr) { int8x16_t in0 = vld1q_s8(inptr); int8x16_t in1 = vld1q_s8(inptr2); - int32x4x2_t in_x2 = { - {vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; + int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; vst2q_s32(reinterpret_cast(outptr), in_x2); } static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) { int8x16_t in0 = vld1q_s8(inptr); int8x16_t in1 = vdupq_n_s8(0); - int32x4x2_t in_x2 = { - {vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; + int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; vst2q_s32(reinterpret_cast(outptr), in_x2); } -static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, - int ldin, int m0, int mmax, int k0, - int kmax) { +static void gemm_s8x8x16_mk4_4x4x8_pack_A( + dt_int8* out, const dt_int8* in, int ldin, int m0, int mmax, int k0, int kmax) { megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); constexpr int pack_m = 4; @@ -330,9 +327,8 @@ static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, } } -static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in, - int ldin, int n0, int nmax, int k0, - int kmax) { +static void gemm_s8x8x16_mk4_4x4x8_pack_B( + dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); constexpr int pack_n = 4; diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h index 184f4d96..535bca8d 100644 --- a/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h @@ -18,7 +18,6 @@ namespace megdnn { namespace aarch64 { namespace matmul_mk4_8x8x8 { - /** * Overview of register layout: * @@ -39,18 +38,18 @@ namespace matmul_mk4_8x8x8 { * | v16 | | v28 | * | v17 | | v29 | * | v16 | | v30 | - * | v17 | | v31 | + * | v17 | | v31 | * +--------+ - - - - +---------------------------------+ * * Accumulator */ -static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { +static void kern_8x8( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { K /= 8; LDC = LDC * sizeof(int16_t); - const int8_t* a_ptr = packB;//packA; - const int8_t* b_ptr = packA;//packB; + const int8_t* a_ptr = packB; // packA; + const int8_t* b_ptr = packA; // packB; // clang-format off #define LOAD_C_8 \ "ld1 {v0.8h}, [x0], #16\n" \ @@ -291,17 +290,17 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -// clang-format on + // clang-format on } -static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { +static void kern_8x8_remain( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { K /= 8; LDC = LDC * sizeof(int16_t); const int8_t* a_ptr = packB; const int8_t* b_ptr = packA; -// clang-format off + // clang-format off register int16_t* outptr asm("x0") = output; asm volatile( "add x1, x0, %x[LDC]\n" @@ -476,7 +475,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, "cbnz %w[K], 1b\n" "cmp %w[is_first_k], #1\n" - "beq 2f\n" + "beq 2f\n" "cmp %x[m_remain], #8 \n" "beq 8f \n" "cmp %x[m_remain], #4 \n" @@ -633,7 +632,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, "zip2 v15.2d, v30.2d, v31.2d \n" "add v6.8h, v6.8h, v13.8h \n" "add v7.8h, v7.8h, v15.8h \n" -//save to memory + // save to memory "cmp %x[m_remain], #8 \n" "beq 4f \n" "cmp %x[m_remain], #4 \n" @@ -766,31 +765,27 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, "b 1000f \n" "1000: \n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), + [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) : - [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), - [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), - [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), - [ n_remain ] "+r"(n_remain) - : - : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31"); -// clang-format on + : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", + "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + // clang-format on #undef LOAD_C_8 #undef STORE_C_8 } - -static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { +static void kern_4x8( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { K /= 8; LDC = LDC * sizeof(int16_t); - const int8_t* a_ptr = packB;//packA; - const int8_t* b_ptr = packA;//packB; + const int8_t* a_ptr = packB; // packA; + const int8_t* b_ptr = packA; // packB; // clang-format off #define LOAD_C_4 \ "ld1 {v0.8h}, [x0], #16\n" \ @@ -1018,14 +1013,14 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, #undef LOAD_C_4 #undef STORE_C_4 } -static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { +static void kern_4x8_remain( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { K /= 8; LDC = LDC * sizeof(int16_t); - const int8_t* a_ptr = packB;//packA; - const int8_t* b_ptr = packA;//packB; -// clang-format off + const int8_t* a_ptr = packB; // packA; + const int8_t* b_ptr = packA; // packB; + // clang-format off register int16_t* outptr asm("x0") = output; asm volatile( @@ -1324,13 +1319,12 @@ static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C_4 } - //! pack to icxoc //! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7)) -//! if M K is not times of 8,pack 0 instead -static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, - const dt_int8* inptr, int ldin, - int m0, int mmax, int k0, int kmax) { +//! if M K is not times of 8,pack 0 instead +static void gemm_s8x8x16_mk4_8x8x8_pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0, + int kmax) { megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); constexpr int pack_m = 8; @@ -1349,8 +1343,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, prefetch_2x(inptr0); prefetch_2x(inptr1); int k_idx = k0; - for ( ; k_idx + 7 < kmax; k_idx += pack_k) { - interleave_8x8_mk4_b(inptr0,inptr1,outptr); + for (; k_idx + 7 < kmax; k_idx += pack_k) { + interleave_8x8_mk4_b(inptr0, inptr1, outptr); } if (k_idx < kmax) { @@ -1368,9 +1362,9 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, prefetch_2x(inptr0); prefetch_2x(inptr1); int k_idx = k0; - for ( ; k_idx + 7 < kmax; k_idx += pack_k) { + for (; k_idx + 7 < kmax; k_idx += pack_k) { inptr1 = zerobuff; - interleave_8x8_mk4_b(inptr0,inptr1,outptr); + interleave_8x8_mk4_b(inptr0, inptr1, outptr); } if (k_idx < kmax) { @@ -1383,9 +1377,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, } //! pack to nxic //! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead. -static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, - int ldin, int n0, int nmax, int k0, - int kmax) { +static void gemm_s8x8x16_mk4_8x8x8_pack_B( + dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); constexpr int pack_n = 8; @@ -1394,14 +1387,14 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, int8_t tmpbuff0[pack_n * pack_size] = {0}; int8_t tmpbuff1[pack_n * pack_size] = {0}; int8_t zerobuff[pack_n * pack_size] = {0}; - const int ksize = round_up((kmax - k0),8); + const int ksize = round_up((kmax - k0), 8); const int nsize = nmax - n0; const int n_end = nsize / pack_n * pack_n + n0; const int remain_n = nsize % pack_n; int output_stride = ksize * pack_n; int8_t* outptr_base = out; int k_idx = k0; - for ( ; k_idx + 7 < kmax; k_idx += pack_k) { + for (; k_idx + 7 < kmax; k_idx += pack_k) { const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; const int8_t* inptr1 = inptr0 + ldin; prefetch_3x(inptr0); @@ -1410,7 +1403,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, auto outptr = outptr_base; for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { transpose_8x8_mk4_b(inptr0, inptr1, outptr); - outptr += output_stride; + outptr += output_stride; } if (remain_n > 0) { memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size); @@ -1422,8 +1415,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, } outptr_base += pack_n * pack_k; } - - if(k_idx < kmax){ + + if (k_idx < kmax) { const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; const int8_t* inptr1 = nullptr; prefetch_3x(inptr0); @@ -1444,7 +1437,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, } } -} // namespace matmul_mk4_16x12x4_a53 +} // namespace matmul_mk4_8x8x8 } // namespace aarch64 } // namespace megdnn diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp index 08ae289f..78d64a14 100644 --- a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.cpp @@ -10,13 +10,13 @@ * implied. */ +#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" #include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" -#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" -#include "src/aarch64/matrix_mul/int8x8x16/strategy.h" +#include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_common.h" @@ -28,39 +28,35 @@ using namespace aarch64::matmul; // ===========================gemm_s8x8x16_4x4================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); -void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, - int ymax, int k0, int kmax, - bool transpose) const { +void gemm_s8x8x16_8x8::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(out, in, ldin, y0, - ymax, k0, kmax); + matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n( + out, in, ldin, y0, ymax, k0, kmax); } else { - matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, - kmax); + matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } } -void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, - bool transpose) const { +void gemm_s8x8x16_8x8::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(out, in, ldin, x0, - xmax, k0, kmax); + matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n( + out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, - size_t M, size_t N, size_t K, dt_int16* C, - size_t LDC, bool is_first_k, const dt_int16*, - dt_int16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - (A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int16), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s8x8x16_8x8::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); @@ -79,15 +75,15 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, - is_first_k); + matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(N - n, 4)); + matmul_8x8x8::kern_8x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -99,16 +95,17 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, const dt_int8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4)); + matmul_8x8x8::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), - std::min(N - n, 4)); + matmul_8x8x8::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -119,39 +116,33 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, // ===========================gemm_s8x8x16_4x4================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4); -void gemm_s8x8x16_4x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, - int ymax, int k0, int kmax, - bool transpose) const { +void gemm_s8x8x16_4x4::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, - kmax); + matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, kmax); } else { - matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, - kmax); + matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } } -void gemm_s8x8x16_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, - bool transpose) const { +void gemm_s8x8x16_4x4::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, - size_t M, size_t N, size_t K, dt_int16* C, - size_t LDC, bool is_first_k, const dt_int16*, - dt_int16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - (A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int16), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s8x8x16_4x4::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); @@ -169,16 +160,17 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, A_INTERLEAVE, B_INTERLEAVE); + matmul_4x4x16::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, + B_INTERLEAVE); output += B_INTERLEAVE; cur_packB += K4; } for (; n < N; n += B_INTERLEAVE) { - matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, A_INTERLEAVE, - std::min(N - n, B_INTERLEAVE)); + matmul_4x4x16::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, + std::min(N - n, B_INTERLEAVE)); output += B_INTERLEAVE; cur_packB += K4; } @@ -191,10 +183,10 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t n = 0; const dt_int8* cur_packB = packB; for (; n < N; n += B_INTERLEAVE) { - matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, - std::min(M - m, A_INTERLEAVE), - std::min(N - n, B_INTERLEAVE)); + matmul_4x4x16::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, A_INTERLEAVE), + std::min(N - n, B_INTERLEAVE)); output += B_INTERLEAVE; cur_packB += K4; } @@ -205,28 +197,26 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, // ===========================gemm_s8x8x16_mk4_16x12================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53); -void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in, - int ldin, int y0, int ymax, int k0, - int kmax, bool) const { - matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0, - ymax, k0, kmax); +void gemm_s8x8x16_mk4_16x12_a53::pack_A( + dt_int16* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool) const { + matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A( + out, in, ldin, y0, ymax, k0, kmax); } -void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in, - int ldin, int x0, int xmax, int k0, - int kmax, bool) const { - matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0, - xmax, k0, kmax); +void gemm_s8x8x16_mk4_16x12_a53::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool) const { + matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B( + out, in, ldin, x0, xmax, k0, kmax); } -void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, - const dt_int8* packB, size_t M, size_t N, - size_t K, dt_int16* C, size_t LDC, - bool is_first_k, const dt_int16*, - dt_int16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - C_dtype.enumv() == DTypeEnum::Int16 && - A_dtype.enumv() == DTypeEnum::Int8); +void gemm_s8x8x16_mk4_16x12_a53::kern( + const dt_int16* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && + A_dtype.enumv() == DTypeEnum::Int8); megdnn_assert(is_first_k == true, "only impl is_first_k"); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -246,14 +236,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, size_t n_idx = 0; const int8_t* cur_packB = packB; for (; n_idx + pack_n <= N; n_idx += pack_n) { - matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, - is_first_k, pack_n); + matmul_mk4_16x12x4_a53::kern_16x12( + packA, cur_packB, K, output, LDC, is_first_k, pack_n); output += pack_n * pack_size; cur_packB += pack_n * K; } if (remain_n > 0) { - matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, - is_first_k, remain_n); + matmul_mk4_16x12x4_a53::kern_16x12( + packA, cur_packB, K, output, LDC, is_first_k, remain_n); output += remain_n * pack_size; cur_packB += pack_n * K; } @@ -265,14 +255,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, size_t n_idx = 0; const int8_t* cur_packB = packB; for (; n_idx + pack_n <= N; n_idx += pack_n) { - matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, - is_first_k, pack_n); + matmul_mk4_16x12x4_a53::kern_8x12( + packA, cur_packB, K, output, LDC, is_first_k, pack_n); output += pack_n * pack_size; cur_packB += pack_n * K; } if (remain_n > 0) { - matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, - is_first_k, remain_n); + matmul_mk4_16x12x4_a53::kern_8x12( + packA, cur_packB, K, output, LDC, is_first_k, remain_n); output += remain_n * pack_size; cur_packB += pack_n * K; } @@ -286,14 +276,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, size_t n_idx = 0; const int8_t* cur_packB = packB; for (; n_idx + pack_n <= N; n_idx += pack_n) { - matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, - is_first_k, pack_n); + matmul_mk4_16x12x4_a53::kern_4x12( + packA, cur_packB, K, output, LDC, is_first_k, pack_n); output += pack_n * pack_size; cur_packB += pack_n * K; } if (remain_n > 0) { - matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, - is_first_k, remain_n); + matmul_mk4_16x12x4_a53::kern_4x12( + packA, cur_packB, K, output, LDC, is_first_k, remain_n); output += remain_n * pack_size; cur_packB += pack_n * K; } @@ -303,27 +293,26 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, // ===========================gemm_s8x8x16_mk4_4x4_a72================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72); -void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin, - int y0, int ymax, int k0, int kmax, - bool) const { - matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax, - k0, kmax); +void gemm_s8x8x16_mk4_4x4_a72::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool) const { + matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A( + out, in, ldin, y0, ymax, k0, kmax); } -void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax, - bool) const { - matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax, - k0, kmax); +void gemm_s8x8x16_mk4_4x4_a72::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool) const { + matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B( + out, in, ldin, x0, xmax, k0, kmax); } -void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, - size_t M, size_t N, size_t K, dt_int16* C, - size_t LDC, bool is_first_k, - const dt_int16*, dt_int16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - C_dtype.enumv() == DTypeEnum::Int16 && - A_dtype.enumv() == DTypeEnum::Int8); +void gemm_s8x8x16_mk4_4x4_a72::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && + A_dtype.enumv() == DTypeEnum::Int8); megdnn_assert(is_first_k == true, "only impl is_first_k"); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -343,14 +332,14 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, const int8_t* cur_packB = packB; for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) { - matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, pack_n); + matmul_mk4_4x4x8_a72::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, pack_n); output += pack_n * pack_size; cur_packB += pack_n * packed_k; } if (remain_n > 0) { - matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, remain_n); + matmul_mk4_4x4x8_a72::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, remain_n); output += remain_n * pack_size; cur_packB += pack_n * packed_k; } @@ -361,27 +350,24 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, // ===========================gemm_s8x8x16_mk4_8x8x8================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8); -void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, - int ldin, int y0, int ymax, int k0, - int kmax, bool) const { - matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, - ymax, k0, kmax); +void gemm_s8x8x16_mk4_8x8x8::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool) const { + matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, ymax, k0, kmax); } -void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, - int ldin, int x0, int xmax, int k0, - int kmax, bool) const { - matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, - xmax, k0, kmax); +void gemm_s8x8x16_mk4_8x8x8::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool) const { + matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, xmax, k0, kmax); } -void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, - size_t M, size_t N, size_t K, dt_int16* C, - size_t LDC, bool is_first_k, const dt_int16*, - dt_int16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - C_dtype.enumv() == DTypeEnum::Int16 && - A_dtype.enumv() == DTypeEnum::Int8); +void gemm_s8x8x16_mk4_8x8x8::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && + A_dtype.enumv() == DTypeEnum::Int8); megdnn_assert(is_first_k == true, "only impl is_first_k"); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -402,14 +388,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t n_idx = 0; const int8_t* cur_packB = packB; for (; n_idx + pack_n <= N; n_idx += pack_n) { - matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, - is_first_k, pack_m, pack_n); + matmul_mk4_8x8x8::kern_8x8( + packA, cur_packB, K, output, LDC, is_first_k, pack_m, pack_n); output += pack_n * pack_size; cur_packB += KSIZE8; } if (remain_n > 0) { - matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC, - is_first_k, pack_m, remain_n); + matmul_mk4_8x8x8::kern_8x8_remain( + packA, cur_packB, K, output, LDC, is_first_k, pack_m, remain_n); output += remain_n * pack_size; cur_packB += KSIZE8; } @@ -421,14 +407,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t n_idx = 0; const int8_t* cur_packB = packB; for (; n_idx + pack_n <= N; n_idx += pack_n) { - matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, - is_first_k, 4, pack_n); + matmul_mk4_8x8x8::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, 4, pack_n); output += pack_n * pack_size; cur_packB += pack_n * K; } if (remain_n > 0) { - matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC, - is_first_k, 4, remain_n); + matmul_mk4_8x8x8::kern_4x8_remain( + packA, cur_packB, K, output, LDC, is_first_k, 4, remain_n); output += remain_n * pack_size; cur_packB += pack_n * K; } diff --git a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h index a303a26f..d4b40f29 100644 --- a/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h +++ b/dnn/src/aarch64/matrix_mul/int8x8x16/strategy.h @@ -17,17 +17,17 @@ namespace megdnn { namespace aarch64 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, - gemm_s8x8x16_8x8); -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true, - gemm_s8x8x16_4x4); -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false, - gemm_s8x8x16_mk4_4x4_a72); -MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(dt_int8, dt_int16, dt_int16, dt_int16, - 16, 12, 4, false, false, - gemm_s8x8x16_mk4_16x12_a53); -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, false, - gemm_s8x8x16_mk4_8x8x8); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, gemm_s8x8x16_8x8); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int16, dt_int16, 4, 4, 16, false, true, gemm_s8x8x16_4x4); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int16, dt_int16, 4, 4, 8, false, false, gemm_s8x8x16_mk4_4x4_a72); +MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE( + dt_int8, dt_int16, dt_int16, dt_int16, 16, 12, 4, false, false, + gemm_s8x8x16_mk4_16x12_a53); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int16, dt_int16, 8, 8, 8, false, false, gemm_s8x8x16_mk4_8x8x8); } // namespace matmul } // namespace aarch64 diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.cpp b/dnn/src/aarch64/matrix_mul/opr_impl.cpp index 470e5b5d..f16fbda0 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.cpp +++ b/dnn/src/aarch64/matrix_mul/opr_impl.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/aarch64/matrix_mul/algos.h" #include "src/aarch64/matrix_mul/opr_impl.h" +#include "src/aarch64/matrix_mul/algos.h" #include "src/common/metahelper.h" #include "src/common/utils.h" @@ -52,8 +52,8 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { SmallVector m_all_algos; fallback::MatrixMulImpl::AlgoBase::Mapper m_all_algos_map; -public: +public: AlgoPack() { m_all_algos.emplace_back(&f32_gemv); m_all_algos.emplace_back(&f32K8x12x1); @@ -104,11 +104,11 @@ const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) -SmallVector -MatrixMulImpl::get_all_packed_algo() { +SmallVector MatrixMulImpl::get_all_packed_algo() { auto&& algos = arm_common::MatrixMulImpl::get_all_packed_algo(); - algos.insert(algos.begin(), algo_pack().all_algos().begin(), - algo_pack().all_algos().end()); + algos.insert( + algos.begin(), algo_pack().all_algos().begin(), + algo_pack().all_algos().end()); return std::move(algos); } diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.h b/dnn/src/aarch64/matrix_mul/opr_impl.h index 39557bb3..c65d2133 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.h +++ b/dnn/src/aarch64/matrix_mul/opr_impl.h @@ -25,8 +25,7 @@ public: } }; - SmallVector get_all_packed_algo() - override; + SmallVector get_all_packed_algo() override; MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); @@ -47,9 +46,9 @@ private: class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel // 8x12x4 DotProduct #endif - class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 - class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 - class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 + class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 + class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 + class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16 @@ -67,6 +66,7 @@ private: class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int8x8x16 Kernel 4x4x16 class AlgoInt4x4x16K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 class AlgoPack; + public: static const AlgoPack& algo_pack(); }; diff --git a/dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h b/dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h index 1c217ca6..d81ccca0 100644 --- a/dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h +++ b/dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h @@ -44,9 +44,9 @@ namespace matmul_8x8x8 { * Accumulator */ -static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, uint8_t za, - uint8_t zb) { +static void kern_8x8( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, uint8_t za, uint8_t zb) { K /= 8; const uint8_t* a_ptr = packA; const uint8_t* b_ptr = packB; @@ -277,14 +277,14 @@ static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, "stp q18, q19, [x5]\n" "stp q20, q21, [x6]\n" "stp q22, q23, [x7]\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [output] "+r"(output), [za] "+r"(za), [zb] "+r"(zb) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output), [za] "+r"(za), + [zb] "+r"(zb) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "x1", - "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "x1", "x2", "x3", "x4", "x5", + "x6", "x7", "cc", "memory"); } /** @@ -316,9 +316,9 @@ static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, * Accumulator */ -static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, size_t n_remain, - uint8_t za, uint8_t zb) { +static void kern_8x4( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t n_remain, uint8_t za, uint8_t zb) { K /= 8; const uint8_t* a_ptr = packA; const uint8_t* b_ptr = packB; @@ -529,17 +529,15 @@ static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [za] "+r"(za), [zb] "+r"(zb), - [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), - [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), - [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), - [outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [za] "+r"(za), + [zb] "+r"(zb), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), + [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), + [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -571,9 +569,9 @@ static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, * Accumulator */ -static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, size_t m_remain, - uint8_t za, uint8_t zb) { +static void kern_4x8( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t m_remain, uint8_t za, uint8_t zb) { K /= 8; const uint8_t* a_ptr = packA; const uint8_t* b_ptr = packB; @@ -738,14 +736,13 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [za] "+r"(za), [zb] "+r"(zb), - [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [za] "+r"(za), + [zb] "+r"(zb), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -778,9 +775,9 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, * Accumulator */ -static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, size_t m_remain, - size_t n_remain, uint8_t za, uint8_t zb) { +static void kern_4x4( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t m_remain, size_t n_remain, uint8_t za, uint8_t zb) { K /= 8; const uint8_t* a_ptr = packA; const uint8_t* b_ptr = packB; @@ -940,15 +937,14 @@ static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, "cbnz %w[K], 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), - [outptr0] "+r"(outptr0), [za] "+r"(za), [zb] "+r"(zb), - [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [za] "+r"(za), + [zb] "+r"(zb), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), [x1] "+r"(x1), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) : - : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", - "v11", "cc", "memory"); + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -956,9 +952,9 @@ static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, #undef STORE_C } -static void gemm_u8_8x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, - int ldin, int y0, int ymax, int k0, int kmax, - uint8_t zero_point) { +static void gemm_u8_8x8_pack_A_n( + dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, uint8_t zero_point) { uint8_t zerobuff[16]; std::fill(zerobuff, zerobuff + 16, zero_point); @@ -984,13 +980,15 @@ static void gemm_u8_8x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, int K = kmax - k0; for (; K > 15; K -= 16) { - interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + interleave_8x8_2_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 8, K, zero_point); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 8, K, zero_point); } } @@ -1010,9 +1008,11 @@ static void gemm_u8_8x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -1028,9 +1028,11 @@ static void gemm_u8_8x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -1038,15 +1040,14 @@ static void gemm_u8_8x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, megdnn_assert(0); } } - interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, - zero_point); + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, zero_point); } } } -static void gemm_u8_8x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, - int ldin, int x0, int xmax, int k0, - int kmax, uint8_t zero_point) { +static void gemm_u8_8x8_transpose_pack_A_n( + dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, + uint8_t zero_point) { uint8_t zerobuff[16]; std::fill(zerobuff, zerobuff + 16, zero_point); const int ksize = kmax - k0; @@ -1083,17 +1084,23 @@ static void gemm_u8_8x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1101,8 +1108,9 @@ static void gemm_u8_8x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, megdnn_assert(0); } } - transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += ksize8; } @@ -1111,17 +1119,23 @@ static void gemm_u8_8x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1130,8 +1144,9 @@ static void gemm_u8_8x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, } } - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, 4, zero_point); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, 4, zero_point); outptr += ksize4; } @@ -1139,17 +1154,23 @@ static void gemm_u8_8x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1158,8 +1179,9 @@ static void gemm_u8_8x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, } } - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, xmax - x, zero_point); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, xmax - x, zero_point); } outptr_base += 8 * 8; @@ -1167,9 +1189,9 @@ static void gemm_u8_8x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, } } -static void gemm_u8_8x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, - int x0, int xmax, int k0, int kmax, - uint8_t zero_point) { +static void gemm_u8_8x8_pack_B_n( + dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, + uint8_t zero_point) { uint8_t zerobuff[16]; std::fill(zerobuff, zerobuff + 16, zero_point); const int ksize = kmax - k0; @@ -1207,17 +1229,23 @@ static void gemm_u8_8x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1226,8 +1254,9 @@ static void gemm_u8_8x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, } } outptr_interleave = outptr; - interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr_interleave); + interleave_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave); outptr += ksize8; } @@ -1236,17 +1265,23 @@ static void gemm_u8_8x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1256,8 +1291,9 @@ static void gemm_u8_8x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, 4, zero_point); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, 4, zero_point); outptr += ksize4; } @@ -1265,17 +1301,23 @@ static void gemm_u8_8x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, if (k + 7 >= kmax) { switch (k + 7 - kmax) { case 6: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 5: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 4: - inptr3 = zerobuff; MEGDNN_FALLTHRU + inptr3 = zerobuff; + MEGDNN_FALLTHRU case 3: - inptr4 = zerobuff; MEGDNN_FALLTHRU + inptr4 = zerobuff; + MEGDNN_FALLTHRU case 2: - inptr5 = zerobuff; MEGDNN_FALLTHRU + inptr5 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr6 = zerobuff; MEGDNN_FALLTHRU + inptr6 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr7 = zerobuff; break; @@ -1285,8 +1327,9 @@ static void gemm_u8_8x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, xmax - x, zero_point); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, xmax - x, zero_point); } outptr_base += 8 * 8; @@ -1294,10 +1337,9 @@ static void gemm_u8_8x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, } } -static void gemm_u8_8x8_transpose_pack_B_n(dt_uint8* outptr, - const dt_uint8* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - uint8_t zero_point) { +static void gemm_u8_8x8_transpose_pack_B_n( + dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, uint8_t zero_point) { uint8_t zerobuff[16]; std::fill(zerobuff, zerobuff + 16, zero_point); constexpr int interleave4 = 32; @@ -1325,14 +1367,16 @@ static void gemm_u8_8x8_transpose_pack_B_n(dt_uint8* outptr, int K = kmax - k0; for (; K > 7; K -= 8) { - transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += interleave8; } if (K > 0) { - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 8, K, zero_point); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 8, K, zero_point); outptr += interleave8; } } @@ -1353,9 +1397,11 @@ static void gemm_u8_8x8_transpose_pack_B_n(dt_uint8* outptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -1372,9 +1418,11 @@ static void gemm_u8_8x8_transpose_pack_B_n(dt_uint8* outptr, if (y + 3 >= ymax) { switch (y + 3 - ymax) { case 2: - inptr1 = zerobuff; MEGDNN_FALLTHRU + inptr1 = zerobuff; + MEGDNN_FALLTHRU case 1: - inptr2 = zerobuff; MEGDNN_FALLTHRU + inptr2 = zerobuff; + MEGDNN_FALLTHRU case 0: inptr3 = zerobuff; break; @@ -1382,8 +1430,7 @@ static void gemm_u8_8x8_transpose_pack_B_n(dt_uint8* outptr, megdnn_assert(0); } } - transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, - zero_point); + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, zero_point); outptr += interleave4; } } diff --git a/dnn/src/aarch64/matrix_mul/quint8/strategy.cpp b/dnn/src/aarch64/matrix_mul/quint8/strategy.cpp index 0e37f986..96da0baa 100644 --- a/dnn/src/aarch64/matrix_mul/quint8/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/quint8/strategy.cpp @@ -21,39 +21,38 @@ using namespace aarch64::matmul; MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8); -void gemm_u8_8x8::pack_A(dt_uint8* outptr, const dt_uint8* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +void gemm_u8_8x8::pack_A( + dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { uint8_t zA = A_dtype.param().zero_point; if (transpose) { - matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, - ymax, k0, kmax, zA); + matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n( + outptr, inptr, ldin, y0, ymax, k0, kmax, zA); } else { - matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, - kmax, zA); + matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax, zA); } } -void gemm_u8_8x8::pack_B(dt_uint8* out, const dt_uint8* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { +void gemm_u8_8x8::pack_B( + dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { uint8_t zB = B_dtype.param().zero_point; if (transpose) { - matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, - k0, kmax, zB); + matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n( + out, in, ldin, x0, xmax, k0, kmax, zB); } else { - matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, - zB); + matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, zB); } } -void gemm_u8_8x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, - size_t N, size_t K, dt_int32* C, size_t LDC, - bool is_first_k, const dt_int32*, dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Quantized8Asymm && - C_dtype.enumv() == DTypeEnum::QuantizedS32, - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_u8_8x8::kern( + const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Quantized8Asymm && + C_dtype.enumv() == DTypeEnum::QuantizedS32, + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); uint8_t zA = A_dtype.param().zero_point; uint8_t zB = B_dtype.param().zero_point; @@ -71,15 +70,16 @@ void gemm_u8_8x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t n = 0; const dt_uint8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k, - zA, zB); + matmul_8x8x8::kern_8x8( + packA, cur_packB, K, output, LDC, is_first_k, zA, zB); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(N - n, 4), zA, zB); + matmul_8x8x8::kern_8x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4), zA, zB); output += 4; cur_packB += K4; } @@ -91,16 +91,17 @@ void gemm_u8_8x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, const dt_uint8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), zA, zB); + matmul_8x8x8::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), zA, zB); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), - std::min(N - n, 4), zA, zB); + matmul_8x8x8::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4), zA, zB); output += 4; cur_packB += K4; } diff --git a/dnn/src/aarch64/matrix_mul/quint8/strategy.h b/dnn/src/aarch64/matrix_mul/quint8/strategy.h index 67d67f78..2e242920 100644 --- a/dnn/src/aarch64/matrix_mul/quint8/strategy.h +++ b/dnn/src/aarch64/matrix_mul/quint8/strategy.h @@ -16,8 +16,8 @@ namespace megdnn { namespace aarch64 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 8, 8, 8, false, true, - gemm_u8_8x8); +MEGDNN_REG_GEMM_STRATEGY( + dt_uint8, dt_int32, dt_int32, 8, 8, 8, false, true, gemm_u8_8x8); } // namespace matmul } // namespace aarch64 diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp index 207c65eb..b62d8aef 100644 --- a/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp @@ -12,17 +12,17 @@ #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" #if MGB_ENABLE_DOT #include "src/arm_common/simd_macro/marm_neon.h" -#include "src/common/utils.h" #include "src/common/unroll_macro.h" +#include "src/common/utils.h" namespace { MEGDNN_ATTRIBUTE_TARGET("dotprod") -void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride, - uint8_t zero_point_A, uint8_t zero_point_B) { - int32_t zAB = static_cast(zero_point_A) * - static_cast(zero_point_B) * K; +void gemv_naive_n( + const uint8_t* __restrict A, const uint8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, + uint8_t zero_point_A, uint8_t zero_point_B) { + int32_t zAB = + static_cast(zero_point_A) * static_cast(zero_point_B) * K; uint8x16_t zAq = vdupq_n_u8(zero_point_A); uint8x16_t zBq = vdupq_n_u8(zero_point_B); uint8x8_t zA = vdup_n_u8(zero_point_A); @@ -92,8 +92,7 @@ void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, acc[3] += static_cast(A[(m + 1) * Astride + k]) * B[k]; acc_zA += static_cast(B[k]) * zero_point_A; acc_zB += static_cast(A[m * Astride + k]) * zero_point_B; - acc_zB2 += static_cast(A[(m + 1) * Astride + k]) * - zero_point_B; + acc_zB2 += static_cast(A[(m + 1) * Astride + k]) * zero_point_B; } C[m * Cstride] = acc[0] + acc[1] + zAB - acc_zA - acc_zB; C[(m + 1) * Cstride] = acc[2] + acc[3] + zAB - acc_zA - acc_zB2; @@ -140,8 +139,7 @@ void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, acc_zA += static_cast(B[k]) * zero_point_A; acc_zB += static_cast(A[m * Astride + k]) * zero_point_B; } - C[m * Cstride] = - acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB; + C[m * Cstride] = acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB; } } } // namespace @@ -160,13 +158,12 @@ bool megdnn::aarch64::matmul::is_gemv_like_preferred_quint8( } void megdnn::aarch64::matmul::gemv_like_quint8( - const uint8_t* __restrict A, const uint8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, - size_t Bstride, size_t Cstride, uint8_t zero_point_A, - uint8_t zero_point_B) { + const uint8_t* __restrict A, const uint8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, + uint8_t zero_point_A, uint8_t zero_point_B) { megdnn_assert(N == 1); - return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride, - zero_point_A, zero_point_B); + return gemv_naive_n( + A, B, C, M, N, K, Astride, Bstride, Cstride, zero_point_A, zero_point_B); } #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h index 3c2ef596..2f148441 100644 --- a/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h @@ -17,14 +17,14 @@ namespace megdnn { namespace aarch64 { namespace matmul { -bool is_gemv_like_preferred_quint8(bool transposeA, bool transposeB, size_t M, - size_t N, size_t K, size_t LDA, size_t LDB, - size_t LDC); +bool is_gemv_like_preferred_quint8( + bool transposeA, bool transposeB, size_t M, size_t N, size_t K, size_t LDA, + size_t LDB, size_t LDC); -void gemv_like_quint8(const uint8_t* __restrict A, const uint8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride, - uint8_t zero_point_A, uint8_t zero_point_B); +void gemv_like_quint8( + const uint8_t* __restrict A, const uint8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, + uint8_t zero_point_A, uint8_t zero_point_B); } // namespace matmul } // namespace aarch64 diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h b/dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h index 6e46a97f..5864938d 100644 --- a/dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h @@ -56,9 +56,9 @@ namespace matmul_8x8x4 { // zB * k // A -> v27, v28 | B -> v29, v30 | zA * zB * k -> v26 MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { +static void kern_8x8( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { K /= 4; const uint8_t* a_ptr = packA; const uint8_t* b_ptr = packB; @@ -247,20 +247,18 @@ static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, "stp q12, q20, [%[outptr6]]\n" "stp q13, q21, [%[outptr7]]\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), - [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), - [b1] "+w"(b1), [k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk), + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), [a1] "+w"(a1), + [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), [b1] "+w"(b1), + [k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0), - [zero_point_A] "+r"(zero_point_A), - [zero_point_B] "+r"(zero_point_B), [zAB] "+r"(zAB), - [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), - [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), - [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), - [outptr7] "=r"(outptr7) + [zero_point_A] "+r"(zero_point_A), [zero_point_B] "+r"(zero_point_B), + [zAB] "+r"(zAB), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), + [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), + [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7) : - : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", - "v16", "v17", "v18", "v19", "v20", "v21", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "cc", "memory"); + : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "cc", "memory"); } // Overview of register layout: @@ -293,9 +291,10 @@ static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, // A -> v28 | B -> v29, v30 | zA * zB * k -> v26 MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, int m_remain, - uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { +static void kern_4x8( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int m_remain, uint8_t zero_point_A, uint8_t zero_point_B, + uint32_t zAB) { K /= 4; const uint8_t* a_ptr = packA; const uint8_t* b_ptr = packB; @@ -445,17 +444,15 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), - [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), - [is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), - [zero_point_A] "+r"(zero_point_A), - [zero_point_B] "+r"(zero_point_B), [zAB] "+r"(zAB), - [LDC] "+r"(LDC), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), - [b1] "=w"(b1), [b0a] "=w"(b0a), [b1a] "=w"(b1a), - [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), - [outptr3] "=r"(outptr3), [x0] "=r"(x0) + [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), + [m_remain] "+r"(m_remain), [zero_point_A] "+r"(zero_point_A), + [zero_point_B] "+r"(zero_point_B), [zAB] "+r"(zAB), [LDC] "+r"(LDC), + [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), + [b0a] "=w"(b0a), [b1a] "=w"(b1a), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "=r"(x0) : - : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v23", "v24", - "v25", "v26", "v28", "v29", "v30", "memory", "cc"); + : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v23", "v24", "v25", + "v26", "v28", "v29", "v30", "memory", "cc"); #undef LOAD_LINE #undef LOAD_C @@ -496,9 +493,10 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, // A -> v27, v28 | B -> v29 | zA * zB * k -> v26 MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, int n_remain, - uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { +static void kern_8x4( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int n_remain, uint8_t zero_point_A, uint8_t zero_point_B, + uint32_t zAB) { K /= 4; const uint8_t* a_ptr = packA; const uint8_t* b_ptr = packB; @@ -689,16 +687,15 @@ static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [LDC] "+r"(LDC), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), - [zero_point_A] "+r"(zero_point_A), - [zero_point_B] "+r"(zero_point_B), [zAB] "+r"(zAB), [a0] "=w"(a0), - [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), [b0] "=w"(b0), - [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), - [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), - [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), - [outptr7] "=r"(outptr7), [x0] "=r"(x0) + [zero_point_A] "+r"(zero_point_A), [zero_point_B] "+r"(zero_point_B), + [zAB] "+r"(zAB), [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), + [a1a] "=w"(a1a), [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), + [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), + [x0] "=r"(x0) : - : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v23", "v24", - "v25", "v26", "v27", "v28", "v29", "memory", "cc"); + : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "memory", "cc"); #undef LOAD_LINE #undef LOAD_C @@ -735,10 +732,10 @@ static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, // A -> v28 | B -> v29 | zA * zB * k -> v26 MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, int m_remain, - int n_remain, uint8_t zero_point_A, uint8_t zero_point_B, - uint32_t zAB) { +static void kern_4x4( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain, uint8_t zero_point_A, + uint8_t zero_point_B, uint32_t zAB) { K /= 4; const int32_t* a_ptr = reinterpret_cast(packA); const int32_t* b_ptr = reinterpret_cast(packB); @@ -890,12 +887,11 @@ static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), - [zero_point_A] "+r"(zero_point_A), - [zero_point_B] "+r"(zero_point_B), [zAB] "+r"(zAB), - [outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), - [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), - [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), - [outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1) + [zero_point_A] "+r"(zero_point_A), [zero_point_B] "+r"(zero_point_B), + [zAB] "+r"(zAB), [outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), + [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), + [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "=r"(x0), + [x1] "=r"(x1) : : "v4", "v5", "v6", "v7", "v23", "v24", "v25", "v26", "v28", "v29", "memory", "cc"); @@ -908,9 +904,8 @@ static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, #undef SUB_LANE -static void gemm_u8_8x8_transpose_pack_helper(uint8_t* out, const uint8_t* in, - int ldin, int x0, int xmax, - int k0, int kmax) { +static void gemm_u8_8x8_transpose_pack_helper( + uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax) { uint8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(uint8_t) * 16); const int ksize = kmax - k0; @@ -997,10 +992,9 @@ static void gemm_u8_8x8_transpose_pack_helper(uint8_t* out, const uint8_t* in, } } -static void gemm_u8_8x8_interleave_pack_helper(uint8_t* outptr, - const uint8_t* inptr, int ldin, - int y0, int ymax, int k0, - int kmax) { +static void gemm_u8_8x8_interleave_pack_helper( + uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { uint8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(uint8_t) * 16); @@ -1027,13 +1021,15 @@ static void gemm_u8_8x8_interleave_pack_helper(uint8_t* outptr, int K = kmax - k0; //! read 8 * 4 in each row for (; K > 15; K -= 16) { - interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + interleave_8x4_4_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, K); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, K); } } for (; y < ymax; y += 4) { diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp index e8edf8ea..206beafa 100644 --- a/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp @@ -23,35 +23,37 @@ using namespace aarch64::matmul; MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot); -void gemm_u8_8x8_dot::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +void gemm_u8_8x8_dot::pack_A( + uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { if (transpose) { - matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, - ymax, k0, kmax); + matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( + outptr, inptr, ldin, y0, ymax, k0, kmax); } else { - matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, - y0, ymax, k0, kmax); + matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( + outptr, inptr, ldin, y0, ymax, k0, kmax); } } -void gemm_u8_8x8_dot::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { +void gemm_u8_8x8_dot::pack_B( + uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, - xmax, k0, kmax); + matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( + out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, - k0, kmax); + matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( + out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_u8_8x8_dot::kern(const uint8_t* packA, const uint8_t* packB, size_t M, - size_t N, size_t K, dt_int32* C, size_t LDC, - bool is_first_k, const dt_int32*, dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Quantized8Asymm && - C_dtype.enumv() == DTypeEnum::QuantizedS32); +void gemm_u8_8x8_dot::kern( + const uint8_t* packA, const uint8_t* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Quantized8Asymm && + C_dtype.enumv() == DTypeEnum::QuantizedS32); MEGDNN_MARK_USED_VAR(C_dtype); size_t zero_point_A = A_dtype.param().zero_point; size_t zero_point_B = B_dtype.param().zero_point; @@ -71,16 +73,17 @@ void gemm_u8_8x8_dot::kern(const uint8_t* packA, const uint8_t* packB, size_t M, size_t n = 0; const dt_uint8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k, - zero_point_A, zero_point_B, zAB); + matmul_8x8x4::kern_8x8( + packA, cur_packB, K, output, LDC, is_first_k, zero_point_A, + zero_point_B, zAB); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_8x8x4::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(N - n, 4), zero_point_A, - zero_point_B, zAB); + matmul_8x8x4::kern_8x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4), zero_point_A, zero_point_B, zAB); output += 4; cur_packB += K4; } @@ -92,18 +95,18 @@ void gemm_u8_8x8_dot::kern(const uint8_t* packA, const uint8_t* packB, size_t M, const dt_uint8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), zero_point_A, - zero_point_B, zAB); + matmul_8x8x4::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), zero_point_A, zero_point_B, zAB); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_8x8x4::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), - std::min(N - n, 4), zero_point_A, - zero_point_B, zAB); + matmul_8x8x4::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4), + zero_point_A, zero_point_B, zAB); output += 4; cur_packB += K4; } diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h index 1ed9c474..965945d8 100644 --- a/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h @@ -16,11 +16,11 @@ namespace megdnn { namespace aarch64 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(uint8_t, int32_t, int32_t, 8, 8, 4, false, true, - gemm_u8_8x8_dot); +MEGDNN_REG_GEMM_STRATEGY( + uint8_t, int32_t, int32_t, 8, 8, 4, false, true, gemm_u8_8x8_dot); -} // namespace aarch64 } // namespace matmul +} // namespace aarch64 } // namespace megdnn #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/relayout/opr_impl.cpp b/dnn/src/aarch64/relayout/opr_impl.cpp index 8af0b8cc..9031d2c7 100644 --- a/dnn/src/aarch64/relayout/opr_impl.cpp +++ b/dnn/src/aarch64/relayout/opr_impl.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/common/utils.h" #include "src/common/relayout_helper.h" +#include "src/common/utils.h" #include "src/aarch64/handle.h" #include "src/aarch64/relayout/opr_impl.h" @@ -24,121 +24,114 @@ struct TransposeByte { uint8_t v; }; -void trans_16x16_u8(const void* src, void* dst, const size_t src_step, - const size_t dst_step) { +void trans_16x16_u8( + const void* src, void* dst, const size_t src_step, const size_t dst_step) { asm volatile( - "\n" - "ld1 {v0.16b}, [%[src]], %[src_step] \n" - "ld1 {v1.16b}, [%[src]], %[src_step] \n" - "ld1 {v2.16b}, [%[src]], %[src_step] \n" - "ld1 {v3.16b}, [%[src]], %[src_step] \n" - "ld1 {v4.16b}, [%[src]], %[src_step] \n" - "ld1 {v5.16b}, [%[src]], %[src_step] \n" - "ld1 {v6.16b}, [%[src]], %[src_step] \n" - "ld1 {v7.16b}, [%[src]], %[src_step] \n" - "ld1 {v8.16b}, [%[src]], %[src_step] \n" - "ld1 {v9.16b}, [%[src]], %[src_step] \n" - "ld1 {v10.16b}, [%[src]], %[src_step] \n" - "ld1 {v11.16b}, [%[src]], %[src_step] \n" - "ld1 {v12.16b}, [%[src]], %[src_step] \n" - "ld1 {v13.16b}, [%[src]], %[src_step] \n" - "ld1 {v14.16b}, [%[src]], %[src_step] \n" - "ld1 {v15.16b}, [%[src]], %[src_step] \n" - "trn1 v16.16b, v0.16b, v1.16b \n" - "trn2 v17.16b, v0.16b, v1.16b \n" - "trn1 v18.16b, v2.16b, v3.16b \n" - "trn2 v19.16b, v2.16b, v3.16b \n" - "trn1 v20.16b, v4.16b, v5.16b \n" - "trn2 v21.16b, v4.16b, v5.16b \n" - "trn1 v22.16b, v6.16b, v7.16b \n" - "trn2 v23.16b, v6.16b, v7.16b \n" - "trn1 v24.16b, v8.16b, v9.16b \n" - "trn2 v25.16b, v8.16b, v9.16b \n" - "trn1 v26.16b, v10.16b, v11.16b \n" - "trn2 v27.16b, v10.16b, v11.16b \n" - "trn1 v28.16b, v12.16b, v13.16b \n" - "trn2 v29.16b, v12.16b, v13.16b \n" - "trn1 v30.16b, v14.16b, v15.16b \n" - "trn2 v31.16b, v14.16b, v15.16b \n" - "trn1 v0.8h, v16.8h, v18.8h \n" - "trn2 v2.8h, v16.8h, v18.8h \n" - "trn1 v4.8h, v20.8h, v22.8h \n" - "trn2 v6.8h, v20.8h, v22.8h \n" - "trn1 v8.8h, v24.8h, v26.8h \n" - "trn2 v10.8h, v24.8h, v26.8h \n" - "trn1 v12.8h, v28.8h, v30.8h \n" - "trn2 v14.8h, v28.8h, v30.8h \n" - "trn1 v1.8h, v17.8h, v19.8h \n" - "trn2 v3.8h, v17.8h, v19.8h \n" - "trn1 v5.8h, v21.8h, v23.8h \n" - "trn2 v7.8h, v21.8h, v23.8h \n" - "trn1 v9.8h, v25.8h, v27.8h \n" - "trn2 v11.8h, v25.8h, v27.8h \n" - "trn1 v13.8h, v29.8h, v31.8h \n" - "trn2 v15.8h, v29.8h, v31.8h \n" - "trn1 v16.4s, v0.4s, v4.4s \n" - "trn2 v20.4s, v0.4s, v4.4s \n" - "trn1 v24.4s, v8.4s, v12.4s \n" - "trn2 v28.4s, v8.4s, v12.4s \n" - "trn1 v17.4s, v1.4s, v5.4s \n" - "trn2 v21.4s, v1.4s, v5.4s \n" - "trn1 v25.4s, v9.4s, v13.4s \n" - "trn2 v29.4s, v9.4s, v13.4s \n" - "trn1 v18.4s, v2.4s, v6.4s \n" - "trn2 v22.4s, v2.4s, v6.4s \n" - "trn1 v26.4s, v10.4s, v14.4s \n" - "trn2 v30.4s, v10.4s, v14.4s \n" - "trn1 v19.4s, v3.4s, v7.4s \n" - "trn2 v23.4s, v3.4s, v7.4s \n" - "trn1 v27.4s, v11.4s, v15.4s \n" - "trn2 v31.4s, v11.4s, v15.4s \n" - "trn1 v0.2d, v16.2d, v24.2d \n" - "trn2 v8.2d, v16.2d, v24.2d \n" - "trn1 v1.2d, v17.2d, v25.2d \n" - "trn2 v9.2d, v17.2d, v25.2d \n" - "trn1 v2.2d, v18.2d, v26.2d \n" - "trn2 v10.2d, v18.2d, v26.2d \n" - "trn1 v3.2d, v19.2d, v27.2d \n" - "trn2 v11.2d, v19.2d, v27.2d \n" - "trn1 v4.2d, v20.2d, v28.2d \n" - "trn2 v12.2d, v20.2d, v28.2d \n" - "trn1 v5.2d, v21.2d, v29.2d \n" - "trn2 v13.2d, v21.2d, v29.2d \n" - "trn1 v6.2d, v22.2d, v30.2d \n" - "trn2 v14.2d, v22.2d, v30.2d \n" - "trn1 v7.2d, v23.2d, v31.2d \n" - "trn2 v15.2d, v23.2d, v31.2d \n" - "st1 {v0.16b}, [%[dst]], %[dst_step] \n" - "st1 {v1.16b}, [%[dst]], %[dst_step] \n" - "st1 {v2.16b}, [%[dst]], %[dst_step] \n" - "st1 {v3.16b}, [%[dst]], %[dst_step] \n" - "st1 {v4.16b}, [%[dst]], %[dst_step] \n" - "st1 {v5.16b}, [%[dst]], %[dst_step] \n" - "st1 {v6.16b}, [%[dst]], %[dst_step] \n" - "st1 {v7.16b}, [%[dst]], %[dst_step] \n" - "st1 {v8.16b}, [%[dst]], %[dst_step] \n" - "st1 {v9.16b}, [%[dst]], %[dst_step] \n" - "st1 {v10.16b}, [%[dst]], %[dst_step] \n" - "st1 {v11.16b}, [%[dst]], %[dst_step] \n" - "st1 {v12.16b}, [%[dst]], %[dst_step] \n" - "st1 {v13.16b}, [%[dst]], %[dst_step] \n" - "st1 {v14.16b}, [%[dst]], %[dst_step] \n" - "st1 {v15.16b}, [%[dst]], %[dst_step] \n" - : - [src] "+r" (src), - [dst] "+r" (dst) - : - [src_step] "r" (src_step), - [dst_step] "r" (dst_step) - : - "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", - "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", - "d31"); - + "\n" + "ld1 {v0.16b}, [%[src]], %[src_step] \n" + "ld1 {v1.16b}, [%[src]], %[src_step] \n" + "ld1 {v2.16b}, [%[src]], %[src_step] \n" + "ld1 {v3.16b}, [%[src]], %[src_step] \n" + "ld1 {v4.16b}, [%[src]], %[src_step] \n" + "ld1 {v5.16b}, [%[src]], %[src_step] \n" + "ld1 {v6.16b}, [%[src]], %[src_step] \n" + "ld1 {v7.16b}, [%[src]], %[src_step] \n" + "ld1 {v8.16b}, [%[src]], %[src_step] \n" + "ld1 {v9.16b}, [%[src]], %[src_step] \n" + "ld1 {v10.16b}, [%[src]], %[src_step] \n" + "ld1 {v11.16b}, [%[src]], %[src_step] \n" + "ld1 {v12.16b}, [%[src]], %[src_step] \n" + "ld1 {v13.16b}, [%[src]], %[src_step] \n" + "ld1 {v14.16b}, [%[src]], %[src_step] \n" + "ld1 {v15.16b}, [%[src]], %[src_step] \n" + "trn1 v16.16b, v0.16b, v1.16b \n" + "trn2 v17.16b, v0.16b, v1.16b \n" + "trn1 v18.16b, v2.16b, v3.16b \n" + "trn2 v19.16b, v2.16b, v3.16b \n" + "trn1 v20.16b, v4.16b, v5.16b \n" + "trn2 v21.16b, v4.16b, v5.16b \n" + "trn1 v22.16b, v6.16b, v7.16b \n" + "trn2 v23.16b, v6.16b, v7.16b \n" + "trn1 v24.16b, v8.16b, v9.16b \n" + "trn2 v25.16b, v8.16b, v9.16b \n" + "trn1 v26.16b, v10.16b, v11.16b \n" + "trn2 v27.16b, v10.16b, v11.16b \n" + "trn1 v28.16b, v12.16b, v13.16b \n" + "trn2 v29.16b, v12.16b, v13.16b \n" + "trn1 v30.16b, v14.16b, v15.16b \n" + "trn2 v31.16b, v14.16b, v15.16b \n" + "trn1 v0.8h, v16.8h, v18.8h \n" + "trn2 v2.8h, v16.8h, v18.8h \n" + "trn1 v4.8h, v20.8h, v22.8h \n" + "trn2 v6.8h, v20.8h, v22.8h \n" + "trn1 v8.8h, v24.8h, v26.8h \n" + "trn2 v10.8h, v24.8h, v26.8h \n" + "trn1 v12.8h, v28.8h, v30.8h \n" + "trn2 v14.8h, v28.8h, v30.8h \n" + "trn1 v1.8h, v17.8h, v19.8h \n" + "trn2 v3.8h, v17.8h, v19.8h \n" + "trn1 v5.8h, v21.8h, v23.8h \n" + "trn2 v7.8h, v21.8h, v23.8h \n" + "trn1 v9.8h, v25.8h, v27.8h \n" + "trn2 v11.8h, v25.8h, v27.8h \n" + "trn1 v13.8h, v29.8h, v31.8h \n" + "trn2 v15.8h, v29.8h, v31.8h \n" + "trn1 v16.4s, v0.4s, v4.4s \n" + "trn2 v20.4s, v0.4s, v4.4s \n" + "trn1 v24.4s, v8.4s, v12.4s \n" + "trn2 v28.4s, v8.4s, v12.4s \n" + "trn1 v17.4s, v1.4s, v5.4s \n" + "trn2 v21.4s, v1.4s, v5.4s \n" + "trn1 v25.4s, v9.4s, v13.4s \n" + "trn2 v29.4s, v9.4s, v13.4s \n" + "trn1 v18.4s, v2.4s, v6.4s \n" + "trn2 v22.4s, v2.4s, v6.4s \n" + "trn1 v26.4s, v10.4s, v14.4s \n" + "trn2 v30.4s, v10.4s, v14.4s \n" + "trn1 v19.4s, v3.4s, v7.4s \n" + "trn2 v23.4s, v3.4s, v7.4s \n" + "trn1 v27.4s, v11.4s, v15.4s \n" + "trn2 v31.4s, v11.4s, v15.4s \n" + "trn1 v0.2d, v16.2d, v24.2d \n" + "trn2 v8.2d, v16.2d, v24.2d \n" + "trn1 v1.2d, v17.2d, v25.2d \n" + "trn2 v9.2d, v17.2d, v25.2d \n" + "trn1 v2.2d, v18.2d, v26.2d \n" + "trn2 v10.2d, v18.2d, v26.2d \n" + "trn1 v3.2d, v19.2d, v27.2d \n" + "trn2 v11.2d, v19.2d, v27.2d \n" + "trn1 v4.2d, v20.2d, v28.2d \n" + "trn2 v12.2d, v20.2d, v28.2d \n" + "trn1 v5.2d, v21.2d, v29.2d \n" + "trn2 v13.2d, v21.2d, v29.2d \n" + "trn1 v6.2d, v22.2d, v30.2d \n" + "trn2 v14.2d, v22.2d, v30.2d \n" + "trn1 v7.2d, v23.2d, v31.2d \n" + "trn2 v15.2d, v23.2d, v31.2d \n" + "st1 {v0.16b}, [%[dst]], %[dst_step] \n" + "st1 {v1.16b}, [%[dst]], %[dst_step] \n" + "st1 {v2.16b}, [%[dst]], %[dst_step] \n" + "st1 {v3.16b}, [%[dst]], %[dst_step] \n" + "st1 {v4.16b}, [%[dst]], %[dst_step] \n" + "st1 {v5.16b}, [%[dst]], %[dst_step] \n" + "st1 {v6.16b}, [%[dst]], %[dst_step] \n" + "st1 {v7.16b}, [%[dst]], %[dst_step] \n" + "st1 {v8.16b}, [%[dst]], %[dst_step] \n" + "st1 {v9.16b}, [%[dst]], %[dst_step] \n" + "st1 {v10.16b}, [%[dst]], %[dst_step] \n" + "st1 {v11.16b}, [%[dst]], %[dst_step] \n" + "st1 {v12.16b}, [%[dst]], %[dst_step] \n" + "st1 {v13.16b}, [%[dst]], %[dst_step] \n" + "st1 {v14.16b}, [%[dst]], %[dst_step] \n" + "st1 {v15.16b}, [%[dst]], %[dst_step] \n" + : [src] "+r"(src), [dst] "+r"(dst) + : [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31"); } -} // anonymous namespace +} // anonymous namespace namespace megdnn { namespace relayout { @@ -149,9 +142,9 @@ struct transpose_traits { }; template <> -void transpose_block(const TransposeByte* src, - TransposeByte* dst, const size_t src_stride, - const size_t dst_stride) { +void transpose_block( + const TransposeByte* src, TransposeByte* dst, const size_t src_stride, + const size_t dst_stride) { trans_16x16_u8(src, dst, src_stride, dst_stride); } @@ -159,9 +152,8 @@ void transpose_block(const TransposeByte* src, } // namespace relayout } // namespace megdnn -void aarch64::RelayoutForwardImpl::exec(_megdnn_tensor_in src0, - _megdnn_tensor_out dst0, - Handle* src_handle) { +void aarch64::RelayoutForwardImpl::exec( + _megdnn_tensor_in src0, _megdnn_tensor_out dst0, Handle* src_handle) { check_cpu_handle(src_handle); TensorND src = src0, dst = dst0; check_layout_and_canonize(src.layout, dst.layout); @@ -178,10 +170,8 @@ void aarch64::RelayoutForwardImpl::exec(_megdnn_tensor_in src0, if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { auto sptr = static_cast(src.raw_ptr), dptr = static_cast(dst.raw_ptr); - MEGDNN_DISPATCH_CPU_KERN_OPR( - transpose_fallback::transpose( - trans_param.batch, trans_param.m, trans_param.n, sptr, - dptr)); + MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose( + trans_param.batch, trans_param.m, trans_param.n, sptr, dptr)); return; } exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); diff --git a/dnn/src/aarch64/relayout/opr_impl.h b/dnn/src/aarch64/relayout/opr_impl.h index 61035b27..e897c342 100644 --- a/dnn/src/aarch64/relayout/opr_impl.h +++ b/dnn/src/aarch64/relayout/opr_impl.h @@ -16,11 +16,11 @@ namespace megdnn { namespace aarch64 { class RelayoutForwardImpl final : public fallback::RelayoutForwardImpl { - public: +public: using fallback::RelayoutForwardImpl::RelayoutForwardImpl; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - Handle *src_handle) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, Handle* src_handle) override; bool is_thread_safe() const override { return true; } }; diff --git a/dnn/src/aarch64/rotate/opr_impl.cpp b/dnn/src/aarch64/rotate/opr_impl.cpp index 1ad25d88..b80783e2 100644 --- a/dnn/src/aarch64/rotate/opr_impl.cpp +++ b/dnn/src/aarch64/rotate/opr_impl.cpp @@ -10,21 +10,19 @@ */ #include -#include "src/aarch64/rotate/opr_impl.h" #include "src/aarch64/handle.h" +#include "src/aarch64/rotate/opr_impl.h" #include "src/common/cv/common.h" #include "src/common/cv/helper.h" #include "src/common/utils.h" - namespace megdnn { namespace megcv { -void rotate_8uc1_clockwise_16x16(const uchar *src, - uchar *dst, - size_t src_step, size_t dst_step) -{ - asm volatile ("\n" +void rotate_8uc1_clockwise_16x16( + const uchar* src, uchar* dst, size_t src_step, size_t dst_step) { + asm volatile( + "\n" "ld1 {v0.16b}, [%[src]], %[src_step] \n" "ld1 {v1.16b}, [%[src]], %[src_step] \n" "ld1 {v2.16b}, [%[src]], %[src_step] \n" @@ -109,7 +107,7 @@ void rotate_8uc1_clockwise_16x16(const uchar *src, "trn2 v14.2d, v22.2d, v30.2d \n" "trn1 v7.2d, v23.2d, v31.2d \n" "trn2 v15.2d, v23.2d, v31.2d \n" -// There is no rev128 instruction, so we use rev64 and ext to simulate it. + // There is no rev128 instruction, so we use rev64 and ext to simulate it. "rev64 v0.16b, v0.16b \n" "rev64 v1.16b, v1.16b \n" "rev64 v2.16b, v2.16b \n" @@ -159,23 +157,16 @@ void rotate_8uc1_clockwise_16x16(const uchar *src, "st1 {v13.16b}, [%[dst]], %[dst_step] \n" "st1 {v14.16b}, [%[dst]], %[dst_step] \n" "st1 {v15.16b}, [%[dst]], %[dst_step] \n" - : - [src] "+r" (src), - [dst] "+r" (dst) - : - [src_step] "r" (src_step), - [dst_step] "r" (dst_step) - : - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", - "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - ); + : [src] "+r"(src), [dst] "+r"(dst) + : [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15"); } -void rotate_8uc1_counterclockwise_16x16(const uchar *src, - uchar *dst, - size_t src_step, size_t dst_step) -{ - asm volatile ("\n" +void rotate_8uc1_counterclockwise_16x16( + const uchar* src, uchar* dst, size_t src_step, size_t dst_step) { + asm volatile( + "\n" "ld1 {v0.16b}, [%[src]], %[src_step] \n" "ld1 {v1.16b}, [%[src]], %[src_step] \n" "ld1 {v2.16b}, [%[src]], %[src_step] \n" @@ -277,21 +268,15 @@ void rotate_8uc1_counterclockwise_16x16(const uchar *src, "st1 {v2.16b}, [%[dst]], %[dst_step] \n" "st1 {v1.16b}, [%[dst]], %[dst_step] \n" "st1 {v0.16b}, [%[dst]], %[dst_step] \n" - : - [src] "+r" (src), - [dst] "+r" (dst) - : - [src_step] "r" (src_step), - [dst_step] "r" (dst_step) - : - "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", - "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - ); + : [src] "+r"(src), [dst] "+r"(dst) + : [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15"); } -void rotate_8uc1_clockwise(const uchar* src, uchar* dst, const size_t rows, - const size_t cols, const size_t src_step, - const size_t dst_step) { +void rotate_8uc1_clockwise( + const uchar* src, uchar* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { const size_t block = 16; (void)block; size_t i = 0; @@ -300,14 +285,12 @@ void rotate_8uc1_clockwise(const uchar* src, uchar* dst, const size_t rows, size_t j = 0; for (; j + block <= cols; j += block) { rotate_8uc1_clockwise_16x16( - src + i * src_step + j, - dst + j * dst_step + (rows - (i + block)), src_step, - dst_step); + src + i * src_step + j, dst + j * dst_step + (rows - (i + block)), + src_step, dst_step); } for (; j < cols; ++j) { for (size_t k = 0; k < block; ++k) { - dst[j * dst_step + (rows - 1 - (i + k))] = - src[(i + k) * src_step + j]; + dst[j * dst_step + (rows - 1 - (i + k))] = src[(i + k) * src_step + j]; } } } @@ -319,10 +302,9 @@ void rotate_8uc1_clockwise(const uchar* src, uchar* dst, const size_t rows, } } -void rotate_8uc1_counterclockwise(const uchar* src, uchar* dst, - const size_t rows, const size_t cols, - const size_t src_step, - const size_t dst_step) { +void rotate_8uc1_counterclockwise( + const uchar* src, uchar* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { const size_t block = 16; (void)block; size_t i = 0; @@ -331,14 +313,12 @@ void rotate_8uc1_counterclockwise(const uchar* src, uchar* dst, size_t j = 0; for (; j + block <= cols; j += block) { rotate_8uc1_counterclockwise_16x16( - src + i * src_step + j, - dst + (cols - (j + block)) * dst_step + i, src_step, - dst_step); + src + i * src_step + j, dst + (cols - (j + block)) * dst_step + i, + src_step, dst_step); } for (; j < cols; ++j) { for (size_t k = 0; k < block; ++k) { - dst[(cols - 1 - j) * dst_step + (i + k)] = - src[(i + k) * src_step + j]; + dst[(cols - 1 - j) * dst_step + (i + k)] = src[(i + k) * src_step + j]; } } } @@ -356,11 +336,11 @@ void rotate(const Mat& src, Mat& dst, bool clockwise) { megdnn_assert(src.channels() == dst.channels()); megdnn_assert(src.channels() == 1_z); if (clockwise) { - rotate_8uc1_clockwise(src.ptr(), dst.ptr(), src.rows(), src.cols(), - src.step(), dst.step()); + rotate_8uc1_clockwise( + src.ptr(), dst.ptr(), src.rows(), src.cols(), src.step(), dst.step()); } else { - rotate_8uc1_counterclockwise(src.ptr(), dst.ptr(), src.rows(), - src.cols(), src.step(), dst.step()); + rotate_8uc1_counterclockwise( + src.ptr(), dst.ptr(), src.rows(), src.cols(), src.step(), dst.step()); } } @@ -368,8 +348,8 @@ void rotate(const Mat& src, Mat& dst, bool clockwise) { namespace aarch64 { -void RotateImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_workspace workspace) { +void RotateImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { using namespace megcv; check_exec(src.layout, dst.layout, workspace.size); diff --git a/dnn/src/aarch64/rotate/opr_impl.h b/dnn/src/aarch64/rotate/opr_impl.h index eba709fc..6a95d267 100644 --- a/dnn/src/aarch64/rotate/opr_impl.h +++ b/dnn/src/aarch64/rotate/opr_impl.h @@ -17,16 +17,16 @@ namespace megdnn { namespace aarch64 { class RotateImpl : public fallback::RotateImpl { - public: - using fallback::RotateImpl::RotateImpl; +public: + using fallback::RotateImpl::RotateImpl; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, - const TensorLayout&) override { - return 0; - } + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { + return 0; + } }; } // namespace aarch64 diff --git a/dnn/src/aarch64/warp_perspective/opr_impl.cpp b/dnn/src/aarch64/warp_perspective/opr_impl.cpp index defa9d05..c2c62bda 100644 --- a/dnn/src/aarch64/warp_perspective/opr_impl.cpp +++ b/dnn/src/aarch64/warp_perspective/opr_impl.cpp @@ -19,18 +19,15 @@ namespace megdnn { namespace aarch64 { -void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in mat, - _megdnn_tensor_in mat_idx, - _megdnn_tensor_in dst, - _megdnn_workspace workspace) -{ - check_exec(src.layout, mat.layout, mat_idx.layout, dst.layout, - workspace.size); - if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, - param().format)) { - warp_perspective_cv_exec(src, mat, mat_idx, dst, param().border_val, - param().bmode, param().imode, handle()); +void WarpPerspectiveImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, + _megdnn_tensor_in dst, _megdnn_workspace workspace) { + check_exec(src.layout, mat.layout, mat_idx.layout, dst.layout, workspace.size); + if (warp::is_cv_available( + src.layout, mat.layout, dst.layout, param().imode, param().format)) { + warp_perspective_cv_exec( + src, mat, mat_idx, dst, param().border_val, param().bmode, + param().imode, handle()); } else { //! Use arm_common implementation arm_common::WarpPerspectiveImpl::exec(src, mat, mat_idx, dst, workspace); diff --git a/dnn/src/aarch64/warp_perspective/opr_impl.h b/dnn/src/aarch64/warp_perspective/opr_impl.h index 23227613..d25b6586 100644 --- a/dnn/src/aarch64/warp_perspective/opr_impl.h +++ b/dnn/src/aarch64/warp_perspective/opr_impl.h @@ -19,9 +19,9 @@ class WarpPerspectiveImpl : public arm_common::WarpPerspectiveImpl { public: using arm_common::WarpPerspectiveImpl::WarpPerspectiveImpl; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, - _megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; }; } // namespace aarch64 diff --git a/dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp b/dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp index 743524ca..91108e23 100644 --- a/dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp +++ b/dnn/src/aarch64/warp_perspective/warp_perspective_cv.cpp @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/aarch64/handle.h" #include "src/aarch64/warp_perspective/warp_perspective_cv.h" +#include "src/aarch64/handle.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/cv/common.h" #include "src/common/cv/helper.h" @@ -25,8 +25,9 @@ namespace { constexpr size_t BLOCK_SZ = 32u; template -void warp_perspective_cv(const Mat& src, Mat& dst, const float* trans, - const float border_value, size_t task_id) { +void warp_perspective_cv( + const Mat& src, Mat& dst, const float* trans, const float border_value, + size_t task_id) { // no extra padding double M[9]; rep(i, 9) M[i] = trans[i]; @@ -127,10 +128,8 @@ void warp_perspective_cv(const Mat& src, Mat& dst, const float* trans, float64x2_t vw0 = vaddq_f64(vW0, vmulq_f64(vM6, vidx)); float64x2_t vw1 = vaddq_f64(vw0, v2M6); - vw0 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw0), v0f, - vceqq_f64(vw0, v0f)); - vw1 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw1), v0f, - vceqq_f64(vw1, v0f)); + vw0 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw0), v0f, vceqq_f64(vw0, v0f)); + vw1 = vbitq_f64(vdivq_f64(vTABLE_SIZE, vw1), v0f, vceqq_f64(vw1, v0f)); float64x2_t vtmp0 = vmlaq_f64(vX0, vM0, vidx); float64x2_t vtmp1 = vaddq_f64(vtmp0, v2M0); @@ -154,8 +153,8 @@ void warp_perspective_cv(const Mat& src, Mat& dst, const float* trans, int32x4_t vx = vcombine_s32(vx0, vx1); int32x4_t vy = vcombine_s32(vy0, vy1); - int16x4x2_t ret = {{vqshrn_n_s32(vx, INTER_BITS), - vqshrn_n_s32(vy, INTER_BITS)}}; + int16x4x2_t ret = { + {vqshrn_n_s32(vx, INTER_BITS), vqshrn_n_s32(vy, INTER_BITS)}}; vst2_s16(xy + x1 * 2, ret); vidx = vaddq_f64(vidx, v4f); @@ -163,8 +162,7 @@ void warp_perspective_cv(const Mat& src, Mat& dst, const float* trans, vx = vandq_s32(vx, vtabmask); vy = vandq_s32(vy, vtabmask); - vst1_s16(&alpha[x1], - vqmovn_s32(vmlaq_n_s32(vx, vy, INTER_TAB_SIZE))); + vst1_s16(&alpha[x1], vqmovn_s32(vmlaq_n_s32(vx, vy, INTER_TAB_SIZE))); } for (; x1 < bw; x1++) { double W = W0 + M[6] * x1; @@ -190,9 +188,9 @@ void warp_perspective_cv(const Mat& src, Mat& dst, const float* trans, } } // anonymous namespace void megdnn::aarch64::warp_perspective_cv_exec( - _megdnn_tensor_in src, _megdnn_tensor_in trans, - _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, float border_value, - BorderMode bmode, InterpolationMode imode, Handle* handle) { + _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in mat_idx, + _megdnn_tensor_in dst, float border_value, BorderMode bmode, + InterpolationMode imode, Handle* handle) { size_t ch = dst.layout[3]; size_t width = dst.layout[2]; size_t height = dst.layout[1]; @@ -202,11 +200,11 @@ void megdnn::aarch64::warp_perspective_cv_exec( size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); - size_t parallelism_batch = div_ceil(height, BLOCK_SZ_H) * - div_ceil(width, BLOCK_SZ_W); - megdnn_assert(ch == 1 || ch == 3 || ch == 2, - "unsupported src channel: %zu, avaiable channel size: 1/2/3", - ch); + size_t parallelism_batch = + div_ceil(height, BLOCK_SZ_H) * div_ceil(width, BLOCK_SZ_W); + megdnn_assert( + ch == 1 || ch == 3 || ch == 2, + "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); const float* trans_ptr = trans.ptr(); const int* midx_ptr = nullptr; if (mat_idx.raw_ptr) { @@ -214,63 +212,61 @@ void megdnn::aarch64::warp_perspective_cv_exec( midx_ptr = mat_idx.ptr(); } if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { -#define cb(_imode, _bmode, _ch) \ - auto task = [src, trans_ptr, midx_ptr, dst, border_value, \ - parallelism_batch](size_t index, size_t) { \ - size_t batch_id = index / parallelism_batch; \ - size_t task_id = index % parallelism_batch; \ - size_t src_id = batch_id; \ - if (midx_ptr) { \ - src_id = midx_ptr[batch_id]; \ - megdnn_assert( \ - src_id < src.layout.shape[0], \ - "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ - batch_id, src_id, src.layout.shape[0]); \ - } \ - Mat src_mat = TensorND2Mat(src, src_id); \ - Mat dst_mat = TensorND2Mat(dst, batch_id); \ - const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ - warp_perspective_cv( \ - src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ - MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ - task_id); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast(handle), batch* parallelism_batch, \ - task); - DISPATCH_IMODE(imode, bmode, ch, cb) +#define cb(_imode, _bmode, _ch) \ + auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ + src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ + warp_perspective_cv< \ + float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode MEGDNN_COMMA _ch>( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), batch* parallelism_batch, task); + DISPATCH_IMODE(imode, bmode, ch, cb) #undef cb - } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { -#define cb(_imode, _bmode, _ch) \ - auto task = [src, trans_ptr, midx_ptr, dst, border_value, \ - parallelism_batch](size_t index, size_t) { \ - size_t batch_id = index / parallelism_batch; \ - size_t task_id = index % parallelism_batch; \ - size_t src_id = batch_id; \ - if (midx_ptr) { \ - src_id = midx_ptr[batch_id]; \ - megdnn_assert( \ - src_id < src.layout.shape[0], \ - "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ - batch_id, src_id, src.layout.shape[0]); \ - } \ - Mat src_mat = TensorND2Mat(src, src_id); \ - Mat dst_mat = TensorND2Mat(dst, batch_id); \ - const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ - warp_perspective_cv( \ - src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ - MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ - task_id); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast(handle), batch* parallelism_batch, \ - task); - DISPATCH_IMODE(imode, bmode, ch, cb) + } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { +#define cb(_imode, _bmode, _ch) \ + auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ + src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ + warp_perspective_cv< \ + uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode MEGDNN_COMMA _ch>( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), batch* parallelism_batch, task); + DISPATCH_IMODE(imode, bmode, ch, cb) #undef cb - } else { - megdnn_throw("Unsupported datatype of WarpPerspective optr."); - } + } else { + megdnn_throw("Unsupported datatype of WarpPerspective optr."); + } } // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/warp_perspective/warp_perspective_cv.h b/dnn/src/aarch64/warp_perspective/warp_perspective_cv.h index 39071b6b..6918ed1d 100644 --- a/dnn/src/aarch64/warp_perspective/warp_perspective_cv.h +++ b/dnn/src/aarch64/warp_perspective/warp_perspective_cv.h @@ -20,12 +20,11 @@ namespace aarch64 { * \fn warp_perspective_cv * \brief Used if the format is NHWC, transfer from megcv */ -void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, - _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, - float border_value, - param::WarpPerspective::BorderMode border_mode, - param::WarpPerspective::InterpolationMode imode, - Handle* handle); +void warp_perspective_cv_exec( + _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in mat_idx, + _megdnn_tensor_in dst, float border_value, + param::WarpPerspective::BorderMode border_mode, + param::WarpPerspective::InterpolationMode imode, Handle* handle); } // namespace aarch64 } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/block_helper.h b/dnn/src/arm_common/conv_bias/block_helper.h index 9eea2a2d..d733ba0b 100644 --- a/dnn/src/arm_common/conv_bias/block_helper.h +++ b/dnn/src/arm_common/conv_bias/block_helper.h @@ -14,12 +14,12 @@ namespace megdnn { namespace { // block_helper is used to calculate oh block size -static inline int l2_block_helper(const int nthread, const int amount, - const int size_per_unit) { +static inline int l2_block_helper( + const int nthread, const int amount, const int size_per_unit) { constexpr int l2_cache_size = 256 * 1024; const int block_per_thread = div_ceil(amount, nthread); - const int best_block = std::min( - amount, (l2_cache_size + size_per_unit / 2) / size_per_unit); + const int best_block = + std::min(amount, (l2_cache_size + size_per_unit / 2) / size_per_unit); const int max_block_num = div_ceil(block_per_thread, best_block); const int min_block_num = std::max(max_block_num - 1, 1); const int max_block = div_ceil(block_per_thread, max_block_num); diff --git a/dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp b/dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp index e68a6cb4..ebf43f6e 100644 --- a/dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp +++ b/dnn/src/arm_common/conv_bias/direct/multi_thread_common.cpp @@ -17,22 +17,20 @@ using namespace megdnn; using namespace arm_common; namespace { -bool need_dst_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { auto align = param.src_type.enumv() == DTypeEnum::Float32 ? 4 : 8; return param.osz[1] % align; } -bool need_src_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_src_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { return true; } return need_dst_copy(param); } void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, - size_t PH, size_t PW, size_t& IH2, size_t& IW2, size_t& OW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t IH, + size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH, size_t PW, + size_t& IH2, size_t& IW2, size_t& OW2) { MEGDNN_MARK_USED_VAR(PW); MEGDNN_MARK_USED_VAR(PH); auto&& fm = param.filter_meta; @@ -75,8 +73,8 @@ WorkspaceBundle MultithreadDirectConvCommon::get_bundle if (param.filter_meta.should_flip) { if (m_large_group) { //! Serial in group, each thread has own workspace and then reuse - part1 = fm.spatial[0] * fm.spatial[1] * fm.ocpg * fm.icpg * - nr_threads * sizeof(io_ctype); + part1 = fm.spatial[0] * fm.spatial[1] * fm.ocpg * fm.icpg * nr_threads * + sizeof(io_ctype); } else { part1 = fm.spatial[0] * fm.spatial[1] * fm.ocpg * fm.icpg * group * sizeof(io_ctype); @@ -87,8 +85,7 @@ WorkspaceBundle MultithreadDirectConvCommon::get_bundle return {nullptr, {part0, part1}}; } template -WorkspaceBundle -MultithreadDirectConvCommon::get_bundle_stride( +WorkspaceBundle MultithreadDirectConvCommon::get_bundle_stride( const ConvBiasImpl::NCBKernSizeParam& param, bool m_large_group) { UNPACK_CONV_F32_NCB_KERN_SIZES(param); MEGDNN_MARK_USED_VAR(N); @@ -105,9 +102,8 @@ MultithreadDirectConvCommon::get_bundle_stride( // src_size: copied src // dst_size: copied dst if (need_src_copy(param)) { - src_size = m_large_group - ? IC * IH2 * IW2 * sizeof(io_ctype) * nr_threads - : IC * IH2 * IW2 * sizeof(io_ctype) * group * batch; + src_size = m_large_group ? IC * IH2 * IW2 * sizeof(io_ctype) * nr_threads + : IC * IH2 * IW2 * sizeof(io_ctype) * group * batch; }; if (need_dst_copy(param)) { //! add 16 Byte extra space in case of invalid read and write @@ -119,10 +115,8 @@ MultithreadDirectConvCommon::get_bundle_stride( //! Process one output channel weight flip template void MultithreadDirectConvCommon::weight_flip_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t FH = kern_param.filter_meta.spatial[0]; size_t FW = kern_param.filter_meta.spatial[1]; size_t IC = kern_param.filter_meta.icpg; @@ -132,9 +126,8 @@ void MultithreadDirectConvCommon::weight_flip_kern( group_id = ncb_index.ndrange_id[0]; const io_ctype* filter = kern_param.filter(group_id) + channel_id * FH * FW * IC; - io_ctype* filter_flip = - static_cast(bundle.get(1)) + - (workspace_group_id * IC * OC + channel_id * IC) * FH * FW; + io_ctype* filter_flip = static_cast(bundle.get(1)) + + (workspace_group_id * IC * OC + channel_id * IC) * FH * FW; rep(ic, IC) { const io_ctype* filter_plane = filter + ic * FH * FW; io_ctype* filter_flip_plane = filter_flip + ic * FH * FW; @@ -148,10 +141,8 @@ void MultithreadDirectConvCommon::weight_flip_kern( //! Process one input channel copy padding template void MultithreadDirectConvCommon::copy_padding_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -164,10 +155,9 @@ void MultithreadDirectConvCommon::copy_padding_kern( size_t GROUP = kern_param.filter_meta.group; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; - size_t batch_id = ncb_index.ndrange_id[1], - group_id = ncb_index.ndrange_id[0]; + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + channel_id = workspace_ids[2]; + size_t batch_id = ncb_index.ndrange_id[1], group_id = ncb_index.ndrange_id[0]; const io_ctype* sptr = static_cast( kern_param.src(batch_id, group_id, channel_id)); if (PH > 0 || PW > 0) { @@ -178,11 +168,11 @@ void MultithreadDirectConvCommon::copy_padding_kern( channel_id * IH2 * IW2; std::memset(sptr_base, 0, sizeof(io_ctype) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(io_ctype) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(io_ctype) * IW); } - } else if (batch_id + 1 == N && channel_id + 1 == IC && - group_id + 1 == GROUP) { + } else if (batch_id + 1 == N && channel_id + 1 == IC && group_id + 1 == GROUP) { //! copy last plane io_ctype* sptr_last_c = static_cast(bundle.get(0)); std::memcpy(sptr_last_c, sptr, sizeof(io_ctype) * IH2 * IW2); @@ -190,11 +180,9 @@ void MultithreadDirectConvCommon::copy_padding_kern( }; //! Process one input channel copy padding template -void MultithreadDirectConvCommon:: - copy_padding_kern_stride(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { +void MultithreadDirectConvCommon::copy_padding_kern_stride( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -210,8 +198,7 @@ void MultithreadDirectConvCommon:: size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1]; + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1]; size_t channel_id = workspace_ids[2], batch_id = ncb_index.ndrange_id[1], group_id = ncb_index.ndrange_id[0]; @@ -225,8 +212,9 @@ void MultithreadDirectConvCommon:: channel_id * IH2 * IW2; std::memset(sptr_base, 0, sizeof(io_ctype) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(io_ctype) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(io_ctype) * IW); } } }; @@ -234,10 +222,9 @@ void MultithreadDirectConvCommon:: //! compute one output channel template void MultithreadDirectConvCommon::do_conv_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const kern_direct_conv_f32& fun, const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const kern_direct_conv_f32& fun, + const CpuNDRange& workspace_ids) { size_t OH = kern_param.osz[0]; size_t OW = kern_param.osz[1]; size_t FH = kern_param.filter_meta.spatial[0]; @@ -252,8 +239,7 @@ void MultithreadDirectConvCommon::do_conv_kern( size_t N = kern_param.n; size_t GROUP = kern_param.filter_meta.group; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; size_t channel_id = workspace_ids[2]; const io_ctype* sptr = kern_param.src(batch_id, group_id); @@ -271,8 +257,7 @@ void MultithreadDirectConvCommon::do_conv_kern( auto fptr = kern_param.filter_meta.should_flip ? static_cast(bundle.get(1)) + - (workspace_group_id * OC * IC + channel_id * IC) * - FH * FW + (workspace_group_id * OC * IC + channel_id * IC) * FH * FW : filter + channel_id * FH * FW * IC; if (PH > 0 || PW > 0) { sptr_base = static_cast(bundle.get(0)) + @@ -289,26 +274,23 @@ void MultithreadDirectConvCommon::do_conv_kern( } std::memset(dptr, 0, sizeof(io_ctype) * (OH * OW)); rep(ic, IC) { - io_ctype* sptr_cur = - (ic + 1 == IC ? sptr_last_c : sptr_base + ic * IH2 * IW2); + io_ctype* sptr_cur = (ic + 1 == IC ? sptr_last_c : sptr_base + ic * IH2 * IW2); fun(reinterpret_cast(sptr_cur), reinterpret_cast(fptr + ic * FH * FW), reinterpret_cast(dptr), IH2, IW2, OH, OW, FH, FW); } - PostProcess::run(dptr, const_cast(bias_ptr), dptr, - kern_param.bias_mode, kern_param.nonlineMode, - kern_param.bias_type, kern_param.dst_type, 1_z, - 1_z, OH, OW); + PostProcess::run( + dptr, const_cast(bias_ptr), dptr, kern_param.bias_mode, + kern_param.nonlineMode, kern_param.bias_type, kern_param.dst_type, 1_z, 1_z, + OH, OW); }; //! compute one output channel template void MultithreadDirectConvCommon::do_conv_kern_stride( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, const ConvBiasImpl::NCBKernIndex& ncb_index, - const kern_direct_conv_f32_stride& fun, - const CpuNDRange& workspace_ids) { + const kern_direct_conv_f32_stride& fun, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t OH = kern_param.osz[0]; @@ -325,8 +307,7 @@ void MultithreadDirectConvCommon::do_conv_kern_stride( size_t GROUP = kern_param.filter_meta.group; //! Used for get the workspace offset - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; size_t channel_id = workspace_ids[2]; const io_ctype* sptr = kern_param.src(batch_id, group_id); @@ -349,8 +330,8 @@ void MultithreadDirectConvCommon::do_conv_kern_stride( sptr_base = const_cast(sptr); } if (need_dst_copy(kern_param)) { - dptr_base = static_cast(bundle.get(1)) + - ncb_index.thread_id * OH * OW2; + dptr_base = + static_cast(bundle.get(1)) + ncb_index.thread_id * OH * OW2; } else { dptr_base = dptr; } @@ -359,18 +340,19 @@ void MultithreadDirectConvCommon::do_conv_kern_stride( fun(reinterpret_cast(sptr_base), reinterpret_cast(fptr), reinterpret_cast(dptr_base), IH2, IW2, OH, OW2, IC); - copy_plane_in_bytes(dptr, dptr_base, OH, OW * sizeof(io_ctype), - OW * sizeof(io_ctype), OW2 * sizeof(io_ctype)); + copy_plane_in_bytes( + dptr, dptr_base, OH, OW * sizeof(io_ctype), OW * sizeof(io_ctype), + OW2 * sizeof(io_ctype)); } else { std::memset(dptr_base, 0, sizeof(io_ctype) * (OH * OW)); fun(reinterpret_cast(sptr_base), reinterpret_cast(fptr), reinterpret_cast(dptr_base), IH2, IW2, OH, OW, IC); } - PostProcess::run(dptr, const_cast(bias_ptr), dptr, - kern_param.bias_mode, kern_param.nonlineMode, - kern_param.bias_type, kern_param.dst_type, 1_z, - 1_z, OH, OW); + PostProcess::run( + dptr, const_cast(bias_ptr), dptr, kern_param.bias_mode, + kern_param.nonlineMode, kern_param.bias_type, kern_param.dst_type, 1_z, 1_z, + OH, OW); }; template class megdnn::arm_common::MultithreadDirectConvCommon; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/arm_common/conv_bias/direct/multi_thread_common.h b/dnn/src/arm_common/conv_bias/direct/multi_thread_common.h index 921ed6ef..09d47fe5 100644 --- a/dnn/src/arm_common/conv_bias/direct/multi_thread_common.h +++ b/dnn/src/arm_common/conv_bias/direct/multi_thread_common.h @@ -23,40 +23,34 @@ public: using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; - using kern_direct_conv_f32 = - std::function; + using kern_direct_conv_f32 = std::function; using kern_direct_conv_f32_stride = std::function; + const compute_ctype* src, const compute_ctype* filter, compute_ctype* dst, + size_t, size_t, size_t, size_t, size_t)>; - static WorkspaceBundle get_bundle(const NCBKernSizeParam& param, - bool m_large_group); - static WorkspaceBundle get_bundle_stride(const NCBKernSizeParam& param, - bool m_large_group); - static void weight_flip_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); - static void copy_padding_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); - static void copy_padding_kern_stride(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); - static void do_conv_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const kern_direct_conv_f32& fun, - const CpuNDRange& workspace_ids); - static void do_conv_kern_stride(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const kern_direct_conv_f32_stride& fun, - const CpuNDRange& workspace_ids); + static WorkspaceBundle get_bundle( + const NCBKernSizeParam& param, bool m_large_group); + static WorkspaceBundle get_bundle_stride( + const NCBKernSizeParam& param, bool m_large_group); + static void weight_flip_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); + static void copy_padding_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); + static void copy_padding_kern_stride( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); + static void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const kern_direct_conv_f32& fun, + const CpuNDRange& workspace_ids); + static void do_conv_kern_stride( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const kern_direct_conv_f32_stride& fun, + const CpuNDRange& workspace_ids); }; } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/f16/algos.cpp b/dnn/src/arm_common/conv_bias/f16/algos.cpp index 9302b29b..d26912aa 100644 --- a/dnn/src/arm_common/conv_bias/f16/algos.cpp +++ b/dnn/src/arm_common/conv_bias/f16/algos.cpp @@ -33,9 +33,9 @@ bool ConvBiasImpl::AlgoFP16WinogradF23::usable( MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 0, 0) { using Strategy = winograd::winograd_2x3_4x4_f16; Strategy strategy(param.src_type, param.filter_type, param.dst_type); - auto&& matmul_param = megdnn::winograd::ConvBias( - strategy, m_tile_size, param) - .get_matmul_kern_param(param); + auto&& matmul_param = + megdnn::winograd::ConvBias(strategy, m_tile_size, param) + .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && param.filter_meta.format == param::ConvBias::Format::NCHW && !param.filter_meta.should_flip && @@ -43,22 +43,19 @@ bool ConvBiasImpl::AlgoFP16WinogradF23::usable( param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float16 && - param.filter_meta.icpg % 4 == 0 && - param.filter_meta.ocpg % 4 == 0; + param.filter_meta.icpg % 4 == 0 && param.filter_meta.ocpg % 4 == 0; } MIDOUT_END(); return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP16WinogradF23, - winograd::winograd_2x3_4x4_f16, - megdnn_arm_common_winograd_fp16, - param::MatrixMul::Format::DEFAULT); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP16WinogradF23, winograd::winograd_2x3_4x4_f16, + megdnn_arm_common_winograd_fp16, param::MatrixMul::Format::DEFAULT); /* ======================= AlgoFP16WinogradF45 ======================== */ @@ -69,9 +66,9 @@ bool ConvBiasImpl::AlgoFP16WinogradF45::usable( MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 1, 0) { using Strategy = winograd::winograd_4x5_1x1_f16; Strategy strategy(param.src_type, param.filter_type, param.dst_type); - auto&& matmul_param = megdnn::winograd::ConvBias( - strategy, m_tile_size, param) - .get_matmul_kern_param(param); + auto&& matmul_param = + megdnn::winograd::ConvBias(strategy, m_tile_size, param) + .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && param.filter_meta.format == param::ConvBias::Format::NCHW && !param.filter_meta.should_flip && @@ -79,8 +76,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF45::usable( param.filter_meta.spatial[0] == 5) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float16; @@ -89,10 +85,9 @@ bool ConvBiasImpl::AlgoFP16WinogradF45::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP16WinogradF45, - winograd::winograd_4x5_1x1_f16, - megdnn_arm_common_winograd_fp16, - param::MatrixMul::Format::DEFAULT); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP16WinogradF45, winograd::winograd_4x5_1x1_f16, + megdnn_arm_common_winograd_fp16, param::MatrixMul::Format::DEFAULT); /* ======================= AlgoFP16WinogradF63 ======================== */ @@ -103,9 +98,9 @@ bool ConvBiasImpl::AlgoFP16WinogradF63::usable( MIDOUT_BEGIN(megdnn_arm_common_winograd_fp16, 2, 0) { using Strategy = winograd::winograd_6x3_1x1_f16; Strategy strategy(param.src_type, param.filter_type, param.dst_type); - auto&& matmul_param = megdnn::winograd::ConvBias( - strategy, m_tile_size, param) - .get_matmul_kern_param(param); + auto&& matmul_param = + megdnn::winograd::ConvBias(strategy, m_tile_size, param) + .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && param.filter_meta.format == param::ConvBias::Format::NCHW && !param.filter_meta.should_flip && @@ -113,8 +108,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF63::usable( param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float16; @@ -123,10 +117,9 @@ bool ConvBiasImpl::AlgoFP16WinogradF63::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP16WinogradF63, - winograd::winograd_6x3_1x1_f16, - megdnn_arm_common_winograd_fp16, - param::MatrixMul::Format::DEFAULT); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP16WinogradF63, winograd::winograd_6x3_1x1_f16, + megdnn_arm_common_winograd_fp16, param::MatrixMul::Format::DEFAULT); /* ======================= AlgoFP16WinogradF23_8x8 ======================== */ @@ -141,8 +134,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable( using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = - megdnn::winograd::ConvBias( + megdnn::winograd::ConvBias( strategy, m_tile_size, param) .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && @@ -153,8 +145,7 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable( param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float16; @@ -163,17 +154,16 @@ bool ConvBiasImpl::AlgoFP16WinogradF23_8x8::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP16WinogradF23_8x8, - winograd::winograd_2x3_8x8_f16, - megdnn_arm_common_winograd_fp16, - param::MatrixMul::Format::MK8); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP16WinogradF23_8x8, winograd::winograd_2x3_8x8_f16, + megdnn_arm_common_winograd_fp16, param::MatrixMul::Format::MK8); /*========================from Convolution=============================*/ MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_kimpl) -bool ConvBiasImpl::AlgoF16Direct::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoF16Direct::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; @@ -185,23 +175,20 @@ bool ConvBiasImpl::AlgoF16Direct::usable(const NCBKernSizeParam& param, return fm.format == param::ConvBias::Format::NCHW && param.src_type.enumv() == DTypeEnum::Float16 && param.filter_type.enumv() == DTypeEnum::Float16 && - param.dst_type.enumv() == DTypeEnum::Float16 && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && param.isz[0] * param.isz[1] >= 8 && - param.osz[0] * param.osz[1] >= 8 && FH <= 7 && SH == 1 && - SW == 1; + param.dst_type.enumv() == DTypeEnum::Float16 && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + param.isz[0] * param.isz[1] >= 8 && param.osz[0] * param.osz[1] >= 8 && + FH <= 7 && SH == 1 && SW == 1; } MIDOUT_END(); return false; } -size_t ConvBiasImpl::AlgoF16Direct::get_workspace( - const NCBKernSizeParam& param) const { +size_t ConvBiasImpl::AlgoF16Direct::get_workspace(const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 0, 1) { bool large_group = param.filter_meta.group >= param.nr_threads; - auto wbundle = - MultithreadDirectConvCommon::get_bundle( - param, large_group); + auto wbundle = MultithreadDirectConvCommon::get_bundle( + param, large_group); return wbundle.total_size_in_bytes(); } MIDOUT_END(); @@ -224,56 +211,57 @@ SmallVector ConvBiasImpl::AlgoF16Direct::get_kimpls( //! one group for better performance if (large_group) { //! Channel wise conv and big groups - auto exec_one_group = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto exec_one_group = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { auto fm = kern_param.filter_meta; size_t IC = fm.icpg; size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); if (fm.should_flip) { for (size_t oc = 0; oc < OC; oc++) { - MultithreadDirectConvCommon:: - weight_flip_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + MultithreadDirectConvCommon::weight_flip_kern( + bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, oc}); } } for (size_t ic = 0; ic < IC; ic++) { - MultithreadDirectConvCommon:: - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + MultithreadDirectConvCommon::copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { MultithreadDirectConvCommon::do_conv_kern( - bundle, kern_param, ncb_index, - fp16::conv_bias::kern_direct_f16, + bundle, kern_param, ncb_index, fp16::conv_bias::kern_direct_f16, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { if (fm.should_flip) { - auto weight_flip = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto weight_flip = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - MultithreadDirectConvCommon:: - weight_flip_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + MultithreadDirectConvCommon::weight_flip_kern( + bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({weight_flip, {group, 1_z, OC}}); } - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); MultithreadDirectConvCommon::copy_padding_kern( bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); - auto do_conv = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto do_conv = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); MultithreadDirectConvCommon::do_conv_kern( - bundle, kern_param, ncb_index, - fp16::conv_bias::kern_direct_f16, ncb_index.ndrange_id); + bundle, kern_param, ncb_index, fp16::conv_bias::kern_direct_f16, + ncb_index.ndrange_id); }; ret_kerns.push_back({do_conv, {group, N, OC}}); } @@ -291,25 +279,24 @@ SmallVector ConvBiasImpl::AlgoF16Direct::dispatch_kerns( /* ===================== stride-1 algo ===================== */ -bool ConvBiasImpl::AlgoF16DirectStride1::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoF16DirectStride1::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; return param.filter_meta.format == param::ConvBias::Format::NCHW && param.src_type.enumv() == DTypeEnum::Float16 && param.filter_type.enumv() == DTypeEnum::Float16 && - param.dst_type.enumv() == DTypeEnum::Float16 && - !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); + param.dst_type.enumv() == DTypeEnum::Float16 && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5); } MIDOUT_END(); return false; } -SmallVector -ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( +SmallVector ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( const NCBKernSizeParam& param) const { auto fm = param.filter_meta; auto FH = fm.spatial[0]; @@ -318,8 +305,9 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( size_t OC = param.filter_meta.ocpg; size_t group = fm.group; bool large_group = group >= param.nr_threads; - using Func = std::function; + using Func = std::function; Func conv_kern_function = nullptr; #define SWITCH_KERN() \ @@ -353,34 +341,33 @@ ConvBiasImpl::AlgoF16DirectStride1::get_kimpls( bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { MultithreadDirectConvCommon:: - copy_padding_kern_stride(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern_stride( + bundle, kern_param, ncb_index, + {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - MultithreadDirectConvCommon:: - do_conv_kern_stride(bundle, kern_param, ncb_index, - conv_kern_function, - {ncb_index.thread_id, 0, oc}); + MultithreadDirectConvCommon::do_conv_kern_stride( + bundle, kern_param, ncb_index, conv_kern_function, + {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - MultithreadDirectConvCommon:: - copy_padding_kern_stride(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + MultithreadDirectConvCommon::copy_padding_kern_stride( + bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); auto do_conv = [bundle, conv_kern_function]( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - MultithreadDirectConvCommon:: - do_conv_kern_stride(bundle, kern_param, ncb_index, - conv_kern_function, - ncb_index.ndrange_id); + MultithreadDirectConvCommon::do_conv_kern_stride( + bundle, kern_param, ncb_index, conv_kern_function, + ncb_index.ndrange_id); }; ret_kerns.push_back({do_conv, {group, N, OC}}); } @@ -391,16 +378,16 @@ size_t ConvBiasImpl::AlgoF16DirectStride1::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 1) { bool large_group = param.filter_meta.group >= param.nr_threads; - auto bundle = MultithreadDirectConvCommon< - dt_float16, __fp16>::get_bundle_stride(param, large_group); + auto bundle = + MultithreadDirectConvCommon::get_bundle_stride( + param, large_group); return bundle.total_size_in_bytes(); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoF16DirectStride1::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoF16DirectStride1::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_kimpl, 1, 2) { return get_kimpls(param); diff --git a/dnn/src/arm_common/conv_bias/f16/algos.h b/dnn/src/arm_common/conv_bias/f16/algos.h index 23a5d9cc..4f8f9d7f 100644 --- a/dnn/src/arm_common/conv_bias/f16/algos.h +++ b/dnn/src/arm_common/conv_bias/f16/algos.h @@ -20,8 +20,8 @@ namespace arm_common { class ConvBiasImpl::AlgoFP16WinogradF23 final : public AlgoBase { public: - AlgoFP16WinogradF23(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoFP16WinogradF23( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} const char* name() const override { if (m_name.empty()) { @@ -39,8 +39,8 @@ public: class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { public: - AlgoFP16WinogradF45(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoFP16WinogradF45( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} const char* name() const override { if (m_name.empty()) { @@ -57,8 +57,8 @@ public: }; class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { public: - AlgoFP16WinogradF63(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoFP16WinogradF63( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} const char* name() const override { if (m_name.empty()) { @@ -75,8 +75,8 @@ public: }; class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { public: - AlgoFP16WinogradF23_8x8(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoFP16WinogradF23_8x8( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} const char* name() const override { if (m_name.empty()) { @@ -85,9 +85,7 @@ public: } return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_FP16) }; @@ -96,12 +94,11 @@ class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "F16DIRECT"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; @@ -118,12 +115,11 @@ class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "F16STRD1"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; @@ -137,12 +133,11 @@ class ConvBiasImpl::AlgoF16ChannelWiseNCHW88 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "F16_CHANNEL_WISE_NCHW88"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -158,12 +153,11 @@ class ConvBiasImpl::AlgoF16DirectNCHW88 final : public AlgoBase { public: AlgoF16DirectNCHW88() {} - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "F16_CONV_NCHW88_DIRECT"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp index fbd5fa7a..aa9a072d 100644 --- a/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.cpp @@ -60,8 +60,7 @@ static inline void shift_src(float16x8_t rsrc[3][4]) { } template -static inline float16x8_t load_bias(const float16_t* bias, - const float16x8_t& init) { +static inline float16x8_t load_bias(const float16_t* bias, const float16x8_t& init) { if (bias_mode == BiasMode::BIAS) { return vld1q_f16(bias); } else { @@ -72,11 +71,10 @@ static inline float16x8_t load_bias(const float16_t* bias, template struct compute_element { template - static inline void call(const float16_t*& src0, const float16_t*& src1, - const float16_t*& src2, float16_t*& dst, - const float16_t*& bias, const float16x8_t& init, - float16x8_t rsrc[3][4], float16x8_t rfilter[3][3], - const Op& op) { + static inline void call( + const float16_t*& src0, const float16_t*& src1, const float16_t*& src2, + float16_t*& dst, const float16_t*& bias, const float16x8_t& init, + float16x8_t rsrc[3][4], float16x8_t rfilter[3][3], const Op& op) { #define RSRC(i, j) rsrc[i][((j) + bw) % 4] float16x8_t rdst = load_bias(bias, init); if (has_top) { @@ -123,19 +121,18 @@ struct compute_element { template struct compute_element { template - static inline void call(const float16_t*& src0, const float16_t*& src1, - const float16_t*& src2, float16_t*& dst, - const float16_t*& bias, const float16x8_t& init, - float16x8_t rsrc[3][4], float16x8_t rfilter[3][3], - const Op& op) {} + static inline void call( + const float16_t*& src0, const float16_t*& src1, const float16_t*& src2, + float16_t*& dst, const float16_t*& bias, const float16x8_t& init, + float16x8_t rsrc[3][4], float16x8_t rfilter[3][3], const Op& op) {} }; template struct compute_element_right { template - static inline void call(float16_t*& dst, const float16_t*& bias, - const float16x8_t& init, float16x8_t rsrc[3][4], - float16x8_t rfilter[3][3], const Op& op) { + static inline void call( + float16_t*& dst, const float16_t*& bias, const float16x8_t& init, + float16x8_t rsrc[3][4], float16x8_t rfilter[3][3], const Op& op) { float16x8_t rdst = load_bias(bias, init); if (has_top) { @@ -164,9 +161,9 @@ struct compute_element_right { template struct compute_element_right_pad { template - static inline void call(float16_t*& dst, const float16_t*& bias, - const float16x8_t& init, float16x8_t rsrc[3][4], - float16x8_t rfilter[3][3], const Op& op) { + static inline void call( + float16_t*& dst, const float16_t*& bias, const float16x8_t& init, + float16x8_t rsrc[3][4], float16x8_t rfilter[3][3], const Op& op) { float16x8_t rdst = load_bias(bias, init); if (has_top) { @@ -191,11 +188,10 @@ struct compute_element_right_pad { template struct compute_row { template - static inline void call(const float16_t*& src0, const float16_t*& src1, - const float16_t*& src2, float16_t*& dst, - const float16_t*& bias, const float16x8_t& init, - float16x8_t rsrc[3][4], float16x8_t rfilter[3][3], - int W, const Op& op) { + static inline void call( + const float16_t*& src0, const float16_t*& src1, const float16_t*& src2, + float16_t*& dst, const float16_t*& bias, const float16x8_t& init, + float16x8_t rsrc[3][4], float16x8_t rfilter[3][3], int W, const Op& op) { if (has_top) { rsrc[0][0] = vdupq_n_f16(0); rsrc[0][1] = vld1q_f16(src0 + 0); @@ -283,23 +279,22 @@ void channel_wise_nchw88::do_conv_kern_3x3_stride1_padding1( float16x8_t rsrc[3][4]; - compute_row::call(src0, src1, src2, dst, bias, init, - rsrc, rfilter, W, op); + compute_row::call( + src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op); for (int h = 1; h < H - 1; h += 1) { - compute_row::call(src0, src1, src2, dst, bias, - init, rsrc, rfilter, W, op); + compute_row::call( + src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op); } - compute_row::call(src0, src1, src2, dst, bias, init, - rsrc, rfilter, W, op); + compute_row::call( + src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op); } -#define INSTANTIATION(bias, Op) \ - template void \ - channel_wise_nchw88::do_conv_kern_3x3_stride1_padding1( \ - const float16_t*, float16_t*, const float16_t*, const float16_t*, \ - int, int); +#define INSTANTIATION(bias, Op) \ + template void channel_wise_nchw88::do_conv_kern_3x3_stride1_padding1( \ + const float16_t*, float16_t*, const float16_t*, const float16_t*, int, \ + int); #define FOR_OP(bias) \ INSTANTIATION(bias, SigmoidOp<__fp16>) \ diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h index 0a4aa5ee..7e532b7d 100644 --- a/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_3x3_s1p1_nchw88_kern.h @@ -23,9 +23,9 @@ namespace fp16 { namespace channel_wise_nchw88 { template -void do_conv_kern_3x3_stride1_padding1(const __fp16* src, __fp16* dst, - const __fp16* filter, const __fp16* bias, - int H, int W); +void do_conv_kern_3x3_stride1_padding1( + const __fp16* src, __fp16* dst, const __fp16* filter, const __fp16* bias, int H, + int W); } // namespace channel_wise_nchw88 } // namespace fp16 diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp index face6042..f532a683 100644 --- a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_algo.cpp @@ -23,9 +23,9 @@ using namespace arm_common; using namespace fp16; using conv_fun = std::function; + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const size_t PH, size_t PW)>; MIDOUT_DECL(conv_bias_fp16_channel_wise_nchw88) @@ -36,10 +36,11 @@ bool ConvBiasImpl::AlgoF16ChannelWiseNCHW88::usable( size_t OC = fm.ocpg; size_t IC = fm.icpg; size_t GROUP = fm.group; - bool ok_type = (param.src_type.enumv() == DTypeEnum::Float16 && - param.filter_type.enumv() == DTypeEnum::Float16 && - param.bias_type.enumv() == DTypeEnum::Float16 && - param.dst_type.enumv() == DTypeEnum::Float16); + bool ok_type = + (param.src_type.enumv() == DTypeEnum::Float16 && + param.filter_type.enumv() == DTypeEnum::Float16 && + param.bias_type.enumv() == DTypeEnum::Float16 && + param.dst_type.enumv() == DTypeEnum::Float16); bool ok_format = OC == 1 && IC == 1 && GROUP % 8 == 0 && fm.format == param::Convolution::Format::NCHW88; bool ok_filter = fm.spatial_ndim == 2 && FH == fm.spatial[1] && @@ -57,9 +58,8 @@ size_t ConvBiasImpl::AlgoF16ChannelWiseNCHW88::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoF16ChannelWiseNCHW88::dispatch_kerns( - const NCBKernSizeParam& param) const { +SmallVector ConvBiasImpl::AlgoF16ChannelWiseNCHW88:: + dispatch_kerns(const NCBKernSizeParam& param) const { const constexpr size_t pack_group_size = 8_z; auto fm = param.filter_meta; const int batch = param.n; @@ -69,12 +69,14 @@ ConvBiasImpl::AlgoF16ChannelWiseNCHW88::dispatch_kerns( conv_fun do_conv_fun = nullptr; // NOTE: remain_w is not used to gen hash of midout for compatible with // shape runtime -#define DO_CONV_KERN_FUN(_stride, filter, bias_mode, op) \ - MIDOUT_BEGIN(conv_bias_fp16_channel_wise_nchw88, \ - midout_iv(#_stride #filter #bias_mode #op##_hash)) { \ - do_conv_fun = channel_wise_nchw88:: \ - do_conv_kern_##_stride##_##filter##x##filter; \ - } \ +#define DO_CONV_KERN_FUN(_stride, filter, bias_mode, op) \ + MIDOUT_BEGIN( \ + conv_bias_fp16_channel_wise_nchw88, \ + midout_iv(#_stride #filter #bias_mode #op##_hash)) { \ + do_conv_fun = \ + channel_wise_nchw88::do_conv_kern_##_stride##_##filter##x##filter< \ + bias_mode, op>; \ + } \ MIDOUT_END(); #define GET_OP_PARAM(_stride, filter, bias_mode) \ @@ -148,10 +150,11 @@ ConvBiasImpl::AlgoF16ChannelWiseNCHW88::dispatch_kerns( SmallVector ret_kerns; - CpuNDRange ncb_range = {static_cast(batch), - static_cast(group / pack_group_size)}; - auto do_conv = [do_conv_fun](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { + CpuNDRange ncb_range = { + static_cast(batch), static_cast(group / pack_group_size)}; + auto do_conv = [do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { size_t PH = kern_param.filter_meta.padding[0]; size_t PW = kern_param.filter_meta.padding[1]; size_t OH = kern_param.osz[0]; @@ -161,16 +164,14 @@ ConvBiasImpl::AlgoF16ChannelWiseNCHW88::dispatch_kerns( size_t batch_id = ncb_index.ndrange_id[0]; size_t group_id = ncb_index.ndrange_id[1]; - const __fp16* sptr = - reinterpret_cast(kern_param.src( - batch_id, group_id, 0, pack_group_size)); + const __fp16* sptr = reinterpret_cast( + kern_param.src(batch_id, group_id, 0, pack_group_size)); const __fp16* fptr = reinterpret_cast( kern_param.filter(group_id, pack_group_size)); - __fp16* dst = reinterpret_cast<__fp16*>(kern_param.dst( - batch_id, group_id, 0, pack_group_size)); - const __fp16* bptr = - reinterpret_cast(kern_param.bias( - batch_id, group_id, 0, pack_group_size)); + __fp16* dst = reinterpret_cast<__fp16*>( + kern_param.dst(batch_id, group_id, 0, pack_group_size)); + const __fp16* bptr = reinterpret_cast( + kern_param.bias(batch_id, group_id, 0, pack_group_size)); do_conv_fun(sptr, fptr, bptr, dst, IH, IW, OH, OW, PH, PW); }; diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp index 9457c2fc..1dd0934d 100644 --- a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.cpp @@ -52,11 +52,11 @@ template void compute_vec(float16x8_t& dst, float16x8_t* src, float16x8_t* filter); #define cb(i) dst = vfmaq_f16(dst, src[i], filter[i]); -#define COMPUTE_MACRO(n) \ - template <> \ - inline void compute_vec(float16x8_t & dst, float16x8_t * src, \ - float16x8_t * filter) { \ - UNROLL_CALL_NOWRAPPER(n, cb); \ +#define COMPUTE_MACRO(n) \ + template <> \ + inline void compute_vec( \ + float16x8_t & dst, float16x8_t * src, float16x8_t * filter) { \ + UNROLL_CALL_NOWRAPPER(n, cb); \ } COMPUTE_MACRO(2); COMPUTE_MACRO(3); @@ -70,17 +70,17 @@ struct load_bias_vec; #define cb_bias(i) dst[i] = vld1q_f16((bptr) + i * 8); #define cb_init(i) dst[i] = init; -#define INIT_BIAS_MACRO(n) \ - template \ - struct load_bias_vec { \ - static void impl(float16x8_t* dst, const float16x8_t& init, \ - const __fp16* bptr) { \ - if (bias_mode == BiasMode::BIAS) { \ - UNROLL_CALL_NOWRAPPER(n, cb_bias); \ - } else { \ - UNROLL_CALL_NOWRAPPER(n, cb_init); \ - } \ - } \ +#define INIT_BIAS_MACRO(n) \ + template \ + struct load_bias_vec { \ + static void impl( \ + float16x8_t* dst, const float16x8_t& init, const __fp16* bptr) { \ + if (bias_mode == BiasMode::BIAS) { \ + UNROLL_CALL_NOWRAPPER(n, cb_bias); \ + } else { \ + UNROLL_CALL_NOWRAPPER(n, cb_init); \ + } \ + } \ }; INIT_BIAS_MACRO(1); @@ -91,25 +91,23 @@ INIT_BIAS_MACRO(4); #undef INIT_BIAS_MACRO } // namespace -#define COMPUTE_PADDING_KERNEL(oh) \ - do { \ - int iw = ow * stride - PW; \ - float16x8_t result; \ - load_bias_vec::impl(&result, init, \ - bias + (oh)*OW * 8 + ow * 8); \ - for (int kh = 0; kh < fh; kh++) { \ - if (kh + ih < 0 || kh + ih >= static_cast(IH)) \ - continue; \ - for (int kw = 0; kw < fh; kw++) { \ - if (kw + iw < 0 || kw + iw >= static_cast(IW)) \ - continue; \ - const __fp16* sptr = src + (kh + ih) * IW * 8 + (kw + iw) * 8; \ - result = vfmaq_f16(result, kernel[kh * fh + kw], \ - vld1q_f16(sptr)); \ - } \ - } \ - __fp16* output = dst + (oh)*OW * 8 + ow * 8; \ - op(result, output); \ +#define COMPUTE_PADDING_KERNEL(oh) \ + do { \ + int iw = ow * stride - PW; \ + float16x8_t result; \ + load_bias_vec::impl(&result, init, bias + (oh)*OW * 8 + ow * 8); \ + for (int kh = 0; kh < fh; kh++) { \ + if (kh + ih < 0 || kh + ih >= static_cast(IH)) \ + continue; \ + for (int kw = 0; kw < fh; kw++) { \ + if (kw + iw < 0 || kw + iw >= static_cast(IW)) \ + continue; \ + const __fp16* sptr = src + (kh + ih) * IW * 8 + (kw + iw) * 8; \ + result = vfmaq_f16(result, kernel[kh * fh + kw], vld1q_f16(sptr)); \ + } \ + } \ + __fp16* output = dst + (oh)*OW * 8 + ow * 8; \ + op(result, output); \ } while (0) #define COMPUTE_PADDING_TOP() \ @@ -158,9 +156,9 @@ INIT_BIAS_MACRO(4); template void channel_wise_nchw88::do_conv_kern_stride1_2x2( - const __fp16* src, const __fp16* filter, const __fp16* bias, - __fp16* dst, const size_t IH, const size_t IW, const size_t OH, - const size_t OW, const size_t PH, const size_t PW) { + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const size_t PH, const size_t PW) { float16x8_t kernel[4]; load_vec<4>(kernel, filter); Op op; @@ -194,8 +192,8 @@ void channel_wise_nchw88::do_conv_kern_stride1_2x2( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[2][4]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 8 + ow * 8); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); float16x8_t src_v[3][5]; @@ -217,8 +215,8 @@ void channel_wise_nchw88::do_conv_kern_stride1_2x2( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[2]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 8 + ow * 8); load_bias_vec::impl( &dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); float16x8_t src_v[3][2]; @@ -246,8 +244,8 @@ void channel_wise_nchw88::do_conv_kern_stride1_2x2( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[1][4]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 8 + ow * 8); float16x8_t src_v[2][5]; load_vec<5>(src_v[0], input); COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); @@ -262,8 +260,8 @@ void channel_wise_nchw88::do_conv_kern_stride1_2x2( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v; - load_bias_vec::impl(&dst_v, init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v, init, bias + oh * OW * 8 + ow * 8); float16x8_t src_v[2][2]; load_vec<2>(src_v[0], input); compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); @@ -281,12 +279,12 @@ void channel_wise_nchw88::do_conv_kern_stride1_2x2( template void channel_wise_nchw88::do_conv_kern_stride1_3x3( - const __fp16* src, const __fp16* filter, const __fp16* bias, - __fp16* dst, const size_t IH, const size_t IW, const size_t OH, - const size_t OW, const size_t PH, const size_t PW) { + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const size_t PH, const size_t PW) { if (IH == OH && IW == OW && IH >= 3 && IW >= 3 && PH == 1 && PW == 1) { - do_conv_kern_3x3_stride1_padding1(src, dst, filter, bias, - OH, OW); + do_conv_kern_3x3_stride1_padding1( + src, dst, filter, bias, OH, OW); return; } @@ -318,8 +316,8 @@ void channel_wise_nchw88::do_conv_kern_stride1_3x3( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[1][2]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 8 + ow * 8); float16x8_t src_v[3][4]; load_vec<4>(src_v[0], input); load_vec<4>(src_v[1], input + IW * 8); @@ -338,8 +336,8 @@ void channel_wise_nchw88::do_conv_kern_stride1_3x3( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[1]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 8 + ow * 8); float16x8_t src_v[3][3]; load_vec<3>(src_v[0], input); load_vec<3>(src_v[1], input + IW * 8); @@ -358,9 +356,9 @@ void channel_wise_nchw88::do_conv_kern_stride1_3x3( template void channel_wise_nchw88::do_conv_kern_stride1_5x5( - const __fp16* src, const __fp16* filter, const __fp16* bias, - __fp16* dst, const size_t IH, const size_t IW, const size_t OH, - const size_t OW, const size_t PH, const size_t PW) { + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const size_t PH, const size_t PW) { float16x8_t kernel[25]; load_vec<25>(kernel, filter); Op op; @@ -390,8 +388,8 @@ void channel_wise_nchw88::do_conv_kern_stride1_5x5( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[2][2]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 8 + ow * 8); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); float16x8_t kernel[2][5]; @@ -429,8 +427,8 @@ void channel_wise_nchw88::do_conv_kern_stride1_5x5( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[2][1]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 8 + ow * 8); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); float16x8_t kernel[2][5]; @@ -471,8 +469,8 @@ void channel_wise_nchw88::do_conv_kern_stride1_5x5( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[1][2]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 8 + ow * 8); float16x8_t kernel[2][5]; float16x8_t src_v[2][6]; #define COMPUTE_5X5_2(i, dst, src, kernel) \ @@ -498,8 +496,8 @@ void channel_wise_nchw88::do_conv_kern_stride1_5x5( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v; - load_bias_vec::impl(&dst_v, init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v, init, bias + oh * OW * 8 + ow * 8); float16x8_t kernel[2][5]; float16x8_t src_v[2][5]; #define COMPUTE_5X5_1(i, dst, src, kernel) \ @@ -526,9 +524,9 @@ void channel_wise_nchw88::do_conv_kern_stride1_5x5( template void channel_wise_nchw88::do_conv_kern_stride2_2x2( - const __fp16* src, const __fp16* filter, const __fp16* bias, - __fp16* dst, const size_t IH, const size_t IW, const size_t OH, - const size_t OW, const size_t PH, const size_t PW) { + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const size_t PH, const size_t PW) { float16x8_t kernel[4]; load_vec<4>(kernel, filter); Op op; @@ -561,8 +559,8 @@ void channel_wise_nchw88::do_conv_kern_stride2_2x2( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[4]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 8 + ow * 8); float16x8_t src_v[2][8]; load_vec<8>(src_v[0], input); COMPUTE_2X2(dst_v, src_v[0], &kernel[0]); @@ -577,8 +575,8 @@ void channel_wise_nchw88::do_conv_kern_stride2_2x2( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v; - load_bias_vec::impl(&dst_v, init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v, init, bias + oh * OW * 8 + ow * 8); float16x8_t src_v[2][2]; load_vec<2>(src_v[0], input); compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); @@ -595,9 +593,9 @@ void channel_wise_nchw88::do_conv_kern_stride2_2x2( template void channel_wise_nchw88::do_conv_kern_stride2_3x3( - const __fp16* src, const __fp16* filter, const __fp16* bias, - __fp16* dst, const size_t IH, const size_t IW, const size_t OH, - const size_t OW, const size_t PH, const size_t PW) { + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const size_t PH, const size_t PW) { float16x8_t kernel[9]; load_vec<9>(kernel, filter); Op op; @@ -625,8 +623,8 @@ void channel_wise_nchw88::do_conv_kern_stride2_3x3( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[2][2]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 8 + ow * 8); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); float16x8_t src_v[2][5]; @@ -656,8 +654,8 @@ void channel_wise_nchw88::do_conv_kern_stride2_3x3( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[2]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 8 + ow * 8); load_bias_vec::impl( &dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); float16x8_t src_v[2][3]; @@ -687,8 +685,8 @@ void channel_wise_nchw88::do_conv_kern_stride2_3x3( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[2]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 8 + ow * 8); float16x8_t src_v[3][5]; load_vec<5>(src_v[0], input); compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); @@ -706,8 +704,8 @@ void channel_wise_nchw88::do_conv_kern_stride2_3x3( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v; - load_bias_vec::impl(&dst_v, init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v, init, bias + oh * OW * 8 + ow * 8); float16x8_t src_v[3][3]; load_vec<3>(src_v[0], input); compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); @@ -724,9 +722,9 @@ void channel_wise_nchw88::do_conv_kern_stride2_3x3( template void channel_wise_nchw88::do_conv_kern_stride2_5x5( - const __fp16* src, const __fp16* filter, const __fp16* bias, - __fp16* dst, const size_t IH, const size_t IW, const size_t OH, - const size_t OW, const size_t PH, const size_t PW) { + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const size_t PH, const size_t PW) { float16x8_t kernel[25]; load_vec<25>(kernel, filter); Op op; @@ -754,8 +752,8 @@ void channel_wise_nchw88::do_conv_kern_stride2_5x5( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[2][2]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 8 + ow * 8); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); float16x8_t kernel[3][5]; @@ -798,8 +796,8 @@ void channel_wise_nchw88::do_conv_kern_stride2_5x5( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v[2]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 8 + ow * 8); load_bias_vec::impl( &dst_v[1], init, bias + (oh + 1) * OW * 8 + ow * 8); float16x8_t kernel[3][5]; @@ -845,8 +843,8 @@ void channel_wise_nchw88::do_conv_kern_stride2_5x5( const __fp16* input = src + ih * IW * 8 + iw * 8; __fp16* output = dst + oh * OW * 8 + ow * 8; float16x8_t dst_v; - load_bias_vec::impl(&dst_v, init, - bias + oh * OW * 8 + ow * 8); + load_bias_vec::impl( + &dst_v, init, bias + oh * OW * 8 + ow * 8); float16x8_t kernel[2][5]; float16x8_t src_v[2][5]; #define COMPUTE_5X5_1(i, dst, src, kernel) \ @@ -871,12 +869,10 @@ void channel_wise_nchw88::do_conv_kern_stride2_5x5( COMPUTE_PADDING_BOTTOM(); } -#define INSTANTIATION(stride, i, bias, Op) \ - template void \ - channel_wise_nchw88::do_conv_kern_##stride##_##i##x##i( \ - const __fp16*, const __fp16*, const __fp16*, __fp16*, \ - const size_t, const size_t, const size_t, const size_t, \ - const size_t, const size_t); +#define INSTANTIATION(stride, i, bias, Op) \ + template void channel_wise_nchw88::do_conv_kern_##stride##_##i##x##i( \ + const __fp16*, const __fp16*, const __fp16*, __fp16*, const size_t, \ + const size_t, const size_t, const size_t, const size_t, const size_t); #define FOR_OP(stride, i, bias) \ INSTANTIATION(stride, i, bias, SigmoidOp<__fp16>) \ diff --git a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h index 033e8d56..d0f8b459 100644 --- a/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h +++ b/dnn/src/arm_common/conv_bias/f16/channel_wise_nchw88_kern.h @@ -22,12 +22,12 @@ namespace arm_common { namespace fp16 { namespace channel_wise_nchw88 { -#define KERN(stride, i) \ - template \ - void do_conv_kern_##stride##_##i##x##i( \ - const __fp16* src, const __fp16* filter, const __fp16* bias, \ - __fp16* dst, const size_t IH, const size_t IW, const size_t OH, \ - const size_t OW, const size_t PH, const size_t PW); +#define KERN(stride, i) \ + template \ + void do_conv_kern_##stride##_##i##x##i( \ + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, \ + const size_t IH, const size_t IW, const size_t OH, const size_t OW, \ + const size_t PH, const size_t PW); KERN(stride1, 2) KERN(stride1, 3) diff --git a/dnn/src/arm_common/conv_bias/f16/direct.cpp b/dnn/src/arm_common/conv_bias/f16/direct.cpp index feaacd20..8400393c 100644 --- a/dnn/src/arm_common/conv_bias/f16/direct.cpp +++ b/dnn/src/arm_common/conv_bias/f16/direct.cpp @@ -158,16 +158,18 @@ namespace { template struct do_pixel_proxy { - static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow); + static void exec( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow); }; template struct do_pixel_proxy<1, height, width> { - static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); const int ih = oh, iw = ow; @@ -197,9 +199,10 @@ struct do_pixel_proxy<1, height, width> { template struct do_pixel_proxy<2, height, width> { - static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); const int ih = oh, iw = ow; @@ -231,9 +234,10 @@ struct do_pixel_proxy<2, height, width> { template struct do_pixel_proxy<3, height, width> { - static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); const int ih = oh, iw = ow; @@ -268,9 +272,10 @@ struct do_pixel_proxy<3, height, width> { template struct do_pixel_proxy<4, height, width> { - static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); const int ih = oh, iw = ow; @@ -308,9 +313,10 @@ struct do_pixel_proxy<4, height, width> { template struct do_pixel_proxy<5, height, width> { - static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); const int ih = oh, iw = ow; @@ -350,9 +356,10 @@ struct do_pixel_proxy<5, height, width> { template struct do_pixel_proxy<6, height, width> { - static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); const int ih = oh, iw = ow; @@ -395,9 +402,10 @@ struct do_pixel_proxy<6, height, width> { template struct do_pixel_proxy<7, height, width> { - static void exec(const __fp16* src, const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); const int ih = oh, iw = ow; @@ -445,18 +453,18 @@ struct do_pixel_proxy<7, height, width> { #undef LOAD_RESULT_VAL template -void do_pixel(const __fp16* src, const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { - do_pixel_proxy::exec(src, filter, dst, IH, IW, OH, OW, - FW, oh, ow); +void do_pixel( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { + do_pixel_proxy::exec( + src, filter, dst, IH, IW, OH, OW, FW, oh, ow); } template -void do_conv_tpl_enable_prefetch(const __fp16* src, - const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, - const int OW, const int FW) { +void do_conv_tpl_enable_prefetch( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FW) { const int hbeg = 0, hend = OH; const int wbeg = 0, wend = OW; int i, j; @@ -464,13 +472,11 @@ void do_conv_tpl_enable_prefetch(const __fp16* src, for (j = wbeg; j + 8 <= wend; j += 8) { // do prefetch const int prefetch_index_input = - (j + 16) < wend - ? i * IW + j + 16 - : (i + 8) * IW + (((j + 16 - wend) >> 2) << 2); + (j + 16) < wend ? i * IW + j + 16 + : (i + 8) * IW + (((j + 16 - wend) >> 2) << 2); const int prefetch_index_output = - (j + 16) < wend - ? i * OW + j + 16 - : (i + 8) * OW + (((j + 16 - wend) >> 2) << 2); + (j + 16) < wend ? i * OW + j + 16 + : (i + 8) * OW + (((j + 16 - wend) >> 2) << 2); const __fp16* src_prefetch = src + prefetch_index_input; const __fp16* dst_prefetch = dst + prefetch_index_output; for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { @@ -478,21 +484,19 @@ void do_conv_tpl_enable_prefetch(const __fp16* src, } #define unroll_prefetch_cb(i) __builtin_prefetch(dst_prefetch + i * OW, 1, 3); UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, - j); + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); } -#define DISPATCH(width) \ - do { \ - const int prefetch_index_input = (i + 8) * IW + 12; \ - const int prefetch_index_output = (i + 8) * OW + 12; \ - const __fp16* src_prefetch = src + prefetch_index_input; \ - const __fp16* dst_prefetch = dst + prefetch_index_output; \ - for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ - __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ - } \ - UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \ - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ - j); \ +#define DISPATCH(width) \ + do { \ + const int prefetch_index_input = (i + 8) * IW + 12; \ + const int prefetch_index_output = (i + 8) * OW + 12; \ + const __fp16* src_prefetch = src + prefetch_index_input; \ + const __fp16* dst_prefetch = dst + prefetch_index_output; \ + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ + } \ + UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ } while (0) switch (wend - j) { case 1: @@ -520,60 +524,56 @@ void do_conv_tpl_enable_prefetch(const __fp16* src, #undef DISPATCH } -#define DISPATCH2(height, width) \ - do { \ - const int prefetch_index_input = IH * IW + 12; \ - const __fp16* src_prefetch = src + prefetch_index_input; \ - for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ - __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ - } \ - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ - j); \ +#define DISPATCH2(height, width) \ + do { \ + const int prefetch_index_input = IH * IW + 12; \ + const __fp16* src_prefetch = src + prefetch_index_input; \ + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ + } \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ } while (0) -#define DISPATCH1(height) \ - do { \ - for (j = wbeg; j + 8 <= wend; j += 8) { \ - const int prefetch_index_input = \ - (j + 16) < wend \ - ? i * IW + j + 16 \ - : (i + 8) * IW + (((j + 16 - wend) >> 2) << 2); \ - const int prefetch_index_output = \ - (j + 16) < wend \ - ? i * OW + j + 16 \ - : (i + 8) * OW + (((j + 16 - wend) >> 2) << 2); \ - const __fp16* src_prefetch = src + prefetch_index_input; \ - const __fp16* dst_prefetch = dst + prefetch_index_output; \ - for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ - __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ - } \ - UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \ - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ - j); \ - } \ - switch (wend - j) { \ - case 1: \ - DISPATCH2(height, 1); \ - break; \ - case 2: \ - DISPATCH2(height, 2); \ - break; \ - case 3: \ - DISPATCH2(height, 3); \ - break; \ - case 4: \ - DISPATCH2(height, 4); \ - break; \ - case 5: \ - DISPATCH2(height, 5); \ - break; \ - case 6: \ - DISPATCH2(height, 6); \ - break; \ - case 7: \ - DISPATCH2(height, 7); \ - break; \ - } \ +#define DISPATCH1(height) \ + do { \ + for (j = wbeg; j + 8 <= wend; j += 8) { \ + const int prefetch_index_input = \ + (j + 16) < wend ? i * IW + j + 16 \ + : (i + 8) * IW + (((j + 16 - wend) >> 2) << 2); \ + const int prefetch_index_output = \ + (j + 16) < wend ? i * OW + j + 16 \ + : (i + 8) * OW + (((j + 16 - wend) >> 2) << 2); \ + const __fp16* src_prefetch = src + prefetch_index_input; \ + const __fp16* dst_prefetch = dst + prefetch_index_output; \ + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ + } \ + UNROLL_CALL_NOWRAPPER(BLOCK_H, unroll_prefetch_cb); \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ + } \ + switch (wend - j) { \ + case 1: \ + DISPATCH2(height, 1); \ + break; \ + case 2: \ + DISPATCH2(height, 2); \ + break; \ + case 3: \ + DISPATCH2(height, 3); \ + break; \ + case 4: \ + DISPATCH2(height, 4); \ + break; \ + case 5: \ + DISPATCH2(height, 5); \ + break; \ + case 6: \ + DISPATCH2(height, 6); \ + break; \ + case 7: \ + DISPATCH2(height, 7); \ + break; \ + } \ } while (0) switch (hend - i) { case 1: @@ -605,22 +605,19 @@ void do_conv_tpl_enable_prefetch(const __fp16* src, #undef unroll_prefetch_cb } template -void do_conv_tpl_disable_prefetch(const __fp16* src, - const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, - const int OW, const int FW) { +void do_conv_tpl_disable_prefetch( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FW) { const int hbeg = 0, hend = OH; const int wbeg = 0, wend = OW; int i, j; for (i = hbeg; i + BLOCK_H <= hend; i += BLOCK_H) { for (j = wbeg; j + 8 <= wend; j += 8) { - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, - j); + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); } -#define DISPATCH(width) \ - do { \ - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ - j); \ +#define DISPATCH(width) \ + do { \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ } while (0) switch (wend - j) { case 1: @@ -647,40 +644,38 @@ void do_conv_tpl_disable_prefetch(const __fp16* src, } #undef DISPATCH } -#define DISPATCH2(height, width) \ - do { \ - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ - j); \ +#define DISPATCH2(height, width) \ + do { \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ } while (0) -#define DISPATCH1(height) \ - do { \ - for (j = wbeg; j + 8 <= wend; j += 8) { \ - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ - j); \ - } \ - switch (wend - j) { \ - case 1: \ - DISPATCH2(height, 1); \ - break; \ - case 2: \ - DISPATCH2(height, 2); \ - break; \ - case 3: \ - DISPATCH2(height, 3); \ - break; \ - case 4: \ - DISPATCH2(height, 4); \ - break; \ - case 5: \ - DISPATCH2(height, 5); \ - break; \ - case 6: \ - DISPATCH2(height, 6); \ - break; \ - case 7: \ - DISPATCH2(height, 7); \ - break; \ - } \ +#define DISPATCH1(height) \ + do { \ + for (j = wbeg; j + 8 <= wend; j += 8) { \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ + } \ + switch (wend - j) { \ + case 1: \ + DISPATCH2(height, 1); \ + break; \ + case 2: \ + DISPATCH2(height, 2); \ + break; \ + case 3: \ + DISPATCH2(height, 3); \ + break; \ + case 4: \ + DISPATCH2(height, 4); \ + break; \ + case 5: \ + DISPATCH2(height, 5); \ + break; \ + case 6: \ + DISPATCH2(height, 6); \ + break; \ + case 7: \ + DISPATCH2(height, 7); \ + break; \ + } \ } while (0) switch (hend - i) { case 1: @@ -712,16 +707,14 @@ void do_conv_tpl_disable_prefetch(const __fp16* src, } } // anonymous namespace -void conv_bias::kern_direct_f16(const __fp16* src, - const __fp16* filter, __fp16* dst, - const int IH, const int IW, const int OH, - const int OW, const int FH, const int FW) { +void conv_bias::kern_direct_f16( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FH, const int FW) { megdnn_assert_internal(FH <= 7); if (IH > 100 && IW > 100) { -#define GAO(FH) \ - do { \ - return do_conv_tpl_enable_prefetch(src, filter, dst, IH, IW, OH, \ - OW, FW); \ +#define GAO(FH) \ + do { \ + return do_conv_tpl_enable_prefetch(src, filter, dst, IH, IW, OH, OW, FW); \ } while (0) switch (FH) { case 1: @@ -755,10 +748,9 @@ void conv_bias::kern_direct_f16(const __fp16* src, } #undef GAO } else { -#define GAO(FH) \ - do { \ - return do_conv_tpl_disable_prefetch(src, filter, dst, IH, IW, OH, \ - OW, FW); \ +#define GAO(FH) \ + do { \ + return do_conv_tpl_disable_prefetch(src, filter, dst, IH, IW, OH, OW, FW); \ } while (0) switch (FH) { case 1: diff --git a/dnn/src/arm_common/conv_bias/f16/direct.h b/dnn/src/arm_common/conv_bias/f16/direct.h index 7949d310..e96cf913 100644 --- a/dnn/src/arm_common/conv_bias/f16/direct.h +++ b/dnn/src/arm_common/conv_bias/f16/direct.h @@ -16,14 +16,14 @@ namespace megdnn { namespace arm_common { -namespace fp16{ +namespace fp16 { namespace conv_bias { -void kern_direct_f16(const __fp16* src, const __fp16* filter, - __fp16* dst, const int IH, const int IW, const int OH, - const int OW, const int FH, const int FW); +void kern_direct_f16( + const __fp16* src, const __fp16* filter, __fp16* dst, const int IH, + const int IW, const int OH, const int OW, const int FH, const int FW); -} // namespace convolution +} // namespace conv_bias } // namespace fp16 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp index 06aaf80b..12d10f82 100644 --- a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_algo.cpp @@ -23,11 +23,9 @@ using namespace megdnn; using namespace arm_common; -using conv_fun = - std::function; +using conv_fun = std::function; MIDOUT_DECL(megdnn_arm_common_conv_bias_fp16_nchw88) namespace { @@ -47,10 +45,9 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { return {nullptr, {s}}; } -void copy_padding_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { +void copy_padding_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { auto fm = kern_param.filter_meta; size_t group = fm.group; size_t IH = kern_param.isz[0]; @@ -82,16 +79,16 @@ void copy_padding_kern(const WorkspaceBundle& bundle, channel_id * IH2 * IW2 * 8; std::memset(sptr_base, 0, IH2 * IW2 * 8 * sizeof(dt_float16)); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 * 8 + PW * 8, - sptr + ih * IW * 8, IW * 8 * sizeof(dt_float16)); + std::memcpy( + sptr_base + (ih + PH) * IW2 * 8 + PW * 8, sptr + ih * IW * 8, + IW * 8 * sizeof(dt_float16)); } }; template -static void do_conv_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { +static void do_conv_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { auto fm = kern_param.filter_meta; size_t group = fm.group; size_t OH = kern_param.osz[0]; @@ -136,8 +133,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle, } // namespace /* ===================== stride1 algo ===================== */ -bool ConvBiasImpl::AlgoF16DirectNCHW88::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoF16DirectNCHW88::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { auto&& fm = param.filter_meta; auto fh = fm.spatial[0]; int oc = fm.ocpg; @@ -159,16 +156,16 @@ bool ConvBiasImpl::AlgoF16DirectNCHW88::usable(const NCBKernSizeParam& param, size_t ConvBiasImpl::AlgoF16DirectNCHW88::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_nchw88_stride1, - midout_iv("AlgoF16DirectNCHW88::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_fp16_nchw88_stride1, + midout_iv("AlgoF16DirectNCHW88::get_workspace"_hash)) { return get_bundle(param).total_size_in_bytes(); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoF16DirectNCHW88::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoF16DirectNCHW88::dispatch_kerns( const NCBKernSizeParam& param) const { auto fm = param.filter_meta; size_t batch = param.n; @@ -178,11 +175,12 @@ ConvBiasImpl::AlgoF16DirectNCHW88::dispatch_kerns( conv_fun do_conv_fun = nullptr; // NOTE: remain_w is not used to gen hash of midout for compatible with // shape runtime -#define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp16_nchw88, \ - midout_iv(#filter #bias_mode #stride #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_fp16_nchw88, \ + midout_iv(#filter #bias_mode #stride #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); #define GET_STRIDE_PARAM(filter, bias_mode, op) \ @@ -277,12 +275,11 @@ ConvBiasImpl::AlgoF16DirectNCHW88::dispatch_kerns( size_t OC = fm.ocpg / 8; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - do_conv_fun(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + do_conv_fun(bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); } }; // TODO: large group only, further multithread optimization required diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp index 354ce3b1..7c1f75f2 100644 --- a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.cpp @@ -24,13 +24,12 @@ using namespace arm_common; template struct compute_fma { - static inline void call(const float16x8_t* ri, const float16x8_t* rf, - float16x8_t* rdst) { + static inline void call( + const float16x8_t* ri, const float16x8_t* rf, float16x8_t* rdst) { #if defined(__aarch64__) rdst[bw] = vfmaq_laneq_f16(rdst[bw], rf[pc], ri[bw], pc); #else - rdst[bw] = vfmaq_f16(rdst[bw], rf[pc], - vdupq_n_f16(vgetq_lane_f16(ri[bw], pc))); + rdst[bw] = vfmaq_f16(rdst[bw], rf[pc], vdupq_n_f16(vgetq_lane_f16(ri[bw], pc))); #endif compute_fma::call(ri, rf, rdst); } @@ -38,16 +37,16 @@ struct compute_fma { template struct compute_fma { - static inline void call(const float16x8_t* ri, const float16x8_t* rf, - float16x8_t* rdst) { + static inline void call( + const float16x8_t* ri, const float16x8_t* rf, float16x8_t* rdst) { compute_fma::call(ri, rf, rdst); } }; template struct compute_fma { - static inline void call(const float16x8_t* ri, const float16x8_t* rf, - float16x8_t* rdst) {} + static inline void call( + const float16x8_t* ri, const float16x8_t* rf, float16x8_t* rdst) {} }; template @@ -103,9 +102,9 @@ struct store_dst { }; template -static inline void do_conv_kern_1xBW(const float16_t*& src, float16_t*& dst, - const float16_t* filter, int IW, int OW, - int& ow) { +static inline void do_conv_kern_1xBW( + const float16_t*& src, float16_t*& dst, const float16_t* filter, int IW, int OW, + int& ow) { constexpr int PC = 8; constexpr int FW = FH; constexpr int SW = SH; @@ -125,8 +124,7 @@ static inline void do_conv_kern_1xBW(const float16_t*& src, float16_t*& dst, load_src::call(ri, src + (fh * IW + fw) * PC); if (FH > 1 || FW > 1) { - load_filter::call(rf, - filter + (fh * FW + fw) * PC * PC); + load_filter::call(rf, filter + (fh * FW + fw) * PC * PC); } compute_fma::call(ri, rf, rdst); @@ -141,8 +139,7 @@ static inline void do_conv_kern_1xBW(const float16_t*& src, float16_t*& dst, } template -static void do_load_bias_kern(float16_t* dst, const float16_t* bias, int OH, - int OW) { +static void do_load_bias_kern(float16_t* dst, const float16_t* bias, int OH, int OW) { constexpr int PC = 8; if (bias_mode == BiasMode::NO_BIAS) { @@ -198,9 +195,9 @@ static void do_op_kern(float16_t* dst, int OH, int OW) { } template -static void do_conv_kern(const float16_t* src, float16_t* dst, - const float16_t* filter, int IC, int IH, int IW, - int OH, int OW) { +static void do_conv_kern( + const float16_t* src, float16_t* dst, const float16_t* filter, int IC, int IH, + int IW, int OH, int OW) { constexpr int PC = 8; constexpr int FW = FH; @@ -213,13 +210,10 @@ static void do_conv_kern(const float16_t* src, float16_t* dst, float16_t* dst_ptr_w = dst_ptr_h; int ow = 0; - do_conv_kern_1xBW(src_ptr_w, dst_ptr_w, filter, IW, OW, - ow); + do_conv_kern_1xBW(src_ptr_w, dst_ptr_w, filter, IW, OW, ow); if (OW & 3) { - do_conv_kern_1xBW(src_ptr_w, dst_ptr_w, filter, IW, - OW, ow); - do_conv_kern_1xBW(src_ptr_w, dst_ptr_w, filter, IW, - OW, ow); + do_conv_kern_1xBW(src_ptr_w, dst_ptr_w, filter, IW, OW, ow); + do_conv_kern_1xBW(src_ptr_w, dst_ptr_w, filter, IW, OW, ow); } src_ptr_h += SH * IW * PC; dst_ptr_h += OW * PC; @@ -229,8 +223,9 @@ static void do_conv_kern(const float16_t* src, float16_t* dst, } } -static void do_conv_kern_1x1(const float16_t* src, float16_t* dst, - const float16_t* filter, int IC, int OH, int OW) { +static void do_conv_kern_1x1( + const float16_t* src, float16_t* dst, const float16_t* filter, int IC, int OH, + int OW) { constexpr int PC = 8; const int IH = OH; const int IW = OW; @@ -242,21 +237,18 @@ static void do_conv_kern_1x1(const float16_t* src, float16_t* dst, float16_t* dst_ptr_hw = dst; int ohw = 0; - do_conv_kern_1xBW<1, 1, 8>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, - ohw); - do_conv_kern_1xBW<1, 1, 4>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, - ohw); - do_conv_kern_1xBW<1, 1, 1>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, - ohw); + do_conv_kern_1xBW<1, 1, 8>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, ohw); + do_conv_kern_1xBW<1, 1, 4>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, ohw); + do_conv_kern_1xBW<1, 1, 1>(src_ptr_hw, dst_ptr_hw, filter, IHW, OHW, ohw); src += IHW * PC; filter += PC * PC; } } template -void conv_bias::conv_direct_fp16_nchw88(const __fp16* src, const __fp16* filter, - const __fp16* bias, __fp16* dst, int IC, - int IH, int IW, int OH, int OW) { +void conv_bias::conv_direct_fp16_nchw88( + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, + int IC, int IH, int IW, int OH, int OW) { do_load_bias_kern(dst, bias, OH, OW); if (FH == 1 && SH == 1 && IH == OH && IW == OW) { do_conv_kern_1x1(src, dst, filter, IC, OH, OW); @@ -266,11 +258,10 @@ void conv_bias::conv_direct_fp16_nchw88(const __fp16* src, const __fp16* filter, do_op_kern(dst, OH, OW); } -#define INSTANTIATION(stride, filter, bias, Op) \ - template void \ - conv_bias::conv_direct_fp16_nchw88( \ - const __fp16*, const __fp16*, const __fp16*, __fp16*, int, int, \ - int, int, int); +#define INSTANTIATION(stride, filter, bias, Op) \ + template void conv_bias::conv_direct_fp16_nchw88( \ + const __fp16*, const __fp16*, const __fp16*, __fp16*, int, int, int, int, \ + int); #define FOR_OP(stride, filter, bias) \ INSTANTIATION(stride, filter, bias, SigmoidOp<__fp16>) \ diff --git a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.h b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.h index dafedc3e..84d49ef6 100644 --- a/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.h +++ b/dnn/src/arm_common/conv_bias/f16/direct_nchw88_kern.h @@ -21,9 +21,9 @@ namespace arm_common { namespace conv_bias { template -void conv_direct_fp16_nchw88(const __fp16* src, const __fp16* filter, - const __fp16* bias, __fp16* dst, int IC, int IH, - int IW, int OH, int OW); +void conv_direct_fp16_nchw88( + const __fp16* src, const __fp16* filter, const __fp16* bias, __fp16* dst, + int IC, int IH, int IW, int OH, int OW); } // namespace conv_bias } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp b/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp index 2768ba9c..b72babc7 100644 --- a/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp +++ b/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.cpp @@ -13,8 +13,8 @@ #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #include "./do_conv_stride1.h" -#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/arm_common/simd_macro/marm_neon.h" using namespace megdnn; using namespace arm_common; @@ -24,10 +24,9 @@ using namespace conv_stride1; using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; - -void conv_stride1::do_conv_2x2_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, - size_t IH, size_t IW, size_t OH, size_t OW, - size_t IC) { +void conv_stride1::do_conv_2x2_stride1( + const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - OW; //! unroll of 2 size_t ic = 0; @@ -130,9 +129,9 @@ void conv_stride1::do_conv_2x2_stride1(const __fp16* src, const __fp16* filter, } } -void conv_stride1::do_conv_3x3_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, - size_t IH, size_t IW, size_t OH, size_t OW, - size_t IC) { +void conv_stride1::do_conv_3x3_stride1( + const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - OW; rep(ic, IC) { @@ -271,9 +270,9 @@ void conv_stride1::do_conv_3x3_stride1(const __fp16* src, const __fp16* filter, } } -void conv_stride1::do_conv_5x5_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, - size_t IH, size_t IW, size_t OH, size_t OW, - size_t IC) { +void conv_stride1::do_conv_5x5_stride1( + const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - OW; rep(ic, IC) { diff --git a/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h b/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h index ede394ee..d4fe891f 100644 --- a/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h +++ b/dnn/src/arm_common/conv_bias/f16/do_conv_stride1.h @@ -17,12 +17,15 @@ namespace megdnn { namespace arm_common { namespace fp16 { namespace conv_stride1 { -void do_conv_2x2_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); -void do_conv_3x3_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); -void do_conv_5x5_stride1(const __fp16* src, const __fp16* filter, __fp16* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +void do_conv_2x2_stride1( + const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC); +void do_conv_3x3_stride1( + const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC); +void do_conv_5x5_stride1( + const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC); } // namespace conv_stride1 } // namespace fp16 } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/f16/helper.h b/dnn/src/arm_common/conv_bias/f16/helper.h index 6d4f253e..c5dbc975 100644 --- a/dnn/src/arm_common/conv_bias/f16/helper.h +++ b/dnn/src/arm_common/conv_bias/f16/helper.h @@ -34,306 +34,266 @@ #if MEGDNN_AARCH64 -#define TRANSPOSE_4x4(a, ret) \ - do { \ - auto b00 = vzip1_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a1b1a2b2*/ \ - auto b01 = vzip2_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a3b3a4b4*/ \ - auto b10 = vzip1_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c1d1c2d2*/ \ - auto b11 = vzip2_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c3d3c4d4*/ \ - auto s32b00 = vreinterpret_s32_f16(b00); \ - auto s32b01 = vreinterpret_s32_f16(b01); \ - auto s32b10 = vreinterpret_s32_f16(b10); \ - auto s32b11 = vreinterpret_s32_f16(b11); \ - CONCAT(ret, 0).value = \ - vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)); \ - CONCAT(ret, 1).value = \ - vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)); \ - CONCAT(ret, 2).value = \ - vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)); \ - CONCAT(ret, 3).value = \ - vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)); \ +#define TRANSPOSE_4x4(a, ret) \ + do { \ + auto b00 = vzip1_f16(CONCAT(a, 0).value, CONCAT(a, 1).value); /*a1b1a2b2*/ \ + auto b01 = vzip2_f16(CONCAT(a, 0).value, CONCAT(a, 1).value); /*a3b3a4b4*/ \ + auto b10 = vzip1_f16(CONCAT(a, 2).value, CONCAT(a, 3).value); /*c1d1c2d2*/ \ + auto b11 = vzip2_f16(CONCAT(a, 2).value, CONCAT(a, 3).value); /*c3d3c4d4*/ \ + auto s32b00 = vreinterpret_s32_f16(b00); \ + auto s32b01 = vreinterpret_s32_f16(b01); \ + auto s32b10 = vreinterpret_s32_f16(b10); \ + auto s32b11 = vreinterpret_s32_f16(b11); \ + CONCAT(ret, 0).value = vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)); \ + CONCAT(ret, 1).value = vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)); \ + CONCAT(ret, 2).value = vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)); \ + CONCAT(ret, 3).value = vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)); \ } while (0); -#define TRANSPOSE_4x8(a, ret) \ - do { \ - auto b00 = vzip1q_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4*/ \ - auto b01 = vzip2q_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a5b5a6b6a7b7a8b8*/ \ - auto b10 = vzip1q_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4*/ \ - auto b11 = vzip2q_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c5d5c6d6c7d7c8d8*/ \ - auto s32b00 = vreinterpretq_s32_f16(b00); \ - auto s32b01 = vreinterpretq_s32_f16(b01); \ - auto s32b10 = vreinterpretq_s32_f16(b10); \ - auto s32b11 = vreinterpretq_s32_f16(b11); \ - auto f16b00 = vreinterpretq_f16_s32( \ - vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1a2b2c2d2*/ \ - auto f16b01 = vreinterpretq_f16_s32( \ - vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3a4b4a4d4*/ \ - auto f16b10 = vreinterpretq_f16_s32( \ - vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5a6b6c6d6*/ \ - auto f16b11 = vreinterpretq_f16_s32( \ - vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7a8b8c8d8*/ \ - CONCAT(ret, 0).value = vget_low_f16(f16b00); \ - CONCAT(ret, 1).value = vget_high_f16(f16b00); \ - CONCAT(ret, 2).value = vget_low_f16(f16b01); \ - CONCAT(ret, 3).value = vget_high_f16(f16b01); \ - CONCAT(ret, 4).value = vget_low_f16(f16b10); \ - CONCAT(ret, 5).value = vget_high_f16(f16b10); \ - CONCAT(ret, 6).value = vget_low_f16(f16b11); \ - CONCAT(ret, 7).value = vget_high_f16(f16b11); \ +#define TRANSPOSE_4x8(a, ret) \ + do { \ + auto b00 = vzip1q_f16( \ + CONCAT(a, 0).value, CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4*/ \ + auto b01 = vzip2q_f16( \ + CONCAT(a, 0).value, CONCAT(a, 1).value); /*a5b5a6b6a7b7a8b8*/ \ + auto b10 = vzip1q_f16( \ + CONCAT(a, 2).value, CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4*/ \ + auto b11 = vzip2q_f16( \ + CONCAT(a, 2).value, CONCAT(a, 3).value); /*c5d5c6d6c7d7c8d8*/ \ + auto s32b00 = vreinterpretq_s32_f16(b00); \ + auto s32b01 = vreinterpretq_s32_f16(b01); \ + auto s32b10 = vreinterpretq_s32_f16(b10); \ + auto s32b11 = vreinterpretq_s32_f16(b11); \ + auto f16b00 = vreinterpretq_f16_s32( \ + vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1a2b2c2d2*/ \ + auto f16b01 = vreinterpretq_f16_s32( \ + vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3a4b4a4d4*/ \ + auto f16b10 = vreinterpretq_f16_s32( \ + vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5a6b6c6d6*/ \ + auto f16b11 = vreinterpretq_f16_s32( \ + vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7a8b8c8d8*/ \ + CONCAT(ret, 0).value = vget_low_f16(f16b00); \ + CONCAT(ret, 1).value = vget_high_f16(f16b00); \ + CONCAT(ret, 2).value = vget_low_f16(f16b01); \ + CONCAT(ret, 3).value = vget_high_f16(f16b01); \ + CONCAT(ret, 4).value = vget_low_f16(f16b10); \ + CONCAT(ret, 5).value = vget_high_f16(f16b10); \ + CONCAT(ret, 6).value = vget_low_f16(f16b11); \ + CONCAT(ret, 7).value = vget_high_f16(f16b11); \ } while (0); -#define TRANSPOSE_8x4(a, ret) \ - do { \ - auto b00 = vzip1_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a1b1a2b2*/ \ - auto b01 = vzip2_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a3b3a4b4*/ \ - auto b10 = vzip1_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c1d1c2d2*/ \ - auto b11 = vzip2_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c3d3c4d4*/ \ - auto b20 = vzip1_f16(CONCAT(a, 4).value, \ - CONCAT(a, 5).value); /*e1f1e2f2*/ \ - auto b21 = vzip2_f16(CONCAT(a, 4).value, \ - CONCAT(a, 5).value); /*e3f3e4f4*/ \ - auto b30 = vzip1_f16(CONCAT(a, 6).value, \ - CONCAT(a, 7).value); /*g1h1g2h2*/ \ - auto b31 = vzip2_f16(CONCAT(a, 6).value, \ - CONCAT(a, 7).value); /*g3h3g4h4*/ \ - auto s32b00 = vreinterpret_s32_f16(b00); \ - auto s32b01 = vreinterpret_s32_f16(b01); \ - auto s32b10 = vreinterpret_s32_f16(b10); \ - auto s32b11 = vreinterpret_s32_f16(b11); \ - auto s32b20 = vreinterpret_s32_f16(b20); \ - auto s32b21 = vreinterpret_s32_f16(b21); \ - auto s32b30 = vreinterpret_s32_f16(b30); \ - auto s32b31 = vreinterpret_s32_f16(b31); \ - CONCAT(ret, 0).value = \ - vcombine_f16(vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)), \ - vreinterpret_f16_s32(vzip1_s32(s32b20, s32b30))); \ - CONCAT(ret, 1).value = \ - vcombine_f16(vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)), \ - vreinterpret_f16_s32(vzip2_s32(s32b20, s32b30))); \ - CONCAT(ret, 2).value = \ - vcombine_f16(vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)), \ - vreinterpret_f16_s32(vzip1_s32(s32b21, s32b31))); \ - CONCAT(ret, 3).value = \ - vcombine_f16(vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)), \ - vreinterpret_f16_s32(vzip2_s32(s32b21, s32b31))); \ +#define TRANSPOSE_8x4(a, ret) \ + do { \ + auto b00 = vzip1_f16(CONCAT(a, 0).value, CONCAT(a, 1).value); /*a1b1a2b2*/ \ + auto b01 = vzip2_f16(CONCAT(a, 0).value, CONCAT(a, 1).value); /*a3b3a4b4*/ \ + auto b10 = vzip1_f16(CONCAT(a, 2).value, CONCAT(a, 3).value); /*c1d1c2d2*/ \ + auto b11 = vzip2_f16(CONCAT(a, 2).value, CONCAT(a, 3).value); /*c3d3c4d4*/ \ + auto b20 = vzip1_f16(CONCAT(a, 4).value, CONCAT(a, 5).value); /*e1f1e2f2*/ \ + auto b21 = vzip2_f16(CONCAT(a, 4).value, CONCAT(a, 5).value); /*e3f3e4f4*/ \ + auto b30 = vzip1_f16(CONCAT(a, 6).value, CONCAT(a, 7).value); /*g1h1g2h2*/ \ + auto b31 = vzip2_f16(CONCAT(a, 6).value, CONCAT(a, 7).value); /*g3h3g4h4*/ \ + auto s32b00 = vreinterpret_s32_f16(b00); \ + auto s32b01 = vreinterpret_s32_f16(b01); \ + auto s32b10 = vreinterpret_s32_f16(b10); \ + auto s32b11 = vreinterpret_s32_f16(b11); \ + auto s32b20 = vreinterpret_s32_f16(b20); \ + auto s32b21 = vreinterpret_s32_f16(b21); \ + auto s32b30 = vreinterpret_s32_f16(b30); \ + auto s32b31 = vreinterpret_s32_f16(b31); \ + CONCAT(ret, 0).value = vcombine_f16( \ + vreinterpret_f16_s32(vzip1_s32(s32b00, s32b10)), \ + vreinterpret_f16_s32(vzip1_s32(s32b20, s32b30))); \ + CONCAT(ret, 1).value = vcombine_f16( \ + vreinterpret_f16_s32(vzip2_s32(s32b00, s32b10)), \ + vreinterpret_f16_s32(vzip2_s32(s32b20, s32b30))); \ + CONCAT(ret, 2).value = vcombine_f16( \ + vreinterpret_f16_s32(vzip1_s32(s32b01, s32b11)), \ + vreinterpret_f16_s32(vzip1_s32(s32b21, s32b31))); \ + CONCAT(ret, 3).value = vcombine_f16( \ + vreinterpret_f16_s32(vzip2_s32(s32b01, s32b11)), \ + vreinterpret_f16_s32(vzip2_s32(s32b21, s32b31))); \ } while (0); -#define TRANSPOSE_8x8(a, ret) \ - do { \ - auto b00 = vzip1q_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ - auto b01 = vzip2q_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \ - auto b10 = vzip1q_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ - auto b11 = vzip2q_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \ - auto b20 = vzip1q_f16(CONCAT(a, 4).value, \ - CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ - auto b21 = vzip2q_f16(CONCAT(a, 4).value, \ - CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \ - auto b30 = vzip1q_f16(CONCAT(a, 6).value, \ - CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ - auto b31 = vzip2q_f16(CONCAT(a, 6).value, \ - CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \ - auto s32b00 = vreinterpretq_s32_f16(b00); \ - auto s32b01 = vreinterpretq_s32_f16(b01); \ - auto s32b10 = vreinterpretq_s32_f16(b10); \ - auto s32b11 = vreinterpretq_s32_f16(b11); \ - auto s32b20 = vreinterpretq_s32_f16(b20); \ - auto s32b21 = vreinterpretq_s32_f16(b21); \ - auto s32b30 = vreinterpretq_s32_f16(b30); \ - auto s32b31 = vreinterpretq_s32_f16(b31); \ - auto s64b00 = vreinterpretq_s64_s32( \ - vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1 a2b2c2d2*/ \ - auto s64b01 = vreinterpretq_s64_s32( \ - vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3 a4b4c4d4*/ \ - auto s64b10 = vreinterpretq_s64_s32( \ - vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5 a6b6c6d6*/ \ - auto s64b11 = vreinterpretq_s64_s32( \ - vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7 a8b8c8d8*/ \ - auto s64b20 = vreinterpretq_s64_s32( \ - vzip1q_s32(s32b20, s32b30)); /*e1f1g1h1 e2f2g2h2*/ \ - auto s64b21 = vreinterpretq_s64_s32( \ - vzip2q_s32(s32b20, s32b30)); /*e3f3g3h3 e4f4g4h4*/ \ - auto s64b30 = vreinterpretq_s64_s32( \ - vzip1q_s32(s32b21, s32b31)); /*e5f5g5h5 e6f6g6h6*/ \ - auto s64b31 = vreinterpretq_s64_s32( \ - vzip2q_s32(s32b21, s32b31)); /*e7f7g7h7 e8f8g8h8*/ \ - CONCAT(ret, 0).value = \ - vreinterpretq_f16_s64(vzip1q_s64(s64b00, s64b20)); \ - CONCAT(ret, 1).value = \ - vreinterpretq_f16_s64(vzip2q_s64(s64b00, s64b20)); \ - CONCAT(ret, 2).value = \ - vreinterpretq_f16_s64(vzip1q_s64(s64b01, s64b21)); \ - CONCAT(ret, 3).value = \ - vreinterpretq_f16_s64(vzip2q_s64(s64b01, s64b21)); \ - CONCAT(ret, 4).value = \ - vreinterpretq_f16_s64(vzip1q_s64(s64b10, s64b30)); \ - CONCAT(ret, 5).value = \ - vreinterpretq_f16_s64(vzip2q_s64(s64b10, s64b30)); \ - CONCAT(ret, 6).value = \ - vreinterpretq_f16_s64(vzip1q_s64(s64b11, s64b31)); \ - CONCAT(ret, 7).value = \ - vreinterpretq_f16_s64(vzip2q_s64(s64b11, s64b31)); \ +#define TRANSPOSE_8x8(a, ret) \ + do { \ + auto b00 = vzip1q_f16( \ + CONCAT(a, 0).value, CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ + auto b01 = vzip2q_f16( \ + CONCAT(a, 0).value, CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \ + auto b10 = vzip1q_f16( \ + CONCAT(a, 2).value, CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ + auto b11 = vzip2q_f16( \ + CONCAT(a, 2).value, CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \ + auto b20 = vzip1q_f16( \ + CONCAT(a, 4).value, CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ + auto b21 = vzip2q_f16( \ + CONCAT(a, 4).value, CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \ + auto b30 = vzip1q_f16( \ + CONCAT(a, 6).value, CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ + auto b31 = vzip2q_f16( \ + CONCAT(a, 6).value, CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \ + auto s32b00 = vreinterpretq_s32_f16(b00); \ + auto s32b01 = vreinterpretq_s32_f16(b01); \ + auto s32b10 = vreinterpretq_s32_f16(b10); \ + auto s32b11 = vreinterpretq_s32_f16(b11); \ + auto s32b20 = vreinterpretq_s32_f16(b20); \ + auto s32b21 = vreinterpretq_s32_f16(b21); \ + auto s32b30 = vreinterpretq_s32_f16(b30); \ + auto s32b31 = vreinterpretq_s32_f16(b31); \ + auto s64b00 = vreinterpretq_s64_s32( \ + vzip1q_s32(s32b00, s32b10)); /*a1b1c1d1 a2b2c2d2*/ \ + auto s64b01 = vreinterpretq_s64_s32( \ + vzip2q_s32(s32b00, s32b10)); /*a3b3c3d3 a4b4c4d4*/ \ + auto s64b10 = vreinterpretq_s64_s32( \ + vzip1q_s32(s32b01, s32b11)); /*a5b5c5d5 a6b6c6d6*/ \ + auto s64b11 = vreinterpretq_s64_s32( \ + vzip2q_s32(s32b01, s32b11)); /*a7b7c7d7 a8b8c8d8*/ \ + auto s64b20 = vreinterpretq_s64_s32( \ + vzip1q_s32(s32b20, s32b30)); /*e1f1g1h1 e2f2g2h2*/ \ + auto s64b21 = vreinterpretq_s64_s32( \ + vzip2q_s32(s32b20, s32b30)); /*e3f3g3h3 e4f4g4h4*/ \ + auto s64b30 = vreinterpretq_s64_s32( \ + vzip1q_s32(s32b21, s32b31)); /*e5f5g5h5 e6f6g6h6*/ \ + auto s64b31 = vreinterpretq_s64_s32( \ + vzip2q_s32(s32b21, s32b31)); /*e7f7g7h7 e8f8g8h8*/ \ + CONCAT(ret, 0).value = vreinterpretq_f16_s64(vzip1q_s64(s64b00, s64b20)); \ + CONCAT(ret, 1).value = vreinterpretq_f16_s64(vzip2q_s64(s64b00, s64b20)); \ + CONCAT(ret, 2).value = vreinterpretq_f16_s64(vzip1q_s64(s64b01, s64b21)); \ + CONCAT(ret, 3).value = vreinterpretq_f16_s64(vzip2q_s64(s64b01, s64b21)); \ + CONCAT(ret, 4).value = vreinterpretq_f16_s64(vzip1q_s64(s64b10, s64b30)); \ + CONCAT(ret, 5).value = vreinterpretq_f16_s64(vzip2q_s64(s64b10, s64b30)); \ + CONCAT(ret, 6).value = vreinterpretq_f16_s64(vzip1q_s64(s64b11, s64b31)); \ + CONCAT(ret, 7).value = vreinterpretq_f16_s64(vzip2q_s64(s64b11, s64b31)); \ } while (0); #else -#define TRANSPOSE_4x4(a, ret) \ - do { \ - auto b0_01 = vzip_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ - auto b1_01 = vzip_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ - auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \ - auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \ - auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \ - auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \ - auto s32b00b10 = vzip_s32(s32b00, s32b10); /*a1b1c1d1 a2b2c2d2*/ \ - auto s32b01b11 = vzip_s32(s32b01, s32b11); /*a3b3c3d3 a4b4c4d4*/ \ - CONCAT(ret, 0).value = vreinterpret_f16_s32(s32b00b10.val[0]); \ - CONCAT(ret, 1).value = vreinterpret_f16_s32(s32b00b10.val[1]); \ - CONCAT(ret, 2).value = vreinterpret_f16_s32(s32b01b11.val[0]); \ - CONCAT(ret, 3).value = vreinterpret_f16_s32(s32b01b11.val[1]); \ +#define TRANSPOSE_4x4(a, ret) \ + do { \ + auto b0_01 = vzip_f16( \ + CONCAT(a, 0).value, CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ + auto b1_01 = vzip_f16( \ + CONCAT(a, 2).value, CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ + auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \ + auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \ + auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \ + auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \ + auto s32b00b10 = vzip_s32(s32b00, s32b10); /*a1b1c1d1 a2b2c2d2*/ \ + auto s32b01b11 = vzip_s32(s32b01, s32b11); /*a3b3c3d3 a4b4c4d4*/ \ + CONCAT(ret, 0).value = vreinterpret_f16_s32(s32b00b10.val[0]); \ + CONCAT(ret, 1).value = vreinterpret_f16_s32(s32b00b10.val[1]); \ + CONCAT(ret, 2).value = vreinterpret_f16_s32(s32b01b11.val[0]); \ + CONCAT(ret, 3).value = vreinterpret_f16_s32(s32b01b11.val[1]); \ } while (0); -#define TRANSPOSE_4x8(a, ret) \ - do { \ - auto b0_01 = vzipq_f16( \ - CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4 a5b5a6b6a7b7a8b8*/ \ - auto b1_01 = vzipq_f16( \ - CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4 c5d6c6d6c7d7c8d8*/ \ - auto s32b00 = vreinterpretq_s32_f16(b0_01.val[0]); \ - auto s32b01 = vreinterpretq_s32_f16(b0_01.val[1]); \ - auto s32b10 = vreinterpretq_s32_f16(b1_01.val[0]); \ - auto s32b11 = vreinterpretq_s32_f16(b1_01.val[1]); \ - auto s32b00b10 = vzipq_s32( \ - s32b00, s32b10); /*a1b1c1d1a2b2c2d2 a3b3c3d3a4b4c4d4*/ \ - auto s32b01b11 = vzipq_s32( \ - s32b01, s32b11); /*a5b5c5d5a6b6c6d6 a7b7c7d7a8b8c8d8*/ \ - CONCAT(ret, 0).value = \ - vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[0])); \ - CONCAT(ret, 1).value = \ - vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[0])); \ - CONCAT(ret, 2).value = \ - vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[1])); \ - CONCAT(ret, 3).value = \ - vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[1])); \ - CONCAT(ret, 4).value = \ - vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[0])); \ - CONCAT(ret, 5).value = \ - vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[0])); \ - CONCAT(ret, 6).value = \ - vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[1])); \ - CONCAT(ret, 7).value = \ - vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[1])); \ +#define TRANSPOSE_4x8(a, ret) \ + do { \ + auto b0_01 = vzipq_f16( \ + CONCAT(a, 0).value, \ + CONCAT(a, 1).value); /*a1b1a2b2a3b3a4b4 a5b5a6b6a7b7a8b8*/ \ + auto b1_01 = vzipq_f16( \ + CONCAT(a, 2).value, \ + CONCAT(a, 3).value); /*c1d1c2d2c3d3c4d4 c5d6c6d6c7d7c8d8*/ \ + auto s32b00 = vreinterpretq_s32_f16(b0_01.val[0]); \ + auto s32b01 = vreinterpretq_s32_f16(b0_01.val[1]); \ + auto s32b10 = vreinterpretq_s32_f16(b1_01.val[0]); \ + auto s32b11 = vreinterpretq_s32_f16(b1_01.val[1]); \ + auto s32b00b10 = \ + vzipq_s32(s32b00, s32b10); /*a1b1c1d1a2b2c2d2 a3b3c3d3a4b4c4d4*/ \ + auto s32b01b11 = \ + vzipq_s32(s32b01, s32b11); /*a5b5c5d5a6b6c6d6 a7b7c7d7a8b8c8d8*/ \ + CONCAT(ret, 0).value = vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[0])); \ + CONCAT(ret, 1).value = vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[0])); \ + CONCAT(ret, 2).value = vreinterpret_f16_s32(vget_low_f16(s32b00b10.val[1])); \ + CONCAT(ret, 3).value = vreinterpret_f16_s32(vget_high_f16(s32b00b10.val[1])); \ + CONCAT(ret, 4).value = vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[0])); \ + CONCAT(ret, 5).value = vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[0])); \ + CONCAT(ret, 6).value = vreinterpret_f16_s32(vget_low_f16(s32b01b11.val[1])); \ + CONCAT(ret, 7).value = vreinterpret_f16_s32(vget_high_f16(s32b01b11.val[1])); \ } while (0); -#define TRANSPOSE_8x4(a, ret) \ - do { \ - auto b0_01 = vzip_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ - auto b1_01 = vzip_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ - auto b2_01 = vzip_f16(CONCAT(a, 4).value, \ - CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ - auto b3_01 = vzip_f16(CONCAT(a, 6).value, \ - CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ - auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \ - auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \ - auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \ - auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \ - auto s32b20 = vreinterpret_s32_f16(b2_01.val[0]); \ - auto s32b21 = vreinterpret_s32_f16(b2_01.val[1]); \ - auto s32b30 = vreinterpret_s32_f16(b3_01.val[0]); \ - auto s32b31 = vreinterpret_s32_f16(b3_01.val[1]); \ - auto s32b00b10 = vzip_s32(s32b00, s32b10); \ - auto s32b01b11 = vzip_s32(s32b01, s32b11); \ - auto s32b20b30 = vzip_s32(s32b20, s32b30); \ - auto s32b21b31 = vzip_s32(s32b21, s32b31); \ - CONCAT(ret, 0).value = \ - vcombine_f16(vreinterpret_f16_s32(s32b00b10.val[0]), \ - vreinterpret_f16_s32(s32b20b30.val[0])); \ - CONCAT(ret, 1).value = \ - vcombine_f16(vreinterpret_f16_s32(s32b00b10.val[1]), \ - vreinterpret_f16_s32(s32b20b30.val[1])); \ - CONCAT(ret, 2).value = \ - vcombine_f16(vreinterpret_f16_s32(s32b01b11.val[0]), \ - vreinterpret_f16_s32(s32b21b31.val[0])); \ - CONCAT(ret, 3).value = \ - vcombine_f16(vreinterpret_f16_s32(s32b01b11.val[1]), \ - vreinterpret_f16_s32(s32b21b31.val[1])); \ +#define TRANSPOSE_8x4(a, ret) \ + do { \ + auto b0_01 = vzip_f16( \ + CONCAT(a, 0).value, CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ + auto b1_01 = vzip_f16( \ + CONCAT(a, 2).value, CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ + auto b2_01 = vzip_f16( \ + CONCAT(a, 4).value, CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ + auto b3_01 = vzip_f16( \ + CONCAT(a, 6).value, CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ + auto s32b00 = vreinterpret_s32_f16(b0_01.val[0]); \ + auto s32b01 = vreinterpret_s32_f16(b0_01.val[1]); \ + auto s32b10 = vreinterpret_s32_f16(b1_01.val[0]); \ + auto s32b11 = vreinterpret_s32_f16(b1_01.val[1]); \ + auto s32b20 = vreinterpret_s32_f16(b2_01.val[0]); \ + auto s32b21 = vreinterpret_s32_f16(b2_01.val[1]); \ + auto s32b30 = vreinterpret_s32_f16(b3_01.val[0]); \ + auto s32b31 = vreinterpret_s32_f16(b3_01.val[1]); \ + auto s32b00b10 = vzip_s32(s32b00, s32b10); \ + auto s32b01b11 = vzip_s32(s32b01, s32b11); \ + auto s32b20b30 = vzip_s32(s32b20, s32b30); \ + auto s32b21b31 = vzip_s32(s32b21, s32b31); \ + CONCAT(ret, 0).value = vcombine_f16( \ + vreinterpret_f16_s32(s32b00b10.val[0]), \ + vreinterpret_f16_s32(s32b20b30.val[0])); \ + CONCAT(ret, 1).value = vcombine_f16( \ + vreinterpret_f16_s32(s32b00b10.val[1]), \ + vreinterpret_f16_s32(s32b20b30.val[1])); \ + CONCAT(ret, 2).value = vcombine_f16( \ + vreinterpret_f16_s32(s32b01b11.val[0]), \ + vreinterpret_f16_s32(s32b21b31.val[0])); \ + CONCAT(ret, 3).value = vcombine_f16( \ + vreinterpret_f16_s32(s32b01b11.val[1]), \ + vreinterpret_f16_s32(s32b21b31.val[1])); \ } while (0); -#define TRANSPOSE_8x8(a, ret) \ - do { \ - auto b00 = vzipq_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ - auto b01 = vzipq_f16(CONCAT(a, 0).value, \ - CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \ - auto b10 = vzipq_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ - auto b11 = vzipq_f16(CONCAT(a, 2).value, \ - CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \ - auto b20 = vzipq_f16(CONCAT(a, 4).value, \ - CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ - auto b21 = vzipq_f16(CONCAT(a, 4).value, \ - CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \ - auto b30 = vzipq_f16(CONCAT(a, 6).value, \ - CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ - auto b31 = vzipq_f16(CONCAT(a, 6).value, \ - CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \ - auto s32b00 = vreinterpretq_s32_f16(b00.val[0]); \ - auto s32b01 = vreinterpretq_s32_f16(b01.val[1]); \ - auto s32b10 = vreinterpretq_s32_f16(b10.val[0]); \ - auto s32b11 = vreinterpretq_s32_f16(b11.val[1]); \ - auto s32b20 = vreinterpretq_s32_f16(b20.val[0]); \ - auto s32b21 = vreinterpretq_s32_f16(b21.val[1]); \ - auto s32b30 = vreinterpretq_s32_f16(b30.val[0]); \ - auto s32b31 = vreinterpretq_s32_f16(b31.val[1]); \ - auto s32b00b10 = vzipq_s32(s32b00, s32b10); \ - auto s32b01b11 = vzipq_s32(s32b01, s32b11); \ - auto s32b20b30 = vzipq_s32(s32b20, s32b30); \ - auto s32b21b31 = vzipq_s32(s32b21, s32b31); \ - CONCAT(ret, 0).value = vreinterpretq_f16_s32( \ - vcombine_s32(vget_low_s32(s32b00b10.val[0]), \ - vget_low_s32(s32b20b30.val[0]))); \ - CONCAT(ret, 1).value = vreinterpretq_f16_s32( \ - vcombine_s32(vget_high_s32(s32b00b10.val[0]), \ - vget_high_s32(s32b20b30.val[0]))); \ - CONCAT(ret, 2).value = vreinterpretq_f16_s32( \ - vcombine_s32(vget_low_s32(s32b00b10.val[1]), \ - vget_low_s32(s32b20b30.val[1]))); \ - CONCAT(ret, 3).value = vreinterpretq_f16_s32( \ - vcombine_s32(vget_high_s32(s32b00b10.val[1]), \ - vget_high_s32(s32b20b30.val[1]))); \ - CONCAT(ret, 4).value = vreinterpretq_f16_s32( \ - vcombine_s32(vget_low_s32(s32b01b11.val[0]), \ - vget_low_s32(s32b21b31.val[0]))); \ - CONCAT(ret, 5).value = vreinterpretq_f16_s32( \ - vcombine_s32(vget_high_s32(s32b01b11.val[0]), \ - vget_high_s32(s32b21b31.val[0]))); \ - CONCAT(ret, 6).value = vreinterpretq_f16_s32( \ - vcombine_s32(vget_low_s32(s32b01b11.val[1]), \ - vget_low_s32(s32b21b31.val[1]))); \ - CONCAT(ret, 7).value = vreinterpretq_f16_s32( \ - vcombine_s32(vget_high_s32(s32b01b11.val[1]), \ - vget_high_s32(s32b21b31.val[1]))); \ +#define TRANSPOSE_8x8(a, ret) \ + do { \ + auto b00 = vzipq_f16( \ + CONCAT(a, 0).value, CONCAT(a, 1).value); /*a1b1a2b2 a3b3a4b4*/ \ + auto b01 = vzipq_f16( \ + CONCAT(a, 0).value, CONCAT(a, 1).value); /*a5b5a6b6 a7b7a8b8*/ \ + auto b10 = vzipq_f16( \ + CONCAT(a, 2).value, CONCAT(a, 3).value); /*c1d1c2d2 c3d3c4d4*/ \ + auto b11 = vzipq_f16( \ + CONCAT(a, 2).value, CONCAT(a, 3).value); /*c5d5c6d6 c7d7c8d8*/ \ + auto b20 = vzipq_f16( \ + CONCAT(a, 4).value, CONCAT(a, 5).value); /*e1f1e2f2 e3f3e4f4*/ \ + auto b21 = vzipq_f16( \ + CONCAT(a, 4).value, CONCAT(a, 5).value); /*e5f5e6f6 e7f7e8f8*/ \ + auto b30 = vzipq_f16( \ + CONCAT(a, 6).value, CONCAT(a, 7).value); /*g1h1g2h2 g3h3g4h4*/ \ + auto b31 = vzipq_f16( \ + CONCAT(a, 6).value, CONCAT(a, 7).value); /*g5h5g6h6 g7h7g8h8*/ \ + auto s32b00 = vreinterpretq_s32_f16(b00.val[0]); \ + auto s32b01 = vreinterpretq_s32_f16(b01.val[1]); \ + auto s32b10 = vreinterpretq_s32_f16(b10.val[0]); \ + auto s32b11 = vreinterpretq_s32_f16(b11.val[1]); \ + auto s32b20 = vreinterpretq_s32_f16(b20.val[0]); \ + auto s32b21 = vreinterpretq_s32_f16(b21.val[1]); \ + auto s32b30 = vreinterpretq_s32_f16(b30.val[0]); \ + auto s32b31 = vreinterpretq_s32_f16(b31.val[1]); \ + auto s32b00b10 = vzipq_s32(s32b00, s32b10); \ + auto s32b01b11 = vzipq_s32(s32b01, s32b11); \ + auto s32b20b30 = vzipq_s32(s32b20, s32b30); \ + auto s32b21b31 = vzipq_s32(s32b21, s32b31); \ + CONCAT(ret, 0).value = vreinterpretq_f16_s32(vcombine_s32( \ + vget_low_s32(s32b00b10.val[0]), vget_low_s32(s32b20b30.val[0]))); \ + CONCAT(ret, 1).value = vreinterpretq_f16_s32(vcombine_s32( \ + vget_high_s32(s32b00b10.val[0]), vget_high_s32(s32b20b30.val[0]))); \ + CONCAT(ret, 2).value = vreinterpretq_f16_s32(vcombine_s32( \ + vget_low_s32(s32b00b10.val[1]), vget_low_s32(s32b20b30.val[1]))); \ + CONCAT(ret, 3).value = vreinterpretq_f16_s32(vcombine_s32( \ + vget_high_s32(s32b00b10.val[1]), vget_high_s32(s32b20b30.val[1]))); \ + CONCAT(ret, 4).value = vreinterpretq_f16_s32(vcombine_s32( \ + vget_low_s32(s32b01b11.val[0]), vget_low_s32(s32b21b31.val[0]))); \ + CONCAT(ret, 5).value = vreinterpretq_f16_s32(vcombine_s32( \ + vget_high_s32(s32b01b11.val[0]), vget_high_s32(s32b21b31.val[0]))); \ + CONCAT(ret, 6).value = vreinterpretq_f16_s32(vcombine_s32( \ + vget_low_s32(s32b01b11.val[1]), vget_low_s32(s32b21b31.val[1]))); \ + CONCAT(ret, 7).value = vreinterpretq_f16_s32(vcombine_s32( \ + vget_high_s32(s32b01b11.val[1]), vget_high_s32(s32b21b31.val[1]))); \ } while (0); #endif diff --git a/dnn/src/arm_common/conv_bias/f16/strategy.h b/dnn/src/arm_common/conv_bias/f16/strategy.h index 4a9fd473..53f89b7a 100644 --- a/dnn/src/arm_common/conv_bias/f16/strategy.h +++ b/dnn/src/arm_common/conv_bias/f16/strategy.h @@ -18,14 +18,18 @@ namespace megdnn { namespace arm_common { namespace winograd { -MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 2, - 3, 4, 4, winograd_2x3_4x4_f16) -MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 4, - 5, 1, 1, winograd_4x5_1x1_f16) -MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 6, - 3, 1, 1, winograd_6x3_1x1_f16) -MEGDNN_REG_WINOGRAD_STRATEGY(dt_float16, dt_float16, dt_float16, dt_float16, 2, - 3, 8, 8, winograd_2x3_8x8_f16) +MEGDNN_REG_WINOGRAD_STRATEGY( + dt_float16, dt_float16, dt_float16, dt_float16, 2, 3, 4, 4, + winograd_2x3_4x4_f16) +MEGDNN_REG_WINOGRAD_STRATEGY( + dt_float16, dt_float16, dt_float16, dt_float16, 4, 5, 1, 1, + winograd_4x5_1x1_f16) +MEGDNN_REG_WINOGRAD_STRATEGY( + dt_float16, dt_float16, dt_float16, dt_float16, 6, 3, 1, 1, + winograd_6x3_1x1_f16) +MEGDNN_REG_WINOGRAD_STRATEGY( + dt_float16, dt_float16, dt_float16, dt_float16, 2, 3, 8, 8, + winograd_2x3_8x8_f16) } // namespace winograd } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp b/dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp index 8ec31629..2230adf1 100644 --- a/dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp +++ b/dnn/src/arm_common/conv_bias/f16/strategy_2x3.cpp @@ -46,17 +46,16 @@ void transpose_4x4(const __fp16* src, __fp16* dst, int lda, int ldb) { struct InputTransform2X3 { template - static void prepare(const __fp16* input, __fp16* patch, __fp16* patchT, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t ic, size_t IC) { + static void prepare( + const __fp16* input, __fp16* patch, __fp16* patchT, int ih_start, + int iw_start, size_t IH, size_t IW, size_t ic, size_t IC) { constexpr size_t alpha = 2 + 3 - 1; constexpr size_t alpha4 = alpha * 4; if (!(inner && ic + 4 < IC)) { memset(patch, 0, sizeof(__fp16) * 4 * alpha * alpha); } if (inner) { - const __fp16* input_ptr = - input + ic * IH * IW + ih_start * IW + iw_start; + const __fp16* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; for (size_t ico = 0; ico < 4; ++ico) { if (ic + ico < IC) { auto v0 = vld1_f16(input_ptr); @@ -96,14 +95,13 @@ struct InputTransform2X3 { transpose_4x4(patch + 12 * 1, patchT + 12 * 4, 16, 4); } - static void transform(const __fp16* patchT, __fp16* input_transform_buf, - size_t unit_idx, size_t nr_units_in_tile, size_t ic, - size_t IC) { + static void transform( + const __fp16* patchT, __fp16* input_transform_buf, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC) { constexpr size_t alpha = 2 + 3 - 1; // BT * d * B -#define cb(m, n) \ - Vector<__fp16, 4> d##m##n = \ - Vector<__fp16, 4>::load(patchT + m * 4 * 4 + n * 4); +#define cb(m, n) \ + Vector<__fp16, 4> d##m##n = Vector<__fp16, 4>::load(patchT + m * 4 * 4 + n * 4); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -130,10 +128,10 @@ struct InputTransform2X3 { UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(m, n) \ - d##m##n.save(input_transform_buf + \ - (m * alpha + n) * nr_units_in_tile * IC + unit_idx * IC + \ - ic); +#define cb(m, n) \ + d##m##n.save( \ + input_transform_buf + (m * alpha + n) * nr_units_in_tile * IC + \ + unit_idx * IC + ic); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) #undef cb } @@ -141,20 +139,18 @@ struct InputTransform2X3 { template struct OutputTransform2X3 { - static void transform(const dt_float16* output_transform_buf, - const dt_float16* bias, dt_float16* output, - dt_float16* transform_mid_buf, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, size_t oc_index, - size_t unit_idx, size_t nr_units_in_tile, - const DType& src_dtype, const DType& dst_dtype) { + static void transform( + const dt_float16* output_transform_buf, const dt_float16* bias, + dt_float16* output, dt_float16* transform_mid_buf, size_t oh_start, + size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, + const DType& src_dtype, const DType& dst_dtype) { Op op(src_dtype, dst_dtype); const __fp16* output_transform_ptr = reinterpret_cast(output_transform_buf); const __fp16* bias_ptr = reinterpret_cast(bias); __fp16* output_ptr = reinterpret_cast<__fp16*>(output); - __fp16* transform_mid_ptr = - reinterpret_cast<__fp16*>(transform_mid_buf); + __fp16* transform_mid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf); //! AT * m * A constexpr size_t alpha = 2 + 3 - 1; @@ -225,20 +221,19 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f16) -void winograd_2x3_4x4_f16::filter(const dt_float16* filter, - dt_float16* filter_transform_buf, - dt_float16* transform_mid_buf, size_t OC, - size_t IC, size_t oc_start, size_t oc_end) { +void winograd_2x3_4x4_f16::filter( + const dt_float16* filter, dt_float16* filter_transform_buf, + dt_float16* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, + size_t oc_end) { constexpr int alpha = 2 + 3 - 1; //! G * g * GT - __fp16* filter_transbuf_ptr = - reinterpret_cast<__fp16*>(filter_transform_buf); + __fp16* filter_transbuf_ptr = reinterpret_cast<__fp16*>(filter_transform_buf); __fp16* filter_transmid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf); for (size_t oc = oc_start; oc < oc_end; oc++) { rep(ic, IC) { - const __fp16* filter_ptr = reinterpret_cast(filter) + - (oc * IC + ic) * 3 * 3; + const __fp16* filter_ptr = + reinterpret_cast(filter) + (oc * IC + ic) * 3 * 3; /** * origin: (4x3) * (3 x 3) * (3 x 4) * pack to G and g to times of 4 @@ -290,12 +285,10 @@ void winograd_2x3_4x4_f16::filter(const dt_float16* filter, } } -void winograd_2x3_4x4_f16::input(const dt_float16* input, - dt_float16* input_transform_buf, - dt_float16* transform_mid_buf, size_t IH, - size_t IW, size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_2x3_4x4_f16::input( + const dt_float16* input, dt_float16* input_transform_buf, + dt_float16* transform_mid_buf, size_t IH, size_t IW, size_t IC, size_t PH, + size_t PW, size_t unit_start_idx, size_t nr_units_in_tile) { megdnn_assert(IC % 4 == 0); constexpr int alpha = 3 + 2 - 1; @@ -316,36 +309,35 @@ void winograd_2x3_4x4_f16::input(const dt_float16* input, InputTransform2X3::prepare( reinterpret_cast(input), reinterpret_cast<__fp16*>(patch), - reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, - IH, IW, ic, IC); + reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, IH, IW, + ic, IC); InputTransform2X3::transform( reinterpret_cast(patchT), - reinterpret_cast<__fp16*>(input_transform_buf), - unit_idx, nr_units_in_tile, ic, IC); + reinterpret_cast<__fp16*>(input_transform_buf), unit_idx, + nr_units_in_tile, ic, IC); } else { InputTransform2X3::prepare( reinterpret_cast(input), reinterpret_cast<__fp16*>(patch), - reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, - IH, IW, ic, IC); + reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, IH, IW, + ic, IC); InputTransform2X3::transform( reinterpret_cast(patchT), - reinterpret_cast<__fp16*>(input_transform_buf), - unit_idx, nr_units_in_tile, ic, IC); + reinterpret_cast<__fp16*>(input_transform_buf), unit_idx, + nr_units_in_tile, ic, IC); } } } } -void winograd_2x3_4x4_f16::output(const dt_float16* output_transform_buf, - const dt_float16* bias, dt_float16* output, - dt_float16* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, - size_t unit_start_idx, size_t nr_units_in_tile) { +void winograd_2x3_4x4_f16::output( + const dt_float16* output_transform_buf, const dt_float16* bias, + dt_float16* output, dt_float16* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, size_t oc_start, size_t oc_end, + size_t unit_start_idx, size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); - + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); for (size_t oc = oc_start; oc < oc_end; oc += 4) { diff --git a/dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp b/dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp index a4855b42..62a3642b 100644 --- a/dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp +++ b/dnn/src/arm_common/conv_bias/f16/strategy_2x3_8x8.cpp @@ -13,16 +13,16 @@ #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/naive/matrix_mul/matrix_mul_helper.h" -#include "src/arm_common/elemwise_helper/op_unary.h" -#include "src/arm_common/conv_bias/f16/strategy.h" #include "src/arm_common/conv_bias/f16/helper.h" +#include "src/arm_common/conv_bias/f16/strategy.h" +#include "src/arm_common/elemwise_helper/op_unary.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/utils.h" -#include "src/common/winograd/winograd_generator.h" +#include "midout.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" -#include "midout.h" +#include "src/common/winograd/winograd_generator.h" #include "src/common/winograd/winograd_helper.h" @@ -63,16 +63,15 @@ void transpose_8x4(const __fp16* src, __fp16* dst, int lda, int ldb) { struct InputTransform2X3_8x8 { template - static void prepare(const __fp16* input, __fp16* patch, __fp16* patchT, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t ic, size_t IC) { + static void prepare( + const __fp16* input, __fp16* patch, __fp16* patchT, int ih_start, + int iw_start, size_t IH, size_t IW, size_t ic, size_t IC) { constexpr size_t alpha = 2 + 3 - 1; if (!(inner && ic + 8 < IC)) { memset(patch, 0, sizeof(__fp16) * 8 * alpha * alpha); } if (inner) { - const __fp16* input_ptr = - input + ic * IH * IW + ih_start * IW + iw_start; + const __fp16* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; for (size_t ico = 0; ico < 8; ++ico) { if (ic + ico < IC) { auto v0 = vld1_f16(input_ptr); @@ -112,14 +111,13 @@ struct InputTransform2X3_8x8 { transpose_8x4(patch + 4 * 3, patchT + 32 * 3, 16, 4); } - static void transform(const __fp16* patchT, __fp16* input_transform_buf, - size_t unit_idx, size_t nr_units_in_tile, size_t ic, - size_t IC) { + static void transform( + const __fp16* patchT, __fp16* input_transform_buf, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC) { constexpr size_t alpha = 2 + 3 - 1; // BT * d * B -#define cb(m, n) \ - Vector<__fp16, 8> d##m##n = \ - Vector<__fp16, 8>::load(patchT + 8 * (m * 4 + n)); +#define cb(m, n) \ + Vector<__fp16, 8> d##m##n = Vector<__fp16, 8>::load(patchT + 8 * (m * 4 + n)); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -148,10 +146,10 @@ struct InputTransform2X3_8x8 { size_t ICB = IC / 8; size_t icb = ic / 8; -#define cb(m, n) \ - d##m##n.save(input_transform_buf + \ - (m * alpha + n) * nr_units_in_tile * ICB * 8 + \ - icb * nr_units_in_tile * 8 + unit_idx * 8); +#define cb(m, n) \ + d##m##n.save( \ + input_transform_buf + (m * alpha + n) * nr_units_in_tile * ICB * 8 + \ + icb * nr_units_in_tile * 8 + unit_idx * 8); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) #undef cb } @@ -159,20 +157,18 @@ struct InputTransform2X3_8x8 { template struct OutputTransform2X3_8x8 { - static void transform(const dt_float16* output_transform_buf, - const dt_float16* bias, dt_float16* output, - dt_float16* transform_mid_buf, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, size_t oc_index, - size_t unit_idx, size_t nr_units_in_tile, - const DType& src_dtype, const DType& dst_dtype) { + static void transform( + const dt_float16* output_transform_buf, const dt_float16* bias, + dt_float16* output, dt_float16* transform_mid_buf, size_t oh_start, + size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, + const DType& src_dtype, const DType& dst_dtype) { Op op(src_dtype, dst_dtype); const __fp16* output_transform_ptr = reinterpret_cast(output_transform_buf); const __fp16* bias_ptr = reinterpret_cast(bias); __fp16* output_ptr = reinterpret_cast<__fp16*>(output); - __fp16* transform_mid_ptr = - reinterpret_cast<__fp16*>(transform_mid_buf); + __fp16* transform_mid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf); //! AT * m * A constexpr size_t alpha = 2 + 3 - 1; @@ -180,11 +176,10 @@ struct OutputTransform2X3_8x8 { size_t oc = oc_start + oc_index; size_t OCB = (oc_end - oc_start) / 8; size_t ocb = oc_index / 8; - -#define cb(m, n) \ - auto v##m##n = Vector<__fp16, 8>::load( \ - output_transform_ptr + \ - (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ + +#define cb(m, n) \ + auto v##m##n = Vector<__fp16, 8>::load( \ + output_transform_ptr + (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ ocb * nr_units_in_tile * 8 + unit_idx * 8); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -249,22 +244,21 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_f16) -void winograd_2x3_8x8_f16::filter(const dt_float16* filter, - dt_float16* filter_transform_buf, - dt_float16* transform_mid_buf, size_t OC, - size_t IC, size_t oc_start, size_t oc_end) { +void winograd_2x3_8x8_f16::filter( + const dt_float16* filter, dt_float16* filter_transform_buf, + dt_float16* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, + size_t oc_end) { constexpr int alpha = 2 + 3 - 1; //! G * g * GT - __fp16* filter_transbuf_ptr = - reinterpret_cast<__fp16*>(filter_transform_buf); + __fp16* filter_transbuf_ptr = reinterpret_cast<__fp16*>(filter_transform_buf); __fp16* filter_transmid_ptr = reinterpret_cast<__fp16*>(transform_mid_buf); size_t OCB = OC / 8; size_t ICB = IC / 8; for (size_t oc = oc_start; oc < oc_end; oc++) { rep(ic, IC) { - const __fp16* filter_ptr = reinterpret_cast(filter) + - (oc * IC + ic) * 3 * 3; + const __fp16* filter_ptr = + reinterpret_cast(filter) + (oc * IC + ic) * 3 * 3; /** * origin: (4x3) * (3 x 3) * (3 x 4) * pack to G and g to times of 4 @@ -313,19 +307,19 @@ void winograd_2x3_8x8_f16::filter(const dt_float16* filter, size_t icb = (ic) / 8; size_t ic8 = (ic) % 8; rep(i, alpha) rep(j, alpha) { - filter_transbuf_ptr[(i * alpha + j) * OCB * ICB * 8 * 8 + - ocb * ICB * 8 * 8 + icb * 8 * 8 + ic8 * 8 + - oc8] = filter_transmid_ptr[i * alpha + j]; + filter_transbuf_ptr + [(i * alpha + j) * OCB * ICB * 8 * 8 + ocb * ICB * 8 * 8 + + icb * 8 * 8 + ic8 * 8 + oc8] = + filter_transmid_ptr[i * alpha + j]; } } } } -void winograd_2x3_8x8_f16::input(const dt_float16* input, - dt_float16* input_transform_buf, - dt_float16* transform_mid_buf, size_t IH, - size_t IW, size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, size_t nr_units_in_tile) { +void winograd_2x3_8x8_f16::input( + const dt_float16* input, dt_float16* input_transform_buf, + dt_float16* transform_mid_buf, size_t IH, size_t IW, size_t IC, size_t PH, + size_t PW, size_t unit_start_idx, size_t nr_units_in_tile) { megdnn_assert(IC % 8 == 0); constexpr int alpha = 3 + 2 - 1; @@ -346,38 +340,35 @@ void winograd_2x3_8x8_f16::input(const dt_float16* input, InputTransform2X3_8x8::prepare( reinterpret_cast(input), reinterpret_cast<__fp16*>(patch), - reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, - IH, IW, ic, IC); + reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, IH, IW, + ic, IC); InputTransform2X3_8x8::transform( reinterpret_cast(patchT), - reinterpret_cast<__fp16*>(input_transform_buf), - unit_idx, nr_units_in_tile, ic, IC); + reinterpret_cast<__fp16*>(input_transform_buf), unit_idx, + nr_units_in_tile, ic, IC); } else { InputTransform2X3_8x8::prepare( reinterpret_cast(input), reinterpret_cast<__fp16*>(patch), - reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, - IH, IW, ic, IC); + reinterpret_cast<__fp16*>(patchT), ih_start, iw_start, IH, IW, + ic, IC); InputTransform2X3_8x8::transform( reinterpret_cast(patchT), - reinterpret_cast<__fp16*>(input_transform_buf), - unit_idx, nr_units_in_tile, ic, IC); + reinterpret_cast<__fp16*>(input_transform_buf), unit_idx, + nr_units_in_tile, ic, IC); } } } } -void winograd_2x3_8x8_f16::output(const dt_float16* output_transform_buf, - const dt_float16* bias, dt_float16* output, - dt_float16* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, - size_t OW, size_t oc_start, size_t oc_end, - size_t unit_start_idx, - size_t nr_units_in_tile) { -#define cb(_bmode, _nonline_op, ...) \ - OutputTransform2X3_8x8<_bmode MEGDNN_COMMA _nonline_op>::transform( \ - __VA_ARGS__); +void winograd_2x3_8x8_f16::output( + const dt_float16* output_transform_buf, const dt_float16* bias, + dt_float16* output, dt_float16* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, size_t oc_start, size_t oc_end, + size_t unit_start_idx, size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform2X3_8x8<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); @@ -390,10 +381,10 @@ void winograd_2x3_8x8_f16::output(const dt_float16* output_transform_buf, size_t oh_start = nh * OUTPUT_BLOCK_SIZE; size_t ow_start = nw * OUTPUT_BLOCK_SIZE; DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_f16_F23_8x8, cb, __fp16, __fp16, - bmode, nonline_mode, output_transform_buf, bias, output, - transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, - oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, dst_dtype); + megdnn_arm_common_winograd_f16_F23_8x8, cb, __fp16, __fp16, bmode, + nonline_mode, output_transform_buf, bias, output, transform_mid_buf, + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, + nr_units_in_tile, src_dtype, dst_dtype); } } #undef cb diff --git a/dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp b/dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp index 01cecbdd..bcda5712 100644 --- a/dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp +++ b/dnn/src/arm_common/conv_bias/f16/strategy_4x5.cpp @@ -86,9 +86,10 @@ struct FilterTransform4X5 { wd##6 = tmp0 - tmp1; \ wd##7 = d##4; \ } while (0); - static void transform(const __fp16* filter, __fp16* filter_transform_buf, - __fp16* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { + static void transform( + const __fp16* filter, __fp16* filter_transform_buf, + __fp16* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, + size_t oc_end) { // Gg * GT // G //[[ 1. 0. 0. 0. 0. ] @@ -129,8 +130,7 @@ struct FilterTransform4X5 { #undef cb FILTER_TRANSFORM(g, Gg) #if MEGDNN_AARCH64 - float16x8_t vgr = {Ggr0, Ggr1, Ggr2, Ggr3, - Ggr4, Ggr5, Ggr6, Ggr7}; + float16x8_t vgr = {Ggr0, Ggr1, Ggr2, Ggr3, Ggr4, Ggr5, Ggr6, Ggr7}; Vector<__fp16, 8> Ggt4(vgr); TRANSPOSE_8x4(Gg, Ggt); FILTER_TRANSFORM_FINAL(Ggt, result); @@ -138,42 +138,42 @@ struct FilterTransform4X5 { UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { - filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + - oc] = transform_mid_buf[j * alpha + i]; + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = + transform_mid_buf[j * alpha + i]; } #else #define GET_VECTOR_FP16D_ELEM(s, i, idx) vget_lane_f16(CONCAT(s, i).value, idx) -#define cb(i) \ - do { \ - mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \ - auto tmp024 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \ - GET_VECTOR_FP16D_ELEM(Gg, i, 2) + Ggr##i; \ - auto tmp13 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) + \ - GET_VECTOR_FP16D_ELEM(Gg, i, 3); \ - mid_buf1[1] = (tmp024 + tmp13) * -0.2222222; \ - mid_buf1[2] = (tmp024 - tmp13) * -0.2222222; \ - auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \ - auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \ - auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \ - auto tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ - auto tmp4 = Ggr##i * 0.0444444; \ - tmp024 = tmp0 + tmp2 + tmp4; \ - tmp13 = tmp1 + tmp3; \ - mid_buf1[3] = tmp024 + tmp13; \ - mid_buf1[4] = tmp024 - tmp13; \ - tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \ - tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \ - tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \ - tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ - tmp4 = Ggr##i * 0.1777778; \ - tmp024 = tmp0 + tmp2 + tmp4; \ - tmp13 = tmp1 + tmp3; \ - mid_buf1[5] = tmp024 + tmp13; \ - mid_buf1[6] = tmp024 - tmp13; \ - mid_buf1[7] = Ggr##i; \ - mid_buf1 += 8; \ +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \ + auto tmp024 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \ + GET_VECTOR_FP16D_ELEM(Gg, i, 2) + Ggr##i; \ + auto tmp13 = \ + GET_VECTOR_FP16D_ELEM(Gg, i, 1) + GET_VECTOR_FP16D_ELEM(Gg, i, 3); \ + mid_buf1[1] = (tmp024 + tmp13) * -0.2222222; \ + mid_buf1[2] = (tmp024 - tmp13) * -0.2222222; \ + auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \ + auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \ + auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \ + auto tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ + auto tmp4 = Ggr##i * 0.0444444; \ + tmp024 = tmp0 + tmp2 + tmp4; \ + tmp13 = tmp1 + tmp3; \ + mid_buf1[3] = tmp024 + tmp13; \ + mid_buf1[4] = tmp024 - tmp13; \ + tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \ + tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \ + tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \ + tmp3 = GET_VECTOR_FP16D_ELEM(Gg, i, 3) * 0.0888889; \ + tmp4 = Ggr##i * 0.1777778; \ + tmp024 = tmp0 + tmp2 + tmp4; \ + tmp13 = tmp1 + tmp3; \ + mid_buf1[5] = tmp024 + tmp13; \ + mid_buf1[6] = tmp024 - tmp13; \ + mid_buf1[7] = Ggr##i; \ + mid_buf1 += 8; \ } while (0); __fp16* mid_buf1 = transform_mid_buf; UNROLL_CALL_NOWRAPPER(8, cb); @@ -181,8 +181,8 @@ struct FilterTransform4X5 { #undef cb #undef GET_VECTOR_FP16D_ELEM rep(i, alpha) rep(j, alpha) { - filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + - oc] = transform_mid_buf[i * alpha + j]; + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = + transform_mid_buf[i * alpha + j]; } #endif } @@ -214,10 +214,10 @@ struct InputTransform4X5 { #define GET_VECTOR_FP16Q_ELEM(s, i, idx) vgetq_lane_f16(CONCAT(s, i).value, idx) template - static void transform(const __fp16* input, __fp16* input_transform_buf, - __fp16* transform_mid_buf, int ih_start, int iw_start, - size_t ic, size_t IH, size_t IW, size_t IC, - size_t unit_idx, size_t nr_units_in_tile) { + static void transform( + const __fp16* input, __fp16* input_transform_buf, __fp16* transform_mid_buf, + int ih_start, int iw_start, size_t ic, size_t IH, size_t IW, size_t IC, + size_t unit_idx, size_t nr_units_in_tile) { // BTd * B //([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ], // [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ], @@ -238,8 +238,7 @@ struct InputTransform4X5 { #undef cb if (inner) { - const __fp16* input_ptr = - input + ic * IH * IW + ih_start * IW + iw_start; + const __fp16* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; #define cb(i) d##i = Vector<__fp16, 8>::load(input_ptr + IW * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb @@ -273,33 +272,30 @@ struct InputTransform4X5 { #undef cb rep(i, alpha) rep(j, alpha) { - input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + - unit_idx * IC + ic] = - transform_mid_buf[j * alpha + i]; + input_transform_buf + [(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = + transform_mid_buf[j * alpha + i]; } } }; #undef INPUT_TRANSFORM -#define OUTPUT_TRANSFORM(m, s) \ - do { \ - s##0 = m##0 + m##1 + m##2 + m##3 + m##4 + m##5 + m##6; \ - s##1 = m##1 - m##2 + m##3 * 0.5 - m##4 * 0.5 + m##5 * 2.0 - \ - m##6 * 2.0; \ - s##2 = m##1 + m##2 + m##3 * 0.25 + m##4 * 0.25 + m##5 * 4.0 + \ - m##6 * 4.0; \ - s##3 = m##1 - m##2 + m##3 * 0.125 - m##4 * 0.125 + m##5 * 8.0 - \ - m##6 * 8.0 + m##7; \ +#define OUTPUT_TRANSFORM(m, s) \ + do { \ + s##0 = m##0 + m##1 + m##2 + m##3 + m##4 + m##5 + m##6; \ + s##1 = m##1 - m##2 + m##3 * 0.5 - m##4 * 0.5 + m##5 * 2.0 - m##6 * 2.0; \ + s##2 = m##1 + m##2 + m##3 * 0.25 + m##4 * 0.25 + m##5 * 4.0 + m##6 * 4.0; \ + s##3 = m##1 - m##2 + m##3 * 0.125 - m##4 * 0.125 + m##5 * 8.0 - m##6 * 8.0 + \ + m##7; \ } while (0) template struct OutputTransform4X5 { - static void transform(const dt_float16* output_transform_buf, - const dt_float16* bias, dt_float16* output, - dt_float16* transform_mid_buf, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, size_t oc_index, - size_t unit_idx, size_t nr_units_in_tile, - const DType& src_dtype, const DType& dst_dtype) { + static void transform( + const dt_float16* output_transform_buf, const dt_float16* bias, + dt_float16* output, dt_float16* transform_mid_buf, size_t oh_start, + size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, + const DType& src_dtype, const DType& dst_dtype) { Op op(src_dtype, dst_dtype); //! AT * m * A // AT f45 @@ -312,24 +308,20 @@ struct OutputTransform4X5 { reinterpret_cast(output_transform_buf); const __fp16* fp16_bias = reinterpret_cast(bias); __fp16* fp16_output = reinterpret_cast<__fp16*>(output); - __fp16* fp16_transform_mid_buf = - reinterpret_cast<__fp16*>(transform_mid_buf); + __fp16* fp16_transform_mid_buf = reinterpret_cast<__fp16*>(transform_mid_buf); __fp16* mid_buf1 = fp16_transform_mid_buf; size_t OC = oc_end - oc_start; size_t oc = oc_start + oc_index; -#define cb(m, n) \ - fp16_transform_mid_buf[m * alpha + n] = \ - fp16_output_transform_buf[(m * alpha + n) * nr_units_in_tile * \ - OC + \ - unit_idx * OC + oc_index]; +#define cb(m, n) \ + fp16_transform_mid_buf[m * alpha + n] = fp16_output_transform_buf \ + [(m * alpha + n) * nr_units_in_tile * OC + unit_idx * OC + oc_index]; UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb -#define cb(i) \ - auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i); +#define cb(i) auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb #define cb(i) Vector<__fp16, 8> s##i; @@ -406,23 +398,20 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f16) -void winograd_4x5_1x1_f16::filter(const dt_float16* filter, - dt_float16* filter_transform_buf, - dt_float16* transform_mid_buf, size_t OC, - size_t IC, size_t oc_start, size_t oc_end) { +void winograd_4x5_1x1_f16::filter( + const dt_float16* filter, dt_float16* filter_transform_buf, + dt_float16* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, + size_t oc_end) { FilterTransform4X5::transform( reinterpret_cast(filter), reinterpret_cast<__fp16*>(filter_transform_buf), - reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start, - oc_end); + reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start, oc_end); } -void winograd_4x5_1x1_f16::input(const dt_float16* input, - dt_float16* input_transform_buf, - dt_float16* transform_mid_buf, size_t IH, - size_t IW, size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_4x5_1x1_f16::input( + const dt_float16* input, dt_float16* input_transform_buf, + dt_float16* transform_mid_buf, size_t IH, size_t IW, size_t IC, size_t PH, + size_t PW, size_t unit_start_idx, size_t nr_units_in_tile) { constexpr int alpha = 4 + 5 - 1; // OW = IW + 2 * PW - KERNEL_SIZE + 1 auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); @@ -452,12 +441,11 @@ void winograd_4x5_1x1_f16::input(const dt_float16* input, } } -void winograd_4x5_1x1_f16::output(const dt_float16* output_transform_buf, - const dt_float16* bias, dt_float16* output, - dt_float16* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, - size_t unit_start_idx, size_t nr_units_in_tile) { +void winograd_4x5_1x1_f16::output( + const dt_float16* output_transform_buf, const dt_float16* bias, + dt_float16* output, dt_float16* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, size_t oc_start, size_t oc_end, + size_t unit_start_idx, size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform4X5<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); diff --git a/dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp b/dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp index 68362eea..2a3258cc 100644 --- a/dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp +++ b/dnn/src/arm_common/conv_bias/f16/strategy_6x3.cpp @@ -66,9 +66,10 @@ struct FilterTransform6X3 { wd##6 = tmp0 - tmp1; \ wd##7 = d##2; \ } while (0); - static void transform(const __fp16* filter, __fp16* filter_transform_buf, - __fp16* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { + static void transform( + const __fp16* filter, __fp16* filter_transform_buf, + __fp16* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, + size_t oc_end) { // Gg * GT // G // 1.0000000 0.0000000 0.0000000 @@ -115,8 +116,8 @@ struct FilterTransform6X3 { UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { - filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + - oc] = transform_mid_buf[j * alpha + i]; + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = + transform_mid_buf[j * alpha + i]; } #else /* 1.0000000 -0.2222222 -0.2222222 0.0111111 0.0111111 @@ -126,35 +127,35 @@ struct FilterTransform6X3 { 0.0444444 0.1777778 0.1777778 1.0000000*/ #define GET_VECTOR_FP16D_ELEM(s, i, idx) vget_lane_f16(CONCAT(s, i).value, idx) -#define cb(i) \ - do { \ - mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \ - auto tmp02 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) + \ - GET_VECTOR_FP16D_ELEM(Gg, i, 2); \ - mid_buf1[1] = (tmp02 + GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \ - mid_buf1[2] = (tmp02 - GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \ - auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \ - auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \ - auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \ - tmp02 = tmp0 + tmp2; \ - mid_buf1[3] = tmp02 + tmp1; \ - mid_buf1[4] = tmp02 - tmp1; \ - tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \ - tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \ - tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \ - tmp02 = tmp0 + tmp2; \ - mid_buf1[5] = tmp02 + tmp1; \ - mid_buf1[6] = tmp02 - tmp1; \ - mid_buf1[7] = GET_VECTOR_FP16D_ELEM(Gg, i, 2); \ - mid_buf1 += 8; \ +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_FP16D_ELEM(Gg, i, 0); \ + auto tmp02 = \ + GET_VECTOR_FP16D_ELEM(Gg, i, 0) + GET_VECTOR_FP16D_ELEM(Gg, i, 2); \ + mid_buf1[1] = (tmp02 + GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \ + mid_buf1[2] = (tmp02 - GET_VECTOR_FP16D_ELEM(Gg, i, 1)) * -0.2222222; \ + auto tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.0111111; \ + auto tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.0222222; \ + auto tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.0444444; \ + tmp02 = tmp0 + tmp2; \ + mid_buf1[3] = tmp02 + tmp1; \ + mid_buf1[4] = tmp02 - tmp1; \ + tmp0 = GET_VECTOR_FP16D_ELEM(Gg, i, 0) * 0.7111111; \ + tmp1 = GET_VECTOR_FP16D_ELEM(Gg, i, 1) * 0.3555556; \ + tmp2 = GET_VECTOR_FP16D_ELEM(Gg, i, 2) * 0.1777778; \ + tmp02 = tmp0 + tmp2; \ + mid_buf1[5] = tmp02 + tmp1; \ + mid_buf1[6] = tmp02 - tmp1; \ + mid_buf1[7] = GET_VECTOR_FP16D_ELEM(Gg, i, 2); \ + mid_buf1 += 8; \ } while (0); __fp16* mid_buf1 = transform_mid_buf; UNROLL_CALL_NOWRAPPER(8, cb); mid_buf1 = transform_mid_buf; #undef cb rep(i, alpha) rep(j, alpha) { - filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + - oc] = transform_mid_buf[i * alpha + j]; + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = + transform_mid_buf[i * alpha + j]; } #undef GET_VECTOR_FP16D_ELEM #endif @@ -197,10 +198,10 @@ struct FilterTransform6X3 { #define GET_VECTOR_FP16Q_ELEM(s, i, idx) vgetq_lane_f16(CONCAT(s, i).value, idx) struct InputTransform6x3 { template - static void transform(const __fp16* input, __fp16* input_transform_buf, - __fp16* transform_mid_buf, int ih_start, int iw_start, - size_t ic, size_t IH, size_t IW, size_t IC, - size_t unit_idx, size_t nr_units_in_tile) { + static void transform( + const __fp16* input, __fp16* input_transform_buf, __fp16* transform_mid_buf, + int ih_start, int iw_start, size_t ic, size_t IH, size_t IW, size_t IC, + size_t unit_idx, size_t nr_units_in_tile) { // BTd * B // 1.000 0.000 -5.25 0.000 5.250 0.000 -1.0 0.00 // -0.00 1.000 1.000 -4.25 -4.25 1.000 1.00 -0.0 @@ -220,8 +221,7 @@ struct InputTransform6x3 { #undef cb if (inner) { - const __fp16* input_ptr = - input + ic * IH * IW + ih_start * IW + iw_start; + const __fp16* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; #define cb(i) d##i = Vector<__fp16, 8>::load(input_ptr + IW * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb @@ -257,9 +257,9 @@ struct InputTransform6x3 { #undef cb rep(i, alpha) rep(j, alpha) { - input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + - unit_idx * IC + ic] = - transform_mid_buf[j * alpha + i]; + input_transform_buf + [(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = + transform_mid_buf[j * alpha + i]; } #else //! 1 0 0 0 0 0 0 0 @@ -313,9 +313,9 @@ struct InputTransform6x3 { #undef cb rep(i, alpha) rep(j, alpha) { - input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + - unit_idx * IC + ic] = - transform_mid_buf[i * alpha + j]; + input_transform_buf + [(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = + transform_mid_buf[i * alpha + j]; } #endif } @@ -340,13 +340,12 @@ struct InputTransform6x3 { } while (0) template struct OutputTransform6X3 { - static void transform(const dt_float16* output_transform_buf, - const dt_float16* bias, dt_float16* output, - dt_float16* transform_mid_buf, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, size_t oc_index, - size_t unit_idx, size_t nr_units_in_tile, - const DType& src_dtype, const DType& dst_dtype) { + static void transform( + const dt_float16* output_transform_buf, const dt_float16* bias, + dt_float16* output, dt_float16* transform_mid_buf, size_t oh_start, + size_t ow_start, size_t OH, size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, + const DType& src_dtype, const DType& dst_dtype) { Op op(src_dtype, dst_dtype); //! AT * m * A // AT f45 @@ -361,24 +360,20 @@ struct OutputTransform6X3 { reinterpret_cast(output_transform_buf); const __fp16* fp16_bias = reinterpret_cast(bias); __fp16* fp16_output = reinterpret_cast<__fp16*>(output); - __fp16* fp16_transform_mid_buf = - reinterpret_cast<__fp16*>(transform_mid_buf); + __fp16* fp16_transform_mid_buf = reinterpret_cast<__fp16*>(transform_mid_buf); __fp16* mid_buf1 = fp16_transform_mid_buf; size_t OC = oc_end - oc_start; size_t oc = oc_start + oc_index; -#define cb(m, n) \ - fp16_transform_mid_buf[m * alpha + n] = \ - fp16_output_transform_buf[(m * alpha + n) * nr_units_in_tile * \ - OC + \ - unit_idx * OC + oc_index]; +#define cb(m, n) \ + fp16_transform_mid_buf[m * alpha + n] = fp16_output_transform_buf \ + [(m * alpha + n) * nr_units_in_tile * OC + unit_idx * OC + oc_index]; UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb -#define cb(i) \ - auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i); +#define cb(i) auto m##i = Vector<__fp16, 8>::load(fp16_transform_mid_buf + alpha * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb #define cb(i) Vector<__fp16, 8> s##i; @@ -396,29 +391,28 @@ struct OutputTransform6X3 { OUTPUT_TRANSFORM(m, s); mid_buf1 = fp16_transform_mid_buf; -#define cb(i) \ - do { \ - auto m1addm2 = GET_VECTOR_FP16Q_ELEM(s, i, 1) + \ - GET_VECTOR_FP16Q_ELEM(s, i, 2); \ - auto m1subm2 = GET_VECTOR_FP16Q_ELEM(s, i, 1) - \ - GET_VECTOR_FP16Q_ELEM(s, i, 2); \ - auto m3addm4 = GET_VECTOR_FP16Q_ELEM(s, i, 3) + \ - GET_VECTOR_FP16Q_ELEM(s, i, 4); \ - auto m3subm4 = GET_VECTOR_FP16Q_ELEM(s, i, 3) - \ - GET_VECTOR_FP16Q_ELEM(s, i, 4); \ - auto m5addm6 = GET_VECTOR_FP16Q_ELEM(s, i, 5) + \ - GET_VECTOR_FP16Q_ELEM(s, i, 6); \ - auto m5subm6 = GET_VECTOR_FP16Q_ELEM(s, i, 5) - \ - GET_VECTOR_FP16Q_ELEM(s, i, 6); \ - mid_buf1[0] = \ - GET_VECTOR_FP16Q_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ - mid_buf1[1] = m1subm2 + m3subm4 * 2 + m5subm6 * 0.5; \ - mid_buf1[2] = m1addm2 + m3addm4 * 4 + m5addm6 * 0.25; \ - mid_buf1[3] = m1subm2 + m3subm4 * 8 + m5subm6 * 0.125; \ - mid_buf1[4] = m1addm2 + m3addm4 * 16 + m5addm6 * 0.0625; \ - mid_buf1[5] = m1subm2 + m3subm4 * 32 + m5subm6 * 0.03125 + \ - GET_VECTOR_FP16Q_ELEM(s, i, 7); \ - mid_buf1 += 6; \ +#define cb(i) \ + do { \ + auto m1addm2 = \ + GET_VECTOR_FP16Q_ELEM(s, i, 1) + GET_VECTOR_FP16Q_ELEM(s, i, 2); \ + auto m1subm2 = \ + GET_VECTOR_FP16Q_ELEM(s, i, 1) - GET_VECTOR_FP16Q_ELEM(s, i, 2); \ + auto m3addm4 = \ + GET_VECTOR_FP16Q_ELEM(s, i, 3) + GET_VECTOR_FP16Q_ELEM(s, i, 4); \ + auto m3subm4 = \ + GET_VECTOR_FP16Q_ELEM(s, i, 3) - GET_VECTOR_FP16Q_ELEM(s, i, 4); \ + auto m5addm6 = \ + GET_VECTOR_FP16Q_ELEM(s, i, 5) + GET_VECTOR_FP16Q_ELEM(s, i, 6); \ + auto m5subm6 = \ + GET_VECTOR_FP16Q_ELEM(s, i, 5) - GET_VECTOR_FP16Q_ELEM(s, i, 6); \ + mid_buf1[0] = GET_VECTOR_FP16Q_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ + mid_buf1[1] = m1subm2 + m3subm4 * 2 + m5subm6 * 0.5; \ + mid_buf1[2] = m1addm2 + m3addm4 * 4 + m5addm6 * 0.25; \ + mid_buf1[3] = m1subm2 + m3subm4 * 8 + m5subm6 * 0.125; \ + mid_buf1[4] = m1addm2 + m3addm4 * 16 + m5addm6 * 0.0625; \ + mid_buf1[5] = m1subm2 + m3subm4 * 32 + m5subm6 * 0.03125 + \ + GET_VECTOR_FP16Q_ELEM(s, i, 7); \ + mid_buf1 += 6; \ } while (0); mid_buf1 = fp16_transform_mid_buf; UNROLL_CALL_NOWRAPPER(6, cb); @@ -436,8 +430,8 @@ struct OutputTransform6X3 { float16x8_t vr0123_45 = {mid_buf1[4], mid_buf1[5], mid_buf1[10], mid_buf1[11], mid_buf1[16], mid_buf1[17], mid_buf1[22], mid_buf1[23]}; - float16x4_t vr45_45 = {mid_buf1[28], mid_buf1[29], mid_buf1[34], - mid_buf1[35]}; + float16x4_t vr45_45 = { + mid_buf1[28], mid_buf1[29], mid_buf1[34], mid_buf1[35]}; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { float16x4_t bias0 = vdup_n_f16(fp16_bias[oc]); @@ -458,18 +452,14 @@ struct OutputTransform6X3 { UNROLL_CALL_NOWRAPPER(6, cb); #undef cb - float16x8_t vb0123_45 = {fp16_bias[index + 0 * OW + 4], - fp16_bias[index + 0 * OW + 5], - fp16_bias[index + 1 * OW + 4], - fp16_bias[index + 1 * OW + 5], - fp16_bias[index + 2 * OW + 4], - fp16_bias[index + 2 * OW + 5], - fp16_bias[index + 3 * OW + 4], - fp16_bias[index + 3 * OW + 5]}; - float16x4_t vb45_45 = {fp16_bias[index + 4 * OW + 4], - fp16_bias[index + 4 * OW + 5], - fp16_bias[index + 5 * OW + 4], - fp16_bias[index + 5 * OW + 5]}; + float16x8_t vb0123_45 = { + fp16_bias[index + 0 * OW + 4], fp16_bias[index + 0 * OW + 5], + fp16_bias[index + 1 * OW + 4], fp16_bias[index + 1 * OW + 5], + fp16_bias[index + 2 * OW + 4], fp16_bias[index + 2 * OW + 5], + fp16_bias[index + 3 * OW + 4], fp16_bias[index + 3 * OW + 5]}; + float16x4_t vb45_45 = { + fp16_bias[index + 4 * OW + 4], fp16_bias[index + 4 * OW + 5], + fp16_bias[index + 5 * OW + 4], fp16_bias[index + 5 * OW + 5]}; vr45_45 = vadd_f16(vr45_45, vb45_45); vr0123_45 = vaddq_f16(vr0123_45, vb0123_45); } @@ -527,23 +517,20 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f16) -void winograd_6x3_1x1_f16::filter(const dt_float16* filter, - dt_float16* filter_transform_buf, - dt_float16* transform_mid_buf, size_t OC, - size_t IC, size_t oc_start, size_t oc_end) { +void winograd_6x3_1x1_f16::filter( + const dt_float16* filter, dt_float16* filter_transform_buf, + dt_float16* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, + size_t oc_end) { FilterTransform6X3::transform( reinterpret_cast(filter), reinterpret_cast<__fp16*>(filter_transform_buf), - reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start, - oc_end); + reinterpret_cast<__fp16*>(transform_mid_buf), OC, IC, oc_start, oc_end); } -void winograd_6x3_1x1_f16::input(const dt_float16* input, - dt_float16* input_transform_buf, - dt_float16* transform_mid_buf, size_t IH, - size_t IW, size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_6x3_1x1_f16::input( + const dt_float16* input, dt_float16* input_transform_buf, + dt_float16* transform_mid_buf, size_t IH, size_t IW, size_t IC, size_t PH, + size_t PW, size_t unit_start_idx, size_t nr_units_in_tile) { constexpr int alpha = 6 + 3 - 1; // OW = IW + 2 * PW - KERNEL_SIZE + 1 auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); @@ -573,12 +560,11 @@ void winograd_6x3_1x1_f16::input(const dt_float16* input, } } -void winograd_6x3_1x1_f16::output(const dt_float16* output_transform_buf, - const dt_float16* bias, dt_float16* output, - dt_float16* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, - size_t unit_start_idx, size_t nr_units_in_tile) { +void winograd_6x3_1x1_f16::output( + const dt_float16* output_transform_buf, const dt_float16* bias, + dt_float16* output, dt_float16* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, size_t OW, size_t oc_start, size_t oc_end, + size_t unit_start_idx, size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.cpp b/dnn/src/arm_common/conv_bias/fp32/algos.cpp index 5edcb53f..ce886905 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/algos.cpp @@ -41,8 +41,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = - megdnn::winograd::ConvBias( + megdnn::winograd::ConvBias( strategy, m_tile_size, param) .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && @@ -53,8 +52,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float32; @@ -63,10 +61,9 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF23_4x4, - winograd::winograd_2x3_4x4_f, - megdnn_arm_common_winograd_fp32, - param::MatrixMul::Format::MK4); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP32WinogradF23_4x4, winograd::winograd_2x3_4x4_f, + megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); /* ======================= AlgoFP32WinogradF63 ======================== */ @@ -77,9 +74,9 @@ bool ConvBiasImpl::AlgoFP32WinogradF63::usable( MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 1, 0) { using Strategy = winograd::winograd_6x3_1x1_f; Strategy strategy(param.src_type, param.filter_type, param.dst_type); - auto&& matmul_param = megdnn::winograd::ConvBias( - strategy, m_tile_size, param) - .get_matmul_kern_param(param); + auto&& matmul_param = + megdnn::winograd::ConvBias(strategy, m_tile_size, param) + .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && param.filter_meta.format == param::ConvBias::Format::NCHW && !param.filter_meta.should_flip && @@ -87,8 +84,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63::usable( param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float32; @@ -97,10 +93,9 @@ bool ConvBiasImpl::AlgoFP32WinogradF63::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF63, - winograd::winograd_6x3_1x1_f, - megdnn_arm_common_winograd_fp32, - param::MatrixMul::Format::DEFAULT); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP32WinogradF63, winograd::winograd_6x3_1x1_f, + megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); /* ======================= AlgoFP32WinogradF54 ======================== */ @@ -111,9 +106,9 @@ bool ConvBiasImpl::AlgoFP32WinogradF54::usable( MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 2, 0) { using Strategy = winograd::winograd_5x4_1x1_f; Strategy strategy(param.src_type, param.filter_type, param.dst_type); - auto&& matmul_param = megdnn::winograd::ConvBias( - strategy, m_tile_size, param) - .get_matmul_kern_param(param); + auto&& matmul_param = + megdnn::winograd::ConvBias(strategy, m_tile_size, param) + .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && param.filter_meta.format == param::ConvBias::Format::NCHW && !param.filter_meta.should_flip && @@ -121,8 +116,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF54::usable( param.filter_meta.spatial[0] == 4) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float32; @@ -131,10 +125,9 @@ bool ConvBiasImpl::AlgoFP32WinogradF54::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF54, - winograd::winograd_5x4_1x1_f, - megdnn_arm_common_winograd_fp32, - param::MatrixMul::Format::DEFAULT); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP32WinogradF54, winograd::winograd_5x4_1x1_f, + megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); /* ======================= AlgoFP32WinogradF45 ======================== */ @@ -145,9 +138,9 @@ bool ConvBiasImpl::AlgoFP32WinogradF45::usable( MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, 3, 0) { using Strategy = winograd::winograd_4x5_1x1_f; Strategy strategy(param.src_type, param.filter_type, param.dst_type); - auto&& matmul_param = megdnn::winograd::ConvBias( - strategy, m_tile_size, param) - .get_matmul_kern_param(param); + auto&& matmul_param = + megdnn::winograd::ConvBias(strategy, m_tile_size, param) + .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && param.filter_meta.format == param::ConvBias::Format::NCHW && !param.filter_meta.should_flip && @@ -155,8 +148,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF45::usable( param.filter_meta.spatial[0] == 5) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float32; @@ -165,10 +157,9 @@ bool ConvBiasImpl::AlgoFP32WinogradF45::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF45, - winograd::winograd_4x5_1x1_f, - megdnn_arm_common_winograd_fp32, - param::MatrixMul::Format::DEFAULT); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP32WinogradF45, winograd::winograd_4x5_1x1_f, + megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::DEFAULT); /* ======================= AlgoFP32WinogradF63_4x4 ======================== */ @@ -183,8 +174,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = - megdnn::winograd::ConvBias( + megdnn::winograd::ConvBias( strategy, m_tile_size, param) .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && @@ -195,22 +185,19 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4::usable( param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float32 && - param.filter_meta.icpg % 4 == 0 && - param.filter_meta.ocpg % 4 == 0; + param.filter_meta.icpg % 4 == 0 && param.filter_meta.ocpg % 4 == 0; } MIDOUT_END(); return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF63_4x4, - winograd::winograd_6x3_4x4_f, - megdnn_arm_common_winograd_fp32, - param::MatrixMul::Format::MK4); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP32WinogradF63_4x4, winograd::winograd_6x3_4x4_f, + megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); /* =================== AlgoFP32WinogradF23_4x4_NCHW44 =================== */ @@ -218,15 +205,15 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { MEGDNN_MARK_USED_VAR(param); - MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, - midout_iv("AlgoFP32WinogradF23_4x4_NCHW44"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_winograd_fp32, + midout_iv("AlgoFP32WinogradF23_4x4_NCHW44"_hash)) { if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) return false; using Strategy = winograd::winograd_F23_mk4_f_nchw44; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = - megdnn::winograd::ConvBias( + megdnn::winograd::ConvBias( strategy, m_tile_size, param) .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && @@ -238,8 +225,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable( param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float32; @@ -248,10 +234,9 @@ bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF23_4x4_NCHW44, - winograd::winograd_F23_mk4_f_nchw44, - megdnn_arm_common_winograd_fp32, - param::MatrixMul::Format::MK4); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP32WinogradF23_4x4_NCHW44, winograd::winograd_F23_mk4_f_nchw44, + megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); /* =================== AlgoFP32WinogradF63_4x4_NCHW44 ===================== */ @@ -259,15 +244,15 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { MEGDNN_MARK_USED_VAR(param); - MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, - midout_iv("AlgoFP32WinogradF63_4x4_NCHW44"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_winograd_fp32, + midout_iv("AlgoFP32WinogradF63_4x4_NCHW44"_hash)) { if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) return false; using Strategy = winograd::winograd_F63_mk4_f_nchw44; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = - megdnn::winograd::ConvBias( + megdnn::winograd::ConvBias( strategy, m_tile_size, param) .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && @@ -279,37 +264,34 @@ bool ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::usable( param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float32 && - param.filter_meta.icpg % 4 == 0 && - param.filter_meta.ocpg % 4 == 0; + param.filter_meta.icpg % 4 == 0 && param.filter_meta.ocpg % 4 == 0; } MIDOUT_END(); return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF63_4x4_NCHW44, - winograd::winograd_F63_mk4_f_nchw44, - megdnn_arm_common_winograd_fp32, - param::MatrixMul::Format::MK4); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP32WinogradF63_4x4_NCHW44, winograd::winograd_F63_mk4_f_nchw44, + megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); /* =================== AlgoFP32WinogradF73_4x4_NCHW44 ===================== */ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { - MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, - midout_iv("AlgoFP32WinogradF73_4x4_NCHW44"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_winograd_fp32, + midout_iv("AlgoFP32WinogradF73_4x4_NCHW44"_hash)) { if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) return false; using Strategy = winograd::winograd_F73_mk4_f_nchw44; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = - megdnn::winograd::ConvBias( + megdnn::winograd::ConvBias( strategy, m_tile_size, param) .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && @@ -321,8 +303,7 @@ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable( param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::Float32; @@ -331,16 +312,15 @@ bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF73_4x4_NCHW44, - winograd::winograd_F73_mk4_f_nchw44, - megdnn_arm_common_winograd_fp32, - param::MatrixMul::Format::MK4); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoFP32WinogradF73_4x4_NCHW44, winograd::winograd_F73_mk4_f_nchw44, + megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); /* ===================== direct algo ===================== */ MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); -bool ConvBiasImpl::AlgoF32Direct::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoF32Direct::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; @@ -352,17 +332,15 @@ bool ConvBiasImpl::AlgoF32Direct::usable(const NCBKernSizeParam& param, return fm.format == param::ConvBias::Format::NCHW && param.src_type.enumv() == DTypeEnum::Float32 && param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32 && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && param.isz[0] * param.isz[1] >= 4 && - param.osz[0] * param.osz[1] >= 4 && FH <= 7 && SH == 1 && - SW == 1; + param.dst_type.enumv() == DTypeEnum::Float32 && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && + param.isz[0] * param.isz[1] >= 4 && param.osz[0] * param.osz[1] >= 4 && + FH <= 7 && SH == 1 && SW == 1; } MIDOUT_END(); return false; } -size_t ConvBiasImpl::AlgoF32Direct::get_workspace( - const NCBKernSizeParam& param) const { +size_t ConvBiasImpl::AlgoF32Direct::get_workspace(const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 0, 1) { bool large_group = param.filter_meta.group >= param.nr_threads; auto wbundle = MultithreadDirectConvCommon::get_bundle( @@ -381,15 +359,15 @@ SmallVector ConvBiasImpl::AlgoF32Direct::get_kimpls( size_t group = fm.group; bool large_group = group >= param.nr_threads; WorkspaceBundle bundle = - MultithreadDirectConvCommon::get_bundle( - param, large_group); + MultithreadDirectConvCommon::get_bundle(param, large_group); SmallVector ret_kerns; //! When group >= nr_threads, treat it as large_group, each thread process //! one group for better performance if (large_group) { //! Channel wise conv and big groups - auto exec_one_group = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto exec_one_group = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { auto fm = kern_param.filter_meta; size_t IC = fm.icpg; size_t OC = fm.ocpg; @@ -403,36 +381,37 @@ SmallVector ConvBiasImpl::AlgoF32Direct::get_kimpls( } for (size_t ic = 0; ic < IC; ic++) { MultithreadDirectConvCommon::copy_padding_kern( - bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { MultithreadDirectConvCommon::do_conv_kern( - bundle, kern_param, ncb_index, - fp32::conv_bias::kern_direct, + bundle, kern_param, ncb_index, fp32::conv_bias::kern_direct, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { if (fm.should_flip) { - auto weight_flip = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto weight_flip = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); MultithreadDirectConvCommon::weight_flip_kern( bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({weight_flip, {group, 1_z, OC}}); } - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); MultithreadDirectConvCommon::copy_padding_kern( bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); - auto do_conv = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto do_conv = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); MultithreadDirectConvCommon::do_conv_kern( bundle, kern_param, ncb_index, fp32::conv_bias::kern_direct, @@ -452,18 +431,17 @@ SmallVector ConvBiasImpl::AlgoF32Direct::dispatch_kerns( return {}; } /* ===================== stride-1 algo ===================== */ -bool ConvBiasImpl::AlgoF32DirectStride1::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoF32DirectStride1::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; return param.filter_meta.format == param::ConvBias::Format::NCHW && param.src_type.enumv() == DTypeEnum::Float32 && param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32 && - !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && - FH == fm.spatial[1] && + param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); } MIDOUT_END(); @@ -474,17 +452,15 @@ size_t ConvBiasImpl::AlgoF32DirectStride1::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 1) { bool large_group = param.filter_meta.group >= param.nr_threads; - auto bundle = - MultithreadDirectConvCommon::get_bundle_stride( - param, large_group); + auto bundle = MultithreadDirectConvCommon::get_bundle_stride( + param, large_group); return bundle.total_size_in_bytes(); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( +SmallVector ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( const NCBKernSizeParam& param) const { auto fm = param.filter_meta; auto FH = fm.spatial[0]; @@ -493,8 +469,9 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( size_t OC = param.filter_meta.ocpg; size_t group = fm.group; bool large_group = group >= param.nr_threads; - using Func = std::function; + using Func = std::function; Func conv_kern_function = nullptr; #define SWITCH_KERN_STR1() \ @@ -530,9 +507,8 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - MultithreadDirectConvCommon:: - copy_padding_kern_stride(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + MultithreadDirectConvCommon::copy_padding_kern_stride( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { MultithreadDirectConvCommon::do_conv_kern_stride( @@ -542,8 +518,9 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); MultithreadDirectConvCommon::copy_padding_kern_stride( bundle, kern_param, ncb_index, ncb_index.ndrange_id); @@ -562,8 +539,7 @@ ConvBiasImpl::AlgoF32DirectStride1::get_kimpls( return ret_kerns; } -SmallVector -ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 1, 2) { return get_kimpls(param); @@ -574,18 +550,17 @@ ConvBiasImpl::AlgoF32DirectStride1::dispatch_kerns( /* ===================== stride-2 algo ===================== */ -bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoF32DirectStride2::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 0) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; return param.filter_meta.format == param::ConvBias::Format::NCHW && param.src_type.enumv() == DTypeEnum::Float32 && param.filter_type.enumv() == DTypeEnum::Float32 && - param.dst_type.enumv() == DTypeEnum::Float32 && - !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && + param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); } MIDOUT_END(); @@ -595,16 +570,14 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 1) { bool large_group = param.filter_meta.group >= param.nr_threads; - auto bundle = - MultithreadDirectConvCommon::get_bundle_stride( - param, large_group); + auto bundle = MultithreadDirectConvCommon::get_bundle_stride( + param, large_group); return bundle.total_size_in_bytes(); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( +SmallVector ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( const NCBKernSizeParam& param) const { auto fm = param.filter_meta; auto FH = fm.spatial[0]; @@ -613,8 +586,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( size_t OC = param.filter_meta.ocpg; size_t group = fm.group; bool large_group = group >= param.nr_threads; - using Func = std::function; + using Func = std::function; Func conv_kern_function = nullptr; #define SWITCH_KERN_STR2() \ @@ -650,9 +624,8 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - MultithreadDirectConvCommon:: - copy_padding_kern_stride(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + MultithreadDirectConvCommon::copy_padding_kern_stride( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { MultithreadDirectConvCommon::do_conv_kern_stride( @@ -662,8 +635,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); MultithreadDirectConvCommon::copy_padding_kern_stride( bundle, kern_param, ncb_index, ncb_index.ndrange_id); @@ -682,8 +656,7 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( return ret_kerns; } -SmallVector -ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( const NCBKernSizeParam& param) const { MIDOUT_BEGIN(megdnn_arm_common_conv_bias_f32_kimpl, 2, 2) { return get_kimpls(param); diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/arm_common/conv_bias/fp32/algos.h index f7935bd8..b2c7d53e 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.h +++ b/dnn/src/arm_common/conv_bias/fp32/algos.h @@ -19,8 +19,8 @@ namespace megdnn { namespace arm_common { class ConvBiasImpl::AlgoFP32WinogradF23_4x4 final : public AlgoBase { public: - AlgoFP32WinogradF23_4x4(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoFP32WinogradF23_4x4( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} const char* name() const override { if (m_name.empty()) { @@ -29,17 +29,15 @@ public: } return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32) }; class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { public: - AlgoFP32WinogradF63(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoFP32WinogradF63( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} const char* name() const override { if (m_name.empty()) { @@ -57,8 +55,8 @@ public: class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { public: - AlgoFP32WinogradF63_4x4(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoFP32WinogradF63_4x4( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} const char* name() const override { if (m_name.empty()) { @@ -67,17 +65,15 @@ public: } return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32) }; class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { public: - AlgoFP32WinogradF54(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoFP32WinogradF54( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} const char* name() const override { if (m_name.empty()) { @@ -95,8 +91,8 @@ public: class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { public: - AlgoFP32WinogradF45(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoFP32WinogradF45( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} const char* name() const override { if (m_name.empty()) { @@ -126,9 +122,7 @@ public: } return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) }; @@ -146,9 +140,7 @@ public: } return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) }; @@ -166,9 +158,7 @@ public: } return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) }; @@ -178,12 +168,11 @@ class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "F32DIRECT"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -198,12 +187,11 @@ class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "F32STRD1"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -218,12 +206,11 @@ class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "F32STRD2"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -239,12 +226,11 @@ class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { public: AlgoF32DirectNCHW44() {} - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "F32_CONV_NCHW44_DIRECT"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -260,12 +246,11 @@ class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { public: AlgoF32DirectNCHWNCHW44() {} - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "F32_CONV_NCHW_NCHW44"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -280,12 +265,11 @@ class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "F32_CHANNEL_WISE_NCHW44"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp index dc5a8798..5d9f1328 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.cpp @@ -63,8 +63,7 @@ static inline void shift_src(float32x4_t rsrc[3][4]) { } template -static inline float32x4_t load_bias(const float* bias, - const float32x4_t& init) { +static inline float32x4_t load_bias(const float* bias, const float32x4_t& init) { if (bias_mode == BiasMode::BIAS) { return vld1q_f32(bias); } else { @@ -75,10 +74,10 @@ static inline float32x4_t load_bias(const float* bias, template struct compute_element { template - static inline void call(const float*& src0, const float*& src1, - const float*& src2, float*& dst, const float*& bias, - const float32x4_t& init, float32x4_t rsrc[3][4], - float32x4_t rfilter[3][3], const Op& op) { + static inline void call( + const float*& src0, const float*& src1, const float*& src2, float*& dst, + const float*& bias, const float32x4_t& init, float32x4_t rsrc[3][4], + float32x4_t rfilter[3][3], const Op& op) { #define RSRC(i, j) rsrc[i][((j) + bw) % 4] float32x4_t rdst = load_bias(bias, init); if (has_top) { @@ -131,9 +130,9 @@ struct compute_element { template struct compute_element_right { template - static inline void call(float*& dst, const float*& bias, - const float32x4_t& init, float32x4_t rsrc[3][4], - float32x4_t rfilter[3][3], const Op& op) { + static inline void call( + float*& dst, const float*& bias, const float32x4_t& init, + float32x4_t rsrc[3][4], float32x4_t rfilter[3][3], const Op& op) { float32x4_t rdst = load_bias(bias, init); if (has_top) { @@ -162,9 +161,9 @@ struct compute_element_right { template struct compute_element_right_pad { template - static inline void call(float*& dst, const float*& bias, - const float32x4_t& init, float32x4_t rsrc[3][4], - float32x4_t rfilter[3][3], const Op& op) { + static inline void call( + float*& dst, const float*& bias, const float32x4_t& init, + float32x4_t rsrc[3][4], float32x4_t rfilter[3][3], const Op& op) { float32x4_t rdst = load_bias(bias, init); if (has_top) { @@ -189,10 +188,10 @@ struct compute_element_right_pad { template struct compute_row { template - static inline void call(const float*& src0, const float*& src1, - const float*& src2, float*& dst, const float*& bias, - const float32x4_t& init, float32x4_t rsrc[3][4], - float32x4_t rfilter[3][3], int W, const Op& op) { + static inline void call( + const float*& src0, const float*& src1, const float*& src2, float*& dst, + const float*& bias, const float32x4_t& init, float32x4_t rsrc[3][4], + float32x4_t rfilter[3][3], int W, const Op& op) { if (has_top) { rsrc[0][0] = vdupq_n_f32(0); rsrc[0][1] = vld1q_f32(src0 + 0); @@ -253,8 +252,8 @@ struct compute_row { template void channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1( - const float* src, float* dst, const float* filter, const float* bias, - int H, int W) { + const float* src, float* dst, const float* filter, const float* bias, int H, + int W) { Op op; float32x4_t init = vdupq_n_f32(0); @@ -279,16 +278,16 @@ void channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1( float32x4_t rsrc[3][4]; - compute_row::call(src0, src1, src2, dst, bias, init, - rsrc, rfilter, W, op); + compute_row::call( + src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op); for (int h = 1; h < H - 1; h += 1) { - compute_row::call(src0, src1, src2, dst, bias, - init, rsrc, rfilter, W, op); + compute_row::call( + src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op); } - compute_row::call(src0, src1, src2, dst, bias, init, - rsrc, rfilter, W, op); + compute_row::call( + src0, src1, src2, dst, bias, init, rsrc, rfilter, W, op); } #define INSTANTIATION(bias, Op) \ diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h index 51669ec2..77e193d8 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_3x3_s1p1_nchw44_kern.h @@ -20,9 +20,9 @@ namespace arm_common { namespace channel_wise_nchw44_float { template -void do_conv_kern_3x3_stride1_padding1(const float* src, float* dst, - const float* filter, const float* bias, - int H, int W); +void do_conv_kern_3x3_stride1_padding1( + const float* src, float* dst, const float* filter, const float* bias, int H, + int W); } // namespace channel_wise_nchw44_float } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp index cb886600..48b65f8a 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.cpp @@ -57,8 +57,7 @@ static inline void load_filter(const float* filter, float32x4_t rfilter[5]) { } template -static inline float32x4_t load_bias(const float* bias, - const float32x4_t& init) { +static inline float32x4_t load_bias(const float* bias, const float32x4_t& init) { if (bias_mode == BiasMode::BIAS) { return vld1q_f32(bias); } else { @@ -66,13 +65,12 @@ static inline float32x4_t load_bias(const float* bias, } } -template +template struct compute_element { template - static inline void call(const float*& src, float*& dst, const float*& bias, - const float32x4_t& init, float32x4_t rsrc[6], - float32x4_t rfilter[5], const Op& op) { + static inline void call( + const float*& src, float*& dst, const float*& bias, const float32x4_t& init, + float32x4_t rsrc[6], float32x4_t rfilter[5], const Op& op) { #define RSRC(i) rsrc[((i) + bw) % 6] float32x4_t rdst; if (need_load_bias) { @@ -96,9 +94,8 @@ struct compute_element { src += 4; dst += 4; bias += 4; - compute_element::call(src, dst, bias, init, rsrc, rfilter, - op); + compute_element::call( + src, dst, bias, init, rsrc, rfilter, op); #undef RSRC } }; @@ -109,13 +106,12 @@ struct compute_element { static inline void call(Types... args) {} }; -template +template struct compute_element_right { template - static inline void call(float*& dst, const float*& bias, - const float32x4_t& init, float32x4_t rsrc[6], - float32x4_t rfilter[5], const Op& op) { + static inline void call( + float*& dst, const float*& bias, const float32x4_t& init, + float32x4_t rsrc[6], float32x4_t rfilter[5], const Op& op) { float32x4_t rdst; if (need_load_bias) { rdst = load_bias(bias, init); @@ -146,9 +142,9 @@ struct compute_element_right { template struct compute_row_src_1x5 { template - static inline void call(const float* src, float* dst, const float* bias, - const float32x4_t& init, float32x4_t rsrc[6], - float32x4_t rfilter[5], int W, const Op& op) { + static inline void call( + const float* src, float* dst, const float* bias, const float32x4_t& init, + float32x4_t rsrc[6], float32x4_t rfilter[5], int W, const Op& op) { rsrc[0] = vdupq_n_f32(0); rsrc[1] = vdupq_n_f32(0); rsrc[2] = vld1q_f32(src + 0); @@ -192,10 +188,10 @@ struct compute_row_src_1x5 { template struct compute_row { template - static inline void call(const float*& src, float*& dst, const float* filter, - const float*& bias, const float32x4_t& init, - float32x4_t rsrc[6], float32x4_t rfilter[5], int W, - const Op& op) { + static inline void call( + const float*& src, float*& dst, const float* filter, const float*& bias, + const float32x4_t& init, float32x4_t rsrc[6], float32x4_t rfilter[5], int W, + const Op& op) { if (top_padding < 1) { load_filter(filter + 0, rfilter); compute_row_src_1x5::call( @@ -210,10 +206,8 @@ struct compute_row { { load_filter(filter + 40, rfilter); - compute_row_src_1x5::call(src, dst, bias, init, - rsrc, rfilter, W, - op); + compute_row_src_1x5::call( + src, dst, bias, init, rsrc, rfilter, W, op); } if (bottom_padding < 2) { @@ -237,8 +231,8 @@ struct compute_row { template void channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2( - const float* src, float* dst, const float* filter, const float* bias, - int H, int W) { + const float* src, float* dst, const float* filter, const float* bias, int H, + int W) { Op op; float32x4_t init = vdupq_n_f32(0); @@ -249,18 +243,18 @@ void channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2( float32x4_t rsrc[6]; float32x4_t rfilter[5]; - compute_row<2, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc, - rfilter, W, op); - compute_row<1, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc, - rfilter, W, op); + compute_row<2, 0, bias_mode>::call( + src, dst, filter, bias, init, rsrc, rfilter, W, op); + compute_row<1, 0, bias_mode>::call( + src, dst, filter, bias, init, rsrc, rfilter, W, op); for (int h = 2; h < H - 2; h += 1) { - compute_row<0, 0, bias_mode>::call(src, dst, filter, bias, init, rsrc, - rfilter, W, op); + compute_row<0, 0, bias_mode>::call( + src, dst, filter, bias, init, rsrc, rfilter, W, op); } - compute_row<0, 1, bias_mode>::call(src, dst, filter, bias, init, rsrc, - rfilter, W, op); - compute_row<0, 2, bias_mode>::call(src, dst, filter, bias, init, rsrc, - rfilter, W, op); + compute_row<0, 1, bias_mode>::call( + src, dst, filter, bias, init, rsrc, rfilter, W, op); + compute_row<0, 2, bias_mode>::call( + src, dst, filter, bias, init, rsrc, rfilter, W, op); } #define INSTANTIATION(bias, Op) \ diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h index 28b04380..d3bd5fc3 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_5x5_s1p2_nchw44_kern.h @@ -20,9 +20,9 @@ namespace arm_common { namespace channel_wise_nchw44_float { template -void do_conv_kern_5x5_stride1_padding2(const float* src, float* dst, - const float* filter, const float* bias, - int H, int W); +void do_conv_kern_5x5_stride1_padding2( + const float* src, float* dst, const float* filter, const float* bias, int H, + int W); } // namespace channel_wise_nchw44_float } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp index 13235461..6f717cba 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_algo.cpp @@ -32,9 +32,10 @@ bool ConvBiasImpl::AlgoF32ChannelWiseNCHW44::usable( size_t OC = fm.ocpg; size_t IC = fm.icpg; size_t GROUP = fm.group; - bool ok_type = (param.src_type.enumv() == DTypeEnum::Float32 && - param.filter_type.enumv() == DTypeEnum::Float32 && - (param.dst_type.enumv() == DTypeEnum::Float32)); + bool ok_type = + (param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_type.enumv() == DTypeEnum::Float32 && + (param.dst_type.enumv() == DTypeEnum::Float32)); bool ok_format = OC == 1 && IC == 1 && GROUP % 4 == 0 && fm.format == param::Convolution::Format::NCHW44; bool ok_filter = fm.spatial_ndim == 2 && FH == fm.spatial[1] && @@ -52,9 +53,8 @@ size_t ConvBiasImpl::AlgoF32ChannelWiseNCHW44::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoF32ChannelWiseNCHW44::dispatch_kerns( - const NCBKernSizeParam& param) const { +SmallVector ConvBiasImpl::AlgoF32ChannelWiseNCHW44:: + dispatch_kerns(const NCBKernSizeParam& param) const { const constexpr size_t pack_group_size = 4_z; auto fm = param.filter_meta; const int batch = param.n; @@ -65,31 +65,31 @@ ConvBiasImpl::AlgoF32ChannelWiseNCHW44::dispatch_kerns( // NOTE: remain_w is not used to gen hash of midout for compatible with // shape runtime #define DO_CONV_KERN_FUN(_stride, filter, bias_mode, op) \ - MIDOUT_BEGIN(conv_bias_fp32_channel_wise_nchw44, \ - midout_iv(#_stride #filter #bias_mode #op##_hash)) { \ + MIDOUT_BEGIN( \ + conv_bias_fp32_channel_wise_nchw44, \ + midout_iv(#_stride #filter #bias_mode #op##_hash)) { \ do_conv_fun = channel_wise_nchw44_float:: \ do_conv_kern_##_stride##_##filter##x##filter; \ } \ MIDOUT_END(); -#define GET_OP_PARAM(_stride, filter, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(_stride, filter, bias_mode, NoneOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(_stride, filter, bias_mode, ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::SIGMOID: \ - DO_CONV_KERN_FUN(_stride, filter, bias_mode, \ - SigmoidOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(_stride, filter, bias_mode, HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(_stride, filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(_stride, filter, bias_mode, NoneOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(_stride, filter, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::SIGMOID: \ + DO_CONV_KERN_FUN(_stride, filter, bias_mode, SigmoidOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(_stride, filter, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(_stride, filter) \ @@ -144,10 +144,11 @@ ConvBiasImpl::AlgoF32ChannelWiseNCHW44::dispatch_kerns( SmallVector ret_kerns; - CpuNDRange ncb_range = {static_cast(batch), - static_cast(group / pack_group_size)}; - auto do_conv = [do_conv_fun](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { + CpuNDRange ncb_range = { + static_cast(batch), static_cast(group / pack_group_size)}; + auto do_conv = [do_conv_fun]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { size_t PH = kern_param.filter_meta.padding[0]; size_t PW = kern_param.filter_meta.padding[1]; size_t OH = kern_param.osz[0]; @@ -160,8 +161,7 @@ ConvBiasImpl::AlgoF32ChannelWiseNCHW44::dispatch_kerns( const float* sptr = kern_param.src(batch_id, group_id, 0, pack_group_size); const float* fptr = kern_param.filter(group_id, pack_group_size); - float* dst = - kern_param.dst(batch_id, group_id, 0, pack_group_size); + float* dst = kern_param.dst(batch_id, group_id, 0, pack_group_size); const float* bptr = kern_param.bias(batch_id, group_id, 0, pack_group_size); //! copy in case of illegal read src when padding is zero @@ -171,4 +171,4 @@ ConvBiasImpl::AlgoF32ChannelWiseNCHW44::dispatch_kerns( return ret_kerns; } -//vim: syntax=cpp.doxygen +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp index bb849a40..4acda6fc 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.cpp @@ -49,11 +49,11 @@ template void compute_vec(float32x4_t& dst, float32x4_t* src, float32x4_t* filter); #define cb(i) dst = vmlaq_f32(dst, src[i], filter[i]); -#define COMPUTE_MACRO(n) \ - template <> \ - inline void compute_vec(float32x4_t & dst, float32x4_t * src, \ - float32x4_t * filter) { \ - UNROLL_CALL_NOWRAPPER(n, cb); \ +#define COMPUTE_MACRO(n) \ + template <> \ + inline void compute_vec( \ + float32x4_t & dst, float32x4_t * src, float32x4_t * filter) { \ + UNROLL_CALL_NOWRAPPER(n, cb); \ } COMPUTE_MACRO(2); COMPUTE_MACRO(3); @@ -67,17 +67,17 @@ struct load_bias_vec; #define cb_bias(i) dst[i] = vld1q_f32((bptr) + i * 4); #define cb_init(i) dst[i] = init; -#define INIT_BIAS_MACRO(n) \ - template \ - struct load_bias_vec { \ - static void impl(float32x4_t* dst, const float32x4_t& init, \ - const float* bptr) { \ - if (bias_mode == BiasMode::BIAS) { \ - UNROLL_CALL_NOWRAPPER(n, cb_bias); \ - } else { \ - UNROLL_CALL_NOWRAPPER(n, cb_init); \ - } \ - } \ +#define INIT_BIAS_MACRO(n) \ + template \ + struct load_bias_vec { \ + static void impl( \ + float32x4_t* dst, const float32x4_t& init, const float* bptr) { \ + if (bias_mode == BiasMode::BIAS) { \ + UNROLL_CALL_NOWRAPPER(n, cb_bias); \ + } else { \ + UNROLL_CALL_NOWRAPPER(n, cb_init); \ + } \ + } \ }; INIT_BIAS_MACRO(1); @@ -88,34 +88,32 @@ INIT_BIAS_MACRO(4); #undef INIT_BIAS_MACRO } // namespace -#define COMPUTE_PADDING_KERNEL() \ - do { \ - int iw = ow * stride - PW; \ - float32x4_t result; \ - load_bias_vec::impl(&result, init, \ - bias + oh * OW * 4 + ow * 4); \ - for (int kh = 0; kh < fh; kh++) { \ - if (kh + ih < 0 || kh + ih >= static_cast(IH)) \ - continue; \ - for (int kw = 0; kw < fh; kw++) { \ - if (kw + iw < 0 || kw + iw >= static_cast(IW)) \ - continue; \ - const float* sptr = src + (kh + ih) * IW * 4 + (kw + iw) * 4; \ - result = vmlaq_f32(result, kernel[kh * fh + kw], \ - vld1q_f32(sptr)); \ - } \ - } \ - float* output = dst + oh * OW * 4 + ow * 4; \ - op(result, output); \ +#define COMPUTE_PADDING_KERNEL() \ + do { \ + int iw = ow * stride - PW; \ + float32x4_t result; \ + load_bias_vec::impl(&result, init, bias + oh * OW * 4 + ow * 4); \ + for (int kh = 0; kh < fh; kh++) { \ + if (kh + ih < 0 || kh + ih >= static_cast(IH)) \ + continue; \ + for (int kw = 0; kw < fh; kw++) { \ + if (kw + iw < 0 || kw + iw >= static_cast(IW)) \ + continue; \ + const float* sptr = src + (kh + ih) * IW * 4 + (kw + iw) * 4; \ + result = vmlaq_f32(result, kernel[kh * fh + kw], vld1q_f32(sptr)); \ + } \ + } \ + float* output = dst + oh * OW * 4 + ow * 4; \ + op(result, output); \ } while (0) template struct PaddingCompute { - static void compute(const float* src, const float* bias, float* dst, - const int fh, const int stride, const size_t IH, - const size_t IW, const size_t OH, const size_t OW, - const size_t PH, const size_t PW, - const float32x4_t* kernel, const float32x4_t& init) { + static void compute( + const float* src, const float* bias, float* dst, const int fh, + const int stride, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const size_t PH, const size_t PW, + const float32x4_t* kernel, const float32x4_t& init) { size_t oh_start = (PH + stride - 1) / stride; size_t ow_start = (PW + stride - 1) / stride; size_t oh_end = (IH + PH - fh) / stride + 1; @@ -147,17 +145,18 @@ struct PaddingCompute { template struct PaddingComputeK3P1 { - static void compute(const float* src, const float* bias, float* dst, - const size_t stride, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, - const float32x4_t* kernel, const float32x4_t& init) { + static void compute( + const float* src, const float* bias, float* dst, const size_t stride, + const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const float32x4_t* kernel, const float32x4_t& init) { constexpr size_t PH = 1, PW = 1, FH = 3; size_t oh_start = (PH + stride - 1) / stride; size_t ow_start = (PW + stride - 1) / stride; size_t oh_end = (IH + PH - FH) / stride + 1; size_t ow_end = (IW + PW - FH) / stride + 1; - megdnn_assert(oh_start == ow_start && oh_start == 1, - "channel wise padding param error"); + megdnn_assert( + oh_start == ow_start && oh_start == 1, + "channel wise padding param error"); megdnn_assert(ow_end == OW - 1 || ow_end == OW, "padding PW error"); megdnn_assert(oh_end == OH - 1 || oh_end == OH, "padding PH error"); Op op; @@ -190,8 +189,7 @@ struct PaddingComputeK3P1 { // line one right if (OW != ow_end) { float32x4_t result; - load_bias_vec::impl(&result, init, - bias + (OW - 1) * 4); + load_bias_vec::impl(&result, init, bias + (OW - 1) * 4); const float* sptr = src + (ow_end * stride - PW) * 4; result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr)); result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + 4)); @@ -206,18 +204,14 @@ struct PaddingComputeK3P1 { // left { float32x4_t result; - load_bias_vec::impl(&result, init, - bias + oh * OW * 4); + load_bias_vec::impl(&result, init, bias + oh * OW * 4); const float* sptr = src + ih * IW * 4; result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr)); result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4)); result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[5], - vld1q_f32(sptr + IW * 4 + 4)); - result = vmlaq_f32(result, kernel[7], - vld1q_f32(sptr + 2 * IW * 4)); - result = vmlaq_f32(result, kernel[8], - vld1q_f32(sptr + 2 * IW * 4 + 4)); + result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 4)); + result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + 2 * IW * 4)); + result = vmlaq_f32(result, kernel[8], vld1q_f32(sptr + 2 * IW * 4 + 4)); float* output = dst + oh * OW * 4; op(result, output); } @@ -226,17 +220,13 @@ struct PaddingComputeK3P1 { float32x4_t result; load_bias_vec::impl( &result, init, bias + oh * OW * 4 + (OW - 1) * 4); - const float* sptr = - src + ih * IW * 4 + (ow_end * stride - PW) * 4; + const float* sptr = src + ih * IW * 4 + (ow_end * stride - PW) * 4; result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[4], - vld1q_f32(sptr + IW * 4 + 4)); - result = vmlaq_f32(result, kernel[6], - vld1q_f32(sptr + 2 * IW * 4)); - result = vmlaq_f32(result, kernel[7], - vld1q_f32(sptr + 2 * IW * 4 + 4)); + result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); + result = vmlaq_f32(result, kernel[6], vld1q_f32(sptr + 2 * IW * 4)); + result = vmlaq_f32(result, kernel[7], vld1q_f32(sptr + 2 * IW * 4 + 4)); float* output = dst + oh * OW * 4 + ow_end * 4; op(result, output); } @@ -246,14 +236,12 @@ struct PaddingComputeK3P1 { size_t oh = OH - 1; { float32x4_t result; - load_bias_vec::impl(&result, init, - bias + oh * OW * 4); + load_bias_vec::impl(&result, init, bias + oh * OW * 4); const float* sptr = src + (oh_end * stride - PH) * IW * 4; result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr)); result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 4)); result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[5], - vld1q_f32(sptr + IW * 4 + 4)); + result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 4)); float* output = dst + oh_end * OW * 4; op(result, output); } @@ -261,18 +249,15 @@ struct PaddingComputeK3P1 { for (size_t ow = ow_start; ow < ow_end; ow++) { int iw = ow * stride - PW; float32x4_t result; - load_bias_vec::impl(&result, init, - bias + oh * OW * 4 + ow * 4); - const float* sptr = - src + (oh_end * stride - PH) * IW * 4 + iw * 4; + load_bias_vec::impl( + &result, init, bias + oh * OW * 4 + ow * 4); + const float* sptr = src + (oh_end * stride - PH) * IW * 4 + iw * 4; result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); result = vmlaq_f32(result, kernel[2], vld1q_f32(sptr + 8)); result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[4], - vld1q_f32(sptr + IW * 4 + 4)); - result = vmlaq_f32(result, kernel[5], - vld1q_f32(sptr + IW * 4 + 8)); + result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); + result = vmlaq_f32(result, kernel[5], vld1q_f32(sptr + IW * 4 + 8)); float* output = dst + oh_end * OW * 4 + ow * 4; op(result, output); } @@ -286,8 +271,7 @@ struct PaddingComputeK3P1 { result = vmlaq_f32(result, kernel[0], vld1q_f32(sptr)); result = vmlaq_f32(result, kernel[1], vld1q_f32(sptr + 4)); result = vmlaq_f32(result, kernel[3], vld1q_f32(sptr + IW * 4)); - result = vmlaq_f32(result, kernel[4], - vld1q_f32(sptr + IW * 4 + 4)); + result = vmlaq_f32(result, kernel[4], vld1q_f32(sptr + IW * 4 + 4)); float* output = dst + oh_end * OW * 4 + ow_end * 4; op(result, output); } @@ -314,8 +298,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( size_t oh_end = IH + PH - 1; size_t ow_end = IW + PW - 1; if (PH || PW) { - PaddingCompute::compute(src, bias, dst, 2, 1, IH, IW, OH, - OW, PH, PW, kernel, init); + PaddingCompute::compute( + src, bias, dst, 2, 1, IH, IW, OH, OW, PH, PW, kernel, init); } #define COMPUTE_2X2(dst, src, kernel) \ compute_vec<2>(dst[0], &src[0], kernel); \ @@ -332,8 +316,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[2][4]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); float32x4_t src_v[3][5]; @@ -355,8 +339,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[2]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); float32x4_t src_v[3][2]; @@ -380,8 +364,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[1][4]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 4 + ow * 4); float32x4_t src_v[2][5]; load_vec<5>(src_v[0], input); COMPUTE_2X2(dst_v[0], src_v[0], &kernel[0]); @@ -396,8 +380,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_2x2( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v; - load_bias_vec::impl(&dst_v, init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v, init, bias + oh * OW * 4 + ow * 4); float32x4_t src_v[2][2]; load_vec<2>(src_v[0], input); compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); @@ -416,8 +400,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( const size_t IH, const size_t IW, const size_t OH, const size_t OW, const size_t PH, const size_t PW) { if (IH == OH && IW == OW && IH >= 3 && IW >= 3 && PH == 1 && PW == 1) { - channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1( + channel_wise_nchw44_float::do_conv_kern_3x3_stride1_padding1( src, dst, filter, bias, OH, OW); return; } @@ -434,8 +417,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( size_t oh_end = IH + PH - 2; size_t ow_end = IW + PW - 2; if (PH || PW) { - PaddingCompute::compute(src, bias, dst, 3, 1, IH, IW, OH, - OW, PH, PW, kernel, init); + PaddingCompute::compute( + src, bias, dst, 3, 1, IH, IW, OH, OW, PH, PW, kernel, init); } size_t oh = oh_start; for (; oh + 1 < oh_end; oh += 2) { @@ -446,8 +429,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[2][4]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); float32x4_t src_v[2][6]; @@ -490,8 +473,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[2]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); float32x4_t src_v[2][3]; @@ -518,8 +501,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[4]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 4 + ow * 4); float32x4_t src_v[2][6]; load_vec<6>(src_v[0], input); compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); @@ -544,8 +527,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_3x3( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v; - load_bias_vec::impl(&dst_v, init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v, init, bias + oh * OW * 4 + ow * 4); float32x4_t src_v[3][3]; load_vec<3>(src_v[0], input); compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); @@ -564,8 +547,7 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( const size_t IH, const size_t IW, const size_t OH, const size_t OW, const size_t PH, const size_t PW) { if (IH == OH && IW == OW && IH >= 5 && IW >= 5 && PH == 2 && PW == 2) { - channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2( + channel_wise_nchw44_float::do_conv_kern_5x5_stride1_padding2( src, dst, filter, bias, OH, OW); return; } @@ -593,8 +575,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[2][2]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); float32x4_t kernel[2][5]; @@ -632,8 +614,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[2][1]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); float32x4_t kernel[2][5]; @@ -671,8 +653,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[1][2]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 4 + ow * 4); float32x4_t kernel[2][5]; float32x4_t src_v[2][6]; #define COMPUTE_5X5_2(i, dst, src, kernel) \ @@ -698,8 +680,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride1_5x5( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v; - load_bias_vec::impl(&dst_v, init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v, init, bias + oh * OW * 4 + ow * 4); float32x4_t kernel[2][5]; float32x4_t src_v[2][5]; #define COMPUTE_5X5_1(i, dst, src, kernel) \ @@ -739,8 +721,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( size_t oh_end = (IH + PH) / 2; size_t ow_end = (IW + PW) / 2; if (PH || PW) { - PaddingCompute::compute(src, bias, dst, 2, 2, IH, IW, OH, - OW, PH, PW, kernel, init); + PaddingCompute::compute( + src, bias, dst, 2, 2, IH, IW, OH, OW, PH, PW, kernel, init); } #define COMPUTE_2X2(dst, src, kernel) \ compute_vec<2>(dst[0], &src[0], kernel); \ @@ -756,8 +738,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[4]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 4 + ow * 4); float32x4_t src_v[2][8]; load_vec<8>(src_v[0], input); COMPUTE_2X2(dst_v, src_v[0], &kernel[0]); @@ -772,8 +754,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_2x2( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v; - load_bias_vec::impl(&dst_v, init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v, init, bias + oh * OW * 4 + ow * 4); float32x4_t src_v[2][2]; load_vec<2>(src_v[0], input); compute_vec<2>(dst_v, &src_v[0][0], &kernel[0]); @@ -803,11 +785,11 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( size_t oh_end = (IH + PH - 3) / 2 + 1; size_t ow_end = (IW + PW - 3) / 2 + 1; if (PH == 1 && PW == 1) { - PaddingComputeK3P1::compute(src, bias, dst, 2, IH, IW, - OH, OW, kernel, init); + PaddingComputeK3P1::compute( + src, bias, dst, 2, IH, IW, OH, OW, kernel, init); } else if (PH || PW) { - PaddingCompute::compute(src, bias, dst, 3, 2, IH, IW, OH, - OW, PH, PW, kernel, init); + PaddingCompute::compute( + src, bias, dst, 3, 2, IH, IW, OH, OW, PH, PW, kernel, init); } size_t oh = oh_start; for (; oh + 1 < oh_end; oh += 2) { @@ -818,8 +800,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[2][2]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); float32x4_t src_v[2][5]; @@ -849,8 +831,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[2]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); float32x4_t src_v[2][3]; @@ -878,8 +860,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[2]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 4 + ow * 4); float32x4_t src_v[3][5]; load_vec<5>(src_v[0], input); compute_vec<3>(dst_v[0], &src_v[0][0], &kernel[0]); @@ -897,8 +879,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_3x3( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v; - load_bias_vec::impl(&dst_v, init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v, init, bias + oh * OW * 4 + ow * 4); float32x4_t src_v[3][3]; load_vec<3>(src_v[0], input); compute_vec<3>(dst_v, &src_v[0][0], &kernel[0]); @@ -940,8 +922,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[2][2]; - load_bias_vec::impl(dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); float32x4_t kernel[3][5]; @@ -984,8 +966,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v[2]; - load_bias_vec::impl(&dst_v[0], init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v[0], init, bias + oh * OW * 4 + ow * 4); load_bias_vec::impl( &dst_v[1], init, bias + (oh + 1) * OW * 4 + ow * 4); float32x4_t kernel[3][5]; @@ -1029,8 +1011,8 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( const float* input = src + ih * IW * 4 + iw * 4; float* output = dst + oh * OW * 4 + ow * 4; float32x4_t dst_v; - load_bias_vec::impl(&dst_v, init, - bias + oh * OW * 4 + ow * 4); + load_bias_vec::impl( + &dst_v, init, bias + oh * OW * 4 + ow * 4); float32x4_t kernel[2][5]; float32x4_t src_v[2][5]; #define COMPUTE_5X5_1(i, dst, src, kernel) \ @@ -1053,13 +1035,12 @@ void channel_wise_nchw44_float::do_conv_kern_stride2_5x5( } } -#define INSTANTIATION(stride, i, bias, Op) \ - template void \ - channel_wise_nchw44_float::do_conv_kern_##stride##_##i##x##i( \ - const float*, const float*, const float*, float*, \ - const size_t, const size_t, const size_t, const size_t, \ - const size_t, const size_t); +#define INSTANTIATION(stride, i, bias, Op) \ + template void \ + channel_wise_nchw44_float::do_conv_kern_##stride##_##i##x##i( \ + const float*, const float*, const float*, float*, const size_t, \ + const size_t, const size_t, const size_t, const size_t, \ + const size_t); #define FOR_OP(stride, i, bias) \ INSTANTIATION(stride, i, bias, SigmoidOp) \ diff --git a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h index 01a581a8..617241bd 100644 --- a/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/fp32/channel_wise_nchw44_kern.h @@ -19,12 +19,12 @@ namespace megdnn { namespace arm_common { namespace channel_wise_nchw44_float { -#define KERN(stride, i) \ - template \ - void do_conv_kern_##stride##_##i##x##i( \ - const float* src, const float* filter, const float* bias, \ - float* dst, const size_t IH, const size_t IW, const size_t OH, \ - const size_t OW, const size_t PH, const size_t PW); +#define KERN(stride, i) \ + template \ + void do_conv_kern_##stride##_##i##x##i( \ + const float* src, const float* filter, const float* bias, float* dst, \ + const size_t IH, const size_t IW, const size_t OH, const size_t OW, \ + const size_t PH, const size_t PW); KERN(stride1, 2) KERN(stride1, 3) diff --git a/dnn/src/arm_common/conv_bias/fp32/direct.cpp b/dnn/src/arm_common/conv_bias/fp32/direct.cpp index 374c31f9..c4300f1f 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct.cpp @@ -9,14 +9,14 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "src/arm_common/conv_bias/fp32/direct.h" #include #include "include/megdnn/oprs.h" #include "midout.h" -#include "src/arm_common/conv_bias/fp32/direct.h" -#include "src/arm_common/simd_macro/marm_neon.h" -#include "src/common/utils.h" #include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" +#include "src/common/utils.h" MIDOUT_DECL(megdnn_arm_conv_f32) using namespace megdnn; @@ -28,9 +28,10 @@ namespace { template struct do_pixel_proxy { - static void exec(const float* src, const float* filter, float* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow); + static void exec( + const float* src, const float* filter, float* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow); }; #define cb_load(i) data = vld1q_lane_f32(dst + i, data, i); @@ -96,9 +97,10 @@ struct do_pixel_proxy { template struct do_pixel_proxy<1, height, width> { - static void exec(const float* src, const float* filter, float* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const float* src, const float* filter, float* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { (void)IH; (void)OH; const int ih = oh, iw = ow; @@ -136,9 +138,10 @@ struct do_pixel_proxy<1, height, width> { template struct do_pixel_proxy<2, height, width> { - static void exec(const float* src, const float* filter, float* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const float* src, const float* filter, float* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { (void)IH; (void)OH; const int ih = oh, iw = ow; @@ -188,9 +191,10 @@ struct do_pixel_proxy<2, height, width> { template struct do_pixel_proxy<3, height, width> { - static void exec(const float* src, const float* filter, float* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const float* src, const float* filter, float* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { (void)IH; (void)OH; const int ih = oh, iw = ow; @@ -252,9 +256,10 @@ struct do_pixel_proxy<3, height, width> { template struct do_pixel_proxy<4, height, width> { - static void exec(const float* src, const float* filter, float* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const float* src, const float* filter, float* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { (void)IH; (void)OH; const int ih = oh, iw = ow; @@ -328,14 +333,14 @@ struct do_pixel_proxy<4, height, width> { template struct do_pixel_proxy<5, height, width> { - static void exec(const float* src, const float* filter, float* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const float* src, const float* filter, float* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { (void)IH; (void)OH; const int ih = oh, iw = ow; - float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, - inp; + float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, inp; src += ih * IW + iw; dst += oh * OW + ow; LOAD_OUT; @@ -417,14 +422,15 @@ struct do_pixel_proxy<5, height, width> { template struct do_pixel_proxy<6, height, width> { - static void exec(const float* src, const float* filter, float* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const float* src, const float* filter, float* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { (void)IH; (void)OH; const int ih = oh, iw = ow; - float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, - kr5, inp; + float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, + inp; src += ih * IW + iw; dst += oh * OW + ow; LOAD_OUT; @@ -518,14 +524,15 @@ struct do_pixel_proxy<6, height, width> { template struct do_pixel_proxy<7, height, width> { - static void exec(const float* src, const float* filter, float* dst, - const int IH, const int IW, const int OH, const int OW, - const int FW, const int oh, const int ow) { + static void exec( + const float* src, const float* filter, float* dst, const int IH, + const int IW, const int OH, const int OW, const int FW, const int oh, + const int ow) { (void)IH; (void)OH; const int ih = oh, iw = ow; - float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, - kr5, kr6, inp; + float32x4_t out0{0}, out1{0}, out2{0}, out3{0}, kr0, kr1, kr2, kr3, kr4, kr5, + kr6, inp; src += ih * IW + iw; dst += oh * OW + ow; LOAD_OUT; @@ -634,17 +641,17 @@ struct do_pixel_proxy<7, height, width> { #undef STORE_OUT template -void do_pixel(const float* src, const float* filter, float* dst, const int IH, - const int IW, const int OH, const int OW, const int FW, - const int oh, const int ow) { - do_pixel_proxy::exec(src, filter, dst, IH, IW, OH, OW, - FW, oh, ow); +void do_pixel( + const float* src, const float* filter, float* dst, const int IH, const int IW, + const int OH, const int OW, const int FW, const int oh, const int ow) { + do_pixel_proxy::exec( + src, filter, dst, IH, IW, OH, OW, FW, oh, ow); } template -void do_conv_tpl_enable_prefetch(const float* src, const float* filter, - float* dst, const int IH, const int IW, - const int OH, const int OW, const int FW) { +void do_conv_tpl_enable_prefetch( + const float* src, const float* filter, float* dst, const int IH, const int IW, + const int OH, const int OW, const int FW) { const int hbeg = 0, hend = OH; const int wbeg = 0, wend = OW; int i, j; @@ -652,13 +659,11 @@ void do_conv_tpl_enable_prefetch(const float* src, const float* filter, for (j = wbeg; j + 4 <= wend; j += 4) { // do prefetch const int prefetch_index_input = - (j + 16) < wend - ? i * IW + j + 16 - : (i + 4) * IW + (((j + 16 - wend) >> 2) << 2); + (j + 16) < wend ? i * IW + j + 16 + : (i + 4) * IW + (((j + 16 - wend) >> 2) << 2); const int prefetch_index_output = - (j + 16) < wend - ? i * OW + j + 16 - : (i + 4) * OW + (((j + 16 - wend) >> 2) << 2); + (j + 16) < wend ? i * OW + j + 16 + : (i + 4) * OW + (((j + 16 - wend) >> 2) << 2); const float* src_prefetch = src + prefetch_index_input; const float* dst_prefetch = dst + prefetch_index_output; for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { @@ -699,51 +704,47 @@ void do_conv_tpl_enable_prefetch(const float* src, const float* filter, #undef DISPATCH } -#define DISPATCH2(height, width) \ - do { \ - const int prefetch_index_input = IH * IW + 12; \ - const float* src_prefetch = src + prefetch_index_input; \ - for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ - __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ - } \ - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ - j); \ +#define DISPATCH2(height, width) \ + do { \ + const int prefetch_index_input = IH * IW + 12; \ + const float* src_prefetch = src + prefetch_index_input; \ + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ + } \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ } while (0) -#define DISPATCH1(height) \ - do { \ - for (j = wbeg; j + 4 <= wend; j += 4) { \ - const int prefetch_index_input = \ - (j + 16) < wend \ - ? i * IW + j + 16 \ - : (i + 4) * IW + (((j + 16 - wend) >> 2) << 2); \ - const int prefetch_index_output = \ - (j + 16) < wend \ - ? i * OW + j + 16 \ - : (i + 4) * OW + (((j + 16 - wend) >> 2) << 2); \ - const float* src_prefetch = src + prefetch_index_input; \ - const float* dst_prefetch = dst + prefetch_index_output; \ - for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ - __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ - } \ - __builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); \ - __builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); \ - __builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); \ - __builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); \ - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ - j); \ - } \ - switch (wend - j) { \ - case 1: \ - DISPATCH2(height, 1); \ - break; \ - case 2: \ - DISPATCH2(height, 2); \ - break; \ - case 3: \ - DISPATCH2(height, 3); \ - break; \ - } \ +#define DISPATCH1(height) \ + do { \ + for (j = wbeg; j + 4 <= wend; j += 4) { \ + const int prefetch_index_input = \ + (j + 16) < wend ? i * IW + j + 16 \ + : (i + 4) * IW + (((j + 16 - wend) >> 2) << 2); \ + const int prefetch_index_output = \ + (j + 16) < wend ? i * OW + j + 16 \ + : (i + 4) * OW + (((j + 16 - wend) >> 2) << 2); \ + const float* src_prefetch = src + prefetch_index_input; \ + const float* dst_prefetch = dst + prefetch_index_output; \ + for (int iw_id = 0; iw_id < FH + 3; ++iw_id) { \ + __builtin_prefetch(src_prefetch + iw_id * IW, 0, 3); \ + } \ + __builtin_prefetch(dst_prefetch + 0 * OW, 1, 3); \ + __builtin_prefetch(dst_prefetch + 1 * OW, 1, 3); \ + __builtin_prefetch(dst_prefetch + 2 * OW, 1, 3); \ + __builtin_prefetch(dst_prefetch + 3 * OW, 1, 3); \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ + } \ + switch (wend - j) { \ + case 1: \ + DISPATCH2(height, 1); \ + break; \ + case 2: \ + DISPATCH2(height, 2); \ + break; \ + case 3: \ + DISPATCH2(height, 3); \ + break; \ + } \ } while (0) switch (hend - i) { case 1: @@ -760,9 +761,9 @@ void do_conv_tpl_enable_prefetch(const float* src, const float* filter, #undef DISPATCH2 } template -void do_conv_tpl_disable_prefetch(const float* src, const float* filter, - float* dst, const int IH, const int IW, - const int OH, const int OW, const int FW) { +void do_conv_tpl_disable_prefetch( + const float* src, const float* filter, float* dst, const int IH, const int IW, + const int OH, const int OW, const int FW) { const int hbeg = 0, hend = OH; const int wbeg = 0, wend = OW; int i, j; @@ -787,28 +788,26 @@ void do_conv_tpl_disable_prefetch(const float* src, const float* filter, } #undef DISPATCH } -#define DISPATCH2(height, width) \ - do { \ - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ - j); \ +#define DISPATCH2(height, width) \ + do { \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ } while (0) -#define DISPATCH1(height) \ - do { \ - for (j = wbeg; j + 4 <= wend; j += 4) { \ - do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, \ - j); \ - } \ - switch (wend - j) { \ - case 1: \ - DISPATCH2(height, 1); \ - break; \ - case 2: \ - DISPATCH2(height, 2); \ - break; \ - case 3: \ - DISPATCH2(height, 3); \ - break; \ - } \ +#define DISPATCH1(height) \ + do { \ + for (j = wbeg; j + 4 <= wend; j += 4) { \ + do_pixel(src, filter, dst, IH, IW, OH, OW, FW, i, j); \ + } \ + switch (wend - j) { \ + case 1: \ + DISPATCH2(height, 1); \ + break; \ + case 2: \ + DISPATCH2(height, 2); \ + break; \ + case 3: \ + DISPATCH2(height, 3); \ + break; \ + } \ } while (0) switch (hend - i) { case 1: @@ -826,15 +825,14 @@ void do_conv_tpl_disable_prefetch(const float* src, const float* filter, } } // anonymous namespace -void conv_bias::kern_direct(const float* src, const float* filter, float* dst, - const int IH, const int IW, const int OH, - const int OW, const int FH, const int FW) { +void conv_bias::kern_direct( + const float* src, const float* filter, float* dst, const int IH, const int IW, + const int OH, const int OW, const int FH, const int FW) { megdnn_assert_internal(FH <= 7); if (IH > 100 && IW > 100) { -#define GAO(FH) \ - do { \ - return do_conv_tpl_enable_prefetch(src, filter, dst, IH, IW, OH, \ - OW, FW); \ +#define GAO(FH) \ + do { \ + return do_conv_tpl_enable_prefetch(src, filter, dst, IH, IW, OH, OW, FW); \ } while (0) switch (FH) { case 1: @@ -868,10 +866,9 @@ void conv_bias::kern_direct(const float* src, const float* filter, float* dst, } #undef GAO } else { -#define GAO(FH) \ - do { \ - return do_conv_tpl_disable_prefetch(src, filter, dst, IH, IW, OH, \ - OW, FW); \ +#define GAO(FH) \ + do { \ + return do_conv_tpl_disable_prefetch(src, filter, dst, IH, IW, OH, OW, FW); \ } while (0) switch (FH) { case 1: diff --git a/dnn/src/arm_common/conv_bias/fp32/direct.h b/dnn/src/arm_common/conv_bias/fp32/direct.h index fe95b647..57941be8 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct.h @@ -14,16 +14,16 @@ namespace megdnn { namespace arm_common { -namespace fp32{ +namespace fp32 { namespace conv_bias { -void kern_direct(const float *src, const float *filter, float *dst, - const int IH, const int IW, const int OH, const int OW, - const int FH, const int FW); +void kern_direct( + const float* src, const float* filter, float* dst, const int IH, const int IW, + const int OH, const int OW, const int FH, const int FW); -} // namespace convolution -} // namespace fp32 -} // namespace arm_common -} // namespace megdnn +} // namespace conv_bias +} // namespace fp32 +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp index 08f5e4bd..16e30e2c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern.cpp @@ -19,11 +19,10 @@ namespace megdnn { namespace arm_common { namespace conv_bias { template <> -void pack_src_fp32_nchw44<1>(float* sptr_base, const float* sptr_origin, - const int, const int pw, const int pad_right, - const int ih, const int iw, const int iw2, - const int pad_top, const int pad_bottom, - const int ic, const int ic_stride) { +void pack_src_fp32_nchw44<1>( + float* sptr_base, const float* sptr_origin, const int, const int pw, + const int pad_right, const int ih, const int iw, const int iw2, + const int pad_top, const int pad_bottom, const int ic, const int ic_stride) { constexpr int ic_step = 4; rep_step(ic_idx, ic, ic_step) { const float* sptr = sptr_origin + ic_idx * ic_stride; @@ -45,10 +44,9 @@ void pack_src_fp32_nchw44<1>(float* sptr_base, const float* sptr_origin, namespace { -static inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr, - const int odd_start, - const int src_idx, - const int iw_idx) { +static inline void odd_even_split_iw8_even( + float* sptr_base, const float* sptr, const int odd_start, const int src_idx, + const int iw_idx) { constexpr int ic_step = 4; const int src_offset = src_idx * ic_step; const int even_offset = iw_idx / 2 * ic_step; @@ -72,9 +70,9 @@ static inline void odd_even_split_iw8_even(float* sptr_base, const float* sptr, vst1q_f32(sptr_base + odd_offset + 3 * ic_step, temp[7]); } -static inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr, - const int odd_start, - const int src_idx, const int iw_idx) { +static inline void odd_even_split_iw8_odd( + float* sptr_base, const float* sptr, const int odd_start, const int src_idx, + const int iw_idx) { constexpr int ic_step = 4; const int src_offset = src_idx * ic_step; const int even_offset = (iw_idx + 1) / 2 * ic_step; @@ -100,11 +98,10 @@ static inline void odd_even_split_iw8_odd(float* sptr_base, const float* sptr, } // namespace template <> -void pack_src_fp32_nchw44<2>(float* sptr_base, const float* sptr_origin, - const int ph, const int pw, const int pad_right, - const int ih, const int iw, const int iw2, - const int pad_top, const int pad_bottom, - const int ic, const int ic_stride) { +void pack_src_fp32_nchw44<2>( + float* sptr_base, const float* sptr_origin, const int ph, const int pw, + const int pad_right, const int ih, const int iw, const int iw2, + const int pad_top, const int pad_bottom, const int ic, const int ic_stride) { constexpr int ic_step = 4; int odd_start = megdnn::div_ceil(iw2, 2); float32x4_t zero_v = vdupq_n_f32(0.f); @@ -120,32 +117,32 @@ void pack_src_fp32_nchw44<2>(float* sptr_base, const float* sptr_origin, if (iw_idx % 2 == 0) { vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); } else { - vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, - zero_v); + vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); } ++iw_idx; } int src_idx = 0; if (even_start) { for (; src_idx + 7 < iw; src_idx += 8) { - odd_even_split_iw8_even(sptr_base, sptr, odd_start, src_idx, - iw_idx); + odd_even_split_iw8_even( + sptr_base, sptr, odd_start, src_idx, iw_idx); iw_idx += 8; } } else { for (; src_idx + 7 < iw; src_idx += 8) { - odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx, - iw_idx); + odd_even_split_iw8_odd(sptr_base, sptr, odd_start, src_idx, iw_idx); iw_idx += 8; } } for (; src_idx < iw; ++src_idx) { if (iw_idx % 2 == 0) { - vst1q_f32(sptr_base + iw_idx / 2 * ic_step, - vld1q_f32(sptr + src_idx * ic_step)); + vst1q_f32( + sptr_base + iw_idx / 2 * ic_step, + vld1q_f32(sptr + src_idx * ic_step)); } else { - vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, - vld1q_f32(sptr + src_idx * ic_step)); + vst1q_f32( + sptr_base + (odd_start + iw_idx / 2) * ic_step, + vld1q_f32(sptr + src_idx * ic_step)); } ++iw_idx; } @@ -153,8 +150,7 @@ void pack_src_fp32_nchw44<2>(float* sptr_base, const float* sptr_origin, if (iw_idx % 2 == 0) { vst1q_f32(sptr_base + iw_idx / 2 * ic_step, zero_v); } else { - vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, - zero_v); + vst1q_f32(sptr_base + (odd_start + iw_idx / 2) * ic_step, zero_v); } ++iw_idx; } diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h index 0ddce83a..bf9418b4 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h @@ -24,50 +24,52 @@ using namespace megdnn; using namespace arm_common; namespace { -template +template < + int src_idx, int weight_idx, int c_dim, int ow_block, int remain_w, typename T, + typename T2, typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); }; -template +template < + int src_idx, int weight_idx, int c_dim, int ow_block, typename T, typename T2, + typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&) {} }; -#define cb2(step, lane, ow_block) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % ow_block], lane); \ - c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ - src[(step + src_idx) % ow_block], lane); - -#define cb(step, lane, ow_block) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % ow_block], lane); - -#define SHIFT_CAL_HELPER(ow_block, remain_w) \ - template \ - struct ShiftCalHelper { \ - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ - UNROLL_CALL_RAW(remain_w, cb2, 0, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb2, 1, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb2, 2, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb2, 3, ow_block); \ - } \ - }; \ - template \ - struct ShiftCalHelper { \ - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ - UNROLL_CALL_RAW(remain_w, cb, 0, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb, 1, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb, 2, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb, 3, ow_block); \ - } \ +#define cb2(step, lane, ow_block) \ + c[0][step] = vfmaq_laneq_f32( \ + c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \ + c[1][step] = vfmaq_laneq_f32( \ + c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane); + +#define cb(step, lane, ow_block) \ + c[0][step] = vfmaq_laneq_f32( \ + c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); + +#define SHIFT_CAL_HELPER(ow_block, remain_w) \ + template < \ + int src_idx, int weight_idx, typename T, typename T2, typename T3, \ + typename T4> \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(remain_w, cb2, 0, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 1, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 2, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 3, ow_block); \ + } \ + }; \ + template < \ + int src_idx, int weight_idx, typename T, typename T2, typename T3, \ + typename T4> \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(remain_w, cb, 0, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 1, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 2, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 3, ow_block); \ + } \ }; SHIFT_CAL_HELPER(8, 1); @@ -88,11 +90,12 @@ SHIFT_CAL_HELPER(4, 4); #undef cb #undef cb2 -template +template < + int src_idx, int weight_idx, int c_dim, int ow_block, int remain_w, typename T, + typename T2, typename T3> MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl(c, src, weight); + ShiftCalHelper:: + impl(c, src, weight); }; template struct OCHelper { @@ -116,21 +119,22 @@ public: /** * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel * */ -template +template < + BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, + int ow_block> struct KerNeonXXs1Nchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op); + static void impl( + const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op); }; -template +template struct KerNeonXXs1Nchw44FP32 { - static void impl(const float32_t* src_ptr_origin, - const float32_t* weight_ptr, const float32_t* bias_ptr, - float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, - const Op& op) { + static void impl( + const float32_t* src_ptr_origin, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op) { constexpr int ic_step = 4; constexpr int filter_size = 2; constexpr int oc_step = 4; @@ -151,8 +155,7 @@ struct KerNeonXXs1Nchw44FP32 { for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { float32x4_t src[ow_block]; float32x4_t weight[c_dim][ic_step]; - load_helper(src, src_ptr, - 0); + load_helper(src, src_ptr, 0); load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -164,18 +167,16 @@ struct KerNeonXXs1Nchw44FP32 { weight_ptr += ld_weight_fh; } } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; -template +template struct KerNeonXXs1Nchw44FP32 { - static void impl(const float32_t* src_ptr_origin, - const float32_t* weight_ptr, const float32_t* bias_ptr, - float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, - const Op& op) { + static void impl( + const float32_t* src_ptr_origin, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op) { constexpr int ic_step = 4; constexpr int filter_size = 3; constexpr int oc_step = 4; @@ -196,8 +197,7 @@ struct KerNeonXXs1Nchw44FP32 { for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { float32x4_t src[ow_block]; float32x4_t weight[c_dim][ic_step]; - load_helper(src, src_ptr, - 0); + load_helper(src, src_ptr, 0); load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -213,17 +213,15 @@ struct KerNeonXXs1Nchw44FP32 { weight_ptr += ld_weight_fh; } } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; -template +template struct KerNeonXXs1Nchw44FP32 { - static void impl(const float32_t* src_ptr_origin, - const float32_t* weight_ptr, const float32_t* bias_ptr, - float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, - const Op& op) { + static void impl( + const float32_t* src_ptr_origin, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op) { constexpr int ic_step = 4; constexpr int filter_size = 5; constexpr int oc_step = 4; @@ -244,8 +242,7 @@ struct KerNeonXXs1Nchw44FP32 { for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { float32x4_t src[ow_block]; float32x4_t weight[c_dim][ic_step]; - load_helper(src, src_ptr, - 0); + load_helper(src, src_ptr, 0); load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -273,18 +270,16 @@ struct KerNeonXXs1Nchw44FP32 { weight_ptr += ld_weight_fh; } } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; -template +template struct KerNeonXXs1Nchw44FP32 { - static void impl(const float32_t* src_ptr_origin, - const float32_t* weight_ptr, const float32_t* bias_ptr, - float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, - const Op& op) { + static void impl( + const float32_t* src_ptr_origin, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op) { constexpr int ic_step = 4; constexpr int filter_size = 7; constexpr int oc_step = 4; @@ -305,8 +300,7 @@ struct KerNeonXXs1Nchw44FP32 { for (int fh_idx = 0; fh_idx < filter_size; ++fh_idx) { float32x4_t src[ow_block]; float32x4_t weight[c_dim][ic_step]; - load_helper(src, src_ptr, - 0); + load_helper(src, src_ptr, 0); load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -344,21 +338,17 @@ struct KerNeonXXs1Nchw44FP32 { weight_ptr += ld_weight_fh; } } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; } // namespace template -void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, - const float* bias, float*, float* dst, - const int oc, const int ic, - const int ih, const int iw, - const int oh, const int oh_block, - const int ow, const Op& op, const int, - const int) { +void conv_bias::conv_direct_fp32_nchw44( + const float* src, const float* filter, const float* bias, float*, float* dst, + const int oc, const int ic, const int ih, const int iw, const int oh, + const int oh_block, const int ow, const Op& op, const int, const int) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -383,20 +373,18 @@ void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, using remain_fun = std::function; + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op)>; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonXXs1Nchw44FP32::impl; \ - kern_small_oc_remain = \ - KerNeonXXs1Nchw44FP32::impl; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = KerNeonXXs1Nchw44FP32< \ + bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \ + kern_small_oc_remain = KerNeonXXs1Nchw44FP32< \ + bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \ break; UNROLL_CALL_RAW(8, cb); @@ -414,13 +402,11 @@ void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - KerNeonXXs1Nchw44FP32::impl(src + src_offset, - filter + weight_offset, - bias + bias_offset, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + KerNeonXXs1Nchw44FP32< + bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>:: + impl(src + src_offset, filter + weight_offset, + bias + bias_offset, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); } if (ow_remain > 0) { const int src_offset = @@ -430,9 +416,9 @@ void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + bias_offset, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + bias_offset, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } @@ -448,13 +434,11 @@ void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - KerNeonXXs1Nchw44FP32::impl(src + src_offset, - filter + weight_offset, - bias + bias_offset, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + KerNeonXXs1Nchw44FP32< + bias_mode, Op, ow_step, filter_size, oc_step, ow_step>:: + impl(src + src_offset, filter + weight_offset, + bias + bias_offset, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); } if (ow_remain > 0) { const int src_offset = @@ -464,21 +448,20 @@ void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + bias_offset, dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + bias_offset, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } } -#define INSTANTIATION(filter_size, bias_mode, Op) \ - template void \ - conv_bias::conv_direct_fp32_nchw44( \ - const float* src, const float* filter, const float* bias, float*, \ - float* dst, const int oc, const int ic, const int ih, \ - const int iw, const int oh, const int oh_block, const int ow, \ - const Op& op, const int, const int); +#define INSTANTIATION(filter_size, bias_mode, Op) \ + template void conv_bias::conv_direct_fp32_nchw44( \ + const float* src, const float* filter, const float* bias, float*, \ + float* dst, const int oc, const int ic, const int ih, const int iw, \ + const int oh, const int oh_block, const int ow, const Op& op, const int, \ + const int); #define FOR_OP(filter_size, bias) \ INSTANTIATION(filter_size, bias, NoneOp) \ diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h index fd7b47ad..b31cd438 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h @@ -24,50 +24,52 @@ using namespace megdnn; using namespace arm_common; namespace { -template +template < + int src_idx, int weight_idx, int c_dim, int ow_block, int remain_w, typename T, + typename T2, typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); }; -template +template < + int src_idx, int weight_idx, int c_dim, int ow_block, typename T, typename T2, + typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&) {} }; -#define cb2(step, lane, ow_block) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % ow_block], lane); \ - c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ - src[(step + src_idx) % ow_block], lane); - -#define cb(step, lane, ow_block) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % ow_block], lane); - -#define SHIFT_CAL_HELPER(ow_block, remain_w) \ - template \ - struct ShiftCalHelper { \ - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ - UNROLL_CALL_RAW(remain_w, cb2, 0, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb2, 1, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb2, 2, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb2, 3, ow_block); \ - } \ - }; \ - template \ - struct ShiftCalHelper { \ - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ - UNROLL_CALL_RAW(remain_w, cb, 0, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb, 1, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb, 2, ow_block); \ - UNROLL_CALL_RAW(remain_w, cb, 3, ow_block); \ - } \ +#define cb2(step, lane, ow_block) \ + c[0][step] = vfmaq_laneq_f32( \ + c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); \ + c[1][step] = vfmaq_laneq_f32( \ + c[1][step], weight[1][lane], src[(step + src_idx) % ow_block], lane); + +#define cb(step, lane, ow_block) \ + c[0][step] = vfmaq_laneq_f32( \ + c[0][step], weight[0][lane], src[(step + src_idx) % ow_block], lane); + +#define SHIFT_CAL_HELPER(ow_block, remain_w) \ + template < \ + int src_idx, int weight_idx, typename T, typename T2, typename T3, \ + typename T4> \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(remain_w, cb2, 0, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 1, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 2, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 3, ow_block); \ + } \ + }; \ + template < \ + int src_idx, int weight_idx, typename T, typename T2, typename T3, \ + typename T4> \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(remain_w, cb, 0, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 1, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 2, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 3, ow_block); \ + } \ }; SHIFT_CAL_HELPER(8, 1); @@ -88,11 +90,12 @@ SHIFT_CAL_HELPER(4, 4); #undef cb #undef cb2 -template +template < + int src_idx, int weight_idx, int c_dim, int ow_block, int remain_w, typename T, + typename T2, typename T3> MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl(c, src, weight); + ShiftCalHelper:: + impl(c, src, weight); }; template @@ -116,22 +119,22 @@ public: /** * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel * */ -template +template < + BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, + int ow_block> struct KerNeonXXs2Nchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op, - const float32_t* src_ptr_odd); + static void impl( + const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op, const float32_t* src_ptr_odd); }; -template +template struct KerNeonXXs2Nchw44FP32 { - static void impl(const float32_t* src_ptr_origin, - const float32_t* weight_ptr, const float32_t* bias_ptr, - float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, - const Op& op, const float32_t* src_ptr_odd_origin) { + static void impl( + const float32_t* src_ptr_origin, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op, const float32_t* src_ptr_odd_origin) { constexpr int loop_ic_step = 4; constexpr int filter_size = 2; constexpr int oc_step = 4; @@ -155,12 +158,11 @@ struct KerNeonXXs2Nchw44FP32 { float32x4_t weight[c_dim][4]; /////////row 0///////////// load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, - ld_weight_oc); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - load_helper(src, src_ptr_odd, - 0); + load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -169,12 +171,11 @@ struct KerNeonXXs2Nchw44FP32 { weight_ptr += ld_weight_fh; /////////row 1///////////// load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, - ld_weight_oc); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); - load_helper(src, src_ptr_odd, - 0); + load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -182,18 +183,16 @@ struct KerNeonXXs2Nchw44FP32 { src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; -template +template struct KerNeonXXs2Nchw44FP32 { - static void impl(const float32_t* src_ptr_origin, - const float32_t* weight_ptr, const float32_t* bias_ptr, - float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, - const Op& op, const float32_t* src_ptr_odd_origin) { + static void impl( + const float32_t* src_ptr_origin, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op, const float32_t* src_ptr_odd_origin) { constexpr int loop_ic_step = 4; constexpr int filter_size = 3; constexpr int oc_step = 4; @@ -216,8 +215,8 @@ struct KerNeonXXs2Nchw44FP32 { float32x4_t weight[c_dim][4]; /////////row 0///////////// load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, - ld_weight_oc); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); @@ -225,8 +224,7 @@ struct KerNeonXXs2Nchw44FP32 { weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - load_helper(src, src_ptr_odd, - 0); + load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -235,16 +233,15 @@ struct KerNeonXXs2Nchw44FP32 { weight_ptr += ld_weight_fh; /////////row 1///////////// load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, - ld_weight_oc); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - load_helper(src, src_ptr_odd, - 0); + load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -253,8 +250,8 @@ struct KerNeonXXs2Nchw44FP32 { weight_ptr += ld_weight_fh; //////////row 2///////////// load_helper(src, src_ptr, 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, - ld_weight_oc); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); @@ -262,8 +259,7 @@ struct KerNeonXXs2Nchw44FP32 { weight, weight_ptr, ld_weight_oc); cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); - load_helper(src, src_ptr_odd, - 0); + load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -271,18 +267,16 @@ struct KerNeonXXs2Nchw44FP32 { src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; -template +template struct KerNeonXXs2Nchw44FP32 { - static void impl(const float32_t* src_ptr_origin, - const float32_t* weight_ptr, const float32_t* bias_ptr, - float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, - const Op& op, const float32_t* src_ptr_odd_origin) { + static void impl( + const float32_t* src_ptr_origin, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op, const float32_t* src_ptr_odd_origin) { constexpr int loop_ic_step = 4; constexpr int filter_size = 5; constexpr int oc_step = 4; @@ -306,10 +300,9 @@ struct KerNeonXXs2Nchw44FP32 { float32x4_t src[ow_block]; float32x4_t weight[c_dim][4]; // even element - load_helper(src, src_ptr, - 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, - ld_weight_oc); + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( @@ -320,8 +313,7 @@ struct KerNeonXXs2Nchw44FP32 { weight, weight_ptr, ld_weight_oc); cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); // odd element - load_helper( - src, src_ptr_odd, 0); + load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -335,8 +327,7 @@ struct KerNeonXXs2Nchw44FP32 { weight_ptr += ld_weight_fh; } } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; @@ -345,13 +336,12 @@ struct KerNeonXXs2Nchw44FP32 { * kernel[6], kernel[1], kernel[3], kernel[5] * src is packed like 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9 **/ -template +template struct KerNeonXXs2Nchw44FP32 { - static void impl(const float32_t* src_ptr_origin, - const float32_t* weight_ptr, const float32_t* bias_ptr, - float32_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, - const Op& op, const float32_t* src_ptr_odd_origin) { + static void impl( + const float32_t* src_ptr_origin, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op, const float32_t* src_ptr_odd_origin) { constexpr int loop_ic_step = 4; constexpr int filter_size = 7; constexpr int oc_step = 4; @@ -375,10 +365,9 @@ struct KerNeonXXs2Nchw44FP32 { float32x4_t src[ow_block]; float32x4_t weight[c_dim][4]; // even element - load_helper(src, src_ptr, - 0); - load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, - ld_weight_oc); + load_helper(src, src_ptr, 0); + load_helper<4, 0, oc_step, c_dim, Vld1q_f32>( + weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( @@ -393,8 +382,7 @@ struct KerNeonXXs2Nchw44FP32 { weight, weight_ptr, ld_weight_oc); cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); // odd element - load_helper( - src, src_ptr_odd, 0); + load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); @@ -412,21 +400,17 @@ struct KerNeonXXs2Nchw44FP32 { weight_ptr += ld_weight_fh; } } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; } // namespace template -void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, - const float* bias, float*, float* dst, - const int oc, const int ic, - const int ih, const int iw, - const int oh, const int oh_block, - const int ow, const Op& op, const int, - const int) { +void conv_bias::conv_direct_fp32_nchw44( + const float* src, const float* filter, const float* bias, float*, float* dst, + const int oc, const int ic, const int ih, const int iw, const int oh, + const int oh_block, const int ow, const Op& op, const int, const int) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -452,21 +436,18 @@ void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, using remain_fun = std::function; + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op, const float32_t* src_ptr_odd_origin)>; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonXXs2Nchw44FP32::impl; \ - kern_small_oc_remain = \ - KerNeonXXs2Nchw44FP32::impl; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = KerNeonXXs2Nchw44FP32< \ + bias_mode, Op, step, filter_size, big_oc_step, ow_step>::impl; \ + kern_small_oc_remain = KerNeonXXs2Nchw44FP32< \ + bias_mode, Op, step, filter_size, oc_step, ow_step>::impl; \ break; UNROLL_CALL_RAW(8, cb); @@ -477,41 +458,39 @@ void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, const int weight_offset = oc_idx * ic * fh * fw; for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = (oh_idx * stride_h * iw + - ow_idx / 2 * stride_w * ih_step) * - ic_step; + const int src_offset = + (oh_idx * stride_h * iw + ow_idx / 2 * stride_w * ih_step) * + ic_step; const int src_offset_odd = - (oh_idx * stride_h * iw + - ow_idx / 2 * stride_w * ih_step + odd_start) * + (oh_idx * stride_h * iw + ow_idx / 2 * stride_w * ih_step + + odd_start) * ic_step; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - KerNeonXXs2Nchw44FP32::impl(src + src_offset, - filter + weight_offset, - bias + bias_offset, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op, - src + src_offset_odd); + KerNeonXXs2Nchw44FP32< + bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step>:: + impl(src + src_offset, filter + weight_offset, + bias + bias_offset, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op, src + src_offset_odd); } if (ow_remain > 0) { - const int src_offset = (oh_idx * stride_h * iw + - ow_end / 2 * stride_w * ih_step) * - ic_step; + const int src_offset = + (oh_idx * stride_h * iw + ow_end / 2 * stride_w * ih_step) * + ic_step; const int src_offset_odd = - (oh_idx * stride_h * iw + - ow_end / 2 * stride_w * ih_step + odd_start) * + (oh_idx * stride_h * iw + ow_end / 2 * stride_w * ih_step + + odd_start) * ic_step; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + bias_offset, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op, src + src_offset_odd); + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + bias_offset, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op, + src + src_offset_odd); } } } @@ -520,54 +499,50 @@ void conv_bias::conv_direct_fp32_nchw44(const float* src, const float* filter, const int weight_offset = oc_idx * ic * fh * fw; for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = (oh_idx * stride_h * iw + - ow_idx / 2 * stride_w * ih_step) * - ic_step; + const int src_offset = + (oh_idx * stride_h * iw + ow_idx / 2 * stride_w * ih_step) * + ic_step; const int src_offset_odd = - (oh_idx * stride_h * iw + - ow_idx / 2 * stride_w * ih_step + odd_start) * + (oh_idx * stride_h * iw + ow_idx / 2 * stride_w * ih_step + + odd_start) * ic_step; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - KerNeonXXs2Nchw44FP32::impl(src + src_offset, - filter + weight_offset, - bias + bias_offset, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op, - src + src_offset_odd); + KerNeonXXs2Nchw44FP32< + bias_mode, Op, ow_step, filter_size, oc_step, ow_step>:: + impl(src + src_offset, filter + weight_offset, + bias + bias_offset, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op, src + src_offset_odd); } if (ow_remain > 0) { - const int src_offset = (oh_idx * stride_h * iw + - ow_end / 2 * stride_w * ih_step) * - ic_step; + const int src_offset = + (oh_idx * stride_h * iw + ow_end / 2 * stride_w * ih_step) * + ic_step; const int src_offset_odd = - (oh_idx * stride_h * iw + - ow_end / 2 * stride_w * ih_step + odd_start) * + (oh_idx * stride_h * iw + ow_end / 2 * stride_w * ih_step + + odd_start) * ic_step; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; const int bias_offset = bias_mode == BiasMode::BIAS ? dst_offset : oc_idx; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + bias_offset, dst + dst_offset, ic, - ih, iw, ld_dst_oc, op, - src + src_offset_odd); + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + bias_offset, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op, + src + src_offset_odd); } } } } -#define INSTANTIATION(filter_size, bias_mode, Op) \ - template void \ - conv_bias::conv_direct_fp32_nchw44( \ - const float* src, const float* filter, const float* bias, float*, \ - float* dst, const int oc, const int ic, const int ih, \ - const int iw, const int oh, const int oh_block, const int ow, \ - const Op& op, const int, const int); +#define INSTANTIATION(filter_size, bias_mode, Op) \ + template void conv_bias::conv_direct_fp32_nchw44( \ + const float* src, const float* filter, const float* bias, float*, \ + float* dst, const int oc, const int ic, const int ih, const int iw, \ + const int oh, const int oh_block, const int ow, const Op& op, const int, \ + const int); #define FOR_OP(filter_size, bias) \ INSTANTIATION(filter_size, bias, NoneOp) \ diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h index 8090c5e5..36daabc8 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h @@ -23,7 +23,6 @@ #include "src/armv7/matrix_mul/asm/common.h" #endif - using namespace megdnn; using namespace arm_common; @@ -36,47 +35,49 @@ namespace { *\tparam T2 is type of src regs *\tparam T3 is type of weight regs */ -template +template < + int src_idx, int weight_idx, int c_dim, int stride, int remain_w, typename T, + typename T2, typename T3> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); }; -template +template < + int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, + typename T3> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&) {} }; -#define cb(step) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ - src[(step * stride + src_idx) / 4], \ - (step * stride + src_idx) % 4); \ - c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][weight_idx], \ - src[(step * stride + src_idx) / 4], \ - (step * stride + src_idx) % 4); - -#define cb2(step) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ - src[(step * stride + src_idx) / 4], \ - (step * stride + src_idx) % 4); - -#define SHIFT_CAL_HELPER(ow_remain) \ - template \ - struct ShiftCalHelper { \ - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ - UNROLL_CALL_RAW(ow_remain, cb); \ - } \ - }; \ - template \ - struct ShiftCalHelper { \ - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ - UNROLL_CALL_RAW(ow_remain, cb2); \ - } \ +#define cb(step) \ + c[0][step] = vfmaq_laneq_f32( \ + c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \ + (step * stride + src_idx) % 4); \ + c[1][step] = vfmaq_laneq_f32( \ + c[1][step], weight[1][weight_idx], src[(step * stride + src_idx) / 4], \ + (step * stride + src_idx) % 4); + +#define cb2(step) \ + c[0][step] = vfmaq_laneq_f32( \ + c[0][step], weight[0][weight_idx], src[(step * stride + src_idx) / 4], \ + (step * stride + src_idx) % 4); + +#define SHIFT_CAL_HELPER(ow_remain) \ + template < \ + int src_idx, int weight_idx, int stride, typename T, typename T2, \ + typename T3> \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(ow_remain, cb); \ + } \ + }; \ + template < \ + int src_idx, int weight_idx, int stride, typename T, typename T2, \ + typename T3> \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(ow_remain, cb2); \ + } \ }; SHIFT_CAL_HELPER(1) @@ -92,11 +93,12 @@ SHIFT_CAL_HELPER(8) #undef cb #undef cb2 -template +template < + int src_idx, int weight_idx, int c_dim, int stride, int remain_w, typename T, + typename T2, typename T3> MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl(c, src, weight); + ShiftCalHelper::impl( + c, src, weight); }; enum CpuTag { DEFAULT_CPU_TAG = 0, @@ -122,28 +124,30 @@ public: /** * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel **/ -template +template < + BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, + int stride, int ow_block, int tag = CpuTag::DEFAULT_CPU_TAG> struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op); + static void impl( + const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op); }; -template -struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op) { +template < + BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, + int ow_block> +struct KerNeonXXs2NchwNchw44FP32< + bias_mode, Op, remain_w, 7, oc_block, stride, ow_block> { + static void impl( + const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op) { constexpr int loop_ic_step = 1; constexpr int filter_size = 7; constexpr int oc_step = 4; constexpr int simd_len = 4; constexpr int src_reg_size = - (ow_block * stride + filter_size - stride + simd_len - 1) / - simd_len; + (ow_block * stride + filter_size - stride + simd_len - 1) / simd_len; constexpr int ld_weight_fw = oc_step * filter_size; const int ld_weight_oc = oc_step * filter_size * filter_size * ic; @@ -157,17 +161,16 @@ struct KerNeonXXs2NchwNchw44FP32( \ - src, src_ptr + step * iw, 0); \ - load_helper( \ - weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ - cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ - cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ - cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ - cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ - cal_helper<4, 4, c_dim, stride, remain_w>(c, src, weight); \ - cal_helper<5, 5, c_dim, stride, remain_w>(c, src, weight); \ +#define KERNEL_CB(step) \ + load_helper(src, src_ptr + step * iw, 0); \ + load_helper( \ + weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ + cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<4, 4, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<5, 5, c_dim, stride, remain_w>(c, src, weight); \ cal_helper<6, 6, c_dim, stride, remain_w>(c, src, weight); UNROLL_CALL_RAW(7, KERNEL_CB) @@ -176,25 +179,25 @@ struct KerNeonXXs2NchwNchw44FP32(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; -template -struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op) { +template < + BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, + int ow_block> +struct KerNeonXXs2NchwNchw44FP32< + bias_mode, Op, remain_w, 5, oc_block, stride, ow_block> { + static void impl( + const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op) { constexpr int loop_ic_step = 1; constexpr int filter_size = 5; constexpr int oc_step = 4; constexpr int simd_len = 4; constexpr int src_reg_size = - (ow_block * stride + filter_size - stride + simd_len - 1) / - simd_len; + (ow_block * stride + filter_size - stride + simd_len - 1) / simd_len; constexpr int ld_weight_fw = oc_step * filter_size; const int ld_weight_oc = oc_step * filter_size * filter_size * ic; @@ -208,15 +211,14 @@ struct KerNeonXXs2NchwNchw44FP32( \ - src, src_ptr + step * iw, 0); \ - load_helper( \ - weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ - cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ - cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ - cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ - cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ +#define KERNEL_CB(step) \ + load_helper(src, src_ptr + step * iw, 0); \ + load_helper( \ + weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ + cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ cal_helper<4, 4, c_dim, stride, remain_w>(c, src, weight); UNROLL_CALL_RAW(5, KERNEL_CB) #undef KERNEL_CB @@ -224,25 +226,25 @@ struct KerNeonXXs2NchwNchw44FP32(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; -template -struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op) { +template < + BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, + int ow_block> +struct KerNeonXXs2NchwNchw44FP32< + bias_mode, Op, remain_w, 3, oc_block, stride, ow_block> { + static void impl( + const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op) { constexpr int loop_ic_step = 1; constexpr int filter_size = 3; constexpr int oc_step = 4; constexpr int simd_len = 4; constexpr int src_reg_size = - (ow_block * stride + filter_size - stride + simd_len - 1) / - simd_len; + (ow_block * stride + filter_size - stride + simd_len - 1) / simd_len; constexpr int ld_weight_fw = oc_step * filter_size; const int ld_weight_oc = oc_step * filter_size * filter_size * ic; @@ -256,8 +258,7 @@ struct KerNeonXXs2NchwNchw44FP32(src, src_ptr, - 0); + load_helper(src, src_ptr, 0); load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); @@ -265,8 +266,7 @@ struct KerNeonXXs2NchwNchw44FP32(c, src, weight); // row 1 - load_helper( - src, src_ptr + iw, 0); + load_helper(src, src_ptr + iw, 0); load_helper( weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); @@ -285,8 +285,7 @@ struct KerNeonXXs2NchwNchw44FP32(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; @@ -294,9 +293,10 @@ struct KerNeonXXs2NchwNchw44FP32 struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op) { + static void impl( + const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op) { constexpr int oc_block = 4; constexpr int stride = 2; constexpr int remain_w = 8; @@ -306,8 +306,7 @@ struct KerNeonXXs2NchwNchw44FP32 { constexpr int oc_step = 4; constexpr int src_line_block = ow_block * stride + filter_size - stride; - const int iw_skip_bytes = - (iw - round_up(src_line_block, 2)) * sizeof(float); + const int iw_skip_bytes = (iw - round_up(src_line_block, 2)) * sizeof(float); const int ld_src_ic_skip_bytes = iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes; constexpr int c_dim = OCHelper::val; @@ -450,29 +449,27 @@ struct KerNeonXXs2NchwNchw44FP32 { "vmla.f32 %q[c7], q7, d8[0]\n" "6:\n" - : [c0] "+w"(c[0][0]), [c1] "+w"(c[0][1]), - [c2] "+w"(c[0][2]), [c3] "+w"(c[0][3]), - [c4] "+w"(c[0][4]), [c5] "+w"(c[0][5]), - [c6] "+w"(c[0][6]), [c7] "+w"(c[0][7]), - [src_ptr] "+r"(src_ptr), [weight_ptr] "+r"(weight_ptr) + : [c0] "+w"(c[0][0]), [c1] "+w"(c[0][1]), [c2] "+w"(c[0][2]), + [c3] "+w"(c[0][3]), [c4] "+w"(c[0][4]), [c5] "+w"(c[0][5]), + [c6] "+w"(c[0][6]), [c7] "+w"(c[0][7]), [src_ptr] "+r"(src_ptr), + [weight_ptr] "+r"(weight_ptr) : [ld_src_ic_skip_bytes] "r"(ld_src_ic_skip_bytes), [iw_skip_bytes] "r"(iw_skip_bytes) - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", - "d9", "d10", "d11", "d12", "d13", "d14", "d15", "r1", - "r2", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "r1", "r2", "cc", "memory"); } - store_ocx_ow8_remain_static(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; template -struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op) { +struct KerNeonXXs2NchwNchw44FP32< + bias_mode, Op, 8, 3, 4, 2, 8, CpuTag::DEFAULT_CPU_TAG> { + static void impl( + const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op) { constexpr int oc_block = 4; constexpr int stride = 2; constexpr int remain_w = 8; @@ -482,8 +479,7 @@ struct KerNeonXXs2NchwNchw44FP32::val; @@ -612,38 +608,36 @@ struct KerNeonXXs2NchwNchw44FP32(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; #endif -template -struct KerNeonXXs2NchwNchw44FP32 { - static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, - const float32_t* bias_ptr, float32_t* dst_ptr, int ic, - int ih, int iw, int ld_dst_oc, const Op& op) { +template < + BiasMode bias_mode, typename Op, int remain_w, int oc_block, int stride, + int ow_block> +struct KerNeonXXs2NchwNchw44FP32< + bias_mode, Op, remain_w, 2, oc_block, stride, ow_block> { + static void impl( + const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op) { constexpr int loop_ic_step = 1; constexpr int filter_size = 2; constexpr int oc_step = 4; constexpr int simd_len = 4; constexpr int src_reg_size = - (ow_block * stride + filter_size - stride + simd_len - 1) / - simd_len; + (ow_block * stride + filter_size - stride + simd_len - 1) / simd_len; constexpr int ld_weight_fw = oc_step * filter_size; const int ld_weight_oc = oc_step * filter_size * filter_size * ic; @@ -657,16 +651,14 @@ struct KerNeonXXs2NchwNchw44FP32(src, src_ptr, - 0); + load_helper(src, src_ptr, 0); load_helper( weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); // row 1 - load_helper( - src, src_ptr + iw, 0); + load_helper(src, src_ptr + iw, 0); load_helper( weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); @@ -675,18 +667,17 @@ struct KerNeonXXs2NchwNchw44FP32(c, op, dst_ptr, - ld_dst_oc); + store_ocx_ow8_remain_static(c, op, dst_ptr, ld_dst_oc); } }; template struct ConvDirectFp32NchwNchw44 { static MEGDNN_ALWAYS_INLINE void impl( - const float32_t* src, const float32_t* filter, - const float32_t* bias, float32_t*, float32_t* dst, const int oc, - const int ic, const int ih, const int iw, const int oh, - const int oh_block, const int ow, const Op& op) { + const float32_t* src, const float32_t* filter, const float32_t* bias, + float32_t*, float32_t* dst, const int oc, const int ic, const int ih, + const int iw, const int oh, const int oh_block, const int ow, + const Op& op) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 1; @@ -712,20 +703,18 @@ struct ConvDirectFp32NchwNchw44 { using remain_fun = std::function; + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op)>; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonXXs2NchwNchw44FP32::impl; \ - kern_small_oc_remain = \ - KerNeonXXs2NchwNchw44FP32::impl; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = KerNeonXXs2NchwNchw44FP32< \ + bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \ + kern_small_oc_remain = KerNeonXXs2NchwNchw44FP32< \ + bias_mode, Op, step, filter_size, oc_step, stride, ow_step>::impl; \ break; UNROLL_CALL_RAW(8, cb); @@ -737,28 +726,27 @@ struct ConvDirectFp32NchwNchw44 { const int weight_offset = oc_idx * ic * fh * fw; for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; KerNeonXXs2NchwNchw44FP32< - bias_mode, Op, ow_step, filter_size, big_oc_step, - stride, ow_step>::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); + bias_mode, Op, ow_step, filter_size, big_oc_step, stride, + ow_step>:: + impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, + op); } if (ow_remain > 0) { - const int src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } @@ -767,29 +755,27 @@ struct ConvDirectFp32NchwNchw44 { const int weight_offset = oc_idx * ic * fh * fw; for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; KerNeonXXs2NchwNchw44FP32< - bias_mode, Op, ow_step, filter_size, oc_step, - stride, ow_step>::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); + bias_mode, Op, ow_step, filter_size, oc_step, stride, + ow_step>:: + impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, + op); } if (ow_remain > 0) { - const int src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, - filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } @@ -800,10 +786,10 @@ struct ConvDirectFp32NchwNchw44 { template struct ConvDirectFp32NchwNchw44 { static MEGDNN_ALWAYS_INLINE void impl( - const float32_t* src, const float32_t* filter, - const float32_t* bias, float32_t*, float32_t* dst, const int oc, - const int ic, const int ih, const int iw, const int oh, - const int oh_block, const int ow, const Op& op) { + const float32_t* src, const float32_t* filter, const float32_t* bias, + float32_t*, float32_t* dst, const int oc, const int ic, const int ih, + const int iw, const int oh, const int oh_block, const int ow, + const Op& op) { constexpr int filter_size = 3; constexpr int stride = 2; constexpr int fh = filter_size; @@ -826,16 +812,15 @@ struct ConvDirectFp32NchwNchw44 { using remain_fun = std::function; + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op)>; remain_fun kern_big_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonXXs2NchwNchw44FP32::impl; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = KerNeonXXs2NchwNchw44FP32< \ + bias_mode, Op, step, filter_size, big_oc_step, stride, ow_step>::impl; \ break; UNROLL_CALL_RAW(8, cb); @@ -844,10 +829,9 @@ struct ConvDirectFp32NchwNchw44 { } #undef cb #if MGB_ENABLE_CPUINFO - auto arch_tag = - cpuinfo_get_current_core()->uarch == cpuinfo_uarch_cortex_a7 - ? CpuTag::A7_TAG - : CpuTag::DEFAULT_CPU_TAG; + auto arch_tag = cpuinfo_get_current_core()->uarch == cpuinfo_uarch_cortex_a7 + ? CpuTag::A7_TAG + : CpuTag::DEFAULT_CPU_TAG; #else auto arch_tag = CpuTag::A7_TAG; #endif @@ -860,30 +844,27 @@ struct ConvDirectFp32NchwNchw44 { const int weight_offset = oc_idx * ic * fh * fw; for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; KerNeonXXs2NchwNchw44FP32< - bias_mode, Op, ow_step, filter_size, - big_oc_step, stride, ow_step, - CpuTag::A7_TAG>::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + bias_mode, Op, ow_step, filter_size, big_oc_step, + stride, ow_step, CpuTag::A7_TAG>:: + impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); } if (ow_remain > 0) { - const int src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, - filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } @@ -892,29 +873,27 @@ struct ConvDirectFp32NchwNchw44 { const int weight_offset = oc_idx * ic * fh * fw; for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; + const int src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; KerNeonXXs2NchwNchw44FP32< - bias_mode, Op, ow_step, filter_size, - big_oc_step, stride, - ow_step>::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, dst + dst_offset, - ic, ih, iw, ld_dst_oc, op); + bias_mode, Op, ow_step, filter_size, big_oc_step, + stride, ow_step>:: + impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); } if (ow_remain > 0) { - const int src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, - filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); + const int src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } @@ -930,21 +909,19 @@ template void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44( const float32_t* src, const float32_t* filter, const float32_t* bias, float32_t*, float32_t* dst, const int oc, const int ic, const int ih, - const int iw, const int oh, const int oh_block, const int ow, - const Op& op, const int, const int) { + const int iw, const int oh, const int oh_block, const int ow, const Op& op, + const int, const int) { ConvDirectFp32NchwNchw44::impl( - src, filter, bias, nullptr, dst, oc, ic, ih, iw, oh, oh_block, ow, - op); + src, filter, bias, nullptr, dst, oc, ic, ih, iw, oh, oh_block, ow, op); } -#define INSTANTIATION(stride, filter_size, bias_mode, Op) \ - template void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44< \ - bias_mode, Op, filter_size, stride>( \ - const float32_t* src, const float32_t* filter, \ - const float32_t* bias, float32_t*, float32_t* dst, const int oc, \ - const int ic, const int ih, const int iw, const int oh, \ - const int oh_block, const int ow, const Op& op, const int, \ - const int); +#define INSTANTIATION(stride, filter_size, bias_mode, Op) \ + template void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44< \ + bias_mode, Op, filter_size, stride>( \ + const float32_t* src, const float32_t* filter, const float32_t* bias, \ + float32_t*, float32_t* dst, const int oc, const int ic, const int ih, \ + const int iw, const int oh, const int oh_block, const int ow, \ + const Op& op, const int, const int); #define FOR_OP(stride, filter, bias) \ INSTANTIATION(stride, filter, bias, NoneOp) \ diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp index 57347971..3f822c22 100644 --- a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp @@ -28,9 +28,9 @@ using namespace conv_stride1; using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; -void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, - float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { +void conv_stride1::do_conv_2x2_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - OW; //! unroll of 2 size_t ic = 0; @@ -65,23 +65,19 @@ void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, MEGDNN_SIMD_TYPE _sum = MEGDNN_SIMD_LOADU(outptr); - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r000, - MEGDNN_SIMD_GET_LOW(_k0), 0); - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r001, - MEGDNN_SIMD_GET_LOW(_k0), 1); - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r010, - MEGDNN_SIMD_GET_HIGH(_k0), 0); - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r011, - MEGDNN_SIMD_GET_HIGH(_k0), 1); - - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r100, - MEGDNN_SIMD_GET_LOW(_k1), 0); - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r101, - MEGDNN_SIMD_GET_LOW(_k1), 1); - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r110, - MEGDNN_SIMD_GET_HIGH(_k1), 0); - _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r111, - MEGDNN_SIMD_GET_HIGH(_k1), 1); + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r000, MEGDNN_SIMD_GET_LOW(_k0), 0); + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r001, MEGDNN_SIMD_GET_LOW(_k0), 1); + _sum = MEGDNN_SIMD_VMLAQ_LANE( + _sum, _r010, MEGDNN_SIMD_GET_HIGH(_k0), 0); + _sum = MEGDNN_SIMD_VMLAQ_LANE( + _sum, _r011, MEGDNN_SIMD_GET_HIGH(_k0), 1); + + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r100, MEGDNN_SIMD_GET_LOW(_k1), 0); + _sum = MEGDNN_SIMD_VMLAQ_LANE(_sum, _r101, MEGDNN_SIMD_GET_LOW(_k1), 1); + _sum = MEGDNN_SIMD_VMLAQ_LANE( + _sum, _r110, MEGDNN_SIMD_GET_HIGH(_k1), 0); + _sum = MEGDNN_SIMD_VMLAQ_LANE( + _sum, _r111, MEGDNN_SIMD_GET_HIGH(_k1), 1); MEGDNN_SIMD_STOREU(outptr, _sum); @@ -143,9 +139,9 @@ void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, } } -void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, - float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { +void conv_stride1::do_conv_3x3_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - OW; rep(ic, IC) { @@ -290,9 +286,9 @@ void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, } } -void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, - float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { +void conv_stride1::do_conv_5x5_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - OW; rep(ic, IC) { @@ -530,9 +526,9 @@ void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, } } -void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, - float* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t IC) { +void conv_stride1::do_conv_7x7_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - OW; rep(ic, IC) { @@ -564,20 +560,14 @@ void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, MEGDNN_SIMD_TYPE _k0123 = MEGDNN_SIMD_LOADU(k0); MEGDNN_SIMD_TYPE _k4567 = MEGDNN_SIMD_LOADU(k0 + 4); - MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); // 0 1 2 3 - MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); // 4 5 6 7 - MEGDNN_SIMD_TYPE _r00n = - MEGDNN_SIMD_LOADU(r0 + 8); // 8 9 10 11 - MEGDNN_SIMD_TYPE _r01 = - MEGDNN_SIMD_EXT(_r00, _r04, 1); // 1 2 3 4 - MEGDNN_SIMD_TYPE _r02 = - MEGDNN_SIMD_EXT(_r00, _r04, 2); // 2 3 4 5 - MEGDNN_SIMD_TYPE _r03 = - MEGDNN_SIMD_EXT(_r00, _r04, 3); // 3 4 5 6 - MEGDNN_SIMD_TYPE _r05 = - MEGDNN_SIMD_EXT(_r04, _r00n, 1); // 5 6 7 8 - MEGDNN_SIMD_TYPE _r06 = - MEGDNN_SIMD_EXT(_r04, _r00n, 2); // 6 7 8 9 + MEGDNN_SIMD_TYPE _r00 = MEGDNN_SIMD_LOADU(r0); // 0 1 2 3 + MEGDNN_SIMD_TYPE _r04 = MEGDNN_SIMD_LOADU(r0 + 4); // 4 5 6 7 + MEGDNN_SIMD_TYPE _r00n = MEGDNN_SIMD_LOADU(r0 + 8); // 8 9 10 11 + MEGDNN_SIMD_TYPE _r01 = MEGDNN_SIMD_EXT(_r00, _r04, 1); // 1 2 3 4 + MEGDNN_SIMD_TYPE _r02 = MEGDNN_SIMD_EXT(_r00, _r04, 2); // 2 3 4 5 + MEGDNN_SIMD_TYPE _r03 = MEGDNN_SIMD_EXT(_r00, _r04, 3); // 3 4 5 6 + MEGDNN_SIMD_TYPE _r05 = MEGDNN_SIMD_EXT(_r04, _r00n, 1); // 5 6 7 8 + MEGDNN_SIMD_TYPE _r06 = MEGDNN_SIMD_EXT(_r04, _r00n, 2); // 6 7 8 9 _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r00, _k0123, 0); _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r01, _k0123, 1); diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h index 66de74fa..fadf8c32 100644 --- a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h +++ b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.h @@ -17,18 +17,21 @@ namespace arm_common { namespace fp32 { namespace conv_stride1 { -void do_conv_2x2_stride1(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); -void do_conv_3x3_stride1(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); -void do_conv_5x5_stride1(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); -void do_conv_7x7_stride1(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +void do_conv_2x2_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC); +void do_conv_3x3_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC); +void do_conv_5x5_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC); +void do_conv_7x7_stride1( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC); } // namespace conv_stride1 } // namespace fp32 } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp index 5682a658..24856f54 100644 --- a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.cpp @@ -13,8 +13,8 @@ #include "./do_conv_stride2.h" #include "midout.h" -#include "src/arm_common/simd_macro/neon_helper.h" #include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/arm_common/simd_macro/neon_helper.h" MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_convs2) @@ -26,10 +26,9 @@ using namespace conv_stride2; using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; - -void conv_stride2::do_conv_2x2_stride2(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, - size_t IC) { +void conv_stride2::do_conv_2x2_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; rep(ic, IC) { @@ -79,9 +78,9 @@ void conv_stride2::do_conv_2x2_stride2(const float* src, const float* filter, fl } } -void conv_stride2::do_conv_3x3_stride2(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, - size_t IC) { +void conv_stride2::do_conv_3x3_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; rep(ic, IC) { @@ -157,9 +156,9 @@ void conv_stride2::do_conv_3x3_stride2(const float* src, const float* filter, fl } } -void conv_stride2::do_conv_5x5_stride2(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, - size_t IC) { +void conv_stride2::do_conv_5x5_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; rep(ic, IC) { @@ -290,9 +289,9 @@ void conv_stride2::do_conv_5x5_stride2(const float* src, const float* filter, fl } } -void conv_stride2::do_conv_7x7_stride2(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, - size_t IC) { +void conv_stride2::do_conv_7x7_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - 2 * OW + IW; rep(ic, IC) { diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h index acd53821..74c28586 100644 --- a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h +++ b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride2.h @@ -16,14 +16,18 @@ namespace megdnn { namespace arm_common { namespace fp32 { namespace conv_stride2 { -void do_conv_2x2_stride2(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); -void do_conv_3x3_stride2(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); -void do_conv_5x5_stride2(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); -void do_conv_7x7_stride2(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC); +void do_conv_2x2_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC); +void do_conv_3x3_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC); +void do_conv_5x5_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC); +void do_conv_7x7_stride2( + const float* src, const float* filter, float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC); } // namespace conv_stride2 } // namespace fp32 } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp index 3dc28844..f454da13 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp @@ -22,15 +22,14 @@ using namespace megdnn; using namespace arm_common; using conv_fun = std::function; + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range)>; MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw44_stride1) namespace { -static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, - const int iw2) { +static inline size_t get_perthread_cache_bytes( + const int ic, const int ih2, const int iw2) { // border_size is used to avoid read illegal memory int border_size = 64 * 2; return ic * ih2 * iw2 * sizeof(float) + border_size; @@ -50,11 +49,10 @@ static void get_rectified_size( oh2 = oh; ow2 = ow; - int block_oh = l2_block_helper(param.nr_threads, oh, - ic * iw * sizeof(float) * stride_h); + int block_oh = + l2_block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * stride_h); ih2 = block_oh * stride_h + filter_h - stride_h; - iw2 = round_up(iw + 2 * static_cast(fm.padding[1]), - nr_elements_in_cacheline); + iw2 = round_up(iw + 2 * static_cast(fm.padding[1]), nr_elements_in_cacheline); } static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { @@ -68,10 +66,10 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { }; template -static void do_conv_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange&, const CpuNDRange&) { +static void do_conv_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange&, + const CpuNDRange&) { const int oh = kern_param.osz[0]; const int ow = kern_param.osz[1]; const int fh = kern_param.filter_meta.spatial[0]; @@ -94,37 +92,33 @@ static void do_conv_kern(const WorkspaceBundle& bundle, const int group_id = ncb_index.ndrange_id[1]; constexpr int oc_idx = 0; int oc_block = oc; - int oh_block = l2_block_helper(kern_param.nr_threads, oh2, - ic * iw * sizeof(float) * stride_h); + int oh_block = l2_block_helper( + kern_param.nr_threads, oh2, ic * iw * sizeof(float) * stride_h); const int oh_idx = ncb_index.ndrange_id[2]; const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); const int ih_real = oh_block_real * stride_h + fh - stride_h; const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0); const int src_bottom_pad = std::max( - (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, - 0); + (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, 0); const int remain_right_pad = std::max(iw2 - iw - pw, 0); - const int src_offset = - std::max(oh_idx * oh_block * stride_h - ph, 0) * iw * pack_c; + const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw * pack_c; const float* origin_sptr = static_cast(kern_param.src( batch_id, group_id, 0, 1, 1)) + src_offset; const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); - float* sptr = reinterpret_cast((int8_t*)bundle.get(0) + - ncb_index.thread_id * src_size); + float* sptr = reinterpret_cast( + (int8_t*)bundle.get(0) + ncb_index.thread_id * src_size); conv_bias::pack_src_fp32_nchw44( sptr, origin_sptr, ph, pw, remain_right_pad, ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, src_bottom_pad, ic, ih * iw); - const float* fptr = - kern_param.filter(group_id) + oc_idx * fh * fw * ic; + const float* fptr = kern_param.filter(group_id) + oc_idx * fh * fw * ic; float_t* dst = kern_param.dst(batch_id, group_id) + oh_idx * oh_block * ow * pack_c; - const int bias_offset = bias_mode == BiasMode::BIAS - ? oh_idx * oh_block * ow * pack_c - : oc_idx; + const int bias_offset = + bias_mode == BiasMode::BIAS ? oh_idx * oh_block * ow * pack_c : oc_idx; const float* bptr = kern_param.bias(batch_id, group_id, oc_idx, 1, pack_c) + bias_offset; @@ -138,8 +132,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle, } // namespace /* ===================== stride1 algo ===================== */ -bool ConvBiasImpl::AlgoF32DirectNCHW44::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoF32DirectNCHW44::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { auto&& fm = param.filter_meta; auto fh = fm.spatial[0]; int oc = fm.ocpg; @@ -161,16 +155,16 @@ bool ConvBiasImpl::AlgoF32DirectNCHW44::usable(const NCBKernSizeParam& param, size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw44_stride1, - midout_iv("AlgoF32DirectNCHW44::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_fp32_nchw44_stride1, + midout_iv("AlgoF32DirectNCHW44::get_workspace"_hash)) { return get_bundle(param).total_size_in_bytes(); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns( const NCBKernSizeParam& param) const { auto fm = param.filter_meta; const int batch = param.n; @@ -179,11 +173,12 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns( conv_fun do_conv_fun = nullptr; // NOTE: remain_w is not used to gen hash of midout for compatible with // shape runtime -#define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw44_stride1, \ - midout_iv(#filter #bias_mode #stride #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(filter, bias_mode, op, stride) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_fp32_nchw44_stride1, \ + midout_iv(#filter #bias_mode #stride #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); #define GET_STRIDE_PARAM(filter, bias_mode, op) \ @@ -268,17 +263,16 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns( int ic = param.filter_meta.icpg; int iw = param.isz[1]; int stride_h = param.filter_meta.stride[0]; - int oh_block = l2_block_helper(param.nr_threads, oh, - ic * iw * sizeof(float) * stride_h); - CpuNDRange ncb_range = {static_cast(batch), - static_cast(group), - static_cast(div_ceil(oh, oh_block))}; + int oh_block = + l2_block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * stride_h); + CpuNDRange ncb_range = { + static_cast(batch), static_cast(group), + static_cast(div_ceil(oh, oh_block))}; auto do_conv = [wbundle, do_conv_fun, ncb_range]( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { wbundle.set(kern_param.workspace_ptr); - do_conv_fun(wbundle, kern_param, ncb_index, ncb_index.ndrange_id, - ncb_range); + do_conv_fun(wbundle, kern_param, ncb_index, ncb_index.ndrange_id, ncb_range); }; ret_kerns.push_back({do_conv, ncb_range}); return ret_kerns; diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h index 50e77262..c25fd850 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_kern.h @@ -17,17 +17,15 @@ namespace arm_common { namespace conv_bias { template -void conv_direct_fp32_nchw44(const float* src, const float* filter, - const float* bias, float*, float* dst, - const int oc, const int ic, const int ih, - const int iw, const int oh, const int oh_block, - const int ow, const Op& op, const int, const int); +void conv_direct_fp32_nchw44( + const float* src, const float* filter, const float* bias, float*, float* dst, + const int oc, const int ic, const int ih, const int iw, const int oh, + const int oh_block, const int ow, const Op& op, const int, const int); template -void pack_src_fp32_nchw44(float* sptr_base, const float* sptr_origin, const int, - const int pw, const int pad_right, const int ih, - const int iw, const int iw2, const int pad_top, - const int pad_bottom, const int ic, - const int ic_stride); +void pack_src_fp32_nchw44( + float* sptr_base, const float* sptr_origin, const int, const int pw, + const int pad_right, const int ih, const int iw, const int iw2, + const int pad_top, const int pad_bottom, const int ic, const int ic_stride); } // namespace conv_bias } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp index 53eba711..fbfa91fa 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_algo.cpp @@ -23,14 +23,13 @@ using namespace megdnn; using namespace arm_common; using conv_fun = std::function; + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range)>; MIDOUT_DECL(megdnn_arm_common_conv_bias_fp32_nchw_nchw44) namespace { -static inline int block_helper(const int nthread, const int amount, - const int per_unit_bytes) { +static inline int block_helper( + const int nthread, const int amount, const int per_unit_bytes) { MEGDNN_MARK_USED_VAR(per_unit_bytes); const int block_per_thread = div_ceil(amount, nthread); const int best_block = 16; @@ -43,8 +42,8 @@ static inline int block_helper(const int nthread, const int amount, int block = max_loss > min_loss ? min_block : max_block; return block; } -static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, - const int iw2) { +static inline size_t get_perthread_cache_bytes( + const int ic, const int ih2, const int iw2) { // border_size is used to avoid read illegal memory int border_size = 64 * 2; return ic * ih2 * iw2 * sizeof(float) + border_size; @@ -84,10 +83,9 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { return {nullptr, {src_size * param.nr_threads, weight_size}}; }; -static inline void copy_pad_src(float* sptr_base, const float* sptr_origin, - int ph, int pw, int pad_right, int ih, int iw, - int iw2, int pad_top, int pad_bottom, int ic, - int ic_stride) { +static inline void copy_pad_src( + float* sptr_base, const float* sptr_origin, int ph, int pw, int pad_right, + int ih, int iw, int iw2, int pad_top, int pad_bottom, int ic, int ic_stride) { MEGDNN_MARK_USED_VAR(ph); rep(ic_idx, ic) { const float* sptr = sptr_origin + ic_idx * ic_stride; @@ -106,9 +104,9 @@ static inline void copy_pad_src(float* sptr_base, const float* sptr_origin, sptr_base += iw2 * pad_bottom; } } -static void pack_weight(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index) { +static void pack_weight( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) { const int group_id = ncb_index.ndrange_id[0]; int fh = kern_param.filter_meta.spatial[0]; int fw = kern_param.filter_meta.spatial[1]; @@ -116,19 +114,18 @@ static void pack_weight(const WorkspaceBundle& bundle, int ic = kern_param.filter_meta.icpg; int oc_block = oc; int oc_idx = 0; - const float* fptr = - kern_param.filter(group_id) + oc_idx * fh * fw * ic; + const float* fptr = kern_param.filter(group_id) + oc_idx * fh * fw * ic; auto packed_weight = reinterpret_cast(bundle.get(1)) + group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; - fp32_direct_nchw_nchw44::pack_weight_fp32_nchw_nchw44(fptr, packed_weight, - oc_block, fh, fw, ic); + fp32_direct_nchw_nchw44::pack_weight_fp32_nchw_nchw44( + fptr, packed_weight, oc_block, fh, fw, ic); } template -static void do_conv_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange&, const CpuNDRange&) { +static void do_conv_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange&, + const CpuNDRange&) { const int oh = kern_param.osz[0]; const int ow = kern_param.osz[1]; const int fh = kern_param.filter_meta.spatial[0]; @@ -157,34 +154,33 @@ static void do_conv_kern(const WorkspaceBundle& bundle, const int ih_real = oh_block_real * stride_h + fh - stride_h; const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0); const int src_bottom_pad = std::max( - (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, - 0); + (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, 0); const int remain_right_pad = std::max(iw2 - iw - pw, 0); const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw; const float* origin_sptr = static_cast(kern_param.src( batch_id, group_id, 0, 1, 1)) + src_offset; const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); - float* sptr = reinterpret_cast((int8_t*)bundle.get(0) + - ncb_index.thread_id * src_size); + float* sptr = reinterpret_cast( + (int8_t*)bundle.get(0) + ncb_index.thread_id * src_size); - copy_pad_src(sptr, origin_sptr, ph, pw, remain_right_pad, - ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, - src_bottom_pad, ic, ih * iw); + copy_pad_src( + sptr, origin_sptr, ph, pw, remain_right_pad, + ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, + src_bottom_pad, ic, ih * iw); // pack weight auto packed_weight = reinterpret_cast(bundle.get(1)) + group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; // get param float_t* dst = kern_param.dst(batch_id, group_id) + oh_idx * oh_block * ow * pack_c; - const float* bptr = - kern_param.bias(batch_id, group_id) + oc_idx; + const float* bptr = kern_param.bias(batch_id, group_id) + oc_idx; Op op; - fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44( - sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, - oh, oh_block_real, ow, op, ph, pw); + fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44< + bias_mode, Op, filter_size, stride>( + sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, + oh_block_real, ow, op, ph, pw); } } // namespace @@ -192,24 +188,23 @@ static void do_conv_kern(const WorkspaceBundle& bundle, bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy) const { return nchw_nchwxx_valid( - param.src_type.enumv(), param.filter_type.enumv(), - param.dst_type.enumv(), param.filter_meta, param.bias_mode, - param.nonlineMode); + param.src_type.enumv(), param.filter_type.enumv(), param.dst_type.enumv(), + param.filter_meta, param.bias_mode, param.nonlineMode); } size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44, - midout_iv("AlgoF32DirectNCHWNCHW44::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_fp32_nchw_nchw44, + midout_iv("AlgoF32DirectNCHWNCHW44::get_workspace"_hash)) { return get_bundle(param).total_size_in_bytes(); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoF32DirectNCHWNCHW44::dispatch_kerns( - const NCBKernSizeParam& param) const { +SmallVector ConvBiasImpl::AlgoF32DirectNCHWNCHW44:: + dispatch_kerns(const NCBKernSizeParam& param) const { auto fm = param.filter_meta; const int batch = param.n; const int group = fm.group; @@ -217,11 +212,12 @@ ConvBiasImpl::AlgoF32DirectNCHWNCHW44::dispatch_kerns( conv_fun do_conv_fun = nullptr; // NOTE: remain_w is not used to gen hash of midout for compatible with // shape runtime -#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44, \ - midout_iv(#stride #filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_fp32_nchw_nchw44, \ + midout_iv(#stride #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); #define GET_OP_PARAM(stride, filter, bias_mode) \ @@ -279,8 +275,9 @@ ConvBiasImpl::AlgoF32DirectNCHWNCHW44::dispatch_kerns( DISPATCH_CONV_KERN(2); break; default: - megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", - param.filter_meta.stride[0]) + megdnn_throw(ssprintf( + "Unsupport stride size %u for the first conv", + param.filter_meta.stride[0]) .c_str()); break; } @@ -296,21 +293,21 @@ ConvBiasImpl::AlgoF32DirectNCHWNCHW44::dispatch_kerns( SmallVector ret_kerns; int oh = param.osz[0]; int oh_block = block_helper(param.nr_threads, oh, 0); - auto do_pack_weight = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto do_pack_weight = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); pack_weight(bundle, kern_param, ncb_index); }; ret_kerns.push_back({do_pack_weight, {static_cast(group)}}); - CpuNDRange ncb_range = {static_cast(batch), - static_cast(group), - static_cast(div_ceil(oh, oh_block))}; + CpuNDRange ncb_range = { + static_cast(batch), static_cast(group), + static_cast(div_ceil(oh, oh_block))}; auto do_conv = [bundle, do_conv_fun, ncb_range]( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, - ncb_range); + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, ncb_range); }; ret_kerns.push_back({do_conv, ncb_range}); diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h index 1c69bbad..dc26275c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h @@ -22,10 +22,9 @@ namespace megdnn { namespace arm_common { namespace fp32_direct_nchw_nchw44 { -static inline void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, - float32_t* dst_ptr, - const int oc, const int kh, - const int kw, const int ic) { +static inline void pack_weight_fp32_nchw_nchw44( + const float32_t* in_ptr, float32_t* dst_ptr, const int oc, const int kh, + const int kw, const int ic) { constexpr int oc_step = 4; const int filter_oc_stride = kh * kw * ic; const int filter_ic_stride = kh * kw * oc_step; @@ -45,12 +44,11 @@ static inline void pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, } } template -void conv_direct_fp32_nchw_nchw44(const float32_t* src, const float32_t* filter, - const float32_t* bias, float32_t*, - float32_t* dst, const int oc, const int ic, - const int ih, const int iw, const int oh, - const int oh_block, const int ow, - const Op& op, const int, const int); +void conv_direct_fp32_nchw_nchw44( + const float32_t* src, const float32_t* filter, const float32_t* bias, + float32_t*, float32_t* dst, const int oc, const int ic, const int ih, + const int iw, const int oh, const int oh_block, const int ow, const Op& op, + const int, const int); } // namespace fp32_direct_nchw_nchw44 } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/fp32/filter_transform.h b/dnn/src/arm_common/conv_bias/fp32/filter_transform.h index cce7a55e..f263afbb 100644 --- a/dnn/src/arm_common/conv_bias/fp32/filter_transform.h +++ b/dnn/src/arm_common/conv_bias/fp32/filter_transform.h @@ -10,17 +10,17 @@ */ #pragma once +#include "megdnn/opr_param_defs.h" +#include "src/arm_common/conv_bias/fp32/helper.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/utils.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" -#include "src/arm_common/conv_bias/fp32/helper.h" -#include "megdnn/opr_param_defs.h" namespace megdnn { namespace arm_common { -template +template struct FilterTransform6X3 { #define FILTER_TRANSFORM(d, wd) \ do { \ @@ -40,9 +40,9 @@ struct FilterTransform6X3 { wd##7 = d##2; \ } while (0); - static void transform(const float* filter, float* filter_transform_buf, - float* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { + static void transform( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { // Gg * GT // G // 1.0000000 0.0000000 0.0000000 @@ -95,40 +95,38 @@ struct FilterTransform6X3 { #undef cb rep(i, alpha) rep(j, alpha) { if (format == param::MatrixMul::Format::DEFAULT) { - filter_transform_buf[(i * alpha + j) * OC * IC + - ic * OC + oc] = + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = transform_mid_buf[j * alpha + i]; } else { - filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * - 4 + - ocb * ICB * 4 * 4 + icb * 4 * 4 + - ic4 * 4 + oc4] = - transform_mid_buf[j * alpha + i]; + filter_transform_buf + [(i * alpha + j) * OCB * ICB * 4 * 4 + + ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + oc4] = + transform_mid_buf[j * alpha + i]; } } #else -#define cb(i) \ - do { \ - mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ - auto tmp0 = (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * \ - -0.2222222f; \ - auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * -0.2222222f; \ - mid_buf1[1] = tmp0 + tmp1; \ - mid_buf1[2] = tmp0 - tmp1; \ - tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \ - GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \ - tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f; \ - mid_buf1[3] = tmp0 + tmp1; \ - mid_buf1[4] = tmp0 - tmp1; \ - tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \ - GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \ - tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f; \ - mid_buf1[5] = tmp0 + tmp1; \ - mid_buf1[6] = tmp0 - tmp1; \ - mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ - mid_buf1 += 8; \ +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ + auto tmp0 = \ + (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * -0.2222222f; \ + auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * -0.2222222f; \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \ + GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \ + tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f; \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \ + GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \ + tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f; \ + mid_buf1[5] = tmp0 + tmp1; \ + mid_buf1[6] = tmp0 - tmp1; \ + mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 2); \ + mid_buf1 += 8; \ } while (0); #define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) @@ -139,15 +137,13 @@ struct FilterTransform6X3 { rep(i, alpha) rep(j, alpha) { if (format == param::MatrixMul::Format::DEFAULT) { - filter_transform_buf[(i * alpha + j) * OC * IC + - ic * OC + oc] = + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = transform_mid_buf[i * alpha + j]; } else { - filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * - 4 + - ocb * ICB * 4 * 4 + icb * 4 * 4 + - ic4 * 4 + oc4] = - transform_mid_buf[i * alpha + j]; + filter_transform_buf + [(i * alpha + j) * OCB * ICB * 4 * 4 + + ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + oc4] = + transform_mid_buf[i * alpha + j]; } } #endif diff --git a/dnn/src/arm_common/conv_bias/fp32/helper.h b/dnn/src/arm_common/conv_bias/fp32/helper.h index 6d13a44c..3c2b2f70 100644 --- a/dnn/src/arm_common/conv_bias/fp32/helper.h +++ b/dnn/src/arm_common/conv_bias/fp32/helper.h @@ -10,8 +10,8 @@ */ #pragma once -#include "src/common/unroll_macro.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" namespace megdnn { namespace arm_common { @@ -55,127 +55,89 @@ inline void transpose_4x4(const float* src, float* dst, int lda, int ldb) { #if MEGDNN_AARCH64 //! ret and a are type Vector -#define TRANSPOSE_8x8(a, ret) \ - do { \ - auto b0 = vzipq_f32(CONCAT(a, 0).value.val[0], \ - CONCAT(a, 1).value.val[0]); \ - auto b1 = vzipq_f32(CONCAT(a, 0).value.val[1], \ - CONCAT(a, 1).value.val[1]); \ - auto b2 = vzipq_f32(CONCAT(a, 2).value.val[0], \ - CONCAT(a, 3).value.val[0]); \ - auto b3 = vzipq_f32(CONCAT(a, 2).value.val[1], \ - CONCAT(a, 3).value.val[1]); \ - auto b4 = vzipq_f32(CONCAT(a, 4).value.val[0], \ - CONCAT(a, 5).value.val[0]); \ - auto b5 = vzipq_f32(CONCAT(a, 4).value.val[1], \ - CONCAT(a, 5).value.val[1]); \ - auto b6 = vzipq_f32(CONCAT(a, 6).value.val[0], \ - CONCAT(a, 7).value.val[0]); \ - auto b7 = vzipq_f32(CONCAT(a, 6).value.val[1], \ - CONCAT(a, 7).value.val[1]); \ - CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ - vreinterpretq_s64_f32(b2.val[0]))); \ - CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b4.val[0]), \ - vreinterpretq_s64_f32(b6.val[0]))); \ - CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ - vreinterpretq_s64_f32(b2.val[0]))); \ - CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b4.val[0]), \ - vreinterpretq_s64_f32(b6.val[0]))); \ - CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ - vreinterpretq_s64_f32(b2.val[1]))); \ - CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b4.val[1]), \ - vreinterpretq_s64_f32(b6.val[1]))); \ - CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b0.val[1]), \ - vreinterpretq_s64_f32(b2.val[1]))); \ - CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b4.val[1]), \ - vreinterpretq_s64_f32(b6.val[1]))); \ - CONCAT(ret, 4).value.val[0] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b1.val[0]), \ - vreinterpretq_s64_f32(b3.val[0]))); \ - CONCAT(ret, 4).value.val[1] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b5.val[0]), \ - vreinterpretq_s64_f32(b7.val[0]))); \ - CONCAT(ret, 5).value.val[0] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b1.val[0]), \ - vreinterpretq_s64_f32(b3.val[0]))); \ - CONCAT(ret, 5).value.val[1] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b5.val[0]), \ - vreinterpretq_s64_f32(b7.val[0]))); \ - CONCAT(ret, 6).value.val[0] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b1.val[1]), \ - vreinterpretq_s64_f32(b3.val[1]))); \ - CONCAT(ret, 6).value.val[1] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b5.val[1]), \ - vreinterpretq_s64_f32(b7.val[1]))); \ - CONCAT(ret, 7).value.val[0] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b1.val[1]), \ - vreinterpretq_s64_f32(b3.val[1]))); \ - CONCAT(ret, 7).value.val[1] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b5.val[1]), \ - vreinterpretq_s64_f32(b7.val[1]))); \ +#define TRANSPOSE_8x8(a, ret) \ + do { \ + auto b0 = vzipq_f32(CONCAT(a, 0).value.val[0], CONCAT(a, 1).value.val[0]); \ + auto b1 = vzipq_f32(CONCAT(a, 0).value.val[1], CONCAT(a, 1).value.val[1]); \ + auto b2 = vzipq_f32(CONCAT(a, 2).value.val[0], CONCAT(a, 3).value.val[0]); \ + auto b3 = vzipq_f32(CONCAT(a, 2).value.val[1], CONCAT(a, 3).value.val[1]); \ + auto b4 = vzipq_f32(CONCAT(a, 4).value.val[0], CONCAT(a, 5).value.val[0]); \ + auto b5 = vzipq_f32(CONCAT(a, 4).value.val[1], CONCAT(a, 5).value.val[1]); \ + auto b6 = vzipq_f32(CONCAT(a, 6).value.val[0], CONCAT(a, 7).value.val[0]); \ + auto b7 = vzipq_f32(CONCAT(a, 6).value.val[1], CONCAT(a, 7).value.val[1]); \ + CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b0.val[0]), vreinterpretq_s64_f32(b2.val[0]))); \ + CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b4.val[0]), vreinterpretq_s64_f32(b6.val[0]))); \ + CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b0.val[0]), vreinterpretq_s64_f32(b2.val[0]))); \ + CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b4.val[0]), vreinterpretq_s64_f32(b6.val[0]))); \ + CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b0.val[1]), vreinterpretq_s64_f32(b2.val[1]))); \ + CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b4.val[1]), vreinterpretq_s64_f32(b6.val[1]))); \ + CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b0.val[1]), vreinterpretq_s64_f32(b2.val[1]))); \ + CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b4.val[1]), vreinterpretq_s64_f32(b6.val[1]))); \ + CONCAT(ret, 4).value.val[0] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b1.val[0]), vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 4).value.val[1] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b5.val[0]), vreinterpretq_s64_f32(b7.val[0]))); \ + CONCAT(ret, 5).value.val[0] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b1.val[0]), vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 5).value.val[1] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b5.val[0]), vreinterpretq_s64_f32(b7.val[0]))); \ + CONCAT(ret, 6).value.val[0] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b1.val[1]), vreinterpretq_s64_f32(b3.val[1]))); \ + CONCAT(ret, 6).value.val[1] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b5.val[1]), vreinterpretq_s64_f32(b7.val[1]))); \ + CONCAT(ret, 7).value.val[0] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b1.val[1]), vreinterpretq_s64_f32(b3.val[1]))); \ + CONCAT(ret, 7).value.val[1] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b5.val[1]), vreinterpretq_s64_f32(b7.val[1]))); \ } while (0); -#define TRANSPOSE_8x3(a, ret) \ - auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ - auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ - auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ - auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ - CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ - vreinterpretq_s64_f32(b1.val[0]))); \ - CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b2.val[0]), \ - vreinterpretq_s64_f32(b3.val[0]))); \ - CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ - vreinterpretq_s64_f32(b1.val[0]))); \ - CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b2.val[0]), \ - vreinterpretq_s64_f32(b3.val[0]))); \ - CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ - vreinterpretq_s64_f32(b1.val[1]))); \ - CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b2.val[1]), \ - vreinterpretq_s64_f32(b3.val[1]))); +#define TRANSPOSE_8x3(a, ret) \ + auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ + auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ + auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ + auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ + CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b0.val[0]), vreinterpretq_s64_f32(b1.val[0]))); \ + CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b2.val[0]), vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b0.val[0]), vreinterpretq_s64_f32(b1.val[0]))); \ + CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b2.val[0]), vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b0.val[1]), vreinterpretq_s64_f32(b1.val[1]))); \ + CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b2.val[1]), vreinterpretq_s64_f32(b3.val[1]))); -#define TRANSPOSE_8x4(a, ret) \ - auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ - auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ - auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ - auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ - CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b0.val[0]), \ - vreinterpretq_s64_f32(b1.val[0]))); \ - CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b2.val[0]), \ - vreinterpretq_s64_f32(b3.val[0]))); \ - CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b0.val[0]), \ - vreinterpretq_s64_f32(b1.val[0]))); \ - CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b2.val[0]), \ - vreinterpretq_s64_f32(b3.val[0]))); \ - CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b0.val[1]), \ - vreinterpretq_s64_f32(b1.val[1]))); \ - CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64( \ - vzip1q_s64(vreinterpretq_s64_f32(b2.val[1]), \ - vreinterpretq_s64_f32(b3.val[1]))); \ - CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b0.val[1]), \ - vreinterpretq_s64_f32(b1.val[1]))); \ - CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64( \ - vzip2q_s64(vreinterpretq_s64_f32(b2.val[1]), \ - vreinterpretq_s64_f32(b3.val[1]))); +#define TRANSPOSE_8x4(a, ret) \ + auto b0 = vzipq_f32(CONCAT(a, 0).value, CONCAT(a, 1).value); \ + auto b1 = vzipq_f32(CONCAT(a, 2).value, CONCAT(a, 3).value); \ + auto b2 = vzipq_f32(CONCAT(a, 4).value, CONCAT(a, 5).value); \ + auto b3 = vzipq_f32(CONCAT(a, 6).value, CONCAT(a, 7).value); \ + CONCAT(ret, 0).value.val[0] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b0.val[0]), vreinterpretq_s64_f32(b1.val[0]))); \ + CONCAT(ret, 0).value.val[1] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b2.val[0]), vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 1).value.val[0] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b0.val[0]), vreinterpretq_s64_f32(b1.val[0]))); \ + CONCAT(ret, 1).value.val[1] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b2.val[0]), vreinterpretq_s64_f32(b3.val[0]))); \ + CONCAT(ret, 2).value.val[0] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b0.val[1]), vreinterpretq_s64_f32(b1.val[1]))); \ + CONCAT(ret, 2).value.val[1] = vreinterpretq_f32_s64(vzip1q_s64( \ + vreinterpretq_s64_f32(b2.val[1]), vreinterpretq_s64_f32(b3.val[1]))); \ + CONCAT(ret, 3).value.val[0] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b0.val[1]), vreinterpretq_s64_f32(b1.val[1]))); \ + CONCAT(ret, 3).value.val[1] = vreinterpretq_f32_s64(vzip2q_s64( \ + vreinterpretq_s64_f32(b2.val[1]), vreinterpretq_s64_f32(b3.val[1]))); #elif MEGDNN_ARMV7 #define TRANSPOSE_8x4(a, ret) \ diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy.h b/dnn/src/arm_common/conv_bias/fp32/strategy.h index 5a92f2de..e942074b 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy.h +++ b/dnn/src/arm_common/conv_bias/fp32/strategy.h @@ -18,29 +18,24 @@ namespace megdnn { namespace arm_common { namespace winograd { -MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, - winograd_2x3_4x4_f) +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, winograd_2x3_4x4_f) -MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, - winograd_6x3_1x1_f) +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 1, 1, winograd_6x3_1x1_f) -MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, - winograd_6x3_4x4_f) +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, winograd_6x3_4x4_f) -MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1, - winograd_5x4_1x1_f) +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1, winograd_5x4_1x1_f) -MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 5, 1, 1, - winograd_4x5_1x1_f) +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 5, 1, 1, winograd_4x5_1x1_f) -MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, - winograd_F23_mk4_f_nchw44) +MEGDNN_REG_WINOGRAD_STRATEGY( + float, float, float, float, 2, 3, 4, 4, winograd_F23_mk4_f_nchw44) -MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, - winograd_F63_mk4_f_nchw44) +MEGDNN_REG_WINOGRAD_STRATEGY( + float, float, float, float, 6, 3, 4, 4, winograd_F63_mk4_f_nchw44) -MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 7, 3, 4, 4, - winograd_F73_mk4_f_nchw44) +MEGDNN_REG_WINOGRAD_STRATEGY( + float, float, float, float, 7, 3, 4, 4, winograd_F73_mk4_f_nchw44) } // namespace winograd } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp index 3e0fb6f1..4b211e30 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_2x3_4x4.cpp @@ -16,9 +16,9 @@ #include "src/common/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" -#include "src/naive/matrix_mul/matrix_mul_helper.h" -#include "src/arm_common/elemwise_helper/op_unary.h" #include "src/arm_common/conv_bias/fp32/helper.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" #include "midout.h" MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F23) @@ -29,16 +29,15 @@ namespace { struct InputTransform2X3 { template - static void prepare(const float* input, float* patch, float* patchT, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t ic, size_t IC) { + static void prepare( + const float* input, float* patch, float* patchT, int ih_start, int iw_start, + size_t IH, size_t IW, size_t ic, size_t IC) { constexpr size_t alpha = 2 + 3 - 1; if (!(inner && ic + 4 < IC)) { memset(patch, 0, sizeof(float) * 4 * alpha * alpha); } if (inner) { - const float* input_ptr = - input + ic * IH * IW + ih_start * IW + iw_start; + const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; for (size_t ico = 0; ico < 4; ++ico) { if (ic + ico < IC) { auto v0 = vld1q_f32(input_ptr); @@ -78,14 +77,13 @@ struct InputTransform2X3 { transpose_4x4(patch + 12 * 1, patchT + 12 * 4, 16, 4); } - static void transform(const float* patchT, float* input_transform_buf, - size_t unit_idx, size_t nr_units_in_tile, size_t ic, - size_t IC) { + static void transform( + const float* patchT, float* input_transform_buf, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC) { constexpr size_t alpha = 2 + 3 - 1; // BT * d * B #define cb(m, n) \ - Vector d##m##n = \ - Vector::load(patchT + m * 4 * 4 + n * 4); + Vector d##m##n = Vector::load(patchT + m * 4 * 4 + n * 4); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -103,9 +101,9 @@ struct InputTransform2X3 { UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(m) \ - d##m##0 = t##m##0 - t##m##2; \ - d##m##1 = t##m##1 + t##m##2; \ +#define cb(m) \ + d##m##0 = t##m##0 - t##m##2; \ + d##m##1 = t##m##1 + t##m##2; \ d##m##2 = t##m##2 - t##m##1; \ d##m##3 = t##m##3 - t##m##1; @@ -114,10 +112,10 @@ struct InputTransform2X3 { size_t ICB = IC / 4; size_t icb = ic / 4; -#define cb(m, n) \ - d##m##n.save(input_transform_buf + \ - (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ - icb * nr_units_in_tile * 4 + unit_idx * 4); +#define cb(m, n) \ + d##m##n.save( \ + input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ + icb * nr_units_in_tile * 4 + unit_idx * 4); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) #undef cb } @@ -125,13 +123,11 @@ struct InputTransform2X3 { template struct OutputTransform2X3 { - static void transform(const float* output_transform_buf, const float* bias, - float* output, float* transform_mid_buf, - size_t oh_start, size_t ow_start, size_t OH, - size_t OW, size_t oc_start, size_t oc_end, - size_t oc_index, size_t unit_idx, - size_t nr_units_in_tile, const DType& src_dtype, - const DType& dst_dtype) { + static void transform( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { Op op(src_dtype, dst_dtype); //! AT * m * A constexpr size_t alpha = 2 + 3 - 1; @@ -139,11 +135,10 @@ struct OutputTransform2X3 { size_t oc = oc_start + oc_index; size_t OCB = (oc_end - oc_start) / 4; size_t ocb = oc_index / 4; - -#define cb(m, n) \ - auto v##m##n = Vector::load( \ - output_transform_buf + \ - (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ + +#define cb(m, n) \ + auto v##m##n = Vector::load( \ + output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ ocb * nr_units_in_tile * 4 + unit_idx * 4); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -206,10 +201,9 @@ namespace arm_common { namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_f) -void winograd_2x3_4x4_f::filter(const float* filter, - float* filter_transform_buf, - float* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { +void winograd_2x3_4x4_f::filter( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { constexpr int alpha = 2 + 3 - 1; //! G * g * GT float32x4_t g0{1.f, 0, 0, 0}, g1{0.5, 0.5, 0.5, 0}, g2{0.5, -0.5, 0.5, 0}, @@ -264,17 +258,18 @@ void winograd_2x3_4x4_f::filter(const float* filter, size_t icb = ic / 4; size_t ic4 = ic % 4; rep(i, alpha) rep(j, alpha) { - filter_transform_buf[(i * alpha + j) * OCB * ICB * 4 * 4 + - ocb * ICB * 4 * 4 + icb * 4 * 4 + ic4 * 4 + - oc4] = transform_mid_buf[i * alpha + j]; + filter_transform_buf + [(i * alpha + j) * OCB * ICB * 4 * 4 + ocb * ICB * 4 * 4 + + icb * 4 * 4 + ic4 * 4 + oc4] = + transform_mid_buf[i * alpha + j]; } } } -void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, - float* transform_mid_buf, size_t IH, size_t IW, - size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, size_t nr_units_in_tile) { +void winograd_2x3_4x4_f::input( + const float* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { megdnn_assert(IC % 4 == 0); constexpr int alpha = 3 + 2 - 1; @@ -292,30 +287,28 @@ void winograd_2x3_4x4_f::input(const float* input, float* input_transform_buf, int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { - InputTransform2X3::prepare(input, patch, patchT, ih_start, - iw_start, IH, IW, ic, IC); - InputTransform2X3::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + InputTransform2X3::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + InputTransform2X3::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } else { - InputTransform2X3::prepare(input, patch, patchT, - ih_start, iw_start, IH, IW, - ic, IC); - InputTransform2X3::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + InputTransform2X3::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + InputTransform2X3::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } } } } -void winograd_2x3_4x4_f::output(const float* output_transform_buf, - const float* bias, float* output, - float* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_2x3_4x4_f::output( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp index 8c977c36..12093fd0 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_4x5.cpp @@ -83,9 +83,9 @@ struct FilterTransform4X5 { wd##6 = tmp0 - tmp1; \ wd##7 = d##4; \ } while (0); - static void transform(const float* filter, float* filter_transform_buf, - float* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { + static void transform( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { // Gg * GT // G //[[ 1. 0. 0. 0. 0. ] @@ -139,8 +139,8 @@ struct FilterTransform4X5 { UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { - filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + - oc] = transform_mid_buf[j * alpha + i]; + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = + transform_mid_buf[j * alpha + i]; } } } @@ -167,25 +167,23 @@ struct InputTransform4X5 { wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ } while (0) -#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ - vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) -#define GET_VECTOR_LOW_ELEM(s, i, idx) \ - vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) +#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) +#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) template - static void transform(const float* input, float* input_transform_buf, - float* transform_mid_buf, int ih_start, int iw_start, - size_t ic, size_t IH, size_t IW, size_t IC, - size_t unit_idx, size_t nr_units_in_tile) { - // BTd * B - //([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ], - // [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ], - // [ 0. , -1. , 1. , 4.25, -4.25, -1. , 1. , 0. ], - // [ 0. , 2. , 4. , -2.5 , -5. , 0.5 , 1. , 0. ], - // [ 0. , -2. , 4. , 2.5 , -5. , -0.5 , 1. , 0. ], - // [ 0. , 0.5 , 0.25, -2.5 , -1.25, 2. , 1. , 0. ], - // [ 0. , -0.5 , 0.25, 2.5 , -1.25, -2. , 1. , 0. ], - // [ 0. , -1. , 0. , 5.25, 0. , -5.25, 0. , 1. ]])) + static void transform( + const float* input, float* input_transform_buf, float* transform_mid_buf, + int ih_start, int iw_start, size_t ic, size_t IH, size_t IW, size_t IC, + size_t unit_idx, size_t nr_units_in_tile) { + // BTd * B + //([[ 1. , 0. , -5.25, 0. , 5.25, 0. , -1. , 0. ], + // [ 0. , 1. , 1. , -4.25, -4.25, 1. , 1. , 0. ], + // [ 0. , -1. , 1. , 4.25, -4.25, -1. , 1. , 0. ], + // [ 0. , 2. , 4. , -2.5 , -5. , 0.5 , 1. , 0. ], + // [ 0. , -2. , 4. , 2.5 , -5. , -0.5 , 1. , 0. ], + // [ 0. , 0.5 , 0.25, -2.5 , -1.25, 2. , 1. , 0. ], + // [ 0. , -0.5 , 0.25, 2.5 , -1.25, -2. , 1. , 0. ], + // [ 0. , -1. , 0. , 5.25, 0. , -5.25, 0. , 1. ]])) constexpr size_t alpha = 4 + 5 - 1; if (!inner) { @@ -197,8 +195,7 @@ struct InputTransform4X5 { #undef cb if (inner) { - const float* input_ptr = - input + ic * IH * IW + ih_start * IW + iw_start; + const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; #define cb(i) d##i = Vector::load(input_ptr + IW * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb @@ -232,46 +229,42 @@ struct InputTransform4X5 { UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { - input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + - unit_idx * IC + ic] = - transform_mid_buf[j * alpha + i]; + input_transform_buf + [(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = + transform_mid_buf[j * alpha + i]; } #else -#define cb(i) \ - do { \ - mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ - GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ - 5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ - GET_VECTOR_LOW_ELEM(wd, i, 2)); \ - mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ - GET_VECTOR_LOW_ELEM(wd, i, 1) + \ - 5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ - GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ - auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \ - -5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ - GET_VECTOR_HIGH_ELEM(wd, i, 2); \ - auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \ - -2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \ - 0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \ - mid_buf1[3] = tmp0 + tmp1; \ - mid_buf1[4] = tmp0 - tmp1; \ - tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ - -4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ - GET_VECTOR_HIGH_ELEM(wd, i, 2); \ - tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ - GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \ - GET_VECTOR_HIGH_ELEM(wd, i, 1); \ - mid_buf1[1] = tmp0 + tmp1; \ - mid_buf1[2] = tmp0 - tmp1; \ - tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \ - GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \ - GET_VECTOR_HIGH_ELEM(wd, i, 2); \ - tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \ - GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \ - GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \ - mid_buf1[5] = tmp0 + tmp1; \ - mid_buf1[6] = tmp0 - tmp1; \ - mid_buf1 += 8; \ +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ + 5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ + GET_VECTOR_LOW_ELEM(wd, i, 2)); \ + mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + 5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ + auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \ + -5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + -2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \ + 0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ + -4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1); \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \ + GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \ + mid_buf1[5] = tmp0 + tmp1; \ + mid_buf1[6] = tmp0 - tmp1; \ + mid_buf1 += 8; \ } while (0); float* mid_buf1 = transform_mid_buf; @@ -280,9 +273,9 @@ struct InputTransform4X5 { #undef cb rep(i, alpha) rep(j, alpha) { - input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + - unit_idx * IC + ic] = - transform_mid_buf[i * alpha + j]; + input_transform_buf + [(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = + transform_mid_buf[i * alpha + j]; } #endif } @@ -298,13 +291,11 @@ struct InputTransform4X5 { } while (0) template struct OutputTransform4X5 { - static void transform(const float* output_transform_buf, const float* bias, - float* output, float* transform_mid_buf, - size_t oh_start, size_t ow_start, size_t OH, - size_t OW, size_t oc_start, size_t oc_end, - size_t oc_index, size_t unit_idx, - size_t nr_units_in_tile, const DType& src_dtype, - const DType& dst_dtype) { + static void transform( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { Op op(src_dtype, dst_dtype); //! AT * m * A // AT f45 @@ -318,10 +309,9 @@ struct OutputTransform4X5 { size_t OC = oc_end - oc_start; size_t oc = oc_start + oc_index; -#define cb(m, n) \ - transform_mid_buf[m * alpha + n] = \ - output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ - unit_idx * OC + oc_index]; +#define cb(m, n) \ + transform_mid_buf[m * alpha + n] = output_transform_buf \ + [(m * alpha + n) * nr_units_in_tile * OC + unit_idx * OC + oc_index]; UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb @@ -333,26 +323,20 @@ struct OutputTransform4X5 { #undef cb OUTPUT_TRANSFORM(m, s); -#define cb(i) \ - do { \ - auto add12 = \ - GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ - auto add34 = \ - GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ - auto add56 = \ - GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ - auto sub12 = \ - GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ - auto sub34 = \ - GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ - auto sub56 = \ - GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ - mid_buf1[0] = GET_VECTOR_LOW_ELEM(s, i, 0) + add12 + add34 + add56; \ - mid_buf1[1] = sub12 + sub34 * 0.5 + sub56 * 2.0; \ - mid_buf1[2] = add12 + add34 * 0.25 + add56 * 4.0; \ - mid_buf1[3] = sub12 + sub34 * 0.125 + sub56 * 8.0 + \ - GET_VECTOR_HIGH_ELEM(s, i, 3); \ - mid_buf1 += 4; \ +#define cb(i) \ + do { \ + auto add12 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto add34 = GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto add56 = GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ + auto sub12 = GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto sub34 = GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto sub56 = GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ + mid_buf1[0] = GET_VECTOR_LOW_ELEM(s, i, 0) + add12 + add34 + add56; \ + mid_buf1[1] = sub12 + sub34 * 0.5 + sub56 * 2.0; \ + mid_buf1[2] = add12 + add34 * 0.25 + add56 * 4.0; \ + mid_buf1[3] = \ + sub12 + sub34 * 0.125 + sub56 * 8.0 + GET_VECTOR_HIGH_ELEM(s, i, 3); \ + mid_buf1 += 4; \ } while (0); mid_buf1 = transform_mid_buf; @@ -409,18 +393,17 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_4x5_1x1_f) -void winograd_4x5_1x1_f::filter(const float* filter, - float* filter_transform_buf, - float* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { - FilterTransform4X5::transform(filter, filter_transform_buf, - transform_mid_buf, OC, IC, oc_start, oc_end); +void winograd_4x5_1x1_f::filter( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { + FilterTransform4X5::transform( + filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, oc_end); } -void winograd_4x5_1x1_f::input(const float* input, float* input_transform_buf, - float* transform_mid_buf, size_t IH, size_t IW, - size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, size_t nr_units_in_tile) { +void winograd_4x5_1x1_f::input( + const float* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { constexpr int alpha = 4 + 5 - 1; // OW = IW + 2 * PW - KERNEL_SIZE + 1 @@ -447,12 +430,11 @@ void winograd_4x5_1x1_f::input(const float* input, float* input_transform_buf, } } -void winograd_4x5_1x1_f::output(const float* output_transform_buf, - const float* bias, float* output, - float* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_4x5_1x1_f::output( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform4X5<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp index b502602e..dac78b2e 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_5x4.cpp @@ -46,9 +46,9 @@ struct FilterTransform5X4 { wd##7 = d##3; \ } while (0) - static void transform(const float* filter, float* filter_transform_buf, - float* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { + static void transform( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { // Gg * GT // G // 1 0 0 0 @@ -90,34 +90,32 @@ struct FilterTransform5X4 { UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { - filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + - oc] = transform_mid_buf[j * alpha + i]; + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = + transform_mid_buf[j * alpha + i]; } #else -#define cb(i) \ - do { \ - mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ - auto tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \ - GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \ - auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f + \ - GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \ - mid_buf1[1] = tmp0 + tmp1; \ - mid_buf1[2] = tmp0 - tmp1; \ - tmp0 = (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * \ - -0.2222222f; \ - tmp1 = (GET_VECTOR_ELEM(wd, i, 1) + GET_VECTOR_ELEM(wd, i, 3)) * \ - -0.2222222f; \ - mid_buf1[3] = tmp0 + tmp1; \ - mid_buf1[4] = tmp0 - tmp1; \ - tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \ - GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \ - tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f + \ - GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \ - mid_buf1[5] = tmp0 + tmp1; \ - mid_buf1[6] = tmp0 - tmp1; \ - mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ - mid_buf1 += 8; \ +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_ELEM(wd, i, 0); \ + auto tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.7111111f + \ + GET_VECTOR_ELEM(wd, i, 2) * 0.1777778f; \ + auto tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.3555556f + \ + GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = (GET_VECTOR_ELEM(wd, i, 0) + GET_VECTOR_ELEM(wd, i, 2)) * -0.2222222f; \ + tmp1 = (GET_VECTOR_ELEM(wd, i, 1) + GET_VECTOR_ELEM(wd, i, 3)) * -0.2222222f; \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_ELEM(wd, i, 0) * 0.0111111f + \ + GET_VECTOR_ELEM(wd, i, 2) * 0.0444444f; \ + tmp1 = GET_VECTOR_ELEM(wd, i, 1) * 0.0222222f + \ + GET_VECTOR_ELEM(wd, i, 3) * 0.0888889f; \ + mid_buf1[5] = tmp0 + tmp1; \ + mid_buf1[6] = tmp0 - tmp1; \ + mid_buf1[7] = GET_VECTOR_ELEM(wd, i, 3); \ + mid_buf1 += 8; \ } while (0); #define GET_VECTOR_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value, idx) @@ -126,8 +124,8 @@ struct FilterTransform5X4 { mid_buf1 = transform_mid_buf; #undef cb rep(i, alpha) rep(j, alpha) { - filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + - oc] = transform_mid_buf[i * alpha + j]; + filter_transform_buf[(i * alpha + j) * OC * IC + ic * OC + oc] = + transform_mid_buf[i * alpha + j]; } #endif } @@ -156,16 +154,14 @@ struct InputTransform5X4 { wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ } while (0) -#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ - vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) -#define GET_VECTOR_LOW_ELEM(s, i, idx) \ - vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) +#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) +#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) template - static void transform(const float* input, float* input_transform_buf, - float* transform_mid_buf, int ih_start, int iw_start, - size_t ic, size_t IH, size_t IW, size_t IC, - size_t unit_idx, size_t nr_units_in_tile) { + static void transform( + const float* input, float* input_transform_buf, float* transform_mid_buf, + int ih_start, int iw_start, size_t ic, size_t IH, size_t IW, size_t IC, + size_t unit_idx, size_t nr_units_in_tile) { // BTd * B // BT // 1 0 -5.25 0 5.25 0 -1 0 @@ -187,8 +183,7 @@ struct InputTransform5X4 { #undef cb if (inner) { - const float* input_ptr = - input + ic * IH * IW + ih_start * IW + iw_start; + const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; #define cb(i) d##i = Vector::load(input_ptr + IW * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb @@ -222,46 +217,42 @@ struct InputTransform5X4 { UNROLL_CALL_NOWRAPPER(8, cb); #undef cb rep(i, alpha) rep(j, alpha) { - input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + - unit_idx * IC + ic] = - transform_mid_buf[j * alpha + i]; + input_transform_buf + [(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = + transform_mid_buf[j * alpha + i]; } #else -#define cb(i) \ - do { \ - mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ - GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ - 5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ - GET_VECTOR_LOW_ELEM(wd, i, 2)); \ - mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ - GET_VECTOR_LOW_ELEM(wd, i, 1) + \ - 5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ - GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ - auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \ - -5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ - GET_VECTOR_HIGH_ELEM(wd, i, 2); \ - auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \ - -2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \ - 0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \ - mid_buf1[1] = tmp0 + tmp1; \ - mid_buf1[2] = tmp0 - tmp1; \ - tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ - -4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ - GET_VECTOR_HIGH_ELEM(wd, i, 2); \ - tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ - GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \ - GET_VECTOR_HIGH_ELEM(wd, i, 1); \ - mid_buf1[3] = tmp0 + tmp1; \ - mid_buf1[4] = tmp0 - tmp1; \ - tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \ - GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \ - GET_VECTOR_HIGH_ELEM(wd, i, 2); \ - tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \ - GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \ - GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \ - mid_buf1[5] = tmp0 + tmp1; \ - mid_buf1[6] = tmp0 - tmp1; \ - mid_buf1 += 8; \ +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ + 5.25 * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ + GET_VECTOR_LOW_ELEM(wd, i, 2)); \ + mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + 5.25 * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ + auto tmp0 = 4 * GET_VECTOR_LOW_ELEM(wd, i, 2) + \ + -5 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + \ + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + auto tmp1 = 2 * GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + -2.5 * GET_VECTOR_LOW_ELEM(wd, i, 3) + \ + 0.5 * GET_VECTOR_HIGH_ELEM(wd, i, 1); \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ + -4.25 * GET_VECTOR_HIGH_ELEM(wd, i, 0) + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + GET_VECTOR_LOW_ELEM(wd, i, 3) * -4.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1); \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) * 0.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 0) * -1.25 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 2); \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5 + \ + GET_VECTOR_LOW_ELEM(wd, i, 3) * -2.5 + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2; \ + mid_buf1[5] = tmp0 + tmp1; \ + mid_buf1[6] = tmp0 - tmp1; \ + mid_buf1 += 8; \ } while (0); float* mid_buf1 = transform_mid_buf; @@ -270,9 +261,9 @@ struct InputTransform5X4 { #undef cb rep(i, alpha) rep(j, alpha) { - input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + - unit_idx * IC + ic] = - transform_mid_buf[i * alpha + j]; + input_transform_buf + [(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = + transform_mid_buf[i * alpha + j]; } #endif } @@ -301,13 +292,11 @@ struct InputTransform5X4 { template struct OutputTransform5X4 { - static void transform(const float* output_transform_buf, const float* bias, - float* output, float* transform_mid_buf, - size_t oh_start, size_t ow_start, size_t OH, - size_t OW, size_t oc_start, size_t oc_end, - size_t oc_index, size_t unit_idx, - size_t nr_units_in_tile, const DType& src_dtype, - const DType& dst_dtype) { + static void transform( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { Op op(src_dtype, dst_dtype); //! AT * m * A // AT @@ -322,10 +311,9 @@ struct OutputTransform5X4 { size_t OC = oc_end - oc_start; size_t oc = oc_start + oc_index; -#define cb(m, n) \ - transform_mid_buf[m * alpha + n] = \ - output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ - unit_idx * OC + oc_index]; +#define cb(m, n) \ + transform_mid_buf[m * alpha + n] = output_transform_buf \ + [(m * alpha + n) * nr_units_in_tile * OC + unit_idx * OC + oc_index]; UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb @@ -337,28 +325,21 @@ struct OutputTransform5X4 { #undef cb OUTPUT_TRANSFORM(m, s); -#define cb(i) \ - do { \ - auto m1addm2 = \ - GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ - auto m1subm2 = \ - GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ - auto m3addm4 = \ - GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ - auto m3subm4 = \ - GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ - auto m5addm6 = \ - GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ - auto m5subm6 = \ - GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ - mid_buf1[0] = \ - GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ - mid_buf1[1] = 0.5f * m1subm2 + m3subm4 + 2.0f * m5subm6; \ - mid_buf1[2] = 0.25f * m1addm2 + m3addm4 + 4.0f * m5addm6; \ - mid_buf1[3] = 0.125f * m1subm2 + m3subm4 + 8.0f * m5subm6; \ - mid_buf1[4] = 0.0625f * m1addm2 + m3addm4 + 16.0f * m5addm6 + \ - GET_VECTOR_HIGH_ELEM(s, i, 3); \ - mid_buf1 += 5; \ +#define cb(i) \ + do { \ + auto m1addm2 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto m1subm2 = GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto m3addm4 = GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto m3subm4 = GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto m5addm6 = GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ + auto m5subm6 = GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ + mid_buf1[0] = GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ + mid_buf1[1] = 0.5f * m1subm2 + m3subm4 + 2.0f * m5subm6; \ + mid_buf1[2] = 0.25f * m1addm2 + m3addm4 + 4.0f * m5addm6; \ + mid_buf1[3] = 0.125f * m1subm2 + m3subm4 + 8.0f * m5subm6; \ + mid_buf1[4] = 0.0625f * m1addm2 + m3addm4 + 16.0f * m5addm6 + \ + GET_VECTOR_HIGH_ELEM(s, i, 3); \ + mid_buf1 += 5; \ } while (0); mid_buf1 = transform_mid_buf; @@ -424,18 +405,17 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_5x4_1x1_f) -void winograd_5x4_1x1_f::filter(const float* filter, - float* filter_transform_buf, - float* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { - FilterTransform5X4::transform(filter, filter_transform_buf, - transform_mid_buf, OC, IC, oc_start, oc_end); +void winograd_5x4_1x1_f::filter( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { + FilterTransform5X4::transform( + filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, oc_end); } -void winograd_5x4_1x1_f::input(const float* input, float* input_transform_buf, - float* transform_mid_buf, size_t IH, size_t IW, - size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, size_t nr_units_in_tile) { +void winograd_5x4_1x1_f::input( + const float* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { constexpr int alpha = 5 + 4 - 1; // OW = IW + 2 * PW - KERNEL_SIZE + 1 @@ -463,13 +443,11 @@ void winograd_5x4_1x1_f::input(const float* input, float* input_transform_buf, } } -void winograd_5x4_1x1_f::output(const float* output_transform_buf, - const float* bias, float* output, - float* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, - size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_5x4_1x1_f::output( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform5X4<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp index 15489b35..f0d958d5 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_6x3.cpp @@ -57,16 +57,14 @@ namespace { wd##7 = (d##7 - d##1) + (d##3 - d##5) * 5.25f; \ } while (0); -#define GET_VECTOR_HIGH_ELEM(s, i, idx) \ - vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) -#define GET_VECTOR_LOW_ELEM(s, i, idx) \ - vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) +#define GET_VECTOR_HIGH_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[1], idx) +#define GET_VECTOR_LOW_ELEM(s, i, idx) vgetq_lane_f32(CONCAT(s, i).value.val[0], idx) struct InputTransform6X3 { template - static void transform(const float* input, float* input_transform_buf, - float* transform_mid_buf, int ih_start, int iw_start, - size_t ic, size_t IH, size_t IW, size_t IC, - size_t unit_idx, size_t nr_units_in_tile) { + static void transform( + const float* input, float* input_transform_buf, float* transform_mid_buf, + int ih_start, int iw_start, size_t ic, size_t IH, size_t IW, size_t IC, + size_t unit_idx, size_t nr_units_in_tile) { constexpr size_t alpha = 6 + 3 - 1; if (!inner) { memset(transform_mid_buf, 0, sizeof(float) * alpha * alpha); @@ -77,8 +75,7 @@ struct InputTransform6X3 { #undef cb if (inner) { - const float* input_ptr = - input + ic * IH * IW + ih_start * IW + iw_start; + const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; #define cb(i) d##i = Vector::load(input_ptr + IW * i); UNROLL_CALL_NOWRAPPER(8, cb); #undef cb @@ -114,9 +111,9 @@ struct InputTransform6X3 { #undef cb rep(i, alpha) rep(j, alpha) { - input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + - unit_idx * IC + ic] = - transform_mid_buf[j * alpha + i]; + input_transform_buf + [(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = + transform_mid_buf[j * alpha + i]; } #else //! 1 0 0 0 0 0 0 0 @@ -127,42 +124,38 @@ struct InputTransform6X3 { //! 0 1 -1 2 -2 0.5 -0.5 -5.25 //! -1 1 1 1 1 1 1 0 //! 0 0 0 0 0 0 0 1 -#define cb(i) \ - do { \ - mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - \ - GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ - 5.25f * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ - GET_VECTOR_LOW_ELEM(wd, i, 2)); \ - mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - \ - GET_VECTOR_LOW_ELEM(wd, i, 1) + \ - 5.25f * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ - GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ - auto tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + \ - GET_VECTOR_HIGH_ELEM(wd, i, 2) - \ - 4.25f * GET_VECTOR_HIGH_ELEM(wd, i, 0); \ - auto tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + \ - GET_VECTOR_HIGH_ELEM(wd, i, 1) - \ - 4.25f * GET_VECTOR_LOW_ELEM(wd, i, 3); \ - mid_buf1[1] = tmp0 + tmp1; \ - mid_buf1[2] = tmp0 - tmp1; \ - tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ - 0.25f * GET_VECTOR_LOW_ELEM(wd, i, 2) - \ - GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f; \ - tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5f - \ - GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \ - GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2.f; \ - mid_buf1[3] = tmp0 + tmp1; \ - mid_buf1[4] = tmp0 - tmp1; \ - tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ - (GET_VECTOR_LOW_ELEM(wd, i, 2) - \ - GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f) * \ - 4; \ - tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 2.f - \ - GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \ - GET_VECTOR_HIGH_ELEM(wd, i, 1) * 0.5f; \ - mid_buf1[5] = tmp0 + tmp1; \ - mid_buf1[6] = tmp0 - tmp1; \ - mid_buf1 += 8; \ +#define cb(i) \ + do { \ + mid_buf1[0] = GET_VECTOR_LOW_ELEM(wd, i, 0) - GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ + 5.25f * (GET_VECTOR_HIGH_ELEM(wd, i, 0) - \ + GET_VECTOR_LOW_ELEM(wd, i, 2)); \ + mid_buf1[7] = GET_VECTOR_HIGH_ELEM(wd, i, 3) - GET_VECTOR_LOW_ELEM(wd, i, 1) + \ + 5.25f * (GET_VECTOR_LOW_ELEM(wd, i, 3) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 1)); \ + auto tmp0 = GET_VECTOR_LOW_ELEM(wd, i, 2) + GET_VECTOR_HIGH_ELEM(wd, i, 2) - \ + 4.25f * GET_VECTOR_HIGH_ELEM(wd, i, 0); \ + auto tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) + GET_VECTOR_HIGH_ELEM(wd, i, 1) - \ + 4.25f * GET_VECTOR_LOW_ELEM(wd, i, 3); \ + mid_buf1[1] = tmp0 + tmp1; \ + mid_buf1[2] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ + 0.25f * GET_VECTOR_LOW_ELEM(wd, i, 2) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f; \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 0.5f - \ + GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1) * 2.f; \ + mid_buf1[3] = tmp0 + tmp1; \ + mid_buf1[4] = tmp0 - tmp1; \ + tmp0 = GET_VECTOR_HIGH_ELEM(wd, i, 2) + \ + (GET_VECTOR_LOW_ELEM(wd, i, 2) - \ + GET_VECTOR_HIGH_ELEM(wd, i, 0) * 1.25f) * \ + 4; \ + tmp1 = GET_VECTOR_LOW_ELEM(wd, i, 1) * 2.f - \ + GET_VECTOR_LOW_ELEM(wd, i, 3) * 2.5f + \ + GET_VECTOR_HIGH_ELEM(wd, i, 1) * 0.5f; \ + mid_buf1[5] = tmp0 + tmp1; \ + mid_buf1[6] = tmp0 - tmp1; \ + mid_buf1 += 8; \ } while (0); float* mid_buf1 = transform_mid_buf; @@ -171,9 +164,9 @@ struct InputTransform6X3 { #undef cb rep(i, alpha) rep(j, alpha) { - input_transform_buf[(i * alpha + j) * nr_units_in_tile * IC + - unit_idx * IC + ic] = - transform_mid_buf[i * alpha + j]; + input_transform_buf + [(i * alpha + j) * nr_units_in_tile * IC + unit_idx * IC + ic] = + transform_mid_buf[i * alpha + j]; } #endif } @@ -215,13 +208,11 @@ struct InputTransform6X3 { template struct OutputTransform6X3 { - static void transform(const float* output_transform_buf, const float* bias, - float* output, float* transform_mid_buf, - size_t oh_start, size_t ow_start, size_t OH, - size_t OW, size_t oc_start, size_t oc_end, - size_t oc_index, size_t unit_idx, - size_t nr_units_in_tile, const DType& src_dtype, - const DType& dst_dtype) { + static void transform( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { constexpr size_t alpha = 6 + 3 - 1; Op op(src_dtype, dst_dtype); float* mid_buf1 = transform_mid_buf; @@ -230,10 +221,9 @@ struct OutputTransform6X3 { size_t OC = oc_end - oc_start; size_t oc = oc_start + oc_index; -#define cb(m, n) \ - transform_mid_buf[m * alpha + n] = \ - output_transform_buf[(m * alpha + n) * nr_units_in_tile * OC + \ - unit_idx * OC + oc_index]; +#define cb(m, n) \ + transform_mid_buf[m * alpha + n] = output_transform_buf \ + [(m * alpha + n) * nr_units_in_tile * OC + unit_idx * OC + oc_index]; UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb @@ -257,29 +247,22 @@ struct OutputTransform6X3 { * 1 -0.5 0.25 -0.125 0.0625 -0.03125 * 0 0.0 0 0 0 1 */ -#define cb(i) \ - do { \ - auto m1addm2 = \ - GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ - auto m1subm2 = \ - GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ - auto m3addm4 = \ - GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ - auto m3subm4 = \ - GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ - auto m5addm6 = \ - GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ - auto m5subm6 = \ - GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ - mid_buf1[0] = \ - GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ - mid_buf1[1] = m1subm2 + 2.f * m3subm4 + 0.5f * m5subm6; \ - mid_buf1[2] = m1addm2 + 4.f * m3addm4 + 0.25f * m5addm6; \ - mid_buf1[3] = m1subm2 + 8.f * m3subm4 + 0.125f * m5subm6; \ - mid_buf1[4] = m1addm2 + 16.f * m3addm4 + 0.0625f * m5addm6; \ - mid_buf1[5] = m1subm2 + 32.f * m3subm4 + 0.03125f * m5subm6 + \ - GET_VECTOR_HIGH_ELEM(s, i, 3); \ - mid_buf1 += 6; \ +#define cb(i) \ + do { \ + auto m1addm2 = GET_VECTOR_LOW_ELEM(s, i, 1) + GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto m1subm2 = GET_VECTOR_LOW_ELEM(s, i, 1) - GET_VECTOR_LOW_ELEM(s, i, 2); \ + auto m3addm4 = GET_VECTOR_LOW_ELEM(s, i, 3) + GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto m3subm4 = GET_VECTOR_LOW_ELEM(s, i, 3) - GET_VECTOR_HIGH_ELEM(s, i, 0); \ + auto m5addm6 = GET_VECTOR_HIGH_ELEM(s, i, 1) + GET_VECTOR_HIGH_ELEM(s, i, 2); \ + auto m5subm6 = GET_VECTOR_HIGH_ELEM(s, i, 1) - GET_VECTOR_HIGH_ELEM(s, i, 2); \ + mid_buf1[0] = GET_VECTOR_LOW_ELEM(s, i, 0) + m1addm2 + m3addm4 + m5addm6; \ + mid_buf1[1] = m1subm2 + 2.f * m3subm4 + 0.5f * m5subm6; \ + mid_buf1[2] = m1addm2 + 4.f * m3addm4 + 0.25f * m5addm6; \ + mid_buf1[3] = m1subm2 + 8.f * m3subm4 + 0.125f * m5subm6; \ + mid_buf1[4] = m1addm2 + 16.f * m3addm4 + 0.0625f * m5addm6; \ + mid_buf1[5] = m1subm2 + 32.f * m3subm4 + 0.03125f * m5subm6 + \ + GET_VECTOR_HIGH_ELEM(s, i, 3); \ + mid_buf1 += 6; \ } while (0); mid_buf1 = transform_mid_buf; @@ -304,8 +287,7 @@ struct OutputTransform6X3 { item1 = vadd_f32(item1, bias1); } else if (bmode == BiasMode::BIAS) { bias0 = vld1q_f32(bias + oc * OH * OW + oh * OW + ow_start); - bias1 = vld1_f32(bias + oc * OH * OW + oh * OW + ow_start + - 4); + bias1 = vld1_f32(bias + oc * OH * OW + oh * OW + ow_start + 4); item0 = vaddq_f32(item0, bias0); item1 = vadd_f32(item1, bias1); } @@ -348,19 +330,17 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_1x1_f) -void winograd_6x3_1x1_f::filter(const float* filter, - float* filter_transform_buf, - float* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { +void winograd_6x3_1x1_f::filter( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { FilterTransform6X3::transform( - filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, - oc_end); + filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, oc_end); } -void winograd_6x3_1x1_f::input(const float* input, float* input_transform_buf, - float* transform_mid_buf, size_t IH, size_t IW, - size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, size_t nr_units_in_tile) { +void winograd_6x3_1x1_f::input( + const float* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { constexpr int alpha = 3 + 6 - 1; // OW = IW + 2 * PW - KERNEL_SIZE + 1 @@ -387,13 +367,11 @@ void winograd_6x3_1x1_f::input(const float* input, float* input_transform_buf, } } -void winograd_6x3_1x1_f::output(const float* output_transform_buf, - const float* bias, float* output, - float* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, - size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_6x3_1x1_f::output( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp index 6d1011c6..15bba172 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_6x3_4x4.cpp @@ -30,16 +30,15 @@ namespace { struct InputTransform6X3 { template - static void prepare(const float* input, float* patch, float* patchT, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t ic, size_t IC) { + static void prepare( + const float* input, float* patch, float* patchT, int ih_start, int iw_start, + size_t IH, size_t IW, size_t ic, size_t IC) { constexpr size_t alpha = 6 + 3 - 1; if (!(inner && ic + 4 < IC)) { memset(patch, 0, sizeof(float) * 4 * alpha * alpha); } if (inner) { - const float* input_ptr = - input + ic * IH * IW + ih_start * IW + iw_start; + const float* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; for (size_t ico = 0; ico < 4; ++ico) { if (ic + ico < IC) { #define cb(i) \ @@ -84,14 +83,13 @@ struct InputTransform6X3 { #undef cb } - static void transform(const float* patchT, float* input_transform_buf, - size_t unit_idx, size_t nr_units_in_tile, size_t ic, - size_t IC) { + static void transform( + const float* patchT, float* input_transform_buf, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC) { constexpr size_t alpha = 6 + 3 - 1; // BT * d * B #define cb(m, n) \ - Vector d##m##n = \ - Vector::load(patchT + m * 8 * 4 + n * 4); + Vector d##m##n = Vector::load(patchT + m * 8 * 4 + n * 4); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb @@ -105,37 +103,35 @@ struct InputTransform6X3 { //! 0 1 -1 2 -2 0.5 -0.5 -5.25 //! -1 1 1 1 1 1 1 0 //! 0 0 0 0 0 0 0 1 -#define cb(m) \ - auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ - auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ - auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ - auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ - d5##m * 2.f + d6##m; \ - auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \ - d4##m * 1.25f - d5##m * 2.f + d6##m; \ - auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ - d5##m * 0.5f + d6##m; \ - auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ - d5##m * 0.5f + d6##m; \ +#define cb(m) \ + auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ + auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ + auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ + auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ + d5##m * 2.f + d6##m; \ + auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - d4##m * 1.25f - \ + d5##m * 2.f + d6##m; \ + auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ + d5##m * 0.5f + d6##m; \ + auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ + d5##m * 0.5f + d6##m; \ auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; UNROLL_CALL_NOWRAPPER(8, cb); #undef cb -#define cb(m) \ - d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ - d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \ - (t##m##3 + t##m##4) * 4.25f; \ - d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \ - (t##m##3 - t##m##4) * 4.25f; \ - d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \ - t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \ - d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \ - t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \ - d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ - t##m##5 * 0.5f + t##m##6; \ - d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \ - t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \ +#define cb(m) \ + d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ + d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - (t##m##3 + t##m##4) * 4.25f; \ + d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + (t##m##3 - t##m##4) * 4.25f; \ + d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - t##m##4 * 1.25f + \ + t##m##5 * 2.f + t##m##6; \ + d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - t##m##4 * 1.25f - \ + t##m##5 * 2.f + t##m##6; \ + d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ + t##m##5 * 0.5f + t##m##6; \ + d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - t##m##4 * 5.f - \ + t##m##5 * 0.5f + t##m##6; \ d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; UNROLL_CALL_NOWRAPPER(8, cb); @@ -143,10 +139,10 @@ struct InputTransform6X3 { size_t ICB = IC / 4; size_t icb = ic / 4; -#define cb(m, n) \ - d##m##n.save(input_transform_buf + \ - (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ - icb * nr_units_in_tile * 4 + unit_idx * 4); +#define cb(m, n) \ + d##m##n.save( \ + input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ + icb * nr_units_in_tile * 4 + unit_idx * 4); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) #undef cb } @@ -154,13 +150,11 @@ struct InputTransform6X3 { template struct OutputTransform6X3 { - static void transform(const float* output_transform_buf, const float* bias, - float* output, float* transform_mid_buf, - size_t oh_start, size_t ow_start, size_t OH, - size_t OW, size_t oc_start, size_t oc_end, - size_t oc_index, size_t unit_idx, - size_t nr_units_in_tile, const DType& src_dtype, - const DType& dst_dtype) { + static void transform( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { Op op(src_dtype, dst_dtype); //! AT * m * A constexpr size_t alpha = 6 + 3 - 1; @@ -169,10 +163,9 @@ struct OutputTransform6X3 { size_t OCB = (oc_end - oc_start) / 4; size_t ocb = oc_index / 4; -#define cb(m, n) \ - auto v##m##n = Vector::load( \ - output_transform_buf + \ - (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ +#define cb(m, n) \ + auto v##m##n = Vector::load( \ + output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ ocb * nr_units_in_tile * 4 + unit_idx * 4); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); #undef cb @@ -267,19 +260,17 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_6x3_4x4_f) -void winograd_6x3_4x4_f::filter(const float* filter, - float* filter_transform_buf, - float* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { +void winograd_6x3_4x4_f::filter( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { FilterTransform6X3::transform( - filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, - oc_end); + filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, oc_end); } -void winograd_6x3_4x4_f::input(const float* input, float* input_transform_buf, - float* transform_mid_buf, size_t IH, size_t IW, - size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, size_t nr_units_in_tile) { +void winograd_6x3_4x4_f::input( + const float* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { megdnn_assert(IC % 4 == 0); constexpr int alpha = 3 + 6 - 1; @@ -297,30 +288,28 @@ void winograd_6x3_4x4_f::input(const float* input, float* input_transform_buf, int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { - InputTransform6X3::prepare(input, patch, patchT, ih_start, - iw_start, IH, IW, ic, IC); - InputTransform6X3::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + InputTransform6X3::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + InputTransform6X3::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } else { - InputTransform6X3::prepare(input, patch, patchT, - ih_start, iw_start, IH, IW, - ic, IC); - InputTransform6X3::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + InputTransform6X3::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + InputTransform6X3::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } } } } -void winograd_6x3_4x4_f::output(const float* output_transform_buf, - const float* bias, float* output, - float* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_6x3_4x4_f::output( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform6X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp index 2e8deab8..f2c13301 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp @@ -16,9 +16,9 @@ #include "src/common/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" -#include "src/naive/matrix_mul/matrix_mul_helper.h" -#include "src/arm_common/elemwise_helper/op_unary.h" #include "src/arm_common/conv_bias/fp32/helper.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" #include "midout.h" MIDOUT_DECL(megdnn_arm_common_winograd_nchw44_fp32_F23_mk4) @@ -32,11 +32,10 @@ constexpr size_t pack_size = 4; struct InputTransformF23_NCHW44 { template - static void transform(float* patchT, const float* input, - float* input_transform_buf, size_t ih_start, - size_t iw_start, size_t IH, size_t IW, - size_t unit_idx, size_t nr_units_in_tile, size_t ic, - size_t IC) { + static void transform( + float* patchT, const float* input, float* input_transform_buf, + size_t ih_start, size_t iw_start, size_t IH, size_t IW, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC) { size_t IW4 = IW * pack_size; size_t icb = ic / pack_size; size_t iw4_start = iw_start * pack_size; @@ -74,14 +73,11 @@ struct InputTransformF23_NCHW44 { for (int iw = iw0_act; iw < iw1_act; ++iw) { size_t iho = ih - ih_start, iwo = iw - iw_start; auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); - vst1q_f32( - patchT + iho * alpha * pack_size + iwo * pack_size, - src); + vst1q_f32(patchT + iho * alpha * pack_size + iwo * pack_size, src); } } -#define cb(m, n) \ - d##m##n = Vector::load(patchT + m * alpha * pack_size + \ - n * pack_size); +#define cb(m, n) \ + d##m##n = Vector::load(patchT + m * alpha * pack_size + n * pack_size); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb } @@ -107,10 +103,11 @@ struct InputTransformF23_NCHW44 { UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(m, n) \ - d##m##n.save(input_transform_buf + \ - (m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + unit_idx * pack_size); +#define cb(m, n) \ + d##m##n.save( \ + input_transform_buf + \ + (m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) #undef cb } @@ -119,13 +116,11 @@ struct InputTransformF23_NCHW44 { #define CONCAT(a, idx) a##idx template struct OutputTransformF23_NCHW44 { - static void transform(const float* output_transform_buf, const float* bias, - float* output, float* transform_mid_buf, - size_t oh_start, size_t ow_start, size_t OH, - size_t OW, size_t oc_start, size_t oc_end, - size_t oc_index, size_t unit_idx, - size_t nr_units_in_tile, const DType& src_dtype, - const DType& dst_dtype) { + static void transform( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { MEGDNN_MARK_USED_VAR(transform_mid_buf); Op op(src_dtype, dst_dtype); //! AT * m * A @@ -173,20 +168,19 @@ struct OutputTransformF23_NCHW44 { UNROLL_CALL_RAW_D2(2, 2, cb); #undef cb } -#define out_save(oho, owo) \ - do { \ - size_t oh = oh_start + oho; \ - size_t ow = ow_start + owo; \ - if (oh < OH && ow < OW) { \ - if (bmode == BiasMode::BIAS) { \ - v##oho##owo += Vector::load(bias + oc * OH * OW + \ - oh * OW * pack_size + \ - ow * pack_size); \ - v##oho##owo = op(v##oho##owo.value); \ - } \ - v##oho##owo.save(output + oc * OH * OW + oh * OW * pack_size + \ - ow * pack_size); \ - } \ +#define out_save(oho, owo) \ + do { \ + size_t oh = oh_start + oho; \ + size_t ow = ow_start + owo; \ + if (oh < OH && ow < OW) { \ + if (bmode == BiasMode::BIAS) { \ + v##oho##owo += Vector::load( \ + bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ + v##oho##owo = op(v##oho##owo.value); \ + } \ + v##oho##owo.save( \ + output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ + } \ } while (0); UNROLL_CALL_RAW_D2(2, 2, out_save); #undef out_save @@ -200,11 +194,9 @@ namespace arm_common { namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F23_mk4_f_nchw44) -void winograd_F23_mk4_f_nchw44::filter(const float* filter, - float* filter_transform_buf, - float* transform_mid_buf, size_t OC, - size_t IC, size_t oc_start, - size_t oc_end) { +void winograd_F23_mk4_f_nchw44::filter( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { //! 1 0 0 v00 v01 v02 1 0.5 0.5 0 //! 0.5 0.5 0.5 v10 v11 v12 0 0.5 -0.5 0 //! 0.5 -0.5 0.5 v20 v21 v22 0 0.5 0.5 1 @@ -213,12 +205,12 @@ void winograd_F23_mk4_f_nchw44::filter(const float* filter, constexpr size_t pack_size = 4; MEGDNN_MARK_USED_VAR(transform_mid_buf); - megdnn_assert((oc_end - oc_start) % pack_size == 0 && - oc_start % pack_size == 0 && - oc_end % pack_size == 0 && IC % pack_size == 0 && - OC % pack_size == 0, - "NCHW44 Winograd filter transform requires both OC and IC " - "are times of 4"); + megdnn_assert( + (oc_end - oc_start) % pack_size == 0 && oc_start % pack_size == 0 && + oc_end % pack_size == 0 && IC % pack_size == 0 && + OC % pack_size == 0, + "NCHW44 Winograd filter transform requires both OC and IC " + "are times of 4"); size_t OCB = OC / pack_size; size_t ICB = IC / pack_size; @@ -226,9 +218,8 @@ void winograd_F23_mk4_f_nchw44::filter(const float* filter, for (size_t icb = 0; icb < ICB; icb++) { for (size_t ic_inner = 0; ic_inner < pack_size; ic_inner++) { const float* fptr = filter + - (ocb * ICB + icb) * KERNEL_SIZE * - KERNEL_SIZE * pack_size * - pack_size + + (ocb * ICB + icb) * KERNEL_SIZE * KERNEL_SIZE * + pack_size * pack_size + ic_inner * pack_size; #define cb(m, n) \ @@ -248,11 +239,12 @@ void winograd_F23_mk4_f_nchw44::filter(const float* filter, UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd); #undef FILTER_TRANSFORM -#define cb_save(m, n) \ - ret##m##n.save(filter_transform_buf + \ - (m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \ - ocb * ICB * pack_size * pack_size + \ - icb * pack_size * pack_size + ic_inner * pack_size); +#define cb_save(m, n) \ + ret##m##n.save( \ + filter_transform_buf + \ + (m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \ + ocb * ICB * pack_size * pack_size + icb * pack_size * pack_size + \ + ic_inner * pack_size); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb_save) #undef cb_save } @@ -260,18 +252,15 @@ void winograd_F23_mk4_f_nchw44::filter(const float* filter, } } -void winograd_F23_mk4_f_nchw44::input(const float* input, - float* input_transform_buf, - float* transform_mid_buf, size_t IH, - size_t IW, size_t IC, size_t PH, - size_t PW, size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_F23_mk4_f_nchw44::input( + const float* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { megdnn_assert(IC % 4 == 0); constexpr int alpha = 3 + 2 - 1; // OW = IW + 2 * PW - KERNEL_SIZE + 1 - auto units_w = - div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); float* patchT = transform_mid_buf + 4 * alpha * alpha; for (size_t ic = 0; ic < IC; ic += 4) { @@ -284,50 +273,49 @@ void winograd_F23_mk4_f_nchw44::input(const float* input, if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { InputTransformF23_NCHW44::transform( - patchT, input, input_transform_buf, ih_start, iw_start, - IH, IW, unit_idx, nr_units_in_tile, ic, IC); + patchT, input, input_transform_buf, ih_start, iw_start, IH, IW, + unit_idx, nr_units_in_tile, ic, IC); } else { InputTransformF23_NCHW44::transform( - patchT, input, input_transform_buf, ih_start, iw_start, - IH, IW, unit_idx, nr_units_in_tile, ic, IC); + patchT, input, input_transform_buf, ih_start, iw_start, IH, IW, + unit_idx, nr_units_in_tile, ic, IC); } } } } -void winograd_F23_mk4_f_nchw44::output(const float* output_transform_buf, - const float* bias, float* output, - float* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, - size_t OW, size_t oc_start, - size_t oc_end, size_t unit_start_idx, - size_t nr_units_in_tile) { -#define cb(_bmode, _nonline_op, ...) \ - for (size_t oc = oc_start; oc < oc_end; oc += 4) { \ - size_t oc_index = oc - oc_start; \ - rep(unit_idx, nr_units_in_tile) { \ - size_t index = unit_start_idx + unit_idx; \ - auto nh = index / units_w; \ - auto nw = index % units_w; \ - size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ - size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ - OutputTransformF23_NCHW44<_bmode, _nonline_op>::transform( \ - output_transform_buf, bias, output, transform_mid_buf, \ - oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, \ - unit_idx, nr_units_in_tile, src_dtype, dst_dtype); \ - } \ +void winograd_F23_mk4_f_nchw44::output( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + for (size_t oc = oc_start; oc < oc_end; oc += 4) { \ + size_t oc_index = oc - oc_start; \ + rep(unit_idx, nr_units_in_tile) { \ + size_t index = unit_start_idx + unit_idx; \ + auto nh = index / units_w; \ + auto nw = index % units_w; \ + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ + OutputTransformF23_NCHW44<_bmode, _nonline_op>::transform( \ + output_transform_buf, bias, output, transform_mid_buf, oh_start, \ + ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, \ + nr_units_in_tile, src_dtype, dst_dtype); \ + } \ } auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); constexpr size_t pack_size = 4; size_t OC = oc_end - oc_start; - megdnn_assert(OC % pack_size == 0 && oc_start % pack_size == 0 && - oc_end % pack_size == 0, - "NCHW44 Winograd filter transform requires OC is times of 4"); + megdnn_assert( + OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, + "NCHW44 Winograd filter transform requires OC is times of 4"); - DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, - cb, float, float, bmode, nonline_mode); + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, cb, float, float, bmode, + nonline_mode); #undef cb } diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp index 444cbc9d..dd0e1e9c 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp @@ -35,9 +35,9 @@ constexpr float input_parameters[12] = {5.25f, 4.25f, 0.5f, 0.25f, 2.5f, 1.25f, struct InputTransformF63_NCHW44 { template - static void prepare(const float* input, float* patch, float* patchT, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t ic, size_t IC) { + static void prepare( + const float* input, float* patch, float* patchT, int ih_start, int iw_start, + size_t IH, size_t IW, size_t ic, size_t IC) { MEGDNN_MARK_USED_VAR(patch); size_t IW4 = IW * pack_size; size_t iw4_start = iw_start * pack_size; @@ -69,17 +69,15 @@ struct InputTransformF63_NCHW44 { for (int iw = iw0_act; iw < iw1_act; ++iw) { size_t iho = ih - ih_start, iwo = iw - iw_start; auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); - vst1q_f32( - patchT + iho * pack_size * alpha + iwo * pack_size, - src); + vst1q_f32(patchT + iho * pack_size * alpha + iwo * pack_size, src); } } } } - static void transform(const float* patchT, float* input_transform_buf, - size_t unit_idx, size_t nr_units_in_tile, size_t ic, - size_t IC) { + static void transform( + const float* patchT, float* input_transform_buf, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC) { // BT * d * B size_t ICB = IC / pack_size; @@ -154,91 +152,91 @@ struct InputTransformF63_NCHW44 { UNROLL_CALL_RAW(8, cb); #undef cb -#define cb(i) \ - d0 = t0##i; \ - d1 = t6##i; \ - d2 = t6##i; \ - d3 = t6##i; \ - d4 = t6##i; \ - d5 = t6##i; \ - d6 = t6##i; \ - d7 = t7##i; \ - d0 = d0 - t6##i; \ - d1 = d1 + t1##i; \ - d2 = d2 - t1##i; \ - d3 = vfmaq_laneq_f32(d3, t1##i, v0, 2); \ - d4 = vfmsq_laneq_f32(d4, t1##i, v0, 2); \ - d5 = vfmaq_laneq_f32(d5, t1##i, v1, 2); \ - d6 = vfmsq_laneq_f32(d6, t1##i, v1, 2); \ - d7 = d7 - t1##i; \ - d0 = vfmsq_laneq_f32(d0, t2##i, v0, 0); \ - d1 = d1 + t2##i; \ - d2 = d2 + t2##i; \ - d3 = vfmaq_laneq_f32(d3, t2##i, v0, 3); \ - d4 = vfmaq_laneq_f32(d4, t2##i, v0, 3); \ - d5 = vfmaq_laneq_f32(d5, t2##i, v1, 3); \ - d6 = vfmaq_laneq_f32(d6, t2##i, v1, 3); \ - d1 = vfmsq_laneq_f32(d1, t3##i, v0, 1); \ - d2 = vfmaq_laneq_f32(d2, t3##i, v0, 1); \ - d3 = vfmsq_laneq_f32(d3, t3##i, v1, 0); \ - d4 = vfmaq_laneq_f32(d4, t3##i, v1, 0); \ - d5 = vfmsq_laneq_f32(d5, t3##i, v1, 0); \ - d6 = vfmaq_laneq_f32(d6, t3##i, v1, 0); \ - d7 = vfmaq_laneq_f32(d7, t3##i, v0, 0); \ - d0 = vfmaq_laneq_f32(d0, t4##i, v0, 0); \ - d1 = vfmsq_laneq_f32(d1, t4##i, v0, 1); \ - d2 = vfmsq_laneq_f32(d2, t4##i, v0, 1); \ - d3 = vfmsq_laneq_f32(d3, t4##i, v1, 1); \ - d4 = vfmsq_laneq_f32(d4, t4##i, v1, 1); \ - d5 = vfmsq_laneq_f32(d5, t4##i, v2, 0); \ - d6 = vfmsq_laneq_f32(d6, t4##i, v2, 0); \ - d1 = d1 + t5##i; \ - d2 = d2 - t5##i; \ - d3 = vfmaq_laneq_f32(d3, t5##i, v1, 2); \ - d4 = vfmsq_laneq_f32(d4, t5##i, v1, 2); \ - d5 = vfmaq_laneq_f32(d5, t5##i, v0, 2); \ - d6 = vfmsq_laneq_f32(d6, t5##i, v0, 2); \ - d7 = vfmsq_laneq_f32(d7, t5##i, v0, 0); \ - vst1q_f32(input_transform_buf + \ - (0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d0); \ - vst1q_f32(input_transform_buf + \ - (1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d1); \ - vst1q_f32(input_transform_buf + \ - (2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d2); \ - vst1q_f32(input_transform_buf + \ - (3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d3); \ - vst1q_f32(input_transform_buf + \ - (4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d4); \ - vst1q_f32(input_transform_buf + \ - (5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d5); \ - vst1q_f32(input_transform_buf + \ - (6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d6); \ - vst1q_f32(input_transform_buf + \ - (7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d7); +#define cb(i) \ + d0 = t0##i; \ + d1 = t6##i; \ + d2 = t6##i; \ + d3 = t6##i; \ + d4 = t6##i; \ + d5 = t6##i; \ + d6 = t6##i; \ + d7 = t7##i; \ + d0 = d0 - t6##i; \ + d1 = d1 + t1##i; \ + d2 = d2 - t1##i; \ + d3 = vfmaq_laneq_f32(d3, t1##i, v0, 2); \ + d4 = vfmsq_laneq_f32(d4, t1##i, v0, 2); \ + d5 = vfmaq_laneq_f32(d5, t1##i, v1, 2); \ + d6 = vfmsq_laneq_f32(d6, t1##i, v1, 2); \ + d7 = d7 - t1##i; \ + d0 = vfmsq_laneq_f32(d0, t2##i, v0, 0); \ + d1 = d1 + t2##i; \ + d2 = d2 + t2##i; \ + d3 = vfmaq_laneq_f32(d3, t2##i, v0, 3); \ + d4 = vfmaq_laneq_f32(d4, t2##i, v0, 3); \ + d5 = vfmaq_laneq_f32(d5, t2##i, v1, 3); \ + d6 = vfmaq_laneq_f32(d6, t2##i, v1, 3); \ + d1 = vfmsq_laneq_f32(d1, t3##i, v0, 1); \ + d2 = vfmaq_laneq_f32(d2, t3##i, v0, 1); \ + d3 = vfmsq_laneq_f32(d3, t3##i, v1, 0); \ + d4 = vfmaq_laneq_f32(d4, t3##i, v1, 0); \ + d5 = vfmsq_laneq_f32(d5, t3##i, v1, 0); \ + d6 = vfmaq_laneq_f32(d6, t3##i, v1, 0); \ + d7 = vfmaq_laneq_f32(d7, t3##i, v0, 0); \ + d0 = vfmaq_laneq_f32(d0, t4##i, v0, 0); \ + d1 = vfmsq_laneq_f32(d1, t4##i, v0, 1); \ + d2 = vfmsq_laneq_f32(d2, t4##i, v0, 1); \ + d3 = vfmsq_laneq_f32(d3, t4##i, v1, 1); \ + d4 = vfmsq_laneq_f32(d4, t4##i, v1, 1); \ + d5 = vfmsq_laneq_f32(d5, t4##i, v2, 0); \ + d6 = vfmsq_laneq_f32(d6, t4##i, v2, 0); \ + d1 = d1 + t5##i; \ + d2 = d2 - t5##i; \ + d3 = vfmaq_laneq_f32(d3, t5##i, v1, 2); \ + d4 = vfmsq_laneq_f32(d4, t5##i, v1, 2); \ + d5 = vfmaq_laneq_f32(d5, t5##i, v0, 2); \ + d6 = vfmsq_laneq_f32(d6, t5##i, v0, 2); \ + d7 = vfmsq_laneq_f32(d7, t5##i, v0, 0); \ + vst1q_f32( \ + input_transform_buf + \ + (0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d0); \ + vst1q_f32( \ + input_transform_buf + \ + (1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d1); \ + vst1q_f32( \ + input_transform_buf + \ + (2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d2); \ + vst1q_f32( \ + input_transform_buf + \ + (3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d3); \ + vst1q_f32( \ + input_transform_buf + \ + (4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d4); \ + vst1q_f32( \ + input_transform_buf + \ + (5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d5); \ + vst1q_f32( \ + input_transform_buf + \ + (6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d6); \ + vst1q_f32( \ + input_transform_buf + \ + (7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d7); UNROLL_CALL_RAW(8, cb); #undef cb } @@ -246,13 +244,11 @@ struct InputTransformF63_NCHW44 { template struct OutputTransformF63_NCHW44 { - static void transform(const float* output_transform_buf, const float* bias, - float* output, float* transform_mid_buf, - size_t oh_start, size_t ow_start, size_t OH, - size_t OW, size_t oc_start, size_t oc_end, - size_t oc_index, size_t unit_idx, - size_t nr_units_in_tile, const DType& src_dtype, - const DType& dst_dtype) { + static void transform( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { MEGDNN_MARK_USED_VAR(transform_mid_buf); Op op(src_dtype, dst_dtype); //! AT * m * A @@ -330,20 +326,19 @@ struct OutputTransformF63_NCHW44 { UNROLL_CALL_RAW_D2(6, 6, cb); #undef cb } -#define out_save(oho, owo) \ - do { \ - size_t oh = oh_start + oho; \ - size_t ow = ow_start + owo; \ - if (oh < OH && ow < OW) { \ - if (bmode == BiasMode::BIAS) { \ - v##oho##owo += Vector::load(bias + oc * OH * OW + \ - oh * OW * pack_size + \ - ow * pack_size); \ - v##oho##owo = op(v##oho##owo.value); \ - } \ - v##oho##owo.save(output + oc * OH * OW + oh * OW * pack_size + \ - ow * pack_size); \ - } \ +#define out_save(oho, owo) \ + do { \ + size_t oh = oh_start + oho; \ + size_t ow = ow_start + owo; \ + if (oh < OH && ow < OW) { \ + if (bmode == BiasMode::BIAS) { \ + v##oho##owo += Vector::load( \ + bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ + v##oho##owo = op(v##oho##owo.value); \ + } \ + v##oho##owo.save( \ + output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ + } \ } while (0); UNROLL_CALL_RAW_D2(6, 6, out_save); } @@ -357,11 +352,9 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F63_mk4_f_nchw44) -void winograd_F63_mk4_f_nchw44::filter(const float* filter, - float* filter_transform_buf, - float* transform_mid_buf, size_t OC, - size_t IC, size_t oc_start, - size_t oc_end) { +void winograd_F63_mk4_f_nchw44::filter( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { constexpr size_t pack_size = 4; // Gg * GT // G @@ -374,12 +367,12 @@ void winograd_F63_mk4_f_nchw44::filter(const float* filter, // 0.7111111 -0.3555556 0.1777778 // 0.0000000 0.0000000 1.0000000 MEGDNN_MARK_USED_VAR(transform_mid_buf); - megdnn_assert((oc_end - oc_start) % pack_size == 0 && - oc_start % pack_size == 0 && - oc_end % pack_size == 0 && IC % pack_size == 0 && - OC % pack_size == 0, - "NCHW44 Winograd filter transform requires both OC and IC " - "are times of 4"); + megdnn_assert( + (oc_end - oc_start) % pack_size == 0 && oc_start % pack_size == 0 && + oc_end % pack_size == 0 && IC % pack_size == 0 && + OC % pack_size == 0, + "NCHW44 Winograd filter transform requires both OC and IC " + "are times of 4"); size_t ICB = IC / pack_size; @@ -387,9 +380,8 @@ void winograd_F63_mk4_f_nchw44::filter(const float* filter, for (size_t icb = 0; icb < ICB; icb++) { for (size_t ic_inner = 0; ic_inner < pack_size; ic_inner++) { const float* fptr = filter + - (ocb * ICB + icb) * KERNEL_SIZE * - KERNEL_SIZE * pack_size * - pack_size + + (ocb * ICB + icb) * KERNEL_SIZE * KERNEL_SIZE * + pack_size * pack_size + ic_inner * pack_size; #define cb(m, n) \ @@ -417,10 +409,10 @@ void winograd_F63_mk4_f_nchw44::filter(const float* filter, UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); UNROLL_CALL_RAW(8, FILTER_TRANSFORM, ret, wd); #undef FILTER_TRANSFORM -#define cb_save(m, n) \ - ret##m##n.save(filter_transform_buf + (m * alpha + n) * OC * IC + \ - ocb * IC * pack_size + icb * pack_size * pack_size + \ - ic_inner * pack_size); +#define cb_save(m, n) \ + ret##m##n.save( \ + filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ + icb * pack_size * pack_size + ic_inner * pack_size); UNROLL_CALL_NOWRAPPER_D2(8, 8, cb_save) #undef cb_save } @@ -428,19 +420,16 @@ void winograd_F63_mk4_f_nchw44::filter(const float* filter, } } -void winograd_F63_mk4_f_nchw44::input(const float* input, - float* input_transform_buf, - float* transform_mid_buf, size_t IH, - size_t IW, size_t IC, size_t PH, - size_t PW, size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_F63_mk4_f_nchw44::input( + const float* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { constexpr size_t pack_size = 4; megdnn_assert(IC % pack_size == 0); constexpr int alpha = 3 + 6 - 1; // OW = IW + 2 * PW - KERNEL_SIZE + 1 - auto units_w = - div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); float* patch = transform_mid_buf; float* patchT = transform_mid_buf + pack_size * alpha * alpha; @@ -453,59 +442,55 @@ void winograd_F63_mk4_f_nchw44::input(const float* input, int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { - InputTransformF63_NCHW44::prepare(input, patch, patchT, - ih_start, iw_start, IH, - IW, ic, IC); - InputTransformF63_NCHW44::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, - ic, IC); + InputTransformF63_NCHW44::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + InputTransformF63_NCHW44::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } else { - InputTransformF63_NCHW44::prepare(input, patch, patchT, - ih_start, iw_start, IH, - IW, ic, IC); - InputTransformF63_NCHW44::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, - ic, IC); + InputTransformF63_NCHW44::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + InputTransformF63_NCHW44::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } } } } -void winograd_F63_mk4_f_nchw44::output(const float* output_transform_buf, - const float* bias, float* output, - float* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, - size_t OW, size_t oc_start, - size_t oc_end, size_t unit_start_idx, - size_t nr_units_in_tile) { -#define cb(_bmode, _nonline_op, ...) \ - for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \ - size_t oc_index = oc - oc_start; \ - rep(unit_idx, nr_units_in_tile) { \ - size_t index = unit_start_idx + unit_idx; \ - auto nh = index / units_w; \ - auto nw = index % units_w; \ - size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ - size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ - OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>:: \ - transform(output_transform_buf, bias, output, \ - transform_mid_buf, oh_start, ow_start, OH, OW, \ - oc_start, oc_end, oc_index, unit_idx, \ - nr_units_in_tile, src_dtype, dst_dtype); \ - } \ +void winograd_F63_mk4_f_nchw44::output( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \ + size_t oc_index = oc - oc_start; \ + rep(unit_idx, nr_units_in_tile) { \ + size_t index = unit_start_idx + unit_idx; \ + auto nh = index / units_w; \ + auto nw = index % units_w; \ + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ + OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \ + output_transform_buf, bias, output, transform_mid_buf, oh_start, \ + ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, \ + nr_units_in_tile, src_dtype, dst_dtype); \ + } \ } auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); constexpr size_t pack_size = 4; size_t OC = oc_end - oc_start; - megdnn_assert(OC % pack_size == 0 && oc_start % pack_size == 0 && - oc_end % pack_size == 0, - "NCHW44 Winograd filter transform requires OC is times of 4"); + megdnn_assert( + OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, + "NCHW44 Winograd filter transform requires OC is times of 4"); - DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_fp32_F63_mk4, cb, - float, float, bmode, nonline_mode); + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp32_F63_mk4, cb, float, float, bmode, + nonline_mode); #undef cb } diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp index b106a810..d009a90d 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp @@ -31,16 +31,15 @@ namespace { constexpr size_t alpha = 7 + 3 - 1; constexpr size_t pack_size = 4; constexpr float input_parameters[28] = { - 1.5f, 0.75f, 3.0f, 7.875f, 0.5f, 2.5f, 0.125f, - 0.875f, 4.0f, 8.0f, 5.25f, 7.375f, 5.375f, 3.5f, - 7.75f, 0.25f, 2.125f, 10.625f, 0.625f, 4.375f, 5.0f, - 10.0f, 5.75f, 2.75f, 4.25f, 1.75f, 2.0f, 0.0f}; + 1.5f, 0.75f, 3.0f, 7.875f, 0.5f, 2.5f, 0.125f, 0.875f, 4.0f, 8.0f, + 5.25f, 7.375f, 5.375f, 3.5f, 7.75f, 0.25f, 2.125f, 10.625f, 0.625f, 4.375f, + 5.0f, 10.0f, 5.75f, 2.75f, 4.25f, 1.75f, 2.0f, 0.0f}; struct InputTransformF73_NCHW44 { template - static void prepare(const float* input, float* patch, float* patchT, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t ic, size_t IC) { + static void prepare( + const float* input, float* patch, float* patchT, int ih_start, int iw_start, + size_t IH, size_t IW, size_t ic, size_t IC) { MEGDNN_MARK_USED_VAR(patch); size_t IW4 = IW * pack_size; size_t iw4_start = iw_start * pack_size; @@ -72,17 +71,15 @@ struct InputTransformF73_NCHW44 { for (int iw = iw0_act; iw < iw1_act; ++iw) { size_t iho = ih - ih_start, iwo = iw - iw_start; auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); - vst1q_f32( - patchT + iho * pack_size * alpha + iwo * pack_size, - src); + vst1q_f32(patchT + iho * pack_size * alpha + iwo * pack_size, src); } } } } - static void transform(const float* patchT, float* input_transform_buf, - size_t unit_idx, size_t nr_units_in_tile, size_t ic, - size_t IC) { + static void transform( + const float* patchT, float* input_transform_buf, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC) { // BT * d * B size_t ICB = IC / pack_size; @@ -191,114 +188,114 @@ struct InputTransformF73_NCHW44 { UNROLL_CALL_RAW(9, cb); #undef cb -#define cb(i) \ - d8 = t8##i; \ - d0 = t7##i; \ - d1 = t7##i; \ - d2 = t7##i; \ - d3 = t7##i; \ - d4 = t7##i; \ - d5 = t7##i; \ - d6 = t7##i; \ - d7 = t7##i; \ - d8 = vfmsq_laneq_f32(d8, t7##i, v0, 0); \ - d0 = d0 - t1##i; \ - d1 = vfmsq_laneq_f32(d1, t1##i, v0, 0); \ - d2 = vfmaq_laneq_f32(d2, t1##i, v0, 0); \ - d3 = vfmsq_laneq_f32(d3, t1##i, v0, 1); \ - d4 = vfmaq_laneq_f32(d4, t1##i, v0, 1); \ - d5 = vfmsq_laneq_f32(d5, t1##i, v0, 2); \ - d6 = vfmaq_laneq_f32(d6, t1##i, v0, 2); \ - d7 = d7 - t1##i; \ - d8 = vfmaq_laneq_f32(d8, t1##i, v0, 0); \ - d0 = vfmsq_laneq_f32(d0, t2##i, v0, 3); \ - d1 = vfmsq_laneq_f32(d1, t2##i, v1, 0); \ - d2 = vfmsq_laneq_f32(d2, t2##i, v1, 1); \ - d3 = vfmaq_laneq_f32(d3, t2##i, v1, 2); \ - d4 = vfmsq_laneq_f32(d4, t2##i, v1, 3); \ - d5 = vfmsq_laneq_f32(d5, t2##i, v2, 0); \ - d6 = vfmsq_laneq_f32(d6, t2##i, v2, 1); \ - d8 = d8 - t2##i; \ - d0 = vfmaq_laneq_f32(d0, t3##i, v2, 2); \ - d1 = vfmaq_laneq_f32(d1, t3##i, v2, 3); \ - d2 = vfmsq_laneq_f32(d2, t3##i, v3, 0); \ - d3 = vfmaq_laneq_f32(d3, t3##i, v2, 0); \ - d4 = vfmsq_laneq_f32(d4, t3##i, v3, 1); \ - d5 = vfmaq_laneq_f32(d5, t3##i, v3, 2); \ - d6 = vfmaq_laneq_f32(d6, t3##i, v3, 3); \ - d7 = vfmaq_laneq_f32(d7, t3##i, v2, 2); \ - d8 = vfmsq_laneq_f32(d8, t3##i, v0, 3); \ - d0 = vfmaq_laneq_f32(d0, t4##i, v0, 3); \ - d1 = vfmaq_laneq_f32(d1, t4##i, v4, 0); \ - d2 = vfmaq_laneq_f32(d2, t4##i, v4, 1); \ - d3 = vfmsq_laneq_f32(d3, t4##i, v4, 2); \ - d4 = vfmaq_laneq_f32(d4, t4##i, v4, 3); \ - d5 = vfmaq_laneq_f32(d5, t4##i, v5, 0); \ - d6 = vfmaq_laneq_f32(d6, t4##i, v5, 1); \ - d8 = vfmaq_laneq_f32(d8, t4##i, v2, 2); \ - d0 = vfmsq_laneq_f32(d0, t5##i, v2, 2); \ - d1 = vfmsq_laneq_f32(d1, t5##i, v5, 2); \ - d2 = vfmsq_laneq_f32(d2, t5##i, v5, 3); \ - d3 = vfmsq_laneq_f32(d3, t5##i, v6, 0); \ - d4 = vfmaq_laneq_f32(d4, t5##i, v6, 1); \ - d5 = vfmsq_laneq_f32(d5, t5##i, v5, 2); \ - d6 = vfmsq_laneq_f32(d6, t5##i, v6, 0); \ - d7 = vfmsq_laneq_f32(d7, t5##i, v2, 2); \ - d8 = vfmaq_laneq_f32(d8, t5##i, v0, 3); \ - d0 = vfmsq_laneq_f32(d0, t6##i, v0, 0); \ - d1 = vfmsq_laneq_f32(d1, t6##i, v1, 0); \ - d2 = vfmsq_laneq_f32(d2, t6##i, v1, 1); \ - d3 = vfmaq_laneq_f32(d3, t6##i, v1, 0); \ - d4 = vfmsq_laneq_f32(d4, t6##i, v3, 1); \ - d5 = d5 - t6##i; \ - d6 = vfmsq_laneq_f32(d6, t6##i, v6, 2); \ - d8 = vfmsq_laneq_f32(d8, t6##i, v2, 2); \ - d0 = vfmaq_laneq_f32(d0, t0##i, v0, 0); \ - vst1q_f32(input_transform_buf + \ - (0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d0); \ - vst1q_f32(input_transform_buf + \ - (1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d1); \ - vst1q_f32(input_transform_buf + \ - (2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d2); \ - vst1q_f32(input_transform_buf + \ - (3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d3); \ - vst1q_f32(input_transform_buf + \ - (4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d4); \ - vst1q_f32(input_transform_buf + \ - (5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d5); \ - vst1q_f32(input_transform_buf + \ - (6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d6); \ - vst1q_f32(input_transform_buf + \ - (7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d7); \ - vst1q_f32(input_transform_buf + \ - (8 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + \ - unit_idx * pack_size, \ - d8); +#define cb(i) \ + d8 = t8##i; \ + d0 = t7##i; \ + d1 = t7##i; \ + d2 = t7##i; \ + d3 = t7##i; \ + d4 = t7##i; \ + d5 = t7##i; \ + d6 = t7##i; \ + d7 = t7##i; \ + d8 = vfmsq_laneq_f32(d8, t7##i, v0, 0); \ + d0 = d0 - t1##i; \ + d1 = vfmsq_laneq_f32(d1, t1##i, v0, 0); \ + d2 = vfmaq_laneq_f32(d2, t1##i, v0, 0); \ + d3 = vfmsq_laneq_f32(d3, t1##i, v0, 1); \ + d4 = vfmaq_laneq_f32(d4, t1##i, v0, 1); \ + d5 = vfmsq_laneq_f32(d5, t1##i, v0, 2); \ + d6 = vfmaq_laneq_f32(d6, t1##i, v0, 2); \ + d7 = d7 - t1##i; \ + d8 = vfmaq_laneq_f32(d8, t1##i, v0, 0); \ + d0 = vfmsq_laneq_f32(d0, t2##i, v0, 3); \ + d1 = vfmsq_laneq_f32(d1, t2##i, v1, 0); \ + d2 = vfmsq_laneq_f32(d2, t2##i, v1, 1); \ + d3 = vfmaq_laneq_f32(d3, t2##i, v1, 2); \ + d4 = vfmsq_laneq_f32(d4, t2##i, v1, 3); \ + d5 = vfmsq_laneq_f32(d5, t2##i, v2, 0); \ + d6 = vfmsq_laneq_f32(d6, t2##i, v2, 1); \ + d8 = d8 - t2##i; \ + d0 = vfmaq_laneq_f32(d0, t3##i, v2, 2); \ + d1 = vfmaq_laneq_f32(d1, t3##i, v2, 3); \ + d2 = vfmsq_laneq_f32(d2, t3##i, v3, 0); \ + d3 = vfmaq_laneq_f32(d3, t3##i, v2, 0); \ + d4 = vfmsq_laneq_f32(d4, t3##i, v3, 1); \ + d5 = vfmaq_laneq_f32(d5, t3##i, v3, 2); \ + d6 = vfmaq_laneq_f32(d6, t3##i, v3, 3); \ + d7 = vfmaq_laneq_f32(d7, t3##i, v2, 2); \ + d8 = vfmsq_laneq_f32(d8, t3##i, v0, 3); \ + d0 = vfmaq_laneq_f32(d0, t4##i, v0, 3); \ + d1 = vfmaq_laneq_f32(d1, t4##i, v4, 0); \ + d2 = vfmaq_laneq_f32(d2, t4##i, v4, 1); \ + d3 = vfmsq_laneq_f32(d3, t4##i, v4, 2); \ + d4 = vfmaq_laneq_f32(d4, t4##i, v4, 3); \ + d5 = vfmaq_laneq_f32(d5, t4##i, v5, 0); \ + d6 = vfmaq_laneq_f32(d6, t4##i, v5, 1); \ + d8 = vfmaq_laneq_f32(d8, t4##i, v2, 2); \ + d0 = vfmsq_laneq_f32(d0, t5##i, v2, 2); \ + d1 = vfmsq_laneq_f32(d1, t5##i, v5, 2); \ + d2 = vfmsq_laneq_f32(d2, t5##i, v5, 3); \ + d3 = vfmsq_laneq_f32(d3, t5##i, v6, 0); \ + d4 = vfmaq_laneq_f32(d4, t5##i, v6, 1); \ + d5 = vfmsq_laneq_f32(d5, t5##i, v5, 2); \ + d6 = vfmsq_laneq_f32(d6, t5##i, v6, 0); \ + d7 = vfmsq_laneq_f32(d7, t5##i, v2, 2); \ + d8 = vfmaq_laneq_f32(d8, t5##i, v0, 3); \ + d0 = vfmsq_laneq_f32(d0, t6##i, v0, 0); \ + d1 = vfmsq_laneq_f32(d1, t6##i, v1, 0); \ + d2 = vfmsq_laneq_f32(d2, t6##i, v1, 1); \ + d3 = vfmaq_laneq_f32(d3, t6##i, v1, 0); \ + d4 = vfmsq_laneq_f32(d4, t6##i, v3, 1); \ + d5 = d5 - t6##i; \ + d6 = vfmsq_laneq_f32(d6, t6##i, v6, 2); \ + d8 = vfmsq_laneq_f32(d8, t6##i, v2, 2); \ + d0 = vfmaq_laneq_f32(d0, t0##i, v0, 0); \ + vst1q_f32( \ + input_transform_buf + \ + (0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d0); \ + vst1q_f32( \ + input_transform_buf + \ + (1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d1); \ + vst1q_f32( \ + input_transform_buf + \ + (2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d2); \ + vst1q_f32( \ + input_transform_buf + \ + (3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d3); \ + vst1q_f32( \ + input_transform_buf + \ + (4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d4); \ + vst1q_f32( \ + input_transform_buf + \ + (5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d5); \ + vst1q_f32( \ + input_transform_buf + \ + (6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d6); \ + vst1q_f32( \ + input_transform_buf + \ + (7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d7); \ + vst1q_f32( \ + input_transform_buf + \ + (8 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size, \ + d8); UNROLL_CALL_RAW(9, cb); #undef cb @@ -307,13 +304,11 @@ struct InputTransformF73_NCHW44 { template struct OutputTransformF73_NCHW44 { - static void transform(const float* output_transform_buf, const float* bias, - float* output, float* transform_mid_buf, - size_t oh_start, size_t ow_start, size_t OH, - size_t OW, size_t oc_start, size_t oc_end, - size_t oc_index, size_t unit_idx, - size_t nr_units_in_tile, const DType& src_dtype, - const DType& dst_dtype) { + static void transform( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& dst_dtype) { MEGDNN_MARK_USED_VAR(transform_mid_buf); Op op(src_dtype, dst_dtype); //! AT * m * A @@ -346,44 +341,40 @@ struct OutputTransformF73_NCHW44 { */ Vector v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; -#define cb(m) \ - v1addv2 = v1##m + v2##m; \ - v1subv2 = v1##m - v2##m; \ - v3addv4 = v3##m + v4##m; \ - v3subv4 = v3##m - v4##m; \ - v5addv6 = v5##m + v6##m; \ - v5subv6 = v5##m - v6##m; \ - auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m; \ - auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f; \ - auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f; \ - auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f; \ - auto t4##m = \ - v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f; \ - auto t5##m = \ - v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m * 7.59375f; \ - auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + \ - v7##m * 11.390625f + v8##m; +#define cb(m) \ + v1addv2 = v1##m + v2##m; \ + v1subv2 = v1##m - v2##m; \ + v3addv4 = v3##m + v4##m; \ + v3subv4 = v3##m - v4##m; \ + v5addv6 = v5##m + v6##m; \ + v5subv6 = v5##m - v6##m; \ + auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m; \ + auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f; \ + auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f; \ + auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f; \ + auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f; \ + auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m * 7.59375f; \ + auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + v7##m * 11.390625f + \ + v8##m; UNROLL_CALL_NOWRAPPER(9, cb); #undef cb -#define cb(m) \ - v1addv2 = t##m##1 + t##m##2; \ - v1subv2 = t##m##1 - t##m##2; \ - v3addv4 = t##m##3 + t##m##4; \ - v3subv4 = t##m##3 - t##m##4; \ - v5addv6 = t##m##5 + t##m##6; \ - v5subv6 = t##m##5 - t##m##6; \ - v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7; \ - v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f; \ - v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f; \ - v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375; \ - v##m##4 = \ - v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f; \ - v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + \ - t##m##7 * 7.59375f; \ - v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + \ - t##m##7 * 11.390625f + t##m##8; +#define cb(m) \ + v1addv2 = t##m##1 + t##m##2; \ + v1subv2 = t##m##1 - t##m##2; \ + v3addv4 = t##m##3 + t##m##4; \ + v3subv4 = t##m##3 - t##m##4; \ + v5addv6 = t##m##5 + t##m##6; \ + v5subv6 = t##m##5 - t##m##6; \ + v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7; \ + v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f; \ + v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f; \ + v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375; \ + v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f; \ + v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7 * 7.59375f; \ + v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + t##m##7 * 11.390625f + \ + t##m##8; UNROLL_CALL_NOWRAPPER(7, cb); #undef cb @@ -401,20 +392,19 @@ struct OutputTransformF73_NCHW44 { UNROLL_CALL_RAW_D2(7, 7, cb); #undef cb } -#define out_save(oho, owo) \ - do { \ - size_t oh = oh_start + oho; \ - size_t ow = ow_start + owo; \ - if (oh < OH && ow < OW) { \ - if (bmode == BiasMode::BIAS) { \ - v##oho##owo += Vector::load(bias + oc * OH * OW + \ - oh * OW * pack_size + \ - ow * pack_size); \ - v##oho##owo = op(v##oho##owo.value); \ - } \ - v##oho##owo.save(output + oc * OH * OW + oh * OW * pack_size + \ - ow * pack_size); \ - } \ +#define out_save(oho, owo) \ + do { \ + size_t oh = oh_start + oho; \ + size_t ow = ow_start + owo; \ + if (oh < OH && ow < OW) { \ + if (bmode == BiasMode::BIAS) { \ + v##oho##owo += Vector::load( \ + bias + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ + v##oho##owo = op(v##oho##owo.value); \ + } \ + v##oho##owo.save( \ + output + oc * OH * OW + oh * OW * pack_size + ow * pack_size); \ + } \ } while (0); UNROLL_CALL_RAW_D2(7, 7, out_save); } @@ -428,11 +418,9 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F73_mk4_f_nchw44) -void winograd_F73_mk4_f_nchw44::filter(const float* filter, - float* filter_transform_buf, - float* transform_mid_buf, size_t OC, - size_t IC, size_t oc_start, - size_t oc_end) { +void winograd_F73_mk4_f_nchw44::filter( + const float* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { constexpr size_t pack_size = 4; // Gg * GT // G @@ -446,12 +434,12 @@ void winograd_F73_mk4_f_nchw44::filter(const float* filter, //-0.1523810 -0.2285714 -0.3428572 // 0.0000000 0.0000000 1.0000000 MEGDNN_MARK_USED_VAR(transform_mid_buf); - megdnn_assert((oc_end - oc_start) % pack_size == 0 && - oc_start % pack_size == 0 && - oc_end % pack_size == 0 && IC % pack_size == 0 && - OC % pack_size == 0, - "NCHW44 Winograd filter transform requires both OC and IC " - "are times of 4"); + megdnn_assert( + (oc_end - oc_start) % pack_size == 0 && oc_start % pack_size == 0 && + oc_end % pack_size == 0 && IC % pack_size == 0 && + OC % pack_size == 0, + "NCHW44 Winograd filter transform requires both OC and IC " + "are times of 4"); size_t ICB = IC / pack_size; @@ -459,9 +447,8 @@ void winograd_F73_mk4_f_nchw44::filter(const float* filter, for (size_t icb = 0; icb < ICB; icb++) { for (size_t ic_inner = 0; ic_inner < pack_size; ic_inner++) { const float* fptr = filter + - (ocb * ICB + icb) * KERNEL_SIZE * - KERNEL_SIZE * pack_size * - pack_size + + (ocb * ICB + icb) * KERNEL_SIZE * KERNEL_SIZE * + pack_size * pack_size + ic_inner * pack_size; #define cb(m, n) \ @@ -470,28 +457,28 @@ void winograd_F73_mk4_f_nchw44::filter(const float* filter, UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) #undef cb -#define FILTER_TRANSFORM(n, wd, g) \ - auto wd##n##0 = g##0##n * 0.6666667f; \ - auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f; \ - auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f; \ - auto wd##n##3 = g##0##n * 0.0222222f + g##1##n * 0.0444444f + \ - g##2##n * 0.0888889f; \ - auto wd##n##4 = g##0##n * -0.0031746f + g##1##n * 0.0063492f + \ - g##2##n * -0.0126984f; \ - auto wd##n##5 = g##0##n * -0.7111111f + g##1##n * -0.3555556f + \ - g##2##n * -0.1777778f; \ - auto wd##n##6 = g##0##n * -0.3555556f + g##1##n * 0.1777778f + \ - g##2##n * -0.0888889f; \ - auto wd##n##7 = g##0##n * -0.1523810f + g##1##n * -0.2285714f + \ - g##2##n * -0.3428572f; \ +#define FILTER_TRANSFORM(n, wd, g) \ + auto wd##n##0 = g##0##n * 0.6666667f; \ + auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f; \ + auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f; \ + auto wd##n##3 = \ + g##0##n * 0.0222222f + g##1##n * 0.0444444f + g##2##n * 0.0888889f; \ + auto wd##n##4 = \ + g##0##n * -0.0031746f + g##1##n * 0.0063492f + g##2##n * -0.0126984f; \ + auto wd##n##5 = \ + g##0##n * -0.7111111f + g##1##n * -0.3555556f + g##2##n * -0.1777778f; \ + auto wd##n##6 = \ + g##0##n * -0.3555556f + g##1##n * 0.1777778f + g##2##n * -0.0888889f; \ + auto wd##n##7 = \ + g##0##n * -0.1523810f + g##1##n * -0.2285714f + g##2##n * -0.3428572f; \ auto wd##n##8 = g##2##n; UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); UNROLL_CALL_RAW(9, FILTER_TRANSFORM, ret, wd); #undef FILTER_TRANSFORM -#define cb_save(m, n) \ - ret##m##n.save(filter_transform_buf + (m * alpha + n) * OC * IC + \ - ocb * IC * pack_size + icb * pack_size * pack_size + \ - ic_inner * pack_size); +#define cb_save(m, n) \ + ret##m##n.save( \ + filter_transform_buf + (m * alpha + n) * OC * IC + ocb * IC * pack_size + \ + icb * pack_size * pack_size + ic_inner * pack_size); UNROLL_CALL_NOWRAPPER_D2(9, 9, cb_save) #undef cb_save } @@ -499,19 +486,16 @@ void winograd_F73_mk4_f_nchw44::filter(const float* filter, } } -void winograd_F73_mk4_f_nchw44::input(const float* input, - float* input_transform_buf, - float* transform_mid_buf, size_t IH, - size_t IW, size_t IC, size_t PH, - size_t PW, size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_F73_mk4_f_nchw44::input( + const float* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { constexpr size_t pack_size = 4; megdnn_assert(IC % pack_size == 0); constexpr int alpha = 3 + 7 - 1; // OW = IW + 2 * PW - KERNEL_SIZE + 1 - auto units_w = - div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); float* patch = transform_mid_buf; float* patchT = transform_mid_buf + pack_size * alpha * alpha; @@ -524,59 +508,55 @@ void winograd_F73_mk4_f_nchw44::input(const float* input, int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { - InputTransformF73_NCHW44::prepare(input, patch, patchT, - ih_start, iw_start, IH, - IW, ic, IC); - InputTransformF73_NCHW44::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, - ic, IC); + InputTransformF73_NCHW44::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + InputTransformF73_NCHW44::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } else { - InputTransformF73_NCHW44::prepare(input, patch, patchT, - ih_start, iw_start, IH, - IW, ic, IC); - InputTransformF73_NCHW44::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, - ic, IC); + InputTransformF73_NCHW44::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + InputTransformF73_NCHW44::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } } } } -void winograd_F73_mk4_f_nchw44::output(const float* output_transform_buf, - const float* bias, float* output, - float* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, - size_t OW, size_t oc_start, - size_t oc_end, size_t unit_start_idx, - size_t nr_units_in_tile) { -#define cb(_bmode, _nonline_op, ...) \ - for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \ - size_t oc_index = oc - oc_start; \ - rep(unit_idx, nr_units_in_tile) { \ - size_t index = unit_start_idx + unit_idx; \ - auto nh = index / units_w; \ - auto nw = index % units_w; \ - size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ - size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ - OutputTransformF73_NCHW44<_bmode MEGDNN_COMMA _nonline_op>:: \ - transform(output_transform_buf, bias, output, \ - transform_mid_buf, oh_start, ow_start, OH, OW, \ - oc_start, oc_end, oc_index, unit_idx, \ - nr_units_in_tile, src_dtype, dst_dtype); \ - } \ +void winograd_F73_mk4_f_nchw44::output( + const float* output_transform_buf, const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \ + size_t oc_index = oc - oc_start; \ + rep(unit_idx, nr_units_in_tile) { \ + size_t index = unit_start_idx + unit_idx; \ + auto nh = index / units_w; \ + auto nw = index % units_w; \ + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ + OutputTransformF73_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \ + output_transform_buf, bias, output, transform_mid_buf, oh_start, \ + ow_start, OH, OW, oc_start, oc_end, oc_index, unit_idx, \ + nr_units_in_tile, src_dtype, dst_dtype); \ + } \ } auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); constexpr size_t pack_size = 4; size_t OC = oc_end - oc_start; - megdnn_assert(OC % pack_size == 0 && oc_start % pack_size == 0 && - oc_end % pack_size == 0, - "NCHW44 Winograd filter transform requires OC is times of 4"); + megdnn_assert( + OC % pack_size == 0 && oc_start % pack_size == 0 && oc_end % pack_size == 0, + "NCHW44 Winograd filter transform requires OC is times of 4"); - DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_fp32_F73_mk4, cb, - float, float, bmode, nonline_mode); + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp32_F73_mk4, cb, float, float, bmode, + nonline_mode); #undef cb } diff --git a/dnn/src/arm_common/conv_bias/img2col_helper.h b/dnn/src/arm_common/conv_bias/img2col_helper.h index 2af4ac7c..ce757d07 100644 --- a/dnn/src/arm_common/conv_bias/img2col_helper.h +++ b/dnn/src/arm_common/conv_bias/img2col_helper.h @@ -14,10 +14,10 @@ namespace { template -void img2col_stride(const dtype* __restrict src, - dtype* __restrict dst, const int OC, const int OH, - const int OW, const int IC, const int IH, const int IW, - const int FH, const int FW, const int SH, const int SW) { +void img2col_stride( + const dtype* __restrict src, dtype* __restrict dst, const int OC, const int OH, + const int OW, const int IC, const int IH, const int IW, const int FH, + const int FW, const int SH, const int SW) { (void)OC; size_t i = 0; rep(ic, IC) { @@ -33,8 +33,9 @@ void img2col_stride(const dtype* __restrict src, fh2 = FH - fh - 1; fw2 = FW - fw - 1; } - dst[i++] = src[ic * IH * IW + (oh * SH + fh2) * IW + - (ow * SW + fw2)]; + dst[i++] = + src[ic * IH * IW + (oh * SH + fh2) * IW + + (ow * SW + fw2)]; } } } @@ -43,8 +44,9 @@ void img2col_stride(const dtype* __restrict src, } template -void img2col(const dtype* src, dtype* dst, size_t /* OC */, size_t OH, - size_t OW, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW) { +void img2col( + const dtype* src, dtype* dst, size_t /* OC */, size_t OH, size_t OW, size_t IC, + size_t IH, size_t IW, size_t FH, size_t FW) { size_t offset = (4 - OW % 4) % 4; size_t i = 0; rep(ic, IC) { @@ -61,14 +63,10 @@ void img2col(const dtype* src, dtype* dst, size_t /* OC */, size_t OH, fh2 = FH - fh - 1; fw2 = FW - fw - 1; } - dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + - (ow + fw2) + 0]; - dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + - (ow + fw2) + 1]; - dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + - (ow + fw2) + 2]; - dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + - (ow + fw2) + 3]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + (ow + fw2) + 0]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + (ow + fw2) + 1]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + (ow + fw2) + 2]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + (ow + fw2) + 3]; } i -= offset; } diff --git a/dnn/src/arm_common/conv_bias/int8/algos.cpp b/dnn/src/arm_common/conv_bias/int8/algos.cpp index d4be9237..a74ef427 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.cpp +++ b/dnn/src/arm_common/conv_bias/int8/algos.cpp @@ -28,14 +28,15 @@ using namespace arm_common; MIDOUT_DECL(megdnn_arm_common_conv_bias_int8) /* ===================== stride1 algo ===================== */ -bool ConvBiasImpl::AlgoS8DirectStride1::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoS8DirectStride1::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { return direct_int8_stride1::can_conv_direct_stride1_int8(param); } bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoS8DirectStride1::is_preferred"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoS8DirectStride1::is_preferred"_hash)) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; auto OC = fm.ocpg; @@ -51,8 +52,9 @@ bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred( size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoS8DirectStride1::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoS8DirectStride1::get_workspace"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = direct_int8_stride1::get_bundle(param, large_group); return bundle.total_size_in_bytes(); @@ -61,11 +63,11 @@ size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns( - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoS8DirectStride1::dispatch_kerns"_hash)) { +SmallVector ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns( + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoS8DirectStride1::dispatch_kerns"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; return direct_int8_stride1::get_kimpls(param, large_group); } @@ -75,14 +77,14 @@ ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns( /* ===================== stride1 algo ===================== */ bool ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { return channel_wise_nchw44::stride1::is_available(param); } size_t ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::get_workspace( - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, midout_iv("AlgoS8ChanWiseStride1NCHW44::get_workspace"_hash)) { auto bundle = channel_wise_nchw44::stride1::get_bundle(param); return bundle.total_size_in_bytes(); @@ -91,11 +93,11 @@ size_t ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::dispatch_kerns( - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoS8ChanWiseStride1NCHW44::dispatch_kerns"_hash)) { +SmallVector ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44:: + dispatch_kerns(const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoS8ChanWiseStride1NCHW44::dispatch_kerns"_hash)) { return channel_wise_nchw44::stride1::get_kimpls(param); } MIDOUT_END(); @@ -104,15 +106,15 @@ ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::dispatch_kerns( /* ===================== stride2 algo ===================== */ bool ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { return channel_wise_nchw44::stride2::is_available(param); } size_t ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::get_workspace( - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoS8ChanWiseStride2NCHW44::get_workspace"_hash)) { + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoS8ChanWiseStride2NCHW44::get_workspace"_hash)) { auto bundle = channel_wise_nchw44::stride2::get_bundle(param); return bundle.total_size_in_bytes(); } @@ -120,11 +122,11 @@ size_t ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns( - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoS8ChanWiseStride2NCHW44::dispatch_kerns"_hash)) { +SmallVector ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44:: + dispatch_kerns(const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoS8ChanWiseStride2NCHW44::dispatch_kerns"_hash)) { return channel_wise_nchw44::stride2::get_kimpls(param); } MIDOUT_END(); @@ -132,15 +134,16 @@ ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns( } /* ===================== stride2 algo ===================== */ -bool ConvBiasImpl::AlgoS8DirectStride2::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoS8DirectStride2::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { return direct_int8_stride2::can_conv_direct_stride2_int8(param); } size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoS8DirectStride2::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoS8DirectStride2::get_workspace"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = direct_int8_stride2::get_bundle(param, large_group); return bundle.total_size_in_bytes(); @@ -149,11 +152,11 @@ size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoS8DirectStride2::dispatch_kerns"_hash)) { +SmallVector ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoS8DirectStride2::dispatch_kerns"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; return direct_int8_stride2::get_kimpls(param, large_group); } @@ -163,8 +166,8 @@ ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( #if MGB_ENABLE_DOT /* ===================== dot stride1 algo ======================== */ -bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoDotS8DirectStride1::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { if (!cpuinfo_has_arm_neon_dot()) { return false; } @@ -173,8 +176,9 @@ bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param, size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoDotS8DirectStride1::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoDotS8DirectStride1::get_workspace"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group); return bundle.total_size_in_bytes(); @@ -183,11 +187,11 @@ size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoDotS8DirectStride1::dispatch_kerns"_hash)) { +SmallVector ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoDotS8DirectStride1::dispatch_kerns"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; return direct_dotprod_int8_stride1::get_kimpls(param, large_group); } @@ -196,9 +200,9 @@ ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( } /* ===================== dot stride2 algo ======================== */ -bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { - if (!cpuinfo_has_arm_neon_dot()){ +bool ConvBiasImpl::AlgoDotS8DirectStride2::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + if (!cpuinfo_has_arm_neon_dot()) { return false; } return direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); @@ -206,8 +210,9 @@ bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param, size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoDotS8DirectStride2::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoDotS8DirectStride2::get_workspace"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group); return bundle.total_size_in_bytes(); @@ -216,11 +221,11 @@ size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoDotS8DirectStride2::dispatch_kerns"_hash)) { +SmallVector ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoDotS8DirectStride2::dispatch_kerns"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; return direct_dotprod_int8_stride2::get_kimpls(param, large_group); } @@ -234,29 +239,28 @@ ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns( bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("AlgoS8WinogradF23_8x8::usable"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("AlgoS8WinogradF23_8x8::usable"_hash)) { if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0) return false; using Strategy = winograd::winograd_2x3_8x8_s8; using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = - megdnn::winograd::ConvBias( + megdnn::winograd::ConvBias( strategy, m_tile_size, param) .get_matmul_kern_param(param); return m_matmul_algo->usable(matmul_param) && m_matmul_algo->packmode() == PackMode::NO_PACK && (param.filter_meta.format == param::ConvBias::Format::NCHW && - param.filter_type.enumv() == DTypeEnum::QuantizedS8) && + param.filter_type.enumv() == DTypeEnum::QuantizedS8) && !param.filter_meta.should_flip && (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::QuantizedS8 && @@ -267,18 +271,18 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoS8WinogradF23_8x8, - winograd::winograd_2x3_8x8_s8, - megdnn_arm_common_conv_bias_int8, - param::MatrixMul::Format::MK8); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoS8WinogradF23_8x8, winograd::winograd_2x3_8x8_s8, + megdnn_arm_common_conv_bias_int8, param::MatrixMul::Format::MK8); //=========================== input int8 compute float32 ========= bool ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { MEGDNN_MARK_USED_VAR(param); - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("arm_common_AlgoS8CF32WinogradF23_4x4::usable"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("arm_common_AlgoS8CF32WinogradF23_4x4::usable"_hash)) { if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) return false; bool is_matmul_usable = false; @@ -287,21 +291,18 @@ bool ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44::usable( using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; Strategy strategy(param.src_type, param.filter_type, param.dst_type); is_matmul_usable = m_matmul_algo->usable( - megdnn::winograd::ConvBias( + megdnn::winograd::ConvBias( strategy, m_tile_size, param) .get_matmul_kern_param(param)); - return is_matmul_usable && - m_matmul_algo->packmode() == PackMode::NO_PACK && + return is_matmul_usable && m_matmul_algo->packmode() == PackMode::NO_PACK && (param.filter_meta.format == param::ConvBias::Format::NCHW44 && - param.filter_type.enumv() == DTypeEnum::QuantizedS8) && + param.filter_type.enumv() == DTypeEnum::QuantizedS8) && !param.filter_meta.should_flip && (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && (param.compute_mode == param::ConvBias::ComputeMode::FLOAT32 || param.compute_mode == param::ConvBias::ComputeMode::DEFAULT) && @@ -313,10 +314,9 @@ bool ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoS8CF32WinogradF23_4x4_NCHW44, - winograd::winograd_2x3_4x4_s8_f32_nchw44, - megdnn_arm_common_conv_bias_int8, - param::MatrixMul::Format::MK4); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoS8CF32WinogradF23_4x4_NCHW44, winograd::winograd_2x3_4x4_s8_f32_nchw44, + megdnn_arm_common_conv_bias_int8, param::MatrixMul::Format::MK4); /* ======================= AlgoS8WinogradF23_8x8_NCHW44 ======================== */ bool ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44::usable( @@ -330,8 +330,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44::usable( using Strategy = winograd::winograd_2x3_8x8_s8_nchw44; Strategy strategy(param.src_type, param.filter_type, param.dst_type); auto&& matmul_param = - megdnn::winograd::ConvBias( + megdnn::winograd::ConvBias( strategy, m_tile_size, param) .get_matmul_kern_param(param); bool is_matmul_usable = m_matmul_algo->usable(matmul_param); @@ -343,8 +342,7 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44::usable( param.filter_meta.spatial[0] == 3) && (param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == 1) && - (param.filter_meta.dilation[0] == - param.filter_meta.dilation[1] && + (param.filter_meta.dilation[0] == param.filter_meta.dilation[1] && param.filter_meta.dilation[0] == 1) && param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && param.src_type.enumv() == DTypeEnum::QuantizedS8 && @@ -355,9 +353,8 @@ bool ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44::usable( return false; } -MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoS8WinogradF23_8x8_NCHW44, - winograd::winograd_2x3_8x8_s8_nchw44, - megdnn_arm_common_conv_bias_int8, - param::MatrixMul::Format::MK8); +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL( + AlgoS8WinogradF23_8x8_NCHW44, winograd::winograd_2x3_8x8_s8_nchw44, + megdnn_arm_common_conv_bias_int8, param::MatrixMul::Format::MK8); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h index 3addf2a0..fbf971b5 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.h +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -18,14 +18,12 @@ namespace megdnn { namespace arm_common { class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase { - public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "S8STRD1"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; @@ -39,14 +37,12 @@ public: }; class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { - public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "S8STRD2"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -60,12 +56,11 @@ public: class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { public: AlgoS8DirectNCHW44() {} - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "S8_NCHW44_DIRECT"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; @@ -79,12 +74,11 @@ public: class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { public: AlgoS8DirectNCHWNCHW44() {} - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "S8_CONV_NCHW_NCHW44"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; @@ -97,12 +91,11 @@ public: class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "S8_CHAN_WISE_STRD1_NCHW44"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; @@ -114,12 +107,11 @@ public: class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "S8_CHAN_WISE_STRD2_NCHW44"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; @@ -133,12 +125,10 @@ public: class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMDOTS8_NCHW_NCHW44"; } - bool usable(const NCBKernSizeParam&, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) + const override; size_t get_workspace(const NCBKernSizeParam&) const override; virtual SmallVector dispatch_kerns( @@ -150,14 +140,11 @@ public: }; class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { - public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMDOTS8STRD1"; } - bool usable(const NCBKernSizeParam&, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) + const override; size_t get_workspace(const NCBKernSizeParam&) const override; virtual SmallVector dispatch_kerns( @@ -169,15 +156,12 @@ public: }; class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { - public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMDOTS8STRD2"; } - bool usable(const NCBKernSizeParam&, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) + const override; size_t get_workspace(const NCBKernSizeParam&) const override; virtual SmallVector dispatch_kerns( @@ -192,17 +176,14 @@ class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { public: AlgoDotS8Direct_NCHW44() {} - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMDOTS8DIRECT_NCHW44"; } - bool usable(const NCBKernSizeParam&, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable(const NCBKernSizeParam&, AlgoSelectionStrategy algo_selection_strategy) + const override; size_t get_workspace(const NCBKernSizeParam&) const override; - SmallVector dispatch_kerns( - const NCBKernSizeParam& param) const override; + SmallVector dispatch_kerns(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override; @@ -215,8 +196,8 @@ public: class ConvBiasImpl::AlgoS8WinogradF23_8x8 final : public AlgoBase { public: - AlgoS8WinogradF23_8x8(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoS8WinogradF23_8x8( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} const char* name() const override { if (m_name.empty()) { @@ -225,9 +206,7 @@ public: } return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_S8) }; @@ -246,9 +225,7 @@ public: } return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32) }; @@ -256,8 +233,8 @@ public: //=======================input int8 compute int16 output int8============ class ConvBiasImpl::AlgoS8WinogradF23_8x8_NCHW44 final : public AlgoBase { public: - AlgoS8WinogradF23_8x8_NCHW44(fallback::MatrixMulImpl::AlgoBase* matmul_algo, - uint32_t tile_size) + AlgoS8WinogradF23_8x8_NCHW44( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} const char* name() const override { if (m_name.empty()) { @@ -267,9 +244,7 @@ public: } return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8) }; diff --git a/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp index fb533e63..3c98e3a3 100644 --- a/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp +++ b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.cpp @@ -20,9 +20,9 @@ using namespace megdnn; using namespace arm_common; -static inline void accumulate_2_q_vector(int8x16_t& src0, int8x16_t& kern0, - int8x16_t& src1, int8x16_t& kern1, - int32x4_t* sum) { +static inline void accumulate_2_q_vector( + int8x16_t& src0, int8x16_t& kern0, int8x16_t& src1, int8x16_t& kern1, + int32x4_t* sum) { int16x8_t tmp_sum0 = vmull_s8(vget_low_s8(src0), vget_low_s8(kern0)); int16x8_t tmp_sum1 = vmull_high_s8(src0, kern0); tmp_sum0 = vmlal_s8(tmp_sum0, vget_low_s8(src1), vget_low_s8(kern1)); @@ -33,8 +33,8 @@ static inline void accumulate_2_q_vector(int8x16_t& src0, int8x16_t& kern0, sum[3] = vaddw_s16(sum[3], vget_high_s16(tmp_sum1)); } -static inline void accumulate_1_q_vector(int8x16_t& src0, int8x16_t& kern0, - int32x4_t* sum) { +static inline void accumulate_1_q_vector( + int8x16_t& src0, int8x16_t& kern0, int32x4_t* sum) { int16x8_t tmp_sum0 = vmull_s8(vget_low_s8(src0), vget_low_s8(kern0)); int16x8_t tmp_sum1 = vmull_high_s8(src0, kern0); sum[0] = vaddw_s16(sum[0], vget_low_s16(tmp_sum0)); @@ -43,9 +43,9 @@ static inline void accumulate_1_q_vector(int8x16_t& src0, int8x16_t& kern0, sum[3] = vaddw_s16(sum[3], vget_high_s16(tmp_sum1)); } -static inline void accumulate_2_d_vector(int8x16_t& src0, int8x8_t& kern0, - int8x16_t& src1, int8x8_t& kern1, - int32x4_t& sum0, int32x4_t& sum1) { +static inline void accumulate_2_d_vector( + int8x16_t& src0, int8x8_t& kern0, int8x16_t& src1, int8x8_t& kern1, + int32x4_t& sum0, int32x4_t& sum1) { int16x8_t tmp_sum0 = vmull_s8(vget_low_s8(src0), kern0); int16x8_t tmp_sum1 = vmull_s8(vget_high_s8(src0), kern0); tmp_sum0 = vmlal_s8(tmp_sum0, vget_low_s8(src1), kern1); @@ -56,20 +56,17 @@ static inline void accumulate_2_d_vector(int8x16_t& src0, int8x8_t& kern0, sum1 = vaddw_s16(sum1, vget_high_s16(tmp_sum1)); } -static inline void accumulate_1_line_horizon(const int8x8_t& src0, - const int8x8_t& kern0, - const int8x8_t& src1, - const int8x8_t& kern1, - int32x4_t& sum) { +static inline void accumulate_1_line_horizon( + const int8x8_t& src0, const int8x8_t& kern0, const int8x8_t& src1, + const int8x8_t& kern1, int32x4_t& sum) { int16x8_t tmp_sum = vmull_s8(src0, kern0); tmp_sum = vmlal_s8(tmp_sum, src1, kern1); sum = vaddw_s16(sum, vget_low_s16(tmp_sum)); sum = vaddw_s16(sum, vget_high_s16(tmp_sum)); } -static inline void accumulate_1_d_vector(const int8x8_t& src0, - const int8x8_t& kern0, - int32x4_t& sum) { +static inline void accumulate_1_d_vector( + const int8x8_t& src0, const int8x8_t& kern0, int32x4_t& sum) { int16x8_t tmp_sum = vmull_s8(src0, kern0); sum = vaddw_s16(sum, vget_low_s16(tmp_sum)); sum = vaddw_s16(sum, vget_high_s16(tmp_sum)); @@ -79,46 +76,42 @@ static inline void accumulate_1_d_vector(const int8x8_t& src0, sum = vaddw_s16(sum, vget_low_s16(tmp_sum)); \ sum = vaddw_s16(sum, vget_high_s16(tmp_sum)); -#define STORE_1_LINE(dst, oh, ow, OW, sum) \ - if (quantized) { \ - dt_qint8* dptr = \ - reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ - op({{sum[0], sum[1]}}, dptr); \ - op({{sum[2], sum[3]}}, dptr + 8); \ - } else { \ - dt_int32* dptr = \ - reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ - vst1q_s32(dptr, sum[0]); \ - vst1q_s32(dptr + 4, sum[1]); \ - vst1q_s32(dptr + 8, sum[2]); \ - vst1q_s32(dptr + 12, sum[3]); \ +#define STORE_1_LINE(dst, oh, ow, OW, sum) \ + if (quantized) { \ + dt_qint8* dptr = reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ + op({{sum[0], sum[1]}}, dptr); \ + op({{sum[2], sum[3]}}, dptr + 8); \ + } else { \ + dt_int32* dptr = reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ + vst1q_s32(dptr, sum[0]); \ + vst1q_s32(dptr + 4, sum[1]); \ + vst1q_s32(dptr + 8, sum[2]); \ + vst1q_s32(dptr + 12, sum[3]); \ } -#define STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum, remain) \ - if (quantized) { \ - dt_qint8* dptr = \ - reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ - if (remain == 1) { \ - op(sum[0], dptr); \ - } else if (remain == 2) { \ - op({{sum[0], sum[1]}}, dptr); \ - } else if (remain == 3) { \ - op({{sum[0], sum[1]}}, dptr); \ - op(sum[2], dptr + 8); \ - } \ - } else { \ - dt_int32* dptr = \ - reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ - if (remain == 1) { \ - vst1q_s32(dptr, sum[0]); \ - } else if (remain == 2) { \ - vst1q_s32(dptr, sum[0]); \ - vst1q_s32(dptr + 4, sum[1]); \ - } else if (remain == 3) { \ - vst1q_s32(dptr, sum[0]); \ - vst1q_s32(dptr + 4, sum[1]); \ - vst1q_s32(dptr + 8, sum[2]); \ - } \ +#define STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum, remain) \ + if (quantized) { \ + dt_qint8* dptr = reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ + if (remain == 1) { \ + op(sum[0], dptr); \ + } else if (remain == 2) { \ + op({{sum[0], sum[1]}}, dptr); \ + } else if (remain == 3) { \ + op({{sum[0], sum[1]}}, dptr); \ + op(sum[2], dptr + 8); \ + } \ + } else { \ + dt_int32* dptr = reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ + if (remain == 1) { \ + vst1q_s32(dptr, sum[0]); \ + } else if (remain == 2) { \ + vst1q_s32(dptr, sum[0]); \ + vst1q_s32(dptr + 4, sum[1]); \ + } else if (remain == 3) { \ + vst1q_s32(dptr, sum[0]); \ + vst1q_s32(dptr + 4, sum[1]); \ + vst1q_s32(dptr + 8, sum[2]); \ + } \ } template @@ -145,9 +138,8 @@ void channel_wise_nchw44::direct_stride1_2x2_int8( } else { sum00 = vdupq_n_s32(0); } - int32x4_t sum01 = sum00, sum02 = sum00, sum03 = sum00, - sum10 = sum00, sum11 = sum00, sum12 = sum00, - sum13 = sum00; + int32x4_t sum01 = sum00, sum02 = sum00, sum03 = sum00, sum10 = sum00, + sum11 = sum00, sum12 = sum00, sum13 = sum00; int8x16_t src0 = vld1q_s8(sptr0); int8x8_t src03 = vld1_s8(sptr0 + 3 * 4), src00 = vget_low_s8(src0), src02 = vget_high_s8(src0); @@ -552,22 +544,19 @@ void channel_wise_nchw44::direct_stride1_3x3_int8( "saddw2 %[sum21].4s, %[sum21].4s, v27.8h\n" "saddw %[sum22].4s, %[sum22].4s, v28.4h\n" "saddw2 %[sum23].4s, %[sum23].4s, v28.8h\n" - : [k0] "+w"(kern[0]), [k1] "+w"(kern[1]), - [k2] "+w"(kern[2]), [k3] "+w"(kern[3]), - [k4] "+w"(kern[4]), [k5] "+w"(kern[5]), - [k6] "+w"(kern[6]), [k7] "+w"(kern[7]), - [k8] "+w"(kern[8]), [sum00] "+w"(sum0[0]), - [sum01] "+w"(sum0[1]), [sum02] "+w"(sum0[2]), - [sum03] "+w"(sum0[3]), [sum10] "+w"(sum1[0]), - [sum11] "+w"(sum1[1]), [sum12] "+w"(sum1[2]), - [sum13] "+w"(sum1[3]), [sum20] "+w"(sum2[0]), - [sum21] "+w"(sum2[1]), [sum22] "+w"(sum2[2]), - [sum23] "+w"(sum2[3]), [sptr0] "+r"(sptr0), - [sptr1] "+r"(sptr1), [sptr2] "+r"(sptr2), - [sptr3] "+r"(sptr3), [sptr4] "+r"(sptr4) : - : "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "cc", "memory"); + [k0] "+w"(kern[0]), [k1] "+w"(kern[1]), [k2] "+w"(kern[2]), + [k3] "+w"(kern[3]), [k4] "+w"(kern[4]), [k5] "+w"(kern[5]), + [k6] "+w"(kern[6]), [k7] "+w"(kern[7]), [k8] "+w"(kern[8]), + [sum00] "+w"(sum0[0]), [sum01] "+w"(sum0[1]), [sum02] "+w"(sum0[2]), + [sum03] "+w"(sum0[3]), [sum10] "+w"(sum1[0]), [sum11] "+w"(sum1[1]), + [sum12] "+w"(sum1[2]), [sum13] "+w"(sum1[3]), [sum20] "+w"(sum2[0]), + [sum21] "+w"(sum2[1]), [sum22] "+w"(sum2[2]), [sum23] "+w"(sum2[3]), + [sptr0] "+r"(sptr0), [sptr1] "+r"(sptr1), [sptr2] "+r"(sptr2), + [sptr3] "+r"(sptr3), [sptr4] "+r"(sptr4) + : + : "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", "cc", "memory"); STORE_1_LINE(dst, (oh), ow, OW, sum0); STORE_1_LINE(dst, (oh + 1), ow, OW, sum1); @@ -991,10 +980,8 @@ void channel_wise_nchw44::direct_stride2_2x2_int8( int8x16_t src10 = vld1q_s8(sptr1); int8x16_t src11 = vld1q_s8(sptr1 + 16); - accumulate_2_d_vector(src00, kern01, src10, kern23, sum[0][0], - sum[0][1]); - accumulate_2_d_vector(src01, kern01, src11, kern23, sum[0][2], - sum[0][3]); + accumulate_2_d_vector(src00, kern01, src10, kern23, sum[0][0], sum[0][1]); + accumulate_2_d_vector(src01, kern01, src11, kern23, sum[0][2], sum[0][3]); int8x16_t src20 = vld1q_s8(sptr2); int8x16_t src21 = vld1q_s8(sptr2 + 16); @@ -1002,10 +989,8 @@ void channel_wise_nchw44::direct_stride2_2x2_int8( int8x16_t src30 = vld1q_s8(sptr3); int8x16_t src31 = vld1q_s8(sptr3 + 16); - accumulate_2_d_vector(src20, kern01, src30, kern23, sum[1][0], - sum[1][1]); - accumulate_2_d_vector(src21, kern01, src31, kern23, sum[1][2], - sum[1][3]); + accumulate_2_d_vector(src20, kern01, src30, kern23, sum[1][0], sum[1][1]); + accumulate_2_d_vector(src21, kern01, src31, kern23, sum[1][2], sum[1][3]); STORE_1_LINE(dst, oh, ow, OW, sum[0]); STORE_1_LINE(dst, (oh + 1), ow, OW, sum[1]); @@ -1031,10 +1016,8 @@ void channel_wise_nchw44::direct_stride2_2x2_int8( int8x16_t src10 = vld1q_s8(sptr1); int8x16_t src11 = vld1q_s8(sptr1 + 16); - accumulate_2_d_vector(src00, kern01, src10, kern23, sum[0][0], - sum[0][1]); - accumulate_2_d_vector(src01, kern01, src11, kern23, sum[0][2], - sum[0][3]); + accumulate_2_d_vector(src00, kern01, src10, kern23, sum[0][0], sum[0][1]); + accumulate_2_d_vector(src01, kern01, src11, kern23, sum[0][2], sum[0][3]); int8x16_t src20 = vld1q_s8(sptr2); int8x16_t src21 = vld1q_s8(sptr2 + 16); @@ -1042,10 +1025,8 @@ void channel_wise_nchw44::direct_stride2_2x2_int8( int8x16_t src30 = vld1q_s8(sptr3); int8x16_t src31 = vld1q_s8(sptr3 + 16); - accumulate_2_d_vector(src20, kern01, src30, kern23, sum[1][0], - sum[1][1]); - accumulate_2_d_vector(src21, kern01, src31, kern23, sum[1][2], - sum[1][3]); + accumulate_2_d_vector(src20, kern01, src30, kern23, sum[1][0], sum[1][1]); + accumulate_2_d_vector(src21, kern01, src31, kern23, sum[1][2], sum[1][3]); STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum[0], remain); STORE_1_LINE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain); @@ -1069,10 +1050,8 @@ void channel_wise_nchw44::direct_stride2_2x2_int8( int8x16_t src10 = vld1q_s8(sptr1); int8x16_t src11 = vld1q_s8(sptr1 + 16); - accumulate_2_d_vector(src00, kern01, src10, kern23, sum0[0], - sum0[1]); - accumulate_2_d_vector(src01, kern01, src11, kern23, sum0[2], - sum0[3]); + accumulate_2_d_vector(src00, kern01, src10, kern23, sum0[0], sum0[1]); + accumulate_2_d_vector(src01, kern01, src11, kern23, sum0[2], sum0[3]); STORE_1_LINE(dst, oh, ow, OW, sum0); } @@ -1092,10 +1071,8 @@ void channel_wise_nchw44::direct_stride2_2x2_int8( int8x16_t src10 = vld1q_s8(sptr1); int8x16_t src11 = vld1q_s8(sptr1 + 16); - accumulate_2_d_vector(src00, kern01, src10, kern23, sum0[0], - sum0[1]); - accumulate_2_d_vector(src01, kern01, src11, kern23, sum0[2], - sum0[3]); + accumulate_2_d_vector(src00, kern01, src10, kern23, sum0[0], sum0[1]); + accumulate_2_d_vector(src01, kern01, src11, kern23, sum0[2], sum0[3]); STORE_1_LINE_REMAIN(dst, oh, ow, OW, sum0, remain); } @@ -1126,15 +1103,14 @@ void channel_wise_nchw44::direct_stride2_3x3_int8( int8x8_t kern80 = vreinterpret_s8_s32( vzip_s32(vreinterpret_s32_s8(vld1_s8(filter + 28)), zero).val[1]); -#define COMPUTE_ONE_LINE(src00, src01, src02, kern01, kern20, sum) \ - accumulate_1_line_horizon(vget_low_s8(src00), kern01, vget_high_s8(src00), \ - kern20, sum[0]); \ - accumulate_1_line_horizon(vget_high_s8(src00), kern01, vget_low_s8(src01), \ - kern20, sum[1]); \ - accumulate_1_line_horizon(vget_low_s8(src01), kern01, vget_high_s8(src01), \ - kern20, sum[2]); \ - accumulate_1_line_horizon(vget_high_s8(src01), kern01, src02, kern20, \ - sum[3]); +#define COMPUTE_ONE_LINE(src00, src01, src02, kern01, kern20, sum) \ + accumulate_1_line_horizon( \ + vget_low_s8(src00), kern01, vget_high_s8(src00), kern20, sum[0]); \ + accumulate_1_line_horizon( \ + vget_high_s8(src00), kern01, vget_low_s8(src01), kern20, sum[1]); \ + accumulate_1_line_horizon( \ + vget_low_s8(src01), kern01, vget_high_s8(src01), kern20, sum[2]); \ + accumulate_1_line_horizon(vget_high_s8(src01), kern01, src02, kern20, sum[3]); size_t oh = 0_z; for (; oh + 2 <= OH; oh += 2) { @@ -1341,42 +1317,42 @@ void channel_wise_nchw44::direct_stride2_5x5_int8( kern4[2] = vreinterpret_s8_s32( vzip_s32(vreinterpret_s32_s8(vld1_s8(filter + 92)), zero).val[1]); -#define COMPUTE_ONE_VECTOR(src00, src01, src02, src10, src11, src12, kern0, \ - kern1, sum) \ - accumulate_1_line_horizon(src00, kern0[0], src10, kern1[0], sum); \ - accumulate_1_line_horizon(src01, kern0[1], src11, kern1[1], sum); \ +#define COMPUTE_ONE_VECTOR( \ + src00, src01, src02, src10, src11, src12, kern0, kern1, sum) \ + accumulate_1_line_horizon(src00, kern0[0], src10, kern1[0], sum); \ + accumulate_1_line_horizon(src01, kern0[1], src11, kern1[1], sum); \ accumulate_1_line_horizon(src02, kern0[2], src12, kern1[2], sum); -#define COMPUTE_TWO_LINE(src0, src1, kern0, kern1, sum) \ - COMPUTE_ONE_VECTOR(vget_low_s8(src0[0]), vget_high_s8(src0[0]), \ - vget_low_s8(src0[1]), vget_low_s8(src1[0]), \ - vget_high_s8(src1[0]), vget_low_s8(src1[1]), kern0, \ - kern1, sum[0]) \ - COMPUTE_ONE_VECTOR(vget_high_s8(src0[0]), vget_low_s8(src0[1]), \ - vget_high_s8(src0[1]), vget_high_s8(src1[0]), \ - vget_low_s8(src1[1]), vget_high_s8(src1[1]), kern0, \ - kern1, sum[1]) \ - COMPUTE_ONE_VECTOR(vget_low_s8(src0[1]), vget_high_s8(src0[1]), \ - vget_low_s8(src0[2]), vget_low_s8(src1[1]), \ - vget_high_s8(src1[1]), vget_low_s8(src1[2]), kern0, \ - kern1, sum[2]) \ - COMPUTE_ONE_VECTOR(vget_high_s8(src0[1]), vget_low_s8(src0[2]), \ - vget_high_s8(src0[2]), vget_high_s8(src1[1]), \ - vget_low_s8(src1[2]), vget_high_s8(src1[2]), kern0, \ - kern1, sum[3]) - -#define COMPUTE_ONE_LINE(src, kern, sum) \ - accumulate_1_line_horizon(vget_low_s8(src[0]), kern[0], \ - vget_high_s8(src[0]), kern[1], sum[0]); \ - accumulate_1_line_horizon(vget_high_s8(src[0]), kern[0], \ - vget_low_s8(src[1]), kern[1], sum[1]); \ - accumulate_1_line_horizon(vget_low_s8(src[1]), kern[0], \ - vget_high_s8(src[1]), kern[1], sum[2]); \ - accumulate_1_line_horizon(vget_high_s8(src[1]), kern[0], \ - vget_low_s8(src[2]), kern[1], sum[3]); \ - accumulate_1_d_vector(vget_low_s8(src[1]), kern[2], sum[0]); \ - accumulate_1_d_vector(vget_high_s8(src[1]), kern[2], sum[1]); \ - accumulate_1_d_vector(vget_low_s8(src[2]), kern[2], sum[2]); \ +#define COMPUTE_TWO_LINE(src0, src1, kern0, kern1, sum) \ + COMPUTE_ONE_VECTOR( \ + vget_low_s8(src0[0]), vget_high_s8(src0[0]), vget_low_s8(src0[1]), \ + vget_low_s8(src1[0]), vget_high_s8(src1[0]), vget_low_s8(src1[1]), kern0, \ + kern1, sum[0]) \ + COMPUTE_ONE_VECTOR( \ + vget_high_s8(src0[0]), vget_low_s8(src0[1]), vget_high_s8(src0[1]), \ + vget_high_s8(src1[0]), vget_low_s8(src1[1]), vget_high_s8(src1[1]), kern0, \ + kern1, sum[1]) \ + COMPUTE_ONE_VECTOR( \ + vget_low_s8(src0[1]), vget_high_s8(src0[1]), vget_low_s8(src0[2]), \ + vget_low_s8(src1[1]), vget_high_s8(src1[1]), vget_low_s8(src1[2]), kern0, \ + kern1, sum[2]) \ + COMPUTE_ONE_VECTOR( \ + vget_high_s8(src0[1]), vget_low_s8(src0[2]), vget_high_s8(src0[2]), \ + vget_high_s8(src1[1]), vget_low_s8(src1[2]), vget_high_s8(src1[2]), kern0, \ + kern1, sum[3]) + +#define COMPUTE_ONE_LINE(src, kern, sum) \ + accumulate_1_line_horizon( \ + vget_low_s8(src[0]), kern[0], vget_high_s8(src[0]), kern[1], sum[0]); \ + accumulate_1_line_horizon( \ + vget_high_s8(src[0]), kern[0], vget_low_s8(src[1]), kern[1], sum[1]); \ + accumulate_1_line_horizon( \ + vget_low_s8(src[1]), kern[0], vget_high_s8(src[1]), kern[1], sum[2]); \ + accumulate_1_line_horizon( \ + vget_high_s8(src[1]), kern[0], vget_low_s8(src[2]), kern[1], sum[3]); \ + accumulate_1_d_vector(vget_low_s8(src[1]), kern[2], sum[0]); \ + accumulate_1_d_vector(vget_high_s8(src[1]), kern[2], sum[1]); \ + accumulate_1_d_vector(vget_low_s8(src[2]), kern[2], sum[2]); \ accumulate_1_d_vector(vget_high_s8(src[2]), kern[2], sum[3]) size_t oh = 0_z; @@ -1606,19 +1582,15 @@ void channel_wise_nchw44::direct_stride2_5x5_int8( #define INSTANTIATION(quantized, stride, i, bias, Op) \ template void channel_wise_nchw44::direct_##stride##_##i##x##i##_int8< \ - quantized, bias, Op>(const int8_t*, const int8_t*, const int32_t*, \ - void*, const size_t, const size_t, \ - const size_t, const size_t, const Op&); - -#define FOR_OP(stride, i, bias) \ - INSTANTIATION(true, stride, i, bias, \ - TypeCvtOp) \ - INSTANTIATION(true, stride, i, bias, \ - ReluOp) \ - INSTANTIATION(true, stride, i, bias, \ - HSwishOp) \ - INSTANTIATION(false, stride, i, bias, \ - NoneOp) + quantized, bias, Op>( \ + const int8_t*, const int8_t*, const int32_t*, void*, const size_t, \ + const size_t, const size_t, const size_t, const Op&); + +#define FOR_OP(stride, i, bias) \ + INSTANTIATION(true, stride, i, bias, TypeCvtOp) \ + INSTANTIATION(true, stride, i, bias, ReluOp) \ + INSTANTIATION(true, stride, i, bias, HSwishOp) \ + INSTANTIATION(false, stride, i, bias, NoneOp) #define FOR_BIAS(stride, i) \ FOR_OP(stride, i, BiasMode::NO_BIAS) \ diff --git a/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h index 9c91e82e..5e009ebd 100644 --- a/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h +++ b/dnn/src/arm_common/conv_bias/int8/channel_wise_kernel.h @@ -16,12 +16,12 @@ namespace megdnn { namespace arm_common { namespace channel_wise_nchw44 { -#define KERN(stride, i) \ - template \ - void direct_##stride##_##i##x##i##_int8( \ - const int8_t* src, const int8_t* filter, const int32_t* bias, \ - void* dst, const size_t IH, const size_t IW, const size_t OH, \ - const size_t OW, const Op& op); +#define KERN(stride, i) \ + template \ + void direct_##stride##_##i##x##i##_int8( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, void* dst, \ + const size_t IH, const size_t IW, const size_t OH, const size_t OW, \ + const Op& op); KERN(stride1, 2) KERN(stride1, 3) diff --git a/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp b/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp index a3720fc7..85006e64 100644 --- a/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.cpp @@ -23,8 +23,8 @@ using namespace channel_wise_nchw44; namespace { void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& IH2, + size_t& IW2) { auto&& fm = param.filter_meta; auto SW = fm.stride[1]; auto OH = param.osz[0]; @@ -54,16 +54,15 @@ bool stride1::is_available(const NCBKernSizeParam& param) { (param.src_type.enumv() == DTypeEnum::Int8 && param.filter_type.enumv() == DTypeEnum::Int8 && param.dst_type.enumv() == DTypeEnum::Int32)) && - fm.format == param::Convolution::Format::NCHW44 && - !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5) && - fm.icpg == 1 && fm.ocpg == 1 && fm.group % 4 == 0; + fm.format == param::Convolution::Format::NCHW44 && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5) && fm.icpg == 1 && fm.ocpg == 1 && + fm.group % 4 == 0; return avaible; } -WorkspaceBundle stride1::get_bundle( - const ConvBiasImpl::NCBKernSizeParam& param) { +WorkspaceBundle stride1::get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { size_t nr_threads = param.nr_threads; size_t IH2, IW2; get_rectified_size(param, IH2, IW2); @@ -76,9 +75,9 @@ WorkspaceBundle stride1::get_bundle( //! compute one output channel template -void stride1::do_conv_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { +void stride1::do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { size_t PH = kern_param.filter_meta.padding[0]; size_t PW = kern_param.filter_meta.padding[1]; size_t OH = kern_param.osz[0]; @@ -89,8 +88,7 @@ void stride1::do_conv_kern(const WorkspaceBundle& bundle, get_rectified_size(kern_param, IH2, IW2); Op op = Op(1.0f, 1.0f); if (quantized) { - float scale_bias = - kern_param.bias_type.param().scale; + float scale_bias = kern_param.bias_type.param().scale; float scale_dst = kern_param.dst_type.param().scale; op = Op(scale_bias, scale_dst); } @@ -110,9 +108,9 @@ void stride1::do_conv_kern(const WorkspaceBundle& bundle, //! copy in case of illegal read src when padding is zero std::memset(padding_src, 0, sizeof(int8_t) * IH2 * IW2 * pack_ic_size); rep(ih, IH) { - std::memcpy(padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, - sptr + ih * IW * pack_ic_size, - sizeof(int8_t) * IW * pack_ic_size); + std::memcpy( + padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, + sptr + ih * IW * pack_ic_size, sizeof(int8_t) * IW * pack_ic_size); } sptr = padding_src; @@ -123,56 +121,57 @@ void stride1::do_conv_kern(const WorkspaceBundle& bundle, #undef KERN } -SmallVector stride1::get_kimpls( - const NCBKernSizeParam& param) { +SmallVector stride1::get_kimpls(const NCBKernSizeParam& param) { auto fm = param.filter_meta; size_t N = param.n; size_t group = fm.group / 4; - megdnn_assert(fm.group % 4 == 0, - "nchw44 channel wise conv with group is not times of 4"); + megdnn_assert( + fm.group % 4 == 0, "nchw44 channel wise conv with group is not times of 4"); WorkspaceBundle wbundle = get_bundle(param); bool quantized = param.dst_type.enumv() == DTypeEnum::QuantizedS8; conv_fun do_conv_fun = nullptr; -#define DO_CONV_KERN_FUN(quantized, filter, bias_mode, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride1, \ - midout_iv(#quantized #filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(quantized, filter, bias_mode, op) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_int8_nchw44_stride1, \ + midout_iv(#quantized #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); -#define GET_OP_PARAM(i, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - if (quantized) { \ - DO_CONV_KERN_FUN(true, i, bias_mode, \ - TypeCvtOp) \ - } else { \ - DO_CONV_KERN_FUN(false, i, bias_mode, \ - NoneOp) \ - } \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - if (quantized) { \ - DO_CONV_KERN_FUN(true, i, bias_mode, \ - ReluOp) \ - } else { \ - DO_CONV_KERN_FUN(false, i, bias_mode, \ - NoneOp) \ - } \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - if (quantized) { \ - DO_CONV_KERN_FUN(true, i, bias_mode, \ - HSwishOp) \ - } else { \ - DO_CONV_KERN_FUN(false, i, bias_mode, \ - NoneOp) \ - } \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + if (quantized) { \ + DO_CONV_KERN_FUN( \ + true, i, bias_mode, \ + TypeCvtOp) \ + } else { \ + DO_CONV_KERN_FUN( \ + false, i, bias_mode, NoneOp) \ + } \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + if (quantized) { \ + DO_CONV_KERN_FUN( \ + true, i, bias_mode, ReluOp) \ + } else { \ + DO_CONV_KERN_FUN( \ + false, i, bias_mode, NoneOp) \ + } \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + if (quantized) { \ + DO_CONV_KERN_FUN( \ + true, i, bias_mode, HSwishOp) \ + } else { \ + DO_CONV_KERN_FUN( \ + false, i, bias_mode, NoneOp) \ + } \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(i) \ @@ -231,16 +230,15 @@ bool stride2::is_available(const NCBKernSizeParam& param) { (param.src_type.enumv() == DTypeEnum::Int8 && param.filter_type.enumv() == DTypeEnum::Int8 && param.dst_type.enumv() == DTypeEnum::Int32)) && - fm.format == param::Convolution::Format::NCHW44 && - !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5) && - fm.icpg == 1 && fm.ocpg == 1 && fm.group % 4 == 0; + fm.format == param::Convolution::Format::NCHW44 && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5) && fm.icpg == 1 && fm.ocpg == 1 && + fm.group % 4 == 0; return avaible; } -WorkspaceBundle stride2::get_bundle( - const ConvBiasImpl::NCBKernSizeParam& param) { +WorkspaceBundle stride2::get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { size_t nr_threads = param.nr_threads; size_t IH2, IW2; get_rectified_size(param, IH2, IW2); @@ -253,9 +251,9 @@ WorkspaceBundle stride2::get_bundle( //! compute one output channel template -void stride2::do_conv_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { +void stride2::do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { size_t PH = kern_param.filter_meta.padding[0]; size_t PW = kern_param.filter_meta.padding[1]; size_t OH = kern_param.osz[0]; @@ -266,8 +264,7 @@ void stride2::do_conv_kern(const WorkspaceBundle& bundle, get_rectified_size(kern_param, IH2, IW2); Op op = Op(1.0f, 1.0f); if (quantized) { - float scale_bias = - kern_param.bias_type.param().scale; + float scale_bias = kern_param.bias_type.param().scale; float scale_dst = kern_param.dst_type.param().scale; op = Op(scale_bias, scale_dst); } @@ -287,9 +284,9 @@ void stride2::do_conv_kern(const WorkspaceBundle& bundle, //! copy in case of illegal read src when padding is zero std::memset(padding_src, 0, sizeof(int8_t) * IH2 * IW2 * pack_ic_size); rep(ih, IH) { - std::memcpy(padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, - sptr + ih * IW * pack_ic_size, - sizeof(int8_t) * IW * pack_ic_size); + std::memcpy( + padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, + sptr + ih * IW * pack_ic_size, sizeof(int8_t) * IW * pack_ic_size); } sptr = padding_src; @@ -300,22 +297,22 @@ void stride2::do_conv_kern(const WorkspaceBundle& bundle, #undef KERN } -SmallVector stride2::get_kimpls( - const NCBKernSizeParam& param) { +SmallVector stride2::get_kimpls(const NCBKernSizeParam& param) { auto fm = param.filter_meta; size_t N = param.n; size_t group = fm.group / 4; - megdnn_assert(fm.group % 4 == 0, - "nchw44 channel wise conv with group is not times of 4"); + megdnn_assert( + fm.group % 4 == 0, "nchw44 channel wise conv with group is not times of 4"); WorkspaceBundle wbundle = get_bundle(param); bool quantized = param.dst_type.enumv() == DTypeEnum::QuantizedS8; conv_fun do_conv_fun = nullptr; -#define DO_CONV_KERN_FUN(quantized, filter, bias_mode, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_stride2, \ - midout_iv(#quantized #filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(quantized, filter, bias_mode, op) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_int8_nchw44_stride2, \ + midout_iv(#quantized #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); DISPATCH_CONV_KERN(); diff --git a/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.h b/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.h index a3610dc3..73b62b91 100644 --- a/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.h +++ b/dnn/src/arm_common/conv_bias/int8/channel_wise_nchw44.h @@ -21,9 +21,9 @@ using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; -using conv_fun = std::function; +using conv_fun = std::function; namespace stride1 { @@ -32,8 +32,9 @@ bool is_available(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); SmallVector get_kimpls(const NCBKernSizeParam& param); } // namespace stride1 @@ -44,13 +45,14 @@ bool is_available(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); SmallVector get_kimpls(const NCBKernSizeParam& param); } // namespace stride2 -} // namespace direct_int8_stride1 +} // namespace channel_wise_nchw44 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8/direct.cpp b/dnn/src/arm_common/conv_bias/int8/direct.cpp index cee5429b..f62610fa 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct.cpp @@ -31,12 +31,10 @@ using namespace arm_common; } template -void conv_bias::conv_direct_stride1_2x2_int8_nchw(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const size_t IH, - const size_t IW, const size_t OH, - const size_t OW, const Op& op) { +void conv_bias::conv_direct_stride1_2x2_int8_nchw( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); int8x8_t k00 = vdup_n_s8(filter[0]); int8x8_t k01 = vdup_n_s8(filter[1]); @@ -315,12 +313,10 @@ void conv_bias::conv_direct_stride1_2x2_int8_nchw(const int8_t* src, } template -void conv_bias::conv_direct_stride1_3x3_int8_nchw(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const size_t IH, - const size_t IW, const size_t OH, - const size_t OW, const Op& op) { +void conv_bias::conv_direct_stride1_3x3_int8_nchw( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); int8x8_t k00 = vdup_n_s8(filter[0]); int8x8_t k01 = vdup_n_s8(filter[1]); @@ -474,12 +470,10 @@ void conv_bias::conv_direct_stride1_3x3_int8_nchw(const int8_t* src, } template -void conv_bias::conv_direct_stride1_5x5_int8_nchw(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const size_t IH, - const size_t IW, const size_t OH, - const size_t OW, const Op& op) { +void conv_bias::conv_direct_stride1_5x5_int8_nchw( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); int8x8_t k00 = vdup_n_s8(filter[0]); int8x8_t k01 = vdup_n_s8(filter[1]); @@ -761,12 +755,10 @@ void conv_bias::conv_direct_stride1_5x5_int8_nchw(const int8_t* src, } template -void conv_bias::conv_direct_stride1_7x7_int8_nchw(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const size_t IH, - const size_t IW, const size_t OH, - const size_t OW, const Op& op) { +void conv_bias::conv_direct_stride1_7x7_int8_nchw( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); int8x8_t k00 = vdup_n_s8(filter[0]); int8x8_t k01 = vdup_n_s8(filter[1]); @@ -1244,12 +1236,10 @@ void conv_bias::conv_direct_stride1_7x7_int8_nchw(const int8_t* src, } template -void conv_bias::conv_direct_stride2_2x2_int8_nchw(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const size_t IH, - const size_t IW, const size_t OH, - const size_t OW, const Op& op) { +void conv_bias::conv_direct_stride2_2x2_int8_nchw( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); #define GET_R2(sptr) \ _r00 = vld1_s8(sptr); \ @@ -1310,12 +1300,10 @@ void conv_bias::conv_direct_stride2_2x2_int8_nchw(const int8_t* src, } template -void conv_bias::conv_direct_stride2_3x3_int8_nchw(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const size_t IH, - const size_t IW, const size_t OH, - const size_t OW, const Op& op) { +void conv_bias::conv_direct_stride2_3x3_int8_nchw( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); #define GET_R3(sptr) \ _r00 = vld1_s8(sptr); \ @@ -1467,12 +1455,10 @@ void conv_bias::conv_direct_stride2_3x3_int8_nchw(const int8_t* src, } template -void conv_bias::conv_direct_stride2_5x5_int8_nchw(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const size_t IH, - const size_t IW, const size_t OH, - const size_t OW, const Op& op) { +void conv_bias::conv_direct_stride2_5x5_int8_nchw( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); #define GET_R5(sptr) \ _r00 = vld1_s8(sptr); \ @@ -1725,12 +1711,10 @@ void conv_bias::conv_direct_stride2_5x5_int8_nchw(const int8_t* src, } template -void conv_bias::conv_direct_stride2_7x7_int8_nchw(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const size_t IH, - const size_t IW, const size_t OH, - const size_t OW, const Op& op) { +void conv_bias::conv_direct_stride2_7x7_int8_nchw( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); #define GET_R7(sptr) \ _r00 = vld1_s8(sptr); \ @@ -2138,16 +2122,18 @@ void conv_bias::conv_direct_stride2_7x7_int8_nchw(const int8_t* src, template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_nchw< \ first_ic, last_ic, bias, Op>( \ const int8_t*, const int8_t*, const int32_t*, int32_t*, int8_t*, \ - const size_t, const size_t, const size_t, const size_t, \ - const Op&); - -#define FOR_OP(stride, i, first_ic, last_ic, bias) \ - INSTANTIATION(stride, i, first_ic, last_ic, bias, \ - TypeCvtOp) \ - INSTANTIATION(stride, i, first_ic, last_ic, bias, \ - ReluOp) \ - INSTANTIATION(stride, i, first_ic, last_ic, bias, \ - HSwishOp) + const size_t, const size_t, const size_t, const size_t, const Op&); + +#define FOR_OP(stride, i, first_ic, last_ic, bias) \ + INSTANTIATION( \ + stride, i, first_ic, last_ic, bias, \ + TypeCvtOp) \ + INSTANTIATION( \ + stride, i, first_ic, last_ic, bias, \ + ReluOp) \ + INSTANTIATION( \ + stride, i, first_ic, last_ic, bias, \ + HSwishOp) #define FOR_BIAS(stride, i, first_ic, last_ic) \ FOR_OP(stride, i, first_ic, last_ic, BiasMode::NO_BIAS) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp index d467272f..d9b7421b 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp @@ -92,15 +92,13 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index) { template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_2x2_int8_dot( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const Op& op) { + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; - const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, - 2, 3, 16, 16, 3, 4, 16, 16}; - const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, - 6, 7, 16, 16, 7, 8, 16, 16}; + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, 6, 7, 16, 16, 7, 8, 16, 16}; int32_t* outptr = temp; int32_t* outptr2 = outptr + OW; int8_t* dstptr = dst; @@ -113,8 +111,8 @@ void conv_bias::conv_direct_stride1_2x2_int8_dot( const int8_t* k0 = filter; - int8x16_t _k = vreinterpretq_s8_s32( - vdupq_n_s32(*reinterpret_cast(k0))); + int8x16_t _k = + vreinterpretq_s8_s32(vdupq_n_s32(*reinterpret_cast(k0))); uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; int8x16_t _k1 = vqtbl1q_s8_v7(_k, _idx); _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; @@ -241,13 +239,13 @@ void conv_bias::conv_direct_stride1_2x2_int8_dot( int8x16_t _r01 = vextq_s8(_r00, _r01_, 1); int8x16_t _r11 = vextq_s8(_r10, _r11_, 1); - int16x8x2_t r_20 = vzipq_s16(vreinterpretq_s16_s8(_r00), - vreinterpretq_s16_s8(_r10)); + int16x8x2_t r_20 = + vzipq_s16(vreinterpretq_s16_s8(_r00), vreinterpretq_s16_s8(_r10)); int8x16_t _r0 = r_20.val[0]; int8x16_t _r2 = r_20.val[1]; - int16x8x2_t r1_21 = vzipq_s16(vreinterpretq_s16_s8(_r01), - vreinterpretq_s16_s8(_r11)); + int16x8x2_t r1_21 = + vzipq_s16(vreinterpretq_s16_s8(_r01), vreinterpretq_s16_s8(_r11)); int8x16_t _r1 = r1_21.val[0]; int8x16_t _r3 = r1_21.val[1]; @@ -328,16 +326,14 @@ void conv_bias::conv_direct_stride1_2x2_int8_dot( template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_3x3_int8_dot( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const Op& op) { + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; - const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, - 2, 3, 4, 16, 3, 4, 5, 16}; - const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; int32_t* outptr = temp; @@ -564,14 +560,13 @@ void conv_bias::conv_direct_stride1_3x3_int8_dot( template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_2x2_int8_dot( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const Op& op) { + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - 2 * OW + IW; - const uint8x16_t _idx0 = {0, 1, 16, 16, 2, 3, 16, 16, - 4, 5, 16, 16, 6, 7, 16, 16}; + const uint8x16_t _idx0 = {0, 1, 16, 16, 2, 3, 16, 16, 4, 5, 16, 16, 6, 7, 16, 16}; int32_t* outptr = temp; int8_t* dstptr = dst; @@ -580,8 +575,8 @@ void conv_bias::conv_direct_stride2_2x2_int8_dot( const int8_t* k0 = filter; const int32_t* __restrict bptr = bias; - int8x16_t _k = vreinterpretq_s8_s32( - vdupq_n_s32(*reinterpret_cast(k0))); + int8x16_t _k = + vreinterpretq_s8_s32(vdupq_n_s32(*reinterpret_cast(k0))); uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; int8x16_t _k1 = vqtbl1q_s8_v7(_k, _idx); _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; @@ -611,8 +606,8 @@ void conv_bias::conv_direct_stride2_2x2_int8_dot( //! here will not not read out of bound int8x16_t _r10 = vld1q_s8(r1); - int16x8x2_t r_00 = vzipq_s16(vreinterpretq_s16_s8(_r00), - vreinterpretq_s16_s8(_r10)); + int16x8x2_t r_00 = + vzipq_s16(vreinterpretq_s16_s8(_r00), vreinterpretq_s16_s8(_r10)); int8x16_t _r0 = r_00.val[0]; int8x16_t _r1 = r_00.val[1]; @@ -660,14 +655,13 @@ void conv_bias::conv_direct_stride2_2x2_int8_dot( template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_3x3_int8_dot( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const Op& op) { + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - 2 * OW + IW; - const uint8x16_t _idx0 = {0, 1, 2, 16, 2, 3, 4, 16, - 4, 5, 6, 16, 6, 7, 8, 16}; + const uint8x16_t _idx0 = {0, 1, 2, 16, 2, 3, 4, 16, 4, 5, 6, 16, 6, 7, 8, 16}; int32_t* outptr = temp; int32_t* outptr2 = outptr + OW; int8_t* dstptr = dst; @@ -816,9 +810,9 @@ void conv_bias::conv_direct_stride2_3x3_int8_dot( template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_5x5_int8_dot( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const Op& op) { + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - 2 * OW + IW; @@ -1115,15 +1109,14 @@ void conv_bias::conv_direct_stride2_5x5_int8_dot( template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_7x7_int8_dot( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const Op& op) { + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - 2 * OW + IW; const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; - const uint8x16_t _idx01 = {4, 5, 6, 16, 6, 7, 8, 16, - 8, 9, 10, 16, 10, 11, 12, 16}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 6, 7, 8, 16, 8, 9, 10, 16, 10, 11, 12, 16}; //! start from 8 const uint8x16_t& _idx10 = _idx00; const uint8x16_t& _idx11 = _idx01; @@ -1478,9 +1471,9 @@ void conv_bias::conv_direct_stride2_7x7_int8_dot( template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_5x5_int8_dot( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const Op& op) { + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; @@ -1779,15 +1772,14 @@ void conv_bias::conv_direct_stride1_5x5_int8_dot( template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_7x7_int8_dot( - const int8_t* src, const int8_t* filter, const int32_t* bias, - int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const Op& op) { + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t IH, const size_t IW, const size_t OH, const size_t OW, + const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; - const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; @@ -2123,21 +2115,22 @@ void conv_bias::conv_direct_stride1_7x7_int8_dot( #undef ST1_S32X4 #undef ST2_S32X4X2 - #define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \ template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_dot< \ first_ic, last_ic, bias, Op>( \ const int8_t*, const int8_t*, const int32_t*, int32_t*, int8_t*, \ - const size_t, const size_t, const size_t, const size_t, \ - const Op&); - -#define FOR_OP(stride, i, first_ic, last_ic, bias) \ - INSTANTIATION(stride, i, first_ic, last_ic, bias, \ - TypeCvtOp) \ - INSTANTIATION(stride, i, first_ic, last_ic, bias, \ - ReluOp) \ - INSTANTIATION(stride, i, first_ic, last_ic, bias, \ - HSwishOp) + const size_t, const size_t, const size_t, const size_t, const Op&); + +#define FOR_OP(stride, i, first_ic, last_ic, bias) \ + INSTANTIATION( \ + stride, i, first_ic, last_ic, bias, \ + TypeCvtOp) \ + INSTANTIATION( \ + stride, i, first_ic, last_ic, bias, \ + ReluOp) \ + INSTANTIATION( \ + stride, i, first_ic, last_ic, bias, \ + HSwishOp) #define FOR_BIAS(stride, i, first_ic, last_ic) \ FOR_OP(stride, i, first_ic, last_ic, BiasMode::NO_BIAS) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.h b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.h index a19bae6d..accc78cd 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.h @@ -35,7 +35,7 @@ KERN(stride2, 7) #undef KERN -} // namesapce conv_bias +} // namespace conv_bias } // namespace arm_common } // namespace megdnn #endif diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp index 6c20d4a6..efa2b57f 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp @@ -23,12 +23,10 @@ namespace arm_common { namespace direct_dotprod_nchw44 { template <> -void copy_packed_src_int8_nchw44<1>(int8_t* dst, const int dst_step, - const int8_t* src, const int src_step, - const int ic, const int ic_step, - const int ih, const int pad_left, - const int pad_right, const int pad_top, - const int pad_bottom) { +void copy_packed_src_int8_nchw44<1>( + int8_t* dst, const int dst_step, const int8_t* src, const int src_step, + const int ic, const int ic_step, const int ih, const int pad_left, + const int pad_right, const int pad_top, const int pad_bottom) { MEGDNN_MARK_USED_VAR(pad_right); constexpr int IC_PACK_SIZE = 4; rep_step(ic_idx, ic, IC_PACK_SIZE) { @@ -53,27 +51,23 @@ void copy_packed_src_int8_nchw44<1>(int8_t* dst, const int dst_step, i_src += bytes_copy / sizeof(int8_t); } //! pad bottom - int bytes_pad_bottom = - pad_bottom * dst_step * IC_PACK_SIZE * sizeof(int8_t); + int bytes_pad_bottom = pad_bottom * dst_step * IC_PACK_SIZE * sizeof(int8_t); memset(dst, 0, bytes_pad_bottom); dst += bytes_pad_bottom / sizeof(int8_t); } } template <> -void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step, - const int8_t* src, const int src_step, - const int ic, const int ic_step, - const int ih, const int pad_left, - const int pad_right, const int pad_top, - const int pad_bottom) { +void copy_packed_src_int8_nchw44<2>( + int8_t* dst, const int dst_step, const int8_t* src, const int src_step, + const int ic, const int ic_step, const int ih, const int pad_left, + const int pad_right, const int pad_top, const int pad_bottom) { MEGDNN_MARK_USED_VAR(pad_right); constexpr int IC_PACK_SIZE = 4; int odd_start = megdnn::div_ceil(dst_step, 2); bool nochange = pad_left % 2 == 0; rep_step(ic_idx, ic, IC_PACK_SIZE) { - const int32_t* i_src = - reinterpret_cast(src + ic_idx * ic_step); + const int32_t* i_src = reinterpret_cast(src + ic_idx * ic_step); int bytes_pad_top = pad_top * dst_step * IC_PACK_SIZE * sizeof(int8_t); memset(dst, 0, bytes_pad_top); dst += bytes_pad_top / sizeof(int8_t); @@ -81,8 +75,8 @@ void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step, int bytes_row_in_dst = dst_step * IC_PACK_SIZE * sizeof(int8_t); memset(dst, 0, bytes_row_in_dst); - int32_t* dst_even = reinterpret_cast(dst) + pad_left / 2 + - pad_left % 2; + int32_t* dst_even = + reinterpret_cast(dst) + pad_left / 2 + pad_left % 2; int32_t* dst_odd = reinterpret_cast(dst) + odd_start + pad_left / 2; int i_src_idx = 0; @@ -132,8 +126,7 @@ void copy_packed_src_int8_nchw44<2>(int8_t* dst, const int dst_step, i_src += src_step; } //! pad bottom - int bytes_pad_bottom = - pad_bottom * dst_step * IC_PACK_SIZE * sizeof(int8_t); + int bytes_pad_bottom = pad_bottom * dst_step * IC_PACK_SIZE * sizeof(int8_t); memset(dst, 0, bytes_pad_bottom); dst += bytes_pad_bottom / sizeof(int8_t); } diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h index e1cde5f4..1b9789c4 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h @@ -42,13 +42,12 @@ using BiasMode = ConvBiasForward::BiasMode; * @return none */ -template -void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, - const int8_t* src, const int ih, const int iw, - const int8_t* filter, const int32_t* bias, - const int oh_size, const int oc, const int ic, - const Op& op); +template < + typename dst_type, int stride, BiasMode bias_mode, typename Op, int filter_size> +void conv_direct_sdot_int8_nchw44( + dst_type* dst, const int oh, const int ow, const int8_t* src, const int ih, + const int iw, const int8_t* filter, const int32_t* bias, const int oh_size, + const int oc, const int ic, const Op& op); /** * @brief : copy data from src to dst for direct conv with no side effect * @param : [output ptr] dst @@ -65,11 +64,10 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, * @return none */ template -void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step, - const int8_t* src, const int src_step, - const int ic, const int ic_step, const int ih, - const int pad_left, const int pad_right, - const int pad_top, const int pad_bottom); +void copy_packed_src_int8_nchw44( + int8_t* dst, const int dst_step, const int8_t* src, const int src_step, + const int ic, const int ic_step, const int ih, const int pad_left, + const int pad_right, const int pad_top, const int pad_bottom); } // namespace direct_dotprod_nchw44 } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp index 3909ed9f..c124ee36 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp @@ -23,16 +23,15 @@ using namespace arm_common; MIDOUT_DECL(megdnn_arm_common_conv_bias_int8) -using direct_fun = - std::function; +using direct_fun = std::function; namespace { static void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih, - int& iw, int& oh, int& ow) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, int& ih, int& iw, + int& oh, int& ow) { int IC = param.filter_meta.icpg; int IW = param.isz[1]; int OH = param.osz[0]; @@ -52,8 +51,7 @@ static void get_rectified_size( iw = round_up(IW + 2 * PW, cacheline); } -static inline int get_perthread_cache_bytes(const int ic, const int ih, - const int iw) { +static inline int get_perthread_cache_bytes(const int ic, const int ih, const int iw) { // border_size is used to avoid read illegal memory int border_size = 64 * 2; return ic * ih * iw * sizeof(int8_t) + border_size; @@ -69,11 +67,12 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { return {nullptr, {bytes_of_copy_per_thread * param.nr_threads}}; } -template -static void conv_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& ncb_param, - const ConvBiasImpl::NCBKernIndex& ncb_index) { +template < + typename dst_type, size_t filter_size, BiasMode bias_mode, typename Op, + int stride> +static void conv_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& ncb_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) { const int OH = ncb_param.osz[0]; const int OW = ncb_param.osz[1]; const int FH = ncb_param.filter_meta.spatial[0]; @@ -99,8 +98,8 @@ static void conv_kern(const WorkspaceBundle& bundle, const int oh_tile_id = ncb_index.ndrange_id[2]; const int thread_id = ncb_index.thread_id; - const int oh_tile_size = l2_block_helper(ncb_param.nr_threads, OH, - IC * IW * sizeof(int8_t) * 2); + const int oh_tile_size = + l2_block_helper(ncb_param.nr_threads, OH, IC * IW * sizeof(int8_t) * 2); const int oh_start_row = oh_tile_id * oh_tile_size; const int ih_start_row = std::max(oh_start_row * SH - PH, 0); @@ -114,13 +113,11 @@ static void conv_kern(const WorkspaceBundle& bundle, const int cols_padding_at_right = std::max(iw2 - IW - PW, 0); //! src layout{IC/4, IH, IW, 4} - const int bytes_of_src_offset = - ih_start_row * IW * IC_PACK_SIZE * sizeof(int8_t); + const int bytes_of_src_offset = ih_start_row * IW * IC_PACK_SIZE * sizeof(int8_t); const int8_t* copy_src = static_cast( ncb_param.src(batch_id, group_id) + bytes_of_src_offset); - const int bytes_of_copy_per_thread = - get_perthread_cache_bytes(IC, ih2, iw2); + const int bytes_of_copy_per_thread = get_perthread_cache_bytes(IC, ih2, iw2); int8_t* copy_dst = reinterpret_cast(bundle.get(0)) + thread_id * bytes_of_copy_per_thread; @@ -142,15 +139,14 @@ static void conv_kern(const WorkspaceBundle& bundle, Op op = Op(1.0f, 4.0f); if (ncb_param.dst_type.enumv() == DTypeEnum::QuantizedS8) { - float scale_bias = - ncb_param.bias_type.param().scale; + float scale_bias = ncb_param.bias_type.param().scale; float scale_dst = ncb_param.dst_type.param().scale; op = Op(scale_bias, scale_dst); } direct_dotprod_nchw44::conv_direct_sdot_int8_nchw44< dst_type, stride, bias_mode, Op, filter_size>( - dst, OH, OW, copy_dst, ih_real_size, iw2, weights, bias, - oh_real_size, OC, IC, op); + dst, OH, OW, copy_dst, ih_real_size, iw2, weights, bias, oh_real_size, OC, + IC, op); } } // namespace @@ -158,7 +154,7 @@ static void conv_kern(const WorkspaceBundle& bundle, bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const { - if (!cpuinfo_has_arm_neon_dot()){ + if (!cpuinfo_has_arm_neon_dot()) { return false; } MEGDNN_MARK_USED_VAR(algo_selection_strategy); @@ -187,9 +183,8 @@ bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( bool layout_ok = fm.format == param::Convolution::Format::NCHW44_DOT && IC % 4 == 0 && OC % 4 == 0; - bool param_ok = !fm.should_flip && fm.spatial_ndim == 2 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && FH == FW && - (FH >= 2 && FH <= 7); + bool param_ok = !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && FH == FW && (FH >= 2 && FH <= 7); bool stride_ok = SH == SW && (SH == 1 || SH == 2); @@ -204,19 +199,20 @@ bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred( size_t ConvBiasImpl::AlgoDotS8Direct_NCHW44::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("ALGODOTS8DIRECT_NCHW44::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("ALGODOTS8DIRECT_NCHW44::get_workspace"_hash)) { return get_bundle(param).total_size_in_bytes(); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, - midout_iv("ALGODOTS8DIRECT_NCHW44::dispatch_kerns"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8, + midout_iv("ALGODOTS8DIRECT_NCHW44::dispatch_kerns"_hash)) { auto fm = param.filter_meta; size_t BATCH = param.n; size_t GROUP = fm.group; @@ -224,47 +220,48 @@ ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( direct_fun kernel = nullptr; bool quantized = param.dst_type.enumv() == DTypeEnum::QuantizedS8; -#define DO_CONV_KERN_FUN(dst_type, filter, bias_mode, op, stride) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, \ - midout_iv(#dst_type #filter #bias_mode #op##_hash)) { \ - kernel = conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(dst_type, filter, bias_mode, op, stride) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_int8, \ + midout_iv(#dst_type #filter #bias_mode #op##_hash)) { \ + kernel = conv_kern; \ + } \ MIDOUT_END(); -#define GET_OP_PARAM(i, bias_mode, stride) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - if (quantized) { \ - DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \ - TypeCvtOp, \ - stride) \ - } else { \ - DO_CONV_KERN_FUN(dt_int32, i, bias_mode, \ - NoneOp, \ - stride) \ - } \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - if (quantized) { \ - DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \ - ReluOp, \ - stride) \ - } else { \ - megdnn_assert("No support NoQuantized RELU"); \ - } \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - if (quantized) { \ - DO_CONV_KERN_FUN(dt_int8, i, bias_mode, \ - HSwishOp, \ - stride) \ - } else { \ - megdnn_assert("No support NoQuantized H_SWISH"); \ - } \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(i, bias_mode, stride) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + if (quantized) { \ + DO_CONV_KERN_FUN( \ + dt_int8, i, bias_mode, \ + TypeCvtOp, stride) \ + } else { \ + DO_CONV_KERN_FUN( \ + dt_int32, i, bias_mode, \ + NoneOp, stride) \ + } \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + if (quantized) { \ + DO_CONV_KERN_FUN( \ + dt_int8, i, bias_mode, \ + ReluOp, stride) \ + } else { \ + megdnn_assert("No support NoQuantized RELU"); \ + } \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + if (quantized) { \ + DO_CONV_KERN_FUN( \ + dt_int8, i, bias_mode, \ + HSwishOp, stride) \ + } else { \ + megdnn_assert("No support NoQuantized H_SWISH"); \ + } \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_STRIDE_PARAM(filter, bias_mode) \ @@ -325,8 +322,8 @@ ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns( int OH = param.osz[0]; int IC = param.filter_meta.icpg; int IW = param.isz[1]; - int oh_tile_size = l2_block_helper(param.nr_threads, OH, - IC * IW * sizeof(int8_t) * 2); + int oh_tile_size = + l2_block_helper(param.nr_threads, OH, IC * IW * sizeof(int8_t) * 2); size_t oh_tiles = static_cast(div_ceil(OH, oh_tile_size)); auto do_conv = [wbundle, kernel]( diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h index 3504b4c4..1050da31 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h @@ -30,8 +30,8 @@ constexpr int filter_next_col = IC_PACK_SIZE * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] template -MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8], - const int32_t* bias_ptr, int oc_step) { +MEGDNN_ALWAYS_INLINE void init_ocx_ow8( + int32x4_t c[][8], const int32_t* bias_ptr, int oc_step) { static_assert(row == 1 || row == 2 || row == 3, "Invalid OC number."); if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { #define BIAS_INIT(step, i) c[i][step] = vld1q_s32(bias_ptr + i * oc_step); @@ -58,20 +58,17 @@ MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8], } } -#define cb11(col) \ - op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); +#define cb11(col) op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); #define cb21(col) \ op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); \ - op(res[1][col], \ - reinterpret_cast(dst_ptr + ld_dst_oc + col / 2 * 8)); + op(res[1][col], reinterpret_cast(dst_ptr + ld_dst_oc + col / 2 * 8)); -#define cb31(col) \ - op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); \ - op(res[1][col], \ - reinterpret_cast(dst_ptr + ld_dst_oc + col / 2 * 8)); \ - op(res[2][col], reinterpret_cast(dst_ptr + ld_dst_oc + \ - ld_dst_oc + col / 2 * 8)); +#define cb31(col) \ + op(res[0][col], reinterpret_cast(dst_ptr + col / 2 * 8)); \ + op(res[1][col], reinterpret_cast(dst_ptr + ld_dst_oc + col / 2 * 8)); \ + op(res[2][col], \ + reinterpret_cast(dst_ptr + ld_dst_oc + ld_dst_oc + col / 2 * 8)); #define cb12(step) \ op({{res[0][2 * step], res[0][2 * step + 1]}}, \ @@ -93,14 +90,14 @@ MEGDNN_ALWAYS_INLINE void init_ocx_ow8(int32x4_t c[][8], template struct StoreOCxOWx { - static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, - T* dst_ptr, const int ld_dst_oc); + static MEGDNN_ALWAYS_INLINE void impl( + int32x4_t res[][8], const Op& op, T* dst_ptr, const int ld_dst_oc); }; template struct StoreOCxOWx<1, ow_remain, Op, T> { - static void impl(int32x4_t res[][8], const Op& op, T* dst_ptr, - const int ld_dst_oc) { + static void impl( + int32x4_t res[][8], const Op& op, T* dst_ptr, const int ld_dst_oc) { MEGDNN_MARK_USED_VAR(ld_dst_oc); switch (ow_remain) { case 8: @@ -131,8 +128,8 @@ struct StoreOCxOWx<1, ow_remain, Op, T> { template struct StoreOCxOWx<2, ow_remain, Op, T> { - static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, - T* dst_ptr, const int ld_dst_oc) { + static MEGDNN_ALWAYS_INLINE void impl( + int32x4_t res[][8], const Op& op, T* dst_ptr, const int ld_dst_oc) { switch (ow_remain) { case 8: UNROLL_CALL_RAW(4, cb22); @@ -162,8 +159,8 @@ struct StoreOCxOWx<2, ow_remain, Op, T> { template struct StoreOCxOWx<3, ow_remain, Op, T> { - static MEGDNN_ALWAYS_INLINE void impl(int32x4_t res[][8], const Op& op, - T* dst_ptr, const int ld_dst_oc) { + static MEGDNN_ALWAYS_INLINE void impl( + int32x4_t res[][8], const Op& op, T* dst_ptr, const int ld_dst_oc) { switch (ow_remain) { case 8: UNROLL_CALL_RAW(4, cb32); @@ -199,45 +196,47 @@ struct StoreOCxOWx<3, ow_remain, Op, T> { #undef cb32 template -MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8], - const Op& op, T* dst_ptr, - const int ld_dst_oc) { +MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static( + int32x4_t res[][8], const Op& op, T* dst_ptr, const int ld_dst_oc) { StoreOCxOWx::impl(res, op, dst_ptr, ld_dst_oc); } -template +template < + int res_row, int src_row, int src_start_idx, int weight_idx, typename T, + typename T2, typename T3> struct ShiftCalHelper { MEGDNN_ATTRIBUTE_TARGET("dotprod") static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) { -#define cb(step) \ - res[res_row][step] = \ - vdotq_laneq_s32(res[res_row][step], weight[weight_idx], \ - src[src_row][(src_start_idx + step) / 4], \ - (src_start_idx + step) % 4); +#define cb(step) \ + res[res_row][step] = vdotq_laneq_s32( \ + res[res_row][step], weight[weight_idx], \ + src[src_row][(src_start_idx + step) / 4], (src_start_idx + step) % 4); UNROLL_CALL_RAW(8, cb); #undef cb } }; -template +template < + int res_row, int src_row, int src_start_idx, int weight_idx, typename T, + typename T2, typename T3> MEGDNN_ATTRIBUTE_TARGET("dotprod") MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) { - ShiftCalHelper::impl(res, src, weight); + ShiftCalHelper::impl( + res, src, weight); }; /** * oc12_owx(m = 12, n = x) and oc8_owx(m = 8, n = x) and oc4_owx(m = 4, n = x) * gemm like kernel * */ -template +template < + typename dst_type, int stride, BiasMode bias_mode, typename Op, int ow_remain, + int filter_size, int oc_interval, int ow_interval> struct KernNeonSdotNCHW44 { - static void impl(dst_type* dst, const int dst_step, const int8_t* src, - const int ih, const int iw, const int8_t* filter, - const int32_t* bias, const int ic, const Op& op); + static void impl( + dst_type* dst, const int dst_step, const int8_t* src, const int ih, + const int iw, const int8_t* filter, const int32_t* bias, const int ic, + const Op& op); }; } // namespace direct_dotprod_nchw44 diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp index 1ae736b8..cc9624f3 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp @@ -16,19 +16,20 @@ namespace megdnn { namespace arm_common { namespace direct_dotprod_nchw44 { -template -struct KernNeonSdotNCHW44 { +template < + typename dst_type, BiasMode bias_mode, typename Op, int ow_remain, + int filter_size, int oc_interval, int ow_interval> +struct KernNeonSdotNCHW44< + dst_type, 1, bias_mode, Op, ow_remain, filter_size, oc_interval, ow_interval> { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(dst_type* dst, const int dst_step, const int8_t* src, - const int ih, const int iw, const int8_t* filter, - const int32_t* bias, const int ic, const Op& op) { + static void impl( + dst_type* dst, const int dst_step, const int8_t* src, const int ih, + const int iw, const int8_t* filter, const int32_t* bias, const int ic, + const Op& op) { constexpr int FH = filter_size; constexpr int FW = filter_size; constexpr int filter_next_row = - FW * OC_PACK_SIZE * - IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + FW * OC_PACK_SIZE * IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] const int filter_next_4oc = FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] @@ -51,34 +52,34 @@ struct KernNeonSdotNCHW44(src, i_src, 0); //! do not use switch order 3,2,1 because it will slow the speed. -#define CALC_PART(step) \ - switch (LOOP) { \ - case 1: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, 0, step, 0>(res, src, weight); \ - break; \ - case 2: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, 0, step, 0>(res, src, weight); \ - weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ - filter_next_col * step); \ - cal_helper<1, 0, step, 1>(res, src, weight); \ - break; \ - case 3: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, 0, step, 0>(res, src, weight); \ - weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ - filter_next_col * step); \ - cal_helper<1, 0, step, 1>(res, src, weight); \ - weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ - filter_next_col * step); \ - cal_helper<2, 0, step, 2>(res, src, weight); \ - break; \ - default: \ - break; \ +#define CALC_PART(step) \ + switch (LOOP) { \ + case 1: \ + weight[0] = \ + vld1q_s8(i_filter + filter_next_4oc * 0 + filter_next_col * step); \ + cal_helper<0, 0, step, 0>(res, src, weight); \ + break; \ + case 2: \ + weight[0] = \ + vld1q_s8(i_filter + filter_next_4oc * 0 + filter_next_col * step); \ + cal_helper<0, 0, step, 0>(res, src, weight); \ + weight[1] = \ + vld1q_s8(i_filter + filter_next_4oc * 1 + filter_next_col * step); \ + cal_helper<1, 0, step, 1>(res, src, weight); \ + break; \ + case 3: \ + weight[0] = \ + vld1q_s8(i_filter + filter_next_4oc * 0 + filter_next_col * step); \ + cal_helper<0, 0, step, 0>(res, src, weight); \ + weight[1] = \ + vld1q_s8(i_filter + filter_next_4oc * 1 + filter_next_col * step); \ + cal_helper<1, 0, step, 1>(res, src, weight); \ + weight[2] = \ + vld1q_s8(i_filter + filter_next_4oc * 2 + filter_next_col * step); \ + cal_helper<2, 0, step, 2>(res, src, weight); \ + break; \ + default: \ + break; \ } switch (filter_size) { @@ -103,19 +104,17 @@ struct KernNeonSdotNCHW44(res, op, dst, - dst_step); + store_ocx_owx_remain_static(res, op, dst, dst_step); } }; -template +template < + typename dst_type, int stride, BiasMode bias_mode, typename Op, int filter_size> MEGDNN_ATTRIBUTE_TARGET("dotprod") -void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, - const int8_t* src, const int ih, const int iw, - const int8_t* filter, const int32_t* bias, - const int oh_size, const int oc, const int ic, - const Op& op) { +void conv_direct_sdot_int8_nchw44( + dst_type* dst, const int oh, const int ow, const int8_t* src, const int ih, + const int iw, const int8_t* filter, const int32_t* bias, const int oh_size, + const int oc, const int ic, const Op& op) { constexpr int FH = filter_size; constexpr int FW = filter_size; constexpr int IC_PACK_SIZE = 4; @@ -137,36 +136,31 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, const int dst_numbers_per_channel = oh * ow; const int ow_remain = ow % OW_INTERVAL; const int ow_end_idx = ow - ow_remain; - const int oc_remain = - oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 + const int oc_remain = oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 const int oc_end_idx = oc - oc_remain; - const int dst_numbers_4channel_packed = - dst_numbers_per_channel * OC_PACK_SIZE; + const int dst_numbers_4channel_packed = dst_numbers_per_channel * OC_PACK_SIZE; using remain_fun = std::function; + const int iw, const int8_t* filter, const int32_t* bias, const int ic, + const Op& op)>; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_mid_oc_remain = nullptr; remain_fun kern_sma_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KernNeonSdotNCHW44::impl; \ - kern_mid_oc_remain = \ - KernNeonSdotNCHW44::impl; \ - kern_sma_oc_remain = \ - KernNeonSdotNCHW44::impl; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = KernNeonSdotNCHW44< \ + dst_type, stride, bias_mode, Op, step, filter_size, OC_BIG_INTERVAL, \ + OW_INTERVAL>::impl; \ + kern_mid_oc_remain = KernNeonSdotNCHW44< \ + dst_type, stride, bias_mode, Op, step, filter_size, OC_MID_INTERVAL, \ + OW_INTERVAL>::impl; \ + kern_sma_oc_remain = KernNeonSdotNCHW44< \ + dst_type, stride, bias_mode, Op, step, filter_size, OC_SMA_INTERVAL, \ + OW_INTERVAL>::impl; \ break; UNROLL_CALL_RAW(8, cb); #undef cb @@ -184,14 +178,13 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { const int src_offset_in_element = (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; - const int dst_offset_in_element = - oc_idx * dst_numbers_per_channel + - (oh_idx * ow + ow_idx) * OC_PACK_SIZE; + const int dst_offset_in_element = oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_idx) * OC_PACK_SIZE; const int bias_offset_in_element = oc_idx; - KernNeonSdotNCHW44:: - impl(dst + dst_offset_in_element, - dst_numbers_4channel_packed, + KernNeonSdotNCHW44< + dst_type, stride, bias_mode, Op, OW_INTERVAL, filter_size, + OC_BIG_INTERVAL, OW_INTERVAL>:: + impl(dst + dst_offset_in_element, dst_numbers_4channel_packed, src + src_offset_in_element, ih, iw, filter + filter_offset_in_element, bias + bias_offset_in_element, ic, op); @@ -203,11 +196,11 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, oc_idx * dst_numbers_per_channel + (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; const int bias_offset_in_element = oc_idx; - kern_big_oc_remain(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, iw, - filter + filter_offset_in_element, - bias + bias_offset_in_element, ic, op); + kern_big_oc_remain( + dst + dst_offset_in_element, dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); } } } @@ -221,34 +214,27 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { const int src_offset_in_element = (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; - const int dst_offset_in_element = - oc_idx * dst_numbers_per_channel + - (oh_idx * ow + ow_idx) * OC_PACK_SIZE; + const int dst_offset_in_element = oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_idx) * OC_PACK_SIZE; const int bias_offset_in_element = oc_idx; if (oc_remain == 8) { KernNeonSdotNCHW44< - dst_type, stride, bias_mode, Op, OW_INTERVAL, - filter_size, OC_MID_INTERVAL, - OW_INTERVAL>::impl(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, - iw, - filter + - filter_offset_in_element, - bias + bias_offset_in_element, - ic, op); + dst_type, stride, bias_mode, Op, OW_INTERVAL, filter_size, + OC_MID_INTERVAL, OW_INTERVAL>:: + impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); } else { KernNeonSdotNCHW44< - dst_type, stride, bias_mode, Op, OW_INTERVAL, - filter_size, OC_SMA_INTERVAL, - OW_INTERVAL>::impl(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, - iw, - filter + - filter_offset_in_element, - bias + bias_offset_in_element, - ic, op); + dst_type, stride, bias_mode, Op, OW_INTERVAL, filter_size, + OC_SMA_INTERVAL, OW_INTERVAL>:: + impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); } } if (ow_remain) { @@ -259,17 +245,17 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; const int bias_offset_in_element = oc_idx; if (oc_remain == 8) { - kern_mid_oc_remain(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, iw, - filter + filter_offset_in_element, - bias + bias_offset_in_element, ic, op); + kern_mid_oc_remain( + dst + dst_offset_in_element, dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); } else { - kern_sma_oc_remain(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, iw, - filter + filter_offset_in_element, - bias + bias_offset_in_element, ic, op); + kern_sma_oc_remain( + dst + dst_offset_in_element, dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); } } } @@ -277,23 +263,22 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, #endif } -#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ - template void conv_direct_sdot_int8_nchw44( \ - dst_type * dst, const int oh, const int ow, const int8_t* src, \ - const int ih, const int iw, const int8_t* weight, \ - const int32_t* bias, const int oh_size, const int oc, \ - const int ic, const Op& op); +#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ + template void \ + conv_direct_sdot_int8_nchw44( \ + dst_type * dst, const int oh, const int ow, const int8_t* src, \ + const int ih, const int iw, const int8_t* weight, const int32_t* bias, \ + const int oh_size, const int oc, const int ic, const Op& op); -#define FOR_OP(stride, i, bias_mode) \ - INSTANTIATION(dt_int8, stride, i, bias_mode, \ - TypeCvtOp) \ - INSTANTIATION(dt_int32, stride, i, bias_mode, \ - NoneOp) \ - INSTANTIATION(dt_int8, stride, i, bias_mode, \ - ReluOp) \ - INSTANTIATION(dt_int8, stride, i, bias_mode, \ - HSwishOp) +#define FOR_OP(stride, i, bias_mode) \ + INSTANTIATION( \ + dt_int8, stride, i, bias_mode, TypeCvtOp) \ + INSTANTIATION( \ + dt_int32, stride, i, bias_mode, NoneOp) \ + INSTANTIATION( \ + dt_int8, stride, i, bias_mode, ReluOp) \ + INSTANTIATION( \ + dt_int8, stride, i, bias_mode, HSwishOp) #define FOR_BIAS(stride, i) \ FOR_OP(stride, i, BiasMode::NO_BIAS) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp index 36b37778..841e868f 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp @@ -16,19 +16,20 @@ namespace megdnn { namespace arm_common { namespace direct_dotprod_nchw44 { -template -struct KernNeonSdotNCHW44 { +template < + typename dst_type, BiasMode bias_mode, typename Op, int ow_remain, + int filter_size, int oc_interval, int ow_interval> +struct KernNeonSdotNCHW44< + dst_type, 2, bias_mode, Op, ow_remain, filter_size, oc_interval, ow_interval> { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(dst_type* dst, const int dst_step, const int8_t* src, - const int ih, const int iw, const int8_t* filter, - const int32_t* bias, const int ic, const Op& op) { + static void impl( + dst_type* dst, const int dst_step, const int8_t* src, const int ih, + const int iw, const int8_t* filter, const int32_t* bias, const int ic, + const Op& op) { constexpr int FH = filter_size; constexpr int FW = filter_size; constexpr int filter_next_row = - FW * OC_PACK_SIZE * - IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] + FW * OC_PACK_SIZE * IC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] const int filter_next_4oc = FH * FW * ic * OC_PACK_SIZE; //! [OC/4, IC/4, FH, FW, 4OC, 4IC] @@ -52,34 +53,34 @@ struct KernNeonSdotNCHW44(src, i_src, offset); //! do not use switch order 3,2,1 because it will slow the speed. -#define CALC_PART(step) \ - switch (LOOP) { \ - case 1: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ - break; \ - case 2: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ - weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ - filter_next_col * step); \ - cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \ - break; \ - case 3: \ - weight[0] = vld1q_s8(i_filter + filter_next_4oc * 0 + \ - filter_next_col * step); \ - cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ - weight[1] = vld1q_s8(i_filter + filter_next_4oc * 1 + \ - filter_next_col * step); \ - cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \ - weight[2] = vld1q_s8(i_filter + filter_next_4oc * 2 + \ - filter_next_col * step); \ - cal_helper<2, step % 2, step / 2, 2>(res, src, weight); \ - break; \ - default: \ - break; \ +#define CALC_PART(step) \ + switch (LOOP) { \ + case 1: \ + weight[0] = \ + vld1q_s8(i_filter + filter_next_4oc * 0 + filter_next_col * step); \ + cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ + break; \ + case 2: \ + weight[0] = \ + vld1q_s8(i_filter + filter_next_4oc * 0 + filter_next_col * step); \ + cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ + weight[1] = \ + vld1q_s8(i_filter + filter_next_4oc * 1 + filter_next_col * step); \ + cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \ + break; \ + case 3: \ + weight[0] = \ + vld1q_s8(i_filter + filter_next_4oc * 0 + filter_next_col * step); \ + cal_helper<0, step % 2, step / 2, 0>(res, src, weight); \ + weight[1] = \ + vld1q_s8(i_filter + filter_next_4oc * 1 + filter_next_col * step); \ + cal_helper<1, step % 2, step / 2, 1>(res, src, weight); \ + weight[2] = \ + vld1q_s8(i_filter + filter_next_4oc * 2 + filter_next_col * step); \ + cal_helper<2, step % 2, step / 2, 2>(res, src, weight); \ + break; \ + default: \ + break; \ } switch (filter_size) { @@ -104,19 +105,17 @@ struct KernNeonSdotNCHW44(res, op, dst, - dst_step); + store_ocx_owx_remain_static(res, op, dst, dst_step); } }; -template +template < + typename dst_type, int stride, BiasMode bias_mode, typename Op, int filter_size> MEGDNN_ATTRIBUTE_TARGET("dotprod") -void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, - const int8_t* src, const int ih, const int iw, - const int8_t* filter, const int32_t* bias, - const int oh_size, const int oc, const int ic, - const Op& op) { +void conv_direct_sdot_int8_nchw44( + dst_type* dst, const int oh, const int ow, const int8_t* src, const int ih, + const int iw, const int8_t* filter, const int32_t* bias, const int oh_size, + const int oc, const int ic, const Op& op) { constexpr int FH = filter_size; constexpr int FW = filter_size; constexpr int IC_PACK_SIZE = 4; @@ -138,36 +137,31 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, const int dst_numbers_per_channel = oh * ow; const int ow_remain = ow % OW_INTERVAL; const int ow_end_idx = ow - ow_remain; - const int oc_remain = - oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 + const int oc_remain = oc % OC_BIG_INTERVAL; //! NCHW44 means oc_remain = 4 or 8 const int oc_end_idx = oc - oc_remain; - const int dst_numbers_4channel_packed = - dst_numbers_per_channel * OC_PACK_SIZE; + const int dst_numbers_4channel_packed = dst_numbers_per_channel * OC_PACK_SIZE; using remain_fun = std::function; + const int iw, const int8_t* filter, const int32_t* bias, const int ic, + const Op& op)>; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_mid_oc_remain = nullptr; remain_fun kern_sma_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KernNeonSdotNCHW44::impl; \ - kern_mid_oc_remain = \ - KernNeonSdotNCHW44::impl; \ - kern_sma_oc_remain = \ - KernNeonSdotNCHW44::impl; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = KernNeonSdotNCHW44< \ + dst_type, stride, bias_mode, Op, step, filter_size, OC_BIG_INTERVAL, \ + OW_INTERVAL>::impl; \ + kern_mid_oc_remain = KernNeonSdotNCHW44< \ + dst_type, stride, bias_mode, Op, step, filter_size, OC_MID_INTERVAL, \ + OW_INTERVAL>::impl; \ + kern_sma_oc_remain = KernNeonSdotNCHW44< \ + dst_type, stride, bias_mode, Op, step, filter_size, OC_SMA_INTERVAL, \ + OW_INTERVAL>::impl; \ break; UNROLL_CALL_RAW(8, cb); #undef cb @@ -185,14 +179,13 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { const int src_offset_in_element = (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; - const int dst_offset_in_element = - oc_idx * dst_numbers_per_channel + - (oh_idx * ow + ow_idx) * OC_PACK_SIZE; + const int dst_offset_in_element = oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_idx) * OC_PACK_SIZE; const int bias_offset_in_element = oc_idx; - KernNeonSdotNCHW44:: - impl(dst + dst_offset_in_element, - dst_numbers_4channel_packed, + KernNeonSdotNCHW44< + dst_type, stride, bias_mode, Op, OW_INTERVAL, filter_size, + OC_BIG_INTERVAL, OW_INTERVAL>:: + impl(dst + dst_offset_in_element, dst_numbers_4channel_packed, src + src_offset_in_element, ih, iw, filter + filter_offset_in_element, bias + bias_offset_in_element, ic, op); @@ -204,11 +197,11 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, oc_idx * dst_numbers_per_channel + (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; const int bias_offset_in_element = oc_idx; - kern_big_oc_remain(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, iw, - filter + filter_offset_in_element, - bias + bias_offset_in_element, ic, op); + kern_big_oc_remain( + dst + dst_offset_in_element, dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); } } } @@ -222,34 +215,27 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, for (int ow_idx = 0; ow_idx < ow_end_idx; ow_idx += OW_INTERVAL) { const int src_offset_in_element = (oh_idx * SH * iw + ow_idx) * IC_PACK_SIZE; - const int dst_offset_in_element = - oc_idx * dst_numbers_per_channel + - (oh_idx * ow + ow_idx) * OC_PACK_SIZE; + const int dst_offset_in_element = oc_idx * dst_numbers_per_channel + + (oh_idx * ow + ow_idx) * OC_PACK_SIZE; const int bias_offset_in_element = oc_idx; if (oc_remain == 8) { KernNeonSdotNCHW44< - dst_type, stride, bias_mode, Op, OW_INTERVAL, - filter_size, OC_MID_INTERVAL, - OW_INTERVAL>::impl(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, - iw, - filter + - filter_offset_in_element, - bias + bias_offset_in_element, - ic, op); + dst_type, stride, bias_mode, Op, OW_INTERVAL, filter_size, + OC_MID_INTERVAL, OW_INTERVAL>:: + impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); } else { KernNeonSdotNCHW44< - dst_type, stride, bias_mode, Op, OW_INTERVAL, - filter_size, OC_SMA_INTERVAL, - OW_INTERVAL>::impl(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, - iw, - filter + - filter_offset_in_element, - bias + bias_offset_in_element, - ic, op); + dst_type, stride, bias_mode, Op, OW_INTERVAL, filter_size, + OC_SMA_INTERVAL, OW_INTERVAL>:: + impl(dst + dst_offset_in_element, + dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); } } if (ow_remain) { @@ -260,17 +246,17 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, (oh_idx * ow + ow_end_idx) * OC_PACK_SIZE; const int bias_offset_in_element = oc_idx; if (oc_remain == 8) { - kern_mid_oc_remain(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, iw, - filter + filter_offset_in_element, - bias + bias_offset_in_element, ic, op); + kern_mid_oc_remain( + dst + dst_offset_in_element, dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); } else { - kern_sma_oc_remain(dst + dst_offset_in_element, - dst_numbers_4channel_packed, - src + src_offset_in_element, ih, iw, - filter + filter_offset_in_element, - bias + bias_offset_in_element, ic, op); + kern_sma_oc_remain( + dst + dst_offset_in_element, dst_numbers_4channel_packed, + src + src_offset_in_element, ih, iw, + filter + filter_offset_in_element, + bias + bias_offset_in_element, ic, op); } } } @@ -278,23 +264,22 @@ void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, #endif } -#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ - template void conv_direct_sdot_int8_nchw44( \ - dst_type * dst, const int oh, const int ow, const int8_t* src, \ - const int ih, const int iw, const int8_t* weight, \ - const int32_t* bias, const int oh_size, const int oc, \ - const int ic, const Op& op); +#define INSTANTIATION(dst_type, stride, filter_size, bias_mode, Op) \ + template void \ + conv_direct_sdot_int8_nchw44( \ + dst_type * dst, const int oh, const int ow, const int8_t* src, \ + const int ih, const int iw, const int8_t* weight, const int32_t* bias, \ + const int oh_size, const int oc, const int ic, const Op& op); -#define FOR_OP(stride, i, bias_mode) \ - INSTANTIATION(dt_int8, stride, i, bias_mode, \ - TypeCvtOp) \ - INSTANTIATION(dt_int32, stride, i, bias_mode, \ - NoneOp) \ - INSTANTIATION(dt_int8, stride, i, bias_mode, \ - ReluOp) \ - INSTANTIATION(dt_int8, stride, i, bias_mode, \ - HSwishOp) +#define FOR_OP(stride, i, bias_mode) \ + INSTANTIATION( \ + dt_int8, stride, i, bias_mode, TypeCvtOp) \ + INSTANTIATION( \ + dt_int32, stride, i, bias_mode, NoneOp) \ + INSTANTIATION( \ + dt_int8, stride, i, bias_mode, ReluOp) \ + INSTANTIATION( \ + dt_int8, stride, i, bias_mode, HSwishOp) #define FOR_BIAS(stride, i) \ FOR_OP(stride, i, BiasMode::NO_BIAS) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp index 565e7cff..5e887f31 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp @@ -17,8 +17,9 @@ namespace megdnn { namespace arm_common { namespace dot_direct_nchw_nchw44 { -template +template < + int src_idx, int weight_idx, typename Func, typename T, typename T2, + typename T3, typename T4> struct ShiftCalHelper { MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(T& c, T2& src, T3& weight) { @@ -33,8 +34,9 @@ struct ShiftCalHelper { } }; -template +template < + int src_idx, int weight_idx, typename Func, typename T, typename T2, + typename T3, typename T4> struct ShiftCalHelper { MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(T& c, T2& src, T3& weight) { @@ -48,14 +50,12 @@ struct ShiftCalHelper { }; ////////////////////stride 1/////////////////// -template -struct KerNeonDotXXs2Nchw44Int8 { +template +struct KerNeonDotXXs2Nchw44Int8 { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 1; constexpr int filter_hight = 1; constexpr int filter_width = 4; @@ -82,8 +82,7 @@ struct KerNeonDotXXs2Nchw44Int8( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); src_ptr += ic_stride; weight_ptr += filter_hight * filter_width * oc_step; @@ -93,14 +92,12 @@ struct KerNeonDotXXs2Nchw44Int8 -struct KerNeonDotXXs2Nchw44Int8 { +template +struct KerNeonDotXXs2Nchw44Int8 { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 1; constexpr int filter_hight = 2; constexpr int filter_width = 4; @@ -127,13 +124,11 @@ struct KerNeonDotXXs2Nchw44Int8( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); // row 1 load_helper( src, src_ptr + 1 * iw * pack_iw_len, 0); - cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); + cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); src_ptr += ic_stride; weight_ptr += filter_hight * filter_width * oc_step; @@ -142,14 +137,12 @@ struct KerNeonDotXXs2Nchw44Int8 -struct KerNeonDotXXs2Nchw44Int8 { +template +struct KerNeonDotXXs2Nchw44Int8 { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 1; constexpr int filter_hight = 3; constexpr int filter_width = 4; @@ -176,18 +169,15 @@ struct KerNeonDotXXs2Nchw44Int8( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); // row 1 load_helper( src, src_ptr + 1 * iw * pack_iw_len, 0); - cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); + cal_helper<0, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); // row 2 load_helper( src, src_ptr + 2 * iw * pack_iw_len, 0); - cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); + cal_helper<0, 2, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); src_ptr += ic_stride; weight_ptr += filter_hight * filter_width * oc_step; @@ -197,14 +187,12 @@ struct KerNeonDotXXs2Nchw44Int8 -struct KerNeonDotXXs2Nchw44Int8 { +template +struct KerNeonDotXXs2Nchw44Int8 { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 1; constexpr int filter_hight = 5; constexpr int filter_width = 8; @@ -228,13 +216,12 @@ struct KerNeonDotXXs2Nchw44Int8( \ - src, src_ptr + step * iw * pack_iw_len, 0); \ - load_helper( \ - weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ - weight); \ +#define cb(step) \ + load_helper( \ + src, src_ptr + step * iw * pack_iw_len, 0); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); \ cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); UNROLL_CALL_RAW(5, cb); @@ -247,14 +234,12 @@ struct KerNeonDotXXs2Nchw44Int8 -struct KerNeonDotXXs2Nchw44Int8 { +template +struct KerNeonDotXXs2Nchw44Int8 { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 1; constexpr int filter_hight = 7; constexpr int filter_width = 8; @@ -277,13 +262,12 @@ struct KerNeonDotXXs2Nchw44Int8( \ - src, src_ptr + step * iw * pack_iw_len, 0); \ - load_helper( \ - weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ - weight); \ +#define cb(step) \ + load_helper( \ + src, src_ptr + step * iw * pack_iw_len, 0); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); \ cal_helper<4, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); UNROLL_CALL_RAW(7, cb); @@ -296,15 +280,11 @@ struct KerNeonDotXXs2Nchw44Int8 -void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, - const int8_t* sptr_origin, const int, - const int pw, const int, const int ih, - const int iw, const int iw2, - const int pad_top, const int pad_bottom, - const int ic, const int ic_stride, - int8_t* temp_ptr) { - static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4, - 2, 3, 4, 5, 3, 4, 5, 6}; +void pack_src_int8_nchw_nchw44_dot<1>( + int8_t* sptr_base, const int8_t* sptr_origin, const int, const int pw, + const int, const int ih, const int iw, const int iw2, const int pad_top, + const int pad_bottom, const int ic, const int ic_stride, int8_t* temp_ptr) { + static uint8_t reorder_idx[16] = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); constexpr int iw_step = 16; @@ -314,8 +294,7 @@ void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, rep(ic_idx, ic) { const int8_t* sptr = sptr_origin + ic_idx * ic_stride; memset(sptr_base, 0, - sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * - pack_iw_len); + sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * pack_iw_len); sptr_base += iw2 * pad_top * pack_iw_len; rep(ih_idx, ih) { memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); @@ -337,14 +316,10 @@ void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, vst1q_s8(sptr_base + iw_idx * pack_iw_len + 48, dst[3]); } for (int iw_idx = iw_with_pad_end; iw_idx < iw_with_pad; ++iw_idx) { - *(sptr_base + iw_idx * pack_iw_len + 0) = - *(temp_ptr + iw_idx + 0); - *(sptr_base + iw_idx * pack_iw_len + 1) = - *(temp_ptr + iw_idx + 1); - *(sptr_base + iw_idx * pack_iw_len + 2) = - *(temp_ptr + iw_idx + 2); - *(sptr_base + iw_idx * pack_iw_len + 3) = - *(temp_ptr + iw_idx + 3); + *(sptr_base + iw_idx * pack_iw_len + 0) = *(temp_ptr + iw_idx + 0); + *(sptr_base + iw_idx * pack_iw_len + 1) = *(temp_ptr + iw_idx + 1); + *(sptr_base + iw_idx * pack_iw_len + 2) = *(temp_ptr + iw_idx + 2); + *(sptr_base + iw_idx * pack_iw_len + 3) = *(temp_ptr + iw_idx + 3); } sptr_base += iw2 * pack_iw_len; sptr += iw; @@ -355,12 +330,10 @@ void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const int oc, const int ic, - const int ih, const int iw, const int oh, - const int oh_block, const int ow, - const Op& op) { +void conv_direct_int8_nchw_nchw44_dot( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const int oc, const int ic, const int ih, const int iw, + const int oh, const int oh_block, const int ow, const Op& op) { MEGDNN_MARK_USED_VAR(temp); constexpr int fh = filter_size; constexpr int fw = (filter_size + 3) / 4 * 4; @@ -384,21 +357,18 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, const int oc_remain = oc - oc_end; const int ld_dst_oc = oc_step * img_stride; - using remain_fun = - std::function; + using remain_fun = std::function; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonDotXXs2Nchw44Int8::impl; \ - kern_small_oc_remain = \ - KerNeonDotXXs2Nchw44Int8::impl; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = KerNeonDotXXs2Nchw44Int8< \ + bias_mode, Op, step, filter_size, big_oc_step, ow_step, stride>::impl; \ + kern_small_oc_remain = KerNeonDotXXs2Nchw44Int8< \ + bias_mode, Op, step, filter_size, oc_step, ow_step, stride>::impl; \ break; UNROLL_CALL_RAW(8, cb); @@ -415,13 +385,11 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + KerNeonDotXXs2Nchw44Int8< + bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step, + stride>:: + impl(src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } if (ow_remain > 0) { const int src_offset = @@ -429,9 +397,9 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } @@ -445,13 +413,10 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + KerNeonDotXXs2Nchw44Int8< + bias_mode, Op, ow_step, filter_size, oc_step, ow_step, stride>:: + impl(src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } if (ow_remain > 0) { const int src_offset = @@ -459,28 +424,28 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } } -#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ - template void \ - conv_direct_int8_nchw_nchw44_dot( \ - const int8_t* src, const int8_t* filter, const int32_t* bias, \ - int32_t* temp, int8_t* dst, const int oc, const int ic, \ - const int ih, const int iw, const int oh, const int oh_block, \ - const int ow, const Op& op); - -#define GET_OP_PARAM(stride, filter, bias_mode) \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - TypeCvtOp) \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - ReluOp) \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - HSwishOp) +#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ + template void \ + conv_direct_int8_nchw_nchw44_dot( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, \ + int32_t* temp, int8_t* dst, const int oc, const int ic, const int ih, \ + const int iw, const int oh, const int oh_block, const int ow, \ + const Op& op); + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, TypeCvtOp) \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, ReluOp) \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, HSwishOp) #define GET_BIAS_MODE_PARAM(stride, filter) \ GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp index 786cbc30..e1a0afa3 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp @@ -16,67 +16,60 @@ namespace megdnn { namespace arm_common { namespace dot_direct_nchw_nchw44 { -template +template < + int src_idx, int weight_idx, typename Func, typename T, typename T2, + typename T3, typename T4> struct ShiftCalHelper { MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ - c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ - c[0][step * 2], weight[0][weight_idx], \ - src[0][(src_idx + step) / 4]); \ - c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \ - c[1][step * 2], weight[1][weight_idx], \ - src[0][(src_idx + step) / 4]); \ - c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ - c[0][step * 2 + 1], weight[0][weight_idx], \ - src[1][(src_idx + step) / 4]); \ - c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ - c[1][step * 2 + 1], weight[1][weight_idx], \ - src[1][(src_idx + step) / 4]); +#define cb(step) \ + c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2], weight[0][weight_idx], src[0][(src_idx + step) / 4]); \ + c[1][step * 2] = Func::template impl<(src_idx + step) % 4>( \ + c[1][step * 2], weight[1][weight_idx], src[0][(src_idx + step) / 4]); \ + c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2 + 1], weight[0][weight_idx], src[1][(src_idx + step) / 4]); \ + c[1][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ + c[1][step * 2 + 1], weight[1][weight_idx], src[1][(src_idx + step) / 4]); UNROLL_CALL_RAW(4, cb); #undef cb } }; -template +template < + int src_idx, int weight_idx, typename Func, typename T, typename T2, + typename T3, typename T4> struct ShiftCalHelper { MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ - c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ - c[0][step * 2], weight[0][weight_idx], \ - src[0][(src_idx + step) / 4]); \ - c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ - c[0][step * 2 + 1], weight[0][weight_idx], \ - src[1][(src_idx + step) / 4]); +#define cb(step) \ + c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2], weight[0][weight_idx], src[0][(src_idx + step) / 4]); \ + c[0][step * 2 + 1] = Func::template impl<(src_idx + step) % 4>( \ + c[0][step * 2 + 1], weight[0][weight_idx], src[1][(src_idx + step) / 4]); UNROLL_CALL_RAW(4, cb); #undef cb } }; -template -struct KerNeonDotXXs2Nchw44Int8 { +template +struct KerNeonDotXXs2Nchw44Int8 { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int, - int, int, int, const Op&) { + static void impl( + const int8_t*, const int8_t*, const int32_t*, int8_t*, int, int, int, int, + const Op&) { megdnn_assert(0, "not impl"); } }; -template -struct KerNeonDotXXs2Nchw44Int8 { +template +struct KerNeonDotXXs2Nchw44Int8 { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 2; constexpr int filter_hight = 2; constexpr int filter_width = 4; @@ -103,15 +96,13 @@ struct KerNeonDotXXs2Nchw44Int8( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); // row 1 load_helper( src, src_ptr + 1 * iw, stride); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); src_ptr += ic_stride; weight_ptr += filter_hight * filter_width * oc_step; @@ -121,14 +112,12 @@ struct KerNeonDotXXs2Nchw44Int8 -struct KerNeonDotXXs2Nchw44Int8 { +template +struct KerNeonDotXXs2Nchw44Int8 { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 2; constexpr int filter_hight = 3; constexpr int filter_width = 4; @@ -155,22 +144,19 @@ struct KerNeonDotXXs2Nchw44Int8( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); // row 1 load_helper( src, src_ptr + 1 * iw, stride); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); // row 2 load_helper( src, src_ptr + 2 * iw, stride); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, - weight); + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); src_ptr += ic_stride; weight_ptr += filter_hight * filter_width * oc_step; @@ -180,14 +166,12 @@ struct KerNeonDotXXs2Nchw44Int8 -struct KerNeonDotXXs2Nchw44Int8 { +template +struct KerNeonDotXXs2Nchw44Int8 { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 2; constexpr int filter_hight = 5; constexpr int filter_width = 8; @@ -210,13 +194,11 @@ struct KerNeonDotXXs2Nchw44Int8(src, src_ptr + step * iw, \ - stride); \ - load_helper( \ - weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ - weight); \ +#define cb(step) \ + load_helper(src, src_ptr + step * iw, stride); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); \ cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); UNROLL_CALL_RAW(5, cb); #undef cb @@ -236,14 +218,12 @@ struct KerNeonDotXXs2Nchw44Int8 -struct KerNeonDotXXs2Nchw44Int8 { +template +struct KerNeonDotXXs2Nchw44Int8 { MEGDNN_ATTRIBUTE_TARGET("dotprod") - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 2; constexpr int filter_hight = 7; constexpr int filter_width = 8; @@ -266,13 +246,11 @@ struct KerNeonDotXXs2Nchw44Int8(src, src_ptr + step * iw, \ - stride); \ - load_helper( \ - weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ - cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, \ - weight); \ +#define cb(step) \ + load_helper(src, src_ptr + step * iw, stride); \ + load_helper( \ + weight, weight_ptr + step * 2 * simd_len, ld_weight_oc); \ + cal_helper<0, 0, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); \ cal_helper<1, 1, c_dim, Vdotq_laneq_s32, ow_block, stride>(c, src, weight); UNROLL_CALL_RAW(7, cb); #undef cb @@ -296,8 +274,7 @@ void pack_src_int8_nchw_nchw44_dot<2>( sizeof(int8_t) * ic_step * iw2 * (ih + pad_top + pad_bottom)); sptr_base += iw2 * pad_top * ic_step; rep(ih_idx, ih) { - memcpy(sptr_base + pw * ic_step, sptr, - sizeof(int8_t) * iw * ic_step); + memcpy(sptr_base + pw * ic_step, sptr, sizeof(int8_t) * iw * ic_step); sptr_base += iw2 * ic_step; sptr += iw * ic_step; } @@ -307,12 +284,10 @@ void pack_src_int8_nchw_nchw44_dot<2>( template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const int oc, const int ic, - const int ih, const int iw, const int oh, - const int oh_block, const int ow, - const Op& op) { +void conv_direct_int8_nchw_nchw44_dot( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const int oc, const int ic, const int ih, const int iw, + const int oh, const int oh_block, const int ow, const Op& op) { MEGDNN_MARK_USED_VAR(temp); constexpr int fh = filter_size; constexpr int fw = (filter_size + 3) / 4 * 4; @@ -336,21 +311,18 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, const int oc_remain = oc - oc_end; const int ld_dst_oc = oc_step * img_stride; - using remain_fun = - std::function; + using remain_fun = std::function; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonDotXXs2Nchw44Int8::impl; \ - kern_small_oc_remain = \ - KerNeonDotXXs2Nchw44Int8::impl; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = KerNeonDotXXs2Nchw44Int8< \ + bias_mode, Op, step, filter_size, big_oc_step, ow_step, stride>::impl; \ + kern_small_oc_remain = KerNeonDotXXs2Nchw44Int8< \ + bias_mode, Op, step, filter_size, oc_step, ow_step, stride>::impl; \ break; UNROLL_CALL_RAW(8, cb); @@ -367,13 +339,11 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + KerNeonDotXXs2Nchw44Int8< + bias_mode, Op, ow_step, filter_size, big_oc_step, ow_step, + stride>:: + impl(src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } if (ow_remain > 0) { const int src_offset = @@ -381,9 +351,9 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } @@ -397,13 +367,10 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonDotXXs2Nchw44Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + KerNeonDotXXs2Nchw44Int8< + bias_mode, Op, ow_step, filter_size, oc_step, ow_step, stride>:: + impl(src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } if (ow_remain > 0) { const int src_offset = @@ -411,29 +378,29 @@ void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, pack_iw_len; const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } } -#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ - template void \ - conv_direct_int8_nchw_nchw44_dot( \ - const int8_t* src, const int8_t* filter, const int32_t* bias, \ - int32_t* temp, int8_t* dst, const int oc, const int ic, \ - const int ih, const int iw, const int oh, const int oh_block, \ - const int ow, const Op& op); - -#define GET_OP_PARAM(stride, filter, bias_mode) \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - TypeCvtOp) \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - ReluOp) \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - HSwishOp) +#define DO_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ + template void \ + conv_direct_int8_nchw_nchw44_dot( \ + const int8_t* src, const int8_t* filter, const int32_t* bias, \ + int32_t* temp, int8_t* dst, const int oc, const int ic, const int ih, \ + const int iw, const int oh, const int oh_block, const int ow, \ + const Op& op); + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, TypeCvtOp) \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, ReluOp) \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, HSwishOp) #define GET_BIAS_MODE_PARAM(stride, filter) \ GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp index c1b112cf..43e4839b 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp @@ -21,14 +21,12 @@ namespace megdnn { namespace arm_common { namespace { -template -static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - DstType* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, - const Op& op) { +template < + BiasMode bias_mode, typename Op, int remain_w, int filter_size, int c_dim, + typename DstType> +static void ker_neon_dirctconv_2x2s1_oc8_ow8( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -49,8 +47,8 @@ static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step * pack_iw_len; src[0] = vld1q_s8(src_ic_0_3); src[1] = vld1q_s8((src_ic_0_3 + 16)); @@ -63,8 +61,7 @@ static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0][0] = vld1q_s8(read_weight_ptr); weight[0][1] = vld1q_s8(read_weight_ptr + 16); @@ -113,14 +110,12 @@ static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, c, op, dst_ptr, ld_dst_oc); } -template -static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - DstType* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, - const Op& op) { +template < + BiasMode bias_mode, typename Op, int remain_w, int filter_size, int c_dim, + typename DstType> +static void ker_neon_dirctconv_2x2s1_oc4_ow8( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int oc_step = 4; @@ -139,8 +134,8 @@ static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step * pack_iw_len; src[0] = vld1q_s8(src_ic_0_3); src[1] = vld1q_s8((src_ic_0_3 + 16)); @@ -153,8 +148,7 @@ static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0][0] = vld1q_s8(read_weight_ptr); weight[0][1] = vld1q_s8(read_weight_ptr + 16); @@ -186,12 +180,13 @@ static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, c, op, dst_ptr, ld_dst_oc); } -template +template < + BiasMode bias_mode, typename Op, int remain_w, int filter_size, int c_dim, + typename DstType> struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc); + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, const Op& op, int ld_dst_oc); }; /** dot like impl. dot 4 ic to 1 oc, accumale to c @@ -203,12 +198,11 @@ high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> **/ //! TODO: can try oh = 2 impl, oc = 8 impl -template +template struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, const Op& op, int ld_dst_oc) { constexpr int filter_size = 3; constexpr int fh = filter_size; constexpr int fw = filter_size; @@ -285,12 +279,11 @@ struct KerNeonDirectStride1Int8 { } }; -template +template struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, const Op& op, int ld_dst_oc) { constexpr int filter_size = 5; constexpr int fh = filter_size; constexpr int fw = filter_size; @@ -388,12 +381,11 @@ struct KerNeonDirectStride1Int8 { } }; -template +template struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, const Op& op, int ld_dst_oc) { constexpr int filter_size = 7; constexpr int fh = filter_size; constexpr int fw = filter_size; @@ -513,13 +505,10 @@ struct KerNeonDirectStride1Int8 { }; template -void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - DstType* dst, const size_t oc, - const size_t ic, const size_t ih, - const size_t iw, const size_t oh, - const size_t ow, const Op& op) { +void conv_direct_stride1_2x2_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { MEGDNN_MARK_USED_VAR(temp); constexpr size_t filter_size = 2; constexpr size_t fh = filter_size; @@ -539,21 +528,18 @@ void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, const int ld_oc = oh * ow * oc_step; using remain_fun = std::function; + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op)>; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - ker_neon_dirctconv_2x2s1_oc8_ow8; \ - kern_small_oc_remain = \ - ker_neon_dirctconv_2x2s1_oc4_ow8; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = ker_neon_dirctconv_2x2s1_oc8_ow8< \ + bias_mode, Op, step, filter_size, 2, DstType>; \ + kern_small_oc_remain = ker_neon_dirctconv_2x2s1_oc4_ow8< \ + bias_mode, Op, step, filter_size, 1, DstType>; \ break; UNROLL_CALL_RAW(8, cb); @@ -569,8 +555,8 @@ void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s1_oc8_ow8( + ker_neon_dirctconv_2x2s1_oc8_ow8< + bias_mode, Op, ow_step, filter_size, 2, DstType>( src + src_offset, filter + weight_offset, bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_oc, op); } @@ -579,9 +565,9 @@ void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, (oh_idx * iw + ow_end) * ic_step * pack_iw_len; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, - ld_oc, op); + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); } } } @@ -594,8 +580,8 @@ void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s1_oc4_ow8( + ker_neon_dirctconv_2x2s1_oc4_ow8< + bias_mode, Op, ow_step, filter_size, 1, DstType>( src + src_offset, filter + weight_offset, bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_oc, op); } @@ -604,21 +590,18 @@ void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, (oh_idx * iw + ow_end) * ic_step * pack_iw_len; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_oc, op); + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc, op); } } } } template -void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - DstType* dst, const size_t oc, - const size_t ic, const size_t ih, - const size_t iw, const size_t oh, - const size_t ow, const Op& op) { +void conv_direct_stride1_int8_nchw44_kern( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { MEGDNN_MARK_USED_VAR(temp); constexpr size_t fh = filter_size; constexpr size_t fw = filter_size; @@ -634,17 +617,15 @@ void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, const size_t ow_remain = ow - ow_end; using remain_fun = std::function; + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, const Op& op, int ld_dst_oc)>; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_small_oc_remain = \ - KerNeonDirectStride1Int8::impl; \ +#define cb(step) \ + case step: \ + kern_small_oc_remain = KerNeonDirectStride1Int8< \ + bias_mode, Op, step, filter_size, 1, DstType>::impl; \ break; UNROLL_CALL_RAW(8, cb); @@ -661,21 +642,19 @@ void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonDirectStride1Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, op, ld_dst_oc); + KerNeonDirectStride1Int8< + bias_mode, Op, ow_step, filter_size, 1, DstType>:: + impl(src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op, ld_dst_oc); } if (ow_remain > 0) { const size_t src_offset = (oh_idx * iw + ow_end) * ic_step * pack_iw_len; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, op, ld_dst_oc); + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op, ld_dst_oc); } } } @@ -685,43 +664,43 @@ void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, namespace int8_direct_nchw44 { template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, DstType* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - conv_direct_stride1_int8_nchw44_kern( + static void impl( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { + conv_direct_stride1_int8_nchw44_kern( src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); } }; template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, DstType* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op) { + static void impl( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { conv_direct_stride1_2x2_int8_nchw44( src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); } }; -#define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \ - template struct ConvDirectInt8Nchw44Choose; - -#define GET_OP_PARAM(stride, filter, bias_mode) \ - DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - \ - TypeCvtOp) \ - DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - \ - ReluOp) \ - DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - \ - HSwishOp) \ +#define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \ + template struct ConvDirectInt8Nchw44Choose< \ + bias_mode, Op, filter_size, DstType, stride>; + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + DO_CONV_KERN_FUN( \ + stride, dt_qint8, filter, bias_mode, \ + \ + TypeCvtOp) \ + DO_CONV_KERN_FUN( \ + stride, dt_qint8, filter, bias_mode, \ + \ + ReluOp) \ + DO_CONV_KERN_FUN( \ + stride, dt_qint8, filter, bias_mode, \ + \ + HSwishOp) \ DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, NoneOp) #define GET_BIAS_MODE_PARAM(stride, filter) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp index 2b4ab1ec..b2a3cff5 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp @@ -22,22 +22,21 @@ namespace megdnn { namespace arm_common { namespace { -template +template < + BiasMode bias_mode, typename Op, int remain_w, int filter_size, int c_dim, + typename DstType> struct KerNeonDirectStride2Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc); + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, const Op& op, int ld_dst_oc); }; -template -static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - DstType* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, - const Op& op) { +template < + BiasMode bias_mode, typename Op, int remain_w, int filter_size, int c_dim, + typename DstType> +static void ker_neon_dirctconv_2x2s2_oc8_ow8( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -58,8 +57,8 @@ static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step * pack_iw_len; src[0] = vld1q_s8(src_ic_0_3); src[1] = vld1q_s8(src_ic_0_3 + 16); @@ -72,8 +71,7 @@ static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, src[8] = vld1q_s8(src_ic_0_3 + 8 * 16); // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0][0] = vld1q_s8(read_weight_ptr); weight[0][1] = vld1q_s8(read_weight_ptr + 16); @@ -129,14 +127,12 @@ static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, c, op, dst_ptr, ld_dst_oc); } -template -static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int32_t* bias_ptr, - DstType* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, - const Op& op) { +template < + BiasMode bias_mode, typename Op, int remain_w, int filter_size, int c_dim, + typename DstType> +static void ker_neon_dirctconv_2x2s2_oc4_ow8( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int oc_step = 4; @@ -155,8 +151,8 @@ static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; + const int8_t* src_ic_0_3 = + src_ptr + ic_idx * ic_stride + fh_idx * iw * ic_step * pack_iw_len; src[0] = vld1q_s8(src_ic_0_3); src[1] = vld1q_s8((src_ic_0_3 + 16)); @@ -169,8 +165,7 @@ static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, src[8] = vld1q_s8((src_ic_0_3 + 8 * 16)); // oc == 0 - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0] = vld1q_s8(read_weight_ptr); weight[1] = vld1q_s8(read_weight_ptr + 16); @@ -218,12 +213,11 @@ high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> dot: (<0, 0> + <0, 3>) + (<0, 1> + <0, 2>) -> <0> **/ // TODO: can try oh = 2 impl, oc = 8 impl -template +template struct KerNeonDirectStride2Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, const Op& op, int ld_dst_oc) { constexpr int filter_size = 3; constexpr int fh = filter_size; constexpr int fw = filter_size; @@ -306,12 +300,11 @@ struct KerNeonDirectStride2Int8 { c, op, dst_ptr, ld_dst_oc); } }; -template +template struct KerNeonDirectStride2Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, const Op& op, int ld_dst_oc) { constexpr int filter_size = 5; constexpr int fh = filter_size; constexpr int fw = filter_size; @@ -414,12 +407,11 @@ struct KerNeonDirectStride2Int8 { c, op, dst_ptr, ld_dst_oc); } }; -template +template struct KerNeonDirectStride2Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, DstType* dst_ptr, int ic, int ih, - int iw, const Op& op, int ld_dst_oc) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, const Op& op, int ld_dst_oc) { constexpr int filter_size = 7; constexpr int fh = filter_size; constexpr int fw = filter_size; @@ -567,21 +559,18 @@ void conv_direct_stride2_2x2_int8_nchw44( const int ld_dst_oc = oh * ow * oc_step; using remain_fun = std::function; + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op)>; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - ker_neon_dirctconv_2x2s2_oc8_ow8; \ - kern_small_oc_remain = \ - ker_neon_dirctconv_2x2s2_oc4_ow8; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = ker_neon_dirctconv_2x2s2_oc8_ow8< \ + bias_mode, Op, step, filter_size, 2, DstType>; \ + kern_small_oc_remain = ker_neon_dirctconv_2x2s2_oc4_ow8< \ + bias_mode, Op, step, filter_size, 1, DstType>; \ break; UNROLL_CALL_RAW(8, cb); @@ -594,25 +583,23 @@ void conv_direct_stride2_2x2_int8_nchw44( const size_t weight_offset = oc_idx * ic * fh * fw; for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s2_oc8_ow8( + const size_t src_offset = (oh_idx * stride_h * iw + ow_idx * stride_w) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc8_ow8< + bias_mode, Op, ow_step, filter_size, 2, DstType>( src + src_offset, filter + weight_offset, bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); + const size_t src_offset = (oh_idx * stride_h * iw + ow_end * stride_w) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } @@ -622,25 +609,23 @@ void conv_direct_stride2_2x2_int8_nchw44( const size_t weight_offset = oc_idx * ic * fh * fw; for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s2_oc4_ow8( + const size_t src_offset = (oh_idx * stride_h * iw + ow_idx * stride_w) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc4_ow8< + bias_mode, Op, ow_step, filter_size, 1, DstType>( src + src_offset, filter + weight_offset, bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * - pack_iw_len; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + const size_t src_offset = (oh_idx * stride_h * iw + ow_end * stride_w) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } @@ -667,17 +652,15 @@ void conv_direct_stride2_int8_nchw44_kern( const int ld_dst_oc = oh * ow * oc_step; using remain_fun = std::function; + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + DstType* dst_ptr, int ic, int ih, int iw, const Op& op, int ld_dst_oc)>; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_small_oc_remain = \ - KerNeonDirectStride2Int8::impl; \ +#define cb(step) \ + case step: \ + kern_small_oc_remain = KerNeonDirectStride2Int8< \ + bias_mode, Op, step, filter_size, 1, DstType>::impl; \ break; UNROLL_CALL_RAW(8, cb); @@ -690,27 +673,23 @@ void conv_direct_stride2_int8_nchw44_kern( const size_t weight_offset = oc_idx * ic * fh * fw; for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step * - pack_iw_len; + const size_t src_offset = (oh_idx * stride_h * iw + ow_idx * stride_w) * + ic_step * pack_iw_len; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonDirectStride2Int8::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, op, ld_dst_oc); + KerNeonDirectStride2Int8< + bias_mode, Op, ow_step, filter_size, 1, DstType>:: + impl(src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op, ld_dst_oc); } if (ow_remain > 0) { - const size_t src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step * - pack_iw_len; + const size_t src_offset = (oh_idx * stride_h * iw + ow_end * stride_w) * + ic_step * pack_iw_len; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, op, ld_dst_oc); + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, op, ld_dst_oc); } } } @@ -720,43 +699,43 @@ void conv_direct_stride2_int8_nchw44_kern( namespace int8_direct_nchw44 { template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, DstType* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - conv_direct_stride2_int8_nchw44_kern( + static void impl( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { + conv_direct_stride2_int8_nchw44_kern( src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); } }; template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, DstType* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op) { + static void impl( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { conv_direct_stride2_2x2_int8_nchw44( src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); } }; -#define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \ - template struct ConvDirectInt8Nchw44Choose; - -#define GET_OP_PARAM(stride, filter, bias_mode) \ - DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - \ - TypeCvtOp) \ - DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - \ - ReluOp) \ - DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - \ - HSwishOp) \ +#define DO_CONV_KERN_FUN(stride, DstType, filter_size, bias_mode, Op) \ + template struct ConvDirectInt8Nchw44Choose< \ + bias_mode, Op, filter_size, DstType, stride>; + +#define GET_OP_PARAM(stride, filter, bias_mode) \ + DO_CONV_KERN_FUN( \ + stride, dt_qint8, filter, bias_mode, \ + \ + TypeCvtOp) \ + DO_CONV_KERN_FUN( \ + stride, dt_qint8, filter, bias_mode, \ + \ + ReluOp) \ + DO_CONV_KERN_FUN( \ + stride, dt_qint8, filter, bias_mode, \ + \ + HSwishOp) \ DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, NoneOp) #define GET_BIAS_MODE_PARAM(stride, filter) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h index 2eaf00ef..74a11e4f 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_common.h @@ -16,12 +16,13 @@ namespace megdnn { namespace arm_common { namespace { -template +template < + BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, + int stride> struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op); + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op); }; template diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp index 6537b0f4..f238a1fc 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp @@ -30,93 +30,96 @@ namespace { * @tparam T4 temp regs type */ -template +template < + int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, + typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp); static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); }; -template +template < + int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, + typename T3, typename T4> MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) { ShiftCalHelper::impl( c, src, weight, temp); } -template +template < + int src_idx, int weight_idx, int c_dim, int stride, typename T, typename T2, + typename T3> MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { ShiftCalHelper::impl( c, src, weight); }; -template +template < + int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[0][weight_idx], - c[0][0], temp[0]); - c[1][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[1][weight_idx], - c[1][0], temp[1]); - c[0][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[0][weight_idx], - c[0][1], temp[2]); - c[1][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[1][weight_idx], - c[1][1], temp[3]); - c[0][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[0][weight_idx], - c[0][2], temp[0]); - c[1][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[1][weight_idx], - c[1][2], temp[1]); - c[0][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[0][weight_idx], - c[0][3], temp[2]); - c[1][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[1][weight_idx], - c[1][3], temp[3]); - - c[0][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[0][weight_idx], - c[0][4], temp[0]); - c[1][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[1][weight_idx], - c[1][4], temp[1]); - c[0][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[0][weight_idx], - c[0][5], temp[2]); - c[1][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[1][weight_idx], - c[1][5], temp[3]); - c[0][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[0][weight_idx], - c[0][6], temp[0]); - c[1][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[1][weight_idx], - c[1][6], temp[1]); - c[0][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[0][weight_idx], - c[0][7], temp[2]); - c[1][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[1][weight_idx], - c[1][7], temp[3]); + c[0][0] = vdotq_s32_h( + src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]); + c[1][0] = vdotq_s32_h( + src[(0 + src_idx) % 8], weight[1][weight_idx], c[1][0], temp[1]); + c[0][1] = vdotq_s32_h( + src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[2]); + c[1][1] = vdotq_s32_h( + src[(1 + src_idx) % 8], weight[1][weight_idx], c[1][1], temp[3]); + c[0][2] = vdotq_s32_h( + src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[0]); + c[1][2] = vdotq_s32_h( + src[(2 + src_idx) % 8], weight[1][weight_idx], c[1][2], temp[1]); + c[0][3] = vdotq_s32_h( + src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[2]); + c[1][3] = vdotq_s32_h( + src[(3 + src_idx) % 8], weight[1][weight_idx], c[1][3], temp[3]); + + c[0][4] = vdotq_s32_h( + src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]); + c[1][4] = vdotq_s32_h( + src[(4 + src_idx) % 8], weight[1][weight_idx], c[1][4], temp[1]); + c[0][5] = vdotq_s32_h( + src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[2]); + c[1][5] = vdotq_s32_h( + src[(5 + src_idx) % 8], weight[1][weight_idx], c[1][5], temp[3]); + c[0][6] = vdotq_s32_h( + src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[0]); + c[1][6] = vdotq_s32_h( + src[(6 + src_idx) % 8], weight[1][weight_idx], c[1][6], temp[1]); + c[0][7] = vdotq_s32_h( + src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[2]); + c[1][7] = vdotq_s32_h( + src[(7 + src_idx) % 8], weight[1][weight_idx], c[1][7], temp[3]); } static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); }; -template +template < + int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = vdotq_s32_h(src[(0 + src_idx) % 8], weight[0][weight_idx], - c[0][0], temp[0]); - c[0][1] = vdotq_s32_h(src[(1 + src_idx) % 8], weight[0][weight_idx], - c[0][1], temp[1]); - c[0][2] = vdotq_s32_h(src[(2 + src_idx) % 8], weight[0][weight_idx], - c[0][2], temp[2]); - c[0][3] = vdotq_s32_h(src[(3 + src_idx) % 8], weight[0][weight_idx], - c[0][3], temp[3]); - c[0][4] = vdotq_s32_h(src[(4 + src_idx) % 8], weight[0][weight_idx], - c[0][4], temp[0]); - c[0][5] = vdotq_s32_h(src[(5 + src_idx) % 8], weight[0][weight_idx], - c[0][5], temp[1]); - c[0][6] = vdotq_s32_h(src[(6 + src_idx) % 8], weight[0][weight_idx], - c[0][6], temp[2]); - c[0][7] = vdotq_s32_h(src[(7 + src_idx) % 8], weight[0][weight_idx], - c[0][7], temp[3]); + c[0][0] = vdotq_s32_h( + src[(0 + src_idx) % 8], weight[0][weight_idx], c[0][0], temp[0]); + c[0][1] = vdotq_s32_h( + src[(1 + src_idx) % 8], weight[0][weight_idx], c[0][1], temp[1]); + c[0][2] = vdotq_s32_h( + src[(2 + src_idx) % 8], weight[0][weight_idx], c[0][2], temp[2]); + c[0][3] = vdotq_s32_h( + src[(3 + src_idx) % 8], weight[0][weight_idx], c[0][3], temp[3]); + c[0][4] = vdotq_s32_h( + src[(4 + src_idx) % 8], weight[0][weight_idx], c[0][4], temp[0]); + c[0][5] = vdotq_s32_h( + src[(5 + src_idx) % 8], weight[0][weight_idx], c[0][5], temp[1]); + c[0][6] = vdotq_s32_h( + src[(6 + src_idx) % 8], weight[0][weight_idx], c[0][6], temp[2]); + c[0][7] = vdotq_s32_h( + src[(7 + src_idx) % 8], weight[0][weight_idx], c[0][7], temp[3]); } static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&); }; template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 1; constexpr int filter_height = 1; constexpr int filter_width = 4; @@ -152,12 +155,11 @@ struct KerNeonXXs2NchwNchw44 { } }; - template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 1; constexpr int filter_height = 2; constexpr int filter_width = 4; @@ -186,8 +188,7 @@ struct KerNeonXXs2NchwNchw44 { cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); load_helper( - dot4_weight, weight_ptr + 1 * filter_width * oc_step, - ld_weight_oc); + dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc); load_helper( src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); @@ -202,9 +203,9 @@ struct KerNeonXXs2NchwNchw44 { template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 1; constexpr int filter_height = 3; constexpr int filter_width = 4; @@ -233,16 +234,14 @@ struct KerNeonXXs2NchwNchw44 { src, nchw_src_ptr + 0 * iw * pack_iw_len, 0); cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); load_helper( - dot4_weight, weight_ptr + 1 * filter_width * oc_step, - ld_weight_oc); + dot4_weight, weight_ptr + 1 * filter_width * oc_step, ld_weight_oc); load_helper( src, nchw_src_ptr + 1 * iw * pack_iw_len, 0); cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); load_helper( - dot4_weight, weight_ptr + 2 * filter_width * oc_step, - ld_weight_oc); + dot4_weight, weight_ptr + 2 * filter_width * oc_step, ld_weight_oc); load_helper( src, nchw_src_ptr + 2 * iw * pack_iw_len, 0); cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); @@ -256,9 +255,9 @@ struct KerNeonXXs2NchwNchw44 { template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 1; constexpr int filter_height = 5; constexpr int filter_width = 8; @@ -280,17 +279,14 @@ struct KerNeonXXs2NchwNchw44 { int8x16_t src[src_reg]; int8x16_t dot4_weight[c_dim][weight_reg]; int16x8_t temp_c[4]; -#define cb(step) \ - load_helper( \ - dot4_weight, weight_ptr + step * filter_width * oc_step, \ - ld_weight_oc); \ - load_helper( \ - src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ - cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ - load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ - src, \ - nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \ - 0); \ +#define cb(step) \ + load_helper( \ + dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \ + load_helper( \ + src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ + load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ + src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \ cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); UNROLL_CALL_RAW(5, cb); #undef cb @@ -303,9 +299,9 @@ struct KerNeonXXs2NchwNchw44 { template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 1; constexpr int filter_height = 7; constexpr int filter_width = 8; @@ -327,17 +323,14 @@ struct KerNeonXXs2NchwNchw44 { int8x16_t src[src_reg]; int8x16_t dot4_weight[c_dim][weight_reg]; int16x8_t temp_c[4]; -#define cb(step) \ - load_helper( \ - dot4_weight, weight_ptr + step * filter_width * oc_step, \ - ld_weight_oc); \ - load_helper( \ - src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ - cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ - load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ - src, \ - nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, \ - 0); \ +#define cb(step) \ + load_helper( \ + dot4_weight, weight_ptr + step * filter_width * oc_step, ld_weight_oc); \ + load_helper( \ + src, nchw_src_ptr + step * iw * pack_iw_len, 0); \ + cal_helper<0, 0, c_dim, stride>(c, src, dot4_weight, temp_c); \ + load_helper<4, 0, simd_len, 0, Vld1q_s8>( \ + src, nchw_src_ptr + step * iw * pack_iw_len + src_reg * pack_iw_len, 0); \ cal_helper<4, 1, c_dim, stride>(c, src, dot4_weight, temp_c); UNROLL_CALL_RAW(7, cb); @@ -356,9 +349,9 @@ namespace int8_direct_nchw_nchw44 { * pack interleave two adjacent row in filter to one row * */ template <> -void pack_nchw44_weight_for_nchw_conv<1>(const int8_t* src_ptr, int8_t* dst_ptr, - const int ic, const int fh, - const int fw, const int oc) { +void pack_nchw44_weight_for_nchw_conv<1>( + const int8_t* src_ptr, int8_t* dst_ptr, const int ic, const int fh, + const int fw, const int oc) { constexpr int oc_step = 4; const int fw2 = round_up(fw, 4); const int fw_remain = fw2 - fw; @@ -370,8 +363,8 @@ void pack_nchw44_weight_for_nchw_conv<1>(const int8_t* src_ptr, int8_t* dst_ptr, rep_step(oc_idx, oc, oc_step) { int32_t* dst_temp_ptr = reinterpret_cast(dst_ptr + oc_idx * ic * fh * fw2); - const int32_t* src_temp_ptr = reinterpret_cast( - src_ptr + oc_idx * ic * fh * fw); + const int32_t* src_temp_ptr = + reinterpret_cast(src_ptr + oc_idx * ic * fh * fw); // transpose ic and pad rep(fh_idx, fh) { rep(fw_idx, fw) { @@ -393,8 +386,7 @@ void pack_nchw44_weight_for_nchw_conv<1>(const int8_t* src_ptr, int8_t* dst_ptr, rep_step(idx, oc_step_stride, 16) { int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx); - vst1q_s8(trans_dst_temp_ptr + idx, - vqtbl1q_s8(temp, tbl_transpose_4x4)); + vst1q_s8(trans_dst_temp_ptr + idx, vqtbl1q_s8(temp, tbl_transpose_4x4)); } } }; @@ -404,14 +396,11 @@ void pack_nchw44_weight_for_nchw_conv<1>(const int8_t* src_ptr, int8_t* dst_ptr, * pack interleave two adjacent row in src and repeat 4 times, store to one row * */ template <> -void pack_nchw_src_for_nchw44_conv<1>(const int8_t* sptr_origin, - int8_t* sptr_base, const int ic, - const int pad_top, const int pad_bottom, - const int, const int, const int ih, - const int iw, const int iw2, const int pw, - int8_t* temp_ptr) { - static uint8_t reorder_idx[16] = {0, 1, 0, 1, 0, 1, 0, 1, - 2, 3, 2, 3, 2, 3, 2, 3}; +void pack_nchw_src_for_nchw44_conv<1>( + const int8_t* sptr_origin, int8_t* sptr_base, const int ic, const int pad_top, + const int pad_bottom, const int, const int, const int ih, const int iw, + const int iw2, const int pw, int8_t* temp_ptr) { + static uint8_t reorder_idx[16] = {0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3}; uint8x16_t tbl_idx = vld1q_u8(&reorder_idx[0]); constexpr int iw_step = 4; @@ -422,8 +411,7 @@ void pack_nchw_src_for_nchw44_conv<1>(const int8_t* sptr_origin, rep(ic_idx, ic) { const int8_t* sptr = sptr_origin + ic_idx * ic_stride; memset(sptr_base, 0, - sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * - pack_iw_len); + sizeof(int8_t) * iw2 * (ih + pad_top + pad_bottom) * pack_iw_len); sptr_base += iw2 * pad_top * pack_iw_len; rep(ih_idx, ih) { memset(temp_ptr, 0, iw_with_pad * sizeof(int8_t)); @@ -458,11 +446,10 @@ void pack_nchw_src_for_nchw44_conv<1>(const int8_t* sptr_origin, template struct ConvDiectStrideInt8NchwNchw44 { - static void impl(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, int8_t* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op) { + static void impl( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { MEGDNN_MARK_USED_VAR(temp); constexpr int stride = 1; constexpr size_t fh = filter_size; @@ -486,19 +473,17 @@ struct ConvDiectStrideInt8NchwNchw44 { using remain_fun = std::function; + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op)>; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonXXs2NchwNchw44::impl; \ - kern_small_oc_remain = \ - KerNeonXXs2NchwNchw44::impl; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = KerNeonXXs2NchwNchw44< \ + bias_mode, Op, step, filter_size, big_oc_step, stride>::impl; \ + kern_small_oc_remain = KerNeonXXs2NchwNchw44< \ + bias_mode, Op, step, filter_size, oc_step, stride>::impl; \ break; UNROLL_CALL_RAW(8, cb); @@ -510,29 +495,27 @@ struct ConvDiectStrideInt8NchwNchw44 { const size_t weight_offset = oc_idx * ic * fh * fw; for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; - - KerNeonXXs2NchwNchw44::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + + KerNeonXXs2NchwNchw44< + bias_mode, Op, ow_step, filter_size, big_oc_step, stride>:: + impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, + op); } if (ow_remain > 0) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } @@ -542,46 +525,42 @@ struct ConvDiectStrideInt8NchwNchw44 { const size_t weight_offset = oc_idx * ic * fh * fw; for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44< + bias_mode, Op, ow_step, filter_size, oc_step, stride>:: + impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, + op); } if (ow_remain > 0) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, - filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } } }; -#define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ - template struct ConvDiectStrideInt8NchwNchw44; - -#define INSTANCE_OP_PARAM(stride, filter, bias_mode) \ - INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ - TypeCvtOp) \ - INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ - ReluOp) \ - INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ - HSwishOp) +#define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ + template struct ConvDiectStrideInt8NchwNchw44; + +#define INSTANCE_OP_PARAM(stride, filter, bias_mode) \ + INSTANCE_CONV_KERN_FUN( \ + stride, filter, bias_mode, TypeCvtOp) \ + INSTANCE_CONV_KERN_FUN( \ + stride, filter, bias_mode, ReluOp) \ + INSTANCE_CONV_KERN_FUN( \ + stride, filter, bias_mode, HSwishOp) #define INSTANCE_BIAS_MODE_PARAM(stride, filter) \ INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp index 23345c0e..67bf9a15 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s2.cpp @@ -30,44 +30,40 @@ namespace { * @tparam T4 temp regs type */ -template +template < + int src_idx, int weight_idx, int c_dim, typename Func, int stride, typename T, + typename T2, typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp); static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); }; -template +template < + int src_idx, int weight_idx, int c_dim, typename FUNC, int stride, typename T, + typename T2, typename T3, typename T4> MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight, T4& temp) { - ShiftCalHelper::impl(c, src, weight, temp); + ShiftCalHelper::impl( + c, src, weight, temp); } -template +template < + int src_idx, int weight_idx, int c_dim, typename FUNC, int stride, typename T, + typename T2, typename T3> MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl(c, src, weight); + ShiftCalHelper::impl( + c, src, weight); }; -template +template < + int src_idx, int weight_idx, typename Func, typename T, typename T2, + typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], - temp[0]); - c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0], - temp[1]); - c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], - temp[2]); - c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1], - temp[3]); - c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], - temp[0]); - c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2], - temp[1]); - c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], - temp[2]); - c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3], - temp[3]); + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], temp[0]); + c[1][0] = Func::impl(src[0 + src_idx], weight[1][weight_idx], c[1][0], temp[1]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], temp[2]); + c[1][1] = Func::impl(src[1 + src_idx], weight[1][weight_idx], c[1][1], temp[3]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], temp[0]); + c[1][2] = Func::impl(src[2 + src_idx], weight[1][weight_idx], c[1][2], temp[1]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], temp[2]); + c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3], temp[3]); } static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); @@ -80,18 +76,15 @@ struct ShiftCalHelper { c[1][3] = Func::impl(src[3 + src_idx], weight[1][weight_idx], c[1][3]); } }; -template +template < + int src_idx, int weight_idx, typename Func, typename T, typename T2, + typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight, T4& temp) { - c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], - temp[0]); - c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], - temp[2]); - c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], - temp[0]); - c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], - temp[2]); + c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0], temp[0]); + c[0][1] = Func::impl(src[1 + src_idx], weight[0][weight_idx], c[0][1], temp[2]); + c[0][2] = Func::impl(src[2 + src_idx], weight[0][weight_idx], c[0][2], temp[0]); + c[0][3] = Func::impl(src[3 + src_idx], weight[0][weight_idx], c[0][3], temp[2]); } static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { c[0][0] = Func::impl(src[0 + src_idx], weight[0][weight_idx], c[0][0]); @@ -129,9 +122,9 @@ struct ShiftCalHelper { **/ template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, 0, 8, 0, 8, 0, 8, 0, 8}; constexpr int stride = 2; @@ -152,14 +145,13 @@ struct KerNeonXXs2NchwNchw44 { for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { - const int8_t* nchw_src_ptr = - src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; int8x16_t src[6]; int8x16_t dot4_weight[c_dim][3]; int16x8_t temp_c[4]; - load_helper<3, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, - ld_dot4_weight_oc); + load_helper<3, 0, 16, c_dim, Vld1q_s8>( + dot4_weight, weight_ptr, ld_dot4_weight_oc); load_helper<6, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( c, src, dot4_weight, temp_c); @@ -172,38 +164,34 @@ struct KerNeonXXs2NchwNchw44 { int8x8_t dot2_weight[c_dim][1]; load_helper<1, 3 * 16, 8, c_dim, Vld1_s8>( dot2_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 3 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, - 0); + load_helper<4, 3 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, 0); cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( c, src_dot2, dot2_weight, temp_c); weight_ptr += filter_size * pack_iw_len * fh_step; } - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + - 6 * iw * ic_step * pack_iw_len; + const int8_t* nchw_src_ptr = + src_ptr + ic_idx * ic_stride + 6 * iw * ic_step * pack_iw_len; int8x8_t dot2_weight[c_dim][3]; int16x8_t temp_c[4]; int8x8_t src_dot2[6]; uint8x16_t tbl = vld1q_u8(src_idx_buffer); - load_helper<3, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, - ld_dot4_weight_oc); - load_helper_x<6, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, - 0, tbl); - cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(c, src_dot2, - dot2_weight, temp_c); - cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>(c, src_dot2, - dot2_weight, temp_c); - cal_helper<2, 2, c_dim, Vdot2_s32_h, stride>(c, src_dot2, - dot2_weight, temp_c); + load_helper<3, 0, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_dot4_weight_oc); + load_helper_x<6, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, 0, tbl); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); + cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); + cal_helper<2, 2, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); int16x8_t dot1_weight[c_dim][1]; int16x8_t src_dot1[4]; load_helper<1, 3 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( dot1_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 3 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, - nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, - dot1_weight); + load_helper<4, 3 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, dot1_weight); weight_ptr += filter_size * pack_iw_len; } store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); @@ -212,9 +200,9 @@ struct KerNeonXXs2NchwNchw44 { #if MEGDNN_AARCH64 template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, 0, 8, 0, 8, 0, 8, 0, 8}; uint8x16_t vtbl = vld1q_u8(src_idx_buffer); @@ -244,8 +232,7 @@ struct KerNeonXXs2NchwNchw44 { const int8_t* weight_ptr_oc = weight_ptr + ld_dot4_weight_oc; const int8_t* nchw_src_ptr_last_line = - src_ptr + ic_idx * ic_stride + - 6 * iw * ic_step * pack_iw_len; + src_ptr + ic_idx * ic_stride + 6 * iw * ic_step * pack_iw_len; /** * r0-r7 c * r24-r31 temp @@ -652,22 +639,20 @@ struct KerNeonXXs2NchwNchw44 { "smlal %[c13].4s, v12.4h, v19.4h\n" : - [c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]), - [c01] "+w"(c[0][1]), [c11] "+w"(c[1][1]), - [c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]), + [c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]), [c01] "+w"(c[0][1]), + [c11] "+w"(c[1][1]), [c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]), [c03] "+w"(c[0][3]), [c13] "+w"(c[1][3]), - [nchw_src_ptr] "+r"(nchw_src_ptr), - [weight_ptr] "+r"(weight_ptr), + [nchw_src_ptr] "+r"(nchw_src_ptr), [weight_ptr] "+r"(weight_ptr), [weight_ptr_oc] "+r"(weight_ptr_oc) : [vtbl] "w"(vtbl), [nchw_src_ptr_last_line] "r"(nchw_src_ptr_last_line), [src_step] "r"(src_step), [weight_step] "r"(weight_step), [weight_step_small] "r"(weight_step_small) - : "x5", "x6", "x7", "x8", "x9", "x10", "v8", "v9", "v10", - "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", - "v19", "v20", "v21", "v24", "v25", "v26", "v27", "v28", - "v29", "v30", "v31", "cc", "memory"); + : "x5", "x6", "x7", "x8", "x9", "x10", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "cc", "memory"); } store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); } @@ -675,9 +660,9 @@ struct KerNeonXXs2NchwNchw44 { #endif template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 2; constexpr int filter_size = 5; static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, @@ -697,14 +682,13 @@ struct KerNeonXXs2NchwNchw44 { for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { for (int fh_idx = 0; fh_idx < fh_end; fh_idx += ih_step) { - const int8_t* nchw_src_ptr = - src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * pack_iw_len; + const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * pack_iw_len; int8x16_t src[5]; int8x16_t dot4_weight[c_dim][2]; int16x8_t temp_c[4]; - load_helper<2, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, - ld_dot4_weight_oc); + load_helper<2, 0, 16, c_dim, Vld1q_s8>( + dot4_weight, weight_ptr, ld_dot4_weight_oc); load_helper<5, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( c, src, dot4_weight, temp_c); @@ -715,38 +699,34 @@ struct KerNeonXXs2NchwNchw44 { int8x8_t dot2_weight[c_dim][1]; load_helper<1, 2 * 16, 8, c_dim, Vld1_s8>( dot2_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 2 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, - 0); + load_helper<4, 2 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, 0); cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( c, src_dot2, dot2_weight, temp_c); weight_ptr += filter_size * pack_iw_len * ih_step; } - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + - fh_end * iw * ic_step * pack_iw_len; + const int8_t* nchw_src_ptr = + src_ptr + ic_idx * ic_stride + fh_end * iw * ic_step * pack_iw_len; int8x8_t dot2_weight[c_dim][2]; int16x8_t temp_c[4]; int8x8_t src_dot2[5]; uint8x16_t tbl = vld1q_u8(src_idx_buffer); - load_helper<2, 0, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, - ld_dot4_weight_oc); - load_helper_x<5, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, - 0, tbl); + load_helper<2, 0, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_dot4_weight_oc); + load_helper_x<5, 0, 16, 0, Vldq_tbl_low_s8>(src_dot2, nchw_src_ptr, 0, tbl); - cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>(c, src_dot2, - dot2_weight, temp_c); - cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>(c, src_dot2, - dot2_weight, temp_c); + cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); + cal_helper<1, 1, c_dim, Vdot2_s32_h, stride>( + c, src_dot2, dot2_weight, temp_c); int16x8_t dot1_weight[c_dim][1]; int16x8_t src_dot1[4]; load_helper<1, 2 * 8, 8, c_dim, Vldq_dup_4s8_8s16>( dot1_weight, weight_ptr, ld_dot4_weight_oc); - load_helper<4, 2 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, - nchw_src_ptr, 0); + load_helper<4, 2 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, - dot1_weight); + cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, dot1_weight); weight_ptr += filter_size * pack_iw_len; } store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); @@ -769,9 +749,9 @@ struct KerNeonXXs2NchwNchw44 { **/ template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int stride = 2; constexpr int filter_size = 3; static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, @@ -795,8 +775,8 @@ struct KerNeonXXs2NchwNchw44 { int8x16_t src[4]; int8x16_t dot4_weight[c_dim][1]; int16x8_t temp_c[4]; - load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, - ld_weight_oc); + load_helper<1, 0, 16, c_dim, Vld1q_s8>( + dot4_weight, weight_ptr, ld_weight_oc); load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>( c, src, dot4_weight, temp_c); @@ -805,21 +785,20 @@ struct KerNeonXXs2NchwNchw44 { int8x8_t dot2_weight[c_dim][1]; load_helper<1, 1 * 16, 8, c_dim, Vld1_s8>( dot2_weight, weight_ptr, ld_weight_oc); - load_helper<4, 1 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, - 0); + load_helper<4, 1 * 16, 16, 0, Vld1_s8>(src_dot2, nchw_src_ptr, 0); cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( c, src_dot2, dot2_weight, temp_c); } // last line { - const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride + - 2 * iw * ic_step * pack_iw_len; + const int8_t* nchw_src_ptr = + src_ptr + ic_idx * ic_stride + 2 * iw * ic_step * pack_iw_len; int16x8_t temp_c[4]; int8x8_t src_dot2[4]; int8x8_t dot2_weight[c_dim][1]; uint8x16_t tbl = vld1q_u8(src_idx_buffer); - load_helper<1, 24, 8, c_dim, Vld1_s8>(dot2_weight, weight_ptr, - ld_weight_oc); + load_helper<1, 24, 8, c_dim, Vld1_s8>( + dot2_weight, weight_ptr, ld_weight_oc); load_helper_x<4, 0, 16, 0, Vldq_tbl_low_s8>( src_dot2, nchw_src_ptr, 0, tbl); cal_helper<0, 0, c_dim, Vdot2_s32_h, stride>( @@ -828,10 +807,9 @@ struct KerNeonXXs2NchwNchw44 { int16x8_t src_dot1[4]; load_helper<1, 32, 8, c_dim, Vldq_dup_4s8_8s16>( dot1_weight, weight_ptr, ld_weight_oc); - load_helper<4, 1 * 16, 16, 0, Vld1_dup_s8_s16>(src_dot1, - nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, - dot1_weight); + load_helper<4, 1 * 16, 16, 0, Vld1_dup_s8_s16>( + src_dot1, nchw_src_ptr, 0); + cal_helper<0, 0, c_dim, Vmlal_s16, stride>(c, src_dot1, dot1_weight); weight_ptr += filter_size * filter_size * pack_iw_len; } } @@ -842,9 +820,9 @@ struct KerNeonXXs2NchwNchw44 { #if MEGDNN_AARCH64 template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int filter_size = 3; static const uint8_t src_idx_buffer[16] = {0, 8, 0, 8, 0, 8, 0, 8, 0, 8, 0, 8, 0, 8, 0, 8}; @@ -867,8 +845,7 @@ struct KerNeonXXs2NchwNchw44 { for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; const int8_t* nchw_src_ptr_last_line = - src_ptr + ic_idx * ic_stride + - 2 * iw * ic_step * pack_iw_len; + src_ptr + ic_idx * ic_stride + 2 * iw * ic_step * pack_iw_len; const int8_t* weight_ptr_oc = weight_ptr + ld_weight_oc; /** * r0-r7 c @@ -980,31 +957,28 @@ struct KerNeonXXs2NchwNchw44 { "smlal %[c13].4s, v15.4h, v19.4h\n" : - [c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]), - [c01] "+w"(c[0][1]), [c11] "+w"(c[1][1]), - [c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]), + [c00] "+w"(c[0][0]), [c10] "+w"(c[1][0]), [c01] "+w"(c[0][1]), + [c11] "+w"(c[1][1]), [c02] "+w"(c[0][2]), [c12] "+w"(c[1][2]), [c03] "+w"(c[0][3]), [c13] "+w"(c[1][3]), - [weight_ptr] "+r"(weight_ptr), - [weight_ptr_oc] "+r"(weight_ptr_oc) + [weight_ptr] "+r"(weight_ptr), [weight_ptr_oc] "+r"(weight_ptr_oc) : [vtbl] "w"(vtbl), [nchw_src_ptr] "r"(nchw_src_ptr), [nchw_src_ptr_last_line] "r"(nchw_src_ptr_last_line), [weight_step] "r"(weight_step) - : "x5", "x6", "x7", "v8", "v9", "v10", "v11", "v12", "v13", - "v14", "v15", "v16", "v17", "v18", "v19", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); + : "x5", "x6", "x7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "cc", "memory"); } store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); } }; #endif -template +template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { constexpr int filter_size = 2; constexpr int oc_step = 4; constexpr int loop_ic_step = 1; @@ -1022,31 +996,30 @@ struct KerNeonXXs2NchwNchw44 { int8x16_t src[4]; int8x16_t dot4_weight[c_dim][1]; int16x8_t temp_c[4]; - load_helper<1, 0, 16, c_dim, Vld1q_s8>(dot4_weight, weight_ptr, - ld_weight_oc); + load_helper<1, 0, 16, c_dim, Vld1q_s8>( + dot4_weight, weight_ptr, ld_weight_oc); load_helper<4, 0, 16, 0, Vld1q_s8>(src, nchw_src_ptr, 0); - cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, - temp_c); + cal_helper<0, 0, c_dim, Vdotq_s32_h, stride>(c, src, dot4_weight, temp_c); weight_ptr += oc_step * filter_size * filter_size; } store_ocx_ow4_remain_static(c, op, dst_ptr, ld_dst_oc); } }; -template +template struct KerNeonXXs2NchwNchw44 { - static void impl(const int8_t*, const int8_t*, const int32_t*, int8_t*, int, - int, int, int, const Op&) { + static void impl( + const int8_t*, const int8_t*, const int32_t*, int8_t*, int, int, int, int, + const Op&) { megdnn_assert(0, "not impl nchw_nchw44 1x1 s2"); } }; enum PACK_MODE { NO_PAD = 0, FIRST_PAD = 1, LAST_PAD = 2 }; template -MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr, - int left_pad, int right_pad, - const int iw) { +MEGDNN_ALWAYS_INLINE void pack_src_one_line( + const int8_t* inptr, int8_t* outptr, int left_pad, int right_pad, + const int iw) { const int8_t* src_row_0 = inptr; const int8_t* src_row_1 = inptr + iw; constexpr int combine_row = 2; @@ -1068,17 +1041,17 @@ MEGDNN_ALWAYS_INLINE void pack_src_one_line(const int8_t* inptr, int8_t* outptr, row0 = vdupq_n_s8(0); } int8x16x2_t pack_rows = vzipq_s8(row0, row1); -#define STORE_8S8(step) \ - vst1_s8(outptr + step * 8, \ - vreinterpret_s8_s16(vdup_laneq_s16( \ - vreinterpretq_s16_s8(pack_rows.val[0]), step))); +#define STORE_8S8(step) \ + vst1_s8(outptr + step * 8, \ + vreinterpret_s8_s16( \ + vdup_laneq_s16(vreinterpretq_s16_s8(pack_rows.val[0]), step))); UNROLL_CALL_RAW(8, STORE_8S8); #undef STORE_8S8 -#define STORE_8S8(step) \ - vst1_s8(outptr + out_gap + step * 8, \ - vreinterpret_s8_s16(vdup_laneq_s16( \ - vreinterpretq_s16_s8(pack_rows.val[1]), step))); +#define STORE_8S8(step) \ + vst1_s8(outptr + out_gap + step * 8, \ + vreinterpret_s8_s16( \ + vdup_laneq_s16(vreinterpretq_s16_s8(pack_rows.val[1]), step))); UNROLL_CALL_RAW(8, STORE_8S8); #undef STORE_8S8 @@ -1109,12 +1082,10 @@ namespace int8_direct_nchw_nchw44 { * pack interleave two adjacent row in src and repeat 4 times, store to one row * */ template <> -void pack_nchw_src_for_nchw44_conv<2>(const int8_t* inptr, int8_t* outptr, - const int ic, const int top_pad, - const int bottom_pad, const int left_pad, - const int right_pad, const int ih, - const int iw, const int, const int, - int8_t*) { +void pack_nchw_src_for_nchw44_conv<2>( + const int8_t* inptr, int8_t* outptr, const int ic, const int top_pad, + const int bottom_pad, const int left_pad, const int right_pad, const int ih, + const int iw, const int, const int, int8_t*) { constexpr int src_expand = 4; constexpr int oh_step = 2; const int oh = ih + top_pad + bottom_pad; @@ -1127,16 +1098,16 @@ void pack_nchw_src_for_nchw44_conv<2>(const int8_t* inptr, int8_t* outptr, if (top_pad - oh_idx >= oh_step) { memset(outptr, 0, oh_step * ow * sizeof(int8_t)); } else { - pack_src_one_line(inptr, outptr, left_pad, - right_pad, iw); + pack_src_one_line( + inptr, outptr, left_pad, right_pad, iw); inptr += iw; } outptr += oh_step * ow; } for (; oh_idx < oh_end; oh_idx += oh_step) { - pack_src_one_line(inptr, outptr, left_pad, - right_pad, iw); + pack_src_one_line( + inptr, outptr, left_pad, right_pad, iw); inptr += oh_step * iw; outptr += oh_step * ow; } @@ -1146,8 +1117,8 @@ void pack_nchw_src_for_nchw44_conv<2>(const int8_t* inptr, int8_t* outptr, if (last_pad >= 0) { memset(outptr, 0, oh_step * ow * sizeof(int8_t)); } else { - pack_src_one_line(inptr, outptr, left_pad, - right_pad, iw); + pack_src_one_line( + inptr, outptr, left_pad, right_pad, iw); inptr += iw; } outptr += oh_step * ow; @@ -1160,9 +1131,9 @@ void pack_nchw_src_for_nchw44_conv<2>(const int8_t* inptr, int8_t* outptr, * pack interleave two adjacent row in filter to one row * */ template <> -void pack_nchw44_weight_for_nchw_conv<2>(const int8_t* inptr, int8_t* outptr, - const int ic, const int fh, - const int fw, const int oc) { +void pack_nchw44_weight_for_nchw_conv<2>( + const int8_t* inptr, int8_t* outptr, const int ic, const int fh, const int fw, + const int oc) { constexpr int oc_step = 4; constexpr int ic_step = 2; constexpr int fh_step = 2; @@ -1185,9 +1156,8 @@ void pack_nchw44_weight_for_nchw_conv<2>(const int8_t* inptr, int8_t* outptr, for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { const int fh_offset = fh_idx * fw * filter_stride; for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; + const int8_t* filter_ptr = + inptr + fh_offset + fw_idx * filter_stride + ic_offset; int8x8_t row_0 = vld1_s8(filter_ptr); int8x8_t row_1 = vld1_s8(filter_ptr + fw * filter_stride); int8x16_t combine_row = vcombine_s8(row_0, row_1); @@ -1201,9 +1171,8 @@ void pack_nchw44_weight_for_nchw_conv<2>(const int8_t* inptr, int8_t* outptr, if (fh_remain > 0) { const int fh_offset = fh_end * fw * filter_stride; for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; + const int8_t* filter_ptr = + inptr + fh_offset + fw_idx * filter_stride + ic_offset; int8x8_t row_0 = vld1_s8(filter_ptr); int8x8_t row_1 = vld1_s8(filter_ptr + filter_stride); int8x16_t combine_row = vcombine_s8(row_0, row_1); @@ -1214,14 +1183,11 @@ void pack_nchw44_weight_for_nchw_conv<2>(const int8_t* inptr, int8_t* outptr, output_ic1 += 8; } if (fw_remain > 0) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_end * filter_stride + - ic_offset; + const int8_t* filter_ptr = + inptr + fh_offset + fw_end * filter_stride + ic_offset; int8x8_t row_0 = vld1_s8(filter_ptr); - vst1_lane_s32((int32_t*)output_ic0, - vreinterpret_s32_s8(row_0), 0); - vst1_lane_s32((int32_t*)output_ic1, - vreinterpret_s32_s8(row_0), 1); + vst1_lane_s32((int32_t*)output_ic0, vreinterpret_s32_s8(row_0), 0); + vst1_lane_s32((int32_t*)output_ic1, vreinterpret_s32_s8(row_0), 1); output_ic0 += 4; output_ic1 += 4; } @@ -1233,9 +1199,8 @@ void pack_nchw44_weight_for_nchw_conv<2>(const int8_t* inptr, int8_t* outptr, for (int fh_idx = 0; fh_idx < fh_end; fh_idx += fh_step) { const int fh_offset = fh_idx * fw * filter_stride; for (int fw_idx = 0; fw_idx < fw; ++fw_idx) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; + const int8_t* filter_ptr = + inptr + fh_offset + fw_idx * filter_stride + ic_offset; int8x8_t row_0 = vreinterpret_s8_s32( vld1_dup_s32((const int32_t*)(filter_ptr))); int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( @@ -1249,22 +1214,20 @@ void pack_nchw44_weight_for_nchw_conv<2>(const int8_t* inptr, int8_t* outptr, if (fh_remain > 0) { const int fh_offset = fh_end * fw * filter_stride; for (int fw_idx = 0; fw_idx < fw_end; fw_idx += fw_step) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_idx * filter_stride + - ic_offset; + const int8_t* filter_ptr = + inptr + fh_offset + fw_idx * filter_stride + ic_offset; int8x8_t row_0 = vreinterpret_s8_s32( vld1_dup_s32((const int32_t*)(filter_ptr))); - int8x8_t row_1 = vreinterpret_s8_s32(vld1_dup_s32( - (const int32_t*)(filter_ptr + filter_stride))); + int8x8_t row_1 = vreinterpret_s8_s32( + vld1_dup_s32((const int32_t*)(filter_ptr + filter_stride))); int8x16_t combine_row = vcombine_s8(row_0, row_1); combine_row = vqtbl1q_s8(combine_row, ic2_idx_h); vst1_s8(output_ic0, vget_low_s8(combine_row)); output_ic0 += 8; } if (fw_remain > 0) { - const int8_t* filter_ptr = inptr + fh_offset + - fw_end * filter_stride + - ic_offset; + const int8_t* filter_ptr = + inptr + fh_offset + fw_end * filter_stride + ic_offset; *(int32_t*)(output_ic0) = *(const int32_t*)(filter_ptr); output_ic0 += 4; } @@ -1277,16 +1240,14 @@ void pack_nchw44_weight_for_nchw_conv<2>(const int8_t* inptr, int8_t* outptr, template struct ConvDiectStrideInt8NchwNchw44 { - static void impl(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, int8_t* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op) { + static void impl( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { MEGDNN_MARK_USED_VAR(temp); constexpr size_t stride = 2; constexpr size_t fh = filter_size; - constexpr size_t fw = - stride == 2 ? filter_size : (filter_size + 3) / 4 * 4; + constexpr size_t fw = stride == 2 ? filter_size : (filter_size + 3) / 4 * 4; constexpr size_t ic_step = 1; constexpr size_t big_oc_step = 8; constexpr size_t oc_step = 4; @@ -1306,19 +1267,17 @@ struct ConvDiectStrideInt8NchwNchw44 { using remain_fun = std::function; + const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op& op)>; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - KerNeonXXs2NchwNchw44::impl; \ - kern_small_oc_remain = \ - KerNeonXXs2NchwNchw44::impl; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = KerNeonXXs2NchwNchw44< \ + bias_mode, Op, step, filter_size, big_oc_step, stride>::impl; \ + kern_small_oc_remain = KerNeonXXs2NchwNchw44< \ + bias_mode, Op, step, filter_size, oc_step, stride>::impl; \ break; UNROLL_CALL_RAW(4, cb); @@ -1331,28 +1290,26 @@ struct ConvDiectStrideInt8NchwNchw44 { const size_t weight_offset = oc_idx * ic * fh * fw; for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44< + bias_mode, Op, 0, filter_size, big_oc_step, stride>:: + impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, + op); } if (ow_remain > 0) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } @@ -1361,46 +1318,42 @@ struct ConvDiectStrideInt8NchwNchw44 { const size_t weight_offset = oc_idx * ic * fh * fw; for (size_t oh_idx = 0; oh_idx < oh; oh_idx += oh_step) { for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); + const size_t src_offset = + (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44< + bias_mode, Op, 0, filter_size, oc_step, stride>:: + impl(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc, + op); } if (ow_remain > 0) { - const size_t src_offset = (oh_idx * stride_h * iw + - ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const size_t dst_offset = oc_idx * img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, - filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); + const size_t src_offset = + (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const size_t dst_offset = + oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op); } } } } }; -#define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ - template struct ConvDiectStrideInt8NchwNchw44; - -#define INSTANCE_OP_PARAM(stride, filter, bias_mode) \ - INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ - TypeCvtOp) \ - INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ - ReluOp) \ - INSTANCE_CONV_KERN_FUN(stride, filter, bias_mode, \ - HSwishOp) +#define INSTANCE_CONV_KERN_FUN(stride, filter_size, bias_mode, Op) \ + template struct ConvDiectStrideInt8NchwNchw44; + +#define INSTANCE_OP_PARAM(stride, filter, bias_mode) \ + INSTANCE_CONV_KERN_FUN( \ + stride, filter, bias_mode, TypeCvtOp) \ + INSTANCE_CONV_KERN_FUN( \ + stride, filter, bias_mode, ReluOp) \ + INSTANCE_CONV_KERN_FUN( \ + stride, filter, bias_mode, HSwishOp) #define INSTANCE_BIAS_MODE_PARAM(stride, filter) \ INSTANCE_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp index dcef72d4..27d3cae4 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_algo.cpp @@ -22,10 +22,9 @@ using namespace megdnn; using namespace arm_common; using conv_fun = std::function; + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range)>; MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44) static void get_rectified_size( @@ -52,8 +51,7 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { int IH2, IW2; get_rectified_size(param, IH2, IW2); if (group == 1) { - size_t src_size = - batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; + size_t src_size = batch * group * IC * IH2 * IW2 * sizeof(int8_t) * src_expand; size_t weight_size = group * OC * IC * FH * FW * sizeof(int8_t); return {nullptr, {src_size, weight_size}}; } else { @@ -64,10 +62,9 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { } }; -static void copy_padding_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { +static void copy_padding_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { int IH = kern_param.isz[0]; int IW = kern_param.isz[1]; int IC = kern_param.filter_meta.icpg; @@ -99,11 +96,11 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); //! copy to sptr_base to eliminate padding effect - int8_t* sptr_base = static_cast(bundle.get(0)) + - (workspace_batch_id * GROUP * padding_group_size + - workspace_group_id * padding_group_size + - workspace_ic * IH2 * IW2) * - expend_element; + int8_t* sptr_base = + static_cast(bundle.get(0)) + + (workspace_batch_id * GROUP * padding_group_size + + workspace_group_id * padding_group_size + workspace_ic * IH2 * IW2) * + expend_element; size_t nr_ic = workspace_ic_block; if (GROUP > 1) { nr_ic = IC; @@ -125,13 +122,11 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, } } -template -static void do_conv_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids, - const CpuNDRange& ncb_range) { +template +static void do_conv_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range) { size_t OH = kern_param.osz[0]; size_t OW = kern_param.osz[1]; size_t FH = kern_param.filter_meta.spatial[0]; @@ -141,13 +136,11 @@ static void do_conv_kern(const WorkspaceBundle& bundle, size_t GROUP = kern_param.filter_meta.group; int IH2, IW2; get_rectified_size(kern_param, IH2, IW2); - bool need_post_process = - kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) Op op(1.f, 4.f); if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; + float scale_bias = kern_param.bias_type.param().scale; float scale_dst = kern_param.dst_type.param().scale; op = Op(scale_bias, scale_dst); } @@ -167,27 +160,25 @@ static void do_conv_kern(const WorkspaceBundle& bundle, if (oc_id == (oc_block_num - 1)) { oc_block = OC - oc_id * nr_pack_per_step * pack_c; } - megdnn_assert(oc_block % pack_c == 0, - "oc must be devisible by 4, but oc = %zu", oc_block); + megdnn_assert( + oc_block % pack_c == 0, "oc must be devisible by 4, but oc = %zu", + oc_block); const int8_t* sptr = static_cast(bundle.get(0)) + workspace_batch_id * GROUP * padding_group_size * src_expand_size + workspace_group_id * padding_group_size * src_expand_size; - const int8_t* fptr = - kern_param.filter(group_id) + oc_idx * FH * FW * IC; + const int8_t* fptr = kern_param.filter(group_id) + oc_idx * FH * FW * IC; DstType* dst = reinterpret_cast( kern_param.dst(batch_id, group_id, oc_idx)); - const int32_t* bptr = - kern_param.bias(batch_id, group_id) + oc_idx; + const int32_t* bptr = kern_param.bias(batch_id, group_id) + oc_idx; auto packed_weight = reinterpret_cast(bundle.get(1)) + group_id * OC * IC * FH * FW + oc_idx * IC * FH * FW; - int8_direct_nchw44::nchw44_pack_filter(fptr, packed_weight, - oc_block / 4 * IC / 4 * FH * FW); - int8_direct_nchw44::conv_direct_int8_nchw44( - sptr, packed_weight, bptr, nullptr, static_cast(dst), - oc_block, IC, IH2, IW2, OH, OW, op); + int8_direct_nchw44::nchw44_pack_filter( + fptr, packed_weight, oc_block / 4 * IC / 4 * FH * FW); + int8_direct_nchw44::conv_direct_int8_nchw44( + sptr, packed_weight, bptr, nullptr, static_cast(dst), oc_block, + IC, IH2, IW2, OH, OW, op); } bool ConvBiasImpl::AlgoS8DirectNCHW44::usable( @@ -206,33 +197,32 @@ bool ConvBiasImpl::AlgoS8DirectNCHW44::usable( param.dst_type.enumv() == DTypeEnum::QuantizedS32))) && (fm.format == param::Convolution::Format::NCHW44) && (oc % 4 == 0 && ic % 4 == 0 && oc >= 4) && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && - (fm.stride[0] == 2 || fm.stride[0] == 1) && fh == fw && - (fh == 2 || fh == 3 || fh == 5 || fh == 7) && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == fm.stride[1] && (fm.stride[0] == 2 || fm.stride[0] == 1) && + fh == fw && (fh == 2 || fh == 3 || fh == 5 || fh == 7) && param.bias_mode != BiasMode::BIAS; return avaible; } bool ConvBiasImpl::AlgoS8DirectNCHW44::is_preferred( - const NCBKernSizeParam& param) const { + const NCBKernSizeParam& param) const { // TODO: benchmark and fix MEGDNN_MARK_USED_VAR(param); return false; } size_t ConvBiasImpl::AlgoS8DirectNCHW44::get_workspace( - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, - midout_iv("AlgoS8DirectNCHW44::get_workspace"_hash)) { + const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8_nchw44, + midout_iv("AlgoS8DirectNCHW44::get_workspace"_hash)) { return get_bundle(param).total_size_in_bytes(); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( const NCBKernSizeParam& param) const { auto fm = param.filter_meta; size_t N = param.n; @@ -247,47 +237,49 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( bool need_post_process = param.dst_type.enumv() == DTypeEnum::QuantizedS8; // NOTE: remain_w is not used to gen hash of midout for compatible with changing // shape runtime -#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44, \ - midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode, op) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_int8_nchw44, \ + midout_iv(#stride #dst_type #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); -#define GET_OP_PARAM(stride, filter, bias_mode) \ - if (need_post_process) { \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(stride, dt_qint8, filter, bias_mode, \ - \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0, "no supported noline mode"); \ - break; \ - } \ - } else { \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(stride, dt_int32, filter, bias_mode, \ - NoneOp) \ - break; \ - default: \ - megdnn_assert( \ - 0, \ - "only support IDENTITY mode when dst is not qint8"); \ - break; \ - } \ +#define GET_OP_PARAM(stride, filter, bias_mode) \ + if (need_post_process) { \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN( \ + stride, dt_qint8, filter, bias_mode, \ + \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN( \ + stride, dt_qint8, filter, bias_mode, \ + \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN( \ + stride, dt_qint8, filter, bias_mode, \ + \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0, "no supported noline mode"); \ + break; \ + } \ + } else { \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN( \ + stride, dt_int32, filter, bias_mode, NoneOp) \ + break; \ + default: \ + megdnn_assert(0, "only support IDENTITY mode when dst is not qint8"); \ + break; \ + } \ } #define GET_BIAS_MODE_PARAM(stride, filter) \ @@ -330,8 +322,9 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( DISPATCH_CONV_KERN(2); break; default: - megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", - param.filter_meta.stride[0]) + megdnn_throw(ssprintf( + "Unsupport stride size %u for the first conv", + param.filter_meta.stride[0]) .c_str()); break; } @@ -353,11 +346,11 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( } if (group == 1) { CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; - auto copy_padding = [wbundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [wbundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { wbundle.set(kern_param.workspace_ptr); - copy_padding_kern(wbundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(wbundle, kern_param, ncb_index, ncb_index.ndrange_id); }; constexpr size_t pack_ic = 4; ret_kerns.push_back({copy_padding, {N, group, div_ceil(IC, pack_ic)}}); @@ -365,8 +358,8 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { wbundle.set(kern_param.workspace_ptr); - do_conv_fun(wbundle, kern_param, ncb_index, ncb_index.ndrange_id, - ncb_range); + do_conv_fun( + wbundle, kern_param, ncb_index, ncb_index.ndrange_id, ncb_range); }; ret_kerns.push_back({do_conv, ncb_range}); } else { @@ -375,10 +368,11 @@ ConvBiasImpl::AlgoS8DirectNCHW44::dispatch_kerns( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { wbundle.set(kern_param.workspace_ptr); - copy_padding_kern(wbundle, kern_param, ncb_index, - {0, ncb_index.thread_id, 0}); - do_conv_fun(wbundle, kern_param, ncb_index, - {0, ncb_index.thread_id, 0}, ncb_range); + copy_padding_kern( + wbundle, kern_param, ncb_index, {0, ncb_index.thread_id, 0}); + do_conv_fun( + wbundle, kern_param, ncb_index, {0, ncb_index.thread_id, 0}, + ncb_range); }; ret_kerns.push_back({do_conv, ncb_range}); } diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h index 71bd9c0a..3e5cea4d 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw44_kern.h @@ -35,8 +35,7 @@ low 64 bit <0, 0> <0, 1> <1, 2> <1, 3> | <2, 0> <2, 1> <3, 2> <3, 3> --------------------------------------------------------------------- high 64 bit <0, 3> <0, 2> <1, 1> <1, 0> | <2, 3> <2, 2> <3, 1> <3, 0> **/ -static inline void nchw44_pack_filter(const int8_t* src, int8_t* dst, - int length) { +static inline void nchw44_pack_filter(const int8_t* src, int8_t* dst, int length) { static const uint8_t weight_idx_buffer[16] = {0, 4, 9, 13, 2, 6, 11, 15, 12, 8, 5, 1, 14, 10, 7, 3}; constexpr int simd_len = 16; @@ -69,26 +68,23 @@ static inline void nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { } } -template +template < + BiasMode bias_mode, typename Op, int filter_size, typename DstType, int stride> struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, DstType* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op); + static void impl( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op); }; -template -void conv_direct_int8_nchw44(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, DstType* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op) { - ConvDirectInt8Nchw44Choose::impl(src, filter, bias, temp, dst, oc, - ic, ih, iw, oh, ow, op); +template < + BiasMode bias_mode, typename Op, int filter_size, typename DstType, int stride> +void conv_direct_int8_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + DstType* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op) { + ConvDirectInt8Nchw44Choose::impl( + src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); } } // namespace int8_direct_nchw44 diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp index be29f345..fb82f210 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_algo.cpp @@ -23,10 +23,9 @@ using namespace megdnn; using namespace arm_common; using conv_fun = std::function; + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range)>; MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw_nchw44) static void get_rectified_size( @@ -68,8 +67,7 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { const size_t src_expand = stride_h == 2 ? 4 : 16; get_rectified_size(param, ih2, iw2, oh2, ow2); megdnn_assert(group == 1, "only support group == 1 now"); - size_t src_size = - batch * group * ic * ih2 * iw2 * sizeof(int8_t) * src_expand; + size_t src_size = batch * group * ic * ih2 * iw2 * sizeof(int8_t) * src_expand; size_t weight_size = group * oc * ic * fh * fw * sizeof(int8_t); size_t tmp_size = 0; if (stride_h == 1) { @@ -79,10 +77,9 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { return {nullptr, {src_size, weight_size, tmp_size * param.nr_threads}}; }; -static void copy_padding_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { +static void copy_padding_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { int ih = kern_param.isz[0]; int iw = kern_param.isz[1]; int ic = kern_param.filter_meta.icpg; @@ -109,11 +106,11 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, const int8_t* sptr = static_cast( kern_param.src(batch_id, group_id, workspace_ic_id, 1, 1)); //! copy to sptr_base to eliminate padding effect - int8_t* sptr_base = static_cast(bundle.get(0)) + - (workspace_batch_id * group * padding_group_size + - workspace_group_id * padding_group_size + - workspace_ic * ih2 * iw2) * - src_expand; + int8_t* sptr_base = + static_cast(bundle.get(0)) + + (workspace_batch_id * group * padding_group_size + + workspace_group_id * padding_group_size + workspace_ic * ih2 * iw2) * + src_expand; if (stride_h == 1) { const size_t tmp_size = get_temp_bytes(iw, pw); int8_t* tmp_ptr = reinterpret_cast(bundle.get(2)) + @@ -125,9 +122,9 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, sptr, sptr_base, 1, ph, ph, pw, pw, ih, iw, iw2, pw, nullptr); } } -static void pack_weight(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index) { +static void pack_weight( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) { const int group_id = ncb_index.ndrange_id[0]; int fh = kern_param.filter_meta.spatial[0]; int fw = kern_param.filter_meta.spatial[1]; @@ -137,8 +134,7 @@ static void pack_weight(const WorkspaceBundle& bundle, int fw2 = stride_h == 2 ? fw : round_up(fw, 4); int oc_block = oc; int oc_idx = 0; - const int8_t* fptr = - kern_param.filter(group_id) + oc_idx * fh * fw * ic; + const int8_t* fptr = kern_param.filter(group_id) + oc_idx * fh * fw * ic; auto packed_weight = reinterpret_cast(bundle.get(1)) + group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2; @@ -151,11 +147,10 @@ static void pack_weight(const WorkspaceBundle& bundle, } } template -static void do_conv_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids, - const CpuNDRange& ncb_range) { +static void do_conv_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range) { int oh = kern_param.osz[0]; int ow = kern_param.osz[1]; int fh = kern_param.filter_meta.spatial[0]; @@ -166,13 +161,11 @@ static void do_conv_kern(const WorkspaceBundle& bundle, int group = kern_param.filter_meta.group; int ih2, iw2, oh2, ow2; get_rectified_size(kern_param, ih2, iw2, oh2, ow2); - bool need_post_process = - kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) Op op = Op(1.0f, 4.0f); if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; + float scale_bias = kern_param.bias_type.param().scale; float scale_dst = kern_param.dst_type.param().scale; op = Op(scale_bias, scale_dst); } @@ -192,35 +185,31 @@ static void do_conv_kern(const WorkspaceBundle& bundle, if (oc_id == (oc_block_num - 1)) { oc_block = oc - oc_id * nr_pack_per_step * pack_c; } - megdnn_assert(oc_block % pack_c == 0, - "oc must be devisible by 4, but oc = %d", oc_block); + megdnn_assert( + oc_block % pack_c == 0, "oc must be devisible by 4, but oc = %d", oc_block); const int8_t* sptr = static_cast(bundle.get(0)) + workspace_batch_id * group * padding_group_size * src_expand_size + workspace_group_id * padding_group_size * src_expand_size; int8_t* dst = reinterpret_cast( - reinterpret_cast( - kern_param.dst(batch_id, group_id)) + + reinterpret_cast(kern_param.dst(batch_id, group_id)) + oc_idx * oh * ow); - const int32_t* bptr = - kern_param.bias(batch_id, group_id) + oc_idx; + const int32_t* bptr = kern_param.bias(batch_id, group_id) + oc_idx; int8_t* packed_weight = reinterpret_cast(bundle.get(1)) + - group_id * oc * ic * fh * fw2 + - oc_idx * ic * fh * fw2; - int8_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44( - sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh, - ow, op); + group_id * oc * ic * fh * fw2 + oc_idx * ic * fh * fw2; + int8_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44< + bias_mode, Op, filter, stride>( + sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih2, iw2, oh, ow, + op); } -bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { return nchw_nchwxx_valid( - param.src_type.enumv(), param.filter_type.enumv(), - param.dst_type.enumv(), param.filter_meta, param.bias_mode, - param.nonlineMode); + param.src_type.enumv(), param.filter_type.enumv(), param.dst_type.enumv(), + param.filter_meta, param.bias_mode, param.nonlineMode); } bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred( @@ -232,16 +221,16 @@ bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred( size_t ConvBiasImpl::AlgoS8DirectNCHWNCHW44::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44, - midout_iv("AlgoS8DirectNCHWNCHW44::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8_nchw_nchw44, + midout_iv("AlgoS8DirectNCHWNCHW44::get_workspace"_hash)) { return get_bundle(param).total_size_in_bytes(); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( const NCBKernSizeParam& param) const { auto fm = param.filter_meta; size_t N = param.n; @@ -251,30 +240,34 @@ ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( conv_fun do_conv_fun = nullptr; // NOTE: remain_w is not used to gen hash of midout for compatible with changing // shape runtime -#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44, \ - midout_iv(#stride #filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_int8_nchw_nchw44, \ + midout_iv(#stride #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); -#define GET_OP_PARAM(stride, filter, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(stride, filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(stride, filter) \ switch (param.bias_mode) { \ @@ -319,8 +312,9 @@ ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( DISPATCH_CONV_KERN(2); break; default: - megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", - param.filter_meta.stride[0]) + megdnn_throw(ssprintf( + "Unsupport stride size %u for the first conv", + param.filter_meta.stride[0]) .c_str()); break; } @@ -337,15 +331,17 @@ ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( constexpr size_t pack_oc = 8; size_t oc_step = pack_oc; - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {N, group, fm.icpg}}); - auto do_pack_weight = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto do_pack_weight = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); pack_weight(bundle, kern_param, ncb_index); }; @@ -356,8 +352,7 @@ ConvBiasImpl::AlgoS8DirectNCHWNCHW44::dispatch_kerns( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, - ncb_range); + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, ncb_range); }; ret_kerns.push_back({do_conv, ncb_range}); diff --git a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h index 3e46b807..403968b0 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h @@ -24,35 +24,29 @@ namespace arm_common { namespace int8_direct_nchw_nchw44 { template -void pack_nchw_src_for_nchw44_conv(const int8_t* inptr, int8_t* outptr, - const int ic, const int top_pad, - const int bottom_pad, const int left_pad, - const int right_pad, const int ih, - const int iw, const int iw2, const int pw, - int8_t* temp_ptr); +void pack_nchw_src_for_nchw44_conv( + const int8_t* inptr, int8_t* outptr, const int ic, const int top_pad, + const int bottom_pad, const int left_pad, const int right_pad, const int ih, + const int iw, const int iw2, const int pw, int8_t* temp_ptr); template -void pack_nchw44_weight_for_nchw_conv(const int8_t* inptr, int8_t* outptr, - const int ic, const int fh, const int fw, - const int oc); +void pack_nchw44_weight_for_nchw_conv( + const int8_t* inptr, int8_t* outptr, const int ic, const int fh, const int fw, + const int oc); template struct ConvDiectStrideInt8NchwNchw44 { - static void impl(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, int8_t* dst, - const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow, - const Op& op); + static void impl( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t oc, const size_t ic, const size_t ih, + const size_t iw, const size_t oh, const size_t ow, const Op& op); }; template -static void conv_direct_int8_nchw_nchw44(const int8_t* src, - const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const size_t oc, - const size_t ic, const size_t ih, - const size_t iw, const size_t oh, - const size_t ow, const Op& op) { +static void conv_direct_int8_nchw_nchw44( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow, const Op& op) { ConvDiectStrideInt8NchwNchw44::impl( src, filter, bias, temp, dst, oc, ic, ih, iw, oh, ow, op); } diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp index bfa8ac8b..8224d006 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp @@ -23,15 +23,13 @@ using namespace megdnn; using namespace arm_common; using conv_fun = std::function; + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range)>; MIDOUT_DECL(megdnn_arm_common_conv_bias_int8_nchw44_dot) namespace { -static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, - const int iw2, - const int stride) { +static inline size_t get_perthread_cache_bytes( + const int ic, const int ih2, const int iw2, const int stride) { //! border_size is used to avoid read illegal memory constexpr int cacheline_size = 64; constexpr int border_size = 2 * cacheline_size; @@ -56,8 +54,8 @@ static void get_rectified_size( int ic = param.filter_meta.icpg; int iw = param.isz[1]; int oh = param.osz[0]; - int block_oh = l2_block_helper(param.nr_threads, oh, - ic * iw * sizeof(int8_t) * stride_h); + int block_oh = + l2_block_helper(param.nr_threads, oh, ic * iw * sizeof(int8_t) * stride_h); ih2 = block_oh * stride_h + filter_h - stride_h; iw2 = iw + 2 * static_cast(fm.padding[1]); } @@ -80,13 +78,12 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { temp_size = get_temp_bytes(iw, pw); } return {nullptr, - {src_size * param.nr_threads, weight_size, - temp_size * param.nr_threads}}; + {src_size * param.nr_threads, weight_size, temp_size * param.nr_threads}}; }; -void do_weight_trans(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex&, const CpuNDRange&) { +void do_weight_trans( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex&, const CpuNDRange&) { const int ic = kern_param.filter_meta.icpg; const int oc = kern_param.filter_meta.ocpg; const int fh = kern_param.filter_meta.spatial[0]; @@ -99,10 +96,10 @@ void do_weight_trans(const WorkspaceBundle& bundle, } template -static void do_conv_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange&, const CpuNDRange&) { +static void do_conv_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange&, + const CpuNDRange&) { const int oh = kern_param.osz[0]; const int ow = kern_param.osz[1]; const int fh = kern_param.filter_meta.spatial[0]; @@ -124,24 +121,22 @@ static void do_conv_kern(const WorkspaceBundle& bundle, const int group_id = ncb_index.ndrange_id[1]; constexpr int oc_idx = 0; int oc_block = oc; - int oh_block = l2_block_helper(kern_param.nr_threads, oh, - ic * iw * sizeof(int8_t) * stride_h); + int oh_block = l2_block_helper( + kern_param.nr_threads, oh, ic * iw * sizeof(int8_t) * stride_h); const int oh_idx = ncb_index.ndrange_id[2]; const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); const int ih_real = oh_block_real * stride_h + fh - stride_h; const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0); const int src_bottom_pad = std::max( - (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, - 0); + (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, 0); const int remain_right_pad = std::max(iw2 - iw - pw, 0); const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw; - const int8_t* origin_sptr = - static_cast( - kern_param.src(batch_id, group_id, 0, 1, 1)) + - src_offset; + const int8_t* origin_sptr = static_cast(kern_param.src( + batch_id, group_id, 0, 1, 1)) + + src_offset; const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2, stride_w); - int8_t* sptr = reinterpret_cast(bundle.get(0)) + - ncb_index.thread_id * src_size; + int8_t* sptr = + reinterpret_cast(bundle.get(0)) + ncb_index.thread_id * src_size; int8_t* tmp_ptr = nullptr; if (stride == 1) { const size_t tmp_size = get_temp_bytes(iw, pw); @@ -159,14 +154,13 @@ static void do_conv_kern(const WorkspaceBundle& bundle, oh_idx * oh_block * ow * pack_c; const int bias_offset = oc_idx; - const int32_t* bptr = - kern_param.bias(batch_id, group_id) + bias_offset; + const int32_t* bptr = kern_param.bias(batch_id, group_id) + bias_offset; float scale_bias = kern_param.bias_type.param().scale; float scale_dst = kern_param.dst_type.param().scale; Op op(scale_bias, scale_dst); - dot_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44_dot( + dot_direct_nchw_nchw44::conv_direct_int8_nchw_nchw44_dot< + bias_mode, Op, filter, stride>( sptr, fptr, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, oh_block_real, ow, op); } @@ -175,28 +169,27 @@ static void do_conv_kern(const WorkspaceBundle& bundle, bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy) const { - if (!cpuinfo_has_arm_neon_dot()){ + if (!cpuinfo_has_arm_neon_dot()) { return false; } return nchw_nchwxx_valid( - param.src_type.enumv(), param.filter_type.enumv(), - param.dst_type.enumv(), param.filter_meta, param.bias_mode, - param.nonlineMode); + param.src_type.enumv(), param.filter_type.enumv(), param.dst_type.enumv(), + param.filter_meta, param.bias_mode, param.nonlineMode); } size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot, - midout_iv("AlgoDotS8DirectNCHWNCHW44::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8_nchw44_dot, + midout_iv("AlgoDotS8DirectNCHWNCHW44::get_workspace"_hash)) { return get_bundle(param).total_size_in_bytes(); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns( - const NCBKernSizeParam& param) const { +SmallVector ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44:: + dispatch_kerns(const NCBKernSizeParam& param) const { auto fm = param.filter_meta; const int batch = param.n; const int group = fm.group; @@ -204,30 +197,34 @@ ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns( conv_fun do_conv_fun = nullptr; // NOTE: remain_w is not used to gen hash of midout for compatible with // shape runtime -#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot, \ - midout_iv(#stride #filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_int8_nchw44_dot, \ + midout_iv(#stride #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); -#define GET_OP_PARAM(stride, filter, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(stride, filter, bias_mode, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(stride, filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, \ + TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, \ + ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN( \ + stride, filter, bias_mode, \ + HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(stride, filter) \ @@ -291,15 +288,16 @@ ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns( int iw = param.isz[1]; int stride_h = param.filter_meta.stride[0]; - int oh_block = l2_block_helper(param.nr_threads, oh, - ic * iw * sizeof(int8_t) * stride_h); + int oh_block = + l2_block_helper(param.nr_threads, oh, ic * iw * sizeof(int8_t) * stride_h); - CpuNDRange ncb_range = {static_cast(batch), - static_cast(group), - static_cast(div_ceil(oh, oh_block))}; + CpuNDRange ncb_range = { + static_cast(batch), static_cast(group), + static_cast(div_ceil(oh, oh_block))}; - auto do_trans_weight = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto do_trans_weight = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); do_weight_trans(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; @@ -309,8 +307,7 @@ ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::dispatch_kerns( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, - ncb_range); + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, ncb_range); }; ret_kerns.push_back({do_conv, ncb_range}); return ret_kerns; diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h index bf930f9d..ba9394ec 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h @@ -23,17 +23,19 @@ namespace megdnn { namespace arm_common { namespace dot_direct_nchw_nchw44 { -template +template < + int src_idx, int weight_idx, int c_dim, typename Func, int ow_block, int stride, + typename T, typename T2, typename T3, typename T4> struct ShiftCalHelper { static void impl(T& c, T2& src, T3& weight); }; -template +template < + int src_idx, int weight_idx, int c_dim, typename FUNC, int ow_block, int stride, + typename T, typename T2, typename T3> inline void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl(c, src, weight); + ShiftCalHelper:: + impl(c, src, weight); }; //! OCHelper is used to trans oc_block to row number of result regs template @@ -57,26 +59,24 @@ public: /** * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel * */ -template +template < + BiasMode bias_mode, typename Op, int remain_w, int filter_size, int oc_block, + int ow_block, int stride> struct KerNeonDotXXs2Nchw44Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op& op); + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, + int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op); }; template -void pack_src_int8_nchw_nchw44_dot(int8_t* sptr_base, const int8_t* sptr_origin, - const int, const int pw, const int, - const int ih, const int iw, const int iw2, - const int pad_top, const int pad_bottom, - const int ic, const int ic_stride, int8_t*); +void pack_src_int8_nchw_nchw44_dot( + int8_t* sptr_base, const int8_t* sptr_origin, const int, const int pw, + const int, const int ih, const int iw, const int iw2, const int pad_top, + const int pad_bottom, const int ic, const int ic_stride, int8_t*); -static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr, - const int8_t* src_ptr, - const int oc, const int ic, - const int fh, const int fw, - const int fw2) { +static inline void pack_weight_int8_nchw_nchw44_dot( + int8_t* dst_ptr, const int8_t* src_ptr, const int oc, const int ic, + const int fh, const int fw, const int fw2) { constexpr int oc_step = 4; const int fw_remain = fw2 - fw; const int dst_ic_stride = fh * fw2; @@ -87,8 +87,8 @@ static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr, rep_step(oc_idx, oc, oc_step) { int32_t* dst_temp_ptr = reinterpret_cast(dst_ptr + oc_idx * ic * fh * fw2); - const int32_t* src_temp_ptr = reinterpret_cast( - src_ptr + oc_idx * ic * fh * fw); + const int32_t* src_temp_ptr = + reinterpret_cast(src_ptr + oc_idx * ic * fh * fw); // transpose ic and pad rep(fh_idx, fh) { rep(fw_idx, fw) { @@ -110,19 +110,16 @@ static inline void pack_weight_int8_nchw_nchw44_dot(int8_t* dst_ptr, rep_step(idx, oc_step_stride, 16) { int8x16_t temp = vld1q_s8(trans_dst_temp_ptr + idx); - vst1q_s8(trans_dst_temp_ptr + idx, - vqtbl1q_s8(temp, tbl_transpose_4x4)); + vst1q_s8(trans_dst_temp_ptr + idx, vqtbl1q_s8(temp, tbl_transpose_4x4)); } } } template -void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, - const int32_t* bias, int32_t* temp, - int8_t* dst, const int oc, const int ic, - const int ih, const int iw, const int oh, - const int oh_block, const int ow, - const Op& op); +void conv_direct_int8_nchw_nchw44_dot( + const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, + int8_t* dst, const int oc, const int ic, const int ih, const int iw, + const int oh, const int oh_block, const int ow, const Op& op); } // namespace dot_direct_nchw_nchw44 } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/int8/strategy.h b/dnn/src/arm_common/conv_bias/int8/strategy.h index dbab2f4b..9217e0e2 100644 --- a/dnn/src/arm_common/conv_bias/int8/strategy.h +++ b/dnn/src/arm_common/conv_bias/int8/strategy.h @@ -18,13 +18,13 @@ namespace megdnn { namespace arm_common { namespace winograd { -MEGDNN_REG_WINOGRAD_STRATEGY(int8_t, int8_t, int16_t, int, 2, 3, 8, 8, - winograd_2x3_8x8_s8) -MEGDNN_REG_WINOGRAD_STRATEGY(int8_t, int8_t, int16_t, int, 2, 3, 8, 8, - winograd_2x3_8x8_s8_nchw44) -MEGDNN_REG_WINOGRAD_STRATEGY(int8_t, int8_t, float, float, 2, 3, 4, 4, - winograd_2x3_4x4_s8_f32_nchw44) -} +MEGDNN_REG_WINOGRAD_STRATEGY( + int8_t, int8_t, int16_t, int, 2, 3, 8, 8, winograd_2x3_8x8_s8) +MEGDNN_REG_WINOGRAD_STRATEGY( + int8_t, int8_t, int16_t, int, 2, 3, 8, 8, winograd_2x3_8x8_s8_nchw44) +MEGDNN_REG_WINOGRAD_STRATEGY( + int8_t, int8_t, float, float, 2, 3, 4, 4, winograd_2x3_4x4_s8_f32_nchw44) +} // namespace winograd } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8/strategy_2x3_8x8.cpp b/dnn/src/arm_common/conv_bias/int8/strategy_2x3_8x8.cpp index e92aab40..e36c130e 100644 --- a/dnn/src/arm_common/conv_bias/int8/strategy_2x3_8x8.cpp +++ b/dnn/src/arm_common/conv_bias/int8/strategy_2x3_8x8.cpp @@ -12,17 +12,17 @@ #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/naive/matrix_mul/matrix_mul_helper.h" +#include "src/arm_common/conv_bias/int8/helper.h" +#include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/conv_bias/winograd_common/winograd_common.h" #include "src/arm_common/elemwise_helper/op_unary.h" -#include "src/arm_common/conv_bias/int8/strategy.h" -#include "src/arm_common/conv_bias/int8/helper.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/utils.h" -#include "src/common/winograd/winograd_generator.h" +#include "midout.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" -#include "midout.h" +#include "src/common/winograd/winograd_generator.h" MIDOUT_DECL(megdnn_arm_common_winograd_s8_F23_8x8) using namespace megdnn; @@ -60,15 +60,14 @@ void transpose_8x4(const int16_t* src, int16_t* dst, int lda, int ldb) { } struct FilterTransform2X3_qs8 { - static void transform(const int8_t* filter_ptr, int16_t* filter_transform_buf, - int16_t* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { + static void transform( + const int8_t* filter_ptr, int16_t* filter_transform_buf, + int16_t* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, + size_t oc_end) { constexpr int alpha = 2 + 3 - 1; //! G * g * GT - int16x4_t g0{2, 0, 0, 0}, g1{1, 1, 1, 0}, g2{1, -1, 1, 0}, - g3{0, 0, 2, 0}; - int16x4_t gt0{2, 1, 1, 0}, gt1{0, 1, -1, 0}, gt2{0, 1, 1, 2}, - gt3{0, 0, 0, 0}; + int16x4_t g0{2, 0, 0, 0}, g1{1, 1, 1, 0}, g2{1, -1, 1, 0}, g3{0, 0, 2, 0}; + int16x4_t gt0{2, 1, 1, 0}, gt1{0, 1, -1, 0}, gt2{0, 1, 1, 2}, gt3{0, 0, 0, 0}; size_t OCB = OC / 8; size_t ICB = IC / 8; @@ -94,26 +93,26 @@ struct FilterTransform2X3_qs8 { int16x4_t v2 = vget_low_s16(vmovl_s8(s2)); \ int16x4_t v3 = vdup_n_s16(0); -#define cb(oc, ic, get_v) \ - get_v int16x4_t vsum0 = vdup_n_s16(0), vsum1 = vdup_n_s16(0), \ - vsum2 = vdup_n_s16(0), vsum3 = vdup_n_s16(0); \ - MATRIX_MUL4x4(vsum, g, v); \ - int16x4_t vres0 = vdup_n_s16(0), vres1 = vdup_n_s16(0), \ - vres2 = vdup_n_s16(0), vres3 = vdup_n_s16(0); \ - MATRIX_MUL4x4(vres, vsum, gt); \ - vst1_s16(transform_mid_buf, vres0); \ - vst1_s16(transform_mid_buf + 4, vres1); \ - vst1_s16(transform_mid_buf + 8, vres2); \ - vst1_s16(transform_mid_buf + 12, vres3); \ - size_t ocb = (oc) / 8; \ - size_t oc8 = (oc) % 8; \ - size_t icb = (ic) / 8; \ - size_t ic8 = (ic) % 8; \ - rep(i, alpha) rep(j, alpha) { \ - filter_transform_buf[(i * alpha + j) * OCB * ICB * 8 * 8 + \ - ocb * ICB * 8 * 8 + icb * 8 * 8 + ic8 * 8 + \ - oc8] = transform_mid_buf[i * alpha + j]; \ - } \ +#define cb(oc, ic, get_v) \ + get_v int16x4_t vsum0 = vdup_n_s16(0), vsum1 = vdup_n_s16(0), \ + vsum2 = vdup_n_s16(0), vsum3 = vdup_n_s16(0); \ + MATRIX_MUL4x4(vsum, g, v); \ + int16x4_t vres0 = vdup_n_s16(0), vres1 = vdup_n_s16(0), vres2 = vdup_n_s16(0), \ + vres3 = vdup_n_s16(0); \ + MATRIX_MUL4x4(vres, vsum, gt); \ + vst1_s16(transform_mid_buf, vres0); \ + vst1_s16(transform_mid_buf + 4, vres1); \ + vst1_s16(transform_mid_buf + 8, vres2); \ + vst1_s16(transform_mid_buf + 12, vres3); \ + size_t ocb = (oc) / 8; \ + size_t oc8 = (oc) % 8; \ + size_t icb = (ic) / 8; \ + size_t ic8 = (ic) % 8; \ + rep(i, alpha) rep(j, alpha) { \ + filter_transform_buf \ + [(i * alpha + j) * OCB * ICB * 8 * 8 + ocb * ICB * 8 * 8 + \ + icb * 8 * 8 + ic8 * 8 + oc8] = transform_mid_buf[i * alpha + j]; \ + } \ filter += 9; for (size_t oc = oc_start; oc < oc_end; oc++) { @@ -132,9 +131,7 @@ struct FilterTransform2X3_qs8 { cb(oc, ic, get_v_general); } } else { - rep(ic, IC - 1) { - cb(OC - 1, ic, get_v_general); - } + rep(ic, IC - 1) { cb(OC - 1, ic, get_v_general); } cb(OC - 1, IC - 1, get_v_searal); } } @@ -146,16 +143,15 @@ struct FilterTransform2X3_qs8 { struct InputTransform2X3_qs8 { template - static void prepare(const int8_t* input, int16_t* patch, int16_t* patchT, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t ic, size_t IC) { + static void prepare( + const int8_t* input, int16_t* patch, int16_t* patchT, int ih_start, + int iw_start, size_t IH, size_t IW, size_t ic, size_t IC) { constexpr size_t alpha = 2 + 3 - 1; if (!(inner && ic + 8 < IC)) { memset(patch, 0, sizeof(int16_t) * 8 * alpha * alpha); } if (inner) { - const int8_t* input_ptr = - input + ic * IH * IW + ih_start * IW + iw_start; + const int8_t* input_ptr = input + ic * IH * IW + ih_start * IW + iw_start; InputGetter getter; for (size_t ico = 0; ico < 8; ++ico) { if (ic + ico < IC) { @@ -184,8 +180,7 @@ struct InputTransform2X3_qs8 { size_t iho = ih - ih_start, iwo = iw - iw_start; patch[ico * alpha * alpha + iho * alpha + iwo] = static_cast( - input[(ic + ico) * IH * IW + - ih * IW + iw]); + input[(ic + ico) * IH * IW + ih * IW + iw]); } } } @@ -197,14 +192,13 @@ struct InputTransform2X3_qs8 { transpose_8x4(patch + 4 * 3, patchT + 32 * 3, 16, 4); } - static void transform(const int16_t* patchT, int16_t* input_transform_buf, - size_t unit_idx, size_t nr_units_in_tile, size_t ic, - size_t IC) { + static void transform( + const int16_t* patchT, int16_t* input_transform_buf, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC) { constexpr size_t alpha = 2 + 3 - 1; // BT * d * B -#define cb(m, n) \ - Vector d##m##n = \ - Vector::load(patchT + 8 * (m * 4 + n)); +#define cb(m, n) \ + Vector d##m##n = Vector::load(patchT + 8 * (m * 4 + n)); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -233,10 +227,10 @@ struct InputTransform2X3_qs8 { size_t ICB = IC / 8; size_t icb = ic / 8; -#define cb(m, n) \ - d##m##n.save(input_transform_buf + \ - (m * alpha + n) * nr_units_in_tile * ICB * 8 + \ - icb * nr_units_in_tile * 8 + unit_idx * 8); +#define cb(m, n) \ + d##m##n.save( \ + input_transform_buf + (m * alpha + n) * nr_units_in_tile * ICB * 8 + \ + icb * nr_units_in_tile * 8 + unit_idx * 8); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) #undef cb } @@ -244,14 +238,12 @@ struct InputTransform2X3_qs8 { template struct OutputTransform2X3_qs8 { - static void transform(const int32_t* output_transform_buf, - const int32_t* bias, int8_t* output, - int32_t* transform_mid_buf, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, size_t oc_index, - size_t unit_idx, size_t nr_units_in_tile, - const DType& src_dtype, const DType& filter_dtype, - const DType& dst_dtype) { + static void transform( + const int32_t* output_transform_buf, const int32_t* bias, int8_t* output, + int32_t* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& filter_dtype, + const DType& dst_dtype) { float scale_filter = 0.f; if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) { scale_filter = filter_dtype.param().scale; @@ -261,8 +253,8 @@ struct OutputTransform2X3_qs8 { } float input_filter_scale = src_dtype.param().scale * scale_filter; - DType buffer_dtype = dtype::QuantizedS32(input_filter_scale * 0.5f * - 0.5f * 1.0f * 1.0f); + DType buffer_dtype = + dtype::QuantizedS32(input_filter_scale * 0.5f * 0.5f * 1.0f * 1.0f); Op op(buffer_dtype, dst_dtype); //! AT * m * A constexpr size_t alpha = 2 + 3 - 1; @@ -271,10 +263,9 @@ struct OutputTransform2X3_qs8 { size_t OCB = (oc_end - oc_start) / 8; size_t ocb = oc_index / 8; -#define cb(m, n) \ - auto v##m##n = Vector::load( \ - output_transform_buf + \ - (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ +#define cb(m, n) \ + auto v##m##n = Vector::load( \ + output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ ocb * nr_units_in_tile * 8 + unit_idx * 8); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -314,15 +305,12 @@ struct OutputTransform2X3_qs8 { dt_qint8 res_int8 = dt_qint8(0); size_t oh = oh_start + oho; size_t ow = ow_start + owo; - int32_t res = - transform_mid_buf[oho * 2 * 8 + owo * 8 + oco]; + int32_t res = transform_mid_buf[oho * 2 * 8 + owo * 8 + oco]; if (bmode == BiasMode::BIAS) { - res += bias[(oc + oco) * OH * OW + oh * OW + ow] * 2 * - 2; + res += bias[(oc + oco) * OH * OW + oh * OW + ow] * 2 * 2; } res_int8 = op(dt_qint32(res)); - output[(oc + oco) * OH * OW + oh * OW + ow] = - res_int8.as_int8(); + output[(oc + oco) * OH * OW + oh * OW + ow] = res_int8.as_int8(); } } } @@ -336,24 +324,20 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_s8) -void winograd_2x3_8x8_s8::filter(const int8_t* filter, - int16_t* filter_transform_buf, - int16_t* transform_mid_buf, size_t OC, - size_t IC, size_t oc_start, size_t oc_end) { - FilterTransform2X3_qs8::transform(filter, filter_transform_buf, - transform_mid_buf, OC, IC, oc_start, - oc_end); +void winograd_2x3_8x8_s8::filter( + const int8_t* filter, int16_t* filter_transform_buf, int16_t* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { + FilterTransform2X3_qs8::transform( + filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, oc_end); } -void winograd_2x3_8x8_s8::input(const int8_t* input, - int16_t* input_transform_buf, - int16_t* transform_mid_buf, size_t IH, - size_t IW, size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_2x3_8x8_s8::input( + const int8_t* input, int16_t* input_transform_buf, int16_t* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { megdnn_assert(IC % 8 == 0); constexpr int alpha = 3 + 2 - 1; - + // OW = IW + 2 * PW - KERNEL_SIZE + 1 auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); int16_t* patch = transform_mid_buf; @@ -368,35 +352,30 @@ void winograd_2x3_8x8_s8::input(const int8_t* input, int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { - InputTransform2X3_qs8::prepare(input, patch, patchT, - ih_start, iw_start, IH, IW, - ic, IC); - InputTransform2X3_qs8::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + InputTransform2X3_qs8::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + InputTransform2X3_qs8::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } else { - InputTransform2X3_qs8::prepare(input, patch, patchT, - ih_start, iw_start, IH, - IW, ic, IC); - InputTransform2X3_qs8::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + InputTransform2X3_qs8::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC); + InputTransform2X3_qs8::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } } } } -void winograd_2x3_8x8_s8::output(const int* output_transform_buf, - const int* bias, int8_t* output, - int* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, - size_t unit_start_idx, - size_t nr_units_in_tile) { -#define cb(_bmode, _nonline_op, ...) \ - OutputTransform2X3_qs8<_bmode MEGDNN_COMMA _nonline_op>::transform( \ - __VA_ARGS__); +void winograd_2x3_8x8_s8::output( + const int* output_transform_buf, const int* bias, int8_t* output, + int* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform2X3_qs8<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); @@ -411,8 +390,9 @@ void winograd_2x3_8x8_s8::output(const int* output_transform_buf, DISPATCH_CONV_WINOGRAD_BIAS_QUANTIZED( megdnn_arm_common_winograd_s8_F23_8x8, cb, dt_qint32, dt_qint8, bmode, nonline_mode, output_transform_buf, bias, output, - transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, - unit_idx, nr_units_in_tile, src_dtype, filter_dtype, dst_dtype); + transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, + oc_index, unit_idx, nr_units_in_tile, src_dtype, filter_dtype, + dst_dtype); } } #undef cb diff --git a/dnn/src/arm_common/conv_bias/int8/strategy_nchw44_2x3_4x4.cpp b/dnn/src/arm_common/conv_bias/int8/strategy_nchw44_2x3_4x4.cpp index a9d3c0f6..dc2141a5 100644 --- a/dnn/src/arm_common/conv_bias/int8/strategy_nchw44_2x3_4x4.cpp +++ b/dnn/src/arm_common/conv_bias/int8/strategy_nchw44_2x3_4x4.cpp @@ -16,10 +16,10 @@ #include "src/common/utils.h" #include "src/fallback/conv_bias/winograd/winograd.h" +#include "src/arm_common/conv_bias/fp32/helper.h" #include "src/arm_common/conv_bias/winograd_common/winograd_common.h" -#include "src/naive/matrix_mul/matrix_mul_helper.h" #include "src/arm_common/elemwise_helper/op_unary.h" -#include "src/arm_common/conv_bias/fp32/helper.h" +#include "src/naive/matrix_mul/matrix_mul_helper.h" #include "midout.h" @@ -30,9 +30,10 @@ using namespace arm_common; namespace { struct InputTransform2X3 { template - static void prepare(const int8_t* input, float* patch, float* patchT, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t ic, size_t IC, size_t PH, size_t PW) { + static void prepare( + const int8_t* input, float* patch, float* patchT, int ih_start, + int iw_start, size_t IH, size_t IW, size_t ic, size_t IC, size_t PH, + size_t PW) { megdnn_assert( ic % 4 == 0 && IC % 4 == 0, "Winograd input prepare param is not times of 4!"); @@ -50,14 +51,10 @@ struct InputTransform2X3 { int32x4_t v_2 = vmovl_s16(vget_low_s16(v_high)); int32x4_t v_3 = vmovl_s16(vget_high_s16(v_high)); - vst1q_f32(patchT + ico * 4 * alpha + 0 * 4, - vcvtq_f32_s32(v_0)); - vst1q_f32(patchT + ico * 4 * alpha + 1 * 4, - vcvtq_f32_s32(v_1)); - vst1q_f32(patchT + ico * 4 * alpha + 2 * 4, - vcvtq_f32_s32(v_2)); - vst1q_f32(patchT + ico * 4 * alpha + 3 * 4, - vcvtq_f32_s32(v_3)); + vst1q_f32(patchT + ico * 4 * alpha + 0 * 4, vcvtq_f32_s32(v_0)); + vst1q_f32(patchT + ico * 4 * alpha + 1 * 4, vcvtq_f32_s32(v_1)); + vst1q_f32(patchT + ico * 4 * alpha + 2 * 4, vcvtq_f32_s32(v_2)); + vst1q_f32(patchT + ico * 4 * alpha + 3 * 4, vcvtq_f32_s32(v_3)); input_ptr += IW * 4; } } else { @@ -74,21 +71,21 @@ struct InputTransform2X3 { for (int ih = ih0_act; ih < ih1_act; ++ih) { for (int iw = iw0_act; iw < iw1_act; ++iw) { size_t iho = ih - ih_start, iwo = iw - iw_start; - vst1q_f32(patchT + iho * alpha * 4 + iwo * 4, - getter(input_ptr + ih * IW * 4 + iw * 4)); + vst1q_f32( + patchT + iho * alpha * 4 + iwo * 4, + getter(input_ptr + ih * IW * 4 + iw * 4)); } } } } - static void transform(const float* patchT, float* input_transform_buf, - size_t unit_idx, size_t nr_units_in_tile, size_t ic, - size_t IC) { + static void transform( + const float* patchT, float* input_transform_buf, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC) { constexpr size_t alpha = 2 + 3 - 1; // BT * d * B -#define cb(m, n) \ - Vector d##m##n = \ - Vector::load(patchT + m * 4 * 4 + n * 4); +#define cb(m, n) \ + Vector d##m##n = Vector::load(patchT + m * 4 * 4 + n * 4); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -117,10 +114,10 @@ struct InputTransform2X3 { size_t ICB = IC / 4; size_t icb = ic / 4; -#define cb(m, n) \ - d##m##n.save(input_transform_buf + \ - (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ - icb * nr_units_in_tile * 4 + unit_idx * 4); +#define cb(m, n) \ + d##m##n.save( \ + input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 4 + \ + icb * nr_units_in_tile * 4 + unit_idx * 4); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) #undef cb } @@ -128,13 +125,12 @@ struct InputTransform2X3 { template struct OutputTransform2X3 { - static void transform(const float* output_transform_buf, const float* bias, - int8_t* output, float* transform_mid_buf, - size_t oh_start, size_t ow_start, size_t OH, - size_t OW, size_t oc_start, size_t oc_end, - size_t oc_index, size_t unit_idx, - size_t nr_units_in_tile, const DType& src_dtype, - const DType& filter_dtype, const DType& dst_dtype) { + static void transform( + const float* output_transform_buf, const float* bias, int8_t* output, + float* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& filter_dtype, + const DType& dst_dtype) { float scale_filter = 0.f; MEGDNN_MARK_USED_VAR(transform_mid_buf); if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) { @@ -155,10 +151,9 @@ struct OutputTransform2X3 { size_t OCB = (oc_end - oc_start) / 4; size_t ocb = oc_index / 4; -#define cb(m, n) \ - auto v##m##n = Vector::load( \ - output_transform_buf + \ - (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ +#define cb(m, n) \ + auto v##m##n = Vector::load( \ + output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 4 + \ ocb * nr_units_in_tile * 4 + unit_idx * 4); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -187,10 +182,10 @@ struct OutputTransform2X3 { const float32x4_t vvbias = vcvtq_f32_s32(vld1q_s32(tmp_bias + oc)); vbias = Vector(vvbias); - result[0][0] += vbias; - result[0][1] += vbias; - result[1][0] += vbias; - result[1][1] += vbias; + result[0][0] += vbias; + result[0][1] += vbias; + result[1][0] += vbias; + result[1][1] += vbias; } #undef cb @@ -206,8 +201,8 @@ struct OutputTransform2X3 { Vector res; res = result[oho][owo]; if (bmode == BiasMode::BIAS) { - const float32x4_t vvbias = vcvtq_f32_s32(vld1q_s32( - tmp_bias + oc * OH * OW + oh * OW * 4 + ow * 4)); + const float32x4_t vvbias = vcvtq_f32_s32( + vld1q_s32(tmp_bias + oc * OH * OW + oh * OW * 4 + ow * 4)); res += Vector(vvbias); } #if MEGDNN_AARCH64 @@ -235,10 +230,9 @@ namespace arm_common { namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_4x4_s8_f32_nchw44) -void winograd_2x3_4x4_s8_f32_nchw44::filter(const int8_t* filter, - float* filter_transform_buf, - float* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { +void winograd_2x3_4x4_s8_f32_nchw44::filter( + const int8_t* filter, float* filter_transform_buf, float* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { constexpr int alpha = 2 + 3 - 1; /** * origin: (4x3) * (3 x 3) * (3 x 4) @@ -250,21 +244,21 @@ void winograd_2x3_4x4_s8_f32_nchw44::filter(const int8_t* filter, InputGetter getter; MEGDNN_MARK_USED_VAR(transform_mid_buf); - megdnn_assert((oc_end - oc_start) % 4 == 0 && oc_start % 4 == 0 && - oc_end % 4 == 0 && IC % 4 == 0 && OC % 4 == 0, - "Winograd filter transform input param is not times of 4!"); + megdnn_assert( + (oc_end - oc_start) % 4 == 0 && oc_start % 4 == 0 && oc_end % 4 == 0 && + IC % 4 == 0 && OC % 4 == 0, + "Winograd filter transform input param is not times of 4!"); size_t OCB = OC / 4; size_t ICB = IC / 4; for (size_t ocb = oc_start / 4; ocb < oc_end / 4; ocb++) { for (size_t icb = 0; icb < ICB; icb++) { for (size_t ic_inner = 0; ic_inner < 4; ic_inner++) { - const int8_t* fptr = filter + (ocb * ICB + icb) * 3 * 3 * 4 * 4 + - ic_inner * 4; + const int8_t* fptr = + filter + (ocb * ICB + icb) * 3 * 3 * 4 * 4 + ic_inner * 4; -#define cb(m, n) \ - Vector g##m##n = \ - Vector(getter(fptr + (m * 3 + n) * 4 * 4)); +#define cb(m, n) \ + Vector g##m##n = Vector(getter(fptr + (m * 3 + n) * 4 * 4)); UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) #undef cb @@ -281,11 +275,10 @@ void winograd_2x3_4x4_s8_f32_nchw44::filter(const int8_t* filter, UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd); #undef FILTER_TRANSFORM - -#define cb(m, n) \ - ret##m##n.save(filter_transform_buf + \ - (m * alpha + n) * OCB * ICB * 4 * 4 + ocb * ICB * 4 * 4 + \ - icb * 4 * 4 + ic_inner * 4); +#define cb(m, n) \ + ret##m##n.save( \ + filter_transform_buf + (m * alpha + n) * OCB * ICB * 4 * 4 + \ + ocb * ICB * 4 * 4 + icb * 4 * 4 + ic_inner * 4); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) #undef cb } @@ -293,16 +286,14 @@ void winograd_2x3_4x4_s8_f32_nchw44::filter(const int8_t* filter, } } -void winograd_2x3_4x4_s8_f32_nchw44::input(const int8_t* input, float* input_transform_buf, - float* transform_mid_buf, size_t IH, size_t IW, - size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_2x3_4x4_s8_f32_nchw44::input( + const int8_t* input, float* input_transform_buf, float* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { megdnn_assert(IC % 4 == 0); constexpr int alpha = 3 + 2 - 1; - auto units_w = - div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); float* patch = transform_mid_buf; float* patchT = transform_mid_buf + 4 * alpha * alpha; @@ -315,31 +306,30 @@ void winograd_2x3_4x4_s8_f32_nchw44::input(const int8_t* input, float* input_tra int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { - InputTransform2X3::prepare(input, patch, patchT, ih_start, - iw_start, IH, IW, ic, IC,PH,PW); - InputTransform2X3::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + InputTransform2X3::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC, PH, + PW); + InputTransform2X3::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } else { - InputTransform2X3::prepare(input, patch, patchT, - ih_start, iw_start, IH, IW, - ic, IC,PH,PW); - InputTransform2X3::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + InputTransform2X3::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC, PH, + PW); + InputTransform2X3::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } } } } -void winograd_2x3_4x4_s8_f32_nchw44::output(const float* output_transform_buf, - const float* bias, int8_t* output, - float* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, - size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_2x3_4x4_s8_f32_nchw44::output( + const float* output_transform_buf, const float* bias, int8_t* output, + float* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { #define cb(_bmode, _nonline_op, ...) \ OutputTransform2X3<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); @@ -353,11 +343,10 @@ void winograd_2x3_4x4_s8_f32_nchw44::output(const float* output_transform_buf, size_t oh_start = nh * OUTPUT_BLOCK_SIZE; size_t ow_start = nw * OUTPUT_BLOCK_SIZE; DISPATCH_CONV_WINOGRAD_BIAS_QUANTIZED( - megdnn_arm_common_winograd_nchw44_s8_comp_fp32_f23, cb, - dt_qint32, dt_qint8, bmode, nonline_mode, - output_transform_buf, bias, output, transform_mid_buf, - oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, - unit_idx, nr_units_in_tile, src_dtype, filter_dtype, + megdnn_arm_common_winograd_nchw44_s8_comp_fp32_f23, cb, dt_qint32, + dt_qint8, bmode, nonline_mode, output_transform_buf, bias, output, + transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, + oc_index, unit_idx, nr_units_in_tile, src_dtype, filter_dtype, dst_dtype); } } diff --git a/dnn/src/arm_common/conv_bias/int8/strategy_nchw44_2x3_8x8.cpp b/dnn/src/arm_common/conv_bias/int8/strategy_nchw44_2x3_8x8.cpp index ab8310d2..a97fc54d 100644 --- a/dnn/src/arm_common/conv_bias/int8/strategy_nchw44_2x3_8x8.cpp +++ b/dnn/src/arm_common/conv_bias/int8/strategy_nchw44_2x3_8x8.cpp @@ -12,16 +12,16 @@ #include "src/fallback/conv_bias/winograd/winograd.h" #include "src/naive/matrix_mul/matrix_mul_helper.h" +#include "src/arm_common/conv_bias/int8/helper.h" +#include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/conv_bias/winograd_common/winograd_common.h" #include "src/arm_common/elemwise_helper/op_unary.h" -#include "src/arm_common/conv_bias/int8/strategy.h" -#include "src/arm_common/conv_bias/int8/helper.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/utils.h" -#include "src/common/winograd/winograd_generator.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" +#include "src/common/winograd/winograd_generator.h" #include "midout.h" @@ -33,30 +33,31 @@ using namespace arm_common; namespace { struct FilterTransform2X3_qs8 { - static void transform(const int8_t* filter_ptr, int16_t* filter_transform_buf, - int16_t* transform_mid_buf, size_t OC, size_t IC, - size_t oc_start, size_t oc_end) { + static void transform( + const int8_t* filter_ptr, int16_t* filter_transform_buf, + int16_t* transform_mid_buf, size_t OC, size_t IC, size_t oc_start, + size_t oc_end) { constexpr int alpha = 2 + 3 - 1; - /** - * origin: (4x3) * (3 x 3) * (3 x 4) - */ - //! 1 0 0 v00 v01 v02 1 0.5 0.5 0 - //! 0.5 0.5 0.5 v10 v11 v12 0 0.5 -0.5 0 - //! 0.5 -0.5 0.5 v20 v21 v22 0 0.5 0.5 1 - //! 0 0 1 - - //! 2 0 0 v00 v01 v02 2 1 1 0 - //! 1 1 1 v10 v11 v12 0 1 -1 0 - //! 1 -1 1 v20 v21 v22 0 1 1 2 - //! 0 0 2 + /** + * origin: (4x3) * (3 x 3) * (3 x 4) + */ + //! 1 0 0 v00 v01 v02 1 0.5 0.5 0 + //! 0.5 0.5 0.5 v10 v11 v12 0 0.5 -0.5 0 + //! 0.5 -0.5 0.5 v20 v21 v22 0 0.5 0.5 1 + //! 0 0 1 + + //! 2 0 0 v00 v01 v02 2 1 1 0 + //! 1 1 1 v10 v11 v12 0 1 -1 0 + //! 1 -1 1 v20 v21 v22 0 1 1 2 + //! 0 0 2 //! G * g * GT InputGetter getter; MEGDNN_MARK_USED_VAR(transform_mid_buf); megdnn_assert( - (oc_end - oc_start) % 4 == 0 && oc_start % 4 == 0 && - oc_end % 4 == 0 && IC % 8 == 0 && OC % 8 == 0, + (oc_end - oc_start) % 4 == 0 && oc_start % 4 == 0 && oc_end % 4 == 0 && + IC % 8 == 0 && OC % 8 == 0, "Winograd filter transform input param is not times of 8!"); size_t OCB = OC / 8; size_t ICB = IC / 8; @@ -69,9 +70,8 @@ struct FilterTransform2X3_qs8 { const int8_t* fptr = filter_ptr + (ocb * ICB4 + icb) * 3 * 3 * 4 * 4 + ic_inner * 4; -#define cb(m, n) \ - Vector g##m##n = \ - Vector(getter(fptr + (m * 3 + n) * 4 * 4)); +#define cb(m, n) \ + Vector g##m##n = Vector(getter(fptr + (m * 3 + n) * 4 * 4)); UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) #undef cb @@ -87,9 +87,9 @@ struct FilterTransform2X3_qs8 { UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd); #undef FILTER_TRANSFORM -#define cb(m, n) \ - ret##m##n.save( \ - filter_transform_buf + (m * alpha + n) * OCB * ICB * 8 * 8 + \ +#define cb(m, n) \ + ret##m##n.save( \ + filter_transform_buf + (m * alpha + n) * OCB * ICB * 8 * 8 + \ tmp_ocb * ICB * 8 * 8 + icb * 4 * 8 + ic_inner * 8 + index * 4); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) @@ -102,20 +102,21 @@ struct FilterTransform2X3_qs8 { struct InputTransform2X3_qs8 { template - static void prepare(const int8_t* input, int16_t* patch, int16_t* patchT, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t ic, size_t IC, size_t PH, size_t PW) { - megdnn_assert(ic % 8 == 0 && IC % 8 == 0, - "Winograd input prepare param is not times of 4!"); + static void prepare( + const int8_t* input, int16_t* patch, int16_t* patchT, int ih_start, + int iw_start, size_t IH, size_t IW, size_t ic, size_t IC, size_t PH, + size_t PW) { + megdnn_assert( + ic % 8 == 0 && IC % 8 == 0, + "Winograd input prepare param is not times of 4!"); MEGDNN_MARK_USED_VAR(patch); constexpr size_t alpha = 2 + 3 - 1; if (inner) { const int8_t* input_ptr = input + ic * IH * IW + ih_start * IW * 4 + iw_start * 4; for (size_t ico = 0; ico < alpha; ++ico) { - int8x16_t v_input0 = vld1q_s8(input_ptr); // c0123 - int8x16_t v_input1 = - vld1q_s8(input_ptr + IH * IW * 4); // c4567 + int8x16_t v_input0 = vld1q_s8(input_ptr); // c0123 + int8x16_t v_input1 = vld1q_s8(input_ptr + IH * IW * 4); // c4567 int32x4_t v32_00 = vreinterpretq_s32_s8(v_input0); int32x4_t v32_01 = vreinterpretq_s32_s8(v_input1); @@ -149,24 +150,24 @@ struct InputTransform2X3_qs8 { for (int ih = ih0_act; ih < ih1_act; ++ih) { for (int iw = iw0_act; iw < iw1_act; ++iw) { size_t iho = ih - ih_start, iwo = iw - iw_start; - vst1q_s16(patchT + iho * alpha * 8 + iwo * 8, - vcombine_s16( - getter(input_ptr + ih * IW * 4 + iw * 4), - getter(input_ptr + IH * IW * 4 + - ih * IW * 4 + iw * 4))); + vst1q_s16( + patchT + iho * alpha * 8 + iwo * 8, + vcombine_s16( + getter(input_ptr + ih * IW * 4 + iw * 4), + getter(input_ptr + IH * IW * 4 + ih * IW * 4 + + iw * 4))); } } } } - static void transform(const int16_t* patchT, int16_t* input_transform_buf, - size_t unit_idx, size_t nr_units_in_tile, size_t ic, - size_t IC) { + static void transform( + const int16_t* patchT, int16_t* input_transform_buf, size_t unit_idx, + size_t nr_units_in_tile, size_t ic, size_t IC) { constexpr size_t alpha = 2 + 3 - 1; // BT * d * B -#define cb(m, n) \ - Vector d##m##n = \ - Vector::load(patchT + m * 4 * 8 + n * 8); +#define cb(m, n) \ + Vector d##m##n = Vector::load(patchT + m * 4 * 8 + n * 8); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -195,10 +196,10 @@ struct InputTransform2X3_qs8 { size_t ICB = IC / 8; size_t icb = ic / 8; -#define cb(m, n) \ - d##m##n.save(input_transform_buf + \ - (m * alpha + n) * ICB * nr_units_in_tile * 8 + \ - icb * nr_units_in_tile * 8 + unit_idx * 8); +#define cb(m, n) \ + d##m##n.save( \ + input_transform_buf + (m * alpha + n) * ICB * nr_units_in_tile * 8 + \ + icb * nr_units_in_tile * 8 + unit_idx * 8); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) #undef cb } @@ -206,14 +207,12 @@ struct InputTransform2X3_qs8 { template struct OutputTransform2X3_qs8 { - static void transform(const int32_t* output_transform_buf, - const int32_t* bias, int8_t* output, - int32_t* transform_mid_buf, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, size_t oc_index, - size_t unit_idx, size_t nr_units_in_tile, - const DType& src_dtype, const DType& filter_dtype, - const DType& dst_dtype) { + static void transform( + const int32_t* output_transform_buf, const int32_t* bias, int8_t* output, + int32_t* transform_mid_buf, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, const DType& filter_dtype, + const DType& dst_dtype) { MEGDNN_MARK_USED_VAR(transform_mid_buf); float scale_filter = 0.f; if (filter_dtype.enumv() == DTypeEnum::QuantizedS8) { @@ -224,8 +223,8 @@ struct OutputTransform2X3_qs8 { } float input_filter_scale = src_dtype.param().scale * scale_filter; - DType buffer_dtype = dtype::QuantizedS32(input_filter_scale * 0.5f * - 0.5f * 1.0f * 1.0f); + DType buffer_dtype = + dtype::QuantizedS32(input_filter_scale * 0.5f * 0.5f * 1.0f * 1.0f); Op op(buffer_dtype, dst_dtype); //! AT * m * A constexpr size_t alpha = 2 + 3 - 1; @@ -234,10 +233,9 @@ struct OutputTransform2X3_qs8 { size_t OCB = (oc_end - oc_start) / 8; size_t ocb = oc_index / 8; -#define cb(m, n) \ - auto v##m##n = Vector::load( \ - output_transform_buf + \ - (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ +#define cb(m, n) \ + auto v##m##n = Vector::load( \ + output_transform_buf + (m * alpha + n) * OCB * nr_units_in_tile * 8 + \ ocb * nr_units_in_tile * 8 + unit_idx * 8); UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb @@ -252,23 +250,22 @@ struct OutputTransform2X3_qs8 { UNROLL_CALL_NOWRAPPER(4, cb); #undef cb + Vector result[2][2]; - Vector result[2][2]; + result[0][0] = t00 + t01 + t02; + result[1][0] = t10 + t11 + t12; + result[0][1] = t01 - t02 + t03; + result[1][1] = t11 - t12 + t13; - result[0][0] = t00 + t01 + t02; - result[1][0] = t10 + t11 + t12; - result[0][1] = t01 - t02 + t03; - result[1][1] = t11 - t12 + t13; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + Vector vbias; + vbias = Vector::load(bias + oc) * 4; - if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - Vector vbias; - vbias = Vector::load(bias + oc) * 4; - - result[0][0] += vbias; - result[0][1] += vbias; - result[1][0] += vbias; - result[1][1] += vbias; - } + result[0][0] += vbias; + result[0][1] += vbias; + result[1][0] += vbias; + result[1][1] += vbias; + } #if MEGDNN_AARCH64 int32_t* tmp_output = static_cast(static_cast(output)); @@ -280,17 +277,16 @@ struct OutputTransform2X3_qs8 { Vector res = result[oho][owo]; if (bmode == BiasMode::BIAS) { int32x4x2_t vbias; - vbias.val[0] = vld1q_s32(bias + oc * OH * OW + oh * OW * 4 + - ow * 4); - vbias.val[1] = vld1q_s32(bias + (oc + 4) * OH * OW + - oh * OW * 4 + ow * 4); + vbias.val[0] = + vld1q_s32(bias + oc * OH * OW + oh * OW * 4 + ow * 4); + vbias.val[1] = + vld1q_s32(bias + (oc + 4) * OH * OW + oh * OW * 4 + ow * 4); res += Vector(vbias) * 4; } #if MEGDNN_AARCH64 int8x8_t res_int8 = op(res.value); int32x2_t res32 = vreinterpret_s32_s8(res_int8); - tmp_output[oc / 4 * OH * OW + oh * OW + ow] = - vget_lane_s32(res32, 0); + tmp_output[oc / 4 * OH * OW + oh * OW + ow] = vget_lane_s32(res32, 0); tmp_output[(oc / 4 + 1) * OH * OW + oh * OW + ow] = vget_lane_s32(res32, 1); #else @@ -317,26 +313,22 @@ namespace winograd { MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_2x3_8x8_s8_nchw44) -void winograd_2x3_8x8_s8_nchw44::filter(const int8_t* filter, - int16_t* filter_transform_buf, - int16_t* transform_mid_buf, size_t OC, - size_t IC, size_t oc_start, size_t oc_end) { - FilterTransform2X3_qs8::transform(filter, filter_transform_buf, - transform_mid_buf, OC, IC, oc_start, - oc_end); +void winograd_2x3_8x8_s8_nchw44::filter( + const int8_t* filter, int16_t* filter_transform_buf, int16_t* transform_mid_buf, + size_t OC, size_t IC, size_t oc_start, size_t oc_end) { + FilterTransform2X3_qs8::transform( + filter, filter_transform_buf, transform_mid_buf, OC, IC, oc_start, oc_end); } -void winograd_2x3_8x8_s8_nchw44::input(const int8_t* input, - int16_t* input_transform_buf, - int16_t* transform_mid_buf, size_t IH, - size_t IW, size_t IC, size_t PH, size_t PW, - size_t unit_start_idx, - size_t nr_units_in_tile) { +void winograd_2x3_8x8_s8_nchw44::input( + const int8_t* input, int16_t* input_transform_buf, int16_t* transform_mid_buf, + size_t IH, size_t IW, size_t IC, size_t PH, size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { megdnn_assert(IC % 8 == 0); constexpr int alpha = 3 + 2 - 1; auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); int16_t* patch = transform_mid_buf; - int16_t* patchT = transform_mid_buf;// + 8 * alpha * alpha; + int16_t* patchT = transform_mid_buf; // + 8 * alpha * alpha; for (size_t ic = 0; ic < IC; ic += 8) { rep(unit_idx, nr_units_in_tile) { @@ -347,35 +339,32 @@ void winograd_2x3_8x8_s8_nchw44::input(const int8_t* input, int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { - InputTransform2X3_qs8::prepare(input, patch, patchT, - ih_start, iw_start, IH, IW, - ic, IC,PH,PW); - InputTransform2X3_qs8::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + InputTransform2X3_qs8::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC, PH, + PW); + InputTransform2X3_qs8::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } else { - InputTransform2X3_qs8::prepare(input, patch, patchT, - ih_start, iw_start, IH, - IW, ic, IC,PH,PW); - InputTransform2X3_qs8::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, ic, - IC); + InputTransform2X3_qs8::prepare( + input, patch, patchT, ih_start, iw_start, IH, IW, ic, IC, PH, + PW); + InputTransform2X3_qs8::transform( + patchT, input_transform_buf, unit_idx, nr_units_in_tile, ic, + IC); } } } } -void winograd_2x3_8x8_s8_nchw44::output(const int* output_transform_buf, - const int* bias, int8_t* output, - int* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t OH, size_t OW, - size_t oc_start, size_t oc_end, - size_t unit_start_idx, - size_t nr_units_in_tile) { -#define cb(_bmode, _nonline_op, ...) \ - OutputTransform2X3_qs8<_bmode MEGDNN_COMMA _nonline_op>::transform( \ - __VA_ARGS__); +void winograd_2x3_8x8_s8_nchw44::output( + const int* output_transform_buf, const int* bias, int8_t* output, + int* transform_mid_buf, BiasMode bmode, NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransform2X3_qs8<_bmode MEGDNN_COMMA _nonline_op>::transform(__VA_ARGS__); auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); for (size_t oc = oc_start; oc < oc_end; oc += 8) { size_t oc_index = oc - oc_start; @@ -386,11 +375,10 @@ void winograd_2x3_8x8_s8_nchw44::output(const int* output_transform_buf, size_t oh_start = nh * OUTPUT_BLOCK_SIZE; size_t ow_start = nw * OUTPUT_BLOCK_SIZE; DISPATCH_CONV_WINOGRAD_BIAS_QUANTIZED( - megdnn_arm_common_winograd_nchw44_s8_int16_8x8, cb, - dt_qint32, dt_qint8, bmode, nonline_mode, - output_transform_buf, bias, output, transform_mid_buf, - oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, - unit_idx, nr_units_in_tile, src_dtype, filter_dtype, + megdnn_arm_common_winograd_nchw44_s8_int16_8x8, cb, dt_qint32, + dt_qint8, bmode, nonline_mode, output_transform_buf, bias, output, + transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, oc_end, + oc_index, unit_idx, nr_units_in_tile, src_dtype, filter_dtype, dst_dtype); } } diff --git a/dnn/src/arm_common/conv_bias/int8/stride1.cpp b/dnn/src/arm_common/conv_bias/int8/stride1.cpp index ed874005..ce61f4fe 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/stride1.cpp @@ -22,20 +22,18 @@ using namespace arm_common; using namespace direct_int8_stride1; namespace { -bool need_dst_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { return param.osz[1] % 8; } -bool need_src_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_src_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { return true; } return need_dst_copy(param); } void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& IH2, + size_t& IW2, size_t& OH2, size_t& OW2) { auto&& fm = param.filter_meta; auto SW = fm.stride[1]; auto OH = param.osz[0]; @@ -49,8 +47,7 @@ void get_rectified_size( IW2 = SW * OW2 + FW - SW; } } // namespace -bool direct_int8_stride1::can_conv_direct_stride1_int8( - const NCBKernSizeParam& param) { +bool direct_int8_stride1::can_conv_direct_stride1_int8(const NCBKernSizeParam& param) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; auto OC = fm.ocpg; @@ -66,13 +63,14 @@ bool direct_int8_stride1::can_conv_direct_stride1_int8( param.filter_type.enumv() == DTypeEnum::Int8 && param.dst_type.enumv() == DTypeEnum::Int32)) && fm.format == param::Convolution::Format::NCHW && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); if (param.bias_type.valid()) { - avaible &= ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && - param.bias_type.enumv() == DTypeEnum::QuantizedS32) || - (param.bias_type.enumv() == param.dst_type.enumv())); + avaible &= + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.bias_type.enumv() == DTypeEnum::QuantizedS32) || + (param.bias_type.enumv() == param.dst_type.enumv())); } bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) || ((FH == 3 || FH == 5 || FH == 7) && @@ -91,9 +89,8 @@ WorkspaceBundle direct_int8_stride1::get_bundle( get_rectified_size(param, IH2, IW2, OH2, OW2); size_t src_size = 0, dst_size = 0; if (need_src_copy(param)) { - src_size = m_large_group - ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads - : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + src_size = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; }; if (need_dst_copy(param)) { dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; @@ -107,10 +104,8 @@ WorkspaceBundle direct_int8_stride1::get_bundle( } //! Process one input channel copy padding void direct_int8_stride1::copy_padding_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -123,9 +118,8 @@ void direct_int8_stride1::copy_padding_kern( bool need_src_copy_var = need_src_copy(kern_param); size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2], - group_id = ncb_index.ndrange_id[0], + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + channel_id = workspace_ids[2], group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; const int8_t* sptr = static_cast( @@ -138,17 +132,17 @@ void direct_int8_stride1::copy_padding_kern( channel_id * IH2 * IW2; std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(int8_t) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); } } }; //! compute one output channel template -void direct_int8_stride1::do_conv_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { +void direct_int8_stride1::do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t OH = kern_param.osz[0]; size_t OW = kern_param.osz[1]; size_t FH = kern_param.filter_meta.spatial[0]; @@ -159,29 +153,25 @@ void direct_int8_stride1::do_conv_kern(const WorkspaceBundle& bundle, get_rectified_size(kern_param, IH2, IW2, OH2, OW2); bool need_src_copy_var = need_src_copy(kern_param); bool need_dst_copy_var = need_dst_copy(kern_param); - bool need_post_process = - kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) Op op = Op(1.0f, 4.0f); if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; + float scale_bias = kern_param.bias_type.param().scale; float scale_dst = kern_param.dst_type.param().scale; op = Op(scale_bias, scale_dst); } size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + oc = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; //! If large group, each thread has its own worspace, set group_id with //! thread_id const int8_t* sptr = kern_param.src(batch_id, group_id); - const int8_t* fptr = - kern_param.filter(group_id) + oc * FH * FW * IC; - void* dst = reinterpret_cast(reinterpret_cast( - kern_param.dst(batch_id, group_id, oc))); + const int8_t* fptr = kern_param.filter(group_id) + oc * FH * FW * IC; + void* dst = reinterpret_cast( + reinterpret_cast(kern_param.dst(batch_id, group_id, oc))); const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); if (need_src_copy_var) { sptr = static_cast(bundle.get(0)) + @@ -244,13 +234,14 @@ void direct_int8_stride1::do_conv_kern(const WorkspaceBundle& bundle, #undef KERN1_NO_POST_PROCESS if (need_dst_copy_var) { rep(oh, OH) { - std::memcpy(reinterpret_cast( - reinterpret_cast(dst) + - oh * OW * kern_param.dst_type.size()), - reinterpret_cast( - reinterpret_cast(dptr) + - oh * OW2 * kern_param.dst_type.size()), - kern_param.dst_type.size() * OW); + std::memcpy( + reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); } } } @@ -268,23 +259,20 @@ SmallVector direct_int8_stride1::get_kimpls( #define DO_CONV_KERN_FUN(filter, bias_mode, op) \ do_conv_fun = do_conv_kern; -#define GET_OP_PARAM(i, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(i) \ @@ -331,21 +319,21 @@ SmallVector direct_int8_stride1::get_kimpls( size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - do_conv_fun(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + do_conv_fun( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - copy_padding_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); auto do_conv = [bundle, do_conv_fun]( diff --git a/dnn/src/arm_common/conv_bias/int8/stride1.h b/dnn/src/arm_common/conv_bias/int8/stride1.h index 7a02f20a..cf63b9ea 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride1.h +++ b/dnn/src/arm_common/conv_bias/int8/stride1.h @@ -28,17 +28,16 @@ bool can_conv_direct_stride1_int8(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); -void copy_padding_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void copy_padding_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); -SmallVector get_kimpls(const NCBKernSizeParam& param, - bool); +SmallVector get_kimpls(const NCBKernSizeParam& param, bool); } // namespace direct_int8_stride1 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp index 6ef8eb42..6916dc64 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp @@ -22,20 +22,18 @@ using namespace arm_common; using namespace direct_dotprod_int8_stride1; namespace { -bool need_dst_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { return param.osz[1] % 8; } -bool need_src_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_src_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { return true; } return need_dst_copy(param); } void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& IH2, + size_t& IW2, size_t& OH2, size_t& OW2) { auto&& fm = param.filter_meta; auto SW = fm.stride[1]; auto OH = param.osz[0]; @@ -67,18 +65,19 @@ bool direct_dotprod_int8_stride1::can_conv_direct_stride1_int8( param.filter_type.enumv() == DTypeEnum::Int8 && param.dst_type.enumv() == DTypeEnum::Int32)) && fm.format == param::Convolution::Format::NCHW && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) || ((FH == 3 || FH == 5 || FH == 7) && (OC <= 16 || (IC <= 4 && OC <= 32)))) && param.bias_mode != BiasMode::BIAS; if (param.bias_type.valid()) { - avaible &= ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && - param.bias_type.enumv() == DTypeEnum::QuantizedS32) || - (param.bias_type.enumv() == param.dst_type.enumv())); + avaible &= + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.bias_type.enumv() == DTypeEnum::QuantizedS32) || + (param.bias_type.enumv() == param.dst_type.enumv())); } return avaible && preferred; } @@ -93,9 +92,8 @@ WorkspaceBundle direct_dotprod_int8_stride1::get_bundle( get_rectified_size(param, IH2, IW2, OH2, OW2); size_t src_size = 0, dst_size = 0; if (need_src_copy(param)) { - src_size = m_large_group - ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads - : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + src_size = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; }; if (need_dst_copy(param)) { dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; @@ -109,10 +107,8 @@ WorkspaceBundle direct_dotprod_int8_stride1::get_bundle( } //! Process one input channel copy padding void direct_dotprod_int8_stride1::copy_padding_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -126,11 +122,9 @@ void direct_dotprod_int8_stride1::copy_padding_kern( size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1], + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1], channel_id = workspace_ids[2]; - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1]; + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1]; const int8_t* sptr = kern_param.src(batch_id, group_id, channel_id); if (need_src_copy_var) { //! copy to sptr_base to eliminate padding effect @@ -140,8 +134,9 @@ void direct_dotprod_int8_stride1::copy_padding_kern( channel_id * IH2 * IW2; std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(int8_t) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); } } }; @@ -160,28 +155,24 @@ void direct_dotprod_int8_stride1::do_conv_kern( get_rectified_size(kern_param, IH2, IW2, OH2, OW2); bool need_src_copy_var = need_src_copy(kern_param); bool need_dst_copy_var = need_dst_copy(kern_param); - bool need_post_process = - kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) Op op = Op(1.0f, 4.0f); if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; + float scale_bias = kern_param.bias_type.param().scale; float scale_dst = kern_param.dst_type.param().scale; op = Op(scale_bias, scale_dst); } size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + oc = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; //! If large group, each thread has its own worspace, set group_id //! with thread_id const int8_t* sptr = kern_param.src(batch_id, group_id); - const int8_t* fptr = - kern_param.filter(group_id) + oc * FH * FW * IC; + const int8_t* fptr = kern_param.filter(group_id) + oc * FH * FW * IC; void* dst = kern_param.dst(batch_id, group_id, oc); const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); if (need_src_copy_var) { @@ -245,13 +236,14 @@ void direct_dotprod_int8_stride1::do_conv_kern( #undef KERN1_NO_POST_PROCESS if (need_dst_copy_var) { rep(oh, OH) { - std::memcpy(reinterpret_cast( - reinterpret_cast(dst) + - oh * OW * kern_param.dst_type.size()), - reinterpret_cast( - reinterpret_cast(dptr) + - oh * OW2 * kern_param.dst_type.size()), - kern_param.dst_type.size() * OW); + std::memcpy( + reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); } } } @@ -269,23 +261,20 @@ SmallVector direct_dotprod_int8_stride1::get_kimpls( #define DO_CONV_KERN_FUN(filter, bias_mode, op) \ do_conv_fun = do_conv_kern; -#define GET_OP_PARAM(i, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(i) \ @@ -332,21 +321,21 @@ SmallVector direct_dotprod_int8_stride1::get_kimpls( size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - do_conv_fun(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + do_conv_fun( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - copy_padding_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); auto do_conv = [bundle, do_conv_fun]( diff --git a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h index 443a2192..6bb5ef30 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h +++ b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h @@ -27,18 +27,16 @@ bool can_conv_direct_stride1_int8(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); -void copy_padding_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void copy_padding_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); -SmallVector get_kimpls(const NCBKernSizeParam& param, - bool); +SmallVector get_kimpls(const NCBKernSizeParam& param, bool); } // namespace direct_dotprod_int8_stride1 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8/stride2.cpp b/dnn/src/arm_common/conv_bias/int8/stride2.cpp index 66e2c302..a91e52aa 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/stride2.cpp @@ -22,20 +22,18 @@ using namespace arm_common; using namespace direct_int8_stride2; namespace { -bool need_dst_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { return param.osz[1] % 8; } -bool need_src_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_src_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { return true; } return need_dst_copy(param); } void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& IH2, + size_t& IW2, size_t& OH2, size_t& OW2) { auto&& fm = param.filter_meta; size_t SW = fm.stride[1]; size_t IH = param.isz[0]; @@ -55,8 +53,7 @@ void get_rectified_size( IW2 = std::max(IW2, IW); } } // namespace -bool direct_int8_stride2::can_conv_direct_stride2_int8( - const NCBKernSizeParam& param) { +bool direct_int8_stride2::can_conv_direct_stride2_int8(const NCBKernSizeParam& param) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; auto OC = fm.ocpg; @@ -72,20 +69,21 @@ bool direct_int8_stride2::can_conv_direct_stride2_int8( param.filter_type.enumv() == DTypeEnum::Int8 && param.dst_type.enumv() == DTypeEnum::Int32)) && fm.format == param::Convolution::Format::NCHW && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); if (param.bias_type.valid()) { - avaible &= ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && - param.bias_type.enumv() == DTypeEnum::QuantizedS32) || - (param.bias_type.enumv() == param.dst_type.enumv())); + avaible &= + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.bias_type.enumv() == DTypeEnum::QuantizedS32) || + (param.bias_type.enumv() == param.dst_type.enumv())); } - bool preferred = (((FH == 2 || FH == 3) && - (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || - (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || - (FH == 7 && OC <= 16)) && - (param.bias_mode != BiasMode::BIAS); + bool preferred = + (((FH == 2 || FH == 3) && (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || + (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || + (FH == 7 && OC <= 16)) && + (param.bias_mode != BiasMode::BIAS); return avaible && preferred; } @@ -99,9 +97,8 @@ WorkspaceBundle direct_int8_stride2::get_bundle( get_rectified_size(param, IH2, IW2, OH2, OW2); size_t src_size = 0, dst_size = 0; if (need_src_copy(param)) { - src_size = m_large_group - ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads - : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + src_size = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; }; if (need_dst_copy(param)) { dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; @@ -115,10 +112,8 @@ WorkspaceBundle direct_int8_stride2::get_bundle( } //! Process one input channel copy padding void direct_int8_stride2::copy_padding_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -132,10 +127,9 @@ void direct_int8_stride2::copy_padding_kern( size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + channel_id = workspace_ids[2]; const int8_t* sptr = static_cast( kern_param.src(batch_id, group_id, channel_id)); if (need_src_copy_var) { @@ -146,17 +140,17 @@ void direct_int8_stride2::copy_padding_kern( channel_id * IH2 * IW2; std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(int8_t) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); } } }; //! compute one output channel template -void direct_int8_stride2::do_conv_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { +void direct_int8_stride2::do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t OH = kern_param.osz[0]; size_t OW = kern_param.osz[1]; size_t FH = kern_param.filter_meta.spatial[0]; @@ -167,27 +161,23 @@ void direct_int8_stride2::do_conv_kern(const WorkspaceBundle& bundle, get_rectified_size(kern_param, IH2, IW2, OH2, OW2); bool need_src_copy_var = need_src_copy(kern_param); bool need_dst_copy_var = need_dst_copy(kern_param); - bool need_post_process = - kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) Op op = Op(1.0f, 4.0f); if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; + float scale_bias = kern_param.bias_type.param().scale; float scale_dst = kern_param.dst_type.param().scale; op = Op(scale_bias, scale_dst); } size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + oc = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; //! If large group, each thread has its own worspace, set group_id with //! thread_id const int8_t* sptr = kern_param.src(batch_id, group_id); - const int8_t* fptr = - kern_param.filter(group_id) + oc * FH * FW * IC; + const int8_t* fptr = kern_param.filter(group_id) + oc * FH * FW * IC; void* dst = kern_param.dst(batch_id, group_id, oc); const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); if (need_src_copy_var) { @@ -251,13 +241,14 @@ void direct_int8_stride2::do_conv_kern(const WorkspaceBundle& bundle, #undef KERN1_NO_POST_PROCESS if (need_dst_copy_var) { rep(oh, OH) { - std::memcpy(reinterpret_cast( - reinterpret_cast(dst) + - oh * OW * kern_param.dst_type.size()), - reinterpret_cast( - reinterpret_cast(dptr) + - oh * OW2 * kern_param.dst_type.size()), - kern_param.dst_type.size() * OW); + std::memcpy( + reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); } } } @@ -275,23 +266,20 @@ SmallVector direct_int8_stride2::get_kimpls( #define DO_CONV_KERN_FUN(filter, bias_mode, op) \ do_conv_fun = do_conv_kern; -#define GET_OP_PARAM(i, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(i) \ @@ -338,21 +326,21 @@ SmallVector direct_int8_stride2::get_kimpls( size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - do_conv_fun(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + do_conv_fun( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - copy_padding_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); auto do_conv = [bundle, do_conv_fun]( diff --git a/dnn/src/arm_common/conv_bias/int8/stride2.h b/dnn/src/arm_common/conv_bias/int8/stride2.h index cf566e47..b3bb1758 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride2.h +++ b/dnn/src/arm_common/conv_bias/int8/stride2.h @@ -27,18 +27,16 @@ bool can_conv_direct_stride2_int8(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); -void copy_padding_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void copy_padding_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); -SmallVector get_kimpls(const NCBKernSizeParam& param, - bool); +SmallVector get_kimpls(const NCBKernSizeParam& param, bool); } // namespace direct_int8_stride2 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp index ebe541b1..f51d2c8f 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp @@ -22,20 +22,18 @@ using namespace arm_common; using namespace direct_dotprod_int8_stride2; namespace { -bool need_dst_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { return param.osz[1] % 8; } -bool need_src_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_src_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { return true; } return need_dst_copy(param); } void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& IH2, + size_t& IW2, size_t& OH2, size_t& OW2) { auto&& fm = param.filter_meta; size_t SW = fm.stride[1]; size_t IH = param.isz[0]; @@ -73,19 +71,20 @@ bool direct_dotprod_int8_stride2::can_conv_direct_stride2_int8( param.filter_type.enumv() == DTypeEnum::Int8 && param.dst_type.enumv() == DTypeEnum::Int32)) && fm.format == param::Convolution::Format::NCHW && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); - bool preferred = (((FH == 2 || FH == 3) && - (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || - (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || - (FH == 7 && OC <= 16)) && - (param.bias_mode != BiasMode::BIAS); + bool preferred = + (((FH == 2 || FH == 3) && (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || + (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || + (FH == 7 && OC <= 16)) && + (param.bias_mode != BiasMode::BIAS); if (param.bias_type.valid()) { - avaible &= ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && - param.bias_type.enumv() == DTypeEnum::QuantizedS32) || - (param.bias_type.enumv() == param.dst_type.enumv())); + avaible &= + ((param.src_type.enumv() == DTypeEnum::QuantizedS8 && + param.bias_type.enumv() == DTypeEnum::QuantizedS32) || + (param.bias_type.enumv() == param.dst_type.enumv())); } return avaible && preferred; } @@ -100,9 +99,8 @@ WorkspaceBundle direct_dotprod_int8_stride2::get_bundle( get_rectified_size(param, IH2, IW2, OH2, OW2); size_t src_size = 0, dst_size = 0; if (need_src_copy(param)) { - src_size = m_large_group - ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads - : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; + src_size = m_large_group ? IC * IH2 * IW2 * sizeof(int8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(int8_t) * group * batch; }; if (need_dst_copy(param)) { dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; @@ -116,10 +114,8 @@ WorkspaceBundle direct_dotprod_int8_stride2::get_bundle( } //! Process one input channel copy padding void direct_dotprod_int8_stride2::copy_padding_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -133,10 +129,9 @@ void direct_dotprod_int8_stride2::copy_padding_kern( size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + channel_id = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; const int8_t* sptr = kern_param.src(batch_id, group_id, channel_id); if (need_src_copy_var) { //! copy to sptr_base to eliminate padding effect @@ -146,8 +141,9 @@ void direct_dotprod_int8_stride2::copy_padding_kern( channel_id * IH2 * IW2; std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(int8_t) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); } } }; @@ -166,27 +162,23 @@ void direct_dotprod_int8_stride2::do_conv_kern( get_rectified_size(kern_param, IH2, IW2, OH2, OW2); bool need_src_copy_var = need_src_copy(kern_param); bool need_dst_copy_var = need_dst_copy(kern_param); - bool need_post_process = - kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; + bool need_post_process = kern_param.dst_type.enumv() == DTypeEnum::QuantizedS8; //! if dst_type is qint32, the op is not used, just fill with (1.0f,4.0f) Op op = Op(1.0f, 4.0f); if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; + float scale_bias = kern_param.bias_type.param().scale; float scale_dst = kern_param.dst_type.param().scale; op = Op(scale_bias, scale_dst); } size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + oc = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; //! If large group, each thread has its own worspace, set group_id //! with thread_id const int8_t* sptr = kern_param.src(batch_id, group_id); - const int8_t* fptr = - kern_param.filter(group_id) + oc * FH * FW * IC; + const int8_t* fptr = kern_param.filter(group_id) + oc * FH * FW * IC; void* dst = kern_param.dst(batch_id, group_id, oc); const int32_t* bptr = kern_param.bias(batch_id, group_id, oc); @@ -251,13 +243,14 @@ void direct_dotprod_int8_stride2::do_conv_kern( #undef KERN1_NO_POST_PROCESS if (need_dst_copy_var) { rep(oh, OH) { - std::memcpy(reinterpret_cast( - reinterpret_cast(dst) + - oh * OW * kern_param.dst_type.size()), - reinterpret_cast( - reinterpret_cast(dptr) + - oh * OW2 * kern_param.dst_type.size()), - kern_param.dst_type.size() * OW); + std::memcpy( + reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); } } } @@ -275,23 +268,20 @@ SmallVector direct_dotprod_int8_stride2::get_kimpls( #define DO_CONV_KERN_FUN(filter, bias_mode, op) \ do_conv_fun = do_conv_kern; -#define GET_OP_PARAM(i, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(i, bias_mode, TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(i) \ @@ -338,21 +328,21 @@ SmallVector direct_dotprod_int8_stride2::get_kimpls( size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - do_conv_fun(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + do_conv_fun( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - copy_padding_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); auto do_conv = [bundle, do_conv_fun]( diff --git a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h index 8cb2cd70..d14cd4db 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h +++ b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h @@ -28,18 +28,16 @@ bool can_conv_direct_stride2_int8(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); -void copy_padding_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void copy_padding_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); -SmallVector get_kimpls(const NCBKernSizeParam& param, - bool); +SmallVector get_kimpls(const NCBKernSizeParam& param, bool); } // namespace direct_dotprod_int8_stride2 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp index 39978f84..a8c8ea09 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.cpp @@ -19,7 +19,6 @@ MIDOUT_DECL(megdnn_arm_common_conv_bias_int8816_kimpl) - using namespace megdnn; using namespace arm_common; @@ -30,8 +29,7 @@ bool need_dst_copy_str1( return true; return false; } -bool need_src_copy_str1( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_src_copy_str1(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { auto&& fm = param.filter_meta; if (fm.padding[0] != 0 || fm.padding[1] != 0) @@ -39,23 +37,21 @@ bool need_src_copy_str1( return need_dst_copy_str1(param); } -void get_rectified_size_str1(size_t IH, size_t IW, size_t OH, size_t OW, - size_t PH, size_t PW, size_t& IH2, size_t& IW2, - size_t& OH2, size_t& OW2) { +void get_rectified_size_str1( + size_t IH, size_t IW, size_t OH, size_t OW, size_t PH, size_t PW, size_t& IH2, + size_t& IW2, size_t& OH2, size_t& OW2) { OH2 = OH; OW2 = (OW + 7) & ~7; IH2 = OH2 + (IH - OH) + 2 * PH; IW2 = OW2 + (IW - OW) + 2 * PW; } -bool need_dst_copy_str2( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_dst_copy_str2(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { // If the size of output is not multiples of 8, we need to copy it. if (param.osz[0] % 8 != 0 || param.osz[1] % 8 != 0) return true; return false; } -bool need_src_copy_str2( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_src_copy_str2(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { auto&& fm = param.filter_meta; // If padding is not zero, we need to copy to eliminate padding effect. if (fm.padding[0] != 0 || fm.padding[1] != 0) @@ -63,10 +59,9 @@ bool need_src_copy_str2( return need_dst_copy_str2(param); } -void get_rectified_size_str2(size_t IH, size_t IW, size_t OH, size_t OW, - size_t FH, size_t FW, size_t PH, size_t PW, - size_t& IH2, size_t& IW2, size_t& OH2, - size_t& OW2) { +void get_rectified_size_str2( + size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH, + size_t PW, size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { MEGDNN_MARK_USED_VAR(PH); MEGDNN_MARK_USED_VAR(PW); OH2 = (OH + 7) & ~7; @@ -81,10 +76,11 @@ void get_rectified_size_str2(size_t IH, size_t IW, size_t OH, size_t OW, } // namespace /* ===================== direct algo ===================== */ -bool ConvBiasImpl::AlgoI8x8x16Direct::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, - midout_iv("AlgoI8x8x16Direct::usable"_hash)) { +bool ConvBiasImpl::AlgoI8x8x16Direct::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8816_kimpl, + midout_iv("AlgoI8x8x16Direct::usable"_hash)) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; return param.bias_mode == BiasMode::NO_BIAS && @@ -92,10 +88,10 @@ bool ConvBiasImpl::AlgoI8x8x16Direct::usable(const NCBKernSizeParam& param, fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && param.src_type.enumv() == DTypeEnum::Int8 && param.filter_type.enumv() == DTypeEnum::Int8 && - param.dst_type.enumv() == DTypeEnum::Int16 && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); + param.dst_type.enumv() == DTypeEnum::Int16 && fm.spatial_ndim == 2 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.stride[0] == 1 && + fm.stride[1] == 1 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5); } MIDOUT_END(); return false; @@ -123,8 +119,9 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Direct::get_bundle( } size_t ConvBiasImpl::AlgoI8x8x16Direct::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, - midout_iv("AlgoI8x8x16Direct::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8816_kimpl, + midout_iv("AlgoI8x8x16Direct::get_workspace"_hash)) { auto bundle = get_bundle(param); return bundle.total_size_in_bytes(); } @@ -133,10 +130,8 @@ size_t ConvBiasImpl::AlgoI8x8x16Direct::get_workspace( } //! Process one input channel copy padding void ConvBiasImpl::AlgoI8x8x16Direct::copy_padding_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -151,11 +146,9 @@ void ConvBiasImpl::AlgoI8x8x16Direct::copy_padding_kern( size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; const int8_t* sptr = kern_param.src(batch_id, group_id, channel_id); if (need_src_copy_var) { //! copy to sptr_base to eliminate padding effect @@ -165,8 +158,9 @@ void ConvBiasImpl::AlgoI8x8x16Direct::copy_padding_kern( channel_id * IH2 * IW2; std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(int8_t) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); } } }; @@ -190,34 +184,29 @@ void ConvBiasImpl::AlgoI8x8x16Direct::do_conv_kern( bool need_dst_copy_var = need_dst_copy_str1(kern_param); size_t padding_group_size = IH2 * IW2 * IC; //! Choose the compute kernel - using Func = - std::function; + using Func = std::function; Func fun_not_add_to_dst = nullptr, fun_add_to_dst = nullptr; if (FH == 2) { - fun_not_add_to_dst = - conv_bias::conv_direct_2x2_sc_int8_int8_int16; + fun_not_add_to_dst = conv_bias::conv_direct_2x2_sc_int8_int8_int16; fun_add_to_dst = conv_bias::conv_direct_2x2_sc_int8_int8_int16; } else if (FH == 3) { - fun_not_add_to_dst = - conv_bias::conv_direct_3x3_sc_int8_int8_int16; + fun_not_add_to_dst = conv_bias::conv_direct_3x3_sc_int8_int8_int16; fun_add_to_dst = conv_bias::conv_direct_3x3_sc_int8_int8_int16; } else if (FH == 5) { - fun_not_add_to_dst = - conv_bias::conv_direct_5x5_sc_int8_int8_int16; + fun_not_add_to_dst = conv_bias::conv_direct_5x5_sc_int8_int8_int16; fun_add_to_dst = conv_bias::conv_direct_5x5_sc_int8_int8_int16; } //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + oc = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; const int8_t* sptr = kern_param.src(batch_id, group_id); - const int8_t* filter = - kern_param.filter(group_id) + oc * FH * FW * IC; + const int8_t* filter = kern_param.filter(group_id) + oc * FH * FW * IC; int16_t* dst = kern_param.dst(batch_id, group_id, oc); if (need_src_copy_var) { sptr = static_cast(bundle.get(0)) + @@ -226,15 +215,15 @@ void ConvBiasImpl::AlgoI8x8x16Direct::do_conv_kern( } int16_t* dptr = nullptr; if (need_dst_copy_var) { - dptr = static_cast(bundle.get(1)) + - ncb_index.thread_id * OH2 * OW2; + dptr = static_cast(bundle.get(1)) + ncb_index.thread_id * OH2 * OW2; } else { dptr = dst; } fun_not_add_to_dst(sptr, filter, dptr, IH2, IW2, OH2, OW2, 0, 0); for (size_t ic = 1; ic < IC; ++ic) { - fun_add_to_dst(sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2, - IW2, OH2, OW2, 0, 0); + fun_add_to_dst( + sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2, IW2, OH2, OW2, + 0, 0); } if (need_dst_copy_var) { rep(oh, OH) { @@ -253,32 +242,34 @@ SmallVector ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls( WorkspaceBundle bundle = get_bundle(param); SmallVector ret_kerns; if (large_group) { - auto exec_one_group = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto exec_one_group = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { auto fm = kern_param.filter_meta; size_t IC = fm.icpg; size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - do_conv_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + do_conv_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - copy_padding_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); - auto do_conv = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto do_conv = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); do_conv_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; @@ -286,11 +277,11 @@ SmallVector ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls( } return ret_kerns; } -SmallVector -ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, - midout_iv("AlgoI8x8x16Direct::dispatch_kerns"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8816_kimpl, + midout_iv("AlgoI8x8x16Direct::dispatch_kerns"_hash)) { return get_kimpls(param); } MIDOUT_END(); @@ -298,10 +289,11 @@ ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns( } /* ===================== stride-2 algo ===================== */ -bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, - midout_iv("AlgoI8x8x16Stride2::usable"_hash)) { +bool ConvBiasImpl::AlgoI8x8x16Stride2::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8816_kimpl, + midout_iv("AlgoI8x8x16Stride2::usable"_hash)) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; return param.bias_mode == BiasMode::NO_BIAS && @@ -309,10 +301,9 @@ bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(const NCBKernSizeParam& param, fm.format == param::ConvBias::Format::NCHW && !fm.should_flip && param.src_type.enumv() == DTypeEnum::Int8 && param.filter_type.enumv() == DTypeEnum::Int8 && - param.dst_type.enumv() == DTypeEnum::Int16 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && - (FH == 2 || FH == 3 || FH == 5); + param.dst_type.enumv() == DTypeEnum::Int16 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5); } MIDOUT_END(); return false; @@ -341,8 +332,9 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Stride2::get_bundle( } size_t ConvBiasImpl::AlgoI8x8x16Stride2::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, - midout_iv("AlgoI8x8x16Stride2::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8816_kimpl, + midout_iv("AlgoI8x8x16Stride2::get_workspace"_hash)) { auto bundle = get_bundle(param); return bundle.total_size_in_bytes(); } @@ -351,10 +343,8 @@ size_t ConvBiasImpl::AlgoI8x8x16Stride2::get_workspace( } //! Process one input channel copy padding void ConvBiasImpl::AlgoI8x8x16Stride2::copy_padding_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -362,19 +352,16 @@ void ConvBiasImpl::AlgoI8x8x16Stride2::copy_padding_kern( size_t OW = kern_param.osz[1]; size_t PH = kern_param.filter_meta.padding[0]; size_t PW = kern_param.filter_meta.padding[1]; - auto FH = kern_param.filter_meta.spatial[0], - FW = kern_param.filter_meta.spatial[1]; + auto FH = kern_param.filter_meta.spatial[0], FW = kern_param.filter_meta.spatial[1]; size_t GROUP = kern_param.filter_meta.group; size_t IH2, IW2, OH2, OW2; get_rectified_size_str2(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OH2, OW2); bool need_src_copy_var = need_src_copy_str2(kern_param); size_t padding_group_size = IH2 * IW2 * IC; - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; const int8_t* sptr = kern_param.src(batch_id, group_id, channel_id); if (need_src_copy_var) { //! copy to sptr_base to eliminate padding effect @@ -384,8 +371,9 @@ void ConvBiasImpl::AlgoI8x8x16Stride2::copy_padding_kern( channel_id * IH2 * IW2; std::memset(sptr_base, 0, sizeof(int8_t) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(int8_t) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(int8_t) * IW); } } }; @@ -409,32 +397,27 @@ void ConvBiasImpl::AlgoI8x8x16Stride2::do_conv_kern( bool need_dst_copy_var = need_dst_copy_str2(kern_param); size_t padding_group_size = IH2 * IW2 * IC; //! Choose the compute kernel - using Func = - std::function; + using Func = std::function; Func fun_not_add_to_dst = nullptr, fun_add_to_dst = nullptr; if (FH == 2) { - fun_not_add_to_dst = - conv_bias::conv_stride2_2x2_sc_int8_int8_int16; + fun_not_add_to_dst = conv_bias::conv_stride2_2x2_sc_int8_int8_int16; fun_add_to_dst = conv_bias::conv_stride2_2x2_sc_int8_int8_int16; } else if (FH == 3) { - fun_not_add_to_dst = - conv_bias::conv_stride2_3x3_sc_int8_int8_int16; + fun_not_add_to_dst = conv_bias::conv_stride2_3x3_sc_int8_int8_int16; fun_add_to_dst = conv_bias::conv_stride2_3x3_sc_int8_int8_int16; } else if (FH == 5) { - fun_not_add_to_dst = - conv_bias::conv_stride2_5x5_sc_int8_int8_int16; + fun_not_add_to_dst = conv_bias::conv_stride2_5x5_sc_int8_int8_int16; fun_add_to_dst = conv_bias::conv_stride2_5x5_sc_int8_int8_int16; } //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], oc = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + oc = workspace_ids[2]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; const int8_t* sptr = kern_param.src(batch_id, group_id); - const int8_t* filter = - kern_param.filter(group_id) + oc * FH * FW * IC; + const int8_t* filter = kern_param.filter(group_id) + oc * FH * FW * IC; int16_t* dst = kern_param.dst(batch_id, group_id, oc); if (need_src_copy_var) { sptr = static_cast(bundle.get(0)) + @@ -443,15 +426,15 @@ void ConvBiasImpl::AlgoI8x8x16Stride2::do_conv_kern( } int16_t* dptr = nullptr; if (need_dst_copy_var) { - dptr = static_cast(bundle.get(1)) + - ncb_index.thread_id * OH2 * OW2; + dptr = static_cast(bundle.get(1)) + ncb_index.thread_id * OH2 * OW2; } else { dptr = dst; } fun_not_add_to_dst(sptr, filter, dptr, IH2, IW2, OH2, OW2, 0, 0); for (size_t ic = 1; ic < IC; ++ic) { - fun_add_to_dst(sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2, - IW2, OH2, OW2, 0, 0); + fun_add_to_dst( + sptr + ic * IH2 * IW2, filter + ic * FH * FW, dptr, IH2, IW2, OH2, OW2, + 0, 0); } if (need_dst_copy_var) { rep(oh, OH) { @@ -470,32 +453,34 @@ SmallVector ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls( WorkspaceBundle bundle = get_bundle(param); SmallVector ret_kerns; if (large_group) { - auto exec_one_group = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto exec_one_group = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { auto fm = kern_param.filter_meta; size_t IC = fm.icpg; size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - do_conv_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + do_conv_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - copy_padding_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); - auto do_conv = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto do_conv = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); do_conv_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; @@ -503,11 +488,11 @@ SmallVector ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls( } return ret_kerns; } -SmallVector -ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, - midout_iv("AlgoI8x8x16Stride2::dispatch_kerns"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8816_kimpl, + midout_iv("AlgoI8x8x16Stride2::dispatch_kerns"_hash)) { return get_kimpls(param); } MIDOUT_END(); @@ -516,11 +501,11 @@ ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns( bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy /*algo_selection_strategy*/) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, - midout_iv("AlgoI8x8x16Stride2Filter2::usable"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8816_kimpl, + midout_iv("AlgoI8x8x16Stride2Filter2::usable"_hash)) { return param.bias_mode == BiasMode::NO_BIAS && - param.nonlineMode == NonlineMode::IDENTITY && - param.nr_threads == 1_z && + param.nonlineMode == NonlineMode::IDENTITY && param.nr_threads == 1_z && conv_bias::can_conv_int8x8x16_stride2_flt2(param); } MIDOUT_END(); @@ -529,22 +514,22 @@ bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable( size_t ConvBiasImpl::AlgoI8x8x16Stride2Filter2::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, - midout_iv("AlgoI8x8x16Stride2Filter2::get_workspace"_hash)) { - return conv_bias::get_workspace_in_bytes_conv_int8x8x16_stride2_flt2( - param); + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8816_kimpl, + midout_iv("AlgoI8x8x16Stride2Filter2::get_workspace"_hash)) { + return conv_bias::get_workspace_in_bytes_conv_int8x8x16_stride2_flt2(param); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoI8x8x16Stride2Filter2::dispatch_kerns( - const NCBKernSizeParam& param) const { +SmallVector ConvBiasImpl::AlgoI8x8x16Stride2Filter2:: + dispatch_kerns(const NCBKernSizeParam& param) const { // return {conv_bias::conv_int8x8x16_stride2_flt2,true}; auto kern = [](const NCBKernParam& param, const NCBKernIndex& ncb_index) { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, - midout_iv("AlgoI8x8x16Stride2Filter2::dispatch_kerns"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_int8816_kimpl, + midout_iv("AlgoI8x8x16Stride2Filter2::dispatch_kerns"_hash)) { auto ncb_param = param; ncb_param.src_ptr = param.src(0, ncb_index.ndrange_id[0]); ncb_param.dst_ptr = param.dst(0, ncb_index.ndrange_id[0]); @@ -558,7 +543,8 @@ ConvBiasImpl::AlgoI8x8x16Stride2Filter2::dispatch_kerns( return {{kern, {group, 1_z, 1_z}}}; } -/* =====================8int8x8x16 channel_wise_nchw44 stride1 stride2 algo ===================== */ +/* =====================8int8x8x16 channel_wise_nchw44 stride1 stride2 algo + * ===================== */ bool ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy) const { auto&& fm = param.filter_meta; @@ -570,13 +556,12 @@ bool ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::usable( param.dst_type.enumv() == DTypeEnum::Int16) && fm.format == param::Convolution::Format::NCHW44 && param.bias_mode != megdnn::BiasMode::BIAS && - param.nonlineMode == megdnn::NonlineMode::IDENTITY && - !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && + param.nonlineMode == megdnn::NonlineMode::IDENTITY && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && (fm.stride[0] == fm.stride[1] && (fm.stride[0] == 1 || fm.stride[0] == 2)) && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5) && - fm.icpg == 1 && fm.ocpg == 1 && fm.group % 4 == 0; + FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5) && fm.icpg == 1 && + fm.ocpg == 1 && fm.group % 4 == 0; return avaible; } @@ -584,8 +569,7 @@ size_t ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace( const NCBKernSizeParam& param) const { MIDOUT_BEGIN( megdnn_arm_common_conv_bias_int8816_kimpl, - midout_iv( - "AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace"_hash)) { + midout_iv("AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace"_hash)) { size_t stride_h = param.filter_meta.stride[0]; size_t stride_w = param.filter_meta.stride[1]; megdnn_assert(stride_h == stride_w); @@ -603,9 +587,9 @@ size_t ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::dispatch_kerns( - const NCBKernSizeParam& param) const { +SmallVector ConvBiasImpl:: + AlgoS8x8x16ChanWiseStride1Stride2NCHW44::dispatch_kerns( + const NCBKernSizeParam& param) const { size_t stride_h = param.filter_meta.stride[0]; size_t stride_w = param.filter_meta.stride[1]; if (stride_h == stride_w && stride_h == 1) { @@ -620,8 +604,7 @@ ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::dispatch_kerns( } else if (stride_h == stride_w && stride_h == 2) { MIDOUT_BEGIN( megdnn_arm_common_conv_bias_int8816_kimpl, - midout_iv( - "AlgoS8x8x16ChanWiseStride2NCHW44_dispatch_kerns"_hash)) { + midout_iv("AlgoS8x8x16ChanWiseStride2NCHW44_dispatch_kerns"_hash)) { return channel_wise_nchw44_8x8x16::stride2::get_kimpls(param); } MIDOUT_END(); diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h index 1ebc2083..c0a7127b 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h @@ -19,22 +19,19 @@ namespace arm_common { class ConvBiasImpl::AlgoI8x8x16Direct final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; - static void copy_padding_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); - static void do_conv_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); + static void copy_padding_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); + static void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "I8816DIRECT"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; @@ -47,12 +44,11 @@ public: class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { public: AlgoS8x8x16DirectNCHW44() {} - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "S8x8x16_NCHW44_DIRECT"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; @@ -65,22 +61,19 @@ public: class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; - static void copy_padding_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); - static void do_conv_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); + static void copy_padding_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); + static void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "I8816STRD2"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -93,13 +86,12 @@ public: class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "I8816STRD2F2"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -110,17 +102,14 @@ public: MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_F2_INT8X8X16) }; -class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final - : public AlgoBase { +class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; - size_t get_workspace( - const NCBKernSizeParam& param) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; ConvAlgoTypePack get_algo_type() const override { @@ -134,12 +123,11 @@ class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { public: AlgoI8x8x16DirectNCHWNCHW44() {} - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "I8816_CONV_NCHW_NCHW44"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h index 09c299cd..8cdf84e6 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel.h @@ -16,12 +16,11 @@ namespace megdnn { namespace arm_common { namespace channel_wise_nchw44_8x8x16 { -#define KERN(stride, i) \ - template \ - void direct_##stride##_##i##x##i##_int8x8x16( \ - const int8_t* src, const int8_t* filter, const int16_t* bias, \ - void* dst, const size_t IH, const size_t IW, const size_t OH, \ - const size_t OW); +#define KERN(stride, i) \ + template \ + void direct_##stride##_##i##x##i##_int8x8x16( \ + const int8_t* src, const int8_t* filter, const int16_t* bias, void* dst, \ + const size_t IH, const size_t IW, const size_t OH, const size_t OW); KERN(stride1, 2) KERN(stride1, 3) diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp index 6eeb7580..65c986e9 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_kernel_int8x8x16_nchw44.cpp @@ -29,36 +29,33 @@ using namespace arm_common; init_sum = vdupq_n_s16(0); \ } -#define STORE_1_LINE_RESULT(dst, oh, ow, OW, sum) \ - do { \ - dt_int16* dptr = \ - reinterpret_cast(dst) + (oh)*OW * 4 + ow * 4; \ - vst1q_s16(dptr, sum[0]); \ - vst1q_s16(dptr + 8, sum[1]); \ - vst1q_s16(dptr + 16, sum[2]); \ - vst1q_s16(dptr + 24, sum[3]); \ +#define STORE_1_LINE_RESULT(dst, oh, ow, OW, sum) \ + do { \ + dt_int16* dptr = reinterpret_cast(dst) + (oh)*OW * 4 + ow * 4; \ + vst1q_s16(dptr, sum[0]); \ + vst1q_s16(dptr + 8, sum[1]); \ + vst1q_s16(dptr + 16, sum[2]); \ + vst1q_s16(dptr + 24, sum[3]); \ } while (0); -#define STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum) \ - do { \ - dt_int16* dptr = \ - reinterpret_cast(dst) + (oh)*OW * 4 + ow * 4; \ - vst1q_s16(dptr, sum[0]); \ - vst1q_s16(dptr + 8, sum[1]); \ +#define STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum) \ + do { \ + dt_int16* dptr = reinterpret_cast(dst) + (oh)*OW * 4 + ow * 4; \ + vst1q_s16(dptr, sum[0]); \ + vst1q_s16(dptr + 8, sum[1]); \ } while (0); -#define STORE_REMAIN(dst, oh, ow, OW, sum, remain) \ - do { \ - dt_int16* dptr = \ - reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ - if (remain == 1) { \ - vst1_s16(dptr, vget_low_s16(sum[0])); \ - } else if (remain == 2) { \ - vst1q_s16(dptr, sum[0]); \ - } else if (remain == 3) { \ - vst1q_s16(dptr, sum[0]); \ - vst1_s16(dptr + 8, vget_low_s16(sum[1])); \ - } \ +#define STORE_REMAIN(dst, oh, ow, OW, sum, remain) \ + do { \ + dt_int16* dptr = reinterpret_cast(dst) + oh * OW * 4 + ow * 4; \ + if (remain == 1) { \ + vst1_s16(dptr, vget_low_s16(sum[0])); \ + } else if (remain == 2) { \ + vst1q_s16(dptr, sum[0]); \ + } else if (remain == 3) { \ + vst1q_s16(dptr, sum[0]); \ + vst1_s16(dptr + 8, vget_low_s16(sum[1])); \ + } \ } while (0); template @@ -91,15 +88,15 @@ void channel_wise_nchw44_8x8x16::direct_stride1_2x2_int8x8x16( _src[1] = vextq_s8(_src[0], _src[2], 4); \ _src[3] = vextq_s8(_src[2], _src[3], 4); -#define CALC_ONE_LINE_8_RESULT(_sum,_src,_kid0,_kid1)\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[0]),kern[_kid0]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[0]),kern[_kid0]);\ - _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[2]),kern[_kid0]);\ - _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[2]),kern[_kid0]);\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[1]),kern[_kid1]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[1]),kern[_kid1]);\ - _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[3]),kern[_kid1]);\ - _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[3]),kern[_kid1]); +#define CALC_ONE_LINE_8_RESULT(_sum, _src, _kid0, _kid1) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[2]), kern[_kid0]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[2]), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[3]), kern[_kid1]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[3]), kern[_kid1]); size_t oh = 0_z; for (; oh + 2 <= OH; oh += 2) { @@ -225,7 +222,6 @@ void channel_wise_nchw44_8x8x16::direct_stride1_2x2_int8x8x16( CALC_ONE_LINE_4_RESULT(sum, src[1], 2, 3); STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum); - } if (ow < OW) { @@ -379,33 +375,33 @@ void channel_wise_nchw44_8x8x16::direct_stride1_3x3_int8x8x16( int8x16_t src[2][3]; - LOAD_3_SRC(sptr0,src[0]); - LOAD_3_SRC(sptr1,src[1]); + LOAD_3_SRC(sptr0, src[0]); + LOAD_3_SRC(sptr1, src[1]); - CALC_ONE_LINE(src[0],kern[0],kern[1],kern[2],sum0);//line0 - CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum0);//line1 - CALC_ONE_LINE(src[1],kern[0],kern[1],kern[2],sum1);//line1 + CALC_ONE_LINE(src[0], kern[0], kern[1], kern[2], sum0); // line0 + CALC_ONE_LINE(src[1], kern[3], kern[4], kern[5], sum0); // line1 + CALC_ONE_LINE(src[1], kern[0], kern[1], kern[2], sum1); // line1 - LOAD_3_SRC(sptr2,src[0]);//line2 + LOAD_3_SRC(sptr2, src[0]); // line2 - CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum0);//line2 + CALC_ONE_LINE(src[0], kern[6], kern[7], kern[8], sum0); // line2 STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum0) - CALC_ONE_LINE(src[0],kern[3],kern[4],kern[5],sum1);//line2 - CALC_ONE_LINE(src[0],kern[0],kern[1],kern[2],sum2);//line2 - LOAD_3_SRC(sptr3,src[1]);//line3 + CALC_ONE_LINE(src[0], kern[3], kern[4], kern[5], sum1); // line2 + CALC_ONE_LINE(src[0], kern[0], kern[1], kern[2], sum2); // line2 + LOAD_3_SRC(sptr3, src[1]); // line3 - CALC_ONE_LINE(src[1],kern[6],kern[7],kern[8],sum1);//line3 - STORE_1_LINE_4_RESULT(dst, (oh+1), ow, OW, sum1) - CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum2);//line3 - LOAD_3_SRC(sptr4,src[0]); - CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum2);//line4 - STORE_1_LINE_4_RESULT(dst, (oh+2), ow, OW, sum2) + CALC_ONE_LINE(src[1], kern[6], kern[7], kern[8], sum1); // line3 + STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum1) + CALC_ONE_LINE(src[1], kern[3], kern[4], kern[5], sum2); // line3 + LOAD_3_SRC(sptr4, src[0]); + CALC_ONE_LINE(src[0], kern[6], kern[7], kern[8], sum2); // line4 + STORE_1_LINE_4_RESULT(dst, (oh + 2), ow, OW, sum2) } if (ow < OW) { size_t iw = ow; size_t remain = OW - ow; - + const int8_t* __restrict sptr0 = sptr + ih * IW * 4 + iw * 4; const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW * 4 + iw * 4; const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW * 4 + iw * 4; @@ -420,27 +416,27 @@ void channel_wise_nchw44_8x8x16::direct_stride1_3x3_int8x8x16( UNROLL_CALL_NOWRAPPER(2, cb); #undef cb - LOAD_3_SRC(sptr0,src[0]);//line2 - LOAD_3_SRC(sptr1,src[1]);//line2 + LOAD_3_SRC(sptr0, src[0]); // line2 + LOAD_3_SRC(sptr1, src[1]); // line2 CALC_ONE_LINE(src[0], kern[0], kern[1], kern[2], sum0); // line0 - CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum0);//line1 - CALC_ONE_LINE(src[1],kern[0],kern[1],kern[2],sum1);//line1 + CALC_ONE_LINE(src[1], kern[3], kern[4], kern[5], sum0); // line1 + CALC_ONE_LINE(src[1], kern[0], kern[1], kern[2], sum1); // line1 - LOAD_3_SRC(sptr2,src[0]);//line2 + LOAD_3_SRC(sptr2, src[0]); // line2 - CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum0);//line2 - STORE_REMAIN(dst, (oh+0), ow, OW, sum0,remain) + CALC_ONE_LINE(src[0], kern[6], kern[7], kern[8], sum0); // line2 + STORE_REMAIN(dst, (oh + 0), ow, OW, sum0, remain) - CALC_ONE_LINE(src[0],kern[3],kern[4],kern[5],sum1);//line2 - CALC_ONE_LINE(src[0],kern[0],kern[1],kern[2],sum2);//line2 - LOAD_3_SRC(sptr3,src[1]);//line3 + CALC_ONE_LINE(src[0], kern[3], kern[4], kern[5], sum1); // line2 + CALC_ONE_LINE(src[0], kern[0], kern[1], kern[2], sum2); // line2 + LOAD_3_SRC(sptr3, src[1]); // line3 - CALC_ONE_LINE(src[1],kern[6],kern[7],kern[8],sum1);//line3 - STORE_REMAIN(dst, (oh+1), ow, OW, sum1,remain) - CALC_ONE_LINE(src[1],kern[3],kern[4],kern[5],sum2);//line3 - LOAD_3_SRC(sptr4,src[0]); - CALC_ONE_LINE(src[0],kern[6],kern[7],kern[8],sum2);//line4 - STORE_REMAIN(dst, (oh+2), ow, OW, sum2, remain) + CALC_ONE_LINE(src[1], kern[6], kern[7], kern[8], sum1); // line3 + STORE_REMAIN(dst, (oh + 1), ow, OW, sum1, remain) + CALC_ONE_LINE(src[1], kern[3], kern[4], kern[5], sum2); // line3 + LOAD_3_SRC(sptr4, src[0]); + CALC_ONE_LINE(src[0], kern[6], kern[7], kern[8], sum2); // line4 + STORE_REMAIN(dst, (oh + 2), ow, OW, sum2, remain) } } for (; oh < OH; oh++) { @@ -521,7 +517,7 @@ void channel_wise_nchw44_8x8x16::direct_stride1_3x3_int8x8x16( LOAD_3_SRC(sptr2, src[0]); // line2 CALC_ONE_LINE(src[0], kern[6], kern[7], kern[8], sum00); // line2 - STORE_REMAIN(dst, oh, ow, OW, sum00,(OW-ow)) + STORE_REMAIN(dst, oh, ow, OW, sum00, (OW - ow)) } } #undef LOAD_3_SRC @@ -560,40 +556,39 @@ void channel_wise_nchw44_8x8x16::direct_stride1_5x5_int8x8x16( src[6] = vextq_s8(src[4], src[8], 8); \ src[7] = vextq_s8(src[4], src[8], 12); +#define CALC_ONE_LINE_4_RESULT(_sum, _src, _kid0, _kid1, _kid2, _kid3, _kid4) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[2]), kern[_kid2]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[2]), kern[_kid2]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[3]), kern[_kid3]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[3]), kern[_kid3]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[4]), kern[_kid4]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[4]), kern[_kid4]); -#define CALC_ONE_LINE_4_RESULT(_sum,_src,_kid0,_kid1,_kid2,_kid3,_kid4)\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[0]),kern[_kid0]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[0]),kern[_kid0]);\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[1]),kern[_kid1]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[1]),kern[_kid1]);\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[2]),kern[_kid2]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[2]),kern[_kid2]);\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[3]),kern[_kid3]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[3]),kern[_kid3]);\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[4]),kern[_kid4]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[4]),kern[_kid4]); - -#define CALC_ONE_LINE_8_RESULT(_sum,_src,_kid0,_kid1,_kid2,_kid3,_kid4)\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[0]),kern[_kid0]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[0]),kern[_kid0]);\ - _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[4]),kern[_kid0]);\ - _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[4]),kern[_kid0]);\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[1]),kern[_kid1]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[1]),kern[_kid1]);\ - _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[5]),kern[_kid1]);\ - _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[5]),kern[_kid1]);\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[2]),kern[_kid2]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[2]),kern[_kid2]);\ - _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[6]),kern[_kid2]);\ - _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[6]),kern[_kid2]);\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[3]),kern[_kid3]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[3]),kern[_kid3]);\ - _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[7]),kern[_kid3]);\ - _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[7]),kern[_kid3]);\ - _sum[0]=vmlal_s8(_sum[0], vget_low_s8(_src[4]),kern[_kid4]);\ - _sum[1]=vmlal_s8(_sum[1],vget_high_s8(_src[4]),kern[_kid4]);\ - _sum[2]=vmlal_s8(_sum[2], vget_low_s8(_src[8]),kern[_kid4]);\ - _sum[3]=vmlal_s8(_sum[3],vget_high_s8(_src[8]),kern[_kid4]); +#define CALC_ONE_LINE_8_RESULT(_sum, _src, _kid0, _kid1, _kid2, _kid3, _kid4) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[4]), kern[_kid0]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[4]), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[5]), kern[_kid1]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[5]), kern[_kid1]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[2]), kern[_kid2]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[2]), kern[_kid2]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[6]), kern[_kid2]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[6]), kern[_kid2]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[3]), kern[_kid3]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[3]), kern[_kid3]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[7]), kern[_kid3]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[7]), kern[_kid3]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[4]), kern[_kid4]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[4]), kern[_kid4]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[8]), kern[_kid4]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[8]), kern[_kid4]); size_t oh = 0_z; for (; oh + 2 <= OH; oh += 2) { @@ -617,27 +612,27 @@ void channel_wise_nchw44_8x8x16::direct_stride1_5x5_int8x8x16( UNROLL_CALL_NOWRAPPER(4, cb); #undef cb - - LOAD_1_LINE_10_SRC(sptr0,src[0]); - LOAD_1_LINE_10_SRC(sptr1,src[1]); - - CALC_ONE_LINE_8_RESULT(sum[0],src[0],0,1,2,3,4); - LOAD_1_LINE_10_SRC(sptr2,src[0]);//line2 - CALC_ONE_LINE_8_RESULT(sum[0],src[1],5,6,7,8,9);//line1 - CALC_ONE_LINE_8_RESULT(sum[1],src[1],0,1,2,3,4);//line1 - LOAD_1_LINE_10_SRC(sptr3,src[1]);//line3 - CALC_ONE_LINE_8_RESULT(sum[0],src[0],10,11,12,13,14);//line2 - CALC_ONE_LINE_8_RESULT(sum[1],src[0],5,6,7,8,9);//line2 - LOAD_1_LINE_10_SRC(sptr4,src[0]);//line4 - CALC_ONE_LINE_8_RESULT(sum[0],src[1],15,16,17,18,19);//line3 - CALC_ONE_LINE_8_RESULT(sum[1],src[1],10,11,12,13,14);//line3 - LOAD_1_LINE_10_SRC(sptr5,src[1]);//line5 - CALC_ONE_LINE_8_RESULT(sum[0],src[0],20,21,22,23,24);//line4 - CALC_ONE_LINE_8_RESULT(sum[1],src[0],15,16,17,18,19);//line3 - CALC_ONE_LINE_8_RESULT(sum[1],src[1],20,21,22,23,24);//line3 - - STORE_1_LINE_RESULT(dst,oh,ow,OW,sum[0]); - STORE_1_LINE_RESULT(dst,(oh+1),ow,OW,sum[1]); + + LOAD_1_LINE_10_SRC(sptr0, src[0]); + LOAD_1_LINE_10_SRC(sptr1, src[1]); + + CALC_ONE_LINE_8_RESULT(sum[0], src[0], 0, 1, 2, 3, 4); + LOAD_1_LINE_10_SRC(sptr2, src[0]); // line2 + CALC_ONE_LINE_8_RESULT(sum[0], src[1], 5, 6, 7, 8, 9); // line1 + CALC_ONE_LINE_8_RESULT(sum[1], src[1], 0, 1, 2, 3, 4); // line1 + LOAD_1_LINE_10_SRC(sptr3, src[1]); // line3 + CALC_ONE_LINE_8_RESULT(sum[0], src[0], 10, 11, 12, 13, 14); // line2 + CALC_ONE_LINE_8_RESULT(sum[1], src[0], 5, 6, 7, 8, 9); // line2 + LOAD_1_LINE_10_SRC(sptr4, src[0]); // line4 + CALC_ONE_LINE_8_RESULT(sum[0], src[1], 15, 16, 17, 18, 19); // line3 + CALC_ONE_LINE_8_RESULT(sum[1], src[1], 10, 11, 12, 13, 14); // line3 + LOAD_1_LINE_10_SRC(sptr5, src[1]); // line5 + CALC_ONE_LINE_8_RESULT(sum[0], src[0], 20, 21, 22, 23, 24); // line4 + CALC_ONE_LINE_8_RESULT(sum[1], src[0], 15, 16, 17, 18, 19); // line3 + CALC_ONE_LINE_8_RESULT(sum[1], src[1], 20, 21, 22, 23, 24); // line3 + + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum[0]); + STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum[1]); } #endif for (; ow + 4 <= OW; ow += 4) { @@ -658,27 +653,26 @@ void channel_wise_nchw44_8x8x16::direct_stride1_5x5_int8x8x16( UNROLL_CALL_NOWRAPPER(2, cb); #undef cb - - LOAD_1_LINE_SRC(sptr0,src[0]); - LOAD_1_LINE_SRC(sptr1,src[1]); - - CALC_ONE_LINE_4_RESULT(sum[0],src[0],0,1,2,3,4); - LOAD_1_LINE_SRC(sptr2,src[0]);//line2 - CALC_ONE_LINE_4_RESULT(sum[0],src[1],5,6,7,8,9);//line1 - CALC_ONE_LINE_4_RESULT(sum[1],src[1],0,1,2,3,4);//line1 - LOAD_1_LINE_SRC(sptr3,src[1]);//line3 - CALC_ONE_LINE_4_RESULT(sum[0],src[0],10,11,12,13,14);//line2 - CALC_ONE_LINE_4_RESULT(sum[1],src[0],5,6,7,8,9);//line2 - LOAD_1_LINE_SRC(sptr4,src[0]);//line4 - CALC_ONE_LINE_4_RESULT(sum[0],src[1],15,16,17,18,19);//line3 - CALC_ONE_LINE_4_RESULT(sum[1],src[1],10,11,12,13,14);//line3 - LOAD_1_LINE_SRC(sptr5,src[1]);//line5 - CALC_ONE_LINE_4_RESULT(sum[0],src[0],20,21,22,23,24);//line4 - CALC_ONE_LINE_4_RESULT(sum[1],src[0],15,16,17,18,19);//line3 - CALC_ONE_LINE_4_RESULT(sum[1],src[1],20,21,22,23,24);//line3 - - STORE_1_LINE_4_RESULT(dst,oh,ow,OW,sum[0]); - STORE_1_LINE_4_RESULT(dst,(oh+1),ow,OW,sum[1]); + LOAD_1_LINE_SRC(sptr0, src[0]); + LOAD_1_LINE_SRC(sptr1, src[1]); + + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 0, 1, 2, 3, 4); + LOAD_1_LINE_SRC(sptr2, src[0]); // line2 + CALC_ONE_LINE_4_RESULT(sum[0], src[1], 5, 6, 7, 8, 9); // line1 + CALC_ONE_LINE_4_RESULT(sum[1], src[1], 0, 1, 2, 3, 4); // line1 + LOAD_1_LINE_SRC(sptr3, src[1]); // line3 + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 10, 11, 12, 13, 14); // line2 + CALC_ONE_LINE_4_RESULT(sum[1], src[0], 5, 6, 7, 8, 9); // line2 + LOAD_1_LINE_SRC(sptr4, src[0]); // line4 + CALC_ONE_LINE_4_RESULT(sum[0], src[1], 15, 16, 17, 18, 19); // line3 + CALC_ONE_LINE_4_RESULT(sum[1], src[1], 10, 11, 12, 13, 14); // line3 + LOAD_1_LINE_SRC(sptr5, src[1]); // line5 + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 20, 21, 22, 23, 24); // line4 + CALC_ONE_LINE_4_RESULT(sum[1], src[0], 15, 16, 17, 18, 19); // line3 + CALC_ONE_LINE_4_RESULT(sum[1], src[1], 20, 21, 22, 23, 24); // line3 + + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]); + STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]); } if (ow < OW) { size_t remain = OW - ow; @@ -689,7 +683,7 @@ void channel_wise_nchw44_8x8x16::direct_stride1_5x5_int8x8x16( const int8_t* __restrict sptr3 = sptr2 + IW * 4; const int8_t* __restrict sptr4 = sptr3 + IW * 4; const int8_t* __restrict sptr5 = sptr4 + IW * 4; - + int16x8_t sum[2][2]; int8x16_t src[2][5]; #define cb(j) \ @@ -698,26 +692,26 @@ void channel_wise_nchw44_8x8x16::direct_stride1_5x5_int8x8x16( UNROLL_CALL_NOWRAPPER(2, cb); #undef cb - LOAD_1_LINE_SRC(sptr0,src[0]); - LOAD_1_LINE_SRC(sptr1,src[1]); - - CALC_ONE_LINE_4_RESULT(sum[0],src[0],0,1,2,3,4); - LOAD_1_LINE_SRC(sptr2,src[0]);//line2 - CALC_ONE_LINE_4_RESULT(sum[0],src[1],5,6,7,8,9);//line1 - CALC_ONE_LINE_4_RESULT(sum[1],src[1],0,1,2,3,4);//line1 - LOAD_1_LINE_SRC(sptr3,src[1]);//line3 - CALC_ONE_LINE_4_RESULT(sum[0],src[0],10,11,12,13,14);//line2 - CALC_ONE_LINE_4_RESULT(sum[1],src[0],5,6,7,8,9);//line2 - LOAD_1_LINE_SRC(sptr4,src[0]);//line4 - CALC_ONE_LINE_4_RESULT(sum[0],src[1],15,16,17,18,19);//line3 - CALC_ONE_LINE_4_RESULT(sum[1],src[1],10,11,12,13,14);//line3 - LOAD_1_LINE_SRC(sptr5,src[1]);//line5 - CALC_ONE_LINE_4_RESULT(sum[0],src[0],20,21,22,23,24);//line4 - CALC_ONE_LINE_4_RESULT(sum[1],src[0],15,16,17,18,19);//line3 - CALC_ONE_LINE_4_RESULT(sum[1],src[1],20,21,22,23,24);//line3 - - STORE_REMAIN(dst,oh,ow,OW,sum[0],remain); - STORE_REMAIN(dst,(oh+1),ow,OW,sum[1],remain); + LOAD_1_LINE_SRC(sptr0, src[0]); + LOAD_1_LINE_SRC(sptr1, src[1]); + + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 0, 1, 2, 3, 4); + LOAD_1_LINE_SRC(sptr2, src[0]); // line2 + CALC_ONE_LINE_4_RESULT(sum[0], src[1], 5, 6, 7, 8, 9); // line1 + CALC_ONE_LINE_4_RESULT(sum[1], src[1], 0, 1, 2, 3, 4); // line1 + LOAD_1_LINE_SRC(sptr3, src[1]); // line3 + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 10, 11, 12, 13, 14); // line2 + CALC_ONE_LINE_4_RESULT(sum[1], src[0], 5, 6, 7, 8, 9); // line2 + LOAD_1_LINE_SRC(sptr4, src[0]); // line4 + CALC_ONE_LINE_4_RESULT(sum[0], src[1], 15, 16, 17, 18, 19); // line3 + CALC_ONE_LINE_4_RESULT(sum[1], src[1], 10, 11, 12, 13, 14); // line3 + LOAD_1_LINE_SRC(sptr5, src[1]); // line5 + CALC_ONE_LINE_4_RESULT(sum[0], src[0], 20, 21, 22, 23, 24); // line4 + CALC_ONE_LINE_4_RESULT(sum[1], src[0], 15, 16, 17, 18, 19); // line3 + CALC_ONE_LINE_4_RESULT(sum[1], src[1], 20, 21, 22, 23, 24); // line3 + + STORE_REMAIN(dst, oh, ow, OW, sum[0], remain); + STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain); } } for (; oh < OH; oh++) { @@ -738,19 +732,19 @@ void channel_wise_nchw44_8x8x16::direct_stride1_5x5_int8x8x16( UNROLL_CALL_NOWRAPPER(4, cb); #undef cb - LOAD_1_LINE_10_SRC(sptr0,src[0]); - LOAD_1_LINE_10_SRC(sptr1,src[1]); + LOAD_1_LINE_10_SRC(sptr0, src[0]); + LOAD_1_LINE_10_SRC(sptr1, src[1]); - CALC_ONE_LINE_8_RESULT(sum,src[0],0,1,2,3,4); - LOAD_1_LINE_10_SRC(sptr2,src[0]);//line2 - CALC_ONE_LINE_8_RESULT(sum,src[1],5,6,7,8,9);//line1 - LOAD_1_LINE_10_SRC(sptr3,src[1]);//line3 - CALC_ONE_LINE_8_RESULT(sum,src[0],10,11,12,13,14);//line2 - LOAD_1_LINE_10_SRC(sptr4,src[0]);//line4 - CALC_ONE_LINE_8_RESULT(sum,src[1],15,16,17,18,19);//line3 - CALC_ONE_LINE_8_RESULT(sum,src[0],20,21,22,23,24);//line4 + CALC_ONE_LINE_8_RESULT(sum, src[0], 0, 1, 2, 3, 4); + LOAD_1_LINE_10_SRC(sptr2, src[0]); // line2 + CALC_ONE_LINE_8_RESULT(sum, src[1], 5, 6, 7, 8, 9); // line1 + LOAD_1_LINE_10_SRC(sptr3, src[1]); // line3 + CALC_ONE_LINE_8_RESULT(sum, src[0], 10, 11, 12, 13, 14); // line2 + LOAD_1_LINE_10_SRC(sptr4, src[0]); // line4 + CALC_ONE_LINE_8_RESULT(sum, src[1], 15, 16, 17, 18, 19); // line3 + CALC_ONE_LINE_8_RESULT(sum, src[0], 20, 21, 22, 23, 24); // line4 - STORE_1_LINE_RESULT(dst,oh,ow,OW,sum); + STORE_1_LINE_RESULT(dst, oh, ow, OW, sum); } #endif for (; ow + 4 <= OW; ow += 4) { @@ -763,24 +757,23 @@ void channel_wise_nchw44_8x8x16::direct_stride1_5x5_int8x8x16( int16x8_t sum[2]; int8x16_t src[2][5]; - sum[0]=init_sum; - sum[1]=init_sum; - - - LOAD_1_LINE_SRC(sptr0,src[0]); - LOAD_1_LINE_SRC(sptr1,src[1]); - - CALC_ONE_LINE_4_RESULT(sum,src[0],0,1,2,3,4); - LOAD_1_LINE_SRC(sptr2,src[0]);//line2 - CALC_ONE_LINE_4_RESULT(sum,src[1],5,6,7,8,9);//line1 - LOAD_1_LINE_SRC(sptr3,src[1]);//line3 - CALC_ONE_LINE_4_RESULT(sum,src[0],10,11,12,13,14);//line2 - LOAD_1_LINE_SRC(sptr4,src[0]);//line4 - CALC_ONE_LINE_4_RESULT(sum,src[1],15,16,17,18,19);//line3 - CALC_ONE_LINE_4_RESULT(sum,src[0],20,21,22,23,24);//line4 - - STORE_1_LINE_4_RESULT(dst,oh,ow,OW,sum); - } + sum[0] = init_sum; + sum[1] = init_sum; + + LOAD_1_LINE_SRC(sptr0, src[0]); + LOAD_1_LINE_SRC(sptr1, src[1]); + + CALC_ONE_LINE_4_RESULT(sum, src[0], 0, 1, 2, 3, 4); + LOAD_1_LINE_SRC(sptr2, src[0]); // line2 + CALC_ONE_LINE_4_RESULT(sum, src[1], 5, 6, 7, 8, 9); // line1 + LOAD_1_LINE_SRC(sptr3, src[1]); // line3 + CALC_ONE_LINE_4_RESULT(sum, src[0], 10, 11, 12, 13, 14); // line2 + LOAD_1_LINE_SRC(sptr4, src[0]); // line4 + CALC_ONE_LINE_4_RESULT(sum, src[1], 15, 16, 17, 18, 19); // line3 + CALC_ONE_LINE_4_RESULT(sum, src[0], 20, 21, 22, 23, 24); // line4 + + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum); + } if (ow < OW) { size_t remain = OW - ow; size_t iw = ow; @@ -791,21 +784,21 @@ void channel_wise_nchw44_8x8x16::direct_stride1_5x5_int8x8x16( const int8_t* __restrict sptr4 = sptr3 + IW * 4; int16x8_t sum[2]; int8x16_t src[2][5]; - sum[0]=init_sum; - sum[1]=init_sum; - - LOAD_1_LINE_SRC(sptr0,src[0]); - LOAD_1_LINE_SRC(sptr1,src[1]); - - CALC_ONE_LINE_4_RESULT(sum,src[0],0,1,2,3,4); - LOAD_1_LINE_SRC(sptr2,src[0]);//line2 - CALC_ONE_LINE_4_RESULT(sum,src[1],5,6,7,8,9);//line1 - LOAD_1_LINE_SRC(sptr3,src[1]);//line3 - CALC_ONE_LINE_4_RESULT(sum,src[0],10,11,12,13,14);//line2 - LOAD_1_LINE_SRC(sptr4,src[0]);//line4 - CALC_ONE_LINE_4_RESULT(sum,src[1],15,16,17,18,19);//line3 - CALC_ONE_LINE_4_RESULT(sum,src[0],20,21,22,23,24);//line4 - STORE_REMAIN(dst,oh,ow,OW,sum,remain); + sum[0] = init_sum; + sum[1] = init_sum; + + LOAD_1_LINE_SRC(sptr0, src[0]); + LOAD_1_LINE_SRC(sptr1, src[1]); + + CALC_ONE_LINE_4_RESULT(sum, src[0], 0, 1, 2, 3, 4); + LOAD_1_LINE_SRC(sptr2, src[0]); // line2 + CALC_ONE_LINE_4_RESULT(sum, src[1], 5, 6, 7, 8, 9); // line1 + LOAD_1_LINE_SRC(sptr3, src[1]); // line3 + CALC_ONE_LINE_4_RESULT(sum, src[0], 10, 11, 12, 13, 14); // line2 + LOAD_1_LINE_SRC(sptr4, src[0]); // line4 + CALC_ONE_LINE_4_RESULT(sum, src[1], 15, 16, 17, 18, 19); // line3 + CALC_ONE_LINE_4_RESULT(sum, src[0], 20, 21, 22, 23, 24); // line4 + STORE_REMAIN(dst, oh, ow, OW, sum, remain); } } #undef LOAD_1_LINE_SRC @@ -821,13 +814,13 @@ void channel_wise_nchw44_8x8x16::direct_stride2_2x2_int8x8x16( MEGDNN_MARK_USED_VAR(IH); const int16_t* __restrict bptr = bias; INIT_SUM(); -const int* fptr = reinterpret_cast(filter); + const int* fptr = reinterpret_cast(filter); int8x8_t kern[4]; #define cb(i) kern[i] = vreinterpret_s8_s32(vld1_dup_s32(fptr + i)); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define CALC_ONE_LINE_8_RESULT(_sum, _rowid, _kid0, _kid1) \ +#define CALC_ONE_LINE_8_RESULT(_sum, _rowid, _kid0, _kid1) \ _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##0), kern[_kid0]); \ _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##0), kern[_kid0]); \ _sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##2), kern[_kid0]); \ @@ -837,11 +830,11 @@ const int* fptr = reinterpret_cast(filter); _sum[2] = vmlal_s8(_sum[2], vget_low_s8(row##_rowid##3), kern[_kid1]); \ _sum[3] = vmlal_s8(_sum[3], vget_high_s8(row##_rowid##3), kern[_kid1]); -#define CALC_ONE_LINE_4_RESULT(_sum, _rowid, _kid0, _kid1) \ +#define CALC_ONE_LINE_4_RESULT(_sum, _rowid, _kid0, _kid1) \ _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##0), kern[_kid0]); \ _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##0), kern[_kid0]); \ _sum[0] = vmlal_s8(_sum[0], vget_low_s8(row##_rowid##1), kern[_kid1]); \ - _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##1), kern[_kid1]); + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(row##_rowid##1), kern[_kid1]); size_t oh = 0_z; for (; oh + 2 <= OH; oh += 2) { @@ -861,32 +854,31 @@ const int* fptr = reinterpret_cast(filter); sum[1][i] = init_sum; UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(i)\ -const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); - UNROLL_CALL_NOWRAPPER(4,cb) + UNROLL_CALL_NOWRAPPER(4, cb) #undef cb -#define cb(i)\ - int32x4x2_t tmp_row##i##_00 = vld2q_s32(tmp_sptr##i);\ - int32x4x2_t tmp_row##i##_01 = vld2q_s32(tmp_sptr##i+8); +#define cb(i) \ + int32x4x2_t tmp_row##i##_00 = vld2q_s32(tmp_sptr##i); \ + int32x4x2_t tmp_row##i##_01 = vld2q_s32(tmp_sptr##i + 8); - UNROLL_CALL_NOWRAPPER(4,cb) + UNROLL_CALL_NOWRAPPER(4, cb) #undef cb -#define cb(i)\ - int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i##_00.val[0]);\ - int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i##_00.val[1]);\ - int8x16_t row##i##2 =vreinterpretq_s8_s32(tmp_row##i##_01.val[0]);\ - int8x16_t row##i##3 =vreinterpretq_s8_s32(tmp_row##i##_01.val[1]); +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_00.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_00.val[1]); \ + int8x16_t row##i##2 = vreinterpretq_s8_s32(tmp_row##i##_01.val[0]); \ + int8x16_t row##i##3 = vreinterpretq_s8_s32(tmp_row##i##_01.val[1]); - UNROLL_CALL_NOWRAPPER(4,cb) + UNROLL_CALL_NOWRAPPER(4, cb) #undef cb - CALC_ONE_LINE_8_RESULT(sum[0],0,0,1); - CALC_ONE_LINE_8_RESULT(sum[0],1,2,3); - CALC_ONE_LINE_8_RESULT(sum[1],2,0,1); - CALC_ONE_LINE_8_RESULT(sum[1],3,2,3); + CALC_ONE_LINE_8_RESULT(sum[0], 0, 0, 1); + CALC_ONE_LINE_8_RESULT(sum[0], 1, 2, 3); + CALC_ONE_LINE_8_RESULT(sum[1], 2, 0, 1); + CALC_ONE_LINE_8_RESULT(sum[1], 3, 2, 3); STORE_1_LINE_RESULT(dst, oh, ow, OW, sum[0]); STORE_1_LINE_RESULT(dst, (oh + 1), ow, OW, sum[1]); } @@ -904,29 +896,27 @@ const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(2, cb); #undef cb -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(i)\ - int32x4x2_t tmp_row##i = vld2q_s32(tmp_sptr##i); +#define cb(i) int32x4x2_t tmp_row##i = vld2q_s32(tmp_sptr##i); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(i)\ - int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\ - int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]);\ +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i.val[1]); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb - CALC_ONE_LINE_4_RESULT(sum[0],0,0,1); - CALC_ONE_LINE_4_RESULT(sum[0],1,2,3); + CALC_ONE_LINE_4_RESULT(sum[0], 0, 0, 1); + CALC_ONE_LINE_4_RESULT(sum[0], 1, 2, 3); - CALC_ONE_LINE_4_RESULT(sum[1],2,0,1); - CALC_ONE_LINE_4_RESULT(sum[1],3,2,3); + CALC_ONE_LINE_4_RESULT(sum[1], 2, 0, 1); + CALC_ONE_LINE_4_RESULT(sum[1], 3, 2, 3); STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum[0]); STORE_1_LINE_4_RESULT(dst, (oh + 1), ow, OW, sum[1]); } @@ -944,8 +934,7 @@ const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); sum[1][i] = init_sum; UNROLL_CALL_NOWRAPPER(2, cb); #undef cb -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb @@ -955,20 +944,20 @@ const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(i)\ - int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\ - int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]);\ +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i.val[1]); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb - CALC_ONE_LINE_4_RESULT(sum[0],0,0,1); - CALC_ONE_LINE_4_RESULT(sum[0],1,2,3); + CALC_ONE_LINE_4_RESULT(sum[0], 0, 0, 1); + CALC_ONE_LINE_4_RESULT(sum[0], 1, 2, 3); - CALC_ONE_LINE_4_RESULT(sum[1],2,0,1); - CALC_ONE_LINE_4_RESULT(sum[1],3,2,3); + CALC_ONE_LINE_4_RESULT(sum[1], 2, 0, 1); + CALC_ONE_LINE_4_RESULT(sum[1], 3, 2, 3); - STORE_REMAIN(dst, (oh+0), ow, OW, sum[0], remain); - STORE_REMAIN(dst, (oh+1), ow, OW, sum[1], remain); + STORE_REMAIN(dst, (oh + 0), ow, OW, sum[0], remain); + STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain); } } for (; oh < OH; oh++) { @@ -980,8 +969,7 @@ const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; int16x8_t sum[4] = {init_sum, init_sum, init_sum, init_sum}; -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(2, cb) #undef cb @@ -1011,9 +999,8 @@ const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); size_t iw = ow * 2; const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; - int16x8_t sum[2]={init_sum,init_sum}; -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + int16x8_t sum[2] = {init_sum, init_sum}; +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(2, cb); #undef cb @@ -1023,14 +1010,14 @@ const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(2, cb); #undef cb -#define cb(i)\ - int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\ - int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]); +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i.val[1]); UNROLL_CALL_NOWRAPPER(2, cb); #undef cb - CALC_ONE_LINE_4_RESULT(sum,0,0,1); - CALC_ONE_LINE_4_RESULT(sum,1,2,3); + CALC_ONE_LINE_4_RESULT(sum, 0, 0, 1); + CALC_ONE_LINE_4_RESULT(sum, 1, 2, 3); STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum); } @@ -1039,9 +1026,8 @@ const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); size_t remain = OW - ow; const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; - int16x8_t sum[2]={init_sum,init_sum}; -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); + int16x8_t sum[2] = {init_sum, init_sum}; +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(2, cb); #undef cb @@ -1051,14 +1037,14 @@ const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(2, cb); #undef cb -#define cb(i)\ - int8x16_t row##i##0 =vreinterpretq_s8_s32(tmp_row##i.val[0]);\ - int8x16_t row##i##1 =vreinterpretq_s8_s32(tmp_row##i.val[1]); +#define cb(i) \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i.val[1]); UNROLL_CALL_NOWRAPPER(2, cb); #undef cb - CALC_ONE_LINE_4_RESULT(sum,0,0,1); - CALC_ONE_LINE_4_RESULT(sum,1,2,3); + CALC_ONE_LINE_4_RESULT(sum, 0, 0, 1); + CALC_ONE_LINE_4_RESULT(sum, 1, 2, 3); STORE_REMAIN(dst, oh, ow, OW, sum, remain); } } @@ -1122,8 +1108,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(5, cb); #undef cb @@ -1142,8 +1127,8 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( vextq_s32(tmp_row##i##_00.val[0], tmp_row##i##_03.val[0], 1)); \ int8x16_t row##i##3 = vreinterpretq_s8_s32(tmp_row##i##_03.val[0]); \ int8x16_t row##i##4 = vreinterpretq_s8_s32(tmp_row##i##_03.val[1]); \ - int8x16_t row##i##5 = vreinterpretq_s8_s32( \ - vextq_s32(tmp_row##i##_03.val[0], tmp_row##i, 1)); + int8x16_t row##i##5 = \ + vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_03.val[0], tmp_row##i, 1)); UNROLL_CALL_NOWRAPPER(5, cb) #undef cb @@ -1174,8 +1159,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( UNROLL_CALL_NOWRAPPER(2, cb); #undef cb -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(5, cb); #undef cb @@ -1187,10 +1171,10 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( UNROLL_CALL_NOWRAPPER(5, cb); #undef cb -#define cb(i) \ +#define cb(i) \ int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \ int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \ - int8x16_t row##i##2 = \ + int8x16_t row##i##2 = \ vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1)); UNROLL_CALL_NOWRAPPER(5, cb); @@ -1221,8 +1205,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( UNROLL_CALL_NOWRAPPER(2, cb); #undef cb -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(5, cb); #undef cb @@ -1234,10 +1217,10 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( UNROLL_CALL_NOWRAPPER(5, cb); #undef cb -#define cb(i) \ +#define cb(i) \ int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \ int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \ - int8x16_t row##i##2 = \ + int8x16_t row##i##2 = \ vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1)); UNROLL_CALL_NOWRAPPER(5, cb); @@ -1264,8 +1247,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; int16x8_t sum[4] = {init_sum, init_sum, init_sum, init_sum}; -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(3, cb); #undef cb @@ -1278,14 +1260,14 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( #undef cb #define cb(i) \ - int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_00.val[0]); \ - int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_00.val[1]); \ - int8x16_t row##i##2 = vreinterpretq_s8_s32( \ + int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_00.val[0]); \ + int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_00.val[1]); \ + int8x16_t row##i##2 = vreinterpretq_s8_s32( \ vextq_s32(tmp_row##i##_00.val[0], tmp_row##i##_03.val[0], 1)); \ int8x16_t row##i##3 = vreinterpretq_s8_s32(tmp_row##i##_03.val[0]); \ int8x16_t row##i##4 = vreinterpretq_s8_s32(tmp_row##i##_03.val[1]); \ - int8x16_t row##i##5 = vreinterpretq_s8_s32( \ - vextq_s32(tmp_row##i##_03.val[0], tmp_row##i, 1)); + int8x16_t row##i##5 = \ + vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_03.val[0], tmp_row##i, 1)); UNROLL_CALL_NOWRAPPER(3, cb) #undef cb @@ -1302,8 +1284,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; int16x8_t sum[2] = {init_sum, init_sum}; -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(3, cb) #undef cb @@ -1315,10 +1296,10 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( UNROLL_CALL_NOWRAPPER(3, cb) #undef cb -#define cb(i) \ +#define cb(i) \ int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \ int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \ - int8x16_t row##i##2 = \ + int8x16_t row##i##2 = \ vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1)); UNROLL_CALL_NOWRAPPER(3, cb) @@ -1336,8 +1317,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4 + iw * 4; const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4 + iw * 4; int16x8_t sum[2] = {init_sum, init_sum}; -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(3, cb) #undef cb @@ -1349,10 +1329,10 @@ void channel_wise_nchw44_8x8x16::direct_stride2_3x3_int8x8x16( UNROLL_CALL_NOWRAPPER(3, cb) #undef cb -#define cb(i) \ +#define cb(i) \ int8x16_t row##i##0 = vreinterpretq_s8_s32(tmp_row##i##_0.val[0]); \ int8x16_t row##i##1 = vreinterpretq_s8_s32(tmp_row##i##_0.val[1]); \ - int8x16_t row##i##2 = \ + int8x16_t row##i##2 = \ vreinterpretq_s8_s32(vextq_s32(tmp_row##i##_0.val[0], tmp_row##i##_1, 1)); UNROLL_CALL_NOWRAPPER(3, cb) @@ -1386,76 +1366,73 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( UNROLL_CALL_NOWRAPPER(25, cb); #undef cb -#define LOAD_5_SRC(_src, _id) \ - do { \ - int32x4x2_t tmp_row_01 = vld2q_s32(tmp_sptr##_id); \ - int32x4x2_t tmp_row_23 = vld2q_s32(tmp_sptr##_id + 2); \ - int32x4_t tmp_row = vld1q_s32(tmp_sptr##_id + 10); \ - _src[0] = vreinterpretq_s8_s32(tmp_row_01.val[0]); \ - _src[1] = vreinterpretq_s8_s32(tmp_row_01.val[1]); \ - _src[2] = vreinterpretq_s8_s32(tmp_row_23.val[0]); \ - _src[3] = vreinterpretq_s8_s32(tmp_row_23.val[1]); \ - _src[4] = vreinterpretq_s8_s32( \ - vextq_s32(tmp_row_23.val[0], tmp_row, 1)); \ +#define LOAD_5_SRC(_src, _id) \ + do { \ + int32x4x2_t tmp_row_01 = vld2q_s32(tmp_sptr##_id); \ + int32x4x2_t tmp_row_23 = vld2q_s32(tmp_sptr##_id + 2); \ + int32x4_t tmp_row = vld1q_s32(tmp_sptr##_id + 10); \ + _src[0] = vreinterpretq_s8_s32(tmp_row_01.val[0]); \ + _src[1] = vreinterpretq_s8_s32(tmp_row_01.val[1]); \ + _src[2] = vreinterpretq_s8_s32(tmp_row_23.val[0]); \ + _src[3] = vreinterpretq_s8_s32(tmp_row_23.val[1]); \ + _src[4] = vreinterpretq_s8_s32(vextq_s32(tmp_row_23.val[0], tmp_row, 1)); \ } while (0); -#define CALC_ONE_LINE_4_RESULT(_sum, _src, _kid0, _kid1, _kid2, _kid3, \ - _kid4) \ - _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \ - _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \ - _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \ - _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); \ - _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[2]), kern[_kid2]); \ - _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[2]), kern[_kid2]); \ - _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[3]), kern[_kid3]); \ - _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[3]), kern[_kid3]); \ - _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[4]), kern[_kid4]); \ +#define CALC_ONE_LINE_4_RESULT(_sum, _src, _kid0, _kid1, _kid2, _kid3, _kid4) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[2]), kern[_kid2]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[2]), kern[_kid2]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[3]), kern[_kid3]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[3]), kern[_kid3]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[4]), kern[_kid4]); \ _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[4]), kern[_kid4]); -#define LOAD_10_SRC(_src, _id) \ - do { \ - int32x4x2_t tmp_row_01 = vld2q_s32(tmp_sptr##_id); \ - int32x4x2_t tmp_row_23 = vld2q_s32(tmp_sptr##_id + 8); \ - int32x4x2_t tmp_row = vld2q_s32(tmp_sptr##_id + 16); \ - _src[0] = vreinterpretq_s8_s32(tmp_row_01.val[0]); \ - _src[1] = vreinterpretq_s8_s32(tmp_row_01.val[1]); \ - _src[2] = vreinterpretq_s8_s32( \ - vextq_s32(tmp_row_01.val[0], tmp_row_23.val[0], 1)); \ - _src[3] = vreinterpretq_s8_s32( \ - vextq_s32(tmp_row_01.val[1], tmp_row_23.val[1], 1)); \ - _src[4] = vreinterpretq_s8_s32( \ - vextq_s32(tmp_row_01.val[0], tmp_row_23.val[0], 2)); \ - _src[5] = vreinterpretq_s8_s32(tmp_row_23.val[0]); \ - _src[6] = vreinterpretq_s8_s32(tmp_row_23.val[1]); \ - _src[7] = vreinterpretq_s8_s32( \ - vextq_s32(tmp_row_23.val[0], tmp_row.val[0], 1)); \ - _src[8] = vreinterpretq_s8_s32( \ - vextq_s32(tmp_row_23.val[1], tmp_row.val[1], 1)); \ - _src[9] = vreinterpretq_s8_s32( \ - vextq_s32(tmp_row_23.val[0], tmp_row.val[0], 2)); \ +#define LOAD_10_SRC(_src, _id) \ + do { \ + int32x4x2_t tmp_row_01 = vld2q_s32(tmp_sptr##_id); \ + int32x4x2_t tmp_row_23 = vld2q_s32(tmp_sptr##_id + 8); \ + int32x4x2_t tmp_row = vld2q_s32(tmp_sptr##_id + 16); \ + _src[0] = vreinterpretq_s8_s32(tmp_row_01.val[0]); \ + _src[1] = vreinterpretq_s8_s32(tmp_row_01.val[1]); \ + _src[2] = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row_01.val[0], tmp_row_23.val[0], 1)); \ + _src[3] = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row_01.val[1], tmp_row_23.val[1], 1)); \ + _src[4] = vreinterpretq_s8_s32( \ + vextq_s32(tmp_row_01.val[0], tmp_row_23.val[0], 2)); \ + _src[5] = vreinterpretq_s8_s32(tmp_row_23.val[0]); \ + _src[6] = vreinterpretq_s8_s32(tmp_row_23.val[1]); \ + _src[7] = \ + vreinterpretq_s8_s32(vextq_s32(tmp_row_23.val[0], tmp_row.val[0], 1)); \ + _src[8] = \ + vreinterpretq_s8_s32(vextq_s32(tmp_row_23.val[1], tmp_row.val[1], 1)); \ + _src[9] = \ + vreinterpretq_s8_s32(vextq_s32(tmp_row_23.val[0], tmp_row.val[0], 2)); \ } while (0); -#define CALC_ONE_LINE_8_RESULT(_sum, _src, _kid0, _kid1, _kid2, _kid3, \ - _kid4) \ - _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \ - _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \ - _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[5]), kern[_kid0]); \ - _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[5]), kern[_kid0]); \ - _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \ - _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); \ - _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[6]), kern[_kid1]); \ - _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[6]), kern[_kid1]); \ - _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[2]), kern[_kid2]); \ - _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[2]), kern[_kid2]); \ - _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[7]), kern[_kid2]); \ - _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[7]), kern[_kid2]); \ - _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[3]), kern[_kid3]); \ - _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[3]), kern[_kid3]); \ - _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[8]), kern[_kid3]); \ - _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[8]), kern[_kid3]); \ - _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[4]), kern[_kid4]); \ - _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[4]), kern[_kid4]); \ - _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[9]), kern[_kid4]); \ +#define CALC_ONE_LINE_8_RESULT(_sum, _src, _kid0, _kid1, _kid2, _kid3, _kid4) \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[0]), kern[_kid0]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[0]), kern[_kid0]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[5]), kern[_kid0]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[5]), kern[_kid0]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[1]), kern[_kid1]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[1]), kern[_kid1]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[6]), kern[_kid1]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[6]), kern[_kid1]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[2]), kern[_kid2]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[2]), kern[_kid2]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[7]), kern[_kid2]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[7]), kern[_kid2]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[3]), kern[_kid3]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[3]), kern[_kid3]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[8]), kern[_kid3]); \ + _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[8]), kern[_kid3]); \ + _sum[0] = vmlal_s8(_sum[0], vget_low_s8(_src[4]), kern[_kid4]); \ + _sum[1] = vmlal_s8(_sum[1], vget_high_s8(_src[4]), kern[_kid4]); \ + _sum[2] = vmlal_s8(_sum[2], vget_low_s8(_src[9]), kern[_kid4]); \ _sum[3] = vmlal_s8(_sum[3], vget_high_s8(_src[9]), kern[_kid4]); size_t oh = 0_z; @@ -1479,8 +1456,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( UNROLL_CALL_NOWRAPPER(4, cb); #undef cb -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(7, cb); #undef cb @@ -1523,8 +1499,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( UNROLL_CALL_NOWRAPPER(2, cb); #undef cb -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(7, cb); #undef cb @@ -1567,8 +1542,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( UNROLL_CALL_NOWRAPPER(2, cb); #undef cb -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(7, cb); #undef cb LOAD_5_SRC(src[0], 0); // line0 @@ -1605,8 +1579,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; int16x8_t sum[4] = {init_sum, init_sum, init_sum, init_sum}; int8x16_t src[3][10]; -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(5, cb); #undef cb LOAD_10_SRC(src[0], 0); // line0 @@ -1631,8 +1604,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; int16x8_t sum[2] = {init_sum, init_sum}; -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(5, cb); #undef cb @@ -1660,8 +1632,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; int16x8_t sum[2] = {init_sum, init_sum}; -#define cb(i) \ - const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); +#define cb(i) const int32_t* tmp_sptr##i = reinterpret_cast(sptr##i); UNROLL_CALL_NOWRAPPER(5, cb); #undef cb @@ -1695,41 +1666,41 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( const int32_t* tmp_filter = reinterpret_cast(filter); INIT_SUM(); int8x8_t kern0[3], kern1[3], kern2[3], kern3[3], kern4[3]; - + int32x2_t tmp_kern = vdup_n_s32(tmp_filter[4]); - tmp_kern = vset_lane_s32(0,tmp_kern,1); + tmp_kern = vset_lane_s32(0, tmp_kern, 1); kern0[0] = vld1_s8(filter); kern0[1] = vld1_s8(filter + 8); kern0[2] = vreinterpret_s8_s32(tmp_kern); - + tmp_kern = vdup_n_s32(tmp_filter[9]); - tmp_kern = vset_lane_s32(0,tmp_kern,1); + tmp_kern = vset_lane_s32(0, tmp_kern, 1); kern1[0] = vld1_s8(filter + 20); kern1[1] = vld1_s8(filter + 28); kern1[2] = vreinterpret_s8_s32(tmp_kern); - + tmp_kern = vdup_n_s32(tmp_filter[14]); - tmp_kern = vset_lane_s32(0,tmp_kern,1); + tmp_kern = vset_lane_s32(0, tmp_kern, 1); kern2[0] = vld1_s8(filter + 40); kern2[1] = vld1_s8(filter + 48); kern2[2] = vreinterpret_s8_s32(tmp_kern); - + tmp_kern = vdup_n_s32(tmp_filter[19]); - tmp_kern = vset_lane_s32(0,tmp_kern,1); + tmp_kern = vset_lane_s32(0, tmp_kern, 1); kern3[0] = vld1_s8(filter + 60); kern3[1] = vld1_s8(filter + 68); kern3[2] = vreinterpret_s8_s32(tmp_kern); - + tmp_kern = vdup_n_s32(tmp_filter[24]); - tmp_kern = vset_lane_s32(0,tmp_kern,1); + tmp_kern = vset_lane_s32(0, tmp_kern, 1); kern4[0] = vld1_s8(filter + 80); kern4[1] = vld1_s8(filter + 88); kern4[2] = vreinterpret_s8_s32(tmp_kern); -#define LOAD_3_SRC_ARRAY(_src,_sptr)\ - _src[0] = vld1q_s8(_sptr);/*0 1 2 3 */\ - _src[1] = vld1q_s8(_sptr + 16);/*4 5 6 7 */\ - _src[2] = vld1q_s8(_sptr + 32);/*8 9 10 11*/ +#define LOAD_3_SRC_ARRAY(_src, _sptr) \ + _src[0] = vld1q_s8(_sptr); /*0 1 2 3 */ \ + _src[1] = vld1q_s8(_sptr + 16); /*4 5 6 7 */ \ + _src[2] = vld1q_s8(_sptr + 32); /*8 9 10 11*/ #define CALC_ONE_LINE(_src, _kern, _sum) \ tmpsum0 = vmull_s8(vget_low_s8(_src[0]), _kern[0]); /*01*/ \ @@ -1744,7 +1715,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( \ tmpsum0 = vmull_s8(vget_low_s8(_src[1]), _kern[0]); /*45*/ \ tmpsum1 = vmull_s8(vget_high_s8(_src[1]), _kern[0]); /*67*/ \ - tmpsum0 = vmlal_s8(tmpsum0, vget_high_s8(_src[1]), _kern[1]); /*67*/ \ + tmpsum0 = vmlal_s8(tmpsum0, vget_high_s8(_src[1]), _kern[1]); /*67*/ \ tmpsum1 = vmlal_s8(tmpsum1, vget_low_s8(_src[2]), _kern[1]); /*89*/ \ tmpsum0 = vmlal_s8(tmpsum0, vget_low_s8(_src[2]), _kern[2]); /*8*/ \ tmpsum1 = vmlal_s8(tmpsum1, vget_high_s8(_src[2]), _kern[2]); /*10*/ \ @@ -1795,7 +1766,7 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( for (; oh + 2 <= OH; oh += 2) { size_t ih = oh * 2; size_t ow = 0_z; - + for (; ow + 4 <= OW; ow += 4) { size_t iw = ow * 2; const int8_t* __restrict sptr0 = src + ih * IW * 4 + iw * 4; @@ -1841,8 +1812,8 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( int16x4_t res0, res1; CALC_8_RESULT(); - STORE_REMAIN(dst, oh, ow, OW, sum[0],remain); - STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1],remain); + STORE_REMAIN(dst, oh, ow, OW, sum[0], remain); + STORE_REMAIN(dst, (oh + 1), ow, OW, sum[1], remain); } } for (; oh < OH; oh++) { @@ -1856,13 +1827,13 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4 + iw * 4; const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4 + iw * 4; - int16x8_t sum[2]={init_sum,init_sum}; + int16x8_t sum[2] = {init_sum, init_sum}; int8x16_t src0[3], src1[3]; int16x8_t tmpsum0, tmpsum1; int16x4_t res0, res1; CALC_4_RESULT(); - STORE_1_LINE_4_RESULT(dst, oh,ow, OW, sum); + STORE_1_LINE_4_RESULT(dst, oh, ow, OW, sum); } if (OW > ow) { size_t iw = ow * 2; @@ -1892,11 +1863,11 @@ void channel_wise_nchw44_8x8x16::direct_stride2_5x5_int8x8x16( #undef STORE_1_LINE_4_RESULT #undef STORE_REMAIN -#define INSTANTIATION(stride, i, bias) \ - template void channel_wise_nchw44_8x8x16:: \ - direct_##stride##_##i##x##i##_int8x8x16( \ - const int8_t*, const int8_t*, const int16_t*, void*, \ - const size_t, const size_t, const size_t, const size_t); +#define INSTANTIATION(stride, i, bias) \ + template void \ + channel_wise_nchw44_8x8x16::direct_##stride##_##i##x##i##_int8x8x16( \ + const int8_t*, const int8_t*, const int16_t*, void*, const size_t, \ + const size_t, const size_t, const size_t); #define FOR_OP(stride, i, bias) INSTANTIATION(stride, i, bias) diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h index 599b5642..c5ebd824 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44.h @@ -21,9 +21,9 @@ using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; -using conv_fun = std::function; +using conv_fun = std::function; namespace stride1 { @@ -32,8 +32,9 @@ bool is_available(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); SmallVector get_kimpls(const NCBKernSizeParam& param); } // namespace stride1 @@ -44,13 +45,14 @@ bool is_available(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); SmallVector get_kimpls(const NCBKernSizeParam& param); } // namespace stride2 -} // namespace direct_int8_stride1 +} // namespace channel_wise_nchw44 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.cpp index b56f033c..41d98733 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.cpp @@ -22,8 +22,8 @@ using namespace channel_wise_nchw44_8x8x16; namespace { void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& IH2, + size_t& IW2) { auto&& fm = param.filter_meta; auto SW = fm.stride[1]; auto OH = param.osz[0]; @@ -40,9 +40,7 @@ void get_rectified_size( MIDOUT_DECL(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride1) MIDOUT_DECL(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride2) - -WorkspaceBundle stride1::get_bundle( - const ConvBiasImpl::NCBKernSizeParam& param) { +WorkspaceBundle stride1::get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { size_t nr_threads = param.nr_threads; size_t IH2, IW2; get_rectified_size(param, IH2, IW2); @@ -55,9 +53,9 @@ WorkspaceBundle stride1::get_bundle( //! compute one output channel template -void stride1::do_conv_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { +void stride1::do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { size_t PH = kern_param.filter_meta.padding[0]; size_t PW = kern_param.filter_meta.padding[1]; size_t OH = kern_param.osz[0]; @@ -82,9 +80,9 @@ void stride1::do_conv_kern(const WorkspaceBundle& bundle, //! copy in case of illegal read src when padding is zero std::memset(padding_src, 0, sizeof(int8_t) * IH2 * IW2 * pack_ic_size); rep(ih, IH) { - std::memcpy(padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, - sptr + ih * IW * pack_ic_size, - sizeof(int8_t) * IW * pack_ic_size); + std::memcpy( + padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, + sptr + ih * IW * pack_ic_size, sizeof(int8_t) * IW * pack_ic_size); } sptr = padding_src; @@ -95,21 +93,21 @@ void stride1::do_conv_kern(const WorkspaceBundle& bundle, #undef KERN } -SmallVector stride1::get_kimpls( - const NCBKernSizeParam& param) { +SmallVector stride1::get_kimpls(const NCBKernSizeParam& param) { auto fm = param.filter_meta; size_t N = param.n; size_t group = fm.group / 4; - megdnn_assert(fm.group % 4 == 0, - "nchw44 channel wise conv with group is not times of 4"); + megdnn_assert( + fm.group % 4 == 0, "nchw44 channel wise conv with group is not times of 4"); WorkspaceBundle wbundle = get_bundle(param); conv_fun do_conv_fun = nullptr; -#define DO_CONV_KERN_FUN(filter, bias_mode) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride1, \ - midout_iv(#filter #bias_mode##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(filter, bias_mode) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride1, \ + midout_iv(#filter #bias_mode##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); #define GET_OP_PARAM(i, bias_mode) \ @@ -122,19 +120,20 @@ SmallVector stride1::get_kimpls( break; \ } -#define GET_BIAS_MODE_PARAM(i) \ - switch (param.bias_mode) { \ - case BiasMode::NO_BIAS: \ - GET_OP_PARAM(i, BiasMode::NO_BIAS) \ - break; \ - case BiasMode::BROADCAST_CHANNEL_BIAS: \ - GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ - break; \ - default: \ - megdnn_assert(0, \ - "only support BiasMode::NO_BIAS and " \ - "BiasMode::BROADCAST_CHANNEL_BIAS"); \ - break; \ +#define GET_BIAS_MODE_PARAM(i) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(i, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(i, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_assert( \ + 0, \ + "only support BiasMode::NO_BIAS and " \ + "BiasMode::BROADCAST_CHANNEL_BIAS"); \ + break; \ } #define DISPATCH_CONV_KERN() \ @@ -168,8 +167,7 @@ SmallVector stride1::get_kimpls( #undef DO_CONV_KERN_FUN } -WorkspaceBundle stride2::get_bundle( - const ConvBiasImpl::NCBKernSizeParam& param) { +WorkspaceBundle stride2::get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { size_t nr_threads = param.nr_threads; size_t IH2, IW2; get_rectified_size(param, IH2, IW2); @@ -182,9 +180,9 @@ WorkspaceBundle stride2::get_bundle( //! compute one output channel template -void stride2::do_conv_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) { +void stride2::do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) { size_t PH = kern_param.filter_meta.padding[0]; size_t PW = kern_param.filter_meta.padding[1]; size_t OH = kern_param.osz[0]; @@ -209,9 +207,9 @@ void stride2::do_conv_kern(const WorkspaceBundle& bundle, //! copy in case of illegal read src when padding is zero std::memset(padding_src, 0, sizeof(int8_t) * IH2 * IW2 * pack_ic_size); rep(ih, IH) { - std::memcpy(padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, - sptr + ih * IW * pack_ic_size, - sizeof(int8_t) * IW * pack_ic_size); + std::memcpy( + padding_src + ((ih + PH) * IW2 + PW) * pack_ic_size, + sptr + ih * IW * pack_ic_size, sizeof(int8_t) * IW * pack_ic_size); } sptr = padding_src; @@ -222,21 +220,21 @@ void stride2::do_conv_kern(const WorkspaceBundle& bundle, #undef KERN } -SmallVector stride2::get_kimpls( - const NCBKernSizeParam& param) { +SmallVector stride2::get_kimpls(const NCBKernSizeParam& param) { auto fm = param.filter_meta; size_t N = param.n; size_t group = fm.group / 4; - megdnn_assert(fm.group % 4 == 0, - "nchw44 channel wise conv with group is not times of 4"); + megdnn_assert( + fm.group % 4 == 0, "nchw44 channel wise conv with group is not times of 4"); WorkspaceBundle wbundle = get_bundle(param); conv_fun do_conv_fun = nullptr; -#define DO_CONV_KERN_FUN(filter, bias_mode) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride2, \ - midout_iv(#filter #bias_mode##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(filter, bias_mode) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_int8x8x16_nchw44_stride2, \ + midout_iv(#filter #bias_mode##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); DISPATCH_CONV_KERN(); diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h index 6a2668ff..5a4604f9 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/channel_wise_nchw44_8x8x16.h @@ -21,9 +21,9 @@ using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; -using conv_fun = std::function; +using conv_fun = std::function; namespace stride1 { @@ -32,8 +32,9 @@ bool is_available(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); SmallVector get_kimpls(const NCBKernSizeParam& param); } // namespace stride1 @@ -44,13 +45,14 @@ bool is_available(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index); SmallVector get_kimpls(const NCBKernSizeParam& param); } // namespace stride2 -} // namespace direct_int8_stride1 +} // namespace channel_wise_nchw44_8x8x16 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.cpp index c5b85bed..a7a76644 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.cpp @@ -20,10 +20,9 @@ using namespace arm_common; using namespace conv_bias; template -void conv_bias::conv_direct_2x2_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, - int16_t* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t PH, - size_t PW) { +void conv_bias::conv_direct_2x2_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW) { size_t OH_start = PH, OH_stop = OH - PH; size_t OW_start = PW, OW_stop = OW - PW; auto run_single = [&](size_t oh, size_t ow) { @@ -222,10 +221,9 @@ void conv_bias::conv_direct_2x2_sc_int8_int8_int16(const int8_t* src, const int8 } template -void conv_bias::conv_direct_3x3_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, - int16_t* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t PH, - size_t PW) { +void conv_bias::conv_direct_3x3_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW) { size_t OH_start = PH, OH_stop = OH - PH; size_t OW_start = PW, OW_stop = OW - PW; @@ -450,10 +448,9 @@ void conv_bias::conv_direct_3x3_sc_int8_int8_int16(const int8_t* src, const int8 } template -void conv_bias::conv_direct_5x5_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, - int16_t* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t PH, - size_t PW) { +void conv_bias::conv_direct_5x5_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW) { size_t OH_start = PH, OH_stop = OH - PH; size_t OW_start = PW, OW_stop = OW - PW; auto run_single = [&](size_t oh, size_t ow) { @@ -574,22 +571,22 @@ void conv_bias::conv_direct_5x5_sc_int8_int8_int16(const int8_t* src, const int8 } template void conv_bias::conv_direct_2x2_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template void conv_bias::conv_direct_2x2_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template void conv_bias::conv_direct_3x3_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template void conv_bias::conv_direct_3x3_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template void conv_bias::conv_direct_5x5_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template void conv_bias::conv_direct_5x5_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.h b/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.h index e319ed62..eb146235 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/conv_direct.h @@ -20,20 +20,17 @@ namespace arm_common { namespace conv_bias { template -void conv_direct_2x2_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, - int16_t* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t PH, - size_t PW); +void conv_direct_2x2_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template -void conv_direct_3x3_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, - int16_t* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t PH, - size_t PW); +void conv_direct_3x3_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template -void conv_direct_5x5_sc_int8_int8_int16(const int8_t* src, const int8_t* filter, - int16_t* dst, size_t IH, size_t IW, - size_t OH, size_t OW, size_t PH, - size_t PW); +void conv_direct_5x5_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); } // namespace conv_bias } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.cpp index 7bca7c57..f5b6d836 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.cpp @@ -22,8 +22,8 @@ using namespace conv_bias; template void conv_bias::conv_stride2_2x2_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW) { + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW) { size_t OH_start = div_ceil(PH, 2), OH_stop = div_floor(IH + PH - 2, 2) + 1, OW_start = div_ceil(PW, 2), @@ -88,10 +88,8 @@ void conv_bias::conv_stride2_2x2_sc_int8_int8_int16( dptr += 4; } for (; ow < OW_stop; ++ow) { - int16_t s0 = sptr[0], s1 = sptr[1], s2 = sptr[IW + 0], - s3 = sptr[IW + 1]; - int16_t f0 = filter[0], f1 = filter[1], f2 = filter[2], - f3 = filter[3]; + int16_t s0 = sptr[0], s1 = sptr[1], s2 = sptr[IW + 0], s3 = sptr[IW + 1]; + int16_t f0 = filter[0], f1 = filter[1], f2 = filter[2], f3 = filter[3]; int16_t d = s0 * f0 + s1 * f1 + s2 * f2 + s3 * f3; if (add_to_dst) { *dptr += d; @@ -106,8 +104,8 @@ void conv_bias::conv_stride2_2x2_sc_int8_int8_int16( template void conv_bias::conv_stride2_3x3_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW) { + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW) { size_t OH_start = div_ceil(PH, 2), OH_stop = div_floor(IH + PH - 3, 2) + 1, OW_start = div_ceil(PW, 2), @@ -250,8 +248,8 @@ void conv_bias::conv_stride2_3x3_sc_int8_int8_int16( template void conv_bias::conv_stride2_5x5_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW) { + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW) { size_t OH_start = div_ceil(PH, 2), OH_stop = div_floor(IH + PH - 5, 2) + 1, OW_start = div_ceil(PW, 2), @@ -413,29 +411,28 @@ void conv_bias::conv_stride2_5x5_sc_int8_int8_int16( } } template void conv_bias::conv_stride2_2x2_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template void conv_bias::conv_stride2_2x2_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template void conv_bias::conv_stride2_3x3_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template void conv_bias::conv_stride2_3x3_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template void conv_bias::conv_stride2_5x5_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template void conv_bias::conv_stride2_5x5_sc_int8_int8_int16( - const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t PH, size_t PW); + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); namespace { -void conv_2x2_optimize_single_channel(const int8_t* src, const uint32_t IH, - const uint32_t IW, const int8_t* filter, - int16_t* dst, const uint32_t OH, - const uint32_t OW) { +void conv_2x2_optimize_single_channel( + const int8_t* src, const uint32_t IH, const uint32_t IW, const int8_t* filter, + int16_t* dst, const uint32_t OH, const uint32_t OW) { int8_t workspace[16]; workspace[0] = filter[0]; workspace[1] = filter[1]; @@ -470,14 +467,11 @@ void conv_2x2_optimize_single_channel(const int8_t* src, const uint32_t IH, dst += 4; } for (; j < IW; j += 2) { - (*dst++) += static_cast(src[0]) * - static_cast(filter[0]) + - static_cast(src[1]) * - static_cast(filter[1]) + - static_cast(src[IW]) * - static_cast(filter[2]) + - static_cast(src[IW + 1]) * - static_cast(filter[3]); + (*dst++) += + static_cast(src[0]) * static_cast(filter[0]) + + static_cast(src[1]) * static_cast(filter[1]) + + static_cast(src[IW]) * static_cast(filter[2]) + + static_cast(src[IW + 1]) * static_cast(filter[3]); src += 2; } src += IW; @@ -501,13 +495,12 @@ bool conv_bias::can_conv_int8x8x16_stride2_flt2( param.dst_type.enumv() == DTypeEnum::Int16 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5) && - param.isz[0] % 2 == 0 && param.isz[1] % 2 == 0 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.spatial[0] == 2 && - fm.spatial[1] == 2 && fm.padding[0] == 0 && fm.padding[1] == 0; + param.isz[0] % 2 == 0 && param.isz[1] % 2 == 0 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.spatial[0] == 2 && fm.spatial[1] == 2 && + fm.padding[0] == 0 && fm.padding[1] == 0; } -void conv_bias::conv_int8x8x16_stride2_flt2( - const ConvBiasImpl::NCBKernParam& param) { +void conv_bias::conv_int8x8x16_stride2_flt2(const ConvBiasImpl::NCBKernParam& param) { UNPACK_CONV_F32_NCB_KERN_SIZES(param); megdnn_ignore(FH); megdnn_ignore(FW); @@ -525,8 +518,8 @@ void conv_bias::conv_int8x8x16_stride2_flt2( memset(dst, 0, sizeof(dst[0]) * OC * OH * OW); for (uint32_t j = 0; j < OC; ++j) { for (uint32_t k = 0; k < IC; ++k) { - conv_2x2_optimize_single_channel(src + k * shape, IH, IW, fptr, - dst, OH, OW); + conv_2x2_optimize_single_channel( + src + k * shape, IH, IW, fptr, dst, OH, OW); fptr += 4; } dst += OH * OW; diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.h b/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.h index 19aa599e..603146b0 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/conv_stride2.h @@ -11,31 +11,27 @@ #pragma once #include "src/arm_common/conv_bias/opr_impl.h" -#include #include +#include namespace megdnn { namespace arm_common { namespace conv_bias { template -void conv_stride2_2x2_sc_int8_int8_int16(const int8_t* src, - const int8_t* filter, int16_t* dst, - size_t IH, size_t IW, size_t OH, - size_t OW, size_t PH, size_t PW); +void conv_stride2_2x2_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template -void conv_stride2_3x3_sc_int8_int8_int16(const int8_t* src, - const int8_t* filter, int16_t* dst, - size_t IH, size_t IW, size_t OH, - size_t OW, size_t PH, size_t PW); +void conv_stride2_3x3_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); template -void conv_stride2_5x5_sc_int8_int8_int16(const int8_t* src, - const int8_t* filter, int16_t* dst, - size_t IH, size_t IW, size_t OH, - size_t OW, size_t PH, size_t PW); +void conv_stride2_5x5_sc_int8_int8_int16( + const int8_t* src, const int8_t* filter, int16_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t PH, size_t PW); -bool can_conv_int8x8x16_stride2_flt2( - const ConvBiasImpl::NCBKernSizeParam& param); +bool can_conv_int8x8x16_stride2_flt2(const ConvBiasImpl::NCBKernSizeParam& param); void conv_int8x8x16_stride2_flt2(const ConvBiasImpl::NCBKernParam& param); diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_algo.cpp index e09d61d9..2434798b 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_algo.cpp @@ -18,10 +18,9 @@ using namespace megdnn; using namespace arm_common; using conv_fun = std::function; + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range)>; MIDOUT_DECL(megdnn_arm_common_conv_bias_int8x8x16_nchw44_direct) static void get_rectified_size( @@ -47,11 +46,9 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { if (group == 1) { size_t src_size = 0; - bool need_padding = param.filter_meta.padding[0] > 0 || - param.filter_meta.padding[1] > 0; - src_size = need_padding - ? batch * group * IC * IH2 * IW2 * sizeof(int8_t) - : 0; + bool need_padding = + param.filter_meta.padding[0] > 0 || param.filter_meta.padding[1] > 0; + src_size = need_padding ? batch * group * IC * IH2 * IW2 * sizeof(int8_t) : 0; #if MEGDNN_ARMV7 if (fm.stride[0] == 1) { constexpr int src_expand_element = 4; @@ -62,11 +59,10 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { return {nullptr, {src_size}}; } else { size_t src_size = 0; - bool need_padding = param.filter_meta.padding[0] > 0 || - param.filter_meta.padding[1] > 0; - src_size = need_padding - ? param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) - : 0; + bool need_padding = + param.filter_meta.padding[0] > 0 || param.filter_meta.padding[1] > 0; + src_size = + need_padding ? param.nr_threads * IC * IH2 * IW2 * sizeof(int8_t) : 0; #if MEGDNN_ARMV7 if (fm.stride[0] == 1) { constexpr int src_expand_element = 4; @@ -79,10 +75,9 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { }; #if MEGDNN_ARMV7 -static void copy_padding_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { +static void copy_padding_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { int IH = kern_param.isz[0]; int IW = kern_param.isz[1]; int IC = kern_param.filter_meta.icpg; @@ -94,7 +89,8 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, int padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset constexpr int pack_ic = 4; - constexpr int src_expand_element = 4;; + constexpr int src_expand_element = 4; + ; size_t workspace_ic_block = 4; size_t workspace_batch_id = workspace_ids[0]; size_t workspace_group_id = workspace_ids[1]; @@ -112,11 +108,11 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); //! copy to sptr_base to eliminate padding effect - int8_t* sptr_base = static_cast(bundle.get(0)) + - (workspace_batch_id * GROUP * padding_group_size + - workspace_group_id * padding_group_size + - workspace_ic * IH2 * IW2) * - src_expand_element; + int8_t* sptr_base = + static_cast(bundle.get(0)) + + (workspace_batch_id * GROUP * padding_group_size + + workspace_group_id * padding_group_size + workspace_ic * IH2 * IW2) * + src_expand_element; size_t nr_ic = workspace_ic_block; if (GROUP > 1) { nr_ic = IC; @@ -139,10 +135,9 @@ static void copy_padding_kern(const WorkspaceBundle& bundle, } #endif -static void copy_padding_kern_no_pack_src(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { +static void copy_padding_kern_no_pack_src( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { int IH = kern_param.isz[0]; int IW = kern_param.isz[1]; int IC = kern_param.filter_meta.icpg; @@ -172,11 +167,11 @@ static void copy_padding_kern_no_pack_src(const WorkspaceBundle& bundle, batch_id, group_id, workspace_ic_id, group_pack_size, pack_ic)); //! copy to sptr_base to eliminate padding effect - int8_t* sptr_base = static_cast(bundle.get(0)) + - (workspace_batch_id * GROUP * padding_group_size + - workspace_group_id * padding_group_size + - workspace_ic * IH2 * IW2) * - src_expand_element; + int8_t* sptr_base = + static_cast(bundle.get(0)) + + (workspace_batch_id * GROUP * padding_group_size + + workspace_group_id * padding_group_size + workspace_ic * IH2 * IW2) * + src_expand_element; size_t nr_ic = workspace_ic_block; if (GROUP > 1) { nr_ic = IC; @@ -199,11 +194,10 @@ static void copy_padding_kern_no_pack_src(const WorkspaceBundle& bundle, } template -static void do_conv_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids, - const CpuNDRange& ncb_range) { +static void do_conv_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range) { size_t OH = kern_param.osz[0]; size_t OW = kern_param.osz[1]; size_t FH = kern_param.filter_meta.spatial[0]; @@ -230,36 +224,33 @@ static void do_conv_kern(const WorkspaceBundle& bundle, if (oc_id == (oc_block_num - 1)) { oc_block = OC - oc_id * nr_pack_per_step * pack_c; } - megdnn_assert(oc_block % pack_c == 0, - "oc must be devisible by 4, but oc = %zu", oc_block); + megdnn_assert( + oc_block % pack_c == 0, "oc must be devisible by 4, but oc = %zu", + oc_block); bool need_padding = kern_param.filter_meta.padding[0] > 0 || kern_param.filter_meta.padding[1] > 0; - const int8_t* sptr = need_padding - ? static_cast(bundle.get(0)) + - workspace_batch_id * GROUP * padding_group_size + - workspace_group_id * padding_group_size - : kern_param.src(batch_id, group_id); + const int8_t* sptr = + need_padding ? static_cast(bundle.get(0)) + + workspace_batch_id * GROUP * padding_group_size + + workspace_group_id * padding_group_size + : kern_param.src(batch_id, group_id); //!armv7 use packsrc mode #if MEGDNN_ARMV7 if (stride == 1) { constexpr size_t src_expand_size = 4; sptr = static_cast(bundle.get(0)) + - workspace_batch_id * GROUP * padding_group_size * - src_expand_size + + workspace_batch_id * GROUP * padding_group_size * src_expand_size + workspace_group_id * padding_group_size * src_expand_size; } #endif - const int8_t* fptr = - kern_param.filter(group_id) + oc_idx * FH * FW * IC; + const int8_t* fptr = kern_param.filter(group_id) + oc_idx * FH * FW * IC; int16_t* dst = reinterpret_cast( kern_param.dst(batch_id, group_id, oc_idx)); - const int16_t* bptr = - kern_param.bias(batch_id, group_id) + oc_idx; - int8x8x16_direct_nchw44::ConvDirectInt8Nchw44Choose< - bias_mode, filter, stride>::impl(sptr, fptr, bptr, dst, oc_block, - IC, IH2, IW2, OH, OW); + const int16_t* bptr = kern_param.bias(batch_id, group_id) + oc_idx; + int8x8x16_direct_nchw44::ConvDirectInt8Nchw44Choose:: + impl(sptr, fptr, bptr, dst, oc_block, IC, IH2, IW2, OH, OW); } bool ConvBiasImpl::AlgoS8x8x16DirectNCHW44::usable( @@ -277,23 +268,21 @@ bool ConvBiasImpl::AlgoS8x8x16DirectNCHW44::usable( param.dst_type.enumv() == DTypeEnum::Int16) && (fm.format == param::Convolution::Format::NCHW44) && (oc % 4 == 0 && ic % 4 == 0 && oc >= 4) && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && - (fm.stride[0] == 2 || fm.stride[0] == 1) && fh == fw && - (fh == 2 || fh == 3 || fh == 5 || fh == 7) && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == fm.stride[1] && (fm.stride[0] == 2 || fm.stride[0] == 1) && + fh == fw && (fh == 2 || fh == 3 || fh == 5 || fh == 7) && param.nonlineMode == NonlineMode::IDENTITY && param.bias_mode != BiasMode::BIAS; return avaible; } size_t ConvBiasImpl::AlgoS8x8x16DirectNCHW44::get_workspace( - const NCBKernSizeParam& param) const { + const NCBKernSizeParam& param) const { return get_bundle(param).total_size_in_bytes(); } -SmallVector -ConvBiasImpl::AlgoS8x8x16DirectNCHW44::dispatch_kerns( - const NCBKernSizeParam& param) const { +SmallVector ConvBiasImpl::AlgoS8x8x16DirectNCHW44:: + dispatch_kerns(const NCBKernSizeParam& param) const { auto fm = param.filter_meta; size_t N = param.n; size_t IC = fm.icpg; @@ -306,67 +295,71 @@ ConvBiasImpl::AlgoS8x8x16DirectNCHW44::dispatch_kerns( WorkspaceBundle wbundle = get_bundle(param); conv_fun do_conv_fun = nullptr; -#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8x8x16_nchw44_direct, \ - midout_iv("int8x8x16_nchw44_direct_" \ - "conv" #stride #filter #bias_mode##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(stride, dst_type, filter, bias_mode) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_int8x8x16_nchw44_direct, \ + midout_iv("int8x8x16_nchw44_direct_" \ + "conv" #stride #filter #bias_mode##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); -#define GET_OP_PARAM(stride, filter, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(stride, dt_int16, filter, bias_mode) \ - break; \ - default: \ - megdnn_throw(ssprintf("only support IDENTITY mode when dst is " \ - "dt_int16 nonlineMode is %d", \ - uint32_t(param.nonlineMode)) \ - .c_str()); \ - break; \ +#define GET_OP_PARAM(stride, filter, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN(stride, dt_int16, filter, bias_mode) \ + break; \ + default: \ + megdnn_throw(ssprintf( \ + "only support IDENTITY mode when dst is " \ + "dt_int16 nonlineMode is %d", \ + uint32_t(param.nonlineMode)) \ + .c_str()); \ + break; \ + } + +#define GET_BIAS_MODE_PARAM(stride, filter) \ + switch (param.bias_mode) { \ + case BiasMode::NO_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ + break; \ + case BiasMode::BROADCAST_CHANNEL_BIAS: \ + GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + break; \ + default: \ + megdnn_throw(ssprintf( \ + "only support NO_BIAS/BROADCAST biasmode " \ + "when dst is " \ + "dt_int16 biasmode is %d", \ + uint32_t(param.bias_mode)) \ + .c_str()); \ + break; \ } -#define GET_BIAS_MODE_PARAM(stride, filter) \ - switch (param.bias_mode) { \ - case BiasMode::NO_BIAS: \ - GET_OP_PARAM(stride, filter, BiasMode::NO_BIAS) \ +#define DISPATCH_CONV_KERN(stride) \ + switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(stride, 2) \ + break; \ + case 3: \ + GET_BIAS_MODE_PARAM(stride, 3) \ + break; \ + case 5: \ + GET_BIAS_MODE_PARAM(stride, 5) \ break; \ - case BiasMode::BROADCAST_CHANNEL_BIAS: \ - GET_OP_PARAM(stride, filter, BiasMode::BROADCAST_CHANNEL_BIAS) \ + case 7: \ + GET_BIAS_MODE_PARAM(stride, 7) \ break; \ default: \ - megdnn_throw(ssprintf("only support NO_BIAS/BROADCAST biasmode " \ - "when dst is " \ - "dt_int16 biasmode is %d", \ - uint32_t(param.bias_mode)) \ + megdnn_throw(ssprintf( \ + "only support 2x2 3x3 5x5 7x7 filters size " \ + "when dst is " \ + "dt_int16 filter size is %u", \ + uint32_t(param.filter_meta.spatial[0])) \ .c_str()); \ break; \ } -#define DISPATCH_CONV_KERN(stride) \ - switch (param.filter_meta.spatial[0]) { \ - case 2: \ - GET_BIAS_MODE_PARAM(stride, 2) \ - break; \ - case 3: \ - GET_BIAS_MODE_PARAM(stride, 3) \ - break; \ - case 5: \ - GET_BIAS_MODE_PARAM(stride, 5) \ - break; \ - case 7: \ - GET_BIAS_MODE_PARAM(stride, 7) \ - break; \ - default: \ - megdnn_throw(ssprintf("only support 2x2 3x3 5x5 7x7 filters size " \ - "when dst is " \ - "dt_int16 filter size is %u", \ - uint32_t(param.filter_meta.spatial[0])) \ - .c_str()); \ - break; \ - } - switch (param.filter_meta.stride[0]) { case 1: DISPATCH_CONV_KERN(1); @@ -375,8 +368,9 @@ ConvBiasImpl::AlgoS8x8x16DirectNCHW44::dispatch_kerns( DISPATCH_CONV_KERN(2); break; default: - megdnn_throw(ssprintf("Unsupport stride size %u for the 8x8x16 direct conv", - param.filter_meta.stride[0]) + megdnn_throw(ssprintf( + "Unsupport stride size %u for the 8x8x16 direct conv", + param.filter_meta.stride[0]) .c_str()); break; } @@ -405,18 +399,17 @@ ConvBiasImpl::AlgoS8x8x16DirectNCHW44::dispatch_kerns( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { wbundle.set(kern_param.workspace_ptr); - copy_padding_kern(wbundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(wbundle, kern_param, ncb_index, ncb_index.ndrange_id); }; constexpr size_t pack_ic = 4; - ret_kerns.push_back( - {copy_padding, {N, group, div_ceil(IC, pack_ic)}}); + ret_kerns.push_back({copy_padding, {N, group, div_ceil(IC, pack_ic)}}); auto do_conv = [wbundle, do_conv_fun, ncb_range]( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { wbundle.set(kern_param.workspace_ptr); - do_conv_fun(wbundle, kern_param, ncb_index, - ncb_index.ndrange_id, ncb_range); + do_conv_fun( + wbundle, kern_param, ncb_index, ncb_index.ndrange_id, + ncb_range); }; ret_kerns.push_back({do_conv, ncb_range}); } else { @@ -425,38 +418,39 @@ ConvBiasImpl::AlgoS8x8x16DirectNCHW44::dispatch_kerns( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { wbundle.set(kern_param.workspace_ptr); - copy_padding_kern(wbundle, kern_param, ncb_index, - {0, ncb_index.thread_id, 0}); - do_conv_fun(wbundle, kern_param, ncb_index, - {0, ncb_index.thread_id, 0}, ncb_range); + copy_padding_kern( + wbundle, kern_param, ncb_index, {0, ncb_index.thread_id, 0}); + do_conv_fun( + wbundle, kern_param, ncb_index, {0, ncb_index.thread_id, 0}, + ncb_range); }; ret_kerns.push_back({do_conv, ncb_range}); } - return ret_kerns; + return ret_kerns; } #endif - bool need_padding = ph > 0 || pw >0; + bool need_padding = ph > 0 || pw > 0; if (group == 1) { CpuNDRange ncb_range = {N, group, div_ceil(OC, oc_step)}; - auto copy_padding = [wbundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [wbundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { wbundle.set(kern_param.workspace_ptr); - copy_padding_kern_no_pack_src(wbundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern_no_pack_src( + wbundle, kern_param, ncb_index, ncb_index.ndrange_id); }; constexpr size_t pack_ic = 4; if (need_padding) { - ret_kerns.push_back( - {copy_padding, {N, group, div_ceil(IC, pack_ic)}}); + ret_kerns.push_back({copy_padding, {N, group, div_ceil(IC, pack_ic)}}); } auto do_conv = [wbundle, do_conv_fun, ncb_range]( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { wbundle.set(kern_param.workspace_ptr); - do_conv_fun(wbundle, kern_param, ncb_index, ncb_index.ndrange_id, - ncb_range); + do_conv_fun( + wbundle, kern_param, ncb_index, ncb_index.ndrange_id, ncb_range); }; ret_kerns.push_back({do_conv, ncb_range}); } else { @@ -466,11 +460,12 @@ ConvBiasImpl::AlgoS8x8x16DirectNCHW44::dispatch_kerns( const NCBKernIndex& ncb_index) mutable { wbundle.set(kern_param.workspace_ptr); if (need_padding) { - copy_padding_kern_no_pack_src(wbundle, kern_param, ncb_index, - {0, ncb_index.thread_id, 0}); + copy_padding_kern_no_pack_src( + wbundle, kern_param, ncb_index, {0, ncb_index.thread_id, 0}); }; - do_conv_fun(wbundle, kern_param, ncb_index, - {0, ncb_index.thread_id, 0}, ncb_range); + do_conv_fun( + wbundle, kern_param, ncb_index, {0, ncb_index.thread_id, 0}, + ncb_range); }; ret_kerns.push_back({do_conv, ncb_range}); } diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h index e307b8d9..9efa6aca 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_8x8x16_nchw44_kern.h @@ -43,13 +43,13 @@ static inline void nchw44_pack_src(const int8_t* src, int8_t* dst, int length) { template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int16_t* bias, int16_t* dst, const size_t oc, - const size_t ic, const size_t ih, const size_t iw, - const size_t oh, const size_t ow); + static void impl( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow); }; -} // namespace int8_direct_nchw44 +} // namespace int8x8x16_direct_nchw44 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_aarch64.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_aarch64.cpp index f6e15969..37766bd7 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_aarch64.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_aarch64.cpp @@ -71,32 +71,31 @@ namespace { megdnn_assert(0, "oc 1 error remainw"); \ }; -#define STORE_2_LINE_RESULT_OW4() \ - switch (remain_w) { \ - case 4: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - break; \ - case 1: \ - vst1_s16(dst_ptr, c[0][0]); \ - vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ - break; \ - case 2: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - break; \ - case 3: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1_s16(dst_ptr + 8, c[0][2]); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ - break; \ - default: \ - megdnn_assert(0, "oc 2 error remainw"); \ - break; \ +#define STORE_2_LINE_RESULT_OW4() \ + switch (remain_w) { \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ + break; \ + default: \ + megdnn_assert(0, "oc 2 error remainw"); \ + break; \ } #define STORE_1_LINE_RESULT_OW4_OH2() \ @@ -146,12 +145,10 @@ namespace { megdnn_assert(0, "oc 1 error remainw"); \ }; -template -static void ker_neon_dirctconv_2x2s1_oc8_ow4(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, int ih, - int iw, int remain_w,int ld_dst_oc) { +template +static void ker_neon_dirctconv_2x2s1_oc8_ow4( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int remain_w, int ld_dst_oc) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -179,10 +176,8 @@ static void ker_neon_dirctconv_2x2s1_oc8_ow4(const int8_t* src_ptr, #undef cb for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* src_row0 = - src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; - const int8_t* src_row1 = - src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; + const int8_t* src_row0 = src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; + const int8_t* src_row1 = src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; src[0] = vld_dup_tbl_s32(src_row0 + 0, idx); src[1] = vld_dup_tbl_s32(src_row0 + 4, idx); @@ -248,12 +243,9 @@ static void ker_neon_dirctconv_2x2s1_oc8_ow4(const int8_t* src_ptr, } template -static void ker_neon_dirctconv_2x2s1_oc4_ow4(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, int ih, - int iw, int remain_w, - int /*ld_dst_oc*/) { +static void ker_neon_dirctconv_2x2s1_oc4_ow4( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int remain_w, int /*ld_dst_oc*/) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -284,23 +276,18 @@ static void ker_neon_dirctconv_2x2s1_oc4_ow4(const int8_t* src_ptr, src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0][0] = vld1q_s8(read_weight_ptr); weight[0][1] = vld1q_s8(read_weight_ptr + 16); int16x8_t tmp0; - CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], - c[0][0]); + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); - CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], - c[0][1]); + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]); - CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], - c[0][2]); + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]); - CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], - c[0][3]); + CALC_ONE_RESULT(src[3], src[4], weight[0][0], weight[0][1], c[0][3]); } weight_ptr += fh * fw * ld_weight_ic4; } @@ -320,12 +307,9 @@ static void ker_neon_dirctconv_2x2s1_oc4_ow4(const int8_t* src_ptr, } while (0); template -static void ker_neon_dirctconv_3x3s1_oc4_ow4(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, int ih, - int iw, int remain_w, - int /*ld_dst_oc*/) { +static void ker_neon_dirctconv_3x3s1_oc4_ow4( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int remain_w, int /*ld_dst_oc*/) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -347,12 +331,9 @@ static void ker_neon_dirctconv_3x3s1_oc4_ow4(const int8_t* src_ptr, #undef cb for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* src_row0 = - src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; - const int8_t* src_row1 = - src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; - const int8_t* src_row2 = - src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; + const int8_t* src_row0 = src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; + const int8_t* src_row1 = src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; + const int8_t* src_row2 = src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; src[0] = vld_dup_tbl_s32(src_row0 + 0, idx); src[1] = vld_dup_tbl_s32(src_row0 + 4, idx); @@ -416,12 +397,10 @@ static void ker_neon_dirctconv_3x3s1_oc4_ow4(const int8_t* src_ptr, } template -static void ker_neon_dirctconv_3x3s1_oc4_ow4_oh2(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, - int ih, int iw, int remain_w, - int /*ld_dst_oc*/, int ow) { +static void ker_neon_dirctconv_3x3s1_oc4_ow4_oh2( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int remain_w, int /*ld_dst_oc*/, + int ow) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -443,14 +422,10 @@ static void ker_neon_dirctconv_3x3s1_oc4_ow4_oh2(const int8_t* src_ptr, #undef cb for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* src_row0 = - src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; - const int8_t* src_row1 = - src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; - const int8_t* src_row2 = - src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; - const int8_t* src_row3 = - src_ptr + ic_idx * ic_stride + 3 * iw * ic_step; + const int8_t* src_row0 = src_ptr + ic_idx * ic_stride + 0 * iw * ic_step; + const int8_t* src_row1 = src_ptr + ic_idx * ic_stride + 1 * iw * ic_step; + const int8_t* src_row2 = src_ptr + ic_idx * ic_stride + 2 * iw * ic_step; + const int8_t* src_row3 = src_ptr + ic_idx * ic_stride + 3 * iw * ic_step; #define LOAD_SRC(_src, _src_ptr) \ _src[0] = vld_dup_tbl_s32(_src_ptr + 0, idx); \ _src[1] = vld_dup_tbl_s32(_src_ptr + 4, idx); \ @@ -469,21 +444,24 @@ static void ker_neon_dirctconv_3x3s1_oc4_ow4_oh2(const int8_t* src_ptr, weight[1][1] = vld1q_s8(weight_ptr + 64); weight[1][2] = vld1q_s8(weight_ptr + 80); - CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], - c[0][0]); // row0 src0 w0 + CALC_ONE_RESULT( + src[0][0], src[0][1], src[0][2], weight[0], + c[0][0]); // row0 src0 w0 CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][1]); CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][2]); CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][3]); LOAD_SRC(src[0], src_row1); - CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], - c[0][4]); // row1 src1 w0 + CALC_ONE_RESULT( + src[0][0], src[0][1], src[0][2], weight[0], + c[0][4]); // row1 src1 w0 CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][5]); CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][6]); CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][7]); - CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[1], - c[0][0]); // row1 src1 w1 + CALC_ONE_RESULT( + src[0][0], src[0][1], src[0][2], weight[1], + c[0][0]); // row1 src1 w1 CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[1], c[0][1]); CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[1], c[0][2]); CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[1], c[0][3]); @@ -493,22 +471,25 @@ static void ker_neon_dirctconv_3x3s1_oc4_ow4_oh2(const int8_t* src_ptr, weight[0][0] = vld1q_s8(weight_ptr + 96); weight[0][1] = vld1q_s8(weight_ptr + 112); weight[0][2] = vld1q_s8(weight_ptr + 128); - CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[1], - c[0][4]); // row2 src0 w1 + CALC_ONE_RESULT( + src[0][0], src[0][1], src[0][2], weight[1], + c[0][4]); // row2 src0 w1 CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[1], c[0][5]); CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[1], c[0][6]); CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[1], c[0][7]); - CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], - c[0][0]); // row2 w0 src[0] + CALC_ONE_RESULT( + src[0][0], src[0][1], src[0][2], weight[0], + c[0][0]); // row2 w0 src[0] CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][1]); CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][2]); CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][3]); LOAD_SRC(src[0], src_row3); - CALC_ONE_RESULT(src[0][0], src[0][1], src[0][2], weight[0], - c[0][4]); // row3 w0 src1 + CALC_ONE_RESULT( + src[0][0], src[0][1], src[0][2], weight[0], + c[0][4]); // row3 w0 src1 CALC_ONE_RESULT(src[0][1], src[0][2], src[0][3], weight[0], c[0][5]); CALC_ONE_RESULT(src[0][2], src[0][3], src[0][4], weight[0], c[0][6]); CALC_ONE_RESULT(src[0][3], src[0][4], src[0][5], weight[0], c[0][7]); @@ -522,16 +503,16 @@ static void ker_neon_dirctconv_3x3s1_oc4_ow4_oh2(const int8_t* src_ptr, template struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int remain_w, int ld_dst_oc); + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int remain_w, int ld_dst_oc); }; template struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int remain_w, int /*ld_dst_oc*/) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int remain_w, int /*ld_dst_oc*/) { constexpr int filter_size = 5; constexpr int fh = filter_size; constexpr int fw = filter_size; @@ -577,49 +558,49 @@ struct KerNeonDirectStride1Int8 { weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); -#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, \ - _w4, _c) \ - do { \ - int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ - int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \ - tmp0 = vaddq_s16(tmp0, tmp1); \ - _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ +#define CALC_ONE_RESULT( \ + _src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, _w4, _c) \ + do { \ + int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ + int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \ + tmp0 = vaddq_s16(tmp0, tmp1); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ } while (0); - CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][0]); - CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][1]); - CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][2]); - CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][3]); - CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][4]); - CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][5]); + CALC_ONE_RESULT( + src[0], src[1], src[2], src[3], src[4], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][0]); + CALC_ONE_RESULT( + src[1], src[2], src[3], src[4], src[5], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][1]); + CALC_ONE_RESULT( + src[2], src[3], src[4], src[5], src[6], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][2]); + CALC_ONE_RESULT( + src[3], src[4], src[5], src[6], src[7], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][3]); + CALC_ONE_RESULT( + src[4], src[5], src[6], src[7], src[8], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][4]); + CALC_ONE_RESULT( + src[5], src[6], src[7], src[8], src[9], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][5]); src[0] = vld_dup_tbl_s32(src_ic_0_3 + 10 * 4, idx); src[1] = vld_dup_tbl_s32(src_ic_0_3 + 11 * 4, idx); - CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][6]); - CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][7]); + CALC_ONE_RESULT( + src[6], src[7], src[8], src[9], src[0], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][6]); + CALC_ONE_RESULT( + src[7], src[8], src[9], src[0], src[1], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][7]); } weight_ptr += fh * fw * ld_weight_ic4; } @@ -630,9 +611,9 @@ struct KerNeonDirectStride1Int8 { #undef CALC_ONE_RESULT template struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int remain_w, int /*ld_dst_oc*/) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int remain_w, int /*ld_dst_oc*/) { constexpr int filter_size = 7; constexpr int fh = filter_size; constexpr int fw = filter_size; @@ -682,51 +663,58 @@ struct KerNeonDirectStride1Int8 { weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); -#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \ - _c) \ - do { \ - int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ - int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \ - int16x8_t tmp2 = vmull_s8(vget_low_s8(_src1), vget_low_s8(_w[1])); \ - int16x8_t tmp3 = vmull_s8(vget_high_s8(_src1), vget_high_s8(_w[1])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \ - tmp2 = vmlal_s8(tmp2, vget_low_s8(_src3), vget_low_s8(_w[3])); \ - tmp3 = vmlal_s8(tmp3, vget_high_s8(_src3), vget_high_s8(_w[3])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \ - tmp2 = vmlal_s8(tmp2, vget_low_s8(_src5), vget_low_s8(_w[5])); \ - tmp3 = vmlal_s8(tmp3, vget_high_s8(_src5), vget_high_s8(_w[5])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src6), vget_high_s8(_w[6])); \ - tmp0 = vaddq_s16(tmp0, tmp1); \ - tmp2 = vaddq_s16(tmp2, tmp3); \ - tmp0 = vaddq_s16(tmp0, tmp2); \ - _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ +#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, _c) \ + do { \ + int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ + int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \ + int16x8_t tmp2 = vmull_s8(vget_low_s8(_src1), vget_low_s8(_w[1])); \ + int16x8_t tmp3 = vmull_s8(vget_high_s8(_src1), vget_high_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \ + tmp2 = vmlal_s8(tmp2, vget_low_s8(_src3), vget_low_s8(_w[3])); \ + tmp3 = vmlal_s8(tmp3, vget_high_s8(_src3), vget_high_s8(_w[3])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \ + tmp2 = vmlal_s8(tmp2, vget_low_s8(_src5), vget_low_s8(_w[5])); \ + tmp3 = vmlal_s8(tmp3, vget_high_s8(_src5), vget_high_s8(_w[5])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src6), vget_high_s8(_w[6])); \ + tmp0 = vaddq_s16(tmp0, tmp1); \ + tmp2 = vaddq_s16(tmp2, tmp3); \ + tmp0 = vaddq_s16(tmp0, tmp2); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ } while (0); - CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], src[5], - src[6], weight, c[0][0]); - CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], src[6], - src[7], weight, c[0][1]); + CALC_ONE_RESULT( + src[0], src[1], src[2], src[3], src[4], src[5], src[6], weight, + c[0][0]); + CALC_ONE_RESULT( + src[1], src[2], src[3], src[4], src[5], src[6], src[7], weight, + c[0][1]); src[0] = vld_dup_tbl_s32(src_ic_0_3 + 10 * 4, idx); src[1] = vld_dup_tbl_s32(src_ic_0_3 + 11 * 4, idx); - CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], src[7], - src[8], weight, c[0][2]); - CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], src[8], - src[9], weight, c[0][3]); + CALC_ONE_RESULT( + src[2], src[3], src[4], src[5], src[6], src[7], src[8], weight, + c[0][2]); + CALC_ONE_RESULT( + src[3], src[4], src[5], src[6], src[7], src[8], src[9], weight, + c[0][3]); src[2] = vld_dup_tbl_s32(src_ic_0_3 + 12 * 4, idx); src[3] = vld_dup_tbl_s32(src_ic_0_3 + 13 * 4, idx); - CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], src[9], - src[0], weight, c[0][4]); - CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], src[0], - src[1], weight, c[0][5]); - CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], src[1], - src[2], weight, c[0][6]); - CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], src[2], - src[3], weight, c[0][7]); + CALC_ONE_RESULT( + src[4], src[5], src[6], src[7], src[8], src[9], src[0], weight, + c[0][4]); + CALC_ONE_RESULT( + src[5], src[6], src[7], src[8], src[9], src[0], src[1], weight, + c[0][5]); + CALC_ONE_RESULT( + src[6], src[7], src[8], src[9], src[0], src[1], src[2], weight, + c[0][6]); + CALC_ONE_RESULT( + src[7], src[8], src[9], src[0], src[1], src[2], src[3], weight, + c[0][7]); } weight_ptr += fh * fw * ld_weight_ic4; } @@ -736,12 +724,10 @@ struct KerNeonDirectStride1Int8 { #undef CALC_ONE_RESULT template -void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, - const int8_t* filter, - const int16_t* bias, int16_t* dst, - const size_t oc, const size_t ic, - const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { +void conv_direct_stride1_2x2_int8_nchw44( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { constexpr size_t filter_size = 2; constexpr size_t fh = filter_size; constexpr size_t fw = filter_size; @@ -806,9 +792,9 @@ void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, template void conv_direct_stride1_3x3_int8x8x16_oh2_nchw44( - const int8_t* src, const int8_t* filter, const int16_t* bias, - int16_t* dst, const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow) { + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { constexpr size_t filter_size = 3; constexpr size_t fh = filter_size; constexpr size_t fw = filter_size; @@ -834,8 +820,7 @@ void conv_direct_stride1_3x3_int8x8x16_oh2_nchw44( oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; ker_neon_dirctconv_3x3s1_oc4_ow4_oh2( src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ow_step, ld_oc, - ow * oc_step); + dst + dst_offset, ic, ih, iw, ow_step, ld_oc, ow * oc_step); } if (ow_remain > 0) { const size_t src_offset = (oh_idx * iw + ow_end) * ic_step; @@ -843,8 +828,7 @@ void conv_direct_stride1_3x3_int8x8x16_oh2_nchw44( oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; ker_neon_dirctconv_3x3s1_oc4_ow4_oh2( src + src_offset, filter + weight_offset, bias + oc_idx, - dst + dst_offset, ic, ih, iw, ow_remain, ld_oc, - ow * oc_step); + dst + dst_offset, ic, ih, iw, ow_remain, ld_oc, ow * oc_step); } } for (; oh_idx < oh; oh_idx += oh_step) { @@ -869,12 +853,10 @@ void conv_direct_stride1_3x3_int8x8x16_oh2_nchw44( } template -void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, - const int8_t* filter, - const int16_t* bias, int16_t* dst, - const size_t oc, const size_t ic, - const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { +void conv_direct_stride1_int8_nchw44_kern( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { constexpr size_t fh = filter_size; constexpr size_t fw = filter_size; constexpr size_t ic_step = 4; @@ -914,10 +896,10 @@ void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, namespace int8x8x16_direct_nchw44 { template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int16_t* bias, int16_t* dst, const size_t oc, - const size_t ic, const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { + static void impl( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { conv_direct_stride1_int8_nchw44_kern( src, filter, bias, dst, oc, ic, ih, iw, oh, ow); } @@ -925,21 +907,21 @@ struct ConvDirectInt8Nchw44Choose { template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int16_t* bias, int16_t* dst, const size_t oc, - const size_t ic, const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { - conv_direct_stride1_2x2_int8_nchw44(src, filter, bias, dst, - oc, ic, ih, iw, oh, ow); + static void impl( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride1_2x2_int8_nchw44( + src, filter, bias, dst, oc, ic, ih, iw, oh, ow); } }; template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int16_t* bias, int16_t* dst, const size_t oc, - const size_t ic, const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { + static void impl( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { conv_direct_stride1_3x3_int8x8x16_oh2_nchw44( src, filter, bias, dst, oc, ic, ih, iw, oh, ow); } diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_armv7.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_armv7.cpp index 2a4aec8a..064bb2d6 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_armv7.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s1_armv7.cpp @@ -70,84 +70,73 @@ namespace { megdnn_assert(0, "oc 1 error remainw"); \ }; -#define STORE_2_LINE_RESULT() \ - switch (remain_w) { \ - case 8: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ - vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 16, \ - vcombine_s16(c[1][4], c[1][5])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 24, \ - vcombine_s16(c[1][6], c[1][7])); \ - break; \ - case 1: \ - vst1_s16(dst_ptr, c[0][0]); \ - vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ - break; \ - case 2: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - break; \ - case 3: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1_s16(dst_ptr + 8, c[0][2]); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ - break; \ - case 4: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - break; \ - case 5: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1_s16(dst_ptr + 16, c[0][4]); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - vst1_s16(dst_ptr + ld_dst_oc + 16, c[1][4]); \ - break; \ - case 6: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 16, \ - vcombine_s16(c[1][4], c[1][5])); \ - break; \ - case 7: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ - vst1_s16(dst_ptr + 24, c[0][6]); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 16, \ - vcombine_s16(c[1][4], c[1][5])); \ - vst1_s16(dst_ptr + ld_dst_oc + 24, c[1][6]); \ - break; \ - default: \ - megdnn_assert(0, "oc 2 error remainw"); \ - break; \ +#define STORE_2_LINE_RESULT() \ + switch (remain_w) { \ + case 8: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, vcombine_s16(c[1][4], c[1][5])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 24, vcombine_s16(c[1][6], c[1][7])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ + break; \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + break; \ + case 5: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1_s16(dst_ptr + 16, c[0][4]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + vst1_s16(dst_ptr + ld_dst_oc + 16, c[1][4]); \ + break; \ + case 6: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, vcombine_s16(c[1][4], c[1][5])); \ + break; \ + case 7: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1_s16(dst_ptr + 24, c[0][6]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, vcombine_s16(c[1][4], c[1][5])); \ + vst1_s16(dst_ptr + ld_dst_oc + 24, c[1][6]); \ + break; \ + default: \ + megdnn_assert(0, "oc 2 error remainw"); \ + break; \ } template -static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc) { +static void ker_neon_dirctconv_2x2s1_oc8_ow8( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -171,10 +160,10 @@ static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, #undef cb for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { - const int8_t* src_row0 = src_ptr + ic_idx * ic_stride + - 0 * iw * ic_step * src_expand_size; - const int8_t* src_row1 = src_ptr + ic_idx * ic_stride + - 1 * iw * ic_step * src_expand_size; + const int8_t* src_row0 = + src_ptr + ic_idx * ic_stride + 0 * iw * ic_step * src_expand_size; + const int8_t* src_row1 = + src_ptr + ic_idx * ic_stride + 1 * iw * ic_step * src_expand_size; src[0] = vld1q_s8(src_row0); src[1] = vld1q_s8(src_row0 + 16); @@ -277,11 +266,9 @@ static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, } template -static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, int ih, - int iw, int /*ld_dst_oc*/) { +static void ker_neon_dirctconv_2x2s1_oc4_ow8( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int /*ld_dst_oc*/) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -307,42 +294,33 @@ static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, fh_idx * iw * ic_step * src_expand_size; src[0] = vld1q_s8(src_ic_0_3); src[1] = vld1q_s8(src_ic_0_3 + 16); - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0][0] = vld1q_s8(read_weight_ptr); weight[0][1] = vld1q_s8(read_weight_ptr + 16); int16x8_t tmp0; src[2] = vld1q_s8(src_ic_0_3 + 2 * 16); src[3] = vld1q_s8(src_ic_0_3 + 3 * 16); - CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], - c[0][0]); + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); src[0] = vld1q_s8(src_ic_0_3 + 4 * 16); - CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], - c[0][1]); + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][1]); src[1] = vld1q_s8(src_ic_0_3 + 5 * 16); - CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], - c[0][2]); + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][2]); src[2] = vld1q_s8(src_ic_0_3 + 6 * 16); - CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], - c[0][3]); + CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][3]); src[3] = vld1q_s8(src_ic_0_3 + 7 * 16); - CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], - c[0][4]); + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][4]); src[0] = vld1q_s8(src_ic_0_3 + 8 * 16); - CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], - c[0][5]); + CALC_ONE_RESULT(src[1], src[2], weight[0][0], weight[0][1], c[0][5]); - CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], - c[0][6]); + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][6]); - CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], - c[0][7]); + CALC_ONE_RESULT(src[3], src[0], weight[0][0], weight[0][1], c[0][7]); } weight_ptr += fh * fw * ld_weight_ic4; } @@ -353,16 +331,16 @@ static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, template struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc); + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc); }; template struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int /*ld_dst_oc*/) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int /*ld_dst_oc*/) { constexpr int filter_size = 3; constexpr int fh = filter_size; constexpr int fw = filter_size; @@ -385,9 +363,8 @@ struct KerNeonDirectStride1Int8 { for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = - src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * src_expand_size; + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * src_expand_size; src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); @@ -415,28 +392,36 @@ struct KerNeonDirectStride1Int8 { int16x8_t tmp0; - CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], weight[1], - weight[2], c[0][0]); + CALC_ONE_RESULT( + src[0], src[1], src[2], weight[0], weight[1], weight[2], + c[0][0]); src[4] = vld1q_s8(src_ic_0_3 + 4 * 16); - CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], weight[1], - weight[2], c[0][1]); + CALC_ONE_RESULT( + src[1], src[2], src[3], weight[0], weight[1], weight[2], + c[0][1]); src[0] = vld1q_s8(src_ic_0_3 + 5 * 16); - CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], weight[1], - weight[2], c[0][2]); + CALC_ONE_RESULT( + src[2], src[3], src[4], weight[0], weight[1], weight[2], + c[0][2]); src[1] = vld1q_s8(src_ic_0_3 + 6 * 16); - CALC_ONE_RESULT(src[3], src[4], src[0], weight[0], weight[1], - weight[2], c[0][3]); + CALC_ONE_RESULT( + src[3], src[4], src[0], weight[0], weight[1], weight[2], + c[0][3]); src[2] = vld1q_s8(src_ic_0_3 + 7 * 16); - CALC_ONE_RESULT(src[4], src[0], src[1], weight[0], weight[1], - weight[2], c[0][4]); + CALC_ONE_RESULT( + src[4], src[0], src[1], weight[0], weight[1], weight[2], + c[0][4]); src[3] = vld1q_s8(src_ic_0_3 + 8 * 16); - CALC_ONE_RESULT(src[0], src[1], src[2], weight[0], weight[1], - weight[2], c[0][5]); + CALC_ONE_RESULT( + src[0], src[1], src[2], weight[0], weight[1], weight[2], + c[0][5]); src[4] = vld1q_s8(src_ic_0_3 + 9 * 16); - CALC_ONE_RESULT(src[1], src[2], src[3], weight[0], weight[1], - weight[2], c[0][6]); - CALC_ONE_RESULT(src[2], src[3], src[4], weight[0], weight[1], - weight[2], c[0][7]); + CALC_ONE_RESULT( + src[1], src[2], src[3], weight[0], weight[1], weight[2], + c[0][6]); + CALC_ONE_RESULT( + src[2], src[3], src[4], weight[0], weight[1], weight[2], + c[0][7]); } weight_ptr += fh * fw * ld_weight_ic4; } @@ -447,9 +432,9 @@ struct KerNeonDirectStride1Int8 { #undef CALC_ONE_RESULT template struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int /*ld_dst_oc*/) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int /*ld_dst_oc*/) { constexpr int filter_size = 5; constexpr int fh = filter_size; constexpr int fw = filter_size; @@ -471,9 +456,8 @@ struct KerNeonDirectStride1Int8 { for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = - src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * src_expand_size; + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * src_expand_size; src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); @@ -493,49 +477,49 @@ struct KerNeonDirectStride1Int8 { weight[2] = vld1q_s8(read_weight_ptr + 2 * 16); weight[3] = vld1q_s8(read_weight_ptr + 3 * 16); weight[4] = vld1q_s8(read_weight_ptr + 4 * 16); -#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, \ - _w4, _c) \ - do { \ - int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ - int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \ - tmp0 = vaddq_s16(tmp0, tmp1); \ - _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ +#define CALC_ONE_RESULT( \ + _src0, _src1, _src2, _src3, _src4, _w0, _w1, _w2, _w3, _w4, _c) \ + do { \ + int16x8_t tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w0)); \ + int16x8_t tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w0)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w1)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w1)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w2)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w2)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w3)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w3)); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w4)); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w4)); \ + tmp0 = vaddq_s16(tmp0, tmp1); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ } while (0); - CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][0]); - CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][1]); - CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][2]); - CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][3]); - CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][4]); - CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][5]); + CALC_ONE_RESULT( + src[0], src[1], src[2], src[3], src[4], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][0]); + CALC_ONE_RESULT( + src[1], src[2], src[3], src[4], src[5], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][1]); + CALC_ONE_RESULT( + src[2], src[3], src[4], src[5], src[6], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][2]); + CALC_ONE_RESULT( + src[3], src[4], src[5], src[6], src[7], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][3]); + CALC_ONE_RESULT( + src[4], src[5], src[6], src[7], src[8], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][4]); + CALC_ONE_RESULT( + src[5], src[6], src[7], src[8], src[9], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][5]); src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); - CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][6]); - CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], - weight[0], weight[1], weight[2], weight[3], - weight[4], c[0][7]); + CALC_ONE_RESULT( + src[6], src[7], src[8], src[9], src[0], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][6]); + CALC_ONE_RESULT( + src[7], src[8], src[9], src[0], src[1], weight[0], weight[1], + weight[2], weight[3], weight[4], c[0][7]); } weight_ptr += fh * fw * ld_weight_ic4; } @@ -546,9 +530,9 @@ struct KerNeonDirectStride1Int8 { #undef CALC_ONE_RESULT template struct KerNeonDirectStride1Int8 { - static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int /*ld_dst_oc*/) { + static void impl( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int /*ld_dst_oc*/) { constexpr int filter_size = 7; constexpr int fh = filter_size; constexpr int fw = filter_size; @@ -571,9 +555,8 @@ struct KerNeonDirectStride1Int8 { for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { - const int8_t* src_ic_0_3 = - src_ptr + ic_idx * ic_stride + - fh_idx * iw * ic_step * src_expand_size; + const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + + fh_idx * iw * ic_step * src_expand_size; src[0] = vld1q_s8(src_ic_0_3 + 0 * 16); src[1] = vld1q_s8(src_ic_0_3 + 1 * 16); @@ -597,48 +580,55 @@ struct KerNeonDirectStride1Int8 { weight[5] = vld1q_s8(read_weight_ptr + 5 * 16); weight[6] = vld1q_s8(read_weight_ptr + 6 * 16); -#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \ - _c) \ - do { \ - tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ - tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ - tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ - tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \ - tmp0 = vmlal_s8(tmp0, vget_high_s8(_src3), vget_high_s8(_w[3])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ - tmp0 = vmlal_s8(tmp0, vget_high_s8(_src4), vget_high_s8(_w[4])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src5), vget_low_s8(_w[5])); \ - tmp0 = vmlal_s8(tmp0, vget_high_s8(_src5), vget_high_s8(_w[5])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ - tmp0 = vmlal_s8(tmp0, vget_high_s8(_src6), vget_high_s8(_w[6])); \ - _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ +#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, _c) \ + do { \ + tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src0), vget_high_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src1), vget_high_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src2), vget_high_s8(_w[2])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src3), vget_high_s8(_w[3])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src4), vget_high_s8(_w[4])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src5), vget_low_s8(_w[5])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src5), vget_high_s8(_w[5])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ + tmp0 = vmlal_s8(tmp0, vget_high_s8(_src6), vget_high_s8(_w[6])); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ } while (0); int16x8_t tmp0; - CALC_ONE_RESULT(src[0], src[1], src[2], src[3], src[4], src[5], - src[6], weight, c[0][0]); - CALC_ONE_RESULT(src[1], src[2], src[3], src[4], src[5], src[6], - src[7], weight, c[0][1]); + CALC_ONE_RESULT( + src[0], src[1], src[2], src[3], src[4], src[5], src[6], weight, + c[0][0]); + CALC_ONE_RESULT( + src[1], src[2], src[3], src[4], src[5], src[6], src[7], weight, + c[0][1]); src[0] = vld1q_s8(src_ic_0_3 + 10 * 16); src[1] = vld1q_s8(src_ic_0_3 + 11 * 16); - CALC_ONE_RESULT(src[2], src[3], src[4], src[5], src[6], src[7], - src[8], weight, c[0][2]); - CALC_ONE_RESULT(src[3], src[4], src[5], src[6], src[7], src[8], - src[9], weight, c[0][3]); + CALC_ONE_RESULT( + src[2], src[3], src[4], src[5], src[6], src[7], src[8], weight, + c[0][2]); + CALC_ONE_RESULT( + src[3], src[4], src[5], src[6], src[7], src[8], src[9], weight, + c[0][3]); src[2] = vld1q_s8(src_ic_0_3 + 12 * 16); src[3] = vld1q_s8(src_ic_0_3 + 13 * 16); - CALC_ONE_RESULT(src[4], src[5], src[6], src[7], src[8], src[9], - src[0], weight, c[0][4]); - CALC_ONE_RESULT(src[5], src[6], src[7], src[8], src[9], src[0], - src[1], weight, c[0][5]); - CALC_ONE_RESULT(src[6], src[7], src[8], src[9], src[0], src[1], - src[2], weight, c[0][6]); - CALC_ONE_RESULT(src[7], src[8], src[9], src[0], src[1], src[2], - src[3], weight, c[0][7]); + CALC_ONE_RESULT( + src[4], src[5], src[6], src[7], src[8], src[9], src[0], weight, + c[0][4]); + CALC_ONE_RESULT( + src[5], src[6], src[7], src[8], src[9], src[0], src[1], weight, + c[0][5]); + CALC_ONE_RESULT( + src[6], src[7], src[8], src[9], src[0], src[1], src[2], weight, + c[0][6]); + CALC_ONE_RESULT( + src[7], src[8], src[9], src[0], src[1], src[2], src[3], weight, + c[0][7]); } weight_ptr += fh * fw * ld_weight_ic4; } @@ -648,9 +638,9 @@ struct KerNeonDirectStride1Int8 { template void conv_direct_stride1_2x2_int8_oc8_ow8_nchw44( - const int8_t* src, const int8_t* filter, const int16_t* bias, - int16_t* dst, const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow) { + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { constexpr size_t filter_size = 2; constexpr size_t fh = filter_size; constexpr size_t fw = filter_size; @@ -667,21 +657,19 @@ void conv_direct_stride1_2x2_int8_oc8_ow8_nchw44( const size_t oc_end = oc / big_oc_step * big_oc_step; const size_t oc_remain = oc - oc_end; const int ld_oc = oh * ow * oc_step; - using remain_fun = - std::function; + using remain_fun = std::function; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = ker_neon_dirctconv_2x2s1_oc8_ow8; \ - kern_small_oc_remain = \ - ker_neon_dirctconv_2x2s1_oc4_ow8; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + ker_neon_dirctconv_2x2s1_oc8_ow8; \ + kern_small_oc_remain = \ + ker_neon_dirctconv_2x2s1_oc4_ow8; \ break; UNROLL_CALL_RAW(8, cb); @@ -707,9 +695,9 @@ void conv_direct_stride1_2x2_int8_oc8_ow8_nchw44( (oh_idx * iw + ow_end) * ic_step * src_expand_size; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, - ld_oc); + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc); } } } @@ -731,9 +719,9 @@ void conv_direct_stride1_2x2_int8_oc8_ow8_nchw44( (oh_idx * iw + ow_end) * ic_step * src_expand_size; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_oc); + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_oc); } } } @@ -741,12 +729,10 @@ void conv_direct_stride1_2x2_int8_oc8_ow8_nchw44( #undef CALC_ONE_RESULT template -void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, - const int8_t* filter, - const int16_t* bias, int16_t* dst, - const size_t oc, const size_t ic, - const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { +void conv_direct_stride1_int8_nchw44_kern( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { constexpr size_t fh = filter_size; constexpr size_t fw = filter_size; constexpr size_t ic_step = 4; @@ -760,17 +746,16 @@ void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, const size_t ow_end = ow / ow_step * ow_step; const size_t ow_remain = ow - ow_end; - using remain_fun = - std::function; + using remain_fun = std::function; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_small_oc_remain = KerNeonDirectStride1Int8::impl; \ +#define cb(step) \ + case step: \ + kern_small_oc_remain = \ + KerNeonDirectStride1Int8::impl; \ break; UNROLL_CALL_RAW(8, cb); @@ -796,9 +781,9 @@ void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, (oh_idx * iw + ow_end) * ic_step * src_expand_size; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc); + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc); } } } @@ -808,10 +793,10 @@ void conv_direct_stride1_int8_nchw44_kern(const int8_t* src, namespace int8x8x16_direct_nchw44 { template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int16_t* bias, int16_t* dst, const size_t oc, - const size_t ic, const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { + static void impl( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { conv_direct_stride1_int8_nchw44_kern( src, filter, bias, dst, oc, ic, ih, iw, oh, ow); } @@ -819,10 +804,10 @@ struct ConvDirectInt8Nchw44Choose { template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int16_t* bias, int16_t* dst, const size_t oc, - const size_t ic, const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { + static void impl( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { conv_direct_stride1_2x2_int8_oc8_ow8_nchw44( src, filter, bias, dst, oc, ic, ih, iw, oh, ow); } diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s2.cpp index d32eb0be..e32aa32b 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_kernels/int8x8x16_direct_nchw44_s2.cpp @@ -92,112 +92,100 @@ namespace { break; \ }; -#define STORE_2_LINE_RESULT() \ - switch (remain_w) { \ - case 8: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ - vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 16, \ - vcombine_s16(c[1][4], c[1][5])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 24, \ - vcombine_s16(c[1][6], c[1][7])); \ - break; \ - case 1: \ - vst1_s16(dst_ptr, c[0][0]); \ - vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ - break; \ - case 2: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - break; \ - case 3: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1_s16(dst_ptr + 8, c[0][2]); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ - break; \ - case 4: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - break; \ - case 5: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1_s16(dst_ptr + 16, c[0][4]); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - vst1_s16(dst_ptr + ld_dst_oc + 16, c[1][4]); \ - break; \ - case 6: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 16, \ - vcombine_s16(c[1][4], c[1][5])); \ - break; \ - case 7: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ - vst1_s16(dst_ptr + 24, c[0][6]); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 16, \ - vcombine_s16(c[1][4], c[1][5])); \ - vst1_s16(dst_ptr + ld_dst_oc + 24, c[1][6]); \ - break; \ - default: \ - megdnn_assert(0, "oc 2 error remainw"); \ - break; \ +#define STORE_2_LINE_RESULT() \ + switch (remain_w) { \ + case 8: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + 24, vcombine_s16(c[0][6], c[0][7])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, vcombine_s16(c[1][4], c[1][5])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 24, vcombine_s16(c[1][6], c[1][7])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ + break; \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + break; \ + case 5: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1_s16(dst_ptr + 16, c[0][4]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + vst1_s16(dst_ptr + ld_dst_oc + 16, c[1][4]); \ + break; \ + case 6: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, vcombine_s16(c[1][4], c[1][5])); \ + break; \ + case 7: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + 16, vcombine_s16(c[0][4], c[0][5])); \ + vst1_s16(dst_ptr + 24, c[0][6]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 16, vcombine_s16(c[1][4], c[1][5])); \ + vst1_s16(dst_ptr + ld_dst_oc + 24, c[1][6]); \ + break; \ + default: \ + megdnn_assert(0, "oc 2 error remainw"); \ + break; \ } -#define STORE_2_LINE_RESULT_OW4() \ - switch (remain_w) { \ - case 4: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc + 8, \ - vcombine_s16(c[1][2], c[1][3])); \ - break; \ - case 1: \ - vst1_s16(dst_ptr, c[0][0]); \ - vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ - break; \ - case 2: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - break; \ - case 3: \ - vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ - vst1_s16(dst_ptr + 8, c[0][2]); \ - vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ - vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ - break; \ - default: \ - megdnn_assert(0, "oc 2 error remainw"); \ - break; \ +#define STORE_2_LINE_RESULT_OW4() \ + switch (remain_w) { \ + case 4: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc + 8, vcombine_s16(c[1][2], c[1][3])); \ + break; \ + case 1: \ + vst1_s16(dst_ptr, c[0][0]); \ + vst1_s16(dst_ptr + ld_dst_oc, c[1][0]); \ + break; \ + case 2: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + break; \ + case 3: \ + vst1q_s16(dst_ptr, vcombine_s16(c[0][0], c[0][1])); \ + vst1_s16(dst_ptr + 8, c[0][2]); \ + vst1q_s16(dst_ptr + ld_dst_oc, vcombine_s16(c[1][0], c[1][1])); \ + vst1_s16(dst_ptr + ld_dst_oc + 8, c[1][2]); \ + break; \ + default: \ + megdnn_assert(0, "oc 2 error remainw"); \ + break; \ } template -static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc) { +static void ker_neon_dirctconv_2x2s2_oc8_ow8( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -232,8 +220,7 @@ static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0][0] = vld1q_s8(read_weight_ptr); weight[0][1] = vld1q_s8(read_weight_ptr + 16); @@ -250,62 +237,46 @@ static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, } while (0); int16x8_t tmp0; - CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], - c[0][0]); - CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], - c[1][0]); + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][0]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][0]); src[0] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); src[1] = vld_dup_tbl_s32(src_ic_0_3 + 20, idx); - CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], - c[0][1]); - CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], - c[1][1]); + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][1]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][1]); src[2] = vld_dup_tbl_s32(src_ic_0_3 + 24, idx); src[3] = vld_dup_tbl_s32(src_ic_0_3 + 28, idx); - CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], - c[0][2]); - CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], - c[1][2]); + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][2]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][2]); src[0] = vld_dup_tbl_s32(src_ic_0_3 + 32, idx); src[1] = vld_dup_tbl_s32(src_ic_0_3 + 36, idx); - CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], - c[0][3]); - CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], - c[1][3]); + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][3]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][3]); src[2] = vld_dup_tbl_s32(src_ic_0_3 + 40, idx); src[3] = vld_dup_tbl_s32(src_ic_0_3 + 44, idx); - CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], - c[0][4]); - CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], - c[1][4]); + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][4]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][4]); src[0] = vld_dup_tbl_s32(src_ic_0_3 + 48, idx); src[1] = vld_dup_tbl_s32(src_ic_0_3 + 52, idx); - CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], - c[0][5]); - CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], - c[1][5]); + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][5]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][5]); src[2] = vld_dup_tbl_s32(src_ic_0_3 + 56, idx); src[3] = vld_dup_tbl_s32(src_ic_0_3 + 60, idx); - CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], - c[0][6]); - CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], - c[1][6]); - CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], - c[0][7]); - CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], - c[1][7]); + CALC_ONE_RESULT(src[0], src[1], weight[0][0], weight[0][1], c[0][6]); + CALC_ONE_RESULT(src[0], src[1], weight[1][0], weight[1][1], c[1][6]); + CALC_ONE_RESULT(src[2], src[3], weight[0][0], weight[0][1], c[0][7]); + CALC_ONE_RESULT(src[2], src[3], weight[1][0], weight[1][1], c[1][7]); } weight_ptr += fh * fw * ld_weight_ic4; } @@ -313,11 +284,9 @@ static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, } template -static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, int ih, - int iw, int /*ld_dst_oc*/) { +static void ker_neon_dirctconv_2x2s2_oc4_ow8( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int /*ld_dst_oc*/) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -349,8 +318,7 @@ static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0] = vld1q_s8(read_weight_ptr); weight[1] = vld1q_s8(read_weight_ptr + 16); @@ -399,11 +367,9 @@ static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, } while (0); template -static void ker_neon_dirctconv_3x3s2_oc8_ow4(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc) { +static void ker_neon_dirctconv_3x3s2_oc8_ow4( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -439,8 +405,7 @@ static void ker_neon_dirctconv_3x3s2_oc8_ow4(const int8_t* src_ptr, src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0][0] = vld1q_s8(read_weight_ptr); weight[0][1] = vld1q_s8(read_weight_ptr + 16); @@ -477,12 +442,9 @@ static void ker_neon_dirctconv_3x3s2_oc8_ow4(const int8_t* src_ptr, } template -static void ker_neon_dirctconv_3x3s2_oc8_ow4_remain(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, - int ih, int iw, - int ld_dst_oc) { +static void ker_neon_dirctconv_3x3s2_oc8_ow4_remain( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -518,8 +480,7 @@ static void ker_neon_dirctconv_3x3s2_oc8_ow4_remain(const int8_t* src_ptr, src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0][0] = vld1q_s8(read_weight_ptr); weight[0][1] = vld1q_s8(read_weight_ptr + 16); @@ -565,11 +526,9 @@ static void ker_neon_dirctconv_3x3s2_oc8_ow4_remain(const int8_t* src_ptr, } while (0); template -static void ker_neon_dirctconv_3x3s2_oc4_ow4(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, int ih, - int iw, int /*ld_dst_oc*/) { +static void ker_neon_dirctconv_3x3s2_oc4_ow4( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int /*ld_dst_oc*/) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -602,8 +561,7 @@ static void ker_neon_dirctconv_3x3s2_oc4_ow4(const int8_t* src_ptr, src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0] = vld1q_s8(read_weight_ptr); weight[1] = vld1q_s8(read_weight_ptr + 16); @@ -628,12 +586,9 @@ static void ker_neon_dirctconv_3x3s2_oc4_ow4(const int8_t* src_ptr, vst1q_s16(dst_ptr + 8, vcombine_s16(c[0][2], c[0][3])); } template -static void ker_neon_dirctconv_3x3s2_oc4_ow4_remain(const int8_t* src_ptr, - const int8_t* weight_ptr, - const int16_t* bias_ptr, - int16_t* dst_ptr, int ic, - int ih, int iw, - int /*ld_dst_oc*/) { +static void ker_neon_dirctconv_3x3s2_oc4_ow4_remain( + const int8_t* src_ptr, const int8_t* weight_ptr, const int16_t* bias_ptr, + int16_t* dst_ptr, int ic, int ih, int iw, int /*ld_dst_oc*/) { constexpr int fh = filter_size; constexpr int fw = filter_size; constexpr int ic_step = 4; @@ -664,8 +619,7 @@ static void ker_neon_dirctconv_3x3s2_oc4_ow4_remain(const int8_t* src_ptr, src[2] = vld_dup_tbl_s32(src_ic_0_3 + 8, idx); src[3] = vld_dup_tbl_s32(src_ic_0_3 + 12, idx); src[4] = vld_dup_tbl_s32(src_ic_0_3 + 16, idx); - const int8_t* read_weight_ptr = - weight_ptr + fh_idx * fw * ld_weight_ic4; + const int8_t* read_weight_ptr = weight_ptr + fh_idx * fw * ld_weight_ic4; weight[0] = vld1q_s8(read_weight_ptr); weight[1] = vld1q_s8(read_weight_ptr + 16); @@ -692,12 +646,10 @@ static void ker_neon_dirctconv_3x3s2_oc4_ow4_remain(const int8_t* src_ptr, #undef CALC_ONE_RESULT template -void conv_direct_stride2_2x2_int8_nchw44(const int8_t* src, - const int8_t* filter, - const int16_t* bias, int16_t* dst, - const size_t oc, const size_t ic, - const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { +void conv_direct_stride2_2x2_int8_nchw44( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { constexpr size_t filter_size = 2; constexpr size_t fh = filter_size; constexpr size_t fw = filter_size; @@ -716,21 +668,19 @@ void conv_direct_stride2_2x2_int8_nchw44(const int8_t* src, const size_t oc_remain = oc - oc_end; const int ld_dst_oc = oh * ow * oc_step; - using remain_fun = - std::function; + using remain_fun = std::function; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = ker_neon_dirctconv_2x2s2_oc8_ow8; \ - kern_small_oc_remain = \ - ker_neon_dirctconv_2x2s2_oc4_ow8; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + ker_neon_dirctconv_2x2s2_oc8_ow8; \ + kern_small_oc_remain = \ + ker_neon_dirctconv_2x2s2_oc4_ow8; \ break; UNROLL_CALL_RAW(8, cb); @@ -745,21 +695,20 @@ void conv_direct_stride2_2x2_int8_nchw44(const int8_t* src, for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { const size_t src_offset = (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s2_oc8_ow8( + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc8_ow8( src + src_offset, filter + weight_offset, bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc); } if (ow_remain > 0) { const size_t src_offset = (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, - ld_dst_oc); + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc); } } } @@ -771,33 +720,30 @@ void conv_direct_stride2_2x2_int8_nchw44(const int8_t* src, for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { const size_t src_offset = (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s2_oc4_ow8( + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_2x2s2_oc4_ow8( src + src_offset, filter + weight_offset, bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc); } if (ow_remain > 0) { const size_t src_offset = (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc); + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc); } } } } template -void conv_direct_stride2_3x3_int8_nchw44(const int8_t* src, - const int8_t* filter, - const int16_t* bias, int16_t* dst, - const size_t oc, const size_t ic, - const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { +void conv_direct_stride2_3x3_int8_nchw44( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { constexpr size_t filter_size = 3; constexpr size_t fh = filter_size; constexpr size_t fw = filter_size; @@ -817,22 +763,19 @@ void conv_direct_stride2_3x3_int8_nchw44(const int8_t* src, const size_t oc_remain = oc - oc_end; const int ld_dst_oc = oh * ow * oc_step; - using remain_fun = - std::function; + using remain_fun = std::function; remain_fun kern_big_oc_remain = nullptr; remain_fun kern_small_oc_remain = nullptr; switch (ow_remain) { -#define cb(step) \ - case step: \ - kern_big_oc_remain = \ - ker_neon_dirctconv_3x3s2_oc8_ow4_remain; \ - kern_small_oc_remain = \ - ker_neon_dirctconv_3x3s2_oc4_ow4_remain; \ +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + ker_neon_dirctconv_3x3s2_oc8_ow4_remain; \ + kern_small_oc_remain = \ + ker_neon_dirctconv_3x3s2_oc4_ow4_remain; \ break; UNROLL_CALL_RAW(8, cb); @@ -847,21 +790,20 @@ void conv_direct_stride2_3x3_int8_nchw44(const int8_t* src, for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step4) { const size_t src_offset = (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_3x3s2_oc8_ow4( + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_3x3s2_oc8_ow4( src + src_offset, filter + weight_offset, bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc); } if (ow_remain > 0) { const size_t src_offset = (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, - ld_dst_oc); + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc); } } } @@ -873,21 +815,20 @@ void conv_direct_stride2_3x3_int8_nchw44(const int8_t* src, for (size_t ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { const size_t src_offset = (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_3x3s2_oc4_ow4( + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_idx) * oc_step; + ker_neon_dirctconv_3x3s2_oc4_ow4( src + src_offset, filter + weight_offset, bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_dst_oc); } if (ow_remain > 0) { const size_t src_offset = (oh_idx * stride_h * iw + ow_end * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc); + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain( + src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc); } } } @@ -896,9 +837,9 @@ void conv_direct_stride2_3x3_int8_nchw44(const int8_t* src, #undef LOAD_SRC template void conv_direct_stride2_5x5_int8x8x16_nchw44( - const int8_t* src, const int8_t* filter, const int16_t* bias, - int16_t* dst, const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow) { + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { constexpr size_t filter_size = 5; constexpr size_t fh = filter_size; constexpr size_t fw = filter_size; @@ -950,8 +891,8 @@ void conv_direct_stride2_5x5_int8x8x16_nchw44( for (; ow_idx + ow_step - 1 < ow; ow_idx += ow_step) { const size_t src_offset = (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_idx) * oc_step; constexpr int ld_weight_ic4 = 16; @@ -965,7 +906,7 @@ void conv_direct_stride2_5x5_int8x8x16_nchw44( c[0][1] = init_sum; c[0][2] = init_sum; c[0][3] = init_sum; -#if MEGDNN_AARCH64 +#if MEGDNN_AARCH64 int8x16_t weight[3][5]; int8x16_t ssrc[2][5]; #else @@ -1017,7 +958,7 @@ void conv_direct_stride2_5x5_int8x8x16_nchw44( LOAD_WEIGHT(weight[0], weight_ptr, 0, 1, 2, 3, 4); LOAD_WEIGHT(weight[1], weight_ptr, 5, 6, 7, 8, 9); CALC_4_RESULT(ssrc[0], weight[0], src_row0); - + LOAD_SRC(ssrc[1], src_row1); LOAD_WEIGHT(weight[2], weight_ptr, 10, 11, 12, 13, 14); LOAD_SRC(ssrc[0], src_row2); @@ -1026,7 +967,7 @@ void conv_direct_stride2_5x5_int8x8x16_nchw44( LOAD_SRC(ssrc[1], src_row3); LOAD_WEIGHT(weight[0], weight_ptr, 15, 16, 17, 18, 19); CALC_4_RESULT(ssrc[0], weight[2], src_row2); - + LOAD_SRC(ssrc[0], src_row4); LOAD_WEIGHT(weight[1], weight_ptr, 20, 21, 22, 23, 24); CALC_4_RESULT(ssrc[1], weight[0], src_row3); @@ -1051,20 +992,24 @@ void conv_direct_stride2_5x5_int8x8x16_nchw44( weight[0][3] = vld1q_s8(_w_ptr + _id3 * 16); \ weight[0][4] = vld1q_s8(_w_ptr + _id4 * 16); -#define CALC_4_RESULT(_src_ptr) \ - CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \ - ssrc[0][4], weight[0], c[0][0]); \ - ssrc[0][0] = vld_dup_tbl_s32(_src_ptr + 36, idx); \ - ssrc[0][1] = vld_dup_tbl_s32(_src_ptr + 40, idx); \ - CALC_ONE_RESULT(ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], \ - ssrc[0][6], weight[0], c[0][1]); \ - CALC_ONE_RESULT(ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], \ - ssrc[0][8], weight[0], c[0][2]); \ - CALC_ONE_RESULT(ssrc[0][6], ssrc[0][7], ssrc[0][8], ssrc[0][0], \ - ssrc[0][1], weight[0], c[0][3]); +#define CALC_4_RESULT(_src_ptr) \ + CALC_ONE_RESULT( \ + ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], ssrc[0][4], weight[0], \ + c[0][0]); \ + ssrc[0][0] = vld_dup_tbl_s32(_src_ptr + 36, idx); \ + ssrc[0][1] = vld_dup_tbl_s32(_src_ptr + 40, idx); \ + CALC_ONE_RESULT( \ + ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], ssrc[0][6], weight[0], \ + c[0][1]); \ + CALC_ONE_RESULT( \ + ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], ssrc[0][8], weight[0], \ + c[0][2]); \ + CALC_ONE_RESULT( \ + ssrc[0][6], ssrc[0][7], ssrc[0][8], ssrc[0][0], ssrc[0][1], weight[0], \ + c[0][3]); int16x8_t tmp0, tmp1; - + LOAD_WEIGHT(weight_ptr, 0, 1, 2, 3, 4); LOAD_SRC(src_row0); CALC_4_RESULT(src_row0); @@ -1072,15 +1017,15 @@ void conv_direct_stride2_5x5_int8x8x16_nchw44( LOAD_WEIGHT(weight_ptr, 5, 6, 7, 8, 9); LOAD_SRC(src_row1); CALC_4_RESULT(src_row1); - + LOAD_WEIGHT(weight_ptr, 10, 11, 12, 13, 14); LOAD_SRC(src_row2); CALC_4_RESULT(src_row2); - + LOAD_WEIGHT(weight_ptr, 15, 16, 17, 18, 19); LOAD_SRC(src_row3); CALC_4_RESULT(src_row3); - + LOAD_WEIGHT(weight_ptr, 20, 21, 22, 23, 24); LOAD_SRC(src_row4); CALC_4_RESULT(src_row4); @@ -1094,8 +1039,8 @@ void conv_direct_stride2_5x5_int8x8x16_nchw44( if (remain_w > 0) { const size_t src_offset = (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_idx) * oc_step; constexpr int ld_weight_ic4 = 16; @@ -1191,13 +1136,16 @@ void conv_direct_stride2_5x5_int8x8x16_nchw44( weight[0][3] = vld1q_s8(_w_ptr + _id3 * 16); \ weight[0][4] = vld1q_s8(_w_ptr + _id4 * 16); -#define CALC_3_RESULT(_src_ptr) \ - CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \ - ssrc[0][4], weight[0], c[0][0]); \ - CALC_ONE_RESULT(ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], \ - ssrc[0][6], weight[0], c[0][1]); \ - CALC_ONE_RESULT(ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], \ - ssrc[0][8], weight[0], c[0][2]); +#define CALC_3_RESULT(_src_ptr) \ + CALC_ONE_RESULT( \ + ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], ssrc[0][4], weight[0], \ + c[0][0]); \ + CALC_ONE_RESULT( \ + ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], ssrc[0][6], weight[0], \ + c[0][1]); \ + CALC_ONE_RESULT( \ + ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], ssrc[0][8], weight[0], \ + c[0][2]); int16x8_t tmp0, tmp1; @@ -1249,9 +1197,9 @@ void conv_direct_stride2_5x5_int8x8x16_nchw44( template void conv_direct_stride2_7x7_int8x8x16_nchw44( - const int8_t* src, const int8_t* filter, const int16_t* bias, - int16_t* dst, const size_t oc, const size_t ic, const size_t ih, - const size_t iw, const size_t oh, const size_t ow) { + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { constexpr size_t filter_size = 7; constexpr size_t fh = filter_size; constexpr size_t fw = filter_size; @@ -1281,25 +1229,24 @@ void conv_direct_stride2_7x7_int8x8x16_nchw44( } size_t oh_idx = 0; -#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, \ - _c) \ - do { \ - tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ - tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w[1])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w[3])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src5), vget_low_s8(_w[5])); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src5), vget_high_s8(_w[5])); \ - tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ - tmp1 = vmlal_s8(tmp1, vget_high_s8(_src6), vget_high_s8(_w[6])); \ - tmp0 = vaddq_s16(tmp0, tmp1); \ - _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ +#define CALC_ONE_RESULT(_src0, _src1, _src2, _src3, _src4, _src5, _src6, _w, _c) \ + do { \ + tmp0 = vmull_s8(vget_low_s8(_src0), vget_low_s8(_w[0])); \ + tmp1 = vmull_s8(vget_high_s8(_src0), vget_high_s8(_w[0])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src1), vget_low_s8(_w[1])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src1), vget_high_s8(_w[1])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src2), vget_low_s8(_w[2])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src2), vget_high_s8(_w[2])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src3), vget_low_s8(_w[3])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src3), vget_high_s8(_w[3])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src4), vget_low_s8(_w[4])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src4), vget_high_s8(_w[4])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src5), vget_low_s8(_w[5])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src5), vget_high_s8(_w[5])); \ + tmp0 = vmlal_s8(tmp0, vget_low_s8(_src6), vget_low_s8(_w[6])); \ + tmp1 = vmlal_s8(tmp1, vget_high_s8(_src6), vget_high_s8(_w[6])); \ + tmp0 = vaddq_s16(tmp0, tmp1); \ + _c = vadd_s16(_c, vadd_s16(vget_low_s16(tmp0), vget_high_s16(tmp0))); \ } while (0); for (; oh_idx < oh; oh_idx += oh_step1) { @@ -1307,8 +1254,8 @@ void conv_direct_stride2_7x7_int8x8x16_nchw44( for (; ow_idx + ow_step - 1 < ow; ow_idx += ow_step) { const size_t src_offset = (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_idx) * oc_step; constexpr int ld_weight_ic4 = 16; @@ -1359,25 +1306,29 @@ void conv_direct_stride2_7x7_int8x8x16_nchw44( weight[0][5] = vld1q_s8(weight_ptr + _id5 * 16); \ weight[0][6] = vld1q_s8(weight_ptr + _id6 * 16); -#define CALC_4_RESULT(_row) \ - CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \ - ssrc[0][4], ssrc[0][5], ssrc[0][6], weight[0], c[0][0]); \ - \ - ssrc[0][7] = vld_dup_tbl_s32(_row + 28, idx); \ - ssrc[0][8] = vld_dup_tbl_s32(_row + 32, idx); \ - CALC_ONE_RESULT(ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], \ - ssrc[0][6], ssrc[0][7], ssrc[0][8], weight[0], c[0][1]); \ - \ - ssrc[0][0] = vld_dup_tbl_s32(_row + 36, idx); \ - ssrc[0][1] = vld_dup_tbl_s32(_row + 40, idx); \ - \ - CALC_ONE_RESULT(ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], \ - ssrc[0][8], ssrc[0][0], ssrc[0][1], weight[0], c[0][2]); \ - ssrc[0][2] = vld_dup_tbl_s32(_row + 44, idx); \ - ssrc[0][3] = vld_dup_tbl_s32(_row + 48, idx); \ - \ - CALC_ONE_RESULT(ssrc[0][6], ssrc[0][7], ssrc[0][8], ssrc[0][0], \ - ssrc[0][1], ssrc[0][2], ssrc[0][3], weight[0], c[0][3]); +#define CALC_4_RESULT(_row) \ + CALC_ONE_RESULT( \ + ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], \ + ssrc[0][6], weight[0], c[0][0]); \ + \ + ssrc[0][7] = vld_dup_tbl_s32(_row + 28, idx); \ + ssrc[0][8] = vld_dup_tbl_s32(_row + 32, idx); \ + CALC_ONE_RESULT( \ + ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], \ + ssrc[0][8], weight[0], c[0][1]); \ + \ + ssrc[0][0] = vld_dup_tbl_s32(_row + 36, idx); \ + ssrc[0][1] = vld_dup_tbl_s32(_row + 40, idx); \ + \ + CALC_ONE_RESULT( \ + ssrc[0][4], ssrc[0][5], ssrc[0][6], ssrc[0][7], ssrc[0][8], ssrc[0][0], \ + ssrc[0][1], weight[0], c[0][2]); \ + ssrc[0][2] = vld_dup_tbl_s32(_row + 44, idx); \ + ssrc[0][3] = vld_dup_tbl_s32(_row + 48, idx); \ + \ + CALC_ONE_RESULT( \ + ssrc[0][6], ssrc[0][7], ssrc[0][8], ssrc[0][0], ssrc[0][1], ssrc[0][2], \ + ssrc[0][3], weight[0], c[0][3]); int16x8_t tmp0, tmp1; @@ -1417,8 +1368,8 @@ void conv_direct_stride2_7x7_int8x8x16_nchw44( for (; ow_idx < ow; ow_idx++) { const size_t src_offset = (oh_idx * stride_h * iw + ow_idx * stride_w) * ic_step; - const size_t dst_offset = oc_idx * out_img_stride + - (oh_idx * ow + ow_idx) * oc_step; + const size_t dst_offset = + oc_idx * out_img_stride + (oh_idx * ow + ow_idx) * oc_step; constexpr int ld_weight_ic4 = 16; @@ -1446,9 +1397,10 @@ void conv_direct_stride2_7x7_int8x8x16_nchw44( src_ptr + ic_idx * ic_stride + 5 * iw * ic_step; const int8_t* src_row6 = src_ptr + ic_idx * ic_stride + 6 * iw * ic_step; -#define CALC_1_RESULT(_row) \ - CALC_ONE_RESULT(ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], \ - ssrc[0][4], ssrc[0][5], ssrc[0][6], weight[0], c); +#define CALC_1_RESULT(_row) \ + CALC_ONE_RESULT( \ + ssrc[0][0], ssrc[0][1], ssrc[0][2], ssrc[0][3], ssrc[0][4], ssrc[0][5], \ + ssrc[0][6], weight[0], c); int16x8_t tmp0, tmp1; LOAD_SRC(src_row0); @@ -1494,42 +1446,42 @@ namespace int8x8x16_direct_nchw44 { template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int16_t* bias, int16_t* dst, const size_t oc, - const size_t ic, const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { - conv_direct_stride2_2x2_int8_nchw44(src, filter, bias, dst, - oc, ic, ih, iw, oh, ow); + static void impl( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride2_2x2_int8_nchw44( + src, filter, bias, dst, oc, ic, ih, iw, oh, ow); } }; template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int16_t* bias, int16_t* dst, const size_t oc, - const size_t ic, const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { - conv_direct_stride2_3x3_int8_nchw44(src, filter, bias, dst, - oc, ic, ih, iw, oh, ow); + static void impl( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { + conv_direct_stride2_3x3_int8_nchw44( + src, filter, bias, dst, oc, ic, ih, iw, oh, ow); } }; template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int16_t* bias, int16_t* dst, const size_t oc, - const size_t ic, const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { + static void impl( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { conv_direct_stride2_5x5_int8x8x16_nchw44( src, filter, bias, dst, oc, ic, ih, iw, oh, ow); } }; template struct ConvDirectInt8Nchw44Choose { - static void impl(const int8_t* src, const int8_t* filter, - const int16_t* bias, int16_t* dst, const size_t oc, - const size_t ic, const size_t ih, const size_t iw, - const size_t oh, const size_t ow) { + static void impl( + const int8_t* src, const int8_t* filter, const int16_t* bias, int16_t* dst, + const size_t oc, const size_t ic, const size_t ih, const size_t iw, + const size_t oh, const size_t ow) { conv_direct_stride2_7x7_int8x8x16_nchw44( src, filter, bias, dst, oc, ic, ih, iw, oh, ow); } diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp index 86ac80a5..231c2dfd 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_algo.cpp @@ -24,15 +24,14 @@ using namespace megdnn; using namespace arm_common; using conv_fun = std::function; + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids, + const CpuNDRange& ncb_range)>; MIDOUT_DECL(megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44) namespace { -static inline size_t get_perthread_cache_bytes(const int ic, const int ih2, - const int iw2) { +static inline size_t get_perthread_cache_bytes( + const int ic, const int ih2, const int iw2) { //! border_size is used to avoid read illegal memory constexpr int iw_expand = 8; int border_size = 64 * 2; @@ -54,8 +53,8 @@ static void get_rectified_size( const int filter_h = static_cast(fm.spatial[0]); const int ic = fm.icpg; iw2 = iw + 2 * static_cast(fm.padding[1]); - int block_oh = l2_block_helper(param.nr_threads, oh, - ic * iw2 * stride_h * iw_expand); + int block_oh = + l2_block_helper(param.nr_threads, oh, ic * iw2 * stride_h * iw_expand); ih2 = block_oh * stride_h + filter_h - stride_h; } @@ -74,20 +73,18 @@ static WorkspaceBundle get_bundle(const ConvBiasImpl::NCBKernSizeParam& param) { constexpr int pack_oc = 8; const int weight_expand = stride == 1 ? 2 : 1; size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); - size_t weight_size = group * round_up(oc, 8) * ic * fh * fw * - sizeof(int8_t) * weight_expand; + size_t weight_size = + group * round_up(oc, 8) * ic * fh * fw * sizeof(int8_t) * weight_expand; size_t bisa_size = 0; - if (param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && - oc % pack_oc != 0) { + if (param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && oc % pack_oc != 0) { bisa_size = round_up(oc, 8) * sizeof(int16_t); } return {nullptr, {src_size * param.nr_threads, weight_size, bisa_size}}; }; -static inline void copy_pad_src(int8_t* sptr_base, const int8_t* sptr_origin, - int ph, int pw, int pad_right, int ih, int iw, - int iw2, int pad_top, int pad_bottom, int ic, - int ic_stride) { +static inline void copy_pad_src( + int8_t* sptr_base, const int8_t* sptr_origin, int ph, int pw, int pad_right, + int ih, int iw, int iw2, int pad_top, int pad_bottom, int ic, int ic_stride) { constexpr int iw_expand = 8; MEGDNN_MARK_USED_VAR(ph); rep(ic_idx, ic) { @@ -107,9 +104,9 @@ static inline void copy_pad_src(int8_t* sptr_base, const int8_t* sptr_origin, sptr_base += iw2 * pad_bottom * iw_expand; } } -static void pack_weight(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index) { +static void pack_weight( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index) { const int group_id = ncb_index.ndrange_id[0]; int fh = kern_param.filter_meta.spatial[0]; int fw = kern_param.filter_meta.spatial[1]; @@ -118,8 +115,7 @@ static void pack_weight(const WorkspaceBundle& bundle, int oc_block = oc; int stride = kern_param.filter_meta.stride[0]; constexpr int oc_idx = 0; - const int8_t* fptr = - kern_param.filter(group_id) + oc_idx * fh * fw * ic; + const int8_t* fptr = kern_param.filter(group_id) + oc_idx * fh * fw * ic; auto packed_weight = reinterpret_cast(bundle.get(1)) + group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw; switch (stride) { @@ -135,19 +131,17 @@ static void pack_weight(const WorkspaceBundle& bundle, break; } constexpr int pack_oc = 8; - if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && - oc % pack_oc != 0) { + if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && oc % pack_oc != 0) { auto packed_bias = reinterpret_cast(bundle.get(2)); - memcpy(packed_bias, kern_param.bias_ptr, - round_up(oc, 8) * sizeof(int16_t)); + memcpy(packed_bias, kern_param.bias_ptr, round_up(oc, 8) * sizeof(int16_t)); } } template -static void do_conv_kern(const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange&, const CpuNDRange&) { +static void do_conv_kern( + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange&, + const CpuNDRange&) { const int oh = kern_param.osz[0]; const int ow = kern_param.osz[1]; const int fh = kern_param.filter_meta.spatial[0]; @@ -172,49 +166,45 @@ static void do_conv_kern(const WorkspaceBundle& bundle, const int group_id = ncb_index.ndrange_id[1]; constexpr int oc_idx = 0; int oc_block = oc; - int oh_block = l2_block_helper(kern_param.nr_threads, oh, - ic * iw2 * stride_h * src_expand); + int oh_block = l2_block_helper( + kern_param.nr_threads, oh, ic * iw2 * stride_h * src_expand); const int oh_idx = ncb_index.ndrange_id[2]; const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block); const int ih_real = oh_block_real * stride_h + fh - stride_h; const int src_top_pad = std::max(ph - oh_idx * oh_block * stride_h, 0); const int src_bottom_pad = std::max( - (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, - 0); + (oh_idx * oh_block + oh_block_real - 1) * stride_h + fh - ih - ph, 0); const int remain_right_pad = std::max(iw2 - iw - pw, 0); const int src_offset = std::max(oh_idx * oh_block * stride_h - ph, 0) * iw; - const int8_t* origin_sptr = - static_cast( - kern_param.src(batch_id, group_id, 0, 1, 1)) + - src_offset; + const int8_t* origin_sptr = static_cast(kern_param.src( + batch_id, group_id, 0, 1, 1)) + + src_offset; const size_t src_size = get_perthread_cache_bytes(ic, ih2, iw2); - int8_t* sptr = reinterpret_cast((int8_t*)bundle.get(0) + - ncb_index.thread_id * src_size); + int8_t* sptr = reinterpret_cast( + (int8_t*)bundle.get(0) + ncb_index.thread_id * src_size); - copy_pad_src(sptr, origin_sptr, ph, pw, remain_right_pad, - ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, - src_bottom_pad, ic, ih * iw); + copy_pad_src( + sptr, origin_sptr, ph, pw, remain_right_pad, + ih_real - src_top_pad - src_bottom_pad, iw, iw2, src_top_pad, + src_bottom_pad, ic, ih * iw); //! pack weight auto packed_weight = reinterpret_cast(bundle.get(1)) + - (group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw) * - weight_expand; + (group_id * oc * ic * fh * fw + oc_idx * ic * fh * fw) * weight_expand; //! get param int16_t* dst = kern_param.dst(batch_id, group_id) + oh_idx * oh_block * ow * pack_c; - const int16_t* bptr = - kern_param.bias(batch_id, group_id) + oc_idx; + const int16_t* bptr = kern_param.bias(batch_id, group_id) + oc_idx; constexpr int pack_oc = 8; - if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && - oc % pack_oc != 0) { + if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && oc % pack_oc != 0) { bptr = reinterpret_cast(bundle.get(2)); } Op op; i8i8i16_direct_nchw_nchw44::conv_direct_i8i8i16_nchw_nchw44< bias_mode, Op, filter_size, stride>( - sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, - oh, oh_block_real, ow, op, ph, pw); + sptr, packed_weight, bptr, nullptr, dst, oc_block, ic, ih_real, iw2, oh, + oh_block_real, ow, op, ph, pw); } } // namespace @@ -222,24 +212,23 @@ static void do_conv_kern(const WorkspaceBundle& bundle, bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy) const { return nchw_nchwxx_valid( - param.src_type.enumv(), param.filter_type.enumv(), - param.dst_type.enumv(), param.filter_meta, param.bias_mode, - param.nonlineMode); + param.src_type.enumv(), param.filter_type.enumv(), param.dst_type.enumv(), + param.filter_meta, param.bias_mode, param.nonlineMode); } size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44, - midout_iv("AlgoI8x8x16DirectNCHWNCHW44::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44, + midout_iv("AlgoI8x8x16DirectNCHWNCHW44::get_workspace"_hash)) { return get_bundle(param).total_size_in_bytes(); } MIDOUT_END(); return 0; } -SmallVector -ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::dispatch_kerns( - const NCBKernSizeParam& param) const { +SmallVector ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44:: + dispatch_kerns(const NCBKernSizeParam& param) const { auto fm = param.filter_meta; const int batch = param.n; const int group = fm.group; @@ -247,11 +236,12 @@ ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::dispatch_kerns( conv_fun do_conv_fun = nullptr; //! NOTE: remain_w is not used to gen hash of midout for compatible with //! shape runtime -#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44, \ - midout_iv(#stride #filter #bias_mode #op##_hash)) { \ - do_conv_fun = do_conv_kern; \ - } \ +#define DO_CONV_KERN_FUN(stride, filter, bias_mode, op) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44, \ + midout_iv(#stride #filter #bias_mode #op##_hash)) { \ + do_conv_fun = do_conv_kern; \ + } \ MIDOUT_END(); #define GET_OP_PARAM(stride, filter, bias_mode) \ @@ -303,8 +293,9 @@ ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::dispatch_kerns( DISPATCH_CONV_KERN(2); break; default: - megdnn_throw(ssprintf("Unsupport stride size %u for the first conv", - param.filter_meta.stride[0]) + megdnn_throw(ssprintf( + "Unsupport stride size %u for the first conv", + param.filter_meta.stride[0]) .c_str()); break; } @@ -323,24 +314,24 @@ ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::dispatch_kerns( const int stride_h = static_cast(fm.stride[0]); const int ic = fm.icpg; get_rectified_size(param, ih2, iw2, oh2, ow2); - int oh_block = l2_block_helper(param.nr_threads, oh, - ic * iw2 * stride_h * iw_expand); + int oh_block = + l2_block_helper(param.nr_threads, oh, ic * iw2 * stride_h * iw_expand); - auto do_pack_weight = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto do_pack_weight = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); pack_weight(bundle, kern_param, ncb_index); }; ret_kerns.push_back({do_pack_weight, {static_cast(group)}}); - CpuNDRange ncb_range = {static_cast(batch), - static_cast(group), - static_cast(div_ceil(oh, oh_block))}; + CpuNDRange ncb_range = { + static_cast(batch), static_cast(group), + static_cast(div_ceil(oh, oh_block))}; auto do_conv = [bundle, do_conv_fun, ncb_range]( const NCBKernParam& kern_param, const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, - ncb_range); + do_conv_fun(bundle, kern_param, ncb_index, ncb_index.ndrange_id, ncb_range); }; ret_kerns.push_back({do_conv, ncb_range}); diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h index 2d518a31..d00cba5b 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h @@ -34,14 +34,13 @@ namespace i8i8i16_direct_nchw_nchw44 { * @param ic */ template -inline void pack_weight_int8_nchw_nchw44(const int8_t* in_ptr, int8_t* dst_ptr, - const int oc, const int kh, - const int kw, const int ic); +inline void pack_weight_int8_nchw_nchw44( + const int8_t* in_ptr, int8_t* dst_ptr, const int oc, const int kh, const int kw, + const int ic); template <> -inline void pack_weight_int8_nchw_nchw44<2>(const int8_t* in_ptr, - int8_t* dst_ptr, const int oc, - const int kh, const int kw, - const int ic) { +inline void pack_weight_int8_nchw_nchw44<2>( + const int8_t* in_ptr, int8_t* dst_ptr, const int oc, const int kh, const int kw, + const int ic) { constexpr int in_pack_oc = 4; constexpr int out_pack_oc = 8; constexpr int out_pair = 2; @@ -59,10 +58,10 @@ inline void pack_weight_int8_nchw_nchw44<2>(const int8_t* in_ptr, for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { int32_t temp0 = *in_oc0_ptr++; int32_t temp1 = *in_oc1_ptr++; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 0] = temp0; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 1] = temp1; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 0] = + temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 1] = + temp1; } } pack_dst_ptr += ic * filter_size * out_pair; @@ -74,19 +73,18 @@ inline void pack_weight_int8_nchw_nchw44<2>(const int8_t* in_ptr, for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { int32_t temp0 = *in_oc0_ptr++; int32_t temp1 = 0; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 0] = temp0; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 1] = temp1; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 0] = + temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 1] = + temp1; } } } } template <> -inline void pack_weight_int8_nchw_nchw44<1>(const int8_t* in_ptr, - int8_t* dst_ptr, const int oc, - const int kh, const int kw, - const int ic) { +inline void pack_weight_int8_nchw_nchw44<1>( + const int8_t* in_ptr, int8_t* dst_ptr, const int oc, const int kh, const int kw, + const int ic) { constexpr int in_pack_oc = 4; constexpr int out_pack_oc = 8; constexpr int out_pair = 4; @@ -104,14 +102,14 @@ inline void pack_weight_int8_nchw_nchw44<1>(const int8_t* in_ptr, for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { int32_t temp0 = *in_oc0_ptr++; int32_t temp1 = *in_oc1_ptr++; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 0] = temp0; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 1] = temp1; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 2] = temp0; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 3] = temp1; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 0] = + temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 1] = + temp1; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 2] = + temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 3] = + temp1; } } pack_dst_ptr += ic * filter_size * out_pair; @@ -123,26 +121,25 @@ inline void pack_weight_int8_nchw_nchw44<1>(const int8_t* in_ptr, for (int ic_idx = 0; ic_idx < ic; ++ic_idx) { int32_t temp0 = *in_oc0_ptr++; int32_t temp1 = 0; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 0] = temp0; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 1] = temp1; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 2] = temp0; - pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + - 3] = temp1; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 0] = + temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 1] = + temp1; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 2] = + temp0; + pack_dst_ptr[(ic_idx * filter_size + filter_idx) * out_pair + 3] = + temp1; } } } } template -void conv_direct_i8i8i16_nchw_nchw44(const int8_t* src, const int8_t* filter, - const int16_t* bias, int8_t*, int16_t* dst, - const int oc, const int ic, const int ih, - const int iw, const int oh, - const int oh_block, const int ow, - const Op& op, const int, const int); +void conv_direct_i8i8i16_nchw_nchw44( + const int8_t* src, const int8_t* filter, const int16_t* bias, int8_t*, + int16_t* dst, const int oc, const int ic, const int ih, const int iw, + const int oh, const int oh_block, const int ow, const Op& op, const int, + const int); } // namespace i8i8i16_direct_nchw_nchw44 diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h index 38bc59eb..16ad7de7 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/kernel/direct_nchw_nchw44_kern_impl.h @@ -39,21 +39,22 @@ namespace { * @tparam T3 * @tparam T4 */ -template +template < + int src_idx, int weight_idx, int c_dim, int ow_block, bool half_adv, int stride, + typename T, typename T2, typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); }; -template +template < + int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ - c[0][step] = vmlal_s8(c[0][step], vget_low_s8(weight[0][weight_idx]), \ - vget_low_s8(src[step + src_idx])); \ - c[0][step] = vmlal_high_s8(c[0][step], weight[0][weight_idx], \ - src[step + src_idx]); +#define cb(step) \ + c[0][step] = vmlal_s8( \ + c[0][step], vget_low_s8(weight[0][weight_idx]), \ + vget_low_s8(src[step + src_idx])); \ + c[0][step] = vmlal_high_s8(c[0][step], weight[0][weight_idx], src[step + src_idx]); UNROLL_CALL_RAW(8, cb); @@ -61,13 +62,14 @@ struct ShiftCalHelper { } }; -template +template < + int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ - c[0][step] = vmlal_s8(c[0][step], vget_low_s8(weight[0][weight_idx]), \ - vget_low_s8(src[step + src_idx])); +#define cb(step) \ + c[0][step] = vmlal_s8( \ + c[0][step], vget_low_s8(weight[0][weight_idx]), \ + vget_low_s8(src[step + src_idx])); UNROLL_CALL_RAW(8, cb); @@ -75,30 +77,28 @@ struct ShiftCalHelper { } }; -template +template < + int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { //! for compatible with stride2 kernel, step, weight_idx, src_idx should //! mul 2 -#define cb(step) \ - c[0][2 * step] = \ - vmlal_s8(c[0][2 * step], vget_low_s8(weight[0][2 * weight_idx]), \ - vget_low_s8(src[step + src_idx])); \ - c[0][2 * step + 1] = \ - vmlal_high_s8(c[0][2 * step + 1], weight[0][2 * weight_idx], \ - src[step + src_idx]); +#define cb(step) \ + c[0][2 * step] = vmlal_s8( \ + c[0][2 * step], vget_low_s8(weight[0][2 * weight_idx]), \ + vget_low_s8(src[step + src_idx])); \ + c[0][2 * step + 1] = vmlal_high_s8( \ + c[0][2 * step + 1], weight[0][2 * weight_idx], src[step + src_idx]); UNROLL_CALL_RAW(4, cb); #undef cb -#define cb(step) \ - c[0][2 * step] = \ - vmlal_high_s8(c[0][2 * step], weight[0][2 * weight_idx + 1], \ - src[step + src_idx]); \ - c[0][2 * step + 1] = vmlal_s8(c[0][2 * step + 1], \ - vget_low_s8(weight[0][2 * weight_idx + 1]), \ - vget_low_s8(src[step + 1 + src_idx])); +#define cb(step) \ + c[0][2 * step] = vmlal_high_s8( \ + c[0][2 * step], weight[0][2 * weight_idx + 1], src[step + src_idx]); \ + c[0][2 * step + 1] = vmlal_s8( \ + c[0][2 * step + 1], vget_low_s8(weight[0][2 * weight_idx + 1]), \ + vget_low_s8(src[step + 1 + src_idx])); UNROLL_CALL_RAW(4, cb); @@ -106,17 +106,16 @@ struct ShiftCalHelper { } }; -template +template < + int src_idx, int weight_idx, typename T, typename T2, typename T3, typename T4> struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ - c[0][2 * step] = \ - vmlal_s8(c[0][2 * step], vget_low_s8(weight[0][2 * weight_idx]), \ - vget_low_s8(src[step + src_idx])); \ - c[0][2 * step + 1] = \ - vmlal_high_s8(c[0][2 * step + 1], weight[0][2 * weight_idx], \ - src[step + src_idx]); +#define cb(step) \ + c[0][2 * step] = vmlal_s8( \ + c[0][2 * step], vget_low_s8(weight[0][2 * weight_idx]), \ + vget_low_s8(src[step + src_idx])); \ + c[0][2 * step + 1] = vmlal_high_s8( \ + c[0][2 * step + 1], weight[0][2 * weight_idx], src[step + src_idx]); UNROLL_CALL_RAW(4, cb); @@ -140,26 +139,29 @@ public: static const int val = 2; }; -template +template < + int src_idx, int weight_idx, int c_dim, int ow_block, bool half_adv, int stride, + typename T, typename T2, typename T3> MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl(c, src, weight); + ShiftCalHelper< + src_idx, weight_idx, c_dim, ow_block, half_adv, stride, T, T2, T3, + int>::impl(c, src, weight); }; -template +template < + BiasMode bias_mode, typename Op, int filter_size, int oc_block, int stride, + int ow_block> struct KerNeonXXs2NchwNchw44I8I8I16 { - static void impl(const int8_t* src_ptr_origin, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op&, const int remain_ow); + static void impl( + const int8_t* src_ptr_origin, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op&, const int remain_ow); }; -template -struct KerNeonXXs2NchwNchw44I8I8I16 { - static void impl(const int8_t* src_ptr_origin, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op&, const int remain_ow) { +template +struct KerNeonXXs2NchwNchw44I8I8I16 { + static void impl( + const int8_t* src_ptr_origin, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op&, const int remain_ow) { constexpr int loop_ic_step = 1; constexpr int filter_size = 2; constexpr int iw_expand = 8; @@ -173,8 +175,7 @@ struct KerNeonXXs2NchwNchw44I8I8I16(c, bias_ptr, 0); @@ -183,11 +184,10 @@ struct KerNeonXXs2NchwNchw44I8I8I16( \ - src, src_ptr + step * ld_src_iw, 0); \ - load_helper( \ - weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ +#define cb(step) \ + load_helper(src, src_ptr + step * ld_src_iw, 0); \ + load_helper( \ + weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ cal_helper<0, 0, c_dim, ow_block, false, stride>(c, src, weight); UNROLL_CALL_RAW(2, cb) #undef cb @@ -201,13 +201,12 @@ struct KerNeonXXs2NchwNchw44I8I8I16 -struct KerNeonXXs2NchwNchw44I8I8I16 { - static void impl(const int8_t* src_ptr_origin, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op&, const int remain_ow) { +template +struct KerNeonXXs2NchwNchw44I8I8I16 { + static void impl( + const int8_t* src_ptr_origin, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op&, const int remain_ow) { constexpr int loop_ic_step = 1; constexpr int filter_size = 3; constexpr int iw_expand = 8; @@ -221,8 +220,7 @@ struct KerNeonXXs2NchwNchw44I8I8I16(c, bias_ptr, 0); @@ -231,12 +229,11 @@ struct KerNeonXXs2NchwNchw44I8I8I16( \ - src, src_ptr + step * ld_src_iw, 0); \ - load_helper( \ - weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ - cal_helper<0, 0, c_dim, ow_block, false, stride>(c, src, weight); \ +#define cb(step) \ + load_helper(src, src_ptr + step * ld_src_iw, 0); \ + load_helper( \ + weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ + cal_helper<0, 0, c_dim, ow_block, false, stride>(c, src, weight); \ cal_helper<1, 1, c_dim, ow_block, true, stride>(c, src, weight); UNROLL_CALL_RAW(3, cb) #undef cb @@ -250,13 +247,12 @@ struct KerNeonXXs2NchwNchw44I8I8I16 -struct KerNeonXXs2NchwNchw44I8I8I16 { - static void impl(const int8_t* src_ptr_origin, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op&, const int remain_ow) { +template +struct KerNeonXXs2NchwNchw44I8I8I16 { + static void impl( + const int8_t* src_ptr_origin, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op&, const int remain_ow) { constexpr int loop_ic_step = 1; constexpr int filter_size = 5; constexpr int iw_expand = 8; @@ -270,8 +266,7 @@ struct KerNeonXXs2NchwNchw44I8I8I16(c, bias_ptr, 0); @@ -280,13 +275,12 @@ struct KerNeonXXs2NchwNchw44I8I8I16( \ - src, src_ptr + step * ld_src_iw, 0); \ - load_helper( \ - weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ - cal_helper<0, 0, c_dim, ow_block, false, stride>(c, src, weight); \ - cal_helper<1, 1, c_dim, ow_block, false, stride>(c, src, weight); \ +#define cb(step) \ + load_helper(src, src_ptr + step * ld_src_iw, 0); \ + load_helper( \ + weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ + cal_helper<0, 0, c_dim, ow_block, false, stride>(c, src, weight); \ + cal_helper<1, 1, c_dim, ow_block, false, stride>(c, src, weight); \ cal_helper<2, 2, c_dim, ow_block, true, stride>(c, src, weight); UNROLL_CALL_RAW(5, cb) #undef cb @@ -300,13 +294,12 @@ struct KerNeonXXs2NchwNchw44I8I8I16 -struct KerNeonXXs2NchwNchw44I8I8I16 { - static void impl(const int8_t* src_ptr_origin, const int8_t* weight_ptr, - const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, - int iw, int ld_dst_oc, const Op&, const int remain_ow) { +template +struct KerNeonXXs2NchwNchw44I8I8I16 { + static void impl( + const int8_t* src_ptr_origin, const int8_t* weight_ptr, + const int16_t* bias_ptr, int16_t* dst_ptr, int ic, int ih, int iw, + int ld_dst_oc, const Op&, const int remain_ow) { constexpr int loop_ic_step = 1; constexpr int filter_size = 7; constexpr int iw_expand = 8; @@ -320,8 +313,7 @@ struct KerNeonXXs2NchwNchw44I8I8I16(c, bias_ptr, 0); @@ -330,14 +322,13 @@ struct KerNeonXXs2NchwNchw44I8I8I16( \ - src, src_ptr + step * ld_src_iw, 0); \ - load_helper( \ - weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ - cal_helper<0, 0, c_dim, ow_block, false, stride>(c, src, weight); \ - cal_helper<1, 1, c_dim, ow_block, false, stride>(c, src, weight); \ - cal_helper<2, 2, c_dim, ow_block, false, stride>(c, src, weight); \ +#define cb(step) \ + load_helper(src, src_ptr + step * ld_src_iw, 0); \ + load_helper( \ + weight, weight_ptr + step * filter_size * filter_pack_oc, 0); \ + cal_helper<0, 0, c_dim, ow_block, false, stride>(c, src, weight); \ + cal_helper<1, 1, c_dim, ow_block, false, stride>(c, src, weight); \ + cal_helper<2, 2, c_dim, ow_block, false, stride>(c, src, weight); \ cal_helper<3, 3, c_dim, ow_block, true, stride>(c, src, weight); UNROLL_CALL_RAW(7, cb) #undef cb @@ -389,10 +380,9 @@ void i8i8i16_direct_nchw_nchw44::conv_direct_i8i8i16_nchw_nchw44( const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; KerNeonXXs2NchwNchw44I8I8I16< - bias_mode, Op, filter_size, big_oc_step, stride, - ow_step>::impl(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op, ow_step); + bias_mode, Op, filter_size, big_oc_step, stride, ow_step>:: + impl(src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op, ow_step); } if (ow_remain > 0) { const int src_offset = @@ -401,10 +391,9 @@ void i8i8i16_direct_nchw_nchw44::conv_direct_i8i8i16_nchw_nchw44( const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; KerNeonXXs2NchwNchw44I8I8I16< - bias_mode, Op, filter_size, big_oc_step, stride, - ow_step>::impl(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op, ow_remain); + bias_mode, Op, filter_size, big_oc_step, stride, ow_step>:: + impl(src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op, ow_remain); } } } @@ -419,10 +408,9 @@ void i8i8i16_direct_nchw_nchw44::conv_direct_i8i8i16_nchw_nchw44( const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; KerNeonXXs2NchwNchw44I8I8I16< - bias_mode, Op, filter_size, oc_step, stride, - ow_step>::impl(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op, ow_step); + bias_mode, Op, filter_size, oc_step, stride, ow_step>:: + impl(src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op, ow_step); } if (ow_remain > 0) { const int src_offset = @@ -431,22 +419,21 @@ void i8i8i16_direct_nchw_nchw44::conv_direct_i8i8i16_nchw_nchw44( const int dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; KerNeonXXs2NchwNchw44I8I8I16< - bias_mode, Op, filter_size, oc_step, stride, - ow_step>::impl(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op, ow_remain); + bias_mode, Op, filter_size, oc_step, stride, ow_step>:: + impl(src + src_offset, filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, ld_dst_oc, op, ow_remain); } } } } -#define INSTANTIATION(stride, filter_size, bias_mode, Op) \ - template void i8i8i16_direct_nchw_nchw44::conv_direct_i8i8i16_nchw_nchw44< \ - bias_mode, Op, filter_size, stride>( \ - const int8_t* src, const int8_t* filter, const int16_t* bias, \ - int8_t*, int16_t* dst, const int oc, const int ic, const int ih, \ - const int iw, const int oh, const int oh_block, const int ow, \ - const Op& op, const int, const int); +#define INSTANTIATION(stride, filter_size, bias_mode, Op) \ + template void i8i8i16_direct_nchw_nchw44::conv_direct_i8i8i16_nchw_nchw44< \ + bias_mode, Op, filter_size, stride>( \ + const int8_t* src, const int8_t* filter, const int16_t* bias, int8_t*, \ + int16_t* dst, const int oc, const int ic, const int ih, const int iw, \ + const int oh, const int oh_block, const int ow, const Op& op, const int, \ + const int); #define FOR_OP(stride, filter, bias) \ INSTANTIATION(stride, filter, bias, NoneOp) diff --git a/dnn/src/arm_common/conv_bias/intrinsic_helper.h b/dnn/src/arm_common/conv_bias/intrinsic_helper.h index 2ee70d20..77098abd 100644 --- a/dnn/src/arm_common/conv_bias/intrinsic_helper.h +++ b/dnn/src/arm_common/conv_bias/intrinsic_helper.h @@ -101,8 +101,7 @@ struct Store_OC4_OW8_Remain<1, Op> { }; template -__ai void store_oc4_ow8_remain_static(int32x4_t c[8], const Op& op, - int8_t* dst_ptr) { +__ai void store_oc4_ow8_remain_static(int32x4_t c[8], const Op& op, int8_t* dst_ptr) { Store_OC4_OW8_Remain::impl(c, op, dst_ptr); } @@ -113,15 +112,13 @@ struct StoreOcxOw4Remain { template struct StoreOcxOw4Remain<2, 0, Op, T> { - static __ai void impl(int32x4_t c[2][4], const Op& op, int8_t* dst_ptr, - int ld_dst_oc) { + static __ai void impl( + int32x4_t c[2][4], const Op& op, int8_t* dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op({{c[1][0], c[1][1]}}, - reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); } }; @@ -131,8 +128,7 @@ struct StoreOcxOw4Remain<2, 3, Op, T> { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); op(c[0][2], reinterpret_cast(dst_ptr + 8)); - op({{c[1][0], c[1][1]}}, - reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); op(c[1][2], reinterpret_cast(dst_ptr + ld_dst_oc + 8)); } }; @@ -140,8 +136,7 @@ template struct StoreOcxOw4Remain<2, 2, Op, T> { static __ai void impl(T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[1][0], c[1][1]}}, - reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); } }; template @@ -154,8 +149,8 @@ struct StoreOcxOw4Remain<2, 1, Op, T> { template struct StoreOcxOw4Remain<1, 0, Op, T> { - static __ai void impl(int32x4_t c[2][4], const Op& op, int8_t* dst_ptr, - int ld_dst_oc) { + static __ai void impl( + int32x4_t c[2][4], const Op& op, int8_t* dst_ptr, int ld_dst_oc) { MEGDNN_MARK_USED_VAR(ld_dst_oc); op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); @@ -185,13 +180,12 @@ struct StoreOcxOw4Remain<1, 1, Op, T> { } }; template -__ai void store_ocx_ow4_remain_static(T& c, const Op& op, int8_t* dst_ptr, - int ld_dst_oc) { +__ai void store_ocx_ow4_remain_static( + T& c, const Op& op, int8_t* dst_ptr, int ld_dst_oc) { StoreOcxOw4Remain::impl(c, op, dst_ptr, ld_dst_oc); } ////////////////////Store_OCX_OW8_Remain///////////////////////// -template +template struct StoreOcxOw8Remain { static __ai void impl(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc); }; @@ -206,10 +200,8 @@ struct StoreOcxOw8Remain<2, 0, Op, T, T2, T3> { op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op({{c[1][4], c[1][5]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 16)); - op({{c[1][6], c[1][7]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 24)); + op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][6], c[1][7]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 24)); } }; template @@ -222,10 +214,8 @@ struct StoreOcxOw8Remain<2, 8, Op, T, T2, T3> { op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op({{c[1][4], c[1][5]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 16)); - op({{c[1][6], c[1][7]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 24)); + op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][6], c[1][7]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 24)); } }; template @@ -238,8 +228,7 @@ struct StoreOcxOw8Remain<2, 7, Op, T, T2, T3> { op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op({{c[1][4], c[1][5]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); op(c[1][6], reinterpret_cast(dst_ptr + ld_dst_oc + 24)); } }; @@ -252,8 +241,7 @@ struct StoreOcxOw8Remain<2, 6, Op, T, T2, T3> { op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op({{c[1][4], c[1][5]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); } }; template @@ -375,34 +363,32 @@ struct StoreOcxOw8Remain<1, 1, Op, T, T2, T3> { }; template -__ai void store_ocx_ow8_remain_static(T& c, const Op& op, T2 dst_ptr, - int ld_dst_oc) { - StoreOcxOw8Remain::impl(c, op, dst_ptr, - ld_dst_oc); +__ai void store_ocx_ow8_remain_static(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + StoreOcxOw8Remain::impl(c, op, dst_ptr, ld_dst_oc); } -template -__ai void store_ocx_ow8_remain_static_dt(T& c, const Op& op, T2 dst_ptr, - int ld_dst_oc) { - StoreOcxOw8Remain::impl(c, op, dst_ptr, - ld_dst_oc); +template +__ai void store_ocx_ow8_remain_static_dt( + T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { + StoreOcxOw8Remain::impl(c, op, dst_ptr, ld_dst_oc); } ////////////////////Store_OCX_OW8_Remain///////////////////////// -template +template < + int c_dim, int ow_block, int nr_group, int out_group, typename T, typename T2, + typename T3> struct StoreOc4Ow8Remain { static __ai void impl(T& c, T2 dst_ptr, int ld_dst_oc, const int ow_remain); }; -#define cb(step) \ - vst1q_lane_s64((int64_t*)(dst_ptr + step * 4), \ - vreinterpretq_s64_s16(c[0][step]), 0); \ - vst1q_lane_s64((int64_t*)(dst_ptr + step * 4 + ld_dst_oc), \ - vreinterpretq_s64_s16(c[0][step]), 1); +#define cb(step) \ + vst1q_lane_s64( \ + (int64_t*)(dst_ptr + step * 4), vreinterpretq_s64_s16(c[0][step]), 0); \ + vst1q_lane_s64( \ + (int64_t*)(dst_ptr + step * 4 + ld_dst_oc), \ + vreinterpretq_s64_s16(c[0][step]), 1); -#define cb2(step) \ - vst1q_lane_s64((int64_t*)(dst_ptr + step * 4), \ - vreinterpretq_s64_s16(c[0][step]), 0); +#define cb2(step) \ + vst1q_lane_s64( \ + (int64_t*)(dst_ptr + step * 4), vreinterpretq_s64_s16(c[0][step]), 0); #define cb_case(step) \ case step: \ @@ -415,8 +401,7 @@ struct StoreOc4Ow8Remain { break; template struct StoreOc4Ow8Remain<1, 8, 2, 2, T, T2, T3> { - static __ai void impl(T& c, T2 dst_ptr, int ld_dst_oc, - const int ow_remain) { + static __ai void impl(T& c, T2 dst_ptr, int ld_dst_oc, const int ow_remain) { if (ow_remain == 8) { UNROLL_CALL_RAW(8, cb) } else { @@ -462,10 +447,9 @@ struct StoreOc4Ow8Remain<1, 8, 2, 1, T, T2, T3> { #undef cb_case #undef cb_case2 -template -__ai void store_oc4_ow8_remain_static(T& c, T2 dst_ptr, const int ld_dst_oc, - const int ow_remain) { +template +__ai void store_oc4_ow8_remain_static( + T& c, T2 dst_ptr, const int ld_dst_oc, const int ow_remain) { StoreOc4Ow8Remain::impl( c, dst_ptr, ld_dst_oc, ow_remain); } @@ -474,121 +458,105 @@ __ai void store_oc4_ow8_remain_static(T& c, T2 dst_ptr, const int ld_dst_oc, template struct Store_OC8_OW8_Remain { - static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, - int ld_dst_oc); + static __ai void impl( + int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, int ld_dst_oc); }; template struct Store_OC8_OW8_Remain<0, Op> { - static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, - int ld_dst_oc) { + static __ai void impl( + int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); op({{c[0][6], c[0][7]}}, reinterpret_cast(dst_ptr + 24)); - op({{c[1][0], c[1][1]}}, - reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op({{c[1][4], c[1][5]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 16)); - op({{c[1][6], c[1][7]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 24)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][6], c[1][7]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 24)); } }; template struct Store_OC8_OW8_Remain<7, Op> { - static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, - int ld_dst_oc) { + static __ai void impl( + int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); op(c[0][6], reinterpret_cast(dst_ptr + 24)); - op({{c[1][0], c[1][1]}}, - reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op({{c[1][4], c[1][5]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); op(c[1][6], reinterpret_cast(dst_ptr + ld_dst_oc + 24)); } }; template struct Store_OC8_OW8_Remain<6, Op> { - static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, - int ld_dst_oc) { + static __ai void impl( + int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); op({{c[0][4], c[0][5]}}, reinterpret_cast(dst_ptr + 16)); - op({{c[1][0], c[1][1]}}, - reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 8)); - op({{c[1][4], c[1][5]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 16)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][4], c[1][5]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 16)); } }; template struct Store_OC8_OW8_Remain<5, Op> { - static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, - int ld_dst_oc) { + static __ai void impl( + int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); op(c[0][4], reinterpret_cast(dst_ptr + 16)); - op({{c[1][0], c[1][1]}}, - reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); op(c[1][4], reinterpret_cast(dst_ptr + ld_dst_oc + 16)); } }; template struct Store_OC8_OW8_Remain<4, Op> { - static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, - int ld_dst_oc) { + static __ai void impl( + int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); op({{c[0][2], c[0][3]}}, reinterpret_cast(dst_ptr + 8)); - op({{c[1][0], c[1][1]}}, - reinterpret_cast(dst_ptr + ld_dst_oc)); - op({{c[1][2], c[1][3]}}, - reinterpret_cast(dst_ptr + ld_dst_oc + 8)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][2], c[1][3]}}, reinterpret_cast(dst_ptr + ld_dst_oc + 8)); } }; template struct Store_OC8_OW8_Remain<3, Op> { - static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, - int ld_dst_oc) { + static __ai void impl( + int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); op(c[0][2], reinterpret_cast(dst_ptr + 8)); - op({{c[1][0], c[1][1]}}, - reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); op(c[1][2], reinterpret_cast(dst_ptr + ld_dst_oc + 8)); } }; template struct Store_OC8_OW8_Remain<2, Op> { - static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, - int ld_dst_oc) { + static __ai void impl( + int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, int ld_dst_oc) { op({{c[0][0], c[0][1]}}, reinterpret_cast(dst_ptr)); - op({{c[1][0], c[1][1]}}, - reinterpret_cast(dst_ptr + ld_dst_oc)); + op({{c[1][0], c[1][1]}}, reinterpret_cast(dst_ptr + ld_dst_oc)); } }; template struct Store_OC8_OW8_Remain<1, Op> { - static __ai void impl(int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, - int ld_dst_oc) { + static __ai void impl( + int32x4_t c[2][8], const Op& op, int8_t* dst_ptr, int ld_dst_oc) { op(c[0][0], reinterpret_cast(dst_ptr)); op(c[1][0], reinterpret_cast(dst_ptr + ld_dst_oc)); } @@ -597,8 +565,7 @@ struct Store_OC8_OW8_Remain<1, Op> { /////////// template -__ai void store_oc8_ow8_remain_static(T& c, const Op& op, T2 dst_ptr, - int ld_dst_oc) { +__ai void store_oc8_ow8_remain_static(T& c, const Op& op, T2 dst_ptr, int ld_dst_oc) { Store_OC8_OW8_Remain::impl(c, op, dst_ptr, ld_dst_oc); } #pragma GCC diagnostic pop @@ -618,8 +585,7 @@ __ai void init_oc4_ow8(int32x4_t c[8], const int32_t* bias_ptr) { } template -__ai void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr, - int oc_step) { +__ai void init_oc8_ow8(int32x4_t c[2][8], const int32_t* bias_ptr, int oc_step) { if (bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) { #define BAIS_INIT(step) \ c[0][step] = vld1q_s32(bias_ptr); \ @@ -682,8 +648,7 @@ struct InitOcxOw8 { #define BAIS_INIT_NO_BIAS_C2(step) \ c[0][step] = neon_vdupq_n(static_cast(0)); \ c[1][step] = neon_vdupq_n(static_cast(0)); -#define BAIS_INIT_NO_BIAS_C1(step) \ - c[0][step] = neon_vdupq_n(static_cast(0)); +#define BAIS_INIT_NO_BIAS_C1(step) c[0][step] = neon_vdupq_n(static_cast(0)); #define BAIS_INIT_BROADCAST_C2(step) \ c[0][step] = neon_vld1q(bias_ptr); \ @@ -694,31 +659,29 @@ struct InitOcxOw8 { c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len); -#define BAIS_INIT_BIAS_C1(step) \ - c[0][step] = neon_vld1q(bias_ptr + step * simd_len); - -#define INSTANCE_InitOcxOw8(ow_remain, cdim) \ - template \ - struct InitOcxOw8 { \ - static __ai void impl(T& c, const T2*, int) { \ - UNROLL_CALL_RAW(ow_remain, BAIS_INIT_NO_BIAS_C##cdim); \ - } \ - }; \ - template \ - struct InitOcxOw8 { \ - static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { \ - (void)oc_step; \ - UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BROADCAST_C##cdim); \ - } \ - }; \ - template \ - struct InitOcxOw8 { \ - static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { \ - constexpr int simd_len = NeonLdqSimd::simd_len; \ - (void)oc_step; \ - UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BIAS_C##cdim); \ - } \ +#define BAIS_INIT_BIAS_C1(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len); + +#define INSTANCE_InitOcxOw8(ow_remain, cdim) \ + template \ + struct InitOcxOw8 { \ + static __ai void impl(T& c, const T2*, int) { \ + UNROLL_CALL_RAW(ow_remain, BAIS_INIT_NO_BIAS_C##cdim); \ + } \ + }; \ + template \ + struct InitOcxOw8 { \ + static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { \ + (void)oc_step; \ + UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BROADCAST_C##cdim); \ + } \ + }; \ + template \ + struct InitOcxOw8 { \ + static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { \ + constexpr int simd_len = NeonLdqSimd::simd_len; \ + (void)oc_step; \ + UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BIAS_C##cdim); \ + } \ }; #define INSTANCE_InitOcxOw8_C(ow_remain) \ INSTANCE_InitOcxOw8(ow_remain, 2); \ @@ -793,8 +756,7 @@ __ai void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) { } /////////////////////////////////////// -static inline void memcpy_s8_dup(int8_t* outptr, const int8_t* inptr, - int count) { +static inline void memcpy_s8_dup(int8_t* outptr, const int8_t* inptr, int count) { constexpr int expand = 8; for (; count >= 8; count -= 8) { int8x8_t in = vld1_s8(inptr); diff --git a/dnn/src/arm_common/conv_bias/matmul_postprocess.h b/dnn/src/arm_common/conv_bias/matmul_postprocess.h index 456970e6..00c75f8b 100644 --- a/dnn/src/arm_common/conv_bias/matmul_postprocess.h +++ b/dnn/src/arm_common/conv_bias/matmul_postprocess.h @@ -20,75 +20,85 @@ namespace megdnn { namespace arm_common { -#define SAVE(C, vres, n, idx) \ - switch (n) { \ - case 4: \ - vst1_lane_s32(reinterpret_cast(C), \ - vreinterpret_s32_s8(vres), idx / 4); \ - break; \ - case 3: \ - vst1_lane_s8(C + 2, vres, idx + 2); MEGDNN_FALLTHRU\ - case 2: \ - vst1_lane_s8(C + 1, vres, idx + 1); MEGDNN_FALLTHRU\ - case 1: \ - vst1_lane_s8(C + 0, vres, idx + 0); \ - break; \ - default: \ - megdnn_assert(0); \ +#define SAVE(C, vres, n, idx) \ + switch (n) { \ + case 4: \ + vst1_lane_s32( \ + reinterpret_cast(C), vreinterpret_s32_s8(vres), \ + idx / 4); \ + break; \ + case 3: \ + vst1_lane_s8(C + 2, vres, idx + 2); \ + MEGDNN_FALLTHRU \ + case 2: \ + vst1_lane_s8(C + 1, vres, idx + 1); \ + MEGDNN_FALLTHRU \ + case 1: \ + vst1_lane_s8(C + 0, vres, idx + 0); \ + break; \ + default: \ + megdnn_assert(0); \ } -#define SAVEU(C, vres, n, idx) \ - switch (n) { \ - case 4: \ - vst1_lane_s32(reinterpret_cast(C), \ - vreinterpret_s32_u8(vres), idx / 4); \ - break; \ - case 3: \ - vst1_lane_u8(C + 2, vres, idx + 2); MEGDNN_FALLTHRU\ - case 2: \ - vst1_lane_u8(C + 1, vres, idx + 1); MEGDNN_FALLTHRU\ - case 1: \ - vst1_lane_u8(C + 0, vres, idx + 0); \ - break; \ - default: \ - megdnn_assert(0); \ +#define SAVEU(C, vres, n, idx) \ + switch (n) { \ + case 4: \ + vst1_lane_s32( \ + reinterpret_cast(C), vreinterpret_s32_u8(vres), \ + idx / 4); \ + break; \ + case 3: \ + vst1_lane_u8(C + 2, vres, idx + 2); \ + MEGDNN_FALLTHRU \ + case 2: \ + vst1_lane_u8(C + 1, vres, idx + 1); \ + MEGDNN_FALLTHRU \ + case 1: \ + vst1_lane_u8(C + 0, vres, idx + 0); \ + break; \ + default: \ + megdnn_assert(0); \ } -template +template < + typename Op, typename dst_type, typename dst_neon_type, typename enable = void> struct Process; template -struct Process, Op>::value>> { - static dst_neon_type run(const int32x4x2_t& wp, const int32x4x2_t, - const Op& op) { +struct Process< + Op, dst_type, dst_neon_type, + std::enable_if_t< + std::is_base_of, Op>::value>> { + static dst_neon_type run(const int32x4x2_t& wp, const int32x4x2_t, const Op& op) { return op(wp); } }; template -struct Process, Op>::value>> { - static dst_neon_type run(const int32x4x2_t& wp, const int32x4x2_t bias, - const Op& op) { +struct Process< + Op, dst_type, dst_neon_type, + std::enable_if_t< + std::is_base_of, Op>::value>> { + static dst_neon_type run( + const int32x4x2_t& wp, const int32x4x2_t bias, const Op& op) { return op(wp, bias); } }; -template +template < + BiasMode bmode, typename Op, typename dst_ctype, int block_m, int block_n, + int m, int n> struct ConvBiasMatmul { - static void postprocess(const dt_int32* bias, const dt_int32* workspace, - dst_ctype* C, size_t LDC, Op op); + static void postprocess( + const dt_int32* bias, const dt_int32* workspace, dst_ctype* C, size_t LDC, + Op op); }; template struct ConvBiasMatmul { - static void postprocess(const dt_int32* bias, const dt_int32* workspace, - dt_int8* C, size_t LDC, const Op& op) { + static void postprocess( + const dt_int32* bias, const dt_int32* workspace, dt_int8* C, size_t LDC, + const Op& op) { static_assert(m > 0 && m <= block_m, "invalid m or n"); int32x4_t vbias0, vwp0, vwp1, vwp2; if (bmode != BiasMode::BROADCAST_CHANNEL_BIAS) { @@ -103,15 +113,15 @@ struct ConvBiasMatmul { vwp2 = vld1q_s32(workspace + 8); int8x8_t vres; - vres = Process::run({{vwp0, vwp1}}, - {{vbias0, vbias0}}, op); + vres = Process::run( + {{vwp0, vwp1}}, {{vbias0, vbias0}}, op); vst1_s8(C, vres); - vres = Process::run({{vwp1, vwp2}}, - {{vbias0, vbias0}}, op); + vres = Process::run( + {{vwp1, vwp2}}, {{vbias0, vbias0}}, op); //! save the high half - vst1_lane_s32(reinterpret_cast(C + 8), - vreinterpret_s32_s8(vres), 1); + vst1_lane_s32( + reinterpret_cast(C + 8), vreinterpret_s32_s8(vres), 1); bias++; C += LDC; @@ -120,13 +130,12 @@ struct ConvBiasMatmul { } }; - template struct ConvBiasMatmul { - static void postprocess(const dt_int32* bias, const dt_int32* workspace, - dt_int8* C, size_t LDC, const Op& op) { - static_assert(m > 0 && m <= block_m && n > 0 && n <= 4, - "invalid m or n"); + static void postprocess( + const dt_int32* bias, const dt_int32* workspace, dt_int8* C, size_t LDC, + const Op& op) { + static_assert(m > 0 && m <= block_m && n > 0 && n <= 4, "invalid m or n"); int i = 0; int32x4_t vbias0, vbias1, vwp0, vwp1; if (bmode != BiasMode::BROADCAST_CHANNEL_BIAS) { @@ -146,8 +155,8 @@ struct ConvBiasMatmul { vwp1 = vld1q_s32(workspace); int8x8_t vres; - vres = Process::run({{vwp0, vwp1}}, - {{vbias0, vbias1}}, op); + vres = Process::run( + {{vwp0, vwp1}}, {{vbias0, vbias1}}, op); SAVE(C, vres, n, 0); C += LDC; SAVE(C, vres, n, 4); @@ -168,8 +177,8 @@ struct ConvBiasMatmul { vwp1 = QConverterBase::vzero(); int8x8_t vres; - vres = Process::run({{vwp0, vwp1}}, - {{vbias0, vbias1}}, op); + vres = Process::run( + {{vwp0, vwp1}}, {{vbias0, vbias1}}, op); SAVE(C, vres, n, 0); C += LDC; } @@ -178,8 +187,9 @@ struct ConvBiasMatmul { template struct ConvBiasMatmul { - static void postprocess(const dt_int32* bias, const dt_int32* workspace, - dt_int8* C, size_t LDC, const Op& op) { + static void postprocess( + const dt_int32* bias, const dt_int32* workspace, dt_int8* C, size_t LDC, + const Op& op) { static_assert(m > 0 && m <= block_m, "invalid m or n"); int i = 0; int32x4_t vbias0, vbias1, vwp0, vwp1; @@ -200,8 +210,8 @@ struct ConvBiasMatmul { vwp1 = vcombine_s32(vld1_s32(workspace), vdup_n_s32(0)); int8x8_t vres; - vres = Process::run({{vwp0, vwp1}}, - {{vbias0, vbias1}}, op); + vres = Process::run( + {{vwp0, vwp1}}, {{vbias0, vbias1}}, op); SAVE(C, vres, n, 0); C += LDC; SAVE(C, vres, n, 4); @@ -222,8 +232,8 @@ struct ConvBiasMatmul { vwp1 = QConverterBase::vzero(); int8x8_t vres; - vres = Process::run({{vwp0, vwp1}}, - {{vbias0, vbias1}}, op); + vres = Process::run( + {{vwp0, vwp1}}, {{vbias0, vbias1}}, op); SAVE(C, vres, n, 0); C += LDC; } @@ -232,8 +242,9 @@ struct ConvBiasMatmul { template struct ConvBiasMatmul { - static void postprocess(const dt_int32* bias, const dt_int32* workspace, - dt_uint8* C, size_t LDC, const Op& op) { + static void postprocess( + const dt_int32* bias, const dt_int32* workspace, dt_uint8* C, size_t LDC, + const Op& op) { static_assert(m > 0 && m <= block_m, "invalid m or n"); int32x4_t vbias0, vwp0, vwp1; if (bmode != BiasMode::BROADCAST_CHANNEL_BIAS) { @@ -258,13 +269,12 @@ struct ConvBiasMatmul { } }; - template struct ConvBiasMatmul { - static void postprocess(const dt_int32* bias, const dt_int32* workspace, - dt_uint8* C, size_t LDC, const Op& op) { - static_assert(m > 0 && m <= block_m && n > 0 && n <= 4, - "invalid m or n"); + static void postprocess( + const dt_int32* bias, const dt_int32* workspace, dt_uint8* C, size_t LDC, + const Op& op) { + static_assert(m > 0 && m <= block_m && n > 0 && n <= 4, "invalid m or n"); int i = 0; int32x4_t vbias0, vbias1, vwp0, vwp1; if (bmode != BiasMode::BROADCAST_CHANNEL_BIAS) { @@ -284,8 +294,8 @@ struct ConvBiasMatmul { vwp1 = vld1q_s32(workspace); uint8x8_t vres; - vres = Process::run({{vwp0, vwp1}}, - {{vbias0, vbias1}}, op); + vres = Process::run( + {{vwp0, vwp1}}, {{vbias0, vbias1}}, op); SAVEU(C, vres, n, 0); C += LDC; SAVEU(C, vres, n, 4); @@ -306,15 +316,14 @@ struct ConvBiasMatmul { vwp1 = QConverterBase::vzero(); uint8x8_t vres; - vres = Process::run({{vwp0, vwp1}}, - {{vbias0, vbias1}}, op); + vres = Process::run( + {{vwp0, vwp1}}, {{vbias0, vbias1}}, op); SAVEU(C, vres, n, 0); C += LDC; } } }; - #define DISPATCH_M(cb, _m, _n, ...) \ switch (_m) { \ case 4: { \ @@ -360,26 +369,26 @@ struct ConvBiasMatmul { } //! _n should be a compiler time constant -#define DISPATCH_M_N(cb, _m, _n, ...) \ - switch (_m) { \ - case 4: { \ - cb(4, _n, ##__VA_ARGS__); \ - break; \ - } \ - case 3: { \ - cb(3, _n, ##__VA_ARGS__); \ - break; \ - } \ - case 2: { \ - cb(2, _n, ##__VA_ARGS__); \ - break; \ - } \ - case 1: { \ - cb(1, _n, ##__VA_ARGS__); \ - break; \ - } \ - default: \ - megdnn_assert(0); \ +#define DISPATCH_M_N(cb, _m, _n, ...) \ + switch (_m) { \ + case 4: { \ + cb(4, _n, ##__VA_ARGS__); \ + break; \ + } \ + case 3: { \ + cb(3, _n, ##__VA_ARGS__); \ + break; \ + } \ + case 2: { \ + cb(2, _n, ##__VA_ARGS__); \ + break; \ + } \ + case 1: { \ + cb(1, _n, ##__VA_ARGS__); \ + break; \ + } \ + default: \ + megdnn_assert(0); \ } } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 19ca06b3..0b82b96d 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -57,8 +57,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoS8DirectStride1 s8_direct_stride1; AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; - AlgoS8x8x16ChanWiseStride1Stride2NCHW44 - s8x8x16_channel_wise_stride1_stride2_nchw44; + AlgoS8x8x16ChanWiseStride1Stride2NCHW44 s8x8x16_channel_wise_stride1_stride2_nchw44; #if MGB_ENABLE_DOT AlgoDotS8DirectStride1 ds8_direct_stride1; @@ -113,8 +112,7 @@ public: m_direct_algos.emplace_back(&s8_direct_nchw_nchw44); m_direct_algos.emplace_back(&s8_direct_stride1); - m_direct_algos.emplace_back( - &s8x8x16_channel_wise_stride1_stride2_nchw44); + m_direct_algos.emplace_back(&s8x8x16_channel_wise_stride1_stride2_nchw44); m_direct_algos.emplace_back(&s8_channel_wise_stride1_nchw44); m_direct_algos.emplace_back(&s8_channel_wise_stride2_nchw44); @@ -142,8 +140,7 @@ public: using MatmulFormat = param::MatrixMul::Format; auto&& matmul_algos = static_cast(matmul_opr) - ->select_algo_type( - {AlgoDataType::FLOAT32, MatmulFormat::MK4}); + ->select_algo_type({AlgoDataType::FLOAT32, MatmulFormat::MK4}); for (auto&& algo : matmul_algos) { if (is_fallback_or_naive(algo)) continue; @@ -179,8 +176,8 @@ public: } } matmul_algos = static_cast(matmul_opr) - ->select_algo_type({AlgoDataType::FLOAT32, - MatmulFormat::DEFAULT}); + ->select_algo_type( + {AlgoDataType::FLOAT32, MatmulFormat::DEFAULT}); for (auto&& algo : matmul_algos) { if (is_fallback_or_naive(algo)) continue; @@ -202,8 +199,8 @@ public: #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC matmul_algos = static_cast(matmul_opr) - ->select_algo_type({AlgoDataType::FLOAT16, - MatmulFormat::DEFAULT}); + ->select_algo_type( + {AlgoDataType::FLOAT16, MatmulFormat::DEFAULT}); for (auto&& algo : matmul_algos) { if (is_fallback_or_naive(algo)) continue; @@ -222,9 +219,9 @@ public: m_winograd_algos.emplace_back(refhold.back().get()); } } - matmul_algos = static_cast(matmul_opr) - ->select_algo_type({AlgoDataType::FLOAT16, - MatmulFormat::MK8}); + matmul_algos = + static_cast(matmul_opr) + ->select_algo_type({AlgoDataType::FLOAT16, MatmulFormat::MK8}); for (auto&& algo : matmul_algos) { if (is_fallback_or_naive(algo)) continue; @@ -237,8 +234,8 @@ public: } #endif matmul_algos = static_cast(matmul_opr) - ->select_algo_type({AlgoDataType::INT16X16X32, - MatmulFormat::MK8}); + ->select_algo_type( + {AlgoDataType::INT16X16X32, MatmulFormat::MK8}); for (auto&& algo : matmul_algos) { if (is_fallback_or_naive(algo)) continue; @@ -265,8 +262,7 @@ public: const SmallVector& direct_algos() const { return m_direct_algos; } - const SmallVector& winograd_algos() - const { + const SmallVector& winograd_algos() const { return m_winograd_algos; } const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } @@ -279,21 +275,21 @@ const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) -SmallVector -ConvBiasImpl::get_all_packed_algo() { +SmallVector ConvBiasImpl::get_all_packed_algo() { auto&& algos = fallback::ConvBiasImpl::get_all_packed_algo(); - algos.insert(algos.begin(), algo_pack().direct_algos().begin(), - algo_pack().direct_algos().end()); - algos.insert(algos.end(), algo_pack().winograd_algos().begin(), - algo_pack().winograd_algos().end()); + algos.insert( + algos.begin(), algo_pack().direct_algos().begin(), + algo_pack().direct_algos().end()); + algos.insert( + algos.end(), algo_pack().winograd_algos().begin(), + algo_pack().winograd_algos().end()); return std::move(algos); } bool ConvBiasImpl::is_matmul_quantized_prefer( const ConvBiasImpl::NCBKernSizeParam& param) const { fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param( - param, {}, 0, BiasMode::NO_BIAS, - param::ConvBias::NonlineMode::IDENTITY); + param, {}, 0, BiasMode::NO_BIAS, param::ConvBias::NonlineMode::IDENTITY); conv_ncb_param.dst_type = param.bias_type; conv_ncb_param.filter_meta.group = 1; @@ -307,10 +303,10 @@ bool ConvBiasImpl::is_matmul_quantized_prefer( conv_ncb_param); } else if (param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) { conv_direct_unusable = - !arm_common::direct_quint8_stride1:: - can_conv_direct_stride1_quint8(conv_ncb_param) && - !arm_common::direct_quint8_stride2:: - can_conv_direct_stride2_quint8(conv_ncb_param); + !arm_common::direct_quint8_stride1::can_conv_direct_stride1_quint8( + conv_ncb_param) && + !arm_common::direct_quint8_stride2::can_conv_direct_stride2_quint8( + conv_ncb_param); } return conv_direct_unusable; } @@ -338,11 +334,9 @@ SmallVector ConvBiasImpl::suggest_algo_category_order( } } if (im2col_prefer) { - return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, - AlgoCategory::NAIVE}; + return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, AlgoCategory::NAIVE}; } else { - return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, - AlgoCategory::NAIVE}; + return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, AlgoCategory::NAIVE}; } } diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 38106a77..65c19f49 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -28,12 +28,10 @@ public: } }; - SmallVector get_all_packed_algo() - override; + SmallVector get_all_packed_algo() override; bool is_matmul_quantized_prefer( - const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) - const override; + const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) const override; SmallVector suggest_algo_category_order( const NCBKernSizeParam& param) const override; diff --git a/dnn/src/arm_common/conv_bias/postprocess_helper.h b/dnn/src/arm_common/conv_bias/postprocess_helper.h index 11c9f60c..f150614c 100644 --- a/dnn/src/arm_common/conv_bias/postprocess_helper.h +++ b/dnn/src/arm_common/conv_bias/postprocess_helper.h @@ -43,35 +43,32 @@ namespace { case megdnn::NonlineMode::IDENTITY: \ break; -#define FOR_NONLINEAR_UNARY(_op) \ - megdnn::arm_common::OpCallerUnary<_op, megdnn::arm_common::VEC>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(dst_ptr), bias_type, dst_type, \ - N* OC* OH* OW* pack_oc_size); - -#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ - megdnn::arm_common:: \ - OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, \ - dst_type, N, OC, OH* OW); - -#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ - megdnn::arm_common::OpCallerBinary<_op, \ - megdnn::arm_common::VEC_BCAST101xX>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, \ - dst_type, N, OC, OH* OW, pack_oc_size); - -#define FOR_NONLINEAR_BINARY(_op) \ - megdnn::arm_common:: \ - OpCallerBinary<_op, megdnn::arm_common::VEC_VEC>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, \ - dst_type, N* OC* OH* OW* pack_oc_size); +#define FOR_NONLINEAR_UNARY(_op) \ + megdnn::arm_common::OpCallerUnary<_op, megdnn::arm_common::VEC>::run( \ + static_cast(conv_dst_ptr), reinterpret_cast(dst_ptr), \ + bias_type, dst_type, N* OC* OH* OW* pack_oc_size); + +#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ + megdnn::arm_common::OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, \ + OC, OH* OW); + +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ + megdnn::arm_common:: \ + OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101xX>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + N, OC, OH* OW, pack_oc_size); + +#define FOR_NONLINEAR_BINARY(_op) \ + megdnn::arm_common::OpCallerBinary<_op, megdnn::arm_common::VEC_VEC>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + N* OC* OH* OW* pack_oc_size); #define FOR_BIAS(_mode) \ switch (_mode) { \ @@ -86,8 +83,9 @@ namespace { if (pack_oc_size == 1) { \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ } else { \ - megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ - "Only support nchw44/nchw88 in ARM"); \ + megdnn_assert( \ + pack_oc_size == 4 || pack_oc_size == 8, \ + "Only support nchw44/nchw88 in ARM"); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \ } \ } \ @@ -122,24 +120,26 @@ namespace { DEFAULT \ } -template +template < + typename ctype, typename dtype = ctype, + megdnn::PostprocessMode postprocess_mode = megdnn::PostprocessMode::FLOAT> struct PostProcess { - static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, - size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + static void run( + void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, + size_t OH, size_t OW, size_t pack_oc_size = 1) { FOR_BIAS(bias_mode) } }; template struct PostProcess { - static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, - size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + static void run( + void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, + size_t OH, size_t OW, size_t pack_oc_size = 1) { MEGDNN_MARK_USED_VAR(conv_dst_ptr); MEGDNN_MARK_USED_VAR(bias_ptr); MEGDNN_MARK_USED_VAR(dst_ptr); @@ -152,9 +152,10 @@ struct PostProcess { MEGDNN_MARK_USED_VAR(OH); MEGDNN_MARK_USED_VAR(OW); MEGDNN_MARK_USED_VAR(pack_oc_size); - megdnn_throw_if(bias_mode != megdnn::BiasMode::NO_BIAS || - nonlineMode != megdnn::NonlineMode::IDENTITY, - megdnn_error, "biasmode or nonlineMode do not support"); + megdnn_throw_if( + bias_mode != megdnn::BiasMode::NO_BIAS || + nonlineMode != megdnn::NonlineMode::IDENTITY, + megdnn_error, "biasmode or nonlineMode do not support"); } }; @@ -167,37 +168,36 @@ struct PostProcess { #undef FOR_BIAS #undef HANDLE_IDENTITY -#define FOR_NONLINEAR_UNARY(_op) \ - megdnn::arm_common::OpCallerUnary< \ - _op, \ - megdnn::arm_common::VEC>::run(static_cast(conv_dst_ptr), \ - reinterpret_cast(dst_ptr), \ - bias_type, dst_type, \ - N* OC* OH* OW* pack_oc_size); - -#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ - megdnn::arm_common::OpCallerBinary<_op, \ - megdnn::arm_common::VEC_BCAST101>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, \ - dst_type, N, OC, OH* OW); - -#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ - megdnn::arm_common::OpCallerBinary<_op, \ - megdnn::arm_common::VEC_BCAST101xX>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, \ - dst_type, N, OC, OH* OW, pack_oc_size); - -#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ - megdnn::arm_common::OpCallerBinary<_op, \ - megdnn::arm_common::VEC_BCAST101xX>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, \ - dst_type, N, OC, OH* OW, pack_oc_size); +#define FOR_NONLINEAR_UNARY(_op) \ + megdnn::arm_common:: \ + OpCallerUnary<_op, megdnn::arm_common::VEC>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(dst_ptr), bias_type, dst_type, \ + N* OC* OH* OW* pack_oc_size); + +#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ + megdnn::arm_common::OpCallerBinary< \ + _op, megdnn::arm_common::VEC_BCAST101>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + N, OC, OH* OW); + +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \ + megdnn::arm_common::OpCallerBinary< \ + _op, megdnn::arm_common::VEC_BCAST101xX>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + N, OC, OH* OW, pack_oc_size); + +#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \ + megdnn::arm_common::OpCallerBinary< \ + _op, megdnn::arm_common::VEC_BCAST101xX>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + N, OC, OH* OW, pack_oc_size); #define HANDLE_IDENTITY(_caller, _op) \ case megdnn::NonlineMode::IDENTITY: \ @@ -228,8 +228,9 @@ struct PostProcess { if (pack_oc_size == 1) { \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ } else { \ - megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ - "Only support nchw44/nchw88 in ARM"); \ + megdnn_assert( \ + pack_oc_size == 4 || pack_oc_size == 8, \ + "Only support nchw44/nchw88 in ARM"); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \ } \ break; \ @@ -238,8 +239,9 @@ struct PostProcess { if (pack_oc_size == 1) { \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ } else { \ - megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ - "Only support nchw44/nchw88 in ARM"); \ + megdnn_assert( \ + pack_oc_size == 4 || pack_oc_size == 8, \ + "Only support nchw44/nchw88 in ARM"); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \ } \ break; \ @@ -250,10 +252,11 @@ struct PostProcess { template struct PostProcess { - static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, - size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + static void run( + void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, + size_t OH, size_t OW, size_t pack_oc_size = 1) { //! when OH * OW = 1, the bias_mode will be BiasMode::BIAS. It is wrong, //! we deal this case at default branch. FOR_BIAS(bias_mode, OH, OW); @@ -268,59 +271,60 @@ struct PostProcess { #undef FOR_NONLINEAR #undef FOR_BIAS -#define FOR_BINARY_BROADCAST(_op) \ - megdnn::arm_common:: \ - OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, \ - dst_type, N, OC, OH* OW); - -#define FOR_BINARY_BROADCAST_NCHWXX(_op) \ - megdnn::arm_common::OpCallerBinary<_op, \ - megdnn::arm_common::VEC_BCAST101xX>:: \ - run(static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, \ - dst_type, N, OC, OH* OW, pack_oc_size); - -#define FOR_BINARY(_op) \ - megdnn::arm_common:: \ - OpCallerBinary<_op, megdnn::arm_common::VEC_VEC>::run( \ - static_cast(conv_dst_ptr), \ - reinterpret_cast(bias_ptr), \ - reinterpret_cast(dst_ptr), bias_type, bias_type, \ - dst_type, N* OC* OH* OW* pack_oc_size); - -#define FOR_BIAS(_bias_mode, OH, OW) \ - switch (_bias_mode) { \ - case megdnn::BiasMode::NO_BIAS: \ - break; \ - case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ - if (pack_oc_size == 1) { \ - FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ - } else { \ - megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \ - "Only support nchw44/nchw88 in ARM"); \ - FOR_BINARY_BROADCAST_NCHWXX(CONCAT_OP(AddOp)); \ - } \ - break; \ - case megdnn::BiasMode::BIAS: \ - FOR_BINARY(CONCAT_OP(AddOp)); \ - break; \ - default: \ - megdnn_throw("unknow biasmode"); \ - break; \ +#define FOR_BINARY_BROADCAST(_op) \ + megdnn::arm_common::OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101>:: \ + run(static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, N, \ + OC, OH* OW); + +#define FOR_BINARY_BROADCAST_NCHWXX(_op) \ + megdnn::arm_common:: \ + OpCallerBinary<_op, megdnn::arm_common::VEC_BCAST101xX>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + N, OC, OH* OW, pack_oc_size); + +#define FOR_BINARY(_op) \ + megdnn::arm_common::OpCallerBinary<_op, megdnn::arm_common::VEC_VEC>::run( \ + static_cast(conv_dst_ptr), \ + reinterpret_cast(bias_ptr), \ + reinterpret_cast(dst_ptr), bias_type, bias_type, dst_type, \ + N* OC* OH* OW* pack_oc_size); + +#define FOR_BIAS(_bias_mode, OH, OW) \ + switch (_bias_mode) { \ + case megdnn::BiasMode::NO_BIAS: \ + break; \ + case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \ + if (pack_oc_size == 1) { \ + FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ + } else { \ + megdnn_assert( \ + pack_oc_size == 4 || pack_oc_size == 8, \ + "Only support nchw44/nchw88 in ARM"); \ + FOR_BINARY_BROADCAST_NCHWXX(CONCAT_OP(AddOp)); \ + } \ + break; \ + case megdnn::BiasMode::BIAS: \ + FOR_BINARY(CONCAT_OP(AddOp)); \ + break; \ + default: \ + megdnn_throw("unknow biasmode"); \ + break; \ } template struct PostProcess { - static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, - size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { - megdnn_throw_if(nonlineMode != megdnn::NonlineMode::IDENTITY, - megdnn_error, "nonlineMode do not support"); + static void run( + void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, + size_t OH, size_t OW, size_t pack_oc_size = 1) { + megdnn_throw_if( + nonlineMode != megdnn::NonlineMode::IDENTITY, megdnn_error, + "nonlineMode do not support"); FOR_BIAS(bias_mode, OH, OW); } }; @@ -335,123 +339,114 @@ struct PostProcess { #undef DEFAULT #undef HANDLE_IDENTITY -#define DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, _bias_id, _src_type, \ - _dst_type, _bmode, _nonline_mode, ...) \ - switch (_nonline_mode) { \ - case param::ConvBias::NonlineMode::IDENTITY: { \ - MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \ - cb(_bmode, NoneOp<_src_type MEGDNN_COMMA _dst_type>, \ - __VA_ARGS__); \ - } \ - MIDOUT_END(); \ - break; \ - } \ - case param::ConvBias::NonlineMode::RELU: { \ - MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \ - cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, \ - __VA_ARGS__); \ - } \ - MIDOUT_END(); \ - break; \ - } \ - case param::ConvBias::NonlineMode::SIGMOID: { \ - MIDOUT_BEGIN(_midout_tag, _bias_id, 2) { \ - cb(_bmode, SigmoidOp<_src_type MEGDNN_COMMA _dst_type>, \ - __VA_ARGS__); \ - } \ - MIDOUT_END(); \ - break; \ - } \ - case param::ConvBias::NonlineMode::H_SWISH: { \ - MIDOUT_BEGIN(_midout_tag, _bias_id, 3) { \ - cb(_bmode, HSwishOp<_src_type MEGDNN_COMMA _dst_type>, \ - __VA_ARGS__); \ - } \ - MIDOUT_END(); \ - break; \ - } \ - default: \ - megdnn_assert(0); \ - break; \ +#define DISPATCH_CONV_WINOGRAD_NONLINE( \ + _midout_tag, cb, _bias_id, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ + switch (_nonline_mode) { \ + case param::ConvBias::NonlineMode::IDENTITY: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \ + cb(_bmode, NoneOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + case param::ConvBias::NonlineMode::RELU: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \ + cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + case param::ConvBias::NonlineMode::SIGMOID: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 2) { \ + cb(_bmode, SigmoidOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + case param::ConvBias::NonlineMode::H_SWISH: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 3) { \ + cb(_bmode, HSwishOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + break; \ } -#define DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED(_midout_tag, cb, _bias_id, \ - _src_type, _dst_type, _bmode, \ - _nonline_mode, ...) \ - switch (_nonline_mode) { \ - case param::ConvBias::NonlineMode::IDENTITY: { \ - MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \ - cb(_bmode, TypeCvtOp<_src_type MEGDNN_COMMA _dst_type>, \ - __VA_ARGS__); \ - } \ - MIDOUT_END(); \ - break; \ - } \ - case param::ConvBias::NonlineMode::RELU: { \ - MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \ - cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, \ - __VA_ARGS__); \ - } \ - MIDOUT_END(); \ - break; \ - } \ - default: \ - megdnn_assert(0); \ - break; \ +#define DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ + _midout_tag, cb, _bias_id, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ + switch (_nonline_mode) { \ + case param::ConvBias::NonlineMode::IDENTITY: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 0) { \ + cb(_bmode, TypeCvtOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + case param::ConvBias::NonlineMode::RELU: { \ + MIDOUT_BEGIN(_midout_tag, _bias_id, 1) { \ + cb(_bmode, ReluOp<_src_type MEGDNN_COMMA _dst_type>, __VA_ARGS__); \ + } \ + MIDOUT_END(); \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + break; \ } -#define DISPATCH_CONV_WINOGRAD_BIAS(_midout_tag, cb, _src_type, _dst_type, \ - _bmode, _nonline_mode, ...) \ - switch (_bmode) { \ - case BiasMode::BIAS: { \ - DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 0, _src_type, \ - _dst_type, BiasMode::BIAS, \ - _nonline_mode, __VA_ARGS__) \ - break; \ - } \ - case BiasMode::NO_BIAS: { \ - DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 1, _src_type, \ - _dst_type, BiasMode::NO_BIAS, \ - _nonline_mode, __VA_ARGS__) \ - break; \ - } \ - case BiasMode::BROADCAST_CHANNEL_BIAS: { \ - DISPATCH_CONV_WINOGRAD_NONLINE(_midout_tag, cb, 2, _src_type, \ - _dst_type, \ - BiasMode::BROADCAST_CHANNEL_BIAS, \ - _nonline_mode, __VA_ARGS__) \ - break; \ - } \ - default: \ - megdnn_assert(0); \ - break; \ +#define DISPATCH_CONV_WINOGRAD_BIAS( \ + _midout_tag, cb, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ + switch (_bmode) { \ + case BiasMode::BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE( \ + _midout_tag, cb, 0, _src_type, _dst_type, BiasMode::BIAS, \ + _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + case BiasMode::NO_BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE( \ + _midout_tag, cb, 1, _src_type, _dst_type, BiasMode::NO_BIAS, \ + _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + case BiasMode::BROADCAST_CHANNEL_BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE( \ + _midout_tag, cb, 2, _src_type, _dst_type, \ + BiasMode::BROADCAST_CHANNEL_BIAS, _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + break; \ } -#define DISPATCH_CONV_WINOGRAD_BIAS_QUANTIZED( \ - _midout_tag, cb, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ - switch (_bmode) { \ - case BiasMode::BIAS: { \ - DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ - _midout_tag, cb, 0, _src_type, _dst_type, BiasMode::BIAS, \ - _nonline_mode, __VA_ARGS__) \ - break; \ - } \ - case BiasMode::NO_BIAS: { \ - DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ - _midout_tag, cb, 1, _src_type, _dst_type, \ - BiasMode::NO_BIAS, _nonline_mode, __VA_ARGS__) \ - break; \ - } \ - case BiasMode::BROADCAST_CHANNEL_BIAS: { \ - DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ - _midout_tag, cb, 2, _src_type, _dst_type, \ - BiasMode::BROADCAST_CHANNEL_BIAS, _nonline_mode, \ - __VA_ARGS__) \ - break; \ - } \ - default: \ - megdnn_assert(0); \ - break; \ +#define DISPATCH_CONV_WINOGRAD_BIAS_QUANTIZED( \ + _midout_tag, cb, _src_type, _dst_type, _bmode, _nonline_mode, ...) \ + switch (_bmode) { \ + case BiasMode::BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ + _midout_tag, cb, 0, _src_type, _dst_type, BiasMode::BIAS, \ + _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + case BiasMode::NO_BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ + _midout_tag, cb, 1, _src_type, _dst_type, BiasMode::NO_BIAS, \ + _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + case BiasMode::BROADCAST_CHANNEL_BIAS: { \ + DISPATCH_CONV_WINOGRAD_NONLINE_QUANTIZED( \ + _midout_tag, cb, 2, _src_type, _dst_type, \ + BiasMode::BROADCAST_CHANNEL_BIAS, _nonline_mode, __VA_ARGS__) \ + break; \ + } \ + default: \ + megdnn_assert(0); \ + break; \ } } // namespace diff --git a/dnn/src/arm_common/conv_bias/quint8/algos.cpp b/dnn/src/arm_common/conv_bias/quint8/algos.cpp index afd41fb0..9c7992e0 100644 --- a/dnn/src/arm_common/conv_bias/quint8/algos.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/algos.cpp @@ -26,15 +26,16 @@ using namespace megdnn; using namespace arm_common; /* ===================== stride1 algo ===================== */ -bool ConvBiasImpl::AlgoQU8DirectStride1::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { +bool ConvBiasImpl::AlgoQU8DirectStride1::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { return direct_quint8_stride1::can_conv_direct_stride1_quint8(param); } size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, - midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_quint8, + midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = direct_quint8_stride1::get_bundle(param, large_group); return bundle.total_size_in_bytes(); @@ -43,11 +44,11 @@ size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, - midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_quint8, + midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; return direct_quint8_stride1::get_kimpls(param, large_group); } @@ -57,15 +58,15 @@ ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns( /* ===================== stride2 algo ===================== */ bool ConvBiasImpl::AlgoQU8DirectStride2::usable( - const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { return direct_quint8_stride2::can_conv_direct_stride2_quint8(param); } size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, - midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_quint8, + midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = direct_quint8_stride2::get_bundle(param, large_group); return bundle.total_size_in_bytes(); @@ -74,11 +75,11 @@ size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, - midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_quint8, + midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; return direct_quint8_stride2::get_kimpls(param, large_group); } @@ -87,9 +88,9 @@ ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( } #if MGB_ENABLE_DOT /* ===================== stride1 algo ===================== */ -bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { - if (!cpuinfo_has_arm_neon_dot()){ +bool ConvBiasImpl::AlgoDotU8DirectStride1::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + if (!cpuinfo_has_arm_neon_dot()) { return false; } return direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(param); @@ -97,8 +98,9 @@ bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param, size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, - midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_quint8, + midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group); return bundle.total_size_in_bytes(); @@ -107,11 +109,11 @@ size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, - midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_quint8, + midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; return direct_dotprod_quint8_stride1::get_kimpls(param, large_group); } @@ -120,9 +122,9 @@ ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( } /* ===================== stride2 algo ===================== */ -bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy) const { - if (!cpuinfo_has_arm_neon_dot()){ +bool ConvBiasImpl::AlgoDotU8DirectStride2::usable( + const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + if (!cpuinfo_has_arm_neon_dot()) { return false; } return direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(param); @@ -130,8 +132,9 @@ bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param, size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, - midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_quint8, + midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group); return bundle.total_size_in_bytes(); @@ -140,11 +143,11 @@ size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace( return 0; } -SmallVector -ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns( +SmallVector ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns( const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, - midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_common_conv_bias_quint8, + midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) { bool large_group = param.filter_meta.group >= param.nr_threads; return direct_dotprod_quint8_stride2::get_kimpls(param, large_group); } diff --git a/dnn/src/arm_common/conv_bias/quint8/algos.h b/dnn/src/arm_common/conv_bias/quint8/algos.h index ae188c0f..9d8c2f58 100644 --- a/dnn/src/arm_common/conv_bias/quint8/algos.h +++ b/dnn/src/arm_common/conv_bias/quint8/algos.h @@ -18,15 +18,13 @@ namespace megdnn { namespace arm_common { class ConvBiasImpl::AlgoQU8DirectStride1 final : public AlgoBase { - public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "QU8STRD1"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -38,14 +36,12 @@ public: }; class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { - public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "QU8STRD2"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -57,15 +53,13 @@ public: }; #if MGB_ENABLE_DOT class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { - public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMDOTU8STRD1"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( @@ -77,14 +71,12 @@ public: }; class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { - public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMDOTU8STRD2"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( diff --git a/dnn/src/arm_common/conv_bias/quint8/direct.cpp b/dnn/src/arm_common/conv_bias/quint8/direct.cpp index 59490afa..bd337e83 100644 --- a/dnn/src/arm_common/conv_bias/quint8/direct.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/direct.cpp @@ -42,10 +42,10 @@ MIDOUT_DECL(conv_direct_stride) template void conv_bias::conv_direct_stride1_2x2_quint8( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const int8_t src_zp, - const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const int8_t src_zp, const int8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); MIDOUT_BEGIN(conv_direct_stride, 0, 0) { int16x8_t v128 = vdupq_n_s16(128); @@ -365,10 +365,10 @@ void conv_bias::conv_direct_stride1_2x2_quint8( template void conv_bias::conv_direct_stride1_3x3_quint8( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const int8_t src_zp, - const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const int8_t src_zp, const int8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); MIDOUT_BEGIN(conv_direct_stride, 0, 1) { int16x8_t v128 = vdupq_n_s16(128); @@ -530,10 +530,10 @@ void conv_bias::conv_direct_stride1_3x3_quint8( template void conv_bias::conv_direct_stride1_5x5_quint8( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const int8_t src_zp, - const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const int8_t src_zp, const int8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); MIDOUT_BEGIN(conv_direct_stride, 0, 2) { int16x8_t v128 = vdupq_n_s16(128); @@ -802,10 +802,10 @@ void conv_bias::conv_direct_stride1_5x5_quint8( template void conv_bias::conv_direct_stride1_7x7_quint8( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const int8_t src_zp, - const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const int8_t src_zp, const int8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); MIDOUT_BEGIN(conv_direct_stride, 0, 3) { int16x8_t v128 = vdupq_n_s16(128); @@ -1235,11 +1235,11 @@ void conv_bias::conv_direct_stride1_7x7_quint8( template void conv_bias::conv_direct_stride2_2x2_quint8( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const int8_t src_zp, - const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { - MEGDNN_MARK_USED_VAR(IH); + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const int8_t src_zp, const int8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { + MEGDNN_MARK_USED_VAR(IH); #define GET_R2(sptr) \ _r00 = SUB128VECTOR(vld1_u8(sptr)); \ _r00 = vtbl1_s8(_r00, _idx); \ @@ -1316,10 +1316,10 @@ void conv_bias::conv_direct_stride2_2x2_quint8( template void conv_bias::conv_direct_stride2_3x3_quint8( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const int8_t src_zp, - const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const int8_t src_zp, const int8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); #define GET_R3(sptr) \ _r00 = SUB128VECTOR(vld1_u8(sptr)); \ @@ -1481,10 +1481,10 @@ void conv_bias::conv_direct_stride2_3x3_quint8( template void conv_bias::conv_direct_stride2_5x5_quint8( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const int8_t src_zp, - const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const int8_t src_zp, const int8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); #define GET_R5(sptr) \ _r00 = SUB128VECTOR(vld1_u8(sptr)); \ @@ -1723,10 +1723,10 @@ void conv_bias::conv_direct_stride2_5x5_quint8( template void conv_bias::conv_direct_stride2_7x7_quint8( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const int8_t src_zp, - const int8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const int8_t src_zp, const int8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); #define GET_R7(sptr) \ _r00 = SUB128VECTOR(vld1_u8(sptr)); \ @@ -2082,25 +2082,27 @@ void conv_bias::conv_direct_stride2_7x7_quint8( #undef POSTPROCESS #undef ACC_S16_S32 -#define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \ - template void conv_bias::conv_direct_##stride##_##i##x##i##_quint8< \ - first_ic, last_ic, bias, Op>( \ - const uint8_t*, const uint8_t*, const int32_t*, int32_t*, \ - uint8_t*, const size_t, const size_t, const size_t, const size_t, \ - const int8_t, const int8_t, const int32_t, const Op&); - -#define FOR_NONLINEAR(stride, i, first_ic, last_ic, bias) \ - INSTANTIATION(stride, i, first_ic, last_ic, bias, \ - TypeCvtOp) \ - INSTANTIATION(stride, i, first_ic, last_ic, bias, \ - ReluOp) \ - INSTANTIATION(stride, i, first_ic, last_ic, bias, \ - HSwishOp) +#define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \ + template void conv_bias::conv_direct_##stride##_##i##x##i##_quint8< \ + first_ic, last_ic, bias, Op>( \ + const uint8_t*, const uint8_t*, const int32_t*, int32_t*, uint8_t*, \ + const size_t, const size_t, const size_t, const size_t, const int8_t, \ + const int8_t, const int32_t, const Op&); + +#define FOR_NONLINEAR(stride, i, first_ic, last_ic, bias) \ + INSTANTIATION( \ + stride, i, first_ic, last_ic, bias, \ + TypeCvtOp) \ + INSTANTIATION( \ + stride, i, first_ic, last_ic, bias, \ + ReluOp) \ + INSTANTIATION( \ + stride, i, first_ic, last_ic, bias, \ + HSwishOp) #define FOR_BIAS(stride, i, first_ic, last_ic) \ FOR_NONLINEAR(stride, i, first_ic, last_ic, BiasMode::NO_BIAS) \ - FOR_NONLINEAR(stride, i, first_ic, last_ic, \ - BiasMode::BROADCAST_CHANNEL_BIAS) + FOR_NONLINEAR(stride, i, first_ic, last_ic, BiasMode::BROADCAST_CHANNEL_BIAS) #define FOR_IC(stride, i) \ FOR_BIAS(stride, i, true, true) \ diff --git a/dnn/src/arm_common/conv_bias/quint8/direct.h b/dnn/src/arm_common/conv_bias/quint8/direct.h index 74cf58fc..98db4126 100644 --- a/dnn/src/arm_common/conv_bias/quint8/direct.h +++ b/dnn/src/arm_common/conv_bias/quint8/direct.h @@ -22,8 +22,7 @@ namespace conv_bias { const uint8_t* src, const uint8_t* filter, const int32_t* bias, \ int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, \ const size_t OH, const size_t OW, const int8_t src_zp, \ - const int8_t filter_zp, const int32_t src_filter_zp, \ - const Op& op); + const int8_t filter_zp, const int32_t src_filter_zp, const Op& op); KERN(stride1, 2) KERN(stride1, 3) @@ -37,7 +36,7 @@ KERN(stride2, 7) #undef KERN -} // namesapce conv_bias +} // namespace conv_bias } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp index 92e172ed..925b1eea 100644 --- a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp @@ -23,16 +23,16 @@ using megdnn::arm_common::TypeCvtOp; constexpr int32_t SHIFT = (1 << 30); -inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index){ - int8x8x2_t src; - src.val[0] = vget_low_s8(a); - src.val[1] = vget_high_s8(a); - uint8x8_t index_low = vget_low_u8(index); - uint8x8_t index_high = vget_high_u8(index); - int8x8_t r00 = vtbl2_s8(src,vreinterpret_s8_u8(index_low)) ; - int8x8_t r01 = vtbl2_s8(src,vreinterpret_s8_u8(index_high)); - int8x16_t r = vcombine_s8(r00,r01); - return r; +inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index) { + int8x8x2_t src; + src.val[0] = vget_low_s8(a); + src.val[1] = vget_high_s8(a); + uint8x8_t index_low = vget_low_u8(index); + uint8x8_t index_high = vget_high_u8(index); + int8x8_t r00 = vtbl2_s8(src, vreinterpret_s8_u8(index_low)); + int8x8_t r01 = vtbl2_s8(src, vreinterpret_s8_u8(index_high)); + int8x16_t r = vcombine_s8(r00, r01); + return r; } #define ST1_S32X4(dst0, tptr) vst1q_u32(tptr, dst0); @@ -48,28 +48,27 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index){ ST1_S32X4(dst1, tptr + 4); \ } -#define POSTPROCESS2_1X8(dst0, tptr, dptr) \ - if (last_ic && fused_kern) { \ - uint32x4x2_t temp; \ - uint32x4_t temp00, temp11; \ - temp = vzipq_u32(dst0.val[0], dst0.val[1]); \ - temp00 = temp.val[0]; \ - temp11 = temp.val[1]; \ - op({{temp00,temp11}},reinterpret_cast(dptr)); \ - } else { \ - ST2_S32X4X2(dst0, tptr); \ +#define POSTPROCESS2_1X8(dst0, tptr, dptr) \ + if (last_ic && fused_kern) { \ + uint32x4x2_t temp; \ + uint32x4_t temp00, temp11; \ + temp = vzipq_u32(dst0.val[0], dst0.val[1]); \ + temp00 = temp.val[0]; \ + temp11 = temp.val[1]; \ + op({{temp00, temp11}}, reinterpret_cast(dptr)); \ + } else { \ + ST2_S32X4X2(dst0, tptr); \ } -#define POSTPROCESS_2X4(dst0, dst1, tptr1, tptr2, dptr1, dptr2) \ - if (last_ic && fused_kern) { \ - uint32x2_t res = reinterpret_cast( \ - op({{vreinterpretq_u32_s32(dst0), \ - vreinterpretq_u32_s32(dst1)}})); \ - vst1_lane_u32(reinterpret_cast(dptr1), res, 0); \ - vst1_lane_u32(reinterpret_cast(dptr2), res, 1); \ - } else { \ - ST1_S32X4(dst0, tptr1); \ - ST1_S32X4(dst1, tptr2); \ +#define POSTPROCESS_2X4(dst0, dst1, tptr1, tptr2, dptr1, dptr2) \ + if (last_ic && fused_kern) { \ + uint32x2_t res = reinterpret_cast( \ + op({{vreinterpretq_u32_s32(dst0), vreinterpretq_u32_s32(dst1)}})); \ + vst1_lane_u32(reinterpret_cast(dptr1), res, 0); \ + vst1_lane_u32(reinterpret_cast(dptr2), res, 1); \ + } else { \ + ST1_S32X4(dst0, tptr1); \ + ST1_S32X4(dst1, tptr2); \ } #define POSTPROCESS_1X4(dst0, tptr, dptr) \ @@ -91,9 +90,8 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index){ tptr = vgetq_lane_u32(dst0, 0); \ } -#define CALC_DST(_sum) \ - _sum = vreinterpretq_u32_s32( \ - vaddq_s32(vreinterpretq_s32_u32(_sum), _shift_zp)) +#define CALC_DST(_sum) \ + _sum = vreinterpretq_u32_s32(vaddq_s32(vreinterpretq_s32_u32(_sum), _shift_zp)) #define CALC_0(_k_idx, _c_idx) \ _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx); \ @@ -107,25 +105,22 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index){ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); -#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ - _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k1_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k1_idx)); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k2_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k2_idx)); \ +#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k1_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k1_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k2_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k2_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); -template +template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_2x2_quint8_dot( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const uint8_t src_zp, - const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const uint8_t src_zp, const uint8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; @@ -138,10 +133,8 @@ void conv_bias::conv_direct_stride1_2x2_quint8_dot( _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); } - const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, - 2, 3, 16, 16, 3, 4, 16, 16}; - const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, - 6, 7, 16, 16, 7, 8, 16, 16}; + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, 6, 7, 16, 16, 7, 8, 16, 16}; //! here we use uint32_t for calc uint32_t* outptr = reinterpret_cast(temp); uint32_t* outptr2 = outptr + OW; @@ -154,8 +147,8 @@ void conv_bias::conv_direct_stride1_2x2_quint8_dot( const uint8_t* k0 = filter; - uint8x16_t _k = vreinterpretq_u8_u32( - vdupq_n_u32(*reinterpret_cast(k0))); + uint8x16_t _k = + vreinterpretq_u8_u32(vdupq_n_u32(*reinterpret_cast(k0))); uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; uint8x16_t _k1 = vqtbl1q_s8_v7(_k, _idx); _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; @@ -200,13 +193,13 @@ void conv_bias::conv_direct_stride1_2x2_quint8_dot( uint8x16_t _r11 = vextq_u8(_r10, _r11_, 1); uint8x16_t _r21 = vextq_u8(_r20, _r21_, 1); - int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), - vreinterpretq_s16_u8(_r10)); + int16x8x2_t r_0 = + vzipq_s16(vreinterpretq_s16_u8(_r00), vreinterpretq_s16_u8(_r10)); uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); - int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), - vreinterpretq_s16_u8(_r11)); + int16x8x2_t r_1 = + vzipq_s16(vreinterpretq_s16_u8(_r01), vreinterpretq_s16_u8(_r11)); int8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); int8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); @@ -215,13 +208,11 @@ void conv_bias::conv_direct_stride1_2x2_quint8_dot( SUB_ZP(_sum01.val[0], _r2); SUB_ZP(_sum01.val[1], _r3); - r_0 = vzipq_s16(vreinterpretq_s16_u8(_r10), - vreinterpretq_s16_u8(_r20)); + r_0 = vzipq_s16(vreinterpretq_s16_u8(_r10), vreinterpretq_s16_u8(_r20)); _r0 = vreinterpretq_u8_s8(r_0.val[0]); _r2 = vreinterpretq_u8_s8(r_0.val[1]); - r_1 = vzipq_s16(vreinterpretq_s16_u8(_r11), - vreinterpretq_s16_u8(_r21)); + r_1 = vzipq_s16(vreinterpretq_s16_u8(_r11), vreinterpretq_s16_u8(_r21)); _r1 = vreinterpretq_u8_s8(r_1.val[0]); _r3 = vreinterpretq_u8_s8(r_1.val[1]); @@ -362,13 +353,13 @@ void conv_bias::conv_direct_stride1_2x2_quint8_dot( uint8x16_t _r01 = vextq_u8(_r00, _r01_, 1); uint8x16_t _r11 = vextq_u8(_r10, _r11_, 1); - int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), - vreinterpretq_s16_u8(_r10)); + int16x8x2_t r_0 = + vzipq_s16(vreinterpretq_s16_u8(_r00), vreinterpretq_s16_u8(_r10)); uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); - int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), - vreinterpretq_s16_u8(_r11)); + int16x8x2_t r_1 = + vzipq_s16(vreinterpretq_s16_u8(_r01), vreinterpretq_s16_u8(_r11)); int8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); int8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); @@ -451,14 +442,13 @@ void conv_bias::conv_direct_stride1_2x2_quint8_dot( #undef SUB_ZP } -template +template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_3x3_quint8_dot( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const uint8_t src_zp, - const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const uint8_t src_zp, const uint8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; @@ -471,10 +461,8 @@ void conv_bias::conv_direct_stride1_3x3_quint8_dot( _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); } - const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, - 2, 3, 4, 16, 3, 4, 5, 16}; - const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; uint32_t* outptr = reinterpret_cast(temp); @@ -691,14 +679,13 @@ void conv_bias::conv_direct_stride1_3x3_quint8_dot( } } -template +template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_2x2_quint8_dot( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const uint8_t src_zp, - const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const uint8_t src_zp, const uint8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - 2 * OW + IW; @@ -711,8 +698,7 @@ void conv_bias::conv_direct_stride2_2x2_quint8_dot( _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); } - const uint8x16_t _idx0 = {0, 1, 16, 16, 2, 3, 16, 16, - 4, 5, 16, 16, 6, 7, 16, 16}; + const uint8x16_t _idx0 = {0, 1, 16, 16, 2, 3, 16, 16, 4, 5, 16, 16, 6, 7, 16, 16}; uint32_t* outptr = reinterpret_cast(temp); uint8_t* dstptr = dst; @@ -721,8 +707,8 @@ void conv_bias::conv_direct_stride2_2x2_quint8_dot( const uint8_t* k0 = filter; - uint8x16_t _k = vreinterpretq_u8_u32( - vdupq_n_u32(*reinterpret_cast(k0))); + uint8x16_t _k = + vreinterpretq_u8_u32(vdupq_n_u32(*reinterpret_cast(k0))); uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; uint8x16_t _k1 = vqtbl1q_s8_v7(_k, _idx); _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; @@ -752,8 +738,8 @@ void conv_bias::conv_direct_stride2_2x2_quint8_dot( //! here will not not read out of bound uint8x16_t _r10 = vld1q_u8(r1); - int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), - vreinterpretq_s16_u8(_r10)); + int16x8x2_t r_0 = + vzipq_s16(vreinterpretq_s16_u8(_r00), vreinterpretq_s16_u8(_r10)); uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); uint8x16_t _r1 = vreinterpretq_u8_s8(r_0.val[1]); SUB_ZP(_sum0, _r0); @@ -802,14 +788,13 @@ void conv_bias::conv_direct_stride2_2x2_quint8_dot( #undef SUB_ZP } -template +template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_3x3_quint8_dot( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const uint8_t src_zp, - const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const uint8_t src_zp, const uint8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - 2 * OW + IW; @@ -822,13 +807,11 @@ void conv_bias::conv_direct_stride2_3x3_quint8_dot( _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); } - const uint8x16_t _idx0 = {0, 1, 2, 16, 2, 3, 4, 16, - 4, 5, 6, 16, 6, 7, 8, 16}; + const uint8x16_t _idx0 = {0, 1, 2, 16, 2, 3, 4, 16, 4, 5, 6, 16, 6, 7, 8, 16}; const uint8x16_t _idx1 = {8, 9, 10, 16, 10, 11, 12, 16, 12, 13, 14, 16, 16, 16, 16, 16}; //! start from 12 13 14 15 - const uint8x16_t _idx2 = {2, 3, 4, 16, 4, 5, 6, 16, - 6, 7, 8, 16, 8, 9, 10, 16}; + const uint8x16_t _idx2 = {2, 3, 4, 16, 4, 5, 6, 16, 6, 7, 8, 16, 8, 9, 10, 16}; const uint8x16_t _idx3 = {10, 11, 12, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16}; uint32_t* outptr = reinterpret_cast(temp); @@ -859,8 +842,7 @@ void conv_bias::conv_direct_stride2_3x3_quint8_dot( int w = 0; for (; w + 3 < width; w += 3) { - uint32x4_t _sum00, _sum01, _sum02, _sum03, _sum10, _sum11, _sum12, - _sum13; + uint32x4_t _sum00, _sum01, _sum02, _sum03, _sum10, _sum11, _sum12, _sum13; if (!first_ic) { _sum00 = vld1q_u32(outptr); _sum01 = vld1q_u32(outptr + 4); @@ -1091,60 +1073,51 @@ void conv_bias::conv_direct_stride2_3x3_quint8_dot( #undef CALC_1 #undef CALC_2 -#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ - _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ - _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ +#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); -#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ - _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k00_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ - _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ - _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k01_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ +#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k00_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k01_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); -#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ - _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ - _elem2 = vdotq2_u32(_filter_zp, _elem); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k10_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k10_idx)); \ - _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); \ - _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ - _elem2 = vdotq2_u32(_filter_zp, _elem); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k11_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k11_idx)); \ +#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k10_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k10_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); \ + _elem = vqtbl1q_s8_v7(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k11_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k11_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); -template +template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_5x5_quint8_dot( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const uint8_t src_zp, - const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const uint8_t src_zp, const uint8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; @@ -1446,14 +1419,13 @@ void conv_bias::conv_direct_stride1_5x5_quint8_dot( } } -template +template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_7x7_quint8_dot( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const uint8_t src_zp, - const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const uint8_t src_zp, const uint8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; @@ -1467,8 +1439,7 @@ void conv_bias::conv_direct_stride1_7x7_quint8_dot( } const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; - const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; @@ -1789,14 +1760,13 @@ void conv_bias::conv_direct_stride1_7x7_quint8_dot( } } -template +template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_5x5_quint8_dot( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const uint8_t src_zp, - const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const uint8_t src_zp, const uint8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - 2 * OW + IW; @@ -2095,14 +2065,13 @@ void conv_bias::conv_direct_stride2_5x5_quint8_dot( } } -template +template MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_7x7_quint8_dot( - const uint8_t* src, const uint8_t* filter, const int32_t* bias, - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, - const size_t OH, const size_t OW, const uint8_t src_zp, - const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op) { + const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, + uint8_t* dst, const size_t IH, const size_t IW, const size_t OH, + const size_t OW, const uint8_t src_zp, const uint8_t filter_zp, + const int32_t src_filter_zp, const Op& op) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - 2 * OW + IW; @@ -2116,8 +2085,7 @@ void conv_bias::conv_direct_stride2_7x7_quint8_dot( } const uint8x16_t _idx00 = {0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9}; - const uint8x16_t _idx01 = {4, 5, 6, 16, 6, 7, 8, 16, - 8, 9, 10, 16, 10, 11, 12, 16}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 6, 7, 8, 16, 8, 9, 10, 16, 10, 11, 12, 16}; //! start from 8 const uint8x16_t& _idx10 = _idx00; const uint8x16_t& _idx11 = _idx01; @@ -2477,25 +2445,29 @@ void conv_bias::conv_direct_stride2_7x7_quint8_dot( #undef ST1_S32X4 #undef ST2_S32X4X2 -#define INSTANTIATION(stride, i, first_ic, last_ic, fused_kern, bias, Op) \ - template void conv_bias::conv_direct_##stride##_##i##x##i##_quint8_dot< \ - first_ic, last_ic, fused_kern, bias, Op>( \ - const uint8_t*, const uint8_t*, const int32_t*, int32_t*, \ - uint8_t*, const size_t, const size_t, const size_t, const size_t, \ - const uint8_t, const uint8_t, const int32_t, const Op&); +#define INSTANTIATION(stride, i, first_ic, last_ic, fused_kern, bias, Op) \ + template void conv_bias::conv_direct_##stride##_##i##x##i##_quint8_dot< \ + first_ic, last_ic, fused_kern, bias, Op>( \ + const uint8_t*, const uint8_t*, const int32_t*, int32_t*, uint8_t*, \ + const size_t, const size_t, const size_t, const size_t, const uint8_t, \ + const uint8_t, const int32_t, const Op&); #define FOR_NONLINEAR(stride, i, first_ic, last_ic, fused_kern, bias) \ - INSTANTIATION(stride, i, first_ic, last_ic, fused_kern, bias, \ - TypeCvtOp) \ - INSTANTIATION(stride, i, first_ic, last_ic, fused_kern, bias, \ - ReluOp) \ - INSTANTIATION(stride, i, first_ic, last_ic, fused_kern, bias, \ - HSwishOp) + INSTANTIATION( \ + stride, i, first_ic, last_ic, fused_kern, bias, \ + TypeCvtOp) \ + INSTANTIATION( \ + stride, i, first_ic, last_ic, fused_kern, bias, \ + ReluOp) \ + INSTANTIATION( \ + stride, i, first_ic, last_ic, fused_kern, bias, \ + HSwishOp) #define FOR_BIAS(stride, i, first_ic, last_ic, fused_kern) \ FOR_NONLINEAR(stride, i, first_ic, last_ic, fused_kern, BiasMode::NO_BIAS) \ - FOR_NONLINEAR(stride, i, first_ic, last_ic, fused_kern, \ - BiasMode::BROADCAST_CHANNEL_BIAS) + FOR_NONLINEAR( \ + stride, i, first_ic, last_ic, fused_kern, \ + BiasMode::BROADCAST_CHANNEL_BIAS) #define FOR_KERN(stride, i, first_ic, last_ic) \ FOR_BIAS(stride, i, first_ic, last_ic, true) \ diff --git a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h index 6e0f2e82..7232d8d5 100644 --- a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h +++ b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h @@ -17,15 +17,15 @@ namespace megdnn { namespace arm_common { namespace conv_bias { -#define KERN(stride, i) \ - template \ - void conv_direct_##stride##_##i##x##i##_quint8_dot( \ - const uint8_t* src, const uint8_t* filter, const int32_t* bias, \ - int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, \ - const size_t OH, const size_t OW, const uint8_t src_zp, \ - const uint8_t filter_zp, const int32_t src_filter_zp, \ - const Op& op); +#define KERN(stride, i) \ + template < \ + bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, \ + typename Op> \ + void conv_direct_##stride##_##i##x##i##_quint8_dot( \ + const uint8_t* src, const uint8_t* filter, const int32_t* bias, \ + int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, \ + const size_t OH, const size_t OW, const uint8_t src_zp, \ + const uint8_t filter_zp, const int32_t src_filter_zp, const Op& op); KERN(stride1, 2) KERN(stride1, 3) @@ -39,7 +39,7 @@ KERN(stride2, 7) #undef KERN -} // namesapce conv_bias +} // namespace conv_bias } // namespace arm_common } // namespace megdnn #endif diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1.cpp b/dnn/src/arm_common/conv_bias/quint8/stride1.cpp index b0350d64..37099e79 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride1.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/stride1.cpp @@ -20,20 +20,18 @@ using namespace arm_common; using namespace direct_quint8_stride1; namespace { -bool need_dst_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { return param.osz[1] % 8; } -bool need_src_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_src_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { return true; } return need_dst_copy(param); } void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& IH2, + size_t& IW2, size_t& OH2, size_t& OW2) { auto&& fm = param.filter_meta; auto SW = fm.stride[1]; auto OH = param.osz[0]; @@ -54,15 +52,14 @@ bool direct_quint8_stride1::can_conv_direct_stride1_quint8( auto FH = fm.spatial[0]; auto OC = fm.ocpg; auto IC = fm.icpg; - bool avaible = - param.src_type.enumv() == DTypeEnum::Quantized8Asymm && - param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && - (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || - param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && - fm.format == param::Convolution::Format::NCHW && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + bool avaible = param.src_type.enumv() == DTypeEnum::Quantized8Asymm && + param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && + (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || + param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); if (param.bias_type.valid()) { avaible &= param.bias_type.enumv() == DTypeEnum::QuantizedS32; } @@ -83,9 +80,8 @@ WorkspaceBundle direct_quint8_stride1::get_bundle( get_rectified_size(param, IH2, IW2, OH2, OW2); size_t src_size = 0, dst_size = 0; if (need_src_copy(param)) { - src_size = m_large_group - ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads - : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; + src_size = m_large_group ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; }; if (need_dst_copy(param)) { dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; @@ -99,10 +95,8 @@ WorkspaceBundle direct_quint8_stride1::get_bundle( } //! Process one input channel copy padding void direct_quint8_stride1::copy_padding_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -116,14 +110,11 @@ void direct_quint8_stride1::copy_padding_kern( size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; - const uint8_t* sptr = - kern_param.src(batch_id, group_id, channel_id); + const uint8_t* sptr = kern_param.src(batch_id, group_id, channel_id); if (need_src_copy_var) { //! copy to sptr_base to eliminate padding effect uint8_t* sptr_base = static_cast(bundle.get(0)) + @@ -134,17 +125,17 @@ void direct_quint8_stride1::copy_padding_kern( kern_param.src_type.param().zero_point; std::memset(sptr_base, _src_zp, sizeof(uint8_t) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(uint8_t) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(uint8_t) * IW); } } }; //! compute one output channel template -void direct_quint8_stride1::do_conv_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { +void direct_quint8_stride1::do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t OH = kern_param.osz[0]; size_t OW = kern_param.osz[1]; size_t FH = kern_param.filter_meta.spatial[0]; @@ -159,30 +150,25 @@ void direct_quint8_stride1::do_conv_kern(const WorkspaceBundle& bundle, (kern_param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); #define SUB128(n) static_cast(static_cast(n) - 128) - uint8_t _src_zp = - kern_param.src_type.param().zero_point; + uint8_t _src_zp = kern_param.src_type.param().zero_point; int8_t src_zp = SUB128(_src_zp); - int8_t filter_zp = SUB128( - kern_param.filter_type.param().zero_point); + int8_t filter_zp = + SUB128(kern_param.filter_type.param().zero_point); int32_t src_filter_zp = static_cast(filter_zp) * static_cast(src_zp) * IC * FH * FW; #undef SUB128 Op op = Op(1.0f, 1.0f, 0); if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; - float scale_dst = - kern_param.dst_type.param().scale; - uint8_t dst_zp = - kern_param.dst_type.param().zero_point; + float scale_bias = kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + uint8_t dst_zp = kern_param.dst_type.param().zero_point; op = Op(scale_bias, scale_dst, dst_zp); } size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], oc = workspace_ids[2], - group_id = ncb_index.ndrange_id[0], + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + oc = workspace_ids[2], group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; const uint8_t* sptr = kern_param.src(batch_id, group_id); @@ -204,12 +190,12 @@ void direct_quint8_stride1::do_conv_kern(const WorkspaceBundle& bundle, dptr = dst; } -#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ - conv_bias::conv_direct_stride1_##filter##x##filter##_quint8< \ - first_ic, last_ic, bias_mode, Op>( \ - sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ - static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, \ - filter_zp, src_filter_zp, op) +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_quint8< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, filter_zp, \ + src_filter_zp, op) #define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ conv_bias::conv_direct_stride1_##filter##x##filter##_quint8< \ @@ -252,13 +238,14 @@ void direct_quint8_stride1::do_conv_kern(const WorkspaceBundle& bundle, #undef KERN1_NO_POST_PROCESS if (need_dst_copy_var) { rep(oh, OH) { - std::memcpy(reinterpret_cast( - reinterpret_cast(dst) + - oh * OW * kern_param.dst_type.size()), - reinterpret_cast( - reinterpret_cast(dptr) + - oh * OW2 * kern_param.dst_type.size()), - kern_param.dst_type.size() * OW); + std::memcpy( + reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); } } } @@ -276,23 +263,21 @@ SmallVector direct_quint8_stride1::get_kimpls( #define DO_CONV_KERN_FUN(filter, bias_mode, op) \ do_conv_fun = do_conv_kern; -#define GET_OP_PARAM(i, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN( \ + i, bias_mode, TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(i) \ @@ -339,21 +324,21 @@ SmallVector direct_quint8_stride1::get_kimpls( size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - do_conv_fun(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + do_conv_fun( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); } else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - copy_padding_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); auto do_conv = [bundle, do_conv_fun]( diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1.h b/dnn/src/arm_common/conv_bias/quint8/stride1.h index 5f0f0606..e1806a54 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride1.h +++ b/dnn/src/arm_common/conv_bias/quint8/stride1.h @@ -28,18 +28,16 @@ bool can_conv_direct_stride1_quint8(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); -void copy_padding_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void copy_padding_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); -SmallVector get_kimpls(const NCBKernSizeParam& param, - bool); +SmallVector get_kimpls(const NCBKernSizeParam& param, bool); } // namespace direct_quint8_stride1 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp index bf503a86..30dbf6a0 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp @@ -20,20 +20,18 @@ using namespace arm_common; using namespace direct_dotprod_quint8_stride1; namespace { -bool need_dst_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { return param.osz[1] % 8; } -bool need_src_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_src_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { return true; } return need_dst_copy(param); } void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& IH2, + size_t& IW2, size_t& OH2, size_t& OW2) { auto&& fm = param.filter_meta; auto SW = fm.stride[1]; auto OH = param.osz[0]; @@ -56,15 +54,14 @@ bool direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8( auto FH = fm.spatial[0]; auto OC = fm.ocpg; auto IC = fm.icpg; - bool avaible = - param.src_type.enumv() == DTypeEnum::Quantized8Asymm && - param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && - (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || - param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && - fm.format == param::Convolution::Format::NCHW && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + bool avaible = param.src_type.enumv() == DTypeEnum::Quantized8Asymm && + param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && + (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || + param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 1 && fm.stride[1] == 1 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); if (param.bias_type.valid()) { avaible &= param.bias_type.enumv() == DTypeEnum::QuantizedS32; } @@ -85,9 +82,8 @@ WorkspaceBundle direct_dotprod_quint8_stride1::get_bundle( get_rectified_size(param, IH2, IW2, OH2, OW2); size_t src_size = 0, dst_size = 0; if (need_src_copy(param)) { - src_size = m_large_group - ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads - : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; + src_size = m_large_group ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; }; if (need_dst_copy(param)) { dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; @@ -101,10 +97,8 @@ WorkspaceBundle direct_dotprod_quint8_stride1::get_bundle( } //! Process one input channel copy padding void direct_dotprod_quint8_stride1::copy_padding_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -118,13 +112,10 @@ void direct_dotprod_quint8_stride1::copy_padding_kern( size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; - const uint8_t* sptr = - kern_param.src(batch_id, group_id, channel_id); + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; + const uint8_t* sptr = kern_param.src(batch_id, group_id, channel_id); if (need_src_copy_var) { //! copy to sptr_base to eliminate padding effect @@ -136,8 +127,9 @@ void direct_dotprod_quint8_stride1::copy_padding_kern( kern_param.src_type.param().zero_point; std::memset(sptr_base, _src_zp, sizeof(uint8_t) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(uint8_t) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(uint8_t) * IW); } } }; @@ -159,28 +151,23 @@ void direct_dotprod_quint8_stride1::do_conv_kern( bool need_post_process = (kern_param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); - uint8_t src_zp = - kern_param.src_type.param().zero_point; + uint8_t src_zp = kern_param.src_type.param().zero_point; uint8_t filter_zp = kern_param.filter_type.param().zero_point; int32_t src_filter_zp = static_cast(filter_zp) * static_cast(src_zp) * IC * FH * FW; Op op(1.0f, 1.0f, static_cast(0)); if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; - float scale_dst = - kern_param.dst_type.param().scale; - uint8_t dst_zp = - kern_param.dst_type.param().zero_point; + float scale_bias = kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + uint8_t dst_zp = kern_param.dst_type.param().zero_point; op = Op(scale_bias, scale_dst, dst_zp); } size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], oc = workspace_ids[2], - group_id = ncb_index.ndrange_id[0], + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + oc = workspace_ids[2], group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; const uint8_t* sptr = kern_param.src(batch_id, group_id); @@ -203,12 +190,12 @@ void direct_dotprod_quint8_stride1::do_conv_kern( dptr = dst; } -#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ - conv_bias::conv_direct_stride1_##filter##x##filter##_quint8_dot< \ - first_ic, last_ic, true, bias_mode, Op>( \ - sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ - static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, \ - filter_zp, src_filter_zp, op) +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride1_##filter##x##filter##_quint8_dot< \ + first_ic, last_ic, true, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, filter_zp, \ + src_filter_zp, op) #define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ conv_bias::conv_direct_stride1_##filter##x##filter##_quint8_dot< \ @@ -252,13 +239,14 @@ void direct_dotprod_quint8_stride1::do_conv_kern( #undef KERN1_NO_POST_PROCESS if (need_dst_copy_var) { rep(oh, OH) { - std::memcpy(reinterpret_cast( - reinterpret_cast(dst) + - oh * OW * kern_param.dst_type.size()), - reinterpret_cast( - reinterpret_cast(dptr) + - oh * OW2 * kern_param.dst_type.size()), - kern_param.dst_type.size() * OW); + std::memcpy( + reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); } } } @@ -276,23 +264,21 @@ SmallVector direct_dotprod_quint8_stride1::get_kimpls( #define DO_CONV_KERN_FUN(filter, bias_mode, op) \ do_conv_fun = do_conv_kern; -#define GET_OP_PARAM(i, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN( \ + i, bias_mode, TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(i) \ @@ -339,21 +325,21 @@ SmallVector direct_dotprod_quint8_stride1::get_kimpls( size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - do_conv_fun(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + do_conv_fun( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); - }else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + } else { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - copy_padding_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); auto do_conv = [bundle, do_conv_fun]( diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h index 85c14fc2..d1b99267 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h +++ b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h @@ -28,18 +28,16 @@ bool can_conv_direct_stride1_quint8(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); -void copy_padding_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void copy_padding_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); -SmallVector get_kimpls(const NCBKernSizeParam& param, - bool); +SmallVector get_kimpls(const NCBKernSizeParam& param, bool); } // namespace direct_dotprod_quint8_stride1 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2.cpp b/dnn/src/arm_common/conv_bias/quint8/stride2.cpp index 718db855..f3ec300f 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride2.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/stride2.cpp @@ -20,20 +20,18 @@ using namespace arm_common; using namespace direct_quint8_stride2; namespace { -bool need_dst_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { return param.osz[1] % 8; } -bool need_src_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_src_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { return true; } return need_dst_copy(param); } void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& IH2, + size_t& IW2, size_t& OH2, size_t& OW2) { auto&& fm = param.filter_meta; size_t SW = fm.stride[1]; size_t IH = param.isz[0]; @@ -62,23 +60,22 @@ bool direct_quint8_stride2::can_conv_direct_stride2_quint8( auto FH = fm.spatial[0]; auto OC = fm.ocpg; auto IC = fm.icpg; - bool avaible = - param.src_type.enumv() == DTypeEnum::Quantized8Asymm && - param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && - (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || - param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && - fm.format == param::Convolution::Format::NCHW && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + bool avaible = param.src_type.enumv() == DTypeEnum::Quantized8Asymm && + param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && + (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || + param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); if (param.bias_type.valid()) { avaible &= param.bias_type.enumv() == DTypeEnum::QuantizedS32; } - bool preferred = (((FH == 2 || FH == 3) && - (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || - (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || - (FH == 7 && OC <= 16)) && - (param.bias_mode != BiasMode::BIAS); + bool preferred = + (((FH == 2 || FH == 3) && (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || + (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || + (FH == 7 && OC <= 16)) && + (param.bias_mode != BiasMode::BIAS); return avaible && preferred; } @@ -92,9 +89,8 @@ WorkspaceBundle direct_quint8_stride2::get_bundle( get_rectified_size(param, IH2, IW2, OH2, OW2); size_t src_size = 0, dst_size = 0; if (need_src_copy(param)) { - src_size = m_large_group - ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads - : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; + src_size = m_large_group ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; }; if (need_dst_copy(param)) { dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; @@ -108,10 +104,8 @@ WorkspaceBundle direct_quint8_stride2::get_bundle( } //! Process one input channel copy padding void direct_quint8_stride2::copy_padding_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -125,13 +119,11 @@ void direct_quint8_stride2::copy_padding_kern( size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2], - group_id = ncb_index.ndrange_id[0], + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + channel_id = workspace_ids[2], group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; - const uint8_t* sptr = - kern_param.src(batch_id, group_id, channel_id); + const uint8_t* sptr = kern_param.src(batch_id, group_id, channel_id); if (need_src_copy_var) { //! copy to sptr_base to eliminate padding effect uint8_t* sptr_base = static_cast(bundle.get(0)) + @@ -142,17 +134,17 @@ void direct_quint8_stride2::copy_padding_kern( kern_param.src_type.param().zero_point; std::memset(sptr_base, _src_zp, sizeof(uint8_t) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(uint8_t) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(uint8_t) * IW); } } }; //! compute one output channel template -void direct_quint8_stride2::do_conv_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { +void direct_quint8_stride2::do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t OH = kern_param.osz[0]; size_t OW = kern_param.osz[1]; size_t FH = kern_param.filter_meta.spatial[0]; @@ -167,30 +159,25 @@ void direct_quint8_stride2::do_conv_kern(const WorkspaceBundle& bundle, (kern_param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); #define SUB128(n) static_cast(static_cast(n) - 128) - uint8_t _src_zp = - kern_param.src_type.param().zero_point; + uint8_t _src_zp = kern_param.src_type.param().zero_point; int8_t src_zp = SUB128(_src_zp); - int8_t filter_zp = SUB128( - kern_param.filter_type.param().zero_point); + int8_t filter_zp = + SUB128(kern_param.filter_type.param().zero_point); int32_t src_filter_zp = static_cast(filter_zp) * static_cast(src_zp) * IC * FH * FW; #undef SUB128 Op op = Op(1.0f, 1.0f, 0); if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; - float scale_dst = - kern_param.dst_type.param().scale; - uint8_t dst_zp = - kern_param.dst_type.param().zero_point; + float scale_bias = kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + uint8_t dst_zp = kern_param.dst_type.param().zero_point; op = Op(scale_bias, scale_dst, dst_zp); } size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], oc = workspace_ids[2], - group_id = ncb_index.ndrange_id[0], + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + oc = workspace_ids[2], group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; const uint8_t* sptr = kern_param.src(batch_id, group_id); @@ -212,12 +199,12 @@ void direct_quint8_stride2::do_conv_kern(const WorkspaceBundle& bundle, dptr = dst; } -#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ - conv_bias::conv_direct_stride2_##filter##x##filter##_quint8< \ - first_ic, last_ic, bias_mode, Op>( \ - sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ - static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, \ - filter_zp, src_filter_zp, op) +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_quint8< \ + first_ic, last_ic, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, filter_zp, \ + src_filter_zp, op) #define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ conv_bias::conv_direct_stride2_##filter##x##filter##_quint8< \ @@ -260,13 +247,14 @@ void direct_quint8_stride2::do_conv_kern(const WorkspaceBundle& bundle, #undef KERN1_NO_POST_PROCESS if (need_dst_copy_var) { rep(oh, OH) { - std::memcpy(reinterpret_cast( - reinterpret_cast(dst) + - oh * OW * kern_param.dst_type.size()), - reinterpret_cast( - reinterpret_cast(dptr) + - oh * OW2 * kern_param.dst_type.size()), - kern_param.dst_type.size() * OW); + std::memcpy( + reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); } } } @@ -284,23 +272,21 @@ SmallVector direct_quint8_stride2::get_kimpls( #define DO_CONV_KERN_FUN(filter, bias_mode, op) \ do_conv_fun = do_conv_kern; -#define GET_OP_PARAM(i, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN( \ + i, bias_mode, TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(i) \ @@ -347,21 +333,21 @@ SmallVector direct_quint8_stride2::get_kimpls( size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - do_conv_fun(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + do_conv_fun( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); - }else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + } else { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - copy_padding_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); auto do_conv = [bundle, do_conv_fun]( diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2.h b/dnn/src/arm_common/conv_bias/quint8/stride2.h index eca45c36..d2ad0418 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride2.h +++ b/dnn/src/arm_common/conv_bias/quint8/stride2.h @@ -28,18 +28,16 @@ bool can_conv_direct_stride2_quint8(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); -void copy_padding_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void copy_padding_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); -SmallVector get_kimpls(const NCBKernSizeParam& param, - bool); +SmallVector get_kimpls(const NCBKernSizeParam& param, bool); } // namespace direct_quint8_stride2 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp index 818dfd97..839b0b30 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp @@ -20,20 +20,18 @@ using namespace arm_common; using namespace direct_dotprod_quint8_stride2; namespace { -bool need_dst_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_dst_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { return param.osz[1] % 8; } -bool need_src_copy( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { +bool need_src_copy(const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param) { if (param.filter_meta.padding[0] || param.filter_meta.padding[1]) { return true; } return need_dst_copy(param); } void get_rectified_size( - const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, - size_t& IH2, size_t& IW2, size_t& OH2, size_t& OW2) { + const megdnn::fallback::ConvBiasImpl::NCBKernSizeParam& param, size_t& IH2, + size_t& IW2, size_t& OH2, size_t& OW2) { auto&& fm = param.filter_meta; size_t SW = fm.stride[1]; size_t IH = param.isz[0]; @@ -62,23 +60,22 @@ bool direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8( auto FH = fm.spatial[0]; auto OC = fm.ocpg; auto IC = fm.icpg; - bool avaible = - param.src_type.enumv() == DTypeEnum::Quantized8Asymm && - param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && - (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || - param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && - fm.format == param::Convolution::Format::NCHW && !fm.should_flip && - fm.spatial_ndim == 2 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && - FH == fm.spatial[1] && (FH == 2 || FH == 3 || FH == 5 || FH == 7); + bool avaible = param.src_type.enumv() == DTypeEnum::Quantized8Asymm && + param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && + (param.dst_type.enumv() == DTypeEnum::QuantizedS32 || + param.dst_type.enumv() == DTypeEnum::Quantized8Asymm) && + fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && + fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && + (FH == 2 || FH == 3 || FH == 5 || FH == 7); if (param.bias_type.valid()) { avaible &= param.bias_type.enumv() == DTypeEnum::QuantizedS32; } - bool preferred = (((FH == 2 || FH == 3) && - (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || - (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || - (FH == 7 && OC <= 16)) && - (param.bias_mode != BiasMode::BIAS); + bool preferred = + (((FH == 2 || FH == 3) && (IC == 1 || (IC <= 8 && OC <= 12) || OC <= 8)) || + (FH == 5 && ((IC == 1 && OC <= 16) || OC <= 12)) || + (FH == 7 && OC <= 16)) && + (param.bias_mode != BiasMode::BIAS); return avaible && preferred; } @@ -92,9 +89,8 @@ WorkspaceBundle direct_dotprod_quint8_stride2::get_bundle( get_rectified_size(param, IH2, IW2, OH2, OW2); size_t src_size = 0, dst_size = 0; if (need_src_copy(param)) { - src_size = m_large_group - ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads - : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; + src_size = m_large_group ? IC * IH2 * IW2 * sizeof(uint8_t) * nr_threads + : IC * IH2 * IW2 * sizeof(uint8_t) * group * batch; }; if (need_dst_copy(param)) { dst_size = OH2 * OW2 * param.dst_type.size() * nr_threads; @@ -108,10 +104,8 @@ WorkspaceBundle direct_dotprod_quint8_stride2::get_bundle( } //! Process one input channel copy padding void direct_dotprod_quint8_stride2::copy_padding_kern( - const WorkspaceBundle& bundle, - const ConvBiasImpl::NCBKernParam& kern_param, - const ConvBiasImpl::NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids) { + const WorkspaceBundle& bundle, const ConvBiasImpl::NCBKernParam& kern_param, + const ConvBiasImpl::NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids) { size_t IH = kern_param.isz[0]; size_t IW = kern_param.isz[1]; size_t IC = kern_param.filter_meta.icpg; @@ -125,13 +119,10 @@ void direct_dotprod_quint8_stride2::copy_padding_kern( size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], channel_id = workspace_ids[2]; - size_t group_id = ncb_index.ndrange_id[0], - batch_id = ncb_index.ndrange_id[1]; - const uint8_t* sptr = - kern_param.src(batch_id, group_id, channel_id); + size_t group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; + const uint8_t* sptr = kern_param.src(batch_id, group_id, channel_id); if (need_src_copy_var) { //! copy to sptr_base to eliminate padding effect uint8_t* sptr_base = static_cast(bundle.get(0)) + @@ -142,8 +133,9 @@ void direct_dotprod_quint8_stride2::copy_padding_kern( kern_param.src_type.param().zero_point; std::memset(sptr_base, _src_zp, sizeof(uint8_t) * IH2 * IW2); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, - sizeof(uint8_t) * IW); + std::memcpy( + sptr_base + (ih + PH) * IW2 + PW, sptr + ih * IW, + sizeof(uint8_t) * IW); } } }; @@ -165,28 +157,23 @@ void direct_dotprod_quint8_stride2::do_conv_kern( bool need_post_process = (kern_param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); - uint8_t src_zp = - kern_param.src_type.param().zero_point; + uint8_t src_zp = kern_param.src_type.param().zero_point; uint8_t filter_zp = kern_param.filter_type.param().zero_point; int32_t src_filter_zp = static_cast(filter_zp) * static_cast(src_zp) * IC * FH * FW; Op op(1.0f, 1.0f, static_cast(0)); if (need_post_process) { - float scale_bias = - kern_param.bias_type.param().scale; - float scale_dst = - kern_param.dst_type.param().scale; - uint8_t dst_zp = - kern_param.dst_type.param().zero_point; + float scale_bias = kern_param.bias_type.param().scale; + float scale_dst = kern_param.dst_type.param().scale; + uint8_t dst_zp = kern_param.dst_type.param().zero_point; op = Op(scale_bias, scale_dst, dst_zp); } size_t padding_group_size = IH2 * IW2 * IC; //! Used for get the workspace offset - size_t workspace_group_id = workspace_ids[0], - workspace_batch_id = workspace_ids[1], oc = workspace_ids[2], - group_id = ncb_index.ndrange_id[0], + size_t workspace_group_id = workspace_ids[0], workspace_batch_id = workspace_ids[1], + oc = workspace_ids[2], group_id = ncb_index.ndrange_id[0], batch_id = ncb_index.ndrange_id[1]; const uint8_t* sptr = kern_param.src(batch_id, group_id); @@ -209,12 +196,12 @@ void direct_dotprod_quint8_stride2::do_conv_kern( dptr = dst; } -#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ - conv_bias::conv_direct_stride2_##filter##x##filter##_quint8_dot< \ - first_ic, last_ic, true, bias_mode, Op>( \ - sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ - static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, \ - filter_zp, src_filter_zp, op) +#define KERN0_NEED_POST_PROCESS(filter, first_ic, last_ic) \ + conv_bias::conv_direct_stride2_##filter##x##filter##_quint8_dot< \ + first_ic, last_ic, true, bias_mode, Op>( \ + sptr + ic * IH2 * IW2, fptr + ic * FH * FW, bptr, tptr, \ + static_cast(dptr), IH2, IW2, OH2, OW2, src_zp, filter_zp, \ + src_filter_zp, op) #define KERN0_NO_POST_PROCESS(filter, first_ic, last_ic) \ conv_bias::conv_direct_stride2_##filter##x##filter##_quint8_dot< \ @@ -258,13 +245,14 @@ void direct_dotprod_quint8_stride2::do_conv_kern( #undef KERN1_NO_POST_PROCESS if (need_dst_copy_var) { rep(oh, OH) { - std::memcpy(reinterpret_cast( - reinterpret_cast(dst) + - oh * OW * kern_param.dst_type.size()), - reinterpret_cast( - reinterpret_cast(dptr) + - oh * OW2 * kern_param.dst_type.size()), - kern_param.dst_type.size() * OW); + std::memcpy( + reinterpret_cast( + reinterpret_cast(dst) + + oh * OW * kern_param.dst_type.size()), + reinterpret_cast( + reinterpret_cast(dptr) + + oh * OW2 * kern_param.dst_type.size()), + kern_param.dst_type.size() * OW); } } } @@ -282,23 +270,21 @@ SmallVector direct_dotprod_quint8_stride2::get_kimpls( #define DO_CONV_KERN_FUN(filter, bias_mode, op) \ do_conv_fun = do_conv_kern; -#define GET_OP_PARAM(i, bias_mode) \ - switch (param.nonlineMode) { \ - case param::ConvBias::NonlineMode::IDENTITY: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - TypeCvtOp) \ - break; \ - case param::ConvBias::NonlineMode::RELU: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - ReluOp) \ - break; \ - case param::ConvBias::NonlineMode::H_SWISH: \ - DO_CONV_KERN_FUN(i, bias_mode, \ - HSwishOp) \ - break; \ - default: \ - megdnn_assert(0); \ - break; \ +#define GET_OP_PARAM(i, bias_mode) \ + switch (param.nonlineMode) { \ + case param::ConvBias::NonlineMode::IDENTITY: \ + DO_CONV_KERN_FUN( \ + i, bias_mode, TypeCvtOp) \ + break; \ + case param::ConvBias::NonlineMode::RELU: \ + DO_CONV_KERN_FUN(i, bias_mode, ReluOp) \ + break; \ + case param::ConvBias::NonlineMode::H_SWISH: \ + DO_CONV_KERN_FUN(i, bias_mode, HSwishOp) \ + break; \ + default: \ + megdnn_assert(0); \ + break; \ } #define GET_BIAS_MODE_PARAM(i) \ @@ -345,21 +331,21 @@ SmallVector direct_dotprod_quint8_stride2::get_kimpls( size_t OC = fm.ocpg; bundle.set(kern_param.workspace_ptr); for (size_t ic = 0; ic < IC; ic++) { - copy_padding_kern(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, ic}); + copy_padding_kern( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, ic}); } for (size_t oc = 0; oc < OC; oc++) { - do_conv_fun(bundle, kern_param, ncb_index, - {ncb_index.thread_id, 0, oc}); + do_conv_fun( + bundle, kern_param, ncb_index, {ncb_index.thread_id, 0, oc}); } }; ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); - }else { - auto copy_padding = [bundle](const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index) mutable { + } else { + auto copy_padding = [bundle]( + const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index) mutable { bundle.set(kern_param.workspace_ptr); - copy_padding_kern(bundle, kern_param, ncb_index, - ncb_index.ndrange_id); + copy_padding_kern(bundle, kern_param, ncb_index, ncb_index.ndrange_id); }; ret_kerns.push_back({copy_padding, {group, N, IC}}); auto do_conv = [bundle, do_conv_fun]( diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h index 7cd9e64a..d93d43ad 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h +++ b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h @@ -28,18 +28,16 @@ bool can_conv_direct_stride2_quint8(const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param, bool m_large_group); -void copy_padding_kern(const WorkspaceBundle& bundle, - const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void copy_padding_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); template -void do_conv_kern(const WorkspaceBundle& bundle, const NCBKernParam& kern_param, - const NCBKernIndex& ncb_index, - const CpuNDRange& workspace_ids); +void do_conv_kern( + const WorkspaceBundle& bundle, const NCBKernParam& kern_param, + const NCBKernIndex& ncb_index, const CpuNDRange& workspace_ids); -SmallVector get_kimpls(const NCBKernSizeParam& param, - bool); +SmallVector get_kimpls(const NCBKernSizeParam& param, bool); } // namespace direct_dotprod_quint8_stride2 } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/convolution/img2col_helper.h b/dnn/src/arm_common/convolution/img2col_helper.h index b830eee2..58df029f 100644 --- a/dnn/src/arm_common/convolution/img2col_helper.h +++ b/dnn/src/arm_common/convolution/img2col_helper.h @@ -14,10 +14,10 @@ namespace { template -void img2col_stride(const dtype* __restrict src, - dtype* __restrict dst, const int OC, const int OH, - const int OW, const int IC, const int IH, const int IW, - const int FH, const int FW, const int SH, const int SW) { +void img2col_stride( + const dtype* __restrict src, dtype* __restrict dst, const int OC, const int OH, + const int OW, const int IC, const int IH, const int IW, const int FH, + const int FW, const int SH, const int SW) { (void)OC; size_t i = 0; rep(ic, IC) { @@ -33,8 +33,9 @@ void img2col_stride(const dtype* __restrict src, fh2 = FH - fh - 1; fw2 = FW - fw - 1; } - dst[i++] = src[ic * IH * IW + (oh * SH + fh2) * IW + - (ow * SW + fw2)]; + dst[i++] = + src[ic * IH * IW + (oh * SH + fh2) * IW + + (ow * SW + fw2)]; } } } @@ -43,8 +44,9 @@ void img2col_stride(const dtype* __restrict src, } template -void img2col(const dtype* src, dtype* dst, size_t /* OC */, size_t OH, - size_t OW, size_t IC, size_t IH, size_t IW, size_t FH, size_t FW) { +void img2col( + const dtype* src, dtype* dst, size_t /* OC */, size_t OH, size_t OW, size_t IC, + size_t IH, size_t IW, size_t FH, size_t FW) { size_t offset = (4 - OW % 4) % 4; size_t i = 0; rep(ic, IC) { @@ -61,14 +63,10 @@ void img2col(const dtype* src, dtype* dst, size_t /* OC */, size_t OH, fh2 = FH - fh - 1; fw2 = FW - fw - 1; } - dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + - (ow + fw2) + 0]; - dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + - (ow + fw2) + 1]; - dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + - (ow + fw2) + 2]; - dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + - (ow + fw2) + 3]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + (ow + fw2) + 0]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + (ow + fw2) + 1]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + (ow + fw2) + 2]; + dst[i++] = src[ic * IH * IW + (oh + fh2) * IW + (ow + fw2) + 3]; } i -= offset; } diff --git a/dnn/src/arm_common/convolution/int8x8x32/algos.cpp b/dnn/src/arm_common/convolution/int8x8x32/algos.cpp index ea9a2a76..abcc4192 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/algos.cpp +++ b/dnn/src/arm_common/convolution/int8x8x32/algos.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/convolution/img2col_helper.h" #include "src/arm_common/convolution/int8x8x32/algos.h" +#include "src/arm_common/convolution/img2col_helper.h" #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" #include "src/common/opr_delegate.h" @@ -26,30 +26,30 @@ using namespace arm_common; /* ===================== ConvolutionBackwardData ===================== */ /* ===================== direct stride 1 algo ===================== */ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( - fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const { - if (!cpuinfo_has_arm_neon_dot()){ + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + if (!cpuinfo_has_arm_neon_dot()) { return false; } return deconv::can_stride1_int8x8x32_dot(param); } size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace( - fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, - midout_iv("AlgoSdot8DirectStride1::get_workspace"_hash)) { + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_conv_int8832_kimpl, + midout_iv("AlgoSdot8DirectStride1::get_workspace"_hash)) { return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param); } MIDOUT_END(); return 0; } -ConvolutionBackwardDataImpl::ncb_kern_t -ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( - fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { - MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, - midout_iv("AlgoSdot8DirectStride1::dispatch_kern"_hash)) { +ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl:: + AlgoSdot8DirectStride1::dispatch_kern( + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + MIDOUT_BEGIN( + megdnn_arm_conv_int8832_kimpl, + midout_iv("AlgoSdot8DirectStride1::dispatch_kern"_hash)) { return deconv::stride1_int8x8x32_dot; } MIDOUT_END(); @@ -58,30 +58,30 @@ ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( /* ===================== direct stride 2 algo ===================== */ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( - fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const { - if (!cpuinfo_has_arm_neon_dot()){ + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + if (!cpuinfo_has_arm_neon_dot()) { return false; } return deconv::can_stride2_int8x8x32_dot(param); } size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace( - fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, - midout_iv("AlgoSdot8DirectStride2::get_workspace"_hash)) { + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_conv_int8832_kimpl, + midout_iv("AlgoSdot8DirectStride2::get_workspace"_hash)) { return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param); } MIDOUT_END(); return 0; } -ConvolutionBackwardDataImpl::ncb_kern_t -ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern( - fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { - MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, - midout_iv("AlgoSdot8DirectStride2::dispatch_kern"_hash)) { +ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl:: + AlgoSdot8DirectStride2::dispatch_kern( + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + MIDOUT_BEGIN( + megdnn_arm_conv_int8832_kimpl, + midout_iv("AlgoSdot8DirectStride2::dispatch_kern"_hash)) { return deconv::stride2_int8x8x32_dot; } MIDOUT_END(); diff --git a/dnn/src/arm_common/convolution/int8x8x32/algos.h b/dnn/src/arm_common/convolution/int8x8x32/algos.h index 7a71a9de..934a9593 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/algos.h +++ b/dnn/src/arm_common/convolution/int8x8x32/algos.h @@ -20,45 +20,39 @@ namespace arm_common { #if MGB_ENABLE_DOT /* ===================== ConvolutionBackwardData ===================== */ -class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final - : public AlgoBase { +class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } - const char* name() const override { - return "AARCH32_I8x8x32_DECONV_STRIDE1"; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "AARCH32_I8x8x32_DECONV_STRIDE1"; } - bool usable(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const override; + bool usable(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) + const override; - size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const override; + size_t get_workspace( + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; - ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam&) const override; + ncb_kern_t dispatch_kern( + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam&) const override; MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_INT8X8X32) }; -class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final - : public AlgoBase { +class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } - const char* name() const override { - return "AARCH32_I8x8x32_DECONV_STRIDE2"; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "AARCH32_I8x8x32_DECONV_STRIDE2"; } - bool usable(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const override; + bool usable(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) + const override; - size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const override; + size_t get_workspace( + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; - ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam&) const override; + ncb_kern_t dispatch_kern( + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam&) const override; MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_INT8X8X32) }; diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp index c03576be..98a9bfc3 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp @@ -11,8 +11,8 @@ #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" #if MGB_ENABLE_DOT -#include "src/common/utils.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" using namespace megdnn; using namespace arm_common; @@ -34,9 +34,9 @@ bool need_src_copy(const NCBKernSizeParam& param) { return FH > PH + 1 || FW > PW + 1 || need_dst_copy(param); } -void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, - size_t FW, size_t PH, size_t PW, size_t& IH2, - size_t& IW2, size_t& OW2) { +void get_rectified_size( + size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH, + size_t PW, size_t& IH2, size_t& IW2, size_t& OW2) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); MEGDNN_MARK_USED_VAR(IW); @@ -93,15 +93,14 @@ inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) { _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { +void deconv_direct_2x2( + const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; - const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, - 2, 3, 16, 16, 3, 4, 16, 16}; - const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, - 6, 7, 16, 16, 7, 8, 16, 16}; + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, 6, 7, 16, 16, 7, 8, 16, 16}; rep(ic, IC) { const int8_t* src_ptr = src; int32_t* dst_ptr = dst + OW * OH * ic; @@ -118,8 +117,7 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, vdupq_n_s32(*reinterpret_cast(k0))); uint8x16_t _idx_k = {3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0}; int8x16_t _k = vqtbl1q_s8_common(_k0, _idx_k); - uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, - 0, 1, 16, 16, 0, 1, 16, 16}; + uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; int8x16_t _k1 = vqtbl1q_s8_common(_k, _idx); _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; int8x16_t _k23 = vqtbl1q_s8_common(_k, _idx); @@ -147,13 +145,13 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, int8x16_t _r11 = vextq_s8(_r10, _r11_, 1); int8x16_t _r21 = vextq_s8(_r20, _r21_, 1); - int16x8x2_t r_00 = vzipq_s16(vreinterpretq_s16_s8(_r00), - vreinterpretq_s16_s8(_r10)); + int16x8x2_t r_00 = vzipq_s16( + vreinterpretq_s16_s8(_r00), vreinterpretq_s16_s8(_r10)); int8x16_t _r0 = r_00.val[0]; int8x16_t _r2 = r_00.val[1]; - int16x8x2_t r_11 = vzipq_s16(vreinterpretq_s16_s8(_r01), - vreinterpretq_s16_s8(_r11)); + int16x8x2_t r_11 = vzipq_s16( + vreinterpretq_s16_s8(_r01), vreinterpretq_s16_s8(_r11)); int8x16_t _r1 = r_11.val[0]; int8x16_t _r3 = r_11.val[1]; @@ -162,13 +160,13 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, _sum01.val[0] = vdotq_s32(_sum01.val[0], _k, _r2); _sum01.val[1] = vdotq_s32(_sum01.val[1], _k, _r3); - r_00 = vzipq_s16(vreinterpretq_s16_s8(_r10), - vreinterpretq_s16_s8(_r20)); + r_00 = vzipq_s16( + vreinterpretq_s16_s8(_r10), vreinterpretq_s16_s8(_r20)); _r0 = r_00.val[0]; _r2 = r_00.val[1]; - r_11 = vzipq_s16(vreinterpretq_s16_s8(_r11), - vreinterpretq_s16_s8(_r21)); + r_11 = vzipq_s16( + vreinterpretq_s16_s8(_r11), vreinterpretq_s16_s8(_r21)); _r1 = r_11.val[0]; _r3 = r_11.val[1]; @@ -262,13 +260,13 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, int8x16_t _r01 = vextq_s8(_r00, _r01_, 1); int8x16_t _r11 = vextq_s8(_r10, _r11_, 1); - int16x8x2_t r_00 = vzipq_s16(vreinterpretq_s16_s8(_r00), - vreinterpretq_s16_s8(_r10)); + int16x8x2_t r_00 = vzipq_s16( + vreinterpretq_s16_s8(_r00), vreinterpretq_s16_s8(_r10)); int8x16_t _r0 = r_00.val[0]; int8x16_t _r2 = r_00.val[1]; - int16x8x2_t r_11 = vzipq_s16(vreinterpretq_s16_s8(_r01), - vreinterpretq_s16_s8(_r11)); + int16x8x2_t r_11 = vzipq_s16( + vreinterpretq_s16_s8(_r01), vreinterpretq_s16_s8(_r11)); int8x16_t _r1 = r_11.val[0]; int8x16_t _r3 = r_11.val[1]; @@ -328,15 +326,14 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, } MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { +void deconv_direct_3x3( + const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; - const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, - 2, 3, 4, 16, 3, 4, 5, 16}; - const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; rep(ic, IC) { @@ -531,8 +528,9 @@ void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { +void deconv_direct_5x5( + const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; @@ -779,14 +777,14 @@ void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, } MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { +void deconv_direct_7x7( + const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW; const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; - const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; @@ -1072,7 +1070,6 @@ void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, } // anonymous namespace - size_t deconv::get_workspace_in_bytes_stride1_int8x8x32_dot( const NCBKernSizeParam& param) { return get_bundle(param).total_size_in_bytes(); @@ -1082,20 +1079,18 @@ bool deconv::can_stride1_int8x8x32_dot(const NCBKernSizeParam& param) { auto&& fm = param.filter_meta; auto FH = fm.spatial[0], FW = fm.spatial[1], OC = fm.ocpg, IC = fm.icpg, PH = fm.padding[0], PW = fm.padding[1]; - bool avaiable = fm.format == param::Convolution::Format::NCHW && - !fm.should_flip && fm.spatial_ndim == 2 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 1 && fm.stride[1] == 1 && FH == FW && - (FH == 2 || FH == 3 || FH == 5 || FH == 7) && + bool avaiable = fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == FW && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && FH >= PH + 1 && FW >= PW + 1; avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 || param.filter_type.enumv() == DTypeEnum::Int8) && (param.grad_type.enumv() == DTypeEnum::QuantizedS32 || param.grad_type.enumv() == DTypeEnum::Int32); - return avaiable && - ((FH == 2 && OC <= 8) || - ((FH == 3 || FH == 5 || FH == 7) && (IC < 32 && OC <= 16))); + return avaiable && ((FH == 2 && OC <= 8) || + ((FH == 3 || FH == 5 || FH == 7) && (IC < 32 && OC <= 16))); } void deconv::stride1_int8x8x32_dot(const NCBKernParam& param) { @@ -1109,8 +1104,9 @@ void deconv::stride1_int8x8x32_dot(const NCBKernParam& param) { get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); - using Func = std::function; + using Func = std::function; Func conv = nullptr; if (FH == 2) { conv = deconv_direct_2x2; @@ -1151,23 +1147,22 @@ void deconv::stride1_int8x8x32_dot(const NCBKernParam& param) { if (need_src_copy_var) { // copy sptr_ori to sptr_copied std::memset(sptr_copied, 0, sizeof(int8_t) * IH2 * IW2); - copy_plane_in_bytes(sptr_copied + padding_h * IW2 + padding_w, - sptr_ori + oc * IH * IW, IH, - IW * sizeof(int8_t), IW2 * sizeof(int8_t), - IW * sizeof(int8_t)); + copy_plane_in_bytes( + sptr_copied + padding_h * IW2 + padding_w, + sptr_ori + oc * IH * IW, IH, IW * sizeof(int8_t), + IW2 * sizeof(int8_t), IW * sizeof(int8_t)); sptr = sptr_copied; } else { sptr = sptr_ori + oc * IH * IW; } - conv(sptr, fptr + oc * IC * FH * FW, dptr, IH2, IW2, OH, OW_real, - IC); + conv(sptr, fptr + oc * IC * FH * FW, dptr, IH2, IW2, OH, OW_real, IC); } if (need_dst_copy_var) { for (size_t ic = 0; ic < IC; ++ic) { - copy_plane_in_bytes(dptr_ori + ic * OH * OW, - dptr + ic * OH * OW2, OH, - OW * sizeof(int32_t), OW * sizeof(int32_t), - OW2 * sizeof(int32_t)); + copy_plane_in_bytes( + dptr_ori + ic * OH * OW, dptr + ic * OH * OW2, OH, + OW * sizeof(int32_t), OW * sizeof(int32_t), + OW2 * sizeof(int32_t)); } } } diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h index 6c661639..06986077 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h @@ -27,8 +27,7 @@ bool can_stride1_int8x8x32_dot(const NCBKernSizeParam& param); void stride1_int8x8x32_dot(const NCBKernParam& param); -size_t get_workspace_in_bytes_stride1_int8x8x32_dot( - const NCBKernSizeParam& param); +size_t get_workspace_in_bytes_stride1_int8x8x32_dot(const NCBKernSizeParam& param); } // namespace deconv } // namespace arm_common diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp index 8fbb0b85..6994d826 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp @@ -11,8 +11,8 @@ #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" #if MGB_ENABLE_DOT -#include "src/common/utils.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" using namespace megdnn; using namespace arm_common; @@ -28,9 +28,9 @@ bool need_dst_copy(const NCBKernSizeParam& param) { return false; } -void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, - size_t FW, size_t PH, size_t PW, size_t& IH2, - size_t& IW2, size_t& OW2) { +void get_rectified_size( + size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH, + size_t PW, size_t& IH2, size_t& IW2, size_t& OW2) { MEGDNN_MARK_USED_VAR(OH); MEGDNN_MARK_USED_VAR(IW); //! OW should be a multiple of 4 @@ -82,15 +82,14 @@ inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) { template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { +void deconv_direct_2x2( + const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW / 2; - const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, - 2, 3, 16, 16, 3, 4, 16, 16}; - const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, - 6, 7, 16, 16, 7, 8, 16, 16}; + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, 6, 7, 16, 16, 7, 8, 16, 16}; uint8x16_t _idx_r_0, _idx_r_1; if (even) { _idx_r_0 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; @@ -115,8 +114,7 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, vdupq_n_s32(*reinterpret_cast(k0))); uint8x16_t _idx_k = {3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0}; int8x16_t _k = vqtbl1q_s8_common(_k0, _idx_k); - uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, - 0, 1, 16, 16, 0, 1, 16, 16}; + uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; int8x16_t _k1 = vqtbl1q_s8_common(_k, _idx); _idx = {2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16, 2, 3, 16, 16}; int8x16_t _k23 = vqtbl1q_s8_common(_k, _idx); @@ -143,13 +141,13 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, int8x16_t _r20 = vqtbl1q_s8_common(_r2_ori, _idx_r_0); int8x16_t _r21 = vqtbl1q_s8_common(_r2_ori, _idx_r_1); - int16x8x2_t r_00 = vzipq_s16(vreinterpretq_s16_s8(_r00), - vreinterpretq_s16_s8(_r10)); + int16x8x2_t r_00 = vzipq_s16( + vreinterpretq_s16_s8(_r00), vreinterpretq_s16_s8(_r10)); int8x16_t _r0 = r_00.val[0]; int8x16_t _r2 = r_00.val[1]; - int16x8x2_t r_11 = vzipq_s16(vreinterpretq_s16_s8(_r01), - vreinterpretq_s16_s8(_r11)); + int16x8x2_t r_11 = vzipq_s16( + vreinterpretq_s16_s8(_r01), vreinterpretq_s16_s8(_r11)); int8x16_t _r1 = r_11.val[0]; int8x16_t _r3 = r_11.val[1]; @@ -158,13 +156,13 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, _sum01.val[0] = vdotq_s32(_sum01.val[0], _k, _r2); _sum01.val[1] = vdotq_s32(_sum01.val[1], _k, _r3); - r_00 = vzipq_s16(vreinterpretq_s16_s8(_r10), - vreinterpretq_s16_s8(_r20)); + r_00 = vzipq_s16( + vreinterpretq_s16_s8(_r10), vreinterpretq_s16_s8(_r20)); _r0 = r_00.val[0]; _r2 = r_00.val[1]; - r_11 = vzipq_s16(vreinterpretq_s16_s8(_r11), - vreinterpretq_s16_s8(_r21)); + r_11 = vzipq_s16( + vreinterpretq_s16_s8(_r11), vreinterpretq_s16_s8(_r21)); _r1 = r_11.val[0]; _r3 = r_11.val[1]; @@ -263,13 +261,13 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, int8x16_t _r10 = vqtbl1q_s8_common(_r1_ori, _idx_r_0); int8x16_t _r11 = vqtbl1q_s8_common(_r1_ori, _idx_r_1); - int16x8x2_t r_00 = vzipq_s16(vreinterpretq_s16_s8(_r00), - vreinterpretq_s16_s8(_r10)); + int16x8x2_t r_00 = vzipq_s16( + vreinterpretq_s16_s8(_r00), vreinterpretq_s16_s8(_r10)); int8x16_t _r0 = r_00.val[0]; int8x16_t _r2 = r_00.val[1]; - int16x8x2_t r_11 = vzipq_s16(vreinterpretq_s16_s8(_r01), - vreinterpretq_s16_s8(_r11)); + int16x8x2_t r_11 = vzipq_s16( + vreinterpretq_s16_s8(_r01), vreinterpretq_s16_s8(_r11)); int8x16_t _r1 = r_11.val[0]; int8x16_t _r3 = r_11.val[1]; @@ -334,15 +332,14 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { +void deconv_direct_3x3( + const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW / 2; - const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, - 2, 3, 4, 16, 3, 4, 5, 16}; - const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; @@ -559,8 +556,9 @@ void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { +void deconv_direct_5x5( + const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW / 2; @@ -837,14 +835,14 @@ void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { +void deconv_direct_7x7( + const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); const size_t tail_step = IW - OW / 2; const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; - const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; @@ -1173,13 +1171,12 @@ size_t deconv::get_workspace_in_bytes_stride2_int8x8x32_dot( bool deconv::can_stride2_int8x8x32_dot(const NCBKernSizeParam& param) { auto&& fm = param.filter_meta; - auto FH = fm.spatial[0], FW = fm.spatial[1], OC = fm.ocpg, - PH = fm.padding[0], PW = fm.padding[1]; - bool avaiable = fm.format == param::Convolution::Format::NCHW && - !fm.should_flip && fm.spatial_ndim == 2 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 2 && fm.stride[1] == 2 && FH == FW && - (FH == 2 || FH == 3 || FH == 5 || FH == 7) && + auto FH = fm.spatial[0], FW = fm.spatial[1], OC = fm.ocpg, PH = fm.padding[0], + PW = fm.padding[1]; + bool avaiable = fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == FW && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && FH >= PH + 1 && FW >= PW + 1; avaiable &= (param.filter_type.enumv() == DTypeEnum::QuantizedS8 || @@ -1201,8 +1198,9 @@ void deconv::stride2_int8x8x32_dot(const NCBKernParam& param) { get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); - using Func = std::function; + using Func = std::function; Func conv = nullptr; if (FH == 2) { if ((padding_w & 1) == 0) @@ -1251,21 +1249,20 @@ void deconv::stride2_int8x8x32_dot(const NCBKernParam& param) { int8_t* sptr = nullptr; rep(oc, OC) { std::memset(sptr_copied, 0, sizeof(int8_t) * IH2 * IW2); - copy_plane_in_bytes(sptr_copied + padding_h * IW2 + padding_w / 2, - sptr_ori + oc * IH * IW, IH, - IW * sizeof(int8_t), 2 * IW2 * sizeof(int8_t), - IW * sizeof(int8_t)); + copy_plane_in_bytes( + sptr_copied + padding_h * IW2 + padding_w / 2, + sptr_ori + oc * IH * IW, IH, IW * sizeof(int8_t), + 2 * IW2 * sizeof(int8_t), IW * sizeof(int8_t)); sptr = sptr_copied; - conv(sptr, fptr + oc * IC * FH * FW, dptr, IH2, IW2, OH, OW_real, - IC); + conv(sptr, fptr + oc * IC * FH * FW, dptr, IH2, IW2, OH, OW_real, IC); } if (need_dst_copy_var) { for (size_t ic = 0; ic < IC; ++ic) { - copy_plane_in_bytes(dptr_ori + ic * OH * OW, - dptr + ic * OH * OW2, OH, - OW * sizeof(int32_t), OW * sizeof(int32_t), - OW2 * sizeof(int32_t)); + copy_plane_in_bytes( + dptr_ori + ic * OH * OW, dptr + ic * OH * OW2, OH, + OW * sizeof(int32_t), OW * sizeof(int32_t), + OW2 * sizeof(int32_t)); } } } diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h index b04b2ed2..a89c88e5 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h @@ -29,7 +29,7 @@ void stride2_int8x8x32_dot(const NCBKernParam& param); size_t get_workspace_in_bytes_stride2_int8x8x32_dot(const NCBKernSizeParam& param); -} // namespace convolution +} // namespace deconv } // namespace arm_common } // namespace megdnn #endif diff --git a/dnn/src/arm_common/convolution/opr_impl.cpp b/dnn/src/arm_common/convolution/opr_impl.cpp index 0131037b..afef91e1 100644 --- a/dnn/src/arm_common/convolution/opr_impl.cpp +++ b/dnn/src/arm_common/convolution/opr_impl.cpp @@ -14,14 +14,13 @@ #include "./quint8/algos.h" #include "src/common/metahelper.h" +#include "src/common/opr_delegate.h" #include "src/common/utils.h" #include "src/naive/handle.h" -#include "src/common/opr_delegate.h" using namespace megdnn; using namespace arm_common; - /* ===================== ConvolutionBackwardData ===================== */ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { #if MGB_ENABLE_DOT @@ -32,8 +31,7 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { #endif fallback::ConvolutionBackwardDataImpl::AlgoBase::Mapper m_all_algos_map; - SmallVector - m_all_algos; + SmallVector m_all_algos; public: AlgoPack() { @@ -49,15 +47,14 @@ public: } } - const SmallVector& - all_algos() const { + const SmallVector& all_algos() + const { return m_all_algos; } const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; -const ConvolutionBackwardDataImpl::AlgoPack& -ConvolutionBackwardDataImpl::algo_pack() { +const ConvolutionBackwardDataImpl::AlgoPack& ConvolutionBackwardDataImpl::algo_pack() { static AlgoPack algo_pack; return algo_pack; } @@ -67,19 +64,18 @@ MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl) SmallVector ConvolutionBackwardDataImpl::get_all_packed_algo() { auto&& algos = fallback::ConvolutionBackwardDataImpl::get_all_packed_algo(); - algos.insert(algos.begin(), algo_pack().all_algos().begin(), - algo_pack().all_algos().end()); + algos.insert( + algos.begin(), algo_pack().all_algos().begin(), + algo_pack().all_algos().end()); return std::move(algos); } -ConvolutionBackwardDataImpl::ncb_kern_t -ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( - Algorithm* algo, const NCBKernSizeParam& param) { +ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl:: + ncb_1g_dispatch_kern(Algorithm* algo, const NCBKernSizeParam& param) { if (algo->handle_type() == Handle::HandleType::ARM_COMMON) { return static_cast(algo)->dispatch_kern(this, param); } - return fallback::ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(algo, - param); + return fallback::ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(algo, param); } size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( @@ -87,8 +83,7 @@ size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( if (algo->handle_type() == Handle::HandleType::ARM_COMMON) { return static_cast(algo)->get_workspace(this, param); } - return fallback::ConvolutionBackwardDataImpl::ncb_1g_get_workspace(algo, - param); + return fallback::ConvolutionBackwardDataImpl::ncb_1g_get_workspace(algo, param); } const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const { diff --git a/dnn/src/arm_common/convolution/opr_impl.h b/dnn/src/arm_common/convolution/opr_impl.h index d0e66122..30d02e2f 100644 --- a/dnn/src/arm_common/convolution/opr_impl.h +++ b/dnn/src/arm_common/convolution/opr_impl.h @@ -9,17 +9,16 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once +#include "src/arm_common/conv_bias/opr_impl.h" #include "src/common/utils.h" #include "src/fallback/convolution/opr_impl.h" -#include "src/arm_common/conv_bias/opr_impl.h" namespace megdnn { namespace arm_common { class ConvBiasImpl; -class ConvolutionBackwardDataImpl - : public fallback::ConvolutionBackwardDataImpl { +class ConvolutionBackwardDataImpl : public fallback::ConvolutionBackwardDataImpl { public: using fallback::ConvolutionBackwardDataImpl::ConvolutionBackwardDataImpl; @@ -32,25 +31,27 @@ protected: AlgoBase() : fallback::ConvolutionBackwardDataImpl::AlgoBase() { m_handle_type = Handle::HandleType::ARM_COMMON; } - virtual bool usable(fallback::ConvolutionBackwardDataImpl* opr, - const NCBKernSizeParam& param) const = 0; - virtual size_t get_workspace(fallback::ConvolutionBackwardDataImpl* opr, - const NCBKernSizeParam& param) const = 0; + virtual bool usable( + fallback::ConvolutionBackwardDataImpl* opr, + const NCBKernSizeParam& param) const = 0; + virtual size_t get_workspace( + fallback::ConvolutionBackwardDataImpl* opr, + const NCBKernSizeParam& param) const = 0; virtual ncb_kern_t dispatch_kern( fallback::ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param) const = 0; }; - ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo, - const NCBKernSizeParam& param) override; + ncb_kern_t ncb_1g_dispatch_kern( + Algorithm* algo, const NCBKernSizeParam& param) override; - size_t ncb_1g_get_workspace(Algorithm* algo, - const NCBKernSizeParam& param) override; + size_t ncb_1g_get_workspace( + Algorithm* algo, const NCBKernSizeParam& param) override; const char* get_algorithm_set_name() const override; - SmallVector - get_all_packed_algo() override; + SmallVector get_all_packed_algo() + override; public: MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl); diff --git a/dnn/src/arm_common/convolution/quint8/algos.cpp b/dnn/src/arm_common/convolution/quint8/algos.cpp index 01a191be..8cf47987 100644 --- a/dnn/src/arm_common/convolution/quint8/algos.cpp +++ b/dnn/src/arm_common/convolution/quint8/algos.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/convolution/img2col_helper.h" #include "src/arm_common/convolution/quint8/algos.h" +#include "src/arm_common/convolution/img2col_helper.h" #include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" #include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" #include "src/common/opr_delegate.h" @@ -28,31 +28,30 @@ using namespace arm_common; /* ===================== direct stride 1 algo ===================== */ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( - fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const { - - if (!cpuinfo_has_arm_neon_dot()){ + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + if (!cpuinfo_has_arm_neon_dot()) { return false; } return deconv::can_stride1_quint8_dot(param); } size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace( - fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, - midout_iv("AlgoUdot8DirectStride1::get_workspace"_hash)) { + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_conv_quint8_kimpl, + midout_iv("AlgoUdot8DirectStride1::get_workspace"_hash)) { return deconv::get_workspace_in_bytes_stride1_quint8_dot(param); } MIDOUT_END(); return 0; } -ConvolutionBackwardDataImpl::ncb_kern_t -ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( - fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { - MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, - midout_iv("AlgoUdot8DirectStride1::dispatch_kern"_hash)) { +ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl:: + AlgoUdot8DirectStride1::dispatch_kern( + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + MIDOUT_BEGIN( + megdnn_arm_conv_quint8_kimpl, + midout_iv("AlgoUdot8DirectStride1::dispatch_kern"_hash)) { return deconv::stride1_quint8_dot; } MIDOUT_END(); @@ -61,30 +60,30 @@ ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( /* ===================== direct stride 2 algo ===================== */ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( - fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const { - if (!cpuinfo_has_arm_neon_dot()){ + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + if (!cpuinfo_has_arm_neon_dot()) { return false; } return deconv::can_stride2_quint8_dot(param); } size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace( - fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const { - MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, - midout_iv("AlgoUdot8DirectStride2::get_workspace"_hash)) { + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + MIDOUT_BEGIN( + megdnn_arm_conv_quint8_kimpl, + midout_iv("AlgoUdot8DirectStride2::get_workspace"_hash)) { return deconv::get_workspace_in_bytes_stride2_quint8_dot(param); } MIDOUT_END(); return 0; } -ConvolutionBackwardDataImpl::ncb_kern_t -ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern( - fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { - MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, - midout_iv("AlgoUdot8DirectStride2::dispatch_kern"_hash)) { +ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl:: + AlgoUdot8DirectStride2::dispatch_kern( + fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { + MIDOUT_BEGIN( + megdnn_arm_conv_quint8_kimpl, + midout_iv("AlgoUdot8DirectStride2::dispatch_kern"_hash)) { return deconv::stride2_quint8_dot; } MIDOUT_END(); diff --git a/dnn/src/arm_common/convolution/quint8/algos.h b/dnn/src/arm_common/convolution/quint8/algos.h index 44b7d6c4..7380d6d8 100644 --- a/dnn/src/arm_common/convolution/quint8/algos.h +++ b/dnn/src/arm_common/convolution/quint8/algos.h @@ -19,46 +19,44 @@ namespace arm_common { #if MGB_ENABLE_DOT /* ===================== ConvolutionBackwardData ===================== */ -class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final - : public AlgoBase { +class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1"; } - bool usable(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const override; + bool usable(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) + const override; - size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const override; + size_t get_workspace( + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; - ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam&) const override; + ncb_kern_t dispatch_kern( + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam&) const override; MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD1_DOT_QU8) }; -class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final - : public AlgoBase { +class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2"; } - bool usable(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const override; + bool usable(fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) + const override; - size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam& param) const override; + size_t get_workspace( + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam& param) const override; - ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, - const NCBKernSizeParam&) const override; + ncb_kern_t dispatch_kern( + fallback::ConvolutionBackwardDataImpl*, + const NCBKernSizeParam&) const override; MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_DOT_QU8) }; diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp index b936bd74..53fb656d 100644 --- a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp @@ -11,8 +11,8 @@ #include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" #if MGB_ENABLE_DOT -#include "src/common/utils.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" using namespace megdnn; using namespace arm_common; @@ -21,7 +21,7 @@ using namespace deconv; namespace { #define SHIFT_BITS 30 -#define SHIFT (1 << SHIFT_BITS) +#define SHIFT (1 << SHIFT_BITS) bool need_dst_copy(const NCBKernSizeParam& param) { if (param.osz[1] % 4 != 0) { @@ -37,9 +37,9 @@ bool need_src_copy(const NCBKernSizeParam& param) { return FH > PH + 1 || FW > PW + 1 || need_dst_copy(param); } -void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, - size_t FW, size_t PH, size_t PW, size_t& IH2, - size_t& IW2, size_t& OW2) { +void get_rectified_size( + size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH, + size_t PW, size_t& IH2, size_t& IW2, size_t& OW2) { MEGDNN_MARK_USED_VAR(OH); MEGDNN_MARK_USED_VAR(IW); MEGDNN_MARK_USED_VAR(PW); @@ -79,9 +79,8 @@ inline uint8x16_t vqtbl1q_u8_common(uint8x16_t a, uint8x16_t index) { return r; } -#define CALC_DST(_sum) \ - _sum = vreinterpretq_u32_s32( \ - vaddq_s32(vreinterpretq_s32_u32(_sum), _shift_zp)); +#define CALC_DST(_sum) \ + _sum = vreinterpretq_u32_s32(vaddq_s32(vreinterpretq_s32_u32(_sum), _shift_zp)); #define CALC_0(_k_idx, _c_idx) \ _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ @@ -95,23 +94,21 @@ inline uint8x16_t vqtbl1q_u8_common(uint8x16_t a, uint8x16_t index) { _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); -#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k1_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k1_idx)); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k2_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k2_idx)); \ +#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k1_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k1_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k2_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k2_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, - uint8_t src_zp, uint8_t filter_zp, - int32_t src_filter_zp) { +void deconv_direct_2x2( + const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IC); const size_t tail_step = IW - OW; @@ -120,10 +117,8 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); - const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, - 2, 3, 16, 16, 3, 4, 16, 16}; - const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, - 6, 7, 16, 16, 7, 8, 16, 16}; + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, 6, 7, 16, 16, 7, 8, 16, 16}; const uint8_t* src_ptr = src; //! here we use uint32_t for calc uint32_t* outptr = reinterpret_cast(dst); @@ -135,8 +130,8 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, const uint8_t* k0 = filter; - uint8x16_t _k0 = vreinterpretq_u8_u32( - vdupq_n_u32(*reinterpret_cast(k0))); + uint8x16_t _k0 = + vreinterpretq_u8_u32(vdupq_n_u32(*reinterpret_cast(k0))); uint8x16_t _idx_k = {3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0}; uint8x16_t _k = vqtbl1q_u8_common(_k0, _idx_k); uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; @@ -172,13 +167,13 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _r11 = vextq_u8(_r10, _r11_, 1); uint8x16_t _r21 = vextq_u8(_r20, _r21_, 1); - int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), - vreinterpretq_s16_u8(_r10)); + int16x8x2_t r_0 = + vzipq_s16(vreinterpretq_s16_u8(_r00), vreinterpretq_s16_u8(_r10)); uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); - int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), - vreinterpretq_s16_u8(_r11)); + int16x8x2_t r_1 = + vzipq_s16(vreinterpretq_s16_u8(_r01), vreinterpretq_s16_u8(_r11)); uint8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); uint8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); @@ -187,13 +182,11 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, SUB_ZP(_sum01.val[0], _r2); SUB_ZP(_sum01.val[1], _r3); - r_0 = vzipq_s16(vreinterpretq_s16_u8(_r10), - vreinterpretq_s16_u8(_r20)); + r_0 = vzipq_s16(vreinterpretq_s16_u8(_r10), vreinterpretq_s16_u8(_r20)); _r0 = vreinterpretq_u8_s8(r_0.val[0]); _r2 = vreinterpretq_u8_s8(r_0.val[1]); - r_1 = vzipq_s16(vreinterpretq_s16_u8(_r11), - vreinterpretq_s16_u8(_r21)); + r_1 = vzipq_s16(vreinterpretq_s16_u8(_r11), vreinterpretq_s16_u8(_r21)); _r1 = vreinterpretq_u8_s8(r_1.val[0]); _r3 = vreinterpretq_u8_s8(r_1.val[1]); @@ -307,13 +300,13 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _r01 = vextq_u8(_r00, _r01_, 1); uint8x16_t _r11 = vextq_u8(_r10, _r11_, 1); - int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), - vreinterpretq_s16_u8(_r10)); + int16x8x2_t r_0 = + vzipq_s16(vreinterpretq_s16_u8(_r00), vreinterpretq_s16_u8(_r10)); uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); - int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), - vreinterpretq_s16_u8(_r11)); + int16x8x2_t r_1 = + vzipq_s16(vreinterpretq_s16_u8(_r01), vreinterpretq_s16_u8(_r11)); uint8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); uint8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); @@ -385,10 +378,10 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, - uint8_t src_zp, uint8_t filter_zp, - int32_t src_filter_zp) { +void deconv_direct_3x3( + const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IC); const size_t tail_step = IW - OW; @@ -397,10 +390,8 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); - const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, - 2, 3, 4, 16, 3, 4, 5, 16}; - const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; const uint8_t* src_ptr = src; @@ -589,58 +580,50 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, #undef CALC_1 #undef CALC_2 -#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ +#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); -#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k00_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ - _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k01_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ +#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k00_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k01_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); -#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ - _elem2 = vdotq2_u32(_filter_zp, _elem); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k10_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k10_idx)); \ - _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ - _elem2 = vdotq2_u32(_filter_zp, _elem); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k11_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k11_idx)); \ +#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k10_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k10_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k11_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k11_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, - uint8_t src_zp, uint8_t filter_zp, - int32_t src_filter_zp) { +void deconv_direct_5x5( + const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IC); const size_t tail_step = IW - OW; @@ -676,8 +659,7 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _k = vld1q_u8(k0 + 9); //! filter row 1 - uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, - 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; uint8x16_t _k123 = vqtbl1q_u8_common(_k, _idx); _idx = {11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16}; uint8x16_t _k4 = vqtbl1q_u8_common(_k, _idx); @@ -909,10 +891,10 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, - uint8_t src_zp, uint8_t filter_zp, - int32_t src_filter_zp) { +void deconv_direct_7x7( + const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IC); const size_t tail_step = IW - OW; @@ -922,8 +904,7 @@ void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; - const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; @@ -947,8 +928,7 @@ void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _k = vld1q_u8(k0 + 33); //! filter row 1 - uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, - 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; uint8x16_t _k123 = vqtbl1q_u8_common(_k, _idx); _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; uint8x16_t _k456 = vqtbl1q_u8_common(_k, _idx); @@ -1222,7 +1202,6 @@ void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, } // anonymous namespace - size_t deconv::get_workspace_in_bytes_stride1_quint8_dot( const NCBKernSizeParam& param) { return get_bundle(param).total_size_in_bytes(); @@ -1230,17 +1209,17 @@ size_t deconv::get_workspace_in_bytes_stride1_quint8_dot( bool deconv::can_stride1_quint8_dot(const NCBKernSizeParam& param) { auto&& fm = param.filter_meta; - auto FH = fm.spatial[0], FW = fm.spatial[1], OC = fm.ocpg, - PH = fm.padding[0], PW = fm.padding[1]; - bool avaiable = fm.format == param::Convolution::Format::NCHW && - !fm.should_flip && fm.spatial_ndim == 2 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 1 && fm.stride[1] == 1 && FH == FW && - (FH == 2 || FH == 3 || FH == 5 || FH == 7) && + auto FH = fm.spatial[0], FW = fm.spatial[1], OC = fm.ocpg, PH = fm.padding[0], + PW = fm.padding[1]; + bool avaiable = fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1 && + FH == FW && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && FH >= PH + 1 && FW >= PW + 1; - avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || - param.grad_type.enumv() == DTypeEnum::Int32); + avaiable &= + (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || + param.grad_type.enumv() == DTypeEnum::Int32); /** * \note In the kernel, we use int32_t to calc the value, in order @@ -1265,15 +1244,14 @@ void deconv::stride1_quint8_dot(const NCBKernParam& param) { int padding_h = FH - PH - 1, padding_w = FW - PW - 1; get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); - uint8_t filter_zp = - param.filter_type.param().zero_point; + uint8_t filter_zp = param.filter_type.param().zero_point; uint8_t src_zp = param.diff_type.param().zero_point; int32_t src_filter_zp = static_cast(filter_zp) * static_cast(src_zp) * OC * FH * FH; - using Func = std::function; + using Func = std::function; Func deconv = nullptr, deconv_last_oc = nullptr; if (FH == 2) { deconv = deconv_direct_2x2; @@ -1293,11 +1271,10 @@ void deconv::stride1_quint8_dot(const NCBKernParam& param) { bool need_src_copy_var = need_src_copy(param); bool need_dst_copy_var = need_dst_copy(param); - uint8_t* base_src_ptr = reinterpret_cast( - const_cast(param.diff())); + uint8_t* base_src_ptr = + reinterpret_cast(const_cast(param.diff())); int32_t* base_dst_ptr = reinterpret_cast(param.grad()); - const uint8_t* fptr = - reinterpret_cast(param.filter()); + const uint8_t* fptr = reinterpret_cast(param.filter()); for (size_t n = 0; n < N; ++n) { int32_t* dptr_copied = static_cast(bundle.get(1)); @@ -1320,10 +1297,10 @@ void deconv::stride1_quint8_dot(const NCBKernParam& param) { if (need_src_copy_var) { // copy sptr_ori to sptr_copied std::memset(sptr_copied, src_zp, sizeof(uint8_t) * IH2 * IW2); - copy_plane_in_bytes(sptr_copied + padding_h * IW2 + padding_w, - sptr_ori + oc * IH * IW, IH, - IW * sizeof(uint8_t), IW2 * sizeof(uint8_t), - IW * sizeof(uint8_t)); + copy_plane_in_bytes( + sptr_copied + padding_h * IW2 + padding_w, + sptr_ori + oc * IH * IW, IH, IW * sizeof(uint8_t), + IW2 * sizeof(uint8_t), IW * sizeof(uint8_t)); sptr = sptr_copied; } else { sptr = sptr_ori + oc * IH * IW; @@ -1333,11 +1310,12 @@ void deconv::stride1_quint8_dot(const NCBKernParam& param) { const uint8_t* filter = fptr + oc * IC * FH * FW; for (size_t ic = 0; ic < IC; ic++) { if (oc != OC - 1) { - deconv(sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, IC, - src_zp, filter_zp, src_filter_zp); + deconv(sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, IC, src_zp, + filter_zp, src_filter_zp); } else { - deconv_last_oc(sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, - IC, src_zp, filter_zp, src_filter_zp); + deconv_last_oc( + sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, IC, src_zp, + filter_zp, src_filter_zp); } dst_ptr += OH * OW_real; filter += FH * FH; @@ -1345,10 +1323,10 @@ void deconv::stride1_quint8_dot(const NCBKernParam& param) { } if (need_dst_copy_var) { for (size_t ic = 0; ic < IC; ++ic) { - copy_plane_in_bytes(dptr_ori + ic * OH * OW, - dptr + ic * OH * OW2, OH, - OW * sizeof(int32_t), OW * sizeof(int32_t), - OW2 * sizeof(int32_t)); + copy_plane_in_bytes( + dptr_ori + ic * OH * OW, dptr + ic * OH * OW2, OH, + OW * sizeof(int32_t), OW * sizeof(int32_t), + OW2 * sizeof(int32_t)); } } } diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp index 77da3c20..417938cc 100644 --- a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp @@ -11,8 +11,8 @@ #include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" #if MGB_ENABLE_DOT -#include "src/common/utils.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/utils.h" using namespace megdnn; using namespace arm_common; @@ -21,7 +21,7 @@ using namespace deconv; namespace { #define SHIFT_BITS 30 -#define SHIFT (1 << SHIFT_BITS) +#define SHIFT (1 << SHIFT_BITS) bool need_dst_copy(const NCBKernSizeParam& param) { if (param.osz[1] % 4 != 0) { @@ -31,9 +31,9 @@ bool need_dst_copy(const NCBKernSizeParam& param) { return false; } -void get_rectified_size(size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, - size_t FW, size_t PH, size_t PW, size_t& IH2, - size_t& IW2, size_t& OW2) { +void get_rectified_size( + size_t IH, size_t IW, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH, + size_t PW, size_t& IH2, size_t& IW2, size_t& OW2) { MEGDNN_MARK_USED_VAR(OH); MEGDNN_MARK_USED_VAR(IW); //! OW should be a multiple of 4 @@ -69,8 +69,7 @@ inline uint8x16_t vqtbl1q_u8_common(uint8x16_t a, uint8x16_t index) { return r; } -inline uint8x16_t vqtbx1q_u8_common(uint8x16_t a, uint8x16_t t, - uint8x16_t idx) { +inline uint8x16_t vqtbx1q_u8_common(uint8x16_t a, uint8x16_t t, uint8x16_t idx) { uint8x8x2_t _temp; _temp.val[0] = vget_low_u8(t); _temp.val[1] = vget_high_u8(t); @@ -80,9 +79,8 @@ inline uint8x16_t vqtbx1q_u8_common(uint8x16_t a, uint8x16_t t, return r; } -#define CALC_DST(_sum) \ - _sum = vreinterpretq_u32_s32( \ - vaddq_s32(vreinterpretq_s32_u32(_sum), _shift_zp)); +#define CALC_DST(_sum) \ + _sum = vreinterpretq_u32_s32(vaddq_s32(vreinterpretq_s32_u32(_sum), _shift_zp)); #define CALC_0(_k_idx, _c_idx) \ _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ @@ -96,23 +94,21 @@ inline uint8x16_t vqtbx1q_u8_common(uint8x16_t a, uint8x16_t t, _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); -#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k1_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k1_idx)); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k2_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k2_idx)); \ +#define CALC_2(_k1_idx, _k2_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k1_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k1_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k2_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k2_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, - uint8_t src_zp, uint8_t filter_zp, - int32_t src_filter_zp) { +void deconv_direct_2x2( + const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IC); const size_t tail_step = IW - OW / 2; @@ -121,10 +117,8 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); - const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, - 2, 3, 16, 16, 3, 4, 16, 16}; - const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, - 6, 7, 16, 16, 7, 8, 16, 16}; + const uint8x16_t _idx0 = {0, 1, 16, 16, 1, 2, 16, 16, 2, 3, 16, 16, 3, 4, 16, 16}; + const uint8x16_t _idx1 = {4, 5, 16, 16, 5, 6, 16, 16, 6, 7, 16, 16, 7, 8, 16, 16}; uint8x16_t _idx_r_0, _idx_r_1; if (even) { _idx_r_0 = {0, 16, 1, 16, 2, 16, 3, 16, 4, 16, 5, 16, 6, 16, 7, 16}; @@ -144,8 +138,8 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, const uint8_t* k0 = filter; - uint8x16_t _k0 = vreinterpretq_u8_u32( - vdupq_n_u32(*reinterpret_cast(k0))); + uint8x16_t _k0 = + vreinterpretq_u8_u32(vdupq_n_u32(*reinterpret_cast(k0))); uint8x16_t _idx_k = {3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0}; uint8x16_t _k = vqtbl1q_u8_common(_k0, _idx_k); uint8x16_t _idx = {0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16, 0, 1, 16, 16}; @@ -180,13 +174,13 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _r20 = vqtbx1q_u8_common(_src_zp, _r2_ori, _idx_r_0); uint8x16_t _r21 = vqtbx1q_u8_common(_src_zp, _r2_ori, _idx_r_1); - int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), - vreinterpretq_s16_u8(_r10)); + int16x8x2_t r_0 = + vzipq_s16(vreinterpretq_s16_u8(_r00), vreinterpretq_s16_u8(_r10)); uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); - int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), - vreinterpretq_s16_u8(_r11)); + int16x8x2_t r_1 = + vzipq_s16(vreinterpretq_s16_u8(_r01), vreinterpretq_s16_u8(_r11)); uint8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); uint8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); @@ -195,13 +189,11 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, SUB_ZP(_sum01.val[0], _r2); SUB_ZP(_sum01.val[1], _r3); - r_0 = vzipq_s16(vreinterpretq_s16_u8(_r10), - vreinterpretq_s16_u8(_r20)); + r_0 = vzipq_s16(vreinterpretq_s16_u8(_r10), vreinterpretq_s16_u8(_r20)); _r0 = vreinterpretq_u8_s8(r_0.val[0]); _r2 = vreinterpretq_u8_s8(r_0.val[1]); - r_1 = vzipq_s16(vreinterpretq_s16_u8(_r11), - vreinterpretq_s16_u8(_r21)); + r_1 = vzipq_s16(vreinterpretq_s16_u8(_r11), vreinterpretq_s16_u8(_r21)); _r1 = vreinterpretq_u8_s8(r_1.val[0]); _r3 = vreinterpretq_u8_s8(r_1.val[1]); @@ -320,13 +312,13 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _r10 = vqtbx1q_u8_common(_src_zp, _r1_ori, _idx_r_0); uint8x16_t _r11 = vqtbx1q_u8_common(_src_zp, _r1_ori, _idx_r_1); - int16x8x2_t r_0 = vzipq_s16(vreinterpretq_s16_u8(_r00), - vreinterpretq_s16_u8(_r10)); + int16x8x2_t r_0 = + vzipq_s16(vreinterpretq_s16_u8(_r00), vreinterpretq_s16_u8(_r10)); uint8x16_t _r0 = vreinterpretq_u8_s8(r_0.val[0]); uint8x16_t _r2 = vreinterpretq_u8_s8(r_0.val[1]); - int16x8x2_t r_1 = vzipq_s16(vreinterpretq_s16_u8(_r01), - vreinterpretq_s16_u8(_r11)); + int16x8x2_t r_1 = + vzipq_s16(vreinterpretq_s16_u8(_r01), vreinterpretq_s16_u8(_r11)); uint8x16_t _r1 = vreinterpretq_u8_s8(r_1.val[0]); uint8x16_t _r3 = vreinterpretq_u8_s8(r_1.val[1]); @@ -402,10 +394,10 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, - uint8_t src_zp, uint8_t filter_zp, - int32_t src_filter_zp) { +void deconv_direct_3x3( + const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IC); const size_t tail_step = IW - OW / 2; @@ -414,10 +406,8 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _filter_zp = vdupq_n_u8(filter_zp); int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); - const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, - 2, 3, 4, 16, 3, 4, 5, 16}; - const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx0 = {0, 1, 2, 16, 1, 2, 3, 16, 2, 3, 4, 16, 3, 4, 5, 16}; + const uint8x16_t _idx1 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx2 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; uint8x16_t _idx_r_0; @@ -626,58 +616,50 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, #undef CALC_1 #undef CALC_2 -#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ +#define CALC_0(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_filter_zp, _elem)); -#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k00_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ - _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k01_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ +#define CALC_1(_k00_idx, _k01_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k00_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k01_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); -#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ - _elem2 = vdotq2_u32(_filter_zp, _elem); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k10_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k10_idx)); \ - _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); \ - _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ - _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ - _sum0##_c_idx = \ - vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ - _elem2 = vdotq2_u32(_filter_zp, _elem); \ - _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ - _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k11_idx, _elem); \ - _sum1##_c_idx = \ - vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k11_idx)); \ +#define CALC_2(_k00_idx, _k01_idx, _k10_idx, _k11_idx, _c_idx) \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##0); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k00_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k00_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k10_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k10_idx)); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); \ + _elem = vqtbl1q_u8_common(_tmp, _idx##_c_idx##1); \ + _sum0##_c_idx = vdotq_u32(_sum0##_c_idx, _k##_k01_idx, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, vdotq2_u32(_src_zp, _k##_k01_idx)); \ + _elem2 = vdotq2_u32(_filter_zp, _elem); \ + _sum0##_c_idx = vsubq_u32(_sum0##_c_idx, _elem2); \ + _sum1##_c_idx = vdotq_u32(_sum1##_c_idx, _k##_k11_idx, _elem); \ + _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_src_zp, _k##_k11_idx)); \ _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, - uint8_t src_zp, uint8_t filter_zp, - int32_t src_filter_zp) { +void deconv_direct_5x5( + const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IC); const size_t tail_step = IW - OW / 2; @@ -719,8 +701,7 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _k = vld1q_u8(k0 + 9); //! filter row 1 - uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, - 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; uint8x16_t _k123 = vqtbl1q_u8_common(_k, _idx); _idx = {11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16, 11, 16, 16, 16}; uint8x16_t _k4 = vqtbl1q_u8_common(_k, _idx); @@ -974,10 +955,10 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, template MEGDNN_ATTRIBUTE_TARGET("dotprod") -void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, - size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, - uint8_t src_zp, uint8_t filter_zp, - int32_t src_filter_zp) { +void deconv_direct_7x7( + const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, + int32_t src_filter_zp) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IC); const size_t tail_step = IW - OW / 2; @@ -987,8 +968,7 @@ void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, int32x4_t _shift_zp = vdupq_n_s32(src_filter_zp - SHIFT); const uint8x16_t _idx00 = {0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6}; - const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, - 6, 7, 8, 16, 7, 8, 9, 16}; + const uint8x16_t _idx01 = {4, 5, 6, 16, 5, 6, 7, 16, 6, 7, 8, 16, 7, 8, 9, 16}; const uint8x16_t _idx10 = {4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10}; const uint8x16_t _idx11 = {8, 9, 10, 16, 9, 10, 11, 16, 10, 11, 12, 16, 11, 12, 13, 16}; @@ -1018,8 +998,7 @@ void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, uint8x16_t _k = vld1q_u8(k0 + 33); //! filter row 1 - uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, - 15, 14, 13, 12, 15, 14, 13, 12}; + uint8x16_t _idx = {15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12, 15, 14, 13, 12}; uint8x16_t _k123 = vqtbl1q_u8_common(_k, _idx); _idx = {11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16, 11, 10, 9, 16}; uint8x16_t _k456 = vqtbl1q_u8_common(_k, _idx); @@ -1330,17 +1309,17 @@ size_t deconv::get_workspace_in_bytes_stride2_quint8_dot( bool deconv::can_stride2_quint8_dot(const NCBKernSizeParam& param) { auto&& fm = param.filter_meta; - auto FH = fm.spatial[0], FW = fm.spatial[1], OC = fm.ocpg, - PH = fm.padding[0], PW = fm.padding[1]; - bool avaiable = fm.format == param::Convolution::Format::NCHW && - !fm.should_flip && fm.spatial_ndim == 2 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.stride[0] == 2 && fm.stride[1] == 2 && FH == FW && - (FH == 2 || FH == 3 || FH == 5 || FH == 7) && + auto FH = fm.spatial[0], FW = fm.spatial[1], OC = fm.ocpg, PH = fm.padding[0], + PW = fm.padding[1]; + bool avaiable = fm.format == param::Convolution::Format::NCHW && !fm.should_flip && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && + FH == FW && (FH == 2 || FH == 3 || FH == 5 || FH == 7) && FH >= PH + 1 && FW >= PW + 1; - avaiable &= (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || - param.grad_type.enumv() == DTypeEnum::Int32); + avaiable &= + (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm || + param.grad_type.enumv() == DTypeEnum::Int32); /** * \note In the kernel, we use uint32_t to calc the value, in order @@ -1352,8 +1331,8 @@ bool deconv::can_stride2_quint8_dot(const NCBKernSizeParam& param) { * be possible(7*7*OC*2^8*2^8 > SHIFT => OC > 334). */ avaiable &= (7 * 7 * OC < (1 << (SHIFT_BITS - 8 - 8))); - return avaiable && ((FH == 2 && OC <= 4) || - ((FH == 3 || FH == 5 || FH == 7) && OC <= 8)); + return avaiable && + ((FH == 2 && OC <= 4) || ((FH == 3 || FH == 5 || FH == 7) && OC <= 8)); } void deconv::stride2_quint8_dot(const NCBKernParam& param) { @@ -1366,15 +1345,14 @@ void deconv::stride2_quint8_dot(const NCBKernParam& param) { int padding_h = FH - PH - 1, padding_w = FW - PW - 1; get_rectified_size(IH, IW, OH, OW, FH, FW, PH, PW, IH2, IW2, OW2); - uint8_t filter_zp = - param.filter_type.param().zero_point; + uint8_t filter_zp = param.filter_type.param().zero_point; uint8_t src_zp = param.diff_type.param().zero_point; int32_t src_filter_zp = static_cast(filter_zp) * static_cast(src_zp) * OC * FH * FH; - using Func = std::function; + using Func = std::function; Func deconv = nullptr, deconv_last_oc = nullptr; switch (FH) { @@ -1400,11 +1378,10 @@ void deconv::stride2_quint8_dot(const NCBKernParam& param) { } bool need_dst_copy_var = need_dst_copy(param); - uint8_t* base_src_ptr = reinterpret_cast( - const_cast(param.diff())); + uint8_t* base_src_ptr = + reinterpret_cast(const_cast(param.diff())); int32_t* base_dst_ptr = reinterpret_cast(param.grad()); - const uint8_t* fptr = - reinterpret_cast(param.filter()); + const uint8_t* fptr = reinterpret_cast(param.filter()); for (size_t n = 0; n < N; ++n) { int32_t* dptr_copied = static_cast(bundle.get(1)); @@ -1426,21 +1403,22 @@ void deconv::stride2_quint8_dot(const NCBKernParam& param) { rep(oc, OC) { // copy sptr_ori to sptr_copied std::memset(sptr_copied, src_zp, sizeof(uint8_t) * IH2 * IW2); - copy_plane_in_bytes(sptr_copied + padding_h * IW2 + padding_w / 2, - sptr_ori + oc * IH * IW, IH, - IW * sizeof(uint8_t), 2 * IW2 * sizeof(uint8_t), - IW * sizeof(uint8_t)); + copy_plane_in_bytes( + sptr_copied + padding_h * IW2 + padding_w / 2, + sptr_ori + oc * IH * IW, IH, IW * sizeof(uint8_t), + 2 * IW2 * sizeof(uint8_t), IW * sizeof(uint8_t)); sptr = sptr_copied; int32_t* dst_ptr = dptr; const uint8_t* filter = fptr + oc * IC * FH * FW; for (size_t ic = 0; ic < IC; ic++) { if (oc != OC - 1) { - deconv(sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, IC, - src_zp, filter_zp, src_filter_zp); + deconv(sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, IC, src_zp, + filter_zp, src_filter_zp); } else { - deconv_last_oc(sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, - IC, src_zp, filter_zp, src_filter_zp); + deconv_last_oc( + sptr, filter, dst_ptr, IH2, IW2, OH, OW_real, IC, src_zp, + filter_zp, src_filter_zp); } dst_ptr += OH * OW_real; filter += FH * FH; @@ -1448,10 +1426,10 @@ void deconv::stride2_quint8_dot(const NCBKernParam& param) { } if (need_dst_copy_var) { for (size_t ic = 0; ic < IC; ++ic) { - copy_plane_in_bytes(dptr_ori + ic * OH * OW, - dptr + ic * OH * OW2, OH, - OW * sizeof(int32_t), OW * sizeof(int32_t), - OW2 * sizeof(int32_t)); + copy_plane_in_bytes( + dptr_ori + ic * OH * OW, dptr + ic * OH * OW2, OH, + OW * sizeof(int32_t), OW * sizeof(int32_t), + OW2 * sizeof(int32_t)); } } } diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h index 3822c14d..277d7798 100644 --- a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h @@ -26,7 +26,7 @@ void stride2_quint8_dot(const NCBKernParam& param); size_t get_workspace_in_bytes_stride2_quint8_dot(const NCBKernSizeParam& param); -} // namespace convolution +} // namespace deconv } // namespace arm_common } // namespace megdnn #endif diff --git a/dnn/src/arm_common/cvt_color/opr_impl.cpp b/dnn/src/arm_common/cvt_color/opr_impl.cpp index 8ec548bb..56ce613b 100644 --- a/dnn/src/arm_common/cvt_color/opr_impl.cpp +++ b/dnn/src/arm_common/cvt_color/opr_impl.cpp @@ -58,15 +58,15 @@ * * --------------------------------------------------------------------------- */ -#include #include "src/arm_common/cvt_color/opr_impl.h" +#include +#include "midout.h" #include "src/arm_common/handle.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/cv/common.h" #include "src/common/cv/cvt_color.h" #include "src/common/cv/helper.h" #include "src/common/utils.h" -#include "midout.h" MIDOUT_DECL(megdnn_arm_cvtcolor) MIDOUT_DECL(megdnn_arm_cvtcolor_cases) @@ -147,23 +147,17 @@ void cvt_yuv_transform(const Mat8u& src, Mat8u& dst) { for (; c <= (int)(width - 16); c += 16, index0 += 48, index1 += 48) { int16x8x2_t v_vu_s16; if (is_planar) { - v_vu_s16.val[0] = - vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pV + c / 2))); - v_vu_s16.val[1] = - vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pU + c / 2))); + v_vu_s16.val[0] = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pV + c / 2))); + v_vu_s16.val[1] = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pU + c / 2))); } else { if (is_uv) { v_vu = vld2_u8(pU + c); - v_vu_s16.val[0] = - vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); - v_vu_s16.val[1] = - vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); + v_vu_s16.val[0] = vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); + v_vu_s16.val[1] = vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); } else { v_vu = vld2_u8(pV + c); - v_vu_s16.val[0] = - vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); - v_vu_s16.val[1] = - vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); + v_vu_s16.val[0] = vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); + v_vu_s16.val[1] = vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); } } @@ -180,11 +174,9 @@ void cvt_yuv_transform(const Mat8u& src, Mat8u& dst) { v_RV1 = vshrq_n_s32(vmull_s16(v_v0, v_359), 8); v_RV3 = vshrq_n_s32(vmull_s16(v_v1, v_359), 8); v_GVU1 = vshrq_n_s32( - vaddq_s32(vmull_s16(v_u0, v_88), vmull_s16(v_v0, v_183)), - 8); + vaddq_s32(vmull_s16(v_u0, v_88), vmull_s16(v_v0, v_183)), 8); v_GVU3 = vshrq_n_s32( - vaddq_s32(vmull_s16(v_u1, v_88), vmull_s16(v_v1, v_183)), - 8); + vaddq_s32(vmull_s16(v_u1, v_88), vmull_s16(v_v1, v_183)), 8); v_BU1 = vshrq_n_s32(vmull_s16(v_u0, v_454), 8); v_BU3 = vshrq_n_s32(vmull_s16(v_u1, v_454), 8); @@ -239,38 +231,38 @@ void cvt_yuv_transform(const Mat8u& src, Mat8u& dst) { if (rgb) { v_RGB.val[0] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), - vmovn_s32(v_R.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), - vmovn_s32(v_R.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[0]), vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[2]), vmovn_s32(v_R.val[3])))); v_RGB.val[1] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), - vmovn_s32(v_G.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), - vmovn_s32(v_G.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[0]), vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[2]), vmovn_s32(v_G.val[3])))); v_RGB.val[2] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), - vmovn_s32(v_B.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), - vmovn_s32(v_B.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[0]), vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[2]), vmovn_s32(v_B.val[3])))); vst3q_u8((dst0 + c * 3), v_RGB); } else { v_BGR.val[0] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), - vmovn_s32(v_B.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), - vmovn_s32(v_B.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[0]), vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[2]), vmovn_s32(v_B.val[3])))); v_BGR.val[1] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), - vmovn_s32(v_G.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), - vmovn_s32(v_G.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[0]), vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[2]), vmovn_s32(v_G.val[3])))); v_BGR.val[2] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), - vmovn_s32(v_R.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), - vmovn_s32(v_R.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[0]), vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[2]), vmovn_s32(v_R.val[3])))); vst3q_u8((dst0 + c * 3), v_BGR); } @@ -302,38 +294,38 @@ void cvt_yuv_transform(const Mat8u& src, Mat8u& dst) { if (rgb) { v_RGB.val[0] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), - vmovn_s32(v_R.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), - vmovn_s32(v_R.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[0]), vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[2]), vmovn_s32(v_R.val[3])))); v_RGB.val[1] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), - vmovn_s32(v_G.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), - vmovn_s32(v_G.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[0]), vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[2]), vmovn_s32(v_G.val[3])))); v_RGB.val[2] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), - vmovn_s32(v_B.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), - vmovn_s32(v_B.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[0]), vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[2]), vmovn_s32(v_B.val[3])))); vst3q_u8((dst1 + c * 3), v_RGB); } else { v_BGR.val[0] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), - vmovn_s32(v_B.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), - vmovn_s32(v_B.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[0]), vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[2]), vmovn_s32(v_B.val[3])))); v_BGR.val[1] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), - vmovn_s32(v_G.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), - vmovn_s32(v_G.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[0]), vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[2]), vmovn_s32(v_G.val[3])))); v_BGR.val[2] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), - vmovn_s32(v_R.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), - vmovn_s32(v_R.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[0]), vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[2]), vmovn_s32(v_R.val[3])))); vst3q_u8((dst1 + c * 3), v_BGR); } } @@ -450,7 +442,7 @@ void cvt_BT601_yuv_transform(const Mat8u& src, Mat8u& dst) { out[index++] = R; \ } -#define YG 18997 /* round(1.164 * 64 * 256 * 256 / 257) */ +#define YG 18997 /* round(1.164 * 64 * 256 * 256 / 257) */ #define YGB -1160 /* 1.164 * 64 * -16 + 64 / 2 */ // U and V contributions to R,G,B. @@ -503,24 +495,18 @@ void cvt_BT601_yuv_transform(const Mat8u& src, Mat8u& dst) { for (; j <= (int)(width - 16); j += 16, index += 48, index1 += 48) { int16x8x2_t v_vu_s16; if (is_planar) { - v_vu_s16.val[0] = - vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pV + jV))); - v_vu_s16.val[1] = - vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pU + jV))); + v_vu_s16.val[0] = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pV + jV))); + v_vu_s16.val[1] = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(pU + jV))); jV += 8; } else { if (is_uv) { v_vu = vld2_u8(pU + j); - v_vu_s16.val[0] = - vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); - v_vu_s16.val[1] = - vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); + v_vu_s16.val[0] = vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); + v_vu_s16.val[1] = vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); } else { v_vu = vld2_u8(pV + j); - v_vu_s16.val[0] = - vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); - v_vu_s16.val[1] = - vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); + v_vu_s16.val[0] = vreinterpretq_s16_u16(vmovl_u8(v_vu.val[0])); + v_vu_s16.val[1] = vreinterpretq_s16_u16(vmovl_u8(v_vu.val[1])); } } @@ -560,16 +546,17 @@ void cvt_BT601_yuv_transform(const Mat8u& src, Mat8u& dst) { int32x4_t v_y3 = vmovl_s16(vget_high_s16(v_y_2quarter)); //! calc -#define CALC(_idx) \ - v_Y1 = vshrq_n_s32(vmulq_s32(vmulq_s32(v_y##_idx, v_0101), v_YG), 16); \ - v_B.val[_idx] = vshrq_n_s32( \ - vsubq_s32(vaddq_s32(v_Y1, v_BB), vmulq_s32(v_u##_idx, v_UB)), 6); \ - v_G.val[_idx] = \ - vshrq_n_s32(vsubq_s32(vaddq_s32(v_Y1, v_BG), \ - vaddq_s32(vmulq_s32(v_u##_idx, v_UG), \ - vmulq_s32(v_v##_idx, v_VG))), \ - 6); \ - v_R.val[_idx] = vshrq_n_s32( \ +#define CALC(_idx) \ + v_Y1 = vshrq_n_s32(vmulq_s32(vmulq_s32(v_y##_idx, v_0101), v_YG), 16); \ + v_B.val[_idx] = vshrq_n_s32( \ + vsubq_s32(vaddq_s32(v_Y1, v_BB), vmulq_s32(v_u##_idx, v_UB)), 6); \ + v_G.val[_idx] = vshrq_n_s32( \ + vsubq_s32( \ + vaddq_s32(v_Y1, v_BG), \ + vaddq_s32( \ + vmulq_s32(v_u##_idx, v_UG), vmulq_s32(v_v##_idx, v_VG))), \ + 6); \ + v_R.val[_idx] = vshrq_n_s32( \ vsubq_s32(vaddq_s32(v_Y1, v_BR), vmulq_s32(v_v##_idx, v_VR)), 6); CALC(0); @@ -579,37 +566,37 @@ void cvt_BT601_yuv_transform(const Mat8u& src, Mat8u& dst) { if (rgb) { v_RGB.val[0] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), - vmovn_s32(v_R.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), - vmovn_s32(v_R.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[0]), vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[2]), vmovn_s32(v_R.val[3])))); v_RGB.val[1] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), - vmovn_s32(v_G.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), - vmovn_s32(v_G.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[0]), vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[2]), vmovn_s32(v_G.val[3])))); v_RGB.val[2] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), - vmovn_s32(v_B.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), - vmovn_s32(v_B.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[0]), vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[2]), vmovn_s32(v_B.val[3])))); vst3q_u8((out + index), v_RGB); } else { v_BGR.val[0] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), - vmovn_s32(v_B.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), - vmovn_s32(v_B.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[0]), vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[2]), vmovn_s32(v_B.val[3])))); v_BGR.val[1] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), - vmovn_s32(v_G.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), - vmovn_s32(v_G.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[0]), vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[2]), vmovn_s32(v_G.val[3])))); v_BGR.val[2] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), - vmovn_s32(v_R.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), - vmovn_s32(v_R.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[0]), vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[2]), vmovn_s32(v_R.val[3])))); vst3q_u8((out + index), v_BGR); } @@ -630,37 +617,37 @@ void cvt_BT601_yuv_transform(const Mat8u& src, Mat8u& dst) { if (rgb) { v_RGB.val[0] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), - vmovn_s32(v_R.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), - vmovn_s32(v_R.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[0]), vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[2]), vmovn_s32(v_R.val[3])))); v_RGB.val[1] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), - vmovn_s32(v_G.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), - vmovn_s32(v_G.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[0]), vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[2]), vmovn_s32(v_G.val[3])))); v_RGB.val[2] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), - vmovn_s32(v_B.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), - vmovn_s32(v_B.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[0]), vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[2]), vmovn_s32(v_B.val[3])))); vst3q_u8((out1 + index1), v_RGB); } else { v_BGR.val[0] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[0]), - vmovn_s32(v_B.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_B.val[2]), - vmovn_s32(v_B.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[0]), vmovn_s32(v_B.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_B.val[2]), vmovn_s32(v_B.val[3])))); v_BGR.val[1] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[0]), - vmovn_s32(v_G.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_G.val[2]), - vmovn_s32(v_G.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[0]), vmovn_s32(v_G.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_G.val[2]), vmovn_s32(v_G.val[3])))); v_BGR.val[2] = vcombine_u8( - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[0]), - vmovn_s32(v_R.val[1]))), - vqmovun_s16(vcombine_s16(vmovn_s32(v_R.val[2]), - vmovn_s32(v_R.val[3])))); + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[0]), vmovn_s32(v_R.val[1]))), + vqmovun_s16(vcombine_s16( + vmovn_s32(v_R.val[2]), vmovn_s32(v_R.val[3])))); vst3q_u8((out1 + index1), v_BGR); } #undef CALC @@ -748,7 +735,7 @@ void cvt_BT601_yuv_transform(const Mat8u& src, Mat8u& dst) { } // namespace -template +template void cvt_rgb2gray_32f_neon(const Mat32f& src, Mat32f& dst) { static const float coef[] = {0.299f, 0.587f, 0.114f}; // load coef into neon types @@ -759,12 +746,13 @@ void cvt_rgb2gray_32f_neon(const Mat32f& src, Mat32f& dst) { const float32x4_t v_cr(vdupq_n_f32(coef_c0)), v_cg(vdupq_n_f32(coef_c1)), v_cb(vdupq_n_f32(coef_c2)); -#define EXPAND(offset) \ - v_src = vld3q_f32(psrc + offset * 3); \ - vst1q_f32(pdst + offset, \ - vmlaq_f32(vmlaq_f32(vmulq_f32(v_src.val[0], v_cr), v_src.val[1], \ - v_cg), \ - v_src.val[2], v_cb)); +#define EXPAND(offset) \ + v_src = vld3q_f32(psrc + offset * 3); \ + vst1q_f32( \ + pdst + offset, \ + vmlaq_f32( \ + vmlaq_f32(vmulq_f32(v_src.val[0], v_cr), v_src.val[1], v_cg), \ + v_src.val[2], v_cb)); for (size_t r = 0; r < src.rows(); ++r) { const float* psrc = src.ptr(r); float* pdst = dst.ptr(r); @@ -845,44 +833,43 @@ void cvt_rgb2yuv_8u_neon(const Mat8u& src, Mat8u& dst) { v_src0.val[1] = vget_low_s16(v_src16.val[1]); v_src0.val[2] = vget_low_s16(v_src16.val[2]); - int32x4_t v_Y0 = vmlal_s16(vmlal_s16(vmull_s16(v_src0.val[0], v_c0), - v_src0.val[1], v_c1), - v_src0.val[2], v_c2); + int32x4_t v_Y0 = vmlal_s16( + vmlal_s16(vmull_s16(v_src0.val[0], v_c0), v_src0.val[1], v_c1), + v_src0.val[2], v_c2); v_Y0 = vshrq_n_s32(vaddq_s32(v_Y0, v_delta2), yuv_shift); - int32x4_t v_Cr0 = vmlaq_s32( - v_delta, vsubq_s32(vmovl_s16(v_src0.val[0]), v_Y0), v_c3); + int32x4_t v_Cr0 = + vmlaq_s32(v_delta, vsubq_s32(vmovl_s16(v_src0.val[0]), v_Y0), v_c3); v_Cr0 = vshrq_n_s32(vaddq_s32(v_Cr0, v_delta2), yuv_shift); - int32x4_t v_Cb0 = vmlaq_s32( - v_delta, vsubq_s32(vmovl_s16(v_src0.val[2]), v_Y0), v_c4); + int32x4_t v_Cb0 = + vmlaq_s32(v_delta, vsubq_s32(vmovl_s16(v_src0.val[2]), v_Y0), v_c4); v_Cb0 = vshrq_n_s32(vaddq_s32(v_Cb0, v_delta2), yuv_shift); v_src0.val[0] = vget_high_s16(v_src16.val[0]); v_src0.val[1] = vget_high_s16(v_src16.val[1]); v_src0.val[2] = vget_high_s16(v_src16.val[2]); - int32x4_t v_Y1 = vmlal_s16(vmlal_s16(vmull_s16(v_src0.val[0], v_c0), - v_src0.val[1], v_c1), - v_src0.val[2], v_c2); + int32x4_t v_Y1 = vmlal_s16( + vmlal_s16(vmull_s16(v_src0.val[0], v_c0), v_src0.val[1], v_c1), + v_src0.val[2], v_c2); v_Y1 = vshrq_n_s32(vaddq_s32(v_Y1, v_delta2), yuv_shift); - int32x4_t v_Cr1 = vmlaq_s32( - v_delta, vsubq_s32(vmovl_s16(v_src0.val[0]), v_Y1), v_c3); + int32x4_t v_Cr1 = + vmlaq_s32(v_delta, vsubq_s32(vmovl_s16(v_src0.val[0]), v_Y1), v_c3); v_Cr1 = vshrq_n_s32(vaddq_s32(v_Cr1, v_delta2), yuv_shift); - int32x4_t v_Cb1 = vmlaq_s32( - v_delta, vsubq_s32(vmovl_s16(v_src0.val[2]), v_Y1), v_c4); + int32x4_t v_Cb1 = + vmlaq_s32(v_delta, vsubq_s32(vmovl_s16(v_src0.val[2]), v_Y1), v_c4); v_Cb1 = vshrq_n_s32(vaddq_s32(v_Cb1, v_delta2), yuv_shift); - v_dst.val[0] = vqmovun_s16( - vcombine_s16(vqmovn_s32(v_Y0), vqmovn_s32(v_Y1))); - v_dst.val[1] = vqmovun_s16( - vcombine_s16(vqmovn_s32(v_Cr0), vqmovn_s32(v_Cr1))); - v_dst.val[2] = vqmovun_s16( - vcombine_s16(vqmovn_s32(v_Cb0), vqmovn_s32(v_Cb1))); + v_dst.val[0] = + vqmovun_s16(vcombine_s16(vqmovn_s32(v_Y0), vqmovn_s32(v_Y1))); + v_dst.val[1] = + vqmovun_s16(vcombine_s16(vqmovn_s32(v_Cr0), vqmovn_s32(v_Cr1))); + v_dst.val[2] = + vqmovun_s16(vcombine_s16(vqmovn_s32(v_Cb0), vqmovn_s32(v_Cb1))); vst3_u8(pdst, v_dst); } for (; psrc < pend; psrc += 3, pdst += 3) { - int Y = descale(psrc[0] * C0 + psrc[1] * C1 + psrc[2] * C2, - yuv_shift); + int Y = descale(psrc[0] * C0 + psrc[1] * C1 + psrc[2] * C2, yuv_shift); int Cr = descale((psrc[0] - Y) * C3 + delta, yuv_shift); int Cb = descale((psrc[2] - Y) * C4 + delta, yuv_shift); pdst[0] = saturate_cast(Y); @@ -912,13 +899,13 @@ void cvt_rgb2yuv_32f_neon(const Mat32f& src, Mat32f& dst) { for (; psrc <= pend - 4 * 3; psrc += 4 * 3, pdst += 4 * 3) { float32x4x3_t v_src = vld3q_f32(psrc), v_dst; - v_dst.val[0] = vmlaq_f32(vmlaq_f32(vmulq_f32(v_src.val[0], v_c0), - v_src.val[1], v_c1), - v_src.val[2], v_c2); - v_dst.val[1] = vmlaq_f32( - v_delta, vsubq_f32(v_src.val[0], v_dst.val[0]), v_c3); - v_dst.val[2] = vmlaq_f32( - v_delta, vsubq_f32(v_src.val[2], v_dst.val[0]), v_c4); + v_dst.val[0] = vmlaq_f32( + vmlaq_f32(vmulq_f32(v_src.val[0], v_c0), v_src.val[1], v_c1), + v_src.val[2], v_c2); + v_dst.val[1] = + vmlaq_f32(v_delta, vsubq_f32(v_src.val[0], v_dst.val[0]), v_c3); + v_dst.val[2] = + vmlaq_f32(v_delta, vsubq_f32(v_src.val[2], v_dst.val[0]), v_c4); vst3q_f32(pdst, v_dst); } @@ -963,39 +950,30 @@ void cvt_yuv2rgb_8u_neon(const Mat8u& src, Mat8u& dst) { v_Cb = vget_low_s16(v_src16.val[2]); int32x4_t v_b0 = vmulq_s32(v_c3, vsubl_s16(v_Cb, v_delta)); - v_b0 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_b0, v_delta2), yuv_shift), - v_Y); - int32x4_t v_g0 = - vmlaq_s32(vmulq_s32(vsubl_s16(v_Cr, v_delta), v_c1), - vsubl_s16(v_Cb, v_delta), v_c2); - v_g0 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_g0, v_delta2), yuv_shift), - v_Y); + v_b0 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_b0, v_delta2), yuv_shift), v_Y); + int32x4_t v_g0 = vmlaq_s32( + vmulq_s32(vsubl_s16(v_Cr, v_delta), v_c1), vsubl_s16(v_Cb, v_delta), + v_c2); + v_g0 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_g0, v_delta2), yuv_shift), v_Y); int32x4_t v_r0 = vmulq_s32(v_c0, vsubl_s16(v_Cr, v_delta)); - v_r0 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_r0, v_delta2), yuv_shift), - v_Y); + v_r0 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_r0, v_delta2), yuv_shift), v_Y); v_Y = vget_high_s16(v_src16.val[0]); v_Cr = vget_high_s16(v_src16.val[1]); v_Cb = vget_high_s16(v_src16.val[2]); int32x4_t v_b1 = vmulq_s32(v_c3, vsubl_s16(v_Cb, v_delta)); - v_b1 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_b1, v_delta2), yuv_shift), - v_Y); - int32x4_t v_g1 = - vmlaq_s32(vmulq_s32(vsubl_s16(v_Cr, v_delta), v_c1), - vsubl_s16(v_Cb, v_delta), v_c2); - v_g1 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_g1, v_delta2), yuv_shift), - v_Y); + v_b1 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_b1, v_delta2), yuv_shift), v_Y); + int32x4_t v_g1 = vmlaq_s32( + vmulq_s32(vsubl_s16(v_Cr, v_delta), v_c1), vsubl_s16(v_Cb, v_delta), + v_c2); + v_g1 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_g1, v_delta2), yuv_shift), v_Y); int32x4_t v_r1 = vmulq_s32(v_c0, vsubl_s16(v_Cr, v_delta)); - v_r1 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_r1, v_delta2), yuv_shift), - v_Y); + v_r1 = vaddw_s16(vshrq_n_s32(vaddq_s32(v_r1, v_delta2), yuv_shift), v_Y); - uint8x8_t v_b = - vqmovun_s16(vcombine_s16(vmovn_s32(v_b0), vmovn_s32(v_b1))); - uint8x8_t v_g = - vqmovun_s16(vcombine_s16(vmovn_s32(v_g0), vmovn_s32(v_g1))); - uint8x8_t v_r = - vqmovun_s16(vcombine_s16(vmovn_s32(v_r0), vmovn_s32(v_r1))); + uint8x8_t v_b = vqmovun_s16(vcombine_s16(vmovn_s32(v_b0), vmovn_s32(v_b1))); + uint8x8_t v_g = vqmovun_s16(vcombine_s16(vmovn_s32(v_g0), vmovn_s32(v_g1))); + uint8x8_t v_r = vqmovun_s16(vcombine_s16(vmovn_s32(v_r0), vmovn_s32(v_r1))); uint8x8x3_t v_dst; v_dst.val[0] = v_r; @@ -1009,8 +987,7 @@ void cvt_yuv2rgb_8u_neon(const Mat8u& src, Mat8u& dst) { uchar Cb = psrc[2]; int b = Y + descale((Cb - delta) * C3, yuv_shift); - int g = Y + - descale((Cb - delta) * C2 + (Cr - delta) * C1, yuv_shift); + int g = Y + descale((Cb - delta) * C2 + (Cr - delta) * C1, yuv_shift); int r = Y + descale((Cr - delta) * C0, yuv_shift); pdst[0] = saturate_cast(r); @@ -1038,13 +1015,13 @@ void cvt_yuv2rgb_32f_neon(const Mat32f& src, Mat32f& dst) { const float* const pend = psrc + src.cols() * 3; for (; psrc <= pend - 4 * 3; psrc += 4 * 3, pdst += 4 * 3) { float32x4x3_t v_src = vld3q_f32(psrc), v_dst; - float32x4_t v_Y = v_src.val[0], v_Cr = v_src.val[1], - v_Cb = v_src.val[2]; + float32x4_t v_Y = v_src.val[0], v_Cr = v_src.val[1], v_Cb = v_src.val[2]; v_dst.val[0] = vmlaq_f32(v_Y, vsubq_f32(v_Cr, v_delta), v_c0); v_dst.val[1] = vaddq_f32( - vmlaq_f32(vmulq_f32(vsubq_f32(v_Cb, v_delta), v_c2), - vsubq_f32(v_Cr, v_delta), v_c1), + vmlaq_f32( + vmulq_f32(vsubq_f32(v_Cb, v_delta), v_c2), + vsubq_f32(v_Cr, v_delta), v_c1), v_Y); v_dst.val[2] = vmlaq_f32(v_Y, vsubq_f32(v_Cb, v_delta), v_c3); @@ -1178,9 +1155,8 @@ void cvt_rgb2gray(const Mat8u& src, Mat8u& dst) { uchar x0 = temp_src[0]; uchar x1 = temp_src[1]; uchar x2 = temp_src[2]; - temp_dst[0] = - (x0 * R2Y + x1 * G2Y + x2 * B2Y + (1 << (yuv_shift - 1))) >> - yuv_shift; + temp_dst[0] = (x0 * R2Y + x1 * G2Y + x2 * B2Y + (1 << (yuv_shift - 1))) >> + yuv_shift; } } } @@ -1333,9 +1309,8 @@ void cvt_rgba2gray(const Mat8u& src, Mat8u& dst) { uchar x0 = temp_src[0]; uchar x1 = temp_src[1]; uchar x2 = temp_src[2]; - temp_dst[0] = - (x0 * R2Y + x1 * G2Y + x2 * B2Y + (1 << (yuv_shift - 1))) >> - yuv_shift; + temp_dst[0] = (x0 * R2Y + x1 * G2Y + x2 * B2Y + (1 << (yuv_shift - 1))) >> + yuv_shift; } } } @@ -1380,8 +1355,7 @@ void cvt_bgr2gray(const Mat8u& src, Mat8u& dst) { uchar x0 = temp_src[0]; uchar x1 = temp_src[1]; uchar x2 = temp_src[2]; - temp_dst[0] = - (tab[x2] + tab[x1 + 256] + tab[x0 + 512]) >> yuv_shift; + temp_dst[0] = (tab[x2] + tab[x1 + 256] + tab[x0 + 512]) >> yuv_shift; } } } @@ -1459,8 +1433,8 @@ void cvt_yuv2bgr_yu12(const Mat8u& src, Mat8u& dst) { } template -void cvt_bt601_yuv(const megcv::Mat& src, megcv::Mat& dst, - param::CvtColor::Mode mode) { +void cvt_bt601_yuv( + const megcv::Mat& src, megcv::Mat& dst, param::CvtColor::Mode mode) { MEGDNN_MARK_USED_VAR(src); MEGDNN_MARK_USED_VAR(dst); MEGDNN_MARK_USED_VAR(mode); @@ -1468,8 +1442,9 @@ void cvt_bt601_yuv(const megcv::Mat& src, megcv::Mat& dst, } template <> -void cvt_bt601_yuv(const megcv::Mat& src, megcv::Mat& dst, - param::CvtColor::Mode mode) { +void cvt_bt601_yuv( + const megcv::Mat& src, megcv::Mat& dst, + param::CvtColor::Mode mode) { using Mode = param::CvtColor::Mode; switch (mode) { case Mode::BT601_YUV2RGB_NV21: @@ -1518,8 +1493,8 @@ void cvt_bt601_yuv(const megcv::Mat& src, megcv::Mat& dst, } template -void CvtColorImpl::cvt_color_exec(const TensorND& src_tensor, - const TensorND& dst_tensor) { +void CvtColorImpl::cvt_color_exec( + const TensorND& src_tensor, const TensorND& dst_tensor) { auto mode = param().mode; for (size_t i = 0; i < src_tensor.layout.shape[0]; ++i) { Mat src = TensorND2Mat(src_tensor, i); @@ -1660,19 +1635,23 @@ void CvtColorImpl::cvt_color_exec(const TensorND& src_tensor, } } } -void CvtColorImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_workspace workspace) { +void CvtColorImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { using namespace megcv; check_exec(src.layout, dst.layout, workspace.size); if (dst.layout.dtype == dtype::Float32()) { MIDOUT_BEGIN(megdnn_arm_cvtcolor MEGDNN_COMMA midout_iv(0)) { - MEGDNN_DISPATCH_CPU_KERN_OPR(cvt_color_exec(src, dst)); - } MIDOUT_END(); + MEGDNN_DISPATCH_CPU_KERN_OPR(cvt_color_exec(src, dst)); + } + MIDOUT_END(); } else if (dst.layout.dtype == dtype::Uint8()) { MIDOUT_BEGIN(megdnn_arm_cvtcolor MEGDNN_COMMA midout_iv(1)) { MEGDNN_DISPATCH_CPU_KERN_OPR(cvt_color_exec(src, dst)); - } MIDOUT_END(); - } else { megdnn_throw("Unsupported datatype of CvtColor optr."); }; + } + MIDOUT_END(); + } else { + megdnn_throw("Unsupported datatype of CvtColor optr."); + }; } } // namespace arm_common diff --git a/dnn/src/arm_common/cvt_color/opr_impl.h b/dnn/src/arm_common/cvt_color/opr_impl.h index ef9b7d50..79715b18 100644 --- a/dnn/src/arm_common/cvt_color/opr_impl.h +++ b/dnn/src/arm_common/cvt_color/opr_impl.h @@ -14,27 +14,24 @@ namespace megdnn { namespace arm_common { -class CvtColorImpl: public CvtColor { - private: - template - void cvt_color_exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst); +class CvtColorImpl : public CvtColor { +private: + template + void cvt_color_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst); - public: - using CvtColor::CvtColor; +public: + using CvtColor::CvtColor; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &) override - { - return 0; - } + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { + return 0; + } }; -} // namespace x86 -} // namespace megdnn +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise/binary/algo.cpp b/dnn/src/arm_common/elemwise/binary/algo.cpp index 81811136..5cad7e43 100644 --- a/dnn/src/arm_common/elemwise/binary/algo.cpp +++ b/dnn/src/arm_common/elemwise/binary/algo.cpp @@ -61,8 +61,7 @@ static inline bool is_available_common(Elemwise::Mode mode) { mode == Mode::FUSE_ADD_RELU) \ return true; -bool ElemwiseImpl::AlgoBinaryVecVec::is_available( - const KernParam& kern_param) const { +bool ElemwiseImpl::AlgoBinaryVecVec::is_available(const KernParam& kern_param) const { if (!is_available_common(kern_param.mode) || (BcastType::VEC_VEC != kern_param.broad_cast_type)) return false; @@ -124,56 +123,56 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available( #undef DISPATCH_MODE_INT #if MEGDNN_AARCH64 -#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ - switch (kern_param.mode) { \ - DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ - DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ - DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ - DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ - DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ - DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ - DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ - DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \ - FuseAddReluOp); \ - DISPATCH_BINARY(FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, \ - FuseAddHSwishOp); \ - default: \ - megdnn_throw(ssprintf("No avaiable algo find for: %d", \ - static_cast(kern_param.mode))); \ +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ + DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ + DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ + DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ + DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ + DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ + DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \ + DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ + DISPATCH_BINARY( \ + FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \ + default: \ + megdnn_throw(ssprintf( \ + "No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ } #else -#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ - switch (kern_param.mode) { \ - DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ - DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ - DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ - DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ - DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ - DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ - DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \ - FuseAddReluOp); \ - DISPATCH_BINARY(FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, \ - FuseAddHSwishOp); \ - default: \ - megdnn_throw(ssprintf("No avaiable algo find for: %d", \ - static_cast(kern_param.mode))); \ +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ + DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ + DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ + DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ + DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ + DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \ + DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ + DISPATCH_BINARY( \ + FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, FuseAddHSwishOp); \ + default: \ + megdnn_throw(ssprintf( \ + "No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ } #endif -#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ - switch (kern_param.mode) { \ - DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ - DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ - DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ - DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ - DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ - DISPATCH_BINARY(RMULH, _case, _type, _type_midout_id, RmulhOp); \ - DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \ - FuseAddReluOp); \ - default: \ - megdnn_throw(ssprintf("No avaiable algo find for: %d", \ - static_cast(kern_param.mode))); \ +#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \ + DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \ + DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \ + DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \ + DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \ + DISPATCH_BINARY(RMULH, _case, _type, _type_midout_id, RmulhOp); \ + DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, FuseAddReluOp); \ + default: \ + megdnn_throw(ssprintf( \ + "No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ } void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const { @@ -181,23 +180,23 @@ void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const { auto &src0 = elparam[0], &src1 = elparam[1]; //! exactly match [x, y] + [x, y] -#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerBinary<_op<_type, _type>, \ - BcastType::VEC_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, dst.layout.dtype, \ - src0.layout.total_nr_elems())); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary<_op<_type, _type>, BcastType::VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ return auto&& dst = *(kern_param.m_dst); @@ -208,30 +207,30 @@ void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const { return; } -void ElemwiseImpl::AlgoBinaryVecScalar::exec( - const KernParam& kern_param) const { +void ElemwiseImpl::AlgoBinaryVecScalar::exec(const KernParam& kern_param) const { auto& elparam = kern_param.binary_elparam; auto &src0 = elparam[0], &src1 = elparam[1]; auto&& dst = *(kern_param.m_dst); // Case 2: vector + scalar -#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerBinary<_op<_type, _type>, \ - BcastType::VEC_SCALAR>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr)[0], \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, dst.layout.dtype, \ - src0.layout.total_nr_elems())); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr)[0], \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ return if (BcastType::VEC_SCALAR == kern_param.broad_cast_type) { @@ -240,23 +239,24 @@ void ElemwiseImpl::AlgoBinaryVecScalar::exec( #undef DISPATCH_BINARY // scalar + vector -#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerBinary<_op<_type, _type>, \ - BcastType::SCALAR_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr)[0], \ - static_cast(src1.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, dst.layout.dtype, \ - src1.layout.total_nr_elems())); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::SCALAR_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr)[0], \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, \ + src1.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ return if (BcastType::SCALAR_VEC == kern_param.broad_cast_type) { @@ -267,8 +267,7 @@ void ElemwiseImpl::AlgoBinaryVecScalar::exec( return; } -void ElemwiseImpl::AlgoBinaryVecBcast101::exec( - const KernParam& kern_param) const { +void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) const { auto& elparam = kern_param.binary_elparam; auto &src0 = elparam[0], &src1 = elparam[1]; auto&& dst = *(kern_param.m_dst); @@ -277,23 +276,25 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec( // Case 3: BcastType::VEC + BCAST_101 if (BcastType::VEC_BCAST101 == kern_param.broad_cast_type && is_broadcasted_channel_like(src1.layout, binfo)) { -#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerBinary<_op<_type, _type>, \ - BcastType::VEC_BCAST101>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ - binfo.z)); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::VEC_BCAST101>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ return DISPATCH_TYPE("AlgoBinaryVecBcast101::exec_vec_b"_hash); @@ -304,23 +305,25 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec( // BCAST_101 + BcastType::VEC if (BcastType::BCAST101_VEC == kern_param.broad_cast_type && is_broadcasted_channel_like(src0.layout, binfo)) { -#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerBinary<_op<_type, _type>, \ - BcastType::BCAST101_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ - binfo.z)); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::BCAST101_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ + binfo.z)); \ + } \ + MIDOUT_END(); \ return DISPATCH_TYPE("AlgoBinaryVecBcast101::exec_b_vec"_hash); @@ -330,8 +333,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec( return; } -void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec( - const KernParam& kern_param) const { +void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) const { auto& elparam = kern_param.binary_elparam; auto &src0 = elparam[0], &src1 = elparam[1]; auto&& dst = *(kern_param.m_dst); @@ -343,26 +345,27 @@ void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec( is_broadcastedx_channel_like<4>(src1.layout, binfo) || is_broadcastedx_channel_like<8>(src1.layout, binfo), "only nchw44 and nchw88 supported"); -#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerBinary<_op<_type, _type>, \ - BcastType::VEC_BCAST101xX>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, dst.layout.dtype, batch_size, \ - binfo.x, binfo.y, binfo.z)); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::VEC_BCAST101xX>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ + binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ return - size_t batch_size = - src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); DISPATCH_TYPE("AlgoBinaryVecBcast101xX::exec_vec_b"_hash); #undef DISPATCH_BINARY @@ -374,26 +377,27 @@ void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec( is_broadcastedx_channel_like<4>(src0.layout, binfo) || is_broadcastedx_channel_like<8>(src0.layout, binfo), "only nchw44 and nchw88 supported"); -#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerBinary<_op<_type, _type>, \ - BcastType::BCAST101xX_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, dst.layout.dtype, batch_size, \ - binfo.x, binfo.y, binfo.z)); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_binary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerBinary< \ + _op<_type, _type>, BcastType::BCAST101xX_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ + binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ return - size_t batch_size = - src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); + size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); DISPATCH_TYPE("AlgoBinaryVecBcast101xX::exec_b_vec"_hash); diff --git a/dnn/src/arm_common/elemwise/binary/algo.h b/dnn/src/arm_common/elemwise/binary/algo.h index 42669b53..05ac6937 100644 --- a/dnn/src/arm_common/elemwise/binary/algo.h +++ b/dnn/src/arm_common/elemwise/binary/algo.h @@ -14,21 +14,20 @@ namespace megdnn { namespace arm_common { -#define DECL_CB(case) \ - class ElemwiseImpl::AlgoBinary##case final \ - : public ElemwiseImpl::AlgoBase { \ - mutable std::string m_name; \ - AlgoAttribute attribute() const override { \ - return AlgoAttribute::REPRODUCIBLE; \ - } \ - const char* name() const override { \ - if (m_name.empty()) { \ - m_name = ssprintf("Elemwise::AlgoBinaryCase" #case); \ - } \ - return m_name.c_str(); \ - } \ - bool is_available(const KernParam&) const override; \ - void exec(const KernParam&) const override; \ +#define DECL_CB(case) \ + class ElemwiseImpl::AlgoBinary##case final : public ElemwiseImpl::AlgoBase { \ + mutable std::string m_name; \ + AlgoAttribute attribute() const override { \ + return AlgoAttribute::REPRODUCIBLE; \ + } \ + const char* name() const override { \ + if (m_name.empty()) { \ + m_name = ssprintf("Elemwise::AlgoBinaryCase" #case); \ + } \ + return m_name.c_str(); \ + } \ + bool is_available(const KernParam&) const override; \ + void exec(const KernParam&) const override; \ }; DECL_CB(VecVec); diff --git a/dnn/src/arm_common/elemwise/neon_mathfun.cpp b/dnn/src/arm_common/elemwise/neon_mathfun.cpp index e53189bc..beffb726 100644 --- a/dnn/src/arm_common/elemwise/neon_mathfun.cpp +++ b/dnn/src/arm_common/elemwise/neon_mathfun.cpp @@ -68,8 +68,7 @@ namespace arm_common { v4sf log_ps_f32(v4sf x) { v4sf one = vdupq_n_f32(1); - x = vmaxq_f32(x, - vdupq_n_f32(0)); /* force flush to zero on denormal values */ + x = vmaxq_f32(x, vdupq_n_f32(0)); /* force flush to zero on denormal values */ v4su invalid_mask = vcleq_f32(x, vdupq_n_f32(0)); v4si ux = vreinterpretq_s32_f32(x); @@ -95,8 +94,8 @@ v4sf log_ps_f32(v4sf x) { v4su mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF)); v4sf tmp = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); x = vsubq_f32(x, one); - e = vsubq_f32(e, vreinterpretq_f32_u32( - vandq_u32(vreinterpretq_u32_f32(one), mask))); + e = vsubq_f32( + e, vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(one), mask))); x = vaddq_f32(x, tmp); v4sf z = vmulq_f32(x, x); @@ -131,9 +130,9 @@ v4sf log_ps_f32(v4sf x) { tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2)); x = vaddq_f32(x, y); x = vaddq_f32(x, tmp); - x = vreinterpretq_f32_u32( - vorrq_u32(vreinterpretq_u32_f32(x), - invalid_mask)); // negative arg will be NAN + x = vreinterpretq_f32_u32(vorrq_u32( + vreinterpretq_u32_f32(x), + invalid_mask)); // negative arg will be NAN return x; } @@ -227,13 +226,13 @@ float16x8_t exp_ps_f16(float16x8_t x) { #define c_minus_cephes_DP1 -0.78515625 #define c_minus_cephes_DP2 -2.4187564849853515625e-4 #define c_minus_cephes_DP3 -3.77489497744594108e-8 -#define c_sincof_p0 -1.9515295891E-4 -#define c_sincof_p1 8.3321608736E-3 -#define c_sincof_p2 -1.6666654611E-1 -#define c_coscof_p0 2.443315711809948E-005 -#define c_coscof_p1 -1.388731625493765E-003 -#define c_coscof_p2 4.166664568298827E-002 -#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI +#define c_sincof_p0 -1.9515295891E-4 +#define c_sincof_p1 8.3321608736E-3 +#define c_sincof_p2 -1.6666654611E-1 +#define c_coscof_p0 2.443315711809948E-005 +#define c_coscof_p1 -1.388731625493765E-003 +#define c_coscof_p2 4.166664568298827E-002 +#define c_cephes_FOPI 1.27323954473516 // 4 / M_PI /* evaluation of 4 sines & cosines at once. diff --git a/dnn/src/arm_common/elemwise/opr_impl.cpp b/dnn/src/arm_common/elemwise/opr_impl.cpp index 6e705205..5e8c2f3e 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise/opr_impl.cpp @@ -31,8 +31,7 @@ class ElemwiseImpl::AlgoPack { AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; - AlgoTernaryFma3Bcast101xXVecBcast101xX - algo_ternaryfma3_bcast101xX_vec_bcast101xX; + AlgoTernaryFma3Bcast101xXVecBcast101xX algo_ternaryfma3_bcast101xX_vec_bcast101xX; AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec; AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; @@ -68,8 +67,7 @@ void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) { if (m_dst->layout.dtype == dtype::Float32() || DNN_FLOAT16_SELECT(m_dst->layout.dtype == dtype::Float16(), false) || m_dst->layout.dtype == dtype::Int32() || - m_dst->layout.dtype == dtype::Int16() || - m_dst->layout.dtype == dtype::Int8()) { + m_dst->layout.dtype == dtype::Int16() || m_dst->layout.dtype == dtype::Int8()) { auto kern_param = make_kern_param(this); kern_param.m_dst = &dst; static AlgoPack m_algo_pack; @@ -89,8 +87,7 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { kern_param.mode = opr->param().mode; kern_param.handle = opr->handle(); - if ((opr->m_src->size() == 3) && - (opr->param().mode == Mode::FUSE_MUL_ADD3)) { + if ((opr->m_src->size() == 3) && (opr->param().mode == Mode::FUSE_MUL_ADD3)) { kern_param.ternary_elparam = opr->make_elemwise_op_param<3>(); bool c_is_scalar; opr->prepare_fma3(kern_param.ternary_elparam, c_is_scalar); @@ -110,8 +107,7 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { return kern_param; } - if (is_vector(src1.layout) && - is_broadcasted_channel_like(src0.layout, binfo) && + if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo) && src0.layout.eq_layout(src2.layout)) { kern_param.broad_cast_type = BcastType::BCAST101_VEC_BCAST101; return kern_param; @@ -151,8 +147,7 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { } } else if (opr->m_src->size() == 2) { kern_param.binary_elparam = opr->make_elemwise_op_param<2>(); - auto &src0 = kern_param.binary_elparam[0], - &src1 = kern_param.binary_elparam[1]; + auto &src0 = kern_param.binary_elparam[0], &src1 = kern_param.binary_elparam[1]; BroadcastChannelInfo binfo; if (is_vector(src0.layout) && is_vector(src1.layout)) { kern_param.broad_cast_type = BcastType::VEC_VEC; @@ -169,14 +164,12 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { return kern_param; } - if (is_vector(src0.layout) && - is_broadcasted_channel_like(src1.layout, binfo)) { + if (is_vector(src0.layout) && is_broadcasted_channel_like(src1.layout, binfo)) { kern_param.broad_cast_type = BcastType::VEC_BCAST101; return kern_param; } - if (is_vector(src1.layout) && - is_broadcasted_channel_like(src0.layout, binfo)) { + if (is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo)) { kern_param.broad_cast_type = BcastType::BCAST101_VEC; return kern_param; } diff --git a/dnn/src/arm_common/elemwise/opr_impl.h b/dnn/src/arm_common/elemwise/opr_impl.h index f22db09c..76990713 100644 --- a/dnn/src/arm_common/elemwise/opr_impl.h +++ b/dnn/src/arm_common/elemwise/opr_impl.h @@ -64,18 +64,17 @@ public: }; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#define DISPATCH_TYPE(_case) \ - if (src0.layout.dtype == dtype::Float32{}) { \ - DISPATCH_MODE_FLOAT(_case, float, 0); \ - } else if (DNN_FLOAT16_SELECT(src0.layout.dtype == dtype::Float16{}, \ - false)) { \ - DISPATCH_MODE_FLOAT(_case, __fp16, 1); \ - } else if (src0.layout.dtype == dtype::Int32{}) { \ - DISPATCH_MODE_INT(_case, int, 2); \ - } else if (src0.layout.dtype == dtype::Int16{}) { \ - DISPATCH_MODE_INT(_case, dt_int16, 3); \ - } else if (src0.layout.dtype == dtype::Int8{}) { \ - DISPATCH_MODE_INT(_case, dt_int8, 4); \ +#define DISPATCH_TYPE(_case) \ + if (src0.layout.dtype == dtype::Float32{}) { \ + DISPATCH_MODE_FLOAT(_case, float, 0); \ + } else if (DNN_FLOAT16_SELECT(src0.layout.dtype == dtype::Float16{}, false)) { \ + DISPATCH_MODE_FLOAT(_case, __fp16, 1); \ + } else if (src0.layout.dtype == dtype::Int32{}) { \ + DISPATCH_MODE_INT(_case, int, 2); \ + } else if (src0.layout.dtype == dtype::Int16{}) { \ + DISPATCH_MODE_INT(_case, dt_int16, 3); \ + } else if (src0.layout.dtype == dtype::Int8{}) { \ + DISPATCH_MODE_INT(_case, dt_int8, 4); \ } #else #define DISPATCH_TYPE(_case) \ diff --git a/dnn/src/arm_common/elemwise/ternary/algo.cpp b/dnn/src/arm_common/elemwise/ternary/algo.cpp index 8070f0f8..0372114d 100644 --- a/dnn/src/arm_common/elemwise/ternary/algo.cpp +++ b/dnn/src/arm_common/elemwise/ternary/algo.cpp @@ -51,39 +51,40 @@ DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); #undef DISPATCH_MODE_FLOAT #undef DISPATCH_MODE_INT -#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ - switch (kern_param.mode) { \ - DISPATCH_TERNARY(FUSE_MUL_ADD3, _case, _type, _type_midout_id, \ - FuseMulAdd3Op); \ - default: \ - megdnn_throw(ssprintf("No avaiable algo find for: %d", \ - static_cast(kern_param.mode))); \ +#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_TERNARY(FUSE_MUL_ADD3, _case, _type, _type_midout_id, FuseMulAdd3Op); \ + default: \ + megdnn_throw(ssprintf( \ + "No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ } #define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT -void ElemwiseImpl::AlgoTernaryFma3VecVecVec::exec( - const KernParam& kern_param) const { +void ElemwiseImpl::AlgoTernaryFma3VecVecVec::exec(const KernParam& kern_param) const { auto& elparam = kern_param.ternary_elparam; auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; // Case 1: shape of (src0, src2) and src1 are exactly match -#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerTernary<_op<_type, _type>, \ - BcastType::VEC_VEC_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr), \ - static_cast(src2.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, src0.layout.total_nr_elems())); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ return auto&& dst = *(kern_param.m_dst); @@ -98,24 +99,26 @@ void ElemwiseImpl::AlgoTernaryFma3VecVecScalar::exec( auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; // Case 2: (src2 is a scalar) && (src0 and src1 has the same shape) -#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerTernary<_op<_type, _type>, \ - BcastType::VEC_VEC_SCALAR>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr), \ - static_cast(src2.raw_ptr)[0], \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, src0.layout.total_nr_elems())); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr)[0], \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ return auto&& dst = *(kern_param.m_dst); @@ -132,26 +135,26 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( // Case 3: shape of src0 and src2 is {1, C, 1, 1} BroadcastChannelInfo binfo; is_broadcasted_channel_like(src0.layout, binfo); -#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerTernary< \ - _op<_type, _type>, \ - BcastType::BCAST101_VEC_BCAST101>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr), \ - static_cast(src2.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ return auto&& dst = *(kern_param.m_dst); @@ -167,30 +170,31 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec( auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; BroadcastChannelInfo binfo; - megdnn_assert(is_broadcastedx_channel_like<4>(src0.layout, binfo) || - is_broadcastedx_channel_like<8>(src0.layout, binfo), - "only nchw44 and nchw88 supported"); -#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerTernary< \ - _op<_type, _type>, \ - BcastType::BCAST101xX_VEC_BCAST101xX>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr), \ - static_cast(src2.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, batch_size, binfo.x, binfo.y, \ - binfo.z)); \ - } \ - MIDOUT_END(); \ + megdnn_assert( + is_broadcastedx_channel_like<4>(src0.layout, binfo) || + is_broadcastedx_channel_like<8>(src0.layout, binfo), + "only nchw44 and nchw88 supported"); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, \ + BcastType::BCAST101xX_VEC_BCAST101xX>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + batch_size, binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ return size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); @@ -207,29 +211,30 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec( auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; BroadcastChannelInfo binfo; - megdnn_assert(is_broadcastedx_channel_like<4>(src1.layout, binfo) || - is_broadcastedx_channel_like<8>(src1.layout, binfo), - "only nchw44 and nchw88 supported"); -#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerTernary<_op<_type, _type>, \ - BcastType::VEC_BCAST101xX_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr), \ - static_cast(src2.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, batch_size, binfo.x, binfo.y, \ - binfo.z)); \ - } \ - MIDOUT_END(); \ + megdnn_assert( + is_broadcastedx_channel_like<4>(src1.layout, binfo) || + is_broadcastedx_channel_like<8>(src1.layout, binfo), + "only nchw44 and nchw88 supported"); +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_BCAST101xX_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + batch_size, binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ return size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); @@ -248,25 +253,26 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( // Case 4: shape of src1 is {1, C, 1, 1}, and src0 and src2 are contig BroadcastChannelInfo binfo; is_broadcasted_channel_like(src1.layout, binfo); -#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerTernary<_op<_type, _type>, \ - BcastType::VEC_BCAST101_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr), \ - static_cast(src2.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_BCAST101_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr), \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + binfo.x, binfo.y, binfo.z)); \ + } \ + MIDOUT_END(); \ return auto&& dst = *(kern_param.m_dst); @@ -282,24 +288,26 @@ void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; // Case 5: (src1 is a scalar) && (src0 and src2 has the same shape) -#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerTernary<_op<_type, _type>, \ - BcastType::VEC_SCALAR_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr)[0], \ - static_cast(src2.raw_ptr), \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, src0.layout.total_nr_elems())); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_SCALAR_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr)[0], \ + static_cast(src2.raw_ptr), \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ return auto&& dst = *(kern_param.m_dst); @@ -314,24 +322,26 @@ void ElemwiseImpl::AlgoTernaryFma3VecScalarScalar::exec( auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; // Case 6: (src1 and src2 is scalar) && (src0 is vector) -#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ - midout_iv(Mode::_mode), _type_midout_id) { \ - thin_function \ - run = OpCallerTernary<_op<_type, _type>, \ - BcastType::VEC_SCALAR_SCALAR>::run; \ - MEGDNN_DISPATCH_CPU_KERN( \ - static_cast(kern_param.handle), \ - run(static_cast(src0.raw_ptr), \ - static_cast(src1.raw_ptr)[0], \ - static_cast(src2.raw_ptr)[0], \ - static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ - src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, src0.layout.total_nr_elems())); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ + midout_iv(Mode::_mode), _type_midout_id) { \ + thin_function \ + run = OpCallerTernary< \ + _op<_type, _type>, BcastType::VEC_SCALAR_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(kern_param.handle), \ + run(static_cast(src0.raw_ptr), \ + static_cast(src1.raw_ptr)[0], \ + static_cast(src2.raw_ptr)[0], \ + static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + } \ + MIDOUT_END(); \ return auto&& dst = *(kern_param.m_dst); diff --git a/dnn/src/arm_common/elemwise/ternary/algo.h b/dnn/src/arm_common/elemwise/ternary/algo.h index 2864d352..211ba451 100644 --- a/dnn/src/arm_common/elemwise/ternary/algo.h +++ b/dnn/src/arm_common/elemwise/ternary/algo.h @@ -14,21 +14,20 @@ namespace megdnn { namespace arm_common { -#define DECL_CB(case) \ - class ElemwiseImpl::AlgoTernaryFma3##case final \ - : public ElemwiseImpl::AlgoBase { \ - mutable std::string m_name; \ - AlgoAttribute attribute() const override { \ - return AlgoAttribute::REPRODUCIBLE; \ - } \ - const char* name() const override { \ - if (m_name.empty()) { \ - m_name = ssprintf("Elemwise::AlgoTernaryFma3" #case); \ - } \ - return m_name.c_str(); \ - } \ - bool is_available(const KernParam&) const override; \ - void exec(const KernParam&) const override; \ +#define DECL_CB(case) \ + class ElemwiseImpl::AlgoTernaryFma3##case final : public ElemwiseImpl::AlgoBase { \ + mutable std::string m_name; \ + AlgoAttribute attribute() const override { \ + return AlgoAttribute::REPRODUCIBLE; \ + } \ + const char* name() const override { \ + if (m_name.empty()) { \ + m_name = ssprintf("Elemwise::AlgoTernaryFma3" #case); \ + } \ + return m_name.c_str(); \ + } \ + bool is_available(const KernParam&) const override; \ + void exec(const KernParam&) const override; \ }; DECL_CB(VecVecVec); diff --git a/dnn/src/arm_common/elemwise/unary/algo.cpp b/dnn/src/arm_common/elemwise/unary/algo.cpp index 2c947eb0..35ab39e2 100644 --- a/dnn/src/arm_common/elemwise/unary/algo.cpp +++ b/dnn/src/arm_common/elemwise/unary/algo.cpp @@ -64,28 +64,27 @@ bool ElemwiseImpl::AlgoUnary::is_available(const KernParam& kern_param) const { } void ElemwiseImpl::AlgoUnary::exec(const KernParam& kern_param) const { -#define DISPATCH_UNARY(_mode, _case, _type, _type_midout_id, _op) \ - case Mode::_mode: \ - MIDOUT_BEGIN(megdnn_arm_common_elemwise_unary, midout_iv(_case), \ - midout_iv(Mode::_mode), midout_iv(_type_midout_id)) { \ - thin_function \ - run = OpCallerUnary<_op<_type, _type>, \ - BcastType::VEC>::run; \ - auto kernel = [nr_elems, nr_elems_per_thread, src0, dst_tensor, \ - run](size_t task_id, size_t) { \ - size_t offset = task_id * nr_elems_per_thread; \ - size_t nr_elems_thread = \ - std::min(nr_elems - offset, nr_elems_per_thread); \ - run(static_cast(src0.raw_ptr) + offset, \ - static_cast<_type*>(dst_tensor.raw_ptr) + offset, \ - src0.layout.dtype, dst_tensor.layout.dtype, \ - nr_elems_thread); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast(kern_param.handle), \ - nr_threads, kernel); \ - } \ - MIDOUT_END(); \ +#define DISPATCH_UNARY(_mode, _case, _type, _type_midout_id, _op) \ + case Mode::_mode: \ + MIDOUT_BEGIN( \ + megdnn_arm_common_elemwise_unary, midout_iv(_case), \ + midout_iv(Mode::_mode), midout_iv(_type_midout_id)) { \ + thin_function run = \ + OpCallerUnary<_op<_type, _type>, BcastType::VEC>::run; \ + auto kernel = [nr_elems, nr_elems_per_thread, src0, dst_tensor, run]( \ + size_t task_id, size_t) { \ + size_t offset = task_id * nr_elems_per_thread; \ + size_t nr_elems_thread = \ + std::min(nr_elems - offset, nr_elems_per_thread); \ + run(static_cast(src0.raw_ptr) + offset, \ + static_cast<_type*>(dst_tensor.raw_ptr) + offset, \ + src0.layout.dtype, dst_tensor.layout.dtype, nr_elems_thread); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(kern_param.handle), nr_threads, \ + kernel); \ + } \ + MIDOUT_END(); \ return auto& elparam = kern_param.unary_elparam; @@ -110,17 +109,19 @@ void ElemwiseImpl::AlgoUnary::exec(const KernParam& kern_param) const { DISPATCH_UNARY(FAST_TANH, _case, _type, _type_midout_id, FastTanhOp); \ DISPATCH_UNARY(H_SWISH, _case, _type, _type_midout_id, HSwishOp); \ default: \ - megdnn_throw(ssprintf("No avaiable algo find for: %d", \ - static_cast(kern_param.mode))); \ + megdnn_throw(ssprintf( \ + "No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ } -#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ - switch (kern_param.mode) { \ - DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \ - DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \ - default: \ - megdnn_throw(ssprintf("No avaiable algo find for: %d", \ - static_cast(kern_param.mode))); \ +#define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \ + switch (kern_param.mode) { \ + DISPATCH_UNARY(RELU, _case, _type, _type_midout_id, ReluOp); \ + DISPATCH_UNARY(ABS, _case, _type, _type_midout_id, AbsOp); \ + default: \ + megdnn_throw(ssprintf( \ + "No avaiable algo find for: %d", \ + static_cast(kern_param.mode))); \ } DISPATCH_TYPE("AlgoUnary::exec"_hash); diff --git a/dnn/src/arm_common/elemwise/unary/algo.h b/dnn/src/arm_common/elemwise/unary/algo.h index 8d31fbef..bcabf5cd 100644 --- a/dnn/src/arm_common/elemwise/unary/algo.h +++ b/dnn/src/arm_common/elemwise/unary/algo.h @@ -16,9 +16,7 @@ namespace arm_common { class ElemwiseImpl::AlgoUnary final : public ElemwiseImpl::AlgoBase { mutable std::string m_name; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { if (m_name.empty()) { m_name = ssprintf("Elemwise::AlgoUnary"); diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/abs.h b/dnn/src/arm_common/elemwise_helper/kimpl/abs.h index 97b4f5ec..70b3cf7e 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/abs.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/abs.h @@ -21,9 +21,7 @@ struct AbsOpBase : UnaryOpBase { void operator()(const src_ctype& src, dst_ctype* dst) const { *dst = operator()(src); } - dst_ctype operator()(const src_ctype& src) const { - return src > 0 ? src : (-src); - } + dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : (-src); } }; template diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/add.h b/dnn/src/arm_common/elemwise_helper/kimpl/add.h index 2083264b..c194413d 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/add.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/add.h @@ -18,8 +18,8 @@ namespace arm_common { template struct AddOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } @@ -28,37 +28,39 @@ struct AddOpBase : BinaryOpBase { } }; -template +template < + typename src_ctype, typename dst_ctype = src_ctype, + bool enable_opt_or_fixup = false> struct AddOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct AddOp<_ctype> : AddOpBase<_ctype> { \ - using AddOpBase::AddOpBase; \ - using AddOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src0, const _neon_type2& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type2 operator()(const _neon_type2& src0, \ - const _neon_type2& src1) const { \ - auto vitem0 = vaddq_##_func_suffix(src0.val[0], src1.val[0]); \ - auto vitem1 = vaddq_##_func_suffix(src0.val[1], src1.val[1]); \ - return {{vitem0, vitem1}}; \ - } \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type operator()(const _neon_type& src0, \ - const _neon_type& src1) const { \ - return vaddq_##_func_suffix(src0, src1); \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct AddOp<_ctype> : AddOpBase<_ctype> { \ + using AddOpBase::AddOpBase; \ + using AddOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()( \ + const _neon_type2& src0, const _neon_type2& src1) const { \ + auto vitem0 = vaddq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vaddq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ + return vaddq_##_func_suffix(src0, src1); \ + } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -72,8 +74,7 @@ OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) template <> struct AddOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_qint8& src0, const dt_qint8& src1, - dt_qint8* dst) const { + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { *dst = operator()(src0, src1); } @@ -93,15 +94,15 @@ struct AddOpBase : BinaryOpBase { szp = this->szp0 + this->szp1; vszp = vdupq_n_f32(szp); } - void operator()(const dt_quint8& src0, const dt_quint8& src1, - dt_quint8* dst) const { + void operator()( + const dt_quint8& src0, const dt_quint8& src1, dt_quint8* dst) const { *dst = operator()(src0, src1); } dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1) const { return QConverter::convert( - src0.as_uint8() * this->scale0 + - src1.as_uint8() * this->scale1 - this->szp, + src0.as_uint8() * this->scale0 + src1.as_uint8() * this->scale1 - + this->szp, this->dzp); } }; @@ -112,13 +113,12 @@ struct AddOp : AddOpBase { constexpr static size_t SIMD_WIDTH = 16; using AddOpBase::operator(); - void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, dt_qint8* dst) const { OPERATOR_BINARY_QINT8; } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { auto vitem0 = vaddq_f32( vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); @@ -136,13 +136,13 @@ struct AddOp : AddOpBase { constexpr static size_t SIMD_WIDTH = 16; using AddOpBase::operator(); - void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, - dt_quint8* dst) const { + void operator()( + const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { OPERATOR_BINARY_QUINT8; } - uint8x8_t operator()(const uint32x4x2_t& vsrc0, - const uint32x4x2_t& vsrc1) const { + uint8x8_t operator()(const uint32x4x2_t& vsrc0, const uint32x4x2_t& vsrc1) const { auto vitem0 = vsubq_f32( vaddq_f32( vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), @@ -162,30 +162,27 @@ struct AddOp : AddOpBase { template <> struct AddOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_qint32& src0, const dt_qint32& src1, - dt_qint8* dst) const { + void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const { *dst = operator()(src0, src1); } dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { return QConverter::convert( - src0.as_int32() * this->scale0 + - src1.as_int32() * this->scale1); + src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1); } }; template <> struct AddOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_qint32& src0, const dt_qint32& src1, - dt_quint8* dst) const { + void operator()( + const dt_qint32& src0, const dt_qint32& src1, dt_quint8* dst) const { *dst = operator()(src0, src1); } dt_quint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { return QConverter::convert( - src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1, - zp); + src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1, zp); } }; @@ -197,13 +194,12 @@ struct AddOp using AddOpBase::operator(); constexpr static size_t SIMD_WIDTH = 4; - void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, dt_qint8* dst) const { vst1_s8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { if (enable_opt_or_fixup) { auto vitem0 = vmulq_f32( vcvtq_f32_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0])), @@ -211,8 +207,7 @@ struct AddOp auto vitem1 = vmulq_f32( vcvtq_f32_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1])), this->vscale0); - return QConverter::convert( - {{vitem0, vitem1}}); + return QConverter::convert({{vitem0, vitem1}}); } else { auto vitem0 = vaddq_f32( vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), @@ -220,15 +215,14 @@ struct AddOp auto vitem1 = vaddq_f32( vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); - return QConverter::convert( - {{vitem0, vitem1}}); + return QConverter::convert({{vitem0, vitem1}}); } } }; #else template -struct AddOp - : AddOpBase, FixupBase { +struct AddOp : AddOpBase, + FixupBase { using AddOpBase::operator(); constexpr static size_t SIMD_WIDTH = 4; @@ -238,26 +232,25 @@ struct AddOp AddOp(float src0_scale, float src1_scale, float dst_scale) : AddOpBase(src0_scale, src1_scale, dst_scale), FixupBase(scale0) {} - void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, dt_qint8* dst) const { vst1_s8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { if (enable_opt_or_fixup) { - auto vitem0 = vqrdmulhq_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0]), - vmultiplier); - auto vitem1 = vqrdmulhq_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1]), - vmultiplier); + auto vitem0 = + vqrdmulhq_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0]), vmultiplier); + auto vitem1 = + vqrdmulhq_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1]), vmultiplier); // FIXME Theoretically, we should check shift != 0 here. auto fixup0 = vshrq_n_s32(vitem0, 31); auto fixup1 = vshrq_n_s32(vitem1, 31); vitem0 = vqaddq_s32(vitem0, fixup0); vitem1 = vqaddq_s32(vitem1, fixup1); - return vqmovn_s16( - vcombine_s16(vqmovn_s32(vrshlq_s32(vitem0, vshift)), - vqmovn_s32(vrshlq_s32(vitem1, vshift)))); + return vqmovn_s16(vcombine_s16( + vqmovn_s32(vrshlq_s32(vitem0, vshift)), + vqmovn_s32(vrshlq_s32(vitem1, vshift)))); } else { auto vitem0 = vaddq_f32( vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), @@ -265,8 +258,7 @@ struct AddOp auto vitem1 = vaddq_f32( vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale0), vmulq_f32(vcvtq_f32_s32(vsrc1.val[1]), this->vscale1)); - return QConverter::convert( - {{vitem0, vitem1}}); + return QConverter::convert({{vitem0, vitem1}}); } } }; @@ -279,13 +271,12 @@ struct AddOp constexpr static size_t SIMD_WIDTH = 4; using AddOpBase::operator(); - void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, - dt_quint8* dst) const { + void operator()( + const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, dt_quint8* dst) const { vst1_u8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); } - uint8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + uint8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { if (enable_opt_or_fixup) { auto vitem0 = vmulq_f32( vcvtq_f32_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0])), diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_h_swish.h b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_h_swish.h index 2e777c2f..b501997b 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_h_swish.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_h_swish.h @@ -10,8 +10,8 @@ */ #pragma once -#include "src/arm_common/elemwise_helper/kimpl/op_base.h" #include "src/arm_common/elemwise_helper/kimpl/kern_macro_prologue.h" +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" namespace megdnn { namespace arm_common { @@ -19,8 +19,8 @@ namespace arm_common { template struct FuseAddHSwishOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { @@ -30,46 +30,48 @@ struct FuseAddHSwishOpBase : BinaryOpBase { } }; -template +template < + typename src_ctype, typename dst_ctype = src_ctype, + bool enable_opt_or_fixup = false> struct FuseAddHSwishOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct FuseAddHSwishOp<_ctype> : FuseAddHSwishOpBase<_ctype> { \ - using FuseAddHSwishOpBase::FuseAddHSwishOpBase; \ - using FuseAddHSwishOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src0, const _neon_type2& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type2 operator()(const _neon_type2& src0, \ - const _neon_type2& src1) const { \ - auto val1 = src0.val[0]; \ - auto val2 = src0.val[1]; \ - auto val3 = src1.val[0]; \ - auto val4 = src1.val[1]; \ - val1 = vaddq_##_func_suffix(val1, val3); \ - val2 = vaddq_##_func_suffix(val2, val4); \ - H_SWISH_KERN(_func_suffix, val1, val2); \ - return {{val1, val2}}; \ - } \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type operator()(const _neon_type& src0, \ - const _neon_type& src1) const { \ - auto val1 = src0; \ - auto val2 = src1; \ - val1 = vaddq_##_func_suffix(val1, val2); \ - H_SWISH_KERN_N1(_func_suffix, val1); \ - return val1; \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddHSwishOp<_ctype> : FuseAddHSwishOpBase<_ctype> { \ + using FuseAddHSwishOpBase::FuseAddHSwishOpBase; \ + using FuseAddHSwishOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()( \ + const _neon_type2& src0, const _neon_type2& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = vaddq_##_func_suffix(val1, val3); \ + val2 = vaddq_##_func_suffix(val2, val4); \ + H_SWISH_KERN(_func_suffix, val1, val2); \ + return {{val1, val2}}; \ + } \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ + auto val1 = src0; \ + auto val2 = src1; \ + val1 = vaddq_##_func_suffix(val1, val2); \ + H_SWISH_KERN_N1(_func_suffix, val1); \ + return val1; \ + } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -78,17 +80,15 @@ OP(__fp16, float16x8_t, float16x8x2_t, f16, 8) #undef OP template <> -struct FuseAddHSwishOpBase - : BinaryOpBase { +struct FuseAddHSwishOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_qint32& src0, const dt_qint32& src1, - dt_qint8* dst) const { + void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const { *dst = operator()(src0, src1); } dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { - float tmp = src0.as_int32() * this->scale_src0 + - src1.as_int32() * this->scale_src1; + float tmp = + src0.as_int32() * this->scale_src0 + src1.as_int32() * this->scale_src1; tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; tmp *= this->scale_dst; return QConverter::convert(tmp); @@ -96,17 +96,16 @@ struct FuseAddHSwishOpBase }; template <> -struct FuseAddHSwishOpBase - : BinaryOpBase { +struct FuseAddHSwishOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_qint32& src0, const dt_qint32& src1, - dt_quint8* dst) const { + void operator()( + const dt_qint32& src0, const dt_qint32& src1, dt_quint8* dst) const { *dst = operator()(src0, src1); } dt_quint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { - float tmp = src0.as_int32() * this->scale_src0 + - src1.as_int32() * this->scale_src1; + float tmp = + src0.as_int32() * this->scale_src0 + src1.as_int32() * this->scale_src1; tmp = tmp * std::max(std::min(tmp + 3.f, 6.f), 0.f) / 6.f; tmp *= this->scale_dst; return QConverter::convert(tmp, zp); @@ -119,13 +118,12 @@ struct FuseAddHSwishOp using FuseAddHSwishOpBase::FuseAddHSwishOpBase; using FuseAddHSwishOpBase::operator(); constexpr static size_t SIMD_WIDTH = 4; - void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, dt_qint8* dst) const { vst1_s8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { float32x4_t vitem0, vitem1; if (enable_opt_or_fixup) { vitem0 = vmulq_f32( @@ -156,13 +154,12 @@ struct FuseAddHSwishOp using FuseAddHSwishOpBase::FuseAddHSwishOpBase; using FuseAddHSwishOpBase::operator(); constexpr static size_t SIMD_WIDTH = 4; - void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, - dt_quint8* dst) const { + void operator()( + const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, dt_quint8* dst) const { vst1_u8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); } - uint8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + uint8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { float32x4_t vitem0, vitem1; if (enable_opt_or_fixup) { vitem0 = vmulq_f32( @@ -184,8 +181,8 @@ struct FuseAddHSwishOp H_SWISH_KERN(f32, vitem0, vitem1); vitem0 = vmulq_f32(vitem0, this->vscale_dst); vitem1 = vmulq_f32(vitem1, this->vscale_dst); - return QConverter::convert({{vitem0, vitem1}}, - this->vzp); + return QConverter::convert( + {{vitem0, vitem1}}, this->vzp); } }; diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_relu.h b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_relu.h index 7fe16b30..bd18d41a 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_relu.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_relu.h @@ -10,8 +10,8 @@ */ #pragma once -#include "src/arm_common/elemwise_helper/kimpl/op_base.h" #include "src/arm_common/elemwise/neon_util_impl_helper.h" +#include "src/arm_common/elemwise_helper/kimpl/op_base.h" namespace megdnn { namespace arm_common { @@ -19,8 +19,8 @@ namespace arm_common { template struct FuseAddReluOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { @@ -29,43 +29,45 @@ struct FuseAddReluOpBase : BinaryOpBase { } }; -template +template < + typename src_ctype, typename dst_ctype = src_ctype, + bool enable_opt_or_fixup = false> struct FuseAddReluOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct FuseAddReluOp<_ctype> : FuseAddReluOpBase<_ctype> { \ - using FuseAddReluOpBase::FuseAddReluOpBase; \ - using FuseAddReluOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src0, const _neon_type2& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type2 operator()(const _neon_type2& src0, \ - const _neon_type2& src1) const { \ - auto val1 = src0.val[0]; \ - auto val2 = src0.val[1]; \ - auto val3 = src1.val[0]; \ - auto val4 = src1.val[1]; \ - FUSE_ADD_RELU_NEON_PACK2(val1, val2, val3, val4, _func_suffix); \ - return {{val1, val2}}; \ - } \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type operator()(const _neon_type& src0, \ - const _neon_type& src1) const { \ - auto val1 = src0; \ - auto val2 = src1; \ - FUSE_ADD_RELU_NEON_PACK(val1, val2, _func_suffix); \ - return val1; \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddReluOp<_ctype> : FuseAddReluOpBase<_ctype> { \ + using FuseAddReluOpBase::FuseAddReluOpBase; \ + using FuseAddReluOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()( \ + const _neon_type2& src0, const _neon_type2& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + FUSE_ADD_RELU_NEON_PACK2(val1, val2, val3, val4, _func_suffix); \ + return {{val1, val2}}; \ + } \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ + auto val1 = src0; \ + auto val2 = src1; \ + FUSE_ADD_RELU_NEON_PACK(val1, val2, _func_suffix); \ + return val1; \ + } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -90,24 +92,20 @@ struct FuseAddReluOpCommon { }; template <> -struct FuseAddReluOpBase - : BinaryOpBase { +struct FuseAddReluOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_qint8& src0, const dt_qint8& src1, - dt_qint8* dst) const { + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { *dst = operator()(src0, src1); } dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { return QConverter::convert(std::max( - src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1, - 0.f)); + src0.as_int8() * this->scale0 + src1.as_int8() * this->scale1, 0.f)); } }; template <> -struct FuseAddReluOpBase - : BinaryOpBase { +struct FuseAddReluOpBase : BinaryOpBase { float szp; float32x4_t vszp; @@ -116,35 +114,34 @@ struct FuseAddReluOpBase szp = this->szp0 + this->szp1; vszp = vdupq_n_f32(szp); } - void operator()(const dt_quint8& src0, const dt_quint8& src1, - dt_quint8* dst) const { + void operator()( + const dt_quint8& src0, const dt_quint8& src1, dt_quint8* dst) const { *dst = operator()(src0, src1); } dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1) const { return QConverter::convert( - std::max(src0.as_uint8() * this->scale0 + - src1.as_uint8() * this->scale1 - - this->szp, - 0.f), + std::max( + src0.as_uint8() * this->scale0 + + src1.as_uint8() * this->scale1 - this->szp, + 0.f), this->dzp); } }; template <> -struct FuseAddReluOp - : FuseAddReluOpBase, FuseAddReluOpCommon { +struct FuseAddReluOp : FuseAddReluOpBase, + FuseAddReluOpCommon { using FuseAddReluOpBase::FuseAddReluOpBase; using FuseAddReluOpBase::operator(); constexpr static size_t SIMD_WIDTH = 16; - void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, dt_qint8* dst) const { OPERATOR_BINARY_QINT8; } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { auto vitem0 = vaddq_f32( vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); @@ -159,19 +156,19 @@ struct FuseAddReluOp }; template <> -struct FuseAddReluOp - : FuseAddReluOpBase, FuseAddReluOpCommon { +struct FuseAddReluOp : FuseAddReluOpBase, + FuseAddReluOpCommon { using FuseAddReluOpBase::FuseAddReluOpBase; using FuseAddReluOpBase::operator(); constexpr static size_t SIMD_WIDTH = 16; - void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, - dt_quint8* dst) const { + void operator()( + const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { OPERATOR_BINARY_QUINT8; } - uint8x8_t operator()(const uint32x4x2_t& vsrc0, - const uint32x4x2_t& vsrc1) const { + uint8x8_t operator()(const uint32x4x2_t& vsrc0, const uint32x4x2_t& vsrc1) const { auto vitem0 = vsubq_f32( vaddq_f32( vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), @@ -185,41 +182,37 @@ struct FuseAddReluOp vitem0 = vmaxq_f32(vitem0, this->vzero()); vitem1 = vmaxq_f32(vitem1, this->vzero()); - return QConverter::convert({{vitem0, vitem1}}, - this->vdzp); + return QConverter::convert( + {{vitem0, vitem1}}, this->vdzp); } }; template <> -struct FuseAddReluOpBase - : BinaryOpBase { +struct FuseAddReluOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_qint32& src0, const dt_qint32& src1, - dt_qint8* dst) const { + void operator()(const dt_qint32& src0, const dt_qint32& src1, dt_qint8* dst) const { *dst = operator()(src0, src1); } dt_qint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { return QConverter::convert(std::max( - src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1, - 0.f)); + src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1, 0.f)); } }; template <> -struct FuseAddReluOpBase - : BinaryOpBase { +struct FuseAddReluOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_qint32& src0, const dt_qint32& src1, - dt_quint8* dst) const { + void operator()( + const dt_qint32& src0, const dt_qint32& src1, dt_quint8* dst) const { *dst = operator()(src0, src1); } dt_quint8 operator()(const dt_qint32& src0, const dt_qint32& src1) const { return QConverter::convert( - std::max(src0.as_int32() * this->scale0 + - src1.as_int32() * this->scale1, - 0.f), + std::max( + src0.as_int32() * this->scale0 + src1.as_int32() * this->scale1, + 0.f), zp); } }; @@ -231,13 +224,12 @@ struct FuseAddReluOp using FuseAddReluOpBase::FuseAddReluOpBase; using FuseAddReluOpBase::operator(); constexpr static size_t SIMD_WIDTH = 4; - void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, dt_qint8* dst) const { vst1_s8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { if (enable_opt_or_fixup) { auto vitem0 = vmulq_f32( vcvtq_f32_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0])), @@ -247,8 +239,7 @@ struct FuseAddReluOp this->vscale0); vitem0 = vmaxq_f32(vitem0, this->vzero()); vitem1 = vmaxq_f32(vitem1, this->vzero()); - return QConverter::convert( - {{vitem0, vitem1}}); + return QConverter::convert({{vitem0, vitem1}}); } else { auto vitem0 = vaddq_f32( @@ -260,8 +251,7 @@ struct FuseAddReluOp vitem0 = vmaxq_f32(vitem0, this->vzero()); vitem1 = vmaxq_f32(vitem1, this->vzero()); - return QConverter::convert( - {{vitem0, vitem1}}); + return QConverter::convert({{vitem0, vitem1}}); } } }; @@ -274,30 +264,27 @@ struct FuseAddReluOp using FuseAddReluOpBase::operator(); constexpr static size_t SIMD_WIDTH = 4; FuseAddReluOp(DType src0_dtype, DType src1_dtype, DType dst_dtype) - : FuseAddReluOpBase(src0_dtype, src1_dtype, dst_dtype), - FixupBase(scale0) {} + : FuseAddReluOpBase(src0_dtype, src1_dtype, dst_dtype), FixupBase(scale0) {} FuseAddReluOp(float src0_scale, float src1_scale, float dst_scale) - : FuseAddReluOpBase(src0_scale, src1_scale, dst_scale), - FixupBase(scale0) {} + : FuseAddReluOpBase(src0_scale, src1_scale, dst_scale), FixupBase(scale0) {} - void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, dt_qint8* dst) const { vst1_s8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { if (enable_opt_or_fixup) { - auto vitem0 = vqrdmulhq_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0]), - vmultiplier); - auto vitem1 = vqrdmulhq_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1]), - vmultiplier); + auto vitem0 = + vqrdmulhq_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0]), vmultiplier); + auto vitem1 = + vqrdmulhq_s32(vaddq_s32(vsrc0.val[1], vsrc1.val[1]), vmultiplier); vitem0 = vmaxq_s32(vitem0, FuseAddReluOpCommon::vzero()); vitem1 = vmaxq_s32(vitem1, FuseAddReluOpCommon::vzero()); - return vqmovn_s16( - vcombine_s16(vqmovn_s32(vrshlq_s32(vitem0, vshift)), - vqmovn_s32(vrshlq_s32(vitem1, vshift)))); + return vqmovn_s16(vcombine_s16( + vqmovn_s32(vrshlq_s32(vitem0, vshift)), + vqmovn_s32(vrshlq_s32(vitem1, vshift)))); } else { auto vitem0 = vaddq_f32( vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), @@ -308,8 +295,7 @@ struct FuseAddReluOp vitem0 = vmaxq_f32(vitem0, this->vzero()); vitem1 = vmaxq_f32(vitem1, this->vzero()); - return QConverter::convert( - {{vitem0, vitem1}}); + return QConverter::convert({{vitem0, vitem1}}); } } }; @@ -321,13 +307,12 @@ struct FuseAddReluOp using FuseAddReluOpBase::FuseAddReluOpBase; using FuseAddReluOpBase::operator(); constexpr static size_t SIMD_WIDTH = 4; - void operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, - dt_quint8* dst) const { + void operator()( + const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, dt_quint8* dst) const { vst1_u8(reinterpret_cast(dst), operator()(vsrc0, vsrc1)); } - uint8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + uint8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { if (enable_opt_or_fixup) { auto vitem0 = vmulq_f32( vcvtq_f32_s32(vaddq_s32(vsrc0.val[0], vsrc1.val[0])), diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h index b3b33c0e..56c11219 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h @@ -18,8 +18,8 @@ namespace arm_common { template struct FuseAddSigmoidOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { @@ -33,42 +33,40 @@ struct FuseAddSigmoidOpBase : BinaryOpBase { template struct FuseAddSigmoidOp; -#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ - template <> \ - struct FuseAddSigmoidOp<_ctype> : FuseAddSigmoidOpBase<_ctype> { \ - using FuseAddSigmoidOpBase::FuseAddSigmoidOpBase; \ - using FuseAddSigmoidOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type operator()(const _neon_type& src0, \ - const _neon_type& src1) const { \ - auto zero_val = vdupq_n_##_func_suffix(0.f); \ - auto one_val = vdupq_n_##_func_suffix(1.f); \ - auto val1 = src0.val[0]; \ - auto val2 = src0.val[1]; \ - auto val3 = src1.val[0]; \ - auto val4 = src1.val[1]; \ - val1 = vaddq_##_func_suffix(val1, val3); \ - val2 = vaddq_##_func_suffix(val2, val4); \ - val1 = vsubq_##_func_suffix(zero_val, val1); \ - val2 = vsubq_##_func_suffix(zero_val, val2); \ - val1 = exp_ps_##_func_suffix(val1); \ - val2 = exp_ps_##_func_suffix(val2); \ - auto recipe1 = vaddq_##_func_suffix(one_val, val1); \ - auto recipe2 = vaddq_##_func_suffix(one_val, val2); \ - val1 = vrecpeq_##_func_suffix(recipe1); \ - val2 = vrecpeq_##_func_suffix(recipe2); \ - val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), \ - val1); \ - val2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe2, val2), \ - val2); \ - return {{val1, val2}}; \ - } \ +#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddSigmoidOp<_ctype> : FuseAddSigmoidOpBase<_ctype> { \ + using FuseAddSigmoidOpBase::FuseAddSigmoidOpBase; \ + using FuseAddSigmoidOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ + auto zero_val = vdupq_n_##_func_suffix(0.f); \ + auto one_val = vdupq_n_##_func_suffix(1.f); \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = vaddq_##_func_suffix(val1, val3); \ + val2 = vaddq_##_func_suffix(val2, val4); \ + val1 = vsubq_##_func_suffix(zero_val, val1); \ + val2 = vsubq_##_func_suffix(zero_val, val2); \ + val1 = exp_ps_##_func_suffix(val1); \ + val2 = exp_ps_##_func_suffix(val2); \ + auto recipe1 = vaddq_##_func_suffix(one_val, val1); \ + auto recipe2 = vaddq_##_func_suffix(one_val, val2); \ + val1 = vrecpeq_##_func_suffix(recipe1); \ + val2 = vrecpeq_##_func_suffix(recipe2); \ + val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), val1); \ + val2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe2, val2), val2); \ + return {{val1, val2}}; \ + } \ }; OP(dt_float32, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_tanh.h b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_tanh.h index 3716c49b..32d42f3b 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_tanh.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_add_tanh.h @@ -18,8 +18,8 @@ namespace arm_common { template struct FuseAddTanhOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { @@ -32,48 +32,44 @@ struct FuseAddTanhOpBase : BinaryOpBase { template struct FuseAddTanhOp; -#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ - template <> \ - struct FuseAddTanhOp<_ctype> : FuseAddTanhOpBase<_ctype> { \ - using FuseAddTanhOpBase::FuseAddTanhOpBase; \ - using FuseAddTanhOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type operator()(const _neon_type& src0, \ - const _neon_type& src1) const { \ - auto val1 = src0.val[0]; \ - auto val2 = src0.val[1]; \ - auto val3 = src1.val[0]; \ - auto val4 = src1.val[1]; \ - val1 = vaddq_##_func_suffix(val1, val3); \ - val2 = vaddq_##_func_suffix(val2, val4); \ - auto exp1 = exp_ps_##_func_suffix(val1); \ - auto exp2 = exp_ps_##_func_suffix(val2); \ - auto rexp1 = vrecpeq_##_func_suffix(exp1); \ - auto rexp2 = vrecpeq_##_func_suffix(exp2); \ - rexp1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp1, rexp1), \ - rexp1); \ - rexp2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp2, rexp2), \ - rexp2); \ - val1 = vsubq_##_func_suffix(exp1, rexp1); \ - val2 = vsubq_##_func_suffix(exp2, rexp2); \ - exp1 = vaddq_##_func_suffix(exp1, rexp1); \ - exp2 = vaddq_##_func_suffix(exp2, rexp2); \ - rexp1 = vrecpeq_##_func_suffix(exp1); \ - rexp2 = vrecpeq_##_func_suffix(exp2); \ - rexp1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp1, rexp1), \ - rexp1); \ - rexp2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp2, rexp2), \ - rexp2); \ - val1 = vmulq_##_func_suffix(val1, rexp1); \ - val2 = vmulq_##_func_suffix(val2, rexp2); \ - return {{val1, val2}}; \ - } \ +#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ + template <> \ + struct FuseAddTanhOp<_ctype> : FuseAddTanhOpBase<_ctype> { \ + using FuseAddTanhOpBase::FuseAddTanhOpBase; \ + using FuseAddTanhOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = vaddq_##_func_suffix(val1, val3); \ + val2 = vaddq_##_func_suffix(val2, val4); \ + auto exp1 = exp_ps_##_func_suffix(val1); \ + auto exp2 = exp_ps_##_func_suffix(val2); \ + auto rexp1 = vrecpeq_##_func_suffix(exp1); \ + auto rexp2 = vrecpeq_##_func_suffix(exp2); \ + rexp1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp1, rexp1), rexp1); \ + rexp2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp2, rexp2), rexp2); \ + val1 = vsubq_##_func_suffix(exp1, rexp1); \ + val2 = vsubq_##_func_suffix(exp2, rexp2); \ + exp1 = vaddq_##_func_suffix(exp1, rexp1); \ + exp2 = vaddq_##_func_suffix(exp2, rexp2); \ + rexp1 = vrecpeq_##_func_suffix(exp1); \ + rexp2 = vrecpeq_##_func_suffix(exp2); \ + rexp1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp1, rexp1), rexp1); \ + rexp2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(exp2, rexp2), rexp2); \ + val1 = vmulq_##_func_suffix(val1, rexp1); \ + val2 = vmulq_##_func_suffix(val2, rexp2); \ + return {{val1, val2}}; \ + } \ }; OP(dt_float32, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_mul_add3.h b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_mul_add3.h index cb30c3b3..90bd1404 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/fuse_mul_add3.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/fuse_mul_add3.h @@ -18,13 +18,14 @@ namespace arm_common { template struct FuseMulAdd3OpBase : TernaryOpBase { using TernaryOpBase::TernaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - const src_ctype src2, dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, const src_ctype src2, + dst_ctype* dst) const { *dst = operator()(src0, src1, src2); } - dst_ctype operator()(const src_ctype& src0, const src_ctype& src1, - const src_ctype& src2) const { + dst_ctype operator()( + const src_ctype& src0, const src_ctype& src1, const src_ctype& src2) const { return (src0 * src1) + src2; } }; @@ -32,26 +33,26 @@ struct FuseMulAdd3OpBase : TernaryOpBase { template struct FuseMulAdd3Op; -#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ - template <> \ - struct FuseMulAdd3Op<_ctype> : FuseMulAdd3OpBase<_ctype> { \ - using FuseMulAdd3OpBase::FuseMulAdd3OpBase; \ - using FuseMulAdd3OpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - const _neon_type& src2, dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1, src2); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type operator()(const _neon_type& src0, const _neon_type& src1, \ - const _neon_type& src2) const { \ - auto vitem0 = vmlaq_##_func_suffix(src2.val[0], src0.val[0], \ - src1.val[0]); \ - auto vitem1 = vmlaq_##_func_suffix(src2.val[1], src0.val[1], \ - src1.val[1]); \ - return {{vitem0, vitem1}}; \ - } \ +#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ + template <> \ + struct FuseMulAdd3Op<_ctype> : FuseMulAdd3OpBase<_ctype> { \ + using FuseMulAdd3OpBase::FuseMulAdd3OpBase; \ + using FuseMulAdd3OpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + const _neon_type& src2, dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1, src2); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + const _neon_type& src2) const { \ + auto vitem0 = vmlaq_##_func_suffix(src2.val[0], src0.val[0], src1.val[0]); \ + auto vitem1 = vmlaq_##_func_suffix(src2.val[1], src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ }; OP(dt_float32, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/hswish.h b/dnn/src/arm_common/elemwise_helper/kimpl/hswish.h index 22f5daae..56a16a18 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/hswish.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/hswish.h @@ -34,39 +34,39 @@ struct HSwishOpBase : UnaryOpBase { template struct HSwishOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct HSwishOp<_ctype> : HSwishOpBase<_ctype> { \ - using HSwishOpBase::HSwishOpBase; \ - using HSwishOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src, _ctype* dst) const { \ - auto vitem = operator()(src); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - void operator()(const _neon_type& src, _ctype* dst) const { \ - auto vitem = operator()(src); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type2 operator()(const _neon_type2& src) const { \ - auto val1 = src.val[0]; \ - auto val2 = src.val[1]; \ - H_SWISH_KERN(_func_suffix, val1, val2); \ - return {{val1, val2}}; \ - } \ - _neon_type operator()(const _neon_type& src) const { \ - auto val_zero = vdupq_n_##_func_suffix(0.f); \ - auto val_six = vdupq_n_##_func_suffix(6.f); \ - auto val_three = vdupq_n_##_func_suffix(3.f); \ - auto val_rec_six = vdupq_n_##_func_suffix(1.f / 6.f); \ - auto clip1 = vmaxq_##_func_suffix( \ - vminq_##_func_suffix(vaddq_##_func_suffix(src, val_three), \ - val_six), \ - val_zero); \ - return vmulq_##_func_suffix(vmulq_##_func_suffix(src, clip1), \ - val_rec_six); \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct HSwishOp<_ctype> : HSwishOpBase<_ctype> { \ + using HSwishOpBase::HSwishOpBase; \ + using HSwishOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type2 operator()(const _neon_type2& src) const { \ + auto val1 = src.val[0]; \ + auto val2 = src.val[1]; \ + H_SWISH_KERN(_func_suffix, val1, val2); \ + return {{val1, val2}}; \ + } \ + _neon_type operator()(const _neon_type& src) const { \ + auto val_zero = vdupq_n_##_func_suffix(0.f); \ + auto val_six = vdupq_n_##_func_suffix(6.f); \ + auto val_three = vdupq_n_##_func_suffix(3.f); \ + auto val_rec_six = vdupq_n_##_func_suffix(1.f / 6.f); \ + auto clip1 = vmaxq_##_func_suffix( \ + vminq_##_func_suffix( \ + vaddq_##_func_suffix(src, val_three), val_six), \ + val_zero); \ + return vmulq_##_func_suffix( \ + vmulq_##_func_suffix(src, clip1), val_rec_six); \ + } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) @@ -115,8 +115,8 @@ struct HSwishOp : HSwishOpBase { vst1_s8(reinterpret_cast(dst), operator()(vsrc)); } void operator()(const int32x4_t& vsrc, dt_qint8* dst) const { - vst1_lane_s32(reinterpret_cast(dst), - (int32x2_t)(operator()(vsrc)), 0); + vst1_lane_s32( + reinterpret_cast(dst), (int32x2_t)(operator()(vsrc)), 0); } int8x8_t operator()(const int32x4x2_t& vsrc) const { @@ -156,8 +156,8 @@ struct HSwishOp : HSwishOpBase { vitem0 = vmulq_f32(vitem0, this->vscale_dst); vitem1 = vmulq_f32(vitem1, this->vscale_dst); - return QConverter::convert({{vitem0, vitem1}}, - this->vzp); + return QConverter::convert( + {{vitem0, vitem1}}, this->vzp); } }; diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_prologue.h b/dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_prologue.h index 52843c95..835e252b 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_prologue.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/kern_macro_prologue.h @@ -10,38 +10,32 @@ * implied. */ -#define H_SWISH_KERN(_func_suffix, _val1, _val2) \ - do { \ - auto val_zero = vdupq_n_##_func_suffix(0.f); \ - auto val_six = vdupq_n_##_func_suffix(6.f); \ - auto val_three = vdupq_n_##_func_suffix(3.f); \ - auto val_rec_six = vdupq_n_##_func_suffix(1.f / 6.f); \ - auto clip1 = vmaxq_##_func_suffix( \ - vminq_##_func_suffix(vaddq_##_func_suffix(_val1, val_three), \ - val_six), \ - val_zero); \ - auto clip2 = vmaxq_##_func_suffix( \ - vminq_##_func_suffix(vaddq_##_func_suffix(_val2, val_three), \ - val_six), \ - val_zero); \ - _val1 = vmulq_##_func_suffix(vmulq_##_func_suffix(_val1, clip1), \ - val_rec_six); \ - _val2 = vmulq_##_func_suffix(vmulq_##_func_suffix(_val2, clip2), \ - val_rec_six); \ +#define H_SWISH_KERN(_func_suffix, _val1, _val2) \ + do { \ + auto val_zero = vdupq_n_##_func_suffix(0.f); \ + auto val_six = vdupq_n_##_func_suffix(6.f); \ + auto val_three = vdupq_n_##_func_suffix(3.f); \ + auto val_rec_six = vdupq_n_##_func_suffix(1.f / 6.f); \ + auto clip1 = vmaxq_##_func_suffix( \ + vminq_##_func_suffix(vaddq_##_func_suffix(_val1, val_three), val_six), \ + val_zero); \ + auto clip2 = vmaxq_##_func_suffix( \ + vminq_##_func_suffix(vaddq_##_func_suffix(_val2, val_three), val_six), \ + val_zero); \ + _val1 = vmulq_##_func_suffix(vmulq_##_func_suffix(_val1, clip1), val_rec_six); \ + _val2 = vmulq_##_func_suffix(vmulq_##_func_suffix(_val2, clip2), val_rec_six); \ } while (0); -#define H_SWISH_KERN_N1(_func_suffix, _val1) \ - do { \ - auto val_zero = vdupq_n_##_func_suffix(0.f); \ - auto val_six = vdupq_n_##_func_suffix(6.f); \ - auto val_three = vdupq_n_##_func_suffix(3.f); \ - auto val_rec_six = vdupq_n_##_func_suffix(1.f / 6.f); \ - auto clip1 = vmaxq_##_func_suffix( \ - vminq_##_func_suffix(vaddq_##_func_suffix(_val1, val_three), \ - val_six), \ - val_zero); \ - _val1 = vmulq_##_func_suffix(vmulq_##_func_suffix(_val1, clip1), \ - val_rec_six); \ +#define H_SWISH_KERN_N1(_func_suffix, _val1) \ + do { \ + auto val_zero = vdupq_n_##_func_suffix(0.f); \ + auto val_six = vdupq_n_##_func_suffix(6.f); \ + auto val_three = vdupq_n_##_func_suffix(3.f); \ + auto val_rec_six = vdupq_n_##_func_suffix(1.f / 6.f); \ + auto clip1 = vmaxq_##_func_suffix( \ + vminq_##_func_suffix(vaddq_##_func_suffix(_val1, val_three), val_six), \ + val_zero); \ + _val1 = vmulq_##_func_suffix(vmulq_##_func_suffix(_val1, clip1), val_rec_six); \ } while (0); // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/max.h b/dnn/src/arm_common/elemwise_helper/kimpl/max.h index 1dce3896..073f3a86 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/max.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/max.h @@ -17,8 +17,8 @@ namespace arm_common { template struct MaxOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { @@ -29,33 +29,34 @@ struct MaxOpBase : BinaryOpBase { template struct MaxOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct MaxOp<_ctype> : MaxOpBase<_ctype> { \ - using MaxOpBase::MaxOpBase; \ - using MaxOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src0, const _neon_type2& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type2 operator()(const _neon_type2& src0, \ - const _neon_type2& src1) const { \ - auto vitem0 = vmaxq_##_func_suffix(src0.val[0], src1.val[0]); \ - auto vitem1 = vmaxq_##_func_suffix(src0.val[1], src1.val[1]); \ - return {{vitem0, vitem1}}; \ - } \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type operator()(const _neon_type& src0, \ - const _neon_type& src1) const { \ - return vmaxq_##_func_suffix(src0, src1); \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct MaxOp<_ctype> : MaxOpBase<_ctype> { \ + using MaxOpBase::MaxOpBase; \ + using MaxOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()( \ + const _neon_type2& src0, const _neon_type2& src1) const { \ + auto vitem0 = vmaxq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vmaxq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ + return vmaxq_##_func_suffix(src0, src1); \ + } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -72,16 +73,15 @@ struct MaxOpBase : BinaryOpBase { using dst_ctype = dt_qint8; using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { float fsrc0 = src0.as_int8() * this->scale0; float fsrc1 = src1.as_int8() * this->scale1; - return QConverter::convert(fsrc0 > fsrc1 ? fsrc0 - : fsrc1); + return QConverter::convert(fsrc0 > fsrc1 ? fsrc0 : fsrc1); } }; @@ -91,8 +91,8 @@ struct MaxOpBase : BinaryOpBase { using dst_ctype = dt_quint8; using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } @@ -110,13 +110,12 @@ struct MaxOp : MaxOpBase { constexpr static size_t SIMD_WIDTH = 16; using MaxOpBase::operator(); - void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, dt_qint8* dst) const { OPERATOR_BINARY_QINT8; } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { auto vitem0 = vmaxq_f32( vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); @@ -133,26 +132,22 @@ struct MaxOp : MaxOpBase { constexpr static size_t SIMD_WIDTH = 16; using MaxOpBase::operator(); - void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, - dt_quint8* dst) const { + void operator()( + const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { OPERATOR_BINARY_QUINT8; } - uint8x8_t operator()(const uint32x4x2_t& vsrc0, - const uint32x4x2_t vsrc1) const { - auto vsrct0 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), - this->vszp0); - auto vsrct1 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale1), - this->vszp1); + uint8x8_t operator()(const uint32x4x2_t& vsrc0, const uint32x4x2_t vsrc1) const { + auto vsrct0 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), this->vszp0); + auto vsrct1 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale1), this->vszp1); auto vitem0 = vmaxq_f32(vsrct0, vsrct1); - vsrct0 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale0), - this->vszp0); - vsrct1 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale1), - this->vszp1); + vsrct0 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale0), this->vszp0); + vsrct1 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale1), this->vszp1); auto vitem1 = vmaxq_f32(vsrct0, vsrct1); return QConverter::convert( {{vitem0, vitem1}}, this->vdzp); diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/min.h b/dnn/src/arm_common/elemwise_helper/kimpl/min.h index 9abb4ef3..80bf3e0d 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/min.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/min.h @@ -18,8 +18,8 @@ namespace arm_common { template struct MinOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { @@ -30,33 +30,34 @@ struct MinOpBase : BinaryOpBase { template struct MinOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct MinOp<_ctype> : MinOpBase<_ctype> { \ - using MinOpBase::MinOpBase; \ - using MinOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src0, const _neon_type2& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type2 operator()(const _neon_type2& src0, \ - const _neon_type2& src1) const { \ - auto vitem0 = vminq_##_func_suffix(src0.val[0], src1.val[0]); \ - auto vitem1 = vminq_##_func_suffix(src0.val[1], src1.val[1]); \ - return {{vitem0, vitem1}}; \ - } \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type operator()(const _neon_type& src0, \ - const _neon_type& src1) const { \ - return vminq_##_func_suffix(src0, src1); \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct MinOp<_ctype> : MinOpBase<_ctype> { \ + using MinOpBase::MinOpBase; \ + using MinOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()( \ + const _neon_type2& src0, const _neon_type2& src1) const { \ + auto vitem0 = vminq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vminq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ + return vminq_##_func_suffix(src0, src1); \ + } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -70,24 +71,22 @@ OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) template <> struct MinOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_qint8& src0, const dt_qint8& src1, - dt_qint8* dst) const { + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { *dst = operator()(src0, src1); } dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { float fsrc0 = src0.as_int8() * this->scale0; float fsrc1 = src1.as_int8() * this->scale1; - return QConverter::convert(fsrc0 < fsrc1 ? fsrc0 - : fsrc1); + return QConverter::convert(fsrc0 < fsrc1 ? fsrc0 : fsrc1); } }; template <> struct MinOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_quint8& src0, const dt_quint8& src1, - dt_quint8* dst) const { + void operator()( + const dt_quint8& src0, const dt_quint8& src1, dt_quint8* dst) const { *dst = operator()(src0, src1); } dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1) const { @@ -104,13 +103,12 @@ struct MinOp : MinOpBase { constexpr static size_t SIMD_WIDTH = 16; using MinOpBase::operator(); - void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, dt_qint8* dst) const { OPERATOR_BINARY_QINT8; } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { auto vitem0 = vminq_f32( vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); @@ -127,26 +125,22 @@ struct MinOp : MinOpBase { constexpr static size_t SIMD_WIDTH = 16; using MinOpBase::operator(); - void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, - dt_quint8* dst) const { + void operator()( + const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { OPERATOR_BINARY_QUINT8; } - uint8x8_t operator()(const uint32x4x2_t& vsrc0, - const uint32x4x2_t& vsrc1) const { - auto vsrct0 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), - this->vszp0); - auto vsrct1 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale1), - this->vszp1); + uint8x8_t operator()(const uint32x4x2_t& vsrc0, const uint32x4x2_t& vsrc1) const { + auto vsrct0 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), this->vszp0); + auto vsrct1 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale1), this->vszp1); auto vitem0 = vminq_f32(vsrct0, vsrct1); - vsrct0 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale0), - this->vszp0); - vsrct1 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale1), - this->vszp1); + vsrct0 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale0), this->vszp0); + vsrct1 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale1), this->vszp1); auto vitem1 = vminq_f32(vsrct0, vsrct1); return QConverter::convert( {{vitem0, vitem1}}, this->vdzp); diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/mul.h b/dnn/src/arm_common/elemwise_helper/kimpl/mul.h index ac57f94e..6d9dbe48 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/mul.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/mul.h @@ -18,8 +18,8 @@ namespace arm_common { template struct MulOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { @@ -30,33 +30,34 @@ struct MulOpBase : BinaryOpBase { template struct MulOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct MulOp<_ctype> : MulOpBase<_ctype> { \ - using MulOpBase::MulOpBase; \ - using MulOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src0, const _neon_type2& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type2 operator()(const _neon_type2& src0, \ - const _neon_type2& src1) const { \ - auto vitem0 = vmulq_##_func_suffix(src0.val[0], src1.val[0]); \ - auto vitem1 = vmulq_##_func_suffix(src0.val[1], src1.val[1]); \ - return {{vitem0, vitem1}}; \ - } \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type operator()(const _neon_type& src0, \ - const _neon_type& src1) const { \ - return vmulq_##_func_suffix(src0, src1); \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct MulOp<_ctype> : MulOpBase<_ctype> { \ + using MulOpBase::MulOpBase; \ + using MulOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()( \ + const _neon_type2& src0, const _neon_type2& src1) const { \ + auto vitem0 = vmulq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vmulq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ + return vmulq_##_func_suffix(src0, src1); \ + } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -70,8 +71,7 @@ OP(dt_int8, int8x16_t, int8x16x2_t, s8, 16) template <> struct MulOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_qint8& src0, const dt_qint8& src1, - dt_qint8* dst) const { + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { *dst = operator()(src0, src1); } @@ -84,8 +84,8 @@ struct MulOpBase : BinaryOpBase { template <> struct MulOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_quint8& src0, const dt_quint8& src1, - dt_quint8* dst) const { + void operator()( + const dt_quint8& src0, const dt_quint8& src1, dt_quint8* dst) const { *dst = operator()(src0, src1); } dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1) const { @@ -102,13 +102,12 @@ struct MulOp : MulOpBase { constexpr static size_t SIMD_WIDTH = 16; using MulOpBase::operator(); - void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, dt_qint8* dst) const { OPERATOR_BINARY_QINT8; } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { auto vitem0 = vmulq_f32( vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale_src0), vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); @@ -126,12 +125,12 @@ struct MulOp : MulOpBase { constexpr static size_t SIMD_WIDTH = 16; using MulOpBase::operator(); - void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, - dt_quint8* dst) const { + void operator()( + const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { OPERATOR_BINARY_QUINT8; } - uint8x8_t operator()(const uint32x4x2_t& vsrc0, - const uint32x4x2_t vsrc1) const { + uint8x8_t operator()(const uint32x4x2_t& vsrc0, const uint32x4x2_t vsrc1) const { auto vfsrc0 = vsubq_f32( vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale_src0), this->vscale_zp0); diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/op_base.h b/dnn/src/arm_common/elemwise_helper/kimpl/op_base.h index 7345208e..81317686 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/op_base.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/op_base.h @@ -36,39 +36,47 @@ struct UnaryOpBase : OpBase { UnaryOpBase(DType /*src_dtype*/, DType /*dst_dtype*/) {} }; -#define OPERATOR_UNARY_QINT8 \ - int16x8_t vsrct = vmovl_low_s8(vsrc.val[0]); \ - vst1_s8(reinterpret_cast(dst), \ - operator()({{vmovl_low_s16(vsrct), vmovl_high_s16(vsrct)}})); \ - \ - vsrct = vmovl_high_s8(vsrc.val[0]); \ - vst1_s8(reinterpret_cast(dst + 8), \ - operator()({{vmovl_low_s16(vsrct), vmovl_high_s16(vsrct)}})); \ - \ - vsrct = vmovl_low_s8(vsrc.val[1]); \ - vst1_s8(reinterpret_cast(dst + 16), \ - operator()({{vmovl_low_s16(vsrct), vmovl_high_s16(vsrct)}})); \ - \ - vsrct = vmovl_high_s8(vsrc.val[1]); \ - vst1_s8(reinterpret_cast(dst + 24), \ - operator()({{vmovl_low_s16(vsrct), vmovl_high_s16(vsrct)}})); - -#define OPERATOR_UNARY_QUINT8 \ - uint16x8_t vsrct = vmovl_low_u8(vsrc.val[0]); \ - vst1_u8(reinterpret_cast(dst), \ - operator()({{vmovl_low_u16(vsrct), vmovl_high_u16(vsrct)}})); \ - \ - vsrct = vmovl_high_u8(vsrc.val[0]); \ - vst1_u8(reinterpret_cast(dst + 8), \ - operator()({{vmovl_low_u16(vsrct), vmovl_high_u16(vsrct)}})); \ - \ - vsrct = vmovl_low_u8(vsrc.val[1]); \ - vst1_u8(reinterpret_cast(dst + 16), \ - operator()({{vmovl_low_u16(vsrct), vmovl_high_u16(vsrct)}})); \ - \ - vsrct = vmovl_high_u8(vsrc.val[1]); \ - vst1_u8(reinterpret_cast(dst + 24), \ - operator()({{vmovl_low_u16(vsrct), vmovl_high_u16(vsrct)}})); +#define OPERATOR_UNARY_QINT8 \ + int16x8_t vsrct = vmovl_low_s8(vsrc.val[0]); \ + vst1_s8(reinterpret_cast(dst), operator()( \ + {{vmovl_low_s16(vsrct), \ + vmovl_high_s16(vsrct)}})); \ + \ + vsrct = vmovl_high_s8(vsrc.val[0]); \ + vst1_s8(reinterpret_cast(dst + 8), operator()( \ + {{vmovl_low_s16(vsrct), \ + vmovl_high_s16(vsrct)}})); \ + \ + vsrct = vmovl_low_s8(vsrc.val[1]); \ + vst1_s8(reinterpret_cast(dst + 16), operator()( \ + {{vmovl_low_s16(vsrct), \ + vmovl_high_s16(vsrct)}})); \ + \ + vsrct = vmovl_high_s8(vsrc.val[1]); \ + vst1_s8(reinterpret_cast(dst + 24), operator()( \ + {{vmovl_low_s16(vsrct), \ + vmovl_high_s16(vsrct)}})); + +#define OPERATOR_UNARY_QUINT8 \ + uint16x8_t vsrct = vmovl_low_u8(vsrc.val[0]); \ + vst1_u8(reinterpret_cast(dst), operator()( \ + {{vmovl_low_u16(vsrct), \ + vmovl_high_u16(vsrct)}})); \ + \ + vsrct = vmovl_high_u8(vsrc.val[0]); \ + vst1_u8(reinterpret_cast(dst + 8), operator()( \ + {{vmovl_low_u16(vsrct), \ + vmovl_high_u16(vsrct)}})); \ + \ + vsrct = vmovl_low_u8(vsrc.val[1]); \ + vst1_u8(reinterpret_cast(dst + 16), operator()( \ + {{vmovl_low_u16(vsrct), \ + vmovl_high_u16(vsrct)}})); \ + \ + vsrct = vmovl_high_u8(vsrc.val[1]); \ + vst1_u8(reinterpret_cast(dst + 24), operator()( \ + {{vmovl_low_u16(vsrct), \ + vmovl_high_u16(vsrct)}})); //! scale_src = src.scale; scale_dst = 1.f / dst.scale (div -> mul) //! scale = src.scale / dst.scale @@ -94,12 +102,9 @@ struct UnaryOpBase : OpBase { float dst_scale = dst_dtype.param().scale; init(src_scale, dst_scale); } - UnaryOpBase(float src_scale, float dst_scale) { - init(src_scale, dst_scale); - } + UnaryOpBase(float src_scale, float dst_scale) { init(src_scale, dst_scale); } }; - //! scale_src = src.scale; scale_dst = 1.f / dst.scale //! scale_zp = src.zp * src.scale; dzp = dst.zp //! scale = src.scale / dst.scale; szp = src.zp * scale @@ -115,8 +120,7 @@ struct UnaryOpBase : OpBase { float scale, szp; float32x4_t vscale, vszp; - void init(float src_scale, float dst_scale, uint8_t src_zp, - uint8_t dst_zp) { + void init(float src_scale, float dst_scale, uint8_t src_zp, uint8_t dst_zp) { scale_src = src_scale; scale_dst = 1.f / dst_scale; vscale_src = vdupq_n_f32(scale_src); @@ -137,8 +141,7 @@ struct UnaryOpBase : OpBase { uint8_t dst_zp = dst_dtype.param().zero_point; init(src_scale, dst_scale, src_zp, dst_zp); } - UnaryOpBase(float src_scale, float dst_scale, uint8_t src_zp, - uint8_t dst_zp) { + UnaryOpBase(float src_scale, float dst_scale, uint8_t src_zp, uint8_t dst_zp) { init(src_scale, dst_scale, src_zp, dst_zp); } float32x4x2_t cvt_to_float(const uint32x4x2_t& vsrc) { @@ -192,9 +195,7 @@ struct UnaryOpBase : OpBase { init(src_scale, dst_scale); } - UnaryOpBase(float src_scale, float dst_scale) { - init(src_scale, dst_scale); - } + UnaryOpBase(float src_scale, float dst_scale) { init(src_scale, dst_scale); } }; template <> @@ -237,59 +238,73 @@ template struct BinaryOpBase : OpBase { using OpBase::OpBase; BinaryOpBase() = default; - BinaryOpBase(DType /*src0_dtype*/, DType /*src1_dtype*/, - DType /*dst_dtype*/) {} + BinaryOpBase(DType /*src0_dtype*/, DType /*src1_dtype*/, DType /*dst_dtype*/) {} }; -#define OPERATOR_BINARY_QINT8 \ - int16x8_t vsrct0 = vmovl_low_s8(vsrc0.val[0]); \ - int16x8_t vsrct1 = vmovl_low_s8(vsrc1.val[0]); \ - vst1_s8(reinterpret_cast(dst), \ - operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ - {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}})); \ - \ - vsrct0 = vmovl_high_s8(vsrc0.val[0]); \ - vsrct1 = vmovl_high_s8(vsrc1.val[0]); \ - vst1_s8(reinterpret_cast(dst + 8), \ - operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ - {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}})); \ - \ - vsrct0 = vmovl_low_s8(vsrc0.val[1]); \ - vsrct1 = vmovl_low_s8(vsrc1.val[1]); \ - vst1_s8(reinterpret_cast(dst + 16), \ - operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ - {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}})); \ - \ - vsrct0 = vmovl_high_s8(vsrc0.val[1]); \ - vsrct1 = vmovl_high_s8(vsrc1.val[1]); \ - vst1_s8(reinterpret_cast(dst + 24), \ - operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ - {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}})) - -#define OPERATOR_BINARY_QUINT8 \ - uint16x8_t vsrct0 = vmovl_low_u8(vsrc0.val[0]); \ - uint16x8_t vsrct1 = vmovl_low_u8(vsrc1.val[0]); \ - vst1_u8(reinterpret_cast(dst), \ - operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ - {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}})); \ - \ - vsrct0 = vmovl_high_u8(vsrc0.val[0]); \ - vsrct1 = vmovl_high_u8(vsrc1.val[0]); \ - vst1_u8(reinterpret_cast(dst + 8), \ - operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ - {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}})); \ - \ - vsrct0 = vmovl_low_u8(vsrc0.val[1]); \ - vsrct1 = vmovl_low_u8(vsrc1.val[1]); \ - vst1_u8(reinterpret_cast(dst + 16), \ - operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ - {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}})); \ - \ - vsrct0 = vmovl_high_u8(vsrc0.val[1]); \ - vsrct1 = vmovl_high_u8(vsrc1.val[1]); \ - vst1_u8(reinterpret_cast(dst + 24), \ - operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ - {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}})) +#define OPERATOR_BINARY_QINT8 \ + int16x8_t vsrct0 = vmovl_low_s8(vsrc0.val[0]); \ + int16x8_t vsrct1 = vmovl_low_s8(vsrc1.val[0]); \ + vst1_s8(reinterpret_cast(dst), operator()( \ + {{vmovl_low_s16(vsrct0), \ + vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), \ + vmovl_high_s16(vsrct1)}})); \ + \ + vsrct0 = vmovl_high_s8(vsrc0.val[0]); \ + vsrct1 = vmovl_high_s8(vsrc1.val[0]); \ + vst1_s8(reinterpret_cast(dst + 8), operator()( \ + {{vmovl_low_s16(vsrct0), \ + vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), \ + vmovl_high_s16(vsrct1)}})); \ + \ + vsrct0 = vmovl_low_s8(vsrc0.val[1]); \ + vsrct1 = vmovl_low_s8(vsrc1.val[1]); \ + vst1_s8(reinterpret_cast(dst + 16), operator()( \ + {{vmovl_low_s16(vsrct0), \ + vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), \ + vmovl_high_s16(vsrct1)}})); \ + \ + vsrct0 = vmovl_high_s8(vsrc0.val[1]); \ + vsrct1 = vmovl_high_s8(vsrc1.val[1]); \ + vst1_s8(reinterpret_cast(dst + 24), operator()( \ + {{vmovl_low_s16(vsrct0), \ + vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), \ + vmovl_high_s16(vsrct1)}})) + +#define OPERATOR_BINARY_QUINT8 \ + uint16x8_t vsrct0 = vmovl_low_u8(vsrc0.val[0]); \ + uint16x8_t vsrct1 = vmovl_low_u8(vsrc1.val[0]); \ + vst1_u8(reinterpret_cast(dst), operator()( \ + {{vmovl_low_u16(vsrct0), \ + vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), \ + vmovl_high_u16(vsrct1)}})); \ + \ + vsrct0 = vmovl_high_u8(vsrc0.val[0]); \ + vsrct1 = vmovl_high_u8(vsrc1.val[0]); \ + vst1_u8(reinterpret_cast(dst + 8), operator()( \ + {{vmovl_low_u16(vsrct0), \ + vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), \ + vmovl_high_u16(vsrct1)}})); \ + \ + vsrct0 = vmovl_low_u8(vsrc0.val[1]); \ + vsrct1 = vmovl_low_u8(vsrc1.val[1]); \ + vst1_u8(reinterpret_cast(dst + 16), \ + operator()( \ + {{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}})); \ + \ + vsrct0 = vmovl_high_u8(vsrc0.val[1]); \ + vsrct1 = vmovl_high_u8(vsrc1.val[1]); \ + vst1_u8(reinterpret_cast(dst + 24), operator()( \ + {{vmovl_low_u16(vsrct0), \ + vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), \ + vmovl_high_u16(vsrct1)}})) /* ================= binary op for quantized types ================== */ @@ -349,8 +364,9 @@ struct BinaryOpBase : OpBase { uint8_t dzp; int32x4_t vdzp; - void init(float src0_scale, float src1_scale, float dst_scale, - uint8_t src0_zp, uint8_t src1_zp, uint8_t dst_zp) { + void init( + float src0_scale, float src1_scale, float dst_scale, uint8_t src0_zp, + uint8_t src1_zp, uint8_t dst_zp) { scale_src0 = src0_scale; vscale_src0 = vdupq_n_f32(scale_src0); scale_src1 = src1_scale; @@ -383,8 +399,9 @@ struct BinaryOpBase : OpBase { init(src0_scale, src1_scale, dst_scale, src0_zp, src1_zp, dst_zp); } - BinaryOpBase(float src0_scale, float src1_scale, float dst_scale, - uint8_t src0_zp, uint8_t src1_zp, uint8_t dst_zp) { + BinaryOpBase( + float src0_scale, float src1_scale, float dst_scale, uint8_t src0_zp, + uint8_t src1_zp, uint8_t dst_zp) { init(src0_scale, src1_scale, dst_scale, src0_zp, src1_zp, dst_zp); } }; @@ -436,8 +453,7 @@ struct BinaryOpBase : OpBase { float scale_src0, scale_src1, scale_dst; float32x4_t vscale_src0, vscale_src1, vscale_dst; - void init(float src0_scale, float src1_scale, float dst_scale, - uint8_t zero_point) { + void init(float src0_scale, float src1_scale, float dst_scale, uint8_t zero_point) { scale_src0 = src0_scale; vscale_src0 = vdupq_n_f32(src0_scale); scale_src1 = src1_scale; @@ -460,8 +476,8 @@ struct BinaryOpBase : OpBase { init(src0_scale, src1_scale, dst_scale, zp); } - BinaryOpBase(float src0_scale, float src1_scale, float dst_scale, - uint8_t zero_point) { + BinaryOpBase( + float src0_scale, float src1_scale, float dst_scale, uint8_t zero_point) { init(src0_scale, src1_scale, dst_scale, zero_point); } }; @@ -471,75 +487,84 @@ template struct TernaryOpBase : OpBase { using OpBase::OpBase; TernaryOpBase() = default; - TernaryOpBase(DType /*src0_dtype*/, DType /*src1_dtype*/, - DType /*src2_dtype*/, DType /*dst_dtype*/) {} + TernaryOpBase( + DType /*src0_dtype*/, DType /*src1_dtype*/, DType /*src2_dtype*/, + DType /*dst_dtype*/) {} }; -#define OPERATOR_TERNARY_QINT8 \ - int16x8_t vsrct0 = vmovl_low_s8(vsrc0.val[0]); \ - int16x8_t vsrct1 = vmovl_low_s8(vsrc1.val[0]); \ - int16x8_t vsrct2 = vmovl_low_s8(vsrc2.val[0]); \ - vst1_s8(reinterpret_cast(dst), \ - operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ - {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ - {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})); \ - \ - vsrct0 = vmovl_high_s8(vsrc0.val[0]); \ - vsrct1 = vmovl_high_s8(vsrc1.val[0]); \ - vsrct2 = vmovl_high_s8(vsrc2.val[0]); \ - vst1_s8(reinterpret_cast(dst + 8), \ - operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ - {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ - {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})); \ - \ - vsrct0 = vmovl_low_s8(vsrc0.val[1]); \ - vsrct1 = vmovl_low_s8(vsrc1.val[1]); \ - vsrct2 = vmovl_low_s8(vsrc2.val[1]); \ - vst1_s8(reinterpret_cast(dst + 16), \ - operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ - {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ - {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})); \ - \ - vsrct0 = vmovl_high_s8(vsrc0.val[1]); \ - vsrct1 = vmovl_high_s8(vsrc1.val[1]); \ - vsrct2 = vmovl_high_s8(vsrc2.val[1]); \ - vst1_s8(reinterpret_cast(dst + 24), \ - operator()({{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ - {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ - {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})) - -#define OPERATOR_TERNARY_QUINT8 \ - uint16x8_t vsrct0 = vmovl_low_u8(vsrc0.val[0]); \ - uint16x8_t vsrct1 = vmovl_low_u8(vsrc1.val[0]); \ - uint16x8_t vsrct2 = vmovl_low_u8(vsrc2.val[0]); \ - vst1_u8(reinterpret_cast(dst), \ - operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ - {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ - {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})); \ - \ - vsrct0 = vmovl_high_u8(vsrc0.val[0]); \ - vsrct1 = vmovl_high_u8(vsrc1.val[0]); \ - vsrct2 = vmovl_high_u8(vsrc2.val[0]); \ - vst1_u8(reinterpret_cast(dst + 8), \ - operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ - {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ - {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})); \ - \ - vsrct0 = vmovl_low_u8(vsrc0.val[1]); \ - vsrct1 = vmovl_low_u8(vsrc1.val[1]); \ - vsrct2 = vmovl_low_u8(vsrc2.val[1]); \ - vst1_u8(reinterpret_cast(dst + 16), \ - operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ - {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ - {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})); \ - \ - vsrct0 = vmovl_high_u8(vsrc0.val[1]); \ - vsrct1 = vmovl_high_u8(vsrc1.val[1]); \ - vsrct2 = vmovl_high_u8(vsrc2.val[1]); \ - vst1_u8(reinterpret_cast(dst + 24), \ - operator()({{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ - {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ - {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})) +#define OPERATOR_TERNARY_QINT8 \ + int16x8_t vsrct0 = vmovl_low_s8(vsrc0.val[0]); \ + int16x8_t vsrct1 = vmovl_low_s8(vsrc1.val[0]); \ + int16x8_t vsrct2 = vmovl_low_s8(vsrc2.val[0]); \ + vst1_s8(reinterpret_cast(dst), \ + operator()( \ + {{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ + {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})); \ + \ + vsrct0 = vmovl_high_s8(vsrc0.val[0]); \ + vsrct1 = vmovl_high_s8(vsrc1.val[0]); \ + vsrct2 = vmovl_high_s8(vsrc2.val[0]); \ + vst1_s8(reinterpret_cast(dst + 8), \ + operator()( \ + {{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ + {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})); \ + \ + vsrct0 = vmovl_low_s8(vsrc0.val[1]); \ + vsrct1 = vmovl_low_s8(vsrc1.val[1]); \ + vsrct2 = vmovl_low_s8(vsrc2.val[1]); \ + vst1_s8(reinterpret_cast(dst + 16), \ + operator()( \ + {{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ + {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})); \ + \ + vsrct0 = vmovl_high_s8(vsrc0.val[1]); \ + vsrct1 = vmovl_high_s8(vsrc1.val[1]); \ + vsrct2 = vmovl_high_s8(vsrc2.val[1]); \ + vst1_s8(reinterpret_cast(dst + 24), \ + operator()( \ + {{vmovl_low_s16(vsrct0), vmovl_high_s16(vsrct0)}}, \ + {{vmovl_low_s16(vsrct1), vmovl_high_s16(vsrct1)}}, \ + {{vmovl_low_s16(vsrct2), vmovl_high_s16(vsrct2)}})) + +#define OPERATOR_TERNARY_QUINT8 \ + uint16x8_t vsrct0 = vmovl_low_u8(vsrc0.val[0]); \ + uint16x8_t vsrct1 = vmovl_low_u8(vsrc1.val[0]); \ + uint16x8_t vsrct2 = vmovl_low_u8(vsrc2.val[0]); \ + vst1_u8(reinterpret_cast(dst), \ + operator()( \ + {{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ + {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})); \ + \ + vsrct0 = vmovl_high_u8(vsrc0.val[0]); \ + vsrct1 = vmovl_high_u8(vsrc1.val[0]); \ + vsrct2 = vmovl_high_u8(vsrc2.val[0]); \ + vst1_u8(reinterpret_cast(dst + 8), \ + operator()( \ + {{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ + {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})); \ + \ + vsrct0 = vmovl_low_u8(vsrc0.val[1]); \ + vsrct1 = vmovl_low_u8(vsrc1.val[1]); \ + vsrct2 = vmovl_low_u8(vsrc2.val[1]); \ + vst1_u8(reinterpret_cast(dst + 16), \ + operator()( \ + {{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ + {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})); \ + \ + vsrct0 = vmovl_high_u8(vsrc0.val[1]); \ + vsrct1 = vmovl_high_u8(vsrc1.val[1]); \ + vsrct2 = vmovl_high_u8(vsrc2.val[1]); \ + vst1_u8(reinterpret_cast(dst + 24), \ + operator()( \ + {{vmovl_low_u16(vsrct0), vmovl_high_u16(vsrct0)}}, \ + {{vmovl_low_u16(vsrct1), vmovl_high_u16(vsrct1)}}, \ + {{vmovl_low_u16(vsrct2), vmovl_high_u16(vsrct2)}})) /*========================= ternaty op for quanzited ====================*/ template <> @@ -551,8 +576,7 @@ struct TernaryOpBase : OpBase { float32x4_t vscale_src0, vscale_src1, vscale_src2, vscale_dst; float scale0, scale1, scale2; float32x4_t vscale0, vscale1, vscale2; - void init(float src0_scale, float src1_scale, float src2_scale, - float dst_scale) { + void init(float src0_scale, float src1_scale, float src2_scale, float dst_scale) { scale_src0 = src0_scale; scale_src1 = src1_scale; scale_src2 = src2_scale; @@ -568,16 +592,16 @@ struct TernaryOpBase : OpBase { vscale1 = vdupq_n_f32(scale1); vscale2 = vdupq_n_f32(scale2); } - TernaryOpBase(DType src0_dtype, DType src1_dtype, DType src2_dtype, - DType dst_dtype) { + TernaryOpBase( + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype) { float src0_scale = src0_dtype.param().scale; float src1_scale = src1_dtype.param().scale; float src2_scale = src2_dtype.param().scale; float dst_scale = dst_dtype.param().scale; init(src0_scale, src1_scale, src2_scale, dst_scale); } - TernaryOpBase(float src0_scale, float src1_scale, float src2_scale, - float dst_scale) { + TernaryOpBase( + float src0_scale, float src1_scale, float src2_scale, float dst_scale) { init(src0_scale, src1_scale, src2_scale, dst_scale); } }; @@ -595,9 +619,9 @@ struct TernaryOpBase : OpBase { float32x4_t vscale0, vscale1, vscale2; uint8_t dzp; int32x4_t vdzp; - void init(float src0_scale, float src1_scale, float src2_scale, - float dst_scale, uint8_t src0_zp, uint8_t src1_zp, - uint8_t src2_zp, uint8_t dst_zp) { + void init( + float src0_scale, float src1_scale, float src2_scale, float dst_scale, + uint8_t src0_zp, uint8_t src1_zp, uint8_t src2_zp, uint8_t dst_zp) { scale_src0 = src0_scale; scale_src1 = src1_scale; scale_src2 = src2_scale; @@ -621,8 +645,8 @@ struct TernaryOpBase : OpBase { dzp = dst_zp; vdzp = vdupq_n_s32(static_cast(dzp)); } - TernaryOpBase(DType src0_dtype, DType src1_dtype, DType src2_dtype, - DType dst_dtype) { + TernaryOpBase( + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype) { float src0_scale = src0_dtype.param().scale; float src1_scale = src1_dtype.param().scale; float src2_scale = src2_dtype.param().scale; @@ -631,14 +655,14 @@ struct TernaryOpBase : OpBase { uint8_t src1_zp = src1_dtype.param().zero_point; uint8_t src2_zp = src2_dtype.param().zero_point; uint8_t dst_zp = dst_dtype.param().zero_point; - init(src0_scale, src1_scale, src2_scale, dst_scale, src0_zp, src1_zp, - src2_zp, dst_zp); + init(src0_scale, src1_scale, src2_scale, dst_scale, src0_zp, src1_zp, src2_zp, + dst_zp); } - TernaryOpBase(float src0_scale, float src1_scale, float src2_scale, - float dst_scale, uint8_t src0_zp, uint8_t src1_zp, - uint8_t src2_zp, uint8_t dst_zp) { - init(src0_scale, src1_scale, src2_scale, dst_scale, src0_zp, src1_zp, - src2_zp, dst_zp); + TernaryOpBase( + float src0_scale, float src1_scale, float src2_scale, float dst_scale, + uint8_t src0_zp, uint8_t src1_zp, uint8_t src2_zp, uint8_t dst_zp) { + init(src0_scale, src1_scale, src2_scale, dst_scale, src0_zp, src1_zp, src2_zp, + dst_zp); } }; @@ -654,8 +678,8 @@ struct FixupBase { int shift = static_cast(::ceilf(::log2f(0.5 / scale))); scale *= ::powf(2, shift); //! Using double can get full precision here, but it can be ignored. - vmultiplier = vdupq_n_s32( - std::round(static_cast(scale) * ((2LL) << 30))); + vmultiplier = + vdupq_n_s32(std::round(static_cast(scale) * ((2LL) << 30))); vshift = vdupq_n_s32(-shift); } }; @@ -665,8 +689,7 @@ template struct UnaryQuantizationOp; template -struct UnaryQuantizationOp - : UnaryOpBase { +struct UnaryQuantizationOp : UnaryOpBase { using UnaryOpBase::UnaryOpBase; constexpr static size_t SIMD_WIDTH = 16; Op op; @@ -735,14 +758,12 @@ template struct BinaryQuantizationOp; template -struct BinaryQuantizationOp - : BinaryOpBase { +struct BinaryQuantizationOp : BinaryOpBase { using BinaryOpBase::BinaryOpBase; constexpr static size_t SIMD_WIDTH = 16; Op op; - void operator()(const dt_qint8& src0, const dt_qint8& src1, - dt_qint8* dst) const { + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { *dst = operator()(src0, src1); } @@ -754,13 +775,12 @@ struct BinaryQuantizationOp return QConverter::convert(fdst); } - void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, dt_qint8* dst) const { OPERATOR_BINARY_QINT8; } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { auto val0 = vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale_src0); auto val1 = vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale_src0); auto val2 = vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale_src1); @@ -779,8 +799,8 @@ struct BinaryQuantizationOp constexpr static size_t SIMD_WIDTH = 16; Op op; - void operator()(const dt_quint8& src0, const dt_quint8& src1, - dt_quint8* dst) const { + void operator()( + const dt_quint8& src0, const dt_quint8& src1, dt_quint8* dst) const { *dst = operator()(src0, src1); } @@ -792,13 +812,13 @@ struct BinaryQuantizationOp return QConverter::convert(fdst, this->dzp); } - void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, - dt_quint8* dst) const { + void operator()( + const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { OPERATOR_BINARY_QUINT8; } - uint8x8_t operator()(const uint32x4x2_t& vsrc0, - const uint32x4x2_t& vsrc1) const { + uint8x8_t operator()(const uint32x4x2_t& vsrc0, const uint32x4x2_t& vsrc1) const { auto val0 = vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale_src0); val0 = vsubq_f32(val0, this->vscale_zp0); auto val1 = vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale_src0); @@ -825,13 +845,14 @@ struct TernaryQuantizationOp constexpr static size_t SIMD_WIDTH = 16; Op op; - void operator()(const dt_qint8& src0, const dt_qint8& src1, - const dt_qint8& src2, dt_qint8* dst) const { + void operator()( + const dt_qint8& src0, const dt_qint8& src1, const dt_qint8& src2, + dt_qint8* dst) const { *dst = operator()(src0, src1, src2); } - dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1, - const dt_qint8& src2) const { + dt_qint8 operator()( + const dt_qint8& src0, const dt_qint8& src1, const dt_qint8& src2) const { float fsrc0 = src0.as_int8() * this->scale_src0; float fsrc1 = src1.as_int8() * this->scale_src1; float fsrc2 = src2.as_int8() * this->scale_src2; @@ -840,14 +861,15 @@ struct TernaryQuantizationOp return QConverter::convert(fdst); } - void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, - const int8x16x2_t& vsrc2, dt_qint8* dst) const { + void operator()( + const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, + const int8x16x2_t& vsrc2, dt_qint8* dst) const { OPERATOR_TERNARY_QINT8; } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1, - const int32x4x2_t& vsrc2) const { + int8x8_t operator()( + const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1, + const int32x4x2_t& vsrc2) const { auto val0 = vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale_src0); auto val1 = vmulq_f32(vcvtq_f32_s32(vsrc0.val[1]), this->vscale_src0); auto val2 = vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale_src1); @@ -868,13 +890,14 @@ struct TernaryQuantizationOp constexpr static size_t SIMD_WIDTH = 16; Op op; - void operator()(const dt_quint8& src0, const dt_quint8& src1, - const dt_quint8& src2, dt_quint8* dst) const { + void operator()( + const dt_quint8& src0, const dt_quint8& src1, const dt_quint8& src2, + dt_quint8* dst) const { *dst = operator()(src0, src1, src2); } - dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1, - const dt_quint8& src2) const { + dt_quint8 operator()( + const dt_quint8& src0, const dt_quint8& src1, const dt_quint8& src2) const { float fsrc0 = src0.as_uint8() * this->scale_src0 - this->scale_zp0; float fsrc1 = src1.as_uint8() * this->scale_src1 - this->scale_zp1; float fsrc2 = src2.as_uint8() * this->scale_src2 - this->scale_zp2; @@ -883,13 +906,15 @@ struct TernaryQuantizationOp return QConverter::convert(fdst, this->dzp); } - void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, - const uint8x16x2_t& vsrc2, dt_quint8* dst) const { + void operator()( + const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + const uint8x16x2_t& vsrc2, dt_quint8* dst) const { OPERATOR_TERNARY_QUINT8; } - uint8x8_t operator()(const uint32x4x2_t& vsrc0, const uint32x4x2_t& vsrc1, - const uint32x4x2_t& vsrc2) const { + uint8x8_t operator()( + const uint32x4x2_t& vsrc0, const uint32x4x2_t& vsrc1, + const uint32x4x2_t& vsrc2) const { auto val0 = vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale_src0); val0 = vsubq_f32(val0, this->vscale_zp0); auto val1 = vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale_src0); diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/pow.h b/dnn/src/arm_common/elemwise_helper/kimpl/pow.h index 7250ac84..a8154598 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/pow.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/pow.h @@ -21,8 +21,8 @@ template struct PowOp : BinaryOpBase { using BinaryOpBase::BinaryOpBase; constexpr static size_t SIMD_WIDTH = 1; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/relu.h b/dnn/src/arm_common/elemwise_helper/kimpl/relu.h index 9a11e511..c8b715e0 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/relu.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/relu.h @@ -22,9 +22,7 @@ struct ReluOpBase : UnaryOpBase { void operator()(const src_ctype& src, dst_ctype* dst) const { *dst = operator()(src); } - dst_ctype operator()(const src_ctype& src) const { - return src > 0 ? src : 0; - } + dst_ctype operator()(const src_ctype& src) const { return src > 0 ? src : 0; } }; template @@ -172,8 +170,7 @@ struct ReluOp : ReluOpBase { vst1_s8(reinterpret_cast(dst), operator()(vsrc)); } void operator()(const int32x4_t& src, dt_qint8* dst) const { - vst1_lane_s32(reinterpret_cast(dst), - (int32x2_t)(operator()(src)), 0); + vst1_lane_s32(reinterpret_cast(dst), (int32x2_t)(operator()(src)), 0); } int8x8_t operator()(const int32x4x2_t& vsrc) const { @@ -197,8 +194,7 @@ struct ReluOp : ReluOpBase { }; #else template <> -struct ReluOp : ReluOpBase, - FixupBase { +struct ReluOp : ReluOpBase, FixupBase { using ReluOpBase::operator(); constexpr static size_t SIMD_WIDTH = 4; @@ -217,8 +213,9 @@ struct ReluOp : ReluOpBase, int32x4_t vitem1 = vqrdmulhq_s32(vsrc.val[1], vmultiplier); vitem0 = vmaxq_s32(vitem0, QConverterBase::vzero()); vitem1 = vmaxq_s32(vitem1, QConverterBase::vzero()); - return vqmovn_s16(vcombine_s16(vqmovn_s32(vrshlq_s32(vitem0, vshift)), - vqmovn_s32(vrshlq_s32(vitem1, vshift)))); + return vqmovn_s16(vcombine_s16( + vqmovn_s32(vrshlq_s32(vitem0, vshift)), + vqmovn_s32(vrshlq_s32(vitem1, vshift)))); } int8x8_t operator()(const float32x4_t& vsrc) const { int32x4_t vitem0 = vqrdmulhq_s32(vcvtq_s32_f32(vsrc), vmultiplier); @@ -258,8 +255,8 @@ struct ReluOp : ReluOpBase { vitem0 = vmaxq_f32(vitem0, QConverterBase::vfzero()); vitem1 = vmaxq_f32(vitem1, QConverterBase::vfzero()); - return QConverter::convert({{vitem0, vitem1}}, - this->vzp); + return QConverter::convert( + {{vitem0, vitem1}}, this->vzp); } }; diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/rmulh.h b/dnn/src/arm_common/elemwise_helper/kimpl/rmulh.h index 476f2969..06989b7b 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/rmulh.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/rmulh.h @@ -18,8 +18,8 @@ namespace arm_common { template struct RmulhOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { @@ -30,33 +30,34 @@ struct RmulhOpBase : BinaryOpBase { template struct RmulhOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct RmulhOp<_ctype> : RmulhOpBase<_ctype> { \ - using RmulhOpBase::RmulhOpBase; \ - using RmulhOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src0, const _neon_type2& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type2 operator()(const _neon_type2& src0, \ - const _neon_type2& src1) const { \ - auto vitem0 = vqrdmulhq_##_func_suffix(src0.val[0], src1.val[0]); \ - auto vitem1 = vqrdmulhq_##_func_suffix(src0.val[1], src1.val[1]); \ - return {{vitem0, vitem1}}; \ - } \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type operator()(const _neon_type& src0, \ - const _neon_type& src1) const { \ - return vqrdmulhq_##_func_suffix(src0, src1); \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct RmulhOp<_ctype> : RmulhOpBase<_ctype> { \ + using RmulhOpBase::RmulhOpBase; \ + using RmulhOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()( \ + const _neon_type2& src0, const _neon_type2& src1) const { \ + auto vitem0 = vqrdmulhq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vqrdmulhq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ + return vqrdmulhq_##_func_suffix(src0, src1); \ + } \ }; OP(dt_int32, int32x4_t, int32x4x2_t, s32, 4) OP(dt_int16, int16x8_t, int16x8x2_t, s16, 8) @@ -70,14 +71,13 @@ struct RmulhOp : RmulhOpBase { using RmulhOpBase::RmulhOpBase; using RmulhOpBase::operator(); constexpr static size_t SIMD_WIDTH = 16; - void operator()(const int8x16x2_t& src0, const int8x16x2_t& src1, - int8_t* dst) const { + void operator()( + const int8x16x2_t& src0, const int8x16x2_t& src1, int8_t* dst) const { auto vitem = operator()(src0, src1); vst1q_s8(dst, vitem.val[0]); vst1q_s8(dst + SIMD_WIDTH, vitem.val[1]); } - int8x16x2_t operator()(const int8x16x2_t& src0, - const int8x16x2_t& src1) const { + int8x16x2_t operator()(const int8x16x2_t& src0, const int8x16x2_t& src1) const { int8x16_t val, var; int8x8_t lol, hil, lor, hir; int16x8_t mu0, mu1; @@ -112,8 +112,7 @@ struct RmulhOp : RmulhOpBase { return {{vcombine_s8(lol, hil), vcombine_s8(lol1, hil1)}}; } - void operator()(const int8x16_t& src0, const int8x16_t& src1, - int8_t* dst) const { + void operator()(const int8x16_t& src0, const int8x16_t& src1, int8_t* dst) const { auto vitem = operator()(src0, src1); vst1q_s8(dst, vitem); } diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h b/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h index f8dc2bad..3d10636c 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/sigmoid.h @@ -33,35 +33,34 @@ struct SigmoidOpBase : UnaryOpBase { template struct SigmoidOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct SigmoidOp<_ctype> : SigmoidOpBase<_ctype> { \ - using SigmoidOpBase::SigmoidOpBase; \ - using SigmoidOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src, _ctype* dst) const { \ - auto vitem = operator()(src); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - void operator()(const _neon_type& src, _ctype* dst) const { \ - auto vitem = operator()(src); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type2 operator()(const _neon_type2& src) const { \ - return {{operator()(src.val[0]), operator()(src.val[1])}}; \ - } \ - _neon_type operator()(const _neon_type& src) const { \ - auto zero_val = vdupq_n_##_func_suffix(0.f); \ - auto one_val = vdupq_n_##_func_suffix(1.f); \ - auto val1 = vsubq_##_func_suffix(zero_val, src); \ - val1 = exp_ps_##_func_suffix(val1); \ - auto recipe1 = vaddq_##_func_suffix(one_val, val1); \ - val1 = vrecpeq_##_func_suffix(recipe1); \ - val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), \ - val1); \ - return val1; \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct SigmoidOp<_ctype> : SigmoidOpBase<_ctype> { \ + using SigmoidOpBase::SigmoidOpBase; \ + using SigmoidOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type2& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type2 operator()(const _neon_type2& src) const { \ + return {{operator()(src.val[0]), operator()(src.val[1])}}; \ + } \ + _neon_type operator()(const _neon_type& src) const { \ + auto zero_val = vdupq_n_##_func_suffix(0.f); \ + auto one_val = vdupq_n_##_func_suffix(1.f); \ + auto val1 = vsubq_##_func_suffix(zero_val, src); \ + val1 = exp_ps_##_func_suffix(val1); \ + auto recipe1 = vaddq_##_func_suffix(one_val, val1); \ + val1 = vrecpeq_##_func_suffix(recipe1); \ + val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), val1); \ + return val1; \ + } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/sub.h b/dnn/src/arm_common/elemwise_helper/kimpl/sub.h index 34da13ec..4ccfc273 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/sub.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/sub.h @@ -18,8 +18,8 @@ namespace arm_common { template struct SubOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { @@ -30,33 +30,34 @@ struct SubOpBase : BinaryOpBase { template struct SubOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct SubOp<_ctype> : SubOpBase<_ctype> { \ - using SubOpBase::SubOpBase; \ - using SubOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src0, const _neon_type2& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type2 operator()(const _neon_type2& src0, \ - const _neon_type2& src1) const { \ - auto vitem0 = vsubq_##_func_suffix(src0.val[0], src1.val[0]); \ - auto vitem1 = vsubq_##_func_suffix(src0.val[1], src1.val[1]); \ - return {{vitem0, vitem1}}; \ - } \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type operator()(const _neon_type& src0, \ - const _neon_type& src1) const { \ - return vsubq_##_func_suffix(src0, src1); \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct SubOp<_ctype> : SubOpBase<_ctype> { \ + using SubOpBase::SubOpBase; \ + using SubOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()( \ + const _neon_type2& src0, const _neon_type2& src1) const { \ + auto vitem0 = vsubq_##_func_suffix(src0.val[0], src1.val[0]); \ + auto vitem1 = vsubq_##_func_suffix(src0.val[1], src1.val[1]); \ + return {{vitem0, vitem1}}; \ + } \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ + return vsubq_##_func_suffix(src0, src1); \ + } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -71,13 +72,12 @@ template <> struct SubOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_qint8& src0, const dt_qint8& src1, - dt_qint8* dst) const { + void operator()(const dt_qint8& src0, const dt_qint8& src1, dt_qint8* dst) const { *dst = operator()(src0, src1); } dt_qint8 operator()(const dt_qint8& src0, const dt_qint8& src1) const { - return QConverter::convert(src0.as_int8() * scale0 - - src1.as_int8() * scale1); + return QConverter::convert( + src0.as_int8() * scale0 - src1.as_int8() * scale1); } }; @@ -85,15 +85,14 @@ template <> struct SubOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const dt_quint8& src0, const dt_quint8& src1, - dt_quint8* dst) const { + void operator()( + const dt_quint8& src0, const dt_quint8& src1, dt_quint8* dst) const { *dst = operator()(src0, src1); } dt_quint8 operator()(const dt_quint8& src0, const dt_quint8& src1) const { float fsrc0 = src0.as_uint8() * scale0 - this->szp0; float fsrc1 = src1.as_uint8() * scale1 - this->szp1; - return QConverter::convert(fsrc0 - fsrc1, - this->dzp); + return QConverter::convert(fsrc0 - fsrc1, this->dzp); } }; @@ -103,12 +102,11 @@ struct SubOp : SubOpBase { constexpr static size_t SIMD_WIDTH = 16; using SubOpBase::operator(); - void operator()(const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, - dt_qint8* dst) const { + void operator()( + const int8x16x2_t& vsrc0, const int8x16x2_t& vsrc1, dt_qint8* dst) const { OPERATOR_BINARY_QINT8; } - int8x8_t operator()(const int32x4x2_t& vsrc0, - const int32x4x2_t& vsrc1) const { + int8x8_t operator()(const int32x4x2_t& vsrc0, const int32x4x2_t& vsrc1) const { auto vitem0 = vsubq_f32( vmulq_f32(vcvtq_f32_s32(vsrc0.val[0]), this->vscale0), vmulq_f32(vcvtq_f32_s32(vsrc1.val[0]), this->vscale1)); @@ -125,25 +123,21 @@ struct SubOp : SubOpBase { constexpr static size_t SIMD_WIDTH = 16; using SubOpBase::operator(); - void operator()(const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, - dt_quint8* dst) const { + void operator()( + const uint8x16x2_t& vsrc0, const uint8x16x2_t& vsrc1, + dt_quint8* dst) const { OPERATOR_BINARY_QUINT8; } - uint8x8_t operator()(const uint32x4x2_t& vsrc0, - const uint32x4x2_t& vsrc1) const { - auto vfsrc0 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), - this->vszp0); - auto vfsrc1 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale1), - this->vszp1); + uint8x8_t operator()(const uint32x4x2_t& vsrc0, const uint32x4x2_t& vsrc1) const { + auto vfsrc0 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[0]), this->vscale0), this->vszp0); + auto vfsrc1 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc1.val[0]), this->vscale1), this->vszp1); auto vitem0 = vsubq_f32(vfsrc0, vfsrc1); - vfsrc0 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale0), - this->vszp0); - vfsrc1 = - vsubq_f32(vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale1), - this->vszp1); + vfsrc0 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc0.val[1]), this->vscale0), this->vszp0); + vfsrc1 = vsubq_f32( + vmulq_f32(vcvtq_f32_u32(vsrc1.val[1]), this->vscale1), this->vszp1); auto vitem1 = vsubq_f32(vfsrc0, vfsrc1); return QConverter::convert( {{vitem0, vitem1}}, this->vdzp); diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/tanh.h b/dnn/src/arm_common/elemwise_helper/kimpl/tanh.h index ecfdcc5c..10792dcb 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/tanh.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/tanh.h @@ -30,40 +30,38 @@ struct TanhOpBase : UnaryOpBase { template struct TanhOp; -#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ - template <> \ - struct TanhOp<_ctype> : TanhOpBase<_ctype> { \ - using TanhOpBase::TanhOpBase; \ - using TanhOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type& src, _ctype* dst) const { \ - auto vitem = operator()(src); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type operator()(const _neon_type& src) const { \ - auto one_val = vdupq_n_##_func_suffix(1.f); \ - auto two_val = vdupq_n_##_func_suffix(2.f); \ - auto val1 = src.val[0]; \ - auto val2 = src.val[1]; \ - val1 = vmulq_##_func_suffix(two_val, val1); \ - val2 = vmulq_##_func_suffix(two_val, val2); \ - val1 = exp_ps_##_func_suffix(val1); \ - val2 = exp_ps_##_func_suffix(val2); \ - val1 = vaddq_##_func_suffix(one_val, val1); \ - val2 = vaddq_##_func_suffix(one_val, val2); \ - auto rval1 = vrecpeq_##_func_suffix(val1); \ - auto rval2 = vrecpeq_##_func_suffix(val2); \ - rval1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(val1, rval1), \ - rval1); \ - rval2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(val2, rval2), \ - rval2); \ - val1 = vmulq_##_func_suffix(two_val, rval1); \ - val2 = vmulq_##_func_suffix(two_val, rval2); \ - val1 = vsubq_##_func_suffix(one_val, val1); \ - val2 = vsubq_##_func_suffix(one_val, val2); \ - return {{val1, val2}}; \ - } \ +#define OP(_ctype, _neon_type, _func_suffix, _simd_width) \ + template <> \ + struct TanhOp<_ctype> : TanhOpBase<_ctype> { \ + using TanhOpBase::TanhOpBase; \ + using TanhOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()(const _neon_type& src, _ctype* dst) const { \ + auto vitem = operator()(src); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type operator()(const _neon_type& src) const { \ + auto one_val = vdupq_n_##_func_suffix(1.f); \ + auto two_val = vdupq_n_##_func_suffix(2.f); \ + auto val1 = src.val[0]; \ + auto val2 = src.val[1]; \ + val1 = vmulq_##_func_suffix(two_val, val1); \ + val2 = vmulq_##_func_suffix(two_val, val2); \ + val1 = exp_ps_##_func_suffix(val1); \ + val2 = exp_ps_##_func_suffix(val2); \ + val1 = vaddq_##_func_suffix(one_val, val1); \ + val2 = vaddq_##_func_suffix(one_val, val2); \ + auto rval1 = vrecpeq_##_func_suffix(val1); \ + auto rval2 = vrecpeq_##_func_suffix(val2); \ + rval1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(val1, rval1), rval1); \ + rval2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(val2, rval2), rval2); \ + val1 = vmulq_##_func_suffix(two_val, rval1); \ + val2 = vmulq_##_func_suffix(two_val, rval2); \ + val1 = vsubq_##_func_suffix(one_val, val1); \ + val2 = vsubq_##_func_suffix(one_val, val2); \ + return {{val1, val2}}; \ + } \ }; OP(dt_float32, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/true_div.h b/dnn/src/arm_common/elemwise_helper/kimpl/true_div.h index cd1bd628..79fd2a65 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/true_div.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/true_div.h @@ -21,8 +21,8 @@ namespace arm_common { template struct TrueDivOpBase : BinaryOpBase { using BinaryOpBase::BinaryOpBase; - void operator()(const src_ctype& src0, const src_ctype& src1, - dst_ctype* dst) const { + void operator()( + const src_ctype& src0, const src_ctype& src1, dst_ctype* dst) const { *dst = operator()(src0, src1); } dst_ctype operator()(const src_ctype& src0, const src_ctype& src1) const { @@ -34,37 +34,38 @@ struct TrueDivOpBase : BinaryOpBase { template struct TrueDivOp; -#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ - template <> \ - struct TrueDivOp<_ctype> : TrueDivOpBase<_ctype> { \ - using TrueDivOpBase::TrueDivOpBase; \ - using TrueDivOpBase::operator(); \ - constexpr static size_t SIMD_WIDTH = _simd_width; \ - void operator()(const _neon_type2& src0, const _neon_type2& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem.val[0]); \ - vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ - } \ - _neon_type2 operator()(const _neon_type2& src0, \ - const _neon_type2& src1) const { \ - auto val1 = src0.val[0]; \ - auto val2 = src0.val[1]; \ - auto val3 = src1.val[0]; \ - auto val4 = src1.val[1]; \ - val1 = vdivq_##_func_suffix(val1, val3); \ - val2 = vdivq_##_func_suffix(val2, val4); \ - return {{val1, val2}}; \ - } \ - void operator()(const _neon_type& src0, const _neon_type& src1, \ - dst_ctype* dst) const { \ - auto vitem = operator()(src0, src1); \ - vst1q_##_func_suffix(dst, vitem); \ - } \ - _neon_type operator()(const _neon_type& src0, \ - const _neon_type& src1) const { \ - return vdivq_##_func_suffix(src0, src1); \ - } \ +#define OP(_ctype, _neon_type, _neon_type2, _func_suffix, _simd_width) \ + template <> \ + struct TrueDivOp<_ctype> : TrueDivOpBase<_ctype> { \ + using TrueDivOpBase::TrueDivOpBase; \ + using TrueDivOpBase::operator(); \ + constexpr static size_t SIMD_WIDTH = _simd_width; \ + void operator()( \ + const _neon_type2& src0, const _neon_type2& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem.val[0]); \ + vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ + } \ + _neon_type2 operator()( \ + const _neon_type2& src0, const _neon_type2& src1) const { \ + auto val1 = src0.val[0]; \ + auto val2 = src0.val[1]; \ + auto val3 = src1.val[0]; \ + auto val4 = src1.val[1]; \ + val1 = vdivq_##_func_suffix(val1, val3); \ + val2 = vdivq_##_func_suffix(val2, val4); \ + return {{val1, val2}}; \ + } \ + void operator()( \ + const _neon_type& src0, const _neon_type& src1, \ + dst_ctype* dst) const { \ + auto vitem = operator()(src0, src1); \ + vst1q_##_func_suffix(dst, vitem); \ + } \ + _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ + return vdivq_##_func_suffix(src0, src1); \ + } \ }; OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/arm_common/elemwise_helper/kimpl/typecvt.h b/dnn/src/arm_common/elemwise_helper/kimpl/typecvt.h index 1cdfba02..a37f574a 100644 --- a/dnn/src/arm_common/elemwise_helper/kimpl/typecvt.h +++ b/dnn/src/arm_common/elemwise_helper/kimpl/typecvt.h @@ -28,8 +28,8 @@ struct TypeCvtOp : UnaryOpBase { vst1_s8(reinterpret_cast(dst), operator()(vsrc)); } void operator()(const int32x4_t& vsrc, dt_qint8* dst) const { - vst1_lane_s32(reinterpret_cast(dst), - (int32x2_t)(operator()(vsrc)), 0); + vst1_lane_s32( + reinterpret_cast(dst), (int32x2_t)(operator()(vsrc)), 0); } void operator()(const src_ctype& src, dst_ctype* dst) const { *dst = operator()(src); @@ -76,8 +76,8 @@ struct TypeCvtOp : UnaryOpBase { auto vitem0 = vmulq_f32(vcvtq_f32_s32(vsrc.val[0]), this->vscale); auto vitem1 = vmulq_f32(vcvtq_f32_s32(vsrc.val[1]), this->vscale); - return QConverter::convert({{vitem0, vitem1}}, - this->vzp); + return QConverter::convert( + {{vitem0, vitem1}}, this->vzp); } }; diff --git a/dnn/src/arm_common/elemwise_helper/op_binary.h b/dnn/src/arm_common/elemwise_helper/op_binary.h index 5b263681..ee8aeff3 100644 --- a/dnn/src/arm_common/elemwise_helper/op_binary.h +++ b/dnn/src/arm_common/elemwise_helper/op_binary.h @@ -11,33 +11,33 @@ #pragma once #include "src/arm_common/elemwise_helper/kimpl/add.h" -#include "src/arm_common/elemwise_helper/kimpl/mul.h" -#include "src/arm_common/elemwise_helper/kimpl/rmulh.h" +#include "src/arm_common/elemwise_helper/kimpl/fuse_add_h_swish.h" #include "src/arm_common/elemwise_helper/kimpl/fuse_add_relu.h" #include "src/arm_common/elemwise_helper/kimpl/fuse_add_sigmoid.h" #include "src/arm_common/elemwise_helper/kimpl/fuse_add_tanh.h" -#include "src/arm_common/elemwise_helper/kimpl/fuse_add_h_swish.h" #include "src/arm_common/elemwise_helper/kimpl/max.h" #include "src/arm_common/elemwise_helper/kimpl/min.h" +#include "src/arm_common/elemwise_helper/kimpl/mul.h" #include "src/arm_common/elemwise_helper/kimpl/pow.h" +#include "src/arm_common/elemwise_helper/kimpl/rmulh.h" #include "src/arm_common/elemwise_helper/kimpl/sub.h" #include "src/arm_common/elemwise_helper/kimpl/true_div.h" //////////////////// quantization ////////////////////////////// namespace megdnn { namespace arm_common { -#define cb(op) \ - template <> \ - struct op \ - : BinaryQuantizationOp > { \ - using BinaryQuantizationOp >::BinaryQuantizationOp; \ - }; \ - template <> \ - struct op \ - : BinaryQuantizationOp > { \ - using BinaryQuantizationOp >::BinaryQuantizationOp; \ +#define cb(op) \ + template <> \ + struct op \ + : BinaryQuantizationOp> { \ + using BinaryQuantizationOp< \ + dt_qint8, dt_qint8, op>::BinaryQuantizationOp; \ + }; \ + template <> \ + struct op \ + : BinaryQuantizationOp> { \ + using BinaryQuantizationOp< \ + dt_quint8, dt_quint8, op>::BinaryQuantizationOp; \ }; cb(TrueDivOp); diff --git a/dnn/src/arm_common/elemwise_helper/op_ternary.h b/dnn/src/arm_common/elemwise_helper/op_ternary.h index 8b9be8b7..5bafadff 100644 --- a/dnn/src/arm_common/elemwise_helper/op_ternary.h +++ b/dnn/src/arm_common/elemwise_helper/op_ternary.h @@ -15,18 +15,18 @@ //////////////////// quantization ////////////////////////////// namespace megdnn { namespace arm_common { -#define cb(op) \ - template <> \ - struct op \ - : TernaryQuantizationOp > { \ - using TernaryQuantizationOp >::TernaryQuantizationOp; \ - }; \ - template <> \ - struct op \ - : TernaryQuantizationOp > { \ - using TernaryQuantizationOp >::TernaryQuantizationOp; \ +#define cb(op) \ + template <> \ + struct op \ + : TernaryQuantizationOp> { \ + using TernaryQuantizationOp< \ + dt_qint8, dt_qint8, op>::TernaryQuantizationOp; \ + }; \ + template <> \ + struct op \ + : TernaryQuantizationOp> { \ + using TernaryQuantizationOp< \ + dt_quint8, dt_quint8, op>::TernaryQuantizationOp; \ }; cb(FuseMulAdd3Op); diff --git a/dnn/src/arm_common/elemwise_helper/op_unary.h b/dnn/src/arm_common/elemwise_helper/op_unary.h index 80cde07a..6bdf72bc 100644 --- a/dnn/src/arm_common/elemwise_helper/op_unary.h +++ b/dnn/src/arm_common/elemwise_helper/op_unary.h @@ -10,32 +10,31 @@ */ #pragma once -#include "src/arm_common/elemwise_helper/kimpl/none.h" #include "src/arm_common/elemwise_helper/kimpl/abs.h" #include "src/arm_common/elemwise_helper/kimpl/exp.h" #include "src/arm_common/elemwise_helper/kimpl/fast_tanh.h" #include "src/arm_common/elemwise_helper/kimpl/hswish.h" +#include "src/arm_common/elemwise_helper/kimpl/none.h" #include "src/arm_common/elemwise_helper/kimpl/relu.h" #include "src/arm_common/elemwise_helper/kimpl/sigmoid.h" #include "src/arm_common/elemwise_helper/kimpl/tanh.h" -#include "src/arm_common/elemwise_helper/kimpl/hswish.h" #include "src/arm_common/elemwise_helper/kimpl/typecvt.h" //////////////////// quantization ////////////////////////////// namespace megdnn { namespace arm_common { -#define cb(op) \ - template <> \ - struct op \ - : UnaryQuantizationOp > { \ - using UnaryQuantizationOp >::UnaryQuantizationOp; \ - }; \ - template <> \ - struct op \ - : UnaryQuantizationOp > { \ - using UnaryQuantizationOp >::UnaryQuantizationOp; \ +#define cb(op) \ + template <> \ + struct op \ + : UnaryQuantizationOp> { \ + using UnaryQuantizationOp< \ + dt_qint8, dt_qint8, op>::UnaryQuantizationOp; \ + }; \ + template <> \ + struct op \ + : UnaryQuantizationOp> { \ + using UnaryQuantizationOp< \ + dt_quint8, dt_quint8, op>::UnaryQuantizationOp; \ }; cb(SigmoidOp); diff --git a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp index 96526a64..da5acbf6 100644 --- a/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp @@ -22,8 +22,8 @@ namespace { using namespace megdnn; template -void neon_round_shr_saturate_int16_static_k(const int16_t* a_ptr, size_t size, - int8_t* dst_ptr) { +void neon_round_shr_saturate_int16_static_k( + const int16_t* a_ptr, size_t size, int8_t* dst_ptr) { static_assert(k >= 1 && k <= 8, "Shift offset out of range"); size_t i = 0; int16x8_t x0, x1, f0, f1; @@ -38,8 +38,7 @@ void neon_round_shr_saturate_int16_static_k(const int16_t* a_ptr, size_t size, vst1_s8(dst_ptr + 8, vqrshrn_n_s16(x1, k)); } for (; i < size; i++, a_ptr++, dst_ptr++) { - *dst_ptr = megdnn::elemwise_multi_type::round_shr_saturate( + *dst_ptr = megdnn::elemwise_multi_type::round_shr_saturate( *a_ptr, k); } } @@ -78,8 +77,7 @@ void ElemwiseMultiTypeImpl::neon_round_shr_saturate_bcast_scalar( vst1q_s8(dst_ptr + 16, vrshlq_s8(x1, shift_vec)); } for (; i < size; i++, a_ptr++, dst_ptr++) { - *dst_ptr = elemwise_multi_type::round_shr_saturate( - *a_ptr, k); + *dst_ptr = elemwise_multi_type::round_shr_saturate(*a_ptr, k); } } @@ -120,8 +118,7 @@ void ElemwiseMultiTypeImpl::neon_round_shr_saturate_bcast_scalar( vst1_s8(dst_ptr + 8, vqmovn_s16(vrshlq_s16(x1, shift_vec))); } for (; i < size; i++, a_ptr++, dst_ptr++) { - *dst_ptr = elemwise_multi_type::round_shr_saturate( - *a_ptr, k); + *dst_ptr = elemwise_multi_type::round_shr_saturate(*a_ptr, k); } } @@ -139,13 +136,13 @@ void ElemwiseMultiTypeImpl::neon_round_shr_saturate_bcast_scalar( f1 = vshrq_n_s32(x1, 31); x0 = vqaddq_s32(x0, f0); x1 = vqaddq_s32(x1, f1); - o0 = vqmovn_s16(vcombine_s16(vqmovn_s32(vrshlq_s32(x0, shift_vec)), - vqmovn_s32(vrshlq_s32(x1, shift_vec)))); + o0 = vqmovn_s16(vcombine_s16( + vqmovn_s32(vrshlq_s32(x0, shift_vec)), + vqmovn_s32(vrshlq_s32(x1, shift_vec)))); vst1_s8(dst_ptr, o0); } for (; i < size; i++, a_ptr++, dst_ptr++) { - *dst_ptr = elemwise_multi_type::round_shr_saturate( - *a_ptr, k); + *dst_ptr = elemwise_multi_type::round_shr_saturate(*a_ptr, k); } } @@ -182,9 +179,8 @@ void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( void neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_int16( size_t batch_size, size_t channel_size, size_t channel_stride, - const int16_t* x_ptr, const int16_t* b_ptr, const int16_t M, - const int offset, const int8_t minv, const int8_t maxv, size_t size, - int8_t* dst_ptr) { + const int16_t* x_ptr, const int16_t* b_ptr, const int16_t M, const int offset, + const int8_t minv, const int8_t maxv, size_t size, int8_t* dst_ptr) { MEGDNN_MARK_USED_VAR(size); const int16x8_t shift_vec = vdupq_n_s16(-offset); const int16x8_t M_vec = vdupq_n_s16(M); @@ -197,8 +193,7 @@ void neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_int16( for (size_t chan = 0; chan < channel_size; ++chan, ++b_pos) { auto b_vec = vdupq_n_s16(b_ptr[b_pos]); channel_offset += channel_stride; - for (; i + 15 < channel_offset; - i += 16, x_ptr += 16, dst_ptr += 16) { + for (; i + 15 < channel_offset; i += 16, x_ptr += 16, dst_ptr += 16) { auto x0 = vld1q_s16(x_ptr); auto x1 = vld1q_s16(x_ptr + 8); x0 = vaddq_s16(x0, b_vec); @@ -210,8 +205,9 @@ void neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_int16( auto fixup1 = vshrq_n_s16(x1, 15); x0 = vqaddq_s16(x0, fixup0); x1 = vqaddq_s16(x1, fixup1); - auto o0 = vcombine_s8(vqmovn_s16(vrshlq_s16(x0, shift_vec)), - vqmovn_s16(vrshlq_s16(x1, shift_vec))); + auto o0 = vcombine_s8( + vqmovn_s16(vrshlq_s16(x0, shift_vec)), + vqmovn_s16(vrshlq_s16(x1, shift_vec))); o0 = vminq_s8(o0, maxv_vec); o0 = vmaxq_s8(o0, minv_vec); vst1q_s8(dst_ptr, o0); @@ -231,10 +227,9 @@ void neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_int16( dt_int16 bias = b_ptr[b_pos]; for (; i < channel_offset; ++i, ++x_ptr, ++dst_ptr) { dt_int16 result = rounding_shift_right_away_from_zero( - round_mulh_saturate(*x_ptr + bias, M), - offset); - *dst_ptr = static_cast(std::max( - std::min(result, maxv), minv)); + round_mulh_saturate(*x_ptr + bias, M), offset); + *dst_ptr = static_cast( + std::max(std::min(result, maxv), minv)); } } } @@ -242,9 +237,8 @@ void neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_int16( void neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_int32( size_t batch_size, size_t channel_size, size_t channel_stride, - const int32_t* x_ptr, const int32_t* b_ptr, const int32_t M, - const int offset, const int8_t minv, const int8_t maxv, size_t size, - int8_t* dst_ptr) { + const int32_t* x_ptr, const int32_t* b_ptr, const int32_t M, const int offset, + const int8_t minv, const int8_t maxv, size_t size, int8_t* dst_ptr) { MEGDNN_MARK_USED_VAR(size); const int32x4_t shift_vec = vdupq_n_s32(-offset); const int32x4_t M_vec = vdupq_n_s32(M); @@ -279,10 +273,9 @@ void neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_int32( dt_int32 bias = b_ptr[b_pos]; for (; i < channel_offset; ++i, ++x_ptr, ++dst_ptr) { dt_int32 result = rounding_shift_right_away_from_zero( - round_mulh_saturate(*x_ptr + bias, M), - offset); - *dst_ptr = static_cast(std::max( - std::min(result, maxv), minv)); + round_mulh_saturate(*x_ptr + bias, M), offset); + *dst_ptr = static_cast( + std::max(std::min(result, maxv), minv)); } } } @@ -301,16 +294,16 @@ bool ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_rshr( auto minv = param[4].ptr()[0]; auto maxv = param[5].ptr()[0]; switch (param[0].layout.dtype.enumv()) { -#define DISPATCH(stype, suffix) \ - case DTypeTrait::enumv: { \ - auto x_ptr = param[0].ptr::ctype>(); \ - auto b_ptr = param[1].ptr::ctype>(); \ - auto M = param[2].ptr::ctype>()[0]; \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_##suffix( \ - binfo.x, binfo.y, binfo.z, x_ptr, b_ptr, M, offset, \ - minv, maxv, param.size, dst)); \ - break; \ +#define DISPATCH(stype, suffix) \ + case DTypeTrait::enumv: { \ + auto x_ptr = param[0].ptr::ctype>(); \ + auto b_ptr = param[1].ptr::ctype>(); \ + auto M = param[2].ptr::ctype>()[0]; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_##suffix( \ + binfo.x, binfo.y, binfo.z, x_ptr, b_ptr, M, offset, minv, \ + maxv, param.size, dst)); \ + break; \ } DISPATCH(dtype::Int16, int16) DISPATCH(dtype::Int32, int32) @@ -327,79 +320,75 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst) { if (dispatch_fuse_add_rmulh_rshr(param, dst)) return; - fallback::ElemwiseMultiTypeImpl:: - on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8(param, dst); + fallback::ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( + param, dst); } void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst) { if (dispatch_fuse_add_rmulh_rshr(param, dst)) return; - fallback::ElemwiseMultiTypeImpl:: - on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8(param, dst); + fallback::ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( + param, dst); } -void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<1>& param, - const TensorND& dst, - Elemwise::Mode mode) { +void ElemwiseMultiTypeImpl::on_quantized_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst, Elemwise::Mode mode) { megdnn_assert(param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); -#define DISPATCH_MODE(_src_dt, _dst_dt) \ - switch (mode) { \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, \ - HSwishOp) \ - default: \ - break; \ - } - -#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ - switch (mode) { \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ABS, AbsOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SIGMOID, \ - SigmoidOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::EXP, ExpOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TANH, TanhOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FAST_TANH, \ - FastTanhOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, \ - HSwishOp) \ - default: \ - break; \ - } - -#define DISPATCH() \ - if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ - dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ - DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ - } else if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ - dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ - DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, \ - dtype::Quantized8Asymm) \ - } else if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ - dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ - DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ - } else if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ - dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ - DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ +#define DISPATCH_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, HSwishOp) \ + default: \ + break; \ + } + +#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::RELU, ReluOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ABS, AbsOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SIGMOID, SigmoidOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::EXP, ExpOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TANH, TanhOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FAST_TANH, FastTanhOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::H_SWISH, HSwishOp) \ + default: \ + break; \ + } + +#define DISPATCH() \ + if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ + } else if ( \ + param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ + } else if ( \ + param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ + } else if ( \ + param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ } TensorND src = param[0]; size_t nr_elems = src.layout.total_nr_elems(); -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerUnary<_op, VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - run(src.ptr(), dst.ptr(), \ - src.layout.dtype, dst.layout.dtype, nr_elems)); \ - return; \ +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function run = \ + OpCallerUnary<_op, VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src.ptr(), dst.ptr(), src.layout.dtype, \ + dst.layout.dtype, nr_elems)); \ + return; \ } DISPATCH() @@ -412,86 +401,81 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<1>& param, #undef DISPATCH_MODE } -void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, - const TensorND& dst, - Elemwise::Mode mode) { - megdnn_assert(param[0].layout.dtype.enumv() == - param[1].layout.dtype.enumv() && - param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); +void ElemwiseMultiTypeImpl::on_quantized_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst, Elemwise::Mode mode) { + megdnn_assert( + param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv() && + param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); -#define DISPATCH_MODE(_src_dt, _dst_dt) \ - switch (mode) { \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, \ - FuseAddReluOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, \ - Elemwise::Mode::FUSE_ADD_H_SWISH, \ - FuseAddHSwishOp) \ - default: \ - break; \ +#define DISPATCH_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, FuseAddReluOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_H_SWISH, FuseAddHSwishOp) \ + default: \ + break; \ } #if MEGDNN_AARCH64 -#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ - switch (mode) { \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MIN, MinOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MAX, MaxOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SUB, SubOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MUL, MulOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TRUE_DIV, \ - TrueDivOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, \ - FuseAddReluOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, \ - Elemwise::Mode::FUSE_ADD_SIGMOID, \ - FuseAddSigmoidOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_TANH, \ - FuseAddTanhOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, \ - Elemwise::Mode::FUSE_ADD_H_SWISH, \ - FuseAddHSwishOp) \ - default: \ - break; \ +#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MIN, MinOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MAX, MaxOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SUB, SubOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MUL, MulOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::TRUE_DIV, TrueDivOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, FuseAddReluOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_SIGMOID, FuseAddSigmoidOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_TANH, FuseAddTanhOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_H_SWISH, FuseAddHSwishOp) \ + default: \ + break; \ } #else -#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ - switch (mode) { \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MIN, MinOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MAX, MaxOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SUB, SubOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MUL, MulOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, \ - FuseAddReluOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, \ - Elemwise::Mode::FUSE_ADD_SIGMOID, \ - FuseAddSigmoidOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_TANH, \ - FuseAddTanhOp) \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, \ - Elemwise::Mode::FUSE_ADD_H_SWISH, \ - FuseAddHSwishOp) \ - default: \ - break; \ +#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::ADD, AddOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MIN, MinOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MAX, MaxOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::SUB, SubOp) \ + DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::MUL, MulOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_RELU, FuseAddReluOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_SIGMOID, FuseAddSigmoidOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_TANH, FuseAddTanhOp) \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_ADD_H_SWISH, FuseAddHSwishOp) \ + default: \ + break; \ } #endif -#define DISPATCH() \ - if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ - dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ - DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ - } else if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ - dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ - DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ - } else if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ - dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ - DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ - } else if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ - dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ - DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, \ - dtype::Quantized8Asymm) \ +#define DISPATCH() \ + if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_MODE(dtype::QuantizedS32, dtype::QuantizedS8) \ + } else if ( \ + param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + DISPATCH_MODE(dtype::QuantizedS32, dtype::Quantized8Asymm) \ + } else if ( \ + param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ + } else if ( \ + param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ } TensorND src0 = param[0]; @@ -500,18 +484,18 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, //! VEC + VEC if (is_vector(src0.layout) && is_vector(src1.layout)) { size_t nr_elems = src0.layout.total_nr_elems(); -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerBinary<_op, VEC_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - run(src0.ptr(), src1.ptr(), \ - dst.ptr(), src0.layout.dtype, \ - src1.layout.dtype, dst.layout.dtype, nr_elems)); \ - return; \ +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ + src0.ptr(), src1.ptr(), dst.ptr(), \ + src0.layout.dtype, src1.layout.dtype, dst.layout.dtype, nr_elems)); \ + return; \ } DISPATCH() @@ -521,33 +505,31 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, //! VEC + SCALAR { - bool normal_case = - is_vector(src0.layout) && is_broadcasted_scalar(src1.layout); + bool normal_case = is_vector(src0.layout) && is_broadcasted_scalar(src1.layout); bool swap_case = false; bool commutable = false; if (mode != Elemwise::Mode::SUB && mode != Elemwise::Mode::TRUE_DIV) commutable = true; if (!normal_case && commutable) { - swap_case = is_vector(src1.layout) && - is_broadcasted_scalar(src0.layout); + swap_case = is_vector(src1.layout) && is_broadcasted_scalar(src0.layout); } if (normal_case || swap_case) { auto &lhs = src0, &rhs = src1; if (swap_case) std::swap(lhs, rhs); -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerBinary<_op, \ - VEC_SCALAR>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ - src0.ptr(), src1.ptr()[0], \ - dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ - dst.layout.dtype, src0.layout.total_nr_elems())); \ - return; \ +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr()[0], \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, src0.layout.total_nr_elems())); \ + return; \ } DISPATCH() @@ -558,19 +540,19 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, //! SCALAR + VEC if (!commutable && is_vector(src1.layout) && is_broadcasted_scalar(src0.layout)) { -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerBinary<_op, \ - SCALAR_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ - src0.ptr()[0], src1.ptr(), \ - dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ - dst.layout.dtype, src1.layout.total_nr_elems())); \ - return; \ +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, SCALAR_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr()[0], src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, src1.layout.total_nr_elems())); \ + return; \ } DISPATCH() @@ -596,19 +578,19 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, auto &lhs = src0, &rhs = src1; if (swap_case) std::swap(lhs, rhs); -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerBinary<_op, \ - VEC_BCAST101>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ - src0.ptr(), src1.ptr(), \ - dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ - dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ - return; \ +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, VEC_BCAST101>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ + return; \ } DISPATCH() @@ -619,19 +601,19 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, //! BCAST101 + VEC : only for SUB or TRUE_DIV if (!commutable && is_vector(src1.layout) && is_broadcasted_channel_like(src0.layout, binfo)) { -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerBinary<_op, \ - BCAST101_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ - src0.ptr(), src1.ptr(), \ - dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ - dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ - return; \ +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, BCAST101_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ + return; \ } DISPATCH() @@ -646,23 +628,21 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, if (is_vector(src0.layout) && (is_broadcastedx_channel_like<4>(src1.layout, binfo) || is_broadcastedx_channel_like<8>(src1.layout, binfo))) { -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerBinary<_op, \ - VEC_BCAST101xX>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ - src0.ptr(), src1.ptr(), \ - dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ - dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ - return; \ - } - size_t batch_size = - src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, VEC_BCAST101xX>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); DISPATCH() #undef DISPATCH_SINGLE_MODE @@ -671,23 +651,21 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, //! BCAST101x + VEC if (is_vector(src1.layout) && is_broadcastedx_channel_like<4>(src0.layout, binfo)) { -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerBinary<_op, \ - BCAST101xX_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ - src0.ptr(), src1.ptr(), \ - dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ - dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ - return; \ - } - size_t batch_size = - src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerBinary<_op, BCAST101xX_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); DISPATCH() #undef DISPATCH_SINGLE_MODE @@ -701,31 +679,30 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, #undef DISPATCH } -void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, - const TensorND& dst, - Elemwise::Mode mode) { +void ElemwiseMultiTypeImpl::on_quantized_mode( + const ElemwiseOpParamN<3>& param, const TensorND& dst, Elemwise::Mode mode) { megdnn_assert( param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv() && param[0].layout.dtype.enumv() == param[2].layout.dtype.enumv() && param[0].layout.dtype.category() == DTypeCategory::QUANTIZED); megdnn_assert(dst.layout.dtype.category() == DTypeCategory::QUANTIZED); -#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ - switch (mode) { \ - DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, Elemwise::Mode::FUSE_MUL_ADD3, \ - FuseMulAdd3Op) \ - default: \ - break; \ +#define DISPATCH_QUANTIZED_MODE(_src_dt, _dst_dt) \ + switch (mode) { \ + DISPATCH_SINGLE_MODE( \ + _src_dt, _dst_dt, Elemwise::Mode::FUSE_MUL_ADD3, FuseMulAdd3Op) \ + default: \ + break; \ } -#define DISPATCH() \ - if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ - dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ - DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ - } else if (param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ - dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ - DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, \ - dtype::Quantized8Asymm) \ +#define DISPATCH() \ + if (param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 && \ + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8) { \ + DISPATCH_QUANTIZED_MODE(dtype::QuantizedS8, dtype::QuantizedS8) \ + } else if ( \ + param[0].layout.dtype.enumv() == DTypeEnum::Quantized8Asymm && \ + dst.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { \ + DISPATCH_QUANTIZED_MODE(dtype::Quantized8Asymm, dtype::Quantized8Asymm) \ } TensorND src0 = param[0]; @@ -733,24 +710,21 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, TensorND src2 = param[2]; //! VEC + VEC + VEC - if (is_vector(src0.layout) && is_vector(src1.layout) && - is_vector(src2.layout)) { + if (is_vector(src0.layout) && is_vector(src1.layout) && is_vector(src2.layout)) { size_t nr_elems = src0.layout.total_nr_elems(); -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerTernary<_op, \ - VEC_VEC_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - run(src0.ptr(), src1.ptr(), \ - src2.ptr(), dst.ptr(), \ - src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, nr_elems)); \ - return; \ +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary<_op, VEC_VEC_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ + src0.ptr(), src1.ptr(), src2.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + src2.layout.dtype, dst.layout.dtype, nr_elems)); \ + return; \ } DISPATCH() @@ -761,21 +735,20 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, //! VEC + VEC + SCALAR if (is_vector(src0.layout) && is_vector(src1.layout) && is_broadcasted_scalar(src2.layout)) { -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerTernary<_op, \ - VEC_VEC_SCALAR>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - run(src0.ptr(), src1.ptr(), \ - src2.ptr()[0], dst.ptr(), \ - src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, src0.layout.total_nr_elems())); \ - return; \ +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary<_op, VEC_VEC_SCALAR>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr()[0], dst.ptr(), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + src0.layout.total_nr_elems())); \ + return; \ } DISPATCH() @@ -790,21 +763,20 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, is_broadcasted_channel_like(src0.layout, binfo) && src0.layout.eq_shape(src2.layout); if (normal_case) { -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerTernary<_op, \ - BCAST101_VEC_BCAST101>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - run(src0.ptr(), src1.ptr(), \ - src2.ptr(), dst.ptr(), \ - src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ - return; \ +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary< \ + _op, BCAST101_VEC_BCAST101>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ + src0.ptr(), src1.ptr(), src2.ptr(), \ + dst.ptr(), src0.layout.dtype, src1.layout.dtype, \ + src2.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, binfo.z)); \ + return; \ } DISPATCH() @@ -820,25 +792,24 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, (is_broadcastedx_channel_like<4>(src1.layout, binfo) || is_broadcastedx_channel_like<8>(src1.layout, binfo)) && src0.layout.eq_shape(src2.layout)) { -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerTernary<_op, \ - VEC_BCAST101xX_VEC>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - run(src0.ptr(), src1.ptr(), \ - src2.ptr(), dst.ptr(), \ - src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ - return; \ - } - - size_t batch_size = - src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary< \ + _op, VEC_BCAST101xX_VEC>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr(), dst.ptr(), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); DISPATCH() #undef DISPATCH_SINGLE_MODE @@ -849,25 +820,24 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, (is_broadcastedx_channel_like<4>(src0.layout, binfo) || is_broadcastedx_channel_like<8>(src0.layout, binfo)) && src0.layout.eq_shape(src2.layout)) { -#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ - case _mode: { \ - using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ - using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ - thin_function \ - run = OpCallerTernary<_op, \ - BCAST101xX_VEC_BCAST101xX>::run; \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - run(src0.ptr(), src1.ptr(), \ - src2.ptr(), dst.ptr(), \ - src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ - dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ - return; \ - } - - size_t batch_size = - src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); +#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ + case _mode: { \ + using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ + using dst_ctype = typename DTypeTrait<_dst_dt>::ctype; \ + thin_function \ + run = OpCallerTernary< \ + _op, BCAST101xX_VEC_BCAST101xX>::run; \ + MEGDNN_DISPATCH_CPU_KERN_OPR( \ + run(src0.ptr(), src1.ptr(), \ + src2.ptr(), dst.ptr(), src0.layout.dtype, \ + src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ + batch_size, binfo.x, binfo.y, binfo.z)); \ + return; \ + } + + size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); DISPATCH() #undef DISPATCH_SINGLE_MODE diff --git a/dnn/src/arm_common/elemwise_multi_type/opr_impl.h b/dnn/src/arm_common/elemwise_multi_type/opr_impl.h index ed57b76c..96e55962 100644 --- a/dnn/src/arm_common/elemwise_multi_type/opr_impl.h +++ b/dnn/src/arm_common/elemwise_multi_type/opr_impl.h @@ -18,32 +18,35 @@ namespace arm_common { class ElemwiseMultiTypeImpl : public fallback::ElemwiseMultiTypeImpl { template - void neon_round_shr_saturate_bcast_scalar(const stype* a_ptr, int8_t k, - size_t size, dt_int8* dst_ptr); + void neon_round_shr_saturate_bcast_scalar( + const stype* a_ptr, int8_t k, size_t size, dt_int8* dst_ptr); template void dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar( const ElemwiseOpParamN<2>& param, megdnn::dt_int8* dst); - bool dispatch_fuse_add_rmulh_rshr(const ElemwiseOpParamN<6>& param, - megdnn::dt_int8* dst); + bool dispatch_fuse_add_rmulh_rshr( + const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst); protected: - void on_round_shr_saturate_iXxi8xi8(const ElemwiseOpParamN<2>& param, - dt_int8* dst) override; + void on_round_shr_saturate_iXxi8xi8( + const ElemwiseOpParamN<2>& param, dt_int8* dst) override; void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( const ElemwiseOpParamN<6>& param, dt_int8* dst) override; void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( const ElemwiseOpParamN<6>& param, dt_int8* dst) override; - void on_quantized_mode(const ElemwiseOpParamN<1>& param, - const TensorND& dst, Elemwise::Mode mode) override; + void on_quantized_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst, + Elemwise::Mode mode) override; - void on_quantized_mode(const ElemwiseOpParamN<2>& param, - const TensorND& dst, Elemwise::Mode mode) override; + void on_quantized_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst, + Elemwise::Mode mode) override; - void on_quantized_mode(const ElemwiseOpParamN<3>& param, - const TensorND& dst, Elemwise::Mode mode) override; + void on_quantized_mode( + const ElemwiseOpParamN<3>& param, const TensorND& dst, + Elemwise::Mode mode) override; public: using fallback::ElemwiseMultiTypeImpl::ElemwiseMultiTypeImpl; diff --git a/dnn/src/arm_common/elemwise_op.h b/dnn/src/arm_common/elemwise_op.h index bc8d2373..a69ad7a6 100644 --- a/dnn/src/arm_common/elemwise_op.h +++ b/dnn/src/arm_common/elemwise_op.h @@ -27,20 +27,18 @@ struct ParamElemVisitor; template struct ParamElemVisitorDup; -#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitor<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vld1q_##_fun_suffix( \ - reinterpret_cast(src)); \ - } \ - }; \ - template <> \ - struct ParamElemVisitorDup<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vdupq_n_##_fun_suffix( \ - *reinterpret_cast(src)); \ - } \ +#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitor<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vld1q_##_fun_suffix(reinterpret_cast(src)); \ + } \ + }; \ + template <> \ + struct ParamElemVisitorDup<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vdupq_n_##_fun_suffix(*reinterpret_cast(src)); \ + } \ } cb(dt_qint32, int32_t, int32x4_t, s32); cb(dt_qint8, int8_t, int8x16_t, s8); @@ -57,14 +55,13 @@ cb(dt_int8, int8_t, int8x16_t, s8); template struct ParamElemVisitorBcast101x4; -#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x4<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vreinterpretq_##_fun_suffix##_##rel_suffix( \ - vld1q_dup_##rel_suffix( \ - reinterpret_cast(src))); \ - } \ +#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix, rel_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x4<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vreinterpretq_##_fun_suffix##_##rel_suffix(vld1q_dup_##rel_suffix( \ + reinterpret_cast(src))); \ + } \ } cb(dt_qint8, int32_t, int8x16_t, s8, s32); @@ -75,13 +72,12 @@ cb(dt_int16, int64_t, int16x8_t, s16, s64); cb(__fp16, uint64_t, float16x8_t, f16, u64); #endif #undef cb -#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x4<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vld1q_##_fun_suffix( \ - reinterpret_cast(src)); \ - } \ +#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x4<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vld1q_##_fun_suffix(reinterpret_cast(src)); \ + } \ } cb(dt_qint32, int32_t, int32x4_t, s32); @@ -91,13 +87,12 @@ cb(dt_int32, int32_t, int32x4_t, s32); template struct ParamElemVisitorBcast101x8; -#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ - template <> \ - struct ParamElemVisitorBcast101x8<_ctype> { \ - _neon_type operator()(const _ctype* src) const { \ - return vld1q_##_fun_suffix( \ - reinterpret_cast(src)); \ - } \ +#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \ + template <> \ + struct ParamElemVisitorBcast101x8<_ctype> { \ + _neon_type operator()(const _ctype* src) const { \ + return vld1q_##_fun_suffix(reinterpret_cast(src)); \ + } \ } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC cb(__fp16, __fp16, float16x8_t, f16); @@ -134,9 +129,9 @@ struct OpCallerUnary; template struct OpCallerUnary { - static void run(const typename Op::src_ctype* src, - typename Op::dst_ctype* dst, DType src_dtype, - DType dst_dtype, size_t nr_elems) { + static void run( + const typename Op::src_ctype* src, typename Op::dst_ctype* dst, + DType src_dtype, DType dst_dtype, size_t nr_elems) { Op op(src_dtype, dst_dtype); ParamElemVisitor vis; size_t i = 0; @@ -164,10 +159,10 @@ struct OpCallerBinary; template struct OpCallerBinary, VEC_VEC> { using Op = PowOp; - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t nr_elems) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); size_t i = 0; #if MEGDNN_FIX_AARCH32_BUG @@ -186,10 +181,10 @@ struct OpCallerBinary, VEC_VEC> { template struct OpCallerBinary, VEC_SCALAR> { using Op = PowOp; - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t nr_elems) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); size_t i = 0; #if MEGDNN_FIX_AARCH32_BUG @@ -207,11 +202,10 @@ struct OpCallerBinary, VEC_SCALAR> { template struct OpCallerBinary, VEC_BCAST101> { using Op = PowOp; - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t batch, - size_t channel, size_t channel_stride) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { const typename Op::src_ctype* src1_ptr = src1; @@ -235,10 +229,10 @@ struct OpCallerBinary, VEC_BCAST101> { template struct OpCallerBinary, SCALAR_VEC> { using Op = PowOp; - static void run(const typename Op::src_ctype src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t nr_elems) { + static void run( + const typename Op::src_ctype src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); size_t i = 0; #if MEGDNN_FIX_AARCH32_BUG @@ -256,11 +250,10 @@ struct OpCallerBinary, SCALAR_VEC> { template struct OpCallerBinary, BCAST101_VEC> { using Op = PowOp; - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t batch, - size_t channel, size_t channel_stride) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; @@ -283,10 +276,10 @@ struct OpCallerBinary, BCAST101_VEC> { template struct OpCallerBinary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t nr_elems) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitor vis1; @@ -313,11 +306,10 @@ struct OpCallerBinary { template struct OpCallerBinary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t batch, - size_t channel, size_t channel_stride) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitorDup vis1; @@ -351,20 +343,18 @@ struct OpCallerBinary { template struct OpCallerBinary, BCAST101xX_VEC> { using Op = PowOp; - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t batch, - size_t nr_channel_blocks, size_t channel_stride, - size_t channel_block_dim) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src0_block_ptr = src0_ptr + cb * channel_block_dim; for (size_t i = 0; i < channel_stride; i++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; - c_iter++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0_block_ptr + c_iter), *src1, dst); src1++; dst++; @@ -378,17 +368,16 @@ struct OpCallerBinary, BCAST101xX_VEC> { template struct OpCallerBinaryBcast101xXVec { template - static void run(const src_ctype* src0, const src_ctype* src1, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src0_block_ptr = src0_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; - img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; - c_iter++) { + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0_block_ptr + c_iter), *src1, dst); src1++; dst++; @@ -402,10 +391,10 @@ struct OpCallerBinaryBcast101xXVec { template struct OpCallerBinaryBcast101xDVec { template - static void run(const src_ctype* src0, const src_ctype* src1, - typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, - const Vis1& vis1, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { @@ -422,8 +411,7 @@ struct OpCallerBinaryBcast101xDVec { } // TODO:all elemwise_multi_type op imp one simd mode for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; - c_iter++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0_block_ptr + c_iter), *src1, dst); src1++; dst++; @@ -437,9 +425,10 @@ struct OpCallerBinaryBcast101xDVec { template struct OpCallerBinaryBcast101xXVec { template - static void run(const src_ctype* src0, const src_ctype* src1, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { ParamElemVisitorBcast101x4 vis0; ParamElemVisitor vis1; OpCallerBinaryBcast101xDVec::run( @@ -454,9 +443,10 @@ struct OpCallerBinaryBcast101xXVec<__fp16, 8> { using src_ctype = __fp16; template - static void run(const src_ctype* src0, const src_ctype* src1, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { ParamElemVisitorBcast101x8 vis0; ParamElemVisitor vis1; OpCallerBinaryBcast101xDVec::run( @@ -468,23 +458,21 @@ struct OpCallerBinaryBcast101xXVec<__fp16, 8> { template struct OpCallerBinary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t batch, - size_t nr_channel_blocks, size_t channel_stride, - size_t channel_block_dim) { - megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); Op op(src0_dtype, src1_dtype, dst_dtype); if (channel_block_dim == 4) { OpCallerBinaryBcast101xXVec::run( - src0, src1, dst, op, batch, nr_channel_blocks, - channel_stride); + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); } else { OpCallerBinaryBcast101xXVec::run( - src0, src1, dst, op, batch, nr_channel_blocks, - channel_stride); + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); } } }; @@ -492,20 +480,18 @@ struct OpCallerBinary { template struct OpCallerBinary, VEC_BCAST101xX> { using Op = PowOp; - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t batch, - size_t nr_channel_blocks, size_t channel_stride, - size_t channel_block_dim) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src1_block_ptr = src1_ptr + cb * channel_block_dim; for (size_t i = 0; i < channel_stride; i++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; - c_iter++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0), *(src1_block_ptr + c_iter), dst); src0++; dst++; @@ -519,17 +505,16 @@ struct OpCallerBinary, VEC_BCAST101xX> { template struct OpCallerBinaryVecBcast101xX { template - static void run(const src_ctype* src0, const src_ctype* src1, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; - img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; - c_iter++) { + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*src0, *(src1_block_ptr + c_iter), dst); src0++; dst++; @@ -543,10 +528,10 @@ struct OpCallerBinaryVecBcast101xX { template struct OpCallerBinaryVecBcast101xD { template - static void run(const src_ctype* src0, const src_ctype* src1, - typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, - const Vis1& vis1, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, const Vis0& vis0, const Vis1& vis1, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { @@ -563,8 +548,7 @@ struct OpCallerBinaryVecBcast101xD { } // TODO:all elemwise_multi_type op imp one simd mode for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; - c_iter++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*src0, *(src1_block_ptr + c_iter), dst); src0++; dst++; @@ -578,9 +562,10 @@ struct OpCallerBinaryVecBcast101xD { template struct OpCallerBinaryVecBcast101xX { template - static void run(const src_ctype* src0, const src_ctype* src1, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { ParamElemVisitor vis0; ParamElemVisitorBcast101x4 vis1; OpCallerBinaryVecBcast101xD::run( @@ -594,9 +579,10 @@ template <> struct OpCallerBinaryVecBcast101xX<__fp16, 8> { using src_ctype = __fp16; template - static void run(const src_ctype* src0, const src_ctype* src1, - typename Op::dst_ctype* dst, const Op& op, size_t batch, - size_t nr_channel_blocks, size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, typename Op::dst_ctype* dst, + const Op& op, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { ParamElemVisitor vis0; ParamElemVisitorBcast101x8 vis1; OpCallerBinaryVecBcast101xD::run( @@ -608,41 +594,39 @@ struct OpCallerBinaryVecBcast101xX<__fp16, 8> { template struct OpCallerBinary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t batch, - size_t nr_channel_blocks, size_t channel_stride, - size_t channel_block_dim) { - megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t nr_channel_blocks, + size_t channel_stride, size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); Op op(src0_dtype, src1_dtype, dst_dtype); if (channel_block_dim == 4) { OpCallerBinaryVecBcast101xX::run( - src0, src1, dst, op, batch, nr_channel_blocks, - channel_stride); + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); } else { OpCallerBinaryVecBcast101xX::run( - src0, src1, dst, op, batch, nr_channel_blocks, - channel_stride); + src0, src1, dst, op, batch, nr_channel_blocks, channel_stride); } } }; template struct OpCallerBinary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t nr_elems) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitorDup vis1; auto vis1_neon = vis1(&src1); size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1_neon, vis1_neon}}, dst); + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_neon, vis1_neon}}, + dst); src0 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } @@ -661,18 +645,18 @@ struct OpCallerBinary { //! this only for nonswap op, like SUB and DIV template struct OpCallerBinary { - static void run(const typename Op::src_ctype src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t nr_elems) { + static void run( + const typename Op::src_ctype src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t nr_elems) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitorDup vis0; ParamElemVisitor vis1; auto vis0_neon = vis0(&src0); size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0_neon, vis0_neon}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, dst); + op({{vis0_neon, vis0_neon}}, {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, + dst); src1 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } @@ -690,11 +674,10 @@ struct OpCallerBinary { template struct OpCallerBinary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType dst_dtype, size_t batch, - size_t channel, size_t channel_stride) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, + DType dst_dtype, size_t batch, size_t channel, size_t channel_stride) { Op op(src0_dtype, src1_dtype, dst_dtype); ParamElemVisitorDup vis0; ParamElemVisitor vis1; @@ -730,12 +713,11 @@ struct OpCallerTernary; template struct OpCallerTernary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitor vis1; @@ -767,12 +749,11 @@ struct OpCallerTernary { //! src0: vector, src1: vector, src2: scalar template struct OpCallerTernary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - const typename Op::src_ctype src2, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitor vis1; @@ -781,8 +762,8 @@ struct OpCallerTernary { size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, - {{vis2_neon, vis2_neon}}, dst); + {{vis1(src1), vis1(src1 + Op::SIMD_WIDTH)}}, {{vis2_neon, vis2_neon}}, + dst); src0 += Op::SIMD_WIDTH * 2; src1 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; @@ -803,13 +784,11 @@ struct OpCallerTernary { //! src0: 1C11, src1: vector, src2: 1C11 template struct OpCallerTernary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch_size, size_t channel_size, - size_t channel_stride) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch_size, size_t channel_size, size_t channel_stride) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis1; ParamElemVisitorDup vis0; @@ -848,20 +827,18 @@ struct OpCallerTernary { template struct OpCallerTernaryBcast101xXVecBcast101xX { template - static void run(const src_ctype* src0, const src_ctype* src1, - const src_ctype* src2, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; auto src2_ptr = src2; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src0_block_ptr = src0_ptr + cb * channel_block_dim; auto src2_block_ptr = src2_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; - img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; - c_iter++) { + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0_block_ptr + c_iter), *src1, *(src2_block_ptr + c_iter), dst); src1++; @@ -876,11 +853,11 @@ struct OpCallerTernaryBcast101xXVecBcast101xX { template struct OpCallerTernaryBcast101xDVecBcast101xD { template - static void run(const src_ctype* src0, const src_ctype* src1, - const src_ctype* src2, typename Op::dst_ctype* dst, - const Op& op, const Vis0& vis0, const Vis1& vis1, - const Vis2& vis2, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, + const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; auto src2_ptr = src2; @@ -901,8 +878,7 @@ struct OpCallerTernaryBcast101xDVecBcast101xD { } // TODO:all elemwise_multi_type op imp one simd mode for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; - c_iter++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*(src0_block_ptr + c_iter), *src1, *(src2_block_ptr + c_iter), dst); src1++; @@ -918,16 +894,16 @@ struct OpCallerTernaryBcast101xDVecBcast101xD { template struct OpCallerTernaryBcast101xXVecBcast101xX { template - static void run(const src_ctype* src0, const src_ctype* src1, - const src_ctype* src2, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { ParamElemVisitorBcast101x4 vis0; ParamElemVisitor vis1; ParamElemVisitorBcast101x4 vis2; OpCallerTernaryBcast101xDVecBcast101xD::run( - src0, src1, src2, dst, op, vis0, vis1, vis2, batch, - nr_channel_blocks, channel_stride); + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, + channel_stride); } }; @@ -936,44 +912,40 @@ template <> struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> { using src_ctype = __fp16; template - static void run(const src_ctype* src0, const src_ctype* src1, - const src_ctype* src2, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { ParamElemVisitorBcast101x8 vis0; ParamElemVisitor vis1; ParamElemVisitorBcast101x8 vis2; OpCallerTernaryBcast101xDVecBcast101xD::run( - src0, src1, src2, dst, op, vis0, vis1, vis2, batch, - nr_channel_blocks, channel_stride); + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, + channel_stride); } }; #endif template struct OpCallerTernary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch, size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); if (channel_block_dim == 4) { - OpCallerTernaryBcast101xXVecBcast101xX::run(src0, src1, src2, - dst, op, batch, - nr_channel_blocks, - channel_stride); + OpCallerTernaryBcast101xXVecBcast101xX::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); } else { - OpCallerTernaryBcast101xXVecBcast101xX::run(src0, src1, src2, - dst, op, batch, - nr_channel_blocks, - channel_stride); + OpCallerTernaryBcast101xXVecBcast101xX::run( + src0, src1, src2, dst, op, batch, nr_channel_blocks, + channel_stride); } } }; @@ -981,13 +953,11 @@ struct OpCallerTernary { //! src1: 1C11, src0 and src2 are contig template struct OpCallerTernary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch_size, size_t channel_size, - size_t channel_stride) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch_size, size_t channel_size, size_t channel_stride) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitorDup vis1; @@ -1025,18 +995,16 @@ struct OpCallerTernary { template struct OpCallerTernaryVecBcast101xXVec { template - static void run(const src_ctype* src0, const src_ctype* src1, - const src_ctype* src2, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src1_block_ptr = src1_ptr + cb * channel_block_dim; - for (size_t img_index = 0; img_index < channel_stride; - img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; - c_iter++) { + for (size_t img_index = 0; img_index < channel_stride; img_index++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*src0, *(src1_block_ptr + c_iter), *src2, dst); src0++; src2++; @@ -1052,11 +1020,11 @@ struct OpCallerTernaryVecBcast101xXVec { template struct OpCallerTernaryVecBcast101xDVec { template - static void run(const src_ctype* src0, const src_ctype* src1, - const src_ctype* src2, typename Op::dst_ctype* dst, - const Op& op, const Vis0& vis0, const Vis1& vis1, - const Vis2& vis2, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0, + const Vis1& vis1, const Vis2& vis2, size_t batch, size_t nr_channel_blocks, + size_t channel_stride) { for (size_t b = 0; b < batch; b++) { auto src1_ptr = src1; for (size_t cb = 0; cb < nr_channel_blocks; cb++) { @@ -1075,8 +1043,7 @@ struct OpCallerTernaryVecBcast101xDVec { } // TODO:all elemwise_multi_type op imp one simd mode for (; img_index < channel_stride; img_index++) { - for (size_t c_iter = 0; c_iter < channel_block_dim; - c_iter++) { + for (size_t c_iter = 0; c_iter < channel_block_dim; c_iter++) { op(*src0, *(src1_block_ptr + c_iter), *src2, dst); src0++; src2++; @@ -1091,16 +1058,16 @@ struct OpCallerTernaryVecBcast101xDVec { template struct OpCallerTernaryVecBcast101xXVec { template - static void run(const src_ctype* src0, const src_ctype* src1, - const src_ctype* src2, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { ParamElemVisitor vis0; ParamElemVisitorBcast101x4 vis1; ParamElemVisitor vis2; OpCallerTernaryVecBcast101xDVec::run( - src0, src1, src2, dst, op, vis0, vis1, vis2, batch, - nr_channel_blocks, channel_stride); + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, + channel_stride); } }; @@ -1109,31 +1076,31 @@ template <> struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> { using src_ctype = __fp16; template - static void run(const src_ctype* src0, const src_ctype* src1, - const src_ctype* src2, typename Op::dst_ctype* dst, - const Op& op, size_t batch, size_t nr_channel_blocks, - size_t channel_stride) { + static void run( + const src_ctype* src0, const src_ctype* src1, const src_ctype* src2, + typename Op::dst_ctype* dst, const Op& op, size_t batch, + size_t nr_channel_blocks, size_t channel_stride) { ParamElemVisitor vis0; ParamElemVisitorBcast101x8 vis1; ParamElemVisitor vis2; OpCallerTernaryVecBcast101xDVec::run( - src0, src1, src2, dst, op, vis0, vis1, vis2, batch, - nr_channel_blocks, channel_stride); + src0, src1, src2, dst, op, vis0, vis1, vis2, batch, nr_channel_blocks, + channel_stride); } }; #endif template struct OpCallerTernary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype* src1, - const typename Op::src_ctype* src2, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t batch, size_t nr_channel_blocks, - size_t channel_stride, size_t channel_block_dim) { - megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8, - "only imp for nchw44/nchw88"); + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype* src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t batch, size_t nr_channel_blocks, size_t channel_stride, + size_t channel_block_dim) { + megdnn_assert( + channel_block_dim == 4 || channel_block_dim == 8, + "only imp for nchw44/nchw88"); Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); if (channel_block_dim == 4) { @@ -1151,12 +1118,11 @@ struct OpCallerTernary { //! src1: scalar, src0 and src2 has the same shape template struct OpCallerTernary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype src1, - const typename Op::src_ctype* src2, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + const typename Op::src_ctype* src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitorDup vis1; @@ -1164,8 +1130,7 @@ struct OpCallerTernary { auto vis1_neon = vis1(&src1); size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1_neon, vis1_neon}}, + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_neon, vis1_neon}}, {{vis2(src2), vis2(src2 + Op::SIMD_WIDTH)}}, dst); src0 += Op::SIMD_WIDTH * 2; src2 += Op::SIMD_WIDTH * 2; @@ -1187,12 +1152,11 @@ struct OpCallerTernary { //! src1, src2: scalar, src0 is vector template struct OpCallerTernary { - static void run(const typename Op::src_ctype* src0, - const typename Op::src_ctype src1, - const typename Op::src_ctype src2, - typename Op::dst_ctype* dst, DType src0_dtype, - DType src1_dtype, DType src2_dtype, DType dst_dtype, - size_t nr_elems) { + static void run( + const typename Op::src_ctype* src0, const typename Op::src_ctype src1, + const typename Op::src_ctype src2, typename Op::dst_ctype* dst, + DType src0_dtype, DType src1_dtype, DType src2_dtype, DType dst_dtype, + size_t nr_elems) { Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); ParamElemVisitor vis0; ParamElemVisitorDup vis1; @@ -1201,8 +1165,8 @@ struct OpCallerTernary { auto vis2_neon = vis2(&src2); size_t i = 0; for (; i + Op::SIMD_WIDTH * 2 <= nr_elems; i += Op::SIMD_WIDTH * 2) { - op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, - {{vis1_neon, vis1_neon}}, {{vis2_neon, vis2_neon}}, dst); + op({{vis0(src0), vis0(src0 + Op::SIMD_WIDTH)}}, {{vis1_neon, vis1_neon}}, + {{vis2_neon, vis2_neon}}, dst); src0 += Op::SIMD_WIDTH * 2; dst += Op::SIMD_WIDTH * 2; } diff --git a/dnn/src/arm_common/handle.cpp b/dnn/src/arm_common/handle.cpp index 5283a306..05de3a2d 100644 --- a/dnn/src/arm_common/handle.cpp +++ b/dnn/src/arm_common/handle.cpp @@ -13,21 +13,20 @@ #include "src/arm_common/handle.h" +#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/convolution/opr_impl.h" -#include "src/arm_common/pooling/opr_impl.h" +#include "src/arm_common/cvt_color/opr_impl.h" +#include "src/arm_common/elemwise/opr_impl.h" +#include "src/arm_common/elemwise_multi_type/opr_impl.h" #include "src/arm_common/local/opr_impl.h" +#include "src/arm_common/pooling/opr_impl.h" +#include "src/arm_common/reduce/opr_impl.h" +#include "src/arm_common/resize/opr_impl.h" #include "src/arm_common/separable_conv/opr_impl.h" #include "src/arm_common/separable_filter/opr_impl.h" -#include "src/arm_common/elemwise/opr_impl.h" -#include "src/arm_common/elemwise_multi_type/opr_impl.h" -#include "src/arm_common/cvt_color/opr_impl.h" +#include "src/arm_common/type_cvt/opr_impl.h" #include "src/arm_common/warp_affine/opr_impl.h" -#include "src/arm_common/resize/opr_impl.h" #include "src/arm_common/warp_perspective/opr_impl.h" -#include "src/arm_common/type_cvt/opr_impl.h" -#include "src/arm_common/reduce/opr_impl.h" -#include "src/arm_common/conv_bias/opr_impl.h" - namespace megdnn { namespace arm_common { @@ -58,7 +57,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData) MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) #pragma GCC diagnostic pop -} // namespace arm_common -} // namespace megdnn +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/handle.h b/dnn/src/arm_common/handle.h index 03888658..701b1255 100644 --- a/dnn/src/arm_common/handle.h +++ b/dnn/src/arm_common/handle.h @@ -17,22 +17,22 @@ namespace megdnn { namespace arm_common { -class HandleImpl: public fallback::HandleImpl { - public: - HandleImpl(megcoreComputingHandle_t computing_handle, - HandleType type = HandleType::ARM_COMMON): - fallback::HandleImpl::HandleImpl(computing_handle, type) - { - #if MGB_ENABLE_CPUINFO - cpuinfo_initialize(); - #endif - } +class HandleImpl : public fallback::HandleImpl { +public: + HandleImpl( + megcoreComputingHandle_t computing_handle, + HandleType type = HandleType::ARM_COMMON) + : fallback::HandleImpl::HandleImpl(computing_handle, type) { +#if MGB_ENABLE_CPUINFO + cpuinfo_initialize(); +#endif + } - template - std::unique_ptr create_operator(); + template + std::unique_ptr create_operator(); }; -} // namespace arm_common -} // namespace megdnn +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/intrinsic_helper.h b/dnn/src/arm_common/intrinsic_helper.h index b29f3e3c..bb67f044 100644 --- a/dnn/src/arm_common/intrinsic_helper.h +++ b/dnn/src/arm_common/intrinsic_helper.h @@ -17,8 +17,9 @@ namespace megdnn { namespace { -template +template < + int weight_number, int base_offset, int ptr_step, int oc_block, typename Func, + typename T, typename T2, typename... XT> struct LoadHelper { static __ai void impl(T& weight, T2 ptr, int oc_offset, XT... args); }; @@ -26,13 +27,14 @@ struct LoadHelper { #define WEIGHT_CB(step) \ src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...); -#define LOAD_HELPER(step) \ - template \ - struct LoadHelper { \ - static __ai void impl(T& src, T2 ptr, int, XT... args) { \ - UNROLL_CALL_RAW(step, WEIGHT_CB); \ - } \ +#define LOAD_HELPER(step) \ + template < \ + int base_offset, int ptr_step, typename Func, typename T, typename T2, \ + typename... XT> \ + struct LoadHelper { \ + static __ai void impl(T& src, T2 ptr, int, XT... args) { \ + UNROLL_CALL_RAW(step, WEIGHT_CB); \ + } \ } LOAD_HELPER(1); @@ -56,16 +58,14 @@ LOAD_HELPER(16); #undef WEIGHT_CB ///////////////////////////c_dim = 1///////////////////////// -#define WEIGHT_CB(step) \ - src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); - -#define LOAD_HELPER(step) \ - template \ - struct LoadHelper { \ - static __ai void impl(T& src, T2 ptr, int) { \ - UNROLL_CALL_RAW(step, WEIGHT_CB); \ - } \ +#define WEIGHT_CB(step) src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); + +#define LOAD_HELPER(step) \ + template \ + struct LoadHelper { \ + static __ai void impl(T& src, T2 ptr, int) { \ + UNROLL_CALL_RAW(step, WEIGHT_CB); \ + } \ } LOAD_HELPER(1); @@ -86,13 +86,12 @@ LOAD_HELPER(9); src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \ src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset); -#define LOAD_HELPER(step) \ - template \ - struct LoadHelper { \ - static __ai void impl(T& src, T2 ptr, int oc_offset) { \ - UNROLL_CALL_RAW(step, WEIGHT_CB); \ - } \ +#define LOAD_HELPER(step) \ + template \ + struct LoadHelper { \ + static __ai void impl(T& src, T2 ptr, int oc_offset) { \ + UNROLL_CALL_RAW(step, WEIGHT_CB); \ + } \ } LOAD_HELPER(1); @@ -107,18 +106,20 @@ LOAD_HELPER(8); #undef LOAD_HELPER #undef WEIGHT_CB -template +template < + int weight_number, int base_offset, int ptr_step, int c_dim, typename Func, + typename T, typename T2> __ai void load_helper(T& weight, T2 ptr, int oc_offset) { LoadHelper::impl( weight, ptr, oc_offset); } -template +template < + int weight_number, int base_offset, int ptr_step, int c_dim, typename Func, + typename T, typename T2, typename... XT> __ai void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) { - LoadHelper::impl(weight, ptr, oc_offset, args...); + LoadHelper::impl( + weight, ptr, oc_offset, args...); } } // namespace diff --git a/dnn/src/arm_common/local/opr_impl.cpp b/dnn/src/arm_common/local/opr_impl.cpp index a871aa30..9f596ee0 100644 --- a/dnn/src/arm_common/local/opr_impl.cpp +++ b/dnn/src/arm_common/local/opr_impl.cpp @@ -36,23 +36,21 @@ void do_one_pixel(float* dst, const float* filter, float sval, int OC) { } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -void do_one_pixel(dt_float16* dst, const dt_float16* filter, dt_float16 sval, - int OC) { +void do_one_pixel(dt_float16* dst, const dt_float16* filter, dt_float16 sval, int OC) { const __fp16* filter_ptr = reinterpret_cast(filter); __fp16* dst_ptr = reinterpret_cast<__fp16*>(dst); const int width = 8u; int oc = 0; float16x8_t vs = vdupq_n_f16(sval); - for (; oc + width <= OC; - oc += width, filter_ptr += width, dst_ptr += width) { + for (; oc + width <= OC; oc += width, filter_ptr += width, dst_ptr += width) { float16x8_t vf = vld1q_f16(filter_ptr); float16x8_t vd = vld1q_f16(dst_ptr); vd = vmlaq_f16(vd, vs, vf); vst1q_f16(dst_ptr, vd); } #if MEGDNN_FIX_AARCH32_BUG - // FIXME: as llvm may cause cannot select error if enable vectorize - #pragma clang loop vectorize(disable) +// FIXME: as llvm may cause cannot select error if enable vectorize +#pragma clang loop vectorize(disable) #endif for (; oc < OC; oc++, dst_ptr++, filter_ptr++) { *dst_ptr += sval * (*filter_ptr); @@ -93,17 +91,17 @@ void exec_internal(const LocalImpl::FloatNoncontigBatchKernParam& kparam) { } // anonymous namespace -size_t LocalImpl::get_workspace_in_bytes(const TensorLayout& /* src */, - const TensorLayout& /* filter */, - const TensorLayout& dst) { +size_t LocalImpl::get_workspace_in_bytes( + const TensorLayout& /* src */, const TensorLayout& /* filter */, + const TensorLayout& dst) { return dst.span().dist_byte(); } LocalImpl::float_noncontig_batch_kern LocalImpl::dispatch_float_noncontig_batch( const TensorLayout& src, const TensorLayout&, const TensorLayout&) { - megdnn_assert(src.stride[0] > 0 && - static_cast(src.stride[0]) >= - src.total_nr_elems() / src.shape[0]); + megdnn_assert( + src.stride[0] > 0 && + static_cast(src.stride[0]) >= src.total_nr_elems() / src.shape[0]); if (src.dtype == dtype::Float32()) { if (param().mode == Mode::CROSS_CORRELATION) { return exec_internal; @@ -124,8 +122,9 @@ LocalImpl::float_noncontig_batch_kern LocalImpl::dispatch_float_noncontig_batch( return nullptr; } -void LocalImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_out dst, _megdnn_workspace workspace) { +void LocalImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { return exec_use_float_noncontig_batch(src, filter, dst, workspace); } diff --git a/dnn/src/arm_common/local/opr_impl.h b/dnn/src/arm_common/local/opr_impl.h index 8e74efcb..44a0b46f 100644 --- a/dnn/src/arm_common/local/opr_impl.h +++ b/dnn/src/arm_common/local/opr_impl.h @@ -16,25 +16,23 @@ namespace megdnn { namespace arm_common { -class LocalImpl final: public naive::LocalForwardImpl { - public: - using naive::LocalForwardImpl::LocalForwardImpl; - - float_noncontig_batch_kern dispatch_float_noncontig_batch( - const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) override; - - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - - size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) override; +class LocalImpl final : public naive::LocalForwardImpl { +public: + using naive::LocalForwardImpl::LocalForwardImpl; + + float_noncontig_batch_kern dispatch_float_noncontig_batch( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; + + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; }; -} // namespace arm_common -} // namespace megdnn +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/algos.cpp b/dnn/src/arm_common/matrix_mul/algos.cpp index e2a2a055..41e3fc4a 100644 --- a/dnn/src/arm_common/matrix_mul/algos.cpp +++ b/dnn/src/arm_common/matrix_mul/algos.cpp @@ -17,7 +17,6 @@ #include "midout.h" - MIDOUT_DECL(megdnn_arm_hgemv) MIDOUT_DECL(megdnn_arm_exec_int8816) MIDOUT_DECL(megdnn_arm_exec_int8832) @@ -33,8 +32,8 @@ WorkspaceBundle get_workspace_bundle_int_8x8x16( const MatrixMulImpl::KernSizeParam& kern_size_param) { auto M = kern_size_param.M, K = kern_size_param.K, N = kern_size_param.N; // Use 8x8 tile - return WorkspaceBundle(nullptr, {(M + 8) * K * sizeof(int8_t), - K * (N + 8) * sizeof(int8_t)}); + return WorkspaceBundle( + nullptr, {(M + 8) * K * sizeof(int8_t), K * (N + 8) * sizeof(int8_t)}); } void exec_int_8x8x16(const MatrixMulImpl::KernParam& kern_param) { @@ -55,8 +54,7 @@ void exec_int_8x8x16(const MatrixMulImpl::KernParam& kern_param) { } } // anonymous namespace -bool MatrixMulImpl::AlgoInt8x8x16::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoInt8x8x16::usable(const KernSizeParam& kern_size_param) const { return kern_size_param.A_type == dtype::Int8() && kern_size_param.B_type == dtype::Int8() && kern_size_param.C_type == dtype::Int16() && @@ -67,8 +65,8 @@ bool MatrixMulImpl::AlgoInt8x8x16::usable( size_t MatrixMulImpl::AlgoInt8x8x16::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_arm_exec_int8816, - midout_iv("AlgoInt8x8x16::get_workspace"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_exec_int8816, midout_iv("AlgoInt8x8x16::get_workspace"_hash)) { auto wbundle = get_workspace_bundle_int_8x8x16(kern_size_param); return wbundle.total_size_in_bytes(); } @@ -84,8 +82,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern( /* ===================== Int8x8x32 Gemv algo ===================== */ namespace { void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_arm_exec_int8832, - midout_iv("int8x8x32_gemv_kern"_hash)) { + MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gemv_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; const auto Aptr = kern_param.A(), Bptr = kern_param.B(); @@ -99,8 +96,8 @@ void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoInt8x8x32Gemv::usable( const KernSizeParam& kern_size_param) const { auto N = kern_size_param.N, LDB = kern_size_param.LDB; - return can_be_treated_as_int8x8x32(kern_size_param) && - !kern_size_param.trA && !kern_size_param.trB && (N == 1 && LDB == 1); + return can_be_treated_as_int8x8x32(kern_size_param) && !kern_size_param.trA && + !kern_size_param.trB && (N == 1 && LDB == 1); } bool MatrixMulImpl::AlgoInt8x8x32Gemv::preferred( @@ -117,8 +114,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern( /* ===================== Int8x8x32 Gemv MK4 algo ===================== */ namespace { void int8x8x32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_arm_exec_int8832, - midout_iv("int8x8x32_gemv_mk4_kern"_hash)) { + MIDOUT_BEGIN(megdnn_arm_exec_int8832, midout_iv("int8x8x32_gemv_mk4_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; const auto Aptr = kern_param.A(), Bptr = kern_param.B(); @@ -136,17 +132,16 @@ bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::usable( auto K = kern_size_param.K; auto LDB = kern_size_param.LDB; - bool is_dtype_ok = - kern_size_param.A_type == kern_size_param.B_type && - (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || - kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && - (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || - kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); + bool is_dtype_ok = kern_size_param.A_type == kern_size_param.B_type && + (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || + kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && + (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && - kern_size_param.format == param::MatrixMul::Format::MK4 && - is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && - M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; + kern_size_param.format == param::MatrixMul::Format::MK4 && is_dtype_ok && + !kern_size_param.trA && !kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && + N == 1 && LDB == 4; } bool MatrixMulImpl::AlgoInt8x8x32GemvMK4::preferred( @@ -164,8 +159,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( /* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ namespace { void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_arm_exec_int8832, - midout_iv("int8x8x32_gemv_mk4_dot_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_arm_exec_int8832, midout_iv("int8x8x32_gemv_mk4_dot_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; const auto Aptr = kern_param.A(), Bptr = kern_param.B(); @@ -178,8 +173,7 @@ void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable( const KernSizeParam& kern_size_param) const { - - if (!cpuinfo_has_arm_neon_dot()){ + if (!cpuinfo_has_arm_neon_dot()) { return false; } auto M = kern_size_param.M; @@ -187,17 +181,16 @@ bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable( auto K = kern_size_param.K; auto LDB = kern_size_param.LDB; - bool is_dtype_ok = - kern_size_param.A_type == kern_size_param.B_type && - (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || - kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && - (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || - kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); + bool is_dtype_ok = kern_size_param.A_type == kern_size_param.B_type && + (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || + kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && + (kern_size_param.C_type.enumv() == DTypeEnum::Int32 || + kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32); return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && - kern_size_param.format == param::MatrixMul::Format::MK4_DOT && - is_dtype_ok && !kern_size_param.trA && !kern_size_param.trB && - M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; + kern_size_param.format == param::MatrixMul::Format::MK4_DOT && is_dtype_ok && + !kern_size_param.trA && !kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && + N == 1 && LDB == 4; } bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::preferred( @@ -214,12 +207,10 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::get_kern( /* ===================== F32 Gemv algo ===================== */ namespace { void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_arm_exec_fp32, - midout_iv("f32_gemv_kern"_hash)) { + MIDOUT_BEGIN(megdnn_arm_exec_fp32, midout_iv("f32_gemv_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); } @@ -227,8 +218,7 @@ void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { } } // anonymous namespace -bool MatrixMulImpl::AlgoF32Gemv::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF32Gemv::usable(const KernSizeParam& kern_size_param) const { // enumerate the M, N, K, only usable when preferred return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.format == param::MatrixMul::Format::DEFAULT && @@ -238,28 +228,24 @@ bool MatrixMulImpl::AlgoF32Gemv::usable( !kern_size_param.trB && preferred(kern_size_param); } -bool MatrixMulImpl::AlgoF32Gemv::preferred( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF32Gemv::preferred(const KernSizeParam& kern_size_param) const { auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K, LDB = kern_size_param.LDB; return M < 8 || (M == 8 && K <= 2) || (N == 1 && LDB == 1); } -MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern( - const KernSizeParam&) const { +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(const KernSizeParam&) const { return f32_gemv_kern; } /* ================== F32 Gemv MK4 algo ================== */ namespace { void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_arm_exec_fp32, - midout_iv("f32_gemv_mk4_kern"_hash)) { + MIDOUT_BEGIN(megdnn_arm_exec_fp32, midout_iv("f32_gemv_mk4_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC); } @@ -267,8 +253,7 @@ void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) { } } // anonymous namespace -bool MatrixMulImpl::AlgoF32GemvMK4::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF32GemvMK4::usable(const KernSizeParam& kern_size_param) const { // enumerate the M, N, K, only usable when preferred auto M = kern_size_param.M; auto N = kern_size_param.N; @@ -280,8 +265,7 @@ bool MatrixMulImpl::AlgoF32GemvMK4::usable( kern_size_param.B_type == kern_size_param.A_type && kern_size_param.C_type == kern_size_param.A_type && kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && - !kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && - LDB == 4; + !kern_size_param.trB && M % 4 == 0 && K % 4 == 0 && N == 1 && LDB == 4; } bool MatrixMulImpl::AlgoF32GemvMK4::preferred( @@ -299,8 +283,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern( namespace { template void gevm_like_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_arm_exec_fp32, - midout_iv("gevm_like_kern"_hash)) { + MIDOUT_BEGIN(megdnn_arm_exec_fp32, midout_iv("gevm_like_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto LDB = kern_param.LDB; const auto Aptr = kern_param.A(), Bptr = kern_param.B(); @@ -311,15 +294,13 @@ void gevm_like_kern(const MatrixMulImpl::KernParam& kern_param) { } } // anonymous namespace -bool MatrixMulImpl::AlgoGevm::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoGevm::usable(const KernSizeParam& kern_size_param) const { // enumerate the M, N, K, only usable when preferred - bool fp32_ok = - kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && - kern_size_param.format == param::MatrixMul::Format::DEFAULT && - kern_size_param.B_type == kern_size_param.A_type && - kern_size_param.C_type == kern_size_param.A_type && - kern_size_param.A_type == dtype::Float32(); + bool fp32_ok = kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && + kern_size_param.format == param::MatrixMul::Format::DEFAULT && + kern_size_param.B_type == kern_size_param.A_type && + kern_size_param.C_type == kern_size_param.A_type && + kern_size_param.A_type == dtype::Float32(); bool fp16_ok = false; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC fp16_ok = kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && @@ -332,8 +313,7 @@ bool MatrixMulImpl::AlgoGevm::usable( return (fp32_ok || fp16_ok || int8_ok) && preferred(kern_size_param); } -bool MatrixMulImpl::AlgoGevm::preferred( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoGevm::preferred(const KernSizeParam& kern_size_param) const { auto M = kern_size_param.M; return kern_size_param.trB && M == 1; } @@ -342,8 +322,9 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoGevm::get_kern( const KernSizeParam& kern_size_param) const { if (kern_size_param.A_type == dtype::Float32()) { return gevm_like_kern; - } else if (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || - kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) { + } else if ( + kern_size_param.A_type.enumv() == DTypeEnum::Int8 || + kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) { return gevm_like_kern; } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -365,21 +346,19 @@ namespace { void f16_gemv_kern(const MatrixMulImpl::KernParam& kern_param) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); MIDOUT_BEGIN(megdnn_arm_hgemv, void) { - arm_common::gemv_like(reinterpret_cast(Aptr), - reinterpret_cast(Bptr), - reinterpret_cast<__fp16*>(Cptr), M, N, K, LDA, - LDB, LDC); + arm_common::gemv_like( + reinterpret_cast(Aptr), + reinterpret_cast(Bptr), reinterpret_cast<__fp16*>(Cptr), + M, N, K, LDA, LDB, LDC); } MIDOUT_END(); } } // anonymous namespace -bool MatrixMulImpl::AlgoF16Gemv::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF16Gemv::usable(const KernSizeParam& kern_size_param) const { // enumerate the M, N, K, only usable when preferred return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.format == param::MatrixMul::Format::DEFAULT && @@ -389,16 +368,14 @@ bool MatrixMulImpl::AlgoF16Gemv::usable( !kern_size_param.trB && preferred(kern_size_param); } -bool MatrixMulImpl::AlgoF16Gemv::preferred( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF16Gemv::preferred(const KernSizeParam& kern_size_param) const { auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K, LDB = kern_size_param.LDB; return M <= 4 || (M == 8 && K <= 2) || (N == 1 && LDB == 1); } -MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16Gemv::get_kern( - const KernSizeParam&) const { +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16Gemv::get_kern(const KernSizeParam&) const { return f16_gemv_kern; } #endif diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index 852e64b2..2d9ac4bd 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -19,9 +19,7 @@ namespace arm_common { class MatrixMulImpl::AlgoInt8x8x16 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARM_COMMON_INT8X8X16"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -34,8 +32,7 @@ public: class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } bool usable(const KernSizeParam&) const override; @@ -51,8 +48,7 @@ public: class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } bool usable(const KernSizeParam&) const override; @@ -69,8 +65,7 @@ public: class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } bool usable(const KernSizeParam&) const override; @@ -89,9 +84,7 @@ protected: ~AlgoF32Gemv() = default; public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARM_COMMON_F32_GEMV"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -105,8 +98,7 @@ public: class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } bool usable(const KernSizeParam&) const override; @@ -122,9 +114,7 @@ public: #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARM_COMMON_F16_GEMV"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -139,9 +129,7 @@ public: class MatrixMulImpl::AlgoGevm : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARM_COMMON_GEVM"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; diff --git a/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.cpp b/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.cpp index 56966512..d02dbda1 100644 --- a/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.cpp +++ b/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.cpp @@ -15,10 +15,11 @@ namespace { -inline int8x8_t vreinterpret_s8_s8(int8x8_t x) { return x; } +inline int8x8_t vreinterpret_s8_s8(int8x8_t x) { + return x; +} -void packA(const int8_t *src, int8_t *dst, size_t M, size_t K) -{ +void packA(const int8_t* src, int8_t* dst, size_t M, size_t K) { #if 0 // naive impl megdnn_assert(M % 8 == 0); @@ -30,26 +31,23 @@ void packA(const int8_t *src, int8_t *dst, size_t M, size_t K) #else // 8x8 block at a time size_t m = 0; - int8_t * __restrict dptr = dst; - for (; m+8 <= M; m += 8) { + int8_t* __restrict dptr = dst; + for (; m + 8 <= M; m += 8) { size_t k = 0; - for (; k+8 <= K; k += 8) { - const int8_t * __restrict sptr = src + (m*K + k); - int8x8_t l0 = vld1_s8(sptr + 0*K), - l1 = vld1_s8(sptr + 1*K), - l2 = vld1_s8(sptr + 2*K), - l3 = vld1_s8(sptr + 3*K), - l4 = vld1_s8(sptr + 4*K), - l5 = vld1_s8(sptr + 5*K), - l6 = vld1_s8(sptr + 6*K), - l7 = vld1_s8(sptr + 7*K); + for (; k + 8 <= K; k += 8) { + const int8_t* __restrict sptr = src + (m * K + k); + int8x8_t l0 = vld1_s8(sptr + 0 * K), l1 = vld1_s8(sptr + 1 * K), + l2 = vld1_s8(sptr + 2 * K), l3 = vld1_s8(sptr + 3 * K), + l4 = vld1_s8(sptr + 4 * K), l5 = vld1_s8(sptr + 5 * K), + l6 = vld1_s8(sptr + 6 * K), l7 = vld1_s8(sptr + 7 * K); // do transpose -#define TRANS(lhs, rhs, bit) { \ - auto tmp = vtrn_s ## bit(vreinterpret_s ## bit ## _s8(lhs), \ - vreinterpret_s ## bit ## _s8(rhs)); \ - lhs = vreinterpret_s8_s ## bit(tmp.val[0]); \ - rhs = vreinterpret_s8_s ## bit(tmp.val[1]); \ -} +#define TRANS(lhs, rhs, bit) \ + { \ + auto tmp = vtrn_s##bit( \ + vreinterpret_s##bit##_s8(lhs), vreinterpret_s##bit##_s8(rhs)); \ + lhs = vreinterpret_s8_s##bit(tmp.val[0]); \ + rhs = vreinterpret_s8_s##bit(tmp.val[1]); \ + } TRANS(l0, l4, 32); TRANS(l1, l5, 32); TRANS(l2, l6, 32); @@ -63,53 +61,61 @@ void packA(const int8_t *src, int8_t *dst, size_t M, size_t K) TRANS(l4, l5, 8); TRANS(l6, l7, 8); #undef TRANS - vst1_s8(dptr, l0); dptr += 8; - vst1_s8(dptr, l1); dptr += 8; - vst1_s8(dptr, l2); dptr += 8; - vst1_s8(dptr, l3); dptr += 8; - vst1_s8(dptr, l4); dptr += 8; - vst1_s8(dptr, l5); dptr += 8; - vst1_s8(dptr, l6); dptr += 8; - vst1_s8(dptr, l7); dptr += 8; + vst1_s8(dptr, l0); + dptr += 8; + vst1_s8(dptr, l1); + dptr += 8; + vst1_s8(dptr, l2); + dptr += 8; + vst1_s8(dptr, l3); + dptr += 8; + vst1_s8(dptr, l4); + dptr += 8; + vst1_s8(dptr, l5); + dptr += 8; + vst1_s8(dptr, l6); + dptr += 8; + vst1_s8(dptr, l7); + dptr += 8; } for (; k < K; ++k) { - const int8_t * __restrict sptr = src + (m*K + k); - for (size_t i = 0; i < 8; ++i) *(dptr++) = *(sptr + i*K); + const int8_t* __restrict sptr = src + (m * K + k); + for (size_t i = 0; i < 8; ++i) + *(dptr++) = *(sptr + i * K); } } if (m < M) { for (size_t k = 0; k < K; ++k) { - const int8_t * __restrict sptr = src + (m*K + k); + const int8_t* __restrict sptr = src + (m * K + k); for (size_t i = 0; i < 8; ++i) { - *(dptr++) = (m+i < M ? *(sptr + i*K) : 0); + *(dptr++) = (m + i < M ? *(sptr + i * K) : 0); } } } #endif } -#define LOAD(i) \ - int8x8_t l ## i = vld1_s8(sptr); \ - int8x8_t s ## i = vld1_s8(sptr + 8); \ +#define LOAD(i) \ + int8x8_t l##i = vld1_s8(sptr); \ + int8x8_t s##i = vld1_s8(sptr + 8); \ sptr += LDB; -#define STORE(i) \ - vst1_s8(dptr, l ## i); \ - dptr += 8; \ - vst1_s8(dptr, s ## i); \ +#define STORE(i) \ + vst1_s8(dptr, l##i); \ + dptr += 8; \ + vst1_s8(dptr, s##i); \ dptr += 8; -#define TRANS(i) \ - int8x8_t l ## i = vld1_s8(sptr); \ - int8x8_t s ## i = vld1_s8(sptr + 8); \ - sptr += N; \ - vst1_s8(dptr, l ## i); \ - dptr += 8; \ - vst1_s8(dptr, s ## i); \ +#define TRANS(i) \ + int8x8_t l##i = vld1_s8(sptr); \ + int8x8_t s##i = vld1_s8(sptr + 8); \ + sptr += N; \ + vst1_s8(dptr, l##i); \ + dptr += 8; \ + vst1_s8(dptr, s##i); \ dptr += 8; -void packB(const int8_t *src, int8_t *dst, size_t K, size_t N, size_t LDB) -{ +void packB(const int8_t* src, int8_t* dst, size_t K, size_t N, size_t LDB) { #if 0 megdnn_assert(N % 8 == 0); for (size_t n = 0; n+8 <= N; n += 8) @@ -118,12 +124,12 @@ void packB(const int8_t *src, int8_t *dst, size_t K, size_t N, size_t LDB) for (size_t n2 = n; n2 < n+8; ++n2) *(dst++) = src[k*N + n2]; } #else - int8_t * __restrict dptr = dst; + int8_t* __restrict dptr = dst; size_t n = 0; - for(; n+16 <=N; n += 16) { + for (; n + 16 <= N; n += 16) { size_t k = 0; - for (; k+8 <= K; k += 8) { - const int8_t * __restrict sptr = src + k * LDB + n; + for (; k + 8 <= K; k += 8) { + const int8_t* __restrict sptr = src + k * LDB + n; LOAD(0); LOAD(1); @@ -144,175 +150,176 @@ void packB(const int8_t *src, int8_t *dst, size_t K, size_t N, size_t LDB) STORE(7); #undef STORE #undef TRANS - } for (; k < K; ++k) { - const int8_t * __restrict sptr = src + k * LDB + n; + const int8_t* __restrict sptr = src + k * LDB + n; int8x8_t l = vld1_s8(sptr); int8x8_t s = vld1_s8(sptr + 8); - vst1_s8(dptr, l); dptr += 8; - vst1_s8(dptr, s); dptr += 8; + vst1_s8(dptr, l); + dptr += 8; + vst1_s8(dptr, s); + dptr += 8; } } - for (; n+8 <= N; n += 8) { + for (; n + 8 <= N; n += 8) { size_t k = 0; - for (; k+8 <= K; k += 8) { - const int8_t * __restrict sptr = src + k * LDB + n; - int8x8_t l0 = vld1_s8(sptr + 0*N), - l1 = vld1_s8(sptr + 1*N), - l2 = vld1_s8(sptr + 2*N), - l3 = vld1_s8(sptr + 3*N), - l4 = vld1_s8(sptr + 4*N), - l5 = vld1_s8(sptr + 5*N), - l6 = vld1_s8(sptr + 6*N), - l7 = vld1_s8(sptr + 7*N); - vst1_s8(dptr, l0); dptr += 8; - vst1_s8(dptr, l1); dptr += 8; - vst1_s8(dptr, l2); dptr += 8; - vst1_s8(dptr, l3); dptr += 8; - vst1_s8(dptr, l4); dptr += 8; - vst1_s8(dptr, l5); dptr += 8; - vst1_s8(dptr, l6); dptr += 8; - vst1_s8(dptr, l7); dptr += 8; + for (; k + 8 <= K; k += 8) { + const int8_t* __restrict sptr = src + k * LDB + n; + int8x8_t l0 = vld1_s8(sptr + 0 * N), l1 = vld1_s8(sptr + 1 * N), + l2 = vld1_s8(sptr + 2 * N), l3 = vld1_s8(sptr + 3 * N), + l4 = vld1_s8(sptr + 4 * N), l5 = vld1_s8(sptr + 5 * N), + l6 = vld1_s8(sptr + 6 * N), l7 = vld1_s8(sptr + 7 * N); + vst1_s8(dptr, l0); + dptr += 8; + vst1_s8(dptr, l1); + dptr += 8; + vst1_s8(dptr, l2); + dptr += 8; + vst1_s8(dptr, l3); + dptr += 8; + vst1_s8(dptr, l4); + dptr += 8; + vst1_s8(dptr, l5); + dptr += 8; + vst1_s8(dptr, l6); + dptr += 8; + vst1_s8(dptr, l7); + dptr += 8; } for (; k < K; ++k) { - const int8_t * __restrict sptr = src + k * LDB + n; + const int8_t* __restrict sptr = src + k * LDB + n; int8x8_t l = vld1_s8(sptr); - vst1_s8(dptr, l); dptr += 8; + vst1_s8(dptr, l); + dptr += 8; } } if (n < N) { for (size_t k = 0; k < K; ++k) { - const int8_t * __restrict sptr = src + k * LDB + n; + const int8_t* __restrict sptr = src + k * LDB + n; int8_t l[8] = {0}; - for (size_t i = 0; n+i < N; ++i) l[i] = sptr[i]; - for (size_t i = 0; i < 8; ++i) *(dptr++) = l[i]; + for (size_t i = 0; n + i < N; ++i) + l[i] = sptr[i]; + for (size_t i = 0; i < 8; ++i) + *(dptr++) = l[i]; } } #endif } -} // anonymous namespace +} // anonymous namespace //#include namespace megdnn { namespace arm_common { -#define GAO(i) { \ - tmp = vdup_lane_s8(a, i); \ - l ## i = vmlal_s8(l ## i, tmp, b); \ -} +#define GAO(i) \ + { \ + tmp = vdup_lane_s8(a, i); \ + l##i = vmlal_s8(l##i, tmp, b); \ + } -#define STORE_REMAIN_N(i, p) \ - if(plen > p) \ +#define STORE_REMAIN_N(i, p) \ + if (plen > p) \ Cptr[p] = vgetq_lane_s16(l##i, p); \ - else \ + else \ break; -#define STORE_PARTRIAL_N(i) { \ - while(1) { \ - STORE_REMAIN_N(i, 0) \ - STORE_REMAIN_N(i, 1) \ - STORE_REMAIN_N(i, 2) \ - STORE_REMAIN_N(i, 3) \ - STORE_REMAIN_N(i, 4) \ - STORE_REMAIN_N(i, 5) \ - STORE_REMAIN_N(i, 6) \ - break; \ - } \ - Cptr += N; \ -} +#define STORE_PARTRIAL_N(i) \ + { \ + while (1) { \ + STORE_REMAIN_N(i, 0) \ + STORE_REMAIN_N(i, 1) \ + STORE_REMAIN_N(i, 2) \ + STORE_REMAIN_N(i, 3) \ + STORE_REMAIN_N(i, 4) \ + STORE_REMAIN_N(i, 5) \ + STORE_REMAIN_N(i, 6) \ + break; \ + } \ + Cptr += N; \ + } -#define STORE_PARTRIAL_M(i) { \ - if(plen > i) { \ - vst1q_s16(Cptr, l##i); \ - Cptr += N; \ - } \ - else \ - break; \ -} +#define STORE_PARTRIAL_M(i) \ + { \ + if (plen > i) { \ + vst1q_s16(Cptr, l##i); \ + Cptr += N; \ + } else \ + break; \ + } -#define GAO_16(i) { \ - tmp = vdup_lane_s8(a, i); \ - l ## i = vmlal_s8(l ## i, tmp, b0); \ - s ## i = vmlal_s8(s ## i, tmp, b1); \ -} +#define GAO_16(i) \ + { \ + tmp = vdup_lane_s8(a, i); \ + l##i = vmlal_s8(l##i, tmp, b0); \ + s##i = vmlal_s8(s##i, tmp, b1); \ + } -#define STORE_16(i) { \ - vst1q_s16(Cptr, l##i); \ - vst1q_s16(Cptr + 8, s##i); \ - Cptr += N; \ -} +#define STORE_16(i) \ + { \ + vst1q_s16(Cptr, l##i); \ + vst1q_s16(Cptr + 8, s##i); \ + Cptr += N; \ + } -#define STORE_REMAIN_N_16(i, p) \ - if(plen > p) \ - Cptr[8+p] = vgetq_lane_s16(s##i, p); \ - else \ +#define STORE_REMAIN_N_16(i, p) \ + if (plen > p) \ + Cptr[8 + p] = vgetq_lane_s16(s##i, p); \ + else \ break; -#define STORE_PARTRIAL_N_16(i) { \ - while(1) { \ - vst1q_s16(Cptr, l##i); \ - STORE_REMAIN_N_16(i, 0) \ - STORE_REMAIN_N_16(i, 1) \ - STORE_REMAIN_N_16(i, 2) \ - STORE_REMAIN_N_16(i, 3) \ - STORE_REMAIN_N_16(i, 4) \ - STORE_REMAIN_N_16(i, 5) \ - STORE_REMAIN_N_16(i, 6) \ - break; \ - } \ - Cptr += N; \ -} +#define STORE_PARTRIAL_N_16(i) \ + { \ + while (1) { \ + vst1q_s16(Cptr, l##i); \ + STORE_REMAIN_N_16(i, 0) \ + STORE_REMAIN_N_16(i, 1) \ + STORE_REMAIN_N_16(i, 2) \ + STORE_REMAIN_N_16(i, 3) \ + STORE_REMAIN_N_16(i, 4) \ + STORE_REMAIN_N_16(i, 5) \ + STORE_REMAIN_N_16(i, 6) \ + break; \ + } \ + Cptr += N; \ + } -#define STORE_PARTRIAL_M_16(i) { \ - if(plen > i) \ - STORE_16(i) \ - else \ - break; \ -} +#define STORE_PARTRIAL_M_16(i) \ + { \ + if (plen > i) \ + STORE_16(i) \ + else \ + break; \ + } -void exec_gemm_int8_int8_int16(const int8_t *A_, const int8_t *B_, int16_t *C, - size_t M, size_t K, size_t N,size_t LDB, - int8_t *w0, int8_t *w1) -{ +void exec_gemm_int8_int8_int16( + const int8_t* A_, const int8_t* B_, int16_t* C, size_t M, size_t K, size_t N, + size_t LDB, int8_t* w0, int8_t* w1) { // for test - //printf("matrix_mul M %ld, K %ld, N %ld \n", M, K, N); + // printf("matrix_mul M %ld, K %ld, N %ld \n", M, K, N); packA(A_, w0, M, K); packB(B_, w1, K, N, LDB); - const int8_t * A = w0; - const int8_t * B = w1; + const int8_t* A = w0; + const int8_t* B = w1; for (size_t m = 0; m < M; m += 8) { size_t n = 0; for (; n + 16 <= N; n += 16) { - //for (; n + 7 < N; n += 16) { - int16x8_t l0 = vdupq_n_s16(0), - l1 = vdupq_n_s16(0), - l2 = vdupq_n_s16(0), - l3 = vdupq_n_s16(0), - l4 = vdupq_n_s16(0), - l5 = vdupq_n_s16(0), - l6 = vdupq_n_s16(0), - l7 = vdupq_n_s16(0), - s0 = vdupq_n_s16(0), - s1 = vdupq_n_s16(0), - s2 = vdupq_n_s16(0), - s3 = vdupq_n_s16(0), - s4 = vdupq_n_s16(0), - s5 = vdupq_n_s16(0), - s6 = vdupq_n_s16(0), + // for (; n + 7 < N; n += 16) { + int16x8_t l0 = vdupq_n_s16(0), l1 = vdupq_n_s16(0), l2 = vdupq_n_s16(0), + l3 = vdupq_n_s16(0), l4 = vdupq_n_s16(0), l5 = vdupq_n_s16(0), + l6 = vdupq_n_s16(0), l7 = vdupq_n_s16(0), s0 = vdupq_n_s16(0), + s1 = vdupq_n_s16(0), s2 = vdupq_n_s16(0), s3 = vdupq_n_s16(0), + s4 = vdupq_n_s16(0), s5 = vdupq_n_s16(0), s6 = vdupq_n_s16(0), s7 = vdupq_n_s16(0); - const int8_t * __restrict Aptr = A + m*K; - const int8_t * __restrict Bptr = B + n*K; + const int8_t* __restrict Aptr = A + m * K; + const int8_t* __restrict Bptr = B + n * K; for (size_t k = 0; k < K; ++k) { int8x8_t tmp; - int8x8_t a = vld1_s8(Aptr), - b0 = vld1_s8(Bptr), - b1 = vld1_s8(Bptr + 8); + int8x8_t a = vld1_s8(Aptr), b0 = vld1_s8(Bptr), b1 = vld1_s8(Bptr + 8); Aptr += 8; Bptr += 16; //__builtin_prefetch(Aptr, 0, 0); @@ -326,13 +333,11 @@ void exec_gemm_int8_int8_int16(const int8_t *A_, const int8_t *B_, int16_t *C, GAO_16(5); GAO_16(6); GAO_16(7); - - } - int16_t * __restrict Cptr = C + m*N + n; + int16_t* __restrict Cptr = C + m * N + n; - if (m+8 <= M) { // sub-case 1: m+8 <= M && n+16 <= N + if (m + 8 <= M) { // sub-case 1: m+8 <= M && n+16 <= N STORE_16(0) STORE_16(1) STORE_16(2) @@ -343,7 +348,7 @@ void exec_gemm_int8_int8_int16(const int8_t *A_, const int8_t *B_, int16_t *C, STORE_16(7) } else { size_t plen = M - m; - while(1) { + while (1) { STORE_PARTRIAL_M_16(0) STORE_PARTRIAL_M_16(1) STORE_PARTRIAL_M_16(2) @@ -357,19 +362,13 @@ void exec_gemm_int8_int8_int16(const int8_t *A_, const int8_t *B_, int16_t *C, } for (; n < N; n += 8) { - int16x8_t l0 = vdupq_n_s16(0), - l1 = vdupq_n_s16(0), - l2 = vdupq_n_s16(0), - l3 = vdupq_n_s16(0), - l4 = vdupq_n_s16(0), - l5 = vdupq_n_s16(0), - l6 = vdupq_n_s16(0), - l7 = vdupq_n_s16(0); - const int8_t * __restrict Aptr = A + m*K; - const int8_t * __restrict Bptr = B + n*K; + int16x8_t l0 = vdupq_n_s16(0), l1 = vdupq_n_s16(0), l2 = vdupq_n_s16(0), + l3 = vdupq_n_s16(0), l4 = vdupq_n_s16(0), l5 = vdupq_n_s16(0), + l6 = vdupq_n_s16(0), l7 = vdupq_n_s16(0); + const int8_t* __restrict Aptr = A + m * K; + const int8_t* __restrict Bptr = B + n * K; for (size_t k = 0; k < K; ++k) { - int8x8_t a = vld1_s8(Aptr), - b = vld1_s8(Bptr); + int8x8_t a = vld1_s8(Aptr), b = vld1_s8(Bptr); int8x8_t tmp; GAO(0); GAO(1); @@ -382,18 +381,18 @@ void exec_gemm_int8_int8_int16(const int8_t *A_, const int8_t *B_, int16_t *C, Aptr += 8; Bptr += 8; } - int16_t * __restrict Cptr = C + m*N + n; - - if (m+8 <= M && n+8 <= N) { - vst1q_s16(Cptr + 0*N, l0); - vst1q_s16(Cptr + 1*N, l1); - vst1q_s16(Cptr + 2*N, l2); - vst1q_s16(Cptr + 3*N, l3); - vst1q_s16(Cptr + 4*N, l4); - vst1q_s16(Cptr + 5*N, l5); - vst1q_s16(Cptr + 6*N, l6); - vst1q_s16(Cptr + 7*N, l7); - } else if (m+8 <=M && n+8 > N) { // m+8<=M && n+8<=N && n+8>N + int16_t* __restrict Cptr = C + m * N + n; + + if (m + 8 <= M && n + 8 <= N) { + vst1q_s16(Cptr + 0 * N, l0); + vst1q_s16(Cptr + 1 * N, l1); + vst1q_s16(Cptr + 2 * N, l2); + vst1q_s16(Cptr + 3 * N, l3); + vst1q_s16(Cptr + 4 * N, l4); + vst1q_s16(Cptr + 5 * N, l5); + vst1q_s16(Cptr + 6 * N, l6); + vst1q_s16(Cptr + 7 * N, l7); + } else if (m + 8 <= M && n + 8 > N) { // m+8<=M && n+8<=N && n+8>N size_t plen = N - n; STORE_PARTRIAL_N(0) STORE_PARTRIAL_N(1) @@ -403,9 +402,9 @@ void exec_gemm_int8_int8_int16(const int8_t *A_, const int8_t *B_, int16_t *C, STORE_PARTRIAL_N(5) STORE_PARTRIAL_N(6) STORE_PARTRIAL_N(7) - } else if(n+8 <= N) { // m+8>M && n+8<=N + } else if (n + 8 <= N) { // m+8>M && n+8<=N size_t plen = M - m; - while(1) { + while (1) { STORE_PARTRIAL_M(0) STORE_PARTRIAL_M(1) STORE_PARTRIAL_M(2) @@ -416,22 +415,20 @@ void exec_gemm_int8_int8_int16(const int8_t *A_, const int8_t *B_, int16_t *C, break; } } else { - int16_t cache[8*8]; - vst1q_s16(cache + 0*8, l0); - vst1q_s16(cache + 1*8, l1); - vst1q_s16(cache + 2*8, l2); - vst1q_s16(cache + 3*8, l3); - vst1q_s16(cache + 4*8, l4); - vst1q_s16(cache + 5*8, l5); - vst1q_s16(cache + 6*8, l6); - vst1q_s16(cache + 7*8, l7); - - for (size_t i = 0; m+i < M && i < 8; ++i) - for (size_t j = 0; n+j < N && j < 8; ++j) - { - Cptr[i*N + j] = cache[i*8 + j]; - } - + int16_t cache[8 * 8]; + vst1q_s16(cache + 0 * 8, l0); + vst1q_s16(cache + 1 * 8, l1); + vst1q_s16(cache + 2 * 8, l2); + vst1q_s16(cache + 3 * 8, l3); + vst1q_s16(cache + 4 * 8, l4); + vst1q_s16(cache + 5 * 8, l5); + vst1q_s16(cache + 6 * 8, l6); + vst1q_s16(cache + 7 * 8, l7); + + for (size_t i = 0; m + i < M && i < 8; ++i) + for (size_t j = 0; n + j < N && j < 8; ++j) { + Cptr[i * N + j] = cache[i * 8 + j]; + } } } } @@ -447,8 +444,7 @@ void exec_gemm_int8_int8_int16(const int8_t *A_, const int8_t *B_, int16_t *C, #undef STORE_PARTRIAL_N_16 #undef STORE_PARTRIAL_M_16 -} // namespace arm_common -} // namespace megdnn +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.h b/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.h index 50a57269..79ecb96e 100644 --- a/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.h +++ b/dnn/src/arm_common/matrix_mul/exec_gemm_int8_int8_int16.h @@ -9,18 +9,18 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once -#include #include +#include namespace megdnn { namespace arm_common { ///! Row-major gemm -void exec_gemm_int8_int8_int16(const int8_t* A, const int8_t* B, int16_t* C, - size_t M, size_t K, size_t N, size_t LDB, - int8_t* w0, int8_t* w1); +void exec_gemm_int8_int8_int16( + const int8_t* A, const int8_t* B, int16_t* C, size_t M, size_t K, size_t N, + size_t LDB, int8_t* w0, int8_t* w1); -} // namespace arm_common -} // namespace megdnn +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp b/dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp index 56531c46..5d0c02f0 100644 --- a/dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp +++ b/dnn/src/arm_common/matrix_mul/fp16/hgemv.cpp @@ -20,9 +20,9 @@ namespace { #define UNROLL_OUT(cb, step) UNROLL_CALL_RAW(step, cb) -void hgemv_naive_n(const __fp16* __restrict A, const __fp16* __restrict B, - __fp16* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { +void hgemv_naive_n( + const __fp16* __restrict A, const __fp16* __restrict B, __fp16* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert(N == 1 && Bstride == 1); #define vaddvq_f16(v) \ ((v)[0] + (v)[1] + (v)[2] + (v)[3] + (v)[4] + (v)[5] + (v)[6] + (v)[7]) @@ -96,11 +96,9 @@ void hgemv_naive_n(const __fp16* __restrict A, const __fp16* __restrict B, } } // namespace -void megdnn::arm_common::gemv_like(const __fp16* __restrict A, - const __fp16* __restrict B, - __fp16* __restrict C, size_t M, size_t N, - size_t K, size_t Astride, size_t Bstride, - size_t Cstride) { +void megdnn::arm_common::gemv_like( + const __fp16* __restrict A, const __fp16* __restrict B, __fp16* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert((M <= 4) || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); if (N == 1) { return hgemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); @@ -112,12 +110,12 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, for (; k + 4 <= K; k += 4) { size_t n = 0; for (; n + 8 <= N; n += 8) { - float16x8_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, - a22, a23, a30, a31, a32, a33; + float16x8_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23, + a30, a31, a32, a33; float16x8_t b0, b1, b2, b3; float16x8_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); #define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]); @@ -151,12 +149,12 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, #undef vstore } for (; n + 4 <= N; n += 4) { - float16x4_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, - a22, a23, a30, a31, a32, a33; + float16x4_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23, + a30, a31, a32, a33; float16x4_t b0, b1, b2, b3; float16x4_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); #define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]); @@ -190,8 +188,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, #undef vstore } for (; n < N; n += 1) { - __fp16 a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, - a23, a30, a31, a32, a33; + __fp16 a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23, a30, + a31, a32, a33; __fp16 b0, b1, b2, b3; __fp16 c0, c1, c2, c3; #define loadC(i) c##i = C[(m + i) * Cstride + n]; @@ -227,8 +225,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x8_t a00, a01, a10, a11, a20, a21, a30, a31; float16x8_t b0, b1; float16x8_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); #define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]); @@ -265,8 +263,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x4_t a00, a01, a10, a11, a20, a21, a30, a31; float16x4_t b0, b1; float16x4_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); #define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]); @@ -336,8 +334,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x8_t a00, a10, a20, a30; float16x8_t b0; float16x8_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); #define loadA2(i) a2##i = vdupq_n_f16(A[(m + 2) * Astride + k + i]); @@ -374,8 +372,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x4_t a00, a10, a20, a30; float16x4_t b0; float16x4_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); #define loadA2(i) a2##i = vdup_n_f16(A[(m + 2) * Astride + k + i]); @@ -449,8 +447,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x8_t a00, a01, a02, a03, a10, a11, a12, a13; float16x8_t b0, b1, b2, b3; float16x8_t c0, c1; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); UNROLL_OUT(loadC, 2) @@ -475,8 +473,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x4_t a00, a01, a02, a03, a10, a11, a12, a13; float16x4_t b0, b1, b2, b3; float16x4_t c0, c1; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); UNROLL_OUT(loadC, 2) @@ -526,8 +524,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x8_t a00, a01, a10, a11; float16x8_t b0, b1; float16x8_t c0, c1; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); UNROLL_OUT(loadC, 2) @@ -552,8 +550,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x4_t a00, a01, a10, a11; float16x4_t b0, b1; float16x4_t c0, c1; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); UNROLL_OUT(loadC, 2) @@ -603,8 +601,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x8_t a00, a10; float16x8_t b0; float16x8_t c0, c1; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdupq_n_f16(A[(m + 1) * Astride + k + i]); UNROLL_OUT(loadC, 2) @@ -629,8 +627,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x4_t a00, a10; float16x4_t b0; float16x4_t c0, c1; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdup_n_f16(A[(m + 1) * Astride + k + i]); UNROLL_OUT(loadC, 2) @@ -684,8 +682,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x8_t a00, a01, a02, a03; float16x8_t b0, b1, b2, b3; float16x8_t c0; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); UNROLL_OUT(loadC, 1) UNROLL_OUT(loadB, 4) @@ -704,8 +702,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x4_t a00, a01, a02, a03; float16x4_t b0, b1, b2, b3; float16x4_t c0; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); UNROLL_OUT(loadC, 1) UNROLL_OUT(loadB, 4) @@ -745,8 +743,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x8_t a00, a01; float16x8_t b0, b1; float16x8_t c0; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); UNROLL_OUT(loadC, 1) UNROLL_OUT(loadB, 2) @@ -765,8 +763,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x4_t a00, a01; float16x4_t b0, b1; float16x4_t c0; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); UNROLL_OUT(loadC, 1) UNROLL_OUT(loadB, 2) @@ -806,8 +804,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x8_t a00; float16x8_t b0; float16x8_t c0; -#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdupq_n_f16(A[(m + 0) * Astride + k + i]); UNROLL_OUT(loadC, 1) UNROLL_OUT(loadB, 1) @@ -826,8 +824,8 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, float16x4_t a00; float16x4_t b0; float16x4_t c0; -#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1_f16(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1_f16(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdup_n_f16(A[(m + 0) * Astride + k + i]); UNROLL_OUT(loadC, 1) UNROLL_OUT(loadB, 1) @@ -863,10 +861,9 @@ void megdnn::arm_common::gemv_like(const __fp16* __restrict A, } } } -bool megdnn::arm_common::is_hgemv_preferred(bool transposeA, bool transposeB, - size_t M, size_t N, size_t K, - size_t /*LDA*/, size_t LDB, - size_t /*LDC*/) { +bool megdnn::arm_common::is_hgemv_preferred( + bool transposeA, bool transposeB, size_t M, size_t N, size_t K, size_t /*LDA*/, + size_t LDB, size_t /*LDC*/) { if (transposeA) return false; if (transposeB) diff --git a/dnn/src/arm_common/matrix_mul/fp16/hgemv.h b/dnn/src/arm_common/matrix_mul/fp16/hgemv.h index 88d52f44..4cdebcaa 100644 --- a/dnn/src/arm_common/matrix_mul/fp16/hgemv.h +++ b/dnn/src/arm_common/matrix_mul/fp16/hgemv.h @@ -16,15 +16,15 @@ namespace megdnn { namespace arm_common { -bool is_hgemv_preferred(bool transposeA, bool transposeB, size_t M, size_t N, - size_t K, size_t /*LDA*/, size_t LDB, size_t /*LDC*/); +bool is_hgemv_preferred( + bool transposeA, bool transposeB, size_t M, size_t N, size_t K, size_t /*LDA*/, + size_t LDB, size_t /*LDC*/); -void gemv_like(const __fp16* __restrict A, const __fp16* __restrict B, - __fp16* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride); +void gemv_like( + const __fp16* __restrict A, const __fp16* __restrict B, __fp16* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); - -} // namespace aarch64 +} // namespace arm_common } // namespace megdnn #endif diff --git a/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp index 72aafce3..4187351c 100644 --- a/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp +++ b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.cpp @@ -30,17 +30,17 @@ namespace { #if !defined(__aarch64__) #define vaddvq_f32(v) (v)[0] + (v)[1] + (v)[2] + (v)[3] #endif -void sgemv_naive_n(const float* __restrict A, const float* __restrict B, - float* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { +void sgemv_naive_n( + const float* __restrict A, const float* __restrict B, float* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert(N == 1 && Bstride == 1); #define reset_acc(i) acc##i = 0; -#define acc_calu(i) acc##i += A[(m + i) * Astride + k] * B[k]; +#define acc_calu(i) acc##i += A[(m + i) * Astride + k] * B[k]; #define vdupq_sum(i) sum##i = vdupq_n_f32(0.f); -#define loadA(i) a##i = vld1q_f32(A + (m + i) * Astride + k); -#define loadB(i) b##i = vld1q_f32(B + k); +#define loadA(i) a##i = vld1q_f32(A + (m + i) * Astride + k); +#define loadB(i) b##i = vld1q_f32(B + k); #define calculate(i) sum##i = vmlaq_f32(sum##i, a##i, b0); -#define vstore(i) C[(m + i) * Cstride] = vaddvq_f32(sum##i) + acc##i; +#define vstore(i) C[(m + i) * Cstride] = vaddvq_f32(sum##i) + acc##i; size_t m = 0; for (; m < M; m += 1) { float acc0; @@ -69,9 +69,9 @@ void sgemv_naive_n(const float* __restrict A, const float* __restrict B, #undef vaddvq_f32 #endif -void sgemv_naive_m(const float* __restrict A, const float* __restrict B, - float* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { +void sgemv_naive_m( + const float* __restrict A, const float* __restrict B, float* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { size_t m = 0; for (; m + 4 <= M; m += 4) { size_t k = 0; @@ -79,12 +79,12 @@ void sgemv_naive_m(const float* __restrict A, const float* __restrict B, for (; k + 4 <= K; k += 4) { size_t n = 0; for (; n + 4 <= N; n += 4) { - float32x4_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, - a22, a23, a30, a31, a32, a33; + float32x4_t a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23, + a30, a31, a32, a33; float32x4_t b0, b1, b2, b3; float32x4_t c0, c1, c2, c3; -#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); #define loadA0(i) a0##i = vdupq_n_f32(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdupq_n_f32(A[(m + 1) * Astride + k + i]); #define loadA2(i) a2##i = vdupq_n_f32(A[(m + 2) * Astride + k + i]); @@ -147,8 +147,8 @@ void sgemv_naive_m(const float* __restrict A, const float* __restrict B, #undef vstore } for (; n < N; n += 1) { - float a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, - a23, a30, a31, a32, a33; + float a00, a01, a02, a03, a10, a11, a12, a13, a20, a21, a22, a23, a30, + a31, a32, a33; float b0, b1, b2, b3; float c0, c1, c2, c3; #define loadC(i) c##i = C[(m + i) * Cstride + n]; @@ -342,8 +342,8 @@ void sgemv_naive_m(const float* __restrict A, const float* __restrict B, float32x4_t c0, c1; #define loadA0(i) a0##i = vdupq_n_f32(A[(m + 0) * Astride + k + i]); #define loadA1(i) a1##i = vdupq_n_f32(A[(m + 1) * Astride + k + i]); -#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 2) UNROLL_OUT(loadB, 4) UNROLL_OUT(loadA0, 4) @@ -561,8 +561,8 @@ void sgemv_naive_m(const float* __restrict A, const float* __restrict B, float32x4_t b0, b1, b2, b3; float32x4_t c0; #define loadA0(i) a0##i = vdupq_n_f32(A[m * Astride + k + i]); -#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); -#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); +#define loadB(i) b##i = vld1q_f32(B + (k + i) * Bstride + n); +#define loadC(i) c##i = vld1q_f32(C + (m + i) * Cstride + n); UNROLL_OUT(loadC, 1) UNROLL_OUT(loadB, 4) UNROLL_OUT(loadA0, 4) @@ -755,12 +755,12 @@ void sgemv_naive_m(const float* __restrict A, const float* __restrict B, } } -void sgemv_naive_n_mk4(const float* __restrict A, const float* __restrict B, - float* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { +void sgemv_naive_n_mk4( + const float* __restrict A, const float* __restrict B, float* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { constexpr size_t PACK_SIZE = 4; - megdnn_assert(N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && - K % PACK_SIZE == 0); + megdnn_assert( + N == 1 && Bstride == PACK_SIZE && M % PACK_SIZE == 0 && K % PACK_SIZE == 0); auto Aptr = A; auto Cptr = C; size_t m = 0; @@ -806,9 +806,9 @@ void sgemv_naive_n_mk4(const float* __restrict A, const float* __restrict B, namespace megdnn { namespace arm_common { -void gemv_like(const float* __restrict A, const float* __restrict B, - float* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { +void gemv_like( + const float* __restrict A, const float* __restrict B, float* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert(M < 8 || (M == 8 && K <= 2) || (N == 1 && Bstride == 1)); if (N == 1) { MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW_N"_hash)) { @@ -823,9 +823,9 @@ void gemv_like(const float* __restrict A, const float* __restrict B, } } -void gemv_like_mk4(const float* __restrict A, const float* __restrict B, - float* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { +void gemv_like_mk4( + const float* __restrict A, const float* __restrict B, float* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert(N == 1 && Bstride == 4); MIDOUT_BEGIN(megdnn_fp32_sgemv, midout_iv("F32_GEMV_NCHW44_N"_hash)) { return sgemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); diff --git a/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h index fc5a6c82..a6e46d5f 100644 --- a/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h +++ b/dnn/src/arm_common/matrix_mul/fp32/exec_sgemv.h @@ -15,18 +15,17 @@ namespace megdnn { namespace arm_common { -bool is_sgemv_like_preferred(bool row_major, bool transposeA, bool transposeB, - size_t M, size_t N, size_t K, float alpha, - size_t /* LDA */, size_t LDB, float beta, - size_t /* LDC */); +bool is_sgemv_like_preferred( + bool row_major, bool transposeA, bool transposeB, size_t M, size_t N, size_t K, + float alpha, size_t /* LDA */, size_t LDB, float beta, size_t /* LDC */); -void gemv_like(const float* __restrict A, const float* __restrict B, - float* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride); +void gemv_like( + const float* __restrict A, const float* __restrict B, float* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); -void gemv_like_mk4(const float* __restrict A, const float* __restrict B, - float* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride); +void gemv_like_mk4( + const float* __restrict A, const float* __restrict B, float* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.cpp b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp index 0170dfec..affe114b 100644 --- a/dnn/src/arm_common/matrix_mul/int8/gemv.cpp +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp @@ -11,9 +11,9 @@ #include "src/arm_common/simd_macro/marm_neon.h" +#include "megdnn/oprs.h" #include "src/arm_common/matrix_mul/int8/gemv.h" #include "src/common/utils.h" -#include "megdnn/oprs.h" #include "midout.h" MIDOUT_DECL(megdnn_arm_common_int8_gemv) @@ -21,12 +21,11 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv) using namespace megdnn; using namespace arm_common; - namespace { -void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { +void gemv_naive_n( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert(N == 1 && Bstride == 1); size_t m = 0; for (; m + 2 <= M; m += 2) { @@ -95,9 +94,9 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, } } -void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { +void gemv_naive_n_mk4( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { constexpr size_t PACK_SIZE = 4; megdnn_assert(N == 1 && Bstride == 4); auto Aptr = A; @@ -173,9 +172,9 @@ void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, #if MGB_ENABLE_DOT namespace { MEGDNN_ATTRIBUTE_TARGET("dotprod") -void gemv_naive_n_dot(const int8_t* __restrict A, const int8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { +void gemv_naive_n_dot( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert(N == 1 && Bstride == 1); size_t m = 0; for (; m + 2 <= M; m += 2) { @@ -184,8 +183,7 @@ void gemv_naive_n_dot(const int8_t* __restrict A, const int8_t* __restrict B, size_t k = 0; for (; k + 16 <= K; k += 16) { int64x2_t a0 = vreinterpretq_s64_s8(vld1q_s8(A + m * Astride + k)); - int64x2_t a1 = - vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); + int64x2_t a1 = vreinterpretq_s64_s8(vld1q_s8(A + (m + 1) * Astride + k)); //! the first 8 elements is m, the last 8 elements is m + 1 int8x16_t a2 = vreinterpretq_s8_s64(vzip1q_s64(a0, a1)); int8x16_t a3 = vreinterpretq_s8_s64(vzip2q_s64(a0, a1)); @@ -243,9 +241,9 @@ void gemv_naive_n_dot(const int8_t* __restrict A, const int8_t* __restrict B, } MEGDNN_ATTRIBUTE_TARGET("dotprod") -void gemv_naive_n_mk4_dotprod(const int8_t* __restrict A, const int8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride) { +void gemv_naive_n_mk4_dotprod( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { constexpr size_t PACK_SIZE = 4; megdnn_assert(N == 1 && Bstride == 4); @@ -323,10 +321,9 @@ void gemv_naive_n_mk4_dotprod(const int8_t* __restrict A, const int8_t* __restri } MEGDNN_ATTRIBUTE_TARGET("dotprod") -void gemv_naive_n_mk4_dot(const int8_t* __restrict A, - const int8_t* __restrict B, int32_t* __restrict C, - size_t M, size_t N, size_t K, size_t Astride, - size_t Bstride, size_t Cstride) { +void gemv_naive_n_mk4_dot( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { constexpr size_t PACK_SIZE = 4; megdnn_assert(N == 1 && Bstride == 4); @@ -379,10 +376,9 @@ void gemv_naive_n_mk4_dot(const int8_t* __restrict A, } // namespace #endif -bool arm_common::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, - size_t M, size_t N, size_t K, - size_t LDA, size_t LDB, - size_t LDC) { +bool arm_common::is_gemv_like_preferred_int8( + bool transposeA, bool transposeB, size_t M, size_t N, size_t K, size_t LDA, + size_t LDB, size_t LDC) { MEGDNN_MARK_USED_VAR(LDA); MEGDNN_MARK_USED_VAR(LDB); MEGDNN_MARK_USED_VAR(LDC); @@ -396,17 +392,14 @@ bool arm_common::is_gemv_like_preferred_int8(bool transposeA, bool transposeB, return N == 1 && LDB == 1; } -void arm_common::gemv_like(const int8_t* __restrict A, - const int8_t* __restrict B, int32_t* __restrict C, - size_t M, size_t N, size_t K, size_t Astride, - size_t Bstride, size_t Cstride) { +void arm_common::gemv_like( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert(N == 1); - MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, - midout_iv("INT8_gemv_like"_hash)) { + MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gemv_like"_hash)) { #if MGB_ENABLE_DOT if (cpuinfo_has_arm_neon_dot()) { - return gemv_naive_n_dot(A, B, C, M, N, K, Astride, Bstride, - Cstride); + return gemv_naive_n_dot(A, B, C, M, N, K, Astride, Bstride, Cstride); } else { return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); } @@ -417,21 +410,17 @@ void arm_common::gemv_like(const int8_t* __restrict A, MIDOUT_END(); } -void arm_common::gemv_like_mk4(const int8_t* __restrict A, - const int8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, - size_t K, size_t Astride, size_t Bstride, - size_t Cstride) { +void arm_common::gemv_like_mk4( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert(N == 1); - MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, - midout_iv("INT8_gemv_like_mk4"_hash)) { + MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gemv_like_mk4"_hash)) { #if MGB_ENABLE_DOT if (cpuinfo_has_arm_neon_dot()) { - return gemv_naive_n_mk4_dotprod(A, B, C, M, N, K, Astride, Bstride, - Cstride); + return gemv_naive_n_mk4_dotprod( + A, B, C, M, N, K, Astride, Bstride, Cstride); } else { - return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, - Cstride); + return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); } #else return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); @@ -441,20 +430,16 @@ void arm_common::gemv_like_mk4(const int8_t* __restrict A, } #if MGB_ENABLE_DOT -void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A, - const int8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, - size_t K, size_t Astride, size_t Bstride, - size_t Cstride) { +void arm_common::gemv_like_mk4_dot( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert(N == 1); - MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, - midout_iv("INT8_gemv_like_mk4_dot"_hash)) { - return gemv_naive_n_mk4_dot(A, B, C, M, N, K, Astride, Bstride, - Cstride); + MIDOUT_BEGIN( + megdnn_arm_common_int8_gemv, midout_iv("INT8_gemv_like_mk4_dot"_hash)) { + return gemv_naive_n_mk4_dot(A, B, C, M, N, K, Astride, Bstride, Cstride); } MIDOUT_END(); } #endif - // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.h b/dnn/src/arm_common/matrix_mul/int8/gemv.h index 13ff27b7..1d92648a 100644 --- a/dnn/src/arm_common/matrix_mul/int8/gemv.h +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.h @@ -16,26 +16,25 @@ namespace megdnn { namespace arm_common { -bool is_gemv_like_preferred_int8(bool transposeA, bool transposeB, size_t M, - size_t N, size_t K, size_t LDA, size_t LDB, - size_t LDC); +bool is_gemv_like_preferred_int8( + bool transposeA, bool transposeB, size_t M, size_t N, size_t K, size_t LDA, + size_t LDB, size_t LDC); -void gemv_like(const int8_t* __restrict A, const int8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride); +void gemv_like( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); -void gemv_like_mk4(const int8_t* __restrict A, const int8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride); +void gemv_like_mk4( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); #if MGB_ENABLE_DOT -void gemv_like_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B, - int32_t* __restrict C, size_t M, size_t N, size_t K, - size_t Astride, size_t Bstride, size_t Cstride); +void gemv_like_mk4_dot( + const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, + size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); #endif } // namespace arm_common } // namespace megdnn - // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.cpp b/dnn/src/arm_common/matrix_mul/opr_impl.cpp index 9d09eed4..6cfef7f1 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.cpp +++ b/dnn/src/arm_common/matrix_mul/opr_impl.cpp @@ -63,12 +63,12 @@ const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(MatrixMulImpl) -SmallVector -MatrixMulImpl::get_all_packed_algo() { +SmallVector MatrixMulImpl::get_all_packed_algo() { static AlgoPack s_algo_pack; auto&& algos = fallback::MatrixMulImpl::get_all_packed_algo(); - algos.insert(algos.begin(), algo_pack().all_algos().begin(), - algo_pack().all_algos().end()); + algos.insert( + algos.begin(), algo_pack().all_algos().begin(), + algo_pack().all_algos().end()); return std::move(algos); } diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.h b/dnn/src/arm_common/matrix_mul/opr_impl.h index cbdb120a..e79aebe0 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.h +++ b/dnn/src/arm_common/matrix_mul/opr_impl.h @@ -9,9 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once +#include "src/common/algo_base.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/opr_impl.h" -#include "src/common/algo_base.h" namespace megdnn { namespace arm_common { @@ -22,28 +22,27 @@ public: bool is_thread_safe() const override { return true; } class AlgoBase : public fallback::MatrixMulImpl::AlgoBase { - public: - AlgoBase() : fallback::MatrixMulImpl::AlgoBase() { - m_handle_type = Handle::HandleType::ARM_COMMON; - } + public: + AlgoBase() : fallback::MatrixMulImpl::AlgoBase() { + m_handle_type = Handle::HandleType::ARM_COMMON; + } }; - SmallVector get_all_packed_algo() - override; + SmallVector get_all_packed_algo() override; MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); protected: - class AlgoF32Gemv; // Arm_common F32 Gemv - class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 - class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv - class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44 - class AlgoGevm; // Arm_common Gevm(support int8 and fp32) + class AlgoF32Gemv; // Arm_common F32 Gemv + class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 + class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv + class AlgoInt8x8x32GemvMK4; // Arm_common Int8x8x32 Gemv NCHW44 + class AlgoGevm; // Arm_common Gevm(support int8 and fp32) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class AlgoF16Gemv; #endif #if MGB_ENABLE_DOT - class AlgoInt8x8x32GemvMK4Dot;// Arm_common Int8x8x32 Gemv NCHW44_DOT + class AlgoInt8x8x32GemvMK4Dot; // Arm_common Int8x8x32 Gemv NCHW44_DOT #endif class AlgoInt8x8x16; // Arm_common Int 8x8x16 class AlgoPack; diff --git a/dnn/src/arm_common/neon_struct.h b/dnn/src/arm_common/neon_struct.h index 43a4eace..98eca959 100644 --- a/dnn/src/arm_common/neon_struct.h +++ b/dnn/src/arm_common/neon_struct.h @@ -16,14 +16,13 @@ namespace megdnn { namespace { struct Vdotq_s32_h { - static __ai int32x4_t impl(int8x16_t& a, int8x16_t& b, int32x4_t& c, - int16x8_t& temp) { + static __ai int32x4_t + impl(int8x16_t& a, int8x16_t& b, int32x4_t& c, int16x8_t& temp) { return vdotq_s32_h(a, b, c, temp); } }; struct Vdot2_s32_h { - static __ai int32x4_t impl(int8x8_t a, int8x8_t b, int32x4_t c, - int16x8_t temp) { + static __ai int32x4_t impl(int8x8_t a, int8x8_t b, int32x4_t c, int16x8_t temp) { return vdot2_s32_h(a, b, c, temp); } }; @@ -38,17 +37,13 @@ struct Vld1q_s8 { static __ai int8x16_t impl(const int8_t* ptr) { return vld1q_s8(ptr); } }; struct Vld1q_f32 { - static __ai float32x4_t impl(const float32_t* ptr) { - return vld1q_f32(ptr); - } + static __ai float32x4_t impl(const float32_t* ptr) { return vld1q_f32(ptr); } }; struct Vld1_s8 { static __ai int8x8_t impl(const int8_t* ptr) { return vld1_s8(ptr); } }; struct Vldq_dup_4s8_8s16 { - static __ai int16x8_t impl(const int8_t* ptr) { - return vldq_dup_4s8_8s16(ptr); - } + static __ai int16x8_t impl(const int8_t* ptr) { return vldq_dup_4s8_8s16(ptr); } }; struct Vldq_tbl_low_s8 { @@ -58,9 +53,7 @@ struct Vldq_tbl_low_s8 { }; struct Vld1_dup_s8_s16 { - static __ai int16x8_t impl(const int8_t* ptr) { - return vld1_dup_s8_s16(ptr); - } + static __ai int16x8_t impl(const int8_t* ptr) { return vld1_dup_s8_s16(ptr); } }; struct Vfmaq_laneq_f32 { diff --git a/dnn/src/arm_common/pooling/algo.cpp b/dnn/src/arm_common/pooling/algo.cpp index 980ee905..44b5db4b 100644 --- a/dnn/src/arm_common/pooling/algo.cpp +++ b/dnn/src/arm_common/pooling/algo.cpp @@ -27,18 +27,17 @@ namespace megdnn { namespace arm_common { WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) { - megdnn_assert((param.src_type.category() == DTypeCategory::FLOAT || - param.src_type.enumv() == DTypeEnum::QuantizedS8 || - param.src_type.enumv() == DTypeEnum::Quantized8Asymm || - param.src_type == dtype::Int8{}) && - param.format == param::Pooling::Format::NCHW && - (param.mode == param::Pooling::Mode::MAX || - (param.mode == param::Pooling::Mode::AVERAGE && - param.filter[0] == 3)) && - param.filter[0] == param.filter[1] && - (param.filter[0] == 3 || param.filter[1] == 5) && - param.stride[0] == 2 && param.stride[1] == 2 && - param.isz[0] >= 2 && param.isz[1] >= 2); + megdnn_assert( + (param.src_type.category() == DTypeCategory::FLOAT || + param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Quantized8Asymm || + param.src_type == dtype::Int8{}) && + param.format == param::Pooling::Format::NCHW && + (param.mode == param::Pooling::Mode::MAX || + (param.mode == param::Pooling::Mode::AVERAGE && param.filter[0] == 3)) && + param.filter[0] == param.filter[1] && + (param.filter[0] == 3 || param.filter[1] == 5) && param.stride[0] == 2 && + param.stride[1] == 2 && param.isz[0] >= 2 && param.isz[1] >= 2); //! max pooling nxn stride 2 auto IW = param.isz[1]; auto OW = param.osz[1]; @@ -56,11 +55,11 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) { return ws; } -WorkspaceBundle get_bundle_nchw44( - const PoolingImpl::PoolingKernSizeParam& param) { - megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8 || - param.src_type.enumv() == DTypeEnum::Int8) && - (param.format == param::Pooling::Format::NCHW44)); +WorkspaceBundle get_bundle_nchw44(const PoolingImpl::PoolingKernSizeParam& param) { + megdnn_assert( + (param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Int8) && + (param.format == param::Pooling::Format::NCHW44)); auto IH = param.isz[0]; auto IW = param.isz[1]; auto PH = param.padding[0]; @@ -72,9 +71,9 @@ WorkspaceBundle get_bundle_nchw44( return WorkspaceBundle(nullptr, {padding_size}); } -const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, - size_t& IH2, size_t& IW2, size_t PH, size_t PW, - const WorkspaceBundle& ws, bool is_max_mode) { +const int8_t* handle_padding( + const int8_t* src, size_t IH, size_t IW, size_t& IH2, size_t& IW2, size_t PH, + size_t PW, const WorkspaceBundle& ws, bool is_max_mode) { int8_t* sptr_base = nullptr; int8_t padding_value = is_max_mode ? INT8_MIN : 0; bool need_pad = ((PH != 0) || (PW != 0)) ? true : false; @@ -84,8 +83,9 @@ const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, sptr_base = static_cast(ws.get(0)); memset(sptr_base, padding_value, sizeof(int8_t) * IH2 * IW2 * 4); rep(ih, IH) { - std::memcpy(sptr_base + (ih + PH) * IW2 * 4 + PW * 4, - src + ih * IW * 4, sizeof(int8_t) * IW * 4); + std::memcpy( + sptr_base + (ih + PH) * IW2 * 4 + PW * 4, src + ih * IW * 4, + sizeof(int8_t) * IW * 4); } } else { IH2 = IH; @@ -109,8 +109,7 @@ bool PoolingImpl::AlgoFilterxModexStride1::usable( return avaible && is_mode_ok; } -void PoolingImpl::AlgoFilterxModexStride1::exec( - const PoolingKernParam& param) const { +void PoolingImpl::AlgoFilterxModexStride1::exec(const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; auto N = param.n, C = param.ic; @@ -121,61 +120,60 @@ void PoolingImpl::AlgoFilterxModexStride1::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(Pooler, NeonPooler, window, midout_type_id) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(0), \ - midout_iv(midout_type_id), Pooler::MIDOUT_CASE_NUM, \ - NeonPooler::MIDOUT_CASE_NUM, window) { \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ - src_dtype = param.src_type](size_t index, size_t) { \ - size_t n = index / C; \ - size_t c = index % C; \ - do_pooling_compact< \ - Pooler MEGDNN_COMMA NeonPooler MEGDNN_COMMA window>( \ - static_cast(src_ptr) + \ - n * C * IH * IW + c * IH * IW, \ - static_cast(dst_ptr) + \ - n * C * OH * OW + c * OH * OW, \ - src_dtype, IH, IW, OH, OW, PH, PW); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ - run); \ - } \ +#define DISPATCH_FUNC(Pooler, NeonPooler, window, midout_type_id) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_pooling, midout_iv(0), midout_iv(midout_type_id), \ + Pooler::MIDOUT_CASE_NUM, NeonPooler::MIDOUT_CASE_NUM, window) { \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + src_dtype = param.src_type](size_t index, size_t) { \ + size_t n = index / C; \ + size_t c = index % C; \ + do_pooling_compact( \ + static_cast(src_ptr) + \ + n * C * IH * IW + c * IH * IW, \ + static_cast(dst_ptr) + n * C * OH * OW + \ + c * OH * OW, \ + src_dtype, IH, IW, OH, OW, PH, PW); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ + } \ MIDOUT_END() -#define DISPATCH_WINDOW(Pooler, NeonPooler, dtype, ctype, comp_type, \ - midout_type_id) \ - switch (FH) { \ - case 2: { \ - using _Pooler = Pooler<4, dtype, ctype, comp_type>; \ - using _NeonPooler = NeonPooler<4, dtype, ctype, comp_type>; \ - DISPATCH_FUNC(_Pooler, _NeonPooler, 2, midout_type_id); \ - break; \ - } \ - case 3: { \ - using _Pooler = Pooler<9, dtype, ctype, comp_type>; \ - using _NeonPooler = NeonPooler<9, dtype, ctype, comp_type>; \ - DISPATCH_FUNC(_Pooler, _NeonPooler, 3, midout_type_id); \ - break; \ - } \ - default: \ - megdnn_assert(0, "unsupport pooling filter size"); \ - break; \ +#define DISPATCH_WINDOW(Pooler, NeonPooler, dtype, ctype, comp_type, midout_type_id) \ + switch (FH) { \ + case 2: { \ + using _Pooler = Pooler<4, dtype, ctype, comp_type>; \ + using _NeonPooler = NeonPooler<4, dtype, ctype, comp_type>; \ + DISPATCH_FUNC(_Pooler, _NeonPooler, 2, midout_type_id); \ + break; \ + } \ + case 3: { \ + using _Pooler = Pooler<9, dtype, ctype, comp_type>; \ + using _NeonPooler = NeonPooler<9, dtype, ctype, comp_type>; \ + DISPATCH_FUNC(_Pooler, _NeonPooler, 3, midout_type_id); \ + break; \ + } \ + default: \ + megdnn_assert(0, "unsupport pooling filter size"); \ + break; \ } -#define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \ - switch (param.mode) { \ - case Mode::MAX: \ - DISPATCH_WINDOW(MaxPooler, NeonMaxPooler, dtype, ctype, comp_type, \ - midout_type_id); \ - break; \ - case Mode::AVERAGE: \ - DISPATCH_WINDOW(MeanInPooler, NeonMeanPooler, dtype, ctype, \ - comp_type, midout_type_id); \ - break; \ - default: \ - megdnn_assert(0, "unsupport pooling mode"); \ - break; \ +#define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \ + switch (param.mode) { \ + case Mode::MAX: \ + DISPATCH_WINDOW( \ + MaxPooler, NeonMaxPooler, dtype, ctype, comp_type, \ + midout_type_id); \ + break; \ + case Mode::AVERAGE: \ + DISPATCH_WINDOW( \ + MeanInPooler, NeonMeanPooler, dtype, ctype, comp_type, \ + midout_type_id); \ + break; \ + default: \ + megdnn_assert(0, "unsupport pooling mode"); \ + break; \ } if (param.src_type == dtype::Float32{}) { @@ -202,14 +200,13 @@ bool PoolingImpl::AlgoFilter2ModexStride2::usable( bool avaible = (param.src_type.category() == DTypeCategory::FLOAT || param.src_type.category() == DTypeCategory::QUANTIZED) && - param.format == Param::Format::NCHW && FH == FW && - SH == SW && FH == 2 && SH == 2; + param.format == Param::Format::NCHW && FH == FW && SH == SW && + FH == 2 && SH == 2; bool is_mode_ok = (param.mode == Mode::MAX || param.mode == Mode::AVERAGE); return avaible && is_mode_ok; } -void PoolingImpl::AlgoFilter2ModexStride2::exec( - const PoolingKernParam& param) const { +void PoolingImpl::AlgoFilter2ModexStride2::exec(const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; auto N = param.n, C = param.ic; @@ -218,24 +215,24 @@ void PoolingImpl::AlgoFilter2ModexStride2::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(Pooler, mode, midout_type_id) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(1), \ - midout_iv(midout_type_id), Pooler::MIDOUT_CASE_NUM) { \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ - src_dtype = param.src_type](size_t index, size_t) { \ - size_t n = index / C; \ - size_t c = index % C; \ - do_pooling_2x2( \ - static_cast(src_ptr) + \ - n * C * IH * IW + c * IH * IW, \ - static_cast(dst_ptr) + \ - n * C * OH * OW + c * OH * OW, \ - src_dtype, IH, IW, OH, OW, PH, PW); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ - run); \ - } \ +#define DISPATCH_FUNC(Pooler, mode, midout_type_id) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_pooling, midout_iv(1), midout_iv(midout_type_id), \ + Pooler::MIDOUT_CASE_NUM) { \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + src_dtype = param.src_type](size_t index, size_t) { \ + size_t n = index / C; \ + size_t c = index % C; \ + do_pooling_2x2( \ + static_cast(src_ptr) + \ + n * C * IH * IW + c * IH * IW, \ + static_cast(dst_ptr) + n * C * OH * OW + \ + c * OH * OW, \ + src_dtype, IH, IW, OH, OW, PH, PW); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ + } \ MIDOUT_END() #define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \ @@ -275,16 +272,14 @@ bool PoolingImpl::AlgoFilter3MaxStride2::usable( const PoolingKernSizeParam& param) const { bool avaible = (param.src_type.category() == DTypeCategory::FLOAT || param.src_type.category() == DTypeCategory::QUANTIZED) && - param.format == Param::Format::NCHW && - param.mode == Mode::MAX && param.filter[0] == 3 && - param.filter[1] == 3 && param.stride[0] == 2 && - param.stride[1] == 2 && param.isz[0] >= 2 && + param.format == Param::Format::NCHW && param.mode == Mode::MAX && + param.filter[0] == 3 && param.filter[1] == 3 && + param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 && param.isz[1] >= 2; return avaible; } -void PoolingImpl::AlgoFilter3MaxStride2::exec( - const PoolingKernParam& param) const { +void PoolingImpl::AlgoFilter3MaxStride2::exec(const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; auto N = param.n, C = param.ic; @@ -294,29 +289,24 @@ void PoolingImpl::AlgoFilter3MaxStride2::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, midout_type_id) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ - midout_iv(midout_type_id)) { \ - WorkspaceBundle wbundle = get_bundle(param); \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ - wbundle = wbundle, \ - workspace_ptr = param.workspace()]( \ - size_t index, size_t thread_id) { \ - auto ws = wbundle; \ - ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ - size_t n = index / C; \ - size_t c = index % C; \ - do_max_pooling_3x3_s2x2_##func##_NEON( \ - static_cast(src_ptr) + n * C * IH * IW + \ - c * IH * IW, \ - static_cast(dst_ptr) + n * C * OH * OW + \ - c * OH * OW, \ - IH, IW, OH, OW, PH, PW, ws); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ - run); \ - } \ +#define DISPATCH_FUNC(type, func, midout_type_id) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), midout_iv(midout_type_id)) { \ + WorkspaceBundle wbundle = get_bundle(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ + size_t index, size_t thread_id) { \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_max_pooling_3x3_s2x2_##func##_NEON( \ + static_cast(src_ptr) + n * C * IH * IW + c * IH * IW, \ + static_cast(dst_ptr) + n * C * OH * OW + c * OH * OW, IH, \ + IW, OH, OW, PH, PW, ws); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ + } \ MIDOUT_END(); if (param.src_type == dtype::Float32{}) { @@ -335,16 +325,14 @@ void PoolingImpl::AlgoFilter3MaxStride2::exec( bool PoolingImpl::AlgoFilter3AverageStride2::usable( const PoolingKernSizeParam& param) const { bool avaible = (param.src_type.category() == DTypeCategory::FLOAT) && - param.format == Param::Format::NCHW && - param.mode == Mode::AVERAGE && param.filter[0] == 3 && - param.filter[1] == 3 && param.stride[0] == 2 && - param.stride[1] == 2 && param.isz[0] >= 2 && + param.format == Param::Format::NCHW && param.mode == Mode::AVERAGE && + param.filter[0] == 3 && param.filter[1] == 3 && + param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 && param.isz[1] >= 2; return avaible; } -void PoolingImpl::AlgoFilter3AverageStride2::exec( - const PoolingKernParam& param) const { +void PoolingImpl::AlgoFilter3AverageStride2::exec(const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; auto N = param.n, C = param.ic; @@ -354,29 +342,24 @@ void PoolingImpl::AlgoFilter3AverageStride2::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, MEGDNN_SIMD_WIDTH, midout_type_id) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(3), \ - midout_iv(midout_type_id)) { \ - WorkspaceBundle wbundle = get_bundle(param); \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ - wbundle = wbundle, \ - workspace_ptr = param.workspace()]( \ - size_t index, size_t thread_id) { \ - auto ws = wbundle; \ - ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ - size_t n = index / C; \ - size_t c = index % C; \ - do_average_pooling_3x3_s2x2_NEON( \ - static_cast(src_ptr) + n * C * IH * IW + \ - c * IH * IW, \ - static_cast(dst_ptr) + n * C * OH * OW + \ - c * OH * OW, \ - IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ - run); \ - } \ +#define DISPATCH_FUNC(type, MEGDNN_SIMD_WIDTH, midout_type_id) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(3), midout_iv(midout_type_id)) { \ + WorkspaceBundle wbundle = get_bundle(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ + size_t index, size_t thread_id) { \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_average_pooling_3x3_s2x2_NEON( \ + static_cast(src_ptr) + n * C * IH * IW + c * IH * IW, \ + static_cast(dst_ptr) + n * C * OH * OW + c * OH * OW, IH, \ + IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ + } \ MIDOUT_END(); if (param.src_type == dtype::Float32{}) { DISPATCH_FUNC(dt_float32, 4, 0); @@ -397,14 +380,12 @@ bool PoolingImpl::AlgoFilter4MaxStride2::usable( bool avaible = (param.src_type.category() == DTypeCategory::FLOAT || param.src_type.category() == DTypeCategory::QUANTIZED) && - param.format == Param::Format::NCHW && - param.mode == Mode::MAX && FH == 4 && FW == 4 && SH == 2 && - SW == 2 && OH >= 2 && OW >= 2; + param.format == Param::Format::NCHW && param.mode == Mode::MAX && + FH == 4 && FW == 4 && SH == 2 && SW == 2 && OH >= 2 && OW >= 2; return avaible; } -void PoolingImpl::AlgoFilter4MaxStride2::exec( - const PoolingKernParam& param) const { +void PoolingImpl::AlgoFilter4MaxStride2::exec(const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; auto N = param.n, C = param.ic; @@ -414,24 +395,20 @@ void PoolingImpl::AlgoFilter4MaxStride2::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, midout_type_id) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(4), \ - midout_iv(midout_type_id)) { \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ - src_dtype = param.src_type](size_t index, size_t) { \ - size_t n = index / C; \ - size_t c = index % C; \ - do_max_pooling_w4x4_s2x2_##func##_NEON( \ - static_cast(src_ptr) + n * C * IH * IW + \ - c * IH * IW, \ - static_cast(dst_ptr) + n * C * OH * OW + \ - c * OH * OW, \ - src_dtype, IH, IW, OH, OW, PH, PW); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ - run); \ - } \ +#define DISPATCH_FUNC(type, func, midout_type_id) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(4), midout_iv(midout_type_id)) { \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ + src_dtype = param.src_type](size_t index, size_t) { \ + size_t n = index / C; \ + size_t c = index % C; \ + do_max_pooling_w4x4_s2x2_##func##_NEON( \ + static_cast(src_ptr) + n * C * IH * IW + c * IH * IW, \ + static_cast(dst_ptr) + n * C * OH * OW + c * OH * OW, \ + src_dtype, IH, IW, OH, OW, PH, PW); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ + } \ MIDOUT_END(); if (param.src_type == dtype::Float32{}) { @@ -457,14 +434,12 @@ bool PoolingImpl::AlgoFilter5MaxStride2::usable( bool avaible = (param.src_type.category() == DTypeCategory::FLOAT || param.src_type.category() == DTypeCategory::QUANTIZED) && - param.format == Param::Format::NCHW && - param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == 2 && - SW == 2 && OH >= 2 && OW >= 2; + param.format == Param::Format::NCHW && param.mode == Mode::MAX && + FH == 5 && FW == 5 && SH == 2 && SW == 2 && OH >= 2 && OW >= 2; return avaible; } -void PoolingImpl::AlgoFilter5MaxStride2::exec( - const PoolingKernParam& param) const { +void PoolingImpl::AlgoFilter5MaxStride2::exec(const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; auto N = param.n, C = param.ic; @@ -474,29 +449,24 @@ void PoolingImpl::AlgoFilter5MaxStride2::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(dtype, type, midout_type_id, MEGDNN_SIMD_WIDTH) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(5), \ - midout_iv(midout_type_id)) { \ - WorkspaceBundle wbundle = get_bundle(param); \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ - wbundle = wbundle, \ - workspace_ptr = param.workspace()]( \ - size_t index, size_t thread_id) { \ - auto ws = wbundle; \ - ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ - size_t n = index / C; \ - size_t c = index % C; \ - do_max_pooling_w5x5_s2x2_NEON( \ - static_cast(src_ptr) + n * C * IH * IW + \ - c * IH * IW, \ - static_cast(dst_ptr) + n * C * OH * OW + \ - c * OH * OW, \ - IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ - run); \ - } \ +#define DISPATCH_FUNC(dtype, type, midout_type_id, MEGDNN_SIMD_WIDTH) \ + MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(5), midout_iv(midout_type_id)) { \ + WorkspaceBundle wbundle = get_bundle(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ + size_t index, size_t thread_id) { \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_max_pooling_w5x5_s2x2_NEON( \ + static_cast(src_ptr) + n * C * IH * IW + c * IH * IW, \ + static_cast(dst_ptr) + n * C * OH * OW + c * OH * OW, IH, \ + IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ + } \ MIDOUT_END(); if (param.src_type == dtype::Float32{}) { @@ -523,14 +493,12 @@ bool PoolingImpl::AlgoInt8Filter2MaxStride2::usable( auto PW = param.padding[1]; bool avaible = param.src_type == dtype::Int8() && - param.format == Param::Format::NCHW && - param.mode == Mode::MAX && SH == 2 && SW == 2 && PH == 0 && - PW == 0 && FH == 2 && FW == 2; + param.format == Param::Format::NCHW && param.mode == Mode::MAX && + SH == 2 && SW == 2 && PH == 0 && PW == 0 && FH == 2 && FW == 2; return avaible; } -void PoolingImpl::AlgoInt8Filter2MaxStride2::exec( - const PoolingKernParam& param) const { +void PoolingImpl::AlgoInt8Filter2MaxStride2::exec(const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; auto N = param.n, C = param.ic; @@ -542,13 +510,12 @@ void PoolingImpl::AlgoInt8Filter2MaxStride2::exec( auto run = [C, IH, IW, OH, OW, src_ptr, dst_ptr](size_t index, size_t) { size_t n = index / C; size_t c = index % C; - pooling_max_w2x2_s2x2(src_ptr + n * C * IH * IW + c * IH * IW, - dst_ptr + n * C * OH * OW + c * OH * OW, 1, 1, - IH, IW, OH, OW); + pooling_max_w2x2_s2x2( + src_ptr + n * C * IH * IW + c * IH * IW, + dst_ptr + n * C * OH * OW + c * OH * OW, 1, 1, IH, IW, OH, OW); }; MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C, - run); + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C, run); } MIDOUT_END(); } @@ -563,14 +530,12 @@ bool PoolingImpl::AlgoInt8Filter3MaxStride2::usable( auto IW = param.isz[1]; bool avaible = param.src_type == dtype::Int8() && - param.format == Param::Format::NCHW && - param.mode == Mode::MAX && FH == 3 && FW == 3 && SH == 2 && - SW == 2 && IH >= 2 && IW >= 2; + param.format == Param::Format::NCHW && param.mode == Mode::MAX && + FH == 3 && FW == 3 && SH == 2 && SW == 2 && IH >= 2 && IW >= 2; return avaible; } -void PoolingImpl::AlgoInt8Filter3MaxStride2::exec( - const PoolingKernParam& param) const { +void PoolingImpl::AlgoInt8Filter3MaxStride2::exec(const PoolingKernParam& param) const { auto IH = param.isz[0], IW = param.isz[1]; auto OH = param.osz[0], OW = param.osz[1]; auto N = param.n, C = param.ic; @@ -582,8 +547,7 @@ void PoolingImpl::AlgoInt8Filter3MaxStride2::exec( MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(7)) { WorkspaceBundle wbundle = get_bundle(param); - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, - wbundle = wbundle, + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, workspace_ptr = param.workspace()]( size_t index, size_t thread_id) { auto ws = wbundle; @@ -592,12 +556,11 @@ void PoolingImpl::AlgoInt8Filter3MaxStride2::exec( size_t c = index % C; do_max_pooling_3x3_s2x2_int8_NEON( src_ptr + n * C * IH * IW + c * IH * IW, - dst_ptr + n * C * OH * OW + c * OH * OW, IH, IW, OH, OW, PH, - PW, ws); + dst_ptr + n * C * OH * OW + c * OH * OW, IH, IW, OH, OW, PH, PW, + ws); }; MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C, - run); + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N * C, run); } MIDOUT_END(); } @@ -616,8 +579,8 @@ bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable( FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2); //! Int8 not support average, because its round mode is different form //! qint8 - avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && - param.mode == Mode::AVERAGE); + avaible &= + !(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE); return avaible; } @@ -633,45 +596,44 @@ void PoolingImpl::AlgoFilter3ModexStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, i, mode) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(8), \ - midout_iv(#type #i##_hash)) { \ - WorkspaceBundle wbundle = get_bundle_nchw44(param); \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ - wbundle = wbundle, \ - workspace_ptr = param.workspace()]( \ - size_t index, size_t thread_id) { \ - auto ws = wbundle; \ - ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ - size_t n = index / C; \ - size_t c = index % C; \ - do_##mode##_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ - static_cast(src_ptr) + n * C * IH * IW * 4 + \ - c * IH * IW * 4, \ - static_cast(dst_ptr) + n * C * OH * OW * 4 + \ - c * OH * OW * 4, \ - IH, IW, OH, OW, PH, PW, ws); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ - run); \ - } \ +#define DISPATCH_FUNC(type, func, i, mode) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_pooling, midout_iv(8), midout_iv(#type #i##_hash)) { \ + WorkspaceBundle wbundle = get_bundle_nchw44(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ + size_t index, size_t thread_id) { \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_##mode##_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ + static_cast(src_ptr) + n * C * IH * IW * 4 + \ + c * IH * IW * 4, \ + static_cast(dst_ptr) + n * C * OH * OW * 4 + \ + c * OH * OW * 4, \ + IH, IW, OH, OW, PH, PW, ws); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ + } \ MIDOUT_END(); -#define DISPATCH_MODE(type, func, stride) \ - switch (param.mode) { \ - case Mode::MAX: { \ - DISPATCH_FUNC(type, func, stride, max); \ - break; \ - } \ - case Mode::AVERAGE: { \ - DISPATCH_FUNC(type, func, stride, avg); \ - break; \ - } \ - default: \ - megdnn_throw(ssprintf("Unsupport pooling mode %d", \ - static_cast(param.mode)) \ - .c_str()); \ +#define DISPATCH_MODE(type, func, stride) \ + switch (param.mode) { \ + case Mode::MAX: { \ + DISPATCH_FUNC(type, func, stride, max); \ + break; \ + } \ + case Mode::AVERAGE: { \ + DISPATCH_FUNC(type, func, stride, avg); \ + break; \ + } \ + default: \ + megdnn_throw( \ + ssprintf( \ + "Unsupport pooling mode %d", static_cast(param.mode)) \ + .c_str()); \ } #define DISPATCH_STRIDE(type, func) \ @@ -709,8 +671,8 @@ bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable( FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2); //! Int8 not support average, because its round mode is different form //! qint8 - avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && - param.mode == Mode::AVERAGE); + avaible &= + !(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE); return avaible; } @@ -726,45 +688,44 @@ void PoolingImpl::AlgoFilter2ModexStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, i, mode) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(9), \ - midout_iv(#func #i##_hash)) { \ - WorkspaceBundle wbundle = get_bundle_nchw44(param); \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ - wbundle = wbundle, \ - workspace_ptr = param.workspace()]( \ - size_t index, size_t thread_id) { \ - auto ws = wbundle; \ - ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ - size_t n = index / C; \ - size_t c = index % C; \ - do_##mode##_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ - static_cast(src_ptr) + n * C * IH * IW * 4 + \ - c * IH * IW * 4, \ - static_cast(dst_ptr) + n * C * OH * OW * 4 + \ - c * OH * OW * 4, \ - IH, IW, OH, OW, PH, PW, ws); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ - run); \ - } \ +#define DISPATCH_FUNC(type, func, i, mode) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_pooling, midout_iv(9), midout_iv(#func #i##_hash)) { \ + WorkspaceBundle wbundle = get_bundle_nchw44(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ + size_t index, size_t thread_id) { \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_##mode##_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ + static_cast(src_ptr) + n * C * IH * IW * 4 + \ + c * IH * IW * 4, \ + static_cast(dst_ptr) + n * C * OH * OW * 4 + \ + c * OH * OW * 4, \ + IH, IW, OH, OW, PH, PW, ws); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ + } \ MIDOUT_END(); -#define DISPATCH_MODE(type, func, stride) \ - switch (param.mode) { \ - case Mode::MAX: { \ - DISPATCH_FUNC(type, func, stride, max); \ - break; \ - } \ - case Mode::AVERAGE: { \ - DISPATCH_FUNC(type, func, stride, avg); \ - break; \ - } \ - default: \ - megdnn_throw(ssprintf("Unsupport pooling mode %d", \ - static_cast(param.mode)) \ - .c_str()); \ +#define DISPATCH_MODE(type, func, stride) \ + switch (param.mode) { \ + case Mode::MAX: { \ + DISPATCH_FUNC(type, func, stride, max); \ + break; \ + } \ + case Mode::AVERAGE: { \ + DISPATCH_FUNC(type, func, stride, avg); \ + break; \ + } \ + default: \ + megdnn_throw( \ + ssprintf( \ + "Unsupport pooling mode %d", static_cast(param.mode)) \ + .c_str()); \ } #define DISPATCH_STRIDE(type, func) \ @@ -803,8 +764,8 @@ bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable( //! Int8 not support average, because its round mode is different form //! qint8 - avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && - param.mode == Mode::AVERAGE); + avaible &= + !(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE); return avaible; } @@ -820,45 +781,44 @@ void PoolingImpl::AlgoFilter4ModexStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, i, mode) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(10), \ - midout_iv(#func #i##_hash)) { \ - WorkspaceBundle wbundle = get_bundle_nchw44(param); \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ - wbundle = wbundle, \ - workspace_ptr = param.workspace()]( \ - size_t index, size_t thread_id) { \ - auto ws = wbundle; \ - ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ - size_t n = index / C; \ - size_t c = index % C; \ - do_##mode##_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ - static_cast(src_ptr) + n * C * IH * IW * 4 + \ - c * IH * IW * 4, \ - static_cast(dst_ptr) + n * C * OH * OW * 4 + \ - c * OH * OW * 4, \ - IH, IW, OH, OW, PH, PW, ws); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ - run); \ - } \ +#define DISPATCH_FUNC(type, func, i, mode) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_pooling, midout_iv(10), midout_iv(#func #i##_hash)) { \ + WorkspaceBundle wbundle = get_bundle_nchw44(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ + size_t index, size_t thread_id) { \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_##mode##_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ + static_cast(src_ptr) + n * C * IH * IW * 4 + \ + c * IH * IW * 4, \ + static_cast(dst_ptr) + n * C * OH * OW * 4 + \ + c * OH * OW * 4, \ + IH, IW, OH, OW, PH, PW, ws); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ + } \ MIDOUT_END(); -#define DISPATCH_MODE(type, func, stride) \ - switch (param.mode) { \ - case Mode::MAX: { \ - DISPATCH_FUNC(type, func, stride, max); \ - break; \ - } \ - case Mode::AVERAGE: { \ - DISPATCH_FUNC(type, func, stride, avg); \ - break; \ - } \ - default: \ - megdnn_throw(ssprintf("Unsupport pooling mode %d", \ - static_cast(param.mode)) \ - .c_str()); \ +#define DISPATCH_MODE(type, func, stride) \ + switch (param.mode) { \ + case Mode::MAX: { \ + DISPATCH_FUNC(type, func, stride, max); \ + break; \ + } \ + case Mode::AVERAGE: { \ + DISPATCH_FUNC(type, func, stride, avg); \ + break; \ + } \ + default: \ + megdnn_throw( \ + ssprintf( \ + "Unsupport pooling mode %d", static_cast(param.mode)) \ + .c_str()); \ } #define DISPATCH_STRIDE(type, func) \ @@ -896,8 +856,8 @@ bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable( FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2); //! Int8 not support average, because its round mode is different form //! qint8 - avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 && - param.mode == Mode::AVERAGE); + avaible &= + !(param.src_type.enumv() == DTypeEnum::Int8 && param.mode == Mode::AVERAGE); return avaible; } @@ -913,45 +873,44 @@ void PoolingImpl::AlgoFilter5ModexStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(type, func, i, mode) \ - MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(11), \ - midout_iv(#func #i##_hash)) { \ - WorkspaceBundle wbundle = get_bundle_nchw44(param); \ - auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ - wbundle = wbundle, \ - workspace_ptr = param.workspace()]( \ - size_t index, size_t thread_id) { \ - auto ws = wbundle; \ - ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ - size_t n = index / C; \ - size_t c = index % C; \ - do_##mode##_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ - static_cast(src_ptr) + n * C * IH * IW * 4 + \ - c * IH * IW * 4, \ - static_cast(dst_ptr) + n * C * OH * OW * 4 + \ - c * OH * OW * 4, \ - IH, IW, OH, OW, PH, PW, ws); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ - run); \ - } \ +#define DISPATCH_FUNC(type, func, i, mode) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_pooling, midout_iv(11), midout_iv(#func #i##_hash)) { \ + WorkspaceBundle wbundle = get_bundle_nchw44(param); \ + auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, wbundle = wbundle, \ + workspace_ptr = param.workspace()]( \ + size_t index, size_t thread_id) { \ + auto ws = wbundle; \ + ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \ + size_t n = index / C; \ + size_t c = index % C; \ + do_##mode##_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ + static_cast(src_ptr) + n * C * IH * IW * 4 + \ + c * IH * IW * 4, \ + static_cast(dst_ptr) + n * C * OH * OW * 4 + \ + c * OH * OW * 4, \ + IH, IW, OH, OW, PH, PW, ws); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ + } \ MIDOUT_END(); -#define DISPATCH_MODE(type, func, stride) \ - switch (param.mode) { \ - case Mode::MAX: { \ - DISPATCH_FUNC(type, func, stride, max); \ - break; \ - } \ - case Mode::AVERAGE: { \ - DISPATCH_FUNC(type, func, stride, avg); \ - break; \ - } \ - default: \ - megdnn_throw(ssprintf("Unsupport pooling mode %d", \ - static_cast(param.mode)) \ - .c_str()); \ +#define DISPATCH_MODE(type, func, stride) \ + switch (param.mode) { \ + case Mode::MAX: { \ + DISPATCH_FUNC(type, func, stride, max); \ + break; \ + } \ + case Mode::AVERAGE: { \ + DISPATCH_FUNC(type, func, stride, avg); \ + break; \ + } \ + default: \ + megdnn_throw( \ + ssprintf( \ + "Unsupport pooling mode %d", static_cast(param.mode)) \ + .c_str()); \ } #define DISPATCH_STRIDE(type, func) \ diff --git a/dnn/src/arm_common/pooling/algo.h b/dnn/src/arm_common/pooling/algo.h index 71506f7d..d0bbfa39 100644 --- a/dnn/src/arm_common/pooling/algo.h +++ b/dnn/src/arm_common/pooling/algo.h @@ -22,9 +22,7 @@ using AlgoBase = PoolingImpl::AlgoBase; class PoolingImpl::AlgoFilterxModexStride1 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; const char* name() const override { return "ARM_POOLING_STRIDE1"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; @@ -33,9 +31,7 @@ public: class PoolingImpl::AlgoFilter2ModexStride2 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; const char* name() const override { return "ARM_POOLING_STRIDE2"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; @@ -43,9 +39,7 @@ public: }; class PoolingImpl::AlgoFilter3MaxStride2 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; const char* name() const override { return "ARM_POOLING_FILTER3_MAX"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; @@ -54,9 +48,7 @@ public: class PoolingImpl::AlgoFilter3AverageStride2 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; const char* name() const override { return "ARM_POOLING_FILTER3_AVERAGE"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; @@ -65,9 +57,7 @@ public: class PoolingImpl::AlgoFilter4MaxStride2 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; const char* name() const override { return "ARM_POOLING_FILTER4_MAX"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; @@ -76,9 +66,7 @@ public: class PoolingImpl::AlgoFilter5MaxStride2 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; const char* name() const override { return "ARM_POOLING_FILTER5_MAX"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; @@ -87,9 +75,7 @@ public: class PoolingImpl::AlgoInt8Filter2MaxStride2 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; const char* name() const override { return "ARM_POOLING_INT8_FILTER2X2"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; @@ -98,9 +84,7 @@ public: class PoolingImpl::AlgoInt8Filter3MaxStride2 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; const char* name() const override { return "ARM_POOLING_INT8_FILTER3X3"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; @@ -109,10 +93,10 @@ public: class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; - const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; + const char* name() const override { + return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; + } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; MEGDNN_DECL_ALGO_TYPE(ARM_Filter3ModexStridexNCHW44) @@ -120,10 +104,10 @@ public: class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; - const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; + const char* name() const override { + return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; + } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; MEGDNN_DECL_ALGO_TYPE(ARM_Filter2ModexStridexNCHW44) @@ -131,10 +115,10 @@ public: class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; - const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; + const char* name() const override { + return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; + } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; MEGDNN_DECL_ALGO_TYPE(ARM_Filter4ModexStridexNCHW44) @@ -142,29 +126,27 @@ public: class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; - const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; + const char* name() const override { + return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; + } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; MEGDNN_DECL_ALGO_TYPE(ARM_Filter5ModexStridexNCHW44) }; class PoolingImpl::AlgoFp32ModexStridexNCHW44 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; - const char* name() const override { return "ARM_POOLING_FP32_MODEX_STRIDEX_NCHW44"; } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; + const char* name() const override { + return "ARM_POOLING_FP32_MODEX_STRIDEX_NCHW44"; + } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; MEGDNN_DECL_ALGO_TYPE(ARM_Fp32ModexStridexNCHW44) }; class PoolingImpl::AlgoFallback final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - }; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; const char* name() const override { return "FALLBACK_POOLING"; } bool usable(const PoolingKernSizeParam&) const override { return true; } void exec(const PoolingKernParam&) const override {} @@ -172,14 +154,12 @@ public: }; WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); -WorkspaceBundle get_bundle_nchw44( - const PoolingImpl::PoolingKernSizeParam& param); +WorkspaceBundle get_bundle_nchw44(const PoolingImpl::PoolingKernSizeParam& param); -const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, - size_t& IH2, size_t& IW2, size_t PH, size_t PW, - const WorkspaceBundle& ws, bool is_max_mode); +const int8_t* handle_padding( + const int8_t* src, size_t IH, size_t IW, size_t& IH2, size_t& IW2, size_t PH, + size_t PW, const WorkspaceBundle& ws, bool is_max_mode); } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp b/dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp index 18c058da..3c341c68 100644 --- a/dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp +++ b/dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp @@ -31,8 +31,7 @@ bool PoolingImpl::AlgoFp32ModexStridexNCHW44::usable( param.format == Param::Format::NCHW44 && (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && fh == fw && sh == sw; - bool size_ok = ((fh == 2 || fh == 3 || fh == 4 || fh == 5) && - (sh == 1 || sh == 2)); + bool size_ok = ((fh == 2 || fh == 3 || fh == 4 || fh == 5) && (sh == 1 || sh == 2)); size_ok |= ((fh == 9 || fh == 13) && (sh == 1)); return avaible && size_ok; @@ -54,34 +53,32 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( void* src_ptr = param.src_ptr; void* dst_ptr = param.dst_ptr; -#define DISPATCH_FUNC(filter, stride, mode) \ - MIDOUT_BEGIN(megdnn_arm_common_fp32_pooling_nchw44, midout_iv(0), \ - midout_iv(#filter #stride #mode##_hash)) { \ - auto run = [ih, iw, oh, ow, ph, pw, src_ptr, dst_ptr](size_t index, \ - size_t) { \ - const int c_idx = index; \ - pooling_fp32_nchw44( \ - static_cast(src_ptr) + c_idx * ih * iw * 4, \ - static_cast(dst_ptr) + c_idx * oh * ow * 4, ih, \ - iw, oh, ow, ph, pw); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast<::megdnn::naive::HandleImpl*>(param.handle), \ - n* ic, run); \ - } \ +#define DISPATCH_FUNC(filter, stride, mode) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_fp32_pooling_nchw44, midout_iv(0), \ + midout_iv(#filter #stride #mode##_hash)) { \ + auto run = [ih, iw, oh, ow, ph, pw, src_ptr, dst_ptr](size_t index, size_t) { \ + const int c_idx = index; \ + pooling_fp32_nchw44( \ + static_cast(src_ptr) + c_idx * ih * iw * 4, \ + static_cast(dst_ptr) + c_idx * oh * ow * 4, ih, iw, oh, \ + ow, ph, pw); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), n* ic, run); \ + } \ MIDOUT_END(); -#define DISPATCH_MODE(filter, stride) \ - switch (param.mode) { \ - case PoolingBase::Mode::MAX: \ - DISPATCH_FUNC(filter, stride, PoolingBase::Mode::MAX); \ - break; \ - case PoolingBase::Mode::AVERAGE: \ - DISPATCH_FUNC(filter, stride, PoolingBase::Mode::AVERAGE); \ - break; \ - default: \ - megdnn_assert(0, "invalid mode %u", \ - static_cast(param.mode)); \ +#define DISPATCH_MODE(filter, stride) \ + switch (param.mode) { \ + case PoolingBase::Mode::MAX: \ + DISPATCH_FUNC(filter, stride, PoolingBase::Mode::MAX); \ + break; \ + case PoolingBase::Mode::AVERAGE: \ + DISPATCH_FUNC(filter, stride, PoolingBase::Mode::AVERAGE); \ + break; \ + default: \ + megdnn_assert(0, "invalid mode %u", static_cast(param.mode)); \ } #define DISPATCH_STRIDE(filter) \ @@ -96,7 +93,7 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( megdnn_assert(0, "invalid stride %d", sh); \ } -#define DISPATCH_STRIDE_1(filter) \ +#define DISPATCH_STRIDE_1(filter) \ switch (sh) { \ case 1: \ DISPATCH_MODE(filter, 1); \ diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.cpp b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.cpp index c0489df8..2de5f8e8 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.cpp +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.cpp @@ -11,26 +11,25 @@ #include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.h" #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#include #include +#include #include +#include #include "src/arm_common/simd_macro/marm_neon.h" -#include namespace megdnn { namespace arm_common { #define MEGDNN_SIMD_WIDTH 8 -void do_max_pooling_3x3_s2x2_float16_NEON(const __fp16* src, __fp16* dst, - size_t IH_, size_t IW_, size_t OH_, - size_t OW_, size_t PH_, size_t PW_, - const WorkspaceBundle& ws) { +void do_max_pooling_3x3_s2x2_float16_NEON( + const __fp16* src, __fp16* dst, size_t IH_, size_t IW_, size_t OH_, size_t OW_, + size_t PH_, size_t PW_, const WorkspaceBundle& ws) { int IH = IH_, IW = IW_, OH = OH_, OW = OW_, PH = PH_, PW = PW_; // cache[i] stores the answer of the i-th line after // pooling along the W dimension. - __fp16* cache[3] = {static_cast<__fp16*>(ws.get(0)), - static_cast<__fp16*>(ws.get(1)), - static_cast<__fp16*>(ws.get(2))}; + __fp16* cache[3] = { + static_cast<__fp16*>(ws.get(0)), static_cast<__fp16*>(ws.get(1)), + static_cast<__fp16*>(ws.get(2))}; __fp16* odd = static_cast<__fp16*>(ws.get(3)); __fp16* even = static_cast<__fp16*>(ws.get(4)); int ih_next = 0; @@ -98,8 +97,7 @@ void do_max_pooling_3x3_s2x2_float16_NEON(const __fp16* src, __fp16* dst, vst1q_f16(cache[0] + ow, d); } } else { - for (; ow + MEGDNN_SIMD_WIDTH <= OW_to; - ow += MEGDNN_SIMD_WIDTH) { + for (; ow + MEGDNN_SIMD_WIDTH <= OW_to; ow += MEGDNN_SIMD_WIDTH) { float16x8_t d, s0, s1, s2; s0 = vld1q_f16(even + ow - (PW >> 1)); s1 = vld1q_f16(odd + ow - (PW >> 1)); @@ -129,15 +127,13 @@ void do_max_pooling_3x3_s2x2_float16_NEON(const __fp16* src, __fp16* dst, vst1q_f16(dptr + ow, d); } for (; ow < OW; ++ow) { - dptr[ow] = std::max(std::max(cache[0][ow], cache[1][ow]), - cache[2][ow]); + dptr[ow] = std::max(std::max(cache[0][ow], cache[1][ow]), cache[2][ow]); } } else { std::memcpy(dptr, cache[0], sizeof(__fp16) * OW); for (int i = 1; i < ih_to - ih_from; ++i) { int ow = 0; - for (; ow + MEGDNN_SIMD_WIDTH <= OW; - ow += MEGDNN_SIMD_WIDTH) { + for (; ow + MEGDNN_SIMD_WIDTH <= OW; ow += MEGDNN_SIMD_WIDTH) { float16x8_t d, s; s = vld1q_f16(cache[i] + ow); d = vld1q_f16(dptr + ow); diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.h b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.h index 92eacbc4..967f61e2 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.h +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_float16.h @@ -18,10 +18,9 @@ namespace megdnn { namespace arm_common { -void do_max_pooling_3x3_s2x2_float16_NEON(const __fp16* src, __fp16* dst, - size_t IH, size_t IW, size_t OH, - size_t OW, size_t PH, size_t PW, - const WorkspaceBundle& ws); +void do_max_pooling_3x3_s2x2_float16_NEON( + const __fp16* src, __fp16* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws); } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.cpp b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.cpp index 7f1b7145..6da5e5c8 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.cpp +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.cpp @@ -10,25 +10,24 @@ */ #include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" -#include #include +#include #include +#include #include "src/arm_common/simd_macro/marm_neon.h" -#include namespace megdnn { namespace arm_common { -void do_max_pooling_3x3_s2x2_int8_NEON(const int8_t* src, int8_t* dst, - size_t IH_, size_t IW_, size_t OH_, - size_t OW_, size_t PH_, size_t PW_, - const WorkspaceBundle& ws) { +void do_max_pooling_3x3_s2x2_int8_NEON( + const int8_t* src, int8_t* dst, size_t IH_, size_t IW_, size_t OH_, size_t OW_, + size_t PH_, size_t PW_, const WorkspaceBundle& ws) { int IH = IH_, IW = IW_, OH = OH_, OW = OW_, PH = PH_, PW = PW_; // cache[i] stores the answer of the i-th line after // pooling along the W dimension. - int8_t* cache[3] = {static_cast(ws.get(0)), - static_cast(ws.get(1)), - static_cast(ws.get(2))}; + int8_t* cache[3] = { + static_cast(ws.get(0)), static_cast(ws.get(1)), + static_cast(ws.get(2))}; int8_t* odd = static_cast(ws.get(3)); int8_t* even = static_cast(ws.get(4)); @@ -127,8 +126,7 @@ void do_max_pooling_3x3_s2x2_int8_NEON(const int8_t* src, int8_t* dst, vst1q_s8(dptr + ow, d); } for (; ow < OW; ++ow) { - dptr[ow] = std::max(std::max(cache[0][ow], cache[1][ow]), - cache[2][ow]); + dptr[ow] = std::max(std::max(cache[0][ow], cache[1][ow]), cache[2][ow]); } } else { std::memcpy(dptr, cache[0], sizeof(int8_t) * OW); @@ -149,16 +147,15 @@ void do_max_pooling_3x3_s2x2_int8_NEON(const int8_t* src, int8_t* dst, } } -void do_max_pooling_3x3_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, - size_t IH_, size_t IW_, size_t OH_, - size_t OW_, size_t PH_, size_t PW_, - const WorkspaceBundle& ws) { +void do_max_pooling_3x3_s2x2_uint8_NEON( + const uint8_t* src, uint8_t* dst, size_t IH_, size_t IW_, size_t OH_, + size_t OW_, size_t PH_, size_t PW_, const WorkspaceBundle& ws) { int IH = IH_, IW = IW_, OH = OH_, OW = OW_, PH = PH_, PW = PW_; // cache[i] stores the answer of the i-th line after // pooling along the W dimension. - uint8_t* cache[3] = {static_cast(ws.get(0)), - static_cast(ws.get(1)), - static_cast(ws.get(2))}; + uint8_t* cache[3] = { + static_cast(ws.get(0)), static_cast(ws.get(1)), + static_cast(ws.get(2))}; uint8_t* odd = static_cast(ws.get(3)); uint8_t* even = static_cast(ws.get(4)); @@ -257,8 +254,7 @@ void do_max_pooling_3x3_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, vst1q_u8(dptr + ow, d); } for (; ow < OW; ++ow) { - dptr[ow] = std::max(std::max(cache[0][ow], cache[1][ow]), - cache[2][ow]); + dptr[ow] = std::max(std::max(cache[0][ow], cache[1][ow]), cache[2][ow]); } } else { std::memcpy(dptr, cache[0], sizeof(uint8_t) * OW); diff --git a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h index ca834b80..3dc23913 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h +++ b/dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h @@ -16,15 +16,13 @@ namespace megdnn { namespace arm_common { -void do_max_pooling_3x3_s2x2_int8_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, size_t OH, - size_t OW, size_t PH, size_t PW, - const WorkspaceBundle& boudle); +void do_max_pooling_3x3_s2x2_int8_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& boudle); -void do_max_pooling_3x3_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, - size_t IH, size_t IW, size_t OH, - size_t OW, size_t PH, size_t PW, - const WorkspaceBundle& boudle); +void do_max_pooling_3x3_s2x2_uint8_NEON( + const uint8_t* src, uint8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& boudle); } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.cpp b/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.cpp index 83f61a16..4efc0ea6 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.cpp +++ b/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.cpp @@ -15,8 +15,9 @@ namespace megdnn { namespace arm_common { -void pooling_max_w2x2_s2x2(const int8_t* src, int8_t* dst, size_t N, size_t C, - size_t IH, size_t IW, size_t OH, size_t OW) { +void pooling_max_w2x2_s2x2( + const int8_t* src, int8_t* dst, size_t N, size_t C, size_t IH, size_t IW, + size_t OH, size_t OW) { for (size_t nc = 0; nc < N * C; ++nc) { for (size_t oh = 0; oh < OH; ++oh) { size_t ih = oh << 1; @@ -35,8 +36,8 @@ void pooling_max_w2x2_s2x2(const int8_t* src, int8_t* dst, size_t N, size_t C, dptr += 8; } for (; ow < OW; ++ow) { - dptr[0] = std::max(std::max(sptr0[0], sptr0[1]), - std::max(sptr1[0], sptr1[1])); + dptr[0] = std::max( + std::max(sptr0[0], sptr0[1]), std::max(sptr1[0], sptr1[1])); sptr0 += 2; sptr1 += 2; dptr += 1; @@ -49,4 +50,3 @@ void pooling_max_w2x2_s2x2(const int8_t* src, int8_t* dst, size_t N, size_t C, } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h b/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h index cc7d8197..c798be38 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h +++ b/dnn/src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h @@ -14,10 +14,10 @@ namespace megdnn { namespace arm_common { -void pooling_max_w2x2_s2x2(const int8_t* src, int8_t* dst, size_t N, size_t C, - size_t IH, size_t IW, size_t OH, size_t OW); +void pooling_max_w2x2_s2x2( + const int8_t* src, int8_t* dst, size_t N, size_t C, size_t IH, size_t IW, + size_t OH, size_t OW); } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.cpp b/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.cpp index e74dfd8b..5cac9ccd 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.cpp +++ b/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.cpp @@ -15,11 +15,9 @@ namespace megdnn { namespace arm_common { -void do_max_pooling_w4x4_s2x2_float_NEON(const dt_float32* src, dt_float32* dst, - DType src_dtype, const int IH, - const int IW, const int OH, - const int OW, const int PH, - const int PW) { +void do_max_pooling_w4x4_s2x2_float_NEON( + const dt_float32* src, dt_float32* dst, DType src_dtype, const int IH, + const int IW, const int OH, const int OW, const int PH, const int PW) { const int window = 4; const int stride = 2; using Pooler = MaxPooler<16, dt_float32, float, float>; @@ -27,15 +25,17 @@ void do_max_pooling_w4x4_s2x2_float_NEON(const dt_float32* src, dt_float32* dst, for (; oh < OH && -PH + stride * oh < 0; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH && -PH + stride * oh + window <= IH; ++oh) { int ow = 0; for (; ow < OW && -PW + stride * ow < 0; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } dt_float32 last_hf_res = -std::numeric_limits::infinity(); int ih = -PH + stride * oh, iw = -PW + stride * ow; @@ -46,8 +46,7 @@ void do_max_pooling_w4x4_s2x2_float_NEON(const dt_float32* src, dt_float32* dst, i3 = vld1q_f32(src + (ih + 3) * IW + iw); float32x4_t sum0 = vmaxq_f32(vmaxq_f32(i0, i1), vmaxq_f32(i2, i3)); float32x2_t t = vpmax_f32(vget_low_f32(sum0), vget_high_f32(sum0)); - dst[oh * OW + ow] = - std::max(vget_lane_f32(t, 0), vget_lane_f32(t, 1)); + dst[oh * OW + ow] = std::max(vget_lane_f32(t, 0), vget_lane_f32(t, 1)); last_hf_res = vget_lane_f32(t, 1); ow += 1; } @@ -60,29 +59,28 @@ void do_max_pooling_w4x4_s2x2_float_NEON(const dt_float32* src, dt_float32* dst, float32x4_t sum0 = vmaxq_f32(vmaxq_f32(i0, i1), vmaxq_f32(i2, i3)); float32x2_t t = vpmax_f32(vget_low_f32(sum0), vget_high_f32(sum0)); dst[oh * OW + ow + 0] = std::max(vget_lane_f32(t, 0), last_hf_res); - dst[oh * OW + ow + 1] = - std::max(vget_lane_f32(t, 0), vget_lane_f32(t, 1)); + dst[oh * OW + ow + 1] = std::max(vget_lane_f32(t, 0), vget_lane_f32(t, 1)); last_hf_res = vget_lane_f32(t, 1); } for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } } -void do_max_pooling_w4x4_s2x2_int8_NEON(const int8_t* src, int8_t* dst, - DType src_dtype, const int IH, - const int IW, const int OH, - const int OW, const int PH, - const int PW) { +void do_max_pooling_w4x4_s2x2_int8_NEON( + const int8_t* src, int8_t* dst, DType src_dtype, const int IH, const int IW, + const int OH, const int OW, const int PH, const int PW) { const int window = 4; const int stride = 2; using Pooler = MaxPooler<16, dt_qint8, int8_t, float>; @@ -90,15 +88,17 @@ void do_max_pooling_w4x4_s2x2_int8_NEON(const int8_t* src, int8_t* dst, for (; oh < OH && -PH + stride * oh < 0; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH && -PH + stride * oh + window <= IH; ++oh) { int ow = 0; for (; ow < OW && -PW + stride * ow < 0; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } int8_t last_res = std::numeric_limits::lowest(); int ih = -PH + stride * oh, iw = -PW + stride * ow; @@ -109,9 +109,8 @@ void do_max_pooling_w4x4_s2x2_int8_NEON(const int8_t* src, int8_t* dst, i3 = vld1q_s8(src + (ih + 3) * IW + iw); int8x16_t sum0 = vmaxq_s8(vmaxq_s8(i0, i1), vmaxq_s8(i2, i3)); int8x8_t t = vpmax_s8(vget_low_s8(sum0), vget_high_s8(sum0)); -#define cb(i) \ - dst[oh * OW + ow + i] = \ - std::max(vget_lane_s8(t, i), vget_lane_s8(t, i + 1)); +#define cb(i) \ + dst[oh * OW + ow + i] = std::max(vget_lane_s8(t, i), vget_lane_s8(t, i + 1)); UNROLL_CALL_NOWRAPPER(7, cb) #undef cb last_res = vget_lane_s8(t, 7); @@ -126,32 +125,31 @@ void do_max_pooling_w4x4_s2x2_int8_NEON(const int8_t* src, int8_t* dst, int8x16_t sum0 = vmaxq_s8(vmaxq_s8(i0, i1), vmaxq_s8(i2, i3)); int8x8_t t = vpmax_s8(vget_low_s8(sum0), vget_high_s8(sum0)); dst[oh * OW + ow + 0] = std::max(vget_lane_s8(t, 0), last_res); -#define cb(i) \ - dst[oh * OW + ow + i + 1] = \ - std::max(vget_lane_s8(t, i), vget_lane_s8(t, i + 1)); +#define cb(i) \ + dst[oh * OW + ow + i + 1] = std::max(vget_lane_s8(t, i), vget_lane_s8(t, i + 1)); UNROLL_CALL_NOWRAPPER(7, cb) #undef cb last_res = vget_lane_s8(t, 7); } for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } } -void do_max_pooling_w4x4_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, - DType src_dtype, const int IH, - const int IW, const int OH, - const int OW, const int PH, - const int PW) { +void do_max_pooling_w4x4_s2x2_uint8_NEON( + const uint8_t* src, uint8_t* dst, DType src_dtype, const int IH, const int IW, + const int OH, const int OW, const int PH, const int PW) { const int window = 4; const int stride = 2; using Pooler = MaxPooler<16, dt_quint8, uint8_t, float>; @@ -159,15 +157,17 @@ void do_max_pooling_w4x4_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, for (; oh < OH && -PH + stride * oh < 0; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH && -PH + stride * oh + window <= IH; ++oh) { int ow = 0; for (; ow < OW && -PW + stride * ow < 0; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } uint8_t last_res = std::numeric_limits::lowest(); int ih = -PH + stride * oh, iw = -PW + stride * ow; @@ -178,9 +178,8 @@ void do_max_pooling_w4x4_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, i3 = vld1q_u8(src + (ih + 3) * IW + iw); uint8x16_t sum0 = vmaxq_u8(vmaxq_u8(i0, i1), vmaxq_u8(i2, i3)); uint8x8_t t = vpmax_u8(vget_low_u8(sum0), vget_high_u8(sum0)); -#define cb(i) \ - dst[oh * OW + ow + i] = \ - std::max(vget_lane_u8(t, i), vget_lane_u8(t, i + 1)); +#define cb(i) \ + dst[oh * OW + ow + i] = std::max(vget_lane_u8(t, i), vget_lane_u8(t, i + 1)); UNROLL_CALL_NOWRAPPER(7, cb) #undef cb last_res = vget_lane_u8(t, 7); @@ -195,32 +194,31 @@ void do_max_pooling_w4x4_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, uint8x16_t sum0 = vmaxq_u8(vmaxq_u8(i0, i1), vmaxq_u8(i2, i3)); uint8x8_t t = vpmax_u8(vget_low_u8(sum0), vget_high_u8(sum0)); dst[oh * OW + ow + 0] = std::max(vget_lane_u8(t, 0), last_res); -#define cb(i) \ - dst[oh * OW + ow + i + 1] = \ - std::max(vget_lane_u8(t, i), vget_lane_u8(t, i + 1)); +#define cb(i) \ + dst[oh * OW + ow + i + 1] = std::max(vget_lane_u8(t, i), vget_lane_u8(t, i + 1)); UNROLL_CALL_NOWRAPPER(7, cb) #undef cb last_res = vget_lane_u8(t, 7); } for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -void do_max_pooling_w4x4_s2x2_float16_NEON(const __fp16* src, __fp16* dst, - DType src_dtype, const int IH, - const int IW, const int OH, - const int OW, const int PH, - const int PW) { +void do_max_pooling_w4x4_s2x2_float16_NEON( + const __fp16* src, __fp16* dst, DType src_dtype, const int IH, const int IW, + const int OH, const int OW, const int PH, const int PW) { const int window = 4; const int stride = 2; using Pooler = MaxPooler<16, dt_float16, __fp16, __fp16>; @@ -228,15 +226,17 @@ void do_max_pooling_w4x4_s2x2_float16_NEON(const __fp16* src, __fp16* dst, for (; oh < OH && -PH + stride * oh < 0; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH && -PH + stride * oh + window <= IH; ++oh) { int ow = 0; for (; ow < OW && -PW + stride * ow < 0; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } __fp16 last_hf_res = -std::numeric_limits::infinity(); int ih = -PH + stride * oh, iw = -PW + stride * ow; @@ -247,12 +247,9 @@ void do_max_pooling_w4x4_s2x2_float16_NEON(const __fp16* src, __fp16* dst, i3 = vld1q_f16(src + (ih + 3) * IW + iw); float16x8_t sum0 = vmaxq_f16(vmaxq_f16(i0, i1), vmaxq_f16(i2, i3)); float16x4_t t = vpmax_f16(vget_low_f16(sum0), vget_high_f16(sum0)); - dst[oh * OW + ow] = - std::max(vget_lane_f16(t, 0), vget_lane_f16(t, 1)); - dst[oh * OW + ow + 1] = - std::max(vget_lane_f16(t, 1), vget_lane_f16(t, 2)); - dst[oh * OW + ow + 2] = - std::max(vget_lane_f16(t, 2), vget_lane_f16(t, 3)); + dst[oh * OW + ow] = std::max(vget_lane_f16(t, 0), vget_lane_f16(t, 1)); + dst[oh * OW + ow + 1] = std::max(vget_lane_f16(t, 1), vget_lane_f16(t, 2)); + dst[oh * OW + ow + 2] = std::max(vget_lane_f16(t, 2), vget_lane_f16(t, 3)); last_hf_res = vget_lane_f16(t, 3); ow += 3; } @@ -265,24 +262,23 @@ void do_max_pooling_w4x4_s2x2_float16_NEON(const __fp16* src, __fp16* dst, float16x8_t sum0 = vmaxq_f16(vmaxq_f16(i0, i1), vmaxq_f16(i2, i3)); float16x4_t t = vpmax_f16(vget_low_f16(sum0), vget_high_f16(sum0)); dst[oh * OW + ow + 0] = std::max(vget_lane_f16(t, 0), last_hf_res); - dst[oh * OW + ow + 1] = - std::max(vget_lane_f16(t, 0), vget_lane_f16(t, 1)); - dst[oh * OW + ow + 2] = - std::max(vget_lane_f16(t, 1), vget_lane_f16(t, 2)); - dst[oh * OW + ow + 3] = - std::max(vget_lane_f16(t, 2), vget_lane_f16(t, 3)); + dst[oh * OW + ow + 1] = std::max(vget_lane_f16(t, 0), vget_lane_f16(t, 1)); + dst[oh * OW + ow + 2] = std::max(vget_lane_f16(t, 1), vget_lane_f16(t, 2)); + dst[oh * OW + ow + 3] = std::max(vget_lane_f16(t, 2), vget_lane_f16(t, 3)); last_hf_res = vget_lane_f16(t, 3); } for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } } @@ -290,4 +286,3 @@ void do_max_pooling_w4x4_s2x2_float16_NEON(const __fp16* src, __fp16* dst, } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h b/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h index 2a80c52b..396ea047 100644 --- a/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h +++ b/dnn/src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h @@ -15,30 +15,21 @@ namespace megdnn { namespace arm_common { -void do_max_pooling_w4x4_s2x2_float_NEON(const dt_float32* src, dt_float32* dst, - DType src_dtype, const int IH, - const int IW, const int OH, - const int OW, const int PH, - const int PW); -void do_max_pooling_w4x4_s2x2_int8_NEON(const int8_t* src, int8_t* dst, - DType src_dtype, const int IH, - const int IW, const int OH, - const int OW, const int PH, - const int PW); -void do_max_pooling_w4x4_s2x2_uint8_NEON(const uint8_t* src, uint8_t* dst, - DType src_dtype, const int IH, - const int IW, const int OH, - const int OW, const int PH, - const int PW); +void do_max_pooling_w4x4_s2x2_float_NEON( + const dt_float32* src, dt_float32* dst, DType src_dtype, const int IH, + const int IW, const int OH, const int OW, const int PH, const int PW); +void do_max_pooling_w4x4_s2x2_int8_NEON( + const int8_t* src, int8_t* dst, DType src_dtype, const int IH, const int IW, + const int OH, const int OW, const int PH, const int PW); +void do_max_pooling_w4x4_s2x2_uint8_NEON( + const uint8_t* src, uint8_t* dst, DType src_dtype, const int IH, const int IW, + const int OH, const int OW, const int PH, const int PW); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -void do_max_pooling_w4x4_s2x2_float16_NEON(const __fp16* src, __fp16* dst, - DType src_dtype, const int IH, - const int IW, const int OH, - const int OW, const int PH, - const int PW); +void do_max_pooling_w4x4_s2x2_float16_NEON( + const __fp16* src, __fp16* dst, DType src_dtype, const int IH, const int IW, + const int OH, const int OW, const int PH, const int PW); #endif } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp index a4f0cf69..a2cec33c 100644 --- a/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp +++ b/dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp @@ -17,11 +17,9 @@ namespace megdnn { namespace arm_common { -void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_max_pooling_2x2_stride1_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); @@ -63,11 +61,9 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } } -void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_max_pooling_2x2_stride2_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); @@ -81,21 +77,21 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, for (; ow + 3 < OW; ow += 4) { int8x16_t src00 = vld1q_s8(sptr0); int8x16_t src04 = vld1q_s8(sptr0 + 4 * 4); - int32x4x2_t src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), - vreinterpretq_s32_s8(src04)); + int32x4x2_t src_tmp = + vuzpq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04)); int32x4_t src0246 = src_tmp.val[0]; int32x4_t src1357 = src_tmp.val[1]; - int8x16_t max0 = vmaxq_s8(vreinterpretq_s8_s32(src0246), - vreinterpretq_s8_s32(src1357)); + int8x16_t max0 = vmaxq_s8( + vreinterpretq_s8_s32(src0246), vreinterpretq_s8_s32(src1357)); src00 = vld1q_s8(sptr1); src04 = vld1q_s8(sptr1 + 4 * 4); - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), - vreinterpretq_s32_s8(src04)); + src_tmp = + vuzpq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04)); src0246 = src_tmp.val[0]; src1357 = src_tmp.val[1]; - int8x16_t max1 = vmaxq_s8(vreinterpretq_s8_s32(src0246), - vreinterpretq_s8_s32(src1357)); + int8x16_t max1 = vmaxq_s8( + vreinterpretq_s8_s32(src0246), vreinterpretq_s8_s32(src1357)); int8x16_t max_out = vmaxq_s8(max0, max1); @@ -120,11 +116,9 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } -void do_avg_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_avg_pooling_2x2_stride1_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { int16_t filter_size = 4; const int8_t* sptr = nullptr; size_t IH2, IW2; @@ -159,18 +153,14 @@ void do_avg_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, #define sum_define(i) int16_t sum##i; UNROLL_CALL_NOWRAPPER(8, sum_define) -#define sum01_avg(i) \ - sum##i = vgetq_lane_s16(sum01, i) > 0 \ - ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ - filter_size; -#define sum23_avg(i) \ - sum##i = vgetq_lane_s16(sum23, i) > 0 \ - ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ - filter_size; +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / filter_size; #define store_sum01(i) *(dptr + i) = static_cast(sum##i); #define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); @@ -197,9 +187,8 @@ void do_avg_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int16x8_t src00 = vmovl_s8(src001); int16x8_t src10 = vmovl_s8(src101); int16x8_t max_tmp = vaddq_s16(src00, src10); -#define do_acc(i) \ - int16_t sum##i = \ - vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4); +#define do_acc(i) \ + int16_t sum##i = vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4); #define do_avg(i) \ sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ : (sum##i - filter_size / 2) / filter_size; @@ -217,11 +206,9 @@ void do_avg_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } -void do_avg_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_avg_pooling_2x2_stride2_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { int16_t filter_size = 4; const int8_t* sptr = nullptr; size_t IH2, IW2; @@ -241,20 +228,19 @@ void do_avg_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int16x8_t sum01 = vdupq_n_s16(0); int16x8_t sum23 = vdupq_n_s16(0); -#define CACULATE_ROW(i) \ - src00 = vld1q_s8(sptr##i); \ - src04 = vld1q_s8(sptr##i + 4 * 4); \ - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04)); \ - src0246 = src_tmp.val[0]; \ - src1357 = src_tmp.val[1]; \ - src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ - src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ - src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ - src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ - sum01 = vaddq_s16(sum01, src02); \ - sum01 = vaddq_s16(sum01, src13); \ - sum23 = vaddq_s16(sum23, src46); \ +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ + src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ + src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ + src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ + sum01 = vaddq_s16(sum01, src02); \ + sum01 = vaddq_s16(sum01, src13); \ + sum23 = vaddq_s16(sum23, src46); \ sum23 = vaddq_s16(sum23, src57); UNROLL_CALL_NOWRAPPER(2, CACULATE_ROW) @@ -262,18 +248,14 @@ void do_avg_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, #define sum_define(i) int16_t sum##i; UNROLL_CALL_NOWRAPPER(8, sum_define) -#define sum01_avg(i) \ - sum##i = vgetq_lane_s16(sum01, i) > 0 \ - ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ - filter_size; -#define sum23_avg(i) \ - sum##i = vgetq_lane_s16(sum23, i) > 0 \ - ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ - filter_size; +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / filter_size; #define store_sum01(i) *(dptr + i) = static_cast(sum##i); #define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); @@ -300,9 +282,8 @@ void do_avg_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int16x8_t src10 = vmovl_s8(src101); int16x8_t max_tmp = vaddq_s16(src00, src10); -#define do_acc(i) \ - int16_t sum##i = \ - vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4); +#define do_acc(i) \ + int16_t sum##i = vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4); #define do_avg(i) \ sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \ : (sum##i - filter_size / 2) / filter_size; diff --git a/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp index e8975e4a..f5bafa9f 100644 --- a/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp +++ b/dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp @@ -17,11 +17,9 @@ namespace megdnn { namespace arm_common { -void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_max_pooling_3x3_stride1_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); @@ -76,9 +74,9 @@ void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int8x8_t max12_tmp = vmax_s8(src012, src112); max12_tmp = vmax_s8(max12_tmp, src212); -#define cb(i) \ - int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \ - max12_tmp[i + 4]); +#define cb(i) \ + int8_t dst##i = \ + std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), max12_tmp[i + 4]); #define store(i) *(dptr + i) = dst##i; UNROLL_CALL_NOWRAPPER(4, cb) UNROLL_CALL_NOWRAPPER(4, store) @@ -92,11 +90,9 @@ void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } -void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_max_pooling_3x3_stride2_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); @@ -111,44 +107,44 @@ void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, for (; ow + 3 < OW; ow += 4) { int8x16_t src00 = vld1q_s8(sptr0); int8x16_t src04 = vld1q_s8(sptr0 + 4 * 4); - int32x4_t src08 = vld1q_dup_s32( - reinterpret_cast(sptr0 + 4 * 8)); - int32x4x2_t src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), - vreinterpretq_s32_s8(src04)); + int32x4_t src08 = + vld1q_dup_s32(reinterpret_cast(sptr0 + 4 * 8)); + int32x4x2_t src_tmp = + vuzpq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04)); int32x4_t src0246 = src_tmp.val[0]; int32x4_t src1357 = src_tmp.val[1]; int32x4_t src2468 = vextq_s32(src0246, src08, 1); - int8x16_t max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), - vreinterpretq_s8_s32(src1357)); + int8x16_t max_tmp = vmaxq_s8( + vreinterpretq_s8_s32(src0246), vreinterpretq_s8_s32(src1357)); int8x16_t max0 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); int8x16_t src10 = vld1q_s8(sptr1); int8x16_t src14 = vld1q_s8(sptr1 + 4 * 4); - int32x4_t src18 = vld1q_dup_s32( - reinterpret_cast(sptr1 + 4 * 8)); + int32x4_t src18 = + vld1q_dup_s32(reinterpret_cast(sptr1 + 4 * 8)); - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src10), - vreinterpretq_s32_s8(src14)); + src_tmp = + vuzpq_s32(vreinterpretq_s32_s8(src10), vreinterpretq_s32_s8(src14)); src0246 = src_tmp.val[0]; src1357 = src_tmp.val[1]; src2468 = vextq_s32(src0246, src18, 1); - max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), - vreinterpretq_s8_s32(src1357)); + max_tmp = vmaxq_s8( + vreinterpretq_s8_s32(src0246), vreinterpretq_s8_s32(src1357)); int8x16_t max1 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); int8x16_t src20 = vld1q_s8(sptr2); int8x16_t src24 = vld1q_s8(sptr2 + 4 * 4); - int32x4_t src28 = vld1q_dup_s32( - reinterpret_cast(sptr2 + 4 * 8)); + int32x4_t src28 = + vld1q_dup_s32(reinterpret_cast(sptr2 + 4 * 8)); - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src20), - vreinterpretq_s32_s8(src24)); + src_tmp = + vuzpq_s32(vreinterpretq_s32_s8(src20), vreinterpretq_s32_s8(src24)); src0246 = src_tmp.val[0]; src1357 = src_tmp.val[1]; src2468 = vextq_s32(src0246, src28, 1); - max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), - vreinterpretq_s8_s32(src1357)); + max_tmp = vmaxq_s8( + vreinterpretq_s8_s32(src0246), vreinterpretq_s8_s32(src1357)); int8x16_t max2 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); max_tmp = vmaxq_s8(max0, max1); int8x16_t max_out = vmaxq_s8(max_tmp, max2); @@ -174,9 +170,9 @@ void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int8x8_t max12_tmp = vmax_s8(src012, src112); max12_tmp = vmax_s8(max12_tmp, src212); -#define cb(i) \ - int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \ - max12_tmp[i + 4]); +#define cb(i) \ + int8_t dst##i = \ + std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), max12_tmp[i + 4]); #define store(i) *(dptr + i) = dst##i; UNROLL_CALL_NOWRAPPER(4, cb) UNROLL_CALL_NOWRAPPER(4, store) @@ -190,11 +186,9 @@ void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } -void do_avg_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_avg_pooling_3x3_stride1_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { int16_t filter_size = 9; const int8_t* sptr = nullptr; size_t IH2, IW2; @@ -234,18 +228,14 @@ void do_avg_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, #define sum_define(i) int16_t sum##i; UNROLL_CALL_NOWRAPPER(8, sum_define) -#define sum01_avg(i) \ - sum##i = vgetq_lane_s16(sum01, i) > 0 \ - ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ - filter_size; -#define sum23_avg(i) \ - sum##i = vgetq_lane_s16(sum23, i) > 0 \ - ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ - filter_size; +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / filter_size; #define store_sum01(i) *(dptr + i) = static_cast(sum##i); #define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); @@ -303,11 +293,9 @@ void do_avg_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } } -void do_avg_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_avg_pooling_3x3_stride2_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { int16_t filter_size = 9; const int8_t* sptr = nullptr; size_t IH2, IW2; @@ -328,26 +316,25 @@ void do_avg_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int16x8_t sum01 = vdupq_n_s16(0); int16x8_t sum23 = vdupq_n_s16(0); -#define CACULATE_ROW(i) \ - src00 = vld1q_s8(sptr##i); \ - src04 = vld1q_s8(sptr##i + 4 * 4); \ - src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04)); \ - src0246 = src_tmp.val[0]; \ - src1357 = src_tmp.val[1]; \ - src2468 = vextq_s32(src0246, src08, 1); \ - src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ - src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ - src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ - src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ - src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ - src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ - sum01 = vaddq_s16(sum01, src02); \ - sum01 = vaddq_s16(sum01, src13); \ - sum01 = vaddq_s16(sum01, src24); \ - sum23 = vaddq_s16(sum23, src46); \ - sum23 = vaddq_s16(sum23, src57); \ +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src2468 = vextq_s32(src0246, src08, 1); \ + src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ + src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ + src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ + src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ + src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ + src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ + sum01 = vaddq_s16(sum01, src02); \ + sum01 = vaddq_s16(sum01, src13); \ + sum01 = vaddq_s16(sum01, src24); \ + sum23 = vaddq_s16(sum23, src46); \ + sum23 = vaddq_s16(sum23, src57); \ sum23 = vaddq_s16(sum23, src68); UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW) @@ -355,18 +342,14 @@ void do_avg_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, #define sum_define(i) int16_t sum##i; UNROLL_CALL_NOWRAPPER(8, sum_define) -#define sum01_avg(i) \ - sum##i = vgetq_lane_s16(sum01, i) > 0 \ - ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ - filter_size; -#define sum23_avg(i) \ - sum##i = vgetq_lane_s16(sum23, i) > 0 \ - ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ - filter_size; +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / filter_size; #define store_sum01(i) *(dptr + i) = static_cast(sum##i); #define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); diff --git a/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp index f2facbf9..be340839 100644 --- a/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp +++ b/dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.cpp @@ -17,11 +17,9 @@ namespace megdnn { namespace arm_common { -void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_max_pooling_4x4_stride1_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); @@ -35,21 +33,17 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int8_t* __restrict dptr = dst + oh * OW * 4; size_t ow = 0; for (; ow + 3 < OW; ow += 4) { - int8x16_t src00, src04, max_out, max_tmp0, max_tmp1, max_tmp2, - max_tmp3; + int8x16_t src00, src04, max_out, max_tmp0, max_tmp1, max_tmp2, max_tmp3; int32x4_t src1234, src2345, src3456; -#define CACULATE_ROW(i) \ - src00 = vld1q_s8(sptr##i); \ - src04 = vld1q_s8(sptr##i + 4 * 4); \ - src1234 = vextq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04), 1); \ - src2345 = vextq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04), 2); \ - src3456 = vextq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04), 3); \ - max_tmp##i = vmaxq_s8(src00, vreinterpretq_s8_s32(src1234)); \ - max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2345)); \ +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src1234 = vextq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04), 1); \ + src2345 = vextq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04), 2); \ + src3456 = vextq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04), 3); \ + max_tmp##i = vmaxq_s8(src00, vreinterpretq_s8_s32(src1234)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2345)); \ max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3456)); UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) @@ -92,11 +86,9 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } } -void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_max_pooling_4x4_stride2_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); @@ -113,20 +105,19 @@ void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int8x16_t src00, src04, max_tmp0, max_tmp1, max_tmp2, max_tmp3; int32x4_t src0246, src1357, src2468, src3579, src08, src09; int32x4x2_t src_tmp; -#define CACULATE_ROW(i) \ - src00 = vld1q_s8(sptr##i); \ - src04 = vld1q_s8(sptr##i + 4 * 4); \ - src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ - src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04)); \ - src0246 = src_tmp.val[0]; \ - src1357 = src_tmp.val[1]; \ - src2468 = vextq_s32(src0246, src08, 1); \ - src3579 = vextq_s32(src1357, src09, 1); \ - max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ - vreinterpretq_s8_s32(src1357)); \ - max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ + src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src2468 = vextq_s32(src0246, src08, 1); \ + src3579 = vextq_s32(src1357, src09, 1); \ + max_tmp##i = \ + vmaxq_s8(vreinterpretq_s8_s32(src0246), vreinterpretq_s8_s32(src1357)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) @@ -171,11 +162,9 @@ void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } -void do_avg_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_avg_pooling_4x4_stride1_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { int16_t filter_size = 16; const int8_t* sptr = nullptr; size_t IH2, IW2; @@ -215,18 +204,14 @@ void do_avg_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, #define sum_define(i) int16_t sum##i; UNROLL_CALL_NOWRAPPER(8, sum_define) -#define sum01_avg(i) \ - sum##i = vgetq_lane_s16(sum01, i) > 0 \ - ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ - filter_size; -#define sum23_avg(i) \ - sum##i = vgetq_lane_s16(sum23, i) > 0 \ - ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ - filter_size; +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / filter_size; #define store_sum01(i) *(dptr + i) = static_cast(sum##i); #define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); @@ -285,11 +270,9 @@ void do_avg_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } } -void do_avg_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_avg_pooling_4x4_stride2_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { int16_t filter_size = 16; const int8_t* sptr = nullptr; size_t IH2, IW2; @@ -311,32 +294,31 @@ void do_avg_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int16x8_t sum01 = vdupq_n_s16(0); int16x8_t sum23 = vdupq_n_s16(0); -#define CACULATE_ROW(i) \ - src00 = vld1q_s8(sptr##i); \ - src04 = vld1q_s8(sptr##i + 4 * 4); \ - src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ - src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04)); \ - src0246 = src_tmp.val[0]; \ - src1357 = src_tmp.val[1]; \ - src2468 = vextq_s32(src0246, src08, 1); \ - src3579 = vextq_s32(src1357, src09, 1); \ - src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ - src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ - src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ - src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ - src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ - src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ - src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \ - src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \ - sum01 = vaddq_s16(sum01, src02); \ - sum01 = vaddq_s16(sum01, src13); \ - sum01 = vaddq_s16(sum01, src24); \ - sum01 = vaddq_s16(sum01, src35); \ - sum23 = vaddq_s16(sum23, src46); \ - sum23 = vaddq_s16(sum23, src57); \ - sum23 = vaddq_s16(sum23, src68); \ +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ + src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src2468 = vextq_s32(src0246, src08, 1); \ + src3579 = vextq_s32(src1357, src09, 1); \ + src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ + src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ + src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ + src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ + src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ + src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ + src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \ + src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \ + sum01 = vaddq_s16(sum01, src02); \ + sum01 = vaddq_s16(sum01, src13); \ + sum01 = vaddq_s16(sum01, src24); \ + sum01 = vaddq_s16(sum01, src35); \ + sum23 = vaddq_s16(sum23, src46); \ + sum23 = vaddq_s16(sum23, src57); \ + sum23 = vaddq_s16(sum23, src68); \ sum23 = vaddq_s16(sum23, src79); UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) @@ -344,18 +326,14 @@ void do_avg_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, #define sum_define(i) int16_t sum##i; UNROLL_CALL_NOWRAPPER(8, sum_define) -#define sum01_avg(i) \ - sum##i = vgetq_lane_s16(sum01, i) > 0 \ - ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ - filter_size; -#define sum23_avg(i) \ - sum##i = vgetq_lane_s16(sum23, i) > 0 \ - ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ - filter_size; +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / filter_size; #define store_sum01(i) *(dptr + i) = static_cast(sum##i); #define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); diff --git a/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp b/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp index 511f1a11..0a07ca19 100644 --- a/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp +++ b/dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp @@ -17,11 +17,9 @@ namespace megdnn { namespace arm_common { -void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_max_pooling_5x5_stride1_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); @@ -36,22 +34,19 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int8_t* __restrict dptr = dst + oh * OW * 4; size_t ow = 0; for (; ow + 3 < OW; ow += 4) { - int8x16_t src00, src04, max_out, max_tmp0, max_tmp1, max_tmp2, - max_tmp3, max_tmp4; + int8x16_t src00, src04, max_out, max_tmp0, max_tmp1, max_tmp2, max_tmp3, + max_tmp4; int32x4_t src1234, src2345, src3456; -#define CACULATE_ROW(i) \ - src00 = vld1q_s8(sptr##i); \ - src04 = vld1q_s8(sptr##i + 4 * 4); \ - src1234 = vextq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04), 1); \ - src2345 = vextq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04), 2); \ - src3456 = vextq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04), 3); \ - max_tmp##i = vmaxq_s8(src00, vreinterpretq_s8_s32(src1234)); \ - max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2345)); \ - max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3456)); \ +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src1234 = vextq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04), 1); \ + src2345 = vextq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04), 2); \ + src3456 = vextq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04), 3); \ + max_tmp##i = vmaxq_s8(src00, vreinterpretq_s8_s32(src1234)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2345)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3456)); \ max_tmp##i = vmaxq_s8(max_tmp##i, src04); UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) @@ -89,8 +84,8 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int32x2_t src##i##_45 = \ vld1_dup_s32(reinterpret_cast(sptr##i + 4 * 4)); UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) - int8x8_t max_45 = vmax_s8(vreinterpret_s8_s32(src0_45), - vreinterpret_s8_s32(src1_45)); + int8x8_t max_45 = + vmax_s8(vreinterpret_s8_s32(src0_45), vreinterpret_s8_s32(src1_45)); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src1_45)); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src2_45)); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src3_45)); @@ -111,11 +106,9 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } } -void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_max_pooling_5x5_stride2_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { const int8_t* sptr = nullptr; size_t IH2, IW2; sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true); @@ -130,28 +123,25 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int8_t* __restrict dptr = dst + oh * OW * 4; size_t ow = 0; for (; ow + 3 < OW; ow += 4) { - int8x16_t src00, src04, max_tmp0, max_tmp1, max_tmp2, max_tmp3, - max_tmp4; - int32x4_t src0246, src1357, src2468, src3579, src46810, src10, - src09, src08; + int8x16_t src00, src04, max_tmp0, max_tmp1, max_tmp2, max_tmp3, max_tmp4; + int32x4_t src0246, src1357, src2468, src3579, src46810, src10, src09, src08; int32x4x2_t src_tmp; -#define CACULATE_ROW(i) \ - src00 = vld1q_s8(sptr##i); \ - src04 = vld1q_s8(sptr##i + 4 * 4); \ - src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ - src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ - src10 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 10)); \ - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04)); \ - src0246 = src_tmp.val[0]; \ - src1357 = src_tmp.val[1]; \ - src2468 = vextq_s32(src0246, src08, 1); \ - src3579 = vextq_s32(src1357, src09, 1); \ - src46810 = vextq_s32(src2468, src10, 1); \ - max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ - vreinterpretq_s8_s32(src1357)); \ - max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ - max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); \ +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ + src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ + src10 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 10)); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src2468 = vextq_s32(src0246, src08, 1); \ + src3579 = vextq_s32(src1357, src09, 1); \ + src46810 = vextq_s32(src2468, src10, 1); \ + max_tmp##i = \ + vmaxq_s8(vreinterpretq_s8_s32(src0246), vreinterpretq_s8_s32(src1357)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ + max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); \ max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src46810)); UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) @@ -190,8 +180,8 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, int32x2_t src##i##_45 = \ vld1_dup_s32(reinterpret_cast(sptr##i + 4 * 4)); UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) - int8x8_t max_45 = vmax_s8(vreinterpret_s8_s32(src0_45), - vreinterpret_s8_s32(src1_45)); + int8x8_t max_45 = + vmax_s8(vreinterpret_s8_s32(src0_45), vreinterpret_s8_s32(src1_45)); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src1_45)); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src2_45)); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src3_45)); @@ -213,11 +203,9 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } -void do_avg_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_avg_pooling_5x5_stride1_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { int16_t filter_size = 25; const int8_t* sptr = nullptr; size_t IH2, IW2; @@ -261,18 +249,14 @@ void do_avg_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, #define sum_define(i) int16_t sum##i; UNROLL_CALL_NOWRAPPER(8, sum_define) -#define sum01_avg(i) \ - sum##i = vgetq_lane_s16(sum01, i) > 0 \ - ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ - filter_size; -#define sum23_avg(i) \ - sum##i = vgetq_lane_s16(sum23, i) > 0 \ - ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ - filter_size; +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / filter_size; #define store_sum01(i) *(dptr + i) = static_cast(sum##i); #define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); @@ -342,11 +326,9 @@ void do_avg_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, } } -void do_avg_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, - size_t IH, size_t IW, - size_t OH, size_t OW, - size_t PH, size_t PW, - const WorkspaceBundle& ws) { +void do_avg_pooling_5x5_stride2_int8_nchw44_NEON( + const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t PH, size_t PW, const WorkspaceBundle& ws) { int16_t filter_size = 25; const int8_t* sptr = nullptr; size_t IH2, IW2; @@ -364,45 +346,42 @@ void do_avg_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, for (; ow + 3 < OW; ow += 4) { int32x4x2_t src_tmp; int8x16_t src00, src04; - int16x8_t src02, src13, src57, src24, src68, src35, src79, src46, - src810; - int32x4_t src08, src09, src10, src0246, src1357, src2468, src3579, - src46810; + int16x8_t src02, src13, src57, src24, src68, src35, src79, src46, src810; + int32x4_t src08, src09, src10, src0246, src1357, src2468, src3579, src46810; int16x8_t sum01 = vdupq_n_s16(0); int16x8_t sum23 = vdupq_n_s16(0); -#define CACULATE_ROW(i) \ - src00 = vld1q_s8(sptr##i); \ - src04 = vld1q_s8(sptr##i + 4 * 4); \ - src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ - src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ - src10 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 10)); \ - src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ - vreinterpretq_s32_s8(src04)); \ - src0246 = src_tmp.val[0]; \ - src1357 = src_tmp.val[1]; \ - src2468 = vextq_s32(src0246, src08, 1); \ - src3579 = vextq_s32(src1357, src09, 1); \ - src46810 = vextq_s32(src2468, src10, 1); \ - src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ - src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ - src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ - src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ - src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ - src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ - src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \ - src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \ - src46 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src46810))); \ - src810 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src46810))); \ - sum01 = vaddq_s16(sum01, src02); \ - sum01 = vaddq_s16(sum01, src13); \ - sum01 = vaddq_s16(sum01, src24); \ - sum01 = vaddq_s16(sum01, src35); \ - sum01 = vaddq_s16(sum01, src46); \ - sum23 = vaddq_s16(sum23, src46); \ - sum23 = vaddq_s16(sum23, src57); \ - sum23 = vaddq_s16(sum23, src68); \ - sum23 = vaddq_s16(sum23, src79); \ +#define CACULATE_ROW(i) \ + src00 = vld1q_s8(sptr##i); \ + src04 = vld1q_s8(sptr##i + 4 * 4); \ + src08 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 8)); \ + src09 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 9)); \ + src10 = vld1q_dup_s32(reinterpret_cast(sptr##i + 4 * 10)); \ + src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), vreinterpretq_s32_s8(src04)); \ + src0246 = src_tmp.val[0]; \ + src1357 = src_tmp.val[1]; \ + src2468 = vextq_s32(src0246, src08, 1); \ + src3579 = vextq_s32(src1357, src09, 1); \ + src46810 = vextq_s32(src2468, src10, 1); \ + src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \ + src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \ + src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \ + src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \ + src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \ + src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \ + src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \ + src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \ + src46 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src46810))); \ + src810 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src46810))); \ + sum01 = vaddq_s16(sum01, src02); \ + sum01 = vaddq_s16(sum01, src13); \ + sum01 = vaddq_s16(sum01, src24); \ + sum01 = vaddq_s16(sum01, src35); \ + sum01 = vaddq_s16(sum01, src46); \ + sum23 = vaddq_s16(sum23, src46); \ + sum23 = vaddq_s16(sum23, src57); \ + sum23 = vaddq_s16(sum23, src68); \ + sum23 = vaddq_s16(sum23, src79); \ sum23 = vaddq_s16(sum23, src810); UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) @@ -410,18 +389,14 @@ void do_avg_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, #define sum_define(i) int16_t sum##i; UNROLL_CALL_NOWRAPPER(8, sum_define) -#define sum01_avg(i) \ - sum##i = vgetq_lane_s16(sum01, i) > 0 \ - ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum01, i) - filter_size / 2) / \ - filter_size; -#define sum23_avg(i) \ - sum##i = vgetq_lane_s16(sum23, i) > 0 \ - ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \ - filter_size \ - : (vgetq_lane_s16(sum23, i) - filter_size / 2) / \ - filter_size; +#define sum01_avg(i) \ + sum##i = vgetq_lane_s16(sum01, i) > 0 \ + ? (vgetq_lane_s16(sum01, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum01, i) - filter_size / 2) / filter_size; +#define sum23_avg(i) \ + sum##i = vgetq_lane_s16(sum23, i) > 0 \ + ? (vgetq_lane_s16(sum23, i) + filter_size / 2) / filter_size \ + : (vgetq_lane_s16(sum23, i) - filter_size / 2) / filter_size; #define store_sum01(i) *(dptr + i) = static_cast(sum##i); #define store_sum23(i) *(dptr + i + 8) = static_cast(sum##i); diff --git a/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h b/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h index 4a6ad128..b03942e3 100644 --- a/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h +++ b/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h @@ -21,14 +21,16 @@ namespace megdnn { namespace arm_common { namespace { -template +template < + int filter, int stride, int ow_step, PoolingBase::Mode mode, typename T1, + typename T2> struct CalXsXNchw44 { static void impl(T1 result, T2 src); }; -template +template < + int filter, int stride, int ow_step, PoolingBase::Mode mode, typename T1, + typename T2> void calculate_xsx_nchw44(T1 result, T2 src) { CalXsXNchw44::impl(result, src); }; @@ -45,19 +47,18 @@ void calculate_xsx_nchw44(T1 result, T2 src) { result[2] = vaddq_f32(result[2], src[2 * stride + step]); \ result[3] = vaddq_f32(result[3], src[3 * stride + step]); -#define INSTANCE_CAL(filter) \ - template \ - struct CalXsXNchw44 { \ - static void impl(T1 result, T2 src) { \ - UNROLL_CALL_RAW(filter, CALCULATE_MAX_CB); \ - } \ - }; \ - template \ - struct CalXsXNchw44 { \ - static void impl(T1 result, T2 src) { \ - UNROLL_CALL_RAW(filter, CALCULATE_AVG_CB); \ - } \ +#define INSTANCE_CAL(filter) \ + template \ + struct CalXsXNchw44 { \ + static void impl(T1 result, T2 src) { \ + UNROLL_CALL_RAW(filter, CALCULATE_MAX_CB); \ + } \ + }; \ + template \ + struct CalXsXNchw44 { \ + static void impl(T1 result, T2 src) { \ + UNROLL_CALL_RAW(filter, CALCULATE_AVG_CB); \ + } \ }; INSTANCE_CAL(2) @@ -77,8 +78,7 @@ struct KerPoolingFilterXStrideXNchw44 { }; template -struct KerPoolingFilterXStrideXNchw44 { +struct KerPoolingFilterXStrideXNchw44 { static void impl(const float32_t* src_ptr, float32_t* dst_ptr, size_t iw) { constexpr int src_reg_size = ow_step * stride + filter - stride; constexpr int packed_ic = 4; @@ -95,8 +95,8 @@ struct KerPoolingFilterXStrideXNchw44( src, src_ptr + fh_idx * iw * packed_ic, 0); - calculate_xsx_nchw44(result, src); + calculate_xsx_nchw44( + result, src); } vst1q_f32(dst_ptr + 0 * packed_ic, result[0]); @@ -107,8 +107,8 @@ struct KerPoolingFilterXStrideXNchw44 -struct KerPoolingFilterXStrideXNchw44 { +struct KerPoolingFilterXStrideXNchw44< + filter, stride, ow_step, PoolingBase::Mode::AVERAGE> { static void impl(const float32_t* src_ptr, float32_t* dst_ptr, size_t iw) { constexpr int src_reg_size = ow_step * stride + filter - stride; constexpr int packed_ic = 4; @@ -127,8 +127,8 @@ struct KerPoolingFilterXStrideXNchw44( src, src_ptr + fh_idx * iw * packed_ic, 0); - calculate_xsx_nchw44(result, src); + calculate_xsx_nchw44( + result, src); } result[0] = vmulq_f32(result[0], div_filter_size_vec); result[1] = vmulq_f32(result[1], div_filter_size_vec); @@ -142,23 +142,22 @@ struct KerPoolingFilterXStrideXNchw44 -void ker_pooling_nchw44_remain_pad(const float32_t* src_ptr, float32_t* dst_ptr, - const int iw, const int pad_top, - const int pad_bottom, const int pad_left, - const int pad_right, const int filter); +void ker_pooling_nchw44_remain_pad( + const float32_t* src_ptr, float32_t* dst_ptr, const int iw, const int pad_top, + const int pad_bottom, const int pad_left, const int pad_right, + const int filter); template <> void ker_pooling_nchw44_remain_pad( - const float32_t* src_ptr, float32_t* dst_ptr, const int iw, - const int pad_top, const int pad_bottom, const int pad_left, - const int pad_right, const int filter) { + const float32_t* src_ptr, float32_t* dst_ptr, const int iw, const int pad_top, + const int pad_bottom, const int pad_left, const int pad_right, + const int filter) { constexpr int ic_step = 4; const int ih_end = filter - pad_bottom; const int iw_end = filter - pad_right; float32x4_t result = vdupq_n_f32(std::numeric_limits::lowest()); for (int ih_idx = pad_top; ih_idx < ih_end; ++ih_idx) { for (int iw_idx = pad_left; iw_idx < iw_end; ++iw_idx) { - float32x4_t src = - vld1q_f32(src_ptr + (iw_idx - pad_left) * ic_step); + float32x4_t src = vld1q_f32(src_ptr + (iw_idx - pad_left) * ic_step); result = vmaxq_f32(result, src); } src_ptr += iw * ic_step; @@ -168,9 +167,9 @@ void ker_pooling_nchw44_remain_pad( template <> void ker_pooling_nchw44_remain_pad( - const float32_t* src_ptr, float32_t* dst_ptr, const int iw, - const int pad_top, const int pad_bottom, const int pad_left, - const int pad_right, const int filter) { + const float32_t* src_ptr, float32_t* dst_ptr, const int iw, const int pad_top, + const int pad_bottom, const int pad_left, const int pad_right, + const int filter) { constexpr int ic_step = 4; const int ih_end = filter - pad_bottom; const int iw_end = filter - pad_right; @@ -180,8 +179,7 @@ void ker_pooling_nchw44_remain_pad( for (int ih_idx = pad_top; ih_idx < ih_end; ++ih_idx) { for (int iw_idx = pad_left; iw_idx < iw_end; ++iw_idx) { - float32x4_t src = - vld1q_f32(src_ptr + (iw_idx - pad_left) * ic_step); + float32x4_t src = vld1q_f32(src_ptr + (iw_idx - pad_left) * ic_step); result = vaddq_f32(result, src); } src_ptr += iw * ic_step; @@ -192,10 +190,10 @@ void ker_pooling_nchw44_remain_pad( template static inline void kern_pooling_with_pad_nchw44( - const float32_t* src, float32_t* dst, const int filter, - const int ow_start, const int ow_end, const int iw, const int ow, - const int stride_w, const int pw, const int real_ih_idx, - const int oh_idx, const int pad_top, const int pad_bottom) { + const float32_t* src, float32_t* dst, const int filter, const int ow_start, + const int ow_end, const int iw, const int ow, const int stride_w, const int pw, + const int real_ih_idx, const int oh_idx, const int pad_top, + const int pad_bottom) { constexpr int ic_step = 4; constexpr int oc_step = 4; for (int ow_idx = ow_start; ow_idx < ow_end; ++ow_idx) { @@ -205,16 +203,16 @@ static inline void kern_pooling_with_pad_nchw44( const int pad_right = std::max(0, iw_idx - pw + filter - iw); const int src_offset = (real_ih_idx * iw + real_iw_idx) * ic_step; const int dst_offset = (oh_idx * ow + ow_idx) * oc_step; - ker_pooling_nchw44_remain_pad(src + src_offset, dst + dst_offset, - iw, pad_top, pad_bottom, pad_left, - pad_right, filter); + ker_pooling_nchw44_remain_pad( + src + src_offset, dst + dst_offset, iw, pad_top, pad_bottom, pad_left, + pad_right, filter); } } template -static inline void pooling_fp32_nchw44_pad(const float32_t* src, float32_t* dst, - int ih, int iw, int oh, int ow, - int ph, int pw) { +static inline void pooling_fp32_nchw44_pad( + const float32_t* src, float32_t* dst, int ih, int iw, int oh, int ow, int ph, + int pw) { constexpr int stride_h = stride; constexpr int stride_w = stride; constexpr int ic_step = 4; @@ -223,8 +221,7 @@ static inline void pooling_fp32_nchw44_pad(const float32_t* src, float32_t* dst, const int ow_pad_left_end = div_ceil(pw, stride_w); const int ow_pad_right_end = (iw - filter + pw - 1) / stride_w; const int ow_pad_right_step_end = - (ow_pad_right_end - ow_pad_left_end) / ow_step * ow_step + - ow_pad_left_end; + (ow_pad_right_end - ow_pad_left_end) / ow_step * ow_step + ow_pad_left_end; rep(oh_idx, oh) { const int ih_idx = oh_idx * stride_h; @@ -232,9 +229,9 @@ static inline void pooling_fp32_nchw44_pad(const float32_t* src, float32_t* dst, const int pad_top = std::max(0, ph - ih_idx); const int pad_bottom = std::max(0, ih_idx - ph + filter - ih); if (pad_top > 0 || pad_bottom > 0) { - kern_pooling_with_pad_nchw44(src, dst, filter, 0, ow, iw, ow, - stride_w, pw, real_ih_idx, - oh_idx, pad_top, pad_bottom); + kern_pooling_with_pad_nchw44( + src, dst, filter, 0, ow, iw, ow, stride_w, pw, real_ih_idx, oh_idx, + pad_top, pad_bottom); } else { kern_pooling_with_pad_nchw44( @@ -244,25 +241,21 @@ static inline void pooling_fp32_nchw44_pad(const float32_t* src, float32_t* dst, ow_idx += ow_step) { const int iw_idx = ow_idx * stride_w; const int real_iw_idx = std::max(iw_idx - pw, 0); - const int src_offset = - (real_ih_idx * iw + real_iw_idx) * ic_step; + const int src_offset = (real_ih_idx * iw + real_iw_idx) * ic_step; const int dst_offset = (oh_idx * ow + ow_idx) * oc_step; - KerPoolingFilterXStrideXNchw44::impl(src + src_offset, - dst + dst_offset, - iw); + KerPoolingFilterXStrideXNchw44::impl( + src + src_offset, dst + dst_offset, iw); } kern_pooling_with_pad_nchw44( - src, dst, filter, ow_pad_right_step_end, ow, iw, ow, - stride_w, pw, real_ih_idx, oh_idx, pad_top, pad_bottom); + src, dst, filter, ow_pad_right_step_end, ow, iw, ow, stride_w, pw, + real_ih_idx, oh_idx, pad_top, pad_bottom); } } } template -static inline void pooling_fp32_nchw44_no_pad(const float32_t* src, - float32_t* dst, int, int iw, - int oh, int ow) { +static inline void pooling_fp32_nchw44_no_pad( + const float32_t* src, float32_t* dst, int, int iw, int oh, int ow) { constexpr int stride_h = stride; constexpr int stride_w = stride; constexpr int ic_step = 4; @@ -283,23 +276,21 @@ static inline void pooling_fp32_nchw44_no_pad(const float32_t* src, src + src_offset, dst + dst_offset, iw); } if (ow_remain > 0) { - kern_pooling_with_pad_nchw44(src, dst, filter, ow_end, ow, iw, - ow, stride_w, 0, ih_idx, oh_idx, - 0, 0); + kern_pooling_with_pad_nchw44( + src, dst, filter, ow_end, ow, iw, ow, stride_w, 0, ih_idx, oh_idx, + 0, 0); } } } template -static inline void pooling_fp32_nchw44(const float32_t* src, float32_t* dst, - int ih, int iw, int oh, int ow, int ph, - int pw) { +static inline void pooling_fp32_nchw44( + const float32_t* src, float32_t* dst, int ih, int iw, int oh, int ow, int ph, + int pw) { if (ph > 0 || pw > 0) { - pooling_fp32_nchw44_pad(src, dst, ih, iw, oh, ow, - ph, pw); + pooling_fp32_nchw44_pad(src, dst, ih, iw, oh, ow, ph, pw); } else { - pooling_fp32_nchw44_no_pad(src, dst, ih, iw, oh, - ow); + pooling_fp32_nchw44_no_pad(src, dst, ih, iw, oh, ow); } } diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp index 3806c491..587a738d 100644 --- a/dnn/src/arm_common/pooling/opr_impl.cpp +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -11,8 +11,8 @@ */ #include "src/arm_common/pooling/opr_impl.h" #include "src/arm_common/pooling/algo.h" -#include "src/common/metahelper.h" #include "src/common/algo_chooser.h" +#include "src/common/metahelper.h" using namespace megdnn; using namespace arm_common; @@ -63,11 +63,10 @@ public: PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack; PoolingImpl::PoolingKernSizeParam PoolingImpl::make_pooling_kern_szie_param( - fallback::PoolingImpl* opr, const TensorLayout& src, - const TensorLayout& dst) { + fallback::PoolingImpl* opr, const TensorLayout& src, const TensorLayout& dst) { auto safe_u32 = [](size_t v) -> uint32_t { - megdnn_assert(v <= std::numeric_limits::max(), - "value too large: %zu", v); + megdnn_assert( + v <= std::numeric_limits::max(), "value too large: %zu", v); return v; }; return {safe_u32(src.shape[0]), @@ -75,10 +74,8 @@ PoolingImpl::PoolingKernSizeParam PoolingImpl::make_pooling_kern_szie_param( {{safe_u32(src.shape[2]), safe_u32(src.shape[3])}}, {{safe_u32(dst.shape[2]), safe_u32(dst.shape[3])}}, {{safe_u32(opr->param().pad_h), safe_u32(opr->param().pad_w)}}, - {{safe_u32(opr->param().window_h), - safe_u32(opr->param().window_w)}}, - {{safe_u32(opr->param().stride_h), - safe_u32(opr->param().stride_w)}}, + {{safe_u32(opr->param().window_h), safe_u32(opr->param().window_w)}}, + {{safe_u32(opr->param().stride_h), safe_u32(opr->param().stride_w)}}, src.dtype, dst.dtype, opr->handle(), @@ -87,8 +84,8 @@ PoolingImpl::PoolingKernSizeParam PoolingImpl::make_pooling_kern_szie_param( }; PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param( - fallback::PoolingImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_out dst, _megdnn_workspace workspace) { + fallback::PoolingImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { PoolingKernParam ret; static_cast(ret) = make_pooling_kern_szie_param(opr, src.layout, dst.layout); @@ -99,12 +96,12 @@ PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param( return ret; }; -size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) { +size_t PoolingImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { TensorLayoutArray layouts{src, dst}; HeuristicCache::Key key{this->handle(), this->get_opr_type(), - layouts.data(), layouts.size(), &this->param(), - sizeof(this->param())}; + layouts.data(), layouts.size(), + &this->param(), sizeof(this->param())}; auto rst = HeuristicCache::instance().get(key); if (rst.policy.algo.valid()) { return rst.workspace; @@ -148,8 +145,8 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, } } -void PoolingImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) { +void PoolingImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); auto param = make_pooling_kern_param(this, src, dst, workspace); auto algo = get_algorithm(this, src.layout, dst.layout); @@ -176,7 +173,7 @@ std::vector PoolingImpl::get_all_algorithms( } std::vector PoolingImpl::get_all_algorithms_safe( const TensorLayout& src, const TensorLayout& dst) { - auto ret_safe = get_all_algorithms(src,dst); + auto ret_safe = get_all_algorithms(src, dst); megdnn_assert(!ret_safe.empty(), "no usable pooling fwd algorithm"); return ret_safe; } @@ -193,11 +190,11 @@ Algorithm* PoolingImpl::get_algorithm_heuristic( return iter; } } - megdnn_throw( - ssprintf("require algorithm with attribute(%s) and without " - "attribute(%s), but can't get suitable algo.\n", - Algorithm::attribute_str(positive_attr).c_str(), - Algorithm::attribute_str(negative_attr).c_str())); + megdnn_throw(ssprintf( + "require algorithm with attribute(%s) and without " + "attribute(%s), but can't get suitable algo.\n", + Algorithm::attribute_str(positive_attr).c_str(), + Algorithm::attribute_str(negative_attr).c_str())); return nullptr; } diff --git a/dnn/src/arm_common/pooling/opr_impl.h b/dnn/src/arm_common/pooling/opr_impl.h index 1f25f9c0..229259d0 100644 --- a/dnn/src/arm_common/pooling/opr_impl.h +++ b/dnn/src/arm_common/pooling/opr_impl.h @@ -10,9 +10,9 @@ * implied. */ #pragma once +#include #include "megdnn/oprs/base.h" #include "src/fallback/pooling/opr_impl.h" -#include namespace megdnn { namespace arm_common { @@ -38,10 +38,10 @@ private: public: using fallback::PoolingImpl::PoolingImpl; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, - const TensorLayout&) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override; static size_t constexpr MAX_SPATIAL_DIM = 2; @@ -83,10 +83,9 @@ public: fallback::PoolingImpl* opr, const TensorLayout& src, const TensorLayout& dst); - PoolingKernParam make_pooling_kern_param(fallback::PoolingImpl* opr, - _megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace); + PoolingKernParam make_pooling_kern_param( + fallback::PoolingImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace); class AlgoBase : public detail::Algorithm { public: enum class AlgoType : uint32_t { @@ -115,8 +114,7 @@ public: uint32_t type() const override { return INVALID_ALGO_TYPE; }; bool is_available_attribute( const PoolingKernSizeParam& param, - const AlgoAttribute& positive_attr = - AlgoAttribute::REPRODUCIBLE, + const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { return contain_attribute_all(positive_attr) && !contain_attribute_any(negative_attr) && usable(param); @@ -143,8 +141,8 @@ public: const TensorLayout& src, const TensorLayout& dst, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { - return get_algorithm_heuristic(src, dst, workspace_limit_in_bytes, - positive_attr, negative_attr) + return get_algorithm_heuristic( + src, dst, workspace_limit_in_bytes, positive_attr, negative_attr) ->info(); } diff --git a/dnn/src/arm_common/pooling/pooling_helper.h b/dnn/src/arm_common/pooling/pooling_helper.h index 32c8a5f5..0cf38c85 100644 --- a/dnn/src/arm_common/pooling/pooling_helper.h +++ b/dnn/src/arm_common/pooling/pooling_helper.h @@ -84,8 +84,7 @@ struct MeanInPooler feed_cnt += 1; } void post(uint8_t* dst) { - this->res = - this->res + static_cast(area - feed_cnt) * zero_point; + this->res = this->res + static_cast(area - feed_cnt) * zero_point; this->res *= this->coef; *dst = std::round(this->res); } @@ -137,11 +136,12 @@ struct NeonMeanPooler { void feed(const int8_t* val) { int8x16_t item = vld1q_s8(val); float32x4_t tmp; -#define cb(i) \ - tmp = (float32x4_t){static_cast(vgetq_lane_s8(item, 4 * i + 0)), \ - static_cast(vgetq_lane_s8(item, 4 * i + 1)), \ - static_cast(vgetq_lane_s8(item, 4 * i + 2)), \ - static_cast(vgetq_lane_s8(item, 4 * i + 3))}; \ +#define cb(i) \ + tmp = (float32x4_t){ \ + static_cast(vgetq_lane_s8(item, 4 * i + 0)), \ + static_cast(vgetq_lane_s8(item, 4 * i + 1)), \ + static_cast(vgetq_lane_s8(item, 4 * i + 2)), \ + static_cast(vgetq_lane_s8(item, 4 * i + 3))}; \ sum##i = vaddq_f32(sum##i, tmp); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb @@ -211,11 +211,12 @@ struct NeonMeanPooler { void feed(const uint8_t* val) { uint8x16_t item = vld1q_u8(val); float32x4_t tmp; -#define cb(i) \ - tmp = (float32x4_t){static_cast(vgetq_lane_u8(item, 4 * i + 0)), \ - static_cast(vgetq_lane_u8(item, 4 * i + 1)), \ - static_cast(vgetq_lane_u8(item, 4 * i + 2)), \ - static_cast(vgetq_lane_u8(item, 4 * i + 3))}; \ +#define cb(i) \ + tmp = (float32x4_t){ \ + static_cast(vgetq_lane_u8(item, 4 * i + 0)), \ + static_cast(vgetq_lane_u8(item, 4 * i + 1)), \ + static_cast(vgetq_lane_u8(item, 4 * i + 2)), \ + static_cast(vgetq_lane_u8(item, 4 * i + 3))}; \ sum##i = vaddq_f32(sum##i, tmp); UNROLL_CALL_NOWRAPPER(4, cb); #undef cb @@ -322,8 +323,7 @@ struct NeonMaxPooler { static constexpr int SIMD_WIDTH = 16; int8x16_t res; - NeonMaxPooler(DType) - : res(vdupq_n_s8(std::numeric_limits::lowest())) {} + NeonMaxPooler(DType) : res(vdupq_n_s8(std::numeric_limits::lowest())) {} void feed(const int8_t* val) { res = vmaxq_s8(res, vld1q_s8(val)); } void post(int8_t* dst) { vst1q_s8(dst, res); } }; @@ -335,8 +335,7 @@ struct NeonMaxPooler { static constexpr int SIMD_WIDTH = 16; uint8x16_t res; - NeonMaxPooler(DType) - : res(vdupq_n_u8(std::numeric_limits::lowest())) {} + NeonMaxPooler(DType) : res(vdupq_n_u8(std::numeric_limits::lowest())) {} void feed(const uint8_t* val) { res = vmaxq_u8(res, vld1q_u8(val)); } void post(uint8_t* dst) { vst1q_u8(dst, res); } }; @@ -356,10 +355,10 @@ struct NeonMaxPooler { #endif template -void do_pxl_naive(int oh, int ow, const typename Pooler::ctype* src, - typename Pooler::ctype* dst, DType src_dtype, const int IH, - const int IW, const int OH, const int OW, const int PH, - const int PW, const int SH, const int SW) { +void do_pxl_naive( + int oh, int ow, const typename Pooler::ctype* src, typename Pooler::ctype* dst, + DType src_dtype, const int IH, const int IW, const int OH, const int OW, + const int PH, const int PW, const int SH, const int SW) { MEGDNN_MARK_USED_VAR(OH); Pooler pooler(src_dtype); rep(wh, window) rep(ww, window) { @@ -376,18 +375,18 @@ namespace detail { template struct do_pxl_2x2_pack_proxy { - static void gao(int oh, int ow, const typename Pooler::ctype* src, - typename Pooler::ctype* dst, DType, const int IH, - const int IW, const int OH, const int OW, const int PH, - const int PW); + static void gao( + int oh, int ow, const typename Pooler::ctype* src, + typename Pooler::ctype* dst, DType, const int IH, const int IW, + const int OH, const int OW, const int PH, const int PW); }; template <> -struct do_pxl_2x2_pack_proxy, - Pooling::Mode::AVERAGE> { - static void gao(int oh, int ow, const dt_float32* src, dt_float32* dst, - DType, const int IH, const int IW, const int OH, - const int OW, const int PH, const int PW) { +struct do_pxl_2x2_pack_proxy< + MeanInPooler<4, dt_float32, float, float>, Pooling::Mode::AVERAGE> { + static void gao( + int oh, int ow, const dt_float32* src, dt_float32* dst, DType, const int IH, + const int IW, const int OH, const int OW, const int PH, const int PW) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); static const auto avg_coef = vdupq_n_f32(0.25f); @@ -407,11 +406,11 @@ struct do_pxl_2x2_pack_proxy, }; template <> -struct do_pxl_2x2_pack_proxy, - Pooling::Mode::AVERAGE> { - static void gao(int oh, int ow, const int8_t* src, int8_t* dst, DType, - const int IH, const int IW, const int OH, const int OW, - const int PH, const int PW) { +struct do_pxl_2x2_pack_proxy< + MeanInPooler<4, dt_qint8, int8_t, float>, Pooling::Mode::AVERAGE> { + static void gao( + int oh, int ow, const int8_t* src, int8_t* dst, DType, const int IH, + const int IW, const int OH, const int OW, const int PH, const int PW) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); int ih = -PH + 2 * oh; @@ -446,11 +445,11 @@ struct do_pxl_2x2_pack_proxy, }; template <> -struct do_pxl_2x2_pack_proxy, - Pooling::Mode::AVERAGE> { - static void gao(int oh, int ow, const uint8_t* src, uint8_t* dst, DType, - const int IH, const int IW, const int OH, const int OW, - const int PH, const int PW) { +struct do_pxl_2x2_pack_proxy< + MeanInPooler<4, dt_quint8, uint8_t, float>, Pooling::Mode::AVERAGE> { + static void gao( + int oh, int ow, const uint8_t* src, uint8_t* dst, DType, const int IH, + const int IW, const int OH, const int OW, const int PH, const int PW) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); int ih = -PH + 2 * oh; @@ -478,11 +477,11 @@ struct do_pxl_2x2_pack_proxy, }; template <> -struct do_pxl_2x2_pack_proxy, - Pooling::Mode::MAX> { - static void gao(int oh, int ow, const dt_float32* src, dt_float32* dst, - DType, const int IH, const int IW, const int OH, - const int OW, const int PH, const int PW) { +struct do_pxl_2x2_pack_proxy< + MaxPooler<4, dt_float32, float, float>, Pooling::Mode::MAX> { + static void gao( + int oh, int ow, const dt_float32* src, dt_float32* dst, DType, const int IH, + const int IW, const int OH, const int OW, const int PH, const int PW) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); int ih = -PH + 2 * oh; @@ -500,11 +499,11 @@ struct do_pxl_2x2_pack_proxy, }; template <> -struct do_pxl_2x2_pack_proxy, - Pooling::Mode::MAX> { - static void gao(int oh, int ow, const int8_t* src, int8_t* dst, DType, - const int IH, const int IW, const int OH, const int OW, - const int PH, const int PW) { +struct do_pxl_2x2_pack_proxy< + MaxPooler<4, dt_qint8, int8_t, float>, Pooling::Mode::MAX> { + static void gao( + int oh, int ow, const int8_t* src, int8_t* dst, DType, const int IH, + const int IW, const int OH, const int OW, const int PH, const int PW) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); int ih = -PH + 2 * oh; @@ -522,11 +521,11 @@ struct do_pxl_2x2_pack_proxy, }; template <> -struct do_pxl_2x2_pack_proxy, - Pooling::Mode::MAX> { - static void gao(int oh, int ow, const uint8_t* src, uint8_t* dst, DType, - const int IH, const int IW, const int OH, const int OW, - const int PH, const int PW) { +struct do_pxl_2x2_pack_proxy< + MaxPooler<4, dt_quint8, uint8_t, float>, Pooling::Mode::MAX> { + static void gao( + int oh, int ow, const uint8_t* src, uint8_t* dst, DType, const int IH, + const int IW, const int OH, const int OW, const int PH, const int PW) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); int ih = -PH + 2 * oh; @@ -545,11 +544,11 @@ struct do_pxl_2x2_pack_proxy, #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC template <> -struct do_pxl_2x2_pack_proxy, - Pooling::Mode::AVERAGE> { - static void gao(int oh, int ow, const __fp16* src, __fp16* dst, DType, - const int IH, const int IW, const int OH, const int OW, - const int PH, const int PW) { +struct do_pxl_2x2_pack_proxy< + MeanInPooler<4, dt_float16, __fp16, __fp16>, Pooling::Mode::AVERAGE> { + static void gao( + int oh, int ow, const __fp16* src, __fp16* dst, DType, const int IH, + const int IW, const int OH, const int OW, const int PH, const int PW) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); static const auto avg_coef = vdupq_n_f16(0.25f); @@ -569,11 +568,11 @@ struct do_pxl_2x2_pack_proxy, }; template <> -struct do_pxl_2x2_pack_proxy, - Pooling::Mode::MAX> { - static void gao(int oh, int ow, const __fp16* src, __fp16* dst, DType, - const int IH, const int IW, const int OH, const int OW, - const int PH, const int PW) { +struct do_pxl_2x2_pack_proxy< + MaxPooler<4, dt_float16, __fp16, __fp16>, Pooling::Mode::MAX> { + static void gao( + int oh, int ow, const __fp16* src, __fp16* dst, DType, const int IH, + const int IW, const int OH, const int OW, const int PH, const int PW) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); int ih = -PH + 2 * oh; @@ -594,20 +593,19 @@ struct do_pxl_2x2_pack_proxy, } // namespace detail template -void do_pxl_2x2_pack(int oh, int ow, const typename Pooler::ctype* src, - typename Pooler::ctype* dst, DType src_dtype, const int IH, - const int IW, const int OH, const int OW, const int PH, - const int PW) { +void do_pxl_2x2_pack( + int oh, int ow, const typename Pooler::ctype* src, typename Pooler::ctype* dst, + DType src_dtype, const int IH, const int IW, const int OH, const int OW, + const int PH, const int PW) { detail::do_pxl_2x2_pack_proxy::gao( oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW); } template -void do_pxl_compact_packed(int oh, int ow, - const typename NeonPooler::ctype* src, - typename NeonPooler::ctype* dst, DType src_dtype, - const int IH, const int IW, const int OH, - const int OW, const int PH, const int PW) { +void do_pxl_compact_packed( + int oh, int ow, const typename NeonPooler::ctype* src, + typename NeonPooler::ctype* dst, DType src_dtype, const int IH, const int IW, + const int OH, const int OW, const int PH, const int PW) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); NeonPooler pooler(src_dtype); @@ -620,27 +618,29 @@ void do_pxl_compact_packed(int oh, int ow, } template -void do_pooling_compact(const typename Pooler::ctype* src, - typename Pooler::ctype* dst, DType src_dtype, - const int IH, const int IW, const int OH, const int OW, - const int PH, const int PW) { - static_assert(std::is_same::value, - "ctype of Pooler and NeonPooler is not the same"); +void do_pooling_compact( + const typename Pooler::ctype* src, typename Pooler::ctype* dst, DType src_dtype, + const int IH, const int IW, const int OH, const int OW, const int PH, + const int PW) { + static_assert( + std::is_same::value, + "ctype of Pooler and NeonPooler is not the same"); const int stride = 1; int oh = 0; for (; oh < OH && oh - PH < 0; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH && oh - PH + window <= IH; ++oh) { int ow = 0; for (; ow < OW && ow - PW < 0; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } for (; ow + NeonPooler::SIMD_WIDTH <= OW && ow + NeonPooler::SIMD_WIDTH - 1 - PW + window <= IW; @@ -649,56 +649,62 @@ void do_pooling_compact(const typename Pooler::ctype* src, oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW); } for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } } template -void do_pooling_2x2(const typename Pooler::ctype* src, - typename Pooler::ctype* dst, DType src_dtype, const int IH, - const int IW, const int OH, const int OW, const int PH, - const int PW) { +void do_pooling_2x2( + const typename Pooler::ctype* src, typename Pooler::ctype* dst, DType src_dtype, + const int IH, const int IW, const int OH, const int OW, const int PH, + const int PW) { const int window = 2; const int stride = 2; int oh = 0; for (; oh < OH && -PH + stride * oh < 0; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH && -PH + stride * oh + window <= IH; ++oh) { int ow = 0; for (; ow < OW && -PW + stride * ow < 0; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } for (; ow + Pooler::SIMD_WIDTH <= OW && -PW + stride * (ow + Pooler::SIMD_WIDTH - 1) + window <= IW; ow += Pooler::SIMD_WIDTH) { - do_pxl_2x2_pack(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW); + do_pxl_2x2_pack( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW); } for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } for (; oh < OH; ++oh) { int ow = 0; for (; ow < OW; ++ow) { - do_pxl_naive(oh, ow, src, dst, src_dtype, IH, IW, - OH, OW, PH, PW, stride, stride); + do_pxl_naive( + oh, ow, src, dst, src_dtype, IH, IW, OH, OW, PH, PW, stride, + stride); } } } @@ -782,11 +788,10 @@ inline float16x8x2_t vunzip(float16x8_t a, float16x8_t b) { // because the __fp16 can't get the lowest value, so add dtype template -void do_max_pooling_w5x5_s2x2_NEON(const ctype* src, ctype* dst, const int IH, - const int IW, const int OH, const int OW, - const int PH, const int PW, - const WorkspaceBundle& ws, - const int MEGDNN_SIMD_WIDTH) { +void do_max_pooling_w5x5_s2x2_NEON( + const ctype* src, ctype* dst, const int IH, const int IW, const int OH, + const int OW, const int PH, const int PW, const WorkspaceBundle& ws, + const int MEGDNN_SIMD_WIDTH) { ctype* cache[5] = { static_cast(ws.get(0)), static_cast(ws.get(1)), static_cast(ws.get(2)), static_cast(ws.get(3)), @@ -872,8 +877,9 @@ void do_max_pooling_w5x5_s2x2_NEON(const ctype* src, ctype* dst, const int IH, vset(dptr + ow, d); } for (; ow < OW; ++ow) - dptr[ow] = std::max({cache[0][ow], cache[1][ow], cache[2][ow], - cache[3][ow], cache[4][ow]}); + dptr[ow] = std::max( + {cache[0][ow], cache[1][ow], cache[2][ow], cache[3][ow], + cache[4][ow]}); } else { std::memcpy(dptr, cache[0], sizeof(ctype) * OW); for (int i = 1; i < ih_to - ih_from; ++i) { @@ -892,17 +898,16 @@ void do_max_pooling_w5x5_s2x2_NEON(const ctype* src, ctype* dst, const int IH, } template -void do_average_pooling_3x3_s2x2_NEON(const ctype* src, ctype* dst, size_t IH_, - size_t IW_, size_t OH_, size_t OW_, - size_t PH_, size_t PW_, - const WorkspaceBundle& ws, - const int MEGDNN_SIMD_WIDTH) { +void do_average_pooling_3x3_s2x2_NEON( + const ctype* src, ctype* dst, size_t IH_, size_t IW_, size_t OH_, size_t OW_, + size_t PH_, size_t PW_, const WorkspaceBundle& ws, + const int MEGDNN_SIMD_WIDTH) { int IH = IH_, IW = IW_, OH = OH_, OW = OW_, PH = PH_, PW = PW_; // cache[i] stores the answer of the i-th line after // pooling along the W dimension. - ctype* cache[3] = {static_cast(ws.get(0)), - static_cast(ws.get(1)), - static_cast(ws.get(2))}; + ctype* cache[3] = { + static_cast(ws.get(0)), static_cast(ws.get(1)), + static_cast(ws.get(2))}; ctype* odd = static_cast(ws.get(3)); ctype* even = static_cast(ws.get(4)); int ih_next = 0; @@ -1001,8 +1006,7 @@ void do_average_pooling_3x3_s2x2_NEON(const ctype* src, ctype* dst, size_t IH_, #pragma clang loop vectorize(disable) #endif for (; ow < OW; ++ow) { - dptr[ow] = - (cache[0][ow] + cache[1][ow] + cache[2][ow]) * factor; + dptr[ow] = (cache[0][ow] + cache[1][ow] + cache[2][ow]) * factor; } } else { std::memcpy(dptr, cache[0], sizeof(ctype) * OW); diff --git a/dnn/src/arm_common/quantized_converter.h b/dnn/src/arm_common/quantized_converter.h index 6959a3b3..c752dbe6 100644 --- a/dnn/src/arm_common/quantized_converter.h +++ b/dnn/src/arm_common/quantized_converter.h @@ -65,8 +65,7 @@ inline int8x8_t QConverter::convert(const float32x4_t& src) { } template <> -inline uint8x8_t QConverter::convert(const float32x4x2_t& vsrc, - const int32x4_t& vzp) { +inline uint8x8_t QConverter::convert(const float32x4x2_t& vsrc, const int32x4_t& vzp) { int32x4_t vres0 = vcvtaq_s32_f32(vsrc.val[0]); int32x4_t vres1 = vcvtaq_s32_f32(vsrc.val[1]); vres0 = vqaddq_s32(vres0, vzp); @@ -74,8 +73,8 @@ inline uint8x8_t QConverter::convert(const float32x4x2_t& vsrc, vres0 = vmaxq_s32(vres0, QConverterBase::vzero()); vres1 = vmaxq_s32(vres1, QConverterBase::vzero()); - return vqmovn_u16(vreinterpretq_u16_s16( - vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1)))); + return vqmovn_u16( + vreinterpretq_u16_s16(vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1)))); } template <> @@ -86,12 +85,12 @@ inline int32x4_t QConverter::convert(const float32x4_t& vsrc) { #else template <> inline int8x8_t QConverter::convert(const float32x4x2_t& vsrc) { - float32x4_t vinc0 = - vbslq_f32(vcgeq_f32(vsrc.val[0], QConverterBase::vfzero()), - QConverterBase::vfhalf(), QConverterBase::vfneg_half()); - float32x4_t vinc1 = - vbslq_f32(vcgeq_f32(vsrc.val[1], QConverterBase::vfzero()), - QConverterBase::vfhalf(), QConverterBase::vfneg_half()); + float32x4_t vinc0 = vbslq_f32( + vcgeq_f32(vsrc.val[0], QConverterBase::vfzero()), QConverterBase::vfhalf(), + QConverterBase::vfneg_half()); + float32x4_t vinc1 = vbslq_f32( + vcgeq_f32(vsrc.val[1], QConverterBase::vfzero()), QConverterBase::vfhalf(), + QConverterBase::vfneg_half()); int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(vsrc.val[0], vinc0)); int32x4_t vres1 = vcvtq_s32_f32(vaddq_f32(vsrc.val[1], vinc1)); @@ -101,9 +100,9 @@ inline int8x8_t QConverter::convert(const float32x4x2_t& vsrc) { template <> inline int8x8_t QConverter::convert(const float32x4_t& src) { - float32x4_t vinc0 = - vbslq_f32(vcgeq_f32(src, QConverterBase::vfzero()), - QConverterBase::vfhalf(), QConverterBase::vfneg_half()); + float32x4_t vinc0 = vbslq_f32( + vcgeq_f32(src, QConverterBase::vfzero()), QConverterBase::vfhalf(), + QConverterBase::vfneg_half()); int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(src, vinc0)); int16x4_t vres0_int16 = vqmovn_s32(vres0); @@ -111,14 +110,13 @@ inline int8x8_t QConverter::convert(const float32x4_t& src) { } template <> -inline uint8x8_t QConverter::convert(const float32x4x2_t& vsrc, - const int32x4_t& vzp) { - float32x4_t vinc0 = - vbslq_f32(vcgeq_f32(vsrc.val[0], QConverterBase::vfzero()), - QConverterBase::vfhalf(), QConverterBase::vfneg_half()); - float32x4_t vinc1 = - vbslq_f32(vcgeq_f32(vsrc.val[1], QConverterBase::vfzero()), - QConverterBase::vfhalf(), QConverterBase::vfneg_half()); +inline uint8x8_t QConverter::convert(const float32x4x2_t& vsrc, const int32x4_t& vzp) { + float32x4_t vinc0 = vbslq_f32( + vcgeq_f32(vsrc.val[0], QConverterBase::vfzero()), QConverterBase::vfhalf(), + QConverterBase::vfneg_half()); + float32x4_t vinc1 = vbslq_f32( + vcgeq_f32(vsrc.val[1], QConverterBase::vfzero()), QConverterBase::vfhalf(), + QConverterBase::vfneg_half()); int32x4_t vres0 = vcvtq_s32_f32(vaddq_f32(vsrc.val[0], vinc0)); int32x4_t vres1 = vcvtq_s32_f32(vaddq_f32(vsrc.val[1], vinc1)); @@ -127,15 +125,15 @@ inline uint8x8_t QConverter::convert(const float32x4x2_t& vsrc, vres0 = vmaxq_s32(vres0, QConverterBase::vzero()); vres1 = vmaxq_s32(vres1, QConverterBase::vzero()); - return vqmovn_u16(vreinterpretq_u16_s16( - vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1)))); + return vqmovn_u16( + vreinterpretq_u16_s16(vcombine_s16(vqmovn_s32(vres0), vqmovn_s32(vres1)))); } template <> inline int32x4_t QConverter::convert(const float32x4_t& vsrc) { - float32x4_t vinc = - vbslq_f32(vcgeq_f32(vsrc, QConverterBase::vfzero()), - QConverterBase::vfhalf(), QConverterBase::vfneg_half()); + float32x4_t vinc = vbslq_f32( + vcgeq_f32(vsrc, QConverterBase::vfzero()), QConverterBase::vfhalf(), + QConverterBase::vfneg_half()); return vcvtq_s32_f32(vaddq_f32(vsrc, vinc)); } diff --git a/dnn/src/arm_common/reduce/opr_impl.cpp b/dnn/src/arm_common/reduce/opr_impl.cpp index adb92c98..815d0628 100644 --- a/dnn/src/arm_common/reduce/opr_impl.cpp +++ b/dnn/src/arm_common/reduce/opr_impl.cpp @@ -11,8 +11,8 @@ #include "src/arm_common/reduce/opr_impl.h" #include -#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/quantized_converter.h" +#include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/reduce_helper.h" #include "src/common/unroll_macro.h" #include "src/common/utils.h" @@ -26,7 +26,7 @@ MIDOUT_DECL(megdnn_arm_common_reduce) namespace { -//!FIXME: we should check this when update the compiler +//! FIXME: we should check this when update the compiler #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if MEGDNN_ARMV7 typedef float fp16_fix_t; @@ -75,8 +75,7 @@ struct MeanReducer { int32_t zp; int32_t cnt; float coef; - MeanReducer(DType src_dtype, size_t cnt) - : res(0), cnt(cnt), coef(1.0 / cnt) { + MeanReducer(DType src_dtype, size_t cnt) : res(0), cnt(cnt), coef(1.0 / cnt) { zp = src_dtype.param().zero_point; } MeanReducer() = default; @@ -84,8 +83,7 @@ struct MeanReducer { #if MEGDNN_AARCH64 res += vaddlvq_u8(vld1q_u8(val)); #elif MEGDNN_ARMV7 - auto sum = - vreinterpretq_s32_u32(vpaddlq_u16(vpaddlq_u8(vld1q_u8(val)))); + auto sum = vreinterpretq_s32_u32(vpaddlq_u16(vpaddlq_u8(vld1q_u8(val)))); res += (vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) + vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3)); #else @@ -143,8 +141,9 @@ struct MeanReducer<__fp16, __fp16, __fp16, true> { void feed_remain(const ctype* val) { result += *val; } void post(ctype* dst) { auto sum_tmp = vadd_f16(vget_low_f16(res), vget_high_f16(res)); - result += (vget_lane_f16(sum_tmp, 0) + vget_lane_f16(sum_tmp, 1) + - vget_lane_f16(sum_tmp, 2) + vget_lane_f16(sum_tmp, 3)); + result += + (vget_lane_f16(sum_tmp, 0) + vget_lane_f16(sum_tmp, 1) + + vget_lane_f16(sum_tmp, 2) + vget_lane_f16(sum_tmp, 3)); *dst = result * coef; } }; @@ -167,9 +166,7 @@ struct MeanReducer<__fp16, __fp16, __fp16, false> { res = vmulq_n_f16(res, coef); vst1q_f16(dst, res); } - void post_remain(ctype* dst){ - *dst = remain * coef; - } + void post_remain(ctype* dst) { *dst = remain * coef; } }; #endif @@ -191,9 +188,7 @@ struct MeanReducer { res = vmulq_n_f32(res, coef); vst1q_f32(dst, res); } - void post_remain(float* dst){ - *dst = remain * coef; - } + void post_remain(float* dst) { *dst = remain * coef; } }; template <> @@ -206,9 +201,8 @@ struct MeanReducer { int32_t cnt; float coef; float32x4_t vcoef; - MeanReducer(DType, size_t cnt) - : remain(0), cnt(cnt), coef(1.0 / cnt) { - memset(res, 0, sizeof (res)); + MeanReducer(DType, size_t cnt) : remain(0), cnt(cnt), coef(1.0 / cnt) { + memset(res, 0, sizeof(res)); vcoef = vdupq_n_f32(coef); } MeanReducer() = default; @@ -232,8 +226,8 @@ struct MeanReducer { for (int i = 0; i < 4; i += 2) { float32x4_t vitem0 = vmulq_f32(vcvtq_f32_s32(res[i]), vcoef); float32x4_t vitem1 = vmulq_f32(vcvtq_f32_s32(res[i + 1]), vcoef); - vst1_s8(dst, (QConverter::convert({{vitem0, vitem1}}))); + vst1_s8(dst, + (QConverter::convert({{vitem0, vitem1}}))); dst += 8; } } @@ -256,11 +250,10 @@ struct MeanReducer { int32x4_t vcnt; float coef; float32x4_t vcoef; - MeanReducer(DType src_dtype, size_t cnt) - : remain(0), cnt(cnt), coef(1.0 / cnt) { + MeanReducer(DType src_dtype, size_t cnt) : remain(0), cnt(cnt), coef(1.0 / cnt) { zp = src_dtype.param().zero_point; vzp = vdupq_n_s32(zp); - memset(res, 0, sizeof (res)); + memset(res, 0, sizeof(res)); vcnt = vdupq_n_s32(cnt); vcoef = vdupq_n_f32(coef); } @@ -289,8 +282,8 @@ struct MeanReducer { float32x4_t vitem0 = vmulq_f32(vcvtq_f32_s32(tmp0), vcoef); float32x4_t vitem1 = vmulq_f32(vcvtq_f32_s32(tmp1), vcoef); - vst1_u8(dst, (QConverter::convert({{vitem0, vitem1}}, vzp))); + vst1_u8(dst, (QConverter::convert( + {{vitem0, vitem1}}, vzp))); dst += 8; } } @@ -306,42 +299,40 @@ struct maxReducer; template struct minReducer; -#define REDUCER_MAX_MIN_C1(_mode, _dtype, _ctype, _comp_type, _stype, __stype, _init) \ - template<> \ - struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ - using ctype = _ctype; \ - static constexpr int SIMD_WIDTH = 16; \ - __stype##8x16_t res; \ - _mode##Reducer(DType, size_t) { res = vdupq_n_##_stype##8(_init); } \ - _mode##Reducer() = default; \ - void feed(const ctype* val) { \ - __stype##8x16_t vval = vld1q_##_stype##8(val); \ - res = v##_mode##q_##_stype##8(vval, res); \ - } \ - void feed_remain(const ctype* val) { \ - __stype##8x16_t vval = vdupq_n_##_stype##8(*val); \ - res = v##_mode##q_##_stype##8(vval, res); \ - } \ - void post(ctype* dst) { \ - __stype##16x8_t vval_low = \ - vmovl_##_stype##8(vget_low_##_stype##8(res)); \ - __stype##16x8_t vval_high = \ - vmovl_##_stype##8(vget_high_##_stype##8(res)); \ - __stype##16x8_t vval_m = \ - v##_mode##q_##_stype##16(vval_low, vval_high); \ - \ - __stype##32x4_t vval_m_low = \ - vmovl_##_stype##16(vget_low_##_stype##16(vval_m)); \ - __stype##32x4_t vval_m_high = \ - vmovl_##_stype##16(vget_high_##_stype##16(vval_m)); \ - __stype##32x4_t vval_m_m = \ - v##_mode##q_##_stype##32(vval_m_low, vval_m_high); \ - using namespace std; \ - *dst = _mode({vgetq_lane_##_stype##32(vval_m_m, 0), \ - vgetq_lane_##_stype##32(vval_m_m, 1), \ - vgetq_lane_##_stype##32(vval_m_m, 2), \ - vgetq_lane_##_stype##32(vval_m_m, 3)}); \ - } \ +#define REDUCER_MAX_MIN_C1(_mode, _dtype, _ctype, _comp_type, _stype, __stype, _init) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = 16; \ + __stype##8x16_t res; \ + _mode##Reducer(DType, size_t) { res = vdupq_n_##_stype##8(_init); } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype##8x16_t vval = vld1q_##_stype##8(val); \ + res = v##_mode##q_##_stype##8(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + __stype##8x16_t vval = vdupq_n_##_stype##8(*val); \ + res = v##_mode##q_##_stype##8(vval, res); \ + } \ + void post(ctype* dst) { \ + __stype##16x8_t vval_low = vmovl_##_stype##8(vget_low_##_stype##8(res)); \ + __stype##16x8_t vval_high = vmovl_##_stype##8(vget_high_##_stype##8(res)); \ + __stype##16x8_t vval_m = v##_mode##q_##_stype##16(vval_low, vval_high); \ + \ + __stype##32x4_t vval_m_low = \ + vmovl_##_stype##16(vget_low_##_stype##16(vval_m)); \ + __stype##32x4_t vval_m_high = \ + vmovl_##_stype##16(vget_high_##_stype##16(vval_m)); \ + __stype##32x4_t vval_m_m = \ + v##_mode##q_##_stype##32(vval_m_low, vval_m_high); \ + using namespace std; \ + *dst = \ + _mode({vgetq_lane_##_stype##32(vval_m_m, 0), \ + vgetq_lane_##_stype##32(vval_m_m, 1), \ + vgetq_lane_##_stype##32(vval_m_m, 2), \ + vgetq_lane_##_stype##32(vval_m_m, 3)}); \ + } \ } REDUCER_MAX_MIN_C1(max, dt_qint8, int8_t, int8_t, s, int, -128); @@ -351,7 +342,7 @@ REDUCER_MAX_MIN_C1(min, dt_quint8, uint8_t, uint8_t, u, uint, 255); #undef REDUCER_MAX_MIN_C1 #define REDUCER_MAX_MIN_C(_mode, _dtype, _ctype, _comp_type, _stype, __stype, _init) \ - template<> \ + template <> \ struct _mode##Reducer<_dtype, _ctype, _comp_type, false> { \ using ctype = _ctype; \ static constexpr int SIMD_WIDTH = 16; \ @@ -369,12 +360,8 @@ REDUCER_MAX_MIN_C1(min, dt_quint8, uint8_t, uint8_t, u, uint, 255); __stype##8x16_t vval = vdupq_n_##_stype(*val); \ remain = v##_mode##q_##_stype(vval, remain); \ } \ - void post(ctype* dst) { \ - vst1q_##_stype(dst, res); \ - } \ - void post_remain(ctype* dst) { \ - vst1q_lane_##_stype(dst, remain, 0); \ - } \ + void post(ctype* dst) { vst1q_##_stype(dst, res); } \ + void post_remain(ctype* dst) { vst1q_lane_##_stype(dst, remain, 0); } \ } REDUCER_MAX_MIN_C(max, dt_qint8, int8_t, int8_t, s8, int, -128); @@ -383,182 +370,78 @@ REDUCER_MAX_MIN_C(max, dt_quint8, uint8_t, uint8_t, u8, uint, 0); REDUCER_MAX_MIN_C(min, dt_quint8, uint8_t, uint8_t, u8, uint, 255); #undef REDUCER_MAX_MIN_C -#define REDUCER_MAX_MIN_C1(_mode, _dtype, _ctype, _comp_type, _stype, __stype, \ - _num, _init) \ - template <> \ - struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ - using ctype = _ctype; \ - static constexpr int SIMD_WIDTH = _num; \ - __stype res; \ - _mode##Reducer(DType, size_t) { res = vdupq_n_##_stype(_init); } \ - _mode##Reducer() = default; \ - void feed(const ctype* val) { \ - __stype vval = vld1q_##_stype(val); \ - res = v##_mode##q_##_stype(vval, res); \ - } \ - void feed_remain(const ctype* val) { \ - __stype vval = vdupq_n_##_stype(*val); \ - res = v##_mode##q_##_stype(vval, res); \ - } \ - void post(ctype* dst) { \ - auto val = v##_mode##_##_stype(vget_low_##_stype(res), \ - vget_high_##_stype(res)); \ - using namespace std; \ - *dst = _mode( \ - {vget_lane_##_stype(val, 0), vget_lane_##_stype(val, 1)}); \ - } \ - } - -REDUCER_MAX_MIN_C1(max, dt_float32, float, float, f32, float32x4_t, 4, - std::numeric_limits::lowest()); -REDUCER_MAX_MIN_C1(min, dt_float32, float, float, f32, float32x4_t, 4, - std::numeric_limits::max()); +#define REDUCER_MAX_MIN_C1( \ + _mode, _dtype, _ctype, _comp_type, _stype, __stype, _num, _init) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + _mode##Reducer(DType, size_t) { res = vdupq_n_##_stype(_init); } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + __stype vval = vdupq_n_##_stype(*val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void post(ctype* dst) { \ + auto val = v##_mode##_##_stype( \ + vget_low_##_stype(res), vget_high_##_stype(res)); \ + using namespace std; \ + *dst = _mode({vget_lane_##_stype(val, 0), vget_lane_##_stype(val, 1)}); \ + } \ + } + +REDUCER_MAX_MIN_C1( + max, dt_float32, float, float, f32, float32x4_t, 4, + std::numeric_limits::lowest()); +REDUCER_MAX_MIN_C1( + min, dt_float32, float, float, f32, float32x4_t, 4, + std::numeric_limits::max()); #undef REDUCER_MAX_MIN_C1 #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#define REDUCER_MAX_MIN_C1(_mode, _dtype, _ctype, _comp_type, _stype, __stype, \ - _num, _init) \ - template <> \ - struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ - using ctype = _ctype; \ - static constexpr int SIMD_WIDTH = _num; \ - __stype res; \ - _mode##Reducer(DType, size_t) { res = vdupq_n_##_stype(_init); } \ - _mode##Reducer() = default; \ - void feed(const ctype* val) { \ - __stype vval = vld1q_##_stype(val); \ - res = v##_mode##q_##_stype(vval, res); \ - } \ - void feed_remain(const ctype* val) { \ - __stype vval = vdupq_n_##_stype(*val); \ - res = v##_mode##q_##_stype(vval, res); \ - } \ - void post(ctype* dst) { \ - auto val = v##_mode##_##_stype(vget_low_##_stype(res), \ - vget_high_##_stype(res)); \ - using namespace std; \ - *dst = _mode( \ - {vget_lane_##_stype(val, 0), vget_lane_##_stype(val, 1), \ - vget_lane_##_stype(val, 2), vget_lane_##_stype(val, 3)}); \ - } \ - } - -REDUCER_MAX_MIN_C1(max, __fp16, __fp16, __fp16, f16, float16x8_t, 8, - std::numeric_limits::lowest()); -REDUCER_MAX_MIN_C1(min, __fp16, __fp16, __fp16, f16, float16x8_t, 8, - std::numeric_limits::max()); -#undef REDUCER_MAX_MIN_C1 -#endif +#define REDUCER_MAX_MIN_C1( \ + _mode, _dtype, _ctype, _comp_type, _stype, __stype, _num, _init) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + _mode##Reducer(DType, size_t) { res = vdupq_n_##_stype(_init); } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + __stype vval = vdupq_n_##_stype(*val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void post(ctype* dst) { \ + auto val = v##_mode##_##_stype( \ + vget_low_##_stype(res), vget_high_##_stype(res)); \ + using namespace std; \ + *dst = \ + _mode({vget_lane_##_stype(val, 0), vget_lane_##_stype(val, 1), \ + vget_lane_##_stype(val, 2), vget_lane_##_stype(val, 3)}); \ + } \ + } -#define REDUCER_MAX_MIN_C(_mode, _dtype, _ctype, _comp_type, _stype, __stype, \ - _num, _init) \ - template <> \ - struct _mode##Reducer<_dtype, _ctype, _comp_type, false> { \ - using ctype = _ctype; \ - static constexpr int SIMD_WIDTH = _num; \ - __stype res; \ - ctype remain; \ - _mode##Reducer(DType, size_t) { \ - res = vdupq_n_##_stype(_init); \ - remain = _init; \ - } \ - _mode##Reducer() = default; \ - void feed(const ctype* val) { \ - __stype vval = vld1q_##_stype(val); \ - res = v##_mode##q_##_stype(vval, res); \ - } \ - void feed_remain(const ctype* val) { \ - using namespace std; \ - remain = _mode(*val, remain); \ - } \ - void post(ctype* dst) { vst1q_##_stype(dst, res); } \ - void post_remain(ctype* dst) { *dst = remain; } \ - } - -REDUCER_MAX_MIN_C(max, dt_float32, float, float, f32, float32x4_t, 4, - std::numeric_limits::lowest()); -REDUCER_MAX_MIN_C(min, dt_float32, float, float, f32, float32x4_t, 4, - std::numeric_limits::max()); -#undef REDUCER_MAX_MIN_C -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#define REDUCER_MAX_MIN_C(_mode, _dtype, _ctype, _comp_type, _stype, __stype, \ - _num, _init) \ - template <> \ - struct _mode##Reducer<_dtype, _ctype, _comp_type, false> { \ - using ctype = _ctype; \ - static constexpr int SIMD_WIDTH = _num; \ - __stype res; \ - fp16_fix_t remain; \ - _mode##Reducer(DType, size_t) { \ - res = vdupq_n_##_stype(_init); \ - remain = _init; \ - } \ - _mode##Reducer() = default; \ - void feed(const ctype* val) { \ - __stype vval = vld1q_##_stype(val); \ - res = v##_mode##q_##_stype(vval, res); \ - } \ - void feed_remain(const ctype* val) { \ - using namespace std; \ - remain = _mode(*val, static_cast<__fp16>(remain)); \ - } \ - void post(ctype* dst) { vst1q_##_stype(dst, res); } \ - void post_remain(ctype* dst) { *dst = static_cast<__fp16>(remain); } \ - } - -REDUCER_MAX_MIN_C(max, __fp16, __fp16, __fp16, f16, float16x8_t, 8, - std::numeric_limits::lowest()); -REDUCER_MAX_MIN_C(min, __fp16, __fp16, __fp16, f16, float16x8_t, 8, - std::numeric_limits::max()); -#undef REDUCER_MAX_MIN_C +REDUCER_MAX_MIN_C1( + max, __fp16, __fp16, __fp16, f16, float16x8_t, 8, + std::numeric_limits::lowest()); +REDUCER_MAX_MIN_C1( + min, __fp16, __fp16, __fp16, f16, float16x8_t, 8, + std::numeric_limits::max()); +#undef REDUCER_MAX_MIN_C1 #endif -/***************************Sum Product Reducer***************************/ -template -struct SumReducer; -template -struct ProductReducer; - -#define REDUCER_SUM_PRODUCT_C1(_mode, _dtype, _ctype, _comp_type, _stype, \ - __stype, _num, _init, _act, _op) \ - template <> \ - struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ - using ctype = _ctype; \ - static constexpr int SIMD_WIDTH = _num; \ - __stype res; \ - ctype remain; \ - _mode##Reducer(DType, size_t) { \ - res = vdupq_n_##_stype(_init); \ - remain = _init; \ - } \ - _mode##Reducer() = default; \ - void feed(const ctype* val) { \ - __stype vval = vld1q_##_stype(val); \ - res = v##_act##q_##_stype(vval, res); \ - } \ - void feed_remain(const ctype* val) { \ - using namespace std; \ - auto op = _op(); \ - remain = op(remain, *val); \ - } \ - void post(ctype* dst) { \ - using namespace std; \ - auto val = v##_act##_##_stype(vget_low_##_stype(res), \ - vget_high_##_stype(res)); \ - auto op = _op(); \ - *dst = op(remain, op(vget_lane_##_stype(val, 0), \ - vget_lane_##_stype(val, 1))); \ - } \ - } - -REDUCER_SUM_PRODUCT_C1(Sum, dt_float32, float, float, f32, float32x4_t, 4, 0, - add, plus); -REDUCER_SUM_PRODUCT_C1(Product, dt_float32, float, float, f32, float32x4_t, 4, - 1.0f, mul, multiplies); -#undef REDUCER_SUM_PRODUCT_C1 - -#define REDUCER_SUM_PRODUCT_C(_mode, _dtype, _ctype, _comp_type, _stype, \ - __stype, _num, _init, _act, _op) \ +#define REDUCER_MAX_MIN_C( \ + _mode, _dtype, _ctype, _comp_type, _stype, __stype, _num, _init) \ template <> \ struct _mode##Reducer<_dtype, _ctype, _comp_type, false> { \ using ctype = _ctype; \ @@ -572,68 +455,177 @@ REDUCER_SUM_PRODUCT_C1(Product, dt_float32, float, float, f32, float32x4_t, 4, _mode##Reducer() = default; \ void feed(const ctype* val) { \ __stype vval = vld1q_##_stype(val); \ - res = v##_act##q_##_stype(vval, res); \ + res = v##_mode##q_##_stype(vval, res); \ } \ void feed_remain(const ctype* val) { \ using namespace std; \ - auto op = _op(); \ - remain = op(remain, (*val)); \ + remain = _mode(*val, remain); \ } \ void post(ctype* dst) { vst1q_##_stype(dst, res); } \ void post_remain(ctype* dst) { *dst = remain; } \ } -REDUCER_SUM_PRODUCT_C(Sum, dt_float32, float, float, f32, float32x4_t, 4, 0, - add, plus); -REDUCER_SUM_PRODUCT_C(Product, dt_float32, float, float, f32, float32x4_t, 4, 1, - mul, multiplies); +REDUCER_MAX_MIN_C( + max, dt_float32, float, float, f32, float32x4_t, 4, + std::numeric_limits::lowest()); +REDUCER_MAX_MIN_C( + min, dt_float32, float, float, f32, float32x4_t, 4, + std::numeric_limits::max()); +#undef REDUCER_MAX_MIN_C #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -REDUCER_SUM_PRODUCT_C(Sum, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 0, add, - plus); -REDUCER_SUM_PRODUCT_C(Product, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 1, - mul, multiplies); +#define REDUCER_MAX_MIN_C( \ + _mode, _dtype, _ctype, _comp_type, _stype, __stype, _num, _init) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, false> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + fp16_fix_t remain; \ + _mode##Reducer(DType, size_t) { \ + res = vdupq_n_##_stype(_init); \ + remain = _init; \ + } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_mode##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + using namespace std; \ + remain = _mode(*val, static_cast<__fp16>(remain)); \ + } \ + void post(ctype* dst) { vst1q_##_stype(dst, res); } \ + void post_remain(ctype* dst) { *dst = static_cast<__fp16>(remain); } \ + } + +REDUCER_MAX_MIN_C( + max, __fp16, __fp16, __fp16, f16, float16x8_t, 8, + std::numeric_limits::lowest()); +REDUCER_MAX_MIN_C( + min, __fp16, __fp16, __fp16, f16, float16x8_t, 8, + std::numeric_limits::max()); +#undef REDUCER_MAX_MIN_C +#endif + +/***************************Sum Product Reducer***************************/ +template +struct SumReducer; +template +struct ProductReducer; + +#define REDUCER_SUM_PRODUCT_C1( \ + _mode, _dtype, _ctype, _comp_type, _stype, __stype, _num, _init, _act, _op) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + ctype remain; \ + _mode##Reducer(DType, size_t) { \ + res = vdupq_n_##_stype(_init); \ + remain = _init; \ + } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_act##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + using namespace std; \ + auto op = _op(); \ + remain = op(remain, *val); \ + } \ + void post(ctype* dst) { \ + using namespace std; \ + auto val = v##_act##_##_stype( \ + vget_low_##_stype(res), vget_high_##_stype(res)); \ + auto op = _op(); \ + *dst = \ + op(remain, \ + op(vget_lane_##_stype(val, 0), vget_lane_##_stype(val, 1))); \ + } \ + } + +REDUCER_SUM_PRODUCT_C1( + Sum, dt_float32, float, float, f32, float32x4_t, 4, 0, add, plus); +REDUCER_SUM_PRODUCT_C1( + Product, dt_float32, float, float, f32, float32x4_t, 4, 1.0f, mul, multiplies); +#undef REDUCER_SUM_PRODUCT_C1 + +#define REDUCER_SUM_PRODUCT_C( \ + _mode, _dtype, _ctype, _comp_type, _stype, __stype, _num, _init, _act, _op) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, false> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + ctype remain; \ + _mode##Reducer(DType, size_t) { \ + res = vdupq_n_##_stype(_init); \ + remain = _init; \ + } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_act##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + using namespace std; \ + auto op = _op(); \ + remain = op(remain, (*val)); \ + } \ + void post(ctype* dst) { vst1q_##_stype(dst, res); } \ + void post_remain(ctype* dst) { *dst = remain; } \ + } + +REDUCER_SUM_PRODUCT_C(Sum, dt_float32, float, float, f32, float32x4_t, 4, 0, add, plus); +REDUCER_SUM_PRODUCT_C( + Product, dt_float32, float, float, f32, float32x4_t, 4, 1, mul, multiplies); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +REDUCER_SUM_PRODUCT_C(Sum, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 0, add, plus); +REDUCER_SUM_PRODUCT_C( + Product, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 1, mul, multiplies); #endif #undef REDUCER_SUM_PRODUCT_C #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#define REDUCER_SUM_PRODUCT_C1(_mode, _dtype, _ctype, _comp_type, _stype, \ - __stype, _num, _init, _act, _op) \ - template <> \ - struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ - using ctype = _ctype; \ - static constexpr int SIMD_WIDTH = _num; \ - __stype res; \ - fp16_fix_t remain; \ - _mode##Reducer(DType, size_t) { \ - res = vdupq_n_##_stype(_init); \ - remain = _init; \ - } \ - _mode##Reducer() = default; \ - void feed(const ctype* val) { \ - __stype vval = vld1q_##_stype(val); \ - res = v##_act##q_##_stype(vval, res); \ - } \ - void feed_remain(const ctype* val) { \ - using namespace std; \ - auto op = _op(); \ - remain = op(remain, *val); \ - } \ - void post(ctype* dst) { \ - using namespace std; \ - auto val = v##_act##_##_stype(vget_low_##_stype(res), \ - vget_high_##_stype(res)); \ - auto op = _op(); \ - *dst = op(remain, op(op(vget_lane_##_stype(val, 0), \ - vget_lane_##_stype(val, 1)), \ - op(vget_lane_##_stype(val, 2), \ - vget_lane_##_stype(val, 3)))); \ - } \ - } - -REDUCER_SUM_PRODUCT_C1(Sum, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 0, add, - plus); -REDUCER_SUM_PRODUCT_C1(Product, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 1, - mul, multiplies); +#define REDUCER_SUM_PRODUCT_C1( \ + _mode, _dtype, _ctype, _comp_type, _stype, __stype, _num, _init, _act, _op) \ + template <> \ + struct _mode##Reducer<_dtype, _ctype, _comp_type, true> { \ + using ctype = _ctype; \ + static constexpr int SIMD_WIDTH = _num; \ + __stype res; \ + fp16_fix_t remain; \ + _mode##Reducer(DType, size_t) { \ + res = vdupq_n_##_stype(_init); \ + remain = _init; \ + } \ + _mode##Reducer() = default; \ + void feed(const ctype* val) { \ + __stype vval = vld1q_##_stype(val); \ + res = v##_act##q_##_stype(vval, res); \ + } \ + void feed_remain(const ctype* val) { \ + using namespace std; \ + auto op = _op(); \ + remain = op(remain, *val); \ + } \ + void post(ctype* dst) { \ + using namespace std; \ + auto val = v##_act##_##_stype( \ + vget_low_##_stype(res), vget_high_##_stype(res)); \ + auto op = _op(); \ + *dst = op( \ + remain, \ + op(op(vget_lane_##_stype(val, 0), vget_lane_##_stype(val, 1)), \ + op(vget_lane_##_stype(val, 2), vget_lane_##_stype(val, 3)))); \ + } \ + } + +REDUCER_SUM_PRODUCT_C1(Sum, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 0, add, plus); +REDUCER_SUM_PRODUCT_C1( + Product, __fp16, __fp16, __fp16, f16, float16x8_t, 8, 1, mul, multiplies); #undef REDUCER_SUM_PRODUCT_C1 #endif @@ -680,7 +672,7 @@ struct SumSqrReducer { float32x4_t res; float remain; - SumSqrReducer(DType, size_t cnt) : remain(0.0f){ + SumSqrReducer(DType, size_t cnt) : remain(0.0f) { MEGDNN_MARK_USED_VAR(cnt); res = vdupq_n_f32(0.0f); } @@ -690,12 +682,8 @@ struct SumSqrReducer { res = vaddq_f32(vmulq_f32(vval, vval), res); } void feed_remain(const float* val) { remain += (*val) * (*val); } - void post(float* dst) { - vst1q_f32(dst, res); - } - void post_remain(float* dst){ - *dst = remain; - } + void post(float* dst) { vst1q_f32(dst, res); } + void post_remain(float* dst) { *dst = remain; } }; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -709,7 +697,7 @@ struct SumSqrReducer<__fp16, __fp16, __fp16, true> { //! armv7+fp16, it may trigger result error. //! ldr instrucation need alignment of 4bytes, while __fp16 result placed in //! text segments is not satisfied. - //!FIXME: we should check it if we upgrade compiler. + //! FIXME: we should check it if we upgrade compiler. fp16_fix_t result; SumSqrReducer(DType, size_t cnt) : result(0.0f) { res = vdupq_n_f16(0.0f); } SumSqrReducer() = default; @@ -738,23 +726,17 @@ struct SumSqrReducer<__fp16, __fp16, __fp16, false> { //! armv7+fp16, it may trigger result error. //! ldr instrucation need alignment of 4bytes, while __fp16 result placed in //! text segments is not satisfied. - //!FIXME: we should check it if we upgrade compiler. + //! FIXME: we should check it if we upgrade compiler. fp16_fix_t remain; - SumSqrReducer(DType, size_t cnt) : remain(0.0f){ - res = vdupq_n_f16(0.0f); - } + SumSqrReducer(DType, size_t cnt) : remain(0.0f) { res = vdupq_n_f16(0.0f); } SumSqrReducer() = default; void feed(const __fp16* val) { float16x8_t vval = vld1q_f16(val); res = vaddq_f16(vmulq_f16(vval, vval), res); } void feed_remain(const __fp16* val) { remain += (*val) * (*val); } - void post(__fp16* dst) { - vst1q_f16(dst, res); - } - void post_remain(__fp16* dst){ - *dst = remain; - } + void post(__fp16* dst) { vst1q_f16(dst, res); } + void post_remain(__fp16* dst) { *dst = remain; } }; #endif @@ -762,20 +744,20 @@ struct SumSqrReducer<__fp16, __fp16, __fp16, false> { template struct Exec { - static void do_reduce(const typename Reducer::ctype* src, - const typename Reducer::ctype* dst, DType src_dtype, - size_t A, size_t B, size_t C); + static void do_reduce( + const typename Reducer::ctype* src, const typename Reducer::ctype* dst, + DType src_dtype, size_t A, size_t B, size_t C); }; template struct Exec { - static void do_reduce(const typename Reducer::ctype* src, - typename Reducer::ctype* dst, DType src_dtype, - size_t A, size_t B, size_t) { + static void do_reduce( + const typename Reducer::ctype* src, typename Reducer::ctype* dst, + DType src_dtype, size_t A, size_t B, size_t) { size_t a = 0; for (; a < A; a++) { Reducer reducer0(src_dtype, B); - auto temp_src0 = src + a * B; + auto temp_src0 = src + a * B; size_t b = 0; for (; b + Reducer::SIMD_WIDTH <= B; b += Reducer::SIMD_WIDTH) { reducer0.feed(temp_src0); @@ -793,9 +775,9 @@ struct Exec { template struct Exec { - static void do_reduce(const typename Reducer::ctype* src, - typename Reducer::ctype* dst, DType src_dtype, - size_t A, size_t B, size_t C) { + static void do_reduce( + const typename Reducer::ctype* src, typename Reducer::ctype* dst, + DType src_dtype, size_t A, size_t B, size_t C) { for (size_t a = 0; a < A; a++) { size_t c = 0; for (; c + Reducer::SIMD_WIDTH <= C; c += Reducer::SIMD_WIDTH) { @@ -819,42 +801,38 @@ struct Exec { } // anonymous namespace -void ReduceImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) { +void ReduceImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); size_t A, B, C; reduce::get_ABC(src.layout, A, B, C, param().axis); bool execed = false; using Mode = param::Reduce::Mode; -#define DISPATCH_FUNC(Reducer, dtype, ctype, comp_type) \ - if (C == 1) { \ - using _Reducer = Reducer; \ - std::function \ - do_reduce = Exec<_Reducer, true>::do_reduce; \ - MIDOUT_BEGIN(megdnn_arm_common_reduce, ctype, dtype, comp_type, \ - midout_iv(1)) { \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - do_reduce(reinterpret_cast(src.raw_ptr), \ - reinterpret_cast(dst.raw_ptr), src_type, \ - A, B, C)); \ - execed = true; \ - } \ - MIDOUT_END(); \ - } else { \ - using _Reducer = Reducer; \ - std::function \ - do_reduce = Exec<_Reducer, false>::do_reduce; \ - MIDOUT_BEGIN(megdnn_arm_common_reduce, ctype, dtype, comp_type, \ - midout_iv(1)) { \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - do_reduce(reinterpret_cast(src.raw_ptr), \ - reinterpret_cast(dst.raw_ptr), src_type, \ - A, B, C)); \ - execed = true; \ - } \ - MIDOUT_END(); \ +#define DISPATCH_FUNC(Reducer, dtype, ctype, comp_type) \ + if (C == 1) { \ + using _Reducer = Reducer; \ + std::function \ + do_reduce = Exec<_Reducer, true>::do_reduce; \ + MIDOUT_BEGIN( \ + megdnn_arm_common_reduce, ctype, dtype, comp_type, midout_iv(1)) { \ + MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ + reinterpret_cast(src.raw_ptr), \ + reinterpret_cast(dst.raw_ptr), src_type, A, B, C)); \ + execed = true; \ + } \ + MIDOUT_END(); \ + } else { \ + using _Reducer = Reducer; \ + std::function \ + do_reduce = Exec<_Reducer, false>::do_reduce; \ + MIDOUT_BEGIN( \ + megdnn_arm_common_reduce, ctype, dtype, comp_type, midout_iv(1)) { \ + MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ + reinterpret_cast(src.raw_ptr), \ + reinterpret_cast(dst.raw_ptr), src_type, A, B, C)); \ + execed = true; \ + } \ + MIDOUT_END(); \ } #define DISPATCH_MODE_QUANTIZED(dtype, ctype, comp_type) \ @@ -906,10 +884,10 @@ void ReduceImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, if (src.layout.dtype.enumv() == DTypeEnum::Quantized8Asymm) { DISPATCH_MODE_QUANTIZED(dt_quint8, uint8_t, int32_t) } - } else if (src.layout.is_contiguous() && - src.layout.dtype.category() == DTypeCategory::FLOAT && - param().data_type == param::Reduce::DataType::DEFAULT) { - + } else if ( + src.layout.is_contiguous() && + src.layout.dtype.category() == DTypeCategory::FLOAT && + param().data_type == param::Reduce::DataType::DEFAULT) { DType src_type = src.layout.dtype; if (src.layout.dtype.enumv() == DTypeEnum::Float32) { DISPATCH_MODE_FLOAT(dt_float32, float, float) diff --git a/dnn/src/arm_common/reduce/opr_impl.h b/dnn/src/arm_common/reduce/opr_impl.h index 5179872f..5b2e3bc4 100644 --- a/dnn/src/arm_common/reduce/opr_impl.h +++ b/dnn/src/arm_common/reduce/opr_impl.h @@ -19,8 +19,9 @@ class ReduceImpl : public fallback::ReduceImpl { public: using fallback::ReduceImpl::ReduceImpl; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; }; } // namespace arm_common diff --git a/dnn/src/arm_common/resize/direct_nchwxx.cpp b/dnn/src/arm_common/resize/direct_nchwxx.cpp index 93edd1c6..a575477a 100644 --- a/dnn/src/arm_common/resize/direct_nchwxx.cpp +++ b/dnn/src/arm_common/resize/direct_nchwxx.cpp @@ -22,8 +22,9 @@ using namespace resize; namespace { template -void resize_direct_nchwxx(const ctype* sptr, ctype* dptr, size_t N, size_t IH, - size_t IW, size_t OH, size_t OW) { +void resize_direct_nchwxx( + const ctype* sptr, ctype* dptr, size_t N, size_t IH, size_t IW, size_t OH, + size_t OW) { using simd_helper = SIMDHelper; constexpr size_t PC = simd_helper::simd_width; using simd_type = typename simd_helper::simd_type; @@ -66,7 +67,7 @@ void resize_direct_nchwxx(const ctype* sptr, ctype* dptr, size_t N, size_t IH, dptr += OH * OW * PC; } } -} +} // namespace void megdnn::arm_common::resize_direct_nearest_nchw44_fp32( const ResizeImpl::KernParam& kern_param) { @@ -89,8 +90,8 @@ void megdnn::arm_common::resize_direct_nearest_nchw88_fp16( auto sptr = reinterpret_cast(kern_param.sptr); auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); resize_direct_nchwxx<__fp16, InterpolationMode::INTER_NEAREST>( - sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, - kern_param.iw, kern_param.oh, kern_param.ow); + sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw, + kern_param.oh, kern_param.ow); } void megdnn::arm_common::resize_direct_linear_nchw88_fp16( @@ -98,8 +99,8 @@ void megdnn::arm_common::resize_direct_linear_nchw88_fp16( auto sptr = reinterpret_cast(kern_param.sptr); auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); resize_direct_nchwxx<__fp16, InterpolationMode::INTER_LINEAR>( - sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, - kern_param.iw, kern_param.oh, kern_param.ow); + sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw, + kern_param.oh, kern_param.ow); } #endif diff --git a/dnn/src/arm_common/resize/direct_nchwxx.h b/dnn/src/arm_common/resize/direct_nchwxx.h index aec01a5c..317629e0 100644 --- a/dnn/src/arm_common/resize/direct_nchwxx.h +++ b/dnn/src/arm_common/resize/direct_nchwxx.h @@ -16,11 +16,9 @@ namespace megdnn { namespace arm_common { -void resize_direct_linear_nchw44_fp32( - const ResizeImpl::KernParam& kern_param); +void resize_direct_linear_nchw44_fp32(const ResizeImpl::KernParam& kern_param); -void resize_direct_nearest_nchw44_fp32( - const ResizeImpl::KernParam& kern_param); +void resize_direct_nearest_nchw44_fp32(const ResizeImpl::KernParam& kern_param); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/arm_common/resize/helper.h b/dnn/src/arm_common/resize/helper.h index 4117024f..29eb1ad5 100644 --- a/dnn/src/arm_common/resize/helper.h +++ b/dnn/src/arm_common/resize/helper.h @@ -29,29 +29,26 @@ struct SIMDHelper { using ctype = float; static constexpr size_t simd_width = 4; - static inline simd_type load(const ctype* src_ptr) { - return vld1q_f32(src_ptr); - } + static inline simd_type load(const ctype* src_ptr) { return vld1q_f32(src_ptr); } static inline void store(ctype* dst_ptr, const simd_type& rdst) { vst1q_f32(dst_ptr, rdst); } - static inline void store2_interleave(ctype* dst_ptr, const simd_type& rdst1, - const simd_type& rdst2) { + static inline void store2_interleave( + ctype* dst_ptr, const simd_type& rdst1, const simd_type& rdst2) { simd_type_x2 rdst; rdst.val[0] = rdst1; rdst.val[1] = rdst2; vst2q_f32(dst_ptr, rdst); } - static inline simd_type fma(const simd_type& a, const simd_type& b, - ctype n) { + static inline simd_type fma(const simd_type& a, const simd_type& b, ctype n) { #if defined(__ARM_FEATURE_FMA) && defined(__aarch64__) return vfmaq_n_f32(a, b, n); #else return vmlaq_n_f32(a, b, n); #endif } - static inline simd_type fma(const simd_type& a, const simd_type& b, - const simd_type& c) { + static inline simd_type fma( + const simd_type& a, const simd_type& b, const simd_type& c) { #if defined(__ARM_FEATURE_FMA) return vfmaq_f32(a, b, c); #else @@ -70,29 +67,26 @@ struct SIMDHelper<__fp16> { using ctype = __fp16; static constexpr size_t simd_width = 8; - static inline simd_type load(const ctype* src_ptr) { - return vld1q_f16(src_ptr); - } + static inline simd_type load(const ctype* src_ptr) { return vld1q_f16(src_ptr); } static inline void store(ctype* dst_ptr, const simd_type& rdst) { vst1q_f16(dst_ptr, rdst); } - static inline void store2_interleave(ctype* dst_ptr, const simd_type& rdst1, - const simd_type& rdst2) { + static inline void store2_interleave( + ctype* dst_ptr, const simd_type& rdst1, const simd_type& rdst2) { simd_type_x2 rdst; rdst.val[0] = rdst1; rdst.val[1] = rdst2; vst2q_f16(dst_ptr, rdst); } - static inline simd_type fma(const simd_type& a, const simd_type& b, - ctype n) { + static inline simd_type fma(const simd_type& a, const simd_type& b, ctype n) { #if defined(__ARM_FEATURE_FMA) && defined(__aarch64__) return vfmaq_n_f16(a, b, n); #else return vaddq_f16(a, vmulq_n_f16(b, n)); #endif } - static inline simd_type fma(const simd_type& a, const simd_type& b, - const simd_type& c) { + static inline simd_type fma( + const simd_type& a, const simd_type& b, const simd_type& c) { return vfmaq_f16(a, b, c); } static inline simd_type dup(float val) { return vdupq_n_f16(val); } @@ -129,6 +123,6 @@ static inline std::tuple get_nearest_linear_coord( return std::make_tuple(1 - alpha, origin_idx, alpha, origin_idx + 1); } -}; -}; -}; +}; // namespace resize +}; // namespace arm_common +}; // namespace megdnn diff --git a/dnn/src/arm_common/resize/opr_impl.cpp b/dnn/src/arm_common/resize/opr_impl.cpp index 3c421bec..c9225da9 100644 --- a/dnn/src/arm_common/resize/opr_impl.cpp +++ b/dnn/src/arm_common/resize/opr_impl.cpp @@ -24,12 +24,11 @@ MIDOUT_DECL(megdnn_arm_resize) namespace megdnn { namespace arm_common { -void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_workspace workspace) { +void ResizeImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); - bool is_contiguous = - src.layout.is_contiguous() && dst.layout.is_contiguous(); + bool is_contiguous = src.layout.is_contiguous() && dst.layout.is_contiguous(); bool is_dtype_same = src.layout.dtype == dst.layout.dtype; bool is_dtype_fp32 = src.layout.dtype == dtype::Float32(); bool is_dtype_fp16 = @@ -56,8 +55,7 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, bool usable = is_contiguous && is_dtype_supported && is_imode_supported; if (param().format == param::Resize::Format::NHWC && - (src.layout[3] == 1 || src.layout[3] == 3) && - is_nhwc_contig_wc(src.layout)) { + (src.layout[3] == 1 || src.layout[3] == 3) && is_nhwc_contig_wc(src.layout)) { MEGDNN_DISPATCH_CPU_KERN_OPR(resize_cv_exec(src, dst, param().imode)); } else if (!usable) { fallback::ResizeImpl::exec(src, dst, workspace); @@ -69,16 +67,14 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, if (is_imode_nearest) { MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(0)) { MEGDNN_DISPATCH_CPU_KERN_OPR( - resize_nearest_upsample2_nchw44_fp32( - kern_param)); + resize_nearest_upsample2_nchw44_fp32(kern_param)); } MIDOUT_END(); } else { megdnn_assert(is_imode_linear, "invalid imode"); MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(1)) { MEGDNN_DISPATCH_CPU_KERN_OPR( - resize_linear_upsample2_nchw44_fp32( - kern_param)); + resize_linear_upsample2_nchw44_fp32(kern_param)); } MIDOUT_END(); } @@ -129,16 +125,14 @@ void ResizeImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, if (is_imode_nearest) { MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(6)) { MEGDNN_DISPATCH_CPU_KERN_OPR( - resize_nearest_upsample2_nchw88_fp16( - kern_param)); + resize_nearest_upsample2_nchw88_fp16(kern_param)); } MIDOUT_END(); } else { megdnn_assert(is_imode_linear, "invalid imode"); MIDOUT_BEGIN(megdnn_arm_resize, midout_iv(7)) { MEGDNN_DISPATCH_CPU_KERN_OPR( - resize_linear_upsample2_nchw88_fp16( - kern_param)); + resize_linear_upsample2_nchw88_fp16(kern_param)); } MIDOUT_END(); } diff --git a/dnn/src/arm_common/resize/opr_impl.h b/dnn/src/arm_common/resize/opr_impl.h index 6f223368..58983426 100644 --- a/dnn/src/arm_common/resize/opr_impl.h +++ b/dnn/src/arm_common/resize/opr_impl.h @@ -19,11 +19,11 @@ class ResizeImpl : public fallback::ResizeImpl { public: using fallback::ResizeImpl::ResizeImpl; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, - const TensorLayout&) override { + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { return 0; } }; diff --git a/dnn/src/arm_common/resize/resize_cv.cpp b/dnn/src/arm_common/resize/resize_cv.cpp index cf2a4d1d..a8bc50c4 100644 --- a/dnn/src/arm_common/resize/resize_cv.cpp +++ b/dnn/src/arm_common/resize/resize_cv.cpp @@ -58,10 +58,10 @@ * * --------------------------------------------------------------------------- */ +#include "src/arm_common/resize/resize_cv.h" #include #include "src/arm_common/handle.h" #include "src/arm_common/resize/opr_impl.h" -#include "src/arm_common/resize/resize_cv.h" #include "src/common/cv/common.h" #include "src/common/cv/helper.h" #include "src/common/utils.h" @@ -184,10 +184,10 @@ void resize_nearest_32f(const Mat32f& src, Mat32f& dst) { } // linear 32f -void build_tabs_linear_32f(const Mat32f& src, const Mat32f& dst, - AlignedVector& tabsx, AlignedVector& tabsy, - AlignedVector& tabrx, - AlignedVector& tabry) { +void build_tabs_linear_32f( + const Mat32f& src, const Mat32f& dst, AlignedVector& tabsx, + AlignedVector& tabsy, AlignedVector& tabrx, + AlignedVector& tabry) { megdnn_assert(src.rows() >= 2); megdnn_assert(src.cols() >= 2); megdnn_assert(dst.rows() >= 2); @@ -226,13 +226,11 @@ void build_tabs_linear_32f(const Mat32f& src, const Mat32f& dst, } } -void calc_cache_linear_32fc1_1(const Mat32f& src, const Mat32f& dst, - const AlignedVector& tabsx, - const AlignedVector& tabsy, - const AlignedVector& tabrx, - const AlignedVector& tabry, int dx, - AlignedVector& cache0, - AlignedVector& cache1) { +void calc_cache_linear_32fc1_1( + const Mat32f& src, const Mat32f& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { (void)tabrx; const float* psrc1 = src.ptr(tabsx[dx] + 1); size_t dstcols = dst.cols(); @@ -249,13 +247,11 @@ void calc_cache_linear_32fc1_1(const Mat32f& src, const Mat32f& dst, } } -void calc_cache_linear_32fc1_2(const Mat32f& src, const Mat32f& dst, - const AlignedVector& tabsx, - const AlignedVector& tabsy, - const AlignedVector& tabrx, - const AlignedVector& tabry, int dx, - AlignedVector& cache0, - AlignedVector& cache1) { +void calc_cache_linear_32fc1_2( + const Mat32f& src, const Mat32f& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { (void)tabrx; const float* psrc0 = src.ptr(tabsx[dx] + 0); const float* psrc1 = src.ptr(tabsx[dx] + 1); @@ -267,43 +263,45 @@ void calc_cache_linear_32fc1_2(const Mat32f& src, const Mat32f& dst, float* cache1_ptr = cache1.data(); const float* tabry_ptr = tabry.data(); for (; dy + 4 <= dstcols; dy += 4) { -#define EXPAND(dy) \ - { \ - int t0 = tabsy[dy + 0]; \ - int t1 = tabsy[dy + 1]; \ - int t2 = tabsy[dy + 2]; \ - int t3 = tabsy[dy + 3]; \ - const float pcsrc00[4] = {psrc0[t0 + 0], psrc0[t1 + 0], psrc0[t2 + 0], \ - psrc0[t3 + 0]}; \ - const float pcsrc01[4] = { \ - psrc0[t0 + 1], \ - psrc0[t1 + 1], \ - psrc0[t2 + 1], \ - psrc0[t3 + 1], \ - }; \ - const float pcsrc10[4] = { \ - psrc1[t0 + 0], \ - psrc1[t1 + 0], \ - psrc1[t2 + 0], \ - psrc1[t3 + 0], \ - }; \ - const float pcsrc11[4] = { \ - psrc1[t0 + 1], \ - psrc1[t1 + 1], \ - psrc1[t2 + 1], \ - psrc1[t3 + 1], \ - }; \ - float32x4_t v_pcsrc00 = vld1q_f32(pcsrc00); \ - float32x4_t v_pcsrc01 = vld1q_f32(pcsrc01); \ - float32x4_t v_pcsrc10 = vld1q_f32(pcsrc10); \ - float32x4_t v_pcsrc11 = vld1q_f32(pcsrc11); \ - float32x4_t v_ry = vld1q_f32(tabry_ptr + dy); \ - float32x4_t v_iry = vsubq_f32(vdupq_n_f32(1.0f), v_ry); \ - vst1q_f32(cache0_ptr + dy, \ - vmlaq_f32(vmulq_f32(v_pcsrc01, v_ry), v_pcsrc00, v_iry)); \ - vst1q_f32(cache1_ptr + dy, \ - vmlaq_f32(vmulq_f32(v_pcsrc11, v_ry), v_pcsrc10, v_iry)); \ - } \ +#define EXPAND(dy) \ + { \ + int t0 = tabsy[dy + 0]; \ + int t1 = tabsy[dy + 1]; \ + int t2 = tabsy[dy + 2]; \ + int t3 = tabsy[dy + 3]; \ + const float pcsrc00[4] = { \ + psrc0[t0 + 0], psrc0[t1 + 0], psrc0[t2 + 0], psrc0[t3 + 0]}; \ + const float pcsrc01[4] = { \ + psrc0[t0 + 1], \ + psrc0[t1 + 1], \ + psrc0[t2 + 1], \ + psrc0[t3 + 1], \ + }; \ + const float pcsrc10[4] = { \ + psrc1[t0 + 0], \ + psrc1[t1 + 0], \ + psrc1[t2 + 0], \ + psrc1[t3 + 0], \ + }; \ + const float pcsrc11[4] = { \ + psrc1[t0 + 1], \ + psrc1[t1 + 1], \ + psrc1[t2 + 1], \ + psrc1[t3 + 1], \ + }; \ + float32x4_t v_pcsrc00 = vld1q_f32(pcsrc00); \ + float32x4_t v_pcsrc01 = vld1q_f32(pcsrc01); \ + float32x4_t v_pcsrc10 = vld1q_f32(pcsrc10); \ + float32x4_t v_pcsrc11 = vld1q_f32(pcsrc11); \ + float32x4_t v_ry = vld1q_f32(tabry_ptr + dy); \ + float32x4_t v_iry = vsubq_f32(vdupq_n_f32(1.0f), v_ry); \ + vst1q_f32( \ + cache0_ptr + dy, \ + vmlaq_f32(vmulq_f32(v_pcsrc01, v_ry), v_pcsrc00, v_iry)); \ + vst1q_f32( \ + cache1_ptr + dy, \ + vmlaq_f32(vmulq_f32(v_pcsrc11, v_ry), v_pcsrc10, v_iry)); \ + } \ while (0) EXPAND(dy); @@ -321,13 +319,11 @@ void calc_cache_linear_32fc1_2(const Mat32f& src, const Mat32f& dst, } } -void calc_cache_linear_32fc3_1(const Mat32f& src, const Mat32f& dst, - const AlignedVector& tabsx, - const AlignedVector& tabsy, - const AlignedVector& tabrx, - const AlignedVector& tabry, int dx, - AlignedVector& cache0, - AlignedVector& cache1) { +void calc_cache_linear_32fc3_1( + const Mat32f& src, const Mat32f& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { (void)tabrx; const float* psrc1 = src.ptr(tabsx[dx] + 1); const size_t dstcols = dst.cols(); @@ -346,13 +342,11 @@ void calc_cache_linear_32fc3_1(const Mat32f& src, const Mat32f& dst, } } -void calc_cache_linear_32fc3_2(const Mat32f& src, const Mat32f& dst, - const AlignedVector& tabsx, - const AlignedVector& tabsy, - const AlignedVector& tabrx, - const AlignedVector& tabry, int dx, - AlignedVector& cache0, - AlignedVector& cache1) { +void calc_cache_linear_32fc3_2( + const Mat32f& src, const Mat32f& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { (void)tabrx; const float* psrc0 = src.ptr(tabsx[dx] + 0); const float* psrc1 = src.ptr(tabsx[dx] + 1); @@ -388,11 +382,11 @@ void resize_linear_32f_neon(const Mat32f& src, Mat32f& dst) { for (int dx = 0; dx < dstrows; ++dx) { if (dx == 0 || tabsx[dx] != tabsx[dx - 1]) { if (dx > 0 && tabsx[dx] == tabsx[dx - 1] + 1) { - calc_cache_linear_32fc1_1(src, dst, tabsx, tabsy, tabrx, - tabry, dx, cache0, cache1); + calc_cache_linear_32fc1_1( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); } else { - calc_cache_linear_32fc1_2(src, dst, tabsx, tabsy, tabrx, - tabry, dx, cache0, cache1); + calc_cache_linear_32fc1_2( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); } } const float* cache0_ptr = cache0.data(); @@ -404,8 +398,7 @@ void resize_linear_32f_neon(const Mat32f& src, Mat32f& dst) { #define EXPAND(x) \ v_cache0 = vld1q_f32(cache0_ptr + dy + x); \ v_cache1 = vld1q_f32(cache1_ptr + dy + x); \ - vst1q_f32(pdst + dy + x, \ - vmlaq_f32(vmulq_f32(v_rx, v_cache1), v_irx, v_cache0)); + vst1q_f32(pdst + dy + x, vmlaq_f32(vmulq_f32(v_rx, v_cache1), v_irx, v_cache0)); float32x4_t v_rx = vdupq_n_f32(rx); float32x4_t v_irx = vdupq_n_f32(irx); for (; dy + 8 <= dstcols; dy += 8) { @@ -433,11 +426,11 @@ void resize_linear_32f_neon(const Mat32f& src, Mat32f& dst) { for (int dx = 0; dx < dstrows; ++dx) { if (dx == 0 || tabsx[dx] != tabsx[dx - 1]) { if (dx > 0 && tabsx[dx] == tabsx[dx - 1] + 1) { - calc_cache_linear_32fc3_1(src, dst, tabsx, tabsy, tabrx, - tabry, dx, cache0, cache1); + calc_cache_linear_32fc3_1( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); } else { - calc_cache_linear_32fc3_2(src, dst, tabsx, tabsy, tabrx, - tabry, dx, cache0, cache1); + calc_cache_linear_32fc3_2( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); } } const float* cache0_ptr = cache0.data(); @@ -448,15 +441,15 @@ void resize_linear_32f_neon(const Mat32f& src, Mat32f& dst) { int dy = 0; float32x4_t v_rx = vdupq_n_f32(rx); float32x4_t v_irx = vdupq_n_f32(irx); -#define EXPAND(x) \ - v_cache0 = vld3q_f32(cache0_ptr + dy + (x)*3); \ - v_cache1 = vld3q_f32(cache1_ptr + dy + (x)*3); \ - v_dst.val[0] = vmlaq_f32(vmulq_f32(v_rx, v_cache1.val[0]), v_irx, \ - v_cache0.val[0]); \ - v_dst.val[1] = vmlaq_f32(vmulq_f32(v_rx, v_cache1.val[1]), v_irx, \ - v_cache0.val[1]); \ - v_dst.val[2] = vmlaq_f32(vmulq_f32(v_rx, v_cache1.val[2]), v_irx, \ - v_cache0.val[2]); \ +#define EXPAND(x) \ + v_cache0 = vld3q_f32(cache0_ptr + dy + (x)*3); \ + v_cache1 = vld3q_f32(cache1_ptr + dy + (x)*3); \ + v_dst.val[0] = \ + vmlaq_f32(vmulq_f32(v_rx, v_cache1.val[0]), v_irx, v_cache0.val[0]); \ + v_dst.val[1] = \ + vmlaq_f32(vmulq_f32(v_rx, v_cache1.val[1]), v_irx, v_cache0.val[1]); \ + v_dst.val[2] = \ + vmlaq_f32(vmulq_f32(v_rx, v_cache1.val[2]), v_irx, v_cache0.val[2]); \ vst3q_f32(pdst + dy + (x)*3, v_dst); for (; dy + 8 * 3 <= dstcols; dy += 8 * 3) { float32x4x3_t v_cache0; @@ -494,10 +487,10 @@ void resize_linear_32f(const Mat32f& src, Mat32f& dst) { } // linear 8u -void build_tabs_linear_8u(const Mat8u& src, const Mat8u& dst, - AlignedVector& tabsx, AlignedVector& tabsy, - AlignedVector& tabrx, - AlignedVector& tabry) { +void build_tabs_linear_8u( + const Mat8u& src, const Mat8u& dst, AlignedVector& tabsx, + AlignedVector& tabsy, AlignedVector& tabrx, + AlignedVector& tabry) { megdnn_assert(src.rows() >= 2); megdnn_assert(src.cols() >= 2); megdnn_assert(dst.rows() >= 2); @@ -536,12 +529,11 @@ void build_tabs_linear_8u(const Mat8u& src, const Mat8u& dst, } } -void calc_cache_8uc1_1(const Mat8u& src, const Mat8u& dst, - const AlignedVector& tabsx, - const AlignedVector& tabsy, - const AlignedVector& tabrx, - const AlignedVector& tabry, int dx, - AlignedVector& cache0, AlignedVector& cache1) { +void calc_cache_8uc1_1( + const Mat8u& src, const Mat8u& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { (void)tabrx; const uchar* psrc1 = src.ptr(tabsx[dx] + 1); size_t dstcols = dst.cols(); @@ -558,12 +550,11 @@ void calc_cache_8uc1_1(const Mat8u& src, const Mat8u& dst, } } -void calc_cache_8uc1_2(const Mat8u& src, const Mat8u& dst, - const AlignedVector& tabsx, - const AlignedVector& tabsy, - const AlignedVector& tabrx, - const AlignedVector& tabry, int dx, - AlignedVector& cache0, AlignedVector& cache1) { +void calc_cache_8uc1_2( + const Mat8u& src, const Mat8u& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { (void)tabrx; const uchar* psrc0 = src.ptr(tabsx[dx] + 0); const uchar* psrc1 = src.ptr(tabsx[dx] + 1); @@ -583,12 +574,11 @@ void calc_cache_8uc1_2(const Mat8u& src, const Mat8u& dst, } } -void calc_cache_8uc3_1(const Mat8u& src, const Mat8u& dst, - const AlignedVector& tabsx, - const AlignedVector& tabsy, - const AlignedVector& tabrx, - const AlignedVector& tabry, int dx, - AlignedVector& cache0, AlignedVector& cache1) { +void calc_cache_8uc3_1( + const Mat8u& src, const Mat8u& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { (void)tabrx; const uchar* psrc1 = src.ptr(tabsx[dx] + 1); size_t dstcols = dst.cols(); @@ -607,12 +597,11 @@ void calc_cache_8uc3_1(const Mat8u& src, const Mat8u& dst, } } -void calc_cache_8uc3_2(const Mat8u& src, const Mat8u& dst, - const AlignedVector& tabsx, - const AlignedVector& tabsy, - const AlignedVector& tabrx, - const AlignedVector& tabry, int dx, - AlignedVector& cache0, AlignedVector& cache1) { +void calc_cache_8uc3_2( + const Mat8u& src, const Mat8u& dst, const AlignedVector& tabsx, + const AlignedVector& tabsy, const AlignedVector& tabrx, + const AlignedVector& tabry, int dx, AlignedVector& cache0, + AlignedVector& cache1) { (void)tabrx; const uchar* psrc0 = src.ptr(tabsx[dx] + 0); const uchar* psrc1 = src.ptr(tabsx[dx] + 1); @@ -650,11 +639,11 @@ void resize_linear_8u_neon(const Mat8u& src, Mat8u& dst) { for (int dx = 0; dx < dstrows; ++dx) { if (dx == 0 || tabsx[dx] != tabsx[dx - 1]) { if (dx > 0 && tabsx[dx] == tabsx[dx - 1] + 1) { - calc_cache_8uc1_1(src, dst, tabsx, tabsy, tabrx, tabry, dx, - cache0, cache1); + calc_cache_8uc1_1( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); } else { - calc_cache_8uc1_2(src, dst, tabsx, tabsy, tabrx, tabry, dx, - cache0, cache1); + calc_cache_8uc1_2( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); } } int rx = tabrx[dx]; @@ -687,18 +676,14 @@ void resize_linear_8u_neon(const Mat8u& src, Mat8u& dst) { v_cache1_c = vld1q_s32(cache1_ptr + dy + 0xc); int16x4_t v_ans0, v_ans4, v_ans8, v_ansc; - v_ans0 = vqshrn_n_s32(vmlaq_s32(vmulq_s32(v_rx, v_cache1_0), - v_irx, v_cache0_0), - 16); - v_ans4 = vqshrn_n_s32(vmlaq_s32(vmulq_s32(v_rx, v_cache1_4), - v_irx, v_cache0_4), - 16); - v_ans8 = vqshrn_n_s32(vmlaq_s32(vmulq_s32(v_rx, v_cache1_8), - v_irx, v_cache0_8), - 16); - v_ansc = vqshrn_n_s32(vmlaq_s32(vmulq_s32(v_rx, v_cache1_c), - v_irx, v_cache0_c), - 16); + v_ans0 = vqshrn_n_s32( + vmlaq_s32(vmulq_s32(v_rx, v_cache1_0), v_irx, v_cache0_0), 16); + v_ans4 = vqshrn_n_s32( + vmlaq_s32(vmulq_s32(v_rx, v_cache1_4), v_irx, v_cache0_4), 16); + v_ans8 = vqshrn_n_s32( + vmlaq_s32(vmulq_s32(v_rx, v_cache1_8), v_irx, v_cache0_8), 16); + v_ansc = vqshrn_n_s32( + vmlaq_s32(vmulq_s32(v_rx, v_cache1_c), v_irx, v_cache0_c), 16); int16x8_t v_half16_0, v_half16_1; v_half16_0 = vcombine_s16(v_ans0, v_ans4); // x0 @@ -713,8 +698,7 @@ void resize_linear_8u_neon(const Mat8u& src, Mat8u& dst) { for (; dy < dstcols; ++dy) { uchar* pcdst = pdst + dy; - pcdst[0] = - (rx * cache1[dy] + irx * cache0[dy]) >> (SCALE + SCALE); + pcdst[0] = (rx * cache1[dy] + irx * cache0[dy]) >> (SCALE + SCALE); } } } else if (src.channels() == 3) { @@ -724,11 +708,11 @@ void resize_linear_8u_neon(const Mat8u& src, Mat8u& dst) { for (int dx = 0; dx < dstrows; ++dx) { if (dx == 0 || tabsx[dx] != tabsx[dx - 1]) { if (dx > 0 && tabsx[dx] == tabsx[dx - 1] + 1) { - calc_cache_8uc3_1(src, dst, tabsx, tabsy, tabrx, tabry, dx, - cache0, cache1); + calc_cache_8uc3_1( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); } else { - calc_cache_8uc3_2(src, dst, tabsx, tabsy, tabrx, tabry, dx, - cache0, cache1); + calc_cache_8uc3_2( + src, dst, tabsx, tabsy, tabrx, tabry, dx, cache0, cache1); } } int rx = tabrx[dx]; @@ -738,12 +722,12 @@ void resize_linear_8u_neon(const Mat8u& src, Mat8u& dst) { for (; dy < dstcols; dy += 3) { uchar* pcdst = pdst + dy; - pcdst[0] = (rx * cache1[dy + 0] + irx * cache0[dy + 0]) >> - (SCALE + SCALE); - pcdst[1] = (rx * cache1[dy + 1] + irx * cache0[dy + 1]) >> - (SCALE + SCALE); - pcdst[2] = (rx * cache1[dy + 2] + irx * cache0[dy + 2]) >> - (SCALE + SCALE); + pcdst[0] = + (rx * cache1[dy + 0] + irx * cache0[dy + 0]) >> (SCALE + SCALE); + pcdst[1] = + (rx * cache1[dy + 1] + irx * cache0[dy + 1]) >> (SCALE + SCALE); + pcdst[2] = + (rx * cache1[dy + 2] + irx * cache0[dy + 2]) >> (SCALE + SCALE); } } } else { @@ -759,8 +743,9 @@ const int INTER_RESIZE_COEF_BITS = 11; const int INTER_RESIZE_COEF_SCALE = 1 << INTER_RESIZE_COEF_BITS; const float MEGCV_PI = acos(-1); struct HResizeNoVec { - int operator()(const uchar**, uchar**, int, const int*, const uchar*, int, - int, int, int, int) const { + int operator()( + const uchar**, uchar**, int, const int*, const uchar*, int, int, int, int, + int) const { return 0; } }; @@ -775,8 +760,8 @@ struct ResizeAreaFastNoVec { }; struct VResizeCubicVec_32f { - int operator()(const uchar** _src, uchar* _dst, const uchar* _beta, - int width) const { + int operator()( + const uchar** _src, uchar* _dst, const uchar* _beta, int width) const { const float** src = (const float**)_src; const float* beta = (const float*)_beta; const float *S0 = src[0], *S1 = src[1], *S2 = src[2], *S3 = src[3]; @@ -787,20 +772,22 @@ struct VResizeCubicVec_32f { for (; x <= width - 8; x += 8) { vst1q_f32( - dst + x, - vmlaq_f32(vmlaq_f32(vmlaq_f32(vmulq_f32(v_b0, - vld1q_f32(S0 + x)), - v_b1, vld1q_f32(S1 + x)), - v_b2, vld1q_f32(S2 + x)), - v_b3, vld1q_f32(S3 + x))); + dst + x, vmlaq_f32( + vmlaq_f32( + vmlaq_f32( + vmulq_f32(v_b0, vld1q_f32(S0 + x)), + v_b1, vld1q_f32(S1 + x)), + v_b2, vld1q_f32(S2 + x)), + v_b3, vld1q_f32(S3 + x))); vst1q_f32( dst + x + 4, - vmlaq_f32(vmlaq_f32(vmlaq_f32(vmulq_f32(v_b0, - vld1q_f32(S0 + x + - 4)), - v_b1, vld1q_f32(S1 + x + 4)), - v_b2, vld1q_f32(S2 + x + 4)), - v_b3, vld1q_f32(S3 + x + 4))); + vmlaq_f32( + vmlaq_f32( + vmlaq_f32( + vmulq_f32(v_b0, vld1q_f32(S0 + x + 4)), + v_b1, vld1q_f32(S1 + x + 4)), + v_b2, vld1q_f32(S2 + x + 4)), + v_b3, vld1q_f32(S3 + x + 4))); } return x; @@ -808,8 +795,8 @@ struct VResizeCubicVec_32f { }; struct VResizeLanczos4Vec_32f { - int operator()(const uchar** _src, uchar* _dst, const uchar* _beta, - int width) const { + int operator()( + const uchar** _src, uchar* _dst, const uchar* _beta, int width) const { const float** src = (const float**)_src; const float* beta = (const float*)_beta; const float *S0 = src[0], *S1 = src[1], *S2 = src[2], *S3 = src[3], @@ -823,14 +810,18 @@ struct VResizeLanczos4Vec_32f { for (; x <= width - 4; x += 4) { float32x4_t v_dst0 = vmlaq_f32( - vmlaq_f32(vmlaq_f32(vmulq_f32(v_b0, vld1q_f32(S0 + x)), - v_b1, vld1q_f32(S1 + x)), - v_b2, vld1q_f32(S2 + x)), + vmlaq_f32( + vmlaq_f32( + vmulq_f32(v_b0, vld1q_f32(S0 + x)), v_b1, + vld1q_f32(S1 + x)), + v_b2, vld1q_f32(S2 + x)), v_b3, vld1q_f32(S3 + x)); float32x4_t v_dst1 = vmlaq_f32( - vmlaq_f32(vmlaq_f32(vmulq_f32(v_b4, vld1q_f32(S4 + x)), - v_b5, vld1q_f32(S5 + x)), - v_b6, vld1q_f32(S6 + x)), + vmlaq_f32( + vmlaq_f32( + vmulq_f32(v_b4, vld1q_f32(S4 + x)), v_b5, + vld1q_f32(S5 + x)), + v_b6, vld1q_f32(S6 + x)), v_b7, vld1q_f32(S7 + x)); vst1q_f32(dst + x, vaddq_f32(v_dst0, v_dst1)); } @@ -839,8 +830,8 @@ struct VResizeLanczos4Vec_32f { } }; struct VResizeLinearVec_32f { - int operator()(const uchar** _src, uchar* _dst, const uchar* _beta, - int width) const { + int operator()( + const uchar** _src, uchar* _dst, const uchar* _beta, int width) const { const float** src = (const float**)_src; const float* beta = (const float*)_beta; const float *S0 = src[0], *S1 = src[1]; @@ -850,23 +841,19 @@ struct VResizeLinearVec_32f { float32x4_t v_b0 = vdupq_n_f32(beta[0]), v_b1 = vdupq_n_f32(beta[1]); for (; x <= width - 8; x += 8) { - float32x4_t v_src00 = vld1q_f32(S0 + x), - v_src01 = vld1q_f32(S0 + x + 4); - float32x4_t v_src10 = vld1q_f32(S1 + x), - v_src11 = vld1q_f32(S1 + x + 4); - - vst1q_f32(dst + x, - vmlaq_f32(vmulq_f32(v_src00, v_b0), v_src10, v_b1)); - vst1q_f32(dst + x + 4, - vmlaq_f32(vmulq_f32(v_src01, v_b0), v_src11, v_b1)); + float32x4_t v_src00 = vld1q_f32(S0 + x), v_src01 = vld1q_f32(S0 + x + 4); + float32x4_t v_src10 = vld1q_f32(S1 + x), v_src11 = vld1q_f32(S1 + x + 4); + + vst1q_f32(dst + x, vmlaq_f32(vmulq_f32(v_src00, v_b0), v_src10, v_b1)); + vst1q_f32(dst + x + 4, vmlaq_f32(vmulq_f32(v_src01, v_b0), v_src11, v_b1)); } return x; } }; struct VResizeLinearVec_32s8u { - int operator()(const uchar** _src, uchar* dst, const uchar* _beta, - int width) const { + int operator()( + const uchar** _src, uchar* dst, const uchar* _beta, int width) const { const int **src = (const int**)_src, *S0 = src[0], *S1 = src[1]; const short* beta = (const short*)_beta; int x = 0; @@ -879,14 +866,12 @@ struct VResizeLinearVec_32s8u { int32x4_t v_src01 = vshrq_n_s32(vld1q_s32(S0 + x + 4), 4), v_src11 = vshrq_n_s32(vld1q_s32(S1 + x + 4), 4); - int16x8_t v_src0 = - vcombine_s16(vmovn_s32(v_src00), vmovn_s32(v_src01)); - int16x8_t v_src1 = - vcombine_s16(vmovn_s32(v_src10), vmovn_s32(v_src11)); + int16x8_t v_src0 = vcombine_s16(vmovn_s32(v_src00), vmovn_s32(v_src01)); + int16x8_t v_src1 = vcombine_s16(vmovn_s32(v_src10), vmovn_s32(v_src11)); - int16x8_t v_dst0 = - vaddq_s16(vshrq_n_s16(vqdmulhq_s16(v_src0, v_b0), 1), - vshrq_n_s16(vqdmulhq_s16(v_src1, v_b1), 1)); + int16x8_t v_dst0 = vaddq_s16( + vshrq_n_s16(vqdmulhq_s16(v_src0, v_b0), 1), + vshrq_n_s16(vqdmulhq_s16(v_src1, v_b1), 1)); v_dst0 = vshrq_n_s16(vaddq_s16(v_dst0, v_delta), 2); v_src00 = vshrq_n_s32(vld1q_s32(S0 + x + 8), 4); @@ -897,13 +882,12 @@ struct VResizeLinearVec_32s8u { v_src0 = vcombine_s16(vmovn_s32(v_src00), vmovn_s32(v_src01)); v_src1 = vcombine_s16(vmovn_s32(v_src10), vmovn_s32(v_src11)); - int16x8_t v_dst1 = - vaddq_s16(vshrq_n_s16(vqdmulhq_s16(v_src0, v_b0), 1), - vshrq_n_s16(vqdmulhq_s16(v_src1, v_b1), 1)); + int16x8_t v_dst1 = vaddq_s16( + vshrq_n_s16(vqdmulhq_s16(v_src0, v_b0), 1), + vshrq_n_s16(vqdmulhq_s16(v_src1, v_b1), 1)); v_dst1 = vshrq_n_s16(vaddq_s16(v_dst1, v_delta), 2); - vst1q_u8(dst + x, - vcombine_u8(vqmovun_s16(v_dst0), vqmovun_s16(v_dst1))); + vst1q_u8(dst + x, vcombine_u8(vqmovun_s16(v_dst0), vqmovun_s16(v_dst1))); } return x; @@ -929,18 +913,20 @@ public: for (; dx <= w - 16; dx += 16, S0 += 32, S1 += 32, D += 16) { uint8x16x2_t v_row0 = vld2q_u8(S0), v_row1 = vld2q_u8(S1); - uint16x8_t v_dst0 = vaddl_u8(vget_low_u8(v_row0.val[0]), - vget_low_u8(v_row0.val[1])); - v_dst0 = - vaddq_u16(v_dst0, vaddl_u8(vget_low_u8(v_row1.val[0]), - vget_low_u8(v_row1.val[1]))); + uint16x8_t v_dst0 = vaddl_u8( + vget_low_u8(v_row0.val[0]), vget_low_u8(v_row0.val[1])); + v_dst0 = vaddq_u16( + v_dst0, vaddl_u8( + vget_low_u8(v_row1.val[0]), + vget_low_u8(v_row1.val[1]))); v_dst0 = vshrq_n_u16(vaddq_u16(v_dst0, v_2), 2); - uint16x8_t v_dst1 = vaddl_u8(vget_high_u8(v_row0.val[0]), - vget_high_u8(v_row0.val[1])); - v_dst1 = vaddq_u16(v_dst1, - vaddl_u8(vget_high_u8(v_row1.val[0]), - vget_high_u8(v_row1.val[1]))); + uint16x8_t v_dst1 = vaddl_u8( + vget_high_u8(v_row0.val[0]), vget_high_u8(v_row0.val[1])); + v_dst1 = vaddq_u16( + v_dst1, vaddl_u8( + vget_high_u8(v_row1.val[0]), + vget_high_u8(v_row1.val[1]))); v_dst1 = vshrq_n_u16(vaddq_u16(v_dst1, v_2), 2); vst1q_u8(D, vcombine_u8(vmovn_u16(v_dst0), vmovn_u16(v_dst1))); @@ -961,8 +947,7 @@ struct ResizeAreaFastVec_SIMD_32f { scale_y(_scale_y), cn(_cn), step(_step * sizeof(float)) { - fast_mode = - scale_x == 2 && scale_y == 2 && (cn == 1 || cn == 3 || cn == 4); + fast_mode = scale_x == 2 && scale_y == 2 && (cn == 1 || cn == 3 || cn == 4); } int operator()(const float* S, float* D, int w) const { @@ -1000,18 +985,17 @@ struct DecimateAlpha { float alpha; }; template -using ResizeFunc = void (*)(const Mat& src, Mat& dst, const int* xofs, - const void* alpha, const int* yofs, - const void* beta, int xmin, int xmax, int ksize); +using ResizeFunc = void (*)( + const Mat& src, Mat& dst, const int* xofs, const void* alpha, + const int* yofs, const void* beta, int xmin, int xmax, int ksize); template -using ResizeAreaFastFunc = void (*)(const Mat& src, Mat& dst, - const int* ofs, const int* xofs, - int scale_x, int scale_y); +using ResizeAreaFastFunc = void (*)( + const Mat& src, Mat& dst, const int* ofs, const int* xofs, int scale_x, + int scale_y); template -using ResizeAreaFunc = void (*)(const Mat& src, Mat& dst, - const DecimateAlpha* xtab, int xtab_size, - const DecimateAlpha* ytab, int ytab_size, - const int* yofs); +using ResizeAreaFunc = void (*)( + const Mat& src, Mat& dst, const DecimateAlpha* xtab, int xtab_size, + const DecimateAlpha* ytab, int ytab_size, const int* yofs); static inline void interpolate_cubic(float x, float* coeffs) { const float A = -0.75f; @@ -1052,9 +1036,9 @@ struct HResizeLanczos4 { typedef WT buf_type; typedef AT alpha_type; - void operator()(const T** src, WT** dst, int count, const int* xofs, - const AT* alpha, int swidth, int dwidth, int cn, int xmin, - int xmax) const { + void operator()( + const T** src, WT** dst, int count, const int* xofs, const AT* alpha, + int swidth, int dwidth, int cn, int xmin, int xmax) const { for (int k = 0; k < count; k++) { const T* S = src[k]; WT* D = dst[k]; @@ -1080,13 +1064,10 @@ struct HResizeLanczos4 { break; for (; dx < xmax; dx++, alpha += 8) { int sx = xofs[dx]; - D[dx] = S[sx - 1 * 3] * alpha[0] + - S[sx - 1 * 2] * alpha[1] + + D[dx] = S[sx - 1 * 3] * alpha[0] + S[sx - 1 * 2] * alpha[1] + S[sx - 1] * alpha[2] + S[sx] * alpha[3] + - S[sx + 1] * alpha[4] + - S[sx + 1 * 2] * alpha[5] + - S[sx + 1 * 3] * alpha[6] + - S[sx + 1 * 4] * alpha[7]; + S[sx + 1] * alpha[4] + S[sx + 1 * 2] * alpha[5] + + S[sx + 1 * 3] * alpha[6] + S[sx + 1 * 4] * alpha[7]; } limit = dwidth; } @@ -1112,13 +1093,10 @@ struct HResizeLanczos4 { break; for (; dx < xmax; dx++, alpha += 8) { int sx = xofs[dx]; - D[dx] = S[sx - 3 * 3] * alpha[0] + - S[sx - 3 * 2] * alpha[1] + + D[dx] = S[sx - 3 * 3] * alpha[0] + S[sx - 3 * 2] * alpha[1] + S[sx - 3] * alpha[2] + S[sx] * alpha[3] + - S[sx + 3] * alpha[4] + - S[sx + 3 * 2] * alpha[5] + - S[sx + 3 * 3] * alpha[6] + - S[sx + 3 * 4] * alpha[7]; + S[sx + 3] * alpha[4] + S[sx + 3 * 2] * alpha[5] + + S[sx + 3 * 3] * alpha[6] + S[sx + 3 * 4] * alpha[7]; } limit = dwidth; } @@ -1133,14 +1111,15 @@ struct HResizeLinear { typedef WT buf_type; typedef AT alpha_type; - void operator()(const T** src, WT** dst, int count, const int* xofs, - const AT* alpha, int swidth, int dwidth, int cn, int xmin, - int xmax) const { + void operator()( + const T** src, WT** dst, int count, const int* xofs, const AT* alpha, + int swidth, int dwidth, int cn, int xmin, int xmax) const { int dx, k; VecOp vecOp; - int dx0 = vecOp((const uchar**)src, (uchar**)dst, count, xofs, - (const uchar*)alpha, swidth, dwidth, cn, xmin, xmax); + int dx0 = + vecOp((const uchar**)src, (uchar**)dst, count, xofs, + (const uchar*)alpha, swidth, dwidth, cn, xmin, xmax); for (k = 0; k <= count - 2; k++) { const T *S0 = src[k], *S1 = src[k + 1]; @@ -1180,9 +1159,9 @@ struct HResizeCubic { typedef WT buf_type; typedef AT alpha_type; - void operator()(const T** src, WT** dst, int count, const int* xofs, - const AT* alpha, int swidth, int dwidth, int cn, int xmin, - int xmax) const { + void operator()( + const T** src, WT** dst, int count, const int* xofs, const AT* alpha, + int swidth, int dwidth, int cn, int xmin, int xmax) const { for (int k = 0; k < count; k++) { const T* S = src[k]; WT* D = dst[k]; @@ -1255,14 +1234,12 @@ struct VResizeLanczos4 { void operator()(const WT** src, T* dst, const AT* beta, int width) const { CastOp castOp; VecOp vecOp; - int k, x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, - width); + int k, x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, width); #if MEGCV_ENABLE_UNROLLED for (; x <= width - 4; x += 4) { WT b = beta[0]; const WT* S = src[0]; - WT s0 = S[x] * b, s1 = S[x + 1] * b, s2 = S[x + 2] * b, - s3 = S[x + 3] * b; + WT s0 = S[x] * b, s1 = S[x + 1] * b, s2 = S[x + 2] * b, s3 = S[x + 3] * b; for (k = 1; k < 8; k++) { b = beta[k]; @@ -1281,10 +1258,10 @@ struct VResizeLanczos4 { #endif for (; x < width; x++) { - dst[x] = castOp(src[0][x] * beta[0] + src[1][x] * beta[1] + - src[2][x] * beta[2] + src[3][x] * beta[3] + - src[4][x] * beta[4] + src[5][x] * beta[5] + - src[6][x] * beta[6] + src[7][x] * beta[7]); + dst[x] = castOp( + src[0][x] * beta[0] + src[1][x] * beta[1] + src[2][x] * beta[2] + + src[3][x] * beta[3] + src[4][x] * beta[4] + src[5][x] * beta[5] + + src[6][x] * beta[6] + src[7][x] * beta[7]); } } }; @@ -1299,8 +1276,7 @@ struct VResizeLinear { const WT *S0 = src[0], *S1 = src[1]; CastOp castOp; VecOp vecOp; - int x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, - width); + int x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, width); #if MEGCV_ENABLE_UNROLLED for (; x <= width - 4; x += 4) { WT t0, t1; @@ -1330,55 +1306,58 @@ struct VResizeCubic { CastOp castOp; VecOp vecOp; - int x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, - width); + int x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, width); for (; x < width; x++) dst[x] = castOp(S0[x] * b0 + S1[x] * b1 + S2[x] * b2 + S3[x] * b3); } }; template <> -struct VResizeLinear, - VResizeLinearVec_32s8u> { +struct VResizeLinear< + uchar, int, short, FixedPtCast, + VResizeLinearVec_32s8u> { typedef uchar value_type; typedef int buf_type; typedef short alpha_type; - void operator()(const buf_type** src, value_type* dst, - const alpha_type* beta, int width) const { + void operator()( + const buf_type** src, value_type* dst, const alpha_type* beta, + int width) const { alpha_type b0 = beta[0], b1 = beta[1]; const buf_type *S0 = src[0], *S1 = src[1]; VResizeLinearVec_32s8u vecOp; - int x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, - width); + int x = vecOp((const uchar**)src, (uchar*)dst, (const uchar*)beta, width); #if MEGCV_ENABLE_UNROLLED for (; x <= width - 4; x += 4) { - dst[x + 0] = uchar((((b0 * (S0[x + 0] >> 4)) >> 16) + - ((b1 * (S1[x + 0] >> 4)) >> 16) + 2) >> - 2); - dst[x + 1] = uchar((((b0 * (S0[x + 1] >> 4)) >> 16) + - ((b1 * (S1[x + 1] >> 4)) >> 16) + 2) >> - 2); - dst[x + 2] = uchar((((b0 * (S0[x + 2] >> 4)) >> 16) + - ((b1 * (S1[x + 2] >> 4)) >> 16) + 2) >> - 2); - dst[x + 3] = uchar((((b0 * (S0[x + 3] >> 4)) >> 16) + - ((b1 * (S1[x + 3] >> 4)) >> 16) + 2) >> - 2); + dst[x + 0] = + uchar((((b0 * (S0[x + 0] >> 4)) >> 16) + + ((b1 * (S1[x + 0] >> 4)) >> 16) + 2) >> + 2); + dst[x + 1] = + uchar((((b0 * (S0[x + 1] >> 4)) >> 16) + + ((b1 * (S1[x + 1] >> 4)) >> 16) + 2) >> + 2); + dst[x + 2] = + uchar((((b0 * (S0[x + 2] >> 4)) >> 16) + + ((b1 * (S1[x + 2] >> 4)) >> 16) + 2) >> + 2); + dst[x + 3] = + uchar((((b0 * (S0[x + 3] >> 4)) >> 16) + + ((b1 * (S1[x + 3] >> 4)) >> 16) + 2) >> + 2); } #endif for (; x < width; x++) - dst[x] = uchar((((b0 * (S0[x] >> 4)) >> 16) + - ((b1 * (S1[x] >> 4)) >> 16) + 2) >> - 2); + dst[x] = uchar( + (((b0 * (S0[x] >> 4)) >> 16) + ((b1 * (S1[x] >> 4)) >> 16) + 2) >> + 2); } }; template -void resizeGeneric_(const Mat& src, Mat& dst, const int* xofs, - const void* _alpha, const int* yofs, const void* _beta, - int xmin, int xmax, int ksize) { +void resizeGeneric_( + const Mat& src, Mat& dst, const int* xofs, const void* _alpha, + const int* yofs, const void* _beta, int xmin, int xmax, int ksize) { typedef typename HResize::value_type T; typedef typename HResize::buf_type WT; typedef typename HResize::alpha_type AT; @@ -1430,36 +1409,39 @@ void resizeGeneric_(const Mat& src, Mat& dst, const int* xofs, prev_sy[k] = sy; } if (k0 < ksize) - hresize(srows + k0, rows + k0, ksize - k0, xofs, alpha, swidth, - dwidth, cn, xmin, xmax); + hresize(srows + k0, rows + k0, ksize - k0, xofs, alpha, swidth, dwidth, cn, + xmin, xmax); vresize((const WT**)(rows), dst.ptr(dy), beta, dwidth); } } template -void setup_resize_env(InterpolationMode /* ip */, int& /* ksize */, - bool& /* fixedpt */, ResizeFunc& /* func */) { +void setup_resize_env( + InterpolationMode /* ip */, int& /* ksize */, bool& /* fixedpt */, + ResizeFunc& /* func */) { megdnn_throw(("unimplemented")); } template <> -void setup_resize_env(InterpolationMode ip, int& ksize, bool& fixedpt, - ResizeFunc& func) { +void setup_resize_env( + InterpolationMode ip, int& ksize, bool& fixedpt, ResizeFunc& func) { fixedpt = false; switch (ip) { case IMode::INTER_CUBIC: ksize = 4; func = resizeGeneric_< HResizeCubic, - VResizeCubic, - VResizeCubicVec_32f>, + VResizeCubic< + float, float, float, Cast, + VResizeCubicVec_32f>, float>; break; case IMode::INTER_LANCZOS4: ksize = 8; func = resizeGeneric_< HResizeLanczos4, - VResizeLanczos4, - VResizeLanczos4Vec_32f>, + VResizeLanczos4< + float, float, float, Cast, + VResizeLanczos4Vec_32f>, float>; break; case IMode::INTER_LINEAR: @@ -1467,8 +1449,9 @@ void setup_resize_env(InterpolationMode ip, int& ksize, bool& fixedpt, ksize = 2; func = resizeGeneric_< HResizeLinear, - VResizeLinear, - VResizeLinearVec_32f>, + VResizeLinear< + float, float, float, Cast, + VResizeLinearVec_32f>, float>; break; default: @@ -1476,8 +1459,8 @@ void setup_resize_env(InterpolationMode ip, int& ksize, bool& fixedpt, } } template <> -void setup_resize_env(InterpolationMode ip, int& ksize, bool& fixedpt, - ResizeFunc& func) { +void setup_resize_env( + InterpolationMode ip, int& ksize, bool& fixedpt, ResizeFunc& func) { fixedpt = true; switch (ip) { case IMode::INTER_CUBIC: @@ -1504,8 +1487,9 @@ void setup_resize_env(InterpolationMode ip, int& ksize, bool& fixedpt, case IMode::INTER_AREA: ksize = 2; func = resizeGeneric_< - HResizeLinear, + HResizeLinear< + uchar, int, short, INTER_RESIZE_COEF_SCALE, + HResizeLinearVec_8u32s>, VResizeLinear< uchar, int, short, FixedPtCast, @@ -1517,8 +1501,8 @@ void setup_resize_env(InterpolationMode ip, int& ksize, bool& fixedpt, } } -int compute_resize_area_tab(int ssize, int dsize, int cn, double scale, - DecimateAlpha* tab) { +int compute_resize_area_tab( + int ssize, int dsize, int cn, double scale, DecimateAlpha* tab) { int k = 0; for (int dx = 0; dx < dsize; dx++) { double fsx1 = dx * scale; @@ -1549,8 +1533,7 @@ int compute_resize_area_tab(int ssize, int dsize, int cn, double scale, tab[k].di = dx * cn; tab[k].si = sx2 * cn; tab[k++].alpha = - (float)(std::min(std::min(fsx2 - sx2, 1.), cellWidth) / - cellWidth); + (float)(std::min(std::min(fsx2 - sx2, 1.), cellWidth) / cellWidth); } } return k; @@ -1558,8 +1541,9 @@ int compute_resize_area_tab(int ssize, int dsize, int cn, double scale, // resize Area Fast template -void resizeAreaFast_(const Mat& src, Mat& dst, const int* ofs, - const int* xofs, int scale_x, int scale_y) { +void resizeAreaFast_( + const Mat& src, Mat& dst, const int* ofs, const int* xofs, int scale_x, + int scale_y) { // Range range(0, dst.rows); int swidth = src.width(); int sheight = src.height(); @@ -1593,8 +1577,7 @@ void resizeAreaFast_(const Mat& src, Mat& dst, const int* ofs, k = 0; #if MEGCV_ENABLE_UNROLLED for (; k <= area - 4; k += 4) - sum += S[ofs[k]] + S[ofs[k + 1]] + S[ofs[k + 2]] + - S[ofs[k + 3]]; + sum += S[ofs[k]] + S[ofs[k + 1]] + S[ofs[k + 2]] + S[ofs[k + 3]]; #endif for (; k < area; k++) sum += S[ofs[k]]; @@ -1633,8 +1616,7 @@ struct ResizeAreaFastVec { cn(_cn), step(_step), vecOp(_cn, _step) { - fast_mode = - scale_x == 2 && scale_y == 2 && (cn == 1 || cn == 3 || cn == 4); + fast_mode = scale_x == 2 && scale_y == 2 && (cn == 1 || cn == 3 || cn == 4); } int operator()(const T* S, T* D, int w) const { @@ -1647,39 +1629,47 @@ struct ResizeAreaFastVec { if (cn == 1) for (; dx < w; ++dx) { int index = dx * 2; - D[dx] = (T)((S[index] + S[index + 1] + nextS[index] + - nextS[index + 1] + 2) >> + D[dx] = + (T)((S[index] + S[index + 1] + nextS[index] + nextS[index + 1] + + 2) >> 2); } else if (cn == 3) for (; dx < w; dx += 3) { int index = dx * 2; - D[dx] = (T)((S[index] + S[index + 3] + nextS[index] + - nextS[index + 3] + 2) >> + D[dx] = + (T)((S[index] + S[index + 3] + nextS[index] + nextS[index + 3] + + 2) >> + 2); + D[dx + 1] = + (T)((S[index + 1] + S[index + 4] + nextS[index + 1] + + nextS[index + 4] + 2) >> + 2); + D[dx + 2] = + (T)((S[index + 2] + S[index + 5] + nextS[index + 2] + + nextS[index + 5] + 2) >> 2); - D[dx + 1] = (T)((S[index + 1] + S[index + 4] + - nextS[index + 1] + nextS[index + 4] + 2) >> - 2); - D[dx + 2] = (T)((S[index + 2] + S[index + 5] + - nextS[index + 2] + nextS[index + 5] + 2) >> - 2); } else { megdnn_assert(cn == 4); for (; dx < w; dx += 4) { int index = dx * 2; - D[dx] = (T)((S[index] + S[index + 4] + nextS[index] + - nextS[index + 4] + 2) >> + D[dx] = + (T)((S[index] + S[index + 4] + nextS[index] + nextS[index + 4] + + 2) >> + 2); + D[dx + 1] = + (T)((S[index + 1] + S[index + 5] + nextS[index + 1] + + nextS[index + 5] + 2) >> + 2); + D[dx + 2] = + (T)((S[index + 2] + S[index + 6] + nextS[index + 2] + + nextS[index + 6] + 2) >> + 2); + D[dx + 3] = + (T)((S[index + 3] + S[index + 7] + nextS[index + 3] + + nextS[index + 7] + 2) >> 2); - D[dx + 1] = (T)((S[index + 1] + S[index + 5] + - nextS[index + 1] + nextS[index + 5] + 2) >> - 2); - D[dx + 2] = (T)((S[index + 2] + S[index + 6] + - nextS[index + 2] + nextS[index + 6] + 2) >> - 2); - D[dx + 3] = (T)((S[index + 3] + S[index + 7] + - nextS[index + 3] + nextS[index + 7] + 2) >> - 2); } } @@ -1706,16 +1696,15 @@ ResizeAreaFastFunc get_resize_area_fast_func() { template <> ResizeAreaFastFunc get_resize_area_fast_func() { - return resizeAreaFast_>; + return resizeAreaFast_< + uchar, int, ResizeAreaFastVec>; } // Resize Area template -static void resizeArea_(const Mat& src, Mat& dst, - const DecimateAlpha* xtab, int xtab_size, - const DecimateAlpha* ytab, int ytab_size, - const int* tabofs) { +static void resizeArea_( + const Mat& src, Mat& dst, const DecimateAlpha* xtab, int xtab_size, + const DecimateAlpha* ytab, int ytab_size, const int* tabofs) { // parallel_for_(Range(0, dst.rows), // ResizeArea_Invoker(src, dst, xtab, xtab_size, ytab, ytab_size, // tabofs), dst.total()/((double)(1 << 16))); @@ -1847,10 +1836,8 @@ void resize_opencv(const Mat& src, Mat& dst, InterpolationMode ip) { ResizeAreaFunc func = get_resize_area_func(); AlignedVector _xytab((swidth + sheight) * 2); DecimateAlpha *xtab = _xytab.data(), *ytab = xtab + swidth * 2; - int xtab_size = - compute_resize_area_tab(swidth, dwidth, cn, scale_x, xtab); - int ytab_size = - compute_resize_area_tab(sheight, dheight, 1, scale_y, ytab); + int xtab_size = compute_resize_area_tab(swidth, dwidth, cn, scale_x, xtab); + int ytab_size = compute_resize_area_tab(sheight, dheight, 1, scale_y, ytab); AlignedVector _tabofs(dheight + 1); int* tabofs = _tabofs.data(); for (k = 0, dy = 0; k < ytab_size; ++k) { @@ -1870,8 +1857,8 @@ void resize_opencv(const Mat& src, Mat& dst, InterpolationMode ip) { bool fixedpt; setup_resize_env(ip, ksize, fixedpt, func); ksize2 = ksize / 2; - AlignedVector _buffer((width + dst.height()) * - (sizeof(int) + sizeof(float) * ksize)); + AlignedVector _buffer( + (width + dst.height()) * (sizeof(int) + sizeof(float) * ksize)); uchar* buffer = _buffer.data(); int* xofs = static_cast(static_cast(buffer)); int* yofs = xofs + width; @@ -1894,8 +1881,7 @@ void resize_opencv(const Mat& src, Mat& dst, InterpolationMode ip) { if (sx < ksize2 - 1) { xmin = dx + 1; - if (sx < 0 && - (ip != IMode::INTER_CUBIC && ip != IMode::INTER_LANCZOS4)) { + if (sx < 0 && (ip != IMode::INTER_CUBIC && ip != IMode::INTER_LANCZOS4)) { fx = 0; sx = 0; } @@ -1925,8 +1911,7 @@ void resize_opencv(const Mat& src, Mat& dst, InterpolationMode ip) { saturate_cast(cbuf[k] * INTER_RESIZE_COEF_SCALE); } for (; k < cn * ksize; ++k) { - ialpha[dx * cn * ksize + k] = - ialpha[dx * cn * ksize + k - ksize]; + ialpha[dx * cn * ksize + k] = ialpha[dx * cn * ksize + k - ksize]; } } else { for (k = 0; k < ksize; ++k) { @@ -1969,8 +1954,8 @@ void resize_opencv(const Mat& src, Mat& dst, InterpolationMode ip) { } func(src, dst, xofs, fixedpt ? static_cast(ialpha) : static_cast(alpha), yofs, - fixedpt ? static_cast(ibeta) : static_cast(beta), xmin, - xmax, ksize); + fixedpt ? static_cast(ibeta) : static_cast(beta), xmin, xmax, + ksize); } } // anonymous namespace @@ -1978,8 +1963,7 @@ void resize_opencv(const Mat& src, Mat& dst, InterpolationMode ip) { void megdnn::arm_common::resize_cv_exec( _megdnn_tensor_in src, _megdnn_tensor_out dst, param::Resize::InterpolationMode imode) { - megdnn_assert(src.layout[3] == 1 || src.layout[3] == 3, - "unsupported src channel"); + megdnn_assert(src.layout[3] == 1 || src.layout[3] == 3, "unsupported src channel"); for (size_t i = 0; i < src.layout.shape[0]; ++i) { if (dst.layout.dtype == dtype::Float32()) { MIDOUT_BEGIN(megdnn_arm_resizecv_dtype, midout_iv(0)) { diff --git a/dnn/src/arm_common/resize/resize_cv.h b/dnn/src/arm_common/resize/resize_cv.h index 491237f3..7c8cc19f 100644 --- a/dnn/src/arm_common/resize/resize_cv.h +++ b/dnn/src/arm_common/resize/resize_cv.h @@ -20,8 +20,9 @@ namespace arm_common { * \fn resize_cv_exec * \brief Used if the format is NHWC, transfer from megcv */ -void resize_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - param::Resize::InterpolationMode imode); +void resize_cv_exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + param::Resize::InterpolationMode imode); } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/resize/upsample2_nchw.cpp b/dnn/src/arm_common/resize/upsample2_nchw.cpp index ac012d40..8f139701 100644 --- a/dnn/src/arm_common/resize/upsample2_nchw.cpp +++ b/dnn/src/arm_common/resize/upsample2_nchw.cpp @@ -22,8 +22,7 @@ using namespace resize; namespace { template -static inline ctype compute_linear_element(const ctype src[4], - const ctype alpha[2]) { +static inline ctype compute_linear_element(const ctype src[4], const ctype alpha[2]) { return src[0] * alpha[0 ^ fh] * alpha[0 ^ fw] + src[1] * alpha[0 ^ fh] * alpha[1 ^ fw] + src[2] * alpha[1 ^ fh] * alpha[0 ^ fw] + @@ -43,9 +42,8 @@ static inline typename simd_helper::simd_type compute_linear_element_simd( } template -static inline void compute_linear_2x2_element(const ctype* src, ctype* dst, - size_t IW, size_t OW, - const ctype alpha[2]) { +static inline void compute_linear_2x2_element( + const ctype* src, ctype* dst, size_t IW, size_t OW, const ctype alpha[2]) { const ctype* src_ptr[4] = {src, src, src, src}; if (has_right) { @@ -77,9 +75,8 @@ static inline void compute_linear_2x2_element(const ctype* src, ctype* dst, template static inline void compute_linear_2x2_element_simd( - const typename simd_helper::ctype* src, - typename simd_helper::ctype* dst, size_t IW, size_t OW, - const typename simd_helper::simd_type alpha[2][2]) { + const typename simd_helper::ctype* src, typename simd_helper::ctype* dst, + size_t IW, size_t OW, const typename simd_helper::simd_type alpha[2][2]) { using simd_type = typename simd_helper::simd_type; simd_type rsrc[4]; @@ -99,8 +96,8 @@ static inline void compute_linear_2x2_element_simd( } template -void linear_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N, - size_t IH, size_t IW) { +void linear_upsample2_nchw( + const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { using simd_helper = SIMDHelper; size_t OW = IW * 2; constexpr size_t PC = simd_helper::simd_width; @@ -114,8 +111,8 @@ void linear_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N, simd_alpha[1][1] = simd_helper::dup(0.25 * 0.25); for (size_t i = 0; i < N; ++i) { - compute_linear_2x2_element(src_ptr, dst_ptr, IW, - OW, alpha); + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); { for (size_t iw = 0; iw + 1 < IW; ++iw) { compute_linear_2x2_element( @@ -127,13 +124,12 @@ void linear_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N, dst_ptr += OW; for (size_t ih = 0; ih + 1 < IH; ++ih) { - compute_linear_2x2_element(src_ptr, dst_ptr, IW, - OW, alpha); + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); size_t iw = 0; for (; iw + PC < IW; iw += PC) { compute_linear_2x2_element_simd( - src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, - simd_alpha); + src_ptr + iw, dst_ptr + (iw * 2 + 1), IW, OW, simd_alpha); } for (; iw + 1 < IW; ++iw) { compute_linear_2x2_element( @@ -146,8 +142,8 @@ void linear_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N, dst_ptr += 2 * OW; } - compute_linear_2x2_element(src_ptr, dst_ptr, IW, - OW, alpha); + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); { for (size_t iw = 0; iw + 1 < IW; ++iw) { compute_linear_2x2_element( @@ -162,8 +158,8 @@ void linear_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N, } template -void nearest_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N, - size_t IH, size_t IW) { +void nearest_upsample2_nchw( + const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { using simd_helper = SIMDHelper; size_t OW = IW * 2; constexpr size_t PC = simd_helper::simd_width; @@ -172,8 +168,7 @@ void nearest_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N, for (size_t ih = 0; ih < IH; ++ih) { size_t iw = 0; for (; iw + PC - 1 < IW; iw += PC) { - typename simd_helper::simd_type r0 = - simd_helper::load(src_ptr + iw); + typename simd_helper::simd_type r0 = simd_helper::load(src_ptr + iw); simd_helper::store2_interleave(dst_ptr + (iw * 2), r0, r0); simd_helper::store2_interleave(dst_ptr + (OW + iw * 2), r0, r0); @@ -195,16 +190,16 @@ void nearest_upsample2_nchw(const ctype* src_ptr, ctype* dst_ptr, size_t N, void megdnn::arm_common::resize_linear_upsample2_nchw_fp32( const ResizeImpl::KernParam& kern_param) { - linear_upsample2_nchw(kern_param.sptr, kern_param.dptr, - kern_param.n * kern_param.c, kern_param.ih, - kern_param.iw); + linear_upsample2_nchw( + kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c, + kern_param.ih, kern_param.iw); } void megdnn::arm_common::resize_nearest_upsample2_nchw_fp32( const ResizeImpl::KernParam& kern_param) { - nearest_upsample2_nchw(kern_param.sptr, kern_param.dptr, - kern_param.n * kern_param.c, kern_param.ih, - kern_param.iw); + nearest_upsample2_nchw( + kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c, + kern_param.ih, kern_param.iw); } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -213,16 +208,16 @@ void megdnn::arm_common::resize_linear_upsample2_nchw_fp16( const ResizeImpl::KernParam& kern_param) { auto sptr = reinterpret_cast(kern_param.sptr); auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); - linear_upsample2_nchw(sptr, dptr, kern_param.n * kern_param.c, - kern_param.ih, kern_param.iw); + linear_upsample2_nchw( + sptr, dptr, kern_param.n * kern_param.c, kern_param.ih, kern_param.iw); } void megdnn::arm_common::resize_nearest_upsample2_nchw_fp16( const ResizeImpl::KernParam& kern_param) { auto sptr = reinterpret_cast(kern_param.sptr); auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); - nearest_upsample2_nchw(sptr, dptr, kern_param.n * kern_param.c, - kern_param.ih, kern_param.iw); + nearest_upsample2_nchw( + sptr, dptr, kern_param.n * kern_param.c, kern_param.ih, kern_param.iw); } #endif diff --git a/dnn/src/arm_common/resize/upsample2_nchw.h b/dnn/src/arm_common/resize/upsample2_nchw.h index 3b6aa7ce..0bc83920 100644 --- a/dnn/src/arm_common/resize/upsample2_nchw.h +++ b/dnn/src/arm_common/resize/upsample2_nchw.h @@ -16,11 +16,9 @@ namespace megdnn { namespace arm_common { -void resize_linear_upsample2_nchw_fp32( - const ResizeImpl::KernParam& kern_param); +void resize_linear_upsample2_nchw_fp32(const ResizeImpl::KernParam& kern_param); -void resize_nearest_upsample2_nchw_fp32( - const ResizeImpl::KernParam& kern_param); +void resize_nearest_upsample2_nchw_fp32(const ResizeImpl::KernParam& kern_param); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/dnn/src/arm_common/resize/upsample2_nchwxx.cpp b/dnn/src/arm_common/resize/upsample2_nchwxx.cpp index 7e91c416..db5b2bc7 100644 --- a/dnn/src/arm_common/resize/upsample2_nchwxx.cpp +++ b/dnn/src/arm_common/resize/upsample2_nchwxx.cpp @@ -35,9 +35,8 @@ static inline typename simd_helper::simd_type compute_linear_element( template static inline void compute_linear_2x2_element( - const typename simd_helper::ctype* src, - typename simd_helper::ctype* dst, size_t IW, size_t OW, - const typename simd_helper::simd_type alpha[2][2]) { + const typename simd_helper::ctype* src, typename simd_helper::ctype* dst, + size_t IW, size_t OW, const typename simd_helper::simd_type alpha[2][2]) { constexpr size_t PC = simd_helper::simd_width; const typename simd_helper::ctype* src_ptr[4] = {src, src, src, src}; @@ -75,8 +74,8 @@ static inline void compute_linear_2x2_element( } template -void linear_upsample2_nchwxx(const ctype* src_ptr, ctype* dst_ptr, size_t N, - size_t IH, size_t IW) { +void linear_upsample2_nchwxx( + const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { using simd_helper = SIMDHelper; size_t OW = IW * 2; constexpr size_t PC = simd_helper::simd_width; @@ -88,19 +87,17 @@ void linear_upsample2_nchwxx(const ctype* src_ptr, ctype* dst_ptr, size_t N, alpha[1][1] = simd_helper::dup(0.25 * 0.25); for (size_t i = 0; i < N; ++i) { - compute_linear_2x2_element(src_ptr, dst_ptr, - IW, OW, alpha); + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); { for (size_t iw = 0; iw + 1 < IW; ++iw) { compute_linear_2x2_element( - src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, - alpha); + src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); } } compute_linear_2x2_element( - src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, - alpha); + src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); dst_ptr += OW * PC; for (size_t ih = 0; ih + 1 < IH; ++ih) { @@ -108,38 +105,34 @@ void linear_upsample2_nchwxx(const ctype* src_ptr, ctype* dst_ptr, size_t N, src_ptr, dst_ptr, IW, OW, alpha); for (size_t iw = 0; iw + 1 < IW; ++iw) { compute_linear_2x2_element( - src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, - alpha); + src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); } compute_linear_2x2_element( - src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, - alpha); + src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); src_ptr += IW * PC; dst_ptr += 2 * OW * PC; } - compute_linear_2x2_element(src_ptr, dst_ptr, - IW, OW, alpha); + compute_linear_2x2_element( + src_ptr, dst_ptr, IW, OW, alpha); { for (size_t iw = 0; iw + 1 < IW; ++iw) { compute_linear_2x2_element( - src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, - alpha); + src_ptr + iw * PC, dst_ptr + (iw * 2 + 1) * PC, IW, OW, alpha); } } compute_linear_2x2_element( - src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, - alpha); + src_ptr + (IW - 1) * PC, dst_ptr + (OW - 1) * PC, IW, OW, alpha); src_ptr += IW * PC; dst_ptr += OW * PC; } } template -void nearest_upsample2_nchwxx(const ctype* src_ptr, ctype* dst_ptr, size_t N, - size_t IH, size_t IW) { +void nearest_upsample2_nchwxx( + const ctype* src_ptr, ctype* dst_ptr, size_t N, size_t IH, size_t IW) { using simd_helper = SIMDHelper; size_t OW = IW * 2; constexpr size_t PC = simd_helper::simd_width; @@ -164,16 +157,16 @@ void nearest_upsample2_nchwxx(const ctype* src_ptr, ctype* dst_ptr, size_t N, void megdnn::arm_common::resize_linear_upsample2_nchw44_fp32( const ResizeImpl::KernParam& kern_param) { - linear_upsample2_nchwxx(kern_param.sptr, kern_param.dptr, - kern_param.n * kern_param.c / 4, kern_param.ih, - kern_param.iw); + linear_upsample2_nchwxx( + kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, + kern_param.ih, kern_param.iw); } void megdnn::arm_common::resize_nearest_upsample2_nchw44_fp32( const ResizeImpl::KernParam& kern_param) { - nearest_upsample2_nchwxx(kern_param.sptr, kern_param.dptr, - kern_param.n * kern_param.c / 4, kern_param.ih, - kern_param.iw); + nearest_upsample2_nchwxx( + kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, + kern_param.ih, kern_param.iw); } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -182,16 +175,16 @@ void megdnn::arm_common::resize_linear_upsample2_nchw88_fp16( const ResizeImpl::KernParam& kern_param) { auto sptr = reinterpret_cast(kern_param.sptr); auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); - linear_upsample2_nchwxx(sptr, dptr, kern_param.n * kern_param.c / 8, - kern_param.ih, kern_param.iw); + linear_upsample2_nchwxx( + sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); } void megdnn::arm_common::resize_nearest_upsample2_nchw88_fp16( const ResizeImpl::KernParam& kern_param) { auto sptr = reinterpret_cast(kern_param.sptr); auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); - nearest_upsample2_nchwxx(sptr, dptr, kern_param.n * kern_param.c / 8, - kern_param.ih, kern_param.iw); + nearest_upsample2_nchwxx( + sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); } #endif diff --git a/dnn/src/arm_common/separable_conv/opr_impl.cpp b/dnn/src/arm_common/separable_conv/opr_impl.cpp index 9f28185e..fa11a3ee 100644 --- a/dnn/src/arm_common/separable_conv/opr_impl.cpp +++ b/dnn/src/arm_common/separable_conv/opr_impl.cpp @@ -12,39 +12,33 @@ #include "./sep_conv_filter.h" #include "src/common/utils.h" //#include "src/arm_common/profile.h" -#include "src/arm_common/handle.h" #include +#include "src/arm_common/handle.h" namespace megdnn { namespace arm_common { using namespace sep_conv; -void SeparableConvImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter_x, - _megdnn_tensor_in filter_y, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) -{ - check_exec(src.layout, filter_x.layout, filter_y.layout, dst.layout, workspace.size); +void SeparableConvImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter_x, _megdnn_tensor_in filter_y, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec( + src.layout, filter_x.layout, filter_y.layout, dst.layout, workspace.size); int ih = src.layout.shape[2]; int iw = src.layout.shape[3]; int oh = dst.layout.shape[2]; int ow = dst.layout.shape[3]; - filter_engine_ = new FilterEngine(ih, iw, oh, ow, - param().ksize_h, param().ksize_w, - param().anchor_h, param().anchor_w, - param().borderMode, param().is_symm_kernel); - - MEGDNN_DISPATCH_CPU_KERN_OPR( - filter_engine_->exec(src, filter_x, filter_y, dst); - ); + filter_engine_ = new FilterEngine( + ih, iw, oh, ow, param().ksize_h, param().ksize_w, param().anchor_h, + param().anchor_w, param().borderMode, param().is_symm_kernel); - delete(filter_engine_); + MEGDNN_DISPATCH_CPU_KERN_OPR(filter_engine_->exec(src, filter_x, filter_y, dst);); + delete (filter_engine_); } -} // namespace arm_common -} // namespace megdnn +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_conv/opr_impl.h b/dnn/src/arm_common/separable_conv/opr_impl.h index 509da5ff..37eef2eb 100644 --- a/dnn/src/arm_common/separable_conv/opr_impl.h +++ b/dnn/src/arm_common/separable_conv/opr_impl.h @@ -9,32 +9,29 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once -#include "megdnn/oprs.h" #include "./sep_conv_filter.h" +#include "megdnn/oprs.h" namespace megdnn { namespace arm_common { using namespace sep_conv; -class SeparableConvImpl: public SeparableConvForward { - public: - //SeparableConvForwardImpl(Handle *handle): SeparableConvForward(handle) {} - using SeparableConvForward::SeparableConvForward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter_x, - _megdnn_tensor_in filter_y, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; +class SeparableConvImpl : public SeparableConvForward { +public: + // SeparableConvForwardImpl(Handle *handle): SeparableConvForward(handle) {} + using SeparableConvForward::SeparableConvForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &, - const TensorLayout &, - const TensorLayout &) override - { - // TODO: deduce the size of ring buffer. - return 0; - } - FilterEngine* filter_engine_; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + // TODO: deduce the size of ring buffer. + return 0; + } + FilterEngine* filter_engine_; }; -} // namespace arm_common -} // namespace megdnn +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_conv/sep_conv_common.h b/dnn/src/arm_common/separable_conv/sep_conv_common.h index 2f226257..18b53534 100644 --- a/dnn/src/arm_common/separable_conv/sep_conv_common.h +++ b/dnn/src/arm_common/separable_conv/sep_conv_common.h @@ -10,8 +10,8 @@ */ #pragma once -#include "src/common/utils.h" #include "megdnn/oprs.h" +#include "src/common/utils.h" namespace megdnn { namespace arm_common { @@ -25,125 +25,120 @@ using ushort = unsigned short; /////////// helper /////////// -static inline size_t align_size(size_t sz, int n) -{ +static inline size_t align_size(size_t sz, int n) { megdnn_assert((n & (n - 1)) == 0); - return (sz + n-1) & -n; + return (sz + n - 1) & -n; } -static inline int clip(int x, int a, int b) -{ - return x >= a ? (x < b ? x : b-1) : a; +static inline int clip(int x, int a, int b) { + return x >= a ? (x < b ? x : b - 1) : a; } -template static inline _Tp* align_ptr(_Tp* ptr, int n=(int)sizeof(_Tp)) -{ - return (_Tp*)(((size_t)ptr + n-1) & -n); +template +static inline _Tp* align_ptr(_Tp* ptr, int n = (int)sizeof(_Tp)) { + return (_Tp*)(((size_t)ptr + n - 1) & -n); } template -T saturate_cast(T x) -{ return x; } +T saturate_cast(T x) { + return x; +} template -T saturate_cast(int x) -{ +T saturate_cast(int x) { return static_cast(x); } template -T saturate_cast(float x) -{ +T saturate_cast(float x) { return static_cast(x); } template -T saturate_cast(double x) -{ +T saturate_cast(double x) { return static_cast(x); } // int -> uchar -template<> unsigned char saturate_cast(int x); +template <> +unsigned char saturate_cast(int x); // int -> short -template<> short saturate_cast(int x); +template <> +short saturate_cast(int x); // float -> int -template<> int saturate_cast(float x); +template <> +int saturate_cast(float x); // float -> short -template<> short saturate_cast(float x); +template <> +short saturate_cast(float x); // double -> int -template<> int saturate_cast(double x); +template <> +int saturate_cast(double x); - -template struct FixedPtCast -{ +template +struct FixedPtCast { typedef ST type1; typedef DT rtype; - enum { SHIFT = bits, DELTA = 1 << (bits-1) }; + enum { SHIFT = bits, DELTA = 1 << (bits - 1) }; - DT operator()(ST val) const - { return saturate_cast
((val + DELTA)>>SHIFT); } + DT operator()(ST val) const { return saturate_cast
((val + DELTA) >> SHIFT); } }; -template struct FixedPtCastEx -{ +template +struct FixedPtCastEx { typedef ST type1; typedef DT rtype; FixedPtCastEx() : SHIFT(0), DELTA(0) {} - FixedPtCastEx(int bits) : SHIFT(bits), DELTA(bits ? 1 << (bits-1) : 0) {} + FixedPtCastEx(int bits) : SHIFT(bits), DELTA(bits ? 1 << (bits - 1) : 0) {} DT operator()(ST val) const { return saturate_cast
(val + DELTA); } int SHIFT, DELTA; }; -template<> struct FixedPtCastEx -{ +template <> +struct FixedPtCastEx { typedef int type1; typedef uchar rtype; FixedPtCastEx() : SHIFT(0), DELTA(0) {} - FixedPtCastEx(int bits) : SHIFT(bits), DELTA(bits ? 1 << (bits-1) : 0) {} - uchar operator()(int val) const { return saturate_cast((val + DELTA)>>SHIFT); } + FixedPtCastEx(int bits) : SHIFT(bits), DELTA(bits ? 1 << (bits - 1) : 0) {} + uchar operator()(int val) const { + return saturate_cast((val + DELTA) >> SHIFT); + } int SHIFT, DELTA; }; - -template struct Cast -{ +template +struct Cast { typedef ST type1; typedef DT rtype; DT operator()(ST val) const { return saturate_cast
(val); } }; -static inline int border_interpolate(int p, int len, BorderMode bmode) -{ - if( (unsigned)p < (unsigned)len ) +static inline int border_interpolate(int p, int len, BorderMode bmode) { + if ((unsigned)p < (unsigned)len) ; - else if( bmode == BorderMode::BORDER_REPLICATE ) + else if (bmode == BorderMode::BORDER_REPLICATE) p = p < 0 ? 0 : len - 1; - else if( bmode == BorderMode::BORDER_REFLECT || bmode == BorderMode::BORDER_REFLECT_101 ) - { + else if ( + bmode == BorderMode::BORDER_REFLECT || + bmode == BorderMode::BORDER_REFLECT_101) { int delta = (bmode == BorderMode::BORDER_REFLECT_101); - if( len == 1 ) + if (len == 1) return 0; - do - { - if( p < 0 ) + do { + if (p < 0) p = -p - 1 + delta; else p = len - 1 - (p - len) - delta; - } - while( (unsigned)p >= (unsigned)len ); - } - else if( bmode == BorderMode::BORDER_WRAP ) - { + } while ((unsigned)p >= (unsigned)len); + } else if (bmode == BorderMode::BORDER_WRAP) { megdnn_assert(len > 0); - if( p < 0 ) - p -= ((p-len+1)/len)*len; + if (p < 0) + p -= ((p - len + 1) / len) * len; while (p >= len) { p -= len; } - } - else if( bmode == BorderMode::BORDER_CONSTANT ) + } else if (bmode == BorderMode::BORDER_CONSTANT) p = -1; else megdnn_throw("Unknown/unsupported border type"); @@ -151,8 +146,8 @@ static inline int border_interpolate(int p, int len, BorderMode bmode) } /////////// helper /////////// -} // namespace sep_conv -} // namespace arm_common -} // namespace megdnn +} // namespace sep_conv +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_conv/sep_conv_filter.h b/dnn/src/arm_common/separable_conv/sep_conv_filter.h index 876f68de..b49aaa2c 100644 --- a/dnn/src/arm_common/separable_conv/sep_conv_filter.h +++ b/dnn/src/arm_common/separable_conv/sep_conv_filter.h @@ -17,8 +17,8 @@ namespace sep_conv { //#define BorderMode param::SeparableConv::BorderMode //#define BorderMode SeparableConv::Param::BorderMode using BorderMode = SeparableConv::Param::BorderMode; -//using uchar = unsigned char; -//using ushort = unsigned short; +// using uchar = unsigned char; +// using ushort = unsigned short; class BaseRowFilter { public: @@ -26,22 +26,26 @@ public: BaseRowFilter(); //! the destructor virtual ~BaseRowFilter(); - //! the filtering operator. Must be overridden in the derived classes. The horizontal border interpolation is done outside of the class. - virtual void operator()(const uchar* src, uchar* dst, uchar* kernel, int width, int cn) = 0; + //! the filtering operator. Must be overridden in the derived classes. The + //! horizontal border interpolation is done outside of the class. + virtual void operator()( + const uchar* src, uchar* dst, uchar* kernel, int width, int cn) = 0; int ksize; int anchor; }; - class BaseColumnFilter { public: //! the default constructor BaseColumnFilter(); //! the destructor virtual ~BaseColumnFilter(); - //! the filtering operator. Must be overridden in the derived classes. The vertical border interpolation is done outside of the class. - virtual void operator()(const uchar** src, uchar* dst, uchar* kernel, int dststep, int dstcount, int width) = 0; + //! the filtering operator. Must be overridden in the derived classes. The vertical + //! border interpolation is done outside of the class. + virtual void operator()( + const uchar** src, uchar* dst, uchar* kernel, int dststep, int dstcount, + int width) = 0; //! resets the internal buffers, if any virtual void reset(); @@ -51,66 +55,64 @@ public: class FilterEngine { public: - //FilterEngine(); + // FilterEngine(); - FilterEngine(const int &ih, const int &iw, - const int &oh, const int &ow, - const int &kh, const int &kw, - const int &anchor_h, const int &anchor_w, - BorderMode borderType = BorderMode::BORDER_CONSTANT, - bool is_symm_kernel = true); + FilterEngine( + const int& ih, const int& iw, const int& oh, const int& ow, const int& kh, + const int& kw, const int& anchor_h, const int& anchor_w, + BorderMode borderType = BorderMode::BORDER_CONSTANT, + bool is_symm_kernel = true); virtual ~FilterEngine(); - void init( const int &ih, const int &iw, - const int &oh, const int &ow, - const int &kh, const int &kw, - const int &anchor_h, const int &anchor_w, - BorderMode borderType, - bool is_symm_kernel); + void init( + const int& ih, const int& iw, const int& oh, const int& ow, const int& kh, + const int& kw, const int& anchor_h, const int& anchor_w, + BorderMode borderType, bool is_symm_kernel); - void exec( const TensorND & src, - const TensorND & kernel_x, - const TensorND & kernel_y, - const TensorND & dst); + void exec( + const TensorND& src, const TensorND& kernel_x, const TensorND& kernel_y, + const TensorND& dst); BaseRowFilter* getSepRowFilter(); BaseColumnFilter* getSepColFilter(); inline int getBorderRowIdx1(int idx); - private: // kernel - int ksize_x_, ksize_y_; - int anchor_x_, anchor_y_; // anchors is useless in this version. - int is_symm_kernel_; // are the kernels symmtric. - - //filter - BaseRowFilter *rowFilter_; - BaseColumnFilter *colFilter_; - - //buffer - std::vector srcRow_; // a buffer of a single appended input row - std::vector ringBuf_; // a buffer of middle results. size = maxBufferRow * (maxWidth + kernel_w - 1) + int ksize_x_, ksize_y_; + int anchor_x_, anchor_y_; // anchors is useless in this version. + int is_symm_kernel_; // are the kernels symmtric. + + // filter + BaseRowFilter* rowFilter_; + BaseColumnFilter* colFilter_; + + // buffer + std::vector srcRow_; // a buffer of a single appended input row + std::vector ringBuf_; // a buffer of middle results. size = maxBufferRow * + // (maxWidth + kernel_w - 1) std::vector row_ptr_; - int rowBuffStride_; // aligned stride of a row in the buffer. - int rowBufferOutputRow_; // each time the buffer is full, we can calculate 'rowBufferOutputRow' out rows at one time. - // In this version rowBufferOutputRow_ = 1. - int maxBufferRow_; // max_size_of buffer row. maxBufferRow_ = ksize_y + (rowBufferOutputRow_ - 1) - // In this version maxBufferRow_ = ksize_y. - - //border + int rowBuffStride_; // aligned stride of a row in the buffer. + int rowBufferOutputRow_; // each time the buffer is full, we can calculate + // 'rowBufferOutputRow' out rows at one time. In this + // version rowBufferOutputRow_ = 1. + int maxBufferRow_; // max_size_of buffer row. maxBufferRow_ = ksize_y + + // (rowBufferOutputRow_ - 1) In this version maxBufferRow_ = + // ksize_y. + + // border BorderMode borderType_; int dx1_, dx2_, dy1_, dy2_; - std::vector borderTab_; // src idx of border elements - std::vector constBorderValue_; // template of append value (out of mat edge) - std::vector constBorderRow_; // a row of srcRow full of border value ---rowFilter---> constBorderRow + std::vector borderTab_; // src idx of border elements + std::vector constBorderValue_; // template of append value (out of mat edge) + std::vector constBorderRow_; // a row of srcRow full of border value + // ---rowFilter---> constBorderRow }; - -} // namespace sep_conv -} // namespace arm_common -} // namespace megdnn +} // namespace sep_conv +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_conv/sep_conv_filter_engine.cpp b/dnn/src/arm_common/separable_conv/sep_conv_filter_engine.cpp index 009aabbf..13f09cf7 100644 --- a/dnn/src/arm_common/separable_conv/sep_conv_filter_engine.cpp +++ b/dnn/src/arm_common/separable_conv/sep_conv_filter_engine.cpp @@ -12,8 +12,8 @@ #include #include -#include "src/arm_common/simd_macro/marm_neon.h" #include +#include "src/arm_common/simd_macro/marm_neon.h" namespace megdnn { namespace arm_common { @@ -23,22 +23,16 @@ using uchar = unsigned char; using ushort = unsigned short; ////////////////////////////////////////////// -//vecOp +// vecOp ///////////////////////////////////////////// -struct RowVec_32f -{ - RowVec_32f() - {} - - RowVec_32f(int _len) - { - ksize = _len; - } +struct RowVec_32f { + RowVec_32f() {} - int operator()(const uchar* _src, uchar* _dst, uchar * kernel, int width, int cn) const - { + RowVec_32f(int _len) { ksize = _len; } + int operator()( + const uchar* _src, uchar* _dst, uchar* kernel, int width, int cn) const { int _ksize = ksize; const float* src0 = (const float*)_src; float* dst = (float*)_dst; @@ -47,12 +41,10 @@ struct RowVec_32f int i = 0, k; width *= cn; - for( ; i <= width - 8; i += 8 ) - { + for (; i <= width - 8; i += 8) { const float* src = src0 + i; float32x4_t f, s0 = vdupq_n_f32(0), s1 = s0, x0, x1; - for( k = 0; k < _ksize; k++, src += cn ) - { + for (k = 0; k < _ksize; k++, src += cn) { f = vdupq_n_f32(_kx[k]); x0 = vld1q_f32(src); @@ -63,12 +55,10 @@ struct RowVec_32f vst1q_f32(dst + i, s0); vst1q_f32(dst + i + 4, s1); } - for( ; i <= width - 4; i += 4 ) - { + for (; i <= width - 4; i += 4) { const float* src = src0 + i; float32x4_t f, s0 = vdupq_n_f32(0), x0; - for( k = 0; k < _ksize; k++, src += cn ) - { + for (k = 0; k < _ksize; k++, src += cn) { f = vdupq_n_f32(_kx[k]); x0 = vld1q_f32(src); @@ -81,30 +71,24 @@ struct RowVec_32f int ksize; }; -struct SymmRowSmallVec_32f -{ +struct SymmRowSmallVec_32f { SymmRowSmallVec_32f() {} - SymmRowSmallVec_32f(int _len) - { - ksize = _len; - } + SymmRowSmallVec_32f(int _len) { ksize = _len; } - int operator()(const uchar* _src, uchar* _dst, uchar * kernel, int width, int cn) const - { + int operator()( + const uchar* _src, uchar* _dst, uchar* kernel, int width, int cn) const { int i = 0, _ksize = ksize; float* dst = (float*)_dst; - const float* src = (const float*)_src + (_ksize/2)*cn; - const float* kx = (float*)kernel + _ksize/2; + const float* src = (const float*)_src + (_ksize / 2) * cn; + const float* kx = (float*)kernel + _ksize / 2; width *= cn; { - if( _ksize == 1 ) + if (_ksize == 1) return 0; - if( _ksize == 3 ) - { + if (_ksize == 3) { float32x4_t k0 = vdupq_n_f32(kx[0]), k1 = vdupq_n_f32(kx[1]); - for( ; i <= width - 8; i += 8, src += 8 ) - { + for (; i <= width - 8; i += 8, src += 8) { float32x4_t x0, x1, x2, y0, y1, y2; x0 = vld1q_f32(src - cn); x1 = vld1q_f32(src); @@ -120,12 +104,10 @@ struct SymmRowSmallVec_32f vst1q_f32(dst + i, x0); vst1q_f32(dst + i + 4, y0); } - } - else if( _ksize == 5 ) - { - float32x4_t k0 = vdupq_n_f32(kx[0]), k1 = vdupq_n_f32(kx[1]), k2 = vdupq_n_f32(kx[2]); - for( ; i <= width - 8; i += 8, src += 8 ) - { + } else if (_ksize == 5) { + float32x4_t k0 = vdupq_n_f32(kx[0]), k1 = vdupq_n_f32(kx[1]), + k2 = vdupq_n_f32(kx[2]); + for (; i <= width - 8; i += 8, src += 8) { float32x4_t x0, x1, x2, y0, y1, y2; x0 = vld1q_f32(src - cn); x1 = vld1q_f32(src); @@ -139,8 +121,9 @@ struct SymmRowSmallVec_32f x0 = vmlaq_f32(x0, x1, k0); y0 = vmlaq_f32(y0, y1, k0); - x2 = vaddq_f32(vld1q_f32(src + cn*2), vld1q_f32(src - cn*2)); - y2 = vaddq_f32(vld1q_f32(src + cn*2 + 4), vld1q_f32(src - cn*2 + 4)); + x2 = vaddq_f32(vld1q_f32(src + cn * 2), vld1q_f32(src - cn * 2)); + y2 = vaddq_f32( + vld1q_f32(src + cn * 2 + 4), vld1q_f32(src - cn * 2 + 4)); x0 = vmlaq_f32(x0, x2, k2); y0 = vmlaq_f32(y0, y2, k2); @@ -148,64 +131,57 @@ struct SymmRowSmallVec_32f vst1q_f32(dst + i + 4, y0); } } - } return i; } int ksize; }; -struct ColumnVec_32f -{ +struct ColumnVec_32f { ColumnVec_32f() {} - ColumnVec_32f(int _len, int) - { - ksize = _len; - } + ColumnVec_32f(int _len, int) { ksize = _len; } - int operator()(const uchar** _src, uchar* _dst, uchar * kernel, int &, int width) const - { + int operator()( + const uchar** _src, uchar* _dst, uchar* kernel, int&, int width) const { const float* ky = (const float*)kernel; int i = 0, k; const float** src = (const float**)_src; - const float *S; + const float* S; float* dst = (float*)_dst; { - for( ; i <= width - 16; i += 16 ) - { + for (; i <= width - 16; i += 16) { float32x4_t f = vdupq_n_f32(ky[0]); float32x4_t s0, s1, s2, s3; float32x4_t x0, x1; S = src[0] + i; s0 = vld1q_f32(S); - s1 = vld1q_f32(S+4); + s1 = vld1q_f32(S + 4); s0 = vmulq_f32(s0, f); s1 = vmulq_f32(s1, f); - s2 = vld1q_f32(S+8); - s3 = vld1q_f32(S+12); + s2 = vld1q_f32(S + 8); + s3 = vld1q_f32(S + 12); s2 = vmulq_f32(s2, f); s3 = vmulq_f32(s3, f); - for( k = 1; k < ksize; k++ ) - { + for (k = 1; k < ksize; k++) { S = src[k] + i; float32x4_t f = vdupq_n_f32(ky[k]); x0 = vld1q_f32(S); - x1 = vld1q_f32(S+4); + x1 = vld1q_f32(S + 4); s0 = vmlaq_f32(s0, f, x0); s1 = vmlaq_f32(s1, f, x1); - x0 = vld1q_f32(S+8); - x1 = vld1q_f32(S+12); + x0 = vld1q_f32(S + 8); + x1 = vld1q_f32(S + 12); s2 = vmlaq_f32(s2, f, x0); s3 = vmlaq_f32(s3, f, x1); } - s0 = vaddq_f32(s0, vld1q_f32(dst+i)); - s1 = vaddq_f32(s1, vld1q_f32(dst+i+4)); - s2 = vaddq_f32(s2, vld1q_f32(dst+i+8)); - s3 = vaddq_f32(s3, vld1q_f32(dst+i+12)); + s0 = vaddq_f32(s0, vld1q_f32(dst + i)); + s1 = vaddq_f32(s1, vld1q_f32(dst + i + 4)); + s2 = vaddq_f32(s2, vld1q_f32(dst + i + 8)); + s3 = vaddq_f32(s3, vld1q_f32(dst + i + 12)); vst1q_f32(dst + i, s0); vst1q_f32(dst + i + 4, s1); @@ -213,15 +189,13 @@ struct ColumnVec_32f vst1q_f32(dst + i + 12, s3); } - for( ; i <= width - 4; i += 4 ) - { + for (; i <= width - 4; i += 4) { float32x4_t f = vdupq_n_f32(ky[0]); float32x4_t x0, s0 = vld1q_f32(src[0] + i); s0 = vmulq_f32(s0, f); - for( k = 1; k < ksize; k++ ) - { + for (k = 1; k < ksize; k++) { float32x4_t f = vdupq_n_f32(ky[k]); S = src[k] + i; x0 = vld1q_f32(S); @@ -237,17 +211,13 @@ struct ColumnVec_32f int ksize; }; -struct SymmColumnVec_32f -{ +struct SymmColumnVec_32f { SymmColumnVec_32f() {} - SymmColumnVec_32f(int _len, int) - { - ksize = _len; - } + SymmColumnVec_32f(int _len, int) { ksize = _len; } - int operator()(const uchar** _src, uchar* _dst, uchar * kernel, int &, int width) const - { - int ksize2 = (ksize)/2; + int operator()( + const uchar** _src, uchar* _dst, uchar* kernel, int&, int width) const { + int ksize2 = (ksize) / 2; const float* ky = (const float*)kernel + ksize2; int i = 0, k; const float** src = (const float**)_src; @@ -255,42 +225,39 @@ struct SymmColumnVec_32f float* dst = (float*)_dst; { - for( ; i <= width - 16; i += 16 ) - { + for (; i <= width - 16; i += 16) { float32x4_t f = vdupq_n_f32(ky[0]); float32x4_t s0, s1, s2, s3; float32x4_t x0, x1; S = src[0] + i; s0 = vld1q_f32(S); - s1 = vld1q_f32(S+4); + s1 = vld1q_f32(S + 4); s0 = vmulq_f32(s0, f); s1 = vmulq_f32(s1, f); - s2 = vld1q_f32(S+8); - s3 = vld1q_f32(S+12); + s2 = vld1q_f32(S + 8); + s3 = vld1q_f32(S + 12); s2 = vmulq_f32(s2, f); s3 = vmulq_f32(s3, f); - for( k = 1; k <= ksize2; k++ ) - { + for (k = 1; k <= ksize2; k++) { S = src[k] + i; S2 = src[-k] + i; float32x4_t f = vdupq_n_f32(ky[k]); x0 = vaddq_f32(vld1q_f32(S), vld1q_f32(S2)); - x1 = vaddq_f32(vld1q_f32(S+4), vld1q_f32(S2+4)); + x1 = vaddq_f32(vld1q_f32(S + 4), vld1q_f32(S2 + 4)); s0 = vmlaq_f32(s0, x0, f); s1 = vmlaq_f32(s1, x1, f); - x0 = vaddq_f32(vld1q_f32(S+8), vld1q_f32(S2+8)); - x1 = vaddq_f32(vld1q_f32(S+12), vld1q_f32(S2+12)); + x0 = vaddq_f32(vld1q_f32(S + 8), vld1q_f32(S2 + 8)); + x1 = vaddq_f32(vld1q_f32(S + 12), vld1q_f32(S2 + 12)); s2 = vmlaq_f32(s2, x0, f); s3 = vmlaq_f32(s3, x1, f); - } - s0 = vaddq_f32(s0, vld1q_f32(dst+i)); - s1 = vaddq_f32(s1, vld1q_f32(dst+i+4)); - s2 = vaddq_f32(s2, vld1q_f32(dst+i+8)); - s3 = vaddq_f32(s3, vld1q_f32(dst+i+12)); + s0 = vaddq_f32(s0, vld1q_f32(dst + i)); + s1 = vaddq_f32(s1, vld1q_f32(dst + i + 4)); + s2 = vaddq_f32(s2, vld1q_f32(dst + i + 8)); + s3 = vaddq_f32(s3, vld1q_f32(dst + i + 12)); vst1q_f32(dst + i, s0); vst1q_f32(dst + i + 4, s1); @@ -298,14 +265,12 @@ struct SymmColumnVec_32f vst1q_f32(dst + i + 12, s3); } - for( ; i <= width - 4; i += 4 ) - { + for (; i <= width - 4; i += 4) { float32x4_t f = vdupq_n_f32(ky[0]); float32x4_t x0, s0 = vld1q_f32(src[0] + i); s0 = vmulq_f32(s0, f); - for( k = 1; k <= ksize2; k++ ) - { + for (k = 1; k <= ksize2; k++) { float32x4_t f = vdupq_n_f32(ky[k]); S = src[k] + i; S2 = src[-k] + i; @@ -318,25 +283,20 @@ struct SymmColumnVec_32f } return i; - } int ksize; }; +struct SymmColumnSmallVec_32f { + SymmColumnSmallVec_32f() {} + SymmColumnSmallVec_32f(int _len, int) { ksize = _len; } -struct SymmColumnSmallVec_32f -{ - SymmColumnSmallVec_32f() { } - SymmColumnSmallVec_32f(int _len, int) - { - ksize = _len; - } - - int operator()(const uchar** _src, uchar* _dst, uchar * kernel, int & count, int width) const - { + int operator()( + const uchar** _src, uchar* _dst, uchar* kernel, int& count, + int width) const { (void)count; - int ksize2 = (ksize)/2; + int ksize2 = (ksize) / 2; const float* ky = (float*)kernel + ksize2; int i = 0; const float** src = (const float**)_src; @@ -344,8 +304,7 @@ struct SymmColumnSmallVec_32f float* dst = (float*)_dst; { float32x4_t k0 = vdupq_n_f32(ky[0]), k1 = vdupq_n_f32(ky[1]); - for( ; i <= width - 8; i += 8 ) - { + for (; i <= width - 8; i += 8) { float32x4_t s0, s1, x0, x1; s0 = vld1q_f32(S1 + i); s1 = vld1q_f32(S1 + i + 4); @@ -371,22 +330,22 @@ struct SymmColumnSmallVec_32f //%RowFilter% ////////////////////////////////////////////////////////////////////////////////////// -BaseRowFilter::BaseRowFilter() { ksize = anchor = -1; } +BaseRowFilter::BaseRowFilter() { + ksize = anchor = -1; +} BaseRowFilter::~BaseRowFilter() {} -template struct RowFilter : public BaseRowFilter -{ - RowFilter(int _ksize, int _anchor, const VecOp& _vecOp=VecOp() ) - { +template +struct RowFilter : public BaseRowFilter { + RowFilter(int _ksize, int _anchor, const VecOp& _vecOp = VecOp()) { anchor = _anchor; ksize = _ksize; vecOp = _vecOp; } - void operator()(const uchar* src, uchar* dst, uchar* kernel, int width, int cn) - { + void operator()(const uchar* src, uchar* dst, uchar* kernel, int width, int cn) { int _ksize = ksize; - const DT* kx = (DT* )kernel; + const DT* kx = (DT*)kernel; const ST* S; DT* D = (DT*)dst; int i, k; @@ -394,32 +353,32 @@ template struct RowFilter : public BaseRo i = vecOp(src, dst, kernel, width, cn); width *= cn; #if MEGCV_ENABLE_UNROLLED - for( ; i <= width - 4; i += 4 ) - { + for (; i <= width - 4; i += 4) { S = (const ST*)src + i; DT f = kx[0]; - DT s0 = f*S[0], s1 = f*S[1], s2 = f*S[2], s3 = f*S[3]; + DT s0 = f * S[0], s1 = f * S[1], s2 = f * S[2], s3 = f * S[3]; - for( k = 1; k < _ksize; k++ ) - { + for (k = 1; k < _ksize; k++) { S += cn; f = kx[k]; - s0 += f*S[0]; s1 += f*S[1]; - s2 += f*S[2]; s3 += f*S[3]; + s0 += f * S[0]; + s1 += f * S[1]; + s2 += f * S[2]; + s3 += f * S[3]; } - D[i] = s0; D[i+1] = s1; - D[i+2] = s2; D[i+3] = s3; + D[i] = s0; + D[i + 1] = s1; + D[i + 2] = s2; + D[i + 3] = s3; } #endif - for( ; i < width; i++ ) - { + for (; i < width; i++) { S = (const ST*)src + i; - DT s0 = kx[0]*S[0]; - for( k = 1; k < _ksize; k++ ) - { + DT s0 = kx[0] * S[0]; + for (k = 1; k < _ksize; k++) { S += cn; - s0 += kx[k]*S[0]; + s0 += kx[k] * S[0]; } D[i] = s0; } @@ -427,18 +386,13 @@ template struct RowFilter : public BaseRo VecOp vecOp; }; +template +struct SymmRowSmallFilter : public RowFilter { + SymmRowSmallFilter(int _ksize, int _anchor, const VecOp& _vecOp = VecOp()) + : RowFilter(_ksize, _anchor, _vecOp) {} -template struct SymmRowSmallFilter : - public RowFilter -{ - SymmRowSmallFilter(int _ksize, int _anchor, - const VecOp& _vecOp = VecOp() ) - : RowFilter( _ksize, _anchor, _vecOp ) - {} - - void operator()(const uchar* src, uchar* dst, uchar* kernel, int width, int cn) - { - int ksize2 = this->ksize/2, ksize2n = ksize2*cn; + void operator()(const uchar* src, uchar* dst, uchar* kernel, int width, int cn) { + int ksize2 = this->ksize / 2, ksize2n = ksize2 * cn; const DT* kx = (DT*)kernel + ksize2; DT* D = (DT*)dst; int i = this->vecOp(src, dst, kernel, width, cn), j, k; @@ -446,40 +400,37 @@ template struct SymmRowSmallFilter : width *= cn; { - if( this->ksize == 1 && kx[0] == 1 ) - { - for( ; i <= width - 2; i += 2 ) - { - DT s0 = S[i], s1 = S[i+1]; - D[i] = s0; D[i+1] = s1; + if (this->ksize == 1 && kx[0] == 1) { + for (; i <= width - 2; i += 2) { + DT s0 = S[i], s1 = S[i + 1]; + D[i] = s0; + D[i + 1] = s1; } S += i; - } - else if( this->ksize == 3 ) - { + } else if (this->ksize == 3) { DT k0 = kx[0], k1 = kx[1]; - for( ; i <= width - 2; i += 2, S += 2 ) - { - DT s0 = S[0]*k0 + (S[-cn] + S[cn])*k1, s1 = S[1]*k0 + (S[1-cn] + S[1+cn])*k1; - D[i] = s0; D[i+1] = s1; + for (; i <= width - 2; i += 2, S += 2) { + DT s0 = S[0] * k0 + (S[-cn] + S[cn]) * k1, + s1 = S[1] * k0 + (S[1 - cn] + S[1 + cn]) * k1; + D[i] = s0; + D[i + 1] = s1; } - } - else if( this->ksize == 5 ) - { + } else if (this->ksize == 5) { DT k0 = kx[0], k1 = kx[1], k2 = kx[2]; - for( ; i <= width - 2; i += 2, S += 2 ) - { - DT s0 = S[0]*k0 + (S[-cn] + S[cn])*k1 + (S[-cn*2] + S[cn*2])*k2; - DT s1 = S[1]*k0 + (S[1-cn] + S[1+cn])*k1 + (S[1-cn*2] + S[1+cn*2])*k2; - D[i] = s0; D[i+1] = s1; + for (; i <= width - 2; i += 2, S += 2) { + DT s0 = S[0] * k0 + (S[-cn] + S[cn]) * k1 + + (S[-cn * 2] + S[cn * 2]) * k2; + DT s1 = S[1] * k0 + (S[1 - cn] + S[1 + cn]) * k1 + + (S[1 - cn * 2] + S[1 + cn * 2]) * k2; + D[i] = s0; + D[i + 1] = s1; } } - for( ; i < width; i++, S++ ) - { - DT s0 = kx[0]*S[0]; - for( k = 1, j = cn; k <= ksize2; k++, j += cn ) - s0 += kx[k]*(S[j] + S[-j]); + for (; i < width; i++, S++) { + DT s0 = kx[0] * S[0]; + for (k = 1, j = cn; k <= ksize2; k++, j += cn) + s0 += kx[k] * (S[j] + S[-j]); D[i] = s0; } } @@ -487,99 +438,93 @@ template struct SymmRowSmallFilter : }; template - BaseRowFilter * getLinearRowFilter(int ksize, bool is_symm_kernel) - { - // TODO: calculate anchor - int anchor = ksize/2; - if(is_symm_kernel) { - if( ksize <= 5 ) - { - //if( typeid(T) == typeid(float) && typeid(T1) == typeid(float)) - return new SymmRowSmallFilter - (ksize, anchor, SymmRowSmallVec_32f(ksize)); - } - - //if( typeid(T) == typeid(float) && typeid(T1) == typeid(float)) - return new RowFilter - (ksize, anchor, RowVec_32f(ksize)); - } else { - //if( typeid(T) == typeid(float) && typeid(T1) == typeid(float)) - return new RowFilter - (ksize, anchor, RowVec_32f(ksize)); +BaseRowFilter* getLinearRowFilter(int ksize, bool is_symm_kernel) { + // TODO: calculate anchor + int anchor = ksize / 2; + if (is_symm_kernel) { + if (ksize <= 5) { + // if( typeid(T) == typeid(float) && typeid(T1) == typeid(float)) + return new SymmRowSmallFilter( + ksize, anchor, SymmRowSmallVec_32f(ksize)); } - //printf("Unsupported combination of source format (=%s), and buffer format (=%s)", - // typeid(T).name(), typeid(T1).name()); - //exit(1); + // if( typeid(T) == typeid(float) && typeid(T1) == typeid(float)) + return new RowFilter(ksize, anchor, RowVec_32f(ksize)); + } else { + // if( typeid(T) == typeid(float) && typeid(T1) == typeid(float)) + return new RowFilter(ksize, anchor, RowVec_32f(ksize)); } -////////////////////////////////////////////////////////////////////////////////////// + // printf("Unsupported combination of source format (=%s), and buffer format (=%s)", + // typeid(T).name(), typeid(T1).name()); + // exit(1); +} +////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////// //%BaseColFilter% ////////////////////////////////////////////////////////////////////////////////////// -BaseColumnFilter::BaseColumnFilter() { ksize = anchor = -1; } +BaseColumnFilter::BaseColumnFilter() { + ksize = anchor = -1; +} BaseColumnFilter::~BaseColumnFilter() {} void BaseColumnFilter::reset() {} -template struct ColumnFilter : public BaseColumnFilter -{ +template +struct ColumnFilter : public BaseColumnFilter { typedef typename CastOp::type1 ST; typedef typename CastOp::rtype DT; - ColumnFilter(int _ksize, int _anchor, - const CastOp& _castOp=CastOp(), - const VecOp& _vecOp=VecOp()) - { + ColumnFilter( + int _ksize, int _anchor, const CastOp& _castOp = CastOp(), + const VecOp& _vecOp = VecOp()) { this->anchor = _anchor; this->ksize = _ksize; this->castOp0 = _castOp; this->vecOp = _vecOp; } - void operator()(const uchar** src, uchar* dst, uchar* kernel, int dststep, int count, int width) - { + void operator()( + const uchar** src, uchar* dst, uchar* kernel, int dststep, int count, + int width) { const ST* ky = (ST*)kernel; int i = 0, k; CastOp castOp = this->castOp0; { - for( ; count > 0; count--, dst += dststep, src++ ) - { + for (; count > 0; count--, dst += dststep, src++) { DT* D = (DT*)dst; i = (this->vecOp)(src, dst, kernel, count, width); #if MEGCV_ENABLE_UNROLLED - for( ; i <= width - 4; i += 4 ) - { + for (; i <= width - 4; i += 4) { ST f = ky[0]; const ST* S = (const ST*)src[0] + i; - ST s0 = f*S[0], s1 = f*S[1], - s2 = f*S[2], s3 = f*S[3]; + ST s0 = f * S[0], s1 = f * S[1], s2 = f * S[2], s3 = f * S[3]; - for( k = 1; k < ksize; k++ ) - { + for (k = 1; k < ksize; k++) { S = (const ST*)src[k] + i; f = ky[k]; - s0 += f*S[0]; - s1 += f*S[1]; - s2 += f*S[2]; - s3 += f*S[3]; + s0 += f * S[0]; + s1 += f * S[1]; + s2 += f * S[2]; + s3 += f * S[3]; } - D[i] += castOp(s0); D[i+1] += castOp(s1); - D[i+2] += castOp(s2); D[i+3] += castOp(s3); + D[i] += castOp(s0); + D[i + 1] += castOp(s1); + D[i + 2] += castOp(s2); + D[i + 3] += castOp(s3); } #endif - for( ; i < width; i++ ) - { + for (; i < width; i++) { ST s0 = D[i]; - //ST s0 = ky[0]*((const ST*)src[0])[i]; - for( k = 0; k < ksize; k++ ) { - s0 += ky[k]* ((const ST*)src[k])[i]; + // ST s0 = ky[0]*((const ST*)src[0])[i]; + for (k = 0; k < ksize; k++) { + s0 += ky[k] * ((const ST*)src[k])[i]; } D[i] = castOp(s0); - //D[i] += castOp(s0); + // D[i] += castOp(s0); } } } @@ -588,64 +533,62 @@ template struct ColumnFilter : public BaseColumnFilte VecOp vecOp; }; -template struct SymmColumnFilter : public BaseColumnFilter -{ +template +struct SymmColumnFilter : public BaseColumnFilter { typedef typename CastOp::type1 ST; typedef typename CastOp::rtype DT; - SymmColumnFilter(int _ksize, int _anchor, - const CastOp& _castOp=CastOp(), - const VecOp& _vecOp=VecOp()) - { + SymmColumnFilter( + int _ksize, int _anchor, const CastOp& _castOp = CastOp(), + const VecOp& _vecOp = VecOp()) { this->anchor = _anchor; this->ksize = _ksize; this->castOp0 = _castOp; this->vecOp = _vecOp; } - void operator()(const uchar** src, uchar* dst, uchar* kernel, int dststep, int count, int width) - { - int ksize2 = this->ksize/2; + void operator()( + const uchar** src, uchar* dst, uchar* kernel, int dststep, int count, + int width) { + int ksize2 = this->ksize / 2; const ST* ky = (ST*)kernel + ksize2; int i, k; CastOp castOp = this->castOp0; src += ksize2; { - for( ; count > 0; count--, dst += dststep, src++ ) - { + for (; count > 0; count--, dst += dststep, src++) { DT* D = (DT*)dst; i = (this->vecOp)(src, dst, kernel, count, width); #if MEGCV_ENABLE_UNROLLED - for( ; i <= width - 4; i += 4 ) - { + for (; i <= width - 4; i += 4) { ST f = ky[0]; - const ST* S = (const ST*)src[0] + i, *S2; - ST s0 = f*S[0], s1 = f*S[1], - s2 = f*S[2], s3 = f*S[3]; + const ST *S = (const ST*)src[0] + i, *S2; + ST s0 = f * S[0], s1 = f * S[1], s2 = f * S[2], s3 = f * S[3]; - for( k = 1; k <= ksize2; k++ ) - { + for (k = 1; k <= ksize2; k++) { S = (const ST*)src[k] + i; S2 = (const ST*)src[-k] + i; f = ky[k]; - s0 += f*(S[0] + S2[0]); - s1 += f*(S[1] + S2[1]); - s2 += f*(S[2] + S2[2]); - s3 += f*(S[3] + S2[3]); + s0 += f * (S[0] + S2[0]); + s1 += f * (S[1] + S2[1]); + s2 += f * (S[2] + S2[2]); + s3 += f * (S[3] + S2[3]); } - D[i] += castOp(s0); D[i+1] += castOp(s1); - D[i+2] += castOp(s2); D[i+3] += castOp(s3); + D[i] += castOp(s0); + D[i + 1] += castOp(s1); + D[i + 2] += castOp(s2); + D[i + 3] += castOp(s3); } #endif - for( ; i < width; i++ ) - { - ST s0 = ky[0]*((const ST*)src[0])[i]; - for( k = 1; k <= ksize2; k++ ) { - s0 += ky[k]*(((const ST*)src[k])[i] + ((const ST*)src[-k])[i]); - //s0 += ky[k]*((const ST*)src[k])[i]; - //s0 += ky[k]*((const ST*)src[-k])[i]; + for (; i < width; i++) { + ST s0 = ky[0] * ((const ST*)src[0])[i]; + for (k = 1; k <= ksize2; k++) { + s0 += ky[k] * + (((const ST*)src[k])[i] + ((const ST*)src[-k])[i]); + // s0 += ky[k]*((const ST*)src[k])[i]; + // s0 += ky[k]*((const ST*)src[-k])[i]; } D[i] += castOp(s0); } @@ -656,24 +599,22 @@ template struct SymmColumnFilter : public BaseColumnF VecOp vecOp; }; - -template - struct SymmColumnSmallFilter : public SymmColumnFilter -{ +template +struct SymmColumnSmallFilter : public SymmColumnFilter { typedef typename CastOp::type1 ST; typedef typename CastOp::rtype DT; - SymmColumnSmallFilter( int _ksize, int _anchor, - const CastOp & _castOp=CastOp(), - const VecOp & _vecOp=VecOp()) - : SymmColumnFilter(_ksize, _anchor, _castOp, _vecOp ) - { - megdnn_assert(this->ksize == 3 ); + SymmColumnSmallFilter( + int _ksize, int _anchor, const CastOp& _castOp = CastOp(), + const VecOp& _vecOp = VecOp()) + : SymmColumnFilter(_ksize, _anchor, _castOp, _vecOp) { + megdnn_assert(this->ksize == 3); } - void operator()(const uchar** src, uchar* dst, uchar* kernel, int dststep, int count, int width) - { - int ksize2 = this->ksize/2; + void operator()( + const uchar** src, uchar* dst, uchar* kernel, int dststep, int count, + int width) { + int ksize2 = this->ksize / 2; const ST* ky = (ST*)kernel + ksize2; int i = 0; ST f0 = ky[0], f1 = ky[1]; @@ -686,210 +627,197 @@ template (this->vecOp)(src, dst, kernel, count, width); } */ - for( ; count > 0; count--, dst += dststep, src++ ) - { + for (; count > 0; count--, dst += dststep, src++) { DT* D = (DT*)dst; i = (this->vecOp)(src, dst, kernel, count, width); - if(count == 0) + if (count == 0) break; const ST* S0 = (const ST*)src[-1]; const ST* S1 = (const ST*)src[0]; const ST* S2 = (const ST*)src[1]; { #if MEGCV_ENABLE_UNROLLED - for( ; i <= width - 4; i += 4 ) - { - ST s0 = (S0[i] + S2[i])*f1 + S1[i]*f0; - ST s1 = (S0[i+1] + S2[i+1])*f1 + S1[i+1]*f0; + for (; i <= width - 4; i += 4) { + ST s0 = (S0[i] + S2[i]) * f1 + S1[i] * f0; + ST s1 = (S0[i + 1] + S2[i + 1]) * f1 + S1[i + 1] * f0; D[i] += castOp(s0); - D[i+1] += castOp(s1); + D[i + 1] += castOp(s1); - s0 = (S0[i+2] + S2[i+2])*f1 + S1[i+2]*f0; - s1 = (S0[i+3] + S2[i+3])*f1 + S1[i+3]*f0; - D[i+2] += castOp(s0); - D[i+3] += castOp(s1); + s0 = (S0[i + 2] + S2[i + 2]) * f1 + S1[i + 2] * f0; + s1 = (S0[i + 3] + S2[i + 3]) * f1 + S1[i + 3] * f0; + D[i + 2] += castOp(s0); + D[i + 3] += castOp(s1); } #endif - for( ; i < width; i ++ ) - { - ST s0 = (S0[i] + S2[i])*f1 + S1[i]*f0; + for (; i < width; i++) { + ST s0 = (S0[i] + S2[i]) * f1 + S1[i] * f0; D[i] += castOp(s0); } } } - } }; - -template - BaseColumnFilter * getLinearColumnFilter(int ksize, int /*bits*/, bool is_symm_kernel) +template +BaseColumnFilter* getLinearColumnFilter(int ksize, int /*bits*/, bool is_symm_kernel) { + int anchor = ksize / 2; { - int anchor = ksize/2; - { - if(is_symm_kernel) { - if( ksize == 3 ) - { - - //if( typeid(T1) == typeid(float) && typeid(T) == typeid(float) ) - return new SymmColumnSmallFilter,SymmColumnSmallVec_32f> - (ksize, anchor, FixedPtCastEx(0), - SymmColumnSmallVec_32f(ksize, 0)); - } - //if( typeid(T1) == typeid(float) && typeid(T) == typeid(float) ) - return new SymmColumnFilter, SymmColumnVec_32f> - (ksize, anchor, FixedPtCastEx(), - SymmColumnVec_32f(ksize, 0)); - } else { - //if( typeid(T1) == typeid(float) && typeid(T) == typeid(float) ) - return new ColumnFilter, ColumnVec_32f> - (ksize, anchor, FixedPtCastEx(), - ColumnVec_32f(ksize, 0)); + if (is_symm_kernel) { + if (ksize == 3) { + // if( typeid(T1) == typeid(float) && typeid(T) == typeid(float) ) + return new SymmColumnSmallFilter< + FixedPtCastEx, SymmColumnSmallVec_32f>( + ksize, anchor, FixedPtCastEx(0), + SymmColumnSmallVec_32f(ksize, 0)); } + // if( typeid(T1) == typeid(float) && typeid(T) == typeid(float) ) + return new SymmColumnFilter, SymmColumnVec_32f>( + ksize, anchor, FixedPtCastEx(), SymmColumnVec_32f(ksize, 0)); + } else { + // if( typeid(T1) == typeid(float) && typeid(T) == typeid(float) ) + return new ColumnFilter, ColumnVec_32f>( + ksize, anchor, FixedPtCastEx(), ColumnVec_32f(ksize, 0)); } - //printf("Unsupported combination of buffer format (=%s), and destination format (=%s)", - // typeid(T1).name(), typeid(T).name()); - //exit(1); } + // printf("Unsupported combination of buffer format (=%s), and destination format + // (=%s)", + // typeid(T1).name(), typeid(T).name()); + // exit(1); +} ////////////////////////////////////////////////////////////////////////////////////// ////%FilterEngine% ////////////////////////////////////////////////////////////////////////////////////// - FilterEngine::FilterEngine(const int &ih, const int &iw, - const int &oh, const int &ow, - const int &kh, const int &kw, - const int &anchor_h, const int &anchor_w, - BorderMode borderType, - bool is_symm_kernel) { - init(ih, iw, oh, ow, kh, kw, anchor_h, anchor_w, borderType, is_symm_kernel); - } - - - FilterEngine::~FilterEngine() - { - delete rowFilter_; - delete colFilter_; - } - - void FilterEngine::init(const int &ih, const int &iw, - const int &oh, const int &ow, - const int &kh, const int &kw, - const int &anchor_h, const int &anchor_w, - BorderMode borderType, - bool is_symm_kernel) { - // reduce warning - int wrn = ih + iw + oh; ++wrn; - - ksize_x_ = kw; - ksize_y_ = kh; - anchor_x_ = anchor_w; - anchor_y_ = anchor_h; - borderType_ = borderType; - is_symm_kernel_ = is_symm_kernel; - - rowFilter_ = getLinearRowFilter(kw, is_symm_kernel_); - colFilter_ = getLinearColumnFilter(kh, 0, is_symm_kernel_); - - rowBufferOutputRow_ = 1; - maxBufferRow_ = ksize_y_ + rowBufferOutputRow_ - 1; - //int rowBuffStride_ = sizeof(float)*(int)align_size(maxWidth + (ksize_y_ - 1),VEC_ALIGN); - rowBuffStride_ = sizeof(float) * (int)align_size(ow, VEC_ALIGN); - row_ptr_.resize(maxBufferRow_); - ringBuf_.resize(rowBuffStride_ * maxBufferRow_ + VEC_ALIGN); - - // There is no need to use constBorder when padding == 0. - //if (borderType_ = BORDER_CONSTANT) { - // constBorderRow.resize(sizeof(int) * (maxWidth + ksize.cols() - 1) + VEC_ALIGN); - //} +FilterEngine::FilterEngine( + const int& ih, const int& iw, const int& oh, const int& ow, const int& kh, + const int& kw, const int& anchor_h, const int& anchor_w, BorderMode borderType, + bool is_symm_kernel) { + init(ih, iw, oh, ow, kh, kw, anchor_h, anchor_w, borderType, is_symm_kernel); +} + +FilterEngine::~FilterEngine() { + delete rowFilter_; + delete colFilter_; +} + +void FilterEngine::init( + const int& ih, const int& iw, const int& oh, const int& ow, const int& kh, + const int& kw, const int& anchor_h, const int& anchor_w, BorderMode borderType, + bool is_symm_kernel) { + // reduce warning + int wrn = ih + iw + oh; + ++wrn; + + ksize_x_ = kw; + ksize_y_ = kh; + anchor_x_ = anchor_w; + anchor_y_ = anchor_h; + borderType_ = borderType; + is_symm_kernel_ = is_symm_kernel; + + rowFilter_ = getLinearRowFilter(kw, is_symm_kernel_); + colFilter_ = getLinearColumnFilter(kh, 0, is_symm_kernel_); + + rowBufferOutputRow_ = 1; + maxBufferRow_ = ksize_y_ + rowBufferOutputRow_ - 1; + // int rowBuffStride_ = sizeof(float)*(int)align_size(maxWidth + (ksize_y_ - + // 1),VEC_ALIGN); + rowBuffStride_ = sizeof(float) * (int)align_size(ow, VEC_ALIGN); + row_ptr_.resize(maxBufferRow_); + ringBuf_.resize(rowBuffStride_ * maxBufferRow_ + VEC_ALIGN); + + // There is no need to use constBorder when padding == 0. + // if (borderType_ = BORDER_CONSTANT) { + // constBorderRow.resize(sizeof(int) * (maxWidth + ksize.cols() - 1) + VEC_ALIGN); + //} +} + +void FilterEngine::exec( + const TensorND& src, const TensorND& kernel_x, const TensorND& kernel_y, + const TensorND& dst) { + // int stride_src = src.layout.stride[1]; + // int stride_dst = dst.layout.stride[1]; + // float *src0 = src.ptr(); + // float *dst0 = dst.ptr(); + float* src_cur_row = src.ptr(); + float* src_cur_step = src.ptr(); + float* dst_cur_chan = dst.ptr(); + int width_src = (int)src.layout.shape[3]; + int width_dst = (int)dst.layout.shape[3]; + int height_src = (int)src.layout.shape[2]; + // int height_dst = dst.layout.shape[2]; + int kernel_chan_stride = (int)kernel_x.layout.stride[1]; + memset(dst.ptr(), 0, sizeof(float) * dst.layout.total_nr_elems()); + + for (int step = 0; step < (int)src.layout.shape[0]; ++step) { + for (int chan_out = 0; chan_out < (int)dst.layout.shape[1]; + ++chan_out, dst_cur_chan += dst.layout.stride[1]) { + float* kx = kernel_x.ptr(); + float* ky = kernel_y.ptr(); + src_cur_row = src_cur_step; + // handle a channel of input + for (int chan_in = 0; chan_in < (int)src.layout.shape[1]; ++chan_in) { + // 1. init row buffer borden + // No need to init row border when padding == 0. + + // 2. fill ring buffer & calculate + int row_count = 0; + int row_ptr_pos = 0; + int dststep = dst.layout.stride[2]; + int bufRows = (int)row_ptr_.size(); + int bi = 0; + float* dst_cur_row = dst_cur_chan; + for (row_count = 0; row_count < height_src; + ++row_count, src_cur_row += width_src) { + // 2.1 Get tab row. No need to do this when padding == 0. + + // 2.2 Calculate a row. + bi = row_count % bufRows; + uchar* brow = + align_ptr(&ringBuf_[0], VEC_ALIGN) + bi * rowBuffStride_; + if (row_count < bufRows - 1) { + row_ptr_[bi] = (float*)brow; + } else { + row_ptr_[bufRows - 1] = (float*)brow; + } + // Get a row & make border + // uchar* row = &srcRow[0]; + // memcpy( row + _dx1*esz, src, (width1 - _dx2 - _dx1)*esz ); + uchar* row = (uchar*)src_cur_row; + (*rowFilter_)(row, brow, (uchar*)kx, width_dst, 1); + // operator()(const uchar* src, uchar* dst, uchar* kernel, int + // width, int cn) + + // Keeping fill the ring_buff until its length is ky + if (row_count < bufRows - 1) { + ++row_ptr_pos; + continue; + } - } + // 2.3 Calculate column + // operator()(const uchar** src, uchar* dst, ST* kernel, int + // dststep, int count, int width) + (*colFilter_)( + (const uchar**)(&row_ptr_[0]), (uchar*)dst_cur_row, + (uchar*)ky, dststep, rowBufferOutputRow_, width_dst); - void FilterEngine::exec( const TensorND & src, - const TensorND & kernel_x, - const TensorND & kernel_y, - const TensorND & dst) { - - //int stride_src = src.layout.stride[1]; - //int stride_dst = dst.layout.stride[1]; - //float *src0 = src.ptr(); - //float *dst0 = dst.ptr(); - float * src_cur_row = src.ptr(); - float * src_cur_step = src.ptr(); - float * dst_cur_chan = dst.ptr(); - int width_src = (int)src.layout.shape[3]; - int width_dst = (int)dst.layout.shape[3]; - int height_src = (int)src.layout.shape[2]; - //int height_dst = dst.layout.shape[2]; - int kernel_chan_stride = (int)kernel_x.layout.stride[1]; - memset(dst.ptr(), 0, sizeof(float) * dst.layout.total_nr_elems()); - - for(int step = 0; step < (int)src.layout.shape[0]; ++step) { - for(int chan_out = 0; chan_out < (int)dst.layout.shape[1]; - ++ chan_out, dst_cur_chan += dst.layout.stride[1]) { - float* kx = kernel_x.ptr(); - float* ky = kernel_y.ptr(); - src_cur_row = src_cur_step; - // handle a channel of input - for(int chan_in = 0; chan_in < (int)src.layout.shape[1]; ++ chan_in) { - // 1. init row buffer borden - // No need to init row border when padding == 0. - - // 2. fill ring buffer & calculate - int row_count = 0; - int row_ptr_pos = 0; - int dststep = dst.layout.stride[2]; - int bufRows = (int)row_ptr_.size(); - int bi = 0; - float* dst_cur_row = dst_cur_chan; - for(row_count = 0; row_count < height_src; - ++row_count, src_cur_row += width_src) { - - //2.1 Get tab row. No need to do this when padding == 0. - - //2.2 Calculate a row. - bi = row_count % bufRows; - uchar* brow = align_ptr(&ringBuf_[0], VEC_ALIGN) + bi * rowBuffStride_; - if(row_count < bufRows - 1) { - row_ptr_[bi] = (float*)brow; - } else { - row_ptr_[bufRows - 1] = (float*)brow; - } - - // Get a row & make border - //uchar* row = &srcRow[0]; - //memcpy( row + _dx1*esz, src, (width1 - _dx2 - _dx1)*esz ); - uchar* row = (uchar*)src_cur_row; - (*rowFilter_)(row, brow, (uchar*)kx, width_dst, 1); - // operator()(const uchar* src, uchar* dst, uchar* kernel, int width, int cn) - - // Keeping fill the ring_buff until its length is ky - if(row_count < bufRows - 1) { - ++ row_ptr_pos; - continue; - } - - // 2.3 Calculate column - // operator()(const uchar** src, uchar* dst, ST* kernel, int dststep, int count, int width) - (*colFilter_)((const uchar**)(&row_ptr_[0]), (uchar*)dst_cur_row, - (uchar*)ky, dststep, rowBufferOutputRow_, width_dst); - - // Update row_ptr - for(int i = 0; i< bufRows - 1; ++i) { - row_ptr_[i] = row_ptr_[i+1]; - } - dst_cur_row += width_dst; //dst.layout.stride[2]; + // Update row_ptr + for (int i = 0; i < bufRows - 1; ++i) { + row_ptr_[i] = row_ptr_[i + 1]; } - kx += kernel_chan_stride; - ky += kernel_chan_stride; - } // chan_in - } // chan_out - src_cur_step += src.layout.shape[0]; - } //step_in - } - -} // namespace sep_conv -} // namespace arm_common -} // namespace megdnn + dst_cur_row += width_dst; // dst.layout.stride[2]; + } + kx += kernel_chan_stride; + ky += kernel_chan_stride; + } // chan_in + } // chan_out + src_cur_step += src.layout.shape[0]; + } // step_in +} + +} // namespace sep_conv +} // namespace arm_common +} // namespace megdnn diff --git a/dnn/src/arm_common/separable_filter/filter.h b/dnn/src/arm_common/separable_filter/filter.h index 109703ff..5414d81b 100644 --- a/dnn/src/arm_common/separable_filter/filter.h +++ b/dnn/src/arm_common/separable_filter/filter.h @@ -60,10 +60,10 @@ */ #pragma once -#include "src/common/cv/filter.h" -#include "src/arm_common/simd_macro/marm_neon.h" #include #include +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/cv/filter.h" namespace megdnn { namespace megcv { @@ -198,8 +198,7 @@ struct SymmColumnVec_32s8u { ~SymmColumnVec_32s8u() { free(kernel); } - int operator()(const uchar** _src, uchar* dst, int& count, - int width) const { + int operator()(const uchar** _src, uchar* dst, int& count, int width) const { MEGDNN_MARK_USED_VAR(count); int _ksize = ksize; int ksize2 = _ksize / 2; @@ -378,10 +377,9 @@ struct SymmRowSmallVec_32f { x0 = vmlaq_f32(x0, x1, k0); y0 = vmlaq_f32(y0, y1, k0); - x2 = vaddq_f32(vld1q_f32(src + cn * 2), - vld1q_f32(src - cn * 2)); - y2 = vaddq_f32(vld1q_f32(src + cn * 2 + 4), - vld1q_f32(src - cn * 2 + 4)); + x2 = vaddq_f32(vld1q_f32(src + cn * 2), vld1q_f32(src - cn * 2)); + y2 = vaddq_f32( + vld1q_f32(src + cn * 2 + 4), vld1q_f32(src - cn * 2 + 4)); x0 = vmlaq_f32(x0, x2, k2); y0 = vmlaq_f32(y0, y2, k2); @@ -551,8 +549,7 @@ struct SymmColumnSmallVec_32f { kernel = (float*)_kernel; } - int operator()(const uchar** _src, uchar* _dst, int& count, - int width) const { + int operator()(const uchar** _src, uchar* _dst, int& count, int width) const { MEGDNN_MARK_USED_VAR(count); int ksize2 = (ksize) / 2; const float* ky = (float*)kernel + ksize2; @@ -597,13 +594,13 @@ static BaseColumnFilter* getLinearColumnFilter(Mat& kernel, int bits) { uchar* kernel_str = static_cast(kernel.raw_ptr()); if (SYMM && ksize == 3) { if (std::is_same::value && std::is_same::value) - return new SymmColumnSmallFilter, - SymmColumnVec_32s8u>( + return new SymmColumnSmallFilter< + FixedPtCastEx, SymmColumnVec_32s8u>( kernel, anchor, FixedPtCastEx(bits), SymmColumnVec_32s8u(kernel_str, ksize, bits)); if (std::is_same::value && std::is_same::value) - return new SymmColumnSmallFilter, - SymmColumnSmallVec_32f>( + return new SymmColumnSmallFilter< + FixedPtCastEx, SymmColumnSmallVec_32f>( kernel, anchor, FixedPtCastEx(0), SymmColumnSmallVec_32f(kernel_str, ksize, 0)); } @@ -618,8 +615,7 @@ static BaseColumnFilter* getLinearColumnFilter(Mat& kernel, int bits) { kernel, anchor, FixedPtCastEx(), ColumnVec_32f(kernel_str, ksize, 0)); - MegCVException( - "Unsupported combination of source format and buffer format\n"); + MegCVException("Unsupported combination of source format and buffer format\n"); } /*! @@ -644,19 +640,18 @@ static BaseRowFilter* getLinearRowFilter(Mat& kernel) { } if (std::is_same::value && std::is_same::value) - return new RowFilter(kernel, anchor, - RowNoVec(kernel_str, ksize)); + return new RowFilter( + kernel, anchor, RowNoVec(kernel_str, ksize)); if (std::is_same::value && std::is_same::value) - return new RowFilter(kernel, anchor, - RowVec_32f(kernel_str, ksize)); + return new RowFilter( + kernel, anchor, RowVec_32f(kernel_str, ksize)); - MegCVException( - "Unsupported combination of source format and buffer format\n"); + MegCVException("Unsupported combination of source format and buffer format\n"); } } // namespace sep_filter -} // namespace x86 +} // namespace megcv } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_filter/opr_impl.cpp b/dnn/src/arm_common/separable_filter/opr_impl.cpp index bb7ac329..1482cdfd 100644 --- a/dnn/src/arm_common/separable_filter/opr_impl.cpp +++ b/dnn/src/arm_common/separable_filter/opr_impl.cpp @@ -59,12 +59,12 @@ * --------------------------------------------------------------------------- */ #include "src/arm_common/separable_filter/opr_impl.h" -#include "src/arm_common/separable_filter/filter.h" +#include #include "src/arm_common/handle.h" +#include "src/arm_common/separable_filter/filter.h" #include "src/common/cv/common.h" #include "src/common/cv/helper.h" #include "src/common/utils.h" -#include namespace megdnn { namespace arm_common { @@ -72,16 +72,15 @@ using namespace megcv; using namespace sep_filter; using BorderMode = param::SeparableFilter::BorderMode; -void SeparableFilterImpl::separable_filter_exec_8u(_megdnn_tensor_in src, - _megdnn_tensor_in filter_x, - _megdnn_tensor_in filter_y, - _megdnn_tensor_out dst) { +void SeparableFilterImpl::separable_filter_exec_8u( + _megdnn_tensor_in src, _megdnn_tensor_in filter_x, _megdnn_tensor_in filter_y, + _megdnn_tensor_out dst) { megdnn_assert(src.layout.dtype == dtype::Uint8()); - Mat kernel_column(1, filter_y.layout.shape[3], 1, - static_cast(filter_y.raw_ptr)); - Mat kernel_row(1, filter_x.layout.shape[3], 1, - static_cast(filter_x.raw_ptr)); + Mat kernel_column( + 1, filter_y.layout.shape[3], 1, static_cast(filter_y.raw_ptr)); + Mat kernel_row( + 1, filter_x.layout.shape[3], 1, static_cast(filter_x.raw_ptr)); size_t src_channels = src.layout.shape[3]; @@ -104,16 +103,16 @@ void SeparableFilterImpl::separable_filter_exec_8u(_megdnn_tensor_in src, BaseColumnFilter* columnFilter = nullptr; if (param().is_symm_kernel) { rowFilter = getLinearRowFilter(kernel_row_int); - columnFilter = getLinearColumnFilter( - kernel_column_int, bits * 2); + columnFilter = + getLinearColumnFilter(kernel_column_int, bits * 2); } else { rowFilter = getLinearRowFilter(kernel_row_int); - columnFilter = getLinearColumnFilter( - kernel_column_int, bits * 2); + columnFilter = + getLinearColumnFilter(kernel_column_int, bits * 2); } - FilterEngine filter(rowFilter, columnFilter, src_channels, - border_value, param().borderMode); + FilterEngine filter( + rowFilter, columnFilter, src_channels, border_value, param().borderMode); megdnn_assert(param().borderMode != BorderMode::BORDER_ISOLATED); for (size_t i = 0; i < src.layout.shape[0]; ++i) { @@ -125,14 +124,13 @@ void SeparableFilterImpl::separable_filter_exec_8u(_megdnn_tensor_in src, } template -void SeparableFilterImpl::separable_filter_exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter_x, - _megdnn_tensor_in filter_y, - _megdnn_tensor_out dst) { - Mat kernel_column(1, filter_y.layout.shape[3], 1, - static_cast(filter_y.raw_ptr)); - Mat kernel_row(1, filter_x.layout.shape[3], 1, - static_cast(filter_x.raw_ptr)); +void SeparableFilterImpl::separable_filter_exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter_x, _megdnn_tensor_in filter_y, + _megdnn_tensor_out dst) { + Mat kernel_column( + 1, filter_y.layout.shape[3], 1, static_cast(filter_y.raw_ptr)); + Mat kernel_row( + 1, filter_x.layout.shape[3], 1, static_cast(filter_x.raw_ptr)); size_t src_channels = src.layout.shape[3]; T border_value[4] = {0, 0, 0, 0}; @@ -141,16 +139,14 @@ void SeparableFilterImpl::separable_filter_exec(_megdnn_tensor_in src, BaseColumnFilter* column_filter = nullptr; if (param().is_symm_kernel) { row_filter = getLinearRowFilter(kernel_row); - column_filter = - getLinearColumnFilter(kernel_column, (int)0); + column_filter = getLinearColumnFilter(kernel_column, (int)0); } else { row_filter = getLinearRowFilter(kernel_row); - column_filter = - getLinearColumnFilter(kernel_column, (int)0); + column_filter = getLinearColumnFilter(kernel_column, (int)0); } - FilterEngine filter(row_filter, column_filter, src_channels, - border_value, param().borderMode); + FilterEngine filter( + row_filter, column_filter, src_channels, border_value, param().borderMode); megdnn_assert(param().borderMode != BorderMode::BORDER_ISOLATED); for (size_t i = 0; i < src.layout.shape[0]; ++i) { @@ -160,13 +156,11 @@ void SeparableFilterImpl::separable_filter_exec(_megdnn_tensor_in src, } } -void SeparableFilterImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter_x, - _megdnn_tensor_in filter_y, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) { - check_exec(src.layout, filter_x.layout, filter_y.layout, dst.layout, - workspace.size); +void SeparableFilterImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter_x, _megdnn_tensor_in filter_y, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec( + src.layout, filter_x.layout, filter_y.layout, dst.layout, workspace.size); if (dst.layout.dtype == dtype::Float32()) { MEGDNN_DISPATCH_CPU_KERN_OPR( separable_filter_exec(src, filter_x, filter_y, dst)); @@ -178,7 +172,7 @@ void SeparableFilterImpl::exec(_megdnn_tensor_in src, }; } -} // namespace x86 -} // namespace megdnn +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/separable_filter/opr_impl.h b/dnn/src/arm_common/separable_filter/opr_impl.h index b3e4c3b0..0f74018d 100644 --- a/dnn/src/arm_common/separable_filter/opr_impl.h +++ b/dnn/src/arm_common/separable_filter/opr_impl.h @@ -15,28 +15,27 @@ namespace arm_common { class SeparableFilterImpl : public SeparableFilterForward { public: using SeparableFilterForward::SeparableFilterForward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x, - _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&, - const TensorLayout&) override { + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { return 0; } private: template - void separable_filter_exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter_x, - _megdnn_tensor_in filter_y, - _megdnn_tensor_out dst); - void separable_filter_exec_8u(_megdnn_tensor_in src, - _megdnn_tensor_in filter_x, - _megdnn_tensor_in filter_y, - _megdnn_tensor_out dst); + void separable_filter_exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, _megdnn_tensor_out dst); + void separable_filter_exec_8u( + _megdnn_tensor_in src, _megdnn_tensor_in filter_x, + _megdnn_tensor_in filter_y, _megdnn_tensor_out dst); }; -} // namespace arm_common -} // namespace megdnn +} // namespace arm_common +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/simd_macro/marm_neon.cpp b/dnn/src/arm_common/simd_macro/marm_neon.cpp index 694b2e2b..5fedb82b 100644 --- a/dnn/src/arm_common/simd_macro/marm_neon.cpp +++ b/dnn/src/arm_common/simd_macro/marm_neon.cpp @@ -13,4 +13,3 @@ #pragma message \ "remove these functions defined in march_neon.h when these functions defined in the future compiler(arm_neon.h)" - diff --git a/dnn/src/arm_common/simd_macro/marm_neon.h b/dnn/src/arm_common/simd_macro/marm_neon.h index 17d544ee..00174984 100644 --- a/dnn/src/arm_common/simd_macro/marm_neon.h +++ b/dnn/src/arm_common/simd_macro/marm_neon.h @@ -27,9 +27,8 @@ #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wpragmas" #pragma GCC diagnostic ignored "-Wattributes" -#define __ai \ - static inline \ - __attribute__((__gnu_inline__, __always_inline__, __nodebug__)) +#define __ai \ + static inline __attribute__((__gnu_inline__, __always_inline__, __nodebug__)) #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && !MEGDNN_DISABLE_FLOAT16 #define MEGDNN_INC_ARM_FP16(_x) _x @@ -41,15 +40,13 @@ //! copy from arm_neon, as in clang7.0 these function not exists #ifdef __LITTLE_ENDIAN__ -__ai float16x8_t vmlaq_f16(float16x8_t __p0, float16x8_t __p1, - float16x8_t __p2) { +__ai float16x8_t vmlaq_f16(float16x8_t __p0, float16x8_t __p1, float16x8_t __p2) { float16x8_t __ret; __ret = __p0 + __p1 * __p2; return __ret; } #else -__ai float16x8_t vmlaq_f16(float16x8_t __p0, float16x8_t __p1, - float16x8_t __p2) { +__ai float16x8_t vmlaq_f16(float16x8_t __p0, float16x8_t __p1, float16x8_t __p2) { float16x8_t __rev0; __rev0 = __builtin_shufflevector(__p0, __p0, 7, 6, 5, 4, 3, 2, 1, 0); float16x8_t __rev1; @@ -64,35 +61,35 @@ __ai float16x8_t vmlaq_f16(float16x8_t __p0, float16x8_t __p1, #endif #ifdef __LITTLE_ENDIAN__ -#define vmlaq_lane_f16(__p0, __p1, __p2, __p3) \ - __extension__({ \ - float16x8_t __s0 = __p0; \ - float16x8_t __s1 = __p1; \ - float16x4_t __s2 = __p2; \ - float16x8_t __ret; \ - __ret = __s0 + __s1 * __builtin_shufflevector(__s2, __s2, __p3, __p3, \ - __p3, __p3, __p3, __p3, \ - __p3, __p3); \ - __ret; \ +#define vmlaq_lane_f16(__p0, __p1, __p2, __p3) \ + __extension__({ \ + float16x8_t __s0 = __p0; \ + float16x8_t __s1 = __p1; \ + float16x4_t __s2 = __p2; \ + float16x8_t __ret; \ + __ret = __s0 + __s1 * __builtin_shufflevector( \ + __s2, __s2, __p3, __p3, __p3, __p3, __p3, __p3, \ + __p3, __p3); \ + __ret; \ }) #else -#define vmlaq_lane_f16(__p0, __p1, __p2, __p3) \ - __extension__({ \ - float16x8_t __s0 = __p0; \ - float16x8_t __s1 = __p1; \ - float16x4_t __s2 = __p2; \ - float16x8_t __rev0; \ - __rev0 = __builtin_shufflevector(__s0, __s0, 7, 6, 5, 4, 3, 2, 1, 0); \ - float16x8_t __rev1; \ - __rev1 = __builtin_shufflevector(__s1, __s1, 7, 6, 5, 4, 3, 2, 1, 0); \ - float16x4_t __rev2; \ - __rev2 = __builtin_shufflevector(__s2, __s2, 3, 2, 1, 0); \ - float16x8_t __ret; \ - __ret = __rev0 + __rev1 * __builtin_shufflevector( \ - __rev2, __rev2, __p3, __p3, __p3, \ - __p3, __p3, __p3, __p3, __p3); \ - __ret = __builtin_shufflevector(__ret, __ret, 7, 6, 5, 4, 3, 2, 1, 0); \ - __ret; \ +#define vmlaq_lane_f16(__p0, __p1, __p2, __p3) \ + __extension__({ \ + float16x8_t __s0 = __p0; \ + float16x8_t __s1 = __p1; \ + float16x4_t __s2 = __p2; \ + float16x8_t __rev0; \ + __rev0 = __builtin_shufflevector(__s0, __s0, 7, 6, 5, 4, 3, 2, 1, 0); \ + float16x8_t __rev1; \ + __rev1 = __builtin_shufflevector(__s1, __s1, 7, 6, 5, 4, 3, 2, 1, 0); \ + float16x4_t __rev2; \ + __rev2 = __builtin_shufflevector(__s2, __s2, 3, 2, 1, 0); \ + float16x8_t __ret; \ + __ret = __rev0 + __rev1 * __builtin_shufflevector( \ + __rev2, __rev2, __p3, __p3, __p3, __p3, \ + __p3, __p3, __p3, __p3); \ + __ret = __builtin_shufflevector(__ret, __ret, 7, 6, 5, 4, 3, 2, 1, 0); \ + __ret; \ }) #endif @@ -119,35 +116,35 @@ __ai float16x8_t vdupq_n_f16(float16_t __p0) { #endif #ifdef __LITTLE_ENDIAN__ -#define vmlaq_laneq_f16(__p0, __p1, __p2, __p3) \ - __extension__({ \ - float16x8_t __s0 = __p0; \ - float16x8_t __s1 = __p1; \ - float16x8_t __s2 = __p2; \ - float16x8_t __ret; \ - __ret = __s0 + __s1 * __builtin_shufflevector(__s2, __s2, __p3, __p3, \ - __p3, __p3, __p3, __p3, \ - __p3, __p3); \ - __ret; \ +#define vmlaq_laneq_f16(__p0, __p1, __p2, __p3) \ + __extension__({ \ + float16x8_t __s0 = __p0; \ + float16x8_t __s1 = __p1; \ + float16x8_t __s2 = __p2; \ + float16x8_t __ret; \ + __ret = __s0 + __s1 * __builtin_shufflevector( \ + __s2, __s2, __p3, __p3, __p3, __p3, __p3, __p3, \ + __p3, __p3); \ + __ret; \ }) #else -#define vmlaq_laneq_f16(__p0, __p1, __p2, __p3) \ - __extension__({ \ - float16x8_t __s0 = __p0; \ - float16x8_t __s1 = __p1; \ - float16x8_t __s2 = __p2; \ - float16x8_t __rev0; \ - __rev0 = __builtin_shufflevector(__s0, __s0, 7, 6, 5, 4, 3, 2, 1, 0); \ - float16x8_t __rev1; \ - __rev1 = __builtin_shufflevector(__s1, __s1, 7, 6, 5, 4, 3, 2, 1, 0); \ - float16x8_t __rev2; \ - __rev2 = __builtin_shufflevector(__s2, __s2, 7, 6, 5, 4, 3, 2, 1, 0); \ - float16x8_t __ret; \ - __ret = __rev0 + __rev1 * __builtin_shufflevector( \ - __rev2, __rev2, __p3, __p3, __p3, \ - __p3, __p3, __p3, __p3, __p3); \ - __ret = __builtin_shufflevector(__ret, __ret, 7, 6, 5, 4, 3, 2, 1, 0); \ - __ret; \ +#define vmlaq_laneq_f16(__p0, __p1, __p2, __p3) \ + __extension__({ \ + float16x8_t __s0 = __p0; \ + float16x8_t __s1 = __p1; \ + float16x8_t __s2 = __p2; \ + float16x8_t __rev0; \ + __rev0 = __builtin_shufflevector(__s0, __s0, 7, 6, 5, 4, 3, 2, 1, 0); \ + float16x8_t __rev1; \ + __rev1 = __builtin_shufflevector(__s1, __s1, 7, 6, 5, 4, 3, 2, 1, 0); \ + float16x8_t __rev2; \ + __rev2 = __builtin_shufflevector(__s2, __s2, 7, 6, 5, 4, 3, 2, 1, 0); \ + float16x8_t __ret; \ + __ret = __rev0 + __rev1 * __builtin_shufflevector( \ + __rev2, __rev2, __p3, __p3, __p3, __p3, \ + __p3, __p3, __p3, __p3); \ + __ret = __builtin_shufflevector(__ret, __ret, 7, 6, 5, 4, 3, 2, 1, 0); \ + __ret; \ }) #endif @@ -219,8 +216,7 @@ __ai float16x8_t vdupq_n_f16(__fp16 a) { /////////////////////////////////////////////////////////////////////// #elif MEGDNN_AARCH64 -#define vmlaq_low_lane_f16(__a, __b, __v, __lane) \ - vmlaq_laneq_f16(__a, __b, __v, __lane) +#define vmlaq_low_lane_f16(__a, __b, __v, __lane) vmlaq_laneq_f16(__a, __b, __v, __lane) #define vmlaq_high_lane_f16(__a, __b, __v, __lane) \ vmlaq_laneq_f16(__a, __b, __v, __lane) @@ -347,12 +343,7 @@ __ai uint8x16_t vtranslq_u8(uint8x8_t a) { #ifdef MEGDNN_TEGRA_X1 #define vset_lane_s16_fix_tx1(__elem, __vec, __index) \ - { \ - asm volatile("ins %0.h[" #__index "], %w1\n" \ - : "+w"(__vec) \ - : "r"(__elem) \ - :); \ - } + { asm volatile("ins %0.h[" #__index "], %w1\n" : "+w"(__vec) : "r"(__elem) :); } #else #define vset_lane_s16_fix_tx1(__elem, __vec, __index) \ __vec = vset_lane_s16(__elem, __vec, __index) @@ -362,8 +353,9 @@ __ai uint8x16_t vtranslq_u8(uint8x8_t a) { __ai int32_t vaddlvq_s16(int16x8_t __p0) { int32_t __ret = 0; auto sum = vpaddlq_s16(__p0); - __ret += (vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) + - vgetq_lane_s32(sum, 2) + vgetq_lane_s32(sum, 3)); + __ret += + (vgetq_lane_s32(sum, 0) + vgetq_lane_s32(sum, 1) + vgetq_lane_s32(sum, 2) + + vgetq_lane_s32(sum, 3)); return __ret; } @@ -417,13 +409,13 @@ __ai int32_t vaddv_s32(int32x2_t a) { } __ai int32_t vaddvq_s32(int32x4_t a) { - return vgetq_lane_s32(a, 0) + vgetq_lane_s32(a, 1) + - vgetq_lane_s32(a, 2) + vgetq_lane_s32(a, 3); + return vgetq_lane_s32(a, 0) + vgetq_lane_s32(a, 1) + vgetq_lane_s32(a, 2) + + vgetq_lane_s32(a, 3); } __ai float32_t vaddvq_f32(float32x4_t a) { - return vgetq_lane_f32(a, 0) + vgetq_lane_f32(a, 1) + - vgetq_lane_f32(a, 2) + vgetq_lane_f32(a, 3); + return vgetq_lane_f32(a, 0) + vgetq_lane_f32(a, 1) + vgetq_lane_f32(a, 2) + + vgetq_lane_f32(a, 3); } #endif // MEGDNN_ARMV7 @@ -470,15 +462,11 @@ __ai uint64x2_t vmovl_low_u32(uint32x4_t __p0) { #elif MEGDNN_AARCH64 __ai float64x2_t vbitq_f64(float64x2_t dst, float64x2_t v1, uint64x2_t mask) { - asm volatile("bit %0.16b, %1.16b, %2.16b\n" - : "+w"(dst) - : "w"(v1), "w"(mask) - :); + asm volatile("bit %0.16b, %1.16b, %2.16b\n" : "+w"(dst) : "w"(v1), "w"(mask) :); return dst; } -#define vmlaq_low_lane_f32(__a, __b, __v, __lane) \ - vmlaq_laneq_f32(__a, __b, __v, __lane) +#define vmlaq_low_lane_f32(__a, __b, __v, __lane) vmlaq_laneq_f32(__a, __b, __v, __lane) #define vmlaq_high_lane_f32(__a, __b, __v, __lane) \ vmlaq_laneq_f32(__a, __b, __v, __lane) @@ -489,10 +477,9 @@ __ai float64x2_t vbitq_f64(float64x2_t dst, float64x2_t v1, uint64x2_t mask) { __ai int8x16_t vqtbl1q_s8(int8x16_t& a, uint8x16_t& idx) { int8x8_t src_low = vget_low_s8(a); int8x8_t src_high = vget_high_s8(a); - return vcombine_s8(vtbl2_s8({src_low, src_high}, - vget_low_s8(vreinterpretq_s8_u8(idx))), - vtbl2_s8({src_low, src_high}, - vget_high_s8(vreinterpretq_s8_u8(idx)))); + return vcombine_s8( + vtbl2_s8({src_low, src_high}, vget_low_s8(vreinterpretq_s8_u8(idx))), + vtbl2_s8({src_low, src_high}, vget_high_s8(vreinterpretq_s8_u8(idx)))); } namespace { template @@ -578,11 +565,9 @@ struct Vfmsq_laneq_f32_armv7<3> { } }; } // namespace -#define vfmaq_laneq_f32(a, b, v, lane) \ - Vfmaq_laneq_f32_armv7::impl(a, b, v) +#define vfmaq_laneq_f32(a, b, v, lane) Vfmaq_laneq_f32_armv7::impl(a, b, v) -#define vfmsq_laneq_f32(a, b, v, lane) \ - Vfmsq_laneq_f32_armv7::impl(a, b, v) +#define vfmsq_laneq_f32(a, b, v, lane) Vfmsq_laneq_f32_armv7::impl(a, b, v) #if MGB_ENABLE_DOT namespace { @@ -618,8 +603,7 @@ struct Vdotq_laneq_s32_armv7<3> { return vdotq_lane_s32(a, b, vget_high_f32(v), 1); } }; -#define vdotq_laneq_s32(a, b, v, lane) \ - Vdotq_laneq_s32_armv7::impl(a, b, v) +#define vdotq_laneq_s32(a, b, v, lane) Vdotq_laneq_s32_armv7::impl(a, b, v) } // namespace #endif @@ -638,40 +622,28 @@ struct Vfmaq_laneq_f32_armv8 { template <> struct Vfmaq_laneq_f32_armv8<0> { __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { - asm volatile("fmla %0.4s, %1.4s, %2.s[0]\n" - : "+w"(a) - : "w"(b), "w"(v) - :); + asm volatile("fmla %0.4s, %1.4s, %2.s[0]\n" : "+w"(a) : "w"(b), "w"(v) :); return a; } }; template <> struct Vfmaq_laneq_f32_armv8<1> { __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { - asm volatile("fmla %0.4s, %1.4s, %2.s[1]\n" - : "+w"(a) - : "w"(b), "w"(v) - :); + asm volatile("fmla %0.4s, %1.4s, %2.s[1]\n" : "+w"(a) : "w"(b), "w"(v) :); return a; } }; template <> struct Vfmaq_laneq_f32_armv8<2> { __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { - asm volatile("fmla %0.4s, %1.4s, %2.s[2]\n" - : "+w"(a) - : "w"(b), "w"(v) - :); + asm volatile("fmla %0.4s, %1.4s, %2.s[2]\n" : "+w"(a) : "w"(b), "w"(v) :); return a; } }; template <> struct Vfmaq_laneq_f32_armv8<3> { __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { - asm volatile("fmla %0.4s, %1.4s, %2.s[3]\n" - : "+w"(a) - : "w"(b), "w"(v) - :); + asm volatile("fmla %0.4s, %1.4s, %2.s[3]\n" : "+w"(a) : "w"(b), "w"(v) :); return a; } }; @@ -683,51 +655,37 @@ struct Vfmsq_laneq_f32_armv8 { template <> struct Vfmsq_laneq_f32_armv8<0> { __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { - asm volatile("fmls %0.4s, %1.4s, %2.s[0]\n" - : "+w"(a) - : "w"(b), "w"(v) - :); + asm volatile("fmls %0.4s, %1.4s, %2.s[0]\n" : "+w"(a) : "w"(b), "w"(v) :); return a; } }; template <> struct Vfmsq_laneq_f32_armv8<1> { __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { - asm volatile("fmls %0.4s, %1.4s, %2.s[1]\n" - : "+w"(a) - : "w"(b), "w"(v) - :); + asm volatile("fmls %0.4s, %1.4s, %2.s[1]\n" : "+w"(a) : "w"(b), "w"(v) :); return a; } }; template <> struct Vfmsq_laneq_f32_armv8<2> { __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { - asm volatile("fmls %0.4s, %1.4s, %2.s[2]\n" - : "+w"(a) - : "w"(b), "w"(v) - :); + asm volatile("fmls %0.4s, %1.4s, %2.s[2]\n" : "+w"(a) : "w"(b), "w"(v) :); return a; } }; template <> struct Vfmsq_laneq_f32_armv8<3> { __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { - asm volatile("fmls %0.4s, %1.4s, %2.s[3]\n" - : "+w"(a) - : "w"(b), "w"(v) - :); + asm volatile("fmls %0.4s, %1.4s, %2.s[3]\n" : "+w"(a) : "w"(b), "w"(v) :); return a; } }; } // namespace #undef vfmaq_laneq_f32 -#define vfmaq_laneq_f32(a, b, v, lane) \ - Vfmaq_laneq_f32_armv8::impl(a, b, v) +#define vfmaq_laneq_f32(a, b, v, lane) Vfmaq_laneq_f32_armv8::impl(a, b, v) #undef vfmsq_laneq_f32 -#define vfmsq_laneq_f32(a, b, v, lane) \ - Vfmsq_laneq_f32_armv8::impl(a, b, v) +#define vfmsq_laneq_f32(a, b, v, lane) Vfmsq_laneq_f32_armv8::impl(a, b, v) #endif __ai int8x16_t vld_dup_tbl_s32(const int8_t* ptr, uint8x16_t& idx) { @@ -740,15 +698,13 @@ __ai int8x16_t vldq_tbl_s8(const int8_t* ptr, uint8x16_t& idx) { result = vqtbl1q_s8(result, idx); return result; } -__ai int32x4_t vdotq_s32_h(int8x16_t& a, int8x16_t& b, int32x4_t& c, - int16x8_t& temp) { +__ai int32x4_t vdotq_s32_h(int8x16_t& a, int8x16_t& b, int32x4_t& c, int16x8_t& temp) { temp = vmull_s8(vget_low_s8(a), vget_low_s8(b)); temp = vmlal_high_s8(temp, a, b); c = vpadalq_s16(c, temp); return c; } -__ai int32x4_t vdot2_s32_h(int8x8_t& a, int8x8_t& b, int32x4_t& c, - int16x8_t& temp) { +__ai int32x4_t vdot2_s32_h(int8x8_t& a, int8x8_t& b, int32x4_t& c, int16x8_t& temp) { temp = vmull_s8(a, b); c = vpadalq_s16(c, temp); return c; @@ -759,8 +715,8 @@ __ai int32x4_t vmlal_s16(int32x4_t& a, int16x8_t& b, int16x8_t& c) { } __ai int16x8_t vldq_dup_4s8_8s16(const int8_t* ptr) { - return vmovl_s8(vreinterpret_s8_s32( - vld1_dup_s32(reinterpret_cast(ptr)))); + return vmovl_s8( + vreinterpret_s8_s32(vld1_dup_s32(reinterpret_cast(ptr)))); } __ai int8x8_t vldq_tbl_low_s8(const int8_t* ptr, uint8x16_t idx) { return vget_low_s8(vldq_tbl_s8(ptr, idx)); @@ -772,10 +728,7 @@ __ai int16x8_t vld1_dup_s8_s16(const int8_t* ptr) { //! we add this because we found that cpu=aarch64_android cann't compile fmsq into fmls. //! it use dup+fmla instead __ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) { - asm volatile("fmls %0.4s, %1.4s, %2.4s\n" - : "+w"(a) - : "w"(b), "w"(v) - :); + asm volatile("fmls %0.4s, %1.4s, %2.4s\n" : "+w"(a) : "w"(b), "w"(v) :); return a; } #if MGB_ENABLE_DOT diff --git a/dnn/src/arm_common/simd_macro/neon_helper.h b/dnn/src/arm_common/simd_macro/neon_helper.h index 19d1aeca..759198a5 100644 --- a/dnn/src/arm_common/simd_macro/neon_helper.h +++ b/dnn/src/arm_common/simd_macro/neon_helper.h @@ -10,48 +10,53 @@ */ #include "src/arm_common/simd_macro/marm_neon.h" -#define MEGDNN_SIMD_NAME NEON +#define MEGDNN_SIMD_NAME NEON #define MEGDNN_SIMD_TARGET neon #define MEGDNN_SIMD_ATTRIBUTE_TARGET #define MEGDNN_SIMD_LAMBDA_ATTRIBUTE_TARGET -#define MEGDNN_SIMD_WIDTH 4 -#define MEGDNN_SIMD_TYPE float32x4_t -#define MEGDNN_SIMD_TYPE2 float32x4x2_t -#define MEGDNN_SIMD_LOADU(addr) vld1q_f32(addr) +#define MEGDNN_SIMD_WIDTH 4 +#define MEGDNN_SIMD_TYPE float32x4_t +#define MEGDNN_SIMD_TYPE2 float32x4x2_t +#define MEGDNN_SIMD_LOADU(addr) vld1q_f32(addr) #define MEGDNN_SIMD_LOADU_2(addr) vcombine_f32(vld1_f32(addr), vdup_n_f32(0.f)) -#define MEGDNN_SIMD_LOADU_3(addr) vld1q_lane_f32(addr + 2, vcombine_f32(vld1_f32(addr), vdup_n_f32(0.f)), 2) +#define MEGDNN_SIMD_LOADU_3(addr) \ + vld1q_lane_f32(addr + 2, vcombine_f32(vld1_f32(addr), vdup_n_f32(0.f)), 2) #define MEGDNN_SIMD_STOREU(addr, reg) vst1q_f32(addr, reg) -#define MEGDNN_SIMD_SETZERO() vdupq_n_f32(0.0f) -#define MEGDNN_SIMD_SET1(num) vdupq_n_f32(num) +#define MEGDNN_SIMD_SETZERO() vdupq_n_f32(0.0f) +#define MEGDNN_SIMD_SET1(num) vdupq_n_f32(num) // XXX The order of a, b, c #define MEGDNN_SIMD_FMADD(a, b, c) vmlaq_f32(c, a, b) -#define MEGDNN_SIMD_MAX(a, b) vmaxq_f32(a, b) -#define MEGDNN_SIMD_UZP(s0, s1, d0, d1) do { \ - auto tmp__ = vuzpq_f32(s0, s1); \ - d0 = tmp__.val[0]; \ - d1 = tmp__.val[1]; \ -} while (0) -#define MEGDNN_SIMD_LOAD2(addr) vld2q_f32(addr) -#define MEGDNN_SIMD_EXT(a, b, c) vextq_f32(a, b, c) -#define MEGDNN_SIMD_MUL(a, b) vmulq_f32(a, b) -#define MEGDNN_SIMD_ADD(a, b) vaddq_f32(a, b) -#define MEGDNN_SIMD_SET_LANE(a, b, c) vsetq_lane_f32(a, b, c) -#define MEGDNN_SIMD_GET_LOW(a) vget_low_f32(a) -#define MEGDNN_SIMD_GET_HIGH(a) vget_high_f32(a) +#define MEGDNN_SIMD_MAX(a, b) vmaxq_f32(a, b) +#define MEGDNN_SIMD_UZP(s0, s1, d0, d1) \ + do { \ + auto tmp__ = vuzpq_f32(s0, s1); \ + d0 = tmp__.val[0]; \ + d1 = tmp__.val[1]; \ + } while (0) +#define MEGDNN_SIMD_LOAD2(addr) vld2q_f32(addr) +#define MEGDNN_SIMD_EXT(a, b, c) vextq_f32(a, b, c) +#define MEGDNN_SIMD_MUL(a, b) vmulq_f32(a, b) +#define MEGDNN_SIMD_ADD(a, b) vaddq_f32(a, b) +#define MEGDNN_SIMD_SET_LANE(a, b, c) vsetq_lane_f32(a, b, c) +#define MEGDNN_SIMD_GET_LOW(a) vget_low_f32(a) +#define MEGDNN_SIMD_GET_HIGH(a) vget_high_f32(a) #define MEGDNN_SIMD_VMLAQ_LANE(a, b, c, d) vmlaq_lane_f32(a, b, c, d) #if MEGDNN_ARMV7 -#define MEGDNN_SIMD_FMA_LANE(a, b, c, d) ({ \ - auto ret__ = vdupq_n_f32(vgetq_lane_f32(c, d)); \ - ret__ = vmlaq_f32(a, b, ret__); \ - ret__;}) -#define MEGDNN_SIMD_ADD_VEC(a) ({ \ - auto tmp__ = vadd_f32(vget_low_f32(a), vget_high_f32(a)); \ - tmp__ = vpadd_f32(tmp__, tmp__); \ - auto ret__ = vget_lane_f32(tmp__, 0); \ - ret__;}) +#define MEGDNN_SIMD_FMA_LANE(a, b, c, d) \ + ({ \ + auto ret__ = vdupq_n_f32(vgetq_lane_f32(c, d)); \ + ret__ = vmlaq_f32(a, b, ret__); \ + ret__; \ + }) +#define MEGDNN_SIMD_ADD_VEC(a) \ + ({ \ + auto tmp__ = vadd_f32(vget_low_f32(a), vget_high_f32(a)); \ + tmp__ = vpadd_f32(tmp__, tmp__); \ + auto ret__ = vget_lane_f32(tmp__, 0); \ + ret__; \ + }) #else // MEGDNN_AARCH64 #define MEGDNN_SIMD_FMA_LANE(a, b, c, d) vfmaq_laneq_f32(a, b, c, d) -#define MEGDNN_SIMD_ADD_VEC(a) vaddvq_f32(a) +#define MEGDNN_SIMD_ADD_VEC(a) vaddvq_f32(a) #endif - diff --git a/dnn/src/arm_common/simd_macro/neon_helper_fp16.h b/dnn/src/arm_common/simd_macro/neon_helper_fp16.h index abaf81e4..6be60bdd 100644 --- a/dnn/src/arm_common/simd_macro/neon_helper_fp16.h +++ b/dnn/src/arm_common/simd_macro/neon_helper_fp16.h @@ -11,15 +11,15 @@ #include "src/arm_common/simd_macro/marm_neon.h" #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#define MEGDNN_SIMD_NAME NEON +#define MEGDNN_SIMD_NAME NEON #define MEGDNN_SIMD_TARGET neon #define MEGDNN_SIMD_ATTRIBUTE_TARGET -#define MEGDNN_SIMD_WIDTH 4 -#define MEGDNN_SIMD_TYPE float16x8_t -#define MEGDNN_SIMD_TYPE2 float16x8x2_t -#define MEGDNN_SIMD_LOADU(addr) vld1q_f16(addr) +#define MEGDNN_SIMD_WIDTH 4 +#define MEGDNN_SIMD_TYPE float16x8_t +#define MEGDNN_SIMD_TYPE2 float16x8x2_t +#define MEGDNN_SIMD_LOADU(addr) vld1q_f16(addr) #define MEGDNN_SIMD_STOREU(addr, reg) vst1q_f16(addr, reg) -#define MEGDNN_SIMD_SETZERO() vdupq_n_f16(0.0f) -#define MEGDNN_SIMD_SET1(num) vdupq_n_f16(num) +#define MEGDNN_SIMD_SETZERO() vdupq_n_f16(0.0f) +#define MEGDNN_SIMD_SET1(num) vdupq_n_f16(num) #endif diff --git a/dnn/src/arm_common/type_cvt/opr_impl.cpp b/dnn/src/arm_common/type_cvt/opr_impl.cpp index 84c0c7e1..900e16c2 100644 --- a/dnn/src/arm_common/type_cvt/opr_impl.cpp +++ b/dnn/src/arm_common/type_cvt/opr_impl.cpp @@ -45,11 +45,9 @@ struct QuantizedTypeCvter { void cvt(const int32_t* src, int8_t* dst) { float32x4_t vitem0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(src)), vscale); - float32x4_t vitem1 = - vmulq_f32(vcvtq_f32_s32(vld1q_s32(src + 4)), vscale); + float32x4_t vitem1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(src + 4)), vscale); - auto vres = QConverter::convert( - {{vitem0, vitem1}}); + auto vres = QConverter::convert({{vitem0, vitem1}}); vst1_s8(dst, vres); } @@ -75,17 +73,17 @@ struct QuantizedTypeCvter { void cvt(const int8_t* src, int32_t* dst) { int16x8_t vitem = vmovl_s8(vld1_s8(src)); - auto vret0 = QConverter::convert(vmulq_f32( - vcvtq_f32_s32(vmovl_s16(vget_low_s16(vitem))), vscale)); - auto vret1 = QConverter::convert(vmulq_f32( - vcvtq_f32_s32(vmovl_s16(vget_high_s16(vitem))), vscale)); + auto vret0 = QConverter::convert( + vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vitem))), vscale)); + auto vret1 = QConverter::convert( + vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(vitem))), vscale)); vst1q_s32(dst, vret0); vst1q_s32(dst + 4, vret1); } void cvt_remain(const int8_t* src, int32_t* dst) { - *dst = saturate(std::round(*src * scale), -2147483648, - 2147483647); + *dst = saturate( + std::round(*src * scale), -2147483648, 2147483647); } }; @@ -109,8 +107,7 @@ struct QuantizedTypeCvter { float32x4_t vitem0 = vmulq_f32(vld1q_f32(src), vscale); float32x4_t vitem1 = vmulq_f32(vld1q_f32(src + 4), vscale); - auto vres = QConverter::convert( - {{vitem0, vitem1}}); + auto vres = QConverter::convert({{vitem0, vitem1}}); vst1_s8(dst, vres); } @@ -142,8 +139,8 @@ struct QuantizedTypeCvter { } void cvt_remain(const int32_t* src, int32_t* dst) { - *dst = saturate(std::round(*src * scale), -2147483648, - 2147483647); + *dst = saturate( + std::round(*src * scale), -2147483648, 2147483647); } }; @@ -164,13 +161,12 @@ struct QuantizedTypeCvter { void cvt(const int8_t* src, int8_t* dst) { int16x8_t vdata = vmovl_s8(vld1_s8(src)); - float32x4_t vitem0 = vmulq_f32( - vcvtq_f32_s32(vmovl_s16(vget_low_s16(vdata))), vscale); - float32x4_t vitem1 = vmulq_f32( - vcvtq_f32_s32(vmovl_s16(vget_high_s16(vdata))), vscale); + float32x4_t vitem0 = + vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vdata))), vscale); + float32x4_t vitem1 = + vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(vdata))), vscale); - auto vres = QConverter::convert( - {{vitem0, vitem1}}); + auto vres = QConverter::convert({{vitem0, vitem1}}); vst1_s8(dst, vres); } @@ -236,8 +232,7 @@ struct QuantizedTypeCvter { void cvt(const int32_t* src, uint8_t* dst) { float32x4_t vitem0 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(src)), vscale); - float32x4_t vitem1 = - vmulq_f32(vcvtq_f32_s32(vld1q_s32(src + 4)), vscale); + float32x4_t vitem1 = vmulq_f32(vcvtq_f32_s32(vld1q_s32(src + 4)), vscale); auto vres = QConverter::convert( {{vitem0, vitem1}}, this->vzp); vst1_u8(dst, vres); @@ -274,10 +269,10 @@ struct QuantizedTypeCvter { void cvt(const uint8_t* src, uint8_t* dst) { int16x8_t vdata = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(src))); vdata = vsubq_s16(vdata, vzp_src); - float32x4_t vitem0 = vmulq_f32( - vcvtq_f32_s32(vmovl_s16(vget_low_s16(vdata))), vscale); - float32x4_t vitem1 = vmulq_f32( - vcvtq_f32_s32(vmovl_s16(vget_high_s16(vdata))), vscale); + float32x4_t vitem0 = + vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vdata))), vscale); + float32x4_t vitem1 = + vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_high_s16(vdata))), vscale); auto vres = QConverter::convert( {{vitem0, vitem1}}, this->vzp_dst); @@ -322,8 +317,7 @@ struct FloatTypeCvter { void cvt(const float* src, __fp16* dst) { float32x4_t vdata0 = vld1q_f32(src); float32x4_t vdata1 = vld1q_f32(src + 4); - float16x8_t vdata = - vcombine_f16(vcvt_f16_f32(vdata0), vcvt_f16_f32(vdata1)); + float16x8_t vdata = vcombine_f16(vcvt_f16_f32(vdata0), vcvt_f16_f32(vdata1)); vst1q_f16(dst, vdata); } @@ -332,9 +326,9 @@ struct FloatTypeCvter { #endif template -void do_typecvt(const typename TypeCvter::stype* src, - typename TypeCvter::dst_type* dst, DType src_dtype, - DType dst_dtype, size_t nr_elems) { +void do_typecvt( + const typename TypeCvter::stype* src, typename TypeCvter::dst_type* dst, + DType src_dtype, DType dst_dtype, size_t nr_elems) { TypeCvter typecvt(src_dtype, dst_dtype); size_t i = 0; for (; i + TypeCvter::SIMD_WIDTH <= nr_elems; i += TypeCvter::SIMD_WIDTH) { @@ -362,44 +356,41 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { bool execed = false; if (src.layout.is_contiguous()) { using namespace dtype; -#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, \ - _midout_iv) \ - if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ - dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ - MIDOUT_BEGIN(megdnn_arm_typecvt_quantized, midout_iv(_midout_iv)) { \ - using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \ - MEGDNN_DISPATCH_CPU_KERN_OPR( \ - do_typecvt<_TypeCvter>(src.compatible_ptr<_stype>(), \ - dst.compatible_ptr<_dtype>(), \ - src_dtype, dst_dtype, nr_elems)); \ - execed = true; \ - } \ - MIDOUT_END(); \ +#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ + if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ + dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ + MIDOUT_BEGIN(megdnn_arm_typecvt_quantized, midout_iv(_midout_iv)) { \ + using _TypeCvter = QuantizedTypeCvter<_stype, _dtype>; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ + src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \ + src_dtype, dst_dtype, nr_elems)); \ + execed = true; \ + } \ + MIDOUT_END(); \ } DISPATCH_QUANTIZED(QuantizedS32, int32_t, Quantized8Asymm, uint8_t, 0); DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS8, int8_t, 1); DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS32, int32_t, 2); DISPATCH_QUANTIZED(QuantizedS8, int8_t, QuantizedS8, int8_t, 3); - DISPATCH_QUANTIZED(Quantized8Asymm, uint8_t, Quantized8Asymm, uint8_t, - 4); + DISPATCH_QUANTIZED(Quantized8Asymm, uint8_t, Quantized8Asymm, uint8_t, 4); DISPATCH_QUANTIZED(QuantizedS32, int32_t, QuantizedS32, int32_t, 5); DISPATCH_QUANTIZED(float, float, QuantizedS8, int8_t, 6); DISPATCH_QUANTIZED(float, float, Quantized8Asymm, uint8_t, 7); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ - if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ - dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ - MIDOUT_BEGIN(megdnn_arm_typecvt_float, midout_iv(_midout_iv)) { \ - using _TypeCvter = FloatTypeCvter<_stype, _dtype>; \ - MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ - reinterpret_cast<_stype*>(src.raw_ptr), \ - reinterpret_cast<_dtype*>(dst.raw_ptr), src_dtype, \ - dst_dtype, nr_elems)); \ - execed = true; \ - } \ - MIDOUT_END(); \ +#define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ + if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ + dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ + MIDOUT_BEGIN(megdnn_arm_typecvt_float, midout_iv(_midout_iv)) { \ + using _TypeCvter = FloatTypeCvter<_stype, _dtype>; \ + MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ + reinterpret_cast<_stype*>(src.raw_ptr), \ + reinterpret_cast<_dtype*>(dst.raw_ptr), src_dtype, dst_dtype, \ + nr_elems)); \ + execed = true; \ + } \ + MIDOUT_END(); \ } DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0); DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1); diff --git a/dnn/src/arm_common/utils.cpp b/dnn/src/arm_common/utils.cpp index 18922e84..a629847d 100644 --- a/dnn/src/arm_common/utils.cpp +++ b/dnn/src/arm_common/utils.cpp @@ -18,79 +18,86 @@ using namespace megdnn; namespace { template -void transpose_naive(const dtype *src, dtype *dst, - int lda, int ldb, int n, int m) -{ - rep(i, n) rep(j, m) { - dst[i*ldb + j] = src[j*lda + i]; - } +void transpose_naive(const dtype* src, dtype* dst, int lda, int ldb, int n, int m) { + rep(i, n) rep(j, m) { dst[i * ldb + j] = src[j * lda + i]; } } -void transpose_4x4_neon(const float *src, float *dst, int lda, int ldb) -{ +void transpose_4x4_neon(const float* src, float* dst, int lda, int ldb) { float32x4x2_t a0, a1; - a0.val[0] = vld1q_f32(src + 0*lda); - a0.val[1] = vld1q_f32(src + 1*lda); - a1.val[0] = vld1q_f32(src + 2*lda); - a1.val[1] = vld1q_f32(src + 3*lda); + a0.val[0] = vld1q_f32(src + 0 * lda); + a0.val[1] = vld1q_f32(src + 1 * lda); + a1.val[0] = vld1q_f32(src + 2 * lda); + a1.val[1] = vld1q_f32(src + 3 * lda); float32x4x2_t b0 = vzipq_f32(a0.val[0], a1.val[0]); float32x4x2_t b1 = vzipq_f32(a0.val[1], a1.val[1]); float32x4x2_t c0 = vzipq_f32(b0.val[0], b1.val[0]); float32x4x2_t c1 = vzipq_f32(b0.val[1], b1.val[1]); - vst1q_f32(dst + 0*ldb, c0.val[0]); - vst1q_f32(dst + 1*ldb, c0.val[1]); - vst1q_f32(dst + 2*ldb, c1.val[0]); - vst1q_f32(dst + 3*ldb, c1.val[1]); + vst1q_f32(dst + 0 * ldb, c0.val[0]); + vst1q_f32(dst + 1 * ldb, c0.val[1]); + vst1q_f32(dst + 2 * ldb, c1.val[0]); + vst1q_f32(dst + 3 * ldb, c1.val[1]); } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -void transpose_8x8_neon(const dt_float16 *src, dt_float16 *dst, int lda, int ldb) -{ +void transpose_8x8_neon(const dt_float16* src, dt_float16* dst, int lda, int ldb) { const __fp16* src_ptr = reinterpret_cast(src); __fp16* dst_ptr = reinterpret_cast<__fp16*>(dst); float16x8x4_t a0, a1; - a0.val[0] = vld1q_f16(src_ptr + 0*lda); // A0A1A2A3A4A5A6A7 - a0.val[1] = vld1q_f16(src_ptr + 1*lda); // B0B1B2B3B4B5B6B7 - a0.val[2] = vld1q_f16(src_ptr + 2*lda); // C0C1C2C3C4C5C6C7 - a0.val[3] = vld1q_f16(src_ptr + 3*lda); // D0D1D2D3D4D5D6D7 - a1.val[0] = vld1q_f16(src_ptr + 4*lda); // E0E1E2E3E4E5E6E7 - a1.val[1] = vld1q_f16(src_ptr + 5*lda); // F0F1F2F3F4F5F6F7 - a1.val[2] = vld1q_f16(src_ptr + 6*lda); // G0G1G2G3G4G5G6G7 - a1.val[3] = vld1q_f16(src_ptr + 7*lda); // H0H1H2H3H4H5H6H7 - - float16x8x2_t b0 = vzipq_f16(a0.val[0], a1.val[0]); // A0E0A1E1A2E2A3E3 A4E4A5E5A6E6A7E7 - float16x8x2_t b1 = vzipq_f16(a0.val[2], a1.val[2]); // C0G0C1G1C2G2C3G3 C4G4C5G5C6G6C7G7 - float16x8x2_t c0 = vzipq_f16(a0.val[1], a1.val[1]); // B0F0B1F1B2F2B3F3 B4F4B5F5B6F6B7F7 - float16x8x2_t c1 = vzipq_f16(a0.val[3], a1.val[3]); // D0H0D1H1D2H2D3H3 D4H4D5H5D6H6D7H7 - - float16x8x2_t d0 = vzipq_f16(b0.val[0], b1.val[0]); // A0C0E0G0A1C1E1G1 A2C2E2G2A3C3E3G3 - float16x8x2_t d1 = vzipq_f16(c0.val[0], c1.val[0]); // B0D0F0H0B1D1F1H1 B2D2F2H2B3D3F3H3 - float16x8x2_t e0 = vzipq_f16(d0.val[0], d1.val[0]); // A0B0C0D0E0F0G0H0 A1B1C1D1E1F1G1H1 - float16x8x2_t e1 = vzipq_f16(d0.val[1], d1.val[1]); // A2B2C2D2E2F2G2H2 A3B3C3D3E3F3G3H3 - - float16x8x2_t f0 = vzipq_f16(b0.val[1], b1.val[1]); // A4C4E4G4A5C5E5G5 A6C6E6G6A7C7E7G7 - float16x8x2_t f1 = vzipq_f16(c0.val[1], c1.val[1]); // B4D4F4H4B5D5F5H5 B6D6E6G6B7D7E7H7 - float16x8x2_t g0 = vzipq_f16(f0.val[0], f1.val[0]); // A4B4C4D4E4F4G4H4 A5B5C5D5E5F5G5H5 - float16x8x2_t g1 = vzipq_f16(f0.val[1], f1.val[1]); // A6B6C6D6E6F6G6H6 A7B7C7D7E7F7G7H7 - - vst1q_f16(dst_ptr + 0*ldb, e0.val[0]); - vst1q_f16(dst_ptr + 1*ldb, e0.val[1]); - vst1q_f16(dst_ptr + 2*ldb, e1.val[0]); - vst1q_f16(dst_ptr + 3*ldb, e1.val[1]); - vst1q_f16(dst_ptr + 4*ldb, g0.val[0]); - vst1q_f16(dst_ptr + 5*ldb, g0.val[1]); - vst1q_f16(dst_ptr + 6*ldb, g1.val[0]); - vst1q_f16(dst_ptr + 7*ldb, g1.val[1]); + a0.val[0] = vld1q_f16(src_ptr + 0 * lda); // A0A1A2A3A4A5A6A7 + a0.val[1] = vld1q_f16(src_ptr + 1 * lda); // B0B1B2B3B4B5B6B7 + a0.val[2] = vld1q_f16(src_ptr + 2 * lda); // C0C1C2C3C4C5C6C7 + a0.val[3] = vld1q_f16(src_ptr + 3 * lda); // D0D1D2D3D4D5D6D7 + a1.val[0] = vld1q_f16(src_ptr + 4 * lda); // E0E1E2E3E4E5E6E7 + a1.val[1] = vld1q_f16(src_ptr + 5 * lda); // F0F1F2F3F4F5F6F7 + a1.val[2] = vld1q_f16(src_ptr + 6 * lda); // G0G1G2G3G4G5G6G7 + a1.val[3] = vld1q_f16(src_ptr + 7 * lda); // H0H1H2H3H4H5H6H7 + + float16x8x2_t b0 = + vzipq_f16(a0.val[0], a1.val[0]); // A0E0A1E1A2E2A3E3 A4E4A5E5A6E6A7E7 + float16x8x2_t b1 = + vzipq_f16(a0.val[2], a1.val[2]); // C0G0C1G1C2G2C3G3 C4G4C5G5C6G6C7G7 + float16x8x2_t c0 = + vzipq_f16(a0.val[1], a1.val[1]); // B0F0B1F1B2F2B3F3 B4F4B5F5B6F6B7F7 + float16x8x2_t c1 = + vzipq_f16(a0.val[3], a1.val[3]); // D0H0D1H1D2H2D3H3 D4H4D5H5D6H6D7H7 + + float16x8x2_t d0 = + vzipq_f16(b0.val[0], b1.val[0]); // A0C0E0G0A1C1E1G1 A2C2E2G2A3C3E3G3 + float16x8x2_t d1 = + vzipq_f16(c0.val[0], c1.val[0]); // B0D0F0H0B1D1F1H1 B2D2F2H2B3D3F3H3 + float16x8x2_t e0 = + vzipq_f16(d0.val[0], d1.val[0]); // A0B0C0D0E0F0G0H0 A1B1C1D1E1F1G1H1 + float16x8x2_t e1 = + vzipq_f16(d0.val[1], d1.val[1]); // A2B2C2D2E2F2G2H2 A3B3C3D3E3F3G3H3 + + float16x8x2_t f0 = + vzipq_f16(b0.val[1], b1.val[1]); // A4C4E4G4A5C5E5G5 A6C6E6G6A7C7E7G7 + float16x8x2_t f1 = + vzipq_f16(c0.val[1], c1.val[1]); // B4D4F4H4B5D5F5H5 B6D6E6G6B7D7E7H7 + float16x8x2_t g0 = + vzipq_f16(f0.val[0], f1.val[0]); // A4B4C4D4E4F4G4H4 A5B5C5D5E5F5G5H5 + float16x8x2_t g1 = + vzipq_f16(f0.val[1], f1.val[1]); // A6B6C6D6E6F6G6H6 A7B7C7D7E7F7G7H7 + + vst1q_f16(dst_ptr + 0 * ldb, e0.val[0]); + vst1q_f16(dst_ptr + 1 * ldb, e0.val[1]); + vst1q_f16(dst_ptr + 2 * ldb, e1.val[0]); + vst1q_f16(dst_ptr + 3 * ldb, e1.val[1]); + vst1q_f16(dst_ptr + 4 * ldb, g0.val[0]); + vst1q_f16(dst_ptr + 5 * ldb, g0.val[1]); + vst1q_f16(dst_ptr + 6 * ldb, g1.val[0]); + vst1q_f16(dst_ptr + 7 * ldb, g1.val[1]); } #endif -} // anonymous namespace +} // anonymous namespace namespace megdnn { template <> -void transpose(const float* src, float* dst, size_t m, size_t n, ptrdiff_t lds, - ptrdiff_t ldd) { +void transpose( + const float* src, float* dst, size_t m, size_t n, ptrdiff_t lds, + ptrdiff_t ldd) { if (lds == -1) { lds = n; } @@ -104,46 +111,46 @@ void transpose(const float* src, float* dst, size_t m, size_t n, ptrdiff_t lds, for (; i + 4 <= ie; i += 4) { auto j = js; for (; j + 4 <= je; j += 4) { - transpose_4x4_neon(src + j * lds + i, dst + i * ldd + j, - lds, ldd); + transpose_4x4_neon(src + j * lds + i, dst + i * ldd + j, lds, ldd); } if (j < je) { - transpose_naive(src + j * lds + i, dst + i * ldd + j, lds, - ldd, 4, je - j); + transpose_naive( + src + j * lds + i, dst + i * ldd + j, lds, ldd, 4, je - j); } } if (i < ie) { - transpose_naive(src + js * lds + i, dst + i * ldd + js, lds, - ldd, ie - i, je - js); + transpose_naive( + src + js * lds + i, dst + i * ldd + js, lds, ldd, ie - i, + je - js); } } } } -template -void transpose_knc2nsck_helper(const dtype *src, dtype *dst, - size_t k, size_t n, size_t c, size_t n_stride) { +template +void transpose_knc2nsck_helper( + const dtype* src, dtype* dst, size_t k, size_t n, size_t c, size_t n_stride) { if (n_stride == k * c) { // dst is contiguous transpose(src, dst, k, n * c); } else { - for (size_t i = 0; i < n; ++ i) { - transpose(src + i * c, dst + i * n_stride, - k, c, n * c); + for (size_t i = 0; i < n; ++i) { + transpose(src + i * c, dst + i * n_stride, k, c, n * c); } } } template <> -void transpose_knc2nsck(const float *src, float *dst, - size_t k, size_t n, size_t c, size_t n_stride) { +void transpose_knc2nsck( + const float* src, float* dst, size_t k, size_t n, size_t c, size_t n_stride) { transpose_knc2nsck_helper(src, dst, k, n, c, n_stride); } #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC template <> -void transpose(const dt_float16* src, dt_float16* dst, size_t m, size_t n, - ptrdiff_t lds, ptrdiff_t ldd) { +void transpose( + const dt_float16* src, dt_float16* dst, size_t m, size_t n, ptrdiff_t lds, + ptrdiff_t ldd) { if (lds == -1) { lds = n; } @@ -157,28 +164,29 @@ void transpose(const dt_float16* src, dt_float16* dst, size_t m, size_t n, for (; i + 8 <= ie; i += 8) { auto j = js; for (; j + 8 <= je; j += 8) { - transpose_8x8_neon(src + j * lds + i, dst + i * ldd + j, - lds, ldd); + transpose_8x8_neon(src + j * lds + i, dst + i * ldd + j, lds, ldd); } if (j < je) { - transpose_naive(src + j * lds + i, dst + i * ldd + j, lds, - ldd, 8, je - j); + transpose_naive( + src + j * lds + i, dst + i * ldd + j, lds, ldd, 8, je - j); } } if (i < ie) { - transpose_naive(src + js * lds + i, dst + i * ldd + js, lds, - ldd, ie - i, je - js); + transpose_naive( + src + js * lds + i, dst + i * ldd + js, lds, ldd, ie - i, + je - js); } } } } template <> -void transpose_knc2nsck(const dt_float16* src, dt_float16* dst, size_t k, - size_t n, size_t c, size_t n_stride) { +void transpose_knc2nsck( + const dt_float16* src, dt_float16* dst, size_t k, size_t n, size_t c, + size_t n_stride) { transpose_knc2nsck_helper(src, dst, k, n, c, n_stride); } #endif -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/utils.h b/dnn/src/arm_common/utils.h index 577231a5..5b6f01cc 100644 --- a/dnn/src/arm_common/utils.h +++ b/dnn/src/arm_common/utils.h @@ -99,9 +99,7 @@ struct Vector<__fp16, 8> { v.value = vld1q_f16(addr); return v; } - static void save(__fp16* addr, const Vector& v) { - vst1q_f16(addr, v.value); - } + static void save(__fp16* addr, const Vector& v) { vst1q_f16(addr, v.value); } void save(__fp16* addr) { save(addr, *this); } Vector operator+(const Vector& lr) { Vector dst; @@ -233,9 +231,7 @@ struct Vector { v.value = vld1q_f32_x2(addr); return v; } - static void save(float* addr, const Vector& v) { - vst1q_f32_x2(addr, v.value); - } + static void save(float* addr, const Vector& v) { vst1q_f32_x2(addr, v.value); } void save(float* addr) { save(addr, *this); } Vector operator+(const Vector& lr) { @@ -318,9 +314,7 @@ struct Vector { v.value = vld1q_s16(addr); return v; } - static void save(int16_t* addr, const Vector& v) { - vst1q_s16(addr, v.value); - } + static void save(int16_t* addr, const Vector& v) { vst1q_s16(addr, v.value); } void save(int16_t* addr) { save(addr, *this); } Vector operator+(const Vector& lr) { Vector dst; @@ -382,9 +376,7 @@ struct Vector { v.value = vld1_s16(addr); return v; } - static void save(int16_t* addr, const Vector& v) { - vst1_s16(addr, v.value); - } + static void save(int16_t* addr, const Vector& v) { vst1_s16(addr, v.value); } void save(int16_t* addr) { save(addr, *this); } Vector operator+(const Vector& lr) { Vector dst; @@ -433,8 +425,6 @@ struct Vector { } }; - - template <> struct Vector { int32x4x2_t value; diff --git a/dnn/src/arm_common/warp_affine/opr_impl.cpp b/dnn/src/arm_common/warp_affine/opr_impl.cpp index bafa6d4b..06ab19d4 100644 --- a/dnn/src/arm_common/warp_affine/opr_impl.cpp +++ b/dnn/src/arm_common/warp_affine/opr_impl.cpp @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/arm_common/handle.h" #include "src/arm_common/warp_affine/opr_impl.h" +#include "src/arm_common/handle.h" #include "src/arm_common/warp_affine/warp_affine_cv.h" #include "src/common/warp_common.h" @@ -21,14 +21,16 @@ using namespace megdnn; using namespace arm_common; using namespace megcv; -void WarpAffineImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, - _megdnn_tensor_out dst, _megdnn_workspace workspace) { +void WarpAffineImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { check_exec(src.layout, mat.layout, dst.layout, workspace.size); - if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, - param().format)) { + if (warp::is_cv_available( + src.layout, mat.layout, dst.layout, param().imode, param().format)) { MIDOUT_BEGIN(megdnn_arm_warpaffine, void) { - warp_affine_cv_exec(src, mat, dst, param().border_val, - param().border_mode, param().imode, handle()); + warp_affine_cv_exec( + src, mat, dst, param().border_val, param().border_mode, + param().imode, handle()); } MIDOUT_END(); } else { diff --git a/dnn/src/arm_common/warp_affine/opr_impl.h b/dnn/src/arm_common/warp_affine/opr_impl.h index 4018938f..39ad8205 100644 --- a/dnn/src/arm_common/warp_affine/opr_impl.h +++ b/dnn/src/arm_common/warp_affine/opr_impl.h @@ -17,11 +17,12 @@ namespace arm_common { class WarpAffineImpl : public naive::WarpAffineImpl { public: using naive::WarpAffineImpl::WarpAffineImpl; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, - _megdnn_tensor_in dst, _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&) override { + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { return 0; } }; diff --git a/dnn/src/arm_common/warp_affine/warp_affine_cv.cpp b/dnn/src/arm_common/warp_affine/warp_affine_cv.cpp index b66767fe..07ed7783 100644 --- a/dnn/src/arm_common/warp_affine/warp_affine_cv.cpp +++ b/dnn/src/arm_common/warp_affine/warp_affine_cv.cpp @@ -85,8 +85,9 @@ namespace { constexpr size_t BLOCK_SZ = 64_z; template -void warp_affine_cv(const Mat& src, Mat& dst, const float* trans, - const float border_value, size_t task_id) { +void warp_affine_cv( + const Mat& src, Mat& dst, const float* trans, const float border_value, + size_t task_id) { // no extra padding double M[6]; rep(i, 6) M[i] = trans[i]; @@ -123,10 +124,8 @@ void warp_affine_cv(const Mat& src, Mat& dst, const float* trans, for (y1 = 0; y1 < bh; ++y1) { short* xy = XY + y1 * bw * 2; - int X0 = saturate_cast((M[1] * (y + y1) + M[2]) * AB_SCALE) + - round_delta; - int Y0 = saturate_cast((M[4] * (y + y1) + M[5]) * AB_SCALE) + - round_delta; + int X0 = saturate_cast((M[1] * (y + y1) + M[2]) * AB_SCALE) + round_delta; + int Y0 = saturate_cast((M[4] * (y + y1) + M[5]) * AB_SCALE) + round_delta; if (imode == IMode::INTER_NEAREST) { x1 = 0; @@ -136,15 +135,13 @@ void warp_affine_cv(const Mat& src, Mat& dst, const float* trans, int16x8x2_t v_dst; v_dst.val[0] = vcombine_s16( vqmovn_s32(vshrq_n_s32( - vaddq_s32(v_X0, vld1q_s32(adelta + x + x1)), - AB_BITS)), + vaddq_s32(v_X0, vld1q_s32(adelta + x + x1)), AB_BITS)), vqmovn_s32(vshrq_n_s32( vaddq_s32(v_X0, vld1q_s32(adelta + x + x1 + 4)), AB_BITS))); v_dst.val[1] = vcombine_s16( vqmovn_s32(vshrq_n_s32( - vaddq_s32(v_Y0, vld1q_s32(bdelta + x + x1)), - AB_BITS)), + vaddq_s32(v_Y0, vld1q_s32(bdelta + x + x1)), AB_BITS)), vqmovn_s32(vshrq_n_s32( vaddq_s32(v_Y0, vld1q_s32(bdelta + x + x1 + 4)), AB_BITS))); @@ -180,12 +177,12 @@ void warp_affine_cv(const Mat& src, Mat& dst, const float* trans, AB_BITS - INTER_BITS); int16x8x2_t v_xy; - v_xy.val[0] = - vcombine_s16(vqmovn_s32(vshrq_n_s32(v_X0, INTER_BITS)), - vqmovn_s32(vshrq_n_s32(v_X1, INTER_BITS))); - v_xy.val[1] = - vcombine_s16(vqmovn_s32(vshrq_n_s32(v_Y0, INTER_BITS)), - vqmovn_s32(vshrq_n_s32(v_Y1, INTER_BITS))); + v_xy.val[0] = vcombine_s16( + vqmovn_s32(vshrq_n_s32(v_X0, INTER_BITS)), + vqmovn_s32(vshrq_n_s32(v_X1, INTER_BITS))); + v_xy.val[1] = vcombine_s16( + vqmovn_s32(vshrq_n_s32(v_Y0, INTER_BITS)), + vqmovn_s32(vshrq_n_s32(v_Y1, INTER_BITS))); vst2q_s16(xy + (x1 << 1), v_xy); @@ -216,8 +213,7 @@ void warp_affine_cv(const Mat& src, Mat& dst, const float* trans, void megdnn::arm_common::warp_affine_cv_exec( _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, - float border_value, BorderMode bmode, InterpolationMode imode, - Handle* handle) { + float border_value, BorderMode bmode, InterpolationMode imode, Handle* handle) { size_t ch = dst.layout[3]; size_t width = dst.layout[2]; size_t height = dst.layout[1]; @@ -227,60 +223,60 @@ void megdnn::arm_common::warp_affine_cv_exec( size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); - size_t parallelism_batch = div_ceil(height, BLOCK_SZ_H) * - div_ceil(width, BLOCK_SZ_W); + size_t parallelism_batch = + div_ceil(height, BLOCK_SZ_H) * div_ceil(width, BLOCK_SZ_W); - megdnn_assert(ch == 1 || ch == 3 || ch == 2, - "unsupported src channel: %zu, avaiable channel size: 1/2/3", - ch); + megdnn_assert( + ch == 1 || ch == 3 || ch == 2, + "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); const float* trans_ptr = trans.ptr(); if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { -#define cb(_imode, _bmode, _ch) \ - MIDOUT_BEGIN(megdnn_arm_common_warp_affine_cv, midout_iv(_imode), \ - midout_iv(_bmode), midout_iv(_ch), float) { \ - auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ - size_t index, size_t) { \ - size_t batch_id = index / parallelism_batch; \ - size_t task_id = index % parallelism_batch; \ - Mat src_mat = TensorND2Mat(src, batch_id); \ - Mat dst_mat = TensorND2Mat(dst, batch_id); \ - const float* task_trans_ptr = trans_ptr + batch_id * 2 * 3; \ - warp_affine_cv( \ - src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ - MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA \ - border_value, \ - task_id); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast(handle), \ - batch* parallelism_batch, task); \ - } \ +#define cb(_imode, _bmode, _ch) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_warp_affine_cv, midout_iv(_imode), midout_iv(_bmode), \ + midout_iv(_ch), float) { \ + auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + Mat src_mat = TensorND2Mat(src, batch_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 2 * 3; \ + warp_affine_cv< \ + float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode MEGDNN_COMMA _ch>( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), batch* parallelism_batch, \ + task); \ + } \ MIDOUT_END(); DISPATCH_IMODE(imode, bmode, ch, cb) } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { #undef cb -#define cb(_imode, _bmode, _ch) \ - MIDOUT_BEGIN(megdnn_arm_common_warp_affine_cv, midout_iv(_imode), \ - midout_iv(_bmode), midout_iv(_ch), uchar) { \ - auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ - size_t index, size_t) { \ - size_t batch_id = index / parallelism_batch; \ - size_t task_id = index % parallelism_batch; \ - Mat src_mat = TensorND2Mat(src, batch_id); \ - Mat dst_mat = TensorND2Mat(dst, batch_id); \ - const float* task_trans_ptr = trans_ptr + batch_id * 2 * 3; \ - warp_affine_cv( \ - src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ - MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA \ - border_value, \ - task_id); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast(handle), \ - batch* parallelism_batch, task); \ - } \ +#define cb(_imode, _bmode, _ch) \ + MIDOUT_BEGIN( \ + megdnn_arm_common_warp_affine_cv, midout_iv(_imode), midout_iv(_bmode), \ + midout_iv(_ch), uchar) { \ + auto task = [src, trans_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + Mat src_mat = TensorND2Mat(src, batch_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 2 * 3; \ + warp_affine_cv< \ + uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode MEGDNN_COMMA _ch>( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), batch* parallelism_batch, \ + task); \ + } \ MIDOUT_END(); DISPATCH_IMODE(imode, bmode, ch, cb) #undef cb diff --git a/dnn/src/arm_common/warp_affine/warp_affine_cv.h b/dnn/src/arm_common/warp_affine/warp_affine_cv.h index 0b2cf588..9881600c 100644 --- a/dnn/src/arm_common/warp_affine/warp_affine_cv.h +++ b/dnn/src/arm_common/warp_affine/warp_affine_cv.h @@ -20,11 +20,10 @@ namespace arm_common { * \fn warp_affine_cv * \brief Used if the format is NHWC, transfer from megcv */ -void warp_affine_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, - _megdnn_tensor_in dst, float border_value, - param::WarpAffine::BorderMode border_mode, - param::WarpAffine::InterpolationMode imode, - Handle* handle); +void warp_affine_cv_exec( + _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in dst, + float border_value, param::WarpAffine::BorderMode border_mode, + param::WarpAffine::InterpolationMode imode, Handle* handle); } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/warp_perspective/opr_impl.cpp b/dnn/src/arm_common/warp_perspective/opr_impl.cpp index 2422731d..857021a3 100644 --- a/dnn/src/arm_common/warp_perspective/opr_impl.cpp +++ b/dnn/src/arm_common/warp_perspective/opr_impl.cpp @@ -22,16 +22,17 @@ MIDOUT_DECL(megdnn_arm_warpperspective) namespace megdnn { namespace arm_common { -void WarpPerspectiveImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, - _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, - _megdnn_workspace workspace) { - check_exec_allow_nhwc_mat_idx(src.layout, mat.layout, mat_idx.layout, - dst.layout, workspace.size); - if (warp::is_cv_available(src.layout, mat.layout, dst.layout, param().imode, - param().format)) { +void WarpPerspectiveImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, + _megdnn_tensor_in dst, _megdnn_workspace workspace) { + check_exec_allow_nhwc_mat_idx( + src.layout, mat.layout, mat_idx.layout, dst.layout, workspace.size); + if (warp::is_cv_available( + src.layout, mat.layout, dst.layout, param().imode, param().format)) { MIDOUT_BEGIN(megdnn_arm_warpperspective, void) { - warp_perspective_cv_exec(src, mat, mat_idx, dst, param().border_val, - param().bmode, param().imode, handle()); + warp_perspective_cv_exec( + src, mat, mat_idx, dst, param().border_val, param().bmode, + param().imode, handle()); } MIDOUT_END(); } else { diff --git a/dnn/src/arm_common/warp_perspective/opr_impl.h b/dnn/src/arm_common/warp_perspective/opr_impl.h index affaf246..865603bf 100644 --- a/dnn/src/arm_common/warp_perspective/opr_impl.h +++ b/dnn/src/arm_common/warp_perspective/opr_impl.h @@ -19,9 +19,9 @@ class WarpPerspectiveImpl : public fallback::WarpPerspectiveImpl { public: using fallback::WarpPerspectiveImpl::WarpPerspectiveImpl; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in mat, - _megdnn_tensor_in mat_idx, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; }; } // namespace arm_common diff --git a/dnn/src/arm_common/warp_perspective/warp_perspective_cv.cpp b/dnn/src/arm_common/warp_perspective/warp_perspective_cv.cpp index 17a9fd72..795f81f8 100644 --- a/dnn/src/arm_common/warp_perspective/warp_perspective_cv.cpp +++ b/dnn/src/arm_common/warp_perspective/warp_perspective_cv.cpp @@ -59,14 +59,14 @@ * --------------------------------------------------------------------------- */ #include "src/arm_common/warp_perspective/warp_perspective_cv.h" +#include +#include #include "src/arm_common/handle.h" #include "src/common/cv/common.h" #include "src/common/cv/helper.h" #include "src/common/cv/interp_helper.h" #include "src/common/utils.h" #include "src/common/warp_common.h" -#include -#include #include "src/arm_common/simd_macro/marm_neon.h" @@ -79,8 +79,9 @@ namespace { constexpr size_t BLOCK_SZ = 32u; template -void warp_perspective_cv(const Mat& src, Mat& dst, const float* trans, - const float border_value, size_t task_id) { +void warp_perspective_cv( + const Mat& src, Mat& dst, const float* trans, const float border_value, + size_t task_id) { // no extra padding double M[9]; rep(i, 9) M[i] = trans[i]; @@ -146,12 +147,12 @@ void warp_perspective_cv(const Mat& src, Mat& dst, const float* trans, remap>(src, dpart, _XY, _matA, bvalue); } -} // anonymous namespace +} // anonymous namespace void megdnn::arm_common::warp_perspective_cv_exec( - _megdnn_tensor_in src, _megdnn_tensor_in trans, - _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, float border_value, - BorderMode bmode, InterpolationMode imode, Handle* handle) { + _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in mat_idx, + _megdnn_tensor_in dst, float border_value, BorderMode bmode, + InterpolationMode imode, Handle* handle) { size_t ch = dst.layout[3]; size_t width = dst.layout[2]; size_t height = dst.layout[1]; @@ -161,11 +162,11 @@ void megdnn::arm_common::warp_perspective_cv_exec( size_t BLOCK_SZ_W = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_H, width); BLOCK_SZ_H = std::min(BLOCK_SZ * BLOCK_SZ / BLOCK_SZ_W, height); - size_t parallelism_batch = div_ceil(height, BLOCK_SZ_H) * - div_ceil(width, BLOCK_SZ_W); - megdnn_assert(ch == 1 || ch == 3 || ch == 2, - "unsupported src channel: %zu, avaiable channel size: 1/2/3", - ch); + size_t parallelism_batch = + div_ceil(height, BLOCK_SZ_H) * div_ceil(width, BLOCK_SZ_W); + megdnn_assert( + ch == 1 || ch == 3 || ch == 2, + "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); const float* trans_ptr = trans.ptr(); const int* midx_ptr = nullptr; if (mat_idx.raw_ptr) { @@ -173,59 +174,57 @@ void megdnn::arm_common::warp_perspective_cv_exec( midx_ptr = mat_idx.ptr(); } if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { -#define cb(_imode, _bmode, _ch) \ - auto task = [src, trans_ptr, midx_ptr, dst, border_value, \ - parallelism_batch](size_t index, size_t) { \ - size_t batch_id = index / parallelism_batch; \ - size_t task_id = index % parallelism_batch; \ - size_t src_id = batch_id; \ - if (midx_ptr) { \ - src_id = midx_ptr[batch_id]; \ - megdnn_assert( \ - src_id < src.layout.shape[0], \ - "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ - batch_id, src_id, src.layout.shape[0]); \ - } \ - Mat src_mat = TensorND2Mat(src, src_id); \ - Mat dst_mat = TensorND2Mat(dst, batch_id); \ - const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ - warp_perspective_cv( \ - src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ - MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ - task_id); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast(handle), batch* parallelism_batch, \ - task); +#define cb(_imode, _bmode, _ch) \ + auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ + src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ + warp_perspective_cv< \ + float MEGDNN_COMMA _imode MEGDNN_COMMA _bmode MEGDNN_COMMA _ch>( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), batch* parallelism_batch, task); DISPATCH_IMODE(imode, bmode, ch, cb) #undef cb } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { -#define cb(_imode, _bmode, _ch) \ - auto task = [src, trans_ptr, midx_ptr, dst, border_value, \ - parallelism_batch](size_t index, size_t) { \ - size_t batch_id = index / parallelism_batch; \ - size_t task_id = index % parallelism_batch; \ - size_t src_id = batch_id; \ - if (midx_ptr) { \ - src_id = midx_ptr[batch_id]; \ - megdnn_assert( \ - src_id < src.layout.shape[0], \ - "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", \ - batch_id, src_id, src.layout.shape[0]); \ - } \ - Mat src_mat = TensorND2Mat(src, src_id); \ - Mat dst_mat = TensorND2Mat(dst, batch_id); \ - const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ - warp_perspective_cv( \ - src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ - MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ - task_id); \ - }; \ - MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ - static_cast(handle), batch* parallelism_batch, \ - task); +#define cb(_imode, _bmode, _ch) \ + auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ + size_t index, size_t) { \ + size_t batch_id = index / parallelism_batch; \ + size_t task_id = index % parallelism_batch; \ + size_t src_id = batch_id; \ + if (midx_ptr) { \ + src_id = midx_ptr[batch_id]; \ + megdnn_assert( \ + src_id < src.layout.shape[0], \ + "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ + src_id, src.layout.shape[0]); \ + } \ + Mat src_mat = TensorND2Mat(src, src_id); \ + Mat dst_mat = TensorND2Mat(dst, batch_id); \ + const float* task_trans_ptr = trans_ptr + batch_id * 3 * 3; \ + warp_perspective_cv< \ + uchar MEGDNN_COMMA _imode MEGDNN_COMMA _bmode MEGDNN_COMMA _ch>( \ + src_mat MEGDNN_COMMA const_cast&>(dst_mat) \ + MEGDNN_COMMA task_trans_ptr MEGDNN_COMMA border_value, \ + task_id); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast(handle), batch* parallelism_batch, task); DISPATCH_IMODE(imode, bmode, ch, cb) #undef cb } else { diff --git a/dnn/src/arm_common/warp_perspective/warp_perspective_cv.h b/dnn/src/arm_common/warp_perspective/warp_perspective_cv.h index 4a4ff95c..a0777dcd 100644 --- a/dnn/src/arm_common/warp_perspective/warp_perspective_cv.h +++ b/dnn/src/arm_common/warp_perspective/warp_perspective_cv.h @@ -20,12 +20,11 @@ namespace arm_common { * \fn warp_perspective_cv * \brief Used if the format is NHWC, transfer from megcv */ -void warp_perspective_cv_exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, - _megdnn_tensor_in mat_idx, _megdnn_tensor_in dst, - float border_value, - param::WarpPerspective::BorderMode border_mode, - param::WarpPerspective::InterpolationMode imode, - Handle* handle); +void warp_perspective_cv_exec( + _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_in mat_idx, + _megdnn_tensor_in dst, float border_value, + param::WarpPerspective::BorderMode border_mode, + param::WarpPerspective::InterpolationMode imode, Handle* handle); } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/armv7/conv_bias/int8/algos.cpp b/dnn/src/armv7/conv_bias/int8/algos.cpp index 3fd1c089..0ccaab89 100644 --- a/dnn/src/armv7/conv_bias/int8/algos.cpp +++ b/dnn/src/armv7/conv_bias/int8/algos.cpp @@ -63,18 +63,19 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( size_t K = IC * FH * FW; size_t N = OH * OW; -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ - MIDOUT_BEGIN(megdnn_armv7_conv_bias_int8, 0, _gemm_midout_enum, \ - _bias_midout_enum, _nonline_midout_enum) { \ - matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ - M, N, K, param.filter_type, param.src_type, param.dst_type); \ - part2 = megdnn::matmul::GemmInterleaved< \ - matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ - M, N, K, false, false, strategy) \ - .get_workspace_size(); \ - } \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN( \ + megdnn_armv7_conv_bias_int8, 0, _gemm_midout_enum, _bias_midout_enum, \ + _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + part2 = megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ + M, N, K, false, false, strategy) \ + .get_workspace_size(); \ + } \ MIDOUT_END() DISPATCH_GEMM_BIAS(s8_4x2, 0) @@ -84,8 +85,8 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( return {nullptr, {part0, part1, part2}}; } -void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, - const NCBKernIndex& ncb_index) { +void ConvBiasImpl::AlgoS8MatrixMul::kimpl( + const NCBKernParam& param, const NCBKernIndex& ncb_index) { auto is_xcorr = !param.filter_meta.should_flip; UNPACK_CONV_NCB_KERN_SIZES(param); auto bundle = get_bundle(param); @@ -137,33 +138,32 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); } else { if (is_xcorr) - img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, - FW, SH, SW); + img2col_stride( + src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); else - img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, - FW, SH, SW); + img2col_stride( + src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); } } { - Workspace workspace(static_cast(bundle.get(2)), - bundle.get_size(2)); + Workspace workspace( + static_cast(bundle.get(2)), bundle.get_size(2)); size_t M = OC; size_t K = IC * FH * FW; size_t N = OH * OW; -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ - MIDOUT_BEGIN(megdnn_armv7_conv_bias_int8, 1, _gemm_midout_enum, \ - _bias_midout_enum, _nonline_midout_enum) { \ - matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ - M, N, K, param.filter_type, param.src_type, param.dst_type); \ - megdnn::matmul::GemmInterleaved< \ - matmul::gemm_##_gemm##_##_bias##_##_nonline> \ - gemm_interleaved(M, N, K, false, false, strategy); \ - gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ - bias); \ - } \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN( \ + megdnn_armv7_conv_bias_int8, 1, _gemm_midout_enum, _bias_midout_enum, \ + _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved \ + gemm_interleaved(M, N, K, false, false, strategy); \ + gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ + } \ MIDOUT_END() DISPATCH_GEMM_BIAS(s8_4x2, 0) diff --git a/dnn/src/armv7/conv_bias/int8/algos.h b/dnn/src/armv7/conv_bias/int8/algos.h index 3260bd25..39667f02 100644 --- a/dnn/src/armv7/conv_bias/int8/algos.h +++ b/dnn/src/armv7/conv_bias/int8/algos.h @@ -24,18 +24,16 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { static void kimpl(const NCBKernParam& param, const NCBKernIndex&); public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "S8MATMUL"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override { return get_bundle(param).total_size_in_bytes(); } - SmallVector dispatch_kerns( - const NCBKernSizeParam& param) const override { + SmallVector dispatch_kerns(const NCBKernSizeParam& param) const override { size_t group = param.filter_meta.group; return {{kimpl, {group, 1_z, 1_z}}}; } diff --git a/dnn/src/armv7/conv_bias/int8/strategy.cpp b/dnn/src/armv7/conv_bias/int8/strategy.cpp index 1a51a07e..8ec5555b 100644 --- a/dnn/src/armv7/conv_bias/int8/strategy.cpp +++ b/dnn/src/armv7/conv_bias/int8/strategy.cpp @@ -28,9 +28,10 @@ struct KernCaller; template struct KernCaller { - static void run(const dt_int8* packA, const dt_int8* packB, size_t M, - size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, - Op op, const dt_int32* bias, dt_int32* workspace) { + static void run( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, + dt_int32* workspace) { megdnn_assert(is_first_k); constexpr size_t A_INTERLEAVE = 4; @@ -47,19 +48,18 @@ struct KernCaller { size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_4x2x16::kern_4x2(packA, cur_packB, K, workspace, 2, - is_first_k, 4, 2); - arm_common::ConvBiasMatmul::postprocess(bias, workspace, - output, LDC, op); + matmul_4x2x16::kern_4x2( + packA, cur_packB, K, workspace, 2, is_first_k, 4, 2); + arm_common::ConvBiasMatmul::postprocess( + bias, workspace, output, LDC, op); output += B_INTERLEAVE; cur_packB += K2; } for (; n < N; n += B_INTERLEAVE) { - matmul_4x2x16::kern_4x2(packA, cur_packB, K, workspace, 2, - is_first_k, 4, - std::min(N - n, 2)); + matmul_4x2x16::kern_4x2( + packA, cur_packB, K, workspace, 2, is_first_k, 4, + std::min(N - n, 2)); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); @@ -81,14 +81,13 @@ struct KernCaller { size_t n = 0; const dt_int8* cur_packB = packB; for (; n < N; n += B_INTERLEAVE) { - matmul_4x2x16::kern_4x2(packA, cur_packB, K, workspace, 2, - is_first_k, std::min(M - m, 4), - std::min(N - n, 2)); + matmul_4x2x16::kern_4x2( + packA, cur_packB, K, workspace, 2, is_first_k, + std::min(M - m, 4), std::min(N - n, 2)); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); - DISPATCH_M(cb, std::min(M - m, 4), - std::min(N - n, 2)); + DISPATCH_M(cb, std::min(M - m, 4), std::min(N - n, 2)); #undef cb output += B_INTERLEAVE; @@ -106,17 +105,16 @@ struct KernCaller { MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x2_nobias_identity) -void gemm_s8_4x2_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax, bool /*transpose*/) const { +void gemm_s8_4x2_nobias_identity::pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool /*transpose*/) const { MEGDNN_MARK_USED_VAR(matmul_4x2x16::gemm_s8_4x2_pack_A_t); - matmul_4x2x16::gemm_s8_4x2_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, - kmax); + matmul_4x2x16::gemm_s8_4x2_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); } -void gemm_s8_4x2_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, - int ldin, int x0, int xmax, int k0, - int kmax, bool /*transpose*/) const { +void gemm_s8_4x2_nobias_identity::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool /*transpose*/) const { MEGDNN_MARK_USED_VAR(matmul_4x2x16::gemm_s8_4x2_pack_B_t); matmul_4x2x16::gemm_s8_4x2_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } @@ -125,18 +123,17 @@ size_t gemm_s8_4x2_nobias_identity::get_workspace_size() const { return 4 * 2 * sizeof(dt_int32); } -#define KERN(_bias, _BIAS, _nonline, _OP) \ - void gemm_s8_4x2_##_bias##_##_nonline::kern( \ - const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \ - size_t K, dt_int8* C, size_t LDC, bool is_first_k, \ - const dt_int32* bias, dt_int32* workspace) const { \ - float scale_A = A_dtype.param().scale; \ - float scale_B = B_dtype.param().scale; \ - float scale_C = C_dtype.param().scale; \ - DEFINE_OP(_OP); \ - impl::KernCaller<_BIAS, decltype(op), 4, 2>::run( \ - packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ - workspace); \ +#define KERN(_bias, _BIAS, _nonline, _OP) \ + void gemm_s8_4x2_##_bias##_##_nonline::kern( \ + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, \ + dt_int8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ + dt_int32* workspace) const { \ + float scale_A = A_dtype.param().scale; \ + float scale_B = B_dtype.param().scale; \ + float scale_C = C_dtype.param().scale; \ + DEFINE_OP(_OP); \ + impl::KernCaller<_BIAS, decltype(op), 4, 2>::run( \ + packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace); \ } #define DEFINE_OP(_Op) \ @@ -147,9 +144,9 @@ KERN(nobias, BiasMode::NO_BIAS, relu, ReluOp) KERN(nobias, BiasMode::NO_BIAS, hswish, HSwishOp) #undef DEFINE_OP -#define DEFINE_OP(_Op) \ - arm_common::_Op op(scale_A* scale_B, \ - scale_A* scale_B, scale_C); +#define DEFINE_OP(_Op) \ + arm_common::_Op op( \ + scale_A* scale_B, scale_A* scale_B, scale_C); KERN(bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) KERN(bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) KERN(bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) diff --git a/dnn/src/armv7/conv_bias/int8/strategy.h b/dnn/src/armv7/conv_bias/int8/strategy.h index c722b0b8..0c8722c0 100644 --- a/dnn/src/armv7/conv_bias/int8/strategy.h +++ b/dnn/src/armv7/conv_bias/int8/strategy.h @@ -20,24 +20,23 @@ namespace matmul { * * \name gemm___biasmode_nolinemode */ -MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 2, 16, - false, true, - gemm_s8_4x2_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( + dt_int8, dt_int8, dt_int32, 4, 2, 16, false, true, gemm_s8_4x2_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x2_nobias_relu, - gemm_s8_4x2_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_4x2_nobias_relu, gemm_s8_4x2_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x2_nobias_hswish, - gemm_s8_4x2_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_4x2_nobias_hswish, gemm_s8_4x2_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x2_bias_channel_identity, - gemm_s8_4x2_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_4x2_bias_channel_identity, gemm_s8_4x2_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x2_bias_channel_relu, - gemm_s8_4x2_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_4x2_bias_channel_relu, gemm_s8_4x2_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x2_bias_channel_hswish, - gemm_s8_4x2_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_s8_4x2_bias_channel_hswish, gemm_s8_4x2_nobias_identity); } // namespace matmul } // namespace armv7 diff --git a/dnn/src/armv7/conv_bias/opr_impl.cpp b/dnn/src/armv7/conv_bias/opr_impl.cpp index d34cd2c1..caaadec0 100644 --- a/dnn/src/armv7/conv_bias/opr_impl.cpp +++ b/dnn/src/armv7/conv_bias/opr_impl.cpp @@ -12,9 +12,9 @@ #include "src/armv7/conv_bias/opr_impl.h" #include "src/armv7/conv_bias/int8/algos.h" #include "src/armv7/conv_bias/quint8/algos.h" +#include "src/common/metahelper.h" #include "src/common/utils.h" #include "src/naive/handle.h" -#include "src/common/metahelper.h" #include "src/fallback/convolution/opr_impl.h" @@ -26,6 +26,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoQU8MatrixMul qu8_matrix_mul; fallback::ConvBiasImpl::AlgoBase::Mapper m_all_algos_map; SmallVector m_all_algos; + public: AlgoPack() { m_all_algos.emplace_back(&qu8_matrix_mul); @@ -36,8 +37,7 @@ public: } } - const SmallVector& all_algos() - const { + const SmallVector& all_algos() const { return m_all_algos; } const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } @@ -50,14 +50,14 @@ const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) -SmallVector -ConvBiasImpl::get_all_packed_algo() { +SmallVector ConvBiasImpl::get_all_packed_algo() { auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); //! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, //! and nearly equal in aarch64, because of the waste of register in //! postprocess - algos.insert(algos.end(), algo_pack().all_algos().begin(), - algo_pack().all_algos().end()); + algos.insert( + algos.end(), algo_pack().all_algos().begin(), + algo_pack().all_algos().end()); return std::move(algos); } diff --git a/dnn/src/armv7/conv_bias/quint8/algos.cpp b/dnn/src/armv7/conv_bias/quint8/algos.cpp index e83996d2..c0a1c8bd 100644 --- a/dnn/src/armv7/conv_bias/quint8/algos.cpp +++ b/dnn/src/armv7/conv_bias/quint8/algos.cpp @@ -63,18 +63,19 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( size_t K = IC * FH * FW; size_t N = OH * OW; -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ - MIDOUT_BEGIN(megdnn_armv7_conv_bias_quint8, 0, _gemm_midout_enum, \ - _bias_midout_enum, _nonline_midout_enum) { \ - matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ - M, N, K, param.filter_type, param.src_type, param.dst_type); \ - part2 = megdnn::matmul::GemmInterleaved< \ - matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ - M, N, K, false, false, strategy) \ - .get_workspace_size(); \ - } \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN( \ + megdnn_armv7_conv_bias_quint8, 0, _gemm_midout_enum, _bias_midout_enum, \ + _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + part2 = megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ + M, N, K, false, false, strategy) \ + .get_workspace_size(); \ + } \ MIDOUT_END() DISPATCH_GEMM_BIAS(u8_4x8, 0) @@ -84,8 +85,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( return {nullptr, {part0, part1, part2}}; } -void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, - const NCBKernIndex& ncb_index) { +void ConvBiasImpl::AlgoQU8MatrixMul::kimpl( + const NCBKernParam& param, const NCBKernIndex& ncb_index) { auto is_xcorr = !param.filter_meta.should_flip; UNPACK_CONV_NCB_KERN_SIZES(param); auto bundle = get_bundle(param); @@ -139,33 +140,32 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, img2col(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); } else { if (is_xcorr) - img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, - FW, SH, SW); + img2col_stride( + src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); else - img2col_stride(src2, B, OC, OH, OW, IC, IH2, IW2, FH, - FW, SH, SW); + img2col_stride( + src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); } } { - Workspace workspace(static_cast(bundle.get(2)), - bundle.get_size(2)); + Workspace workspace( + static_cast(bundle.get(2)), bundle.get_size(2)); size_t M = OC; size_t K = IC * FH * FW; size_t N = OH * OW; -#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ - _bias_midout_enum, _nonline, \ - _nonline_midout_enum) \ - MIDOUT_BEGIN(megdnn_armv7_conv_bias_quint8, 1, _gemm_midout_enum, \ - _bias_midout_enum, _nonline_midout_enum) { \ - matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ - M, N, K, param.filter_type, param.src_type, param.dst_type); \ - megdnn::matmul::GemmInterleaved< \ - matmul::gemm_##_gemm##_##_bias##_##_nonline> \ - gemm_interleaved(M, N, K, false, false, strategy); \ - gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ - bias); \ - } \ +#define DISPATCH_GEMM_STRATEGY( \ + _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + MIDOUT_BEGIN( \ + megdnn_armv7_conv_bias_quint8, 1, _gemm_midout_enum, _bias_midout_enum, \ + _nonline_midout_enum) { \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved \ + gemm_interleaved(M, N, K, false, false, strategy); \ + gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ + } \ MIDOUT_END() DISPATCH_GEMM_BIAS(u8_4x8, 0) diff --git a/dnn/src/armv7/conv_bias/quint8/algos.h b/dnn/src/armv7/conv_bias/quint8/algos.h index a3c412d8..8238f4ff 100644 --- a/dnn/src/armv7/conv_bias/quint8/algos.h +++ b/dnn/src/armv7/conv_bias/quint8/algos.h @@ -24,13 +24,12 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { static void kimpl(const NCBKernParam& param, const NCBKernIndex&); public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "QU8MATMUL"; } - bool usable(const NCBKernSizeParam& param, - AlgoSelectionStrategy algo_selection_strategy) const override; + bool usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override { return get_bundle(param).total_size_in_bytes(); } diff --git a/dnn/src/armv7/conv_bias/quint8/strategy.cpp b/dnn/src/armv7/conv_bias/quint8/strategy.cpp index f30aa761..bb00a1e8 100644 --- a/dnn/src/armv7/conv_bias/quint8/strategy.cpp +++ b/dnn/src/armv7/conv_bias/quint8/strategy.cpp @@ -10,13 +10,13 @@ */ #include "src/armv7/conv_bias/quint8/strategy.h" -#include "src/armv7/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" -#include "src/armv7/matrix_mul/quint8/kernel_4x8x8.h" #include "src/arm_common/conv_bias/matmul_postprocess.h" +#include "src/armv7/matrix_mul/quint8/kernel_4x8x8.h" using namespace megdnn; using namespace armv7; @@ -28,10 +28,10 @@ struct KernCaller; template struct KernCaller { - static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, - size_t N, size_t K, dt_uint8* C, size_t LDC, - bool is_first_k, Op op, const dt_int32* bias, - dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { + static void run( + const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, + dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, + dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { megdnn_assert(is_first_k); constexpr size_t A_INTERLEAVE = 4; @@ -47,9 +47,9 @@ struct KernCaller { const dt_uint8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_4x8x8::kern_4x8(packA, cur_packB, K, workspace, 8, - is_first_k, std::min(M - m, 4), - zp_A, zp_B); + matmul_4x8x8::kern_4x8( + packA, cur_packB, K, workspace, 8, is_first_k, + std::min(M - m, 4), zp_A, zp_B); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); @@ -61,19 +61,18 @@ struct KernCaller { } for (; n < N; n += 4) { - matmul_4x8x8::kern_4x4(packA, cur_packB, K, workspace, 4, - is_first_k, std::min(M - m, 4), - std::min(N - n, 4), zp_A, zp_B); + matmul_4x8x8::kern_4x4( + packA, cur_packB, K, workspace, 4, is_first_k, + std::min(M - m, 4), std::min(N - n, 4), zp_A, + zp_B); #define cb(m, n) \ arm_common::ConvBiasMatmul::postprocess( \ bias, workspace, output, LDC, op); - DISPATCH_M(cb, std::min(M - m, 4), - std::min(N - n, 4)); + DISPATCH_M(cb, std::min(M - m, 4), std::min(N - n, 4)); #undef cb output += 4; cur_packB += K4; - } packA += K4; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { @@ -86,30 +85,27 @@ struct KernCaller { } // namespace impl MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_4x8_nobias_identity); -void gemm_u8_4x8_nobias_identity::pack_A(dt_uint8* outptr, - const dt_uint8* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +void gemm_u8_4x8_nobias_identity::pack_A( + dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { uint8_t zA = A_dtype.param().zero_point; if (transpose) { - matmul_4x8x8::gemm_u8_4x8_transpose_pack_A_n(outptr, inptr, ldin, y0, - ymax, k0, kmax, zA); + matmul_4x8x8::gemm_u8_4x8_transpose_pack_A_n( + outptr, inptr, ldin, y0, ymax, k0, kmax, zA); } else { - matmul_4x8x8::gemm_u8_4x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, - kmax, zA); + matmul_4x8x8::gemm_u8_4x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax, zA); } } -void gemm_u8_4x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, - int ldin, int x0, int xmax, int k0, - int kmax, bool transpose) const { +void gemm_u8_4x8_nobias_identity::pack_B( + dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { uint8_t zB = B_dtype.param().zero_point; if (transpose) { - matmul_4x8x8::gemm_u8_4x8_transpose_pack_B_n(out, in, ldin, x0, xmax, - k0, kmax, zB); + matmul_4x8x8::gemm_u8_4x8_transpose_pack_B_n( + out, in, ldin, x0, xmax, k0, kmax, zB); } else { - matmul_4x8x8::gemm_u8_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, - zB); + matmul_4x8x8::gemm_u8_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, zB); } } @@ -117,21 +113,21 @@ size_t gemm_u8_4x8_nobias_identity::get_workspace_size() const { return 4 * 8 * sizeof(dt_int32); } -#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ - void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ - const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ - size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ - const dt_int32* bias, dt_int32* workspace) const { \ - float scale_A = A_dtype.param().scale; \ - uint8_t zp_A = A_dtype.param().zero_point; \ - float scale_B = B_dtype.param().scale; \ - uint8_t zp_B = B_dtype.param().zero_point; \ - float scale_C = C_dtype.param().scale; \ - uint8_t zp_C = C_dtype.param().zero_point; \ - DEFINE_OP(_OP); \ - impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ - packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ - workspace, zp_A, zp_B); \ +#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ + void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ + const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ + size_t K, dt_uint8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ + dt_int32* workspace) const { \ + float scale_A = A_dtype.param().scale; \ + uint8_t zp_A = A_dtype.param().zero_point; \ + float scale_B = B_dtype.param().scale; \ + uint8_t zp_B = B_dtype.param().zero_point; \ + float scale_C = C_dtype.param().scale; \ + uint8_t zp_C = C_dtype.param().zero_point; \ + DEFINE_OP(_OP); \ + impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ + packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace, zp_A, \ + zp_B); \ } #define DEFINE_OP(_Op) \ @@ -142,13 +138,12 @@ KERN(4, 8, nobias, BiasMode::NO_BIAS, relu, ReluOp) KERN(4, 8, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) #undef DEFINE_OP -#define DEFINE_OP(_Op) \ - arm_common::_Op op(scale_A* scale_B, \ - scale_A* scale_B, scale_C, zp_C); +#define DEFINE_OP(_Op) \ + arm_common::_Op op( \ + scale_A* scale_B, scale_A* scale_B, scale_C, zp_C); KERN(4, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) KERN(4, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) -KERN(4, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, - FuseAddHSwishOp) +KERN(4, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) #undef DEFINE_OP #undef KERN diff --git a/dnn/src/armv7/conv_bias/quint8/strategy.h b/dnn/src/armv7/conv_bias/quint8/strategy.h index 3320086c..8146388f 100644 --- a/dnn/src/armv7/conv_bias/quint8/strategy.h +++ b/dnn/src/armv7/conv_bias/quint8/strategy.h @@ -20,24 +20,24 @@ namespace matmul { * * \name gemm___biasmode_nolinemode */ -MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 4, 8, 8, - false, true, - gemm_u8_4x8_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( + dt_uint8, dt_uint8, dt_int32, 4, 8, 8, false, true, + gemm_u8_4x8_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_4x8_nobias_relu, - gemm_u8_4x8_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_4x8_nobias_relu, gemm_u8_4x8_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_4x8_nobias_hswish, - gemm_u8_4x8_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_4x8_nobias_hswish, gemm_u8_4x8_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_4x8_bias_channel_identity, - gemm_u8_4x8_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_4x8_bias_channel_identity, gemm_u8_4x8_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_4x8_bias_channel_relu, - gemm_u8_4x8_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_4x8_bias_channel_relu, gemm_u8_4x8_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_4x8_bias_channel_hswish, - gemm_u8_4x8_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( + gemm_u8_4x8_bias_channel_hswish, gemm_u8_4x8_nobias_identity); } // namespace matmul } // namespace armv7 diff --git a/dnn/src/armv7/handle.cpp b/dnn/src/armv7/handle.cpp index 62fea3a7..8a624729 100644 --- a/dnn/src/armv7/handle.cpp +++ b/dnn/src/armv7/handle.cpp @@ -13,10 +13,10 @@ #include "src/armv7/handle.h" +#include "src/armv7/conv_bias/opr_impl.h" #include "src/armv7/matrix_mul/opr_impl.h" -#include "src/armv7/rotate/opr_impl.h" #include "src/armv7/relayout/opr_impl.h" -#include "src/armv7/conv_bias/opr_impl.h" +#include "src/armv7/rotate/opr_impl.h" namespace megdnn { namespace armv7 { @@ -37,7 +37,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvBias) MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) #pragma GCC diagnostic pop -} // namespace armv7 -} // namespace megdnn +} // namespace armv7 +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/handle.h b/dnn/src/armv7/handle.h index 971424c3..3e724088 100644 --- a/dnn/src/armv7/handle.h +++ b/dnn/src/armv7/handle.h @@ -14,21 +14,18 @@ namespace megdnn { namespace armv7 { -class HandleImpl: public arm_common::HandleImpl { - public: - HandleImpl(megcoreComputingHandle_t computing_handle, - HandleType type = HandleType::ARMV7): - arm_common::HandleImpl::HandleImpl(computing_handle, type) - { - } +class HandleImpl : public arm_common::HandleImpl { +public: + HandleImpl( + megcoreComputingHandle_t computing_handle, + HandleType type = HandleType::ARMV7) + : arm_common::HandleImpl::HandleImpl(computing_handle, type) {} - template - std::unique_ptr create_operator(); + template + std::unique_ptr create_operator(); }; -} // namespace armv7 -} // namespace megdnn +} // namespace armv7 +} // namespace megdnn // vim: syntax=cpp.doxygen - - diff --git a/dnn/src/armv7/matrix_mul/algos.cpp b/dnn/src/armv7/matrix_mul/algos.cpp index 8d4f6397..5f395297 100644 --- a/dnn/src/armv7/matrix_mul/algos.cpp +++ b/dnn/src/armv7/matrix_mul/algos.cpp @@ -46,16 +46,14 @@ void f32_kern(const MatrixMulImpl::KernParam& kern_param) { armv7::matmul::sgemm_4x12 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } } // anonymous namespace -bool MatrixMulImpl::AlgoF32::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF32::usable(const KernSizeParam& kern_size_param) const { return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.format == param::MatrixMul::Format::DEFAULT && kern_size_param.B_type == kern_size_param.A_type && @@ -65,10 +63,8 @@ bool MatrixMulImpl::AlgoF32::usable( size_t MatrixMulImpl::AlgoF32::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoF32::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("AlgoF32::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -81,22 +77,19 @@ size_t MatrixMulImpl::AlgoF32::get_workspace( return 0; } -MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern( - const KernSizeParam&) const { +MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern(const KernSizeParam&) const { return f32_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern, - "AlgoF32Impl"_hash, - armv7::matmul::sgemm_4x12, float, float, - AlgoDataType::FLOAT32, DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoF32, megdnn_armv7_matmul_kern, "AlgoF32Impl"_hash, + armv7::matmul::sgemm_4x12, float, float, AlgoDataType::FLOAT32, DEFAULT); /* ===================== F32 algo mk4 K4x12 ===================== */ namespace { void f32_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("f32_mk4_pack_4x12_kern"_hash)) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("f32_mk4_pack_4x12_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; @@ -105,12 +98,10 @@ void f32_mk4_pack_4x12_kern(const MatrixMulImpl::KernParam& kern_param) { const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - armv7::matmul::sgemm_mk4_pack_4x12 strategy(M, N, K, A_type, B_type, - C_type); + armv7::matmul::sgemm_mk4_pack_4x12 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -125,24 +116,21 @@ bool MatrixMulImpl::AlgoF32MK4Pack4x12::usable( kern_size_param.C_type == kern_size_param.A_type && kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && !kern_size_param.trB && kern_size_param.M % 4 == 0 && - kern_size_param.K % 4 == 0 && !kern_size_param.trA && - !kern_size_param.trB; + kern_size_param.K % 4 == 0 && !kern_size_param.trA && !kern_size_param.trB; } size_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoF32MK4Pack4x12::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoF32MK4Pack4x12::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - armv7::matmul::sgemm_mk4_pack_4x12 strategy(M, N, K, A_type, B_type, - C_type); - return megdnn::matmul::GemmInterleaved< - armv7::matmul::sgemm_mk4_pack_4x12>(M, N, K, trA, trB, - strategy) + armv7::matmul::sgemm_mk4_pack_4x12 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -154,11 +142,9 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_kern( return f32_mk4_pack_4x12_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4Pack4x12, - megdnn_armv7_matmul_kern, - "AlgoF32MK4Pack4x12"_hash, - armv7::matmul::sgemm_mk4_pack_4x12, float, - float, AlgoDataType::FLOAT32, MK4); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoF32MK4Pack4x12, megdnn_armv7_matmul_kern, "AlgoF32MK4Pack4x12"_hash, + armv7::matmul::sgemm_mk4_pack_4x12, float, float, AlgoDataType::FLOAT32, MK4); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC /* ===================== F16 K4x16x1 algo ===================== */ @@ -170,22 +156,19 @@ void f16_kern(const MatrixMulImpl::KernParam& kern_param) { auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); armv7::matmul::hgemm_4x16 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } } // anonymous namespace -bool MatrixMulImpl::AlgoF16K4x16x1::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF16K4x16x1::usable(const KernSizeParam& kern_size_param) const { return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.format == param::MatrixMul::Format::DEFAULT && kern_size_param.C_type == kern_size_param.A_type && @@ -195,10 +178,9 @@ bool MatrixMulImpl::AlgoF16K4x16x1::usable( size_t MatrixMulImpl::AlgoF16K4x16x1::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoF16K4x16x1::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, midout_iv("AlgoF16K4x16x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -216,11 +198,10 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern( return f16_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K4x16x1, megdnn_armv7_matmul_kern, - "AlgoF16K4x16x1"_hash, - armv7::matmul::hgemm_4x16, dt_float16, - dt_float16, AlgoDataType::FLOAT16, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoF16K4x16x1, megdnn_armv7_matmul_kern, "AlgoF16K4x16x1"_hash, + armv7::matmul::hgemm_4x16, dt_float16, dt_float16, AlgoDataType::FLOAT16, + DEFAULT); #endif @@ -228,21 +209,18 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K4x16x1, megdnn_armv7_matmul_kern, namespace { void kern_int8x8x32_k4x2x16(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("kern_int8x8x32_k4x2x16"_hash)) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("kern_int8x8x32_k4x2x16"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto trA = kern_param.trA, trB = kern_param.trB; - armv7::matmul::gemm_s8_4x2 strategy(M, N, K, kern_param.A_type, - kern_param.B_type, - kern_param.C_type); + armv7::matmul::gemm_s8_4x2 strategy( + M, N, K, kern_param.A_type, kern_param.B_type, kern_param.C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -260,10 +238,10 @@ bool MatrixMulImpl::AlgoInt8x8x32K4x2x16::preferred( size_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoInt8x8x32K4x2x16::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x32K4x2x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; auto trA = kern_size_param.trA, trB = kern_size_param.trB; @@ -281,31 +259,25 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_kern( return kern_int8x8x32_k4x2x16; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x2x16, - megdnn_armv7_matmul_kern, - "AlgoInt8x8x32K4x2x16"_hash, - armv7::matmul::gemm_s8_4x2, int8_t, - int32_t, AlgoDataType::QINT8X8X32, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x32K4x2x16, megdnn_armv7_matmul_kern, "AlgoInt8x8x32K4x2x16"_hash, + armv7::matmul::gemm_s8_4x2, int8_t, int32_t, AlgoDataType::QINT8X8X32, DEFAULT); /* ===================== Int8x8x32 Kernel 4x8x8 algo ===================== */ namespace { void kern_int8x8x32_k4x8x8(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("kern_int8x8x32_k4x8x8"_hash)) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("kern_int8x8x32_k4x8x8"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto trA = kern_param.trA, trB = kern_param.trB; - armv7::matmul::gemm_s8_4x8 strategy(M, N, K, kern_param.A_type, - kern_param.B_type, - kern_param.C_type); + armv7::matmul::gemm_s8_4x8 strategy( + M, N, K, kern_param.A_type, kern_param.B_type, kern_param.C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -323,10 +295,10 @@ bool MatrixMulImpl::AlgoInt8x8x32K4x8x8::preferred( size_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoInt8x8x32K4x8x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x32K4x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; auto trA = kern_size_param.trA, trB = kern_size_param.trB; @@ -344,31 +316,25 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_kern( return kern_int8x8x32_k4x8x8; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x8x8, - megdnn_armv7_matmul_kern, - "AlgoInt8x8x32K4x8x8"_hash, - armv7::matmul::gemm_s8_4x8, int8_t, - int32_t, AlgoDataType::QINT8X8X32, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x32K4x8x8, megdnn_armv7_matmul_kern, "AlgoInt8x8x32K4x8x8"_hash, + armv7::matmul::gemm_s8_4x8, int8_t, int32_t, AlgoDataType::QINT8X8X32, DEFAULT); /* ===================== Quint8 Kernel 4x8x8 algo ===================== */ namespace { void kern_quint8_k4x8x8(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("kern_quint8_k4x8x8"_hash)) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("kern_quint8_k4x8x8"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto trA = kern_param.trA, trB = kern_param.trB; - armv7::matmul::gemm_u8_4x8 strategy(M, N, K, kern_param.A_type, - kern_param.B_type, - kern_param.C_type); + armv7::matmul::gemm_u8_4x8 strategy( + M, N, K, kern_param.A_type, kern_param.B_type, kern_param.C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -385,10 +351,10 @@ bool MatrixMulImpl::AlgoQuint8K4x8x8::usable( size_t MatrixMulImpl::AlgoQuint8K4x8x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoQuint8K4x8x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoQuint8K4x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; auto trA = kern_size_param.trA, trB = kern_size_param.trB; @@ -406,30 +372,26 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern( return kern_quint8_k4x8x8; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K4x8x8, megdnn_armv7_matmul_kern, - "AlgoQuint8K4x8x8"_hash, - armv7::matmul::gemm_u8_4x8, uint8_t, - int32_t, AlgoDataType::QUINT8X8X32, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoQuint8K4x8x8, megdnn_armv7_matmul_kern, "AlgoQuint8K4x8x8"_hash, + armv7::matmul::gemm_u8_4x8, uint8_t, int32_t, AlgoDataType::QUINT8X8X32, + DEFAULT); /* ===================== Int8x8x16 Kernel 2x4x16 algo ===================== */ namespace { void kern_int8x8x16_k2x4x16(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("kern_int8x8x16_k2x4x16"_hash)) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("kern_int8x8x16_k2x4x16"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto trA = kern_param.trA, trB = kern_param.trB; - armv7::matmul::gemm_s8x8x16_4x2 strategy(M, N, K, kern_param.A_type, - kern_param.B_type, - kern_param.C_type); + armv7::matmul::gemm_s8x8x16_4x2 strategy( + M, N, K, kern_param.A_type, kern_param.B_type, kern_param.C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -446,10 +408,10 @@ bool MatrixMulImpl::AlgoInt8x8x16K4x2x16::usable( size_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoInt8x8x16K4x2x16::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x16K4x2x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; auto trA = kern_size_param.trA, trB = kern_size_param.trB; @@ -472,30 +434,26 @@ bool MatrixMulImpl::AlgoInt8x8x16K4x2x16::preferred( return kern_size_param.K > 128; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x2x16, - megdnn_armv7_matmul_kern, - "AlgoInt8x8x16K4x2x16"_hash, - armv7::matmul::gemm_s8x8x16_4x2, int8_t, - int16_t, AlgoDataType::INT8X8X16, DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x16K4x2x16, megdnn_armv7_matmul_kern, "AlgoInt8x8x16K4x2x16"_hash, + armv7::matmul::gemm_s8x8x16_4x2, int8_t, int16_t, AlgoDataType::INT8X8X16, + DEFAULT); /* ===================== Int8x8x16 Kernel 4x8x8 algo ===================== */ namespace { void kern_int8x8x16_k4x8x8(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("kern_int8x8x16_k4x8x8"_hash)) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("kern_int8x8x16_k4x8x8"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto trA = kern_param.trA, trB = kern_param.trB; - armv7::matmul::gemm_s8x8x16_4x8 strategy(M, N, K, kern_param.A_type, - kern_param.B_type, - kern_param.C_type); + armv7::matmul::gemm_s8x8x16_4x8 strategy( + M, N, K, kern_param.A_type, kern_param.B_type, kern_param.C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -512,10 +470,10 @@ bool MatrixMulImpl::AlgoInt8x8x16K4x8x8::usable( size_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoInt8x8x16K4x8x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x16K4x8x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; auto trA = kern_size_param.trA, trB = kern_size_param.trB; @@ -538,31 +496,27 @@ bool MatrixMulImpl::AlgoInt8x8x16K4x8x8::preferred( return kern_size_param.K >= 8 && kern_size_param.K <= 128; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8, - megdnn_armv7_matmul_kern, - "AlgoInt8x8x16K4x8x8"_hash, - armv7::matmul::gemm_s8x8x16_4x8, int8_t, - int16_t, AlgoDataType::INT8X8X16, DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x16K4x8x8, megdnn_armv7_matmul_kern, "AlgoInt8x8x16K4x8x8"_hash, + armv7::matmul::gemm_s8x8x16_4x8, int8_t, int16_t, AlgoDataType::INT8X8X16, + DEFAULT); /* ===================== Int8x8x16 Kernel 8x8x4 algo ===================== */ namespace { void kern_int8x8x16_k8x8x4(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("kern_int8x8x16_k8x8x4"_hash)) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("kern_int8x8x16_k8x8x4"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto trA = kern_param.trA, trB = kern_param.trB; - armv7::matmul::gemm_s8x8x16_8x8 strategy(M, N, K, kern_param.A_type, - kern_param.B_type, - kern_param.C_type); + armv7::matmul::gemm_s8x8x16_8x8 strategy( + M, N, K, kern_param.A_type, kern_param.B_type, kern_param.C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -579,10 +533,10 @@ bool MatrixMulImpl::AlgoInt8x8x16K8x8x4::usable( size_t MatrixMulImpl::AlgoInt8x8x16K8x8x4::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoInt8x8x16K8x8x4::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x16K8x8x4::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; auto trA = kern_size_param.trA, trB = kern_size_param.trB; @@ -605,32 +559,28 @@ bool MatrixMulImpl::AlgoInt8x8x16K8x8x4::preferred( return kern_size_param.K >= 8 && kern_size_param.K <= 128; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x4, - megdnn_armv7_matmul_kern, - "AlgoInt8x8x16K8x8x4"_hash, - armv7::matmul::gemm_s8x8x16_8x8, int8_t, - int16_t, AlgoDataType::INT8X8X16, DEFAULT); - +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x16K8x8x4, megdnn_armv7_matmul_kern, "AlgoInt8x8x16K8x8x4"_hash, + armv7::matmul::gemm_s8x8x16_8x8, int8_t, int16_t, AlgoDataType::INT8X8X16, + DEFAULT); /* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/ namespace { void kern_int8x8x16_mk4_k8x8x4(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("kern_int8x8x16_mk4_k8x8x4"_hash)) { + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, midout_iv("kern_int8x8x16_mk4_k8x8x4"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto trA = kern_param.trA, trB = kern_param.trB; - armv7::matmul::gemm_s8x8x16_mk4_8x8 strategy(M, N, K, kern_param.A_type, - kern_param.B_type, - kern_param.C_type); + armv7::matmul::gemm_s8x8x16_mk4_8x8 strategy( + M, N, K, kern_param.A_type, kern_param.B_type, kern_param.C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -642,16 +592,16 @@ bool MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::usable( return type_ok && kern_size_param.format == param::MatrixMul::Format::MK4 && kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && - !kern_size_param.trA && !kern_size_param.trB && - kern_size_param.M % 4 == 0 && kern_size_param.K % 4 == 0; + !kern_size_param.trA && !kern_size_param.trB && kern_size_param.M % 4 == 0 && + kern_size_param.K % 4 == 0; } size_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoInt8x8x16K8x8x4::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x16K8x8x4::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; auto trA = kern_size_param.trA, trB = kern_size_param.trB; @@ -674,32 +624,27 @@ bool MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::preferred( return kern_size_param.K >= 4; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16MK4_8x8x4, - megdnn_armv7_matmul_kern, - "AlgoInt8x8x16MK4_8x8x4"_hash, - armv7::matmul::gemm_s8x8x16_mk4_8x8, - int8_t, int16_t, int16_t, - AlgoDataType::INT8X8X16, MK4); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( + AlgoInt8x8x16MK4_8x8x4, megdnn_armv7_matmul_kern, "AlgoInt8x8x16MK4_8x8x4"_hash, + armv7::matmul::gemm_s8x8x16_mk4_8x8, int8_t, int16_t, int16_t, + AlgoDataType::INT8X8X16, MK4); /* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */ namespace { void kern_int16x16x32K12x4x1(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("kern_int16x16x32K12x4x1"_hash)) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("kern_int16x16x32K12x4x1"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto trA = kern_param.trA, trB = kern_param.trB; - armv7::matmul::gemm_s16x16x32_12x4 strategy(M, N, K, kern_param.A_type, - kern_param.B_type, - kern_param.C_type); + armv7::matmul::gemm_s16x16x32_12x4 strategy( + M, N, K, kern_param.A_type, kern_param.B_type, kern_param.C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -715,10 +660,10 @@ bool MatrixMulImpl::AlgoInt16x16x32K12x4x1::usable( size_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoInt16x16x32K12x4x1::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt16x16x32K12x4x1::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; auto trA = kern_size_param.trA, trB = kern_size_param.trB; @@ -741,12 +686,10 @@ bool MatrixMulImpl::AlgoInt16x16x32K12x4x1::preferred( return true; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1, - megdnn_armv7_matmul_kern, - "AlgoInt16x16x32K12x4x1"_hash, - armv7::matmul::gemm_s16x16x32_12x4, - int16_t, int32_t, - AlgoDataType::INT16X16X32, DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt16x16x32K12x4x1, megdnn_armv7_matmul_kern, "AlgoInt16x16x32K12x4x1"_hash, + armv7::matmul::gemm_s16x16x32_12x4, int16_t, int32_t, AlgoDataType::INT16X16X32, + DEFAULT); #if MGB_ENABLE_DOT /* ===================== Int8 K6x8x4 algo ===================== */ namespace { @@ -757,14 +700,12 @@ void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); armv7::matmul::gemm_dots8_6x8 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -772,7 +713,7 @@ void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoInt8x8x32K6x8x4::usable( const KernSizeParam& kern_size_param) const { - if (!cpuinfo_has_arm_neon_dot()){ + if (!cpuinfo_has_arm_neon_dot()) { return false; } return can_be_treated_as_int8x8x32(kern_size_param); @@ -780,10 +721,10 @@ bool MatrixMulImpl::AlgoInt8x8x32K6x8x4::usable( size_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoInt8x8x32K6x8x4::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x32K6x8x4::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; @@ -801,31 +742,25 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_kern( return int8_k6x8x4_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K6x8x4, - megdnn_armv7_matmul_kern, - "AlgoInt8x8x32K6x8x4"_hash, - armv7::matmul::gemm_dots8_6x8, int8_t, - int32_t, AlgoDataType::QINT8X8X32, - DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x32K6x8x4, megdnn_armv7_matmul_kern, "AlgoInt8x8x32K6x8x4"_hash, + armv7::matmul::gemm_dots8_6x8, int8_t, int32_t, AlgoDataType::QINT8X8X32, + DEFAULT); /* ===================== Quint8 K4x8x4 algo ===================== */ namespace { void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("quint8_dot_k4x8x4_kern"_hash)) { + MIDOUT_BEGIN(megdnn_armv7_matmul_kern, midout_iv("quint8_dot_k4x8x4_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - armv7::matmul::gemm_dot_quint8_4x8 strategy(M, N, K, A_type, B_type, - C_type); + armv7::matmul::gemm_dot_quint8_4x8 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -833,7 +768,7 @@ void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoQuint8DotK4x8x4::usable( const KernSizeParam& kern_size_param) const { - if (!cpuinfo_has_arm_neon_dot()){ + if (!cpuinfo_has_arm_neon_dot()) { return false; } return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && @@ -845,18 +780,16 @@ bool MatrixMulImpl::AlgoQuint8DotK4x8x4::usable( size_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoQuint8DotK4x8x4::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoQuint8DotK4x8x4::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - armv7::matmul::gemm_dot_quint8_4x8 strategy(M, N, K, A_type, B_type, - C_type); - return megdnn::matmul::GemmInterleaved< - armv7::matmul::gemm_dot_quint8_4x8>(M, N, K, trA, trB, - strategy) + armv7::matmul::gemm_dot_quint8_4x8 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -867,32 +800,27 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_kern( const KernSizeParam&) const { return quint8_dot_k4x8x4_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4, - megdnn_armv7_matmul_kern, - "AlgoQuint8DotK4x8x4"_hash, - armv7::matmul::gemm_dot_quint8_4x8, - uint8_t, int32_t, - AlgoDataType::QUINT8X8X32, DEFAULT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoQuint8DotK4x8x4, megdnn_armv7_matmul_kern, "AlgoQuint8DotK4x8x4"_hash, + armv7::matmul::gemm_dot_quint8_4x8, uint8_t, int32_t, AlgoDataType::QUINT8X8X32, + DEFAULT); /* ======================== Int8 MK4 8x4x4 dot algo ======================== */ namespace { void int8_mk4_8x4x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("int8_mk4_8x4x4_dotprod_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, midout_iv("int8_mk4_8x4x4_dotprod_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); - armv7::matmul::gemm_mk4_dots8_8x4 strategy(M, N, K, A_type, B_type, - C_type); + armv7::matmul::gemm_mk4_dots8_8x4 strategy(M, N, K, A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } @@ -900,7 +828,7 @@ void int8_mk4_8x4x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::usable( const KernSizeParam& kern_size_param) const { - if (!cpuinfo_has_arm_neon_dot()){ + if (!cpuinfo_has_arm_neon_dot()) { return false; } return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && @@ -918,16 +846,13 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace( MIDOUT_BEGIN( megdnn_armv7_matmul_kern, midout_iv("AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - armv7::matmul::gemm_mk4_dots8_8x4 strategy(M, N, K, A_type, B_type, - C_type); - return megdnn::matmul::GemmInterleaved< - armv7::matmul::gemm_mk4_dots8_8x4>(M, N, K, trA, trB, - strategy) + armv7::matmul::gemm_mk4_dots8_8x4 strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -939,11 +864,10 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_kern( return int8_mk4_8x4x4_dotprod_kern; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x4x4DotProd, - megdnn_armv7_matmul_kern, - "AlgoInt8x8x32MK4_8x4x4DotProd"_hash, - armv7::matmul::gemm_mk4_dots8_8x4, int8_t, - int32_t, AlgoDataType::QINT8X8X32, MK4_DOT); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x32MK4_8x4x4DotProd, megdnn_armv7_matmul_kern, + "AlgoInt8x8x32MK4_8x4x4DotProd"_hash, armv7::matmul::gemm_mk4_dots8_8x4, int8_t, + int32_t, AlgoDataType::QINT8X8X32, MK4_DOT); #endif /* ===================== F32 algo K4x8 ===================== */ @@ -962,37 +886,33 @@ void f32_mk4_4x8_kern(const MatrixMulImpl::KernParam& kern_param) { armv7::matmul::sgemm_nopack_4x8 strategy(A_type, B_type, C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } } // anonymous namespace -bool MatrixMulImpl::AlgoF32MK4_4x8::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF32MK4_4x8::usable(const KernSizeParam& kern_size_param) const { return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.format == param::MatrixMul::Format::MK4 && kern_size_param.B_type == kern_size_param.A_type && kern_size_param.C_type == kern_size_param.A_type && - kern_size_param.A_type == dtype::Float32() && - !kern_size_param.trA && !kern_size_param.trB; + kern_size_param.A_type == dtype::Float32() && !kern_size_param.trA && + !kern_size_param.trB; } size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoF32MK4_4x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, midout_iv("AlgoF32MK4_4x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; armv7::matmul::sgemm_nopack_4x8 strategy(A_type, B_type, C_type); - return megdnn::matmul::GemmInterleaved(M, N, K, trA, trB, - strategy) + return megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -1012,23 +932,23 @@ bool MatrixMulImpl::AlgoInt16x16x32MK8_4x8::usable( kern_size_param.format == param::MatrixMul::Format::MK8 && kern_size_param.A_type == dtype::Int16() && kern_size_param.B_type == dtype::Int16() && - kern_size_param.C_type == dtype::Int32() && - !kern_size_param.trA && !kern_size_param.trB; + kern_size_param.C_type == dtype::Int32() && !kern_size_param.trA && + !kern_size_param.trB; } size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoInt16x16x32MK8_4x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt16x16x32MK8_4x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; armv7::matmul::gemm_nopack_s16_4x8 strategy(A_type, B_type, C_type); return megdnn::matmul::GemmInterleaved< - armv7::matmul::gemm_nopack_s16_4x8, false>(M, N, K, trA, - trB, strategy) + armv7::matmul::gemm_nopack_s16_4x8, false>( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -1038,23 +958,21 @@ size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace( MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_kern( const KernSizeParam&) const { auto kern_mk8_4x8 = [](const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoInt16x16x32MK8_4x8::get_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt16x16x32MK8_4x8::get_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; - auto LDA = kern_param.LDA, LDB = kern_param.LDB, - LDC = kern_param.LDC; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; - const auto Aptr = kern_param.A(), - Bptr = kern_param.B(); + const auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); armv7::matmul::gemm_nopack_s16_4x8 strategy(A_type, B_type, C_type); - megdnn::matmul::GemmInterleaved(M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); }; @@ -1064,8 +982,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_kern( #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC /* ===================== F16_MK8_4x8 algo ===================== */ -bool MatrixMulImpl::AlgoF16MK8_4x8::usable( - const KernSizeParam& kern_size_param) const { +bool MatrixMulImpl::AlgoF16MK8_4x8::usable(const KernSizeParam& kern_size_param) const { return kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && kern_size_param.C_type == kern_size_param.A_type && kern_size_param.B_type == kern_size_param.A_type && @@ -1076,17 +993,16 @@ bool MatrixMulImpl::AlgoF16MK8_4x8::usable( size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoF16MK8_4x8::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, midout_iv("AlgoF16MK8_4x8::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto trA = kern_size_param.trA, trB = kern_size_param.trB; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type, C_type); return megdnn::matmul::GemmInterleaved< - armv7::matmul::gemm_nopack_f16_4x8, false>(M, N, K, trA, - trB, strategy) + armv7::matmul::gemm_nopack_f16_4x8, false>( + M, N, K, trA, trB, strategy) .get_workspace_size(); } MIDOUT_END(); @@ -1096,12 +1012,11 @@ size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace( MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern( const KernSizeParam&) const { auto kern_mk8_4x8 = [](const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoF16MK8_4x8::get_kern"_hash)) { + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, midout_iv("AlgoF16MK8_4x8::get_kern"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto trA = kern_param.trA, trB = kern_param.trB; - auto LDA = kern_param.LDA, LDB = kern_param.LDB, - LDC = kern_param.LDC; + auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto A_type = kern_param.A_type, B_type = kern_param.B_type, C_type = kern_param.C_type; const auto Aptr = kern_param.A(), @@ -1109,10 +1024,9 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern( auto Cptr = kern_param.C(); armv7::matmul::gemm_nopack_f16_4x8 strategy(A_type, B_type, C_type); - megdnn::matmul::GemmInterleaved(M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + megdnn::matmul::GemmInterleaved( + M, N, K, trA, trB, strategy) + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); }; @@ -1124,28 +1038,25 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern( namespace { void kern_int8x8x32_mk4_4x2x16(const MatrixMulImpl::KernParam& kern_param) { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("kern_int8x8x32_mk4_4x2x16"_hash)) { + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, midout_iv("kern_int8x8x32_mk4_4x2x16"_hash)) { auto M = kern_param.M, N = kern_param.N, K = kern_param.K; auto Aptr = kern_param.A(), Bptr = kern_param.B(); auto Cptr = kern_param.C(); auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC; auto trA = kern_param.trA, trB = kern_param.trB; - armv7::matmul::gemm_mk4_s8_4x2 strategy(M, N, K, kern_param.A_type, - kern_param.B_type, - kern_param.C_type); + armv7::matmul::gemm_mk4_s8_4x2 strategy( + M, N, K, kern_param.A_type, kern_param.B_type, kern_param.C_type); megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) - .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, - kern_param.workspace_ptr); + .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); } MIDOUT_END(); } } // anonymous namespace -bool MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::usable( - const KernSizeParam& param) const { +bool MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::usable(const KernSizeParam& param) const { return param.A_type.enumv() == param.B_type.enumv() && (param.A_type.enumv() == DTypeEnum::Int8 || param.A_type.enumv() == DTypeEnum::QuantizedS8) && @@ -1158,10 +1069,10 @@ bool MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::usable( size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_workspace( const KernSizeParam& kern_size_param) const { - MIDOUT_BEGIN(megdnn_armv7_matmul_kern, - midout_iv("AlgoInt8x8x32MK4_4x2x16::get_workspace"_hash)) { - auto M = kern_size_param.M, N = kern_size_param.N, - K = kern_size_param.K; + MIDOUT_BEGIN( + megdnn_armv7_matmul_kern, + midout_iv("AlgoInt8x8x32MK4_4x2x16::get_workspace"_hash)) { + auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K; auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; auto trA = kern_size_param.trA, trB = kern_size_param.trB; @@ -1184,10 +1095,9 @@ bool MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::preferred( return kern_size_param.K > 16; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x2x16, - megdnn_armv7_matmul_kern, - "AlgoInt8x8x32MK4_4x2x16"_hash, - armv7::matmul::gemm_mk4_s8_4x2, int8_t, - int32_t, AlgoDataType::QINT8X8X32, MK4); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoInt8x8x32MK4_4x2x16, megdnn_armv7_matmul_kern, + "AlgoInt8x8x32MK4_4x2x16"_hash, armv7::matmul::gemm_mk4_s8_4x2, int8_t, int32_t, + AlgoDataType::QINT8X8X32, MK4); // vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index 26176770..7ccb9cd8 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -21,9 +21,7 @@ namespace armv7 { class MatrixMulImpl::AlgoF32 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV7_F32"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -35,8 +33,7 @@ public: class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } bool usable(const KernSizeParam&) const override; @@ -48,9 +45,7 @@ public: class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV7_F32_MK4_4x8"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -63,9 +58,7 @@ public: #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class MatrixMulImpl::AlgoF16K4x16x1 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH32_F16_K4X16X1"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -75,9 +68,7 @@ public: }; class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH32_F16_MK8_4X8"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -90,9 +81,7 @@ public: #if MGB_ENABLE_DOT class MatrixMulImpl::AlgoInt8x8x32K6x8x4 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH32_INT8_K6X8X4"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -103,9 +92,7 @@ public: class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "AARCH32_QUINT8_K4X8X4"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -116,12 +103,8 @@ public: class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } - const char* name() const override { - return "AARCH32_INT8_MK4_8X4X4_DOTPROD"; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + const char* name() const override { return "AARCH32_INT8_MK4_8X4X4_DOTPROD"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; @@ -130,8 +113,7 @@ public: }; #endif -class MatrixMulImpl::AlgoF32Gemv final - : public arm_common::MatrixMulImpl::AlgoF32Gemv { +class MatrixMulImpl::AlgoF32Gemv final : public arm_common::MatrixMulImpl::AlgoF32Gemv { public: AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { m_handle_type = Handle::HandleType::ARMV7; @@ -141,9 +123,7 @@ public: class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV7_INT8X8X32_K4X2X16"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -155,9 +135,7 @@ public: class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV7_INT8X8X32_K4X8X8"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -169,9 +147,7 @@ public: class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV7_QUINT8_K4X8X8"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -182,9 +158,7 @@ public: class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV7_INT8X8X16_K4X2X16"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -196,9 +170,7 @@ public: class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV7_INT8X8X16_K4X8X8"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -210,9 +182,7 @@ public: class MatrixMulImpl::AlgoInt8x8x16K8x8x4 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV7_INT8X8X16_K8X8X4"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -225,8 +195,7 @@ public: class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } bool usable(const KernSizeParam&) const override; @@ -239,9 +208,7 @@ public: class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV7_INT16X16X32_K12X4X1"; } bool usable(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override; @@ -253,9 +220,7 @@ public: class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase { public: - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "ARMV7_INT16X16X32_MK8_4X8"; } bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; @@ -268,8 +233,7 @@ public: class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } bool usable(const KernSizeParam&) const override; diff --git a/dnn/src/armv7/matrix_mul/asm/common.h b/dnn/src/armv7/matrix_mul/asm/common.h index 1b665a56..ab654462 100644 --- a/dnn/src/armv7/matrix_mul/asm/common.h +++ b/dnn/src/armv7/matrix_mul/asm/common.h @@ -100,11 +100,9 @@ static inline void prefetch_1x(const void* pfp) { * rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[j, i] */ -static inline void interleave_4x1_2_d(const int64_t*& inptr0, - const int64_t*& inptr1, - const int64_t*& inptr2, - const int64_t*& inptr3, - int64_t*& outptr) { +static inline void interleave_4x1_2_d( + const int64_t*& inptr0, const int64_t*& inptr1, const int64_t*& inptr2, + const int64_t*& inptr3, int64_t*& outptr) { asm volatile( "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1 "vld1.32 {d2, d3}, [%[inptr1]]!\n" // B0B1 @@ -119,33 +117,29 @@ static inline void interleave_4x1_2_d(const int64_t*& inptr0, "vst1.32 {d3}, [%[outptr]]!\n" "vst1.32 {d5}, [%[outptr]]!\n" "vst1.32 {d7}, [%[outptr]]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "q3", "cc", "memory"); } -static inline void interleave_2x1_4_s(const int32_t*& inptr0, - const int32_t*& inptr1, - int32_t*& outptr) { +static inline void interleave_2x1_4_s( + const int32_t*& inptr0, const int32_t*& inptr1, int32_t*& outptr) { asm volatile( "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 "vld1.32 {d2, d3}, [%[inptr1]]!\n" // A0A1A2A3 "vst1.32 {d0, d1}, [%[outptr]]!\n" "vst1.32 {d2, d3}, [%[outptr]]!\n" - : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) : : "d0", "d1", "d2", "d3", "cc", "memory"); } template -static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T*& outptr) { +static inline void interleave_8x8_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_8x8_1_b only support uint8_t and int8_t"); @@ -167,18 +161,17 @@ static inline void interleave_8x8_1_b(const T*& inptr0, const T*& inptr1, "vst1.32 {d5}, [%[outptr]]!\n" // F1F2F3F4F5F6F7F8 "vst1.32 {d6}, [%[outptr]]!\n" // G1G2G3G4G5G6G7G8 "vst1.32 {d7}, [%[outptr]]!\n" // H1H2H3H4H5H6H7H8 - : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "q3", "memory"); } template -static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x4_4_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_4x4_4_b only support uint8_t and int8_t"); @@ -189,35 +182,33 @@ static inline void interleave_4x4_4_b(const T*& inptr0, const T*& inptr1, "vld1.32 {d6, d7}, [%[inptr3]]!\n" // D0D1D2D3 "vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 "vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 - "vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2 - "vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3 + "vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2 + "vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3 "vst1.32 {d0-d1},[%[outptr]]!\n" "vst1.32 {d2-d3},[%[outptr]]!\n" "vst1.32 {d4-d5},[%[outptr]]!\n" "vst1.32 {d6-d7},[%[outptr]]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "q3", "memory"); } template -static inline void interleave_2x4_4_b(const T*& inptr0, const T*& inptr1, - T*& outptr) { +static inline void interleave_2x4_4_b(const T*& inptr0, const T*& inptr1, T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_2x4_4_b only support uint8_t and int8_t"); - interleave_2x1_4_s(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(outptr)); + interleave_2x1_4_s( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(outptr)); } template -static inline void interleave_6x4_4_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - T*& outptr) { +static inline void interleave_6x4_4_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_6x4_4_b only support uint8_t and int8_t"); @@ -231,8 +222,8 @@ static inline void interleave_6x4_4_b(const T*& inptr0, const T*& inptr1, "vtrn.32 q0, q1\n" // A0B0A2B2 A1B1A3B3 "vtrn.32 q2, q3\n" // C0D0C2D2 C1D1C3D3 "vtrn.32 q4, q5\n" // E0F0E2F2 E1F1E3F3 - "vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2 - "vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3 + "vswp d1, d4 \n" // q0=A0,B0,C0,D0 q2=A2,B2,C2,D2 + "vswp d3, d6 \n" // q1=A1,B1,C1,D1 q3=A3,B3,C3,D3 "vst1.32 {d0-d1},[%[outptr]]!\n" "vst1.32 {d8}, [%[outptr]]!\n" @@ -244,20 +235,18 @@ static inline void interleave_6x4_4_b(const T*& inptr0, const T*& inptr1, "vst1.32 {d6-d7},[%[outptr]]!\n" "vst1.32 {d11}, [%[outptr]]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "q3", "q4", "q5", "memory"); } template -static inline void interleave_8x4_4_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T*& outptr) { +static inline void interleave_8x4_4_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_8x4_4_b only support uint8_t and int8_t"); @@ -292,19 +281,17 @@ static inline void interleave_8x4_4_b(const T*& inptr0, const T*& inptr1, "vst1.32 {d6-d7},[%[outptr]]!\n" "vst1.32 {d14-d15},[%[outptr]]!\n" - : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "q3", "q4", "q5", "memory"); } template -static inline void interleave_6x4_8_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - T*& outptr) { +static inline void interleave_6x4_8_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_6x8_4_b only support uint8_t and int8_t"); @@ -342,19 +329,18 @@ static inline void interleave_6x4_8_b(const T*& inptr0, const T*& inptr1, "vst1.32 {d22}, [%[outptr]]! \n" "vst1.32 {d14-d15},[%[outptr]]! \n" "vst1.32 {d23}, [%[outptr]]! \n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), [outptr] "+r"(outptr) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "cc", "memory"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "cc", "memory"); } template -static inline void interleave_4x16_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x16_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert(sizeof(T) == 1, "only support size == 1"); asm volatile( "vld1.32 {d0, d1}, [%[inptr0]]!\n" // d0 = A0A1A2A3 @@ -366,30 +352,29 @@ static inline void interleave_4x16_1_b(const T*& inptr0, const T*& inptr1, "vst1.32 {d4, d5}, [%[outptr]]!\n" "vst1.32 {d6, d7}, [%[outptr]]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "q3", "cc", "memory"); } template -static inline void interleave_4x8_2_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void interleave_4x8_2_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_4x8_2_b only support uint8_t and int8_t"); - interleave_4x1_2_d(reinterpret_cast(inptr0), - reinterpret_cast(inptr1), - reinterpret_cast(inptr2), - reinterpret_cast(inptr3), - reinterpret_cast(outptr)); + interleave_4x1_2_d( + reinterpret_cast(inptr0), + reinterpret_cast(inptr1), + reinterpret_cast(inptr2), + reinterpret_cast(inptr3), + reinterpret_cast(outptr)); } template -static inline void interleave_2x16_1_b(const T*& inptr0, const T*& inptr1, - T*& outptr) { +static inline void interleave_2x16_1_b(const T*& inptr0, const T*& inptr1, T*& outptr) { static_assert(sizeof(T) == 1, "only support size == 2"); asm volatile( "vld1.32 {d0, d1}, [%[inptr0]]!\n" @@ -397,18 +382,16 @@ static inline void interleave_2x16_1_b(const T*& inptr0, const T*& inptr1, "vst1.32 {d0, d1}, [%[outptr]]!\n" "vst1.32 {d2, d3}, [%[outptr]]!\n" - : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [outptr] "+r"(outptr) : : "q0", "q1", "cc", "memory"); } template -static inline void interleave_4x4_1_h(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { - static_assert(sizeof(T) == 2, - "interleave_4x16_1_h only support sizeof(T) == 2"); +static inline void interleave_4x4_1_h( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 2, "interleave_4x16_1_h only support sizeof(T) == 2"); asm volatile( "vld1.16 {d0}, [%[inptr0]]!\n" "vld1.16 {d1}, [%[inptr1]]!\n" @@ -419,19 +402,17 @@ static inline void interleave_4x4_1_h(const T*& inptr0, const T*& inptr1, "vst1.16 {d1}, [%[outptr]]!\n" "vst1.16 {d2}, [%[outptr]]!\n" "vst1.16 {d3}, [%[outptr]]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "d0", "d1", "d2", "d3", "d4", "memory"); } template -static inline void interleave_4x12_1_h(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { - static_assert(sizeof(T) == 2, - "interleave_4x12_1_h only support sizeof(T) == 2"); +static inline void interleave_4x12_1_h( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 2, "interleave_4x12_1_h only support sizeof(T) == 2"); asm volatile( "pld [%[inptr0],#192]\n" "vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3 @@ -462,20 +443,18 @@ static inline void interleave_4x12_1_h(const T*& inptr0, const T*& inptr1, "vst1.16 {d9}, [%[outptr]]!\n" // G0G1G2G3 "vst1.16 {d10}, [%[outptr]]!\n" // H0H1H2H3 "vst1.16 {d11}, [%[outptr]]!\n" // H0H1H2H3 - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "memory"); } template -static inline void interleave_4x16_1_h(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { - static_assert(sizeof(T) == 2, - "interleave_4x16_1_h only support sizeof(T) == 2"); +static inline void interleave_4x16_1_h( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 2, "interleave_4x16_1_h only support sizeof(T) == 2"); asm volatile( "vld1.16 {d0, d1, d2, d3}, [%[inptr0]]!\n" "vld1.16 {d4, d5, d6, d7}, [%[inptr1]]!\n" @@ -486,20 +465,18 @@ static inline void interleave_4x16_1_h(const T*& inptr0, const T*& inptr1, "vst1.16 {d4, d5, d6, d7}, [%[outptr]]!\n" "vst1.16 {d8, d9, d10, d11}, [%[outptr]]!\n" "vst1.16 {d12, d13, d14, d15}, [%[outptr]]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "memory"); } template -static inline void interleave_4x4_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { - static_assert(sizeof(T) == 4, - "interleave_4x4_1_s only support sizeof(T) == 4"); +static inline void interleave_4x4_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 4, "interleave_4x4_1_s only support sizeof(T) == 4"); asm volatile( "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 "vld1.32 {d2, d3}, [%[inptr1]]!\n" // A0A1A2A3 @@ -510,17 +487,15 @@ static inline void interleave_4x4_1_s(const T*& inptr0, const T*& inptr1, "vst1.32 {d2, d3}, [%[outptr]]!\n" // E0F0G0H0 "vst1.32 {d4, d5}, [%[outptr]]!\n" // I0J0K0L0 "vst1.32 {d6, d7}, [%[outptr]]!\n" // D0D1D2D3 - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "memory"); } template static inline void interleave_1x4_1_h(const T*& inptr0, T*& outptr) { - static_assert(sizeof(T) == 2, - "transpose_1x4_1_h only support sizeof(T) == 2"); + static_assert(sizeof(T) == 2, "transpose_1x4_1_h only support sizeof(T) == 2"); asm volatile( "vld1.16 {d0}, [%[inptr0]]!\n" // A01234567 "vst1.16 {d0}, [%[outptr]]!\n" @@ -530,11 +505,10 @@ static inline void interleave_1x4_1_h(const T*& inptr0, T*& outptr) { } template -static inline void interleave_4x12_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { - static_assert(sizeof(T) == 4, - "interleave_4x12_1_s only support sizeof(T) == 4"); +static inline void interleave_4x12_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { + static_assert(sizeof(T) == 4, "interleave_4x12_1_s only support sizeof(T) == 4"); asm volatile( "vld1.32 {d0, d1}, [%[inptr0]]!\n" // A0A1A2A3 "vld1.32 {d2, d3}, [%[inptr0]]!\n" // B0B1B2B3 @@ -561,19 +535,17 @@ static inline void interleave_4x12_1_s(const T*& inptr0, const T*& inptr1, "vst1.32 {d18, d19}, [%[outptr]]!\n" // G0G1G2G3 "vst1.32 {d20, d21}, [%[outptr]]!\n" // H0H1H2H3 "vst1.32 {d22, d23}, [%[outptr]]!\n" // H0H1H2H3 - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "memory"); } template static inline void interleave_1x12_1_h(const T*& inptr0, T*& outptr) { - static_assert(sizeof(T) == 2, - "transpose_1x12_1_h only support sizeof(T) == 2"); + static_assert(sizeof(T) == 2, "transpose_1x12_1_h only support sizeof(T) == 2"); asm volatile( "vld1.16 {d0,d1}, [%[inptr0]]!\n" // A01234567 "vld1.16 {d2} , [%[inptr0]]!\n" // A891011 @@ -586,8 +558,7 @@ static inline void interleave_1x12_1_h(const T*& inptr0, T*& outptr) { template static inline void interleave_1x12_1_s(const T*& inptr0, T*& outptr) { - static_assert(sizeof(T) == 4, - "interleave_1x12_1_s only support sizeof(T) == 4"); + static_assert(sizeof(T) == 4, "interleave_1x12_1_s only support sizeof(T) == 4"); asm volatile( "vld1.32 {d0, d1}, [%[inptr0]]!\n" "vld1.32 {d2, d3}, [%[inptr0]]!\n" @@ -602,8 +573,7 @@ static inline void interleave_1x12_1_s(const T*& inptr0, T*& outptr) { template static inline void interleave_1x16_1_h(const T*& inptr0, T*& outptr) { - static_assert(sizeof(T) == 2, - "transpose_1x12_1_h only support sizeof(T) == 2"); + static_assert(sizeof(T) == 2, "transpose_1x12_1_h only support sizeof(T) == 2"); asm volatile( "vld1.16 {d0,d1, d2, d3}, [%[inptr0]]!\n" "vst1.16 {d0,d1, d2, d3}, [%[outptr]]!\n" @@ -614,8 +584,7 @@ static inline void interleave_1x16_1_h(const T*& inptr0, T*& outptr) { template static inline void interleave_1x4_1_s(const T*& inptr0, T*& outptr) { - static_assert(sizeof(T) == 4, - "interleave_1x4_1_s only support sizeof(T) == 4"); + static_assert(sizeof(T) == 4, "interleave_1x4_1_s only support sizeof(T) == 4"); asm volatile( "vld1.32 {d0, d1}, [%[inptr0]]!\n" "vst1.32 {d0, d1}, [%[outptr]]\n" @@ -625,8 +594,8 @@ static inline void interleave_1x4_1_s(const T*& inptr0, T*& outptr) { } template -static inline void interleave_helper(const T*& inptr, T*& outptr, int unroll_k, - int ksize, T val = 0) { +static inline void interleave_helper( + const T*& inptr, T*& outptr, int unroll_k, int ksize, T val = 0) { int k = 0; for (; k < ksize; k++) { *outptr++ = *inptr++; @@ -637,8 +606,8 @@ static inline void interleave_helper(const T*& inptr, T*& outptr, int unroll_k, } template -static inline void interleave_1(const T*& inptr0, T*& outptr, int unroll_k, - int ksize, T val = 0) { +static inline void interleave_1( + const T*& inptr0, T*& outptr, int unroll_k, int ksize, T val = 0) { for (int k = 0; k < ksize; k += unroll_k) { int size = std::min(unroll_k, ksize - k); interleave_helper(inptr0, outptr, unroll_k, size, val); @@ -646,8 +615,9 @@ static inline void interleave_1(const T*& inptr0, T*& outptr, int unroll_k, } template -static inline void interleave_2(const T*& inptr0, const T*& inptr1, T*& outptr, - int unroll_k, int ksize, T val = 0) { +static inline void interleave_2( + const T*& inptr0, const T*& inptr1, T*& outptr, int unroll_k, int ksize, + T val = 0) { for (int k = 0; k < ksize; k += unroll_k) { int size = std::min(unroll_k, ksize - k); interleave_helper(inptr0, outptr, unroll_k, size, val); @@ -656,9 +626,9 @@ static inline void interleave_2(const T*& inptr0, const T*& inptr1, T*& outptr, } template -static inline void interleave_4(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, T*& outptr, - int unroll_k, int ksize, T val = 0) { +static inline void interleave_4( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr, int unroll_k, int ksize, T val = 0) { for (int k = 0; k < ksize; k += unroll_k) { int size = std::min(unroll_k, ksize - k); interleave_helper(inptr0, outptr, unroll_k, size, val); @@ -669,10 +639,10 @@ static inline void interleave_4(const T*& inptr0, const T*& inptr1, } template -static inline void interleave_6(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, T*& outptr, - int unroll_k, int ksize, T val = 0) { +static inline void interleave_6( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, T*& outptr, int unroll_k, int ksize, + T val = 0) { for (int k = 0; k < ksize; k += unroll_k) { int size = std::min(unroll_k, ksize - k); interleave_helper(inptr0, outptr, unroll_k, size, val); @@ -684,11 +654,10 @@ static inline void interleave_6(const T*& inptr0, const T*& inptr1, } } template -static inline void interleave_8(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, T*& outptr, - int unroll_k, int ksize, T val = 0) { +static inline void interleave_8( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T*& outptr, int unroll_k, int ksize, T val = 0) { for (int k = 0; k < ksize; k += unroll_k) { int size = std::min(unroll_k, ksize - k); interleave_helper(inptr0, outptr, unroll_k, size, val); @@ -713,11 +682,10 @@ static inline void interleave_8(const T*& inptr0, const T*& inptr1, * rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[i, j] */ template -static inline void transpose_8x8_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T* outptr) { +static inline void transpose_8x8_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T* outptr) { static_assert( std::is_same::value || std::is_same::value, "transpose_8x8_1_b only support uint8_t and int8_t"); @@ -754,18 +722,17 @@ static inline void transpose_8x8_1_b(const T*& inptr0, const T*& inptr1, "vst1.32 {d5}, [%[outptr]]!\n" // A6B6C6D6E6F6G6H6 "vst1.32 {d3}, [%[outptr]]!\n" // A7B7C7D7E7F7G7H7 "vst1.32 {d7}, [%[outptr]]!\n" // A8B8C8D8E8F8G8H8 - : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "q3", "cc", "memory"); } template -static inline void transpose_8x4_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T* outptr) { +static inline void transpose_8x4_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T* outptr) { static_assert( std::is_same::value || std::is_same::value, "transpose_8x4_1_b only support uint8_t and int8_t"); @@ -790,21 +757,18 @@ static inline void transpose_8x4_1_b(const T*& inptr0, const T*& inptr1, "vst1.32 {d2}, [%[outptr]]!\n" "vst1.32 {d1}, [%[outptr]]!\n" "vst1.32 {d3}, [%[outptr]]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "q0", "q1", "memory"); } template -static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - const T*& inptr8, const T*& inptr9, - const T*& inptr10, const T*& inptr11, - int ldin, T*& outptr) { +static inline void transpose_12x4_1_h( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + const T*& inptr8, const T*& inptr9, const T*& inptr10, const T*& inptr11, + int ldin, T*& outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_12x4_1_h only support uint16_t and int16_t"); @@ -849,14 +813,13 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1, "vst1.16 {d3}, [%[outptr]]!\n" // G0G1G2G3 "vst1.16 {d7}, [%[outptr]]!\n" // H0H1H2H3 "vst1.16 {d11}, [%[outptr]]!\n" // H0H1H2H3 - : - [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), - [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), - [inptr9] "+r"(inptr9), [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [inptr8] "+r"(inptr8), + [inptr9] "+r"(inptr9), [outptr] "+r"(outptr) : [ldin_asm] "r"(ldin_asm) - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "memory"); inptr9 -= ldin_asm; inptr9 += 4; inptr10 += 4; @@ -864,11 +827,10 @@ static inline void transpose_12x4_1_h(const T*& inptr0, const T*& inptr1, } template -static inline void transpose_2x16_1_b_helper(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T* outptr) { +static inline void transpose_2x16_1_b_helper( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T* outptr) { static_assert(sizeof(T) == 1, "only support size == 1"); static uint8x8_t shuffle_idx = {0, 2, 4, 6, 1, 3, 5, 7}; asm volatile( @@ -890,21 +852,19 @@ static inline void transpose_2x16_1_b_helper(const T*& inptr0, const T*& inptr1, "vst1.64 d1, [%[outptr]], r0\n" "vst1.64 d3, [%[outptr]]\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), - [outptr] "+r"(outptr), [shuffle_idx] "+w"(shuffle_idx) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr), + [shuffle_idx] "+w"(shuffle_idx) : : "q0", "q1", "q2", "r0", "memory"); } template -static inline void transpose_4x8_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T* outptr) { +static inline void transpose_4x8_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T* outptr) { static uint8x8_t shuffle_idx = {0, 4, 1, 5, 2, 6, 3, 7}; static_assert( std::is_same::value || std::is_same::value, @@ -933,21 +893,19 @@ static inline void transpose_4x8_1_b(const T*& inptr0, const T*& inptr1, "vst1.32 {d6}, [%[outptr]]!\n" // A2B2C2D2E2F2G2H2 "vst1.32 {d5}, [%[outptr]]!\n" // A3B3C3D3E3F3G3H3 "vst1.32 {d7}, [%[outptr]]!\n" // A4B4C4D4E4F4G4H4 - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), - [outptr] "+r"(outptr), [shuffle_idx] "+w"(shuffle_idx) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr), + [shuffle_idx] "+w"(shuffle_idx) : : "q0", "q1", "q2", "q3", "cc", "memory"); } template -static inline void transpose_4x16_1_b_helper(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - T* outptr) { +static inline void transpose_4x16_1_b_helper( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T* outptr) { static_assert(sizeof(T) == 1, "only support size == 1"); static uint8x8_t shuffle_idx = {0, 4, 1, 5, 2, 6, 3, 7}; asm volatile( @@ -976,21 +934,19 @@ static inline void transpose_4x16_1_b_helper(const T*& inptr0, const T*& inptr1, "vst1.64 d5, [%[outptr]], r0\n" "vst1.64 d7, [%[outptr]]\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), - [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), - [outptr] "+r"(outptr), [shuffle_idx] "+w"(shuffle_idx) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr), + [shuffle_idx] "+w"(shuffle_idx) : : "q0", "q1", "q2", "q3", "q4", "r0", "memory"); } template -static inline void transpose_4x4_1_h(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr, int stride = 8) { - static_assert(sizeof(T) == 2, - "transpose_4x4_1_h only support sizeof(T) == 2"); +static inline void transpose_4x4_1_h( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr, int stride = 8) { + static_assert(sizeof(T) == 2, "transpose_4x4_1_h only support sizeof(T) == 2"); asm volatile( "vld1.16 {d0}, [%[inptr0]]!\n" // A0A1A2A3 @@ -999,24 +955,22 @@ static inline void transpose_4x4_1_h(const T*& inptr0, const T*& inptr1, "vld1.16 {d3}, [%[inptr3]]!\n" // D0D1D2D3 "vtrn.16 d0, d1\n" // A0B0A2B2A1B1A3B3 "vtrn.16 d2, d3\n" // C0D0C2D2C1D1C3D3 - "vtrn.32 q0, q1\n" // A0B0C0D0 A1B1C1D1 A2B2C2D2 A3B3C3D3 + "vtrn.32 q0, q1\n" // A0B0C0D0 A1B1C1D1 A2B2C2D2 A3B3C3D3 "vst1.16 {d0}, [%[outptr]], %[stride]\n" // A0B0C0D0 "vst1.16 {d1}, [%[outptr]], %[stride]\n" // A1B1C1D1 "vst1.16 {d2}, [%[outptr]], %[stride]\n" // A2B2C2D2 "vst1.16 {d3}, [%[outptr]], %[stride]\n" // A3B3C3D3 - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : [stride] "r"(stride) : "d0", "d1", "d2", "d3", "memory"); } template -static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr, int stride = 16) { - static_assert(sizeof(T) == 4, - "transpose_4x4_1_s only support sizeof(T) == 4"); +static inline void transpose_4x4_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr, int stride = 16) { + static_assert(sizeof(T) == 4, "transpose_4x4_1_s only support sizeof(T) == 4"); stride -= 8; asm volatile( @@ -1034,19 +988,17 @@ static inline void transpose_4x4_1_s(const T*& inptr0, const T*& inptr1, "vst1.32 {d5}, [%[outptr]], %[stride]\n" "vst1.32 {d3}, [%[outptr]]!\n" "vst1.32 {d7}, [%[outptr]], %[stride]\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr), [stride] "+r"(stride) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr), [stride] "+r"(stride) : : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "memory"); } template -static inline void transpose_4x2_1_s(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T* outptr, int stride = 8) { - static_assert(sizeof(T) == 4, - "transpose_4x2_1_s only support sizeof(T) == 4"); +static inline void transpose_4x2_1_s( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T* outptr, int stride = 8) { + static_assert(sizeof(T) == 4, "transpose_4x2_1_s only support sizeof(T) == 4"); stride -= 8; asm volatile( @@ -1060,17 +1012,16 @@ static inline void transpose_4x2_1_s(const T*& inptr0, const T*& inptr1, "vst1.32 {d2}, [%[outptr]]!\n" "vst1.32 {d1}, [%[outptr]]!\n" "vst1.32 {d3}, [%[outptr]]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr), [stride] "+r"(stride) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr), [stride] "+r"(stride) : : "d0", "d1", "d2", "d3", "memory"); } template -static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T* outptr) { +static inline void transpose_6x4_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T* outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_6x4_1_b only support uint8_t and int8_t"); @@ -1099,17 +1050,16 @@ static inline void transpose_6x4_1_b(const T*& inptr0, const T*& inptr1, "vst1.32 {d0[1]},[%[outptr]]!\n" "vst1.32 {d1[1]},[%[outptr]]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "memory"); } template -static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T* outptr) { +static inline void transpose_4x4_1_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T* outptr) { static_assert( std::is_same::value || std::is_same::value, "interleave_4x4_1_b only support uint8_t and int8_t"); @@ -1135,17 +1085,15 @@ static inline void transpose_4x4_1_b(const T*& inptr0, const T*& inptr1, "vst1.32 {d2[0]},[%[outptr]]!\n" "vst1.32 {d3[0]},[%[outptr]]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) : : "q0", "q1", "q2", "memory"); } template static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { - static_assert(sizeof(T) == 4, - "transpose_1x12_4_s only support sizeof(T) == 4"); + static_assert(sizeof(T) == 4, "transpose_1x12_4_s only support sizeof(T) == 4"); asm volatile( "vld4.32 {d0-d3}, [%[inptr0]]!\n" @@ -1175,14 +1123,13 @@ static inline void transpose_1x12_4_s(const T*& inptr0, T* outptr) { "vst1.32 {d22-d23}, [%[outptr]]! \n" : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "memory"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "memory"); } template static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { - static_assert(sizeof(T) == 4, - "transpose_1x4_4_s only support sizeof(T) == 4"); + static_assert(sizeof(T) == 4, "transpose_1x4_4_s only support sizeof(T) == 4"); asm volatile( "vld4.32 {d0-d3}, [%[inptr0]]!\n" "vld4.32 {d4-d7}, [%[inptr0]]!\n" @@ -1198,9 +1145,9 @@ static inline void transpose_1x4_4_s(const T*& inptr0, T* outptr) { } template -static inline void transpose_4(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, T* outptr, - int interleave, int size, T val = 0) { +static inline void transpose_4( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T* outptr, int interleave, int size, T val = 0) { megdnn_assert(size <= interleave); int i = 0; for (; i < size; i++) { @@ -1218,11 +1165,10 @@ static inline void transpose_4(const T*& inptr0, const T*& inptr1, } template -static inline void transpose_8(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, T* outptr, - int interleave, int size, T val = 0) { +static inline void transpose_8( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + T* outptr, int interleave, int size, T val = 0) { megdnn_assert(size <= interleave); int i = 0; for (; i < size; i++) { @@ -1248,9 +1194,9 @@ static inline void transpose_8(const T*& inptr0, const T*& inptr1, } template -static inline void transpose_4x1(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - T*& outptr) { +static inline void transpose_4x1( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T*& outptr) { *outptr++ = *inptr0++; *outptr++ = *inptr1++; *outptr++ = *inptr2++; @@ -1258,13 +1204,11 @@ static inline void transpose_4x1(const T*& inptr0, const T*& inptr1, } template -static inline void transpose_12x1(const T*& inptr0, const T*& inptr1, - const T*& inptr2, const T*& inptr3, - const T*& inptr4, const T*& inptr5, - const T*& inptr6, const T*& inptr7, - const T*& inptr8, const T*& inptr9, - const T*& inptr10, const T*& inptr11, - T*& outptr) { +static inline void transpose_12x1( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + const T*& inptr4, const T*& inptr5, const T*& inptr6, const T*& inptr7, + const T*& inptr8, const T*& inptr9, const T*& inptr10, const T*& inptr11, + T*& outptr) { *outptr++ = *inptr0++; *outptr++ = *inptr1++; *outptr++ = *inptr2++; @@ -1282,13 +1226,11 @@ static inline void transpose_12x1(const T*& inptr0, const T*& inptr1, /***********************************Transpose interleave *************/ //! pack form {1, 4(icb), 4(ic), 4(oc)} to {1, 1, 4(oc), 16(ic)} template -static inline void transpose_interleave_4x4_4_b(const T*& inptr0, - const T*& inptr1, - const T*& inptr2, - const T*& inptr3, T* outptr, - int stride = 64) { - static_assert(sizeof(T) == 1, - "transpose_interleave_4x4_4_b only support sizeof(T) == 1"); +static inline void transpose_interleave_4x4_4_b( + const T*& inptr0, const T*& inptr1, const T*& inptr2, const T*& inptr3, + T* outptr, int stride = 64) { + static_assert( + sizeof(T) == 1, "transpose_interleave_4x4_4_b only support sizeof(T) == 1"); asm volatile( "add r1, %[outptr], %[stride]\n" @@ -1338,19 +1280,18 @@ static inline void transpose_interleave_4x4_4_b(const T*& inptr0, "vst1.8 d30,[r3]!\n" "vst1.8 d27,[r3]!\n" "vst1.8 d31,[r3]!\n" - : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), - [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), - [outptr] "+r"(outptr), [stride] "+r"(stride) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr), [stride] "+r"(stride) : - : "r1", "r2", "r3", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q14", "q15", "memory"); + : "r1", "r2", "r3", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q14", "q15", "memory"); } template -static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr, - int stride = 64) { - static_assert(sizeof(T) == 1, - "transpose_interleave_1x4_4_b only support sizeof(T) == 1"); +static inline void transpose_interleave_1x4_4_b( + const T*& inptr0, T* outptr, int stride = 64) { + static_assert( + sizeof(T) == 1, "transpose_interleave_1x4_4_b only support sizeof(T) == 1"); asm volatile( "vld4.8 {d0-d3},[%[inptr0]]!\n" @@ -1364,15 +1305,13 @@ static inline void transpose_interleave_1x4_4_b(const T*& inptr0, T* outptr, "vst1.8 d6, [%[outptr]]!\n" "vst1.8 d3, [%[outptr]]!\n" "vst1.8 d7, [%[outptr]]!\n" - : - [inptr0] "+r"(inptr0), [outptr] "+r"(outptr), [stride] "+r"(stride) + : [inptr0] "+r"(inptr0), [outptr] "+r"(outptr), [stride] "+r"(stride) : : "q0", "q1", "q2", "q3", "memory"); } -static inline void interleave_4x4_8x4_s8_s16(const int8_t* inptr0, - const int8_t* inptr1, - int16_t* outptr) { +static inline void interleave_4x4_8x4_s8_s16( + const int8_t* inptr0, const int8_t* inptr1, int16_t* outptr) { int8x16_t row0 = vld1q_s8(inptr0); int16x8_t row0_01 = vmovl_low_s8(row0); int16x8_t row0_23 = vmovl_high_s8(row0); @@ -1406,8 +1345,7 @@ static inline void transpos_8x4_int8(const int8_t* inptr0, int8_t* outptr) { vst1_s8(outptr + 2 * 8, input.val[2]); vst1_s8(outptr + 3 * 8, input.val[3]); } -static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr, - int count) { +static inline void memcpy_s8_s16(const int8_t* inptr, int16_t* outptr, int count) { for (; count >= 32; count -= 32) { int8x8_t in0 = vld1_s8(inptr); int8x8_t in1 = vld1_s8(inptr + 1 * 8); diff --git a/dnn/src/armv7/matrix_mul/fp16/strategy.cpp b/dnn/src/armv7/matrix_mul/fp16/strategy.cpp index 63875be5..25540336 100644 --- a/dnn/src/armv7/matrix_mul/fp16/strategy.cpp +++ b/dnn/src/armv7/matrix_mul/fp16/strategy.cpp @@ -10,8 +10,8 @@ */ #include "src/armv7/matrix_mul/fp16/strategy.h" -#include "src/armv7/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" #include "src/common/utils.h" #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -45,8 +45,9 @@ namespace { // +--+--+ - - - - +--------+--------+ // // Accumulator -void kern_4x16(const dt_float16* packA, const dt_float16* packB, int K, - dt_float16* output, int LDC, bool is_first_k, int m_remain) { +void kern_4x16( + const dt_float16* packA, const dt_float16* packB, int K, dt_float16* output, + int LDC, bool is_first_k, int m_remain) { const __fp16* a_ptr = reinterpret_cast(packA); const __fp16* b_ptr = reinterpret_cast(packB); int oddk = (K & 1); @@ -83,109 +84,107 @@ void kern_4x16(const dt_float16* packA, const dt_float16* packB, int K, STORE_LINE("18", "19", "20", "21", "2") \ STORE_LINE("22", "23", "24", "25", "3") \ "101:\n" - // clang-format on - - asm volatile( - // load accumulator C - "add r1, r0, %[LDC]\n" - "add r2, r1, %[LDC]\n" - "add r3, r2, %[LDC]\n" - - "cmp %[is_first_k], #1\n" - "beq 1f\n" LOAD_C - - "b 2f\n" - - "1:\n" - "veor.32 q5, q5, q5\n" - "veor.32 q6, q6, q6\n" - "veor.32 q7, q7, q7\n" - "veor.32 q8, q8, q8\n" - "veor.32 q9, q9, q9\n" - "veor.32 q10, q10, q10\n" - "veor.32 q11, q11, q11\n" - "veor.32 q12, q12, q12\n" - - "2: \n" - "vld1.16 {d2, d3, d4, d5}, [%[b_ptr]]!\n" - - "cmp %[K], #0\n" - "beq 4f\n" - - "3:\n" - "vld1.16 {d0, d1}, [%[a_ptr]]!\n" - "vld1.16 {d6, d7, d8, d9}, [%[b_ptr]]!\n" - "vmla.f16 q5, q1, d0[0]\n" - "vmla.f16 q6, q2, d0[0]\n" - "vmla.f16 q7, q1, d0[1]\n" - "vmla.f16 q8, q2, d0[1]\n" - "vmla.f16 q9, q1, d0[2]\n" - "vmla.f16 q10, q2, d0[2]\n" - "vmla.f16 q11, q1, d0[3]\n" - "vmla.f16 q12, q2, d0[3]\n" - - "vmla.f16 q5, q3, d1[0]\n" - "vmla.f16 q6, q4, d1[0]\n" - "vmla.f16 q7, q3, d1[1]\n" - "vmla.f16 q8, q4, d1[1]\n" - "vmla.f16 q9, q3, d1[2]\n" - "vmla.f16 q10, q4, d1[2]\n" - "vmla.f16 q11, q3, d1[3]\n" - "vmla.f16 q12, q4, d1[3]\n" - - "vld1.16 {d2, d3, d4, d5}, [%[b_ptr]]!\n" - "subs %[K], #1\n" - "bne 3b\n" - - "4:\n" - "cmp %[oddk], #1\n" - "beq 5f\n" - - // Even tail - "vld1.16 {d0, d1}, [%[a_ptr]]!\n" - "vld1.16 {d6, d7, d8, d9}, [%[b_ptr]]!\n" - "vmla.f16 q5, q1, d0[0]\n" - "vmla.f16 q6, q2, d0[0]\n" - "vmla.f16 q7, q1, d0[1]\n" - "vmla.f16 q8, q2, d0[1]\n" - "vmla.f16 q9, q1, d0[2]\n" - "vmla.f16 q10, q2, d0[2]\n" - "vmla.f16 q11, q1, d0[3]\n" - "vmla.f16 q12, q2, d0[3]\n" - - "vmla.f16 q5, q3, d1[0]\n" - "vmla.f16 q6, q4, d1[0]\n" - "vmla.f16 q7, q3, d1[1]\n" - "vmla.f16 q8, q4, d1[1]\n" - "vmla.f16 q9, q3, d1[2]\n" - "vmla.f16 q10, q4, d1[2]\n" - "vmla.f16 q11, q3, d1[3]\n" - "vmla.f16 q12, q4, d1[3]\n" - "b 6f\n" - - // odd tail - "5:\n" - "vld1.16 {d0}, [%[a_ptr]]!\n" - "vmla.f16 q5, q1, d0[0]\n" - "vmla.f16 q6, q2, d0[0]\n" - "vmla.f16 q7, q1, d0[1]\n" - "vmla.f16 q8, q2, d0[1]\n" - "vmla.f16 q9, q1, d0[2]\n" - "vmla.f16 q10, q2, d0[2]\n" - "vmla.f16 q11, q1, d0[3]\n" - "vmla.f16 q12, q2, d0[3]\n" - - "6:\n" STORE_C - - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), - [outptr] "+r"(outptr) - : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", - "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", - "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", - "d25", "r1", "r2", "r3", "r10", "cc", "memory"); + // clang-format on + + asm volatile( + // load accumulator C + "add r1, r0, %[LDC]\n" + "add r2, r1, %[LDC]\n" + "add r3, r2, %[LDC]\n" + + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.32 q5, q5, q5\n" + "veor.32 q6, q6, q6\n" + "veor.32 q7, q7, q7\n" + "veor.32 q8, q8, q8\n" + "veor.32 q9, q9, q9\n" + "veor.32 q10, q10, q10\n" + "veor.32 q11, q11, q11\n" + "veor.32 q12, q12, q12\n" + + "2: \n" + "vld1.16 {d2, d3, d4, d5}, [%[b_ptr]]!\n" + + "cmp %[K], #0\n" + "beq 4f\n" + + "3:\n" + "vld1.16 {d0, d1}, [%[a_ptr]]!\n" + "vld1.16 {d6, d7, d8, d9}, [%[b_ptr]]!\n" + "vmla.f16 q5, q1, d0[0]\n" + "vmla.f16 q6, q2, d0[0]\n" + "vmla.f16 q7, q1, d0[1]\n" + "vmla.f16 q8, q2, d0[1]\n" + "vmla.f16 q9, q1, d0[2]\n" + "vmla.f16 q10, q2, d0[2]\n" + "vmla.f16 q11, q1, d0[3]\n" + "vmla.f16 q12, q2, d0[3]\n" + + "vmla.f16 q5, q3, d1[0]\n" + "vmla.f16 q6, q4, d1[0]\n" + "vmla.f16 q7, q3, d1[1]\n" + "vmla.f16 q8, q4, d1[1]\n" + "vmla.f16 q9, q3, d1[2]\n" + "vmla.f16 q10, q4, d1[2]\n" + "vmla.f16 q11, q3, d1[3]\n" + "vmla.f16 q12, q4, d1[3]\n" + + "vld1.16 {d2, d3, d4, d5}, [%[b_ptr]]!\n" + "subs %[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %[oddk], #1\n" + "beq 5f\n" + + // Even tail + "vld1.16 {d0, d1}, [%[a_ptr]]!\n" + "vld1.16 {d6, d7, d8, d9}, [%[b_ptr]]!\n" + "vmla.f16 q5, q1, d0[0]\n" + "vmla.f16 q6, q2, d0[0]\n" + "vmla.f16 q7, q1, d0[1]\n" + "vmla.f16 q8, q2, d0[1]\n" + "vmla.f16 q9, q1, d0[2]\n" + "vmla.f16 q10, q2, d0[2]\n" + "vmla.f16 q11, q1, d0[3]\n" + "vmla.f16 q12, q2, d0[3]\n" + + "vmla.f16 q5, q3, d1[0]\n" + "vmla.f16 q6, q4, d1[0]\n" + "vmla.f16 q7, q3, d1[1]\n" + "vmla.f16 q8, q4, d1[1]\n" + "vmla.f16 q9, q3, d1[2]\n" + "vmla.f16 q10, q4, d1[2]\n" + "vmla.f16 q11, q3, d1[3]\n" + "vmla.f16 q12, q4, d1[3]\n" + "b 6f\n" + + // odd tail + "5:\n" + "vld1.16 {d0}, [%[a_ptr]]!\n" + "vmla.f16 q5, q1, d0[0]\n" + "vmla.f16 q6, q2, d0[0]\n" + "vmla.f16 q7, q1, d0[1]\n" + "vmla.f16 q8, q2, d0[1]\n" + "vmla.f16 q9, q1, d0[2]\n" + "vmla.f16 q10, q2, d0[2]\n" + "vmla.f16 q11, q1, d0[3]\n" + "vmla.f16 q12, q2, d0[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "r1", "r2", "r3", "r10", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -193,7 +192,6 @@ void kern_4x16(const dt_float16* packA, const dt_float16* packB, int K, #undef STORE_C } - // Overview of register layout: // // A 2x4 cell of Rhs is stored in 16bit in q1 @@ -218,9 +216,9 @@ void kern_4x16(const dt_float16* packA, const dt_float16* packB, int K, // +--+--+ - - - - +--------+ // // Accumulator -void kern_4x4(const dt_float16* packA, const dt_float16* packB, int K, - dt_float16* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { +void kern_4x4( + const dt_float16* packA, const dt_float16* packB, int K, dt_float16* output, + int LDC, bool is_first_k, int m_remain, int n_remain) { const __fp16* a_ptr = reinterpret_cast(packA); const __fp16* b_ptr = reinterpret_cast(packB); int oddk = (K & 1); @@ -287,90 +285,91 @@ void kern_4x4(const dt_float16* packA, const dt_float16* packB, int K, STORE_LINE("8", "2") \ STORE_LINE("10", "3") \ "105:\n" - // clang-format on - - asm volatile( - // load accumulator C - "add r1, r0, %[LDC]\n" - "add r2, r1, %[LDC]\n" - "add r3, r2, %[LDC]\n" - - "cmp %[is_first_k], #1\n" - "beq 1f\n" LOAD_C - - "b 2f\n" - - "1:\n" - "veor.32 q2, q2, q2\n" - "veor.32 q3, q3, q3\n" - "veor.32 q4, q4, q4\n" - "veor.32 q5, q5, q5\n" - - "2: \n" - "cmp %[K], #0\n" - "beq 4f\n" - - "3:\n" - "vld1.16 {d0, d1}, [%[a_ptr]]!\n" - "vld1.16 {d2, d3}, [%[b_ptr]]!\n" - "vmla.f16 d4, d2, d0[0]\n" - "vmla.f16 d6, d2, d0[1]\n" - "vmla.f16 d8, d2, d0[2]\n" - "vmla.f16 d10, d2, d0[3]\n" - - "vmla.f16 d4, d3, d1[0]\n" - "vmla.f16 d6, d3, d1[1]\n" - "vmla.f16 d8, d3, d1[2]\n" - "vmla.f16 d10, d3, d1[3]\n" - - "subs %[K], #1\n" - "bne 3b\n" - - "4:\n" - "cmp %[oddk], #1\n" - "beq 5f\n" - - // Even tail - "vld1.16 {d0, d1}, [%[a_ptr]]!\n" - "vld1.16 {d2, d3}, [%[b_ptr]]!\n" - "vmla.f16 d4, d2, d0[0]\n" - "vmla.f16 d6, d2, d0[1]\n" - "vmla.f16 d8, d2, d0[2]\n" - "vmla.f16 d10, d2, d0[3]\n" - - "vmla.f16 d4, d3, d1[0]\n" - "vmla.f16 d6, d3, d1[1]\n" - "vmla.f16 d8, d3, d1[2]\n" - "vmla.f16 d10, d3, d1[3]\n" - - "b 6f\n" - - // odd tail - "5:\n" - "vld1.16 {d0}, [%[a_ptr]]!\n" - "vld1.16 {d2}, [%[b_ptr]]!\n" - "vmla.f16 d4, d2, d0[0]\n" - "vmla.f16 d6, d2, d0[1]\n" - "vmla.f16 d8, d2, d0[2]\n" - "vmla.f16 d10, d2, d0[3]\n" - - "6:\n" STORE_C - - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), - [n_remain] "+r"(n_remain), [outptr] "+r"(outptr) - : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", - "d9", "d10", "r1", "r2", "r3", "r10", "cc", "memory"); + // clang-format on + + asm volatile( + // load accumulator C + "add r1, r0, %[LDC]\n" + "add r2, r1, %[LDC]\n" + "add r3, r2, %[LDC]\n" + + "cmp %[is_first_k], #1\n" + "beq 1f\n" LOAD_C + + "b 2f\n" + + "1:\n" + "veor.32 q2, q2, q2\n" + "veor.32 q3, q3, q3\n" + "veor.32 q4, q4, q4\n" + "veor.32 q5, q5, q5\n" + + "2: \n" + "cmp %[K], #0\n" + "beq 4f\n" + + "3:\n" + "vld1.16 {d0, d1}, [%[a_ptr]]!\n" + "vld1.16 {d2, d3}, [%[b_ptr]]!\n" + "vmla.f16 d4, d2, d0[0]\n" + "vmla.f16 d6, d2, d0[1]\n" + "vmla.f16 d8, d2, d0[2]\n" + "vmla.f16 d10, d2, d0[3]\n" + + "vmla.f16 d4, d3, d1[0]\n" + "vmla.f16 d6, d3, d1[1]\n" + "vmla.f16 d8, d3, d1[2]\n" + "vmla.f16 d10, d3, d1[3]\n" + + "subs %[K], #1\n" + "bne 3b\n" + + "4:\n" + "cmp %[oddk], #1\n" + "beq 5f\n" + + // Even tail + "vld1.16 {d0, d1}, [%[a_ptr]]!\n" + "vld1.16 {d2, d3}, [%[b_ptr]]!\n" + "vmla.f16 d4, d2, d0[0]\n" + "vmla.f16 d6, d2, d0[1]\n" + "vmla.f16 d8, d2, d0[2]\n" + "vmla.f16 d10, d2, d0[3]\n" + + "vmla.f16 d4, d3, d1[0]\n" + "vmla.f16 d6, d3, d1[1]\n" + "vmla.f16 d8, d3, d1[2]\n" + "vmla.f16 d10, d3, d1[3]\n" + + "b 6f\n" + + // odd tail + "5:\n" + "vld1.16 {d0}, [%[a_ptr]]!\n" + "vld1.16 {d2}, [%[b_ptr]]!\n" + "vmla.f16 d4, d2, d0[0]\n" + "vmla.f16 d6, d2, d0[1]\n" + "vmla.f16 d8, d2, d0[2]\n" + "vmla.f16 d10, d2, d0[3]\n" + + "6:\n" STORE_C + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain), + [outptr] "+r"(outptr) + : + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "r1", + "r2", "r3", "r10", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE #undef STORE_C } -void hgemm_4x16_pack_A_n(__fp16* outptr, const __fp16* inptr, int ldin, int y0, - int ymax, int k0, int kmax) { +void hgemm_4x16_pack_A_n( + __fp16* outptr, const __fp16* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { __fp16 zerobuff[16]; std::memset(zerobuff, 0, sizeof(__fp16) * 8); @@ -445,8 +444,8 @@ void hgemm_4x16_pack_A_n(__fp16* outptr, const __fp16* inptr, int ldin, int y0, } } -void hgemm_4x16_pack_A_t(__fp16* out, const __fp16* in, int ldin, int x0, - int xmax, int k0, int kmax) { +void hgemm_4x16_pack_A_t( + __fp16* out, const __fp16* in, int ldin, int x0, int xmax, int k0, int kmax) { int ksize = kmax - k0; int ksize4 = (ksize << 2); __fp16* outptr_base = reinterpret_cast<__fp16*>(out); @@ -467,8 +466,7 @@ void hgemm_4x16_pack_A_t(__fp16* out, const __fp16* in, int ldin, int x0, auto outptr = outptr_base; for (; x + 4 <= xmax; x += 4) { auto outptr_interleave = outptr; - interleave_4x4_1_h(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x4_1_h(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize4; } @@ -480,8 +478,7 @@ void hgemm_4x16_pack_A_t(__fp16* out, const __fp16* in, int ldin, int x0, } for (; k < kmax; k++) { - const __fp16* inptr = - reinterpret_cast(in + k * ldin + x0); + const __fp16* inptr = reinterpret_cast(in + k * ldin + x0); prefetch_3x(inptr); int x = x0; auto outptr = outptr_base; @@ -497,12 +494,10 @@ void hgemm_4x16_pack_A_t(__fp16* out, const __fp16* in, int ldin, int x0, outptr_base += 4; } - } - -void hgemm_4x16_pack_B_n(__fp16* out, const __fp16* in, int ldin, - int x0, int xmax, int k0, int kmax) { +void hgemm_4x16_pack_B_n( + __fp16* out, const __fp16* in, int ldin, int x0, int xmax, int k0, int kmax) { int ksize = kmax - k0; int ksize16 = (ksize << 4); int ksize4 = (ksize << 2); @@ -525,15 +520,13 @@ void hgemm_4x16_pack_B_n(__fp16* out, const __fp16* in, int ldin, auto outptr = outptr_base; for (; x + 16 <= xmax; x += 16) { auto outptr_interleave = outptr; - interleave_4x16_1_h(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x16_1_h(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize16; } outptr = outptr_base4; for (; x + 4 <= xmax; x += 4) { auto outptr_interleave = outptr; - interleave_4x4_1_h(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x4_1_h(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize4; } @@ -546,8 +539,7 @@ void hgemm_4x16_pack_B_n(__fp16* out, const __fp16* in, int ldin, } for (; k < kmax; k++) { - const __fp16* inptr = - reinterpret_cast(in + k * ldin + x0); + const __fp16* inptr = reinterpret_cast(in + k * ldin + x0); prefetch_3x(inptr); int x = x0; auto outptr = outptr_base; @@ -572,8 +564,8 @@ void hgemm_4x16_pack_B_n(__fp16* out, const __fp16* in, int ldin, } } -void hgemm_4x16_pack_B_t(__fp16* out, const __fp16* in, int ldin, - int y0, int ymax, int k0, int kmax) { +void hgemm_4x16_pack_B_t( + __fp16* out, const __fp16* in, int ldin, int y0, int ymax, int k0, int kmax) { __fp16* outptr = out; const __fp16* inptr = in; __fp16 zerobuff[16]; @@ -598,8 +590,7 @@ void hgemm_4x16_pack_B_t(__fp16* out, const __fp16* in, int ldin, int x = (kmax - k0); for (; x > 3; x -= 4) { - transpose_4x4_1_h(inptr0, inptr1, inptr2, inptr3, outptr_inner, - 32); + transpose_4x4_1_h(inptr0, inptr1, inptr2, inptr3, outptr_inner, 32); } for (; x > 0; x--) { *outptr_inner++ = *inptr0++; @@ -669,38 +660,41 @@ void hgemm_4x16_pack_B_t(__fp16* out, const __fp16* in, int ldin, MEGDNN_REG_GEMM_STRATEGY_IMPL(hgemm_4x16); -void hgemm_4x16::pack_A(dt_float16* out, const dt_float16* in, int ldin, int y0, - int ymax, int k0, int kmax, bool transpose_A) const { +void hgemm_4x16::pack_A( + dt_float16* out, const dt_float16* in, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose_A) const { if (transpose_A) { - hgemm_4x16_pack_A_t(reinterpret_cast<__fp16*>(out), - reinterpret_cast(in), ldin, y0, ymax, - k0, kmax); + hgemm_4x16_pack_A_t( + reinterpret_cast<__fp16*>(out), reinterpret_cast(in), + ldin, y0, ymax, k0, kmax); } else { - hgemm_4x16_pack_A_n(reinterpret_cast<__fp16*>(out), - reinterpret_cast(in), ldin, y0, ymax, - k0, kmax); + hgemm_4x16_pack_A_n( + reinterpret_cast<__fp16*>(out), reinterpret_cast(in), + ldin, y0, ymax, k0, kmax); } } -void hgemm_4x16::pack_B(dt_float16* out, const dt_float16* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose_B) const { +void hgemm_4x16::pack_B( + dt_float16* out, const dt_float16* in, int ldin, int x0, int xmax, int k0, + int kmax, bool transpose_B) const { if (transpose_B) { - hgemm_4x16_pack_B_t(reinterpret_cast<__fp16*>(out), - reinterpret_cast(in), ldin, x0, xmax, - k0, kmax); + hgemm_4x16_pack_B_t( + reinterpret_cast<__fp16*>(out), reinterpret_cast(in), + ldin, x0, xmax, k0, kmax); } else { - hgemm_4x16_pack_B_n(reinterpret_cast<__fp16*>(out), - reinterpret_cast(in), ldin, x0, xmax, - k0, kmax); + hgemm_4x16_pack_B_n( + reinterpret_cast<__fp16*>(out), reinterpret_cast(in), + ldin, x0, xmax, k0, kmax); } } -void hgemm_4x16::kern(const dt_float16* packA, const dt_float16* packB, - size_t M, size_t N, size_t K, dt_float16* C, size_t LDC, - bool is_first_k, const dt_float16*, dt_float16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == C_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Float16); +void hgemm_4x16::kern( + const dt_float16* packA, const dt_float16* packB, size_t M, size_t N, size_t K, + dt_float16* C, size_t LDC, bool is_first_k, const dt_float16*, + dt_float16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float16); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); @@ -717,15 +711,17 @@ void hgemm_4x16::kern(const dt_float16* packA, const dt_float16* packB, size_t n = 0; const dt_float16* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - kern_4x16(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4)); + kern_4x16( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K16; } for (; n < N; n += 4) { - kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), std::min(N - n, 4)); + kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); output += 4; cur_packB += K4; } diff --git a/dnn/src/armv7/matrix_mul/fp16/strategy.h b/dnn/src/armv7/matrix_mul/fp16/strategy.h index f25eeb5d..a45def79 100644 --- a/dnn/src/armv7/matrix_mul/fp16/strategy.h +++ b/dnn/src/armv7/matrix_mul/fp16/strategy.h @@ -16,11 +16,11 @@ namespace megdnn { namespace armv7 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 4, 16, 1, false, - true, hgemm_4x16); +MEGDNN_REG_GEMM_STRATEGY( + dt_float16, dt_float16, dt_float16, 4, 16, 1, false, true, hgemm_4x16); -MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 4, 8, 1, - false, true, gemm_nopack_f16_4x8); +MEGDNN_REG_GEMM_STRATEGY_NOPACK( + dt_float16, dt_float16, dt_float16, 4, 8, 1, false, true, gemm_nopack_f16_4x8); } // namespace matmul } // namespace armv7 diff --git a/dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp b/dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp index d4c951da..9c19173b 100644 --- a/dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp +++ b/dnn/src/armv7/matrix_mul/fp16/strategy_mk8_4x8.cpp @@ -9,9 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/armv7/matrix_mul/fp16/strategy.h" -#include "src/armv7/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/armv7/matrix_mul/fp16/strategy.h" #include "src/common/utils.h" #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -21,8 +21,9 @@ using namespace armv7::matmul; namespace { -void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, - dt_float16* output) { +void kern_8x1( + const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, + dt_float16* output) { LDB = (LDB - 4) * sizeof(dt_float16); asm volatile( "subs %[K], #8\n" @@ -76,9 +77,9 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", - "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", - "d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory"); + : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", + "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", + "d27", "d28", "d29", "d30", "d31", "cc", "memory"); } // Overview of register layout: @@ -105,8 +106,9 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, // | v3[0-7]| |v15[0-7]| // +--------+ +--------+--------+ // Accumulator -void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, - dt_float16* output) { +void kern_8x4( + const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, + dt_float16* output) { //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 //! here. LDB = (LDB - 16) * sizeof(dt_float16); @@ -221,21 +223,20 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "cc", "memory"); } } // anonymous namespace MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_4x8); -void gemm_nopack_f16_4x8::kern(const dt_float16* A, size_t LDA, - const dt_float16* B, size_t LDB, dt_float16* C, - size_t LDC, size_t M, size_t K, size_t N, - const dt_float16*, void*, bool trA, - bool trB) const { +void gemm_nopack_f16_4x8::kern( + const dt_float16* A, size_t LDA, const dt_float16* B, size_t LDB, dt_float16* C, + size_t LDC, size_t M, size_t K, size_t N, const dt_float16*, void*, bool trA, + bool trB) const { constexpr static size_t MB = 8; constexpr static size_t KB = 8; constexpr static size_t NB = 4; diff --git a/dnn/src/armv7/matrix_mul/fp32/strategy.h b/dnn/src/armv7/matrix_mul/fp32/strategy.h index 7bba0a83..16c170c4 100644 --- a/dnn/src/armv7/matrix_mul/fp32/strategy.h +++ b/dnn/src/armv7/matrix_mul/fp32/strategy.h @@ -15,14 +15,13 @@ namespace megdnn { namespace armv7 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, - sgemm_4x12); +MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, true, sgemm_4x12); -MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 12, 1, false, false, - sgemm_mk4_pack_4x12); +MEGDNN_REG_GEMM_STRATEGY( + float, float, float, 4, 12, 1, false, false, sgemm_mk4_pack_4x12); -MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 8, 1, false, true, - sgemm_nopack_4x8); +MEGDNN_REG_GEMM_STRATEGY_NOPACK( + float, float, float, 4, 8, 1, false, true, sgemm_nopack_4x8); } // namespace matmul } // namespace armv7 diff --git a/dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp b/dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp index c8c00643..ec21f210 100644 --- a/dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp +++ b/dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp @@ -42,8 +42,9 @@ namespace { // +--+ - - - - +--------+--------+--------+ // // Accumulator -void kern_4x12(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k, int m_remain) { +void kern_4x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -208,15 +209,14 @@ void kern_4x12(const float* packA, const float* packB, int K, float* output, "6:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "r1", "r2", "r3", "r9", "r10", "cc", - "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "r1", "r2", "r3", "r9", "r10", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -248,8 +248,9 @@ void kern_4x12(const float* packA, const float* packB, int K, float* output, // +--+--+ - - - - +--------+ // // Accumulator -void kern_4x4(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k, int m_remain, int n_remain) { +void kern_4x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { const float* a_ptr = packA; const float* b_ptr = packB; int oddk = (K & 1); @@ -386,22 +387,22 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, "6:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain), [outptr] "+r"(outptr) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "r1", "r2", "r3", "r10", "cc", - "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "r1", "r2", "r3", "r10", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE #undef STORE_C } -void sgemm_4x12_pack_A_n(float* outptr, const float* inptr, int ldin, int y0, - int ymax, int k0, int kmax) { +void sgemm_4x12_pack_A_n( + float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { float zerobuff[4]; std::memset(zerobuff, 0, sizeof(float) * 4); @@ -480,8 +481,8 @@ void sgemm_4x12_pack_A_n(float* outptr, const float* inptr, int ldin, int y0, } } -void sgemm_4x12_pack_A_t(float* out, const float* in, int ldin, int x0, - int xmax, int k0, int kmax) { +void sgemm_4x12_pack_A_t( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { int ksize = kmax - k0; int ksize4 = (ksize << 2); float* outptr_base = out; @@ -502,8 +503,7 @@ void sgemm_4x12_pack_A_t(float* out, const float* in, int ldin, int x0, auto outptr = outptr_base; for (; x + 4 <= xmax; x += 4) { auto outptr_interleave = outptr; - interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize4; } @@ -533,8 +533,8 @@ void sgemm_4x12_pack_A_t(float* out, const float* in, int ldin, int x0, } } -void sgemm_4x12_pack_B_n(float* out, const float* in, int ldin, int x0, - int xmax, int k0, int kmax) { +void sgemm_4x12_pack_B_n( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { int ksize = kmax - k0; int ksize12 = ksize * 12; int ksize4 = (ksize << 2); @@ -557,15 +557,13 @@ void sgemm_4x12_pack_B_n(float* out, const float* in, int ldin, int x0, auto outptr = outptr_base; for (; x + 12 <= xmax; x += 12) { auto outptr_interleave = outptr; - interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize12; } outptr = outptr_base4; for (; x + 4 <= xmax; x += 4) { auto outptr_interleave = outptr; - interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize4; } @@ -603,8 +601,8 @@ void sgemm_4x12_pack_B_n(float* out, const float* in, int ldin, int x0, } } -void sgemm_4x12_pack_B_t(float* out, const float* in, int ldin, int y0, - int ymax, int k0, int kmax) { +void sgemm_4x12_pack_B_t( + float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { float* outptr = out; const float* inptr = in; float zerobuff[4]; @@ -629,8 +627,7 @@ void sgemm_4x12_pack_B_t(float* out, const float* in, int ldin, int y0, int x = (kmax - k0); for (; x > 3; x -= 4) { - transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, - 48); + transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, 48); } for (; x > 0; x--) { *outptr_inner++ = *inptr0++; @@ -704,8 +701,9 @@ void sgemm_4x12_pack_B_t(float* out, const float* in, int ldin, int y0, MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x12); -void sgemm_4x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax, - int k0, int kmax, bool transpose_A) const { +void sgemm_4x12::pack_A( + float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose_A) const { if (transpose_A) { sgemm_4x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); } else { @@ -713,8 +711,9 @@ void sgemm_4x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax, } } -void sgemm_4x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, - int k0, int kmax, bool transpose_B) const { +void sgemm_4x12::pack_B( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose_B) const { if (transpose_B) { sgemm_4x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); } else { @@ -722,12 +721,12 @@ void sgemm_4x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, } } -void sgemm_4x12::kern(const float* packA, const float* packB, size_t M, - size_t N, size_t K, float* C, size_t LDC, bool is_first_k, - const float*, float*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == C_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Float32); +void sgemm_4x12::kern( + const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, + size_t LDC, bool is_first_k, const float*, float*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); MEGDNN_MARK_USED_VAR(C_dtype); @@ -744,15 +743,17 @@ void sgemm_4x12::kern(const float* packA, const float* packB, size_t M, size_t n = 0; const float* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - kern_4x12(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4)); + kern_4x12( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K12; } for (; n < N; n += 4) { - kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), std::min(N - n, 4)); + kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); output += 4; cur_packB += K4; } diff --git a/dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp b/dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp index e7fcf67f..6dc6f290 100644 --- a/dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp +++ b/dnn/src/armv7/matrix_mul/fp32/strategy_mk4_4x8.cpp @@ -9,9 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/armv7/matrix_mul/fp32/strategy.h" -#include "src/armv7/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/armv7/matrix_mul/fp32/strategy.h" #include "src/common/utils.h" using namespace megdnn; @@ -65,11 +65,10 @@ void kern_4x1(const float* A, const float* B, size_t LDB, size_t K, float* C) { "vst1.32 {d16, d17}, [%[C]]!\n" - : [ A ] "+r"(A), [ B ] "+r"(B), [ K ] "+r"(K), [ C ] "+r"(C) - : [ LDB ] "r"(LDB) - : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", - "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "cc", - "memory"); + : [A] "+r"(A), [B] "+r"(B), [K] "+r"(K), [C] "+r"(C) + : [LDB] "r"(LDB) + : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", + "d17", "d18", "d19", "d20", "d21", "d22", "d23", "cc", "memory"); } // Overview of register layout: @@ -171,9 +170,9 @@ void kern_4x4(const float* A, const float* B, size_t LDB, size_t K, float* C) { : [A] "+r"(A), [B] "+r"(B), [K] "+r"(K), [C] "+r"(C) : [LDB] "r"(LDB) - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "cc", "memory"); } // Overview of register layout: @@ -303,20 +302,19 @@ void kern_4x8(const float* A, const float* B, size_t LDB, size_t K, float* C) { "vst1.32 {d28, d29, d30, d31}, [%[C]]!\n" : [A] "+r"(A), [B] "+r"(B), [K] "+r"(K), [C] "+r"(C) : [LDB] "r"(LDB) - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "cc", "memory"); } } // namespace MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x8); -void sgemm_nopack_4x8::kern(const float* A, size_t LDA, const float* B, - size_t LDB, float* C, size_t LDC, size_t M, - size_t K, size_t N, const float*, void*, bool trA, - bool trB) const { +void sgemm_nopack_4x8::kern( + const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC, + size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const { constexpr size_t MB = 4; constexpr size_t KB = 4; constexpr size_t NB = 8; diff --git a/dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp b/dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp index 233eb302..1908c730 100644 --- a/dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp +++ b/dnn/src/armv7/matrix_mul/fp32/strategy_mk_4x12.cpp @@ -9,9 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/armv7/matrix_mul/fp32/strategy.h" -#include "src/armv7/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/armv7/matrix_mul/fp32/strategy.h" #include "src/common/utils.h" using namespace megdnn; @@ -42,8 +42,9 @@ namespace { // +--+ - - - - +--------+--------+--------+ // // Accumulator -void kern_4x12(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k) { +void kern_4x12( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k) { MEGDNN_MARK_USED_VAR(LDC); const float* a_ptr = packA; const float* b_ptr = packB; @@ -193,17 +194,13 @@ void kern_4x12(const float* packA, const float* packB, int K, float* output, "vst1.32 {d28-d31}, [%[output0]]!\n" "6:\n" - : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), - [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), - [ output0 ] "+r"(output0) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output0] "+r"(output0) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "q12", "q13", "q14", "q15", "r1", "cc", "memory"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "r1", "cc", "memory"); } - - - // Overview of register layout: // // A 2x4 cell of Rhs is stored in 32bit in v2 - v3 @@ -227,8 +224,9 @@ void kern_4x12(const float* packA, const float* packB, int K, float* output, // +--+ --- - +--------+ // // Accumulator -void kern_4x4(const float* packA, const float* packB, int K, float* output, - int LDC, bool is_first_k, int n_remain) { +void kern_4x4( + const float* packA, const float* packB, int K, float* output, int LDC, + bool is_first_k, int n_remain) { MEGDNN_MARK_USED_VAR(LDC); const float* a_ptr = packA; const float* b_ptr = packB; @@ -278,7 +276,7 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, "23:\n" \ "vst1.32 {d8-d9}, [%[output]]!\n" \ "24:\n" -//clang-format on + //clang-format on asm volatile( "cmp %[is_first_k], #1\n" @@ -344,12 +342,11 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, "vmla.f32 q7, q0, d5[1]\n" "6:\n" STORE_C - : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), - [ is_first_k ] "+r"(is_first_k), [ oddk ] "+r"(oddk), - [ output ] "+r"(output), [ n_remain ] "+r"(n_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), + [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), [output] "+r"(output), + [n_remain] "+r"(n_remain) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "r1", "cc", - "memory"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "r1", "cc", "memory"); #undef LOAD_C #undef STORE_C } @@ -359,8 +356,9 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_pack_4x12); //! Now no matmul mode of only packB support in conv1x1 and im2col, so just copy //! the weight -void sgemm_mk4_pack_4x12::pack_A(float* out, const float* in, int ldin, int y0, - int ymax, int k0, int kmax, bool) const { +void sgemm_mk4_pack_4x12::pack_A( + float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, + bool) const { megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); constexpr int PACK_C_SIZE = 4; @@ -372,9 +370,9 @@ void sgemm_mk4_pack_4x12::pack_A(float* out, const float* in, int ldin, int y0, } } -void sgemm_mk4_pack_4x12::pack_B(float* out, const float* in, int ldin, int x0, - int xmax, int k0, int kmax, - bool transpose_B) const { +void sgemm_mk4_pack_4x12::pack_B( + float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose_B) const { megdnn_assert(!transpose_B); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); float tmpbuff[16] = {0.0f}; @@ -416,12 +414,12 @@ void sgemm_mk4_pack_4x12::pack_B(float* out, const float* in, int ldin, int x0, } } -void sgemm_mk4_pack_4x12::kern(const float* packA, const float* packB, size_t M, - size_t N, size_t K, float* C, size_t LDC, - bool is_first_k, const float*, float*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == C_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Float32); +void sgemm_mk4_pack_4x12::kern( + const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, + size_t LDC, bool is_first_k, const float*, float*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Float32); constexpr int PACK_C_SIZE = 4; constexpr size_t A_INTERLEAVE = 4; constexpr size_t B_INTERLEAVE = 12; @@ -439,8 +437,9 @@ void sgemm_mk4_pack_4x12::kern(const float* packA, const float* packB, size_t M, cur_packB += K12; } for (; n < N; n += 4) { - kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(N - n, 4)); + kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 4)); output += PACK_C_SIZE * 4; cur_packB += K4; } diff --git a/dnn/src/armv7/matrix_mul/int16x16x32/kernel_12x4x1.h b/dnn/src/armv7/matrix_mul/int16x16x32/kernel_12x4x1.h index ac4e0236..6b3ff057 100644 --- a/dnn/src/armv7/matrix_mul/int16x16x32/kernel_12x4x1.h +++ b/dnn/src/armv7/matrix_mul/int16x16x32/kernel_12x4x1.h @@ -45,8 +45,9 @@ namespace matmul_12x4x1 { * * Accumulator */ -static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { +static void kern_12x4( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k) { const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; @@ -184,14 +185,14 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, [outptr_row4] "+r"(outptr_row4), [outptr_row6] "+r"(outptr_row6), [outptr_row8] "+r"(outptr_row8), [outptr_row10] "+r"(outptr_row10) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", - "d12", "d13", "d14", "d15", "d16", "d18", "d20", "d21", "d22", - "d24", "d26", "d28", "d30", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", "d12", + "d13", "d14", "d15", "d16", "d18", "d20", "d21", "d22", "d24", "d26", + "d28", "d30", "cc", "memory"); } -static void kern_12x123(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - size_t n_remain) { +static void kern_12x123( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t n_remain) { const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; int asmLDC = LDC * sizeof(int32_t); @@ -525,14 +526,15 @@ static void kern_12x123(const int16_t* packA, const int16_t* packB, int K, [outptr_row6] "+r"(outptr_row6), [outptr_row8] "+r"(outptr_row8), [outptr_row10] "+r"(outptr_row10) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", - "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", - "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", - "d30", "d31", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", "d12", + "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", "d22", + "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", + "memory"); } -static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { +static void kern_4x4( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k) { const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; @@ -574,14 +576,14 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, [outptr_row1] "+r"(outptr_row1), [outptr_row2] "+r"(outptr_row2), [outptr_row3] "+r"(outptr_row3) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", - "d12", "d13", "d14", "d15", "d16", "d18", "d20", "d21", "d22", - "d24", "d26", "d28", "d30", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", "d12", + "d13", "d14", "d15", "d16", "d18", "d20", "d21", "d22", "d24", "d26", + "d28", "d30", "cc", "memory"); } -static void kern_4x123(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - int n_remain) { +static void kern_4x123( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int n_remain) { const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; @@ -674,13 +676,14 @@ static void kern_4x123(const int16_t* packA, const int16_t* packB, int K, [outptr_row0] "+r"(outptr_row0), [outptr_row1] "+r"(outptr_row1), [outptr_row2] "+r"(outptr_row2), [outptr_row3] "+r"(outptr_row3) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", - "d12", "d13", "d14", "d15", "d16", "d18", "d20", "d21", "d22", - "d24", "d26", "d28", "d30", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d8", "d9", "d10", "d11", "d12", + "d13", "d14", "d15", "d16", "d18", "d20", "d21", "d22", "d24", "d26", + "d28", "d30", "cc", "memory"); } -static void kern_1x4(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k) { +static void kern_1x4( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k) { MEGDNN_MARK_USED_VAR(LDC); const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; @@ -713,9 +716,9 @@ static void kern_1x4(const int16_t* packA, const int16_t* packB, int K, *this kern can hanle 1xk mul kx1 kx2 kx3 get 1x1 1x2 1x3 *123 stands for n remain 1 2 3 ************************************************/ -static void kern_1x123(const int16_t* packA, const int16_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - int n_remain) { +static void kern_1x123( + const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int n_remain) { MEGDNN_MARK_USED_VAR(LDC); const int16_t* a_ptr = packA; const int16_t* b_ptr = packB; @@ -771,9 +774,9 @@ static void kern_1x123(const int16_t* packA, const int16_t* packB, int K, : "d0", "d3", "d8", "d9", "cc", "memory"); } -static void gemm_s16x16x32_12x4_pack_A_n(dt_int16* outptr, - const dt_int16* inptr, int ldin, - int y0, int ymax, int k0, int kmax) { +static void gemm_s16x16x32_12x4_pack_A_n( + dt_int16* outptr, const dt_int16* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int y = y0; int K = kmax - k0; for (; y + 11 < ymax; y += 12) { @@ -792,15 +795,15 @@ static void gemm_s16x16x32_12x4_pack_A_n(dt_int16* outptr, int k = k0; for (; k + 3 < kmax; k += 4) { - transpose_12x4_1_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, - ldin, outptr); + transpose_12x4_1_h( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + inptr8, inptr9, inptr10, inptr11, ldin, outptr); } for (; k < kmax; k++) { - transpose_12x1(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, - outptr); + transpose_12x1( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + inptr8, inptr9, inptr10, inptr11, outptr); } } @@ -825,10 +828,9 @@ static void gemm_s16x16x32_12x4_pack_A_n(dt_int16* outptr, } } -static void gemm_s16x16x32_12x4_transpose_pack_A_n(dt_int16* out, - const dt_int16* in, int ldin, - int x0, int xmax, int k0, - int kmax) { +static void gemm_s16x16x32_12x4_transpose_pack_A_n( + dt_int16* out, const dt_int16* in, int ldin, int x0, int xmax, int k0, + int kmax) { const int ksize = kmax - k0; const int ksize12 = ksize * 12; const int ksize4 = ksize * 4; @@ -836,8 +838,7 @@ static void gemm_s16x16x32_12x4_transpose_pack_A_n(dt_int16* out, int16_t* outptr_interleave = out; int16_t* outptr_base = out; int16_t* outptr_times4_base = out + (xmax - x0) / 12 * ksize12; - int16_t* outptr_times1_base = - outptr_times4_base + ((xmax - x0) % 12) / 4 * ksize4; + int16_t* outptr_times1_base = outptr_times4_base + ((xmax - x0) % 12) / 4 * ksize4; int k = k0; for (; k + 3 < kmax; k += 4) { const int16_t* inptr0 = in + k * ldin + x0; @@ -851,15 +852,13 @@ static void gemm_s16x16x32_12x4_transpose_pack_A_n(dt_int16* out, for (; x + 11 < xmax; x += 12) { outptr_interleave = outptr; - interleave_4x12_1_h(inptr0, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x12_1_h(inptr0, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize12; } outptr = outptr_times4_base; for (; x + 3 < xmax; x += 4) { outptr_interleave = outptr; - interleave_4x4_1_h(inptr0, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x4_1_h(inptr0, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize4; } @@ -904,9 +903,9 @@ static void gemm_s16x16x32_12x4_transpose_pack_A_n(dt_int16* out, } } -static void gemm_s16x16x32_12x4_pack_B_n(dt_int16* out, const dt_int16* in, - int ldin, int x0, int xmax, int k0, - int kmax) { +static void gemm_s16x16x32_12x4_pack_B_n( + dt_int16* out, const dt_int16* in, int ldin, int x0, int xmax, int k0, + int kmax) { const int ksize = kmax - k0; const int ksize4 = ksize * 4; int16_t* outptr = out; @@ -923,14 +922,13 @@ static void gemm_s16x16x32_12x4_pack_B_n(dt_int16* out, const dt_int16* in, outptr = outptr_base; for (; x + 3 < xmax; x += 4) { outptr_interleave = outptr; - interleave_4x4_1_h(inptr0, inptr1, inptr2, inptr3, - outptr_interleave); + interleave_4x4_1_h(inptr0, inptr1, inptr2, inptr3, outptr_interleave); outptr += ksize4; } if (x < xmax) { outptr_interleave = outptr; - interleave_4(inptr0, inptr1, inptr2, inptr3, outptr_interleave, 4, - xmax - x); + interleave_4( + inptr0, inptr1, inptr2, inptr3, outptr_interleave, 4, xmax - x); outptr += ksize4; } outptr_base += 4 * 4; @@ -958,10 +956,9 @@ static void gemm_s16x16x32_12x4_pack_B_n(dt_int16* out, const dt_int16* in, } } -static void gemm_s16x16x32_12x4_transpose_pack_B_n(dt_int16* outptr, - const dt_int16* inptr, - int ldin, int y0, int ymax, - int k0, int kmax) { +static void gemm_s16x16x32_12x4_transpose_pack_B_n( + dt_int16* outptr, const dt_int16* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int K = kmax - k0; int y = y0; int16_t* out = outptr; diff --git a/dnn/src/armv7/matrix_mul/int16x16x32/strategy.cpp b/dnn/src/armv7/matrix_mul/int16x16x32/strategy.cpp index f5b470bc..de4b7b36 100644 --- a/dnn/src/armv7/matrix_mul/int16x16x32/strategy.cpp +++ b/dnn/src/armv7/matrix_mul/int16x16x32/strategy.cpp @@ -9,10 +9,10 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "src/armv7/matrix_mul/int16x16x32/strategy.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/armv7/matrix_mul/asm/common.h" #include "src/armv7/matrix_mul/int16x16x32/kernel_12x4x1.h" -#include "src/armv7/matrix_mul/int16x16x32/strategy.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_common.h" @@ -23,39 +23,36 @@ using namespace armv7::matmul; // ===========================gemm_s16x16x32_4x4================================= MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16x16x32_12x4); -void gemm_s16x16x32_12x4::pack_A(dt_int16* out, const dt_int16* in, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +void gemm_s16x16x32_12x4::pack_A( + dt_int16* out, const dt_int16* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_12x4x1::gemm_s16x16x32_12x4_transpose_pack_A_n(out, in, ldin, y0, - ymax, k0, kmax); + matmul_12x4x1::gemm_s16x16x32_12x4_transpose_pack_A_n( + out, in, ldin, y0, ymax, k0, kmax); } else { - matmul_12x4x1::gemm_s16x16x32_12x4_pack_A_n(out, in, ldin, y0, ymax, k0, - kmax); + matmul_12x4x1::gemm_s16x16x32_12x4_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } } -void gemm_s16x16x32_12x4::pack_B(dt_int16* out, const dt_int16* in, int ldin, - int x0, int xmax, int k0, int kmax, - bool transpose) const { +void gemm_s16x16x32_12x4::pack_B( + dt_int16* out, const dt_int16* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_12x4x1::gemm_s16x16x32_12x4_transpose_pack_B_n(out, in, ldin, x0, - xmax, k0, kmax); + matmul_12x4x1::gemm_s16x16x32_12x4_transpose_pack_B_n( + out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_12x4x1::gemm_s16x16x32_12x4_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_12x4x1::gemm_s16x16x32_12x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_s16x16x32_12x4::kern(const dt_int16* packA, const dt_int16* packB, - size_t M, size_t N, size_t K, dt_int32* C, - size_t LDC, bool is_first_k, const dt_int32*, - dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int16 && - C_dtype.enumv() == DTypeEnum::Int32)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s16x16x32_12x4::kern( + const dt_int16* packA, const dt_int16* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int16 && + C_dtype.enumv() == DTypeEnum::Int32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -73,18 +70,16 @@ void gemm_s16x16x32_12x4::kern(const dt_int16* packA, const dt_int16* packB, size_t n = 0; const dt_int16* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_12x4x1::kern_12x4(packA, cur_packB, K, output, LDC, - is_first_k); + matmul_12x4x1::kern_12x4(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE; cur_packB += K4; } - if (n < N ){ - matmul_12x4x1::kern_12x123(packA, cur_packB, K, output, LDC, - is_first_k, (N-n)); - output += (N-n); + if (n < N) { + matmul_12x4x1::kern_12x123( + packA, cur_packB, K, output, LDC, is_first_k, (N - n)); + output += (N - n); cur_packB += K4; - } packA += K12; @@ -96,16 +91,15 @@ void gemm_s16x16x32_12x4::kern(const dt_int16* packA, const dt_int16* packB, size_t n = 0; const dt_int16* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_12x4x1::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k); + matmul_12x4x1::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE; cur_packB += K4; } - if (n < N){ + if (n < N) { int remain = N - n; - matmul_12x4x1::kern_4x123(packA, cur_packB, K, output, LDC, - is_first_k,remain); + matmul_12x4x1::kern_4x123( + packA, cur_packB, K, output, LDC, is_first_k, remain); output += remain; cur_packB += K4; } @@ -118,17 +112,16 @@ void gemm_s16x16x32_12x4::kern(const dt_int16* packA, const dt_int16* packB, size_t n = 0; const dt_int16* cur_packB = packB; - for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_12x4x1::kern_1x4(packA, cur_packB, K, output, LDC, - is_first_k); + for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { + matmul_12x4x1::kern_1x4(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE; cur_packB += K4; } if (n < N) { int remain = N - n; - matmul_12x4x1::kern_1x123(packA, cur_packB, K, output, LDC, - is_first_k,remain); + matmul_12x4x1::kern_1x123( + packA, cur_packB, K, output, LDC, is_first_k, remain); output += remain; cur_packB += K4; } diff --git a/dnn/src/armv7/matrix_mul/int16x16x32/strategy.h b/dnn/src/armv7/matrix_mul/int16x16x32/strategy.h index 6a2b0a1f..bafba13c 100644 --- a/dnn/src/armv7/matrix_mul/int16x16x32/strategy.h +++ b/dnn/src/armv7/matrix_mul/int16x16x32/strategy.h @@ -15,11 +15,11 @@ namespace megdnn { namespace armv7 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(int16_t, int32_t, int32_t, 12, 4, 1, false, true, - gemm_s16x16x32_12x4); +MEGDNN_REG_GEMM_STRATEGY( + int16_t, int32_t, int32_t, 12, 4, 1, false, true, gemm_s16x16x32_12x4); -MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 4, 8, 1, false, - true, gemm_nopack_s16_4x8); +MEGDNN_REG_GEMM_STRATEGY_NOPACK( + dt_int16, dt_int32, dt_int32, 4, 8, 1, false, true, gemm_nopack_s16_4x8); } // namespace matmul } // namespace armv7 diff --git a/dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp b/dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp index 8982a269..e4361aa2 100644 --- a/dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp +++ b/dnn/src/armv7/matrix_mul/int16x16x32/strategy_mk8_4x8.cpp @@ -9,9 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/armv7/matrix_mul/int16x16x32/strategy.h" -#include "src/armv7/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" +#include "src/armv7/matrix_mul/asm/common.h" +#include "src/armv7/matrix_mul/int16x16x32/strategy.h" #include "src/common/utils.h" using namespace megdnn; @@ -20,8 +20,9 @@ using namespace armv7::matmul; namespace { -void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, - dt_int32* output) { +void kern_8x1( + const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, + dt_int32* output) { //! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 //! here. LDB = (LDB - 4) * sizeof(dt_int16); @@ -100,9 +101,9 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", - "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", - "d25", "d26", "d27", "d28", "d29", "d30", "d31", "cc", "memory"); + : "d0", "d1", "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", + "d17", "d18", "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", + "d27", "d28", "d29", "d30", "d31", "cc", "memory"); } // Overview of register layout: @@ -125,8 +126,9 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, // | q3[0-7]| |q14[0-3]|v15[0-3]| // +--------+ +--------+--------+ // Accumulator -void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, - dt_int32* output) { +void kern_8x4( + const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, + dt_int32* output) { //! As each load 16 number from B, but the pos add 16 * 2, so we minus 16 //! here. LDB = (LDB - 16) * sizeof(dt_int16); @@ -315,20 +317,20 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [output] "+r"(output), [LDB] "+r"(LDB) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "cc", "memory"); } } // anonymous namespace MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_4x8); -void gemm_nopack_s16_4x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, - size_t LDB, dt_int32* C, size_t LDC, size_t M, - size_t K, size_t N, const dt_int32*, void*, - bool trA, bool trB) const { +void gemm_nopack_s16_4x8::kern( + const dt_int16* A, size_t LDA, const dt_int16* B, size_t LDB, dt_int32* C, + size_t LDC, size_t M, size_t K, size_t N, const dt_int32*, void*, bool trA, + bool trB) const { constexpr static size_t MB = 8; constexpr static size_t KB = 8; constexpr static size_t NB = 4; diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_4x2x16.h b/dnn/src/armv7/matrix_mul/int8/kernel_4x2x16.h index d06fa584..ce3772b9 100644 --- a/dnn/src/armv7/matrix_mul/int8/kernel_4x2x16.h +++ b/dnn/src/armv7/matrix_mul/int8/kernel_4x2x16.h @@ -64,9 +64,9 @@ namespace matmul_4x2x16 { * Accumulator */ -static void kern_4x2(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { +static void kern_4x2( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { MEGDNN_MARK_USED_VAR(m_remain); MEGDNN_MARK_USED_VAR(n_remain); K /= 16; diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_4x8x8.h b/dnn/src/armv7/matrix_mul/int8/kernel_4x8x8.h index e664df95..a58f3ed1 100644 --- a/dnn/src/armv7/matrix_mul/int8/kernel_4x8x8.h +++ b/dnn/src/armv7/matrix_mul/int8/kernel_4x8x8.h @@ -17,9 +17,9 @@ namespace megdnn { namespace armv7 { namespace matmul_4x8x8 { -static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - size_t m_remain) { +static void kern_4x8( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t m_remain) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -183,14 +183,14 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, "bne 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [x0] "+r"(x0), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), + [outptr] "+r"(outptr) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "r1", "r2", "r3", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -198,9 +198,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, size_t m_remain, - size_t n_remain) { +static void kern_4x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t m_remain, size_t n_remain) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -365,13 +365,12 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, "3:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), - [x0] "+r"(x0), [m_remain] "+r"(m_remain), - [n_remain] "+r"(n_remain) + [x0] "+r"(x0), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "r1", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "r1", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -379,8 +378,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void gemm_s8_4x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, int kmax) { +static void gemm_s8_4x8_pack_A_n( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -438,9 +438,8 @@ static void gemm_s8_4x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, } } -static void gemm_s8_4x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, - int ldin, int x0, int xmax, int k0, - int kmax) { +static void gemm_s8_4x8_transpose_pack_A_n( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -500,8 +499,9 @@ static void gemm_s8_4x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, } } - transpose_4x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_4x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += ksize4; } @@ -534,16 +534,17 @@ static void gemm_s8_4x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, } } - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, xmax - x); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, xmax - x); } outptr_base += 4 * 8; } } -static void gemm_s8_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8_4x8_pack_B_n( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; @@ -605,8 +606,9 @@ static void gemm_s8_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } } outptr_interleave = outptr; - interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr_interleave); + interleave_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave); outptr += ksize8; } @@ -641,8 +643,9 @@ static void gemm_s8_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, 4); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, 4); outptr += ksize4; } @@ -676,8 +679,9 @@ static void gemm_s8_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, xmax - x); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, xmax - x); } outptr_base += 8 * 8; @@ -685,9 +689,9 @@ static void gemm_s8_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } } -static void gemm_s8_4x8_transpose_pack_B_n(dt_int8* outptr, - const dt_int8* inptr, int ldin, - int y0, int ymax, int k0, int kmax) { +static void gemm_s8_4x8_transpose_pack_B_n( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); constexpr int interleave4 = 32; @@ -715,14 +719,16 @@ static void gemm_s8_4x8_transpose_pack_B_n(dt_int8* outptr, int K = kmax - k0; for (; K > 7; K -= 8) { - transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += interleave8; } if (K > 0) { - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 8, K); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 8, K); outptr += interleave8; } } diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h b/dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h index 0b7de2e7..da41ebbf 100644 --- a/dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h +++ b/dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h @@ -44,9 +44,9 @@ namespace matmul_dot_6x8x4 { // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_6x8(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - size_t m_remain = 6) { +static void kern_6x8( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t m_remain = 6) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -235,15 +235,13 @@ static void kern_6x8(const int8_t* packA, const int8_t* packB, int K, "vsdot.s8 q13, q3, d3[0]\n" "vsdot.s8 q15, q3, d3[1]\n" STORE_C - : [k] "+r"(k), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), - [m_remain] "+r"(m_remain), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), - [outptr3] "+r"(outptr3), [outptr4] "+r"(outptr4), - [outptr5] "+r"(outptr5) + : [k] "+r"(k), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), + [is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), + [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), + [outptr3] "+r"(outptr3), [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "q12", "q13", "q14", "q15", "r12", "cc", "memory"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "r12", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE @@ -276,9 +274,9 @@ static void kern_6x8(const int8_t* packA, const int8_t* packB, int K, // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_6x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, - size_t n_remain = 8, size_t m_remain = 6) { +static void kern_6x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t n_remain = 8, size_t m_remain = 6) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -466,13 +464,12 @@ static void kern_6x4(const int8_t* packA, const int8_t* packB, int K, : [k] "+r"(K), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0), - [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), - [outptr3] "+r"(outptr3), [outptr4] "+r"(outptr4), - [outptr5] "+r"(outptr5), [m_remain] "+r"(m_remain), - [n_remain] "+r"(n_remain) + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3), + [outptr4] "+r"(outptr4), [outptr5] "+r"(outptr5), + [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "q12", "q13", "q14", "q15", "cc", "r12", "memory"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "cc", "r12", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -480,8 +477,9 @@ static void kern_6x4(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void gemm_s8_6x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, int kmax) { +static void gemm_s8_6x8_pack_A_n( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -518,8 +516,7 @@ static void gemm_s8_6x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, megdnn_assert(0); } } - interleave_6x4_8_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - outptr); + interleave_6x4_8_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, outptr); } for (; K > 15; K -= 16) { if (y + 5 >= ymax) { @@ -539,8 +536,7 @@ static void gemm_s8_6x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, megdnn_assert(0); } } - interleave_6x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - outptr); + interleave_6x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, outptr); } if (K > 0) { if (y + 5 >= ymax) { @@ -560,14 +556,13 @@ static void gemm_s8_6x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, megdnn_assert(0); } } - interleave_6(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, outptr, - 4, K); + interleave_6(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, outptr, 4, K); } } } -static void gemm_s8_6x8_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8_6x8_pack_A_t( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; @@ -628,8 +623,8 @@ static void gemm_s8_6x8_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, } } -static void gemm_s8_6x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8_6x8_pack_B_n( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; @@ -715,8 +710,9 @@ static void gemm_s8_6x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } } -static void gemm_s8_6x8_pack_B_t(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, int kmax) { +static void gemm_s8_6x8_pack_B_t( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -743,12 +739,14 @@ static void gemm_s8_6x8_pack_B_t(dt_int8* outptr, const dt_int8* inptr, int K = kmax - k0; //! read 12 * 4 in each row for (; K > 15; K -= 16) { - interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + interleave_8x4_4_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, K); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, K); } } for (; y < ymax; y += 4) { diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h index a29970d1..10985f91 100644 --- a/dnn/src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h +++ b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_4x2x16.h @@ -64,8 +64,9 @@ namespace matmul_mk4_4x2x16 { * Accumulator */ -static void kern_4x2(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, bool is_first_k, int n_remain) { +static void kern_4x2( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, + bool is_first_k, int n_remain) { MEGDNN_MARK_USED_VAR(n_remain); K /= 16; const int8_t* a_ptr = packA; @@ -201,24 +202,25 @@ static void kern_4x2(const int8_t* packA, const int8_t* packB, int K, "vstr d8, [%[outptr]]\n" "vstr d9, [%[outptr], #8]\n" - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [is_first_k] "+r"(is_first_k), [K] "+r"(K), [outptr] "+r"(output), - [n_remain] "+r"(n_remain) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), + [K] "+r"(K), [outptr] "+r"(output), [n_remain] "+r"(n_remain) : - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q11", "q12", "q13", "q14", "q15"); } -static void gemm_mk4_s8_4x2_pack_A(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_mk4_s8_4x2_pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { //! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} int8_t zerobuff[4][64]; std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); - megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, - "mk4 matmul with m is not times of 4"); - megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, - "mk4 matmul with k is not times of 4"); + megdnn_assert( + ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, + "mk4 matmul with m is not times of 4"); + megdnn_assert( + kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, + "mk4 matmul with k is not times of 4"); size_t roundk = round_up(kmax - k0, 16); size_t out_offset = roundk * 4; int y = y0; @@ -235,8 +237,8 @@ static void gemm_mk4_s8_4x2_pack_A(dt_int8* outptr, const dt_int8* inptr, prefetch_2x(inptr3); int K = kmax - k0; for (; K > 15; K -= 16) { - transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, - out_offset); + transpose_interleave_4x4_4_b( + inptr0, inptr1, inptr2, inptr3, output, out_offset); output += 64; } if (K > 0) { @@ -248,8 +250,8 @@ static void gemm_mk4_s8_4x2_pack_A(dt_int8* outptr, const dt_int8* inptr, inptr1 = zerobuff[1]; inptr2 = zerobuff[2]; inptr3 = zerobuff[3]; - transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, - out_offset); + transpose_interleave_4x4_4_b( + inptr0, inptr1, inptr2, inptr3, output, out_offset); output += 64; } } @@ -271,21 +273,21 @@ static void gemm_mk4_s8_4x2_pack_A(dt_int8* outptr, const dt_int8* inptr, } } -static void gemm_mk4_s8_4x2_pack_B(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_mk4_s8_4x2_pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int32_t zerobuff[4]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; const int ICB = (ksize) / 4; const int ksize2 = round_up(ICB, 4) * 2; int32_t* outptr = reinterpret_cast(out); - megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, - "mk4 matmul with k is not times of 4"); + megdnn_assert( + kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, + "mk4 matmul with k is not times of 4"); int k = k0 / 4; for (; k + 3 < ICB; k += 4) { - const int32_t* inptr0 = - reinterpret_cast(in + k * ldin + x0); + const int32_t* inptr0 = reinterpret_cast(in + k * ldin + x0); const int32_t* inptr1 = reinterpret_cast(in + (k + 1) * ldin + x0); const int32_t* inptr2 = @@ -308,8 +310,7 @@ static void gemm_mk4_s8_4x2_pack_B(dt_int8* out, const dt_int8* in, int ldin, outptr += 4 * 2; } if (k < ICB) { - const int32_t* inptr0 = - reinterpret_cast(in + k * ldin + x0); + const int32_t* inptr0 = reinterpret_cast(in + k * ldin + x0); const int32_t* inptr1 = reinterpret_cast(in + (k + 1) * ldin + x0); const int32_t* inptr2 = diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h index d28ec559..d7cf8c4a 100644 --- a/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h +++ b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h @@ -43,8 +43,9 @@ namespace matmul_mk4_dot_8x4x4 { // +--------+ // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, int n_remain) { +static void kern_8x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int n_remain) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -184,8 +185,8 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), [x0] "+r"(x0) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "q12", "q14", "cc", "memory"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q14", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE @@ -212,8 +213,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, int n_remain) { +static void kern_4x4( + const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, int n_remain) { K /= 4; const int32_t* a_ptr = reinterpret_cast(packA); const int32_t* b_ptr = reinterpret_cast(packB); @@ -324,9 +326,8 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, "4:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), - [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), - [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k), - [x0] "+r"(x0) + [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [LDC] "+r"(LDC), + [outptr0] "+r"(outptr0), [k] "+r"(k), [x0] "+r"(x0) : : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "cc", "memory"); @@ -336,9 +337,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void gemm_dots8_8x4_pack_A(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_dots8_8x4_pack_A( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int y = y0, y_start = y0 / 4; for (; y + 7 < ymax; y += 8, y_start += 2) { const int8_t* inptr0 = inptr + y_start * ldin + k0 * 4; @@ -357,8 +358,8 @@ static void gemm_dots8_8x4_pack_A(dt_int8* outptr, const dt_int8* inptr, } } -static void gemm_dots8_8x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_dots8_8x4_pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { const int ksize = kmax - k0; const int ksize4 = ksize * 4; int8_t* outptr = out; diff --git a/dnn/src/armv7/matrix_mul/int8/strategy.cpp b/dnn/src/armv7/matrix_mul/int8/strategy.cpp index 69be5248..2d84f6aa 100644 --- a/dnn/src/armv7/matrix_mul/int8/strategy.cpp +++ b/dnn/src/armv7/matrix_mul/int8/strategy.cpp @@ -28,8 +28,9 @@ MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x2); // ===========================gemm_s8_4x2====================================== -void gemm_s8_4x2::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, - int ymax, int k0, int kmax, bool transpose) const { +void gemm_s8_4x2::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { if (transpose) { matmul_4x2x16::gemm_s8_4x2_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); } else { @@ -37,8 +38,9 @@ void gemm_s8_4x2::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, } } -void gemm_s8_4x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { +void gemm_s8_4x2::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { matmul_4x2x16::gemm_s8_4x2_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); } else { @@ -46,16 +48,16 @@ void gemm_s8_4x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, } } -void gemm_s8_4x2::kern(const dt_int8* packA, const dt_int8* packB, size_t M, - size_t N, size_t K, dt_int32* C, size_t LDC, - bool is_first_k, const dt_int32*, dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int32) || - (A_dtype.enumv() == DTypeEnum::QuantizedS8 && - C_dtype.enumv() == DTypeEnum::QuantizedS32)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s8_4x2::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -75,15 +77,15 @@ void gemm_s8_4x2::kern(const dt_int8* packA, const dt_int8* packB, size_t M, size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, - is_first_k, 4, 2); + matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, is_first_k, 4, 2); output += B_INTERLEAVE; cur_packB += K2; } for (; n < N; n += B_INTERLEAVE) { - matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, - is_first_k, 4, std::min(N - n, 2)); + matmul_4x2x16::kern_4x2( + packA, cur_packB, K, output, LDC, is_first_k, 4, + std::min(N - n, 2)); output += B_INTERLEAVE; cur_packB += K2; } @@ -97,9 +99,9 @@ void gemm_s8_4x2::kern(const dt_int8* packA, const dt_int8* packB, size_t M, size_t n = 0; const dt_int8* cur_packB = packB; for (; n < N; n += B_INTERLEAVE) { - matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, - is_first_k, std::min(M - m, 4), - std::min(N - n, 2)); + matmul_4x2x16::kern_4x2( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 2)); output += B_INTERLEAVE; cur_packB += K2; } @@ -110,36 +112,36 @@ void gemm_s8_4x2::kern(const dt_int8* packA, const dt_int8* packB, size_t M, // ===========================gemm_s8_4x4====================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x8); -void gemm_s8_4x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, - int ymax, int k0, int kmax, bool transpose) const { +void gemm_s8_4x8::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_4x8x8::gemm_s8_4x8_transpose_pack_A_n(out, in, ldin, y0, ymax, - k0, kmax); + matmul_4x8x8::gemm_s8_4x8_transpose_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } else { matmul_4x8x8::gemm_s8_4x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } } -void gemm_s8_4x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { +void gemm_s8_4x8::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_4x8x8::gemm_s8_4x8_transpose_pack_B_n(out, in, ldin, x0, xmax, - k0, kmax); + matmul_4x8x8::gemm_s8_4x8_transpose_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } else { matmul_4x8x8::gemm_s8_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_s8_4x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, - size_t N, size_t K, dt_int32* C, size_t LDC, - bool is_first_k, const dt_int32*, dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int32) || - (A_dtype.enumv() == DTypeEnum::QuantizedS8 && - C_dtype.enumv() == DTypeEnum::QuantizedS32)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s8_4x8::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -158,16 +160,17 @@ void gemm_s8_4x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, const dt_int8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_4x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4)); + matmul_4x8x8::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_4x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), - std::min(N - n, 4)); + matmul_4x8x8::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -178,32 +181,30 @@ void gemm_s8_4x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, #if MGB_ENABLE_DOT // ===========================gemm_s8_6x8====================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dots8_6x8); -void gemm_dots8_6x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, - int ymax, int k0, int kmax, bool transpose) const { +void gemm_dots8_6x8::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_dot_6x8x4::gemm_s8_6x8_pack_A_t(out, in, ldin, y0, ymax, k0, - kmax); + matmul_dot_6x8x4::gemm_s8_6x8_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); } else { - matmul_dot_6x8x4::gemm_s8_6x8_pack_A_n(out, in, ldin, y0, ymax, k0, - kmax); + matmul_dot_6x8x4::gemm_s8_6x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } } -void gemm_dots8_6x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { +void gemm_dots8_6x8::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_dot_6x8x4::gemm_s8_6x8_pack_B_t(out, in, ldin, x0, xmax, k0, - kmax); + matmul_dot_6x8x4::gemm_s8_6x8_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_dot_6x8x4::gemm_s8_6x8_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_dot_6x8x4::gemm_s8_6x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, - size_t N, size_t K, dt_int32* C, size_t LDC, - bool is_first_k, const dt_int32* bias, - dt_int32* workspace) const { +void gemm_dots8_6x8::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32* bias, + dt_int32* workspace) const { MEGDNN_MARK_USED_VAR(bias); constexpr size_t A_INTERLEAVE = 6; constexpr size_t B_INTERLEAVE = 8; @@ -219,15 +220,14 @@ void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, const dt_int8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_dot_6x8x4::kern_6x8(packA, cur_packB, K, output, LDC, - is_first_k); + matmul_dot_6x8x4::kern_6x8(packA, cur_packB, K, output, LDC, is_first_k); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { size_t n_remain = std::min(N - n, 4); - matmul_dot_6x8x4::kern_6x4(packA, cur_packB, K, output, LDC, - is_first_k, n_remain); + matmul_dot_6x8x4::kern_6x4( + packA, cur_packB, K, output, LDC, is_first_k, n_remain); output += n_remain; cur_packB += K4; } @@ -239,15 +239,15 @@ void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, size_t m_remain = std::min(M - m, 6); size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_dot_6x8x4::kern_6x8(packA, cur_packB, K, output, LDC, - is_first_k, m_remain); + matmul_dot_6x8x4::kern_6x8( + packA, cur_packB, K, output, LDC, is_first_k, m_remain); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { size_t n_remain = std::min(N - n, 4); - matmul_dot_6x8x4::kern_6x4(packA, cur_packB, K, output, LDC, - is_first_k, n_remain, m_remain); + matmul_dot_6x8x4::kern_6x4( + packA, cur_packB, K, output, LDC, is_first_k, n_remain, m_remain); output += n_remain; cur_packB += K4; } @@ -257,34 +257,35 @@ void gemm_dots8_6x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, // ===========================gemm_mk4_dots8_8x4====================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_dots8_8x4); -void gemm_mk4_dots8_8x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { - megdnn_assert(!transpose, - "matrix mul mk4 with transposed matrix A is not supported."); - megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, - "mk4 format matmul with m is not times of 4."); - megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, - "mk4 format matmul with k is not times of 4."); - matmul_mk4_dot_8x4x4::gemm_dots8_8x4_pack_A(out, in, ldin, y0, ymax, k0, - kmax); +void gemm_mk4_dots8_8x4::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { + megdnn_assert( + !transpose, "matrix mul mk4 with transposed matrix A is not supported."); + megdnn_assert( + ymax % 4 == 0 && y0 % 4 == 0, + "mk4 format matmul with m is not times of 4."); + megdnn_assert( + kmax % 4 == 0 && k0 % 4 == 0, + "mk4 format matmul with k is not times of 4."); + matmul_mk4_dot_8x4x4::gemm_dots8_8x4_pack_A(out, in, ldin, y0, ymax, k0, kmax); } -void gemm_mk4_dots8_8x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax, - bool transpose) const { - megdnn_assert(!transpose, - "matrix mul mk4 with transposed matrix B is not supported"); - megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, - "mk4 format matmul with k is not times of 4."); - matmul_mk4_dot_8x4x4::gemm_dots8_8x4_pack_B(out, in, ldin, x0, xmax, k0, - kmax); +void gemm_mk4_dots8_8x4::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { + megdnn_assert( + !transpose, "matrix mul mk4 with transposed matrix B is not supported"); + megdnn_assert( + kmax % 4 == 0 && k0 % 4 == 0, + "mk4 format matmul with k is not times of 4."); + matmul_mk4_dot_8x4x4::gemm_dots8_8x4_pack_B(out, in, ldin, x0, xmax, k0, kmax); } -void gemm_mk4_dots8_8x4::kern(const dt_int8* packA, const dt_int8* packB, - size_t M, size_t N, size_t K, dt_int32* C, - size_t LDC, bool is_first_k, const dt_int32* bias, - dt_int32* workspace) const { +void gemm_mk4_dots8_8x4::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32* bias, + dt_int32* workspace) const { MEGDNN_MARK_USED_VAR(bias); constexpr size_t A_INTERLEAVE = 8; //! K is packed to times of 4 @@ -298,8 +299,8 @@ void gemm_mk4_dots8_8x4::kern(const dt_int8* packA, const dt_int8* packB, const dt_int8* cur_packB = packB; for (size_t n = 0; n < N; n += 4) { size_t n_remain = std::min(N - n, 4); - matmul_mk4_dot_8x4x4::kern_8x4(packA, cur_packB, K, output, LDC, - is_first_k, n_remain); + matmul_mk4_dot_8x4x4::kern_8x4( + packA, cur_packB, K, output, LDC, is_first_k, n_remain); output += 16; cur_packB += K4; } @@ -310,8 +311,8 @@ void gemm_mk4_dots8_8x4::kern(const dt_int8* packA, const dt_int8* packB, const dt_int8* cur_packB = packB; for (size_t n = 0; n < N; n += 4) { size_t n_remain = std::min(N - n, 4); - matmul_mk4_dot_8x4x4::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, n_remain); + matmul_mk4_dot_8x4x4::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, n_remain); output += 16; cur_packB += K4; } @@ -323,30 +324,30 @@ void gemm_mk4_dots8_8x4::kern(const dt_int8* packA, const dt_int8* packB, // ===========================gemm_mk4_s8_4x2====================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x2); -void gemm_mk4_s8_4x2::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, - int ymax, int k0, int kmax, bool transpose) const { +void gemm_mk4_s8_4x2::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { megdnn_assert(!transpose); - matmul_mk4_4x2x16::gemm_mk4_s8_4x2_pack_A(out, in, ldin, y0, ymax, k0, - kmax); + matmul_mk4_4x2x16::gemm_mk4_s8_4x2_pack_A(out, in, ldin, y0, ymax, k0, kmax); } -void gemm_mk4_s8_4x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { +void gemm_mk4_s8_4x2::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { megdnn_assert(!transpose); - matmul_mk4_4x2x16::gemm_mk4_s8_4x2_pack_B(out, in, ldin, x0, xmax, k0, - kmax); + matmul_mk4_4x2x16::gemm_mk4_s8_4x2_pack_B(out, in, ldin, x0, xmax, k0, kmax); } -void gemm_mk4_s8_4x2::kern(const dt_int8* packA, const dt_int8* packB, size_t M, - size_t N, size_t K, dt_int32* C, size_t LDC, - bool is_first_k, const dt_int32*, dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int32) || - (A_dtype.enumv() == DTypeEnum::QuantizedS8 && - C_dtype.enumv() == DTypeEnum::QuantizedS32)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_mk4_s8_4x2::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int32) || + (A_dtype.enumv() == DTypeEnum::QuantizedS8 && + C_dtype.enumv() == DTypeEnum::QuantizedS32)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -367,8 +368,9 @@ void gemm_mk4_s8_4x2::kern(const dt_int8* packA, const dt_int8* packB, size_t M, size_t n = 0; const dt_int8* cur_packB = packB; for (; n < N; n += B_INTERLEAVE) { - matmul_mk4_4x2x16::kern_4x2(packA, cur_packB, K, output, is_first_k, - std::min(N - n, 2)); + matmul_mk4_4x2x16::kern_4x2( + packA, cur_packB, K, output, is_first_k, + std::min(N - n, 2)); output += B_INTERLEAVE * 4; cur_packB += K2; } diff --git a/dnn/src/armv7/matrix_mul/int8/strategy.h b/dnn/src/armv7/matrix_mul/int8/strategy.h index 66ddcdd3..e76ff43f 100644 --- a/dnn/src/armv7/matrix_mul/int8/strategy.h +++ b/dnn/src/armv7/matrix_mul/int8/strategy.h @@ -15,20 +15,20 @@ namespace megdnn { namespace armv7 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, true, - gemm_s8_4x2); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int32, dt_int32, 4, 2, 16, false, true, gemm_s8_4x2); -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 8, 8, false, true, - gemm_s8_4x8); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int32, dt_int32, 4, 8, 8, false, true, gemm_s8_4x8); -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false, - gemm_mk4_s8_4x2); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false, gemm_mk4_s8_4x2); #if MGB_ENABLE_DOT -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false, - gemm_dots8_6x8); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false, gemm_dots8_6x8); -MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 4, 4, false, false, - gemm_mk4_dots8_8x4); +MEGDNN_REG_GEMM_STRATEGY( + dt_int8, dt_int32, dt_int32, 8, 4, 4, false, false, gemm_mk4_dots8_8x4); #endif } // namespace matmul } // namespace armv7 diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h index dffbcf09..8df761ec 100644 --- a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h +++ b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x2x16.h @@ -46,9 +46,9 @@ namespace matmul_4x2x16 { * Accumulator */ -static void kern_4x2(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, int m_remain, - int n_remain) { +static void kern_4x2( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int m_remain, int n_remain) { MEGDNN_MARK_USED_VAR(m_remain); MEGDNN_MARK_USED_VAR(n_remain); K /= 16; diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h index 1cd6bdf2..2b42d4ca 100644 --- a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h +++ b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_4x8x8.h @@ -38,9 +38,9 @@ namespace matmul_4x8x8 { * * Accumulator */ -static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, - size_t m_remain) { +static void kern_4x8( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, size_t m_remain) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -166,14 +166,14 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, "bne 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [x0] "+r"(x0), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), + [outptr] "+r"(outptr) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "r1", "r2", "r3", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -202,9 +202,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, * * Accumulator */ -static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, size_t m_remain, - size_t n_remain) { +static void kern_4x4( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, size_t m_remain, size_t n_remain) { K /= 8; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -369,13 +369,12 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, "3:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), - [x0] "+r"(x0), [m_remain] "+r"(m_remain), - [n_remain] "+r"(n_remain) + [x0] "+r"(x0), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "r1", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "r1", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -383,9 +382,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, #undef STORE_C } -static void gemm_s8x8x16_4x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_s8x8x16_4x8_pack_A_n( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -443,9 +442,8 @@ static void gemm_s8x8x16_4x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, } } -static void gemm_s8x8x16_4x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, - int ldin, int x0, int xmax, - int k0, int kmax) { +static void gemm_s8x8x16_4x8_transpose_pack_A_n( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); @@ -505,8 +503,9 @@ static void gemm_s8x8x16_4x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, } } - transpose_4x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_4x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += ksize4; } @@ -539,16 +538,17 @@ static void gemm_s8x8x16_4x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, } } - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, xmax - x); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, xmax - x); } outptr_base += 4 * 8; } } -static void gemm_s8x8x16_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8x8x16_4x8_pack_B_n( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); const int ksize = kmax - k0; @@ -610,8 +610,9 @@ static void gemm_s8x8x16_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } } outptr_interleave = outptr; - interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr_interleave); + interleave_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave); outptr += ksize8; } @@ -646,8 +647,9 @@ static void gemm_s8x8x16_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, 4); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, 4); outptr += ksize4; } @@ -681,8 +683,9 @@ static void gemm_s8x8x16_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, xmax - x); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, xmax - x); } outptr_base += 8 * 8; @@ -690,10 +693,9 @@ static void gemm_s8x8x16_4x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } } -static void gemm_s8x8x16_4x8_transpose_pack_B_n(dt_int8* outptr, - const dt_int8* inptr, int ldin, - int y0, int ymax, int k0, - int kmax) { +static void gemm_s8x8x16_4x8_transpose_pack_B_n( + dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); constexpr int interleave4 = 32; @@ -721,14 +723,16 @@ static void gemm_s8x8x16_4x8_transpose_pack_B_n(dt_int8* outptr, int K = kmax - k0; for (; K > 7; K -= 8) { - transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += interleave8; } if (K > 0) { - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 8, K); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 8, K); outptr += interleave8; } } diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h index 0550e978..b0ba5c76 100644 --- a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h +++ b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_8x8x4.h @@ -38,9 +38,9 @@ namespace matmul_8x8x4 { * */ -static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, - size_t n_remain) { +static void kern_8x8( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, size_t n_remain) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -185,12 +185,11 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, "bne 2b\n" "3:\n" STORE_C - : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), - [ LDC ] "+r"(LDC), [ is_first_k ] "+r"(is_first_k), - [ outptr ] "+r"(outptr), [ nr ] "+r"(nr) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), [nr] "+r"(nr) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "q12", "q13", "q14", "q15", "r1", "r2", "cc", "memory"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14", "q15", "r1", "r2", "cc", "memory"); #undef LOAD_C #undef STORE_LINE #undef STORE_C @@ -213,9 +212,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, * +----+-----------------+ +--------+---------------------------------+ * */ -static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, size_t m_remain, - size_t n_remain) { +static void kern_4x8( + const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, size_t m_remain, size_t n_remain) { K /= 4; const int8_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -341,20 +340,20 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, "bne 2b\n" "3:\n" STORE_C - : [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), [ K ] "+r"(K), - [ LDC ] "+r"(LDC), [ is_first_k ] "+r"(is_first_k), - [ outptr ] "+r"(outptr), [ mr ] "+r"(mr), [ nr ] "+r"(nr) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), + [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), [mr] "+r"(mr), + [nr] "+r"(nr) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "r1", - "r2", "cc", "memory"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "r1", "r2", + "cc", "memory"); #undef LOAD_C #undef STORE_LINE #undef STORE_C } -static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* out, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_s8x8x16_8x8_pack_A_n( + dt_int8* out, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); int8_t* outptr = out; @@ -380,14 +379,16 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* out, const dt_int8* inptr, int K = kmax - k0; for (; K > 15; K -= 16) { - interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + interleave_8x4_4_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { for (; K > 0; K -= 4) - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr, 4, std::min(K, 4)); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, std::min(K, 4)); } } for (; y < ymax; y += 4) { @@ -418,14 +419,13 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* out, const dt_int8* inptr, megdnn_assert(0); } } - interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, - std::min(K, 4)); + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 4, std::min(K, 4)); } } } -static void gemm_s8x8x16_8x8_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8x8x16_8x8_pack_A_t( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); int8_t* outbase = out; @@ -494,8 +494,8 @@ static void gemm_s8x8x16_8x8_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, } } -static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_s8x8x16_8x8_pack_B_n( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); int8_t* outbase = out; @@ -541,9 +541,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, "vst1.32 {d1}, [%[out_interleave]]!\n" "vst1.32 {d2}, [%[out_interleave]]!\n" "vst1.32 {d3}, [%[out_interleave]]!\n" - : [ inptr0 ] "+r"(inptr0), [ inptr1 ] "+r"(inptr1), - [ inptr2 ] "+r"(inptr2), [ inptr3 ] "+r"(inptr3), - [ out_interleave ] "+r"(out_interleave) + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), + [inptr2] "+r"(inptr2), [inptr3] "+r"(inptr3), + [out_interleave] "+r"(out_interleave) : : "q0", "q1", "cc", "memory"); outptr += K8; @@ -572,9 +572,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } } -static void gemm_s8x8x16_8x8_pack_B_t(dt_int8* out, const dt_int8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_s8x8x16_8x8_pack_B_t( + dt_int8* out, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { int8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(int8_t) * 16); int8_t* outptr = out; @@ -629,8 +629,9 @@ static void gemm_s8x8x16_8x8_pack_B_t(dt_int8* out, const dt_int8* inptr, megdnn_assert(0); } } - transpose_4x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_4x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += 4 * 8; } @@ -662,8 +663,9 @@ static void gemm_s8x8x16_8x8_pack_B_t(dt_int8* out, const dt_int8* inptr, megdnn_assert(0); } } - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, kmax - k); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, kmax - k); outptr += 4 * 8; } } diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h index 9ac25ee5..5e0c6eb8 100644 --- a/dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h +++ b/dnn/src/armv7/matrix_mul/int8x8x16/kernel_mk4_8x8x4.h @@ -44,8 +44,9 @@ namespace matmul_mk4_8x8x4 { * * Accumulator */ -static void kern_8x8(const int16_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, int remain_n) { +static void kern_8x8( + const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int remain_n) { K /= 4; const int16_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -171,14 +172,14 @@ static void kern_8x8(const int16_t* packA, const int8_t* packB, int K, "b 101f\n" "4:\n " STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [x0] "+r"(x0), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), + [remain_n] "+r"(remain_n) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "r1", "r2", "r3", "cc", "memory"); #undef STORE_C #undef STORE_LINE } @@ -204,8 +205,9 @@ static void kern_8x8(const int16_t* packA, const int8_t* packB, int K, * * Accumulator */ -static void kern_4x8(const int16_t* packA, const int8_t* packB, int K, - int16_t* output, int LDC, bool is_first_k, int remain_n) { +static void kern_4x8( + const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, + bool is_first_k, int remain_n) { K /= 4; const int16_t* a_ptr = packA; const int8_t* b_ptr = packB; @@ -321,21 +323,21 @@ static void kern_4x8(const int16_t* packA, const int8_t* packB, int K, "4:\n " STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [x0] "+r"(x0), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), + [remain_n] "+r"(remain_n) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "r1", "r2", "r3", "cc", "memory"); #undef STORE_C #undef STORE_LINE } -static void gemm_s8x8x16_mk4_8x8_pack_A_n(dt_int16* outptr, - const dt_int8* inptr, int ldin, - int m0, int mmax, int k0, int kmax) { +static void gemm_s8x8x16_mk4_8x8_pack_A_n( + dt_int16* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0, + int kmax) { megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); constexpr int pack_m = 8; @@ -365,9 +367,8 @@ static void gemm_s8x8x16_mk4_8x8_pack_A_n(dt_int16* outptr, } } -static void gemm_s8x8x16_mk4_8x8_pack_B_n(dt_int8* out, const dt_int8* in, - int ldin, int n0, int nmax, int k0, - int kmax) { +static void gemm_s8x8x16_mk4_8x8_pack_B_n( + dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); int8_t tmpbuff[32] = {0}; diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp index 6dd19968..0a840eac 100644 --- a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp +++ b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.cpp @@ -27,39 +27,34 @@ using namespace armv7::matmul; // ===========================gemm_s8x8x16_4x2================================= MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x2); -void gemm_s8x8x16_4x2::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, - int ymax, int k0, int kmax, - bool transpose) const { +void gemm_s8x8x16_4x2::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_4x2x16::gemm_s8x8x16_4x2_pack_A_t(out, in, ldin, y0, ymax, k0, - kmax); + matmul_4x2x16::gemm_s8x8x16_4x2_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); } else { - matmul_4x2x16::gemm_s8x8x16_4x2_pack_A_n(out, in, ldin, y0, ymax, k0, - kmax); + matmul_4x2x16::gemm_s8x8x16_4x2_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } } -void gemm_s8x8x16_4x2::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, - bool transpose) const { +void gemm_s8x8x16_4x2::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_4x2x16::gemm_s8x8x16_4x2_pack_B_t(out, in, ldin, x0, xmax, k0, - kmax); + matmul_4x2x16::gemm_s8x8x16_4x2_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_4x2x16::gemm_s8x8x16_4x2_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_4x2x16::gemm_s8x8x16_4x2_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_s8x8x16_4x2::kern(const dt_int8* packA, const dt_int8* packB, - size_t M, size_t N, size_t K, dt_int16* C, - size_t LDC, bool is_first_k, const dt_int16*, - dt_int16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int16)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s8x8x16_4x2::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -79,15 +74,15 @@ void gemm_s8x8x16_4x2::kern(const dt_int8* packA, const dt_int8* packB, size_t n = 0; const dt_int8* cur_packB = packB; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, - is_first_k, 4, 2); + matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, is_first_k, 4, 2); output += B_INTERLEAVE; cur_packB += K2; } for (; n < N; n += B_INTERLEAVE) { - matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, - is_first_k, 4, std::min(N - n, 2)); + matmul_4x2x16::kern_4x2( + packA, cur_packB, K, output, LDC, is_first_k, 4, + std::min(N - n, 2)); output += B_INTERLEAVE; cur_packB += K2; } @@ -101,9 +96,9 @@ void gemm_s8x8x16_4x2::kern(const dt_int8* packA, const dt_int8* packB, size_t n = 0; const dt_int8* cur_packB = packB; for (; n < N; n += B_INTERLEAVE) { - matmul_4x2x16::kern_4x2(packA, cur_packB, K, output, LDC, - is_first_k, std::min(M - m, 4), - std::min(N - n, 2)); + matmul_4x2x16::kern_4x2( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 2)); output += B_INTERLEAVE; cur_packB += K2; } @@ -114,39 +109,36 @@ void gemm_s8x8x16_4x2::kern(const dt_int8* packA, const dt_int8* packB, // ===========================gemm_s8x8x16_4x8================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x8); -void gemm_s8x8x16_4x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, - int ymax, int k0, int kmax, - bool transpose) const { +void gemm_s8x8x16_4x8::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_4x8x8::gemm_s8x8x16_4x8_transpose_pack_A_n(out, in, ldin, y0, - ymax, k0, kmax); + matmul_4x8x8::gemm_s8x8x16_4x8_transpose_pack_A_n( + out, in, ldin, y0, ymax, k0, kmax); } else { - matmul_4x8x8::gemm_s8x8x16_4x8_pack_A_n(out, in, ldin, y0, ymax, k0, - kmax); + matmul_4x8x8::gemm_s8x8x16_4x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } } -void gemm_s8x8x16_4x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, - bool transpose) const { +void gemm_s8x8x16_4x8::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_4x8x8::gemm_s8x8x16_4x8_transpose_pack_B_n(out, in, ldin, x0, - xmax, k0, kmax); + matmul_4x8x8::gemm_s8x8x16_4x8_transpose_pack_B_n( + out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_4x8x8::gemm_s8x8x16_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_4x8x8::gemm_s8x8x16_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB, - size_t M, size_t N, size_t K, dt_int16* C, - size_t LDC, bool is_first_k, const dt_int16*, - dt_int16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int16)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s8x8x16_4x8::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -165,16 +157,17 @@ void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB, const dt_int8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_4x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4)); + matmul_4x8x8::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4)); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_4x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), - std::min(N - n, 4)); + matmul_4x8x8::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4)); output += 4; cur_packB += K4; } @@ -185,39 +178,34 @@ void gemm_s8x8x16_4x8::kern(const dt_int8* packA, const dt_int8* packB, // ===========================gemm_s8x8x16_8x8================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); -void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, - int ymax, int k0, int kmax, - bool transpose) const { +void gemm_s8x8x16_8x8::pack_A( + dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_8x8x4::gemm_s8x8x16_8x8_pack_A_t(out, in, ldin, y0, ymax, k0, - kmax); + matmul_8x8x4::gemm_s8x8x16_8x8_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); } else { - matmul_8x8x4::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, - kmax); + matmul_8x8x4::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } } -void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, - int xmax, int k0, int kmax, - bool transpose) const { +void gemm_s8x8x16_8x8::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_8x8x4::gemm_s8x8x16_8x8_pack_B_t(out, in, ldin, x0, xmax, k0, - kmax); + matmul_8x8x4::gemm_s8x8x16_8x8_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_8x8x4::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_8x8x4::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, - size_t M, size_t N, size_t K, dt_int16* C, - size_t LDC, bool is_first_k, const dt_int16*, - dt_int16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - ((A_dtype.enumv() == DTypeEnum::Int8 && - C_dtype.enumv() == DTypeEnum::Int16)), - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_s8x8x16_8x8::kern( + const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + ((A_dtype.enumv() == DTypeEnum::Int8 && + C_dtype.enumv() == DTypeEnum::Int16)), + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -233,8 +221,9 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, const dt_int8* cur_packB = packB; size_t n = 0; for (; n < N; n += B_INTERLEAVE) { - matmul_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k, - std::min(N - n, 8)); + matmul_8x8x4::kern_8x8( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(N - n, 8)); output += B_INTERLEAVE; cur_packB += K * 8; } @@ -245,9 +234,9 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, const dt_int8* cur_packB = packB; size_t n = 0; for (; n < N; n += B_INTERLEAVE) { - matmul_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), - std::min(N - n, 8)); + matmul_8x8x4::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 8)); output += B_INTERLEAVE; cur_packB += K * 8; } @@ -258,27 +247,24 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, // ===========================gemm_s8x8x16_mk4_8x8================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8); -void gemm_s8x8x16_mk4_8x8::pack_A(dt_int16* out, const dt_int8* in, int ldin, - int y0, int ymax, int k0, int kmax, - bool) const { - matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, - kmax); +void gemm_s8x8x16_mk4_8x8::pack_A( + dt_int16* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool) const { + matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } -void gemm_s8x8x16_mk4_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, - int x0, int xmax, int k0, int kmax, - bool) const { - matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); +void gemm_s8x8x16_mk4_8x8::pack_B( + dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool) const { + matmul_mk4_8x8x4::gemm_s8x8x16_mk4_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } -void gemm_s8x8x16_mk4_8x8::kern(const dt_int16* packA, const dt_int8* packB, - size_t M, size_t N, size_t K, dt_int16* C, - size_t LDC, bool is_first_k, const dt_int16*, - dt_int16*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - C_dtype.enumv() == DTypeEnum::Int16 && - A_dtype.enumv() == DTypeEnum::Int8); +void gemm_s8x8x16_mk4_8x8::kern( + const dt_int16* packA, const dt_int8* packB, size_t M, size_t N, size_t K, + dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && + A_dtype.enumv() == DTypeEnum::Int8); megdnn_assert(is_first_k == true, "only impl is_first_k"); MEGDNN_MARK_USED_VAR(A_dtype); MEGDNN_MARK_USED_VAR(B_dtype); @@ -298,14 +284,14 @@ void gemm_s8x8x16_mk4_8x8::kern(const dt_int16* packA, const dt_int8* packB, size_t n_idx = 0; const int8_t* cur_packB = packB; for (; n_idx + pack_n <= N; n_idx += pack_n) { - matmul_mk4_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, - is_first_k, pack_n); + matmul_mk4_8x8x4::kern_8x8( + packA, cur_packB, K, output, LDC, is_first_k, pack_n); output += pack_n * pack_size; cur_packB += pack_n * K; } if (remain_n > 0) { - matmul_mk4_8x8x4::kern_8x8(packA, cur_packB, K, output, LDC, - is_first_k, remain_n); + matmul_mk4_8x8x4::kern_8x8( + packA, cur_packB, K, output, LDC, is_first_k, remain_n); output += remain_n * pack_size; cur_packB += pack_n * K; } @@ -316,14 +302,14 @@ void gemm_s8x8x16_mk4_8x8::kern(const dt_int16* packA, const dt_int8* packB, size_t n_idx = 0; const int8_t* cur_packB = packB; for (; n_idx + pack_n <= N; n_idx += pack_n) { - matmul_mk4_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, - is_first_k, pack_n); + matmul_mk4_8x8x4::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, pack_n); output += pack_n * pack_size; cur_packB += pack_n * K; } if (remain_n > 0) { - matmul_mk4_8x8x4::kern_4x8(packA, cur_packB, K, output, LDC, - is_first_k, remain_n); + matmul_mk4_8x8x4::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, remain_n); output += remain_n * pack_size; cur_packB += pack_n * K; } diff --git a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h index d17bd647..55820b06 100644 --- a/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h +++ b/dnn/src/armv7/matrix_mul/int8x8x16/strategy.h @@ -16,18 +16,17 @@ namespace megdnn { namespace armv7 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 2, 16, false, true, - gemm_s8x8x16_4x2); +MEGDNN_REG_GEMM_STRATEGY( + int8_t, int16_t, int16_t, 4, 2, 16, false, true, gemm_s8x8x16_4x2); -MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 4, 8, 8, false, true, - gemm_s8x8x16_4x8); +MEGDNN_REG_GEMM_STRATEGY( + int8_t, int16_t, int16_t, 4, 8, 8, false, true, gemm_s8x8x16_4x8); -MEGDNN_REG_GEMM_STRATEGY(int8_t, int16_t, int16_t, 8, 8, 4, false, true, - gemm_s8x8x16_8x8); +MEGDNN_REG_GEMM_STRATEGY( + int8_t, int16_t, int16_t, 8, 8, 4, false, true, gemm_s8x8x16_8x8); -MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE(int8_t, int16_t, int16_t, int16_t, 8, - 8, 4, false, false, - gemm_s8x8x16_mk4_8x8); +MEGDNN_REG_GEMM_STRATEGY_WITH_PACK_A_TYPE( + int8_t, int16_t, int16_t, int16_t, 8, 8, 4, false, false, gemm_s8x8x16_mk4_8x8); } // namespace matmul } // namespace armv7 diff --git a/dnn/src/armv7/matrix_mul/opr_impl.cpp b/dnn/src/armv7/matrix_mul/opr_impl.cpp index a24ea6d0..bfb1fe72 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.cpp +++ b/dnn/src/armv7/matrix_mul/opr_impl.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/armv7/matrix_mul/algos.h" #include "src/armv7/matrix_mul/opr_impl.h" +#include "src/armv7/matrix_mul/algos.h" #include "src/common/metahelper.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_impl.h" @@ -90,11 +90,11 @@ const MatrixMulImpl::AlgoPack& MatrixMulImpl::algo_pack() { return algo_pack; } -SmallVector -MatrixMulImpl::get_all_packed_algo() { +SmallVector MatrixMulImpl::get_all_packed_algo() { auto algos = arm_common::MatrixMulImpl::get_all_packed_algo(); - algos.insert(algos.begin(), algo_pack().all_algos().begin(), - algo_pack().all_algos().end()); + algos.insert( + algos.begin(), algo_pack().all_algos().begin(), + algo_pack().all_algos().end()); return algos; } diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h index 39b60346..7068f6bb 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.h +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -25,8 +25,7 @@ public: } }; - SmallVector get_all_packed_algo() - override; + SmallVector get_all_packed_algo() override; MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(MatrixMulImpl); diff --git a/dnn/src/armv7/matrix_mul/quint8/kernel_4x8x8.h b/dnn/src/armv7/matrix_mul/quint8/kernel_4x8x8.h index 3556a186..ba54a223 100644 --- a/dnn/src/armv7/matrix_mul/quint8/kernel_4x8x8.h +++ b/dnn/src/armv7/matrix_mul/quint8/kernel_4x8x8.h @@ -39,9 +39,9 @@ namespace matmul_4x8x8 { * * Accumulator */ -static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, size_t m_remain, - uint8_t za, uint8_t zb) { +static void kern_4x8( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t m_remain, uint8_t za, uint8_t zb) { K /= 8; const uint8_t* a_ptr = packA; const uint8_t* b_ptr = packB; @@ -207,15 +207,14 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, "bne 2b\n" "3:\n" STORE_C - : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), - [x0] "+r"(x0), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [m_remain] "+r"(m_remain), [za] "+r"(za), [zb] "+r"(zb), - [outptr] "+r"(outptr) + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [x0] "+r"(x0), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), + [za] "+r"(za), [zb] "+r"(zb), [outptr] "+r"(outptr) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "r1", "r2", "r3", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "r1", "r2", "r3", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -247,9 +246,9 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, * * Accumulator */ -static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, size_t m_remain, - size_t n_remain, uint8_t za, uint8_t zb) { +static void kern_4x4( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, size_t m_remain, size_t n_remain, uint8_t za, uint8_t zb) { K /= 8; const uint8_t* a_ptr = packA; const uint8_t* b_ptr = packB; @@ -416,13 +415,13 @@ static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, "3:\n" STORE_C : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), - [x0] "+r"(x0), [m_remain] "+r"(m_remain), - [n_remain] "+r"(n_remain), [za] "+r"(za), [zb] "+r"(zb) + [x0] "+r"(x0), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain), + [za] "+r"(za), [zb] "+r"(zb) : - : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", - "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", - "d29", "d30", "d31", "r1", "cc", "memory"); + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31", + "r1", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C @@ -430,9 +429,9 @@ static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, #undef STORE_C } -static void gemm_u8_4x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, - int ldin, int y0, int ymax, int k0, int kmax, - uint8_t zero_point) { +static void gemm_u8_4x8_pack_A_n( + dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, uint8_t zero_point) { uint8_t zerobuff[16]; std::fill(zerobuff, zerobuff + 16, zero_point); @@ -485,15 +484,14 @@ static void gemm_u8_4x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, megdnn_assert(0); } } - interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, - zero_point); + interleave_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, zero_point); } } } -static void gemm_u8_4x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, - int ldin, int x0, int xmax, int k0, - int kmax, uint8_t zero_point) { +static void gemm_u8_4x8_transpose_pack_A_n( + dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, + uint8_t zero_point) { uint8_t zerobuff[16]; std::fill(zerobuff, zerobuff + 16, zero_point); const int ksize = kmax - k0; @@ -552,8 +550,9 @@ static void gemm_u8_4x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, } } - transpose_4x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_4x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += ksize4; } @@ -586,17 +585,18 @@ static void gemm_u8_4x8_transpose_pack_A_n(dt_uint8* out, const dt_uint8* in, } } - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, xmax - x, zero_point); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, xmax - x, zero_point); } outptr_base += 4 * 8; } } -static void gemm_u8_4x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, - int x0, int xmax, int k0, int kmax, - uint8_t zero_point) { +static void gemm_u8_4x8_pack_B_n( + dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, + uint8_t zero_point) { uint8_t zerobuff[16]; std::fill(zerobuff, zerobuff + 16, zero_point); const int ksize = kmax - k0; @@ -658,8 +658,9 @@ static void gemm_u8_4x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, } } outptr_interleave = outptr; - interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr_interleave); + interleave_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave); outptr += ksize8; } @@ -694,8 +695,9 @@ static void gemm_u8_4x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, 4, zero_point); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, 4, zero_point); outptr += ksize4; } @@ -729,8 +731,9 @@ static void gemm_u8_4x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, } outptr_interleave = outptr; - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr_interleave, 4, xmax - x, zero_point); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr_interleave, 4, xmax - x, zero_point); } outptr_base += 8 * 8; @@ -738,10 +741,9 @@ static void gemm_u8_4x8_pack_B_n(dt_uint8* out, const dt_uint8* in, int ldin, } } -static void gemm_u8_4x8_transpose_pack_B_n(dt_uint8* outptr, - const dt_uint8* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - uint8_t zero_point) { +static void gemm_u8_4x8_transpose_pack_B_n( + dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, uint8_t zero_point) { uint8_t zerobuff[16]; std::fill(zerobuff, zerobuff + 16, zero_point); constexpr int interleave4 = 32; @@ -769,14 +771,16 @@ static void gemm_u8_4x8_transpose_pack_B_n(dt_uint8* outptr, int K = kmax - k0; for (; K > 7; K -= 8) { - transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + transpose_8x8_1_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); outptr += interleave8; } if (K > 0) { - transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 8, K, zero_point); + transpose_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 8, K, zero_point); outptr += interleave8; } } @@ -830,8 +834,7 @@ static void gemm_u8_4x8_transpose_pack_B_n(dt_uint8* outptr, megdnn_assert(0); } } - transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, - zero_point); + transpose_4(inptr0, inptr1, inptr2, inptr3, outptr, 8, K, zero_point); outptr += interleave4; } } diff --git a/dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h b/dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h index 35439e14..a6307cca 100644 --- a/dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h +++ b/dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h @@ -42,9 +42,9 @@ namespace matmul_dot_4x8x4 { // // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, uint8_t zA, - uint8_t zB, uint32_t zAB, size_t m_remain = 4) { +static void kern_4x8( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, uint8_t zA, uint8_t zB, uint32_t zAB, size_t m_remain = 4) { K /= 4; const uint8_t* a_ptr = packA; const uint8_t* b_ptr = packB; @@ -221,14 +221,13 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, STORE_C - : [k] "+r"(k), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), - [m_remain] "+r"(m_remain), [za] "+w"(za), [zb] "+w"(zb), - [zab] "+r"(zAB), [outptr0] "+r"(outptr0), [outptr1] "+r"(outptr1), - [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) + : [k] "+r"(k), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), + [is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), [za] "+w"(za), + [zb] "+w"(zb), [zab] "+r"(zAB), [outptr0] "+r"(outptr0), + [outptr1] "+r"(outptr1), [outptr2] "+r"(outptr2), [outptr3] "+r"(outptr3) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "q12", "q13", "cc", "r12", "memory"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "cc", "r12", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE @@ -258,10 +257,10 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, // // Accumulator MEGDNN_ATTRIBUTE_TARGET("dotprod") -static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, - int32_t* output, int LDC, bool is_first_k, uint8_t zA, - uint8_t zB, uint32_t zAB, size_t m_remain = 4, - size_t n_remain = 4) { +static void kern_4x4( + const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, + bool is_first_k, uint8_t zA, uint8_t zB, uint32_t zAB, size_t m_remain = 4, + size_t n_remain = 4) { K /= 4; const uint8_t* a_ptr = packA; const uint8_t* b_ptr = packB; @@ -432,23 +431,22 @@ static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, "vsub.s32 q8, q8, q12\n" "vsub.s32 q10, q10, q12\n" STORE_C - : [k] "+r"(k), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), - [oddk] "+r"(oddk), [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), - [za] "+w"(za), [zb] "+w"(zb), [zab] "+r"(zAB), - [outptr0] "+r"(outptr0), [m_remain] "+r"(m_remain), - [n_remain] "+r"(n_remain), [x0] "+r"(x0) + : [k] "+r"(k), [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), + [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [za] "+w"(za), + [zb] "+w"(zb), [zab] "+r"(zAB), [outptr0] "+r"(outptr0), + [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain), [x0] "+r"(x0) : - : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", - "q11", "q12", "q13", "r0", "r1", "cc", "memory"); + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "r0", "r1", "cc", "memory"); #undef LOAD_LINE #undef LOAD_C #undef STORE_LINE #undef STORE_C } - -static void gemm_quint8_4x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, - int ldin, int y0, int ymax, int k0, int kmax) { +static void gemm_quint8_4x8_pack_A_n( + dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { uint8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(uint8_t) * 16); @@ -498,8 +496,9 @@ static void gemm_quint8_4x8_pack_A_n(dt_uint8* outptr, const dt_uint8* inptr, } } -static void gemm_quint8_4x8_pack_A_t(dt_uint8* out, const dt_uint8* in, int ldin, - int x0, int xmax, int k0, int kmax) { +static void gemm_quint8_4x8_pack_A_t( + dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, + int kmax) { uint8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(uint8_t) * 16); const int ksize = kmax - k0; @@ -558,9 +557,9 @@ static void gemm_quint8_4x8_pack_A_t(dt_uint8* out, const dt_uint8* in, int ldin } } -static void gemm_quint8_4x8_pack_B_n(dt_uint8* out, const dt_uint8* in, - int ldin, int x0, int xmax, int k0, - int kmax) { +static void gemm_quint8_4x8_pack_B_n( + dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, + int kmax) { uint8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(uint8_t) * 16); const int ksize = kmax - k0; @@ -646,9 +645,9 @@ static void gemm_quint8_4x8_pack_B_n(dt_uint8* out, const dt_uint8* in, } } -static void gemm_quint8_4x8_pack_B_t(dt_uint8* outptr, const dt_uint8* inptr, - int ldin, int y0, int ymax, int k0, - int kmax) { +static void gemm_quint8_4x8_pack_B_t( + dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, + int kmax) { uint8_t zerobuff[16]; std::memset(zerobuff, 0, sizeof(uint8_t) * 16); @@ -675,12 +674,14 @@ static void gemm_quint8_4x8_pack_B_t(dt_uint8* outptr, const dt_uint8* inptr, int K = kmax - k0; //! read 12 * 4 in each row for (; K > 15; K -= 16) { - interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, - inptr6, inptr7, outptr); + interleave_8x4_4_b( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr); } if (K > 0) { - interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, - inptr7, outptr, 4, K); + interleave_8( + inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, + outptr, 4, K); } } for (; y < ymax; y += 4) { diff --git a/dnn/src/armv7/matrix_mul/quint8/strategy.cpp b/dnn/src/armv7/matrix_mul/quint8/strategy.cpp index d32f186f..edce5f87 100644 --- a/dnn/src/armv7/matrix_mul/quint8/strategy.cpp +++ b/dnn/src/armv7/matrix_mul/quint8/strategy.cpp @@ -22,39 +22,38 @@ using namespace armv7::matmul; MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_4x8); -void gemm_u8_4x8::pack_A(dt_uint8* outptr, const dt_uint8* inptr, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +void gemm_u8_4x8::pack_A( + dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, + int kmax, bool transpose) const { uint8_t zA = A_dtype.param().zero_point; if (transpose) { - matmul_4x8x8::gemm_u8_4x8_transpose_pack_A_n(outptr, inptr, ldin, y0, - ymax, k0, kmax, zA); + matmul_4x8x8::gemm_u8_4x8_transpose_pack_A_n( + outptr, inptr, ldin, y0, ymax, k0, kmax, zA); } else { - matmul_4x8x8::gemm_u8_4x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, - kmax, zA); + matmul_4x8x8::gemm_u8_4x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax, zA); } } -void gemm_u8_4x8::pack_B(dt_uint8* out, const dt_uint8* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { +void gemm_u8_4x8::pack_B( + dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { uint8_t zB = B_dtype.param().zero_point; if (transpose) { - matmul_4x8x8::gemm_u8_4x8_transpose_pack_B_n(out, in, ldin, x0, xmax, - k0, kmax, zB); + matmul_4x8x8::gemm_u8_4x8_transpose_pack_B_n( + out, in, ldin, x0, xmax, k0, kmax, zB); } else { - matmul_4x8x8::gemm_u8_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, - zB); + matmul_4x8x8::gemm_u8_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, zB); } } -void gemm_u8_4x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, - size_t N, size_t K, dt_int32* C, size_t LDC, - bool is_first_k, const dt_int32*, dt_int32*) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Quantized8Asymm && - C_dtype.enumv() == DTypeEnum::QuantizedS32, - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_u8_4x8::kern( + const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Quantized8Asymm && + C_dtype.enumv() == DTypeEnum::QuantizedS32, + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); uint8_t zA = A_dtype.param().zero_point; uint8_t zB = B_dtype.param().zero_point; @@ -71,16 +70,17 @@ void gemm_u8_4x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, const dt_uint8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_4x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), zA, zB); + matmul_4x8x8::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), zA, zB); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { - matmul_4x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, - std::min(M - m, 4), - std::min(N - n, 4), zA, zB); + matmul_4x8x8::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, + std::min(M - m, 4), std::min(N - n, 4), zA, zB); output += 4; cur_packB += K4; } @@ -91,42 +91,38 @@ void gemm_u8_4x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, #if MGB_ENABLE_DOT // ===========================gemm_dot_quint8_4x8====================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dot_quint8_4x8); -void gemm_dot_quint8_4x8::pack_A(dt_uint8* out, const dt_uint8* in, int ldin, - int y0, int ymax, int k0, int kmax, - bool transpose) const { +void gemm_dot_quint8_4x8::pack_A( + dt_uint8* out, const dt_uint8* in, int ldin, int y0, int ymax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_dot_4x8x4::gemm_quint8_4x8_pack_A_t(out, in, ldin, y0, ymax, k0, - kmax); + matmul_dot_4x8x4::gemm_quint8_4x8_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); } else { - matmul_dot_4x8x4::gemm_quint8_4x8_pack_A_n(out, in, ldin, y0, ymax, k0, - kmax); + matmul_dot_4x8x4::gemm_quint8_4x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); } } -void gemm_dot_quint8_4x8::pack_B(dt_uint8* out, const dt_uint8* in, int ldin, int x0, - int xmax, int k0, int kmax, bool transpose) const { +void gemm_dot_quint8_4x8::pack_B( + dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, + bool transpose) const { if (transpose) { - matmul_dot_4x8x4::gemm_quint8_4x8_pack_B_t(out, in, ldin, x0, xmax, k0, - kmax); + matmul_dot_4x8x4::gemm_quint8_4x8_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); } else { - matmul_dot_4x8x4::gemm_quint8_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, - kmax); + matmul_dot_4x8x4::gemm_quint8_4x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); } } -void gemm_dot_quint8_4x8::kern(const dt_uint8* packA, const dt_uint8* packB, - size_t M, size_t N, size_t K, dt_int32* C, - size_t LDC, bool is_first_k, const dt_int32*, - dt_int32* workspace) const { - megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && - A_dtype.enumv() == DTypeEnum::Quantized8Asymm && - C_dtype.enumv() == DTypeEnum::QuantizedS32, - "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), - C_dtype.name()); +void gemm_dot_quint8_4x8::kern( + const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, + dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, + dt_int32* workspace) const { + megdnn_assert( + A_dtype.enumv() == B_dtype.enumv() && + A_dtype.enumv() == DTypeEnum::Quantized8Asymm && + C_dtype.enumv() == DTypeEnum::QuantizedS32, + "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); uint8_t zA = A_dtype.param().zero_point; uint8_t zB = B_dtype.param().zero_point; - const uint32_t zAB = - static_cast(zA) * static_cast(zB) * K; + const uint32_t zAB = static_cast(zA) * static_cast(zB) * K; constexpr size_t A_INTERLEAVE = 4; constexpr size_t B_INTERLEAVE = 8; @@ -140,37 +136,39 @@ void gemm_dot_quint8_4x8::kern(const dt_uint8* packA, const dt_uint8* packB, const dt_uint8* cur_packB = packB; size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_dot_4x8x4::kern_4x8(packA, cur_packB, K, output, LDC, - is_first_k, zA, zB, zAB); + matmul_dot_4x8x4::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, zA, zB, zAB); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { size_t n_remain = std::min(N - n, 4); - matmul_dot_4x8x4::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, zA, zB, zAB, 4, n_remain); + matmul_dot_4x8x4::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, zA, zB, zAB, 4, + n_remain); output += n_remain; cur_packB += K4; } packA += K4; } - if(m(M - m, 4); size_t n = 0; for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { - matmul_dot_4x8x4::kern_4x8(packA, cur_packB, K, output, LDC, - is_first_k, zA, zB, zAB, m_remain); + matmul_dot_4x8x4::kern_4x8( + packA, cur_packB, K, output, LDC, is_first_k, zA, zB, zAB, + m_remain); output += B_INTERLEAVE; cur_packB += K8; } for (; n < N; n += 4) { size_t n_remain = std::min(N - n, 4); - matmul_dot_4x8x4::kern_4x4(packA, cur_packB, K, output, LDC, - is_first_k, zA, zB, zAB, m_remain, - n_remain); + matmul_dot_4x8x4::kern_4x4( + packA, cur_packB, K, output, LDC, is_first_k, zA, zB, zAB, m_remain, + n_remain); output += n_remain; cur_packB += K4; } diff --git a/dnn/src/armv7/matrix_mul/quint8/strategy.h b/dnn/src/armv7/matrix_mul/quint8/strategy.h index b8833698..2b5ac19d 100644 --- a/dnn/src/armv7/matrix_mul/quint8/strategy.h +++ b/dnn/src/armv7/matrix_mul/quint8/strategy.h @@ -15,11 +15,11 @@ namespace megdnn { namespace armv7 { namespace matmul { -MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 8, false, true, - gemm_u8_4x8); +MEGDNN_REG_GEMM_STRATEGY( + dt_uint8, dt_int32, dt_int32, 4, 8, 8, false, true, gemm_u8_4x8); #if MGB_ENABLE_DOT -MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 4, false, false, - gemm_dot_quint8_4x8); +MEGDNN_REG_GEMM_STRATEGY( + dt_uint8, dt_int32, dt_int32, 4, 8, 4, false, false, gemm_dot_quint8_4x8); #endif } // namespace matmul diff --git a/dnn/src/armv7/relayout/opr_impl.cpp b/dnn/src/armv7/relayout/opr_impl.cpp index 074d1a52..38f54d10 100644 --- a/dnn/src/armv7/relayout/opr_impl.cpp +++ b/dnn/src/armv7/relayout/opr_impl.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/common/utils.h" #include "src/common/relayout_helper.h" +#include "src/common/utils.h" #include "src/armv7/handle.h" #include "src/armv7/relayout/opr_impl.h" @@ -23,89 +23,83 @@ struct TransposeByte { uint8_t v; }; -void trans_16x16_u8(const void* src, void* dst, const size_t src_step, - const size_t dst_step) { +void trans_16x16_u8( + const void* src, void* dst, const size_t src_step, const size_t dst_step) { // 16x16 asm volatile( - "\n" - "vld1.8 {d0, d1}, [%[src]], %[src_step] \n" - "vld1.8 {d2, d3}, [%[src]], %[src_step] \n" - "vld1.8 {d4, d5}, [%[src]], %[src_step] \n" - "vld1.8 {d6, d7}, [%[src]], %[src_step] \n" - "vld1.8 {d8, d9}, [%[src]], %[src_step] \n" - "vld1.8 {d10, d11}, [%[src]], %[src_step] \n" - "vld1.8 {d12, d13}, [%[src]], %[src_step] \n" - "vld1.8 {d14, d15}, [%[src]], %[src_step] \n" - "vld1.8 {d16, d17}, [%[src]], %[src_step] \n" - "vld1.8 {d18, d19}, [%[src]], %[src_step] \n" - "vld1.8 {d20, d21}, [%[src]], %[src_step] \n" - "vld1.8 {d22, d23}, [%[src]], %[src_step] \n" - "vld1.8 {d24, d25}, [%[src]], %[src_step] \n" - "vld1.8 {d26, d27}, [%[src]], %[src_step] \n" - "vld1.8 {d28, d29}, [%[src]], %[src_step] \n" - "vld1.8 {d30, d31}, [%[src]], %[src_step] \n" - "vtrn.8 q0, q1 \n" - "vtrn.8 q2, q3 \n" - "vtrn.8 q4, q5 \n" - "vtrn.8 q6, q7 \n" - "vtrn.8 q8, q9 \n" - "vtrn.8 q10, q11 \n" - "vtrn.8 q12, q13 \n" - "vtrn.8 q14, q15 \n" - "vtrn.16 q0, q2 \n" - "vtrn.16 q1, q3 \n" - "vtrn.16 q4, q6 \n" - "vtrn.16 q5, q7 \n" - "vtrn.16 q8, q10 \n" - "vtrn.16 q9, q11 \n" - "vtrn.16 q12, q14 \n" - "vtrn.16 q13, q15 \n" - "vtrn.32 q0, q4 \n" - "vtrn.32 q1, q5 \n" - "vtrn.32 q2, q6 \n" - "vtrn.32 q3, q7 \n" - "vtrn.32 q8, q12 \n" - "vtrn.32 q9, q13 \n" - "vtrn.32 q10, q14 \n" - "vtrn.32 q11, q15 \n" - "vswp d1, d16 \n" - "vswp d3, d18 \n" - "vswp d5, d20 \n" - "vswp d7, d22 \n" - "vswp d9, d24 \n" - "vswp d11, d26 \n" - "vswp d13, d28 \n" - "vswp d15, d30 \n" - "vst1.8 {d0, d1}, [%[dst]], %[dst_step] \n" - "vst1.8 {d2, d3}, [%[dst]], %[dst_step] \n" - "vst1.8 {d4, d5}, [%[dst]], %[dst_step] \n" - "vst1.8 {d6, d7}, [%[dst]], %[dst_step] \n" - "vst1.8 {d8, d9}, [%[dst]], %[dst_step] \n" - "vst1.8 {d10, d11}, [%[dst]], %[dst_step] \n" - "vst1.8 {d12, d13}, [%[dst]], %[dst_step] \n" - "vst1.8 {d14, d15}, [%[dst]], %[dst_step] \n" - "vst1.8 {d16, d17}, [%[dst]], %[dst_step] \n" - "vst1.8 {d18, d19}, [%[dst]], %[dst_step] \n" - "vst1.8 {d20, d21}, [%[dst]], %[dst_step] \n" - "vst1.8 {d22, d23}, [%[dst]], %[dst_step] \n" - "vst1.8 {d24, d25}, [%[dst]], %[dst_step] \n" - "vst1.8 {d26, d27}, [%[dst]], %[dst_step] \n" - "vst1.8 {d28, d29}, [%[dst]], %[dst_step] \n" - "vst1.8 {d30, d31}, [%[dst]], %[dst_step] \n" - : - [src] "+r" (src), - [dst] "+r" (dst) - : - [src_step] "r" (src_step), - [dst_step] "r" (dst_step) - : - "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", - "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", - "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", - "d31"); + "\n" + "vld1.8 {d0, d1}, [%[src]], %[src_step] \n" + "vld1.8 {d2, d3}, [%[src]], %[src_step] \n" + "vld1.8 {d4, d5}, [%[src]], %[src_step] \n" + "vld1.8 {d6, d7}, [%[src]], %[src_step] \n" + "vld1.8 {d8, d9}, [%[src]], %[src_step] \n" + "vld1.8 {d10, d11}, [%[src]], %[src_step] \n" + "vld1.8 {d12, d13}, [%[src]], %[src_step] \n" + "vld1.8 {d14, d15}, [%[src]], %[src_step] \n" + "vld1.8 {d16, d17}, [%[src]], %[src_step] \n" + "vld1.8 {d18, d19}, [%[src]], %[src_step] \n" + "vld1.8 {d20, d21}, [%[src]], %[src_step] \n" + "vld1.8 {d22, d23}, [%[src]], %[src_step] \n" + "vld1.8 {d24, d25}, [%[src]], %[src_step] \n" + "vld1.8 {d26, d27}, [%[src]], %[src_step] \n" + "vld1.8 {d28, d29}, [%[src]], %[src_step] \n" + "vld1.8 {d30, d31}, [%[src]], %[src_step] \n" + "vtrn.8 q0, q1 \n" + "vtrn.8 q2, q3 \n" + "vtrn.8 q4, q5 \n" + "vtrn.8 q6, q7 \n" + "vtrn.8 q8, q9 \n" + "vtrn.8 q10, q11 \n" + "vtrn.8 q12, q13 \n" + "vtrn.8 q14, q15 \n" + "vtrn.16 q0, q2 \n" + "vtrn.16 q1, q3 \n" + "vtrn.16 q4, q6 \n" + "vtrn.16 q5, q7 \n" + "vtrn.16 q8, q10 \n" + "vtrn.16 q9, q11 \n" + "vtrn.16 q12, q14 \n" + "vtrn.16 q13, q15 \n" + "vtrn.32 q0, q4 \n" + "vtrn.32 q1, q5 \n" + "vtrn.32 q2, q6 \n" + "vtrn.32 q3, q7 \n" + "vtrn.32 q8, q12 \n" + "vtrn.32 q9, q13 \n" + "vtrn.32 q10, q14 \n" + "vtrn.32 q11, q15 \n" + "vswp d1, d16 \n" + "vswp d3, d18 \n" + "vswp d5, d20 \n" + "vswp d7, d22 \n" + "vswp d9, d24 \n" + "vswp d11, d26 \n" + "vswp d13, d28 \n" + "vswp d15, d30 \n" + "vst1.8 {d0, d1}, [%[dst]], %[dst_step] \n" + "vst1.8 {d2, d3}, [%[dst]], %[dst_step] \n" + "vst1.8 {d4, d5}, [%[dst]], %[dst_step] \n" + "vst1.8 {d6, d7}, [%[dst]], %[dst_step] \n" + "vst1.8 {d8, d9}, [%[dst]], %[dst_step] \n" + "vst1.8 {d10, d11}, [%[dst]], %[dst_step] \n" + "vst1.8 {d12, d13}, [%[dst]], %[dst_step] \n" + "vst1.8 {d14, d15}, [%[dst]], %[dst_step] \n" + "vst1.8 {d16, d17}, [%[dst]], %[dst_step] \n" + "vst1.8 {d18, d19}, [%[dst]], %[dst_step] \n" + "vst1.8 {d20, d21}, [%[dst]], %[dst_step] \n" + "vst1.8 {d22, d23}, [%[dst]], %[dst_step] \n" + "vst1.8 {d24, d25}, [%[dst]], %[dst_step] \n" + "vst1.8 {d26, d27}, [%[dst]], %[dst_step] \n" + "vst1.8 {d28, d29}, [%[dst]], %[dst_step] \n" + "vst1.8 {d30, d31}, [%[dst]], %[dst_step] \n" + : [src] "+r"(src), [dst] "+r"(dst) + : [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", "d11", + "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", "d21", + "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31"); } -} // anonymous namespace +} // anonymous namespace namespace megdnn { namespace relayout { @@ -116,9 +110,9 @@ struct transpose_traits { }; template <> -void transpose_block(const TransposeByte* src, - TransposeByte* dst, const size_t src_stride, - const size_t dst_stride) { +void transpose_block( + const TransposeByte* src, TransposeByte* dst, const size_t src_stride, + const size_t dst_stride) { trans_16x16_u8(src, dst, src_stride, dst_stride); } @@ -126,10 +120,8 @@ void transpose_block(const TransposeByte* src, } // namespace relayout } // namespace megdnn - -void armv7::RelayoutForwardImpl::exec(_megdnn_tensor_in src0, - _megdnn_tensor_out dst0, - Handle* src_handle) { +void armv7::RelayoutForwardImpl::exec( + _megdnn_tensor_in src0, _megdnn_tensor_out dst0, Handle* src_handle) { check_cpu_handle(src_handle); TensorND src = src0, dst = dst0; check_layout_and_canonize(src.layout, dst.layout); @@ -146,10 +138,8 @@ void armv7::RelayoutForwardImpl::exec(_megdnn_tensor_in src0, if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { auto sptr = static_cast(src.raw_ptr), dptr = static_cast(dst.raw_ptr); - MEGDNN_DISPATCH_CPU_KERN_OPR( - transpose_fallback::transpose( - trans_param.batch, trans_param.m, trans_param.n, sptr, - dptr)); + MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose( + trans_param.batch, trans_param.m, trans_param.n, sptr, dptr)); return; } exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); diff --git a/dnn/src/armv7/relayout/opr_impl.h b/dnn/src/armv7/relayout/opr_impl.h index 027f15fe..4edfe39c 100644 --- a/dnn/src/armv7/relayout/opr_impl.h +++ b/dnn/src/armv7/relayout/opr_impl.h @@ -16,11 +16,11 @@ namespace megdnn { namespace armv7 { class RelayoutForwardImpl final : public fallback::RelayoutForwardImpl { - public: +public: using fallback::RelayoutForwardImpl::RelayoutForwardImpl; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - Handle *src_handle) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, Handle* src_handle) override; bool is_thread_safe() const override { return true; } }; diff --git a/dnn/src/armv7/rotate/opr_impl.cpp b/dnn/src/armv7/rotate/opr_impl.cpp index a22a8133..4590a2ee 100644 --- a/dnn/src/armv7/rotate/opr_impl.cpp +++ b/dnn/src/armv7/rotate/opr_impl.cpp @@ -11,8 +11,8 @@ #include -#include "src/armv7/rotate/opr_impl.h" #include "src/armv7/handle.h" +#include "src/armv7/rotate/opr_impl.h" #include "src/common/cv/common.h" #include "src/common/cv/helper.h" #include "src/common/utils.h" @@ -20,11 +20,10 @@ namespace megdnn { namespace megcv { -void rotate_8uc1_clockwise_16x16(const uchar *src, - uchar *dst, - size_t src_step, size_t dst_step) -{ - asm volatile ("\n" +void rotate_8uc1_clockwise_16x16( + const uchar* src, uchar* dst, size_t src_step, size_t dst_step) { + asm volatile( + "\n" "vld1.8 {d0, d1}, [%[src]], %[src_step] \n" "vld1.8 {d2, d3}, [%[src]], %[src_step] \n" "vld1.8 {d4, d5}, [%[src]], %[src_step] \n" @@ -121,27 +120,18 @@ void rotate_8uc1_clockwise_16x16(const uchar *src, "vst1.8 {d26, d27}, [%[dst]], %[dst_step] \n" "vst1.8 {d28, d29}, [%[dst]], %[dst_step] \n" "vst1.8 {d30, d31}, [%[dst]], %[dst_step] \n" - : - [src] "+r" (src), - [dst] "+r" (dst) - : - [src_step] "r" (src_step), - [dst_step] "r" (dst_step) - : - "r0", "r1", "r2", "r3", - "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", - "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", - "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", - "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31" - ); - + : [src] "+r"(src), [dst] "+r"(dst) + : [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", + "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31"); } -void rotate_8uc1_counterclockwise_16x16(const uchar *src, - uchar *dst, - size_t src_step, size_t dst_step) -{ - asm volatile ("\n" +void rotate_8uc1_counterclockwise_16x16( + const uchar* src, uchar* dst, size_t src_step, size_t dst_step) { + asm volatile( + "\n" "vld1.8 {d0, d1}, [%[src]], %[src_step] \n" "vld1.8 {d2, d3}, [%[src]], %[src_step] \n" "vld1.8 {d4, d5}, [%[src]], %[src_step] \n" @@ -206,25 +196,17 @@ void rotate_8uc1_counterclockwise_16x16(const uchar *src, "vst1.8 {d4, d5}, [%[dst]], %[dst_step] \n" "vst1.8 {d2, d3}, [%[dst]], %[dst_step] \n" "vst1.8 {d0, d1}, [%[dst]], %[dst_step] \n" - : - [src] "+r" (src), - [dst] "+r" (dst) - : - [src_step] "r" (src_step), - [dst_step] "r" (dst_step) - : - "r0", "r1", "r2", "r3", - "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", - "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", - "d16", "d17", "d18", "d19", "d20", "d21", "d22", "d23", - "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31" - ); + : [src] "+r"(src), [dst] "+r"(dst) + : [src_step] "r"(src_step), [dst_step] "r"(dst_step) + : "r0", "r1", "r2", "r3", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", + "d8", "d9", "d10", "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", + "d19", "d20", "d21", "d22", "d23", "d24", "d25", "d26", "d27", "d28", + "d29", "d30", "d31"); } -void rotate_8uc1_clockwise(const uchar *src, uchar *dst, - const size_t rows, const size_t cols, - const size_t src_step, const size_t dst_step) -{ +void rotate_8uc1_clockwise( + const uchar* src, uchar* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { const size_t block = 16; (void)block; size_t i = 0; @@ -232,28 +214,27 @@ void rotate_8uc1_clockwise(const uchar *src, uchar *dst, for (; i + block <= rows; i += block) { size_t j = 0; for (; j + block <= cols; j += block) { - rotate_8uc1_clockwise_16x16(src + i*src_step + j, - dst + j*dst_step + (rows-(i+block)), + rotate_8uc1_clockwise_16x16( + src + i * src_step + j, dst + j * dst_step + (rows - (i + block)), src_step, dst_step); } for (; j < cols; ++j) { for (size_t k = 0; k < block; ++k) { - dst[j*dst_step + (rows-1-(i+k))] = src[(i+k)*src_step + j]; + dst[j * dst_step + (rows - 1 - (i + k))] = src[(i + k) * src_step + j]; } } } for (; i < rows; ++i) { for (size_t j = 0; j < cols; ++j) { - dst[j*dst_step + (rows-1-i)] = src[i*src_step + j]; + dst[j * dst_step + (rows - 1 - i)] = src[i * src_step + j]; } } } -void rotate_8uc1_counterclockwise(const uchar *src, uchar *dst, - const size_t rows, const size_t cols, - const size_t src_step, const size_t dst_step) -{ +void rotate_8uc1_counterclockwise( + const uchar* src, uchar* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { const size_t block = 16; (void)block; size_t i = 0; @@ -261,37 +242,35 @@ void rotate_8uc1_counterclockwise(const uchar *src, uchar *dst, for (; i + block <= rows; i += block) { size_t j = 0; for (; j + block <= cols; j += block) { - rotate_8uc1_counterclockwise_16x16(src + i*src_step + j, - dst + (cols-(j+block))*dst_step + i, + rotate_8uc1_counterclockwise_16x16( + src + i * src_step + j, dst + (cols - (j + block)) * dst_step + i, src_step, dst_step); } for (; j < cols; ++j) { for (size_t k = 0; k < block; ++k) { - dst[(cols-1-j)*dst_step + (i+k)] = src[(i+k)*src_step + j]; + dst[(cols - 1 - j) * dst_step + (i + k)] = src[(i + k) * src_step + j]; } } } for (; i < rows; ++i) { for (size_t j = 0; j < cols; ++j) { - dst[(cols-1-j)*dst_step + i] = src[i*src_step + j]; + dst[(cols - 1 - j) * dst_step + i] = src[i * src_step + j]; } } } -void rotate(const Mat &src, Mat &dst, - bool clockwise) -{ +void rotate(const Mat& src, Mat& dst, bool clockwise) { megdnn_assert(src.rows() == dst.cols()); megdnn_assert(src.cols() == dst.rows()); megdnn_assert(src.channels() == dst.channels()); megdnn_assert(src.channels() == 1_z); if (clockwise) { - rotate_8uc1_clockwise(src.ptr(), dst.ptr(), src.rows(), src.cols(), - src.step(), dst.step()); + rotate_8uc1_clockwise( + src.ptr(), dst.ptr(), src.rows(), src.cols(), src.step(), dst.step()); } else { - rotate_8uc1_counterclockwise(src.ptr(), dst.ptr(), src.rows(), - src.cols(), src.step(), dst.step()); + rotate_8uc1_counterclockwise( + src.ptr(), dst.ptr(), src.rows(), src.cols(), src.step(), dst.step()); } } @@ -299,8 +278,8 @@ void rotate(const Mat &src, Mat &dst, namespace armv7 { -void RotateImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_workspace workspace) { +void RotateImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { using namespace megcv; check_exec(src.layout, dst.layout, workspace.size); @@ -316,7 +295,6 @@ void RotateImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, rotate(src_mat, dst_mat, param().clockwise); } }); - } } // namespace armv7 diff --git a/dnn/src/armv7/rotate/opr_impl.h b/dnn/src/armv7/rotate/opr_impl.h index 8bccd31e..9adf32a1 100644 --- a/dnn/src/armv7/rotate/opr_impl.h +++ b/dnn/src/armv7/rotate/opr_impl.h @@ -17,15 +17,15 @@ namespace megdnn { namespace armv7 { class RotateImpl : public fallback::RotateImpl { - public: - using fallback::RotateImpl::RotateImpl; +public: + using fallback::RotateImpl::RotateImpl; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, - const TensorLayout&) override { - return 0; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { + return 0; } }; diff --git a/dnn/src/atlas/checksum/opr_impl.cpp b/dnn/src/atlas/checksum/opr_impl.cpp index 88d19d4b..880336ad 100644 --- a/dnn/src/atlas/checksum/opr_impl.cpp +++ b/dnn/src/atlas/checksum/opr_impl.cpp @@ -13,8 +13,8 @@ #include "src/atlas/utils.h" #include "src/naive/handle.h" -#include "src/common/utils.h" #include "src/common/opr_delegate.h" +#include "src/common/utils.h" #include @@ -25,8 +25,8 @@ size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout&) { return 0; } -ChecksumForward::Result ChecksumForwardImpl::exec(_megdnn_tensor_in data, - _megdnn_workspace workspace) { +ChecksumForward::Result ChecksumForwardImpl::exec( + _megdnn_tensor_in data, _megdnn_workspace workspace) { check_exec(data.layout, workspace.size); //! FIXME currently the cce programming interface is not so stable, here i //! just allocate some memory of cpu here and compute the result in cpu @@ -35,8 +35,9 @@ ChecksumForward::Result ChecksumForwardImpl::exec(_megdnn_tensor_in data, megcoreDeviceHandle_t dev_handle; megcoreComputingHandle_t comp_handle = handle()->megcore_computing_handle(); megcoreGetDeviceHandle(comp_handle, &dev_handle); - megcoreMemcpy(comp_handle, cpu_data.data(), data.raw_ptr, cpu_data.size(), - megcoreMemcpyDeviceToHost); + megcoreMemcpy( + comp_handle, cpu_data.data(), data.raw_ptr, cpu_data.size(), + megcoreMemcpyDeviceToHost); megcoreSynchronize(comp_handle); auto opr = inplace_cpu_handle()->create_operator(); diff --git a/dnn/src/atlas/checksum/opr_impl.h b/dnn/src/atlas/checksum/opr_impl.h index cd22c393..c8583459 100644 --- a/dnn/src/atlas/checksum/opr_impl.h +++ b/dnn/src/atlas/checksum/opr_impl.h @@ -28,7 +28,7 @@ public: Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) override; }; -} // namespace naive +} // namespace atlas } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/atlas/handle.cpp b/dnn/src/atlas/handle.cpp index b50187a5..c8e9f568 100644 --- a/dnn/src/atlas/handle.cpp +++ b/dnn/src/atlas/handle.cpp @@ -10,10 +10,10 @@ * implied. */ -#include "megcore_atlas.h" -#include "src/common/handle_impl.h" #include "src/atlas/handle.h" +#include "megcore_atlas.h" #include "src/atlas/checksum/opr_impl.h" +#include "src/common/handle_impl.h" #include diff --git a/dnn/src/atlas/handle.h b/dnn/src/atlas/handle.h index cad6d2ec..f8688a02 100644 --- a/dnn/src/atlas/handle.h +++ b/dnn/src/atlas/handle.h @@ -15,10 +15,10 @@ #include "megdnn/handle.h" #include "megdnn/oprs/general.h" +#include "src/atlas/megcore/device_context.hpp" #include "src/common/handle_impl.h" #include "src/common/megcore/common/device_context.hpp" #include "src/common/utils.h" -#include "src/atlas/megcore/device_context.hpp" #include #include @@ -38,9 +38,7 @@ public: template std::unique_ptr create_operator(); - const megcore::AtlasContext& megcore_context() const { - return m_megcore_context; - } + const megcore::AtlasContext& megcore_context() const { return m_megcore_context; } int device_id() const { return m_device_id; } diff --git a/dnn/src/atlas/megcore/computing_context.cpp b/dnn/src/atlas/megcore/computing_context.cpp index cad6e808..fa25034b 100644 --- a/dnn/src/atlas/megcore/computing_context.cpp +++ b/dnn/src/atlas/megcore/computing_context.cpp @@ -18,9 +18,8 @@ using namespace megcore; using namespace megcore::atlas; -AtlasComputingContext::AtlasComputingContext(megcoreDeviceHandle_t dev_handle, - unsigned int flags, - const AtlasContext& ctx) +AtlasComputingContext::AtlasComputingContext( + megcoreDeviceHandle_t dev_handle, unsigned int flags, const AtlasContext& ctx) : ComputingContext(dev_handle, flags), m_own_stream{ctx.stream == nullptr}, m_ctx{ctx} { @@ -38,22 +37,22 @@ AtlasComputingContext::~AtlasComputingContext() { } } -void AtlasComputingContext::memcpy(void* dst, const void* src, - size_t size_in_bytes, - megcoreMemcpyKind_t kind) { +void AtlasComputingContext::memcpy( + void* dst, const void* src, size_t size_in_bytes, megcoreMemcpyKind_t kind) { switch (kind) { case megcoreMemcpyDeviceToHost: - acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, - ACL_MEMCPY_DEVICE_TO_HOST)); + acl_check(aclrtMemcpy( + dst, size_in_bytes, src, size_in_bytes, ACL_MEMCPY_DEVICE_TO_HOST)); break; case megcoreMemcpyHostToDevice: - acl_check(aclrtMemcpy(dst, size_in_bytes, src, size_in_bytes, - ACL_MEMCPY_HOST_TO_DEVICE)); + acl_check(aclrtMemcpy( + dst, size_in_bytes, src, size_in_bytes, ACL_MEMCPY_HOST_TO_DEVICE)); break; case megcoreMemcpyDeviceToDevice: // async d2d is always faster than sync d2d because of SDMA - acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes, - ACL_MEMCPY_DEVICE_TO_DEVICE, m_ctx.stream)); + acl_check(aclrtMemcpyAsync( + dst, size_in_bytes, src, size_in_bytes, ACL_MEMCPY_DEVICE_TO_DEVICE, + m_ctx.stream)); break; default: megdnn_throw("bad atlas memcpy kind"); diff --git a/dnn/src/atlas/megcore/device_context.cpp b/dnn/src/atlas/megcore/device_context.cpp index f3624923..13c4624c 100644 --- a/dnn/src/atlas/megcore/device_context.cpp +++ b/dnn/src/atlas/megcore/device_context.cpp @@ -21,8 +21,8 @@ using namespace megcore; using namespace atlas; -AtlasDeviceContext::AtlasDeviceContext(int device_id, unsigned int flags, - bool global_initialized) +AtlasDeviceContext::AtlasDeviceContext( + int device_id, unsigned int flags, bool global_initialized) : DeviceContext(megcorePlatformAtlas, device_id, flags) { if (!global_initialized) init_status.init(); diff --git a/dnn/src/atlas/megcore/public_api/computing.cpp b/dnn/src/atlas/megcore/public_api/computing.cpp index 4dabe1ae..990d24f1 100644 --- a/dnn/src/atlas/megcore/public_api/computing.cpp +++ b/dnn/src/atlas/megcore/public_api/computing.cpp @@ -35,24 +35,24 @@ megcoreStatus_t megcore::createComputingHandleWithAtlasContext( unsigned int flags, const AtlasContext& ctx) { MEGDNN_MARK_USED_VAR(flags); megdnn_assert(flags == 0); - auto content = megdnn::make_unique( - devHandle, flags, ctx); + auto content = + megdnn::make_unique(devHandle, flags, ctx); auto& H = *compHandle; H = new megcoreComputingContext; H->content = std::move(content); return megcoreSuccess; } -megcoreStatus_t megcore::getAtlasContext(megcoreComputingHandle_t handle, - AtlasContext* ctx) { +megcoreStatus_t megcore::getAtlasContext( + megcoreComputingHandle_t handle, AtlasContext* ctx) { auto&& H = handle; megdnn_assert(H); megcoreDeviceHandle_t dev_handle = H->content->dev_handle(); megcorePlatform_t platform; megcoreGetPlatform(dev_handle, &platform); megdnn_assert(platform == megcorePlatformAtlas); - auto context = static_cast( - H->content.get()); + auto context = + static_cast(H->content.get()); *ctx = context->context(); return megcoreSuccess; } diff --git a/dnn/src/atlas/utils.cpp b/dnn/src/atlas/utils.cpp index 037ca373..a929bb3b 100644 --- a/dnn/src/atlas/utils.cpp +++ b/dnn/src/atlas/utils.cpp @@ -18,8 +18,9 @@ using namespace megdnn; using namespace atlas; void atlas::__throw_acl_error__(aclError err, const char* msg) { - auto s = ssprintf("acl return %s(%d) occurred; expr: %s", - megcore::atlas::get_error_str(err), int(err), msg); + auto s = ssprintf( + "acl return %s(%d) occurred; expr: %s", megcore::atlas::get_error_str(err), + int(err), msg); megdnn_throw(s.c_str()); } diff --git a/dnn/src/cambricon/checksum/checksum.mlu.h b/dnn/src/cambricon/checksum/checksum.mlu.h index 6da2266a..b024b20e 100644 --- a/dnn/src/cambricon/checksum/checksum.mlu.h +++ b/dnn/src/cambricon/checksum/checksum.mlu.h @@ -23,5 +23,3 @@ void checksum_kernel_union4(uint32_t* dst, const uint32_t* src, int num_elems); #endif // vim: ft=cpp syntax=cpp.doxygen - - diff --git a/dnn/src/cambricon/checksum/opr_impl.cpp b/dnn/src/cambricon/checksum/opr_impl.cpp index ac3a6bbd..458eceb5 100644 --- a/dnn/src/cambricon/checksum/opr_impl.cpp +++ b/dnn/src/cambricon/checksum/opr_impl.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/cambricon/checksum/checksum.mlu.h" #include "src/cambricon/checksum/opr_impl.h" +#include "src/cambricon/checksum/checksum.mlu.h" #include "src/cambricon/utils.h" @@ -20,8 +20,9 @@ using namespace megdnn; using namespace cambricon; namespace { -void bang_c_wrapper(uint32_t* dst, const uint32_t* src, int nr_elems, - cnrtQueue_t queue, cnrtCoreVersion_t core_version) { +void bang_c_wrapper( + uint32_t* dst, const uint32_t* src, int nr_elems, cnrtQueue_t queue, + cnrtCoreVersion_t core_version) { cnrtKernelParamsBuffer_t params; cnrt_check(cnrtGetKernelParamsBuffer(¶ms)); cnrt_check(cnrtKernelParamsBufferAddParam(params, &dst, sizeof(uint32_t*))); @@ -33,16 +34,16 @@ void bang_c_wrapper(uint32_t* dst, const uint32_t* src, int nr_elems, dim.y = 1; dim.z = 1; cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION4; - cnrt_check(cnrtInvokeKernel_V2((void*)&checksum_kernel_union4, dim, - params, c, queue)); + cnrt_check(cnrtInvokeKernel_V2( + (void*)&checksum_kernel_union4, dim, params, c, queue)); } else if (core_version == CNRT_MLU220) { cnrtDim3_t dim; dim.x = 4; dim.y = 1; dim.z = 1; cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION1; - cnrt_check(cnrtInvokeKernel_V2((void*)&checksum_kernel_union1, dim, - params, c, queue)); + cnrt_check(cnrtInvokeKernel_V2( + (void*)&checksum_kernel_union1, dim, params, c, queue)); } after_kernel_launch(); cnrt_check(cnrtDestroyKernelParamsBuffer(params)); @@ -54,32 +55,31 @@ size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout& /* data * return ws_size; } -ChecksumForward::Result ChecksumForwardImpl::exec(_megdnn_tensor_in data, - _megdnn_workspace workspace) { +ChecksumForward::Result ChecksumForwardImpl::exec( + _megdnn_tensor_in data, _megdnn_workspace workspace) { Result result; memset(&result, 0, sizeof(result)); check_exec(data.layout, workspace.size); auto queue = cnrt_queue(handle()); auto ptr = static_cast(data.raw_ptr); - size_t size_all = data.layout.shape[0], - size_ints = size_all / sizeof(uint32_t); + size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t); auto last_val_size = std::min(size_all, 4); - cnrt_check(cnrtMemcpyAsync(&result.last_val, ptr + size_all - last_val_size, - last_val_size, queue, - CNRT_MEM_TRANS_DIR_DEV2HOST)); + cnrt_check(cnrtMemcpyAsync( + &result.last_val, ptr + size_all - last_val_size, last_val_size, queue, + CNRT_MEM_TRANS_DIR_DEV2HOST)); if (size_ints) { auto&& device_info = current_device_info(); - bang_c_wrapper(reinterpret_cast(workspace.raw_ptr), - static_cast(data.raw_ptr), size_ints, queue, - device_info.core_version); - cnrt_check(cnrtMemcpyAsync(&result.checksum, workspace.raw_ptr, - sizeof(result.checksum), queue, - CNRT_MEM_TRANS_DIR_DEV2HOST)); + bang_c_wrapper( + reinterpret_cast(workspace.raw_ptr), + static_cast(data.raw_ptr), size_ints, queue, + device_info.core_version); + cnrt_check(cnrtMemcpyAsync( + &result.checksum, workspace.raw_ptr, sizeof(result.checksum), queue, + CNRT_MEM_TRANS_DIR_DEV2HOST)); } cnrt_check(cnrtSyncQueue(queue)); return result; } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cambricon/checksum/opr_impl.h b/dnn/src/cambricon/checksum/opr_impl.h index de410fbd..45ba190a 100644 --- a/dnn/src/cambricon/checksum/opr_impl.h +++ b/dnn/src/cambricon/checksum/opr_impl.h @@ -32,5 +32,3 @@ public: } // namespace megdnn // vim: syntax=cpp.doxygen - - diff --git a/dnn/src/cambricon/handle.cpp b/dnn/src/cambricon/handle.cpp index 5cc8b736..69b973c9 100644 --- a/dnn/src/cambricon/handle.cpp +++ b/dnn/src/cambricon/handle.cpp @@ -15,8 +15,8 @@ #include "src/cambricon/handle.h" #include "src/cambricon/utils.h" -#include "src/cambricon/checksum/opr_impl.h" #include +#include "src/cambricon/checksum/opr_impl.h" namespace megdnn { namespace cambricon { @@ -61,8 +61,7 @@ MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) } // namespace cambricon } // namespace megdnn -MEGDNN_VERSION_SYMBOL3(CNRT, CNRT_MAJOR_VERSION, CNRT_MINOR_VERSION, - CNRT_PATCH_VERSION); +MEGDNN_VERSION_SYMBOL3( + CNRT, CNRT_MAJOR_VERSION, CNRT_MINOR_VERSION, CNRT_PATCH_VERSION); // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cambricon/handle.h b/dnn/src/cambricon/handle.h index a57f6661..91fa6188 100644 --- a/dnn/src/cambricon/handle.h +++ b/dnn/src/cambricon/handle.h @@ -62,4 +62,3 @@ private: } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cambricon/megcore/cambricon_computing_context.cpp b/dnn/src/cambricon/megcore/cambricon_computing_context.cpp index 2deb23e2..73ee9df8 100644 --- a/dnn/src/cambricon/megcore/cambricon_computing_context.cpp +++ b/dnn/src/cambricon/megcore/cambricon_computing_context.cpp @@ -38,9 +38,8 @@ CambriconComputingContext::~CambriconComputingContext() { } } -void CambriconComputingContext::memcpy(void* dst, const void* src, - size_t size_in_bytes, - megcoreMemcpyKind_t kind) { +void CambriconComputingContext::memcpy( + void* dst, const void* src, size_t size_in_bytes, megcoreMemcpyKind_t kind) { cnrtMemTransDir_t dir; switch (kind) { case megcoreMemcpyDeviceToHost: @@ -60,12 +59,11 @@ void CambriconComputingContext::memcpy(void* dst, const void* src, cnrt_check(cnrtMemcpy(dst, const_cast(src), size_in_bytes, dir)); return; } - cnrt_check(cnrtMemcpyAsync(dst, const_cast(src), size_in_bytes, - context_.queue, dir)); + cnrt_check(cnrtMemcpyAsync( + dst, const_cast(src), size_in_bytes, context_.queue, dir)); } -void CambriconComputingContext::memset(void* dst, int value, - size_t size_in_bytes) { +void CambriconComputingContext::memset(void* dst, int value, size_t size_in_bytes) { cnrt_check(cnrtSyncQueue(context_.queue)); cnrt_check(cnrtMemset(dst, value, size_in_bytes)); } @@ -75,4 +73,3 @@ void CambriconComputingContext::synchronize() { } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cambricon/megcore/cambricon_device_context.cpp b/dnn/src/cambricon/megcore/cambricon_device_context.cpp index 8a2a298e..629f1549 100644 --- a/dnn/src/cambricon/megcore/cambricon_device_context.cpp +++ b/dnn/src/cambricon/megcore/cambricon_device_context.cpp @@ -17,7 +17,7 @@ #include "src/cambricon/megcore/cambricon_device_context.hpp" #define STR_HELPER(x) #x -#define STR(x) STR_HELPER(x) +#define STR(x) STR_HELPER(x) #define CNRT_VERSION_STR \ STR(CNRT_MAJOR_VERSION) \ @@ -31,23 +31,21 @@ using namespace megcore; using namespace cambricon; -CambriconDeviceContext::CambriconDeviceContext(int device_id, - unsigned int flags, - bool global_initialized) +CambriconDeviceContext::CambriconDeviceContext( + int device_id, unsigned int flags, bool global_initialized) : DeviceContext(megcorePlatformCambricon, device_id, flags) { if (!global_initialized) init_status.init(); unsigned int version; cnrt_check(cnrtGetVersion(&version)); - megdnn_assert(version == CNRT_VERSION, - "megcore compiled with cnrt %d, get %d at runtime", - CNRT_VERSION, version); + megdnn_assert( + version == CNRT_VERSION, "megcore compiled with cnrt %d, get %d at runtime", + CNRT_VERSION, version); unsigned int dev_num; cnrt_check(cnrtGetDeviceCount(&dev_num)); MEGDNN_MARK_USED_VAR(dev_num); // check validity of device_id - megdnn_assert(device_id >= 0 && - static_cast(device_id) < dev_num); + megdnn_assert(device_id >= 0 && static_cast(device_id) < dev_num); cnrt_check(cnrtGetDeviceInfo(&device_info, device_id)); } @@ -77,4 +75,3 @@ void CambriconDeviceContext::free(void* ptr) { CambriconDeviceContext::InitStatus CambriconDeviceContext::init_status; // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cambricon/megcore/public_api/computing.cpp b/dnn/src/cambricon/megcore/public_api/computing.cpp index 363ebb72..edb9e033 100644 --- a/dnn/src/cambricon/megcore/public_api/computing.cpp +++ b/dnn/src/cambricon/megcore/public_api/computing.cpp @@ -40,8 +40,8 @@ megcoreStatus_t megcore::createComputingHandleWithCambriconContext( return megcoreSuccess; } -megcoreStatus_t megcore::getCambriconContext(megcoreComputingHandle_t handle, - CambriconContext* ctx) { +megcoreStatus_t megcore::getCambriconContext( + megcoreComputingHandle_t handle, CambriconContext* ctx) { auto&& H = handle; megdnn_assert(H); megcoreDeviceHandle_t dev_handle = H->content->dev_handle(); @@ -55,5 +55,3 @@ megcoreStatus_t megcore::getCambriconContext(megcoreComputingHandle_t handle, } // vim: syntax=cpp.doxygen - - diff --git a/dnn/src/cambricon/utils.cpp b/dnn/src/cambricon/utils.cpp index 02c6d3a5..af174c92 100644 --- a/dnn/src/cambricon/utils.cpp +++ b/dnn/src/cambricon/utils.cpp @@ -33,8 +33,9 @@ DeviceInfoRecord device_info_rec[MAX_NR_DEVICE]; } // namespace void cambricon::__throw_cnrt_error__(cnrtRet_t err, const char* msg) { - auto s = ssprintf("cnrt return %s(%d) occurred; expr: %s", - cnrtGetErrorStr(err), int(err), msg); + auto s = ssprintf( + "cnrt return %s(%d) occurred; expr: %s", cnrtGetErrorStr(err), int(err), + msg); megdnn_throw(s.c_str()); } @@ -72,4 +73,3 @@ cnrtDeviceInfo_t cambricon::current_device_info() { } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cambricon/utils.h b/dnn/src/cambricon/utils.h index 1e3f5804..b5b08f77 100644 --- a/dnn/src/cambricon/utils.h +++ b/dnn/src/cambricon/utils.h @@ -37,4 +37,3 @@ cnrtDeviceInfo_t current_device_info(); } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cambricon/utils.mlu.h b/dnn/src/cambricon/utils.mlu.h index 1355c65b..326cd56d 100644 --- a/dnn/src/cambricon/utils.mlu.h +++ b/dnn/src/cambricon/utils.mlu.h @@ -39,4 +39,3 @@ MEGDNN_NORETURN void __throw_cnrt_error__(cnrtRet_t err, const char* msg); } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/common/adaptive_pooling.cpp b/dnn/src/common/adaptive_pooling.cpp index 3f028909..1a9f0d29 100644 --- a/dnn/src/common/adaptive_pooling.cpp +++ b/dnn/src/common/adaptive_pooling.cpp @@ -18,8 +18,7 @@ namespace megdnn { param::Pooling AdaptivePoolingBase::deduce_pooling_param( const TensorLayout& src, const TensorLayout& dst) { megdnn_assert(param().format == param::AdaptivePooling::Format::NCHW); - size_t IH = src.shape[2], IW = src.shape[3], OH = dst.shape[2], - OW = dst.shape[3]; + size_t IH = src.shape[2], IW = src.shape[3], OH = dst.shape[2], OW = dst.shape[3]; param::Pooling ret; ret.mode = param().mode; diff --git a/dnn/src/common/add_update.cpp b/dnn/src/common/add_update.cpp index b0eb7b3d..d9bbcd53 100644 --- a/dnn/src/common/add_update.cpp +++ b/dnn/src/common/add_update.cpp @@ -16,13 +16,13 @@ namespace megdnn { -void AddUpdateForward::check_exec(const TensorLayout& dst, - const TensorLayout& delta) { +void AddUpdateForward::check_exec(const TensorLayout& dst, const TensorLayout& delta) { // delta can not be broadcasted to dst if dst.total_nr_elems() < // delta.total_nr_elems() - megdnn_assert(dst.dtype == delta.dtype && - dst.total_nr_elems() >= delta.total_nr_elems() && - dst.is_non_overlapping_strong()); + megdnn_assert( + dst.dtype == delta.dtype && + dst.total_nr_elems() >= delta.total_nr_elems() && + dst.is_non_overlapping_strong()); if (dst.dtype.category() == DTypeCategory::INT) { auto check_fv = [](float fv) { int iv = fv; diff --git a/dnn/src/common/add_update_helper.h b/dnn/src/common/add_update_helper.h index 0f6af07b..4b3785f3 100644 --- a/dnn/src/common/add_update_helper.h +++ b/dnn/src/common/add_update_helper.h @@ -19,8 +19,7 @@ class AddUpdateForwardHelper : public AddUpdateForward { using AddUpdateForward::AddUpdateForward; protected: - ElemwiseOpParamN<2> make_param(_megdnn_tensor_inout dst, - _megdnn_tensor_in delta); + ElemwiseOpParamN<2> make_param(_megdnn_tensor_inout dst, _megdnn_tensor_in delta); }; } // namespace megdnn diff --git a/dnn/src/common/algo_base.cpp b/dnn/src/common/algo_base.cpp index c71b10f6..9b412208 100644 --- a/dnn/src/common/algo_base.cpp +++ b/dnn/src/common/algo_base.cpp @@ -15,12 +15,9 @@ using namespace megdnn; -#define FOREACH_ALGO_ATTRIBUTE(cb) \ - cb(DEFAULT) \ - cb(REPRODUCIBLE) \ - cb(NAIVE) \ - cb(USABLE_DEPEND_ON_SHAPE) \ - cb(ACCURACY_DEPEND_ON_BATCH) +#define FOREACH_ALGO_ATTRIBUTE(cb) \ + cb(DEFAULT) cb(REPRODUCIBLE) cb(NAIVE) cb(USABLE_DEPEND_ON_SHAPE) \ + cb(ACCURACY_DEPEND_ON_BATCH) namespace { inline const char* attr_str(const AlgoAttribute& attr) { @@ -36,7 +33,7 @@ inline const char* attr_str(const AlgoAttribute& attr) { std::string Algorithm::attribute_str(const Attribute& attr) { std::string ret; uint32_t attr_val = static_cast(attr); - while(attr_val) { + while (attr_val) { uint32_t mask = ~(attr_val & (attr_val - 1)); Attribute sub_attr = static_cast(mask & attr_val); if (!ret.empty()) { @@ -59,16 +56,17 @@ bool Algorithm::contain_attribute_any(const Attribute& attr) const { return static_cast(attribute() & attr); } -void Algorithm::check_attribute(const Attribute& positive_attr, - const Attribute& negative_attr) const { - megdnn_assert(contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr), - "require algorithm with attribute(%s) and without " - "attribute(%s), but get" - "algorithm(%s) with attribute(%s) ", - Algorithm::attribute_str(positive_attr).c_str(), - Algorithm::attribute_str(negative_attr).c_str(), name(), - Algorithm::attribute_str(attribute()).c_str()); +void Algorithm::check_attribute( + const Attribute& positive_attr, const Attribute& negative_attr) const { + megdnn_assert( + contain_attribute_all(positive_attr) && + !contain_attribute_any(negative_attr), + "require algorithm with attribute(%s) and without " + "attribute(%s), but get" + "algorithm(%s) with attribute(%s) ", + Algorithm::attribute_str(positive_attr).c_str(), + Algorithm::attribute_str(negative_attr).c_str(), name(), + Algorithm::attribute_str(attribute()).c_str()); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/algo_base.h b/dnn/src/common/algo_base.h index 91c3c178..21aac780 100644 --- a/dnn/src/common/algo_base.h +++ b/dnn/src/common/algo_base.h @@ -21,30 +21,28 @@ namespace megdnn { -#define MEGDNN_DECL_ALGO_TYPE(_type) \ - uint32_t type() const override { \ - return static_cast::type>( \ - AlgoType::_type); \ +#define MEGDNN_DECL_ALGO_TYPE(_type) \ + uint32_t type() const override { \ + return static_cast::type>(AlgoType::_type); \ } -#define MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(_opr) \ - static fallback::_opr::AlgoBase* get_algo_from_desc( \ - const AlgorithmDesc& desc) +#define MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(_opr) \ + static fallback::_opr::AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc) -#define MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(_opr) \ - fallback::_opr::AlgoBase* _opr::get_algo_from_desc( \ - const AlgorithmDesc& desc) { \ - megdnn_assert(algo_pack().all_algos_map().find(desc) != \ - algo_pack().all_algos_map().end()); \ - return algo_pack().all_algos_map().at(desc); \ +#define MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(_opr) \ + fallback::_opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \ + megdnn_assert( \ + algo_pack().all_algos_map().find(desc) != \ + algo_pack().all_algos_map().end()); \ + return algo_pack().all_algos_map().at(desc); \ } -#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ - _opr::Algorithm* _opr::get_algorithm_from_desc( \ - const AlgorithmDesc& desc) { \ - megdnn_assert(algo_pack().all_algos_map().find(desc) != \ - algo_pack().all_algos_map().end()); \ - return algo_pack().all_algos_map().at(desc); \ +#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ + _opr::Algorithm* _opr::get_algorithm_from_desc(const AlgorithmDesc& desc) { \ + megdnn_assert( \ + algo_pack().all_algos_map().find(desc) != \ + algo_pack().all_algos_map().end()); \ + return algo_pack().all_algos_map().at(desc); \ } #define MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) \ @@ -61,8 +59,7 @@ protected: public: //! construct the algo which described by desc, and return the instance - AlgoBase* construct_and_get_algo( - const detail::Algorithm::Info::Desc& desc) { + AlgoBase* construct_and_get_algo(const detail::Algorithm::Info::Desc& desc) { auto iter = m_all_algos_map.find(desc); if (iter != m_all_algos_map.end()) { return m_all_algos_map.at(desc); @@ -80,14 +77,12 @@ public: m_refhold.clear(); } - const typename AlgoBase::Mapper& all_algos_map() const { - return m_all_algos_map; - } + const typename AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; template -inline typename std::enable_if::type -set_sub_execution_policy(const Opr*, std::tuple&) {} +inline typename std::enable_if::type set_sub_execution_policy( + const Opr*, std::tuple&) {} template inline typename std::enable_if < @@ -101,8 +96,7 @@ template void set_execution_policy(const Opr* opr, SubOpr... sub_oprs) { if (opr->execution_policy().algo.valid() && !opr->execution_policy().sub_policy.empty()) { - megdnn_assert(opr->execution_policy().sub_policy.size() == - sizeof...(sub_oprs)); + megdnn_assert(opr->execution_policy().sub_policy.size() == sizeof...(sub_oprs)); auto&& sub = std::make_tuple(sub_oprs...); set_sub_execution_policy<0, Opr, SubOpr...>(opr, sub); } @@ -113,8 +107,7 @@ void set_execution_policy(const Opr* opr, SubOpr... sub_oprs) { namespace std { template <> struct hash { - std::size_t operator()( - const megdnn::detail::Algorithm::Info::Desc& desc) const { + std::size_t operator()(const megdnn::detail::Algorithm::Info::Desc& desc) const { return megdnn::hash_combine( megdnn::hash_combine( std::hash()(desc.name), diff --git a/dnn/src/common/algo_chooser.h b/dnn/src/common/algo_chooser.h index 12c3e6e2..6595f66a 100644 --- a/dnn/src/common/algo_chooser.h +++ b/dnn/src/common/algo_chooser.h @@ -26,16 +26,14 @@ namespace megdnn { template size_t get_dnn_workspace(Opr* opr, Args&&... args) { TensorLayoutArray layouts{{args...}}; - HeuristicCache::Key key{opr->handle(), opr->get_opr_type(), - layouts.data(), layouts.size(), &opr->param(), - sizeof(opr->param())}; + HeuristicCache::Key key{opr->handle(), opr->get_opr_type(), layouts.data(), + layouts.size(), &opr->param(), sizeof(opr->param())}; auto rst = HeuristicCache::instance().get(key); if (rst.policy.algo.valid()) { return rst.workspace; } - typename Opr::AlgoBase::SizeArgs size_args(opr, - std::forward(args)...); + typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward(args)...); return get_algorithm(opr, std::forward(args)...) ->get_workspace_in_bytes(size_args); } @@ -51,22 +49,21 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { ret = set; } else { TensorLayoutArray layouts{{args...}}; - HeuristicCache::Key key{opr->handle(), opr->get_opr_type(), - layouts.data(), layouts.size(), &opr->param(), - sizeof(opr->param())}; + HeuristicCache::Key key{opr->handle(), opr->get_opr_type(), + layouts.data(), layouts.size(), + &opr->param(), sizeof(opr->param())}; auto rst = HeuristicCache::instance().get(key); if (rst.policy.algo.valid()) { ret = rst.policy.algo; } else { ret = opr->get_algorithm_info_heuristic( std::forward(args)..., - std::numeric_limits::max(), - AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT) + std::numeric_limits::max(), AlgoAttribute::DEFAULT, + AlgoAttribute::DEFAULT) .desc; } } - return static_cast( - opr->get_algorithm_from_desc(ret)); + return static_cast(opr->get_algorithm_from_desc(ret)); } /*! @@ -79,11 +76,9 @@ typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { if (set.valid()) { return opr->algo_pack().construct_and_get_algo(set); } else { - return static_cast( - opr->get_algorithm_heuristic(std::forward(args)..., - std::numeric_limits::max(), - AlgoAttribute::DEFAULT, - AlgoAttribute::DEFAULT)); + return static_cast(opr->get_algorithm_heuristic( + std::forward(args)..., std::numeric_limits::max(), + AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT)); } } @@ -106,8 +101,7 @@ template std::vector get_all_algorithms_safe( const typename Opr::AlgoBase::SizeArgs& args) { auto ret_safe = get_all_algorithms(args); - megdnn_assert(!ret_safe.empty(), "no algorithm for %s", - args.to_string().c_str()); + megdnn_assert(!ret_safe.empty(), "no algorithm for %s", args.to_string().c_str()); return ret_safe; } @@ -130,24 +124,23 @@ typename Opr::Algorithm* get_algo_match_attribute( template typename Opr::Algorithm* get_algo_match_attribute( const std::vector& algos, - const typename Opr::AlgoBase::SizeArgs& args, - size_t workspace_limit_in_bytes, const char* name, + const typename Opr::AlgoBase::SizeArgs& args, size_t workspace_limit_in_bytes, + const char* name, const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { size_t min_workspace_limit_in_bytes = std::numeric_limits::max(); bool available_but_limited_by_workspace = false; bool available_but_attribute_mismatch = false; for (auto i : algos) { - if (i->is_available_attribute(args, positive_attr, negative_attr, - workspace_limit_in_bytes)) { + if (i->is_available_attribute( + args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return i; } if (i->is_available_attribute(args, positive_attr, negative_attr)) { if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) { available_but_limited_by_workspace = true; - min_workspace_limit_in_bytes = - std::min(min_workspace_limit_in_bytes, - i->get_workspace_in_bytes(args)); + min_workspace_limit_in_bytes = std::min( + min_workspace_limit_in_bytes, i->get_workspace_in_bytes(args)); } } if (i->is_available(args)) { @@ -159,14 +152,14 @@ typename Opr::Algorithm* get_algo_match_attribute( MEGDNN_MARK_USED_VAR(name); if (available_but_limited_by_workspace) { - megdnn_throw( - ssprintf("no %s algorithm without attribute(%s) with " - "attribute(%s) : %s workspace limit %zu is " - "less than mini workspace limit %zu", - name, Algorithm::attribute_str(negative_attr).c_str(), - Algorithm::attribute_str(positive_attr).c_str(), - args.to_string().c_str(), workspace_limit_in_bytes, - min_workspace_limit_in_bytes)); + megdnn_throw(ssprintf( + "no %s algorithm without attribute(%s) with " + "attribute(%s) : %s workspace limit %zu is " + "less than mini workspace limit %zu", + name, Algorithm::attribute_str(negative_attr).c_str(), + Algorithm::attribute_str(positive_attr).c_str(), + args.to_string().c_str(), workspace_limit_in_bytes, + min_workspace_limit_in_bytes)); } else if (available_but_attribute_mismatch) { megdnn_throw(ssprintf( "no %s algorithm without attribute(%s) with attribute(%s)", name, diff --git a/dnn/src/common/argmxx/base_impl.cpp b/dnn/src/common/argmxx/base_impl.cpp index 11ca35ab..f717692f 100644 --- a/dnn/src/common/argmxx/base_impl.cpp +++ b/dnn/src/common/argmxx/base_impl.cpp @@ -14,9 +14,7 @@ namespace megdnn { -void ArgmxxBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &dst) -{ +void ArgmxxBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { auto errmsg = [&]() { return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst); }; @@ -25,8 +23,8 @@ void ArgmxxBase::check_layout_fwd(const TensorLayout &src, megdnn_assert_contiguous(dst); megdnn_assert(src.ndim > 0_z, "%s", errmsg().c_str()); megdnn_assert(src.ndim == dst.ndim, "%s", errmsg().c_str()); - megdnn_assert(param().axis < static_cast(src.ndim), "%s", - errmsg().c_str()); + megdnn_assert( + param().axis < static_cast(src.ndim), "%s", errmsg().c_str()); for (size_t i = 0; i < src.ndim; ++i) { if (i != static_cast(param().axis)) { megdnn_assert_eq_size_t(src.shape[i], dst.shape[i]); @@ -37,42 +35,34 @@ void ArgmxxBase::check_layout_fwd(const TensorLayout &src, megdnn_assert(dst.dtype == dtype::Int32()); } -void ArgmaxForward::deduce_layout(const TensorLayout &src, - TensorLayout &dst) -{ +void ArgmaxForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { dst = src; dst.shape[param().axis] = 1; dst.dtype = dtype::Int32(); dst.init_contiguous_stride(); } -void ArgmaxForward::check_exec(const TensorLayout &src, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void ArgmaxForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void ArgminForward::deduce_layout(const TensorLayout &src, - TensorLayout &dst) -{ +void ArgminForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { dst = src; dst.shape[param().axis] = 1; dst.dtype = dtype::Int32(); dst.init_contiguous_stride(); } -void ArgminForward::check_exec(const TensorLayout &src, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void ArgminForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/argmxx_helper.h b/dnn/src/common/argmxx_helper.h index 64e7cc35..bcd7a2f3 100644 --- a/dnn/src/common/argmxx_helper.h +++ b/dnn/src/common/argmxx_helper.h @@ -24,66 +24,60 @@ struct ArgmxxOp { struct wtype { stype_ key; dt_int32 val; - MEGDNN_HOST MEGDNN_DEVICE wtype() - {} - MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val): - key(key), val(val) - {} - MEGDNN_HOST MEGDNN_DEVICE wtype(wtype &rhs): - key(rhs.key), - val(rhs.val) - {} - MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype &rhs): - key(rhs.key), - val(rhs.val) - {} - MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype &rhs): - key(rhs.key), - val(rhs.val) - {} - MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype &rhs): - key(rhs.key), - val(rhs.val) - {} - MEGDNN_HOST MEGDNN_DEVICE volatile wtype &operator=(const wtype &rhs) volatile - { + MEGDNN_HOST MEGDNN_DEVICE wtype() {} + MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val) + : key(key), val(val) {} + MEGDNN_HOST MEGDNN_DEVICE wtype(wtype& rhs) : key(rhs.key), val(rhs.val) {} + MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype& rhs) + : key(rhs.key), val(rhs.val) {} + MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype& rhs) + : key(rhs.key), val(rhs.val) {} + MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype& rhs) + : key(rhs.key), val(rhs.val) {} + MEGDNN_HOST MEGDNN_DEVICE volatile wtype& operator=(const wtype& rhs) volatile { this->key = rhs.key; this->val = rhs.val; return *this; } }; MEGDNN_HOST MEGDNN_DEVICE - ArgmxxOp(stype_ *src, dt_int32 *dst, uint32_t A, uint32_t B, uint32_t C): - src(src), dst(dst), A(A), B(B), C(C), - INIT(wtype(is_max ? DTypeTrait::min() : - DTypeTrait::max(), 0)) - { - } - MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) - { + ArgmxxOp(stype_* src, dt_int32* dst, uint32_t A, uint32_t B, uint32_t C) + : src(src), + dst(dst), + A(A), + B(B), + C(C), + INIT(wtype( + is_max ? DTypeTrait::min() : DTypeTrait::max(), + 0)) {} + MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { wtype res; res.key = src[idx]; res.val = idx / C % B; return res; } - MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) - { + MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val.val; } - static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) - { + static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { if (is_max) { - if (lhs.key > rhs.key) return lhs; else return rhs; + if (lhs.key > rhs.key) + return lhs; + else + return rhs; } else { - if (lhs.key < rhs.key) return lhs; else return rhs; + if (lhs.key < rhs.key) + return lhs; + else + return rhs; } } - stype_ *src; - dt_int32 *dst; + stype_* src; + dt_int32* dst; uint32_t A, B, C; const wtype INIT; }; -} // namespace argmxx -} // namespace megdnn +} // namespace argmxx +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/argsort.cpp b/dnn/src/common/argsort.cpp index 5ad5c7e6..583aac60 100644 --- a/dnn/src/common/argsort.cpp +++ b/dnn/src/common/argsort.cpp @@ -15,19 +15,19 @@ using namespace megdnn; -void ArgsortForward::deduce_layout(const TensorLayout& src, TensorLayout& dst, - TensorLayout& indices) { - megdnn_assert(src.ndim == 2 && src.is_contiguous(), - "invalid src layout: %s", src.to_string().c_str()); +void ArgsortForward::deduce_layout( + const TensorLayout& src, TensorLayout& dst, TensorLayout& indices) { + megdnn_assert( + src.ndim == 2 && src.is_contiguous(), "invalid src layout: %s", + src.to_string().c_str()); dst = src; indices = src; indices.dtype = dtype::Int32(); } -void ArgsortForward::check_exec(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& indices, - size_t workspace_in_bytes) { +void ArgsortForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& indices, + size_t workspace_in_bytes) { auto errmsg = [&]() { return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " + megdnn_layout_msg(indices); @@ -42,26 +42,21 @@ void ArgsortForward::check_exec(const TensorLayout& src, megdnn_assert(src.dtype == dst.dtype); megdnn_assert(indices.dtype == dtype::Int32()); - auto required_workspace_in_bytes = - get_workspace_in_bytes(src, dst, indices); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst, indices); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void ArgsortBackward::check_exec(const TensorLayout& diff, - const TensorLayout& indices, - const TensorLayout& grad, - size_t workspace_in_bytes) { - megdnn_assert(diff.eq_shape(indices) && diff.dtype == grad.dtype && - indices.dtype == dtype::Int32{} && - diff.is_contiguous() && indices.is_contiguous() && - grad.is_contiguous() && diff.ndim == 2 && - grad.ndim == 2 && diff[0] == grad[0] && - diff[1] <= grad[1], - "invalid layouts: diff=%s indices=%s grad=%s", - diff.to_string().c_str(), indices.to_string().c_str(), - grad.to_string().c_str()); - auto required_workspace_in_bytes = - get_workspace_in_bytes(diff, indices, grad); +void ArgsortBackward::check_exec( + const TensorLayout& diff, const TensorLayout& indices, const TensorLayout& grad, + size_t workspace_in_bytes) { + megdnn_assert( + diff.eq_shape(indices) && diff.dtype == grad.dtype && + indices.dtype == dtype::Int32{} && diff.is_contiguous() && + indices.is_contiguous() && grad.is_contiguous() && diff.ndim == 2 && + grad.ndim == 2 && diff[0] == grad[0] && diff[1] <= grad[1], + "invalid layouts: diff=%s indices=%s grad=%s", diff.to_string().c_str(), + indices.to_string().c_str(), grad.to_string().c_str()); + auto required_workspace_in_bytes = get_workspace_in_bytes(diff, indices, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } diff --git a/dnn/src/common/asm_common_defs.h b/dnn/src/common/asm_common_defs.h index d8557116..3ab86173 100644 --- a/dnn/src/common/asm_common_defs.h +++ b/dnn/src/common/asm_common_defs.h @@ -11,9 +11,9 @@ #pragma once #if defined(__WIN32__) || defined(__APPLE__) -# define cdecl(s) _##s +#define cdecl(s) _##s #else -# define cdecl(s) s +#define cdecl(s) s #endif #if !defined(__APPLE__) @@ -23,7 +23,5 @@ #endif #if defined(__linux__) && defined(__ELF__) && (defined(__arm__) || defined(__aarch64__)) -.pushsection .note.GNU-stack,"",%progbits -.popsection +.pushsection.note.GNU - stack, "", % progbits.popsection #endif - diff --git a/dnn/src/common/basic_types.cpp b/dnn/src/common/basic_types.cpp index d5b93e58..9f3dca61 100644 --- a/dnn/src/common/basic_types.cpp +++ b/dnn/src/common/basic_types.cpp @@ -104,8 +104,9 @@ LogHandler g_log_handler = nullptr; } // anonymous namespace #if MEGDNN_ENABLE_LOGGING -void megdnn::__log__(LogLevel level, const char* file, const char* func, - int line, const char* fmt, ...) { +void megdnn::__log__( + LogLevel level, const char* file, const char* func, int line, const char* fmt, + ...) { if (!g_log_handler) return; va_list ap; @@ -124,10 +125,11 @@ LogHandler megdnn::set_log_handler(LogHandler handler) { /* ===================== TensorShape ===================== */ TensorShape::TensorShape(const SmallVector& init_shape) { - megdnn_assert(init_shape.size() <= MAX_NDIM, - "Illegal to construct a TensorShape with " - "more than MAX_NDIM(%zu) axes; init_shape is %s", - MAX_NDIM, vec2str(init_shape).c_str()); + megdnn_assert( + init_shape.size() <= MAX_NDIM, + "Illegal to construct a TensorShape with " + "more than MAX_NDIM(%zu) axes; init_shape is %s", + MAX_NDIM, vec2str(init_shape).c_str()); ndim = init_shape.size(); memcpy(this->shape, init_shape.data(), sizeof(size_t) * ndim); } @@ -195,8 +197,7 @@ bool TensorShape::is_empty() const { /* ===================== TensorLayout ===================== */ TensorLayout::TensorLayout() = default; -TensorLayout::TensorLayout(DType dtype_) - : dtype{dtype_}, format{Format(dtype)} {} +TensorLayout::TensorLayout(DType dtype_) : dtype{dtype_}, format{Format(dtype)} {} TensorLayout::TensorLayout(DType dtype_, Format format_) : dtype{dtype_}, format{format_} {} @@ -204,19 +205,18 @@ TensorLayout::TensorLayout(DType dtype_, Format format_) TensorLayout::TensorLayout(const TensorShape& shape, DType dtype) : TensorLayout(shape, dtype, Format(dtype)) {} -TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, - TensorFormat format_) +TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, TensorFormat format_) : TensorShape(shape), dtype{dtype}, format{format_} { init_contiguous_stride(); } -TensorLayout::TensorLayout(const TensorShape& shape, - const std::vector& stride, DType dtype) +TensorLayout::TensorLayout( + const TensorShape& shape, const std::vector& stride, DType dtype) : TensorLayout(shape, stride, dtype, Format(dtype)) {} -TensorLayout::TensorLayout(const TensorShape& shape, - const std::vector& stride, DType dtype, - TensorFormat format_) +TensorLayout::TensorLayout( + const TensorShape& shape, const std::vector& stride, DType dtype, + TensorFormat format_) : TensorShape(shape), dtype{dtype}, format{format_} { megdnn_assert_eq_size_t(stride.size(), ndim); for (size_t i = 0; i < shape.ndim; ++i) @@ -232,8 +232,8 @@ size_t TensorLayout::init_contiguous_stride(const TensorShape& shape) { return init_contiguous_stride(); } -size_t TensorLayout::init_contiguous_stride(const TensorShape& shape, - TensorFormat format_) { +size_t TensorLayout::init_contiguous_stride( + const TensorShape& shape, TensorFormat format_) { this->TensorShape::operator=(shape); this->format = format_; return init_contiguous_stride(); @@ -268,11 +268,11 @@ void TensorLayout::remove_axis_inplace(size_t axis) { } } -void TensorLayout::add_axis_inplace(size_t axis, size_t shape, - ptrdiff_t stride) { - megdnn_assert(ndim + 1 <= MAX_NDIM && axis <= ndim && shape, - "can not add axis at %zu (current ndim %zu, MAX_NDIM %zu)", - axis, ndim, MAX_NDIM); +void TensorLayout::add_axis_inplace(size_t axis, size_t shape, ptrdiff_t stride) { + megdnn_assert( + ndim + 1 <= MAX_NDIM && axis <= ndim && shape, + "can not add axis at %zu (current ndim %zu, MAX_NDIM %zu)", axis, ndim, + MAX_NDIM); ndim++; for (size_t i = ndim - 1; i > axis; i--) { this->shape[i] = this->shape[i - 1]; @@ -307,8 +307,8 @@ bool TensorLayout::is_abs_monotonous_allow_brdcst() const { return false; if (ndim == 1) return true; - ptrdiff_t last = std::abs(stride[ndim - 1]) * - static_cast(shape[ndim - 1]); + ptrdiff_t last = + std::abs(stride[ndim - 1]) * static_cast(shape[ndim - 1]); for (int i = ndim - 2; i >= 0; --i) { if (!stride[i] || shape[i] == 1) continue; @@ -375,13 +375,13 @@ bool TensorLayout::is_non_overlapping_strong() const { } bool TensorLayout::eq_layout(const TensorLayout& rhs) const { - megdnn_assert(dtype == rhs.dtype, - "could not compare layout on different dtypes: %s vs %s", - dtype.name(), rhs.dtype.name()); + megdnn_assert( + dtype == rhs.dtype, + "could not compare layout on different dtypes: %s vs %s", dtype.name(), + rhs.dtype.name()); MEGDNN_STATIC_ASSERT(MAX_NDIM == 7, "please update the code"); - auto ax = [](size_t shape0, size_t shape1, ptrdiff_t stride0, - ptrdiff_t stride1) { + auto ax = [](size_t shape0, size_t shape1, ptrdiff_t stride0, ptrdiff_t stride1) { return (shape0 == shape1) & ((shape0 == 1) | (stride0 == stride1)); }; if (ndim == rhs.ndim) { @@ -447,8 +447,9 @@ size_t TensorLayout::access_bytes() const { } TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { - megdnn_throw_if(!ndim || !tshape.ndim, tensor_reshape_error, - "broadcast involves empty tensor"); + megdnn_throw_if( + !ndim || !tshape.ndim, tensor_reshape_error, + "broadcast involves empty tensor"); if (is_scalar()) { TensorLayout result{dtype, format}; @@ -460,10 +461,12 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { return result; } - megdnn_throw_if(tshape.ndim < ndim, tensor_reshape_error, - ssprintf("dimension for broadcast less than " - "dst_shape: src_shape=%s dst_shape=%s", - to_string().c_str(), tshape.to_string().c_str())); + megdnn_throw_if( + tshape.ndim < ndim, tensor_reshape_error, + ssprintf( + "dimension for broadcast less than " + "dst_shape: src_shape=%s dst_shape=%s", + to_string().c_str(), tshape.to_string().c_str())); TensorLayout result{dtype, format}; for (size_t i = 0; i < tshape.ndim; ++i) { int target_idx = tshape.ndim - i - 1; @@ -473,9 +476,10 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { if (tshape.shape[target_idx] != cur_shape) { megdnn_throw_if( cur_shape != 1 && cur_stride != 0, tensor_reshape_error, - ssprintf("broadcast on dim with shape not equal to 1: " - "src_shape=%s dst_shape=%s", - to_string().c_str(), tshape.to_string().c_str())); + ssprintf( + "broadcast on dim with shape not equal to 1: " + "src_shape=%s dst_shape=%s", + to_string().c_str(), tshape.to_string().c_str())); result.shape[target_idx] = tshape.shape[target_idx]; result.stride[target_idx] = 0; } else { @@ -487,8 +491,7 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { return result; } -bool TensorLayout::try_reshape(TensorLayout& result, - const TensorShape& tshp) const { +bool TensorLayout::try_reshape(TensorLayout& result, const TensorShape& tshp) const { megdnn_assert(tshp.ndim); bool is_empty_shape = false; @@ -505,10 +508,11 @@ bool TensorLayout::try_reshape(TensorLayout& result, megdnn_throw_if( !tshp.ndim || total_nr_elems() != tshp.total_nr_elems(), tensor_reshape_error, - ssprintf("number of elements do not match " - "in reshape: src=%s dest=%s", - static_cast(*this).to_string().c_str(), - tshp.to_string().c_str())); + ssprintf( + "number of elements do not match " + "in reshape: src=%s dest=%s", + static_cast(*this).to_string().c_str(), + tshp.to_string().c_str())); auto cont = collapse_contiguous(); result.dtype = this->dtype; @@ -547,9 +551,11 @@ bool TensorLayout::try_reshape(TensorLayout& result, TensorLayout TensorLayout::reshape(const TensorShape& shape) const { TensorLayout ret; auto succ = try_reshape(ret, shape); - megdnn_throw_if(!succ, tensor_reshape_error, - ssprintf("can not reshape from %s to %s", - to_string().c_str(), shape.to_string().c_str())); + megdnn_throw_if( + !succ, tensor_reshape_error, + ssprintf( + "can not reshape from %s to %s", to_string().c_str(), + shape.to_string().c_str())); return ret; } @@ -591,8 +597,7 @@ std::string TensorLayout::serialize() const { MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) #undef cb default: - megdnn_assert(false, - "cannot serialize unknown parameterized DType"); + megdnn_assert(false, "cannot serialize unknown parameterized DType"); break; } } diff --git a/dnn/src/common/batch_conv_bias.cpp b/dnn/src/common/batch_conv_bias.cpp index b3ecb8c0..061dc27e 100644 --- a/dnn/src/common/batch_conv_bias.cpp +++ b/dnn/src/common/batch_conv_bias.cpp @@ -13,17 +13,15 @@ #include "src/common/utils.h" namespace megdnn { -void BatchConvBiasForward::deduce_dtype(DType src, DType filter, - DType /* bias */, DType /* z */, - DType& dst) { +void BatchConvBiasForward::deduce_dtype( + DType src, DType filter, DType /* bias */, DType /* z */, DType& dst) { check_or_deduce_dtype_fwd(src, filter, dst); } -void BatchConvBiasForward::deduce_layout(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& /* bias */, - const TensorLayout& /* z */, - TensorLayout& dst) { +void BatchConvBiasForward::deduce_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& /* bias */, const TensorLayout& /* z */, + TensorLayout& dst) { TensorLayout non_batch_filter; non_batch_filter.ndim = filter.ndim - 1; non_batch_filter.dtype = filter.dtype; @@ -36,12 +34,12 @@ void BatchConvBiasForward::deduce_layout(const TensorLayout& src, } BatchConvBiasForward::CanonizedFilterMeta BatchConvBiasForward::check_exec( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, size_t workspace_in_bytes) { - megdnn_assert(src.dtype.enumv() == filter.dtype.enumv() && - src.dtype.enumv() == DTypeEnum::QuantizedS8, - "batch conv only support qint8"); + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst, size_t workspace_in_bytes) { + megdnn_assert( + src.dtype.enumv() == filter.dtype.enumv() && + src.dtype.enumv() == DTypeEnum::QuantizedS8, + "batch conv only support qint8"); float scale_src = src.dtype.param().scale; float scale_filter = filter.dtype.param().scale; float scale_bias = bias.dtype.param().scale; @@ -76,8 +74,9 @@ BatchConvBiasForward::CanonizedFilterMeta BatchConvBiasForward::check_exec( return ret; if (param().format == param::BatchConvBias::Format::NCHW4) { megdnn_assert(bias.shape[0] == 1); - megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", - bias.to_string().c_str(), dst.to_string().c_str()); + megdnn_assert( + bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", + bias.to_string().c_str(), dst.to_string().c_str()); megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[4] == 4); diff --git a/dnn/src/common/batch_normalization.cpp b/dnn/src/common/batch_normalization.cpp index 13367022..ec7fa656 100644 --- a/dnn/src/common/batch_normalization.cpp +++ b/dnn/src/common/batch_normalization.cpp @@ -14,10 +14,10 @@ namespace megdnn { -void BNForward::deduce_layout(const TensorLayout& src, const TensorLayout&, - const TensorLayout&, TensorLayout&, TensorLayout&, - TensorLayout&, TensorLayout&, - TensorLayout& reserve, TensorLayout& dst) { +void BNForward::deduce_layout( + const TensorLayout& src, const TensorLayout&, const TensorLayout&, + TensorLayout&, TensorLayout&, TensorLayout&, TensorLayout&, + TensorLayout& reserve, TensorLayout& dst) { reserve = {{get_reserve_in_bytes(src)}, dtype::Byte()}; dst = src; } @@ -35,21 +35,19 @@ void BNForward::check_exec( megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); auto required_workspace_in_bytes = get_workspace_in_bytes( - src, bn_scale, bn_bias, mean, variance, batch_mean, - batch_inv_variance, {}, dst); + src, bn_scale, bn_bias, mean, variance, batch_mean, batch_inv_variance, {}, + dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); auto required_reserve_in_bytes = get_reserve_in_bytes(src); megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes); } -void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, - const TensorLayout& saved_batch_mean, - const TensorLayout& saved_batch_variance, - const TensorLayout& bn_scale, - const TensorLayout& d_bn_scale, - const TensorLayout& d_bn_bias, - const TensorLayout& dx, size_t workspace_in_bytes, - size_t reserve_in_bytes) { +void BNBackward::check_exec( + const TensorLayout& x, const TensorLayout& dy, + const TensorLayout& saved_batch_mean, const TensorLayout& saved_batch_variance, + const TensorLayout& bn_scale, const TensorLayout& d_bn_scale, + const TensorLayout& d_bn_bias, const TensorLayout& dx, + size_t workspace_in_bytes, size_t reserve_in_bytes) { megdnn_assert_contiguous(x); megdnn_assert_eq_layout(x, dy); megdnn_assert_eq_layout(x, dx); @@ -60,13 +58,14 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, megdnn_assert(x.dtype.category() == DTypeCategory::FLOAT); megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); auto required_workspace_in_bytes = get_workspace_in_bytes( - x, dy, saved_batch_mean, saved_batch_variance, bn_scale, {}, - d_bn_scale, d_bn_bias, dx); + x, dy, saved_batch_mean, saved_batch_variance, bn_scale, {}, d_bn_scale, + d_bn_bias, dx); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); auto required_reserve_in_bytes = get_reserve_in_bytes(x); megdnn_assert(reserve_in_bytes >= required_reserve_in_bytes); - megdnn_assert(param().fwd_mode == Param::FwdMode::TRAINING, - "BNBackward only support TRAINING mode"); + megdnn_assert( + param().fwd_mode == Param::FwdMode::TRAINING, + "BNBackward only support TRAINING mode"); } } // namespace megdnn diff --git a/dnn/src/common/batched_matrix_mul.cpp b/dnn/src/common/batched_matrix_mul.cpp index e4d9c4cc..85caae25 100644 --- a/dnn/src/common/batched_matrix_mul.cpp +++ b/dnn/src/common/batched_matrix_mul.cpp @@ -13,7 +13,7 @@ namespace megdnn { -void BatchedMatrixMulForward::deduce_dtype(DType A, DType B, DType &C) { +void BatchedMatrixMulForward::deduce_dtype(DType A, DType B, DType& C) { DType C_candi, C_candi2; if (A.category() == DTypeCategory::FLOAT) { C_candi = A; @@ -30,13 +30,12 @@ void BatchedMatrixMulForward::deduce_dtype(DType A, DType B, DType &C) { if (!C.valid()) { C = C_candi; } - megdnn_assert(C.valid() && (C == C_candi || C == C_candi2), - "unsupported BatchedMatMul(%s, %s) -> %s", A.name(), B.name(), - C.name()); + megdnn_assert( + C.valid() && (C == C_candi || C == C_candi2), + "unsupported BatchedMatMul(%s, %s) -> %s", A.name(), B.name(), C.name()); } -void BatchedMatrixMulForward::deduce_layout(const TensorLayout& A, - const TensorLayout& B, - TensorLayout& C) { +void BatchedMatrixMulForward::deduce_layout( + const TensorLayout& A, const TensorLayout& B, TensorLayout& C) { auto errmsg = [&]() { std::string msg; msg.append("A="); @@ -59,8 +58,7 @@ void BatchedMatrixMulForward::deduce_layout(const TensorLayout& A, return l.ndim == 3 && l.stride[2] == 1 && l.stride[1] >= static_cast(l.shape[2]) && (l.shape[0] == 1 || - l.stride[0] >= - static_cast(l.shape[1]) * l.stride[1] || + l.stride[0] >= static_cast(l.shape[1]) * l.stride[1] || l.stride[0] == 0); }; size_t A0, A1, B0, B1; @@ -73,24 +71,26 @@ void BatchedMatrixMulForward::deduce_layout(const TensorLayout& A, if (m_param.transposeB) std::swap(B0, B1); deduce_dtype(A.dtype, B.dtype, C.dtype); - megdnn_assert(good_layout(A) && good_layout(B) && A1 == B0 && - A[0] == B[0] && A.dtype.enumv() == B.dtype.enumv(), - "bad input layouts: %s", errmsg().c_str()); + megdnn_assert( + good_layout(A) && good_layout(B) && A1 == B0 && A[0] == B[0] && + A.dtype.enumv() == B.dtype.enumv(), + "bad input layouts: %s", errmsg().c_str()); C = TensorLayout(TensorShape({A[0], A0, B1}), C.dtype); } -void BatchedMatrixMulForward::check_exec(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C, - size_t workspace_in_bytes) { +void BatchedMatrixMulForward::check_exec( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, + size_t workspace_in_bytes) { TensorLayout C_expect; deduce_layout(A, B, C_expect); - megdnn_assert(C_expect.eq_layout(C), "bad layout for C: expect=%s got=%s", - C_expect.to_string().c_str(), C.to_string().c_str()); + megdnn_assert( + C_expect.eq_layout(C), "bad layout for C: expect=%s got=%s", + C_expect.to_string().c_str(), C.to_string().c_str()); auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C); - megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes, - "needed workspace: %zu; got: %zu", - required_workspace_in_bytes, workspace_in_bytes); + megdnn_assert( + workspace_in_bytes >= required_workspace_in_bytes, + "needed workspace: %zu; got: %zu", required_workspace_in_bytes, + workspace_in_bytes); } } // namespace megdnn diff --git a/dnn/src/common/check_non_finite.cpp b/dnn/src/common/check_non_finite.cpp index 64e4657d..e6ea7b28 100644 --- a/dnn/src/common/check_non_finite.cpp +++ b/dnn/src/common/check_non_finite.cpp @@ -14,8 +14,8 @@ namespace megdnn { -void CheckNonFinite::check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes) { +void CheckNonFinite::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { megdnn_assert_contiguous(src); megdnn_assert_contiguous(dst); megdnn_assert(src.ndim == 1); diff --git a/dnn/src/common/checksum.cpp b/dnn/src/common/checksum.cpp index 76123e43..353dd18e 100644 --- a/dnn/src/common/checksum.cpp +++ b/dnn/src/common/checksum.cpp @@ -14,15 +14,14 @@ using namespace megdnn; -void megdnn::ChecksumForward::check_exec(const TensorLayout &layout, - size_t workspace_in_bytes) { - megdnn_assert(layout.is_contiguous() && - layout.ndim == 1 && - layout.dtype == dtype::Byte() && - layout.shape[0], "%s", layout.to_string().c_str()); +void megdnn::ChecksumForward::check_exec( + const TensorLayout& layout, size_t workspace_in_bytes) { + megdnn_assert( + layout.is_contiguous() && layout.ndim == 1 && + layout.dtype == dtype::Byte() && layout.shape[0], + "%s", layout.to_string().c_str()); auto required_workspace_in_bytes = get_workspace_in_bytes(layout); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/common/concat_split.cpp b/dnn/src/common/concat_split.cpp index 77cc3700..684a5cd5 100644 --- a/dnn/src/common/concat_split.cpp +++ b/dnn/src/common/concat_split.cpp @@ -16,98 +16,86 @@ namespace megdnn { -ConcatSplitBase::ConcatSplitBase(Handle *handle): - OperatorBase(handle), - m_get_layout([](const TensorND &tensor) { return tensor.layout; }), - m_get_shape([](const TensorLayout &layout) { return TensorShape(layout); }) -{ -} +ConcatSplitBase::ConcatSplitBase(Handle* handle) + : OperatorBase(handle), + m_get_layout([](const TensorND& tensor) { return tensor.layout; }), + m_get_shape([](const TensorLayout& layout) { return TensorShape(layout); }) {} -void ConcatSplitBase::check_layout_common(const TensorLayoutArray &srcs, - const TensorLayout &dst) -{ +void ConcatSplitBase::check_layout_common( + const TensorLayoutArray& srcs, const TensorLayout& dst) { // ensure same data type - for (auto &&src: srcs) { + for (auto&& src : srcs) { megdnn_assert(src.dtype == dst.dtype); } // ensure all layouts are contiguous - for (auto &&src: srcs) { + for (auto&& src : srcs) { megdnn_assert_contiguous(src); - } + } megdnn_assert_contiguous(dst); // ensure all layouts have the same ndim auto ndim = dst.ndim; - for (auto &&src: srcs) { + for (auto&& src : srcs) { megdnn_assert_eq_size_t(src.ndim, ndim); - } - // ensure param().axis is correct - auto errmsg = "param().axis=" + - std::to_string(param().axis) + ", ndim=" + - std::to_string(ndim); + } + // ensure param().axis is correct + auto errmsg = "param().axis=" + std::to_string(param().axis) + + ", ndim=" + std::to_string(ndim); MEGDNN_MARK_USED_VAR(errmsg); - megdnn_assert(param().axis < static_cast(ndim), "%s", - errmsg.c_str()); + megdnn_assert(param().axis < static_cast(ndim), "%s", errmsg.c_str()); // ensure shape size for each axis is correct for (size_t i = 0; i < ndim; ++i) { if (i == static_cast(param().axis)) { size_t sum = 0_z; - for (auto &&src: srcs) sum += src.shape[i]; + for (auto&& src : srcs) + sum += src.shape[i]; megdnn_assert_eq_size_t(sum, dst.shape[i]); } else { - for (auto &&src: srcs) { - megdnn_assert(src.shape[i] == dst.shape[i]); + for (auto&& src : srcs) { + megdnn_assert(src.shape[i] == dst.shape[i]); megdnn_assert_eq_size_t(src.shape[i], dst.shape[i]); - } + } } } } -void ConcatSplitBase::get_ABC(const TensorShapeArray &srcs, - size_t &A, - size_t *B, - size_t &C) -{ +void ConcatSplitBase::get_ABC( + const TensorShapeArray& srcs, size_t& A, size_t* B, size_t& C) { auto axis = param().axis; auto shape_arr = srcs[0].shape; auto ndim = srcs[0].ndim; - A = std::accumulate(shape_arr, shape_arr + axis, - 1_z, SafeMultiplies()); + A = std::accumulate(shape_arr, shape_arr + axis, 1_z, SafeMultiplies()); for (size_t i = 0u; i < srcs.size(); ++i) { B[i] = srcs[i].shape[axis]; } - C = std::accumulate(shape_arr + (axis+1), shape_arr + ndim, - 1_z, SafeMultiplies()); + C = std::accumulate( + shape_arr + (axis + 1), shape_arr + ndim, 1_z, SafeMultiplies()); } -void ConcatForward::deduce_layout(const TensorLayoutArray &srcs, - TensorLayout &dst) -{ +void ConcatForward::deduce_layout(const TensorLayoutArray& srcs, TensorLayout& dst) { dst = srcs[0]; auto i = param().axis; dst.shape[i] = 0u; - for (auto &&src: srcs) { + for (auto&& src : srcs) { dst.shape[i] += src.shape[i]; } dst.init_contiguous_stride(); } -void ConcatForward::check_exec(const TensorLayoutArray &srcs, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void ConcatForward::check_exec( + const TensorLayoutArray& srcs, const TensorLayout& dst, + size_t workspace_in_bytes) { check_layout_common(srcs, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(srcs, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void SplitForward::check_exec(const TensorLayout &src, - const TensorLayoutArray &dsts, - size_t workspace_in_bytes) -{ +void SplitForward::check_exec( + const TensorLayout& src, const TensorLayoutArray& dsts, + size_t workspace_in_bytes) { check_layout_common(dsts, src); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dsts); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/cond_take/opr_impl.cpp b/dnn/src/common/cond_take/opr_impl.cpp index dfa744fd..6381e139 100644 --- a/dnn/src/common/cond_take/opr_impl.cpp +++ b/dnn/src/common/cond_take/opr_impl.cpp @@ -14,17 +14,14 @@ using namespace megdnn; -size_t CondTake::check_exec_get_size(const TensorLayout& data, - const TensorLayout& mask, - size_t workspace_in_bytes) { - megdnn_assert(data.eq_shape(mask), - "CondTake shape differs: data=%s mask=%s", - data.TensorShape::to_string().c_str(), - mask.TensorShape::to_string().c_str()); - megdnn_assert(data.is_physical_contiguous() && - mask.is_physical_contiguous()); - megdnn_assert(m_param.eps > 0, "eps must be non-negative; got: %g", - m_param.eps); +size_t CondTake::check_exec_get_size( + const TensorLayout& data, const TensorLayout& mask, size_t workspace_in_bytes) { + megdnn_assert( + data.eq_shape(mask), "CondTake shape differs: data=%s mask=%s", + data.TensorShape::to_string().c_str(), + mask.TensorShape::to_string().c_str()); + megdnn_assert(data.is_physical_contiguous() && mask.is_physical_contiguous()); + megdnn_assert(m_param.eps > 0, "eps must be non-negative; got: %g", m_param.eps); megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(data)); return data.total_nr_elems(); } diff --git a/dnn/src/common/cond_take/predicate.cuh b/dnn/src/common/cond_take/predicate.cuh index c415b9a4..83a5c57c 100644 --- a/dnn/src/common/cond_take/predicate.cuh +++ b/dnn/src/common/cond_take/predicate.cuh @@ -11,8 +11,8 @@ #pragma once -#include "src/common/opr_param_defs_enumv.cuh" #include "megdnn/arch.h" +#include "src/common/opr_param_defs_enumv.cuh" #if MEGDNN_CC_HOST #include "megdnn/opr_param_defs.h" @@ -28,85 +28,76 @@ namespace megdnn { namespace cond_take { - typedef param_enumv::CondTake::Mode PEnum; +typedef param_enumv::CondTake::Mode PEnum; - struct KParam { - float val, eps; +struct KParam { + float val, eps; #if MEGDNN_CC_HOST - KParam(const param::CondTake &p): - val(p.val), eps(p.eps) - {} + KParam(const param::CondTake& p) : val(p.val), eps(p.eps) {} #endif +}; + +template +struct Pred; + +#define do_inst_eq_f(_ct) \ + template <> \ + struct Pred { \ + typedef _ct ctype; \ + ctype val, eps; \ + Pred(const KParam& p) : val(p.val), eps(p.eps) {} \ + __device__ __host__ bool operator()(ctype x) const { \ + return fabsf(val - x) < eps; \ + } \ }; - template - struct Pred; - -#define do_inst_eq_f(_ct) \ - template<> \ - struct Pred { \ - typedef _ct ctype; \ - ctype val, eps; \ - Pred(const KParam &p): val(p.val), eps(p.eps) {} \ - __device__ __host__ bool operator() (ctype x) const { \ - return fabsf(val - x) < eps; \ - } \ - }; - -#define do_inst_eq_i(_ct) \ - template<> \ - struct Pred { \ - typedef _ct ctype; \ - ctype val; \ - Pred(const KParam &p): val(p.val) {} \ - __device__ __host__ bool operator() (ctype x) const { \ - return val == x; \ - } \ +#define do_inst_eq_i(_ct) \ + template <> \ + struct Pred { \ + typedef _ct ctype; \ + ctype val; \ + Pred(const KParam& p) : val(p.val) {} \ + __device__ __host__ bool operator()(ctype x) const { return val == x; } \ }; #define inst_eq_f(_dt) do_inst_eq_f(DTypeTrait<_dt>::ctype) #define inst_eq_i(_dt) do_inst_eq_i(DTypeTrait<_dt>::ctype) - MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(inst_eq_f) - MEGDNN_FOREACH_COMPUTING_DTYPE_INT(inst_eq_i) - inst_eq_i(::megdnn::dtype::Bool) +MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(inst_eq_f) +MEGDNN_FOREACH_COMPUTING_DTYPE_INT(inst_eq_i) +inst_eq_i(::megdnn::dtype::Bool) #undef inst_eq_f #undef inst_eq_i - template - struct Pred { - typedef ctype_ ctype; - Pred eq; + template + struct Pred { + typedef ctype_ ctype; + Pred eq; - Pred(const KParam &p): eq(p) {} + Pred(const KParam& p) : eq(p) {} - __device__ __host__ bool operator() (ctype x) const { - return !this->eq(x); - } - }; + __device__ __host__ bool operator()(ctype x) const { return !this->eq(x); } +}; -#define DEF_OP(_name, _op) \ - template \ - struct Pred { \ - typedef ctype_ ctype; \ - ctype val; \ - Pred(const KParam &p): val(p.val) {} \ - __device__ __host__ bool operator() (ctype x) const { \ - return x _op val; \ - } \ +#define DEF_OP(_name, _op) \ + template \ + struct Pred { \ + typedef ctype_ ctype; \ + ctype val; \ + Pred(const KParam& p) : val(p.val) {} \ + __device__ __host__ bool operator()(ctype x) const { return x _op val; } \ } - DEF_OP(LT, < ); - DEF_OP(LEQ, <= ); - DEF_OP(GT, > ); - DEF_OP(GEQ, >= ); +DEF_OP(LT, <); +DEF_OP(LEQ, <=); +DEF_OP(GT, >); +DEF_OP(GEQ, >=); #undef DEF_OP -#define MEGDNN_FOREACH_COND_TAKE_MODE(cb) \ - cb(EQ) cb(NEQ) cb(LT) cb(LEQ) cb(GT) cb(GEQ) +#define MEGDNN_FOREACH_COND_TAKE_MODE(cb) cb(EQ) cb(NEQ) cb(LT) cb(LEQ) cb(GT) cb(GEQ) -} // namespace cond_take -} // namespace megdnn +} // namespace cond_take +} // namespace megdnn #ifdef def_device #undef __device__ diff --git a/dnn/src/common/conv_bias.cpp b/dnn/src/common/conv_bias.cpp index c747c7e9..642f9588 100644 --- a/dnn/src/common/conv_bias.cpp +++ b/dnn/src/common/conv_bias.cpp @@ -11,41 +11,42 @@ */ #include "src/common/conv_bias.h" -#include "src/common/utils.h" #include "src/common/opr_delegate.h" +#include "src/common/utils.h" namespace megdnn { namespace { void do_check_exec_common( - ConvBiasForward* opr, const TensorLayout& src, - const TensorLayout& filter, const TensorLayout& bias, - const TensorLayout& z, const TensorLayout& dst, + ConvBiasForward* opr, const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst, size_t workspace_in_bytes, const ConvBiasForward::PreprocessedFilter* preprocessed_filter) { - megdnn_assert((src.dtype.enumv() == filter.dtype.enumv()) || - (src.dtype.enumv() == DTypeEnum::Quantized4Asymm && - filter.dtype.enumv() == DTypeEnum::QuantizedS4)); + megdnn_assert( + (src.dtype.enumv() == filter.dtype.enumv()) || + (src.dtype.enumv() == DTypeEnum::Quantized4Asymm && + filter.dtype.enumv() == DTypeEnum::QuantizedS4)); // check compatibility of bias's scale if (src.dtype.category() == DTypeCategory::QUANTIZED) { if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) { float scale_expected = mul_scale(src.dtype, filter.dtype); float scale_bias = bias.dtype.param().scale; - megdnn_assert(std::abs(scale_expected - scale_bias) < 1e-6, - "scale_src: %f scale_filter: %f scale_bias: %f", - get_scale(src.dtype), get_scale(filter.dtype), - scale_bias); + megdnn_assert( + std::abs(scale_expected - scale_bias) < 1e-6, + "scale_src: %f scale_filter: %f scale_bias: %f", + get_scale(src.dtype), get_scale(filter.dtype), scale_bias); } else { megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32); } } megdnn_assert_contiguous(bias); - auto required_workspace_in_bytes = opr->get_workspace_in_bytes( - src, filter, bias, z, dst, preprocessed_filter); - megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes, - "worksapce have size of %zu, but need %zu", - workspace_in_bytes, required_workspace_in_bytes); + auto required_workspace_in_bytes = + opr->get_workspace_in_bytes(src, filter, bias, z, dst, preprocessed_filter); + megdnn_assert( + workspace_in_bytes >= required_workspace_in_bytes, + "worksapce have size of %zu, but need %zu", workspace_in_bytes, + required_workspace_in_bytes); if (bias.ndim != 0) { //! bias.layout == dst.layout failed, no assert information auto check_eq = [](const TensorLayout& bias, const TensorLayout& dst) { @@ -61,76 +62,83 @@ void do_check_exec_common( if (opr->param().format == param::ConvBias::Format::NCHW || opr->param().format == param::ConvBias::Format::NCHW4_NCHW) { megdnn_assert(bias.shape[0] == 1); - megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", - bias.to_string().c_str(), dst.to_string().c_str()); + megdnn_assert( + bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", + bias.to_string().c_str(), dst.to_string().c_str()); megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[3] == 1); - } else if (opr->param().format == param::ConvBias::Format::NHWC || - opr->param().format == param::ConvBias::Format::NCHW4_NHWC) { + } else if ( + opr->param().format == param::ConvBias::Format::NHWC || + opr->param().format == param::ConvBias::Format::NCHW4_NHWC) { megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[1] == 1); megdnn_assert(bias.shape[2] == 1); - megdnn_assert(bias.shape[3] == dst.shape[3], "bias:%s, dst:%s", - bias.to_string().c_str(), dst.to_string().c_str()); - } else if (opr->param().format == param::ConvBias::Format::NCHW4 || - opr->param().format == param::ConvBias::Format::NCHW44 || - opr->param().format == param::ConvBias::Format::NCHW44_DOT || - opr->param().format == - param::ConvBias::Format::NCHW32_NCHW4) { + megdnn_assert( + bias.shape[3] == dst.shape[3], "bias:%s, dst:%s", + bias.to_string().c_str(), dst.to_string().c_str()); + } else if ( + opr->param().format == param::ConvBias::Format::NCHW4 || + opr->param().format == param::ConvBias::Format::NCHW44 || + opr->param().format == param::ConvBias::Format::NCHW44_DOT || + opr->param().format == param::ConvBias::Format::NCHW32_NCHW4) { megdnn_assert(bias.shape[0] == 1); - megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", - bias.to_string().c_str(), dst.to_string().c_str()); + megdnn_assert( + bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", + bias.to_string().c_str(), dst.to_string().c_str()); megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[4] == 4); - } else if (opr->param().format == param::ConvBias::Format::NCHW8 || - opr->param().format == param::ConvBias::Format::NCHW88) { + } else if ( + opr->param().format == param::ConvBias::Format::NCHW8 || + opr->param().format == param::ConvBias::Format::NCHW88) { megdnn_assert(bias.shape[0] == 1); - megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", - bias.to_string().c_str(), dst.to_string().c_str()); + megdnn_assert( + bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", + bias.to_string().c_str(), dst.to_string().c_str()); megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[4] == 8); - } else if (opr->param().format == param::ConvBias::Format::NCHW32 || - opr->param().format == - param::ConvBias::Format::NCHW4_NCHW32) { + } else if ( + opr->param().format == param::ConvBias::Format::NCHW32 || + opr->param().format == param::ConvBias::Format::NCHW4_NCHW32) { megdnn_assert(bias.shape[0] == 1); - megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", - bias.to_string().c_str(), dst.to_string().c_str()); + megdnn_assert( + bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", + bias.to_string().c_str(), dst.to_string().c_str()); megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[4] == 32); } else if (opr->param().format == param::ConvBias::Format::CHWN4) { - megdnn_assert(bias.shape[0] == dst.shape[0], "bias:%s, dst:%s", - bias.to_string().c_str(), dst.to_string().c_str()); + megdnn_assert( + bias.shape[0] == dst.shape[0], "bias:%s, dst:%s", + bias.to_string().c_str(), dst.to_string().c_str()); megdnn_assert(bias.shape[1] == 1); megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[4] == 4); } else if (opr->param().format == param::ConvBias::Format::NCHW64) { megdnn_assert(bias.shape[0] == 1); - megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", - bias.to_string().c_str(), dst.to_string().c_str()); + megdnn_assert( + bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", + bias.to_string().c_str(), dst.to_string().c_str()); megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[4] == 64); } else { - megdnn_assert(opr->param().format == - param::ConvBias::Format::NHWCD4); + megdnn_assert(opr->param().format == param::ConvBias::Format::NHWCD4); megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[1] == 1); - megdnn_assert(bias.shape[2] == dst.shape[2], "bias:%s, dst:%s", - bias.to_string().c_str(), dst.to_string().c_str()); + megdnn_assert( + bias.shape[2] == dst.shape[2], "bias:%s, dst:%s", + bias.to_string().c_str(), dst.to_string().c_str()); megdnn_assert(bias.shape[3] == 1); megdnn_assert(bias.shape[4] == 4); } } if (z.ndim != 0) { - megdnn_assert(opr->param().format != - param::ConvBias::Format::NCHW4_NCHW32); - megdnn_assert(opr->param().format != - param::ConvBias::Format::NCHW32_NCHW4); + megdnn_assert(opr->param().format != param::ConvBias::Format::NCHW4_NCHW32); + megdnn_assert(opr->param().format != param::ConvBias::Format::NCHW32_NCHW4); megdnn_assert(z.dtype.enumv() == dst.dtype.enumv()); megdnn_assert(z.eq_shape(dst)); } @@ -138,38 +146,34 @@ void do_check_exec_common( } // namespace -void ConvBiasForward::deduce_dtype(DType src, DType filter, DType /* bias */, - DType /* z */, DType& dst) { +void ConvBiasForward::deduce_dtype( + DType src, DType filter, DType /* bias */, DType /* z */, DType& dst) { check_or_deduce_dtype_fwd(src, filter, dst); } -void ConvBiasForward::deduce_layout(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& /* bias */, - const TensorLayout& /* z */, - TensorLayout& dst) { +void ConvBiasForward::deduce_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& /* bias */, const TensorLayout& /* z */, + TensorLayout& dst) { deduce_layout_fwd(src, filter, dst); } ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, size_t workspace_in_bytes, + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst, size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter) { - do_check_exec_common(this, src, filter, bias, z, dst, workspace_in_bytes, - preprocessed_filter); + do_check_exec_common( + this, src, filter, bias, z, dst, workspace_in_bytes, preprocessed_filter); auto ret = check_layout_fwd(src, filter, dst); return ret; } -ConvBiasForward::CanonizedFilterMeta -ConvBiasForward::check_exec_allow_noncontiguous( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, size_t workspace_in_bytes, +ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec_allow_noncontiguous( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst, size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter) { - do_check_exec_common(this, src, filter, bias, z, dst, workspace_in_bytes, - preprocessed_filter); + do_check_exec_common( + this, src, filter, bias, z, dst, workspace_in_bytes, preprocessed_filter); TensorLayout dst_expected; dst_expected.dtype = dst.dtype; auto ret = deduce_layout_fwd(src, filter, dst_expected); @@ -184,19 +188,20 @@ template struct NCHW44ParamTrait; std::string ConvBias::WinogradParam::to_string() const { - return ssprintf("%u:%u:%u", channel_block_size, output_block_size, - tile_size); + return ssprintf("%u:%u:%u", channel_block_size, output_block_size, tile_size); } template -std::string ConvBias::algo_name(const std::string& base, const T& p, - param::ConvBias::Format format) { +std::string ConvBias::algo_name( + const std::string& base, const T& p, param::ConvBias::Format format) { if (format == param::ConvBias::Format::NCHW) { - return ssprintf("%s:%s:%s", NCHWParamTrait::category.c_str(), - base.c_str(), p.to_string().c_str()); + return ssprintf( + "%s:%s:%s", NCHWParamTrait::category.c_str(), base.c_str(), + p.to_string().c_str()); } else if (format == param::ConvBias::Format::NCHW44) { - return ssprintf("%s:%s:%s", NCHW44ParamTrait::category.c_str(), - base.c_str(), p.to_string().c_str()); + return ssprintf( + "%s:%s:%s", NCHW44ParamTrait::category.c_str(), base.c_str(), + p.to_string().c_str()); } megdnn_throw("Invalid format"); return ""; @@ -225,8 +230,7 @@ cb(MatmulParam, "MATMUL"); cb(DefaultParam, "DEFAULT"); #undef cb -const std::string NCHWParamTrait::category = - "WINOGRAD"; +const std::string NCHWParamTrait::category = "WINOGRAD"; const std::string NCHW44ParamTrait::category = "WINOGRAD_NCHW44"; @@ -237,18 +241,15 @@ const std::string NCHW44ParamTrait::category = FOREACH_CONV_BIAS_PARAM(cb) #undef cb -ConvBias::WinogradParam ConvBias::parse_winograd_name( - const std::string& algo_name) { +ConvBias::WinogradParam ConvBias::parse_winograd_name(const std::string& algo_name) { ConvBias::WinogradParam ret = INVALID_WINOGRAD_PARAM; char base[128]; char name[128]; - auto parse = [&](const std::string& algo_name, - const std::string& pre) -> auto { + auto parse = [&](const std::string& algo_name, const std::string& pre) -> auto { memset(name, 0, 128); sscanf(algo_name.c_str(), "%[^:]:%[^:]:%u:%u:%u", name, base, - &(ret.channel_block_size), &(ret.output_block_size), - &(ret.tile_size)); + &(ret.channel_block_size), &(ret.output_block_size), &(ret.tile_size)); if (strcmp(name, pre.c_str())) { ret = INVALID_WINOGRAD_PARAM; return false; @@ -271,53 +272,46 @@ ConvBias::WinogradParam ConvBias::parse_winograd_name( constexpr ConvBias::WinogradParam ConvBias::INVALID_WINOGRAD_PARAM; -void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args, - const TensorND* conv_dst_tensor, - const TensorND* dst_tensor, - const TensorND* bias_tensor) { +void handle_bias_and_nonlinear( + Handle* handle, param::ConvBias args, const TensorND* conv_dst_tensor, + const TensorND* dst_tensor, const TensorND* bias_tensor) { using NonlineMode = param::ConvBias::NonlineMode; switch (args.nonlineMode) { -#define cb(_mode) \ - case NonlineMode::_mode: { \ - if (conv_dst_tensor->layout.dtype.category() != \ - DTypeCategory::QUANTIZED) { \ - auto nonlinear = handle->create_operator(); \ - if (bias_tensor->layout.ndim > 0) { \ - nonlinear->param().mode = \ - Elemwise::Param::Mode::FUSE_ADD_##_mode; \ - nonlinear->exec({*conv_dst_tensor, *bias_tensor}, \ - *dst_tensor); \ - } else { \ - nonlinear->param().mode = Elemwise::Param::Mode::_mode; \ - nonlinear->exec({*conv_dst_tensor}, *dst_tensor); \ - } \ - } else { \ - auto nonlinear = handle->create_operator(); \ - if (bias_tensor->layout.ndim > 0) { \ - nonlinear->param().mode = \ - ElemwiseMultiType::Param::Mode::QFUSE_ADD_##_mode; \ - nonlinear->exec({*conv_dst_tensor, *bias_tensor}, \ - *dst_tensor); \ - } else { \ - nonlinear->param().mode = \ - ElemwiseMultiType::Param::Mode::Q##_mode; \ - nonlinear->exec({*conv_dst_tensor}, *dst_tensor); \ - } \ - } \ - break; \ +#define cb(_mode) \ + case NonlineMode::_mode: { \ + if (conv_dst_tensor->layout.dtype.category() != DTypeCategory::QUANTIZED) { \ + auto nonlinear = handle->create_operator(); \ + if (bias_tensor->layout.ndim > 0) { \ + nonlinear->param().mode = Elemwise::Param::Mode::FUSE_ADD_##_mode; \ + nonlinear->exec({*conv_dst_tensor, *bias_tensor}, *dst_tensor); \ + } else { \ + nonlinear->param().mode = Elemwise::Param::Mode::_mode; \ + nonlinear->exec({*conv_dst_tensor}, *dst_tensor); \ + } \ + } else { \ + auto nonlinear = handle->create_operator(); \ + if (bias_tensor->layout.ndim > 0) { \ + nonlinear->param().mode = \ + ElemwiseMultiType::Param::Mode::QFUSE_ADD_##_mode; \ + nonlinear->exec({*conv_dst_tensor, *bias_tensor}, *dst_tensor); \ + } else { \ + nonlinear->param().mode = ElemwiseMultiType::Param::Mode::Q##_mode; \ + nonlinear->exec({*conv_dst_tensor}, *dst_tensor); \ + } \ + } \ + break; \ } cb(RELU); cb(H_SWISH); #undef cb case NonlineMode::SIGMOID: { - megdnn_assert(conv_dst_tensor->layout.dtype.category() != - DTypeCategory::QUANTIZED); + megdnn_assert( + conv_dst_tensor->layout.dtype.category() != + DTypeCategory::QUANTIZED); auto nonlinear = handle->create_operator(); if (bias_tensor->layout.ndim > 0) { - nonlinear->param().mode = - Elemwise::Param::Mode::FUSE_ADD_SIGMOID; - nonlinear->exec({*conv_dst_tensor, *bias_tensor}, - *conv_dst_tensor); + nonlinear->param().mode = Elemwise::Param::Mode::FUSE_ADD_SIGMOID; + nonlinear->exec({*conv_dst_tensor, *bias_tensor}, *conv_dst_tensor); } else { nonlinear->param().mode = Elemwise::Param::Mode::SIGMOID; nonlinear->exec({*conv_dst_tensor}, *conv_dst_tensor); @@ -326,24 +320,19 @@ void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args, } case NonlineMode::IDENTITY: { if (bias_tensor->layout.ndim > 0) { - if (dst_tensor->layout.dtype.category() == - DTypeCategory::QUANTIZED) { - auto nonlinear = - handle->create_operator(); - nonlinear->param().mode = - ElemwiseMultiType::Param::Mode::QADD; - nonlinear->exec({*conv_dst_tensor, *bias_tensor}, - *dst_tensor); + if (dst_tensor->layout.dtype.category() == DTypeCategory::QUANTIZED) { + auto nonlinear = handle->create_operator(); + nonlinear->param().mode = ElemwiseMultiType::Param::Mode::QADD; + nonlinear->exec({*conv_dst_tensor, *bias_tensor}, *dst_tensor); } else { auto nonlinear = handle->create_operator(); nonlinear->param().mode = Elemwise::Param::Mode::ADD; - nonlinear->exec({*conv_dst_tensor, *bias_tensor}, - *dst_tensor); + nonlinear->exec({*conv_dst_tensor, *bias_tensor}, *dst_tensor); } } else { if (conv_dst_tensor->layout.dtype != dst_tensor->layout.dtype) { - handle->create_operator()->exec({*conv_dst_tensor}, - *dst_tensor); + handle->create_operator()->exec( + {*conv_dst_tensor}, *dst_tensor); } } break; diff --git a/dnn/src/common/conv_bias.h b/dnn/src/common/conv_bias.h index 84489c87..591680d8 100644 --- a/dnn/src/common/conv_bias.h +++ b/dnn/src/common/conv_bias.h @@ -18,10 +18,9 @@ namespace megdnn { -void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args, - const TensorND* conv_dst_tensor, - const TensorND* dst_tensor, - const TensorND* bias_tensor); +void handle_bias_and_nonlinear( + Handle* handle, param::ConvBias args, const TensorND* conv_dst_tensor, + const TensorND* dst_tensor, const TensorND* bias_tensor); } // namespace megdnn diff --git a/dnn/src/common/conv_pooling.cpp b/dnn/src/common/conv_pooling.cpp index 92a12667..6a31bbfe 100644 --- a/dnn/src/common/conv_pooling.cpp +++ b/dnn/src/common/conv_pooling.cpp @@ -11,7 +11,4 @@ #include "megdnn.h" #include "src/common/utils.h" -namespace megdnn { - - -} // namespace megdnn \ No newline at end of file +namespace megdnn {} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/common/convolution.cpp b/dnn/src/common/convolution.cpp index f0aa61d8..0fc9afbb 100644 --- a/dnn/src/common/convolution.cpp +++ b/dnn/src/common/convolution.cpp @@ -17,18 +17,18 @@ using namespace megdnn; namespace { template -std::string get_errmsg(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, const Param& param) { +std::string get_errmsg( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, + const Param& param) { MEGDNN_MARK_USED_VAR(src); MEGDNN_MARK_USED_VAR(filter); MEGDNN_MARK_USED_VAR(dst); return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " + megdnn_layout_msg(dst) + ", " + "is_nchw=" + - std::to_string(param.format == param::Convolution::Format::NCHW) + - ", " + "is_xcorr=" + - std::to_string( - (param.mode == Convolution::Mode::CROSS_CORRELATION)) + - ", " + "pad_h=" + std::to_string(param.pad_h) + ", " + + std::to_string(param.format == param::Convolution::Format::NCHW) + ", " + + "is_xcorr=" + + std::to_string((param.mode == Convolution::Mode::CROSS_CORRELATION)) + ", " + + "pad_h=" + std::to_string(param.pad_h) + ", " + "pad_w=" + std::to_string(param.pad_w) + ", " + "stride_h=" + std::to_string(param.stride_h) + ", " + "stride_w=" + std::to_string(param.stride_w) + ", " + @@ -45,8 +45,8 @@ template void make_canonized_filter_meta_nchw_nhwc( size_t src_ndim, const TensorLayout& filter, const Param& param, typename ConvolutionBase::CanonizedFilterMeta& ret) { - megdnn_assert(param.format == Param::Format::NCHW || - param.format == Param::Format::NHWC); + megdnn_assert( + param.format == Param::Format::NCHW || param.format == Param::Format::NHWC); auto img_ndim = src_ndim - 2; size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos; if (param.sparse == Param::Sparse::DENSE) { @@ -58,8 +58,9 @@ void make_canonized_filter_meta_nchw_nhwc( ret.group = 1; flt_start = 0; } else { - megdnn_assert(param.sparse == Param::Sparse::GROUP, - "invalid convolution sparse type"); + megdnn_assert( + param.sparse == Param::Sparse::GROUP, + "invalid convolution sparse type"); megdnn_assert( filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5, "bad filter ndim for group convolution: " @@ -78,8 +79,8 @@ void make_canonized_filter_meta_nchw_nhwc( ocpg_pos = 0; icpg_pos = 1; } else { - megdnn_assert(param.format == Param::Format::NHWC, - "invalid conv tensor format"); + megdnn_assert( + param.format == Param::Format::NHWC, "invalid conv tensor format"); // filter should be (oc, fh, fw, ic) flt_spatial_start = 1; ocpg_pos = 0; @@ -95,9 +96,9 @@ void make_canonized_filter_meta_nchw_nhwc( ret.icpg = filter[flt_start + icpg_pos] * ic_block_size; auto dilation = ret.dilation; for (size_t i = 0; i < ret.spatial_ndim; ++i) { - megdnn_assert(dilation[i] > 0, - "invalid dilation on spatial dim %zu: %u", i, - dilation[i]); + megdnn_assert( + dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i, + dilation[i]); ret.spatial[i] = spatial_getter( filter[i + flt_start + flt_spatial_start], param); ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1; @@ -120,16 +121,18 @@ void make_canonized_filter_meta_nhwcd4( size_t flt_start = 0, flt_spatial_start = 1; bool is_chanwise = false; if (param.sparse == Param::Sparse::DENSE) { - megdnn_assert(filter.ndim == img_ndim + 3, - "bad filter ndim for dense convolution: " - "spatial_ndim=%zu filter_ndim=%zu", - img_ndim, filter.ndim); + megdnn_assert( + filter.ndim == img_ndim + 3, + "bad filter ndim for dense convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); // oc, ic, dims[] ret.group = 1; flt_start = 0; } else { - megdnn_assert(param.sparse == Param::Sparse::GROUP, - "invalid convolution sparse type"); + megdnn_assert( + param.sparse == Param::Sparse::GROUP, + "invalid convolution sparse type"); megdnn_assert( filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 4, "bad filter ndim for group convolution: " @@ -158,9 +161,9 @@ void make_canonized_filter_meta_nhwcd4( } auto dilation = ret.dilation; for (size_t i = 0; i < ret.spatial_ndim; ++i) { - megdnn_assert(dilation[i] > 0, - "invalid dilation on spatial dim %zu: %u", i, - dilation[i]); + megdnn_assert( + dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i, + dilation[i]); ret.spatial[i] = filter[i + flt_start + flt_spatial_start]; ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1; } @@ -182,16 +185,18 @@ void make_canonized_filter_meta_nhwcd4_dot( size_t flt_start = 0, flt_spatial_start = 1; bool is_chanwise = false; if (param.sparse == Param::Sparse::DENSE) { - megdnn_assert(filter.ndim == img_ndim + 4, - "bad filter ndim for dense convolution: " - "spatial_ndim=%zu filter_ndim=%zu", - img_ndim, filter.ndim); + megdnn_assert( + filter.ndim == img_ndim + 4, + "bad filter ndim for dense convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); // oc, ic, dims[] ret.group = 1; flt_start = 0; } else { - megdnn_assert(param.sparse == Param::Sparse::GROUP, - "invalid convolution sparse type"); + megdnn_assert( + param.sparse == Param::Sparse::GROUP, + "invalid convolution sparse type"); megdnn_assert( filter.ndim == img_ndim + 3 || filter.ndim == img_ndim + 5, "bad filter ndim for group convolution: " @@ -221,9 +226,9 @@ void make_canonized_filter_meta_nhwcd4_dot( } auto dilation = ret.dilation; for (size_t i = 0; i < ret.spatial_ndim; ++i) { - megdnn_assert(dilation[i] > 0, - "invalid dilation on spatial dim %zu: %u", i, - dilation[i]); + megdnn_assert( + dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i, + dilation[i]); ret.spatial[i] = filter[i + flt_start + flt_spatial_start]; ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1; } @@ -254,9 +259,10 @@ void make_canonized_filter_meta_nchwxx( * */ - megdnn_assert(param.format == Param::Format::NCHW88 || - param.format == Param::Format::NCHW44 || - param.format == Param::Format::NCHW44_DOT); + megdnn_assert( + param.format == Param::Format::NCHW88 || + param.format == Param::Format::NCHW44 || + param.format == Param::Format::NCHW44_DOT); size_t img_ndim = 2; size_t flt_start = 0; size_t flt_spatial_start = 2; @@ -264,13 +270,13 @@ void make_canonized_filter_meta_nchwxx( if (param.sparse == Param::Sparse::DENSE) { if (filter.ndim == img_ndim + 4) { // oihw8i8o case - megdnn_assert((filter[filter.ndim - 2] == pack_size && - filter[filter.ndim - 1] == pack_size) || - (filter[filter.ndim - 2] == 2 * pack_size && - filter[filter.ndim - 1] == 2 * pack_size), - "last 2 dim of filter must be %zu, but got %zu, %zu", - pack_size, filter[filter.ndim - 2], - filter[filter.ndim - 1]); + megdnn_assert( + (filter[filter.ndim - 2] == pack_size && + filter[filter.ndim - 1] == pack_size) || + (filter[filter.ndim - 2] == 2 * pack_size && + filter[filter.ndim - 1] == 2 * pack_size), + "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size, + filter[filter.ndim - 2], filter[filter.ndim - 1]); ret.group = 1; flt_start = 0; if (filter[filter.ndim - 2] == 2 * pack_size && @@ -290,41 +296,44 @@ void make_canonized_filter_meta_nchwxx( ret.icpg = filter[flt_start + 3]; } else { - megdnn_assert(0, "not support nchwxx filter dim = %zu", - filter.ndim); + megdnn_assert(0, "not support nchwxx filter dim = %zu", filter.ndim); } } else { - megdnn_assert(param.sparse == Param::Sparse::GROUP, - "invalid convolution sparse type"); + megdnn_assert( + param.sparse == Param::Sparse::GROUP, + "invalid convolution sparse type"); flt_start = 1; auto filter_oc = filter[flt_start]; auto filter_ic = filter[flt_start + 1]; if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4)) { // Depthwise case goihw8g - megdnn_assert(filter.ndim == img_ndim + 4, - "bad filter ndim for group convolution: " - "spatial_ndim=%zu filter_ndim=%zu", - img_ndim, filter.ndim); - megdnn_assert(filter[filter.ndim - 1] == pack_size, - "last dim of filter must be %zu, but %zu", pack_size, - filter[filter.ndim - 1]); + megdnn_assert( + filter.ndim == img_ndim + 4, + "bad filter ndim for group convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); + megdnn_assert( + filter[filter.ndim - 1] == pack_size, + "last dim of filter must be %zu, but %zu", pack_size, + filter[filter.ndim - 1]); ret.group = filter[0] * pack_size; ret.ocpg = filter_oc; ret.icpg = filter_ic; } else { // norm group case goihw8i8o - megdnn_assert(filter.ndim == img_ndim + 5, - "bad filter ndim for group convolution: " - "spatial_ndim=%zu filter_ndim=%zu", - img_ndim, filter.ndim); - megdnn_assert((filter[filter.ndim - 1] == pack_size && - filter[filter.ndim - 2] == pack_size) || - (filter[filter.ndim - 1] == 2 * pack_size && - filter[filter.ndim - 2] == 2 * pack_size), - "last 2 dim of filter must be %zu, but got %zu, %zu", - pack_size, filter[filter.ndim - 2], - filter[filter.ndim - 1]); + megdnn_assert( + filter.ndim == img_ndim + 5, + "bad filter ndim for group convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); + megdnn_assert( + (filter[filter.ndim - 1] == pack_size && + filter[filter.ndim - 2] == pack_size) || + (filter[filter.ndim - 1] == 2 * pack_size && + filter[filter.ndim - 2] == 2 * pack_size), + "last 2 dim of filter must be %zu, but got %zu, %zu", pack_size, + filter[filter.ndim - 2], filter[filter.ndim - 1]); ret.group = filter[0]; if (filter[filter.ndim - 2] == 2 * pack_size && @@ -338,18 +347,20 @@ void make_canonized_filter_meta_nchwxx( } } ret.spatial_ndim = 2; - megdnn_assert(ret.spatial_ndim == 2, - "only 2D convolution is supported, and input should be 5-dim " - "for nchwxx; " - "got input dim = %zu", - src_ndim); + megdnn_assert( + ret.spatial_ndim == 2, + "only 2D convolution is supported, and input should be 5-dim " + "for nchwxx; " + "got input dim = %zu", + src_ndim); auto dilation = ret.dilation; for (size_t i = 0; i < ret.spatial_ndim; ++i) { - megdnn_assert(dilation[i] == 1, - "NCHWXX has invalid dilation on spatial dim %zu: %u, " - "require to be 1", - i, dilation[i]); + megdnn_assert( + dilation[i] == 1, + "NCHWXX has invalid dilation on spatial dim %zu: %u, " + "require to be 1", + i, dilation[i]); ret.spatial[i] = filter[i + flt_start + flt_spatial_start]; ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1; } @@ -365,48 +376,54 @@ void make_canonized_filter_meta_nchwx( * OC, IC/pack_size, FH, FW, pack_size [dense] * GROUP, OC, IC/pack_size, FH, FW, pack_size [group] */ - megdnn_assert(param.format == Param::Format::NCHW4 || - param.format == Param::Format::NCHW8 || - param.format == Param::Format::NCHW32 || - param.format == Param::Format::NCHW4_NCHW || - param.format == Param::Format::NCHW4_NHWC || - param.format == Param::Format::NCHW4_NCHW32 || - param.format == Param::Format::NCHW32_NCHW4 || - param.format == Param::Format::NCHW64); + megdnn_assert( + param.format == Param::Format::NCHW4 || + param.format == Param::Format::NCHW8 || + param.format == Param::Format::NCHW32 || + param.format == Param::Format::NCHW4_NCHW || + param.format == Param::Format::NCHW4_NHWC || + param.format == Param::Format::NCHW4_NCHW32 || + param.format == Param::Format::NCHW32_NCHW4 || + param.format == Param::Format::NCHW64); auto img_ndim = src_ndim - 3; size_t flt_start = 0, flt_spatial_start = 2; if (param.sparse == Param::Sparse::DENSE) { - megdnn_assert(filter.ndim == img_ndim + 3, - "bad filter ndim for dense convolution: " - "spatial_ndim=%zu filter_ndim=%zu", - img_ndim, filter.ndim); + megdnn_assert( + filter.ndim == img_ndim + 3, + "bad filter ndim for dense convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); // oc, ic, dims[] ret.group = 1; flt_start = 0; } else { - megdnn_assert(param.sparse == Param::Sparse::GROUP, - "invalid convolution sparse type"); - megdnn_assert(filter.ndim == img_ndim + 4, - "bad filter ndim for group convolution: " - "spatial_ndim=%zu filter_ndim=%zu", - img_ndim, filter.ndim); + megdnn_assert( + param.sparse == Param::Sparse::GROUP, + "invalid convolution sparse type"); + megdnn_assert( + filter.ndim == img_ndim + 4, + "bad filter ndim for group convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); ret.group = filter[0]; flt_start = 1; } ret.spatial_ndim = src_ndim - 3; - megdnn_assert(ret.spatial_ndim == 2, - "only 2D convolution is supported, and input should be 5-dim " - "for nchw4; " - "got input dim = %zu", - src_ndim); + megdnn_assert( + ret.spatial_ndim == 2, + "only 2D convolution is supported, and input should be 5-dim " + "for nchw4; " + "got input dim = %zu", + src_ndim); ret.ocpg = filter[flt_start]; ret.icpg = filter[flt_start + 1] * pack_size; auto dilation = ret.dilation; for (size_t i = 0; i < ret.spatial_ndim; ++i) { - megdnn_assert(dilation[i] == 1, - "NCHW4 has invalid dilation on spatial dim %zu: %u, " - "require to be 1", - i, dilation[i]); + megdnn_assert( + dilation[i] == 1, + "NCHW4 has invalid dilation on spatial dim %zu: %u, " + "require to be 1", + i, dilation[i]); ret.spatial[i] = filter[i + flt_start + flt_spatial_start]; ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1; } @@ -427,20 +444,23 @@ void make_canonized_filter_meta_chwnx( auto img_ndim = src_ndim - 3; size_t flt_start = 0, flt_spatial_start = 1; if (param.sparse == Param::Sparse::DENSE) { - megdnn_assert(filter.ndim == img_ndim + 3, - "bad filter ndim for dense convolution: " - "spatial_ndim=%zu filter_ndim=%zu", - img_ndim, filter.ndim); + megdnn_assert( + filter.ndim == img_ndim + 3, + "bad filter ndim for dense convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); // oc, ic, dims[] ret.group = 1; flt_start = 0; } else { - megdnn_assert(param.sparse == Param::Sparse::GROUP, - "invalid convolution sparse type"); - megdnn_assert(filter.ndim == img_ndim + 4, - "bad filter ndim for group convolution: " - "spatial_ndim=%zu filter_ndim=%zu", - img_ndim, filter.ndim); + megdnn_assert( + param.sparse == Param::Sparse::GROUP, + "invalid convolution sparse type"); + megdnn_assert( + filter.ndim == img_ndim + 4, + "bad filter ndim for group convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); ret.group = filter[0]; flt_start = 1; } @@ -454,10 +474,11 @@ void make_canonized_filter_meta_chwnx( ret.ocpg = filter[flt_start + 3]; auto dilation = ret.dilation; for (size_t i = 0; i < ret.spatial_ndim; ++i) { - megdnn_assert(dilation[i] == 1, - "CHWNx has invalid dilation on spatial dim %zu: %u, " - "require to be 1", - i, dilation[i]); + megdnn_assert( + dilation[i] == 1, + "CHWNx has invalid dilation on spatial dim %zu: %u, " + "require to be 1", + i, dilation[i]); ret.spatial[i] = filter[i + flt_start + flt_spatial_start]; ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1; } @@ -467,9 +488,8 @@ void make_canonized_filter_meta_chwnx( namespace megdnn { template -typename ConvolutionBase::CanonizedFilterMeta -ConvolutionBase::make_canonized_filter_meta( - size_t src_ndim, const TensorLayout& filter) const { +typename ConvolutionBase::CanonizedFilterMeta ConvolutionBase:: + make_canonized_filter_meta(size_t src_ndim, const TensorLayout& filter) const { megdnn_assert_contiguous(filter); CanonizedFilterMeta ret; ret.dtype = filter.dtype; @@ -477,8 +497,7 @@ ConvolutionBase::make_canonized_filter_meta( if (param().mode == Mode::CONVOLUTION) { ret.should_flip = true; } else { - megdnn_assert(param().mode == Mode::CROSS_CORRELATION, - "invalid conv mode"); + megdnn_assert(param().mode == Mode::CROSS_CORRELATION, "invalid conv mode"); ret.should_flip = false; } ret.stride[0] = param().stride_h; @@ -491,51 +510,46 @@ ConvolutionBase::make_canonized_filter_meta( if (param().format == Param::Format::NHWCD4) { if (filter.dtype.enumv() == DTypeEnum::QuantizedS8 || filter.dtype.enumv() == DTypeEnum::Quantized8Asymm) { - make_canonized_filter_meta_nhwcd4_dot(src_ndim, filter, - param(), ret); + make_canonized_filter_meta_nhwcd4_dot( + src_ndim, filter, param(), ret); } else { - make_canonized_filter_meta_nhwcd4(src_ndim, filter, - param(), ret); + make_canonized_filter_meta_nhwcd4( + src_ndim, filter, param(), ret); } - } else if (param().format == Param::Format::NCHW4 || - param().format == Param::Format::NCHW4_NCHW || - param().format == Param::Format::NCHW4_NHWC || - param().format == Param::Format::NCHW4_NCHW32) { - make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, - param(), ret); + } else if ( + param().format == Param::Format::NCHW4 || + param().format == Param::Format::NCHW4_NCHW || + param().format == Param::Format::NCHW4_NHWC || + param().format == Param::Format::NCHW4_NCHW32) { + make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, param(), ret); } else if (param().format == Param::Format::NCHW8) { - make_canonized_filter_meta_nchwx<8, Parameter>(src_ndim, filter, - param(), ret); + make_canonized_filter_meta_nchwx<8, Parameter>(src_ndim, filter, param(), ret); } else if (param().format == Param::Format::NCHW88) { - make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, - param(), ret); - } else if (param().format == Param::Format::NCHW44 || - param().format == Param::Format::NCHW44_DOT) { - make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, - param(), ret); - } else if (param().format == Param::Format::NCHW32 || - param().format == Param::Format::NCHW32_NCHW4) { - make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, - param(), ret); + make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, param(), ret); + } else if ( + param().format == Param::Format::NCHW44 || + param().format == Param::Format::NCHW44_DOT) { + make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, param(), ret); + } else if ( + param().format == Param::Format::NCHW32 || + param().format == Param::Format::NCHW32_NCHW4) { + make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, param(), ret); } else if (param().format == Param::Format::CHWN4) { - make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter, - param(), ret); + make_canonized_filter_meta_chwnx<4, Parameter>(src_ndim, filter, param(), ret); } else if (param().format == Param::Format::NCHW64) { - make_canonized_filter_meta_nchwx<64, Parameter>(src_ndim, filter, - param(), ret); + make_canonized_filter_meta_nchwx<64, Parameter>(src_ndim, filter, param(), ret); } else { - megdnn_assert(param().format == Param::Format::NHWC || - param().format == Param::Format::NCHW); - make_canonized_filter_meta_nchw_nhwc(src_ndim, filter, - param(), ret); + megdnn_assert( + param().format == Param::Format::NHWC || + param().format == Param::Format::NCHW); + make_canonized_filter_meta_nchw_nhwc(src_ndim, filter, param(), ret); } return ret; } template -void ConvolutionBase::check_or_deduce_dtype_fwd(DType src, - DType filter, - DType& dst) const { +void ConvolutionBase::check_or_deduce_dtype_fwd( + DType src, DType filter, DType& dst) const { // The first one will be the default choice. SmallVector supported_dst_dtype; // We rely on megdnn_assert(src.enumv() == filter.enumv()) here. @@ -543,20 +557,19 @@ void ConvolutionBase::check_or_deduce_dtype_fwd(DType src, supported_dst_dtype.push_back(src); } else if (src.enumv() == DTypeEnum::Int8) { supported_dst_dtype = {dtype::Int32(), dtype::Int16()}; - } else if (src.enumv() == DTypeEnum::QuantizedS8 || - src.enumv() == DTypeEnum::Quantized8Asymm || - src.enumv() == DTypeEnum::QuantizedS4 || - src.enumv() == DTypeEnum::Quantized4Asymm) { - supported_dst_dtype.push_back( - dtype::QuantizedS32(mul_scale(src, filter))); - bool cond_dst = - dst.valid() && (dst.enumv() == src.enumv() || - ((dst.enumv() == DTypeEnum::QuantizedS4 || - dst.enumv() == DTypeEnum::Quantized4Asymm) && - src.enumv() == DTypeEnum::QuantizedS8) || - ((src.enumv() == DTypeEnum::QuantizedS4 || - src.enumv() == DTypeEnum::Quantized4Asymm) && - dst.enumv() == DTypeEnum::QuantizedS8)); + } else if ( + src.enumv() == DTypeEnum::QuantizedS8 || + src.enumv() == DTypeEnum::Quantized8Asymm || + src.enumv() == DTypeEnum::QuantizedS4 || + src.enumv() == DTypeEnum::Quantized4Asymm) { + supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(src, filter))); + bool cond_dst = dst.valid() && (dst.enumv() == src.enumv() || + ((dst.enumv() == DTypeEnum::QuantizedS4 || + dst.enumv() == DTypeEnum::Quantized4Asymm) && + src.enumv() == DTypeEnum::QuantizedS8) || + ((src.enumv() == DTypeEnum::QuantizedS4 || + src.enumv() == DTypeEnum::Quantized4Asymm) && + dst.enumv() == DTypeEnum::QuantizedS8)); if (cond_dst) { supported_dst_dtype.push_back(dst); } @@ -566,12 +579,13 @@ void ConvolutionBase::check_or_deduce_dtype_fwd(DType src, } else if (src.enumv() == DTypeEnum::QuantizedS32) { //! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src) megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8); - supported_dst_dtype.push_back( - dtype::QuantizedS8(src.param().scale / - filter.param().scale)); - }else { - megdnn_throw(ssprintf("unsupported input / filter DType: %s x %s", - src.name(), filter.name())); + supported_dst_dtype.push_back(dtype::QuantizedS8( + src.param().scale / + filter.param().scale)); + } else { + megdnn_throw(ssprintf( + "unsupported input / filter DType: %s x %s", src.name(), + filter.name())); } if (!dst.valid()) { dst = supported_dst_dtype.at(0); @@ -584,54 +598,59 @@ void ConvolutionBase::check_or_deduce_dtype_fwd(DType src, } } MEGDNN_MARK_USED_VAR(dst_supported); - megdnn_assert(dst_supported, "unsupported Conv(%s, %s) -> %s", - src.name(), filter.name(), dst.name()); + megdnn_assert( + dst_supported, "unsupported Conv(%s, %s) -> %s", src.name(), + filter.name(), dst.name()); } - megdnn_assert((param().compute_mode == Param::ComputeMode::FLOAT32 || - param().compute_mode == Param::ComputeMode::DEFAULT) + megdnn_assert( + (param().compute_mode == Param::ComputeMode::FLOAT32 || + param().compute_mode == Param::ComputeMode::DEFAULT) #if !MEGDNN_DISABLE_FLOAT16 - || src.enumv() == DTypeEnum::Float16 || - src.enumv() == DTypeEnum::BFloat16 + || src.enumv() == DTypeEnum::Float16 || + src.enumv() == DTypeEnum::BFloat16 #endif - , - "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " - "input / output."); + , + "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " + "input / output."); } template -typename ConvolutionBase::CanonizedFilterMeta -ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, - const TensorLayout& filter, - TensorLayout& dst) const { +typename ConvolutionBase::CanonizedFilterMeta ConvolutionBase:: + deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + TensorLayout& dst) const { auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); }; MEGDNN_MARK_USED_VAR(errmsg); megdnn_assert(src.ndim >= 3_z, "%s", errmsg().c_str()); - megdnn_assert(((src.dtype.enumv() == filter.dtype.enumv()) || - (src.dtype.enumv() == DTypeEnum::Quantized4Asymm && - filter.dtype.enumv() == DTypeEnum::QuantizedS4)), - "%s", errmsg().c_str()); + megdnn_assert( + ((src.dtype.enumv() == filter.dtype.enumv()) || + (src.dtype.enumv() == DTypeEnum::Quantized4Asymm && + filter.dtype.enumv() == DTypeEnum::QuantizedS4)), + "%s", errmsg().c_str()); check_or_deduce_dtype_fwd(src.dtype, filter.dtype, dst.dtype); size_t img_dim; if (param().format == Param::Format::NCHW || param().format == Param::Format::NHWC) { img_dim = src.ndim - 2; - megdnn_assert(filter.ndim >= img_dim + 2 && filter.ndim <= img_dim + 6, - "%s", errmsg().c_str()); + megdnn_assert( + filter.ndim >= img_dim + 2 && filter.ndim <= img_dim + 6, "%s", + errmsg().c_str()); } else { - megdnn_assert(param().format == Param::Format::NHWCD4 || - param().format == Param::Format::NCHW4 || - param().format == Param::Format::NCHW4_NCHW || - param().format == Param::Format::NCHW4_NHWC || - param().format == Param::Format::NCHW4_NCHW32 || - param().format == Param::Format::NCHW44 || - param().format == Param::Format::NCHW44_DOT || - param().format == Param::Format::NCHW8 || - param().format == Param::Format::NCHW32 || - param().format == Param::Format::NCHW32_NCHW4 || - param().format == Param::Format::NCHW88 || - param().format == Param::Format::CHWN4 || - param().format == Param::Format::NCHW64); + megdnn_assert( + param().format == Param::Format::NHWCD4 || + param().format == Param::Format::NCHW4 || + param().format == Param::Format::NCHW4_NCHW || + param().format == Param::Format::NCHW4_NHWC || + param().format == Param::Format::NCHW4_NCHW32 || + param().format == Param::Format::NCHW44 || + param().format == Param::Format::NCHW44_DOT || + param().format == Param::Format::NCHW8 || + param().format == Param::Format::NCHW32 || + param().format == Param::Format::NCHW32_NCHW4 || + param().format == Param::Format::NCHW88 || + param().format == Param::Format::CHWN4 || + param().format == Param::Format::NCHW64); img_dim = src.ndim - 3; if ((param().format == Param::Format::NCHW88 || param().format == Param::Format::NCHW44_DOT || @@ -639,35 +658,34 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, filter.ndim == 5) { img_dim = src.ndim - 2; } - megdnn_assert(filter.ndim == img_dim + 3 || - (filter.ndim == img_dim + 2 && - (param().format == Param::Format::NCHW88 || - param().format == Param::Format::NCHW44_DOT || - param().format == Param::Format::NCHW44)) || - filter.ndim == img_dim + 4 || - filter.ndim == img_dim + 5, - "%s", errmsg().c_str()); + megdnn_assert( + filter.ndim == img_dim + 3 || + (filter.ndim == img_dim + 2 && + (param().format == Param::Format::NCHW88 || + param().format == Param::Format::NCHW44_DOT || + param().format == Param::Format::NCHW44)) || + filter.ndim == img_dim + 4 || filter.ndim == img_dim + 5, + "%s", errmsg().c_str()); if (param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW4_NCHW || param().format == Param::Format::NCHW4_NCHW32) { - megdnn_assert(src.ndim == 5 && - (filter.ndim == 5 || filter.ndim == 6 || - filter.ndim == 7) && - src[src.ndim - 1] == 4 && - filter[filter.ndim - 1] == 4, - "NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and " - "filter's ndim is " - "5 or 6, and " - "last shape " - "is 4 " - "but got src %s, filter %s", - src.to_string().c_str(), filter.to_string().c_str()); + megdnn_assert( + src.ndim == 5 && + (filter.ndim == 5 || filter.ndim == 6 || + filter.ndim == 7) && + src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4, + "NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and " + "filter's ndim is " + "5 or 6, and " + "last shape " + "is 4 " + "but got src %s, filter %s", + src.to_string().c_str(), filter.to_string().c_str()); } if (param().format == Param::Format::NCHW8) { megdnn_assert( src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) && - src[src.ndim - 1] == 8 && - filter[filter.ndim - 1] == 8, + src[src.ndim - 1] == 8 && filter[filter.ndim - 1] == 8, "NCHW8 require src and filter's ndim is 5 or 6, and last " "shape is 8 " "but got src %s, filter %s", @@ -675,72 +693,67 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, } if (param().format == Param::Format::NCHW32 || param().format == Param::Format::NCHW32_NCHW4) { - megdnn_assert(src.ndim == 5 && - (filter.ndim == 5 || filter.ndim == 6) && - src[src.ndim - 1] == 32 && - filter[filter.ndim - 1] == 32, - "NCHW32/NCHW32_NCHW4 require src and filter's ndim " - "is 5 or 6, and last " - "shape is 32 " - "but got src %s, filter %s", - src.to_string().c_str(), filter.to_string().c_str()); + megdnn_assert( + src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) && + src[src.ndim - 1] == 32 && filter[filter.ndim - 1] == 32, + "NCHW32/NCHW32_NCHW4 require src and filter's ndim " + "is 5 or 6, and last " + "shape is 32 " + "but got src %s, filter %s", + src.to_string().c_str(), filter.to_string().c_str()); } if (param().format == Param::Format::NCHW88) { - megdnn_assert((src.ndim == 4 && filter.ndim == 5 && - filter[filter.ndim - 1] == 8) || - (src.ndim == 5 && - ((filter.ndim == 6 && - filter[filter.ndim - 1] == 8) || - (filter.ndim == 7 && - filter[filter.ndim - 1] == 8 && - filter[filter.ndim - 2] == 8)) && - src[src.ndim - 1] == 8), - "NCHW88 require src ndim is 5 and filter's ndim is 6 " - ", and last shape two is 8 but got src %s, filter %s", - src.to_string().c_str(), filter.to_string().c_str()); + megdnn_assert( + (src.ndim == 4 && filter.ndim == 5 && + filter[filter.ndim - 1] == 8) || + (src.ndim == 5 && + ((filter.ndim == 6 && filter[filter.ndim - 1] == 8) || + (filter.ndim == 7 && filter[filter.ndim - 1] == 8 && + filter[filter.ndim - 2] == 8)) && + src[src.ndim - 1] == 8), + "NCHW88 require src ndim is 5 and filter's ndim is 6 " + ", and last shape two is 8 but got src %s, filter %s", + src.to_string().c_str(), filter.to_string().c_str()); } if (param().format == Param::Format::NCHW44 || param().format == Param::Format::NCHW44_DOT) { //! support nchw44 filter change to 88 for int8 winogradf23_88 using //! MK8 mamtul - megdnn_assert((src.ndim == 4 && filter.ndim == 5 && - filter[filter.ndim - 1] == 4) || - (src.ndim == 5 && - ((filter.ndim == 6 && - (filter[filter.ndim - 1] == 4 || - filter[filter.ndim - 1] == 8)) || - (filter.ndim == 7 && - (filter[filter.ndim - 1] == 4 || - filter[filter.ndim - 1] == 8) && - (filter[filter.ndim - 2] == 4 || - filter[filter.ndim - 2] == 8))) && - src[src.ndim - 1] == 4), - "NCHW44 require src ndim is 5 and filter's ndim is 6 " - ", and last shape two is 4 but got src %s, filter %s", - src.to_string().c_str(), filter.to_string().c_str()); + megdnn_assert( + (src.ndim == 4 && filter.ndim == 5 && + filter[filter.ndim - 1] == 4) || + (src.ndim == 5 && + ((filter.ndim == 6 && (filter[filter.ndim - 1] == 4 || + filter[filter.ndim - 1] == 8)) || + (filter.ndim == 7 && + (filter[filter.ndim - 1] == 4 || + filter[filter.ndim - 1] == 8) && + (filter[filter.ndim - 2] == 4 || + filter[filter.ndim - 2] == 8))) && + src[src.ndim - 1] == 4), + "NCHW44 require src ndim is 5 and filter's ndim is 6 " + ", and last shape two is 4 but got src %s, filter %s", + src.to_string().c_str(), filter.to_string().c_str()); } if (param().format == Param::Format::CHWN4) { megdnn_assert( src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) && - src[src.ndim - 1] == 4 && - filter[filter.ndim - 1] == 4, + src[src.ndim - 1] == 4 && filter[filter.ndim - 1] == 4, "CHWN4 require src and filter's ndim is 5 or 6, and last " "shape is 4 " "but got src %s, filter %s", src.to_string().c_str(), filter.to_string().c_str()); } if (param().format == Param::Format::NCHW64) { - megdnn_assert(src.ndim == 5 && - (filter.ndim == 5 || filter.ndim == 6) && - src[src.ndim - 1] == 64 && - filter[filter.ndim - 1] == 64, - "NCHW64 require src and filter's ndim is 5 or 6, and " - "last shape is 64 but got src %s, filter %s", - src.to_string().c_str(), filter.to_string().c_str()); + megdnn_assert( + src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) && + src[src.ndim - 1] == 64 && filter[filter.ndim - 1] == 64, + "NCHW64 require src and filter's ndim is 5 or 6, and " + "last shape is 64 but got src %s, filter %s", + src.to_string().c_str(), filter.to_string().c_str()); } } - megdnn_assert(img_dim == 2, - "currently only convolution on 2D image is supported"); + megdnn_assert(img_dim == 2, "currently only convolution on 2D image is supported"); auto cflt = make_canonized_filter_meta(src.ndim, filter); if (param().format == Param::Format::NCHW || param().format == Param::Format::NHWC) { @@ -750,13 +763,13 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, src_or_dst_c_pos = 1; src_or_dst_spatial_start = 2; } else { - megdnn_assert(param().format == Param::Format::NHWC, - "invalid conv format"); + megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format"); src_or_dst_c_pos = 3; src_or_dst_spatial_start = 1; } - megdnn_assert(cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s", - errmsg().c_str()); + megdnn_assert( + cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s", + errmsg().c_str()); dst.ndim = src.ndim; dst[0] = src[0]; dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group; @@ -766,222 +779,222 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, cflt.stride[i], cflt.padding[i]); } } else if (param().format == Param::Format::NCHW4) { - megdnn_assert(src.ndim == 5, - "invalid src ndim for NCHW4, expected=5, got=%zu", - src.ndim); - megdnn_assert(cflt.icpg * cflt.group == src[1] * 4, - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + src.ndim == 5, "invalid src ndim for NCHW4, expected=5, got=%zu", + src.ndim); + megdnn_assert( + cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", + errmsg().c_str(), cflt.icpg, cflt.group); dst.ndim = src.ndim; dst[0] = src[0]; auto oc = cflt.ocpg * cflt.group; megdnn_assert(oc % 4 == 0); dst[1] = oc / 4; - dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[2] = infer_conv_shape( + src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[3] = infer_conv_shape( + src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); dst[4] = 4; } else if (param().format == Param::Format::NCHW8) { - megdnn_assert(src.ndim == 5, - "invalid src ndim for NCHW8, expected=5, got=%zu", - src.ndim); - megdnn_assert(cflt.icpg * cflt.group == src[1] * 8, - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + src.ndim == 5, "invalid src ndim for NCHW8, expected=5, got=%zu", + src.ndim); + megdnn_assert( + cflt.icpg * cflt.group == src[1] * 8, "%s icpg=%u group=%u", + errmsg().c_str(), cflt.icpg, cflt.group); dst.ndim = src.ndim; dst[0] = src[0]; auto oc = cflt.ocpg * cflt.group; megdnn_assert(oc % 8 == 0); dst[1] = oc / 8; - dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[2] = infer_conv_shape( + src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[3] = infer_conv_shape( + src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); dst[4] = 8; } else if (param().format == Param::Format::NCHW32) { - megdnn_assert(src.ndim == 5, - "invalid src ndim for NCHW32, expected=5, got=%zu", - src.ndim); - megdnn_assert(cflt.icpg * cflt.group == src[1] * 32, - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + src.ndim == 5, "invalid src ndim for NCHW32, expected=5, got=%zu", + src.ndim); + megdnn_assert( + cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u", + errmsg().c_str(), cflt.icpg, cflt.group); dst.ndim = src.ndim; dst[0] = src[0]; auto oc = cflt.ocpg * cflt.group; megdnn_assert(oc % 32 == 0); dst[1] = oc / 32; - dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[2] = infer_conv_shape( + src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[3] = infer_conv_shape( + src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); dst[4] = 32; } else if (param().format == Param::Format::NCHW88) { - megdnn_assert(src.ndim == 5 || (src.ndim == 4 && src[1] <= 8), - "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", - src.ndim); + megdnn_assert( + src.ndim == 5 || (src.ndim == 4 && src[1] <= 8), + "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim); dst.ndim = 5; dst[0] = src[0]; auto oc = cflt.ocpg * cflt.group; megdnn_assert(oc % 8 == 0); dst[1] = oc / 8; - dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[2] = infer_conv_shape( + src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[3] = infer_conv_shape( + src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); dst[4] = 8; if (cflt.group == 1) { - megdnn_assert(cflt.icpg * cflt.group == src[1] * 8 || - (cflt.icpg * cflt.group == src[1]), - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + cflt.icpg * cflt.group == src[1] * 8 || + (cflt.icpg * cflt.group == src[1]), + "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group); } - } else if (param().format == Param::Format::NCHW44 || - param().format == Param::Format::NCHW44_DOT) { - megdnn_assert(src.ndim == 5 || (src.ndim == 4 && src[1] <= 4), - "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", - src.ndim); + } else if ( + param().format == Param::Format::NCHW44 || + param().format == Param::Format::NCHW44_DOT) { + megdnn_assert( + src.ndim == 5 || (src.ndim == 4 && src[1] <= 4), + "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim); dst.ndim = 5; dst[0] = src[0]; auto oc = cflt.ocpg * cflt.group; megdnn_assert(oc % 4 == 0); dst[1] = oc / 4; - dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[2] = infer_conv_shape( + src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[3] = infer_conv_shape( + src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); dst[4] = 4; if (cflt.group == 1) { - megdnn_assert(cflt.icpg * cflt.group == src[1] * 4 || - (cflt.icpg * cflt.group == src[1]), - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + cflt.icpg * cflt.group == src[1] * 4 || + (cflt.icpg * cflt.group == src[1]), + "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group); } } else if (param().format == Param::Format::CHWN4) { - megdnn_assert(src.ndim == 5, - "invalid src ndim for CHWN4, expected=5, got=%zu", - src.ndim); - megdnn_assert(cflt.icpg * cflt.group == src[0] * 4, - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + src.ndim == 5, "invalid src ndim for CHWN4, expected=5, got=%zu", + src.ndim); + megdnn_assert( + cflt.icpg * cflt.group == src[0] * 4, "%s icpg=%u group=%u", + errmsg().c_str(), cflt.icpg, cflt.group); dst.ndim = src.ndim; dst[3] = src[3]; auto oc = cflt.ocpg * cflt.group; megdnn_assert(oc % 4 == 0); dst[0] = oc / 4; - dst[1] = infer_conv_shape(src[1], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[1] = infer_conv_shape( + src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[2] = infer_conv_shape( + src[2], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); dst[4] = 4; } else if (param().format == Param::Format::NCHW4_NCHW) { - megdnn_assert(src.ndim == 5, - "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu", - src.ndim); - megdnn_assert(cflt.icpg * cflt.group == src[1] * 4, - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + src.ndim == 5, "invalid src ndim for NCHW4_NCHW, expected=5, got=%zu", + src.ndim); + megdnn_assert( + cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", + errmsg().c_str(), cflt.icpg, cflt.group); dst.ndim = 4; dst[0] = src[0]; auto oc = cflt.ocpg * cflt.group; dst[1] = oc; - dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[2] = infer_conv_shape( + src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[3] = infer_conv_shape( + src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); } else if (param().format == Param::Format::NCHW4_NHWC) { - megdnn_assert(src.ndim == 5, - "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu", - src.ndim); - megdnn_assert(cflt.icpg * cflt.group == src[1] * 4, - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + src.ndim == 5, "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu", + src.ndim); + megdnn_assert( + cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", + errmsg().c_str(), cflt.icpg, cflt.group); dst.ndim = 4; dst[0] = src[0]; - dst[1] = infer_conv_shape(src[2], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[2] = infer_conv_shape(src[3], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[1] = infer_conv_shape( + src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[2] = infer_conv_shape( + src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); auto oc = cflt.ocpg * cflt.group; dst[3] = oc; } else if (param().format == Param::Format::NCHW4_NCHW32) { - megdnn_assert(src.ndim == 5, - "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu", - src.ndim); - megdnn_assert(cflt.icpg * cflt.group == src[1] * 4, - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu", + src.ndim); + megdnn_assert( + cflt.icpg * cflt.group == src[1] * 4, "%s icpg=%u group=%u", + errmsg().c_str(), cflt.icpg, cflt.group); dst.ndim = src.ndim; dst[0] = src[0]; auto oc = cflt.ocpg * cflt.group; megdnn_assert(oc % 32 == 0); dst[1] = oc / 32; - dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[2] = infer_conv_shape( + src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[3] = infer_conv_shape( + src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); dst[4] = 32; } else if (param().format == Param::Format::NCHW32_NCHW4) { - megdnn_assert(src.ndim == 5, - "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu", - src.ndim); - megdnn_assert(cflt.icpg * cflt.group == src[1] * 32, - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + src.ndim == 5, "invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu", + src.ndim); + megdnn_assert( + cflt.icpg * cflt.group == src[1] * 32, "%s icpg=%u group=%u", + errmsg().c_str(), cflt.icpg, cflt.group); dst.ndim = src.ndim; dst[0] = src[0]; auto oc = cflt.ocpg * cflt.group; megdnn_assert(oc % 4 == 0); dst[1] = oc / 4; - dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[2] = infer_conv_shape( + src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[3] = infer_conv_shape( + src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); dst[4] = 4; } else if (param().format == Param::Format::NCHW64) { - megdnn_assert(src.ndim == 5, - "invalid src ndim for NCHW64, expected=5, got=%zu", - src.ndim); - megdnn_assert(cflt.icpg * cflt.group == src[1] * 64, - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + src.ndim == 5, "invalid src ndim for NCHW64, expected=5, got=%zu", + src.ndim); + megdnn_assert( + cflt.icpg * cflt.group == src[1] * 64, "%s icpg=%u group=%u", + errmsg().c_str(), cflt.icpg, cflt.group); dst.ndim = src.ndim; dst[0] = src[0]; auto oc = cflt.ocpg * cflt.group; megdnn_assert(oc % 64 == 0); dst[1] = oc / 64; - dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[2] = infer_conv_shape( + src[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[3] = infer_conv_shape( + src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); dst[4] = 64; } else { megdnn_assert(param().format == Param::Format::NHWCD4); - megdnn_assert(src.ndim == 5, - "invalid src ndim for NHWCD4, expected=5, got=%zu", - src.ndim); - megdnn_assert(cflt.icpg * cflt.group == src[2] * 4, - "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, - cflt.group); + megdnn_assert( + src.ndim == 5, "invalid src ndim for NHWCD4, expected=5, got=%zu", + src.ndim); + megdnn_assert( + cflt.icpg * cflt.group == src[2] * 4, "%s icpg=%u group=%u", + errmsg().c_str(), cflt.icpg, cflt.group); dst.ndim = src.ndim; dst[0] = src[0]; auto oc = cflt.ocpg * cflt.group; megdnn_assert(oc % 4 == 0); dst[2] = oc / 4; - dst[1] = infer_conv_shape(src[1], cflt.dilated_spatial[0], - cflt.stride[0], cflt.padding[0]); - dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], - cflt.stride[1], cflt.padding[1]); + dst[1] = infer_conv_shape( + src[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + dst[3] = infer_conv_shape( + src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); megdnn_assert(src[4] == 4); dst[4] = 4; } - if (!src.format.is_default() && - !src.format.is_lowbit_aligned()) { // propagate + if (!src.format.is_default() && !src.format.is_lowbit_aligned()) { // propagate dst.format = src.format; } else { // determined by dtype dst.format = TensorFormat(dst.dtype); - } + } dst.init_contiguous_stride(); return cflt; } @@ -996,10 +1009,11 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, * https://stackoverflow.com/questions/25594644/warning-specialization-of-template-in-different-namespace */ template <> -ConvolutionBase::CanonizedFilterMeta -ConvolutionBase::check_layout_fwd( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst) const { +ConvolutionBase::CanonizedFilterMeta ConvolutionBase< + param::Convolution>:: + check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) const { megdnn_assert_contiguous(src); megdnn_assert_contiguous(filter); TensorLayout dst_expected; @@ -1011,10 +1025,10 @@ ConvolutionBase::check_layout_fwd( } template <> -ConvolutionBase::CanonizedFilterMeta -ConvolutionBase::check_layout_fwd( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst) const { +ConvolutionBase::CanonizedFilterMeta ConvolutionBase:: + check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) const { megdnn_assert_contiguous(src); megdnn_assert_contiguous(filter); TensorLayout dst_expected; @@ -1026,10 +1040,11 @@ ConvolutionBase::check_layout_fwd( } template <> -ConvolutionBase::CanonizedFilterMeta -ConvolutionBase::check_layout_fwd( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst) const { +ConvolutionBase::CanonizedFilterMeta ConvolutionBase< + param::BatchConvBias>:: + check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) const { megdnn_assert_contiguous(src); megdnn_assert_contiguous(filter); TensorLayout dst_expected; @@ -1044,16 +1059,14 @@ void ConvolutionForward::deduce_dtype(DType src, DType filter, DType& dst) { check_or_deduce_dtype_fwd(src, filter, dst); } -void ConvolutionForward::deduce_layout(const TensorLayout& src, - const TensorLayout& filter, - TensorLayout& dst) { +void ConvolutionForward::deduce_layout( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) { deduce_layout_fwd(src, filter, dst); } ConvolutionForward::CanonizedFilterMeta ConvolutionForward::check_exec( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, size_t workspace_in_bytes, - const PreprocessedFilter* preprocessed_filter) { + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, + size_t workspace_in_bytes, const PreprocessedFilter* preprocessed_filter) { auto ret = check_layout_fwd(src, filter, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, dst, preprocessed_filter); @@ -1061,11 +1074,9 @@ ConvolutionForward::CanonizedFilterMeta ConvolutionForward::check_exec( return ret; } -ConvolutionBackwardData::CanonizedFilterMeta -ConvolutionBackwardData::check_exec(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes) { +ConvolutionBackwardData::CanonizedFilterMeta ConvolutionBackwardData::check_exec( + const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { auto grad_fwd = grad; auto filter_fwd = filter; auto diff_fwd = diff; @@ -1075,67 +1086,64 @@ ConvolutionBackwardData::check_exec(const TensorLayout& filter, grad_fwd.init_contiguous_stride(); diff_fwd.init_contiguous_stride(); auto ret = check_layout_fwd(grad_fwd, filter_fwd, diff_fwd); - auto required_workspace_in_bytes = - get_workspace_in_bytes(filter, diff, grad); + auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); return ret; } -void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, - DType& grad) { +void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, DType& grad) { SmallVector supported_dst_dtype; if (filter.category() == diff.category() && filter.category() == DTypeCategory::FLOAT) { supported_dst_dtype.push_back(filter); } else if (filter.enumv() == DTypeEnum::Int8 && diff == filter) { supported_dst_dtype.push_back(dtype::Int32()); - } else if ((filter.enumv() == DTypeEnum::QuantizedS8 && - diff.enumv() == DTypeEnum::QuantizedS8) || - (filter.enumv() == DTypeEnum::Quantized8Asymm && - diff.enumv() == DTypeEnum::Quantized8Asymm)) { - supported_dst_dtype.push_back( - dtype::QuantizedS32(mul_scale(filter, diff))); + } else if ( + (filter.enumv() == DTypeEnum::QuantizedS8 && + diff.enumv() == DTypeEnum::QuantizedS8) || + (filter.enumv() == DTypeEnum::Quantized8Asymm && + diff.enumv() == DTypeEnum::Quantized8Asymm)) { + supported_dst_dtype.push_back(dtype::QuantizedS32(mul_scale(filter, diff))); if (grad.valid() && grad.enumv() == diff.enumv()) { supported_dst_dtype.push_back(grad); } } else { - megdnn_throw(ssprintf("unsupported input / diff DType: %s x %s", - filter.name(), diff.name())); + megdnn_throw(ssprintf( + "unsupported input / diff DType: %s x %s", filter.name(), diff.name())); } if (!grad.valid()) { grad = supported_dst_dtype.at(0); } else { - megdnn_assert(vec_contains(supported_dst_dtype, grad), - "unsupported ConvBwd(%s, %s) -> %s", filter.name(), - diff.name(), grad.name()); + megdnn_assert( + vec_contains(supported_dst_dtype, grad), + "unsupported ConvBwd(%s, %s) -> %s", filter.name(), diff.name(), + grad.name()); } - megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 + megdnn_assert( + param().compute_mode != Param::ComputeMode::FLOAT32 #if !MEGDNN_DISABLE_FLOAT16 - || filter.enumv() == DTypeEnum::Float16 || - filter.enumv() == DTypeEnum::BFloat16 + || filter.enumv() == DTypeEnum::Float16 || + filter.enumv() == DTypeEnum::BFloat16 #endif - , - "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " - "input / output."); + , + "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " + "input / output."); } -void ConvolutionBackwardData::deduce_layout(const TensorLayout& filter, - const TensorLayout& diff, - TensorLayout& grad) { +void ConvolutionBackwardData::deduce_layout( + const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) { auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); }; MEGDNN_MARK_USED_VAR(errmsg); megdnn_assert_contiguous(filter); megdnn_assert_contiguous(diff); - megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", - errmsg().c_str()); + megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str()); megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str()); deduce_dtype(filter.dtype, diff.dtype, grad.dtype); auto cflt = make_canonized_filter_meta(diff.ndim, filter); - auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, - size_t pad) { + auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) { MEGDNN_MARK_USED_VAR(errmsg); auto i = (out - 1) * stride + filter; megdnn_assert(i > pad * 2, "%s", errmsg().c_str()); @@ -1150,55 +1158,53 @@ void ConvolutionBackwardData::deduce_layout(const TensorLayout& filter, src_or_dst_c_pos = 1; src_or_dst_spatial_start = 2; } else { - megdnn_assert(param().format == Param::Format::NHWC, - "invalid conv format"); + megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format"); src_or_dst_c_pos = 3; src_or_dst_spatial_start = 1; } - megdnn_assert(cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s", - errmsg().c_str()); + megdnn_assert( + cflt.ocpg * cflt.group == diff[src_or_dst_c_pos], "%s", + errmsg().c_str()); grad.ndim = diff.ndim; grad[0] = diff[0]; grad[src_or_dst_c_pos] = cflt.icpg * cflt.group; for (size_t i = 0; i < cflt.spatial_ndim; ++i) { - grad[i + src_or_dst_spatial_start] = deduce( - diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i], - cflt.stride[i], cflt.padding[i]); + grad[i + src_or_dst_spatial_start] = + deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i], + cflt.stride[i], cflt.padding[i]); } } else if (param().format == Param::Format::NCHW4) { - megdnn_assert(diff.ndim == 5, - "valid diff ndim for NCHW4, expected=5, got=%zu", - diff.ndim); + megdnn_assert( + diff.ndim == 5, "valid diff ndim for NCHW4, expected=5, got=%zu", + diff.ndim); megdnn_assert(cflt.group == 1, "%s", errmsg().c_str()); - megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", - errmsg().c_str()); + megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", errmsg().c_str()); grad.ndim = diff.ndim; grad[0] = diff[0]; auto ic = cflt.icpg * cflt.group; megdnn_assert(ic % 4 == 0); grad[1] = ic / 4; - grad[2] = deduce(diff[2], cflt.dilated_spatial[0], cflt.stride[0], - cflt.padding[0]); - grad[3] = deduce(diff[3], cflt.dilated_spatial[1], cflt.stride[1], - cflt.padding[1]); + grad[2] = deduce( + diff[2], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + grad[3] = deduce( + diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); megdnn_assert(diff[4] == 4); grad[4] = 4; } else { megdnn_assert(param().format == Param::Format::NHWCD4); - megdnn_assert(diff.ndim == 5, - "valid diff ndim for NHWCD4, expected=5, got=%zu", - diff.ndim); - megdnn_assert(cflt.ocpg * cflt.group == diff[2] * 4, "%s", - errmsg().c_str()); + megdnn_assert( + diff.ndim == 5, "valid diff ndim for NHWCD4, expected=5, got=%zu", + diff.ndim); + megdnn_assert(cflt.ocpg * cflt.group == diff[2] * 4, "%s", errmsg().c_str()); grad.ndim = diff.ndim; grad[0] = diff[0]; auto ic = cflt.icpg * cflt.group; megdnn_assert(ic % 4 == 0); grad[2] = ic / 4; - grad[1] = deduce(diff[1], cflt.dilated_spatial[0], cflt.stride[0], - cflt.padding[0]); - grad[3] = deduce(diff[3], cflt.dilated_spatial[1], cflt.stride[1], - cflt.padding[1]); + grad[1] = deduce( + diff[1], cflt.dilated_spatial[0], cflt.stride[0], cflt.padding[0]); + grad[3] = deduce( + diff[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); megdnn_assert(diff[4] == 4); grad[4] = 4; } @@ -1206,15 +1212,14 @@ void ConvolutionBackwardData::deduce_layout(const TensorLayout& filter, grad.init_contiguous_stride(); } -ConvolutionBackwardFilter::CanonizedFilterMeta -ConvolutionBackwardFilter::check_exec(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes) { - megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT && - diff.dtype.category() == DTypeCategory::FLOAT && - grad.dtype.category() == DTypeCategory::FLOAT, - "only float type is supported for conv backward filter"); +ConvolutionBackwardFilter::CanonizedFilterMeta ConvolutionBackwardFilter::check_exec( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { + megdnn_assert( + src.dtype.category() == DTypeCategory::FLOAT && + diff.dtype.category() == DTypeCategory::FLOAT && + grad.dtype.category() == DTypeCategory::FLOAT, + "only float type is supported for conv backward filter"); auto src_fwd = src; auto diff_fwd = diff; diff --git a/dnn/src/common/convolution3d.cpp b/dnn/src/common/convolution3d.cpp index 5f2be1f7..f1e8c073 100644 --- a/dnn/src/common/convolution3d.cpp +++ b/dnn/src/common/convolution3d.cpp @@ -15,18 +15,17 @@ using namespace megdnn; namespace { -std::string get_errmsg(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, - const Convolution3D::Param& param) { +std::string get_errmsg( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, + const Convolution3D::Param& param) { MEGDNN_MARK_USED_VAR(src); MEGDNN_MARK_USED_VAR(filter); MEGDNN_MARK_USED_VAR(dst); return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " + megdnn_layout_msg(dst) + ", " + "is_ncdhw=" + - std::to_string(param.format == param::Convolution3D::Format::NCDHW) + - ", " + "is_xcorr=" + - std::to_string( - (param.mode == Convolution3D::Mode::CROSS_CORRELATION)) + + std::to_string(param.format == param::Convolution3D::Format::NCDHW) + ", " + + "is_xcorr=" + + std::to_string((param.mode == Convolution3D::Mode::CROSS_CORRELATION)) + ", " + "pad_d=" + std::to_string(param.pad_d) + ", " + "pad_h=" + std::to_string(param.pad_h) + ", " + "pad_w=" + std::to_string(param.pad_w) + ", " + @@ -39,8 +38,7 @@ std::string get_errmsg(const TensorLayout& src, const TensorLayout& filter, } } // namespace -Convolution3DBase::CanonizedFilterMeta -Convolution3DBase::make_canonized_filter_meta( +Convolution3DBase::CanonizedFilterMeta Convolution3DBase::make_canonized_filter_meta( size_t src_ndim, const TensorLayout& filter) const { megdnn_assert_contiguous(filter); auto img_ndim = src_ndim - 2; @@ -50,8 +48,7 @@ Convolution3DBase::make_canonized_filter_meta( if (param().mode == Mode::CONVOLUTION) { ret.should_flip = true; } else { - megdnn_assert(param().mode == Mode::CROSS_CORRELATION, - "invalid conv mode"); + megdnn_assert(param().mode == Mode::CROSS_CORRELATION, "invalid conv mode"); ret.should_flip = false; } size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos; @@ -60,19 +57,22 @@ Convolution3DBase::make_canonized_filter_meta( MEGDNN_MARK_USED_VAR(icpg_pos); if (param().sparse == Param::Sparse::DENSE) { - megdnn_assert(filter.ndim == img_ndim + 2, - "bad filter ndim for dense convolution: " - "spatial_ndim=%zu filter_ndim=%zu", - img_ndim, filter.ndim); + megdnn_assert( + filter.ndim == img_ndim + 2, + "bad filter ndim for dense convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); ret.group = 1; flt_start = 0; } else { - megdnn_assert(param().sparse == Param::Sparse::GROUP, - "invalid convolution sparse type"); - megdnn_assert(filter.ndim == img_ndim + 3, - "bad filter ndim for group convolution: " - "spatial_ndim=%zu filter_ndim=%zu", - img_ndim, filter.ndim); + megdnn_assert( + param().sparse == Param::Sparse::GROUP, + "invalid convolution sparse type"); + megdnn_assert( + filter.ndim == img_ndim + 3, + "bad filter ndim for group convolution: " + "spatial_ndim=%zu filter_ndim=%zu", + img_ndim, filter.ndim); ret.group = filter[0]; flt_start = 1; } @@ -83,8 +83,8 @@ Convolution3DBase::make_canonized_filter_meta( ocpg_pos = 0; icpg_pos = 1; } else { - megdnn_assert(param().format == Param::Format::NDHWC, - "invalid conv tensor format"); + megdnn_assert( + param().format == Param::Format::NDHWC, "invalid conv tensor format"); // filter should be (oc, fd, fh, fw, ic) flt_spatial_start = 1; ocpg_pos = 0; @@ -108,9 +108,9 @@ Convolution3DBase::make_canonized_filter_meta( ret.ocpg = filter[flt_start + ocpg_pos]; ret.icpg = filter[flt_start + icpg_pos]; for (size_t i = 0; i < ret.spatial_ndim; ++i) { - megdnn_assert(ret.dilation[i] > 0, - "invalid dilation on spatial dim %zu: %u", i, - ret.dilation[i]); + megdnn_assert( + ret.dilation[i] > 0, "invalid dilation on spatial dim %zu: %u", i, + ret.dilation[i]); ret.spatial[i] = filter[i + flt_start + flt_spatial_start]; ret.dilated_spatial[i] = (ret.spatial[i] - 1) * ret.dilation[i] + 1; } @@ -118,27 +118,29 @@ Convolution3DBase::make_canonized_filter_meta( } Convolution3DBase::CanonizedFilterMeta Convolution3DBase::deduce_layout_fwd( - const TensorLayout& src, const TensorLayout& filter, - TensorLayout& dst) const { + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) const { auto errmsg = [&]() { return get_errmsg(src, filter, dst, param()); }; MEGDNN_MARK_USED_VAR(errmsg); megdnn_assert(src.ndim >= 5_z, "%s", errmsg().c_str()); megdnn_assert(src.dtype == filter.dtype, "%s", errmsg().c_str()); if (param().data_type == Param::DataType::FLOAT) { - megdnn_assert(src.dtype == dtype::Float32() DNN_INC_FLOAT16( - || src.dtype == dtype::Float16()), - "invalid src dtype for conv: %s", src.dtype.name()); + megdnn_assert( + src.dtype == dtype::Float32() + DNN_INC_FLOAT16(|| src.dtype == dtype::Float16()), + "invalid src dtype for conv: %s", src.dtype.name()); dst.dtype = src.dtype; } else { megdnn_assert(param().data_type == Param::DataType::FLOAT_IO16xC32); - DNN_INC_FLOAT16(megdnn_assert(src.dtype == dtype::Float16(), - "invalid src dtype for conv: %s", src.dtype.name())); + DNN_INC_FLOAT16(megdnn_assert( + src.dtype == dtype::Float16(), "invalid src dtype for conv: %s", + src.dtype.name())); DNN_INC_FLOAT16(dst.dtype = dtype::Float16()); } auto img_dim = src.ndim - 2; megdnn_assert(img_dim == 3, "this is the convolution for 3D image"); - megdnn_assert(filter.ndim == img_dim + 2 || filter.ndim == img_dim + 3, - "%s", errmsg().c_str()); + megdnn_assert( + filter.ndim == img_dim + 2 || filter.ndim == img_dim + 3, "%s", + errmsg().c_str()); auto cflt = make_canonized_filter_meta(src.ndim, filter); size_t src_or_dst_c_pos = 0; size_t src_or_dst_spatial_start = 0; @@ -146,13 +148,12 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::deduce_layout_fwd( src_or_dst_c_pos = 1; src_or_dst_spatial_start = 2; } else { - megdnn_assert(param().format == Param::Format::NDHWC, - "invalid conv format"); + megdnn_assert(param().format == Param::Format::NDHWC, "invalid conv format"); src_or_dst_c_pos = 4; src_or_dst_spatial_start = 1; } - megdnn_assert(cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s", - errmsg().c_str()); + megdnn_assert( + cflt.icpg * cflt.group == src[src_or_dst_c_pos], "%s", errmsg().c_str()); dst.ndim = src.ndim; dst[0] = src[0]; dst[src_or_dst_c_pos] = cflt.ocpg * cflt.group; @@ -176,15 +177,14 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DBase::check_layout_fwd( return ret; } -void Convolution3DForward::deduce_layout(const TensorLayout& src, - const TensorLayout& filter, - TensorLayout& dst) { +void Convolution3DForward::deduce_layout( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) { deduce_layout_fwd(src, filter, dst); } Convolution3DBase::CanonizedFilterMeta Convolution3DForward::check_exec( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, size_t workspace_in_bytes) { + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, + size_t workspace_in_bytes) { auto src_fwd = src; auto dst_fwd = dst; src_fwd.init_contiguous_stride(); @@ -197,41 +197,39 @@ Convolution3DBase::CanonizedFilterMeta Convolution3DForward::check_exec( } Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardData::check_exec( - const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_in_bytes) { - megdnn_assert(param().data_type == Param::DataType::FLOAT, - "only float type is supported for conv backward"); + const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { + megdnn_assert( + param().data_type == Param::DataType::FLOAT, + "only float type is supported for conv backward"); auto diff_fwd = diff; auto grad_fwd = grad; diff_fwd.init_contiguous_stride(); grad_fwd.init_contiguous_stride(); auto ret = check_layout_fwd(grad_fwd, filter, diff_fwd); - auto required_workspace_in_bytes = - get_workspace_in_bytes(filter, diff, grad); + auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); return ret; } -void Convolution3DBackwardData::deduce_layout(const TensorLayout& filter, - const TensorLayout& diff, - TensorLayout& grad) { - megdnn_assert(param().data_type == Param::DataType::FLOAT, - "only float type is supported for conv backward"); +void Convolution3DBackwardData::deduce_layout( + const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) { + megdnn_assert( + param().data_type == Param::DataType::FLOAT, + "only float type is supported for conv backward"); auto errmsg = [&]() { return get_errmsg(filter, diff, grad, param()); }; MEGDNN_MARK_USED_VAR(errmsg); megdnn_assert_contiguous(filter); megdnn_assert_contiguous(diff); - megdnn_assert(filter.ndim == 5_z || filter.ndim == 6_z, "%s", - errmsg().c_str()); + megdnn_assert(filter.ndim == 5_z || filter.ndim == 6_z, "%s", errmsg().c_str()); megdnn_assert(diff.ndim == 5_z, "%s", errmsg().c_str()); megdnn_assert(filter.dtype == diff.dtype, "%s", errmsg().c_str()); auto cflt = make_canonized_filter_meta(diff.ndim, filter); megdnn_assert(cflt.ocpg * cflt.group == diff[1], "%s", errmsg().c_str()); - auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, - size_t pad) { + auto deduce = [&errmsg](size_t out, size_t filter, size_t stride, size_t pad) { MEGDNN_MARK_USED_VAR(errmsg); auto i = (out - 1) * stride + filter; megdnn_assert(i > pad * 2, "%s", errmsg().c_str()); @@ -243,17 +241,18 @@ void Convolution3DBackwardData::deduce_layout(const TensorLayout& filter, grad[1] = cflt.group * cflt.icpg; grad.dtype = diff.dtype; for (size_t i = 0; i < cflt.spatial_ndim; ++i) { - grad[i + 2] = deduce(diff[i + 2], cflt.dilated_spatial[i], - cflt.stride[i], cflt.padding[i]); + grad[i + 2] = deduce( + diff[i + 2], cflt.dilated_spatial[i], cflt.stride[i], cflt.padding[i]); } grad.init_contiguous_stride(); } Convolution3DBase::CanonizedFilterMeta Convolution3DBackwardFilter::check_exec( - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_in_bytes) { - megdnn_assert(param().data_type == Param::DataType::FLOAT, - "only float type is supported for conv backward"); + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { + megdnn_assert( + param().data_type == Param::DataType::FLOAT, + "only float type is supported for conv backward"); auto src_fwd = src; auto diff_fwd = diff; src_fwd.init_contiguous_stride(); diff --git a/dnn/src/common/correlation.cpp b/dnn/src/common/correlation.cpp index 1d211b7c..9cb4223c 100644 --- a/dnn/src/common/correlation.cpp +++ b/dnn/src/common/correlation.cpp @@ -15,22 +15,22 @@ namespace megdnn { -void CorrelationBase::deduce_layout_fwd(const TensorLayout& data1, - const TensorLayout& data2, - TensorLayout& dst) { +void CorrelationBase::deduce_layout_fwd( + const TensorLayout& data1, const TensorLayout& data2, TensorLayout& dst) { megdnn_assert_contiguous(data1); megdnn_assert_contiguous(data2); megdnn_assert_contiguous(dst); auto errmsg = [&]() { - return megdnn_layout_msg(data1) + ", " + megdnn_layout_msg(data2) + - ", " + megdnn_layout_msg(dst); + return megdnn_layout_msg(data1) + ", " + megdnn_layout_msg(data2) + ", " + + megdnn_layout_msg(dst); }; MEGDNN_MARK_USED_VAR(errmsg); using Format = CorrelationBase::Param::Format; megdnn_assert(param().format == Format::NCHW); auto data1_dtype = data1.dtype, data2_dtype = data2.dtype; - megdnn_assert(data1_dtype == data2_dtype && - data1_dtype.category() == DTypeCategory::FLOAT); + megdnn_assert( + data1_dtype == data2_dtype && + data1_dtype.category() == DTypeCategory::FLOAT); megdnn_assert(data1.ndim == 4_z, "%s", errmsg().c_str()); megdnn_assert(data2.ndim == 4_z, "%s", errmsg().c_str()); @@ -55,13 +55,11 @@ void CorrelationBase::deduce_layout_fwd(const TensorLayout& data1, uint32_t top_channels = neighborhood_grid_width * neighborhood_grid_width; megdnn_assert(top_width >= 1 && top_height >= 1); - dst = TensorLayout{{data1[0], top_channels, top_height, top_width}, - data1.dtype}; + dst = TensorLayout{{data1[0], top_channels, top_height, top_width}, data1.dtype}; } -void CorrelationBase::check_layout_fwd(const TensorLayout& data1, - const TensorLayout& data2, - const TensorLayout& dst) { +void CorrelationBase::check_layout_fwd( + const TensorLayout& data1, const TensorLayout& data2, const TensorLayout& dst) { TensorLayout dst_expected; megdnn_assert_eq_dtype(data1, dst); megdnn_assert_eq_shape(data1, data2); @@ -69,27 +67,22 @@ void CorrelationBase::check_layout_fwd(const TensorLayout& data1, megdnn_assert_eq_shape(dst_expected, dst); } -void CorrelationForward::deduce_layout(const TensorLayout& data1, - const TensorLayout& data2, - TensorLayout& dst) { +void CorrelationForward::deduce_layout( + const TensorLayout& data1, const TensorLayout& data2, TensorLayout& dst) { deduce_layout_fwd(data1, data2, dst); } -void CorrelationForward::check_exec(const TensorLayout& data1, - const TensorLayout& data2, - const TensorLayout& dst, - size_t workspace_in_bytes) { +void CorrelationForward::check_exec( + const TensorLayout& data1, const TensorLayout& data2, const TensorLayout& dst, + size_t workspace_in_bytes) { check_layout_fwd(data1, data2, dst); - auto required_workspace_in_bytes = - get_workspace_in_bytes(data1, data2, dst); + auto required_workspace_in_bytes = get_workspace_in_bytes(data1, data2, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void CorrelationBackwardData1::check_exec(const TensorLayout& diff, - const TensorLayout& data1, - const TensorLayout& data2, - const TensorLayout& grad1, - size_t workspace_in_bytes) { +void CorrelationBackwardData1::check_exec( + const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& grad1, size_t workspace_in_bytes) { check_layout_fwd(grad1, data2, diff); megdnn_assert_eq_shape(data1, data2); auto required_workspace_in_bytes = @@ -97,11 +90,9 @@ void CorrelationBackwardData1::check_exec(const TensorLayout& diff, megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void CorrelationBackwardData2::check_exec(const TensorLayout& diff, - const TensorLayout& data1, - const TensorLayout& data2, - const TensorLayout& grad2, - size_t workspace_in_bytes) { +void CorrelationBackwardData2::check_exec( + const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& grad2, size_t workspace_in_bytes) { check_layout_fwd(data1, grad2, diff); megdnn_assert_eq_shape(data1, data2); auto required_workspace_in_bytes = @@ -109,19 +100,17 @@ void CorrelationBackwardData2::check_exec(const TensorLayout& diff, megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void CorrelationBackwardData2::deduce_layout(const TensorLayout& diff, - const TensorLayout& data1, - const TensorLayout& data2, - TensorLayout& grad) { +void CorrelationBackwardData2::deduce_layout( + const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, + TensorLayout& grad) { megdnn_assert_eq_shape(data1, data2); check_layout_fwd(data1, data2, diff); grad = data2; } -void CorrelationBackwardData1::deduce_layout(const TensorLayout& diff, - const TensorLayout& data1, - const TensorLayout& data2, - TensorLayout& grad) { +void CorrelationBackwardData1::deduce_layout( + const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, + TensorLayout& grad) { megdnn_assert_eq_shape(data1, data2); check_layout_fwd(data1, data2, diff); grad = data1; diff --git a/dnn/src/common/cumsum.cpp b/dnn/src/common/cumsum.cpp index bb29b7f9..44e9dc95 100644 --- a/dnn/src/common/cumsum.cpp +++ b/dnn/src/common/cumsum.cpp @@ -14,16 +14,13 @@ namespace megdnn { -void CumsumForward::deduce_layout(const TensorLayout &src, TensorLayout &dst) -{ +void CumsumForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { megdnn_assert_contiguous(src); dst = src; } -void CumsumForward::check_exec(const TensorLayout &src, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void CumsumForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { megdnn_assert_contiguous(src); megdnn_assert_eq_layout(src, dst); megdnn_assert(param().axis >= 0); @@ -32,6 +29,6 @@ void CumsumForward::check_exec(const TensorLayout &src, megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/cv/aligned_allocator.h b/dnn/src/common/cv/aligned_allocator.h index d68462e3..b24f4aa1 100644 --- a/dnn/src/common/cv/aligned_allocator.h +++ b/dnn/src/common/cv/aligned_allocator.h @@ -22,7 +22,6 @@ #include "malloc.h" #endif - #if defined(__ANDROID__) || defined(ANDROID) #include "malloc.h" #define HAS_MEMALIGN @@ -98,14 +97,14 @@ public: }; template -inline bool operator==(const aligned_allocator<_T1, _A1>&, - const aligned_allocator<_T2, _A2>&) { +inline bool operator==( + const aligned_allocator<_T1, _A1>&, const aligned_allocator<_T2, _A2>&) { return true; } template -inline bool operator!=(const aligned_allocator<_T1, _A1>&, - const aligned_allocator<_T2, _A2>&) { +inline bool operator!=( + const aligned_allocator<_T1, _A1>&, const aligned_allocator<_T2, _A2>&) { return false; } diff --git a/dnn/src/common/cv/bordermode-inl.h b/dnn/src/common/cv/bordermode-inl.h index 326a0c91..2f83c479 100644 --- a/dnn/src/common/cv/bordermode-inl.h +++ b/dnn/src/common/cv/bordermode-inl.h @@ -64,8 +64,9 @@ static inline int border_interpolate(int p, int len, BorderMode bmode) { ; else if (bmode == BorderMode::BORDER_REPLICATE) p = p < 0 ? 0 : len - 1; - else if (bmode == BorderMode::BORDER_REFLECT || - bmode == BorderMode::BORDER_REFLECT_101) { + else if ( + bmode == BorderMode::BORDER_REFLECT || + bmode == BorderMode::BORDER_REFLECT_101) { int delta = (bmode == BorderMode::BORDER_REFLECT_101); if (len == 1) return 0; @@ -82,8 +83,9 @@ static inline int border_interpolate(int p, int len, BorderMode bmode) { while (p >= len) { p -= len; } - } else if (bmode == BorderMode::BORDER_CONSTANT || - bmode == BorderMode::BORDER_TRANSPARENT) + } else if ( + bmode == BorderMode::BORDER_CONSTANT || + bmode == BorderMode::BORDER_TRANSPARENT) p = -1; else MegCVException("Unknown/unsupported border type"); diff --git a/dnn/src/common/cv/common.h b/dnn/src/common/cv/common.h index 43461354..6f928a58 100644 --- a/dnn/src/common/cv/common.h +++ b/dnn/src/common/cv/common.h @@ -129,9 +129,7 @@ private: public: void* raw_ptr() { return static_cast(m_data.get() + m_offset); } - const void* raw_ptr() const { - return static_cast(m_data.get() + m_offset); - } + const void* raw_ptr() const { return static_cast(m_data.get() + m_offset); } Mat(); Mat(size_t rows, size_t cols, size_t channels, size_t step); @@ -141,8 +139,8 @@ public: Mat(size_t rows, size_t cols, size_t channels, size_t step, T* data); // shallow-copy constructor Mat(const Mat& rhs); - Mat(const Mat& rhs, size_t row_offset, size_t row_count, - size_t col_offset, size_t col_count); + Mat(const Mat& rhs, size_t row_offset, size_t row_count, size_t col_offset, + size_t col_count); Mat& operator=(const Mat& rhs); T& at(size_t r, size_t c, size_t ch); diff --git a/dnn/src/common/cv/cvt_color.h b/dnn/src/common/cv/cvt_color.h index 377478bb..97a78eda 100644 --- a/dnn/src/common/cv/cvt_color.h +++ b/dnn/src/common/cv/cvt_color.h @@ -40,30 +40,21 @@ #define descale(x, n) (((x) + (1 << ((n)-1))) >> (n)) -#define GENERATE_UNSUPPORT_CVT_OPR_FOR_FLOAT(_cb) \ - _cb(cvt_rgba2rgb, float) \ - _cb(cvt_rgba2bgr, float) \ - _cb(cvt_rgba2gray, float) \ - _cb(cvt_rgb2bgr, float) \ - _cb(cvt_bgr2rgb, float) \ - _cb(cvt_yuv2gray_nv21, float) \ - _cb(cvt_yuv2rgb_nv21, float) \ - _cb(cvt_yuv2bgr_nv21, float) \ - _cb(cvt_yuv2gray_nv12, float) \ - _cb(cvt_yuv2rgb_nv12, float) \ - _cb(cvt_yuv2bgr_nv12, float) \ - _cb(cvt_yuv2gray_yv12, float) \ - _cb(cvt_yuv2rgb_yv12, float) \ - _cb(cvt_yuv2bgr_yv12, float) \ - _cb(cvt_yuv2gray_yu12, float) \ - _cb(cvt_yuv2rgb_yu12, float) \ - _cb(cvt_yuv2bgr_yu12, float) +#define GENERATE_UNSUPPORT_CVT_OPR_FOR_FLOAT(_cb) \ + _cb(cvt_rgba2rgb, float) _cb(cvt_rgba2bgr, float) _cb(cvt_rgba2gray, float) _cb( \ + cvt_rgb2bgr, float) _cb(cvt_bgr2rgb, float) _cb(cvt_yuv2gray_nv21, float) \ + _cb(cvt_yuv2rgb_nv21, float) _cb(cvt_yuv2bgr_nv21, float) \ + _cb(cvt_yuv2gray_nv12, float) _cb(cvt_yuv2rgb_nv12, float) _cb( \ + cvt_yuv2bgr_nv12, float) _cb(cvt_yuv2gray_yv12, float) \ + _cb(cvt_yuv2rgb_yv12, float) _cb(cvt_yuv2bgr_yv12, float) \ + _cb(cvt_yuv2gray_yu12, float) \ + _cb(cvt_yuv2rgb_yu12, float) \ + _cb(cvt_yuv2bgr_yu12, float) -#define GENERATE_UNSUPPORT_CVT_OPR(_opr, _type) \ - template <> \ - void _opr<_type>(const megcv::Mat<_type>&, megcv::Mat<_type>&) { \ - MegCVException("There is not a cvt_opr " #_opr \ - " to deal with " #_type); \ +#define GENERATE_UNSUPPORT_CVT_OPR(_opr, _type) \ + template <> \ + void _opr<_type>(const megcv::Mat<_type>&, megcv::Mat<_type>&) { \ + MegCVException("There is not a cvt_opr " #_opr " to deal with " #_type); \ } // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/cv/filter.cpp b/dnn/src/common/cv/filter.cpp index 81e0d113..c40a0ca3 100644 --- a/dnn/src/common/cv/filter.cpp +++ b/dnn/src/common/cv/filter.cpp @@ -68,9 +68,9 @@ namespace filter_common { #define VEC_ALIGN 16 template -FilterEngine::FilterEngine(BaseRowFilter* row_filter, - BaseColumnFilter* column_filter, size_t ch, - const ST* border_value, BorderMode bmode) +FilterEngine::FilterEngine( + BaseRowFilter* row_filter, BaseColumnFilter* column_filter, size_t ch, + const ST* border_value, BorderMode bmode) : m_row_filter(row_filter), m_column_filter(column_filter), m_ch(ch), @@ -85,8 +85,9 @@ FilterEngine::FilterEngine(BaseRowFilter* row_filter, m_buf_step = 0; //! the anchor must be in the kernerl - megdnn_assert(0 <= m_anchor.x && m_anchor.x < m_ksize.cols() && - 0 <= m_anchor.y && m_anchor.y < m_ksize.rows()); + megdnn_assert( + 0 <= m_anchor.x && m_anchor.x < m_ksize.cols() && 0 <= m_anchor.y && + m_anchor.y < m_ksize.rows()); int src_elem_size = (int)sizeof(ST) * m_ch; m_border_elem_size = src_elem_size / ((sizeof(ST) >= 4) ? sizeof(int) : 1); @@ -138,10 +139,8 @@ void FilterEngine::start(const Mat& src) { (*m_row_filter)(&m_src_row[0], dst, m_whole_size.width(), cn); } - m_buf_step = buf_elem_size * - (int)align_size(m_whole_size.width() + m_ksize.width() - 1, - VEC_ALIGN); + (int)align_size(m_whole_size.width() + m_ksize.width() - 1, VEC_ALIGN); m_ring_buf.resize(m_buf_step * m_ksize.height() + VEC_ALIGN); m_left_width = m_anchor.x; m_right_width = m_ksize.width() - m_anchor.x - 1; @@ -160,20 +159,21 @@ void FilterEngine::start(const Mat& src) { //! calc the index of the border value, we will not calc it when //! process border each time for (int i = 0; i < m_left_width; i++) { - int p0 = gaussian_blur::border_interpolate(i - m_left_width, - m_whole_size.width(), m_bmode) * + int p0 = gaussian_blur::border_interpolate( + i - m_left_width, m_whole_size.width(), m_bmode) * m_border_elem_size; for (int j = 0; j < m_border_elem_size; j++) m_border_table[i * m_border_elem_size + j] = p0 + j; } for (int i = 0; i < m_right_width; i++) { - int p0 = gaussian_blur::border_interpolate(m_whole_size.width() + i, - m_whole_size.width(), m_bmode) * + int p0 = gaussian_blur::border_interpolate( + m_whole_size.width() + i, m_whole_size.width(), + m_bmode) * m_border_elem_size; for (int j = 0; j < m_border_elem_size; j++) - m_border_table[(i + m_left_width) * m_border_elem_size + - j] = p0 + j; + m_border_table[(i + m_left_width) * m_border_elem_size + j] = + p0 + j; } } } @@ -183,8 +183,8 @@ void FilterEngine::start(const Mat& src) { } template -int FilterEngine::proceed(const uchar* src, int srcstep, int count, - uchar* dst, int dststep) { +int FilterEngine::proceed( + const uchar* src, int srcstep, int count, uchar* dst, int dststep) { const int* btab = &m_border_table[0]; int src_elem_size = static_cast(sizeof(ST) * m_ch); bool makeBorder = (m_left_width > 0 || m_right_width > 0) && @@ -201,8 +201,7 @@ int FilterEngine::proceed(const uchar* src, int srcstep, int count, count -= dcount; for (; dcount-- > 0; src += srcstep) { int bi = (start_y + row_count) % m_ksize.height(); - uchar* brow = - align_ptr(&m_ring_buf[0], VEC_ALIGN) + bi * m_buf_step; + uchar* brow = align_ptr(&m_ring_buf[0], VEC_ALIGN) + bi * m_buf_step; uchar* row = &m_src_row[0]; if (++row_count > static_cast(m_ksize.height())) { @@ -221,19 +220,16 @@ int FilterEngine::proceed(const uchar* src, int srcstep, int count, for (int i = 0; i < m_left_width * m_border_elem_size; i++) irow[i] = isrc[btab[i]]; - for (int i = 0; i < m_right_width * m_border_elem_size; - i++) { + for (int i = 0; i < m_right_width * m_border_elem_size; i++) { irow[i + (m_whole_size.width() + m_left_width) * m_border_elem_size] = - isrc[btab[i + - m_left_width * m_border_elem_size]]; + isrc[btab[i + m_left_width * m_border_elem_size]]; } } else { for (int i = 0; i < m_left_width * src_elem_size; i++) row[i] = src[btab[i]]; for (int i = 0; i < m_right_width * src_elem_size; i++) - row[i + (m_whole_size.width() + m_left_width) * - src_elem_size] = + row[i + (m_whole_size.width() + m_left_width) * src_elem_size] = src[btab[i + m_left_width * src_elem_size]]; } } @@ -242,11 +238,10 @@ int FilterEngine::proceed(const uchar* src, int srcstep, int count, } int max_i = std::min( - m_ksize.height(), - m_whole_size.height() - dy + (m_ksize.height() - 1)); + m_ksize.height(), m_whole_size.height() - dy + (m_ksize.height() - 1)); for (i = 0; i < max_i; i++) { - int src_y = gaussian_blur::border_interpolate(dy + i - m_anchor.y, - m_whole_size.rows(), m_bmode); + int src_y = gaussian_blur::border_interpolate( + dy + i - m_anchor.y, m_whole_size.rows(), m_bmode); if (src_y < 0) buf_rows[i] = align_ptr(&m_const_border_row[0], VEC_ALIGN); else { @@ -255,16 +250,16 @@ int FilterEngine::proceed(const uchar* src, int srcstep, int count, break; } int bi = src_y % m_ksize.height(); - buf_rows[i] = - align_ptr(&m_ring_buf[0], VEC_ALIGN) + bi * m_buf_step; + buf_rows[i] = align_ptr(&m_ring_buf[0], VEC_ALIGN) + bi * m_buf_step; } } if (i < static_cast(m_ksize.height())) { break; } i -= m_ksize.height() - 1; - (*m_column_filter)(const_cast(&buf_rows[0]), dst, - dststep, i, m_whole_size.width() * m_ch); + (*m_column_filter)( + const_cast(&buf_rows[0]), dst, dststep, i, + m_whole_size.width() * m_ch); } return dy; @@ -275,9 +270,9 @@ void FilterEngine::apply(const Mat& src, Mat& dst) { int src_step = src.step() * sizeof(ST); int dst_step = dst.step() * sizeof(ST); start(src); - proceed(reinterpret_cast(src.ptr()), - static_cast(src_step), m_whole_size.height(), - reinterpret_cast(dst.ptr()), static_cast(dst_step)); + proceed(reinterpret_cast(src.ptr()), static_cast(src_step), + m_whole_size.height(), reinterpret_cast(dst.ptr()), + static_cast(dst_step)); } //! explicit instantiation template @@ -288,10 +283,8 @@ template FilterEngine::FilterEngine( BaseRowFilter* _rowFilter, BaseColumnFilter* _columnFilter, size_t _CH, const float* _borderValue, BorderMode _BorderType); -template void FilterEngine::apply(const Mat& src, - Mat& dst); -template void FilterEngine::apply(const Mat& src, - Mat& dst); +template void FilterEngine::apply(const Mat& src, Mat& dst); +template void FilterEngine::apply(const Mat& src, Mat& dst); template FilterEngine::~FilterEngine(); template FilterEngine::~FilterEngine(); diff --git a/dnn/src/common/cv/filter.h b/dnn/src/common/cv/filter.h index 6e0e618d..99d62cdb 100644 --- a/dnn/src/common/cv/filter.h +++ b/dnn/src/common/cv/filter.h @@ -66,7 +66,7 @@ #include -namespace megdnn { +namespace megdnn { namespace megcv { namespace filter_common { @@ -92,8 +92,8 @@ struct RowNoVec { * \param width The width of the src * \param cn The channel size */ - int operator()(const uchar* /*src*/, uchar* /*dst*/, int /*width*/, - int /*cn*/) const { + int operator()( + const uchar* /*src*/, uchar* /*dst*/, int /*width*/, int /*cn*/) const { return 0; } }; @@ -117,8 +117,8 @@ struct ColumnNoVec { * \param count The count of rows that this column kernel processed. * \param width The width of the src */ - int operator()(const uchar** /*src*/, uchar* /*dst*/, int& /*count*/, - int /*width*/) const { + int operator()(const uchar** /*src*/, uchar* /*dst*/, int& /*count*/, int /*width*/) + const { return 0; } }; @@ -148,8 +148,7 @@ public: //! the filtering operator. Must be overridden in the derived classes. The //! horizontal border interpolation is done outside of the class. - virtual void operator()(const uchar* src, uchar* dst, int width, - int cn) = 0; + virtual void operator()(const uchar* src, uchar* dst, int width, int cn) = 0; //! The size of the kernel int ksize; @@ -164,8 +163,8 @@ public: //! the filtering operator. Must be overridden in the derived classes. The //! vertical border interpolation is done outside of the class. - virtual void operator()(const uchar** src, uchar* dst, int dststep, - int dstcount, int width) = 0; + virtual void operator()( + const uchar** src, uchar* dst, int dststep, int dstcount, int width) = 0; //! resets the internal buffers, if any virtual void reset() {} @@ -184,8 +183,7 @@ public: */ template struct RowFilter : public BaseRowFilter { - RowFilter(const Mat
& kernel_, int anchor_, - const VecOp& vec_op_ = VecOp()) { + RowFilter(const Mat
& kernel_, int anchor_, const VecOp& vec_op_ = VecOp()) { anchor = anchor_; kernel = kernel_.clone(); ksize = kernel.cols(); @@ -240,8 +238,8 @@ struct RowFilter : public BaseRowFilter { template struct SymmRowSmallFilter : public RowFilter { - SymmRowSmallFilter(const Mat
& kernel_, int anchor_, - const VecOp& vec_op_ = VecOp()) + SymmRowSmallFilter( + const Mat
& kernel_, int anchor_, const VecOp& vec_op_ = VecOp()) : RowFilter(kernel_, anchor_, vec_op_) {} void operator()(const uchar* src, uchar* dst, int width, int cn) { @@ -287,7 +285,6 @@ struct SymmRowSmallFilter : public RowFilter { s0 += kx[k] * (S[j] + S[-j]); D[i] = s0; } - } }; @@ -296,9 +293,9 @@ struct ColumnFilter : public BaseColumnFilter { typedef typename CastOp::type1 ST; typedef typename CastOp::rtype DT; - ColumnFilter(const Mat& kernel_, int anchor_, - const CastOp& cast_op_ = CastOp(), - const VecOp& vec_op_ = VecOp()) { + ColumnFilter( + const Mat& kernel_, int anchor_, const CastOp& cast_op_ = CastOp(), + const VecOp& vec_op_ = VecOp()) { kernel = kernel_.clone(); anchor = anchor_; ksize = kernel.cols(); @@ -306,43 +303,39 @@ struct ColumnFilter : public BaseColumnFilter { vec_op = vec_op_; } - void operator()(const uchar** src, uchar* dst, int dststep, int count, int width) - { + void operator()(const uchar** src, uchar* dst, int dststep, int count, int width) { const ST* ky = this->kernel.ptr(); int i = 0, k; CastOp castOp = this->cast_op; { - for( ; count > 0; count--, dst += dststep, src++ ) - { + for (; count > 0; count--, dst += dststep, src++) { DT* D = (DT*)dst; i = (this->vec_op)(src, dst, count, width); #if MEGCV_ENABLE_UNROLLED - for( ; i <= width - 4; i += 4 ) - { + for (; i <= width - 4; i += 4) { ST f = ky[0]; const ST* S = (const ST*)src[0] + i; - ST s0 = f*S[0], s1 = f*S[1], - s2 = f*S[2], s3 = f*S[3]; + ST s0 = f * S[0], s1 = f * S[1], s2 = f * S[2], s3 = f * S[3]; - for( k = 1; k < ksize; k++ ) - { + for (k = 1; k < ksize; k++) { S = (const ST*)src[k] + i; f = ky[k]; - s0 += f*S[0]; - s1 += f*S[1]; - s2 += f*S[2]; - s3 += f*S[3]; + s0 += f * S[0]; + s1 += f * S[1]; + s2 += f * S[2]; + s3 += f * S[3]; } - D[i] = castOp(s0); D[i+1] = castOp(s1); - D[i+2] = castOp(s2); D[i+3] = castOp(s3); + D[i] = castOp(s0); + D[i + 1] = castOp(s1); + D[i + 2] = castOp(s2); + D[i + 3] = castOp(s3); } #endif - for( ; i < width; i++ ) - { + for (; i < width; i++) { ST s0 = 0; - for( k = 0; k < ksize; k++ ) { - s0 += ky[k]* ((const ST*)src[k])[i]; + for (k = 0; k < ksize; k++) { + s0 += ky[k] * ((const ST*)src[k])[i]; } D[i] = castOp(s0); } @@ -360,15 +353,12 @@ struct SymmColumnFilter : public ColumnFilter { typedef typename CastOp::type1 ST; typedef typename CastOp::rtype DT; - SymmColumnFilter(const Mat& kernel_, int anchor_, - const CastOp& cast_op_ = CastOp(), - const VecOp& vec_op_ = VecOp()) - : ColumnFilter(kernel_, anchor_, cast_op_, - vec_op_) { - } + SymmColumnFilter( + const Mat& kernel_, int anchor_, const CastOp& cast_op_ = CastOp(), + const VecOp& vec_op_ = VecOp()) + : ColumnFilter(kernel_, anchor_, cast_op_, vec_op_) {} - void operator()(const uchar** src, uchar* dst, int dststep, int count, - int width) { + void operator()(const uchar** src, uchar* dst, int dststep, int count, int width) { int ksize2 = this->ksize / 2; const ST* ky = this->kernel.ptr() + ksize2; int i, k; @@ -402,8 +392,7 @@ struct SymmColumnFilter : public ColumnFilter { for (; i < width; i++) { ST s0 = ky[0] * ((const ST*)src[0])[i]; for (k = 1; k <= ksize2; k++) { - s0 += ky[k] * - (((const ST*)src[k])[i] + ((const ST*)src[-k])[i]); + s0 += ky[k] * (((const ST*)src[k])[i] + ((const ST*)src[-k])[i]); } D[i] = this->cast_op(s0); } @@ -416,17 +405,15 @@ struct SymmColumnSmallFilter : public SymmColumnFilter { typedef typename CastOp::type1 ST; typedef typename CastOp::rtype DT; - SymmColumnSmallFilter(const Mat& kernel_, int anchor_, - const CastOp& cast_op_ = CastOp(), - const VecOp& vec_op_ = VecOp()) - : SymmColumnFilter(kernel_, anchor_, cast_op_, - vec_op_) { + SymmColumnSmallFilter( + const Mat& kernel_, int anchor_, const CastOp& cast_op_ = CastOp(), + const VecOp& vec_op_ = VecOp()) + : SymmColumnFilter(kernel_, anchor_, cast_op_, vec_op_) { //! \warning Only process if the kernel size is 3 megdnn_assert(this->ksize == 3); } - void operator()(const uchar** src, uchar* dst, int dststep, int count, - int width) { + void operator()(const uchar** src, uchar* dst, int dststep, int count, int width) { int ksize2 = this->ksize / 2; const ST* ky = this->kernel.ptr() + ksize2; int i; @@ -488,8 +475,9 @@ public: * \brief Init the filter and border. * \warning row_filter and column_filter must be non-null */ - FilterEngine(BaseRowFilter* row_filter, BaseColumnFilter* column_filter, - size_t ch, const ST* border_value, BorderMode bmode); + FilterEngine( + BaseRowFilter* row_filter, BaseColumnFilter* column_filter, size_t ch, + const ST* border_value, BorderMode bmode); //! the destructor ~FilterEngine(); @@ -500,8 +488,7 @@ private: //! starts filtering of the src image. void start(const Mat& src); //! processes the next srcCount rows of the image. - int proceed(const uchar* src, int srcStep, int srcCount, uchar* dst, - int dstStep); + int proceed(const uchar* src, int srcStep, int srcCount, uchar* dst, int dstStep); //! row filter filter BaseRowFilter* m_row_filter; diff --git a/dnn/src/common/cv/helper.h b/dnn/src/common/cv/helper.h index db689bb7..66fde120 100644 --- a/dnn/src/common/cv/helper.h +++ b/dnn/src/common/cv/helper.h @@ -133,8 +133,7 @@ static inline DT saturate_cast(ST x) { template <> inline unsigned char saturate_cast(int x) { - return (unsigned char)((unsigned)x <= UCHAR_MAX ? x - : x > 0 ? UCHAR_MAX : 0); + return (unsigned char)((unsigned)x <= UCHAR_MAX ? x : x > 0 ? UCHAR_MAX : 0); } template <> @@ -198,9 +197,7 @@ struct FixedPtCast { typedef DT rtype; enum { SHIFT = bits, DELTA = 1 << (bits - 1) }; - DT operator()(ST val) const { - return saturate_cast
((val + DELTA) >> SHIFT); - } + DT operator()(ST val) const { return saturate_cast
((val + DELTA) >> SHIFT); } }; template @@ -242,8 +239,9 @@ static inline int border_interpolate(int p, int len) { ; else if (bmode == BorderMode::BORDER_REPLICATE) p = p < 0 ? 0 : len - 1; - else if (bmode == BorderMode::BORDER_REFLECT || - bmode == BorderMode::BORDER_REFLECT_101) { + else if ( + bmode == BorderMode::BORDER_REFLECT || + bmode == BorderMode::BORDER_REFLECT_101) { int delta = (bmode == BorderMode::BORDER_REFLECT_101); if (len == 1) return 0; @@ -259,8 +257,9 @@ static inline int border_interpolate(int p, int len) { while (p >= len) { p -= len; } - } else if (bmode == BorderMode::BORDER_CONSTANT || - bmode == BorderMode::BORDER_TRANSPARENT) + } else if ( + bmode == BorderMode::BORDER_CONSTANT || + bmode == BorderMode::BORDER_TRANSPARENT) p = -1; else megdnn_throw("Unknown/unsupported border type"); diff --git a/dnn/src/common/cv/interp_helper.cpp b/dnn/src/common/cv/interp_helper.cpp index 310d74d9..d1a53bca 100644 --- a/dnn/src/common/cv/interp_helper.cpp +++ b/dnn/src/common/cv/interp_helper.cpp @@ -76,8 +76,7 @@ static constexpr double MEGCV_PI_4 = 0.78539816339744830962; /* pi/4 */ typename InterpolationTable< \ INTER_BITS_, INTER_MAX_, \ INTER_REMAP_COEF_BITS_>::template TableHolder<_ksize> \ - InterpolationTable::_name + InterpolationTable::_name DEF_TABLE_HOLDER(sm_tab_linear, 2); DEF_TABLE_HOLDER(sm_tab_cubic, 4); @@ -147,8 +146,9 @@ DEF_FUN(const int16_t*) get_linear_ic4_table() { short* itab = nullptr; MEGDNN_MARK_USED_VAR(tab); MEGDNN_MARK_USED_VAR(itab); - megdnn_assert(table_holder->get(&tab, &itab), - "invoke get_table before get_linear_ic4_table"); + megdnn_assert( + table_holder->get(&tab, &itab), + "invoke get_table before get_linear_ic4_table"); return table_holder->table->bilineartab_ic4_buf; } #endif @@ -189,8 +189,8 @@ DEF_FUN(const void*) get_table(InterpolationMode imode, bool fixpt) { for (k2 = 0; k2 < ksize; ++k2) { float v = vy * _tab[j * ksize + k2]; tab[k1 * ksize + k2] = v; - isum += itab[k1 * ksize + k2] = saturate_cast( - v * INTER_REMAP_COEF_SCALE); + isum += itab[k1 * ksize + k2] = + saturate_cast(v * INTER_REMAP_COEF_SCALE); } } if (isum != INTER_REMAP_COEF_SCALE) { @@ -199,12 +199,11 @@ DEF_FUN(const void*) get_table(InterpolationMode imode, bool fixpt) { int mk1 = ksize2, mk2 = ksize2; for (k1 = ksize2; k1 < ksize2 + 2; ++k1) for (k2 = ksize2; k2 < ksize2 + 2; ++k2) { - if (itab[k1 * ksize + k2] < - itab[mk1 * ksize + mk2]) { + if (itab[k1 * ksize + k2] < itab[mk1 * ksize + mk2]) { mk1 = k1; mk2 = k2; - } else if (itab[k1 * ksize + k2] > - itab[Mk1 * ksize + Mk2]) { + } else if ( + itab[k1 * ksize + k2] > itab[Mk1 * ksize + Mk2]) { Mk1 = k1; Mk2 = k2; } @@ -223,8 +222,7 @@ DEF_FUN(const void*) get_table(InterpolationMode imode, bool fixpt) { #if MEGDNN_X86 if (imode == IMode::INTER_LINEAR) { - int16_t* bilineartab_ic4_buf = - sm_tab_linear.table->bilineartab_ic4_buf; + int16_t* bilineartab_ic4_buf = sm_tab_linear.table->bilineartab_ic4_buf; for (i = 0; i < INTER_TAB_SIZE2; i++) for (j = 0; j < 4; j++) { bilineartab_ic4_buf[i * 2 * 8 + 0 * 8 + j * 2] = diff --git a/dnn/src/common/cv/interp_helper.h b/dnn/src/common/cv/interp_helper.h index 58cb2743..872b24f6 100644 --- a/dnn/src/common/cv/interp_helper.h +++ b/dnn/src/common/cv/interp_helper.h @@ -81,8 +81,7 @@ using BorderMode = megdnn::param::WarpPerspective::BorderMode; * \brief helper for generating interpolation tables for different interpolation * modes */ -template +template class InterpolationTable { public: using IMode = InterpolationMode; @@ -125,8 +124,7 @@ private: alignas(128) int16_t bilineartab_ic4_buf[INTER_TAB_SIZE2 * 2 * 8]; static void* operator new(std::size_t sz) { - return ah::aligned_allocator().allocate(sz / - sizeof(Table)); + return ah::aligned_allocator().allocate(sz / sizeof(Table)); } void operator delete(void* ptr) noexcept { ah::aligned_allocator().deallocate( @@ -161,8 +159,7 @@ private: } }; - static void init_inter_tab_1d(InterpolationMode imode, float* tab, - int tabsz); + static void init_inter_tab_1d(InterpolationMode imode, float* tab, int tabsz); static inline void interpolate_linear(float x, float* coeffs); static inline void interpolate_cubic(float x, float* coeffs); diff --git a/dnn/src/common/cv/linalg.h b/dnn/src/common/cv/linalg.h index b7d63fce..340f65a9 100644 --- a/dnn/src/common/cv/linalg.h +++ b/dnn/src/common/cv/linalg.h @@ -179,8 +179,7 @@ void inverse_mat(value_type* A, value_type* B, uint32_t n) { /// C = A * B /// A, B must point to memory space different from C template -void mat_mult(const value_type* A, const value_type* B, value_type* C, - uint32_t n) { +void mat_mult(const value_type* A, const value_type* B, value_type* C, uint32_t n) { #define AT(A, i, j) A[(i)*n + (j)] memset(C, 0, n * n * sizeof(value_type)); for (uint32_t k = 0; k < n; k++) { @@ -192,8 +191,7 @@ void mat_mult(const value_type* A, const value_type* B, value_type* C, } template -void transpose_mat(const value_type* A, value_type* B, uint32_t rows, - uint32_t cols) { +void transpose_mat(const value_type* A, value_type* B, uint32_t rows, uint32_t cols) { for (uint32_t i = 0; i < rows; i++) for (uint32_t j = 0; j < cols; j++) B[j * rows + i] = A[i * cols + j]; @@ -203,9 +201,9 @@ void transpose_mat(const value_type* A, value_type* B, uint32_t rows, * C_{dim0xdim2} = A_{dim0xdim1} * B_{dim1xdim2} */ template -void mat_mult_non_square(const value_type* A, const value_type* B, - value_type* C, uint8_t dim0, uint32_t dim1, - uint32_t dim2) { +void mat_mult_non_square( + const value_type* A, const value_type* B, value_type* C, uint8_t dim0, + uint32_t dim1, uint32_t dim2) { memset(C, 0, dim0 * dim2 * sizeof(value_type)); for (uint32_t k = 0; k < dim1; k++) for (uint32_t i = 0; i < dim0; i++) @@ -223,8 +221,7 @@ void mat_mult_non_square(const value_type* A, const value_type* B, * @param buf sizeof (rows + cols + cols) * cols */ template -void pseudo_inverse_mat(value_type* A, uint32_t rows, uint32_t cols, - value_type* buf) { +void pseudo_inverse_mat(value_type* A, uint32_t rows, uint32_t cols, value_type* buf) { uint32_t &n = rows, &m = cols; value_type *B = buf, // m x n, A^T @@ -247,8 +244,9 @@ void pseudo_inverse_mat(value_type* A, uint32_t rows, uint32_t cols, * for detail. */ template -void solve_pseudo(value_type* A, uint32_t rows, uint32_t cols, - const value_type* b, value_type* x, value_type* buf) { +void solve_pseudo( + value_type* A, uint32_t rows, uint32_t cols, const value_type* b, value_type* x, + value_type* buf) { pseudo_inverse_mat(A, rows, cols, buf); // A is actual A^{+} now mat_mult_non_square(A, b, x, cols, rows, 1); diff --git a/dnn/src/common/cv/mat.cpp b/dnn/src/common/cv/mat.cpp index 48237e47..378ef9d5 100644 --- a/dnn/src/common/cv/mat.cpp +++ b/dnn/src/common/cv/mat.cpp @@ -24,17 +24,12 @@ namespace megcv { template Mat::Mat(size_t rows, size_t cols, size_t channels, size_t step) - : m_rows(rows), - m_cols(cols), - m_channels(channels), - m_step(step), - m_offset(0) { + : m_rows(rows), m_cols(cols), m_channels(channels), m_step(step), m_offset(0) { megdnn_assert(step >= cols * channels); megdnn_assert(1 <= channels && channels <= 4); T* raw_data; cuda_check(cudaMalloc((void**)&raw_data, sizeof(T) * rows * step)); - m_data = - std::shared_ptr(raw_data, [](T* d) { cuda_check(cudaFree(d)); }); + m_data = std::shared_ptr(raw_data, [](T* d) { cuda_check(cudaFree(d)); }); cudaMemset(m_data.get(), 0, sizeof(T) * rows * step); } @@ -61,15 +56,15 @@ Mat::Mat(const Mat& rhs) m_offset(0) {} template -Mat::Mat(const Mat& rhs, size_t row_offset, size_t row_count, - size_t col_offset, size_t col_count) +Mat::Mat( + const Mat& rhs, size_t row_offset, size_t row_count, size_t col_offset, + size_t col_count) : m_rows(row_count), m_cols(col_count), m_channels(rhs.m_channels), m_step(rhs.m_step), m_data(rhs.m_data), - m_offset(rhs.m_offset + row_offset * m_step + - col_offset * m_channels) {} + m_offset(rhs.m_offset + row_offset * m_step + col_offset * m_channels) {} template Mat& Mat::operator=(const Mat& rhs) { @@ -102,9 +97,9 @@ template Mat Mat::clone() const { Mat res(m_rows, m_cols, m_channels); for (size_t r = 0; r < m_rows; ++r) { - cuda_check(cudaMemcpy(res.ptr(r), this->ptr(r), - sizeof(T) * m_cols * m_channels, - cudaMemcpyDeviceToDevice)); + cuda_check(cudaMemcpy( + res.ptr(r), this->ptr(r), sizeof(T) * m_cols * m_channels, + cudaMemcpyDeviceToDevice)); } return res; } @@ -122,11 +117,12 @@ bool Mat::equals(const Mat& rhs) const { megdnn_assert(row1); megdnn_assert(row2); for (size_t r = 0; r < m_rows; ++r) { - cuda_check(cudaMemcpy(row1.get(), this->ptr(r), - sizeof(T) * m_cols * m_channels, - cudaMemcpyDeviceToHost)); - cuda_check(cudaMemcpy(row2.get(), rhs.ptr(r), sizeof(T) * m_cols * m_channels, - cudaMemcpyDeviceToHost)); + cuda_check(cudaMemcpy( + row1.get(), this->ptr(r), sizeof(T) * m_cols * m_channels, + cudaMemcpyDeviceToHost)); + cuda_check(cudaMemcpy( + row2.get(), rhs.ptr(r), sizeof(T) * m_cols * m_channels, + cudaMemcpyDeviceToHost)); for (size_t i = 0; i < m_cols * m_channels; ++i) { if (row1[i] != row2[i]) return false; @@ -143,15 +139,17 @@ bool Mat::is_continuous() const { template void Mat::read(const T* src) { megdnn_assert(is_continuous()); - cuda_check(cudaMemcpy(m_data.get(), src, sizeof(T) * this->total_nr_elem(), - cudaMemcpyHostToDevice)); + cuda_check(cudaMemcpy( + m_data.get(), src, sizeof(T) * this->total_nr_elem(), + cudaMemcpyHostToDevice)); } template void Mat::write(T* dst) const { megdnn_assert(is_continuous()); - cuda_check(cudaMemcpy(dst, m_data.get(), sizeof(T) * this->total_nr_elem(), - cudaMemcpyDeviceToHost)); + cuda_check(cudaMemcpy( + dst, m_data.get(), sizeof(T) * this->total_nr_elem(), + cudaMemcpyDeviceToHost)); } template class Mat; @@ -267,15 +265,15 @@ Mat::Mat(const Mat& rhs) m_offset(0) {} template -Mat::Mat(const Mat& rhs, size_t row_offset, size_t row_count, - size_t col_offset, size_t col_count) +Mat::Mat( + const Mat& rhs, size_t row_offset, size_t row_count, size_t col_offset, + size_t col_count) : m_rows(row_count), m_cols(col_count), m_channels(rhs.m_channels), m_step(rhs.m_step), m_data(rhs.m_data), - m_offset(rhs.m_offset + row_offset * m_step + - col_offset * m_channels) {} + m_offset(rhs.m_offset + row_offset * m_step + col_offset * m_channels) {} template Mat& Mat::operator=(const Mat& rhs) { @@ -322,8 +320,7 @@ bool Mat::equals(const Mat& rhs) const { if (this->m_channels != rhs.m_channels) return false; for (size_t r = 0; r < m_rows; ++r) { - if (0 != - memcmp(this->ptr(r), rhs.ptr(r), sizeof(T) * m_cols * m_channels)) + if (0 != memcmp(this->ptr(r), rhs.ptr(r), sizeof(T) * m_cols * m_channels)) return false; } return true; diff --git a/dnn/src/common/cvt_color.cpp b/dnn/src/common/cvt_color.cpp index 1b09651e..399ac0c6 100644 --- a/dnn/src/common/cvt_color.cpp +++ b/dnn/src/common/cvt_color.cpp @@ -14,20 +14,15 @@ namespace megdnn { -void CvtColorBase::deduce_layout_fwd(const TensorLayout& src, - TensorLayout& dst) { +void CvtColorBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { auto errmsg = [&]() { return megdnn_layout_msg(src); }; MEGDNN_MARK_USED_VAR(errmsg); auto mode = param().mode; - if (mode == Param::Mode::YUV2RGB_NV21 || - mode == Param::Mode::YUV2BGR_NV21 || - mode == Param::Mode::YUV2RGB_NV12 || - mode == Param::Mode::YUV2BGR_NV12 || - mode == Param::Mode::YUV2RGB_YV12 || - mode == Param::Mode::YUV2BGR_YV12 || - mode == Param::Mode::YUV2RGB_YU12 || - mode == Param::Mode::YUV2BGR_YU12) { + if (mode == Param::Mode::YUV2RGB_NV21 || mode == Param::Mode::YUV2BGR_NV21 || + mode == Param::Mode::YUV2RGB_NV12 || mode == Param::Mode::YUV2BGR_NV12 || + mode == Param::Mode::YUV2RGB_YV12 || mode == Param::Mode::YUV2BGR_YV12 || + mode == Param::Mode::YUV2RGB_YU12 || mode == Param::Mode::YUV2BGR_YU12) { megdnn_log_warn( "Deprecated mode for cvtcolor, you should refer to the wiki " "for detail usage"); @@ -42,8 +37,8 @@ void CvtColorBase::deduce_layout_fwd(const TensorLayout& src, } megdnn_assert( - src.ndim == 4_z && (src.shape[3] == 1_z || src.shape[3] == 3_z || - src.shape[3] == 4_z), + src.ndim == 4_z && + (src.shape[3] == 1_z || src.shape[3] == 3_z || src.shape[3] == 4_z), "%s", errmsg().c_str()); size_t in = src.shape[0]; @@ -141,8 +136,7 @@ void CvtColorBase::deduce_layout_fwd(const TensorLayout& src, dst = TensorLayout(TensorShape({in, oh, ow, oc}), src.dtype); } -void CvtColorBase::check_layout_fwd(const TensorLayout& src, - const TensorLayout& dst) { +void CvtColorBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { megdnn_assert_eq_dtype(src, dst); TensorLayout dst_expected; deduce_layout_fwd(src, dst_expected); @@ -153,8 +147,8 @@ void CvtColor::deduce_layout(const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } -void CvtColor::check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes) { +void CvtColor::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); megdnn_assert_contiguous(src); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); diff --git a/dnn/src/common/dct.cpp b/dnn/src/common/dct.cpp index 7d726549..922ab2d4 100644 --- a/dnn/src/common/dct.cpp +++ b/dnn/src/common/dct.cpp @@ -15,10 +15,9 @@ namespace megdnn { -void DctChannelSelectForward::deduce_layout_fwd(const TensorLayout& src, - const TensorLayout& mask_offset, - const TensorLayout& mask_val, - TensorLayout& dst) { +void DctChannelSelectForward::deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& mask_offset, + const TensorLayout& mask_val, TensorLayout& dst) { const size_t dct_block = param().dct_block_size; const size_t in = src.shape[0]; const size_t ic = src.shape[1]; @@ -29,50 +28,49 @@ void DctChannelSelectForward::deduce_layout_fwd(const TensorLayout& src, const size_t ow = iw / dct_block; //! mask will be empty or (ic + 1) elements size_t oc = mask_offset.ndim > 0 && mask_offset[0] >= 2 - ? mask_val.shape[0] - : ic * dct_block * dct_block; + ? mask_val.shape[0] + : ic * dct_block * dct_block; if (param().fastImpl == Param::FastImpl::FIX_32_MASK) { - megdnn_assert(oc == 32, - "Param::FastImpl::FIX_32_MASK oc must be 32, but %zu", - oc); + megdnn_assert( + oc == 32, "Param::FastImpl::FIX_32_MASK oc must be 32, but %zu", oc); } if (param().format == Param::Format::NCHW) { dst = TensorLayout(TensorShape({in, oc, oh, ow}), dst.dtype); } else { - megdnn_assert(param().format == Param::Format::NCHW4, - "dct format must be nchw or nchw4"); + megdnn_assert( + param().format == Param::Format::NCHW4, + "dct format must be nchw or nchw4"); megdnn_assert(oc % 4 == 0, "oc mod 4 == 0 in nchw4"); dst = TensorLayout(TensorShape({in, oc / 4, oh, ow, 4}), dst.dtype); } } -void DctChannelSelectForward::deduce_layout(const TensorLayout& src, - const TensorLayout& mask_offset, - const TensorLayout& mask_val, - TensorLayout& dst) { +void DctChannelSelectForward::deduce_layout( + const TensorLayout& src, const TensorLayout& mask_offset, + const TensorLayout& mask_val, TensorLayout& dst) { deduce_layout_fwd(src, mask_offset, mask_val, dst); } -void DctChannelSelectForward::check_layout_fwd(const TensorLayout& src, - const TensorLayout& mask_offset, - const TensorLayout& mask_val, - const TensorLayout& dst) { +void DctChannelSelectForward::check_layout_fwd( + const TensorLayout& src, const TensorLayout& mask_offset, + const TensorLayout& mask_val, const TensorLayout& dst) { const size_t dct_block = param().dct_block_size; const size_t ih = src.shape[2]; const size_t iw = src.shape[3]; - megdnn_assert(mask_offset.ndim == 0 || (mask_offset.ndim == 1 && - (mask_offset.shape[0] == 0 || - mask_offset.shape[0] >= 2) && - mask_val.ndim == 1), - "mask only support one valid dim"); + megdnn_assert( + mask_offset.ndim == 0 || + (mask_offset.ndim == 1 && + (mask_offset.shape[0] == 0 || mask_offset.shape[0] >= 2) && + mask_val.ndim == 1), + "mask only support one valid dim"); megdnn_assert(mask_val.ndim <= 1, "only support one dim"); - megdnn_assert(src.dtype.enumv() == DTypeEnum::Uint8, - "src.dtype == dtype::Uint8"); - megdnn_assert(dst.dtype.enumv() == DTypeEnum::Float32 || - dst.dtype.enumv() == DTypeEnum::QuantizedS8, - "dst.dtype == dtype::Float32 || dst.dtype.enumv() == " - "DTypeEnum::QuantizedS8"); + megdnn_assert(src.dtype.enumv() == DTypeEnum::Uint8, "src.dtype == dtype::Uint8"); + megdnn_assert( + dst.dtype.enumv() == DTypeEnum::Float32 || + dst.dtype.enumv() == DTypeEnum::QuantizedS8, + "dst.dtype == dtype::Float32 || dst.dtype.enumv() == " + "DTypeEnum::QuantizedS8"); megdnn_assert(ih % dct_block == 0, "ih mod dctblock == 0"); megdnn_assert(iw % dct_block == 0, "iw mod dctblock == 0"); } diff --git a/dnn/src/common/deformable_conv.cpp b/dnn/src/common/deformable_conv.cpp index 18db856a..969587ee 100644 --- a/dnn/src/common/deformable_conv.cpp +++ b/dnn/src/common/deformable_conv.cpp @@ -19,9 +19,9 @@ using CanonizedFilterMeta = DeformableConvBase::CanonizedFilterMeta; namespace { template -std::string get_errmsg(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& dst, const Param& param) { +std::string get_errmsg( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& dst, const Param& param) { MEGDNN_MARK_USED_VAR(src); MEGDNN_MARK_USED_VAR(filter); MEGDNN_MARK_USED_VAR(dst); @@ -39,15 +39,15 @@ std::string get_errmsg(const TensorLayout& src, const TensorLayout& filter, } template -void make_canonized_filter_meta_nchw(size_t src_ndim, - const TensorLayout& filter, - const Param& param, - CanonizedFilterMeta& ret) { - megdnn_assert(param.mode == Param::Mode::CROSS_CORRELATION, - "only support CROSS_CORRELATION mode"); +void make_canonized_filter_meta_nchw( + size_t src_ndim, const TensorLayout& filter, const Param& param, + CanonizedFilterMeta& ret) { + megdnn_assert( + param.mode == Param::Mode::CROSS_CORRELATION, + "only support CROSS_CORRELATION mode"); - megdnn_assert(param.format == Param::Format::NCHW, - "only support nchw input layout"); + megdnn_assert( + param.format == Param::Format::NCHW, "only support nchw input layout"); size_t flt_start, flt_spatial_start, ocpg_pos, icpg_pos; @@ -71,9 +71,9 @@ void make_canonized_filter_meta_nchw(size_t src_ndim, auto dilation = ret.dilation; for (size_t i = 0; i < ret.spatial_ndim; ++i) { - megdnn_assert(dilation[i] > 0, - "invalid dilation on spatial dim %zu, %u", i, - dilation[i]); + megdnn_assert( + dilation[i] > 0, "invalid dilation on spatial dim %zu, %u", i, + dilation[i]); ret.spatial[i] = filter[i + flt_start + flt_spatial_start]; ret.dilated_spatial[i] = (ret.spatial[i] - 1) * dilation[i] + 1; } @@ -84,8 +84,7 @@ void make_canonized_filter_meta_nchw(size_t src_ndim, namespace megdnn { CanonizedFilterMeta DeformableConvBase::make_canonized_filter_meta( - size_t src_ndim, const TensorLayout& filter, - const TensorLayout& offset) const { + size_t src_ndim, const TensorLayout& filter, const TensorLayout& offset) const { megdnn_assert_contiguous(filter); CanonizedFilterMeta ret; @@ -99,8 +98,7 @@ CanonizedFilterMeta DeformableConvBase::make_canonized_filter_meta( ret.dilation[1] = param().dilate_w; if (param().sparse == Param::Sparse::GROUP) { - megdnn_assert(filter.ndim == 5, - "filter dim should be 5 for group conv"); + megdnn_assert(filter.ndim == 5, "filter dim should be 5 for group conv"); ret.group = filter[0]; } @@ -114,24 +112,23 @@ CanonizedFilterMeta DeformableConvBase::make_canonized_filter_meta( return ret; } -void DeformableConvBase::deduce_layout_fwd(const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - TensorLayout& dst) { +void DeformableConvBase::deduce_layout_fwd( + const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, TensorLayout& dst) { // im shape: (n, IC, IH, IW) - megdnn_assert(im.ndim == 4, "invalid src layout: %s", - megdnn_layout_msg(im).c_str()); + megdnn_assert( + im.ndim == 4, "invalid src layout: %s", megdnn_layout_msg(im).c_str()); // filter shape: (OC, IC, FH, FW) or (g, OC/g, IC/g, FH, FW) - megdnn_assert(filter.ndim == 4 || filter.ndim == 5, - "invalid filter layout: %s", - megdnn_layout_msg(filter).c_str()); + megdnn_assert( + filter.ndim == 4 || filter.ndim == 5, "invalid filter layout: %s", + megdnn_layout_msg(filter).c_str()); // offset shape: (N, 2*dg*FH*FW, OH, OW) - megdnn_assert(offset.ndim == 4, "invalid offset layout: %s", - megdnn_layout_msg(offset).c_str()); + megdnn_assert( + offset.ndim == 4, "invalid offset layout: %s", + megdnn_layout_msg(offset).c_str()); // mask shape: (N, dg*FH*FW, OH, OW) - megdnn_assert(mask.ndim == 4, "invalid mask layout: %s", - megdnn_layout_msg(mask).c_str()); + megdnn_assert( + mask.ndim == 4, "invalid mask layout: %s", megdnn_layout_msg(mask).c_str()); size_t n = im.shape[0], ic = im.shape[1]; size_t ih = im.shape[2], iw = im.shape[3]; @@ -153,8 +150,8 @@ void DeformableConvBase::deduce_layout_fwd(const TensorLayout& im, size_t oh = (ih + ph * 2 - kh) / sh + 1; size_t ow = (iw + pw * 2 - kw) / sw + 1; - megdnn_assert(group > 0 && deformable_group > 0, - "group and deformable group should > 0"); + megdnn_assert( + group > 0 && deformable_group > 0, "group and deformable group should > 0"); megdnn_assert(ic == icpg * group, "im ic != group * icpg of filter"); megdnn_assert(ic % deformable_group == 0, "ic %% deformable_group != 0"); megdnn_assert(oc % deformable_group == 0, "oc %% deformable_group != 0"); @@ -164,32 +161,29 @@ void DeformableConvBase::deduce_layout_fwd(const TensorLayout& im, "invalid deformable group deduced from offset(%s) or mask(%s)", megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str()); - megdnn_assert((offset[1] / (2 * fh * fw)) == (mask[1] / (fh * fw)), - "offset(%s) and mask(%s) should have same deformable group", - megdnn_layout_msg(offset).c_str(), - megdnn_layout_msg(mask).c_str()); - - megdnn_assert((offset[2] == mask[2]) && (offset[3] == mask[3]), - "offset(%s) and mask(%s) should have same spatial dim", - megdnn_layout_msg(offset).c_str(), - megdnn_layout_msg(mask).c_str()); - megdnn_assert(oh == offset[2], "deduced oh(%zu) != offset oh(%zu)", oh, - offset[2]); - megdnn_assert(ow == offset[3], "deduced ow(%zu) != offset ow(%zu)", ow, - offset[3]); + megdnn_assert( + (offset[1] / (2 * fh * fw)) == (mask[1] / (fh * fw)), + "offset(%s) and mask(%s) should have same deformable group", + megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str()); + + megdnn_assert( + (offset[2] == mask[2]) && (offset[3] == mask[3]), + "offset(%s) and mask(%s) should have same spatial dim", + megdnn_layout_msg(offset).c_str(), megdnn_layout_msg(mask).c_str()); + megdnn_assert(oh == offset[2], "deduced oh(%zu) != offset oh(%zu)", oh, offset[2]); + megdnn_assert(ow == offset[3], "deduced ow(%zu) != offset ow(%zu)", ow, offset[3]); dst.ndim = 4; dst = {{n, oc, oh, ow}, im.dtype}; } -void DeformableConvBase::check_layout_fwd(const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst) { +void DeformableConvBase::check_layout_fwd( + const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& dst) { auto& im_dtype = im.dtype; TensorLayout dst_expected; - megdnn_assert(im_dtype.enumv() == DTypeEnum::Float32, - "DeformableConv only support float32 input"); + megdnn_assert( + im_dtype.enumv() == DTypeEnum::Float32, + "DeformableConv only support float32 input"); megdnn_assert_eq_dtype(im, dst); megdnn_assert_eq_dtype(im, filter); megdnn_assert_eq_dtype(im, dst); @@ -199,19 +193,16 @@ void DeformableConvBase::check_layout_fwd(const TensorLayout& im, megdnn_assert_eq_layout(dst_expected, dst); } -void DeformableConvForward::deduce_layout(const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - TensorLayout& dst) { +void DeformableConvForward::deduce_layout( + const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, TensorLayout& dst) { deduce_layout_fwd(im, filter, offset, mask, dst); return; } CanonizedFilterMeta DeformableConvForward::check_exec( - const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& dst, size_t workspace_in_bytes) { + const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& dst, size_t workspace_in_bytes) { auto ret = make_canonized_filter_meta(im.ndim, filter, offset); auto required_workspace_in_bytes = get_workspace_in_bytes(im, filter, offset, mask, dst); @@ -221,9 +212,9 @@ CanonizedFilterMeta DeformableConvForward::check_exec( } CanonizedFilterMeta DeformableConvBackwardFilter::check_exec( - const TensorLayout& im, const TensorLayout& offset, - const TensorLayout& mask, const TensorLayout& out_grad, - const TensorLayout& filter_grad, size_t workspace_in_bytes) { + const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& out_grad, const TensorLayout& filter_grad, + size_t workspace_in_bytes) { check_layout_fwd(im, filter_grad, offset, mask, out_grad); // check dtype megdnn_assert_eq_dtype(im, filter_grad); @@ -237,11 +228,10 @@ CanonizedFilterMeta DeformableConvBackwardFilter::check_exec( } CanonizedFilterMeta DeformableConvBackwardData::check_exec( - const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& out_grad, const TensorLayout& im_grad, - const TensorLayout& offset_grad, const TensorLayout& mask_grad, - size_t workspace_in_bytes) { + const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& im_grad, const TensorLayout& offset_grad, + const TensorLayout& mask_grad, size_t workspace_in_bytes) { check_layout_fwd(im, filter, offset, mask, out_grad); // check dtype @@ -255,9 +245,8 @@ CanonizedFilterMeta DeformableConvBackwardData::check_exec( megdnn_assert_eq_shape(mask, mask_grad); auto ret = make_canonized_filter_meta(im.ndim, filter, offset); - auto required_workspace_in_bytes = - get_workspace_in_bytes(im, filter, offset, mask, out_grad, im_grad, - offset_grad, mask_grad); + auto required_workspace_in_bytes = get_workspace_in_bytes( + im, filter, offset, mask, out_grad, im_grad, offset_grad, mask_grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); return ret; diff --git a/dnn/src/common/deformable_ps_roi_pooling.cpp b/dnn/src/common/deformable_ps_roi_pooling.cpp index f61fdf00..808d6b67 100644 --- a/dnn/src/common/deformable_ps_roi_pooling.cpp +++ b/dnn/src/common/deformable_ps_roi_pooling.cpp @@ -14,11 +14,9 @@ namespace megdnn { -void DeformablePSROIPoolingBase::deduce_layout_fwd(const TensorLayout& data, - const TensorLayout& rois, - const TensorLayout& trans, - TensorLayout& out_data, - TensorLayout& out_count) { +void DeformablePSROIPoolingBase::deduce_layout_fwd( + const TensorLayout& data, const TensorLayout& rois, const TensorLayout& trans, + TensorLayout& out_data, TensorLayout& out_count) { megdnn_assert_contiguous(data); megdnn_assert_contiguous(rois); megdnn_assert_contiguous(trans); @@ -39,31 +37,30 @@ void DeformablePSROIPoolingBase::deduce_layout_fwd(const TensorLayout& data, MEGDNN_MARK_USED_VAR(out_count); MEGDNN_MARK_USED_VAR(errmsg); - megdnn_assert(data.dtype.enumv() == DTypeEnum::Float32, - "DeformablePSROIPooling only support float32 input"); + megdnn_assert( + data.dtype.enumv() == DTypeEnum::Float32, + "DeformablePSROIPooling only support float32 input"); megdnn_assert(data.ndim == 4_z, "invalid data shape, %s", errmsg().c_str()); - megdnn_assert(rois.ndim == 2_z && rois[1] == 5, "invalid rois shape, %s", - errmsg().c_str()); - megdnn_assert(trans.ndim == 4_z, "invalid trans shape, %s", - errmsg().c_str()); + megdnn_assert( + rois.ndim == 2_z && rois[1] == 5, "invalid rois shape, %s", + errmsg().c_str()); + megdnn_assert(trans.ndim == 4_z, "invalid trans shape, %s", errmsg().c_str()); if (!param().no_trans) { - megdnn_assert(trans[1] == 2_z && trans[2] == param().pooled_h && - trans[3] == param().pooled_w, - "invalid trans shape: %s", errmsg().c_str()); + megdnn_assert( + trans[1] == 2_z && trans[2] == param().pooled_h && + trans[3] == param().pooled_w, + "invalid trans shape: %s", errmsg().c_str()); } - out_data = {{rois[0], data[1], param().pooled_h, param().pooled_w}, - data.dtype}; + out_data = {{rois[0], data[1], param().pooled_h, param().pooled_w}, data.dtype}; out_count = out_data; } -void DeformablePSROIPoolingBase::check_layout_fwd(const TensorLayout& data, - const TensorLayout& rois, - const TensorLayout& trans, - const TensorLayout& out_data, - const TensorLayout& out_count, - size_t workspace_in_bytes) { +void DeformablePSROIPoolingBase::check_layout_fwd( + const TensorLayout& data, const TensorLayout& rois, const TensorLayout& trans, + const TensorLayout& out_data, const TensorLayout& out_count, + size_t workspace_in_bytes) { MEGDNN_MARK_USED_VAR(workspace_in_bytes); TensorLayout exp_out_data, exp_out_count; @@ -73,34 +70,29 @@ void DeformablePSROIPoolingBase::check_layout_fwd(const TensorLayout& data, megdnn_assert_eq_layout(out_count, exp_out_count); } -void DeformablePSROIPoolingForward::deduce_layout(const TensorLayout& data, - const TensorLayout& rois, - const TensorLayout& trans, - TensorLayout& out_data, - TensorLayout& out_count) { +void DeformablePSROIPoolingForward::deduce_layout( + const TensorLayout& data, const TensorLayout& rois, const TensorLayout& trans, + TensorLayout& out_data, TensorLayout& out_count) { deduce_layout_fwd(data, rois, trans, out_data, out_count); } -void DeformablePSROIPoolingForward::check_exec(const TensorLayout& data, - const TensorLayout& rois, - const TensorLayout& trans, - const TensorLayout& out_data, - const TensorLayout& out_count, - size_t workspace_in_bytes) { - check_layout_fwd(data, rois, trans, out_data, out_count, - workspace_in_bytes); +void DeformablePSROIPoolingForward::check_exec( + const TensorLayout& data, const TensorLayout& rois, const TensorLayout& trans, + const TensorLayout& out_data, const TensorLayout& out_count, + size_t workspace_in_bytes) { + check_layout_fwd(data, rois, trans, out_data, out_count, workspace_in_bytes); auto required_workspace_in_bytes = get_workspace_in_bytes(data, rois, trans, out_data, out_count); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } void DeformablePSROIPoolingBackward::check_exec( - const TensorLayout& data, const TensorLayout& rois, - const TensorLayout& trans, const TensorLayout& out_diff, - const TensorLayout& out_count, const TensorLayout& data_diff, - const TensorLayout& trans_diff, size_t workspace_in_bytes) { - check_layout_fwd(data_diff, rois, trans_diff, out_diff, out_count, - workspace_in_bytes); + const TensorLayout& data, const TensorLayout& rois, const TensorLayout& trans, + const TensorLayout& out_diff, const TensorLayout& out_count, + const TensorLayout& data_diff, const TensorLayout& trans_diff, + size_t workspace_in_bytes) { + check_layout_fwd( + data_diff, rois, trans_diff, out_diff, out_count, workspace_in_bytes); megdnn_assert_eq_layout(data, data_diff); megdnn_assert_eq_layout(trans, trans_diff); auto required_workspace_in_bytes = get_workspace_in_bytes( diff --git a/dnn/src/common/dot.cpp b/dnn/src/common/dot.cpp index 14faf5b5..ddae6c0c 100644 --- a/dnn/src/common/dot.cpp +++ b/dnn/src/common/dot.cpp @@ -14,15 +14,12 @@ namespace megdnn { -void DotForward::check_exec(const TensorLayout &A, - const TensorLayout &B, - const TensorLayout &C, - size_t workspace_in_bytes) -{ +void DotForward::check_exec( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, + size_t workspace_in_bytes) { auto errmsg = [&]() { - return megdnn_layout_msg(A) - + ", " + megdnn_layout_msg(B) - + ", " + megdnn_layout_msg(C); + return megdnn_layout_msg(A) + ", " + megdnn_layout_msg(B) + ", " + + megdnn_layout_msg(C); }; MEGDNN_MARK_USED_VAR(errmsg); megdnn_assert(A.ndim == 1_z && A.stride[0] >= 0, "%s", errmsg().c_str()); @@ -36,13 +33,11 @@ void DotForward::check_exec(const TensorLayout &A, megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void DotForward::deduce_layout(const TensorLayout &A, - const TensorLayout &, - TensorLayout &C) -{ +void DotForward::deduce_layout( + const TensorLayout& A, const TensorLayout&, TensorLayout& C) { C = TensorLayout(TensorShape{1}, A.dtype); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/dtype.cpp b/dnn/src/common/dtype.cpp index 3e490fd0..417e5613 100644 --- a/dnn/src/common/dtype.cpp +++ b/dnn/src/common/dtype.cpp @@ -12,9 +12,9 @@ #include "megdnn/dtype.h" #include "src/common/utils.h" +#include #include #include -#include using namespace megdnn; using namespace dtype; @@ -23,14 +23,12 @@ using namespace dtype; #pragma message "megdnn float16 disabled" #endif -#define IMPL(_name) \ -DType::Trait _name::sm_trait = { \ - DTypeTrait<_name>::name, \ - DTypeTrait<_name>::size_log, DTypeTrait<_name>::low_bit, \ - DTypeEnum::_name, \ - DTypeTrait<_name>::category, DTypeTrait<_name>::signedness, \ - DTypeTrait<_name>::has_param \ -}; +#define IMPL(_name) \ + DType::Trait _name::sm_trait = { \ + DTypeTrait<_name>::name, DTypeTrait<_name>::size_log, \ + DTypeTrait<_name>::low_bit, DTypeEnum::_name, \ + DTypeTrait<_name>::category, DTypeTrait<_name>::signedness, \ + DTypeTrait<_name>::has_param}; #define TEMPLATED_IMPL(_name) \ template <> \ IMPL(_name) @@ -41,9 +39,8 @@ MEGDNN_FOREACH_PARAMETERIZED_DTYPE(TEMPLATED_IMPL) #undef TEMPLATED_IMPL #undef IMPL -void DType::on_assert_is_failed(const char *rname) const { - megdnn_throw(ssprintf("attempt to access dtype %s as %s", name(), rname) - .c_str()); +void DType::on_assert_is_failed(const char* rname) const { + megdnn_throw(ssprintf("attempt to access dtype %s as %s", name(), rname).c_str()); MEGDNN_MARK_USED_VAR(rname); } @@ -53,37 +50,36 @@ void DType::on_request_lowbit_size() const { DType DType::from_enum(DTypeEnum ev) { switch (ev) { -#define cb(_dt) case DTypeEnum::_dt: return dtype::_dt(); +#define cb(_dt) \ + case DTypeEnum::_dt: \ + return dtype::_dt(); MEGDNN_FOREACH_DTYPE_NAME(cb) #undef cb #define cb(_dt) case DTypeEnum::_dt: MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) - megdnn_throw( - "cannot construct parameterized DType via DType::from_enum"); + megdnn_throw("cannot construct parameterized DType via DType::from_enum"); #undef cb } megdnn_throw("bad DTypeEnum value"); } template -typename ParameterizedDType::Trait* -ParameterizedDType::make_from_param( - const DTypeParam& param) { +typename ParameterizedDType::Trait* ParameterizedDType< + type_enum>::make_from_param(const DTypeParam& param) { struct Hasher { std::size_t operator()(const DTypeParam& key) const { return key.hash(); } }; - static std::unordered_map, - std::unique_ptr, Hasher> + static std::unordered_map< + DTypeParam, std::unique_ptr, Hasher> entries; auto it = entries.find(param); if (it != entries.end()) { return it->second.get(); } - entries[param] = - std::make_unique(SelfType::sm_trait, param); + entries[param] = std::make_unique(SelfType::sm_trait, param); return entries[param].get(); } @@ -103,8 +99,7 @@ inline std::size_t DTypeParam::hash() const { return std::hash()(scale) ^ std::hash()(zero_point); } -inline bool DTypeParam::operator==( - const DTypeParam& rhs) const { +inline bool DTypeParam::operator==(const DTypeParam& rhs) const { return scale == rhs.scale && zero_point == rhs.zero_point; } @@ -117,8 +112,7 @@ inline std::size_t DTypeParam::hash() const { return std::hash()(scale); } -inline bool DTypeParam::operator==( - const DTypeParam& rhs) const { +inline bool DTypeParam::operator==(const DTypeParam& rhs) const { return scale == rhs.scale; } @@ -131,8 +125,7 @@ inline std::size_t DTypeParam::hash() const { return std::hash()(scale); } -inline bool DTypeParam::operator==( - const DTypeParam& rhs) const { +inline bool DTypeParam::operator==(const DTypeParam& rhs) const { return scale == rhs.scale; } @@ -145,8 +138,7 @@ inline std::size_t DTypeParam::hash() const { return std::hash()(scale); } -inline bool DTypeParam::operator==( - const DTypeParam& rhs) const { +inline bool DTypeParam::operator==(const DTypeParam& rhs) const { return scale == rhs.scale; } @@ -160,8 +152,7 @@ inline std::size_t DTypeParam::hash() const { return std::hash()(scale) ^ std::hash()(zero_point); } -inline bool DTypeParam::operator==( - const DTypeParam& rhs) const { +inline bool DTypeParam::operator==(const DTypeParam& rhs) const { return scale == rhs.scale && zero_point == rhs.zero_point; } @@ -174,8 +165,7 @@ inline std::size_t DTypeParam::hash() const { return std::hash()(scale); } -inline bool DTypeParam::operator==( - const DTypeParam& rhs) const { +inline bool DTypeParam::operator==(const DTypeParam& rhs) const { return scale == rhs.scale; } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/common/elemwise/each_mode.inl b/dnn/src/common/elemwise/each_mode.inl index 703cf8f8..22892795 100644 --- a/dnn/src/common/elemwise/each_mode.inl +++ b/dnn/src/common/elemwise/each_mode.inl @@ -1,100 +1,98 @@ // generated by gen_elemwise_each_mode.py -#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb) \ +#define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT(cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_INT(cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_BOOL(cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_FLOAT(cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) \ MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_BINARY_INT(cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) -#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_BOOL(cb) \ +#define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_BOOL(cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_FLOAT(cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ + MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) #define MEGDNN_FOREACH_ELEMWISE_MODE_TERNARY_INT(cb) \ - MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) \ - + MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) diff --git a/dnn/src/common/elemwise/erfinv.h b/dnn/src/common/elemwise/erfinv.h index 00078fc0..d1bc9d96 100644 --- a/dnn/src/common/elemwise/erfinv.h +++ b/dnn/src/common/elemwise/erfinv.h @@ -52,356 +52,286 @@ // LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) template -inline U evaluate_polynomial(const T_* poly, U const& z, std::size_t count) -{ - megdnn_assert(count > 0); - U sum = static_cast(poly[count - 1]); - for(int i = static_cast(count) - 2; i >= 0; --i) - { - sum *= z; - sum += static_cast(poly[i]); - } - return sum; +inline U evaluate_polynomial(const T_* poly, U const& z, std::size_t count) { + megdnn_assert(count > 0); + U sum = static_cast(poly[count - 1]); + for (int i = static_cast(count) - 2; i >= 0; --i) { + sum *= z; + sum += static_cast(poly[i]); + } + return sum; } template -inline V evaluate_polynomial(const T(&a)[N], const V& val) -{ - return evaluate_polynomial(a, val, N); +inline V evaluate_polynomial(const T (&a)[N], const V& val) { + return evaluate_polynomial(a, val, N); } // // The inverse erf and erfc functions share a common implementation, // this version is for 80-bit long double's and smaller: // -inline double erfinv_imp(double p, double q) -{ - using namespace std; +inline double erfinv_imp(double p, double q) { + using namespace std; - double result = 0; + double result = 0; - if(p <= 0.5) - { - // - // Evaluate inverse erf using the rational approximation: - // - // x = p(p+10)(Y+R(p)) - // - // Where Y is a constant, and R(p) is optimised for a low - // absolute error compared to |Y|. - // - // double: Max error found: 2.001849e-18 - // long double: Max error found: 1.017064e-20 - // Maximum Deviation Found (actual error term at infinite precision) 8.030e-21 - // - static const float Y = 0.0891314744949340820313f; - static const double P[] = { - -0.000508781949658280665617, - -0.00836874819741736770379, - 0.0334806625409744615033, - -0.0126926147662974029034, - -0.0365637971411762664006, - 0.0219878681111168899165, - 0.00822687874676915743155, - -0.00538772965071242932965 - }; - static const double Q[] = { - 1.0, - -0.970005043303290640362, - -1.56574558234175846809, - 1.56221558398423026363, - 0.662328840472002992063, - -0.71228902341542847553, - -0.0527396382340099713954, - 0.0795283687341571680018, - -0.00233393759374190016776, - 0.000886216390456424707504 - }; - double g = p * (p + 10); - double r = evaluate_polynomial(P, p) / evaluate_polynomial(Q, p); - result = g * Y + g * r; - } - else if(q >= 0.25) - { - // - // Rational approximation for 0.5 > q >= 0.25 - // - // x = sqrt(-2*log(q)) / (Y + R(q)) - // - // Where Y is a constant, and R(q) is optimised for a low - // absolute error compared to Y. - // - // double : Max error found: 7.403372e-17 - // long double : Max error found: 6.084616e-20 - // Maximum Deviation Found (error term) 4.811e-20 - // - static const float Y = 2.249481201171875f; - static const double P[] = { - -0.202433508355938759655, - 0.105264680699391713268, - 8.37050328343119927838, - 17.6447298408374015486, - -18.8510648058714251895, - -44.6382324441786960818, - 17.445385985570866523, - 21.1294655448340526258, - -3.67192254707729348546 - }; - static const double Q[] = { - 1.0, - 6.24264124854247537712, - 3.9713437953343869095, - -28.6608180499800029974, - -20.1432634680485188801, - 48.5609213108739935468, - 10.8268667355460159008, - -22.6436933413139721736, - 1.72114765761200282724 - }; - double g = sqrt(-2 * log(q)); - double xs = q - 0.25f; - double r = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); - result = g / (Y + r); - } - else - { - // - // For q < 0.25 we have a series of rational approximations all - // of the general form: - // - // let: x = sqrt(-log(q)) - // - // Then the result is given by: - // - // x(Y+R(x-B)) - // - // where Y is a constant, B is the lowest value of x for which - // the approximation is valid, and R(x-B) is optimised for a low - // absolute error compared to Y. - // - // Note that almost all code will really go through the first - // or maybe second approximation. After than we're dealing with very - // small input values indeed: 80 and 128 bit long double's go all the - // way down to ~ 1e-5000 so the "tail" is rather long... - // - double x = sqrt(-log(q)); - if(x < 3) - { - // Max error found: 1.089051e-20 - static const float Y = 0.807220458984375f; - static const double P[] = { - -0.131102781679951906451, - -0.163794047193317060787, - 0.117030156341995252019, - 0.387079738972604337464, - 0.337785538912035898924, - 0.142869534408157156766, - 0.0290157910005329060432, - 0.00214558995388805277169, - -0.679465575181126350155e-6, - 0.285225331782217055858e-7, - -0.681149956853776992068e-9 - }; - static const double Q[] = { - 1.0, - 3.46625407242567245975, - 5.38168345707006855425, - 4.77846592945843778382, - 2.59301921623620271374, - 0.848854343457902036425, - 0.152264338295331783612, - 0.01105924229346489121 - }; - double xs = x - 1.125f; - double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); - result = Y * x + R * x; - } - else if(x < 6) - { - // Max error found: 8.389174e-21 - static const float Y = 0.93995571136474609375f; - static const double P[] = { - -0.0350353787183177984712, - -0.00222426529213447927281, - 0.0185573306514231072324, - 0.00950804701325919603619, - 0.00187123492819559223345, - 0.000157544617424960554631, - 0.460469890584317994083e-5, - -0.230404776911882601748e-9, - 0.266339227425782031962e-11 - }; - static const double Q[] = { - 1.0, - 1.3653349817554063097, - 0.762059164553623404043, - 0.220091105764131249824, - 0.0341589143670947727934, - 0.00263861676657015992959, - 0.764675292302794483503e-4 - }; - double xs = x - 3; - double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); - result = Y * x + R * x; - } - else if(x < 18) - { - // Max error found: 1.481312e-19 - static const float Y = 0.98362827301025390625f; - static const double P[] = { - -0.0167431005076633737133, - -0.00112951438745580278863, - 0.00105628862152492910091, - 0.000209386317487588078668, - 0.149624783758342370182e-4, - 0.449696789927706453732e-6, - 0.462596163522878599135e-8, - -0.281128735628831791805e-13, - 0.99055709973310326855e-16 - }; - static const double Q[] = { - 1.0, - 0.591429344886417493481, - 0.138151865749083321638, - 0.0160746087093676504695, - 0.000964011807005165528527, - 0.275335474764726041141e-4, - 0.282243172016108031869e-6 - }; - double xs = x - 6; - double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); - result = Y * x + R * x; - } - else if(x < 44) - { - // Max error found: 5.697761e-20 - static const float Y = 0.99714565277099609375f; - static const double P[] = { - -0.0024978212791898131227, - -0.779190719229053954292e-5, - 0.254723037413027451751e-4, - 0.162397777342510920873e-5, - 0.396341011304801168516e-7, - 0.411632831190944208473e-9, - 0.145596286718675035587e-11, - -0.116765012397184275695e-17 - }; - static const double Q[] = { - 1.0, - 0.207123112214422517181, - 0.0169410838120975906478, - 0.000690538265622684595676, - 0.145007359818232637924e-4, - 0.144437756628144157666e-6, - 0.509761276599778486139e-9 - }; - double xs = x - 18; - double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); - result = Y * x + R * x; - } - else - { - // Max error found: 1.279746e-20 - static const float Y = 0.99941349029541015625f; - static const double P[] = { - -0.000539042911019078575891, - -0.28398759004727721098e-6, - 0.899465114892291446442e-6, - 0.229345859265920864296e-7, - 0.225561444863500149219e-9, - 0.947846627503022684216e-12, - 0.135880130108924861008e-14, - -0.348890393399948882918e-21 - }; - static const double Q[] = { - 1.0, - 0.0845746234001899436914, - 0.00282092984726264681981, - 0.468292921940894236786e-4, - 0.399968812193862100054e-6, - 0.161809290887904476097e-8, - 0.231558608310259605225e-11 - }; - double xs = x - 44; - double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); - result = Y * x + R * x; - } - } - return result; + if (p <= 0.5) { + // + // Evaluate inverse erf using the rational approximation: + // + // x = p(p+10)(Y+R(p)) + // + // Where Y is a constant, and R(p) is optimised for a low + // absolute error compared to |Y|. + // + // double: Max error found: 2.001849e-18 + // long double: Max error found: 1.017064e-20 + // Maximum Deviation Found (actual error term at infinite precision) 8.030e-21 + // + static const float Y = 0.0891314744949340820313f; + static const double P[] = { + -0.000508781949658280665617, -0.00836874819741736770379, + 0.0334806625409744615033, -0.0126926147662974029034, + -0.0365637971411762664006, 0.0219878681111168899165, + 0.00822687874676915743155, -0.00538772965071242932965}; + static const double Q[] = { + 1.0, + -0.970005043303290640362, + -1.56574558234175846809, + 1.56221558398423026363, + 0.662328840472002992063, + -0.71228902341542847553, + -0.0527396382340099713954, + 0.0795283687341571680018, + -0.00233393759374190016776, + 0.000886216390456424707504}; + double g = p * (p + 10); + double r = evaluate_polynomial(P, p) / evaluate_polynomial(Q, p); + result = g * Y + g * r; + } else if (q >= 0.25) { + // + // Rational approximation for 0.5 > q >= 0.25 + // + // x = sqrt(-2*log(q)) / (Y + R(q)) + // + // Where Y is a constant, and R(q) is optimised for a low + // absolute error compared to Y. + // + // double : Max error found: 7.403372e-17 + // long double : Max error found: 6.084616e-20 + // Maximum Deviation Found (error term) 4.811e-20 + // + static const float Y = 2.249481201171875f; + static const double P[] = {-0.202433508355938759655, 0.105264680699391713268, + 8.37050328343119927838, 17.6447298408374015486, + -18.8510648058714251895, -44.6382324441786960818, + 17.445385985570866523, 21.1294655448340526258, + -3.67192254707729348546}; + static const double Q[] = { + 1.0, + 6.24264124854247537712, + 3.9713437953343869095, + -28.6608180499800029974, + -20.1432634680485188801, + 48.5609213108739935468, + 10.8268667355460159008, + -22.6436933413139721736, + 1.72114765761200282724}; + double g = sqrt(-2 * log(q)); + double xs = q - 0.25f; + double r = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); + result = g / (Y + r); + } else { + // + // For q < 0.25 we have a series of rational approximations all + // of the general form: + // + // let: x = sqrt(-log(q)) + // + // Then the result is given by: + // + // x(Y+R(x-B)) + // + // where Y is a constant, B is the lowest value of x for which + // the approximation is valid, and R(x-B) is optimised for a low + // absolute error compared to Y. + // + // Note that almost all code will really go through the first + // or maybe second approximation. After than we're dealing with very + // small input values indeed: 80 and 128 bit long double's go all the + // way down to ~ 1e-5000 so the "tail" is rather long... + // + double x = sqrt(-log(q)); + if (x < 3) { + // Max error found: 1.089051e-20 + static const float Y = 0.807220458984375f; + static const double P[] = { + -0.131102781679951906451, -0.163794047193317060787, + 0.117030156341995252019, 0.387079738972604337464, + 0.337785538912035898924, 0.142869534408157156766, + 0.0290157910005329060432, 0.00214558995388805277169, + -0.679465575181126350155e-6, 0.285225331782217055858e-7, + -0.681149956853776992068e-9}; + static const double Q[] = { + 1.0, + 3.46625407242567245975, + 5.38168345707006855425, + 4.77846592945843778382, + 2.59301921623620271374, + 0.848854343457902036425, + 0.152264338295331783612, + 0.01105924229346489121}; + double xs = x - 1.125f; + double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); + result = Y * x + R * x; + } else if (x < 6) { + // Max error found: 8.389174e-21 + static const float Y = 0.93995571136474609375f; + static const double P[] = { + -0.0350353787183177984712, -0.00222426529213447927281, + 0.0185573306514231072324, 0.00950804701325919603619, + 0.00187123492819559223345, 0.000157544617424960554631, + 0.460469890584317994083e-5, -0.230404776911882601748e-9, + 0.266339227425782031962e-11}; + static const double Q[] = { + 1.0, + 1.3653349817554063097, + 0.762059164553623404043, + 0.220091105764131249824, + 0.0341589143670947727934, + 0.00263861676657015992959, + 0.764675292302794483503e-4}; + double xs = x - 3; + double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); + result = Y * x + R * x; + } else if (x < 18) { + // Max error found: 1.481312e-19 + static const float Y = 0.98362827301025390625f; + static const double P[] = { + -0.0167431005076633737133, -0.00112951438745580278863, + 0.00105628862152492910091, 0.000209386317487588078668, + 0.149624783758342370182e-4, 0.449696789927706453732e-6, + 0.462596163522878599135e-8, -0.281128735628831791805e-13, + 0.99055709973310326855e-16}; + static const double Q[] = { + 1.0, + 0.591429344886417493481, + 0.138151865749083321638, + 0.0160746087093676504695, + 0.000964011807005165528527, + 0.275335474764726041141e-4, + 0.282243172016108031869e-6}; + double xs = x - 6; + double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); + result = Y * x + R * x; + } else if (x < 44) { + // Max error found: 5.697761e-20 + static const float Y = 0.99714565277099609375f; + static const double P[] = { + -0.0024978212791898131227, -0.779190719229053954292e-5, + 0.254723037413027451751e-4, 0.162397777342510920873e-5, + 0.396341011304801168516e-7, 0.411632831190944208473e-9, + 0.145596286718675035587e-11, -0.116765012397184275695e-17}; + static const double Q[] = { + 1.0, + 0.207123112214422517181, + 0.0169410838120975906478, + 0.000690538265622684595676, + 0.145007359818232637924e-4, + 0.144437756628144157666e-6, + 0.509761276599778486139e-9}; + double xs = x - 18; + double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); + result = Y * x + R * x; + } else { + // Max error found: 1.279746e-20 + static const float Y = 0.99941349029541015625f; + static const double P[] = { + -0.000539042911019078575891, -0.28398759004727721098e-6, + 0.899465114892291446442e-6, 0.229345859265920864296e-7, + 0.225561444863500149219e-9, 0.947846627503022684216e-12, + 0.135880130108924861008e-14, -0.348890393399948882918e-21}; + static const double Q[] = { + 1.0, + 0.0845746234001899436914, + 0.00282092984726264681981, + 0.468292921940894236786e-4, + 0.399968812193862100054e-6, + 0.161809290887904476097e-8, + 0.231558608310259605225e-11}; + double xs = x - 44; + double R = evaluate_polynomial(P, xs) / evaluate_polynomial(Q, xs); + result = Y * x + R * x; + } + } + return result; } -inline double erfcinv(double z) -{ - // - // Begin by testing for domain errors, and other special cases: - // - if((z < 0) || (z > 2)) - return NAN; - if(z == 0) - return INFINITY; - if(z == 2) - return -INFINITY; - // - // Normalise the input, so it's in the range [0,1], we will - // negate the result if z is outside that range. This is a simple - // application of the erfc reflection formula: erfc(-z) = 2 - erfc(z) - // - double p, q, s; - if(z > 1) - { - q = 2 - z; - p = 1 - q; - s = -1; - } - else - { - p = 1 - z; - q = z; - s = 1; - } +inline double erfcinv(double z) { + // + // Begin by testing for domain errors, and other special cases: + // + if ((z < 0) || (z > 2)) + return NAN; + if (z == 0) + return INFINITY; + if (z == 2) + return -INFINITY; + // + // Normalise the input, so it's in the range [0,1], we will + // negate the result if z is outside that range. This is a simple + // application of the erfc reflection formula: erfc(-z) = 2 - erfc(z) + // + double p, q, s; + if (z > 1) { + q = 2 - z; + p = 1 - q; + s = -1; + } else { + p = 1 - z; + q = z; + s = 1; + } - // - // And get the result, negating where required: - // - return s * erfinv_imp(p, q); + // + // And get the result, negating where required: + // + return s * erfinv_imp(p, q); } -inline double erfinv(double z) -{ - // - // Begin by testing for domain errors, and other special cases: - // - if((z < -1) || (z > 1)) - return NAN; - if(z == 1) - return INFINITY; - if(z == -1) - return -INFINITY; - if(z == 0) - return 0; - // - // Normalise the input, so it's in the range [0,1], we will - // negate the result if z is outside that range. This is a simple - // application of the erf reflection formula: erf(-z) = -erf(z) - // - double p, q, s; - if(z < 0) - { - p = -z; - q = 1 - p; - s = -1; - } - else - { - p = z; - q = 1 - z; - s = 1; - } +inline double erfinv(double z) { + // + // Begin by testing for domain errors, and other special cases: + // + if ((z < -1) || (z > 1)) + return NAN; + if (z == 1) + return INFINITY; + if (z == -1) + return -INFINITY; + if (z == 0) + return 0; + // + // Normalise the input, so it's in the range [0,1], we will + // negate the result if z is outside that range. This is a simple + // application of the erf reflection formula: erf(-z) = -erf(z) + // + double p, q, s; + if (z < 0) { + p = -z; + q = 1 - p; + s = -1; + } else { + p = z; + q = 1 - z; + s = 1; + } - // - // And get the result, negating where required: - // - return s * erfinv_imp(p, q); + // + // And get the result, negating where required: + // + return s * erfinv_imp(p, q); } inline float erfcinvf(float z) { @@ -412,6 +342,6 @@ inline float erfinvf(float z) { return erfinv(z); } -#endif // ifndef __CUDACC__ +#endif // ifndef __CUDACC__ // vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index f0a590ff..bde0341d 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -11,10 +11,10 @@ #pragma once -#include "src/common/opr_param_defs_enumv.cuh" +#include "src/common/elemwise/erfinv.h" #include "src/common/elemwise_helper.cuh" +#include "src/common/opr_param_defs_enumv.cuh" #include "src/common/utils.cuh" -#include "src/common/elemwise/erfinv.h" #include "megcore_cdefs.h" #include "megdnn/dtype.h" @@ -30,7 +30,7 @@ using std::min; #ifndef MEGDNN_ELEMWISE_MODE_ENABLE #define MEGDNN_ELEMWISE_MODE_ENABLE(_mode, _cb) _cb(_mode) -#define MEGDNN_ELEMWISE_MODE_ENABLE_ALL 1 +#define MEGDNN_ELEMWISE_MODE_ENABLE_ALL 1 #endif #if MEGDNN_CC_HOST && !defined(__host__) @@ -41,217 +41,214 @@ using std::min; namespace megdnn { - - template - __device__ __host__ inline T log_sum_exp(T x, T y) { - T a, b; - a = x < y ? x : y; - b = x < y ? y : x; - return T(b + log1pf(exp(a - b))); - } - - __device__ __host__ inline float fast_tanh(float x) { - return x * (27.f + x * x) / (27.f + 9.f * x * x); - } - - //! use multiplying (1.f / 6.f) to replace dividing 6.f, because we didn't - //! pass - //! --use_fast_math to nvcc to enable --prec_div optimization, which will - //! cause performance drop on Turing architecture - __device__ __host__ inline float fuse_add_hswish(float x, float y) { - float z = x + y; - return z * min(max(z + 3, 0.f), 6.f) * (1.f / 6.f); - } - - __device__ __host__ inline float fast_tanh_grad(float x, float dx) { - float x_pow2 = x * x; - float deno = 3.f + x_pow2; - return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx; - } - - //! grad of silu - __device__ __host__ inline float silu_grad(float x, float dy) { - const float one = 1.0; - float sigmoid = one / (one + expf(-x)); - return dy * sigmoid * (one + x * (one - sigmoid)); - } - - __device__ __host__ inline float normcdf(float x) { +template +__device__ __host__ inline T log_sum_exp(T x, T y) { + T a, b; + a = x < y ? x : y; + b = x < y ? y : x; + return T(b + log1pf(exp(a - b))); +} + +__device__ __host__ inline float fast_tanh(float x) { + return x * (27.f + x * x) / (27.f + 9.f * x * x); +} + +//! use multiplying (1.f / 6.f) to replace dividing 6.f, because we didn't +//! pass +//! --use_fast_math to nvcc to enable --prec_div optimization, which will +//! cause performance drop on Turing architecture +__device__ __host__ inline float fuse_add_hswish(float x, float y) { + float z = x + y; + return z * min(max(z + 3, 0.f), 6.f) * (1.f / 6.f); +} + +__device__ __host__ inline float fast_tanh_grad(float x, float dx) { + float x_pow2 = x * x; + float deno = 3.f + x_pow2; + return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * dx; +} + +//! grad of silu +__device__ __host__ inline float silu_grad(float x, float dy) { + const float one = 1.0; + float sigmoid = one / (one + expf(-x)); + return dy * sigmoid * (one + x * (one - sigmoid)); +} + +__device__ __host__ inline float normcdf(float x) { #if MEGDNN_CC_HOST - return 0.5f * (1.f + erff(x / sqrtf(2.f))); + return 0.5f * (1.f + erff(x / sqrtf(2.f))); #else - //! use cuda build-in math - return ::normcdff(x); + //! use cuda build-in math + return ::normcdff(x); #endif - } +} - //! grad of gelu - __device__ __host__ inline float gelu_grad(float x, float dy) { - //! 1/ sqrt(2 * pi) - const float coeff = 0.3989422804014327f; - float phi = coeff * expf(-0.5f * x * x); - float normcdf_v = normcdf(x); - return dy * (normcdf_v + x * phi); - } +//! grad of gelu +__device__ __host__ inline float gelu_grad(float x, float dy) { + //! 1/ sqrt(2 * pi) + const float coeff = 0.3989422804014327f; + float phi = coeff * expf(-0.5f * x * x); + float normcdf_v = normcdf(x); + return dy * (normcdf_v + x * phi); +} #include "src/common/elemwise/each_mode.inl" - template - struct ElemwiseKern; +template +struct ElemwiseKern; //! define kernel for a single ctype -#define DEF_KERN(_ctype, _mode, _imp) \ - template \ - struct ElemwiseKern { \ - typedef _ctype ctype; \ - static __host__ __device__ _ctype apply(KERN_SIG) { \ - return ctype(_imp); \ - } \ +#define DEF_KERN(_ctype, _mode, _imp) \ + template \ + struct ElemwiseKern { \ + typedef _ctype ctype; \ + static __host__ __device__ _ctype apply(KERN_SIG) { return ctype(_imp); } \ } //! define kernel for all float types -#define DEF_KERN_FLOAT(_mode, _imp) \ - DEF_KERN(dt_float32, _mode, _imp); \ +#define DEF_KERN_FLOAT(_mode, _imp) \ + DEF_KERN(dt_float32, _mode, _imp); \ DNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \ DNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);) //! define kernel for all int types -#define DEF_KERN_INT(_mode, _imp) \ +#define DEF_KERN_INT(_mode, _imp) \ DEF_KERN(dt_int32, _mode, _imp); \ DEF_KERN(dt_int16, _mode, _imp); \ - DEF_KERN(dt_int8, _mode, _imp); \ - DEF_KERN(dt_uint8, _mode, _imp); \ + DEF_KERN(dt_int8, _mode, _imp); \ + DEF_KERN(dt_uint8, _mode, _imp); //! define kernel for all ctypes #define DEF_KERN_ALL(_mode, _imp) \ - DEF_KERN_INT(_mode, _imp); \ - DEF_KERN_FLOAT(_mode, _imp); \ + DEF_KERN_INT(_mode, _imp); \ + DEF_KERN_FLOAT(_mode, _imp); - /* ================== unary kernels ================== */ +/* ================== unary kernels ================== */ #define KERN_SIG ctype x - // int and float - DEF_KERN_ALL(NEGATE, -x); +// int and float +DEF_KERN_ALL(NEGATE, -x); #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) - DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x); - DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x); +DEF_KERN_INT(RELU, x <= ctype(0) ? ctype(0) : x); +DEF_KERN_FLOAT(RELU, x <= 0.f ? ctype(0) : x); #else - DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x); +DEF_KERN_ALL(RELU, x <= ctype(0) ? ctype(0) : x); #endif - DEF_KERN_INT(ABS, abs(int(x))); - // DEF_KERN_INT(ABS, x > ctype(0) ? x : -x); - DEF_KERN_FLOAT(ABS, fabsf(x)); - - // float only - DEF_KERN_FLOAT(ACOS, acosf(x)); - DEF_KERN_FLOAT(ASIN, asinf(x)); - DEF_KERN_FLOAT(CEIL, ceilf(x)); - DEF_KERN_FLOAT(COS, cosf(x)); - DEF_KERN_FLOAT(EXP, expf(x)); - DEF_KERN_FLOAT(EXPM1, expm1f(x)); - DEF_KERN_FLOAT(FLOOR, floorf(x)); - DEF_KERN_FLOAT(LOG, logf(x)); - DEF_KERN_FLOAT(LOG1P, log1pf(x)); - DEF_KERN_FLOAT(SIGMOID, 1.f / (expf(-x) + 1.f)); - DEF_KERN_FLOAT(SIN, sinf(x)); - DEF_KERN_FLOAT(TANH, tanhf(x)); - DEF_KERN_FLOAT(FAST_TANH, fast_tanh(x)); - DEF_KERN_FLOAT(ROUND, roundf(x)); - DEF_KERN_FLOAT(ERF, erff(x)); - DEF_KERN_FLOAT(ERFINV, erfinvf(x)); - DEF_KERN_FLOAT(ERFC, erfcf(x)); - DEF_KERN_FLOAT(ERFCINV, erfcinvf(x)); - DEF_KERN_FLOAT(H_SWISH, x * min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); - DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f)); - DEF_KERN_FLOAT(GELU, x * normcdf(x)); - - // int only - DEF_KERN(dt_bool, NOT, x ^ 1); +DEF_KERN_INT(ABS, abs(int(x))); +// DEF_KERN_INT(ABS, x > ctype(0) ? x : -x); +DEF_KERN_FLOAT(ABS, fabsf(x)); + +// float only +DEF_KERN_FLOAT(ACOS, acosf(x)); +DEF_KERN_FLOAT(ASIN, asinf(x)); +DEF_KERN_FLOAT(CEIL, ceilf(x)); +DEF_KERN_FLOAT(COS, cosf(x)); +DEF_KERN_FLOAT(EXP, expf(x)); +DEF_KERN_FLOAT(EXPM1, expm1f(x)); +DEF_KERN_FLOAT(FLOOR, floorf(x)); +DEF_KERN_FLOAT(LOG, logf(x)); +DEF_KERN_FLOAT(LOG1P, log1pf(x)); +DEF_KERN_FLOAT(SIGMOID, 1.f / (expf(-x) + 1.f)); +DEF_KERN_FLOAT(SIN, sinf(x)); +DEF_KERN_FLOAT(TANH, tanhf(x)); +DEF_KERN_FLOAT(FAST_TANH, fast_tanh(x)); +DEF_KERN_FLOAT(ROUND, roundf(x)); +DEF_KERN_FLOAT(ERF, erff(x)); +DEF_KERN_FLOAT(ERFINV, erfinvf(x)); +DEF_KERN_FLOAT(ERFC, erfcf(x)); +DEF_KERN_FLOAT(ERFCINV, erfcinvf(x)); +DEF_KERN_FLOAT(H_SWISH, x* min(max(x + 3, 0.f), 6.f) * (1.f / 6.f)); +DEF_KERN_FLOAT(SILU, x / (expf(-x) + 1.f)); +DEF_KERN_FLOAT(GELU, x* normcdf(x)); + +// int only +DEF_KERN(dt_bool, NOT, x ^ 1); #undef KERN_SIG - /* ================== binary kernels ================== */ +/* ================== binary kernels ================== */ #define KERN_SIG ctype x, ctype y - // int and float +// int and float #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) - DEF_KERN_INT(ABS_GRAD, x > ctype(0) ? y : -y); - DEF_KERN_FLOAT(ABS_GRAD, x > 0.f ? y : -y); +DEF_KERN_INT(ABS_GRAD, x > ctype(0) ? y : -y); +DEF_KERN_FLOAT(ABS_GRAD, x > 0.f ? y : -y); #else - DEF_KERN_ALL(ABS_GRAD, x > ctype(0) ? y : -y); +DEF_KERN_ALL(ABS_GRAD, x > ctype(0) ? y : -y); #endif - DEF_KERN_ALL(ADD, x + y); - DEF_KERN_ALL(MAX, x > y ? x : y); - DEF_KERN_ALL(MIN, x < y ? x : y); - DEF_KERN_ALL(MUL, x* y); - DEF_KERN(dt_bool, AND, x && y); - DEF_KERN(dt_bool, OR, x || y); - DEF_KERN(dt_bool, XOR, x ^ y); - DEF_KERN_INT(RMULH, round_mulh_saturate(x, y)); - DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y); - DEF_KERN_ALL(SUB, x - y); +DEF_KERN_ALL(ADD, x + y); +DEF_KERN_ALL(MAX, x > y ? x : y); +DEF_KERN_ALL(MIN, x < y ? x : y); +DEF_KERN_ALL(MUL, x* y); +DEF_KERN(dt_bool, AND, x&& y); +DEF_KERN(dt_bool, OR, x || y); +DEF_KERN(dt_bool, XOR, x ^ y); +DEF_KERN_INT(RMULH, round_mulh_saturate(x, y)); +DEF_KERN_ALL(SIGMOID_GRAD, x*(ctype(1) - x) * y); +DEF_KERN_ALL(SUB, x - y); #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) - DEF_KERN_INT(SWITCH_GT0, x > ctype(0) ? y : ctype(0)); - DEF_KERN_FLOAT(SWITCH_GT0, x > 0.f ? y : ctype(0)); +DEF_KERN_INT(SWITCH_GT0, x > ctype(0) ? y : ctype(0)); +DEF_KERN_FLOAT(SWITCH_GT0, x > 0.f ? y : ctype(0)); #else - DEF_KERN_ALL(SWITCH_GT0, x > ctype(0) ? y : ctype(0)); +DEF_KERN_ALL(SWITCH_GT0, x > ctype(0) ? y : ctype(0)); #endif - DEF_KERN_ALL(TANH_GRAD, (ctype(1) - x * x) * y); - DEF_KERN_ALL(LT, x < y); - DEF_KERN_ALL(LEQ, x <= y); - DEF_KERN_ALL(EQ, x == y); - DEF_KERN(dt_bool, LT, x < y); - DEF_KERN(dt_bool, LEQ, x <= y); - DEF_KERN(dt_bool, EQ, x == y); - - DEF_KERN_INT(FLOOR_DIV, x / y); - DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); - - DEF_KERN_INT(MOD, x % y); - DEF_KERN_FLOAT(MOD, fmodf(x, y)); - - DEF_KERN_INT(SHL, x << y); - DEF_KERN_INT(SHR, x >> y); +DEF_KERN_ALL(TANH_GRAD, (ctype(1) - x * x) * y); +DEF_KERN_ALL(LT, x < y); +DEF_KERN_ALL(LEQ, x <= y); +DEF_KERN_ALL(EQ, x == y); +DEF_KERN(dt_bool, LT, x < y); +DEF_KERN(dt_bool, LEQ, x <= y); +DEF_KERN(dt_bool, EQ, x == y); + +DEF_KERN_INT(FLOOR_DIV, x / y); +DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); + +DEF_KERN_INT(MOD, x % y); +DEF_KERN_FLOAT(MOD, fmodf(x, y)); + +DEF_KERN_INT(SHL, x << y); +DEF_KERN_INT(SHR, x >> y); #if defined(__HIP_PLATFORM_HCC__) && !defined(__HIP_PLATFORM_NVCC__) - DEF_KERN_INT(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y)); - DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y)); +DEF_KERN_INT(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y)); +DEF_KERN_FLOAT(FUSE_ADD_RELU, (x + y) <= 0.f ? ctype(0) : (x + y)); #else - DEF_KERN_ALL(FUSE_ADD_RELU, - (x + y) <= ctype(0) ? ctype(0) : (x + y)); +DEF_KERN_ALL(FUSE_ADD_RELU, (x + y) <= ctype(0) ? ctype(0) : (x + y)); #endif - // float only - DEF_KERN_FLOAT(TRUE_DIV, x / y); - DEF_KERN_FLOAT(POW, powf(x, y)); - DEF_KERN_FLOAT(LOG_SUM_EXP, log_sum_exp(x, y)); - DEF_KERN_FLOAT(FAST_TANH_GRAD, fast_tanh_grad(x, y)); +// float only +DEF_KERN_FLOAT(TRUE_DIV, x / y); +DEF_KERN_FLOAT(POW, powf(x, y)); +DEF_KERN_FLOAT(LOG_SUM_EXP, log_sum_exp(x, y)); +DEF_KERN_FLOAT(FAST_TANH_GRAD, fast_tanh_grad(x, y)); - DEF_KERN_FLOAT(FUSE_ADD_TANH, tanhf(x+y)); - DEF_KERN_FLOAT(FUSE_ADD_SIGMOID, 1.f / (expf(-(x+y)) + 1.f)); +DEF_KERN_FLOAT(FUSE_ADD_TANH, tanhf(x + y)); +DEF_KERN_FLOAT(FUSE_ADD_SIGMOID, 1.f / (expf(-(x + y)) + 1.f)); - DEF_KERN_FLOAT(ATAN2, atan2f(x, y)); - DEF_KERN_FLOAT(H_SWISH_GRAD, - x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y))); +DEF_KERN_FLOAT(ATAN2, atan2f(x, y)); +DEF_KERN_FLOAT( + H_SWISH_GRAD, + x < -3.f ? (ctype)0.f + : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y))); - DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); - DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y)); - DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); +DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); +DEF_KERN_FLOAT(SILU_GRAD, silu_grad(x, y)); +DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y)); #undef KERN_SIG - /* ================== ternary kernels ================== */ +/* ================== ternary kernels ================== */ #define KERN_SIG ctype x, ctype y, ctype z - // int and float - DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); - DEF_KERN_ALL(FUSE_MUL_ADD3, x * y + z); +// int and float +DEF_KERN_ALL(COND_LEQ_MOV, x <= y ? z : ctype(0)); +DEF_KERN_ALL(FUSE_MUL_ADD3, x* y + z); #undef KERN_SIG - #undef DEF_KERN_AD #undef DEF_KERN -} // namespace megdnn +} // namespace megdnn #if MEGDNN_CC_HOST && defined(MEGDNN_HOST_DEVICE_SELF_DEFINE) #undef MEGDNN_HOST_DEVICE_SELF_DEFINE diff --git a/dnn/src/common/elemwise/opr_impl.cpp b/dnn/src/common/elemwise/opr_impl.cpp index 96eb820d..3850190c 100644 --- a/dnn/src/common/elemwise/opr_impl.cpp +++ b/dnn/src/common/elemwise/opr_impl.cpp @@ -74,7 +74,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { #define cb(_m) \ MIDOUT_BEGIN(megdnn_common_elemwise, midout_iv(Mode::_m)) { \ - get(Mode::_m).allow_bool = true; \ + get(Mode::_m).allow_bool = true; \ } \ MIDOUT_END(); MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_BOOL(cb); @@ -141,8 +141,9 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { #if MEGDNN_ELEMWISE_MODE_ENABLE_ALL for (auto&& i : traits) { - megdnn_assert(i.arity && (i.allow_int || i.allow_float || i.allow_bool) && - (!i.commutable || i.arity == 2)); + megdnn_assert( + i.arity && (i.allow_int || i.allow_float || i.allow_bool) && + (!i.commutable || i.arity == 2)); } #else #pragma message "elemwise mode stripped" @@ -156,8 +157,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { return ret; } -void ElemwiseForward::deduce_shape(const TensorShapeArray& src, - TensorShape& dst) { +void ElemwiseForward::deduce_shape(const TensorShapeArray& src, TensorShape& dst) { auto err = [&]() { std::string msg("bad input shape for polyadic operator: "); bool first = true; @@ -189,8 +189,7 @@ void ElemwiseForward::deduce_shape(const TensorShapeArray& src, err(); } int final_idx = std::max(cur_idx, dst_idx); - dst.shape[final_idx] = - (v0 != 0 && v1 != 0) ? std::max(v0, v1) : 0; + dst.shape[final_idx] = (v0 != 0 && v1 != 0) ? std::max(v0, v1) : 0; } else { if (dst_idx < 0) { dst.shape[cur_idx] = cur.shape[cur_idx]; @@ -211,15 +210,13 @@ void FormatDeducer::feed(TensorFormat cur) { if (m_result == m_default) { m_result = cur; } else { - megdnn_assert(m_result == cur, - "different input layout formats in elemwise: %s vs %s", - m_result.impl()->to_string().c_str(), - cur.impl()->to_string().c_str()); + megdnn_assert( + m_result == cur, "different input layout formats in elemwise: %s vs %s", + m_result.impl()->to_string().c_str(), cur.impl()->to_string().c_str()); } } -void ElemwiseForward::deduce_format(const TensorFormatArray& src, - TensorFormat& dst) { +void ElemwiseForward::deduce_format(const TensorFormatArray& src, TensorFormat& dst) { FormatDeducer d; for (auto i : src) { d.feed(i); @@ -227,8 +224,7 @@ void ElemwiseForward::deduce_format(const TensorFormatArray& src, dst = d.get(); } -void ElemwiseForward::deduce_layout(const TensorLayoutArray& src, - TensorLayout& dst) { +void ElemwiseForward::deduce_layout(const TensorLayoutArray& src, TensorLayout& dst) { megdnn_assert(src.size() == mode_trait().arity); DType dtype; FormatDeducer format_deducer; @@ -237,9 +233,9 @@ void ElemwiseForward::deduce_layout(const TensorLayoutArray& src, dtype = i.dtype; dst.format = i.format; } else { - megdnn_assert(dtype == i.dtype, - "input dtype not unique: get %s and %s", dtype.name(), - i.dtype.name()); + megdnn_assert( + dtype == i.dtype, "input dtype not unique: get %s and %s", + dtype.name(), i.dtype.name()); } format_deducer.feed(i.format); @@ -286,16 +282,14 @@ void ElemwiseForward::check_dtype(DType dtype) { auto&& trait = mode_trait(); switch (dtype.category()) { case DTypeCategory::FLOAT: - megdnn_assert(trait.allow_float, "unsupport mode %s for float\n", - trait.name); + megdnn_assert( + trait.allow_float, "unsupport mode %s for float\n", trait.name); break; case DTypeCategory::INT: - megdnn_assert(trait.allow_int, "unsupport mode %s for int\n", - trait.name); + megdnn_assert(trait.allow_int, "unsupport mode %s for int\n", trait.name); break; case DTypeCategory::BOOL: - megdnn_assert(trait.allow_bool, "unsupport mode %s for bool\n", - trait.name); + megdnn_assert(trait.allow_bool, "unsupport mode %s for bool\n", trait.name); break; default: megdnn_throw("bad dtype"); diff --git a/dnn/src/common/elemwise/opr_impl_body.inl b/dnn/src/common/elemwise/opr_impl_body.inl index 1eb1a945..20d98e21 100644 --- a/dnn/src/common/elemwise/opr_impl_body.inl +++ b/dnn/src/common/elemwise/opr_impl_body.inl @@ -13,16 +13,15 @@ #error "on_arity_dispatched_cb_dtype and IMPL_MODE_DISPATCHER must be defined" #endif -template +template void ElemwiseForwardImpl::on_arity_dispatched() { auto src = make_elemwise_op_param(); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype) MEGDNN_FOREACH_COMPUTING_DTYPE_INT(on_arity_dispatched_cb_dtype) - on_arity_dispatched_cb_dtype(::megdnn::dtype::Bool) - megdnn_throw("bad dtype"); + on_arity_dispatched_cb_dtype(::megdnn::dtype::Bool) megdnn_throw("bad dtype"); } -template +template void ElemwiseForwardImpl::on_arity_dispatched_no_bool() { auto src = make_elemwise_op_param(); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(on_arity_dispatched_cb_dtype) @@ -62,60 +61,60 @@ IMPL_MODE_DISPATCHER(1, DTypeCategory::BOOL); IMPL_MODE_DISPATCHER(2, DTypeCategory::BOOL); #undef FOREACH -void ElemwiseForwardImpl::exec( - const TensorNDArray &src, - _megdnn_tensor_out dst) { +void ElemwiseForwardImpl::exec(const TensorNDArray& src, _megdnn_tensor_out dst) { m_src = &src; m_dst = &dst; #define CB_CHK_MODE_ENABLE(_) 1 if (m_param.mode == Mode::FUSE_MUL_ADD3) { -#if MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, CB_CHK_MODE_ENABLE) +0 +#if MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, CB_CHK_MODE_ENABLE) + 0 ElemwiseOpParamN<3> param; bool c_is_scalar; prepare_fma3(param, c_is_scalar); - switch(m_dst->layout.dtype.enumv()) { -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: \ - { \ - using ctype = DTypeTrait<_dt>::ctype; \ - if (c_is_scalar) { \ - return impl_fuse_mul_add3(param); \ - } else { \ - return impl_fuse_mul_add3(param); \ - } \ - } + switch (m_dst->layout.dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + if (c_is_scalar) { \ + return impl_fuse_mul_add3(param); \ + } else { \ + return impl_fuse_mul_add3(param); \ + } \ + } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb default: megdnn_throw("bad dtype"); } -#endif // enable FUSE_MUL_ADD3 +#endif // enable FUSE_MUL_ADD3 } else if (m_param.mode == Mode::FUSE_MUL_ADD4) { -#if MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD4, CB_CHK_MODE_ENABLE) +0 +#if MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD4, CB_CHK_MODE_ENABLE) + 0 ElemwiseOpParamN<4> param; prepare_fma4(param); - switch(m_dst->layout.dtype.enumv()) { -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: \ - return impl_fuse_mul_add4::ctype>(param); + switch (m_dst->layout.dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + return impl_fuse_mul_add4::ctype>(param); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb default: megdnn_throw("bad dtype"); } -#endif // enable FUSE_MUL_ADD4 +#endif // enable FUSE_MUL_ADD4 } #undef CB_CHK_MODE_ENABLE - switch(src.size()) { -#define D(_n) case _n: return on_arity_dispatched<_n>() + switch (src.size()) { +#define D(_n) \ + case _n: \ + return on_arity_dispatched<_n>() D(1); D(2); #undef D - case 3: return on_arity_dispatched_no_bool<3>(); + case 3: + return on_arity_dispatched_no_bool<3>(); default: megdnn_throw("bad size of input tensors"); } diff --git a/dnn/src/common/elemwise/opr_impl_class_def.inl b/dnn/src/common/elemwise/opr_impl_class_def.inl index b158be22..3c3d9fb7 100644 --- a/dnn/src/common/elemwise/opr_impl_class_def.inl +++ b/dnn/src/common/elemwise/opr_impl_class_def.inl @@ -9,35 +9,33 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ - protected: - template - void on_arity_dispatched(); - - template - void on_arity_dispatched_no_bool(); - - template - struct ModeDispatcher; - - /*! - * \brief special impl for FUSE_MUL_ADD3 mode - * \tparam c_is_scalar see ElemwiseForwardImplHelper::prepare_fma3 - */ - template - void impl_fuse_mul_add3(const ElemwiseOpParamN<3> ¶ms); - - /*! - * \brief special impl for FUSE_MUL_ADD4 mode - * \param[out] params see ElemwiseForwardImplHelper::prepare_fma4 - */ - template - void impl_fuse_mul_add4(const ElemwiseOpParamN<4> ¶ms); - - public: - using ElemwiseForwardImplHelper::ElemwiseForwardImplHelper; - - void exec( - const TensorNDArray &src, - _megdnn_tensor_out dst) override; +protected: +template +void on_arity_dispatched(); + +template +void on_arity_dispatched_no_bool(); + +template +struct ModeDispatcher; + +/*! + * \brief special impl for FUSE_MUL_ADD3 mode + * \tparam c_is_scalar see ElemwiseForwardImplHelper::prepare_fma3 + */ +template +void impl_fuse_mul_add3(const ElemwiseOpParamN<3>& params); + +/*! + * \brief special impl for FUSE_MUL_ADD4 mode + * \param[out] params see ElemwiseForwardImplHelper::prepare_fma4 + */ +template +void impl_fuse_mul_add4(const ElemwiseOpParamN<4>& params); + +public: +using ElemwiseForwardImplHelper::ElemwiseForwardImplHelper; + +void exec(const TensorNDArray& src, _megdnn_tensor_out dst) override; // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/elemwise/opr_impl_helper.cpp b/dnn/src/common/elemwise/opr_impl_helper.cpp index 6fbcebdd..30bf96a9 100644 --- a/dnn/src/common/elemwise/opr_impl_helper.cpp +++ b/dnn/src/common/elemwise/opr_impl_helper.cpp @@ -18,8 +18,8 @@ using namespace megdnn; template ElemwiseOpParamN ElemwiseLayoutHelper::make_elemwise_op_param( void* opr, - void (*check_layout_and_broadcast)(void*, const TensorLayoutPtrArray&, - const TensorLayout&), + void (*check_layout_and_broadcast)( + void*, const TensorLayoutPtrArray&, const TensorLayout&), const TensorNDArray& src, const TensorND& dst) { megdnn_assert(src.size() == static_cast(arity)); ElemwiseOpParamN ret; @@ -34,11 +34,9 @@ ElemwiseOpParamN ElemwiseLayoutHelper::make_elemwise_op_param( } // explicit instantiation so subclasses can call this method -#define INST(n) \ - template ElemwiseOpParamN \ - ElemwiseLayoutHelper::make_elemwise_op_param( \ - void*, \ - void (*)(void*, const TensorLayoutPtrArray&, const TensorLayout&), \ +#define INST(n) \ + template ElemwiseOpParamN ElemwiseLayoutHelper::make_elemwise_op_param( \ + void*, void (*)(void*, const TensorLayoutPtrArray&, const TensorLayout&), \ const TensorNDArray&, const TensorND&) INST(1); INST(2); @@ -48,8 +46,8 @@ INST(5); INST(6); #undef INST -void ElemwiseForwardImplHelper::prepare_fma3(ElemwiseOpParamN<3>& param, - bool& c_is_scalar) { +void ElemwiseForwardImplHelper::prepare_fma3( + ElemwiseOpParamN<3>& param, bool& c_is_scalar) { c_is_scalar = is_broadcasted_scalar(m_src->at(2).layout); param = make_elemwise_op_param<3>(); @@ -83,17 +81,17 @@ bool ElemwiseLayoutHelper::is_broadcasted_scalar(const TensorLayout& layout) { template bool ElemwiseLayoutHelper::is_broadcastedx_channel_like( const TensorLayout& layout, BroadcastChannelInfo& info) { - if (layout.format.type() == TensorFormat::Type::DEFAULT && - layout.ndim == 3 && layout.stride[0] == slice_size && - layout.stride[1] == 0 && layout.stride[2] == 1) { + if (layout.format.type() == TensorFormat::Type::DEFAULT && layout.ndim == 3 && + layout.stride[0] == slice_size && layout.stride[1] == 0 && + layout.stride[2] == 1) { info.x = layout.shape[0]; info.y = layout.shape[1]; info.z = layout.shape[2]; return true; - } else if (layout.format.type() == TensorFormat::Type::DEFAULT && - layout.ndim == 4 && layout.stride[0] == 0 && - layout.stride[1] == slice_size && layout.stride[2] == 0 && - layout.stride[3] == 1) { + } else if ( + layout.format.type() == TensorFormat::Type::DEFAULT && layout.ndim == 4 && + layout.stride[0] == 0 && layout.stride[1] == slice_size && + layout.stride[2] == 0 && layout.stride[3] == 1) { info.x = layout.shape[1]; info.y = layout.shape[2]; info.z = layout.shape[3]; @@ -111,14 +109,13 @@ INST(8); bool ElemwiseLayoutHelper::is_broadcasted_channel_like( const TensorLayout& layout, BroadcastChannelInfo& info) { if (layout.format.type() == TensorFormat::Type::DEFAULT) { - if (layout.ndim == 3 && layout.stride[0] == 0 && - layout.stride[2] == 0 && layout.stride[1] == 1) { + if (layout.ndim == 3 && layout.stride[0] == 0 && layout.stride[2] == 0 && + layout.stride[1] == 1) { info.x = layout.shape[0]; info.y = layout.shape[1]; info.z = layout.shape[2]; return true; - } else if (layout.ndim == 2 && layout.stride[1] == 0 && - layout.stride[0] == 1) { + } else if (layout.ndim == 2 && layout.stride[1] == 0 && layout.stride[0] == 1) { info.x = 1; info.y = layout.shape[0]; info.z = layout.shape[1]; @@ -126,8 +123,8 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like( } } else { if (Image2DPack4TensorFormat::is_valid_image(layout)) { - auto align_axis = layout.format.as_impl() - .align_axis(); + auto align_axis = + layout.format.as_impl().align_axis(); if (layout.ndim == 4 && align_axis == 1 && (layout.stride[0] == 0 || layout.shape[0] == 1) && layout.stride[1] == 4 && layout.stride[2] == 0 && @@ -136,10 +133,11 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like( info.y = 1; info.z = layout.shape[2]; return true; - } else if (layout.ndim == 3 && align_axis == 1 && - (layout.stride[0] == 0 || layout.shape[0] == 1) && - layout.stride[1] == 0 && layout.shape[2] == 4 && - layout.stride[2] == 1) { + } else if ( + layout.ndim == 3 && align_axis == 1 && + (layout.stride[0] == 0 || layout.shape[0] == 1) && + layout.stride[1] == 0 && layout.shape[2] == 4 && + layout.stride[2] == 1) { //! [1, 1, 1, 1, 4] + [N, H, 1, W, 4] info.x = 1; info.y = 1; @@ -152,8 +150,8 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like( return false; } -bool ElemwiseLayoutHelper::is_broadcasted_1x(const TensorLayout& layout, - Broadcast1xInfo& binfo) { +bool ElemwiseLayoutHelper::is_broadcasted_1x( + const TensorLayout& layout, Broadcast1xInfo& binfo) { if (layout.ndim == 2 && layout.stride[0] == 0 && layout.stride[1] == 1) { binfo.x = layout[0]; binfo.y = layout[1]; diff --git a/dnn/src/common/elemwise/opr_impl_helper.h b/dnn/src/common/elemwise/opr_impl_helper.h index 892a10f0..2aabbd85 100644 --- a/dnn/src/common/elemwise/opr_impl_helper.h +++ b/dnn/src/common/elemwise/opr_impl_helper.h @@ -49,9 +49,8 @@ public: template static ElemwiseOpParamN make_elemwise_op_param( void* opr, - void (*check_layout_and_broadcast)(void*, - const TensorLayoutPtrArray&, - const TensorLayout&), + void (*check_layout_and_broadcast)( + void*, const TensorLayoutPtrArray&, const TensorLayout&), const TensorNDArray& src, const TensorND& dst); //! check whether given layout is 1D contig @@ -67,8 +66,7 @@ public: * * Note: input can be one-dimensional. */ - static bool is_broadcasted_1x(const TensorLayout& layout, - Broadcast1xInfo& binfo); + static bool is_broadcasted_1x(const TensorLayout& layout, Broadcast1xInfo& binfo); //! check whether given layout is broadcasted scalar static bool is_broadcasted_scalar(const TensorLayout& layout); @@ -79,8 +77,8 @@ public: * Note that Input can also be 2-dimensional, and must be [y, 1] broadacsted * into [y, z]; in such case x would be set to 1. */ - static bool is_broadcasted_channel_like(const TensorLayout& layout, - BroadcastChannelInfo& info); + static bool is_broadcasted_channel_like( + const TensorLayout& layout, BroadcastChannelInfo& info); /*! * \brief check whether layout matches BroadcastChannelInfo @@ -89,17 +87,16 @@ public: * broadacsted into [x, y, z] */ template - static bool is_broadcastedx_channel_like(const TensorLayout& layout, - BroadcastChannelInfo& info); + static bool is_broadcastedx_channel_like( + const TensorLayout& layout, BroadcastChannelInfo& info); }; class ElemwiseForwardImplHelper : public ElemwiseForward, protected ElemwiseLayoutHelper { - static void call_check_layout_and_broadcast(void* opr, - const TensorLayoutPtrArray& src, - const TensorLayout& dst) { - return static_cast(opr) - ->check_layout_and_broadcast(src, dst); + static void call_check_layout_and_broadcast( + void* opr, const TensorLayoutPtrArray& src, const TensorLayout& dst) { + return static_cast(opr)->check_layout_and_broadcast( + src, dst); } protected: diff --git a/dnn/src/common/elemwise_helper.cpp b/dnn/src/common/elemwise_helper.cpp index 200a99d4..df5911d5 100644 --- a/dnn/src/common/elemwise_helper.cpp +++ b/dnn/src/common/elemwise_helper.cpp @@ -14,40 +14,40 @@ namespace megdnn { - template - void ElemwiseOpParamN::init_from_given_tensor() { - megdnn_assert(!size && max_ndim == -1); - max_ndim = 0; - for (int i = 0; i < arity; ++ i) { - TensorLayout &layout = param[i].layout; - layout = layout.collapse_contiguous(); - auto cur = layout.total_nr_elems(); - if (!i) { - size = cur; - } else { - megdnn_assert(size == cur); - } - max_ndim = std::max(max_ndim, layout.ndim); +template +void ElemwiseOpParamN::init_from_given_tensor() { + megdnn_assert(!size && max_ndim == -1); + max_ndim = 0; + for (int i = 0; i < arity; ++i) { + TensorLayout& layout = param[i].layout; + layout = layout.collapse_contiguous(); + auto cur = layout.total_nr_elems(); + if (!i) { + size = cur; + } else { + megdnn_assert(size == cur); } - megdnn_assert(size > 0 && max_ndim > 0); + max_ndim = std::max(max_ndim, layout.ndim); } + megdnn_assert(size > 0 && max_ndim > 0); +} - template - void ElemwiseOpParamN::assert_initialized() const { - megdnn_assert(size, "uninitialized ElemwiseOpParamN"); - } +template +void ElemwiseOpParamN::assert_initialized() const { + megdnn_assert(size, "uninitialized ElemwiseOpParamN"); +} - template struct ElemwiseOpParamN<7>; - template struct ElemwiseOpParamN<6>; - template struct ElemwiseOpParamN<5>; - template struct ElemwiseOpParamN<4>; - template struct ElemwiseOpParamN<3>; - template struct ElemwiseOpParamN<2>; - template struct ElemwiseOpParamN<1>; +template struct ElemwiseOpParamN<7>; +template struct ElemwiseOpParamN<6>; +template struct ElemwiseOpParamN<5>; +template struct ElemwiseOpParamN<4>; +template struct ElemwiseOpParamN<3>; +template struct ElemwiseOpParamN<2>; +template struct ElemwiseOpParamN<1>; - void ElemwiseOpParamN<0>::assert_initialized() const { - megdnn_assert(size, "uninitialized ElemwiseOpParamN"); - } +void ElemwiseOpParamN<0>::assert_initialized() const { + megdnn_assert(size, "uninitialized ElemwiseOpParamN"); } +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/common/elemwise_helper.cuh b/dnn/src/common/elemwise_helper.cuh index 529e851d..e91f92a7 100644 --- a/dnn/src/common/elemwise_helper.cuh +++ b/dnn/src/common/elemwise_helper.cuh @@ -17,68 +17,70 @@ namespace { template struct MulType {}; -template<> struct MulType { typedef int16_t type; }; -template<> struct MulType { typedef int32_t type; }; -template<> struct MulType { typedef int64_t type; }; -template<> struct MulType { typedef uint16_t type; }; +template <> +struct MulType { + typedef int16_t type; +}; +template <> +struct MulType { + typedef int32_t type; +}; +template <> +struct MulType { + typedef int64_t type; +}; +template <> +struct MulType { + typedef uint16_t type; +}; } // namespace namespace megdnn { /*! - * \brief packed param for elemwise operators - * \tparam arity number of operands for this operator - */ -template + * \brief packed param for elemwise operators + * \tparam arity number of operands for this operator + */ +template struct ElemwiseOpParamN { - int max_ndim; //!< max ndim of all params - size_t size; //!< total number of elements (i.e. size of each param) + int max_ndim; //!< max ndim of all params + size_t size; //!< total number of elements (i.e. size of each param) TensorND param[arity]; - ElemwiseOpParamN(): - max_ndim(-1), size(0) - {} + ElemwiseOpParamN() : max_ndim(-1), size(0) {} - const TensorND& operator [](int idx) const { - return param[idx]; - } + const TensorND& operator[](int idx) const { return param[idx]; } - TensorND& operator [](int idx) { - return param[idx]; - } + TensorND& operator[](int idx) { return param[idx]; } /*! - * \brief initialize from current *param* - * - * *size* and *max_ndim* would be computed; params would be collapsed - * - * Each param must have the same number of elements. - */ + * \brief initialize from current *param* + * + * *size* and *max_ndim* would be computed; params would be collapsed + * + * Each param must have the same number of elements. + */ void init_from_given_tensor(); void assert_initialized() const; }; /*! - * \brief for elemwise opr without tensor arguments (i.e. only need index input) - */ -template<> + * \brief for elemwise opr without tensor arguments (i.e. only need index input) + */ +template <> struct ElemwiseOpParamN<0> { - size_t size; //!< total number of elements + size_t size; //!< total number of elements - ElemwiseOpParamN(size_t s = 0): - size(s) - { - } + ElemwiseOpParamN(size_t s = 0) : size(s) {} void assert_initialized() const; }; template -MEGDNN_DEVICE MEGDNN_HOST inline T rounding_shift_right_away_from_zero(T x, - int k) { +MEGDNN_DEVICE MEGDNN_HOST inline T rounding_shift_right_away_from_zero(T x, int k) { T mask = (T(1) << k) - 1; T threshold = (mask >> 1) + (x < 0); return (x >> k) + ((x & mask) > threshold); @@ -93,24 +95,25 @@ MEGDNN_DEVICE MEGDNN_HOST inline T rounding_shift_right_upward(T x, int k) { template MEGDNN_DEVICE MEGDNN_HOST inline T round_mulh_saturate(T a, T b) { - MEGDNN_STATIC_ASSERT(std::numeric_limits::digits <= 32, - "Portable RMULH is not supported for integer " - "types larger than 32 bits."); - MEGDNN_STATIC_ASSERT(std::numeric_limits::is_integer, - "Input types should be integer for RMULH"); + MEGDNN_STATIC_ASSERT( + std::numeric_limits::digits <= 32, + "Portable RMULH is not supported for integer " + "types larger than 32 bits."); + MEGDNN_STATIC_ASSERT( + std::numeric_limits::is_integer, + "Input types should be integer for RMULH"); bool overflow = a == b && a == DTypeTrait::min(); // TODO: This really should be // rounding_shift_right_away_from_zero, but we haven't yet found a fast way // to implement it on ARM NEON. For now, we just try to align with NEON's // VQRDMULH and hope that it does not harm our NN badly. - return overflow ? DTypeTrait::max() - : static_cast(rounding_shift_right_upward( - typename MulType::type(a) * - typename MulType::type(b), - std::numeric_limits::digits)); + return overflow + ? DTypeTrait::max() + : static_cast(rounding_shift_right_upward( + typename MulType::type(a) * typename MulType::type(b), + std::numeric_limits::digits)); } -} // namespace megdnn +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen - diff --git a/dnn/src/common/elemwise_multi_type/kern_defs.cuh b/dnn/src/common/elemwise_multi_type/kern_defs.cuh index de901e1f..aee31f80 100644 --- a/dnn/src/common/elemwise_multi_type/kern_defs.cuh +++ b/dnn/src/common/elemwise_multi_type/kern_defs.cuh @@ -12,8 +12,8 @@ #pragma once #include "megdnn/dtype.h" -#include "src/common/utils.cuh" #include "src/common/elemwise_helper.cuh" +#include "src/common/utils.cuh" #include @@ -30,7 +30,7 @@ struct Fma3iXxf32xf32xiYOp { } }; -template +template MEGDNN_HOST MEGDNN_DEVICE dtype round_shr_saturate(stype x, int k) { stype result = rounding_shift_right_away_from_zero(x, k); if (!is_same::value) { diff --git a/dnn/src/common/elemwise_multi_type/opr_impl.cpp b/dnn/src/common/elemwise_multi_type/opr_impl.cpp index dd2046a1..164aa5b1 100644 --- a/dnn/src/common/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/common/elemwise_multi_type/opr_impl.cpp @@ -35,9 +35,9 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { auto make_check_dtype_func = [](DType expected) { auto func = [expected](DType dtype) { - megdnn_assert(expected.enumv() == dtype.enumv(), - "expected %s, but got %s", expected.name(), - dtype.name()); + megdnn_assert( + expected.enumv() == dtype.enumv(), "expected %s, but got %s", + expected.name(), dtype.name()); }; return func; }; @@ -52,9 +52,9 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { auto make_out_dtype_func = [](DType expected) { auto func = [expected](DType& dtype, bool check) { if (check) { - megdnn_assert(expected.enumv() == dtype.enumv(), - "expected %s, but got %s", expected.name(), - dtype.name()); + megdnn_assert( + expected.enumv() == dtype.enumv(), "expected %s, but got %s", + expected.name(), dtype.name()); } else { dtype = expected; } @@ -230,8 +230,7 @@ const ModeTrait& ModeTrait::from_mode(Mode mode) { return traits.at(static_cast(mode)); } -void ElemwiseMultiType::deduce_layout(const TensorLayoutArray& src, - TensorLayout& dst) { +void ElemwiseMultiType::deduce_layout(const TensorLayoutArray& src, TensorLayout& dst) { auto trait = mode_trait(); megdnn_assert(src.size() == trait.arity); for (size_t i = 0; i < trait.arity; ++i) { diff --git a/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp b/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp index 38cfe8a3..74d82743 100644 --- a/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp +++ b/dnn/src/common/elemwise_multi_type/opr_impl_helper.cpp @@ -14,26 +14,26 @@ using namespace megdnn; -#define ON_QUANTIZED_MODE(_MODE, _n) \ - case Mode::Q##_MODE: \ - on_quantized_mode(make_elemwise_op_param<_n>(src, dst), dst, \ - Elemwise::Mode::_MODE); \ +#define ON_QUANTIZED_MODE(_MODE, _n) \ + case Mode::Q##_MODE: \ + on_quantized_mode( \ + make_elemwise_op_param<_n>(src, dst), dst, Elemwise::Mode::_MODE); \ break -void ElemwiseMultiTypeImplHelper::exec(_megdnn_in const TensorNDArray& src, - _megdnn_tensor_out dst) { +void ElemwiseMultiTypeImplHelper::exec( + _megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) { switch (m_param.mode) { case Mode::FUSE_MUL_ADD3_INT16x32x32x32: - on_fuse_mul_add3_int16x32x32x32(make_elemwise_op_param<3>(src, dst), - dst.ptr()); + on_fuse_mul_add3_int16x32x32x32( + make_elemwise_op_param<3>(src, dst), dst.ptr()); break; case Mode::FUSE_MUL_ADD3_IXxF32xF32xI8: - on_fuse_mul_add3_iXxf32xf32xi8(make_elemwise_op_param<3>(src, dst), - dst.ptr()); + on_fuse_mul_add3_iXxf32xf32xi8( + make_elemwise_op_param<3>(src, dst), dst.ptr()); break; case Mode::ROUND_SHR_SATURATE_IXxI8xI8: - on_round_shr_saturate_iXxi8xi8(make_elemwise_op_param<2>(src, dst), - dst.ptr()); + on_round_shr_saturate_iXxi8xi8( + make_elemwise_op_param<2>(src, dst), dst.ptr()); break; case Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8: on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( @@ -44,61 +44,61 @@ void ElemwiseMultiTypeImplHelper::exec(_megdnn_in const TensorNDArray& src, make_elemwise_op_param<6>(src, dst), dst.ptr()); break; case Mode::ROUND_SHR_SATURATE_IXxI8xI16: - on_round_shr_saturate_iXxi8xi16(make_elemwise_op_param<2>(src, dst), - dst.ptr()); + on_round_shr_saturate_iXxi8xi16( + make_elemwise_op_param<2>(src, dst), dst.ptr()); break; - ON_QUANTIZED_MODE(RELU, 1); - ON_QUANTIZED_MODE(ABS, 1); - ON_QUANTIZED_MODE(ACOS, 1); - ON_QUANTIZED_MODE(ASIN, 1); - ON_QUANTIZED_MODE(CEIL, 1); - ON_QUANTIZED_MODE(COS, 1); - ON_QUANTIZED_MODE(EXP, 1); - ON_QUANTIZED_MODE(EXPM1, 1); - ON_QUANTIZED_MODE(FLOOR, 1); - ON_QUANTIZED_MODE(LOG, 1); - ON_QUANTIZED_MODE(LOG1P, 1); - ON_QUANTIZED_MODE(NEGATE, 1); - ON_QUANTIZED_MODE(SIGMOID, 1); - ON_QUANTIZED_MODE(SIN, 1); - ON_QUANTIZED_MODE(TANH, 1); - ON_QUANTIZED_MODE(FAST_TANH, 1); - ON_QUANTIZED_MODE(ROUND, 1); - ON_QUANTIZED_MODE(ERF, 1); - ON_QUANTIZED_MODE(ERFINV, 1); - ON_QUANTIZED_MODE(ERFC, 1); - ON_QUANTIZED_MODE(ERFCINV, 1); - ON_QUANTIZED_MODE(H_SWISH, 1); + ON_QUANTIZED_MODE(RELU, 1); + ON_QUANTIZED_MODE(ABS, 1); + ON_QUANTIZED_MODE(ACOS, 1); + ON_QUANTIZED_MODE(ASIN, 1); + ON_QUANTIZED_MODE(CEIL, 1); + ON_QUANTIZED_MODE(COS, 1); + ON_QUANTIZED_MODE(EXP, 1); + ON_QUANTIZED_MODE(EXPM1, 1); + ON_QUANTIZED_MODE(FLOOR, 1); + ON_QUANTIZED_MODE(LOG, 1); + ON_QUANTIZED_MODE(LOG1P, 1); + ON_QUANTIZED_MODE(NEGATE, 1); + ON_QUANTIZED_MODE(SIGMOID, 1); + ON_QUANTIZED_MODE(SIN, 1); + ON_QUANTIZED_MODE(TANH, 1); + ON_QUANTIZED_MODE(FAST_TANH, 1); + ON_QUANTIZED_MODE(ROUND, 1); + ON_QUANTIZED_MODE(ERF, 1); + ON_QUANTIZED_MODE(ERFINV, 1); + ON_QUANTIZED_MODE(ERFC, 1); + ON_QUANTIZED_MODE(ERFCINV, 1); + ON_QUANTIZED_MODE(H_SWISH, 1); - ON_QUANTIZED_MODE(ABS_GRAD, 2); - ON_QUANTIZED_MODE(ADD, 2); - ON_QUANTIZED_MODE(FLOOR_DIV, 2); - ON_QUANTIZED_MODE(MAX, 2); - ON_QUANTIZED_MODE(MIN, 2); - ON_QUANTIZED_MODE(MOD, 2); - ON_QUANTIZED_MODE(MUL, 2); - ON_QUANTIZED_MODE(POW, 2); - ON_QUANTIZED_MODE(SIGMOID_GRAD, 2); - ON_QUANTIZED_MODE(SUB, 2); - ON_QUANTIZED_MODE(SWITCH_GT0, 2); - ON_QUANTIZED_MODE(TANH_GRAD, 2); - ON_QUANTIZED_MODE(TRUE_DIV, 2); - ON_QUANTIZED_MODE(LOG_SUM_EXP, 2); + ON_QUANTIZED_MODE(ABS_GRAD, 2); + ON_QUANTIZED_MODE(ADD, 2); + ON_QUANTIZED_MODE(FLOOR_DIV, 2); + ON_QUANTIZED_MODE(MAX, 2); + ON_QUANTIZED_MODE(MIN, 2); + ON_QUANTIZED_MODE(MOD, 2); + ON_QUANTIZED_MODE(MUL, 2); + ON_QUANTIZED_MODE(POW, 2); + ON_QUANTIZED_MODE(SIGMOID_GRAD, 2); + ON_QUANTIZED_MODE(SUB, 2); + ON_QUANTIZED_MODE(SWITCH_GT0, 2); + ON_QUANTIZED_MODE(TANH_GRAD, 2); + ON_QUANTIZED_MODE(TRUE_DIV, 2); + ON_QUANTIZED_MODE(LOG_SUM_EXP, 2); - ON_QUANTIZED_MODE(LT, 2); - ON_QUANTIZED_MODE(LEQ, 2); - ON_QUANTIZED_MODE(EQ, 2); + ON_QUANTIZED_MODE(LT, 2); + ON_QUANTIZED_MODE(LEQ, 2); + ON_QUANTIZED_MODE(EQ, 2); - ON_QUANTIZED_MODE(FUSE_ADD_RELU, 2); - ON_QUANTIZED_MODE(FUSE_ADD_SIGMOID, 2); - ON_QUANTIZED_MODE(FUSE_ADD_TANH, 2); - ON_QUANTIZED_MODE(FAST_TANH_GRAD, 2); - ON_QUANTIZED_MODE(ATAN2, 2); - ON_QUANTIZED_MODE(H_SWISH_GRAD, 2); - ON_QUANTIZED_MODE(FUSE_ADD_H_SWISH, 2); + ON_QUANTIZED_MODE(FUSE_ADD_RELU, 2); + ON_QUANTIZED_MODE(FUSE_ADD_SIGMOID, 2); + ON_QUANTIZED_MODE(FUSE_ADD_TANH, 2); + ON_QUANTIZED_MODE(FAST_TANH_GRAD, 2); + ON_QUANTIZED_MODE(ATAN2, 2); + ON_QUANTIZED_MODE(H_SWISH_GRAD, 2); + ON_QUANTIZED_MODE(FUSE_ADD_H_SWISH, 2); - ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3); - ON_QUANTIZED_MODE(COND_LEQ_MOV, 3); + ON_QUANTIZED_MODE(FUSE_MUL_ADD3, 3); + ON_QUANTIZED_MODE(COND_LEQ_MOV, 3); default: megdnn_throw("invalid mode"); } diff --git a/dnn/src/common/elemwise_multi_type/opr_impl_helper.h b/dnn/src/common/elemwise_multi_type/opr_impl_helper.h index c0196908..7496ef7c 100644 --- a/dnn/src/common/elemwise_multi_type/opr_impl_helper.h +++ b/dnn/src/common/elemwise_multi_type/opr_impl_helper.h @@ -18,16 +18,15 @@ namespace megdnn { class ElemwiseMultiTypeImplHelper : public ElemwiseMultiType, protected ElemwiseLayoutHelper { - static void call_check_layout_and_broadcast(void* opr, - const TensorLayoutPtrArray& src, - const TensorLayout& dst) { + static void call_check_layout_and_broadcast( + void* opr, const TensorLayoutPtrArray& src, const TensorLayout& dst) { return static_cast(opr) ->check_layout_and_broadcast(src, dst); } template - ElemwiseOpParamN make_elemwise_op_param(const TensorNDArray& src, - const TensorND& dst) { + ElemwiseOpParamN make_elemwise_op_param( + const TensorNDArray& src, const TensorND& dst) { return ElemwiseLayoutHelper::make_elemwise_op_param( this, call_check_layout_and_broadcast, src, dst); } @@ -51,22 +50,22 @@ protected: virtual void on_round_shr_saturate_iXxi8xi16( const ElemwiseOpParamN<2>& param, dt_int16* dst) = 0; - virtual void on_quantized_mode(const ElemwiseOpParamN<1>& param, - const TensorND& dst, - Elemwise::Mode mode) { + virtual void on_quantized_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst, + Elemwise::Mode mode) { MEGDNN_MARK_USED_VAR(param); MEGDNN_MARK_USED_VAR(dst); MEGDNN_MARK_USED_VAR(mode); megdnn_throw("Unrealized except arm_common"); } - virtual void on_quantized_mode(const ElemwiseOpParamN<2>& param, - const TensorND& dst, - Elemwise::Mode mode) = 0; + virtual void on_quantized_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst, + Elemwise::Mode mode) = 0; - virtual void on_quantized_mode(const ElemwiseOpParamN<3>& param, - const TensorND& dst, - Elemwise::Mode mode) { + virtual void on_quantized_mode( + const ElemwiseOpParamN<3>& param, const TensorND& dst, + Elemwise::Mode mode) { MEGDNN_MARK_USED_VAR(param); MEGDNN_MARK_USED_VAR(dst); MEGDNN_MARK_USED_VAR(mode); @@ -76,8 +75,8 @@ protected: public: using ElemwiseMultiType::ElemwiseMultiType; - void exec(_megdnn_in const TensorNDArray& src, - _megdnn_tensor_out dst) override final; + void exec( + _megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) override final; }; } // namespace megdnn diff --git a/dnn/src/common/eye.cpp b/dnn/src/common/eye.cpp index f0ba6f92..45c49e70 100644 --- a/dnn/src/common/eye.cpp +++ b/dnn/src/common/eye.cpp @@ -14,15 +14,13 @@ namespace megdnn { -void Eye::check_exec(const TensorLayout &dst, size_t workspace_in_bytes) -{ +void Eye::check_exec(const TensorLayout& dst, size_t workspace_in_bytes) { megdnn_assert(dst.ndim == 2 && dst.dtype.enumv() == param().dtype); megdnn_assert_contiguous(dst); auto required_workspace_in_bytes = get_workspace_in_bytes(dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/common/fake_quant.cpp b/dnn/src/common/fake_quant.cpp index 2334dab4..7dcf39d9 100644 --- a/dnn/src/common/fake_quant.cpp +++ b/dnn/src/common/fake_quant.cpp @@ -15,15 +15,13 @@ namespace megdnn { -void FakeQuantBase::deduce_layout_fwd(const TensorLayout& input, - TensorLayout& output) { +void FakeQuantBase::deduce_layout_fwd(const TensorLayout& input, TensorLayout& output) { output = TensorLayout(input, input.dtype); } -void FakeQuantBase::check_layout_fwd(const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& output) { +void FakeQuantBase::check_layout_fwd( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& output) { megdnn_assert(input.dtype == dtype::Float32()); megdnn_assert(scale.dtype == dtype::Float32()); megdnn_assert(zero_point.dtype == dtype::Float32()); @@ -32,30 +30,26 @@ void FakeQuantBase::check_layout_fwd(const TensorLayout& input, megdnn_assert_eq_layout(expected, output); } -void FakeQuantForward::deduce_layout(const TensorLayout& input, - const TensorLayout& /*scale*/, - const TensorLayout& /*zero_point*/, - TensorLayout& output) { +void FakeQuantForward::deduce_layout( + const TensorLayout& input, const TensorLayout& /*scale*/, + const TensorLayout& /*zero_point*/, TensorLayout& output) { deduce_layout_fwd(input, output); } -void FakeQuantForward::check_exec(const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& output, - size_t workspace_in_bytes) { +void FakeQuantForward::check_exec( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& output, + size_t workspace_in_bytes) { check_layout_fwd(input, scale, zero_point, output); auto required_workspace_space = get_workspace_in_bytes(input, scale, zero_point, output); megdnn_assert(workspace_in_bytes >= required_workspace_space); } -void FakeQuantBackward::check_exec(const TensorLayout& diff, - const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& grad, - size_t workspace_in_bytes) { +void FakeQuantBackward::check_exec( + const TensorLayout& diff, const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& grad, + size_t workspace_in_bytes) { megdnn_assert_eq_shape(input, diff); megdnn_assert_eq_shape(input, grad); auto required_worspace_space = diff --git a/dnn/src/common/flag_warn.cpp b/dnn/src/common/flag_warn.cpp index a19d9ff7..400bd5be 100644 --- a/dnn/src/common/flag_warn.cpp +++ b/dnn/src/common/flag_warn.cpp @@ -12,7 +12,7 @@ #include "megdnn/config/config.h" #if !MEGDNN_ENABLE_MANGLING - #pragma message "Mangling is disabled." -#endif // MEGDNN_ENABLE_MANGLING +#pragma message "Mangling is disabled." +#endif // MEGDNN_ENABLE_MANGLING // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/flip.cpp b/dnn/src/common/flip.cpp index 06e29b14..6c637411 100644 --- a/dnn/src/common/flip.cpp +++ b/dnn/src/common/flip.cpp @@ -14,13 +14,13 @@ namespace megdnn { -void FlipBase::deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst) -{ +void FlipBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { auto errmsg = [&]() { return megdnn_layout_msg(src); }; MEGDNN_MARK_USED_VAR(errmsg); - megdnn_assert(src.ndim == 4_z && (src.shape[3] == 1_z || - src.shape[3] == 3_z), "%s", errmsg().c_str()); + megdnn_assert( + src.ndim == 4_z && (src.shape[3] == 1_z || src.shape[3] == 3_z), "%s", + errmsg().c_str()); size_t in = src.shape[0]; size_t ih = src.shape[1]; @@ -30,28 +30,24 @@ void FlipBase::deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst) dst = TensorLayout(TensorShape({in, ih, iw, ic}), src.dtype); } -void FlipBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &dst) -{ +void FlipBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { TensorLayout dst_expected; megdnn_assert_eq_dtype(src, dst); deduce_layout_fwd(src, dst_expected); megdnn_assert_eq_shape(dst_expected, dst); } -void Flip::deduce_layout(const TensorLayout &src, TensorLayout &dst) -{ +void Flip::deduce_layout(const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } -void Flip::check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void Flip::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/gaussian_blur.cpp b/dnn/src/common/gaussian_blur.cpp index ef742376..9a469a1d 100644 --- a/dnn/src/common/gaussian_blur.cpp +++ b/dnn/src/common/gaussian_blur.cpp @@ -9,19 +9,19 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megdnn/oprs.h" -#include "src/common/utils.h" #include "src/common/cv/common.h" #include "src/common/cv/helper.h" +#include "src/common/utils.h" namespace megdnn { -void GaussianBlurBase::deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst) -{ +void GaussianBlurBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { auto errmsg = [&]() { return megdnn_layout_msg(src); }; MEGDNN_MARK_USED_VAR(errmsg); - megdnn_assert(src.ndim == 4_z && (src.shape[3] == 1_z || - src.shape[3] == 3_z), "%s", errmsg().c_str()); + megdnn_assert( + src.ndim == 4_z && (src.shape[3] == 1_z || src.shape[3] == 3_z), "%s", + errmsg().c_str()); size_t in = src.shape[0]; size_t ih = src.shape[1]; @@ -31,28 +31,25 @@ void GaussianBlurBase::deduce_layout_fwd(const TensorLayout &src, TensorLayout & dst = TensorLayout(TensorShape({in, ih, iw, ic}), src.dtype); } -void GaussianBlurBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &dst) -{ +void GaussianBlurBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& dst) { TensorLayout dst_expected; megdnn_assert_eq_dtype(src, dst); deduce_layout_fwd(src, dst_expected); megdnn_assert_eq_shape(dst_expected, dst); } -void GaussianBlur::deduce_layout(const TensorLayout &src, TensorLayout &dst) -{ +void GaussianBlur::deduce_layout(const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } -void GaussianBlur::check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void GaussianBlur::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/gaussian_blur_helper.h b/dnn/src/common/gaussian_blur_helper.h index 72681a92..d94640c9 100644 --- a/dnn/src/common/gaussian_blur_helper.h +++ b/dnn/src/common/gaussian_blur_helper.h @@ -25,13 +25,11 @@ inline static Mat getGaussianKernel(size_t n, double sigma) { {1.f}, {0.25f, 0.5f, 0.25f}, {0.0625f, 0.25f, 0.375f, 0.25f, 0.0625f}, - {0.03125f, 0.109375f, 0.21875f, 0.28125f, 0.21875f, 0.109375f, - 0.03125f}}; + {0.03125f, 0.109375f, 0.21875f, 0.28125f, 0.21875f, 0.109375f, 0.03125f}}; - const float* fixed_kernel = - n % 2 == 1 && n <= SMALL_GAUSSIAN_SIZE && sigma <= 0 - ? small_gaussian_tab[n >> 1] - : 0; + const float* fixed_kernel = n % 2 == 1 && n <= SMALL_GAUSSIAN_SIZE && sigma <= 0 + ? small_gaussian_tab[n >> 1] + : 0; Mat kernel(1, n, 1); @@ -44,8 +42,7 @@ inline static Mat getGaussianKernel(size_t n, double sigma) { int i; for (i = 0; i < (int)n; i++) { double x = i - (n - 1) * 0.5; - double t = fixed_kernel ? (double)fixed_kernel[i] - : std::exp(scale2X * x * x); + double t = fixed_kernel ? (double)fixed_kernel[i] : std::exp(scale2X * x * x); { c[i] = (T)t; sum += c[i]; @@ -60,28 +57,25 @@ inline static Mat getGaussianKernel(size_t n, double sigma) { } template -inline static void createGaussianKernels(Mat& kx, Mat& ky, Size ksize, - double sigma1, double sigma2) { +inline static void createGaussianKernels( + Mat& kx, Mat& ky, Size ksize, double sigma1, double sigma2) { if (sigma2 <= 0) sigma2 = sigma1; if (ksize.cols() <= 0 && sigma1 > 0) { - double num = - sigma1 * (std::is_same::value ? 3 : 4) * 2 + - 1; + double num = sigma1 * (std::is_same::value ? 3 : 4) * 2 + 1; num = (int)(num + (num >= 0 ? 0.5 : -0.5)); ksize.cols() = ((int)num) | 1; } if (ksize.rows() <= 0 && sigma2 > 0) { - double num = - sigma2 * (std::is_same::value ? 3 : 4) * 2 + - 1; + double num = sigma2 * (std::is_same::value ? 3 : 4) * 2 + 1; num = (int)(num + (num >= 0 ? 0.5 : -0.5)); ksize.rows() = ((int)num) | 1; } - megdnn_assert(ksize.cols() > 0 && ksize.cols() % 2 == 1 && - ksize.rows() > 0 && ksize.rows() % 2 == 1); + megdnn_assert( + ksize.cols() > 0 && ksize.cols() % 2 == 1 && ksize.rows() > 0 && + ksize.rows() % 2 == 1); sigma1 = std::max(sigma1, 0.); sigma2 = std::max(sigma2, 0.); diff --git a/dnn/src/common/group_local.cpp b/dnn/src/common/group_local.cpp index 72485aef..859bc30d 100644 --- a/dnn/src/common/group_local.cpp +++ b/dnn/src/common/group_local.cpp @@ -14,13 +14,11 @@ namespace megdnn { -void GroupLocalBase::deduce_layout_fwd(const TensorLayout &src, - const TensorLayout &filter, - TensorLayout &dst) -{ +void GroupLocalBase::deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) { auto errmsg = [&]() { - return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + - ", " + megdnn_layout_msg(dst) + ", " + + return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " + + megdnn_layout_msg(dst) + ", " + "pad_h=" + std::to_string(param().pad_h) + ", " + "pad_w=" + std::to_string(param().pad_w) + ", " + "stride_h=" + std::to_string(param().stride_h) + ", " + @@ -29,13 +27,15 @@ void GroupLocalBase::deduce_layout_fwd(const TensorLayout &src, MEGDNN_MARK_USED_VAR(errmsg); megdnn_assert_contiguous(src); megdnn_assert_contiguous(filter); - megdnn_assert(param().mode == Mode::CROSS_CORRELATION, + megdnn_assert( + param().mode == Mode::CROSS_CORRELATION, "only CROSS_CORRELATION mode is supported for glocal."); - megdnn_assert(param().sparse == Param::Sparse::DENSE && - param().dilate_h == 1 && param().dilate_w == 1 && - src.dtype.category() == DTypeCategory::FLOAT && - src.dtype == dst.dtype, + megdnn_assert( + param().sparse == Param::Sparse::DENSE && param().dilate_h == 1 && + param().dilate_w == 1 && + src.dtype.category() == DTypeCategory::FLOAT && + src.dtype == dst.dtype, "unsupported conv param for Local opr"); megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); megdnn_assert(filter.ndim == 7_z, "%s", errmsg().c_str()); @@ -44,59 +44,53 @@ void GroupLocalBase::deduce_layout_fwd(const TensorLayout &src, size_t ic = src[1]; size_t ih = src[2]; size_t iw = src[3]; - size_t oc = filter[6]*group; + size_t oc = filter[6] * group; size_t oh = filter[1], ow = filter[2]; megdnn_assert_eq_size_t(filter[0], group); - megdnn_assert_eq_size_t(filter[3]*group, ic); + megdnn_assert_eq_size_t(filter[3] * group, ic); size_t fh = filter[4], fw = filter[5]; // (group, oh, ow, ic/group, fh, fw, oc/group) - infer_conv_shape2d(ih, iw, fh, fw, - param().stride_h, param().stride_w, - param().pad_h, param().pad_w, oh, ow); + infer_conv_shape2d( + ih, iw, fh, fw, param().stride_h, param().stride_w, param().pad_h, + param().pad_w, oh, ow); dst = TensorLayout(TensorShape({n, oc, oh, ow}), src.dtype); } -void GroupLocalBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) -{ +void GroupLocalBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { TensorLayout dst_expected{dst.dtype}; megdnn_assert_eq_dtype(src, filter); megdnn_assert_eq_dtype(src, dst); deduce_layout_fwd(src, filter, dst_expected); megdnn_assert_eq_layout(dst_expected, dst); - megdnn_assert(src.dtype == dtype::Float32() || DNN_FLOAT16_SELECT(src.dtype == dtype::Float16(), true)); + megdnn_assert( + src.dtype == dtype::Float32() || + DNN_FLOAT16_SELECT(src.dtype == dtype::Float16(), true)); } -void GroupLocalForward::check_exec(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void GroupLocalForward::check_exec( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, + size_t workspace_in_bytes) { check_layout_fwd(src, filter, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void GroupLocalBackwardData::check_exec(const TensorLayout &filter, - const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void GroupLocalBackwardData::check_exec( + const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { check_layout_fwd(grad, filter, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void GroupLocalBackwardFilter::check_exec(const TensorLayout &src, - const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void GroupLocalBackwardFilter::check_exec( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { check_layout_fwd(src, grad, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/handle.cpp b/dnn/src/common/handle.cpp index 75b282d8..17f4718c 100644 --- a/dnn/src/common/handle.cpp +++ b/dnn/src/common/handle.cpp @@ -51,8 +51,8 @@ MIDOUT_DECL(HandleOpr); Handle::Handle(megcoreComputingHandle_t computing_handle, HandleType type) : m_computing_handle(computing_handle), m_handle_type(type) {} -std::unique_ptr Handle::make(megcoreComputingHandle_t computing_handle, - int debug_level) { +std::unique_ptr Handle::make( + megcoreComputingHandle_t computing_handle, int debug_level) { (void)debug_level; megcoreDeviceHandle_t device_handle; megcorePlatform_t platform; @@ -114,8 +114,9 @@ std::unique_ptr Handle::make(megcoreComputingHandle_t computing_handle, } else { // CUDA - megdnn_throw_if(platform != megcorePlatformCUDA, megdnn_error, - "platform should be CUDA Platform"); + megdnn_throw_if( + platform != megcorePlatformCUDA, megdnn_error, + "platform should be CUDA Platform"); #if MEGDNN_WITH_CUDA return make_unique(computing_handle); #else diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index e77fc865..adb744ff 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -29,19 +29,13 @@ public: using Handle::Handle; //! global matmul opr - virtual MatrixMul* matmul_opr() { - megdnn_throw("Unimplement matmul opr.\n"); - } + virtual MatrixMul* matmul_opr() { megdnn_throw("Unimplement matmul opr.\n"); } //! global matmul opr with first operand transposed - virtual MatrixMul* matmul_aT_opr() { - megdnn_throw("Unimplement matmul_aT opr.\n"); - } + virtual MatrixMul* matmul_aT_opr() { megdnn_throw("Unimplement matmul_aT opr.\n"); } //! global matmul opr with second operand transposed - virtual MatrixMul* matmul_bT_opr() { - megdnn_throw("Unimplement matmul_bT opr.\n"); - } + virtual MatrixMul* matmul_bT_opr() { megdnn_throw("Unimplement matmul_bT opr.\n"); } //! global matmul opr with both operand transposed virtual MatrixMul* matmul_aT_bT_opr() { @@ -49,13 +43,9 @@ public: } //! global relayout opr - virtual Relayout* relayout_opr() { - megdnn_throw("Unimplement Relayout opr.\n"); - } + virtual Relayout* relayout_opr() { megdnn_throw("Unimplement Relayout opr.\n"); } - virtual Checksum* checksum_opr() { - megdnn_throw("Unimplement Checksum opr.\n"); - } + virtual Checksum* checksum_opr() { megdnn_throw("Unimplement Checksum opr.\n"); } virtual MaxTensorDiff* max_tensor_diff_opr() { megdnn_throw("Unimplement MaxTensorDiff opr.\n"); @@ -65,17 +55,14 @@ protected: static constexpr size_t NR_HELPER_OPRS = 7; template - static Opr* get_helper_opr(Self self, - const typename Opr::Param& param = {}) { + static Opr* get_helper_opr(Self self, const typename Opr::Param& param = {}) { MIDOUT_BEGIN(dnn_src_common_handle_impl, Opr, idx) { static_assert(idx < NR_HELPER_OPRS, "invalid idx"); if (!self->m_helper_oprs[idx]) { MEGDNN_LOCK_GUARD(self->m_helper_oprs_mtx); if (!self->m_helper_oprs[idx]) { - self->m_helper_oprs[idx] = - self->template create_operator(); - auto ret = - static_cast(self->m_helper_oprs[idx].get()); + self->m_helper_oprs[idx] = self->template create_operator(); + auto ret = static_cast(self->m_helper_oprs[idx].get()); ret->param() = param; megdnn_assert(ret->is_thread_safe()); return ret; @@ -96,132 +83,70 @@ private: * \brief iterate though each operator class name; useful for explicit * instantialization of create_operator<> templates */ -#define MEGDNN_FOREACH_OPR_CLASS(cb) \ - cb(ConvolutionForward) \ - cb(ConvolutionBackwardData) \ - cb(ConvolutionBackwardFilter) \ - cb(ConvPoolingForward) \ - cb(ConvBiasForward) \ - cb(Images2NeibsForward) \ - cb(Images2NeibsBackward) \ - cb(SlidingWindowTransposeForward) \ - cb(SlidingWindowTransposeBackward) \ - cb(ElemwiseForward) \ - cb(ElemwiseMultiType) \ - cb(AddUpdateForward) \ - cb(RelayoutForward) \ - cb(PoolingForward) \ - cb(PoolingBackward) \ - cb(LocalForward) \ - cb(LocalBackwardData) \ - cb(LocalBackwardFilter) \ - cb(LRNForward) \ - cb(LRNBackward) \ - cb(ROIPoolingForward) \ - cb(ROIPoolingBackward) \ - cb(WarpPerspectiveForward) \ - cb(WarpPerspectiveBackwardData) \ - cb(WarpPerspectiveBackwardMat) \ - cb(DotForward) \ - cb(MatrixInverse) \ - cb(MatrixMulForward) \ - cb(BatchedMatrixMulForward) \ - cb(SVDForward) \ - cb(ReduceForward) \ - cb(CondTake) \ - cb(CumsumForward) \ - cb(ArgmaxForward) \ - cb(ArgminForward) \ - cb(TransposeForward) \ - cb(ConcatForward) \ - cb(SplitForward) \ - cb(TileForward) \ - cb(TileBackward) \ - cb(RepeatForward) \ - cb(RepeatBackward) \ - cb(ArgsortForward) \ - cb(ArgsortBackward) \ - cb(TypeCvt) \ - cb(IndexingRemapForward) \ - cb(IndexingRemapBackward) \ - cb(ChecksumForward) \ - cb(IndexingOneHotForward) \ - cb(IndexingSetOneHotForward) \ - cb(IndexingMultiAxisVec) \ - cb(IndexingSetMultiAxisVec) \ - cb(IndexingIncrMultiAxisVec) \ - cb(MeshIndexing) \ - cb(IncrMeshIndexing) \ - cb(SetMeshIndexing) \ - cb(BatchedMeshIndexing) \ - cb(BatchedIncrMeshIndexing) \ - cb(BatchedSetMeshIndexing) \ - cb(Linspace) \ - cb(Eye) \ - cb(SleepForward) \ - cb(UniformRNG) \ - cb(GaussianRNG) \ - cb(GammaRNG) \ - cb(BetaRNG) \ - cb(PoissonRNG) \ - cb(PermutationRNG) \ - cb(ShuffleRNGForward) \ - cb(ShuffleRNGBackward) \ - cb(SeparableConvForward) \ - cb(SeparableFilterForward) \ - cb(BNForward) \ - cb(BNBackward) \ - cb(GroupLocalForward) \ - cb(GroupLocalBackwardData) \ - cb(GroupLocalBackwardFilter) \ - cb(Flip) \ - cb(Rotate) \ - cb(ROICopy) \ - cb(CvtColor) \ - cb(WarpAffine) \ - cb(GaussianBlur) \ - cb(Resize) \ - cb(ResizeBackward) \ - cb(ParamPackConcat) \ - cb(MaxTensorDiff) \ - cb(MaskConvForward) \ - cb(MaskPropagate) \ - cb(Convolution3DForward) \ - cb(Convolution3DBackwardData) \ - cb(Convolution3DBackwardFilter) \ - cb(DeformableConvForward) \ - cb(DeformableConvBackwardFilter) \ - cb(DeformableConvBackwardData) \ - cb(DeformablePSROIPoolingForward) \ - cb(DeformablePSROIPoolingBackward) \ - cb(RelayoutFormat) \ - cb(TopK) \ - cb(PowC) \ - cb(LocalShareForward) \ - cb(LocalShareBackwardData) \ - cb(LocalShareBackwardFilter) \ - cb(ROIAlignForward) \ - cb(ROIAlignBackward) \ - cb(CorrelationForward) \ - cb(CorrelationBackwardData1) \ - cb(CorrelationBackwardData2) \ - cb(BatchConvBiasForward) \ - cb(Remap) \ - cb(RemapBackwardData) \ - cb(RemapBackwardMat) \ - cb(AdaptivePoolingForward) \ - cb(AdaptivePoolingBackward) \ - cb(DctChannelSelectForward) \ - cb(FakeQuantForward) \ - cb(FakeQuantBackward) \ - cb(TQTForward) \ - cb(TQTBackward) \ - cb(CheckNonFinite) \ - cb(LSQForward) \ - cb(LSQBackward) \ - cb(Fill) \ - cb(PaddingForward) \ - cb(PaddingBackward) +#define MEGDNN_FOREACH_OPR_CLASS(cb) \ + cb(ConvolutionForward) cb(ConvolutionBackwardData) cb(ConvolutionBackwardFilter) cb( \ + ConvPoolingForward) cb(ConvBiasForward) cb(Images2NeibsForward) cb(Images2NeibsBackward) \ + cb(SlidingWindowTransposeForward) cb(SlidingWindowTransposeBackward) cb( \ + ElemwiseForward) cb(ElemwiseMultiType) cb(AddUpdateForward) \ + cb(RelayoutForward) cb(PoolingForward) cb(PoolingBackward) cb( \ + LocalForward) cb(LocalBackwardData) cb(LocalBackwardFilter) \ + cb(LRNForward) cb(LRNBackward) cb(ROIPoolingForward) cb( \ + ROIPoolingBackward) cb(WarpPerspectiveForward) \ + cb(WarpPerspectiveBackwardData) cb( \ + WarpPerspectiveBackwardMat) cb(DotForward) \ + cb(MatrixInverse) cb(MatrixMulForward) cb( \ + BatchedMatrixMulForward) \ + cb(SVDForward) cb( \ + ReduceForward) cb(CondTake) \ + cb(CumsumForward) cb( \ + ArgmaxForward) \ + cb(ArgminForward) \ + cb(TransposeForward) \ + cb(ConcatForward) \ + cb(SplitForward) \ + cb(TileForward) \ + cb(TileBackward) \ + cb(RepeatForward) \ + cb(RepeatBackward) \ + cb(ArgsortForward) \ + cb(ArgsortBackward) \ + cb(TypeCvt) \ + cb(IndexingRemapForward) \ + cb(IndexingRemapBackward) \ + cb(ChecksumForward) cb(IndexingOneHotForward) cb(IndexingSetOneHotForward) cb(IndexingMultiAxisVec) cb(IndexingSetMultiAxisVec) cb(IndexingIncrMultiAxisVec) \ + cb( \ + MeshIndexing) cb(IncrMeshIndexing) cb(SetMeshIndexing) cb(BatchedMeshIndexing) cb(BatchedIncrMeshIndexing) cb(BatchedSetMeshIndexing) cb(Linspace) cb(Eye) cb(SleepForward) \ + cb(UniformRNG) cb(GaussianRNG) cb( \ + GammaRNG) \ + cb(BetaRNG) cb(PoissonRNG) cb(PermutationRNG) cb(ShuffleRNGForward) cb(ShuffleRNGBackward) cb(SeparableConvForward) cb( \ + SeparableFilterForward) \ + cb( \ + BNForward) cb(BNBackward) cb(GroupLocalForward) cb(GroupLocalBackwardData) \ + cb(GroupLocalBackwardFilter) \ + cb(Flip) cb( \ + Rotate) \ + cb( \ + ROICopy) cb(CvtColor) cb(WarpAffine) cb(GaussianBlur) cb(Resize) cb(ResizeBackward) \ + cb(ParamPackConcat) cb(MaxTensorDiff) cb(MaskConvForward) cb( \ + MaskPropagate) \ + cb(Convolution3DForward) \ + cb(Convolution3DBackwardData) cb(Convolution3DBackwardFilter) cb(DeformableConvForward) cb( \ + DeformableConvBackwardFilter) \ + cb( \ + DeformableConvBackwardData) cb(DeformablePSROIPoolingForward) cb(DeformablePSROIPoolingBackward) cb(RelayoutFormat) cb(TopK) \ + cb(PowC) cb(LocalShareForward) cb( \ + LocalShareBackwardData) cb(LocalShareBackwardFilter) \ + cb( \ + ROIAlignForward) cb(ROIAlignBackward) cb(CorrelationForward) cb(CorrelationBackwardData1) cb(CorrelationBackwardData2) cb(BatchConvBiasForward) cb(Remap) cb(RemapBackwardData) cb(RemapBackwardMat) cb(AdaptivePoolingForward) cb(AdaptivePoolingBackward) \ + cb(DctChannelSelectForward) cb(FakeQuantForward) cb(FakeQuantBackward) \ + cb(TQTForward) cb( \ + TQTBackward) \ + cb(CheckNonFinite) \ + cb(LSQForward) cb( \ + LSQBackward) \ + cb(Fill) cb( \ + PaddingForward) \ + cb(PaddingBackward) /*! * \brief specialize HandleImpl::create_operator for a single opr type; diff --git a/dnn/src/common/hash_ct.h b/dnn/src/common/hash_ct.h index 679686fa..36e0effe 100644 --- a/dnn/src/common/hash_ct.h +++ b/dnn/src/common/hash_ct.h @@ -67,8 +67,7 @@ private: static constexpr uint64_t rotl(uint64_t x, int r) { return ((x << r) | (x >> (64 - r))); } - static constexpr uint64_t mix1(const uint64_t h, const uint64_t prime, - int rshift) { + static constexpr uint64_t mix1(const uint64_t h, const uint64_t prime, int rshift) { return (h ^ (h >> rshift)) * prime; } static constexpr uint64_t mix2(const uint64_t p, const uint64_t v = 0) { @@ -80,32 +79,24 @@ private: #ifdef XXH64_BIG_ENDIAN static constexpr uint32_t endian32(const char* v) { return uint32_t(uint8_t(v[3])) | (uint32_t(uint8_t(v[2])) << 8) | - (uint32_t(uint8_t(v[1])) << 16) | - (uint32_t(uint8_t(v[0])) << 24); + (uint32_t(uint8_t(v[1])) << 16) | (uint32_t(uint8_t(v[0])) << 24); } static constexpr uint64_t endian64(const char* v) { return uint64_t(uint8_t(v[7])) | (uint64_t(uint8_t(v[6])) << 8) | - (uint64_t(uint8_t(v[5])) << 16) | - (uint64_t(uint8_t(v[4])) << 24) | - (uint64_t(uint8_t(v[3])) << 32) | - (uint64_t(uint8_t(v[2])) << 40) | - (uint64_t(uint8_t(v[1])) << 48) | - (uint64_t(uint8_t(v[0])) << 56); + (uint64_t(uint8_t(v[5])) << 16) | (uint64_t(uint8_t(v[4])) << 24) | + (uint64_t(uint8_t(v[3])) << 32) | (uint64_t(uint8_t(v[2])) << 40) | + (uint64_t(uint8_t(v[1])) << 48) | (uint64_t(uint8_t(v[0])) << 56); } #else static constexpr uint32_t endian32(const char* v) { return uint32_t(uint8_t(v[0])) | (uint32_t(uint8_t(v[1])) << 8) | - (uint32_t(uint8_t(v[2])) << 16) | - (uint32_t(uint8_t(v[3])) << 24); + (uint32_t(uint8_t(v[2])) << 16) | (uint32_t(uint8_t(v[3])) << 24); } static constexpr uint64_t endian64(const char* v) { return uint64_t(uint8_t(v[0])) | (uint64_t(uint8_t(v[1])) << 8) | - (uint64_t(uint8_t(v[2])) << 16) | - (uint64_t(uint8_t(v[3])) << 24) | - (uint64_t(uint8_t(v[4])) << 32) | - (uint64_t(uint8_t(v[5])) << 40) | - (uint64_t(uint8_t(v[6])) << 48) | - (uint64_t(uint8_t(v[7])) << 56); + (uint64_t(uint8_t(v[2])) << 16) | (uint64_t(uint8_t(v[3])) << 24) | + (uint64_t(uint8_t(v[4])) << 32) | (uint64_t(uint8_t(v[5])) << 40) | + (uint64_t(uint8_t(v[6])) << 48) | (uint64_t(uint8_t(v[7])) << 56); } #endif static constexpr uint64_t fetch64(const char* p, const uint64_t v = 0) { @@ -114,9 +105,7 @@ private: static constexpr uint64_t fetch32(const char* p) { return uint64_t(endian32(p)) * PRIME1; } - static constexpr uint64_t fetch8(const char* p) { - return uint8_t(*p) * PRIME5; - } + static constexpr uint64_t fetch8(const char* p) { return uint8_t(*p) * PRIME5; } // clang-format off static constexpr uint64_t finalize (const uint64_t h, const char *p, uint64_t len) { diff --git a/dnn/src/common/heuristic_cache.cpp b/dnn/src/common/heuristic_cache.cpp index 8ca1a593..0d6296bf 100644 --- a/dnn/src/common/heuristic_cache.cpp +++ b/dnn/src/common/heuristic_cache.cpp @@ -72,8 +72,9 @@ HeuristicCache::KeyStorage HeuristicCache::Key::build_key_storage() const { cuda_rt /= 1000; auto&& handle = static_cast(m_handle); auto&& prop = handle->device_prop(); - ctg.append(ssprintf(";dev=%s;cap=%d.%d;runtime=%d;", - prop.name, prop.major, prop.minor, cuda_rt)); + ctg.append(ssprintf( + ";dev=%s;cap=%d.%d;runtime=%d;", prop.name, prop.major, prop.minor, + cuda_rt)); break; } #endif @@ -84,8 +85,9 @@ HeuristicCache::KeyStorage HeuristicCache::Key::build_key_storage() const { int drv = -1, hip_rt = -1; hip_check(hipDriverGetVersion(&drv)); hip_check(hipRuntimeGetVersion(&hip_rt)); - ctg.append(ssprintf(";dev=%s;cap=%d.%d,drv=%d;runtime=%d;", - prop.name, prop.major, prop.minor, drv, hip_rt)); + ctg.append(ssprintf( + ";dev=%s;cap=%d.%d,drv=%d;runtime=%d;", prop.name, prop.major, + prop.minor, drv, hip_rt)); break; } #endif @@ -103,10 +105,9 @@ HeuristicCache::KeyStorage HeuristicCache::Key::build_key_storage() const { case Handle::HandleType::ARMV7: #endif { - size_t nr_threads = - static_cast(m_handle) - ->megcore_dispatcher() - ->nr_threads(); + size_t nr_threads = static_cast(m_handle) + ->megcore_dispatcher() + ->nr_threads(); ctg.append(";"); ctg.append(std::to_string(nr_threads)); ctg.append(";"); diff --git a/dnn/src/common/images2neibs.cpp b/dnn/src/common/images2neibs.cpp index 24dff13c..66cd3e3a 100644 --- a/dnn/src/common/images2neibs.cpp +++ b/dnn/src/common/images2neibs.cpp @@ -14,9 +14,7 @@ namespace megdnn { -void Images2NeibsBase::deduce_layout_fwd(const TensorLayout &src, - TensorLayout &dst) -{ +void Images2NeibsBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { auto errmsg = [&]() { return megdnn_layout_msg(src) + ", " + "pad_h=" + std::to_string(param().pad_h) + ", " + @@ -42,42 +40,37 @@ void Images2NeibsBase::deduce_layout_fwd(const TensorLayout &src, size_t ww = this->param().window_w; size_t oh, ow; - infer_conv_shape2d(ih, iw, wh+(wh-1)*(dh-1), ww+(ww-1)*(dw-1), sh, sw, ph, pw, oh, ow); + infer_conv_shape2d( + ih, iw, wh + (wh - 1) * (dh - 1), ww + (ww - 1) * (dw - 1), sh, sw, ph, pw, + oh, ow); dst = TensorLayout(TensorShape({n, ic, oh, ow, wh, ww}), src.dtype); } -void Images2NeibsBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &dst) -{ +void Images2NeibsBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& dst) { TensorLayout dst_expected; deduce_layout_fwd(src, dst_expected); megdnn_assert_eq_layout(dst_expected, dst); } -void Images2NeibsForward::deduce_layout(const TensorLayout &src, - TensorLayout &dst) -{ +void Images2NeibsForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } -void Images2NeibsForward::check_exec(const TensorLayout &src, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void Images2NeibsForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void Images2NeibsBackward::check_exec(const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void Images2NeibsBackward::check_exec( + const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(grad, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(grad, diff); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/indexing_multi_axis_vec.cpp b/dnn/src/common/indexing_multi_axis_vec.cpp index e23aec04..a4cc8b0b 100644 --- a/dnn/src/common/indexing_multi_axis_vec.cpp +++ b/dnn/src/common/indexing_multi_axis_vec.cpp @@ -15,28 +15,28 @@ using namespace megdnn; namespace { - size_t get_index_size_for_workspace( - const TensorShape &shp, const size_t *axes, size_t nr_axes) { - size_t idx_axis = axes[0]; - megdnn_assert(shp.ndim && nr_axes); - for (size_t i = 1; i < nr_axes; ++ i) { - megdnn_assert(axes[i] > axes[i - 1]); - if (axes[i] != axes[i - 1] + 1) { - idx_axis = 0; - break; - } +size_t get_index_size_for_workspace( + const TensorShape& shp, const size_t* axes, size_t nr_axes) { + size_t idx_axis = axes[0]; + megdnn_assert(shp.ndim && nr_axes); + for (size_t i = 1; i < nr_axes; ++i) { + megdnn_assert(axes[i] > axes[i - 1]); + if (axes[i] != axes[i - 1] + 1) { + idx_axis = 0; + break; } - megdnn_assert(shp.ndim > idx_axis, - "index on the %zuth axis; but shape is %s", - idx_axis, shp.to_string().c_str()); - return shp.shape[idx_axis]; } -} // anonymous namespace + megdnn_assert( + shp.ndim > idx_axis, "index on the %zuth axis; but shape is %s", idx_axis, + shp.to_string().c_str()); + return shp.shape[idx_axis]; +} +} // anonymous namespace -IndexingMultiAxisVecBase::IndexDescLayoutOnly -IndexingMultiAxisVecBase::extract_index_layout(const IndexDesc &index) { +IndexingMultiAxisVecBase::IndexDescLayoutOnly IndexingMultiAxisVecBase:: + extract_index_layout(const IndexDesc& index) { IndexDescLayoutOnly ret(index.size()); - for (size_t i = 0; i < index.size(); ++ i) { + for (size_t i = 0; i < index.size(); ++i) { ret[i].layout = index[i].vec.layout; ret[i].axis = index[i].axis; } @@ -44,16 +44,14 @@ IndexingMultiAxisVecBase::extract_index_layout(const IndexDesc &index) { } size_t IndexingMultiAxisVecBase::deduce_layout_fwd( - const TensorLayout &data, - const IndexDescLayoutOnly &index, - TensorLayout &dst) { + const TensorLayout& data, const IndexDescLayoutOnly& index, TensorLayout& dst) { megdnn_assert(!index.empty()); megdnn_assert(data.ndim >= index.size()); dst.ndim = data.ndim - index.size() + 1; dst.shape[0] = 1; dst.dtype = data.dtype; - auto brdcast = [&](const TensorLayout &ly) { + auto brdcast = [&](const TensorLayout& ly) { if (ly.ndim != 1) return false; if (dst.shape[0] == ly.shape[0]) @@ -67,24 +65,26 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( size_t dst_axis = 1; ptrdiff_t prev_axis = -1; - for (size_t axis = 0; axis < index.size(); ++ axis) { - auto &&idx = index[axis]; - megdnn_assert(idx.layout.dtype == dtype::Int32(), - "invalid index dtype: %s", idx.layout.dtype.name()); - megdnn_assert(idx.axis < data.ndim && - static_cast(idx.axis) > prev_axis, + for (size_t axis = 0; axis < index.size(); ++axis) { + auto&& idx = index[axis]; + megdnn_assert( + idx.layout.dtype == dtype::Int32(), "invalid index dtype: %s", + idx.layout.dtype.name()); + megdnn_assert( + idx.axis(idx.axis)> prev_axis, "index %zu requests invalid axis %zu", axis, idx.axis); auto brd_succ = brdcast(idx.layout); - megdnn_assert(brd_succ, "invalid layout at index %zu: %s", - axis, idx.layout.to_string().c_str()); + megdnn_assert( + brd_succ, "invalid layout at index %zu: %s", axis, + idx.layout.to_string().c_str()); - for (size_t i = prev_axis + 1; i < idx.axis; ++ i) { - dst.shape[dst_axis ++] = data.shape[i]; + for (size_t i = prev_axis + 1; i < idx.axis; ++i) { + dst.shape[dst_axis++] = data.shape[i]; } prev_axis = idx.axis; } - for (size_t i = prev_axis + 1; i < data.ndim; ++ i) { - dst.shape[dst_axis ++] = data.shape[i]; + for (size_t i = prev_axis + 1; i < data.ndim; ++i) { + dst.shape[dst_axis++] = data.shape[i]; } megdnn_assert(dst_axis == dst.ndim); @@ -92,7 +92,7 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( { // fix idx_axis if index contains consecutive axes bool contig_idx = true; - for (size_t i = 1; i < index.size(); ++ i) { + for (size_t i = 1; i < index.size(); ++i) { if (index[i].axis != index[i - 1].axis + 1) { contig_idx = false; break; @@ -101,7 +101,7 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( if (contig_idx) { auto shp0 = dst.shape[0]; idx_axis = index[0].axis; - for (size_t i = 0; i < idx_axis; ++ i) { + for (size_t i = 0; i < idx_axis; ++i) { dst.shape[i] = dst.shape[i + 1]; } dst.shape[idx_axis] = shp0; @@ -113,25 +113,23 @@ size_t IndexingMultiAxisVecBase::deduce_layout_fwd( } size_t IndexingMultiAxisVecBase::get_nonindex_axes( - size_t src_ndim, const IndexDesc &index, size_t *out) { + size_t src_ndim, const IndexDesc& index, size_t* out) { auto iter = index.begin(); size_t nr = 0; - for (size_t i = 0; i < src_ndim; ++ i) { + for (size_t i = 0; i < src_ndim; ++i) { if (iter != index.end() && i == iter->axis) { - ++ iter; + ++iter; } else { - out[nr ++] = i; + out[nr++] = i; } } megdnn_assert(nr + index.size() == src_ndim && iter == index.end()); return nr; } -IndexingMultiAxisVecBase::ExecInfo -IndexingMultiAxisVecBase::check_exec_noworkspace( - const TensorLayout &data, const TensorLayout &value, - const IndexDesc &index, IndexDescLayoutOnly &index_layout) { - +IndexingMultiAxisVecBase::ExecInfo IndexingMultiAxisVecBase::check_exec_noworkspace( + const TensorLayout& data, const TensorLayout& value, const IndexDesc& index, + IndexDescLayoutOnly& index_layout) { ExecInfo ret; index_layout = extract_index_layout(index); TensorLayout value_expect; @@ -139,28 +137,29 @@ IndexingMultiAxisVecBase::check_exec_noworkspace( megdnn_assert_eq_shape(value_expect, value); auto value_contig = value.collapse_contiguous(); - megdnn_assert(value_contig.ndim == 1, - "value layout must be 1-dim contiguous; got %s", + megdnn_assert( + value_contig.ndim == 1, "value layout must be 1-dim contiguous; got %s", value.to_string().c_str()); ret.value_stride = value_contig.stride[0]; return ret; } -std::pair -IndexingMultiAxisVecBase::get_value_iter_optimized_layout( - const TensorLayout &data, const TensorLayout &value, - const IndexDesc &index, size_t idx_axis) { +std::pair IndexingMultiAxisVecBase:: + get_value_iter_optimized_layout( + const TensorLayout& data, const TensorLayout& value, + const IndexDesc& index, size_t idx_axis) { size_t data_axes[TensorLayout::MAX_NDIM], - nr_axes = get_nonindex_axes(data.ndim, index, data_axes); + nr_axes = get_nonindex_axes(data.ndim, index, data_axes); - megdnn_assert(nr_axes == value.ndim - 1 && idx_axis < value.ndim && + megdnn_assert( + nr_axes == value.ndim - 1 && idx_axis < value.ndim && nr_axes + index.size() == data.ndim); TensorLayout ret; if (idx_axis) { ret.ndim = idx_axis; - for (size_t i = 0; i < idx_axis; ++ i) { + for (size_t i = 0; i < idx_axis; ++i) { ret.shape[i] = data.shape[data_axes[i]]; ret.stride[i] = data.stride[data_axes[i]]; } @@ -169,20 +168,20 @@ IndexingMultiAxisVecBase::get_value_iter_optimized_layout( ret.shape[ret.ndim] = value.shape[idx_axis]; ret.stride[ret.ndim] = 0; size_t ret_idx_axis = ret.ndim; - ++ ret.ndim; + ++ret.ndim; if (idx_axis < nr_axes) { TensorLayout tail; tail.ndim = nr_axes - idx_axis; - for (size_t i = idx_axis; i < nr_axes; ++ i) { + for (size_t i = idx_axis; i < nr_axes; ++i) { tail.shape[i - idx_axis] = data.shape[data_axes[i]]; tail.stride[i - idx_axis] = data.stride[data_axes[i]]; } tail = tail.collapse_contiguous(); - for (size_t i = 0; i < tail.ndim; ++ i) { + for (size_t i = 0; i < tail.ndim; ++i) { ret.shape[ret.ndim] = tail.shape[i]; ret.stride[ret.ndim] = tail.stride[i]; - ++ ret.ndim; + ++ret.ndim; } } @@ -190,38 +189,36 @@ IndexingMultiAxisVecBase::get_value_iter_optimized_layout( } size_t IndexingMultiAxisVec::get_workspace_in_bytes( - const TensorShape &dst, const size_t *axes, size_t nr_axes) { - return get_workspace_in_bytes( - get_index_size_for_workspace(dst, axes, nr_axes)); + const TensorShape& dst, const size_t* axes, size_t nr_axes) { + return get_workspace_in_bytes(get_index_size_for_workspace(dst, axes, nr_axes)); } IndexingMultiAxisVec::ExecInfo IndexingMultiAxisVec::check_exec( - const TensorLayout &src, const IndexDesc &index, - const TensorLayout &dst, size_t workspace_in_bytes) { + const TensorLayout& src, const IndexDesc& index, const TensorLayout& dst, + size_t workspace_in_bytes) { IndexDescLayoutOnly index_layout; auto ret = check_exec_noworkspace(src, dst, index, index_layout); - megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes( - dst.shape[ret.idx_axis])); + megdnn_assert( + workspace_in_bytes >= get_workspace_in_bytes(dst.shape[ret.idx_axis])); megdnn_assert(ret.value_stride, "dst must be non-overlapping"); return ret; } size_t IndexingModifyMultiAxisVecBase::get_workspace_in_bytes( - const TensorShape &value, const size_t *axes, size_t nr_axes) { - return get_workspace_in_bytes( - get_index_size_for_workspace(value, axes, nr_axes)); + const TensorShape& value, const size_t* axes, size_t nr_axes) { + return get_workspace_in_bytes(get_index_size_for_workspace(value, axes, nr_axes)); } -IndexingModifyMultiAxisVecBase::ExecInfo -IndexingModifyMultiAxisVecBase::check_exec( - const TensorLayout &data, const TensorLayout &value, - const IndexDesc &index, size_t workspace_in_bytes) { - megdnn_assert(data.is_non_overlapping_strong(), - "data layout should not overlap: %s", data.to_string().c_str()); +IndexingModifyMultiAxisVecBase::ExecInfo IndexingModifyMultiAxisVecBase::check_exec( + const TensorLayout& data, const TensorLayout& value, const IndexDesc& index, + size_t workspace_in_bytes) { + megdnn_assert( + data.is_non_overlapping_strong(), "data layout should not overlap: %s", + data.to_string().c_str()); IndexDescLayoutOnly index_layout; auto ret = check_exec_noworkspace(data, value, index, index_layout); - megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes( - value.shape[ret.idx_axis])); + megdnn_assert( + workspace_in_bytes >= get_workspace_in_bytes(value.shape[ret.idx_axis])); return ret; } diff --git a/dnn/src/common/indexing_multi_axis_vec_kdef.h b/dnn/src/common/indexing_multi_axis_vec_kdef.h index 916ef2e7..fc2948b8 100644 --- a/dnn/src/common/indexing_multi_axis_vec_kdef.h +++ b/dnn/src/common/indexing_multi_axis_vec_kdef.h @@ -13,35 +13,35 @@ #if MEGDNN_CC_HOST && !defined(__device__) #define __device__ -#define def_device 1 +#define def_device 1 #endif namespace megdnn { namespace indexing_multi_axis_vec_kdef { struct OprFwd { - template - __device__ static void apply(ctype data, ctype &value) { + template + __device__ static void apply(ctype data, ctype& value) { value = data; } }; struct OprSet { - template - __device__ static void apply(ctype &data, ctype value) { + template + __device__ static void apply(ctype& data, ctype value) { data = value; } }; struct OprIncr { - template - __device__ static void apply(ctype &data, ctype value) { + template + __device__ static void apply(ctype& data, ctype value) { data += value; } }; -} -} +} // namespace indexing_multi_axis_vec_kdef +} // namespace megdnn #if def_device #undef __device__ diff --git a/dnn/src/common/indexing_one_hot.cpp b/dnn/src/common/indexing_one_hot.cpp index 6e7ee89f..d2ac201d 100644 --- a/dnn/src/common/indexing_one_hot.cpp +++ b/dnn/src/common/indexing_one_hot.cpp @@ -16,12 +16,11 @@ using namespace megdnn; void IndexingOneHotBase::deduce_layout_fwd( - const TensorLayout &src, const TensorLayout &index, - TensorLayout &dst) { + const TensorLayout& src, const TensorLayout& index, TensorLayout& dst) { megdnn_assert( m_param.axis < static_cast(src.ndim) && src.ndim >= 2, - "IndexingOneHot on axis %u, but input has only %zu dims", - m_param.axis, src.ndim); + "IndexingOneHot on axis %u, but input has only %zu dims", m_param.axis, + src.ndim); MEGDNN_MARK_USED_VAR(index); dst = src; dst.shape[m_param.axis] = 1; @@ -29,8 +28,7 @@ void IndexingOneHotBase::deduce_layout_fwd( } void IndexingOneHotBase::check_layout_fwd( - const TensorLayout &src, const TensorLayout &index, - const TensorLayout &dst) { + const TensorLayout& src, const TensorLayout& index, const TensorLayout& dst) { auto errmsg = [&]() -> std::string { return ssprintf( "bad layout for IndexingOneHot: " @@ -41,43 +39,44 @@ void IndexingOneHotBase::check_layout_fwd( MEGDNN_MARK_USED_VAR(errmsg); megdnn_assert_eq_dtype(src, dst); megdnn_assert(index.dtype == dtype::Int32(), "%s", errmsg().c_str()); - megdnn_assert(src.is_contiguous() && index.is_contiguous() && - dst.is_contiguous(), "%s", errmsg().c_str()); + megdnn_assert( + src.is_contiguous() && index.is_contiguous() && dst.is_contiguous(), "%s", + errmsg().c_str()); // check index TensorShape idx_shp{src}; - -- idx_shp.ndim; + --idx_shp.ndim; megdnn_assert(m_param.axis >= 0, "%s", errmsg().c_str()); for (auto i = static_cast(m_param.axis); i < idx_shp.ndim; ++i) idx_shp[i] = idx_shp[i + 1]; - megdnn_assert(index.eq_shape(idx_shp), "%s idx_shp=%s", errmsg().c_str(), idx_shp.to_string().c_str()); + megdnn_assert( + index.eq_shape(idx_shp), "%s idx_shp=%s", errmsg().c_str(), + idx_shp.to_string().c_str()); // check dst megdnn_assert( - m_param.axis < static_cast(src.ndim) && src.ndim >= 2, - "%s", errmsg().c_str()); + m_param.axis < static_cast(src.ndim) && src.ndim >= 2, "%s", + errmsg().c_str()); TensorShape dst_shp{src}; dst_shp.shape[m_param.axis] = 1; - megdnn_assert(dst.eq_shape(dst_shp), "%s dst_shp=%s", errmsg().c_str(), dst_shp.to_string().c_str()); + megdnn_assert( + dst.eq_shape(dst_shp), "%s dst_shp=%s", errmsg().c_str(), + dst_shp.to_string().c_str()); } -void IndexingOneHotForward::check_exec(const TensorLayout &src, - const TensorLayout &index, const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void IndexingOneHotForward::check_exec( + const TensorLayout& src, const TensorLayout& index, const TensorLayout& dst, + size_t workspace_in_bytes) { check_layout_fwd(src, index, dst); - auto required_workspace_in_bytes = get_workspace_in_bytes( - src, index, dst); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, index, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void IndexingSetOneHotForward::check_exec(const TensorLayout &data, - const TensorLayout &index, const TensorLayout &sub, - size_t workspace_in_bytes) -{ +void IndexingSetOneHotForward::check_exec( + const TensorLayout& data, const TensorLayout& index, const TensorLayout& sub, + size_t workspace_in_bytes) { check_layout_fwd(data, index, sub); - auto required_workspace_in_bytes = get_workspace_in_bytes( - data, index, sub); + auto required_workspace_in_bytes = get_workspace_in_bytes(data, index, sub); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } diff --git a/dnn/src/common/linspace.cpp b/dnn/src/common/linspace.cpp index ceae94fa..13c2fd17 100644 --- a/dnn/src/common/linspace.cpp +++ b/dnn/src/common/linspace.cpp @@ -14,14 +14,13 @@ namespace megdnn { -void Linspace::check_exec(const TensorLayout &dst, size_t workspace_in_bytes) -{ +void Linspace::check_exec(const TensorLayout& dst, size_t workspace_in_bytes) { megdnn_assert(dst.ndim == 1 && dst.shape[0] > 0); megdnn_assert_contiguous(dst); auto required_workspace_in_bytes = get_workspace_in_bytes(dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/local/local_decl.inl b/dnn/src/common/local/local_decl.inl index 5682eec6..17651f1e 100644 --- a/dnn/src/common/local/local_decl.inl +++ b/dnn/src/common/local/local_decl.inl @@ -23,12 +23,12 @@ namespace megdnn { using LocalKParam = naive::LocalForwardImpl::FloatNoncontigBatchKernParam; -void WITH_SIMD_SUFFIX(local_xcorr)( - const LocalKParam ¶m) MEGDNN_SIMD_ATTRIBUTE_TARGET; +void WITH_SIMD_SUFFIX(local_xcorr)(const LocalKParam& param) + MEGDNN_SIMD_ATTRIBUTE_TARGET; -void WITH_SIMD_SUFFIX(local_conv)( - const LocalKParam ¶m) MEGDNN_SIMD_ATTRIBUTE_TARGET; +void WITH_SIMD_SUFFIX(local_conv)(const LocalKParam& param) + MEGDNN_SIMD_ATTRIBUTE_TARGET; -} // namespace megdnn +} // namespace megdnn #include "src/common/macro_helper_epilogue.h" diff --git a/dnn/src/common/local/local_def.inl b/dnn/src/common/local/local_def.inl index 0f912194..79808b42 100644 --- a/dnn/src/common/local/local_def.inl +++ b/dnn/src/common/local/local_def.inl @@ -18,83 +18,82 @@ #include "src/common/local/local_decl.inl" -#include "src/common/utils.h" #include "src/common/macro_helper.h" +#include "src/common/utils.h" namespace { using namespace megdnn; template -void local_xcorr_tpl(const LocalKParam &kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; +void local_xcorr_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; template -void local_xcorr_tpl(const LocalKParam &kparam) -{ +void local_xcorr_tpl(const LocalKParam& kparam) { const float* src = static_cast(kparam.src); const float* filter = static_cast(kparam.filter); float* dst = static_cast(kparam.dst); float* workspace = static_cast(kparam.workspace); const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh, OW = kparam.ow, FH = kparam.fh, FW = kparam.fw; - const uint32_t PH = kparam.ph, PW = kparam.pw, SH = kparam.sh, - SW = kparam.sw; + const uint32_t PH = kparam.ph, PW = kparam.pw, SH = kparam.sh, SW = kparam.sw; const ptrdiff_t INP_BS = kparam.inp_bs, OUT_BS = kparam.out_bs; - float *dst2 = workspace; + float* dst2 = workspace; const int width = MEGDNN_SIMD_WIDTH; // dst2 is (H, W, N, C) - memset(dst2, 0, sizeof(float) * OH*OW*N*OC); - float *dst2_hwnc = dst2; + memset(dst2, 0, sizeof(float) * OH * OW * N * OC); + float* dst2_hwnc = dst2; rep(oh, OH) rep(ow, OW) { - const float *src_bak = src; + const float* src_bak = src; rep(ic, IC) { rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) { - int ih = -PH + oh*SH + fh; - int iw = -PW + ow*SW + fw; - if (ih < 0 || ih >= IH || iw < 0 || iw >= IW) continue; - float *dst2_bak = dst2; + int ih = -PH + oh * SH + fh; + int iw = -PW + ow * SW + fw; + if (ih < 0 || ih >= IH || iw < 0 || iw >= IW) + continue; + float* dst2_bak = dst2; rep(n, N) { - float s = src[n*INP_BS + ih*IW + iw]; - const float *filter_bak = filter; + float s = src[n * INP_BS + ih * IW + iw]; + const float* filter_bak = filter; MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s); int oc = 0; - for (; oc+4*width <= OC; oc += 4*width, filter += 4*width) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1*width); - MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2*width); - MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); - MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1*width); - MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2*width); - MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3*width); + for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width); + MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width); + MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); + MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width); + MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width); + MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1); vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2); vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 1*width, vd1); - MEGDNN_SIMD_STOREU(dst2 + oc + 2*width, vd2); - MEGDNN_SIMD_STOREU(dst2 + oc + 3*width, vd3); + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1); + MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2); + MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3); } - if (oc+2*width <= OC) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); - MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1*width); + if (oc + 2 * width <= OC) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); + MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 1*width, vd1); - oc += 2*width; - filter += 2*width; + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1); + oc += 2 * width; + filter += 2 * width; } - if (oc+1*width <= OC) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); + if (oc + 1 * width <= OC) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - oc += 1*width; - filter += 1*width; + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + oc += 1 * width; + filter += 1 * width; } for (; oc < OC; ++oc, ++filter) { dst2[oc] += s * (*filter); @@ -104,72 +103,73 @@ void local_xcorr_tpl(const LocalKParam &kparam) } dst2 = dst2_bak; } - src += IH*IW; + src += IH * IW; } src = src_bak; - dst2 += N*OC; + dst2 += N * OC; } transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS); } -void local_xcorr_generic(const LocalKParam &kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; -void local_xcorr_generic(const LocalKParam &kparam) { +void local_xcorr_generic(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; +void local_xcorr_generic(const LocalKParam& kparam) { UNPACK_LOCAL_FLOAT_NONCONTIG_BATCH_KERN_PARAM(kparam, float); - float *dst2 = workspace; + float* dst2 = workspace; const int width = MEGDNN_SIMD_WIDTH; // dst2 is (H, W, N, C) - memset(dst2, 0, sizeof(float) * OH*OW*N*OC); - float *dst2_hwnc = dst2; + memset(dst2, 0, sizeof(float) * OH * OW * N * OC); + float* dst2_hwnc = dst2; rep(oh, OH) rep(ow, OW) { - const float *src_bak = src; + const float* src_bak = src; rep(ic, IC) { rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) { - int ih = -PH + oh*SH + fh; - int iw = -PW + ow*SW + fw; - if (ih < 0 || ih >= IH || iw < 0 || iw >= IW) continue; - float *dst2_bak = dst2; + int ih = -PH + oh * SH + fh; + int iw = -PW + ow * SW + fw; + if (ih < 0 || ih >= IH || iw < 0 || iw >= IW) + continue; + float* dst2_bak = dst2; rep(n, N) { - float s = src[n*INP_BS + ih*IW + iw]; - const float *filter_bak = filter; + float s = src[n * INP_BS + ih * IW + iw]; + const float* filter_bak = filter; MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s); int oc = 0; - for (; oc+4*width <= OC; oc += 4*width, filter += 4*width) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1*width); - MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2*width); - MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); - MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1*width); - MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2*width); - MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3*width); + for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width); + MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width); + MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); + MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width); + MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width); + MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1); vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2); vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 1*width, vd1); - MEGDNN_SIMD_STOREU(dst2 + oc + 2*width, vd2); - MEGDNN_SIMD_STOREU(dst2 + oc + 3*width, vd3); + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1); + MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2); + MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3); } - if (oc+2*width <= OC) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); - MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1*width); + if (oc + 2 * width <= OC) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); + MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 1*width, vd1); - oc += 2*width; - filter += 2*width; + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1); + oc += 2 * width; + filter += 2 * width; } - if (oc+1*width <= OC) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); + if (oc + 1 * width <= OC) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - oc += 1*width; - filter += 1*width; + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + oc += 1 * width; + filter += 1 * width; } for (; oc < OC; ++oc, ++filter) { dst2[oc] += s * (*filter); @@ -179,84 +179,83 @@ void local_xcorr_generic(const LocalKParam &kparam) { } dst2 = dst2_bak; } - src += IH*IW; + src += IH * IW; } src = src_bak; - dst2 += N*OC; + dst2 += N * OC; } transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS); } template -void local_conv_tpl(const LocalKParam &kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; +void local_conv_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; template -void local_conv_tpl(const LocalKParam &kparam) -{ +void local_conv_tpl(const LocalKParam& kparam) { const float* src = static_cast(kparam.src); const float* filter = static_cast(kparam.filter); float* dst = static_cast(kparam.dst); float* workspace = static_cast(kparam.workspace); const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh, OW = kparam.ow, FH = kparam.fh, FW = kparam.fw; - const uint32_t PH = kparam.ph, PW = kparam.pw, SH = kparam.sh, - SW = kparam.sw; + const uint32_t PH = kparam.ph, PW = kparam.pw, SH = kparam.sh, SW = kparam.sw; const ptrdiff_t INP_BS = kparam.inp_bs, OUT_BS = kparam.out_bs; - float *dst2 = workspace; + float* dst2 = workspace; const int width = MEGDNN_SIMD_WIDTH; // dst2 is (H, W, N, C) - memset(dst2, 0, sizeof(float) * OH*OW*N*OC); - float *dst2_hwnc = dst2; + memset(dst2, 0, sizeof(float) * OH * OW * N * OC); + float* dst2_hwnc = dst2; rep(oh, OH) rep(ow, OW) { - const float *src_bak = src; + const float* src_bak = src; rep(ic, IC) { rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) { - int ih = -PH + oh*SH + (FH-fh-1); - int iw = -PW + ow*SW + (FW-fw-1); - if (ih < 0 || ih >= IH || iw < 0 || iw >= IW) continue; - float *dst2_bak = dst2; + int ih = -PH + oh * SH + (FH - fh - 1); + int iw = -PW + ow * SW + (FW - fw - 1); + if (ih < 0 || ih >= IH || iw < 0 || iw >= IW) + continue; + float* dst2_bak = dst2; rep(n, N) { - float s = src[n*INP_BS + ih*IW + iw]; - const float *filter_bak = filter; + float s = src[n * INP_BS + ih * IW + iw]; + const float* filter_bak = filter; MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s); int oc = 0; - for (; oc+4*width <= OC; oc += 4*width, filter += 4*width) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1*width); - MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2*width); - MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); - MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1*width); - MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2*width); - MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3*width); + for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width); + MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width); + MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); + MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width); + MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width); + MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1); vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2); vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 1*width, vd1); - MEGDNN_SIMD_STOREU(dst2 + oc + 2*width, vd2); - MEGDNN_SIMD_STOREU(dst2 + oc + 3*width, vd3); + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1); + MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2); + MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3); } - if (oc+2*width <= OC) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); - MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1*width); + if (oc + 2 * width <= OC) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); + MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 1*width, vd1); - oc += 2*width; - filter += 2*width; + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1); + oc += 2 * width; + filter += 2 * width; } - if (oc+1*width <= OC) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); + if (oc + 1 * width <= OC) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - oc += 1*width; - filter += 1*width; + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + oc += 1 * width; + filter += 1 * width; } for (; oc < OC; ++oc, ++filter) { dst2[oc] += s * (*filter); @@ -266,73 +265,74 @@ void local_conv_tpl(const LocalKParam &kparam) } dst2 = dst2_bak; } - src += IH*IW; + src += IH * IW; } src = src_bak; - dst2 += N*OC; + dst2 += N * OC; } transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS); } -void local_conv_generic(const LocalKParam &kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; -void local_conv_generic(const LocalKParam &kparam) { +void local_conv_generic(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; +void local_conv_generic(const LocalKParam& kparam) { UNPACK_LOCAL_FLOAT_NONCONTIG_BATCH_KERN_PARAM(kparam, float); - float *dst2 = workspace; + float* dst2 = workspace; const int width = MEGDNN_SIMD_WIDTH; // dst2 is (H, W, N, C) - memset(dst2, 0, sizeof(float) * OH*OW*N*OC); - float *dst2_hwnc = dst2; + memset(dst2, 0, sizeof(float) * OH * OW * N * OC); + float* dst2_hwnc = dst2; rep(oh, OH) rep(ow, OW) { - const float *src_bak = src; + const float* src_bak = src; rep(ic, IC) { rep(fh, FH) for (int fw = 0; fw < FW; ++fw, filter += OC) { - int ih = -PH + oh*SH + (FH-fh-1); - int iw = -PW + ow*SW + (FW-fw-1); - if (ih < 0 || ih >= IH || iw < 0 || iw >= IW) continue; - float *dst2_bak = dst2; + int ih = -PH + oh * SH + (FH - fh - 1); + int iw = -PW + ow * SW + (FW - fw - 1); + if (ih < 0 || ih >= IH || iw < 0 || iw >= IW) + continue; + float* dst2_bak = dst2; rep(n, N) { - float s = src[n*INP_BS + ih*IW + iw]; - const float *filter_bak = filter; + float s = src[n * INP_BS + ih * IW + iw]; + const float* filter_bak = filter; MEGDNN_SIMD_TYPE vs = MEGDNN_SIMD_SET1(s); int oc = 0; - for (; oc+4*width <= OC; oc += 4*width, filter += 4*width) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1*width); - MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2*width); - MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); - MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1*width); - MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2*width); - MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3*width); + for (; oc + 4 * width <= OC; oc += 4 * width, filter += 4 * width) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width); + MEGDNN_SIMD_TYPE vf2 = MEGDNN_SIMD_LOADU(filter + 2 * width); + MEGDNN_SIMD_TYPE vf3 = MEGDNN_SIMD_LOADU(filter + 3 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); + MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width); + MEGDNN_SIMD_TYPE vd2 = MEGDNN_SIMD_LOADU(dst2 + oc + 2 * width); + MEGDNN_SIMD_TYPE vd3 = MEGDNN_SIMD_LOADU(dst2 + oc + 3 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1); vd2 = MEGDNN_SIMD_FMADD(vf2, vs, vd2); vd3 = MEGDNN_SIMD_FMADD(vf3, vs, vd3); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 1*width, vd1); - MEGDNN_SIMD_STOREU(dst2 + oc + 2*width, vd2); - MEGDNN_SIMD_STOREU(dst2 + oc + 3*width, vd3); + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1); + MEGDNN_SIMD_STOREU(dst2 + oc + 2 * width, vd2); + MEGDNN_SIMD_STOREU(dst2 + oc + 3 * width, vd3); } - if (oc+2*width <= OC) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); - MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1*width); + if (oc + 2 * width <= OC) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vf1 = MEGDNN_SIMD_LOADU(filter + 1 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); + MEGDNN_SIMD_TYPE vd1 = MEGDNN_SIMD_LOADU(dst2 + oc + 1 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); vd1 = MEGDNN_SIMD_FMADD(vf1, vs, vd1); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 1*width, vd1); - oc += 2*width; - filter += 2*width; + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + MEGDNN_SIMD_STOREU(dst2 + oc + 1 * width, vd1); + oc += 2 * width; + filter += 2 * width; } - if (oc+1*width <= OC) { - MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0*width); - MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0*width); + if (oc + 1 * width <= OC) { + MEGDNN_SIMD_TYPE vf0 = MEGDNN_SIMD_LOADU(filter + 0 * width); + MEGDNN_SIMD_TYPE vd0 = MEGDNN_SIMD_LOADU(dst2 + oc + 0 * width); vd0 = MEGDNN_SIMD_FMADD(vf0, vs, vd0); - MEGDNN_SIMD_STOREU(dst2 + oc + 0*width, vd0); - oc += 1*width; - filter += 1*width; + MEGDNN_SIMD_STOREU(dst2 + oc + 0 * width, vd0); + oc += 1 * width; + filter += 1 * width; } for (; oc < OC; ++oc, ++filter) { dst2[oc] += s * (*filter); @@ -342,38 +342,51 @@ void local_conv_generic(const LocalKParam &kparam) { } dst2 = dst2_bak; } - src += IH*IW; + src += IH * IW; } src = src_bak; - dst2 += N*OC; + dst2 += N * OC; } transpose_knc2nsck(dst2_hwnc, dst, OH * OW, N, OC, OUT_BS); } -} // anonymous namespace +} // anonymous namespace namespace megdnn { #define FUNC_NAME CONCAT_STR(local_xcorr_, MEGDNN_SIMD_NAME) -void FUNC_NAME(const LocalKParam &kparam) { +void FUNC_NAME(const LocalKParam& kparam) { auto N = kparam.n, OC = kparam.oc; -#define DISPATCH_WITH_N_OC(N, OC) do { \ - local_xcorr_tpl(kparam); \ - return; \ -} while (0) +#define DISPATCH_WITH_N_OC(N, OC) \ + do { \ + local_xcorr_tpl(kparam); \ + return; \ + } while (0) -#define DISPATCH_WITH_N(N) \ - switch (OC) { \ - case 16: DISPATCH_WITH_N_OC(N, 16); break; \ - case 32: DISPATCH_WITH_N_OC(N, 32); break; \ - case 48: DISPATCH_WITH_N_OC(N, 48); break; \ - case 64: DISPATCH_WITH_N_OC(N, 64); break; \ +#define DISPATCH_WITH_N(N) \ + switch (OC) { \ + case 16: \ + DISPATCH_WITH_N_OC(N, 16); \ + break; \ + case 32: \ + DISPATCH_WITH_N_OC(N, 32); \ + break; \ + case 48: \ + DISPATCH_WITH_N_OC(N, 48); \ + break; \ + case 64: \ + DISPATCH_WITH_N_OC(N, 64); \ + break; \ } -#define DISPATCH() \ - switch (N) { \ - case 1: DISPATCH_WITH_N(1); break; \ - case 2: DISPATCH_WITH_N(2); break; \ +#define DISPATCH() \ + switch (N) { \ + case 1: \ + DISPATCH_WITH_N(1); \ + break; \ + case 2: \ + DISPATCH_WITH_N(2); \ + break; \ } DISPATCH(); @@ -386,28 +399,39 @@ void FUNC_NAME(const LocalKParam &kparam) { #undef FUNC_NAME - - #define FUNC_NAME CONCAT_STR(local_conv_, MEGDNN_SIMD_NAME) -void FUNC_NAME(const LocalKParam &kparam) { +void FUNC_NAME(const LocalKParam& kparam) { auto N = kparam.n, OC = kparam.oc; -#define DISPATCH_WITH_N_OC(N, OC) do { \ - local_conv_tpl(kparam); \ - return; \ -} while (0) +#define DISPATCH_WITH_N_OC(N, OC) \ + do { \ + local_conv_tpl(kparam); \ + return; \ + } while (0) -#define DISPATCH_WITH_N(N) \ - switch (OC) { \ - case 16: DISPATCH_WITH_N_OC(N, 16); break; \ - case 32: DISPATCH_WITH_N_OC(N, 32); break; \ - case 48: DISPATCH_WITH_N_OC(N, 48); break; \ - case 64: DISPATCH_WITH_N_OC(N, 64); break; \ +#define DISPATCH_WITH_N(N) \ + switch (OC) { \ + case 16: \ + DISPATCH_WITH_N_OC(N, 16); \ + break; \ + case 32: \ + DISPATCH_WITH_N_OC(N, 32); \ + break; \ + case 48: \ + DISPATCH_WITH_N_OC(N, 48); \ + break; \ + case 64: \ + DISPATCH_WITH_N_OC(N, 64); \ + break; \ } -#define DISPATCH() \ - switch (N) { \ - case 1: DISPATCH_WITH_N(1); break; \ - case 2: DISPATCH_WITH_N(2); break; \ +#define DISPATCH() \ + switch (N) { \ + case 1: \ + DISPATCH_WITH_N(1); \ + break; \ + case 2: \ + DISPATCH_WITH_N(2); \ + break; \ } DISPATCH(); @@ -420,6 +444,6 @@ void FUNC_NAME(const LocalKParam &kparam) { #undef FUNC_NAME -} // namespace megdnn +} // namespace megdnn #include "src/common/macro_helper_epilogue.h" diff --git a/dnn/src/common/local/opr_impl.cpp b/dnn/src/common/local/opr_impl.cpp index bad17e03..07003e90 100644 --- a/dnn/src/common/local/opr_impl.cpp +++ b/dnn/src/common/local/opr_impl.cpp @@ -14,13 +14,12 @@ namespace megdnn { -void LocalBase::deduce_layout_fwd(const TensorLayout &src, - const TensorLayout &filter, TensorLayout &dst) -{ - auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + - ", " + megdnn_layout_msg(dst) + ", " + "is_xcorr=" + - std::to_string((param().mode == Mode::CROSS_CORRELATION)) + - ", " + "pad_h=" + std::to_string(param().pad_h) + ", " + +void LocalBase::deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) { + auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " + + megdnn_layout_msg(dst) + ", " + "is_xcorr=" + + std::to_string((param().mode == Mode::CROSS_CORRELATION)) + ", " + + "pad_h=" + std::to_string(param().pad_h) + ", " + "pad_w=" + std::to_string(param().pad_w) + ", " + "stride_h=" + std::to_string(param().stride_h) + ", " + "stride_w=" + std::to_string(param().stride_w); @@ -35,14 +34,14 @@ void LocalBase::deduce_layout_fwd(const TensorLayout &src, megdnn_assert_contiguous(filter); megdnn_assert(src.ndim == 4_z, "%s", errmsg_c); megdnn_assert(filter.ndim == 6_z, "%s", errmsg_c); - megdnn_assert(param().dilate_h == 1 && param().dilate_w == 1, + megdnn_assert( + param().dilate_h == 1 && param().dilate_w == 1, "dilation in local not supported"); - megdnn_assert(param().sparse == Param::Sparse::DENSE && - param().dilate_h == 1 && param().dilate_w == 1 && - src.dtype.category() == DTypeCategory::FLOAT && - dst.dtype == src.dtype && - "unsupported conv param for Local opr"); + megdnn_assert( + param().sparse == Param::Sparse::DENSE && param().dilate_h == 1 && + param().dilate_w == 1 && src.dtype.category() == DTypeCategory::FLOAT && + dst.dtype == src.dtype && "unsupported conv param for Local opr"); size_t n = src[0]; size_t ic = src[1]; @@ -61,10 +60,8 @@ void LocalBase::deduce_layout_fwd(const TensorLayout &src, dst = TensorLayout(TensorShape({n, oc, oh, ow}), src.dtype); } -void LocalBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) -{ +void LocalBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { TensorLayout dst_expected{dst.dtype}; megdnn_assert_eq_dtype(src, filter); megdnn_assert_eq_dtype(src, dst); @@ -74,49 +71,40 @@ void LocalBase::check_layout_fwd(const TensorLayout &src, megdnn_assert_eq_layout(dst_expected, dst); megdnn_assert(src.dtype == filter.dtype && src.dtype == dst.dtype); - megdnn_assert(src.dtype == dtype::Float32() || - DNN_FLOAT16_SELECT(src.dtype == dtype::Float16(), true)); + megdnn_assert( + src.dtype == dtype::Float32() || + DNN_FLOAT16_SELECT(src.dtype == dtype::Float16(), true)); } -void LocalForward::deduce_layout(const TensorLayout &src, - const TensorLayout &filter, - TensorLayout &dst) -{ +void LocalForward::deduce_layout( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) { deduce_layout_fwd(src, filter, dst); } -void LocalForward::check_exec(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void LocalForward::check_exec( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, + size_t workspace_in_bytes) { check_layout_fwd(src, filter, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void LocalBackwardData::check_exec(const TensorLayout &filter, - const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void LocalBackwardData::check_exec( + const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { check_layout_fwd(grad, filter, diff); - auto required_workspace_in_bytes = get_workspace_in_bytes(filter, - diff, grad); + auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void LocalBackwardFilter::check_exec(const TensorLayout &src, - const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void LocalBackwardFilter::check_exec( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { check_layout_fwd(src, grad, diff); - auto required_workspace_in_bytes = get_workspace_in_bytes(src, - diff, grad); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/local_share/opr_impl.cpp b/dnn/src/common/local_share/opr_impl.cpp index 67185d12..20897235 100644 --- a/dnn/src/common/local_share/opr_impl.cpp +++ b/dnn/src/common/local_share/opr_impl.cpp @@ -13,22 +13,20 @@ namespace megdnn { -void LocalShareBase::deduce_layout_fwd(const TensorLayout& src, - const TensorLayout& filter, - TensorLayout& dst) { +void LocalShareBase::deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) { using Mode = LocalShare::Param::Mode; auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter) + ", " + - megdnn_layout_msg(dst) + ", " + "is_xcorr=" + - std::to_string((param().mode == Mode::CROSS_CORRELATION)) + ", " + - "pad_h=" + std::to_string(param().pad_h) + ", " + + megdnn_layout_msg(dst) + ", " + + "is_xcorr=" + std::to_string((param().mode == Mode::CROSS_CORRELATION)) + + ", " + "pad_h=" + std::to_string(param().pad_h) + ", " + "pad_w=" + std::to_string(param().pad_w) + ", " + "stride_h=" + std::to_string(param().stride_h) + ", " + "stride_w=" + std::to_string(param().stride_w) + ", " + "dilate_h=" + std::to_string(param().dilate_h) + ", " + "dilate_w=" + std::to_string(param().dilate_w) + ", " + - "spatial_groups_h=" + std::to_string(param().spatial_groups_h) + - ", " + + "spatial_groups_h=" + std::to_string(param().spatial_groups_h) + ", " + "spatial_groups_w=" + std::to_string(param().spatial_groups_w); auto errmsg_c = errmsg.c_str(); MEGDNN_MARK_USED_VAR(errmsg_c); @@ -39,18 +37,20 @@ void LocalShareBase::deduce_layout_fwd(const TensorLayout& src, using Sparse = Param::Sparse; using Format = Param::Format; using ComputeMode = Param::ComputeMode; - megdnn_assert(param().format == Format::NCHW, - "local shared only support NCHW format"); + megdnn_assert( + param().format == Format::NCHW, "local shared only support NCHW format"); megdnn_assert(src.ndim == 4_z, "%s", errmsg_c); megdnn_assert( (filter.ndim == 6_z && param().sparse == Sparse::DENSE) || (filter.ndim == 7_z && param().sparse == Sparse::GROUP), "%s", errmsg_c); - megdnn_assert(param().dilate_h == 1 && param().dilate_w == 1, - "dilated local shared is not supported"); - megdnn_assert(src.dtype == dtype::Float32() && - param().computeMode == ComputeMode::DEFAULT, - "local shared only support float32"); + megdnn_assert( + param().dilate_h == 1 && param().dilate_w == 1, + "dilated local shared is not supported"); + megdnn_assert( + src.dtype == dtype::Float32() && + param().computeMode == ComputeMode::DEFAULT, + "local shared only support float32"); size_t n = src[0], ci = src[1], hi = src[2], wi = src[3]; size_t sgh = param().spatial_groups_h, sgw = param().spatial_groups_w; @@ -60,21 +60,21 @@ void LocalShareBase::deduce_layout_fwd(const TensorLayout& src, groups = filter[0]; weights_shp_pos = 1; } - megdnn_assert(sgh == filter[weights_shp_pos] && - sgw == filter[weights_shp_pos + 1], - "spatial groups in filter tensor mismatch with those " - "provided in parameter %s", - errmsg_c); + megdnn_assert( + sgh == filter[weights_shp_pos] && sgw == filter[weights_shp_pos + 1], + "spatial groups in filter tensor mismatch with those " + "provided in parameter %s", + errmsg_c); size_t fh = filter[weights_shp_pos + 3], fw = filter[weights_shp_pos + 4], co = filter[weights_shp_pos + 5] * groups; - megdnn_assert(filter[weights_shp_pos + 2] * groups == ci, - "input channels of src and filter mismatch %s", errmsg_c); + megdnn_assert( + filter[weights_shp_pos + 2] * groups == ci, + "input channels of src and filter mismatch %s", errmsg_c); size_t sh = param().stride_h; size_t sw = param().stride_w; size_t ph = param().pad_h; size_t pw = param().pad_w; - size_t ho = infer_conv_shape(hi, fh, sh, ph), - wo = infer_conv_shape(wi, fw, sw, pw); + size_t ho = infer_conv_shape(hi, fh, sh, ph), wo = infer_conv_shape(wi, fw, sw, pw); megdnn_assert( ho % sgh == 0 && wo % sgw == 0, "height and width of output cannot be divided by spatial groups %s", @@ -82,9 +82,8 @@ void LocalShareBase::deduce_layout_fwd(const TensorLayout& src, dst = TensorLayout{{n, co, ho, wo}, src.dtype}; } -void LocalShareBase::check_layout_fwd(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) { +void LocalShareBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { TensorLayout dst_expected; megdnn_assert_eq_dtype(src, filter); megdnn_assert_eq_dtype(src, dst); @@ -94,37 +93,33 @@ void LocalShareBase::check_layout_fwd(const TensorLayout& src, megdnn_assert(src.dtype == dtype::Float32()); } -void LocalShareForward::deduce_layout(const TensorLayout& src, - const TensorLayout& filter, - TensorLayout& dst) { +void LocalShareForward::deduce_layout( + const TensorLayout& src, const TensorLayout& filter, TensorLayout& dst) { deduce_layout_fwd(src, filter, dst); } -void LocalShareForward::check_exec(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst, - size_t workspace_in_bytes) { +void LocalShareForward::check_exec( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, + size_t workspace_in_bytes) { check_layout_fwd(src, filter, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void LocalShareBackwardData::deduce_layout(const TensorLayout& filter, - const TensorLayout& diff, - TensorLayout& grad) { +void LocalShareBackwardData::deduce_layout( + const TensorLayout& filter, const TensorLayout& diff, TensorLayout& grad) { using Mode = LocalShare::Param::Mode; auto errmsg = megdnn_layout_msg(filter) + ", " + megdnn_layout_msg(diff) + ", " + - megdnn_layout_msg(grad) + ", " + "is_xcorr=" + - std::to_string((param().mode == Mode::CROSS_CORRELATION)) + ", " + - "pad_h=" + std::to_string(param().pad_h) + ", " + + megdnn_layout_msg(grad) + ", " + + "is_xcorr=" + std::to_string((param().mode == Mode::CROSS_CORRELATION)) + + ", " + "pad_h=" + std::to_string(param().pad_h) + ", " + "pad_w=" + std::to_string(param().pad_w) + ", " + "stride_h=" + std::to_string(param().stride_h) + ", " + "stride_w=" + std::to_string(param().stride_w) + ", " + "dilate_h=" + std::to_string(param().dilate_h) + ", " + "dilate_w=" + std::to_string(param().dilate_w) + ", " + - "spatial_groups_h=" + std::to_string(param().spatial_groups_h) + - ", " + + "spatial_groups_h=" + std::to_string(param().spatial_groups_h) + ", " + "spatial_groups_w=" + std::to_string(param().spatial_groups_w); auto errmsg_c = errmsg.c_str(); MEGDNN_MARK_USED_VAR(errmsg_c); @@ -135,18 +130,20 @@ void LocalShareBackwardData::deduce_layout(const TensorLayout& filter, using Sparse = Param::Sparse; using Format = Param::Format; using ComputeMode = Param::ComputeMode; - megdnn_assert(param().format == Format::NCHW, - "local shared only support NCHW format"); + megdnn_assert( + param().format == Format::NCHW, "local shared only support NCHW format"); megdnn_assert( (filter.ndim == 6_z && param().sparse == Sparse::DENSE) || (filter.ndim == 7_z && param().sparse == Sparse::GROUP), "%s", errmsg_c); megdnn_assert(diff.ndim == 4_z, "%s", errmsg_c); - megdnn_assert(param().dilate_h == 1 && param().dilate_w == 1, - "dilated local shared is not supported"); - megdnn_assert(diff.dtype == dtype::Float32() && - param().computeMode == ComputeMode::DEFAULT, - "local shared only support float32"); + megdnn_assert( + param().dilate_h == 1 && param().dilate_w == 1, + "dilated local shared is not supported"); + megdnn_assert( + diff.dtype == dtype::Float32() && + param().computeMode == ComputeMode::DEFAULT, + "local shared only support float32"); size_t n = diff[0], co = diff[1], ho = diff[2], wo = diff[3]; size_t sgh = param().spatial_groups_h, sgw = param().spatial_groups_w; @@ -160,22 +157,22 @@ void LocalShareBackwardData::deduce_layout(const TensorLayout& filter, groups = filter[0]; weights_shp_pos = 1; } - megdnn_assert(sgh == filter[weights_shp_pos] && - sgw == filter[weights_shp_pos + 1], - "spatial groups in filter tensor mismatch with those " - "provided in parameter %s", - errmsg_c); - size_t ci = filter[weights_shp_pos + 2] * groups, - fh = filter[weights_shp_pos + 3], fw = filter[weights_shp_pos + 4]; - megdnn_assert(filter[weights_shp_pos + 5] * groups == co, - "input channels of src and filter mismatch %s", errmsg_c); + megdnn_assert( + sgh == filter[weights_shp_pos] && sgw == filter[weights_shp_pos + 1], + "spatial groups in filter tensor mismatch with those " + "provided in parameter %s", + errmsg_c); + size_t ci = filter[weights_shp_pos + 2] * groups, fh = filter[weights_shp_pos + 3], + fw = filter[weights_shp_pos + 4]; + megdnn_assert( + filter[weights_shp_pos + 5] * groups == co, + "input channels of src and filter mismatch %s", errmsg_c); size_t sh = param().stride_h; size_t sw = param().stride_w; size_t ph = param().pad_h; size_t pw = param().pad_w; - auto deduce = [&errmsg_c](size_t out, size_t filter, size_t stride, - size_t pad) { + auto deduce = [&errmsg_c](size_t out, size_t filter, size_t stride, size_t pad) { MEGDNN_MARK_USED_VAR(errmsg_c); auto i = (out - 1) * stride + filter; megdnn_assert(i > pad * 2, "%s", errmsg_c); @@ -190,28 +187,25 @@ void LocalShareBackwardData::deduce_layout(const TensorLayout& filter, grad.dtype = diff.dtype; } -void LocalShareBackwardData::check_exec(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes) { - auto filter_dtype = filter.dtype, diff_dtype = diff.dtype, - grad_dtype = grad.dtype; - megdnn_assert(filter_dtype == dtype::Float32() && - filter_dtype == diff_dtype && filter_dtype == grad_dtype); +void LocalShareBackwardData::check_exec( + const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { + auto filter_dtype = filter.dtype, diff_dtype = diff.dtype, grad_dtype = grad.dtype; + megdnn_assert( + filter_dtype == dtype::Float32() && filter_dtype == diff_dtype && + filter_dtype == grad_dtype); check_layout_fwd(grad, filter, diff); - auto required_workspace_in_bytes = - get_workspace_in_bytes(filter, diff, grad); + auto required_workspace_in_bytes = get_workspace_in_bytes(filter, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void LocalShareBackwardFilter::check_exec(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes) { - auto src_dtype = src.dtype, diff_dtype = diff.dtype, - grad_dtype = grad.dtype; - megdnn_assert(src_dtype == dtype::Float32() && src_dtype == diff_dtype && - src_dtype == grad_dtype); +void LocalShareBackwardFilter::check_exec( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { + auto src_dtype = src.dtype, diff_dtype = diff.dtype, grad_dtype = grad.dtype; + megdnn_assert( + src_dtype == dtype::Float32() && src_dtype == diff_dtype && + src_dtype == grad_dtype); check_layout_fwd(src, grad, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); diff --git a/dnn/src/common/lrn.cpp b/dnn/src/common/lrn.cpp index 1745311a..3b194852 100644 --- a/dnn/src/common/lrn.cpp +++ b/dnn/src/common/lrn.cpp @@ -14,19 +14,16 @@ namespace megdnn { -void LRNBase::check_param() -{ +void LRNBase::check_param() { megdnn_assert(param().n & 1); } -void LRNForward::deduce_layout(const TensorLayout &src, TensorLayout &dst) -{ +void LRNForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { dst = src; } -void LRNForward::check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void LRNForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_param(); megdnn_assert_contiguous(src); megdnn_assert_eq_layout(src, dst); @@ -36,22 +33,18 @@ void LRNForward::check_exec(const TensorLayout &src, const TensorLayout &dst, megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void LRNBackward::check_exec(const TensorLayout &src, - const TensorLayout &dst, - const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void LRNBackward::check_exec( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes) { check_param(); megdnn_assert_contiguous(src); megdnn_assert_eq_layout(src, dst); megdnn_assert_eq_layout(src, diff); megdnn_assert_eq_layout(src, grad); megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); - auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst, - diff, grad); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/lsq.cpp b/dnn/src/common/lsq.cpp index c8078ec3..78fbdc49 100644 --- a/dnn/src/common/lsq.cpp +++ b/dnn/src/common/lsq.cpp @@ -15,16 +15,14 @@ namespace megdnn { -void LSQBase::deduce_layout_fwd(const TensorLayout& input, - TensorLayout& output) { +void LSQBase::deduce_layout_fwd(const TensorLayout& input, TensorLayout& output) { output = TensorLayout(input, input.dtype); } -void LSQBase::check_layout_fwd(const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& grad_scale, - const TensorLayout& output) { +void LSQBase::check_layout_fwd( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& grad_scale, + const TensorLayout& output) { megdnn_assert(input.dtype == dtype::Float32()); megdnn_assert(scale.dtype == dtype::Float32()); megdnn_assert(zero_point.dtype == dtype::Float32()); @@ -34,31 +32,28 @@ void LSQBase::check_layout_fwd(const TensorLayout& input, megdnn_assert_eq_layout(expected, output); } -void LSQForward::deduce_layout(const TensorLayout& input, - const TensorLayout& /* scale */, - const TensorLayout& /*zero_point*/, - const TensorLayout& /*grad_scale*/, - TensorLayout& output) { +void LSQForward::deduce_layout( + const TensorLayout& input, const TensorLayout& /* scale */, + const TensorLayout& /*zero_point*/, const TensorLayout& /*grad_scale*/, + TensorLayout& output) { deduce_layout_fwd(input, output); } -void LSQForward::check_exec(const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& zero_point, - const TensorLayout& grad_scale, - const TensorLayout& output, - size_t workspace_in_bytes) { +void LSQForward::check_exec( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& grad_scale, + const TensorLayout& output, size_t workspace_in_bytes) { check_layout_fwd(input, scale, zero_point, grad_scale, output); - auto required_workspace_space = get_workspace_in_bytes( - input, scale, zero_point, grad_scale, output); + auto required_workspace_space = + get_workspace_in_bytes(input, scale, zero_point, grad_scale, output); megdnn_assert(workspace_in_bytes >= required_workspace_space); } void LSQBackward::check_exec( - const TensorLayout& diff, const TensorLayout& input, - const TensorLayout& scale, const TensorLayout& zero_point, - const TensorLayout& grad_scale, const TensorLayout& grad_x, - const TensorLayout& grad_s, size_t workspace_in_bytes) { + const TensorLayout& diff, const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& zero_point, const TensorLayout& grad_scale, + const TensorLayout& grad_x, const TensorLayout& grad_s, + size_t workspace_in_bytes) { megdnn_assert_eq_shape(diff, input); megdnn_assert_eq_shape(grad_x, input); auto required_worspace_space = get_workspace_in_bytes( diff --git a/dnn/src/common/macro_helper.h b/dnn/src/common/macro_helper.h index abd75181..f997000a 100644 --- a/dnn/src/common/macro_helper.h +++ b/dnn/src/common/macro_helper.h @@ -13,10 +13,10 @@ #endif #define MAKE_STR0(v) #v -#define MAKE_STR(v) MAKE_STR0(v) +#define MAKE_STR(v) MAKE_STR0(v) -#define CONCAT_STR0(a, b) a ## b -#define CONCAT_STR(a, b) CONCAT_STR0(a, b) +#define CONCAT_STR0(a, b) a##b +#define CONCAT_STR(a, b) CONCAT_STR0(a, b) //! add _MEGDNN_SIMD_NAME to given prefix #define WITH_SIMD_SUFFIX(prefix) CONCAT_STR(prefix##_, MEGDNN_SIMD_NAME) diff --git a/dnn/src/common/mask_conv.cpp b/dnn/src/common/mask_conv.cpp index 289cee6b..6088d08d 100644 --- a/dnn/src/common/mask_conv.cpp +++ b/dnn/src/common/mask_conv.cpp @@ -18,24 +18,21 @@ void MaskConvForward::deduce_dtype(DType src, DType filter, DType, DType& dst) { check_or_deduce_dtype_fwd(src, filter, dst); } -void MaskConvForward::deduce_layout(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& mask, - TensorLayout& dst) { +void MaskConvForward::deduce_layout( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& mask, + TensorLayout& dst) { deduce_layout_fwd(src, filter, dst); megdnn_assert(dst[2] == mask[0]); megdnn_assert(dst[3] == mask[1]); } -MaskConvForward::CanonizedFilterMeta -MaskConvForward::check_exec(const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& mask, const TensorLayout& dst, - size_t workspace_in_bytes) { +MaskConvForward::CanonizedFilterMeta MaskConvForward::check_exec( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& mask, + const TensorLayout& dst, size_t workspace_in_bytes) { auto ret = check_layout_fwd(src, filter, dst); megdnn_assert(dst[2] == mask[0]); megdnn_assert(dst[3] == mask[1]); - auto required_workspace_in_bytes = - get_workspace_in_bytes(src, filter, mask, dst); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter, mask, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); return ret; } @@ -43,9 +40,10 @@ MaskConvForward::check_exec(const TensorLayout& src, const TensorLayout& filter, void MaskPropagate::deduce_layout(const TensorLayout& src, TensorLayout& dst) { size_t oh, ow; auto p = param(); - infer_conv_shape2d(src[0], src[1], (p.kernel_h - 1) * p.dilate_h + 1, - (p.kernel_w - 1) * p.dilate_w + 1, p.stride_h, - p.stride_w, p.pad_h, p.pad_w, oh, ow); + infer_conv_shape2d( + src[0], src[1], (p.kernel_h - 1) * p.dilate_h + 1, + (p.kernel_w - 1) * p.dilate_w + 1, p.stride_h, p.stride_w, p.pad_h, p.pad_w, + oh, ow); dst = TensorLayout{{oh, ow}, src.dtype}; } diff --git a/dnn/src/common/matrix_inverse.cpp b/dnn/src/common/matrix_inverse.cpp index e6cbb30d..ba942914 100644 --- a/dnn/src/common/matrix_inverse.cpp +++ b/dnn/src/common/matrix_inverse.cpp @@ -19,24 +19,26 @@ void MatrixInverse::deduce_layout(const TensorLayout& src, TensorLayout& dst) { dst = src; } -size_t MatrixInverse::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) { +size_t MatrixInverse::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { size_t batch, n; canonize_params(src, &batch, &n); - megdnn_assert(src.eq_layout(dst), "src and dst unequal: %s vs %s", - src.to_string().c_str(), dst.to_string().c_str()); + megdnn_assert( + src.eq_layout(dst), "src and dst unequal: %s vs %s", + src.to_string().c_str(), dst.to_string().c_str()); return get_workspace_in_bytes(batch, n, src.dtype.size()); } -void MatrixInverse::canonize_params(const TensorLayout& layout, size_t* batch, - size_t* n) { - megdnn_assert(layout.is_contiguous() && layout.ndim >= 2 && - layout[layout.ndim - 2] == layout[layout.ndim - 1], - "invalid MatrixInverse layout: %s", - layout.to_string().c_str()); - megdnn_assert(DNN_FLOAT16_SELECT(layout.dtype == dtype::Float16(), false) || - layout.dtype == dtype::Float32(), - "MatrixInverse only supports f16 & f32"); +void MatrixInverse::canonize_params( + const TensorLayout& layout, size_t* batch, size_t* n) { + megdnn_assert( + layout.is_contiguous() && layout.ndim >= 2 && + layout[layout.ndim - 2] == layout[layout.ndim - 1], + "invalid MatrixInverse layout: %s", layout.to_string().c_str()); + megdnn_assert( + DNN_FLOAT16_SELECT(layout.dtype == dtype::Float16(), false) || + layout.dtype == dtype::Float32(), + "MatrixInverse only supports f16 & f32"); if (batch) { *batch = 1; for (size_t i = 0; i < layout.ndim - 2; ++i) { @@ -48,14 +50,15 @@ void MatrixInverse::canonize_params(const TensorLayout& layout, size_t* batch, } } -void MatrixInverse::check_exec(const TensorLayout& src, const TensorLayout& dst, - _megdnn_workspace workspace, size_t* batch, - size_t* n) { +void MatrixInverse::check_exec( + const TensorLayout& src, const TensorLayout& dst, _megdnn_workspace workspace, + size_t* batch, size_t* n) { canonize_params(src, batch, n); - megdnn_assert(src.eq_layout(dst), "src and dst unequal: %s vs %s", - src.to_string().c_str(), dst.to_string().c_str()); - megdnn_assert(workspace.size >= - get_workspace_in_bytes(*batch, *n, src.dtype.size())); + megdnn_assert( + src.eq_layout(dst), "src and dst unequal: %s vs %s", + src.to_string().c_str(), dst.to_string().c_str()); + megdnn_assert( + workspace.size >= get_workspace_in_bytes(*batch, *n, src.dtype.size())); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/matrix_mul.cpp b/dnn/src/common/matrix_mul.cpp index 04490bb9..2a0b6cc7 100644 --- a/dnn/src/common/matrix_mul.cpp +++ b/dnn/src/common/matrix_mul.cpp @@ -39,23 +39,24 @@ void MatrixMulForward::deduce_dtype(DType A, DType B, DType& C) { if (!C.valid()) { C = C_candi; } - megdnn_assert(C.valid() && (C == C_candi || C == C_candi2), - "unsupported MatMul(%s, %s) -> %s", A.name(), B.name(), - C.name()); + megdnn_assert( + C.valid() && (C == C_candi || C == C_candi2), + "unsupported MatMul(%s, %s) -> %s", A.name(), B.name(), C.name()); } -void MatrixMulForward::deduce_layout(const TensorLayout& A, - const TensorLayout& B, TensorLayout& C) { - megdnn_assert(A.dtype.enumv() == B.dtype.enumv(), - "matmul input should be of same dtype, got %s and %s", - A.dtype.name(), B.dtype.name()); +void MatrixMulForward::deduce_layout( + const TensorLayout& A, const TensorLayout& B, TensorLayout& C) { + megdnn_assert( + A.dtype.enumv() == B.dtype.enumv(), + "matmul input should be of same dtype, got %s and %s", A.dtype.name(), + B.dtype.name()); deduce_dtype(A.dtype, B.dtype, C.dtype); size_t A0, A1, B0, B1; if (param().format == param::MatrixMul::Format::DEFAULT) { - megdnn_assert(A.ndim == 2 && B.ndim == 2, - "matmul requires input to be 2-dimensional; get: %s %s", - A.TensorShape::to_string().c_str(), - B.TensorShape::to_string().c_str()); + megdnn_assert( + A.ndim == 2 && B.ndim == 2, + "matmul requires input to be 2-dimensional; get: %s %s", + A.TensorShape::to_string().c_str(), B.TensorShape::to_string().c_str()); A0 = A.shape[0]; A1 = A.shape[1]; B0 = B.shape[0]; @@ -64,18 +65,20 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A, std::swap(A0, A1); if (m_param.transposeB) std::swap(B0, B1); - megdnn_assert(A1 == B0, - "shape mismatch in matmal: (transposed) A is (%zu,%zu), " - "(transposed) B is (%zu,%zu)", - A0, A1, B0, B1); + megdnn_assert( + A1 == B0, + "shape mismatch in matmal: (transposed) A is (%zu,%zu), " + "(transposed) B is (%zu,%zu)", + A0, A1, B0, B1); C = TensorLayout(TensorShape({A0, B1}), C.dtype); } else { auto do_deduce = [&](size_t pack_size) { - megdnn_assert(A.ndim == 4 && B.ndim == 3, - "matmul requires input dimension to be A(4), B(3); " - "get: %s %s", - A.TensorShape::to_string().c_str(), - B.TensorShape::to_string().c_str()); + megdnn_assert( + A.ndim == 4 && B.ndim == 3, + "matmul requires input dimension to be A(4), B(3); " + "get: %s %s", + A.TensorShape::to_string().c_str(), + B.TensorShape::to_string().c_str()); A0 = A.shape[0]; A1 = A.shape[1]; B0 = B.shape[0]; @@ -84,20 +87,21 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A, std::swap(A0, A1); if (m_param.transposeB) std::swap(B0, B1); - megdnn_assert(A1 == B0, - "shape mismatch in matmal: (transposed) A is " - "(%zu,%zu,4,4), " - "(transposed) B is (%zu,%zu,4)", - A0, A1, B0, B1); + megdnn_assert( + A1 == B0, + "shape mismatch in matmal: (transposed) A is " + "(%zu,%zu,4,4), " + "(transposed) B is (%zu,%zu,4)", + A0, A1, B0, B1); C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype); }; do_deduce(pack_size(param().format)); } } -void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B, - const TensorLayout& C, - size_t workspace_in_bytes) { +void MatrixMulForward::check_exec( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, + size_t workspace_in_bytes) { auto errmsg = [&]() { std::string msg; msg.append("A="); @@ -167,19 +171,20 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B, megdnn_assert(A.dtype == C.dtype); } else if (A.dtype == dtype::Int8()) { megdnn_assert(C.dtype == dtype::Int16() || C.dtype == dtype::Int32()); - } else if (A.dtype.enumv() == DTypeEnum::QuantizedS8 || - A.dtype.enumv() == DTypeEnum::Quantized8Asymm || - A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { + } else if ( + A.dtype.enumv() == DTypeEnum::QuantizedS8 || + A.dtype.enumv() == DTypeEnum::Quantized8Asymm || + A.dtype.enumv() == DTypeEnum::Quantized4Asymm) { megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS32); - } else if(A.dtype.enumv() == DTypeEnum::QuantizedS4){ + } else if (A.dtype.enumv() == DTypeEnum::QuantizedS4) { megdnn_assert(C.dtype.enumv() == DTypeEnum::QuantizedS16); } - megdnn_assert(param().compute_mode != - Param::ComputeMode::FLOAT32 DNN_INC_FLOAT16( - || A.dtype == dtype::Float16() || - A.dtype == dtype::BFloat16()), - "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " - "input / output."); + megdnn_assert( + param().compute_mode != Param::ComputeMode::FLOAT32 DNN_INC_FLOAT16( + || A.dtype == dtype::Float16() || + A.dtype == dtype::BFloat16()), + "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " + "input / output."); auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } diff --git a/dnn/src/common/max_tensor_diff.cpp b/dnn/src/common/max_tensor_diff.cpp index 96a374cf..706ac7bc 100644 --- a/dnn/src/common/max_tensor_diff.cpp +++ b/dnn/src/common/max_tensor_diff.cpp @@ -15,22 +15,24 @@ using namespace megdnn; -void megdnn::MaxTensorDiff::check_exec(const TensorLayout& layout1, - const TensorLayout& layout2, - size_t workspace_in_bytes) { - megdnn_assert(layout1.eq_layout(layout2), "layout1: %s, layout2: %s", - layout1.to_string().c_str(), layout2.to_string().c_str()); +void megdnn::MaxTensorDiff::check_exec( + const TensorLayout& layout1, const TensorLayout& layout2, + size_t workspace_in_bytes) { + megdnn_assert( + layout1.eq_layout(layout2), "layout1: %s, layout2: %s", + layout1.to_string().c_str(), layout2.to_string().c_str()); if (Image2DPack4TensorFormat::is_valid_image(layout1)) { - megdnn_assert(layout1.is_contiguous() && layout1.ndim == 2 && - layout1.shape[0] && layout1.eq_layout(layout2), - "layout1: %s, layout2: %s", layout1.to_string().c_str(), - layout2.to_string().c_str()); + megdnn_assert( + layout1.is_contiguous() && layout1.ndim == 2 && layout1.shape[0] && + layout1.eq_layout(layout2), + "layout1: %s, layout2: %s", layout1.to_string().c_str(), + layout2.to_string().c_str()); } else { - megdnn_assert(layout1.is_contiguous() && - (layout1.ndim == 1 || layout1.ndim == 2) && - layout1.shape[0] && layout1.eq_layout(layout2), - "layout1: %s, layout2: %s", layout1.to_string().c_str(), - layout2.to_string().c_str()); + megdnn_assert( + layout1.is_contiguous() && (layout1.ndim == 1 || layout1.ndim == 2) && + layout1.shape[0] && layout1.eq_layout(layout2), + "layout1: %s, layout2: %s", layout1.to_string().c_str(), + layout2.to_string().c_str()); } auto required_workspace_in_bytes = get_workspace_in_bytes(layout1, layout2); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); diff --git a/dnn/src/common/megcore/common/computing_context.cpp b/dnn/src/common/megcore/common/computing_context.cpp index 35d118fd..c31f2d2d 100644 --- a/dnn/src/common/megcore/common/computing_context.cpp +++ b/dnn/src/common/megcore/common/computing_context.cpp @@ -11,8 +11,8 @@ #include "src/common/utils.h" -#include "./computing_context.hpp" #include "../cpu/default_computing_context.hpp" +#include "./computing_context.hpp" #if MEGDNN_WITH_CUDA #include "src/cuda/megcore/cuda_computing_context.hpp" #endif @@ -33,8 +33,7 @@ using namespace megcore; using namespace megdnn; std::unique_ptr ComputingContext::make( - megcoreDeviceHandle_t dev_handle, unsigned int flags) -{ + megcoreDeviceHandle_t dev_handle, unsigned int flags) { megcorePlatform_t platform; megcoreGetPlatform(dev_handle, &platform); switch (platform) { @@ -50,8 +49,7 @@ std::unique_ptr ComputingContext::make( #endif #if MEGDNN_WITH_CAMBRICON case megcorePlatformCambricon: - return make_unique(dev_handle, - flags); + return make_unique(dev_handle, flags); #endif #if MEGDNN_WITH_ATLAS case megcorePlatformAtlas: diff --git a/dnn/src/common/megcore/common/device_context.cpp b/dnn/src/common/megcore/common/device_context.cpp index adb0c177..074f40fa 100644 --- a/dnn/src/common/megcore/common/device_context.cpp +++ b/dnn/src/common/megcore/common/device_context.cpp @@ -10,8 +10,8 @@ */ #include "./device_context.hpp" -#include "src/common/utils.h" #include "../cpu/default_device_context.hpp" +#include "src/common/utils.h" #if MEGDNN_WITH_CUDA #include "src/cuda/megcore/cuda_device_context.hpp" #endif @@ -29,9 +29,8 @@ using namespace megcore; using namespace megdnn; -std::unique_ptr DeviceContext::make(megcorePlatform_t platform, - int deviceID, unsigned int flags) -{ +std::unique_ptr DeviceContext::make( + megcorePlatform_t platform, int deviceID, unsigned int flags) { switch (platform) { case megcorePlatformCPU: return make_unique(deviceID, flags); @@ -45,8 +44,7 @@ std::unique_ptr DeviceContext::make(megcorePlatform_t platform, #endif #if MEGDNN_WITH_CAMBRICON case megcorePlatformCambricon: - return make_unique(deviceID, - flags); + return make_unique(deviceID, flags); #endif #if MEGDNN_WITH_ATLAS case megcorePlatformAtlas: @@ -59,5 +57,4 @@ std::unique_ptr DeviceContext::make(megcorePlatform_t platform, DeviceContext::~DeviceContext() noexcept = default; - // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/megcore/cpu/api.cpp b/dnn/src/common/megcore/cpu/api.cpp index f8dee8d9..6890a672 100644 --- a/dnn/src/common/megcore/cpu/api.cpp +++ b/dnn/src/common/megcore/cpu/api.cpp @@ -11,22 +11,20 @@ #include "megcore.h" #include "src/common/utils.h" -#include "./default_computing_context.hpp" #include "../common/computing_context.hpp" #include "../public_api/computing.hpp" +#include "./default_computing_context.hpp" using namespace megcore; CPUDispatcher::~CPUDispatcher() noexcept = default; megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( - megcoreComputingHandle_t *compHandle, - megcoreDeviceHandle_t devHandle, - const std::shared_ptr& dispatcher, - unsigned int flags) { - auto content = megdnn::make_unique< - megcore::cpu::DefaultComputingContext>(devHandle, flags); - auto &H = *compHandle; + megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, + const std::shared_ptr& dispatcher, unsigned int flags) { + auto content = megdnn::make_unique( + devHandle, flags); + auto& H = *compHandle; content->set_dispatcher(dispatcher); H = new megcoreComputingContext; H->content = std::move(content); @@ -34,16 +32,17 @@ megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( } CPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle) { - auto &&H = handle; + auto&& H = handle; megdnn_assert(H); // Check device handle. megcoreDeviceHandle_t dev_handle = H->content->dev_handle(); megcorePlatform_t platform; megcoreGetPlatform(dev_handle, &platform); - megdnn_throw_if(!(platform & megcorePlatformCPU), megdnn_error, - "can not be default ComputingContext"); - auto context = static_cast( - H->content.get()); + megdnn_throw_if( + !(platform & megcorePlatformCPU), megdnn_error, + "can not be default ComputingContext"); + auto context = + static_cast(H->content.get()); return context->get_dispatcher(); } diff --git a/dnn/src/common/megcore/cpu/default_computing_context.cpp b/dnn/src/common/megcore/cpu/default_computing_context.cpp index 696d5ded..c56bd37d 100644 --- a/dnn/src/common/megcore/cpu/default_computing_context.cpp +++ b/dnn/src/common/megcore/cpu/default_computing_context.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/common/utils.h" #include "./default_computing_context.hpp" +#include "src/common/utils.h" #include @@ -35,32 +35,29 @@ using namespace megcore; using namespace cpu; DefaultComputingContext::DefaultComputingContext( - megcoreDeviceHandle_t dev_handle, unsigned int flags): - ComputingContext(dev_handle, flags), - m_dispatcher{megdnn::make_unique()} -{ + megcoreDeviceHandle_t dev_handle, unsigned int flags) + : ComputingContext(dev_handle, flags), + m_dispatcher{megdnn::make_unique()} { megcorePlatform_t platform; megcoreGetPlatform(dev_handle, &platform); - megdnn_throw_if(!(platform & megcorePlatformCPU), megdnn_error, - "can not be default ComputingContext"); + megdnn_throw_if( + !(platform & megcorePlatformCPU), megdnn_error, + "can not be default ComputingContext"); } DefaultComputingContext::~DefaultComputingContext() noexcept = default; -void DefaultComputingContext::memcpy(void *dst, const void *src, - size_t size_in_bytes, - megcoreMemcpyKind_t /* kind */) -{ +void DefaultComputingContext::memcpy( + void* dst, const void* src, size_t size_in_bytes, + megcoreMemcpyKind_t /* kind */) { ::memcpy(dst, src, size_in_bytes); } -void DefaultComputingContext::memset(void *dst, int value, size_t size_in_bytes) -{ +void DefaultComputingContext::memset(void* dst, int value, size_t size_in_bytes) { ::memset(dst, value, size_in_bytes); } -void DefaultComputingContext::synchronize() -{ +void DefaultComputingContext::synchronize() { m_dispatcher->sync(); } diff --git a/dnn/src/common/megcore/cpu/default_device_context.cpp b/dnn/src/common/megcore/cpu/default_device_context.cpp index 57d76ef7..19b8b375 100644 --- a/dnn/src/common/megcore/cpu/default_device_context.cpp +++ b/dnn/src/common/megcore/cpu/default_device_context.cpp @@ -11,16 +11,15 @@ #include "src/common/utils.h" -#include "./default_device_context.hpp" #include +#include "./default_device_context.hpp" using namespace megcore; using namespace megcore::cpu; using namespace megdnn; -DefaultDeviceContext::DefaultDeviceContext(int device_id, unsigned int flags): - DeviceContext(megcorePlatformCPU, device_id, flags) -{ +DefaultDeviceContext::DefaultDeviceContext(int device_id, unsigned int flags) + : DeviceContext(megcorePlatformCPU, device_id, flags) { megdnn_assert(device_id == -1); } @@ -30,15 +29,14 @@ size_t DefaultDeviceContext::mem_alignment_in_bytes() const noexcept { return 1; } -void DefaultDeviceContext::activate() noexcept { -} +void DefaultDeviceContext::activate() noexcept {} -void *DefaultDeviceContext::malloc(size_t size_in_bytes) { +void* DefaultDeviceContext::malloc(size_t size_in_bytes) { return new uint8_t[size_in_bytes]; } -void DefaultDeviceContext::free(void *ptr) { - delete []static_cast(ptr); +void DefaultDeviceContext::free(void* ptr) { + delete[] static_cast(ptr); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/megcore/public_api/computing.cpp b/dnn/src/common/megcore/public_api/computing.cpp index b00e9072..126e828a 100644 --- a/dnn/src/common/megcore/public_api/computing.cpp +++ b/dnn/src/common/megcore/public_api/computing.cpp @@ -12,68 +12,57 @@ #include "megcore.h" #include "src/common/utils.h" -#include "./computing.hpp" #include "../common/computing_context.hpp" +#include "./computing.hpp" using namespace megcore; megcoreStatus_t megcoreCreateComputingHandle( - megcoreComputingHandle_t *compHandle, - megcoreDeviceHandle_t devHandle, - unsigned int flags) -{ + megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, + unsigned int flags) { auto ctx = ComputingContext::make(devHandle, flags); - auto &H = *compHandle; + auto& H = *compHandle; H = new megcoreComputingContext; H->content = std::move(ctx); return megcoreSuccess; } -megcoreStatus_t megcoreDestroyComputingHandle( - megcoreComputingHandle_t handle) -{ +megcoreStatus_t megcoreDestroyComputingHandle(megcoreComputingHandle_t handle) { megdnn_assert(handle); delete handle; return megcoreSuccess; } megcoreStatus_t megcoreGetDeviceHandle( - megcoreComputingHandle_t compHandle, - megcoreDeviceHandle_t *devHandle) -{ + megcoreComputingHandle_t compHandle, megcoreDeviceHandle_t* devHandle) { megdnn_assert(compHandle); *devHandle = compHandle->content->dev_handle(); return megcoreSuccess; } megcoreStatus_t megcoreGetComputingFlags( - megcoreComputingHandle_t handle, - unsigned int *flags) -{ + megcoreComputingHandle_t handle, unsigned int* flags) { megdnn_assert(handle); *flags = handle->content->flags(); return megcoreSuccess; } -megcoreStatus_t megcoreMemcpy(megcoreComputingHandle_t handle, - void *dst, const void *src, size_t sizeInBytes, - megcoreMemcpyKind_t kind) -{ +megcoreStatus_t megcoreMemcpy( + megcoreComputingHandle_t handle, void* dst, const void* src, size_t sizeInBytes, + megcoreMemcpyKind_t kind) { megdnn_assert(handle); handle->content->memcpy(dst, src, sizeInBytes, kind); return megcoreSuccess; } -megcoreStatus_t megcoreMemset(megcoreComputingHandle_t handle, - void *dst, int value, size_t sizeInBytes) -{ +megcoreStatus_t megcoreMemset( + megcoreComputingHandle_t handle, void* dst, int value, size_t sizeInBytes) { megdnn_assert(handle); handle->content->memset(dst, value, sizeInBytes); return megcoreSuccess; } -megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle) -{ +megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle) { megdnn_assert(handle); handle->content->synchronize(); return megcoreSuccess; diff --git a/dnn/src/common/megcore/public_api/device.cpp b/dnn/src/common/megcore/public_api/device.cpp index 77b37447..20dc9828 100644 --- a/dnn/src/common/megcore/public_api/device.cpp +++ b/dnn/src/common/megcore/public_api/device.cpp @@ -11,86 +11,74 @@ #include "megcore.h" #include "src/common/utils.h" -#include "./device.hpp" #include "../common/device_context.hpp" +#include "./device.hpp" using namespace megcore; megcoreStatus_t megcoreCreateDeviceHandle( - megcoreDeviceHandle_t *handle, - megcorePlatform_t platform, int deviceID, unsigned int flags) -{ + megcoreDeviceHandle_t* handle, megcorePlatform_t platform, int deviceID, + unsigned int flags) { auto ctx = DeviceContext::make(platform, deviceID, flags); - auto &H = *handle; + auto& H = *handle; H = new megcoreDeviceContext; H->content = std::move(ctx); return megcoreSuccess; } -megcoreStatus_t megcoreDestroyDeviceHandle( - megcoreDeviceHandle_t handle) -{ +megcoreStatus_t megcoreDestroyDeviceHandle(megcoreDeviceHandle_t handle) { megdnn_assert(handle); delete handle; return megcoreSuccess; } -megcoreStatus_t megcoreGetPlatform(megcoreDeviceHandle_t handle, - megcorePlatform_t *platform) -{ +megcoreStatus_t megcoreGetPlatform( + megcoreDeviceHandle_t handle, megcorePlatform_t* platform) { megdnn_assert(handle); *platform = handle->content->platform(); return megcoreSuccess; } -megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, - int *deviceID) -{ +megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, int* deviceID) { megdnn_assert(handle); *deviceID = handle->content->device_id(); return megcoreSuccess; } -megcoreStatus_t megcoreGetDeviceFlags(megcoreDeviceHandle_t handle, - unsigned int *flags) -{ +megcoreStatus_t megcoreGetDeviceFlags( + megcoreDeviceHandle_t handle, unsigned int* flags) { megdnn_assert(handle); *flags = handle->content->flags(); return megcoreSuccess; } -megcoreStatus_t megcoreGetMemAlignment(megcoreDeviceHandle_t handle, - size_t *memAlignmentInBytes) -{ +megcoreStatus_t megcoreGetMemAlignment( + megcoreDeviceHandle_t handle, size_t* memAlignmentInBytes) { megdnn_assert(handle); *memAlignmentInBytes = handle->content->mem_alignment_in_bytes(); return megcoreSuccess; } -megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle) -{ +megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle) { megdnn_assert(handle); handle->content->activate(); return megcoreSuccess; } -megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle) -{ +megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle) { megdnn_assert(handle); handle->content->deactivate(); return megcoreSuccess; } -megcoreStatus_t megcoreMalloc(megcoreDeviceHandle_t handle, - void **devPtr, size_t sizeInBytes) -{ +megcoreStatus_t megcoreMalloc( + megcoreDeviceHandle_t handle, void** devPtr, size_t sizeInBytes) { megdnn_assert(handle); *devPtr = handle->content->malloc(sizeInBytes); return megcoreSuccess; } -megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, void *devPtr) -{ +megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, void* devPtr) { megdnn_assert(handle); handle->content->free(devPtr); return megcoreSuccess; diff --git a/dnn/src/common/megcore/public_api/misc.cpp b/dnn/src/common/megcore/public_api/misc.cpp index 2d79dc73..6dcd0c5f 100644 --- a/dnn/src/common/megcore/public_api/misc.cpp +++ b/dnn/src/common/megcore/public_api/misc.cpp @@ -11,9 +11,10 @@ #include "megcore.h" #include "src/common/utils.h" -const char *megcoreGetErrorName(megcoreStatus_t status) -{ -#define CASE(x) case x: return (#x) +const char* megcoreGetErrorName(megcoreStatus_t status) { +#define CASE(x) \ + case x: \ + return (#x) switch (status) { CASE(megcoreSuccess); CASE(megcoreErrorMemoryAllocation); diff --git a/dnn/src/common/mesh_indexing.cpp b/dnn/src/common/mesh_indexing.cpp index 614fc7e2..9494f713 100644 --- a/dnn/src/common/mesh_indexing.cpp +++ b/dnn/src/common/mesh_indexing.cpp @@ -16,8 +16,9 @@ namespace megdnn { /* ============================== MeshIndexing ============================= */ -void MeshBase::check_exec(const TensorLayout& origin, - const TensorLayout& indexed, const IndexDesc& desc) { +void MeshBase::check_exec( + const TensorLayout& origin, const TensorLayout& indexed, + const IndexDesc& desc) { megdnn_assert(origin.dtype == indexed.dtype); megdnn_assert(origin.ndim == indexed.ndim); for (auto&& index : desc) { @@ -25,58 +26,55 @@ void MeshBase::check_exec(const TensorLayout& origin, } } -void NormalMeshBase::check_exec(const TensorLayout& src, - const TensorLayout& dst, - const IndexDesc& desc) { +void NormalMeshBase::check_exec( + const TensorLayout& src, const TensorLayout& dst, const IndexDesc& desc) { MeshBase::check_exec(src, dst, desc); for (auto&& index : desc) { size_t ndim = index.vec.layout.ndim; - megdnn_assert(ndim == 1, "index must be 1-dim vector, while dim %zu", - ndim); + megdnn_assert(ndim == 1, "index must be 1-dim vector, while dim %zu", ndim); megdnn_assert(dst.shape[index.axis] == index.vec.layout[0]); } } -void BatchedMeshBase::check_exec(const TensorLayout& src, - const TensorLayout& dst, - const IndexDesc& desc) { +void BatchedMeshBase::check_exec( + const TensorLayout& src, const TensorLayout& dst, const IndexDesc& desc) { MeshBase::check_exec(src, dst, desc); - megdnn_assert(src[0] == dst[0], "batch mismatch, src %zu, dst %zu", src[0], - dst[0]); + megdnn_assert(src[0] == dst[0], "batch mismatch, src %zu, dst %zu", src[0], dst[0]); for (auto&& index : desc) { size_t ndim = index.vec.layout.ndim; - megdnn_assert(ndim == 2, "index must be a 2-dim matrix, while ndim %zu", - ndim); - megdnn_assert(dst[0] == index.vec.layout[0] && - dst[index.axis] == index.vec.layout[1], - "require each index shape equals (%zu, %zu), but got " - "(%zu, %zu)", - dst[0], dst[index.axis], index.vec.layout[0], - index.vec.layout[1]); - megdnn_assert(index.axis != 0, - "index axis should be 0-th dim when executing " - "BatchedMeshIndexing"); + megdnn_assert(ndim == 2, "index must be a 2-dim matrix, while ndim %zu", ndim); + megdnn_assert( + dst[0] == index.vec.layout[0] && dst[index.axis] == index.vec.layout[1], + "require each index shape equals (%zu, %zu), but got " + "(%zu, %zu)", + dst[0], dst[index.axis], index.vec.layout[0], index.vec.layout[1]); + megdnn_assert( + index.axis != 0, + "index axis should be 0-th dim when executing " + "BatchedMeshIndexing"); } } -void MeshIndexing::deduce_layout(const TensorLayout& inp, - const IndexDescLayoutOnly& layouts, - TensorLayout& out_layout) { +void MeshIndexing::deduce_layout( + const TensorLayout& inp, const IndexDescLayoutOnly& layouts, + TensorLayout& out_layout) { out_layout = inp; for (auto&& index : layouts) { - megdnn_assert(index.layout.ndim == 1, + megdnn_assert( + index.layout.ndim == 1, "mesh indexing require index being 1-dim vector"); out_layout[index.axis] = index.layout[0]; } out_layout.init_contiguous_stride(); } -void BatchedMeshIndexing::deduce_layout(const TensorLayout& inp, - const IndexDescLayoutOnly& layouts, - TensorLayout& out_layout) { +void BatchedMeshIndexing::deduce_layout( + const TensorLayout& inp, const IndexDescLayoutOnly& layouts, + TensorLayout& out_layout) { out_layout = inp; for (auto&& index : layouts) { - megdnn_assert(index.layout.ndim == 2, + megdnn_assert( + index.layout.ndim == 2, "batch mesh indexing require index being 2-dim matrix"); out_layout[index.axis] = index.layout[1]; } diff --git a/dnn/src/common/named_tensor.cpp b/dnn/src/common/named_tensor.cpp index 0071f9d6..919145c5 100644 --- a/dnn/src/common/named_tensor.cpp +++ b/dnn/src/common/named_tensor.cpp @@ -17,16 +17,13 @@ using namespace megdnn; /* ===================== Dimension ============================ */ const Dimension::Name Dimension::NAME_ALL[] = { - Dimension::Name::N, Dimension::Name::C, Dimension::Name::H, - Dimension::Name::W, Dimension::Name::G, Dimension::Name::K, - Dimension::Name::R, Dimension::Name::S, Dimension::Name::P, - Dimension::Name::Q, + Dimension::Name::N, Dimension::Name::C, Dimension::Name::H, Dimension::Name::W, + Dimension::Name::G, Dimension::Name::K, Dimension::Name::R, Dimension::Name::S, + Dimension::Name::P, Dimension::Name::Q, }; const int Dimension::NR_NAMES = sizeof(Dimension::NAME_ALL); Dimension::Dimension(const std::string& expr) { - auto errmsg = [&]() { - return ssprintf("Invalid dimension(%s)", expr.c_str()); - }; + auto errmsg = [&]() { return ssprintf("Invalid dimension(%s)", expr.c_str()); }; const char* data = expr.data(); bool has_stride = false; bool has_extent = false; @@ -44,13 +41,13 @@ Dimension::Dimension(const std::string& expr) { megdnn_throw_if(!init_name, megdnn_error, errmsg().c_str()); ++data; } else if (data[0] == '/' && data[1] == '/') { - megdnn_throw_if(!init_name || has_stride || has_extent, - megdnn_error, errmsg().c_str()); + megdnn_throw_if( + !init_name || has_stride || has_extent, megdnn_error, + errmsg().c_str()); has_stride = true; data += 2; } else if (data[0] == '%') { - megdnn_throw_if(!init_name || has_extent, megdnn_error, - errmsg().c_str()); + megdnn_throw_if(!init_name || has_extent, megdnn_error, errmsg().c_str()); has_extent = true; ++data; } else if (data[0] >= '0' && data[0] <= '9') { @@ -85,8 +82,7 @@ Dimension& Dimension::operator=(const Dimension& rhs) { } bool Dimension::operator==(const Dimension& rhs) const { - return m_name == rhs.m_name && m_stride == rhs.m_stride && - m_extent == rhs.m_extent; + return m_name == rhs.m_name && m_stride == rhs.m_stride && m_extent == rhs.m_extent; } bool Dimension::operator<(const Dimension& rhs) const { @@ -100,10 +96,11 @@ bool Dimension::operator<(const Dimension& rhs) const { } Dimension Dimension::operator*(const Dimension& rhs) const { - megdnn_assert(m_name == rhs.m_name, - "Multiply operation cannot be applied on dimensions with " - "different name(lhs:%c, rhs:%c)", - static_cast(m_name), static_cast(rhs.m_name)); + megdnn_assert( + m_name == rhs.m_name, + "Multiply operation cannot be applied on dimensions with " + "different name(lhs:%c, rhs:%c)", + static_cast(m_name), static_cast(rhs.m_name)); megdnn_assert( m_stride == rhs.m_stride * rhs.m_extent, "Multiply operation cannot be applied on operands(lhs:%s, rhs:%s)", @@ -114,41 +111,44 @@ Dimension Dimension::operator*(const Dimension& rhs) const { } Dimension Dimension::operator/(const Dimension& rhs) const { - megdnn_assert(m_name == rhs.m_name, - "Divide operation cannot be applied on dimensions with " - "different name(lhs:%c, rhs:%c)", - static_cast(m_name), static_cast(rhs.m_name)); + megdnn_assert( + m_name == rhs.m_name, + "Divide operation cannot be applied on dimensions with " + "different name(lhs:%c, rhs:%c)", + static_cast(m_name), static_cast(rhs.m_name)); if (operator==(rhs)) return Dimension(m_name, 1, 1); if (m_stride == rhs.m_stride) { if (m_extent == UNDETERMINED_EXTENT) { - megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT, - "Divide operation cannot be applied on " - "operands(dividend:%s, divisor:%s)", - to_string().c_str(), rhs.to_string().c_str()); + megdnn_assert( + rhs.m_extent != UNDETERMINED_EXTENT, + "Divide operation cannot be applied on " + "operands(dividend:%s, divisor:%s)", + to_string().c_str(), rhs.to_string().c_str()); return Dimension(m_name, rhs.m_extent * m_stride); } else { - megdnn_assert(m_extent % rhs.m_extent == 0, - "Divide operation cannot be applied on " - "operands(dividend:%s, divisor:%s)", - to_string().c_str(), rhs.to_string().c_str()); - return Dimension(m_name, rhs.m_extent * m_stride, - m_extent / rhs.m_extent); + megdnn_assert( + m_extent % rhs.m_extent == 0, + "Divide operation cannot be applied on " + "operands(dividend:%s, divisor:%s)", + to_string().c_str(), rhs.to_string().c_str()); + return Dimension(m_name, rhs.m_extent * m_stride, m_extent / rhs.m_extent); } } else { if (m_extent == UNDETERMINED_EXTENT) { - megdnn_assert(rhs.m_extent == UNDETERMINED_EXTENT && - rhs.m_stride % m_stride == 0, - "Divide operation cannot be applied on " - "operands(dividend:%s, divisor:%s)", - to_string().c_str(), rhs.to_string().c_str()); + megdnn_assert( + rhs.m_extent == UNDETERMINED_EXTENT && rhs.m_stride % m_stride == 0, + "Divide operation cannot be applied on " + "operands(dividend:%s, divisor:%s)", + to_string().c_str(), rhs.to_string().c_str()); return Dimension(m_name, m_stride, rhs.m_stride / m_stride); } else { - megdnn_assert(m_extent * m_stride == rhs.m_extent * rhs.m_stride && - rhs.m_stride % m_stride == 0, - "Divide operation cannot be applied on " - "operands(dividend:%s, divisor:%s)", - to_string().c_str(), rhs.to_string().c_str()); + megdnn_assert( + m_extent * m_stride == rhs.m_extent * rhs.m_stride && + rhs.m_stride % m_stride == 0, + "Divide operation cannot be applied on " + "operands(dividend:%s, divisor:%s)", + to_string().c_str(), rhs.to_string().c_str()); return Dimension(m_name, m_stride, m_extent / rhs.m_extent); } } @@ -164,18 +164,19 @@ std::string Dimension::to_string() const { if (m_stride == 1) return ssprintf("%c%%%u", static_cast(m_name), m_extent); else - return ssprintf("%c//%u%%%u", static_cast(m_name), m_stride, - m_extent); + return ssprintf( + "%c//%u%%%u", static_cast(m_name), m_stride, m_extent); } } /* ===================== NamedTensorShape ===================== */ NamedTensorShape::NamedTensorShape(const SmallVector& init_shape) { - megdnn_assert(init_shape.size() <= MAX_NDIM, - "Illegal to construct a NamedTensorShape with " - "more than MAX_NDIM(%zu) axes; got(%zu)", - MAX_NDIM, init_shape.size()); + megdnn_assert( + init_shape.size() <= MAX_NDIM, + "Illegal to construct a NamedTensorShape with " + "more than MAX_NDIM(%zu) axes; got(%zu)", + MAX_NDIM, init_shape.size()); ndim = init_shape.size(); memcpy(this->dims.data(), init_shape.data(), sizeof(Dimension) * ndim); } @@ -246,9 +247,8 @@ NamedTensorShape NamedTensorShape::make_named_tensor_shape(Format format) { case Format::NCHW44_DOT: return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"N%4"}, {"C%4"}}; default: - megdnn_throw( - ssprintf("Format unimplement(%d)", static_cast(format)) - .c_str()); + megdnn_throw(ssprintf("Format unimplement(%d)", static_cast(format)) + .c_str()); } } // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/nchw_nchwxx_valid.cpp b/dnn/src/common/nchw_nchwxx_valid.cpp index 0df3ec85..6f70723c 100644 --- a/dnn/src/common/nchw_nchwxx_valid.cpp +++ b/dnn/src/common/nchw_nchwxx_valid.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "megdnn/oprs/nn.h" #include "src/common/nchw_nchwxx_valid.h" +#include "megdnn/oprs/nn.h" using namespace megdnn; namespace { using NchwNchwxxFuncInterface = std::function::CanonizedFilterMeta& fm, - const BiasMode bias_mode, - const param::ConvBias::NonlineMode nonline_mode); + const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode); template <> inline bool nchw_nchwxx_valid( const DTypeEnum src_dtype, const DTypeEnum filter_dtype, const DTypeEnum dst_dtype, const ConvolutionBase::CanonizedFilterMeta& fm, - const BiasMode bias_mode, - const param::ConvBias::NonlineMode nonline_mode) { - bool ok_type = ((src_dtype == DTypeEnum::Float32 && - filter_dtype == DTypeEnum::Float32 && - (dst_dtype == DTypeEnum::Float32))) && - (fm.format == param::Convolution::Format::NCHW44); + const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode) { + bool ok_type = + ((src_dtype == DTypeEnum::Float32 && filter_dtype == DTypeEnum::Float32 && + (dst_dtype == DTypeEnum::Float32))) && + (fm.format == param::Convolution::Format::NCHW44); bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY || nonline_mode == param::ConvBias::NonlineMode::RELU || nonline_mode == param::ConvBias::NonlineMode::H_SWISH; @@ -48,14 +46,14 @@ inline bool nchw_nchwxx_valid( fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && - (fm.spatial[0] == 2 || fm.spatial[0] == 3 || - fm.spatial[0] == 5 || fm.spatial[0] == 7); + (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || + fm.spatial[0] == 7); bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && (fm.stride[0] == 1 || fm.stride[1] == 2); bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; - bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && - ok_slide && ok_conv; + bool avaible = + ok_type && ok_nonline && ok_src_dst && ok_filter && ok_slide && ok_conv; return avaible; } template <> @@ -63,8 +61,7 @@ inline bool nchw_nchwxx_valid( const DTypeEnum src_dtype, const DTypeEnum filter_dtype, const DTypeEnum dst_dtype, const ConvolutionBase::CanonizedFilterMeta& fm, - const BiasMode bias_mode, - const param::ConvBias::NonlineMode nonline_mode) { + const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode) { bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 && filter_dtype == DTypeEnum::QuantizedS8 && (dst_dtype == DTypeEnum::QuantizedS8))) && @@ -74,17 +71,16 @@ inline bool nchw_nchwxx_valid( nonline_mode == param::ConvBias::NonlineMode::H_SWISH; bool ok_src_dst = fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; - bool ok_filter = - fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && - (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || - fm.spatial[0] == 7 || - (fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); + bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && + (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || + fm.spatial[0] == 7 || + (fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && (fm.stride[0] == 1 || fm.stride[1] == 2); bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; - bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && - ok_slide && ok_conv; + bool avaible = + ok_type && ok_nonline && ok_src_dst && ok_filter && ok_slide && ok_conv; return avaible; } template <> @@ -92,24 +88,22 @@ inline bool nchw_nchwxx_valid( const DTypeEnum src_dtype, const DTypeEnum filter_dtype, const DTypeEnum dst_dtype, const ConvolutionBase::CanonizedFilterMeta& fm, - const BiasMode bias_mode, - const param::ConvBias::NonlineMode nonline_mode) { - bool ok_type = - ((src_dtype == DTypeEnum::Int8 && filter_dtype == DTypeEnum::Int8 && - (dst_dtype == DTypeEnum::Int16))) && - (fm.format == param::Convolution::Format::NCHW44); + const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode) { + bool ok_type = ((src_dtype == DTypeEnum::Int8 && filter_dtype == DTypeEnum::Int8 && + (dst_dtype == DTypeEnum::Int16))) && + (fm.format == param::Convolution::Format::NCHW44); bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY; bool ok_src_dst = fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && - (fm.spatial[0] == 2 || fm.spatial[0] == 3 || - fm.spatial[0] == 5 || fm.spatial[0] == 7); + (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || + fm.spatial[0] == 7); bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && (fm.stride[0] == 2 || fm.stride[0] == 1); bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; - bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && - ok_slide && ok_conv; + bool avaible = + ok_type && ok_nonline && ok_src_dst && ok_filter && ok_slide && ok_conv; return avaible; } template <> @@ -117,8 +111,7 @@ inline bool nchw_nchwxx_valid( const DTypeEnum src_dtype, const DTypeEnum filter_dtype, const DTypeEnum dst_dtype, const ConvolutionBase::CanonizedFilterMeta& fm, - const BiasMode bias_mode, - const param::ConvBias::NonlineMode nonline_mode) { + const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode) { bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 && filter_dtype == DTypeEnum::QuantizedS8 && (dst_dtype == DTypeEnum::QuantizedS8))) && @@ -128,17 +121,16 @@ inline bool nchw_nchwxx_valid( nonline_mode == param::ConvBias::NonlineMode::H_SWISH; bool ok_src_dst = fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1; - bool ok_filter = - fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && - (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || - fm.spatial[0] == 7 || - (fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); + bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] && + (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 || + fm.spatial[0] == 7 || + (fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1)); bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] && (fm.stride[0] == 1 || fm.stride[1] == 2); bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; - bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter && - ok_slide && ok_conv; + bool avaible = + ok_type && ok_nonline && ok_src_dst && ok_filter && ok_slide && ok_conv; return avaible; } @@ -147,12 +139,11 @@ inline bool nchw_nchwxx_valid( const DTypeEnum src_dtype, const DTypeEnum filter_dtype, const DTypeEnum dst_dtype, const ConvolutionBase::CanonizedFilterMeta& fm, - const BiasMode bias_mode, - const param::ConvBias::NonlineMode ) { - bool ok_type = ((src_dtype == DTypeEnum::Float32 && - filter_dtype == DTypeEnum::Float32 && - (dst_dtype == DTypeEnum::Float32))) && - (fm.format == param::Convolution::Format::NCHW88); + const BiasMode bias_mode, const param::ConvBias::NonlineMode) { + bool ok_type = + ((src_dtype == DTypeEnum::Float32 && filter_dtype == DTypeEnum::Float32 && + (dst_dtype == DTypeEnum::Float32))) && + (fm.format == param::Convolution::Format::NCHW88); bool ok_src_dst = fm.icpg < 8 && (fm.ocpg % 8 == 0 && fm.ocpg >= 8) && fm.group == 1; bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS; diff --git a/dnn/src/common/padding.cpp b/dnn/src/common/padding.cpp index eda3e74f..682bd702 100644 --- a/dnn/src/common/padding.cpp +++ b/dnn/src/common/padding.cpp @@ -20,15 +20,15 @@ namespace megdnn { using padding_param = megdnn::param_enumv::Padding; -void PaddingForward::forward_check_exec(const TensorLayout& src, - const TensorLayout& dst) { +void PaddingForward::forward_check_exec( + const TensorLayout& src, const TensorLayout& dst) { check_exec(src, dst); - megdnn_assert(src.dtype.enumv() != DTypeEnum::Bool && - src.dtype.enumv() != DTypeEnum::IntB1 && - src.dtype.enumv() != DTypeEnum::IntB2 && - src.dtype.enumv() != DTypeEnum::IntB4, - "unsupported %s dtype for forward padding opr", - src.dtype.name()); + megdnn_assert( + src.dtype.enumv() != DTypeEnum::Bool && + src.dtype.enumv() != DTypeEnum::IntB1 && + src.dtype.enumv() != DTypeEnum::IntB2 && + src.dtype.enumv() != DTypeEnum::IntB4, + "unsupported %s dtype for forward padding opr", src.dtype.name()); } void PaddingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { @@ -39,26 +39,30 @@ void PaddingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { dst_shape = {src.shape[0] + offsets[0] + offsets[1]}; break; case 2: - dst_shape = {src.shape[0] + offsets[0] + offsets[1], - src.shape[1] + offsets[2] + offsets[3]}; + dst_shape = { + src.shape[0] + offsets[0] + offsets[1], + src.shape[1] + offsets[2] + offsets[3]}; break; case 3: - dst_shape = {src.shape[0] + offsets[0] + offsets[1], - src.shape[1] + offsets[2] + offsets[3], - src.shape[2] + offsets[4] + offsets[5]}; + dst_shape = { + src.shape[0] + offsets[0] + offsets[1], + src.shape[1] + offsets[2] + offsets[3], + src.shape[2] + offsets[4] + offsets[5]}; break; case 4: - dst_shape = {src.shape[0] + offsets[0] + offsets[1], - src.shape[1] + offsets[2] + offsets[3], - src.shape[2] + offsets[4] + offsets[5], - src.shape[3] + offsets[6] + offsets[7]}; + dst_shape = { + src.shape[0] + offsets[0] + offsets[1], + src.shape[1] + offsets[2] + offsets[3], + src.shape[2] + offsets[4] + offsets[5], + src.shape[3] + offsets[6] + offsets[7]}; break; case 5: - dst_shape = {src.shape[0] + offsets[0] + offsets[1], - src.shape[1] + offsets[2] + offsets[3], - src.shape[2] + offsets[4] + offsets[5], - src.shape[3] + offsets[6] + offsets[7], - src.shape[4] + offsets[8] + offsets[9]}; + dst_shape = { + src.shape[0] + offsets[0] + offsets[1], + src.shape[1] + offsets[2] + offsets[3], + src.shape[2] + offsets[4] + offsets[5], + src.shape[3] + offsets[6] + offsets[7], + src.shape[4] + offsets[8] + offsets[9]}; break; case 6: dst_shape = {src.shape[0] + offsets[0] + offsets[1], @@ -84,26 +88,24 @@ void PaddingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { dst = TensorLayout(dst_shape, src.dtype); } -void PaddingBackward::backward_check_exec(const TensorLayout& src, - const TensorLayout& dst) { +void PaddingBackward::backward_check_exec( + const TensorLayout& src, const TensorLayout& dst) { check_exec(dst, src); - megdnn_assert(src.dtype.enumv() == - DTypeEnum::Float32 DNN_INC_FLOAT16( - || src.dtype.enumv() == DTypeEnum::Float16 || - src.dtype.enumv() == DTypeEnum::BFloat16), - "unsupported %s dtype for forward padding opr", - src.dtype.name()); + megdnn_assert( + src.dtype.enumv() == DTypeEnum::Float32 DNN_INC_FLOAT16( + || src.dtype.enumv() == DTypeEnum::Float16 || + src.dtype.enumv() == DTypeEnum::BFloat16), + "unsupported %s dtype for forward padding opr", src.dtype.name()); } SmallVector PaddingBase::get_offsets() { - SmallVector offsets = { - param().front_offset_dim0, param().back_offset_dim0, - param().front_offset_dim1, param().back_offset_dim1, - param().front_offset_dim2, param().back_offset_dim2, - param().front_offset_dim3, param().back_offset_dim3, - param().front_offset_dim4, param().back_offset_dim4, - param().front_offset_dim5, param().back_offset_dim5, - param().front_offset_dim6, param().back_offset_dim6}; + SmallVector offsets = {param().front_offset_dim0, param().back_offset_dim0, + param().front_offset_dim1, param().back_offset_dim1, + param().front_offset_dim2, param().back_offset_dim2, + param().front_offset_dim3, param().back_offset_dim3, + param().front_offset_dim4, param().back_offset_dim4, + param().front_offset_dim5, param().back_offset_dim5, + param().front_offset_dim6, param().back_offset_dim6}; return offsets; } @@ -114,29 +116,31 @@ void PaddingBase::check_exec(const TensorLayout& src, const TensorLayout& dst) { // make sure src and dst is same dtype megdnn_assert_eq_dtype(src, dst); // make sure src and dst is same ndim - megdnn_assert(src.ndim == dst.ndim, "the src.ndim = %zu the dst.ndim = %zu", - src.ndim, dst.ndim); + megdnn_assert( + src.ndim == dst.ndim, "the src.ndim = %zu the dst.ndim = %zu", src.ndim, + dst.ndim); // make sure in every dimension dst is equal or greater than src for (size_t i = 0; i < src.ndim; ++i) { - megdnn_assert(dst.shape[i] == - src.shape[i] + offsets[i * 2] + offsets[i * 2 + 1]); + megdnn_assert( + dst.shape[i] == src.shape[i] + offsets[i * 2] + offsets[i * 2 + 1]); } // check the padding mode is valid - megdnn_assert(static_cast(param().padding_mode) == - padding_param::PaddingMode::REFLECT || - static_cast(param().padding_mode) == - padding_param::PaddingMode::REPLICATE || - static_cast(param().padding_mode) == - padding_param::PaddingMode::CONSTANT, - "unsupported padding mode"); + megdnn_assert( + static_cast(param().padding_mode) == + padding_param::PaddingMode::REFLECT || + static_cast(param().padding_mode) == + padding_param::PaddingMode::REPLICATE || + static_cast(param().padding_mode) == + padding_param::PaddingMode::CONSTANT, + "unsupported padding mode"); // addition check for reflect padding, make sure the reflected index is // valid if (static_cast(param().padding_mode) == padding_param::PaddingMode::REFLECT) { for (size_t i = 0; i < src.ndim; ++i) { - megdnn_assert(offsets[i * 2] < src.shape[i] && - dst.shape[i] - offsets[i * 2] - src.shape[i] < - src.shape[i]); + megdnn_assert( + offsets[i * 2] < src.shape[i] && + dst.shape[i] - offsets[i * 2] - src.shape[i] < src.shape[i]); } } } diff --git a/dnn/src/common/param_pack.cpp b/dnn/src/common/param_pack.cpp index 407981ae..f6fea62e 100644 --- a/dnn/src/common/param_pack.cpp +++ b/dnn/src/common/param_pack.cpp @@ -14,29 +14,31 @@ using namespace megdnn; -void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated, - const TensorLayout& offsets, - const TensorLayout& parts) { - megdnn_assert(offsets.dtype == dtype::Int32{}, "bad dtype: %s", - offsets.dtype.name()); - megdnn_assert(concated.ndim == 1 && offsets.ndim == 1 && parts.ndim == 1 && - concated.stride[0] == 1 && offsets.stride[0] == 1 && - parts.stride[0] == 1, - "bad layout: concated=%s offsets=%s parts=%s", - concated.to_string().c_str(), offsets.to_string().c_str(), - parts.to_string().c_str()); +void ParamPackConcatSplitBase::check_exec( + const TensorLayout& concated, const TensorLayout& offsets, + const TensorLayout& parts) { + megdnn_assert( + offsets.dtype == dtype::Int32{}, "bad dtype: %s", offsets.dtype.name()); + megdnn_assert( + concated.ndim == 1 && offsets.ndim == 1 && parts.ndim == 1 && + concated.stride[0] == 1 && offsets.stride[0] == 1 && + parts.stride[0] == 1, + "bad layout: concated=%s offsets=%s parts=%s", concated.to_string().c_str(), + offsets.to_string().c_str(), parts.to_string().c_str()); } std::vector ParamPackConcatSplitBase::gen_offsets( const TensorShapeArray& shapes, size_t alignment, size_t dtype_size) { - megdnn_assert(alignment && (alignment & (alignment - 1)) == 0, - "alignment must be power of 2: %zu", alignment); + megdnn_assert( + alignment && (alignment & (alignment - 1)) == 0, + "alignment must be power of 2: %zu", alignment); if (alignment < dtype_size) alignment = dtype_size; - megdnn_assert(alignment % dtype_size == 0, - "alignment must be multiple of dtype size: %zu vs %zu", - alignment, dtype_size); + megdnn_assert( + alignment % dtype_size == 0, + "alignment must be multiple of dtype size: %zu vs %zu", alignment, + dtype_size); alignment /= dtype_size; auto get_aligned = [alignment](size_t v) { diff --git a/dnn/src/common/pooling.cpp b/dnn/src/common/pooling.cpp index 32a11d4e..5d99b4d8 100644 --- a/dnn/src/common/pooling.cpp +++ b/dnn/src/common/pooling.cpp @@ -15,8 +15,7 @@ namespace megdnn { -void PoolingBase::deduce_layout_fwd(const TensorLayout& src, - TensorLayout& dst) { +void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " + "pad_h=" + std::to_string(param().pad_h) + ", " + @@ -26,9 +25,8 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, "window_h=" + std::to_string(param().window_h) + ", " + "window_w=" + std::to_string(param().window_w) + ", " + "is_max=" + std::to_string(param().mode == Mode::MAX) + ", " + - "is_nhwc=" + std::to_string(param().format == Param::Format::NHWC) + - ", " + "is_nhwcd4=" + - std::to_string(param().format == Param::Format::NHWCD4); + "is_nhwc=" + std::to_string(param().format == Param::Format::NHWC) + ", " + + "is_nhwcd4=" + std::to_string(param().format == Param::Format::NHWCD4); auto errmsg_c = errmsg.c_str(); MEGDNN_MARK_USED_VAR(errmsg_c); @@ -44,11 +42,12 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, spatial_pos = 1; c_pos = 3; - } else if (param().format == Param::Format::NCHW4 || - param().format == Param::Format::NCHW44 || - param().format == Param::Format::NCHW88 || - param().format == Param::Format::NCHW32 || - param().format == Param::Format::NCHW64) { + } else if ( + param().format == Param::Format::NCHW4 || + param().format == Param::Format::NCHW44 || + param().format == Param::Format::NCHW88 || + param().format == Param::Format::NCHW32 || + param().format == Param::Format::NCHW64) { megdnn_assert(src.ndim == 5_z, "%s", errmsg_c); spatial_pos = 2; @@ -59,8 +58,8 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, batch_pos = 3; } else { megdnn_assert( - param().format == Param::Format::NHWCD4 && src.ndim == 5_z, - "%s", errmsg_c); + param().format == Param::Format::NHWCD4 && src.ndim == 5_z, "%s", + errmsg_c); spatial_pos = 1; c_pos = 2; } @@ -103,11 +102,11 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, if (param().format == Param::Format::NCHW) { dst = TensorLayout(TensorShape({n, c, oh, ow}), src.dtype); } else if (param().format == Param::Format::NHWC) { - megdnn_assert(param().format == Param::Format::NHWC, - "invalid pooling format"); + megdnn_assert(param().format == Param::Format::NHWC, "invalid pooling format"); dst = TensorLayout({n, oh, ow, c}, src.dtype, src.format); - } else if (param().format == Param::Format::NCHW4 || - param().format == Param::Format::NCHW44) { + } else if ( + param().format == Param::Format::NCHW4 || + param().format == Param::Format::NCHW44) { dst = TensorLayout{{n, c / 4, oh, ow, 4}, src.dtype, src.format}; } else if (param().format == Param::Format::NCHW88) { dst = TensorLayout{{n, c / 8, oh, ow, 8}, src.dtype, src.format}; @@ -118,46 +117,42 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, } else if (param().format == Param::Format::CHWN4) { dst = TensorLayout{{c / 4, oh, ow, n, 4}, src.dtype, src.format}; } else { - megdnn_assert(param().format == Param::Format::NHWCD4, - "invalid pooling format"); + megdnn_assert( + param().format == Param::Format::NHWCD4, "invalid pooling format"); dst = TensorLayout{{n, oh, c / 4, ow, 4}, src.dtype, src.format}; } } -void PoolingBase::check_layout_fwd(const TensorLayout& src, - const TensorLayout& dst) { +void PoolingBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { TensorLayout dst_expected; megdnn_assert_eq_dtype(src, dst); deduce_layout_fwd(src, dst_expected); megdnn_assert_eq_layout(dst_expected, dst); megdnn_assert(src.dtype == dst.dtype); - megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT || - src.dtype == dtype::Int8() || - src.dtype.category() == DTypeCategory::QUANTIZED); + megdnn_assert( + src.dtype.category() == DTypeCategory::FLOAT || + src.dtype == dtype::Int8() || + src.dtype.category() == DTypeCategory::QUANTIZED); } void PoolingForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } -void PoolingForward::check_exec(const TensorLayout& src, - const TensorLayout& dst, - size_t workspace_in_bytes) { +void PoolingForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void PoolingBackward::check_exec(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes) { +void PoolingBackward::check_exec( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(src, dst); megdnn_assert_eq_layout(src, grad); megdnn_assert_eq_layout(dst, diff); - auto required_workspace_in_bytes = - get_workspace_in_bytes(src, dst, diff, grad); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } diff --git a/dnn/src/common/pooling/do_max_pooling_3x3_s2x2_float_decl.inl b/dnn/src/common/pooling/do_max_pooling_3x3_s2x2_float_decl.inl index d50fa794..f78a2656 100644 --- a/dnn/src/common/pooling/do_max_pooling_3x3_s2x2_float_decl.inl +++ b/dnn/src/common/pooling/do_max_pooling_3x3_s2x2_float_decl.inl @@ -25,14 +25,12 @@ namespace megdnn { #define FUNC_NAME CONCAT_STR(do_max_pooling_3x3_s2x2_float_, MEGDNN_SIMD_NAME) -void FUNC_NAME(const float *src, float *dst, - size_t IH_, size_t IW_, size_t OH_, size_t OW_, size_t PH_, size_t PW_, - const WorkspaceBundle& ws) -MEGDNN_SIMD_ATTRIBUTE_TARGET; +void FUNC_NAME( + const float* src, float* dst, size_t IH_, size_t IW_, size_t OH_, size_t OW_, + size_t PH_, size_t PW_, const WorkspaceBundle& ws) MEGDNN_SIMD_ATTRIBUTE_TARGET; #undef FUNC_NAME -} +} // namespace megdnn #include "src/common/macro_helper_epilogue.h" - diff --git a/dnn/src/common/pooling/do_max_pooling_3x3_s2x2_float_def.inl b/dnn/src/common/pooling/do_max_pooling_3x3_s2x2_float_def.inl index 4ed1b09d..534a72e4 100644 --- a/dnn/src/common/pooling/do_max_pooling_3x3_s2x2_float_def.inl +++ b/dnn/src/common/pooling/do_max_pooling_3x3_s2x2_float_def.inl @@ -20,55 +20,54 @@ #include "src/common/utils.h" -#include "src/common/macro_helper.h" -#include #include +#include +#include "src/common/macro_helper.h" namespace megdnn { #define FUNC_NAME CONCAT_STR(do_max_pooling_3x3_s2x2_float_, MEGDNN_SIMD_NAME) MEGDNN_SIMD_ATTRIBUTE_TARGET -void FUNC_NAME(const float *src, float *dst, - size_t IH_, size_t IW_, size_t OH_, size_t OW_, size_t PH_, size_t PW_, - const WorkspaceBundle& ws) -{ +void FUNC_NAME( + const float* src, float* dst, size_t IH_, size_t IW_, size_t OH_, size_t OW_, + size_t PH_, size_t PW_, const WorkspaceBundle& ws) { int IH = IH_, IW = IW_, OH = OH_, OW = OW_, PH = PH_, PW = PW_; // cache[i] stores the answer of the i-th line after // pooling along the W dimension. - float* cache[3] = {static_cast(ws.get(0)), - static_cast(ws.get(1)), - static_cast(ws.get(2))}; + float* cache[3] = { + static_cast(ws.get(0)), static_cast(ws.get(1)), + static_cast(ws.get(2))}; float* odd = static_cast(ws.get(3)); float* even = static_cast(ws.get(4)); int ih_next = 0; // "good" area means we can use SIMD to accelerate. - auto get_good_area = [](int I, int /* O */, int P, int &O_from, int &O_to) { + auto get_good_area = [](int I, int /* O */, int P, int& O_from, int& O_to) { // x*2 - P >= 0; 2x >= P; x >= P/2 - O_from = (P+1) / 2; + O_from = (P + 1) / 2; // x*2 - P + 3 <= I; x*2 <= I+P-3; x <= (I+P-3)/2 - O_to = (I+P-3) / 2 + 1; + O_to = (I + P - 3) / 2 + 1; // we must have I >= 2 to ensure O_from <= O_to }; int OW_from, OW_to; get_good_area(IW, OW, PW, OW_from, OW_to); auto process_cache = [&](int ih) MEGDNN_SIMD_LAMBDA_ATTRIBUTE_TARGET { - const float * __restrict sptr = src + ih*IW; + const float* __restrict sptr = src + ih * IW; auto tmp = cache[2]; cache[2] = cache[1]; cache[1] = cache[0]; cache[0] = tmp; // cache 0 is used to store the current answer. auto run_single = [&](int ow) { - int iw = ow*2 - PW; + int iw = ow * 2 - PW; float res = std::numeric_limits::lowest(); - if (iw+0 >= 0 && iw+0 < IW) { - res = std::max(res, sptr[iw+0]); + if (iw + 0 >= 0 && iw + 0 < IW) { + res = std::max(res, sptr[iw + 0]); } - if (iw+1 >= 0 && iw+1 < IW) { - res = std::max(res, sptr[iw+1]); + if (iw + 1 >= 0 && iw + 1 < IW) { + res = std::max(res, sptr[iw + 1]); } - if (iw+2 >= 0 && iw+2 < IW) { - res = std::max(res, sptr[iw+2]); + if (iw + 2 >= 0 && iw + 2 < IW) { + res = std::max(res, sptr[iw + 2]); } cache[0][ow] = res; }; @@ -76,7 +75,7 @@ void FUNC_NAME(const float *src, float *dst, int iw = 0; int odd_offset = 0, even_offset = 0; - for (; iw+2*MEGDNN_SIMD_WIDTH <= IW; iw += 2*MEGDNN_SIMD_WIDTH) { + for (; iw + 2 * MEGDNN_SIMD_WIDTH <= IW; iw += 2 * MEGDNN_SIMD_WIDTH) { MEGDNN_SIMD_TYPE s0, s1, d0, d1; s0 = MEGDNN_SIMD_LOADU(sptr + iw); s1 = MEGDNN_SIMD_LOADU(sptr + iw + MEGDNN_SIMD_WIDTH); @@ -93,38 +92,40 @@ void FUNC_NAME(const float *src, float *dst, even[even_offset++] = sptr[iw]; } int ow = 0; - for (; ow < OW_from; ++ow) run_single(ow); + for (; ow < OW_from; ++ow) + run_single(ow); if (PW & 1) { - for (; ow+MEGDNN_SIMD_WIDTH <= OW_to; ow += MEGDNN_SIMD_WIDTH) { + for (; ow + MEGDNN_SIMD_WIDTH <= OW_to; ow += MEGDNN_SIMD_WIDTH) { MEGDNN_SIMD_TYPE d, s0, s1, s2; - s0 = MEGDNN_SIMD_LOADU(odd + ow - (PW>>1) - 1); - s1 = MEGDNN_SIMD_LOADU(even + ow - (PW>>1)); - s2 = MEGDNN_SIMD_LOADU(odd + ow - (PW>>1)); + s0 = MEGDNN_SIMD_LOADU(odd + ow - (PW >> 1) - 1); + s1 = MEGDNN_SIMD_LOADU(even + ow - (PW >> 1)); + s2 = MEGDNN_SIMD_LOADU(odd + ow - (PW >> 1)); d = MEGDNN_SIMD_MAX(MEGDNN_SIMD_MAX(s0, s1), s2); MEGDNN_SIMD_STOREU(cache[0] + ow, d); } } else { - for (; ow+MEGDNN_SIMD_WIDTH <= OW_to; ow += MEGDNN_SIMD_WIDTH) { + for (; ow + MEGDNN_SIMD_WIDTH <= OW_to; ow += MEGDNN_SIMD_WIDTH) { MEGDNN_SIMD_TYPE d, s0, s1, s2; - s0 = MEGDNN_SIMD_LOADU(even + ow - (PW>>1)); - s1 = MEGDNN_SIMD_LOADU(odd + ow - (PW>>1)); - s2 = MEGDNN_SIMD_LOADU(even + ow - (PW>>1) + 1); + s0 = MEGDNN_SIMD_LOADU(even + ow - (PW >> 1)); + s1 = MEGDNN_SIMD_LOADU(odd + ow - (PW >> 1)); + s2 = MEGDNN_SIMD_LOADU(even + ow - (PW >> 1) + 1); d = MEGDNN_SIMD_MAX(MEGDNN_SIMD_MAX(s0, s1), s2); MEGDNN_SIMD_STOREU(cache[0] + ow, d); } } - for (; ow < OW; ++ow) run_single(ow); + for (; ow < OW; ++ow) + run_single(ow); }; for (int oh = 0; oh < OH; ++oh) { - float * __restrict dptr = dst + oh*OW; - int ih_from = std::min(IH, std::max(0, oh*2 - PH)); - int ih_to = std::min(IH, std::max(0, oh*2 - PH + 3)); + float* __restrict dptr = dst + oh * OW; + int ih_from = std::min(IH, std::max(0, oh * 2 - PH)); + int ih_to = std::min(IH, std::max(0, oh * 2 - PH + 3)); while (ih_next < ih_to) { process_cache(ih_next++); } if (ih_to - ih_from == 3) { int ow = 0; - for (; ow+MEGDNN_SIMD_WIDTH <= OW; ow += MEGDNN_SIMD_WIDTH) { + for (; ow + MEGDNN_SIMD_WIDTH <= OW; ow += MEGDNN_SIMD_WIDTH) { MEGDNN_SIMD_TYPE d, s0, s1, s2; s0 = MEGDNN_SIMD_LOADU(cache[0] + ow); s1 = MEGDNN_SIMD_LOADU(cache[1] + ow); @@ -133,14 +134,13 @@ void FUNC_NAME(const float *src, float *dst, MEGDNN_SIMD_STOREU(dptr + ow, d); } for (; ow < OW; ++ow) { - dptr[ow] = std::max(std::max(cache[0][ow], cache[1][ow]), - cache[2][ow]); + dptr[ow] = std::max(std::max(cache[0][ow], cache[1][ow]), cache[2][ow]); } } else { std::memcpy(dptr, cache[0], sizeof(float) * OW); for (int i = 1; i < ih_to - ih_from; ++i) { int ow = 0; - for (; ow+MEGDNN_SIMD_WIDTH <= OW; ow += MEGDNN_SIMD_WIDTH) { + for (; ow + MEGDNN_SIMD_WIDTH <= OW; ow += MEGDNN_SIMD_WIDTH) { MEGDNN_SIMD_TYPE d, s; s = MEGDNN_SIMD_LOADU(cache[i] + ow); d = MEGDNN_SIMD_LOADU(dptr + ow); @@ -155,4 +155,4 @@ void FUNC_NAME(const float *src, float *dst, } } -} // namespace megdnn +} // namespace megdnn diff --git a/dnn/src/common/postprocess_helper.h b/dnn/src/common/postprocess_helper.h index 3aef6be8..05de2368 100644 --- a/dnn/src/common/postprocess_helper.h +++ b/dnn/src/common/postprocess_helper.h @@ -31,14 +31,15 @@ namespace { MEGDNN_MARK_USED_VAR(OW); \ MEGDNN_MARK_USED_VAR(pack_oc_size) -template +template < + typename ctype, typename dtype = ctype, + megdnn::PostprocessMode postprocess_mode = megdnn::PostprocessMode::FLOAT> struct PostProcess { - static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, - size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + static void run( + void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, + size_t OH, size_t OW, size_t pack_oc_size = 1) { POST_PROCESS_UNUSED_VAR(); megdnn_throw("not impl PostProcess"); } @@ -46,10 +47,11 @@ struct PostProcess { template struct PostProcess { - static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, - size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + static void run( + void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, + size_t OH, size_t OW, size_t pack_oc_size = 1) { POST_PROCESS_UNUSED_VAR(); megdnn_throw("not impl PostProcess"); } @@ -57,10 +59,11 @@ struct PostProcess { template struct PostProcess { - static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, - size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + static void run( + void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, + size_t OH, size_t OW, size_t pack_oc_size = 1) { POST_PROCESS_UNUSED_VAR(); megdnn_throw("not impl PostProcess"); } @@ -68,10 +71,11 @@ struct PostProcess { template struct PostProcess { - static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, - megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, - megdnn::DType bias_type, megdnn::DType dst_type, size_t N, - size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { + static void run( + void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, + megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, + megdnn::DType bias_type, megdnn::DType dst_type, size_t N, size_t OC, + size_t OH, size_t OW, size_t pack_oc_size = 1) { POST_PROCESS_UNUSED_VAR(); megdnn_throw("not impl PostProcess"); } diff --git a/dnn/src/common/powc.cpp b/dnn/src/common/powc.cpp index 06f627cd..84865319 100644 --- a/dnn/src/common/powc.cpp +++ b/dnn/src/common/powc.cpp @@ -17,17 +17,17 @@ using namespace megdnn; void PowC::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { - megdnn_assert(src.layout.dtype == dst.layout.dtype && - src.layout.dtype.category() == DTypeCategory::FLOAT && - src.layout.eq_shape(dst.layout), - "invalid layout: %s vs %s", src.layout.to_string().c_str(), - dst.layout.to_string().c_str()); + megdnn_assert( + src.layout.dtype == dst.layout.dtype && + src.layout.dtype.category() == DTypeCategory::FLOAT && + src.layout.eq_shape(dst.layout), + "invalid layout: %s vs %s", src.layout.to_string().c_str(), + dst.layout.to_string().c_str()); int iv, *ivp = nullptr; float fv, *fvp = nullptr; float p = param().exp; int pi = static_cast(std::round(p)); - if (std::abs(static_cast(pi) - p) < - std::numeric_limits::epsilon()) { + if (std::abs(static_cast(pi) - p) < std::numeric_limits::epsilon()) { iv = pi; ivp = &iv; } else { @@ -38,4 +38,3 @@ void PowC::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/common/reduce.cpp b/dnn/src/common/reduce.cpp index f6311fd6..27ce2778 100644 --- a/dnn/src/common/reduce.cpp +++ b/dnn/src/common/reduce.cpp @@ -30,13 +30,11 @@ DType get_out_dtype(const Reduce::DataType data_type, const DType inp_dtype) { } if (data_type == Reduce::DataType::QUINT_I8xO32) { megdnn_assert(inp_dtype.enumv() == DTypeEnum::Quantized8Asymm); - return dtype::QuantizedS32( - inp_dtype.param().scale); + return dtype::QuantizedS32(inp_dtype.param().scale); } if (data_type == Reduce::DataType::QINT_I8xO32) { megdnn_assert(inp_dtype.enumv() == DTypeEnum::QuantizedS8); - return dtype::QuantizedS32( - inp_dtype.param().scale); + return dtype::QuantizedS32(inp_dtype.param().scale); } megdnn_assert(data_type == Reduce::DataType::DEFAULT); return inp_dtype; @@ -57,13 +55,14 @@ void ReduceForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { dst.init_contiguous_stride(); } -void ReduceForward::check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes) { +void ReduceForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { auto errmsg = [&]() { return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst); }; - megdnn_assert(param().data_type != Reduce::DataType::FLOAT_IO16xC32, - "FLOAT_IO16xC32 is deprecated"); + megdnn_assert( + param().data_type != Reduce::DataType::FLOAT_IO16xC32, + "FLOAT_IO16xC32 is deprecated"); MEGDNN_MARK_USED_VAR(errmsg); megdnn_assert_contiguous(src); megdnn_assert_contiguous(dst); @@ -78,21 +77,24 @@ void ReduceForward::check_exec(const TensorLayout& src, const TensorLayout& dst, megdnn_assert(dst.shape[i] == 1_z, "%s", errmsg().c_str()); } } - megdnn_assert(src.dtype.category() == dst.dtype.category() || - param().data_type == Reduce::DataType::FLOAT_O32xC32, - "the category of reduce output and input must be the same," - " or the data_type is FLOAT_O32xC32"); + megdnn_assert( + src.dtype.category() == dst.dtype.category() || + param().data_type == Reduce::DataType::FLOAT_O32xC32, + "the category of reduce output and input must be the same," + " or the data_type is FLOAT_O32xC32"); if (param().data_type == DataType::DEFAULT) { - megdnn_assert(src.dtype == dst.dtype && - (src.dtype.category() == DTypeCategory::FLOAT || - src.dtype.category() == DTypeCategory::INT || - src.dtype.category() == DTypeCategory::QUANTIZED)); + megdnn_assert( + src.dtype == dst.dtype && + (src.dtype.category() == DTypeCategory::FLOAT || + src.dtype.category() == DTypeCategory::INT || + src.dtype.category() == DTypeCategory::QUANTIZED)); } else if (param().data_type == DataType::QUINT_I8xO32) { megdnn_assert(src.dtype.enumv() == DTypeEnum::Quantized8Asymm); } else if (param().data_type == DataType::QINT_I8xO32) { megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8); - } else if (param().data_type == DataType::FLOAT_IO16xC32 || - param().data_type == DataType::FLOAT_O16xC32) { + } else if ( + param().data_type == DataType::FLOAT_IO16xC32 || + param().data_type == DataType::FLOAT_O16xC32) { megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); } else { megdnn_assert(param().data_type == DataType::FLOAT_O32xC32); diff --git a/dnn/src/common/reduce_helper.cpp b/dnn/src/common/reduce_helper.cpp index bfecda1e..297833fe 100644 --- a/dnn/src/common/reduce_helper.cpp +++ b/dnn/src/common/reduce_helper.cpp @@ -17,15 +17,13 @@ namespace megdnn { namespace reduce { -void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, - size_t axis) { +void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis) { auto shape_arr = shape.shape; auto ndim = shape.ndim; - A = std::accumulate(shape_arr, shape_arr + axis, 1_z, - SafeMultiplies()); + A = std::accumulate(shape_arr, shape_arr + axis, 1_z, SafeMultiplies()); B = shape_arr[axis]; - C = std::accumulate(shape_arr + (axis + 1), shape_arr + ndim, 1_z, - SafeMultiplies()); + C = std::accumulate( + shape_arr + (axis + 1), shape_arr + ndim, 1_z, SafeMultiplies()); } } // namespace reduce diff --git a/dnn/src/common/reduce_helper.h b/dnn/src/common/reduce_helper.h index a5dbaefc..14a0f689 100644 --- a/dnn/src/common/reduce_helper.h +++ b/dnn/src/common/reduce_helper.h @@ -29,9 +29,7 @@ struct SumOp { const size_t B; MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } - MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { - dst[idx] = val; - } + MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; } @@ -48,7 +46,7 @@ struct MeanOp { src_ctype* src; dst_ctype* dst; const size_t B; - + MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val / static_cast(B); @@ -73,9 +71,7 @@ struct SumSqrOp { MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return static_cast(src[idx]) * static_cast(src[idx]); } - MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { - dst[idx] = val; - } + MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; } @@ -93,9 +89,7 @@ struct ProdOp { const size_t B; MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } - MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { - dst[idx] = val; - } + MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { return lhs * rhs; } @@ -113,9 +107,7 @@ struct MinOp { const size_t B; MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } - MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { - dst[idx] = val; - } + MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { #if defined(__CUDA_ARCH__) return lhs < rhs ? lhs : rhs; @@ -137,9 +129,7 @@ struct MaxOp { const size_t B; MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } - MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { - dst[idx] = val; - } + MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { #if defined(__CUDA_ARCH__) return lhs > rhs ? lhs : rhs; @@ -167,20 +157,16 @@ struct CheckNonFiniteOp { return !std::isfinite(src[idx]); #endif } - MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { - dst[idx] = val; - } + MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; } - MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, - size_t B) + MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, size_t B) : INIT(wtype(0)), src(src), dst(dst), B(B) {} }; #if MEGDNN_CC_HOST -void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, - size_t axis); +void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); #endif } // namespace reduce diff --git a/dnn/src/common/relayout.cpp b/dnn/src/common/relayout.cpp index 7e03eb2a..eacf6662 100644 --- a/dnn/src/common/relayout.cpp +++ b/dnn/src/common/relayout.cpp @@ -44,9 +44,7 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { * shape: n, m * stride: 1, n */ - auto strd = [&](size_t idx, ptrdiff_t v) { - return layout.stride[idx] == v; - }; + auto strd = [&](size_t idx, ptrdiff_t v) { return layout.stride[idx] == v; }; if (layout.ndim == 4) { p.batch = layout[0]; p.n = layout[1]; @@ -89,19 +87,18 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { } // anonymous namespace -void RelayoutForward::check_layout_and_canonize(TensorLayout& src, - TensorLayout& dst) { +void RelayoutForward::check_layout_and_canonize(TensorLayout& src, TensorLayout& dst) { megdnn_assert(dst.is_non_overlapping_strong()); src = src.collapse_contiguous(); dst = dst.collapse_contiguous(); - megdnn_assert(src.dtype == dst.dtype && - src.total_nr_elems() == dst.total_nr_elems(), - "check %s == %s and %zu == %zu", src.dtype.name(), - dst.dtype.name(), src.total_nr_elems(), dst.total_nr_elems()); + megdnn_assert( + src.dtype == dst.dtype && src.total_nr_elems() == dst.total_nr_elems(), + "check %s == %s and %zu == %zu", src.dtype.name(), dst.dtype.name(), + src.total_nr_elems(), dst.total_nr_elems()); } -bool relayout::is_transpose(const TensorLayout& src, const TensorLayout& dst, - TransposeParam& p) { +bool relayout::is_transpose( + const TensorLayout& src, const TensorLayout& dst, TransposeParam& p) { if (is_contig(dst) && is_transpose_single(src, p)) { // if the original intention is to transpose (m, n) to (n, m), // then we should use (n, m) as the contig dst and use a corrsponding diff --git a/dnn/src/common/relayout_format.cpp b/dnn/src/common/relayout_format.cpp index d3bf7115..56a91c39 100644 --- a/dnn/src/common/relayout_format.cpp +++ b/dnn/src/common/relayout_format.cpp @@ -16,8 +16,7 @@ using namespace megdnn; -void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, - TensorLayout& dst) { +void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { using Param = param::RelayoutFormat; switch (param().mode) { case Param::Mode::NCHW_NHWCD4: @@ -67,9 +66,10 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4"); dst.ndim = 6; - megdnn_assert(src[0] % 8 == 0, - "NCHW_NCHW88_CONV_DENSE_WEIGHT out channel must " - "align to 8"); + megdnn_assert( + src[0] % 8 == 0, + "NCHW_NCHW88_CONV_DENSE_WEIGHT out channel must " + "align to 8"); dst[0] = src[0] / 8; dst[1] = div_ceil(src[1], 8_z); dst[2] = src[2]; @@ -91,9 +91,10 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5"); dst.ndim = 7; dst[0] = src[0]; - megdnn_assert(src[1] % 8 == 0, - "NCHW_NCHW88_CONV_GROUP_WEIGHT out channel must " - "align to 8"); + megdnn_assert( + src[1] % 8 == 0, + "NCHW_NCHW88_CONV_GROUP_WEIGHT out channel must " + "align to 8"); dst[1] = src[1] / 8; dst[2] = div_ceil(src[2], 8_z); dst[3] = src[3]; @@ -305,9 +306,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { size_t align = handle()->image2d_pitch_alignment(); auto vendor_type = handle()->vendor_type(); using Param = param::RelayoutFormat; -#define CHECK_SRC(_expect) \ - megdnn_assert(src == _expect, "invalid src format: expect=%s got=%s", \ - _expect.to_string().c_str(), src.to_string().c_str()) +#define CHECK_SRC(_expect) \ + megdnn_assert( \ + src == _expect, "invalid src format: expect=%s got=%s", \ + _expect.to_string().c_str(), src.to_string().c_str()) switch (param().mode) { case Param::Mode::NHWC_NHWCD4: CHECK_SRC(DefaultTensorFormat::make()); @@ -330,8 +332,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type); break; case Param::Mode::NHWCD4I_NCHW: - CHECK_SRC( - Image2DPack4TensorFormat::make_raw(2, align, vendor_type)); + CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align, vendor_type)); dst = DefaultTensorFormat::make(); break; case Param::Mode::NHWCD4_NCHW: @@ -402,36 +403,33 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { if (dst.type() == TensorFormat::Type::IMAGE2D_PACK4 && ( - handle()->type() != Handle::HandleType::NAIVE && - handle()->type() != Handle::HandleType::X86)) { - megdnn_throw( - "Dump with Image2DPack4TensorFormat is not available on CUDA compnode, " - "try export CUDA_VISIBLE_DEVICES=\'\'"); + handle()->type() != Handle::HandleType::NAIVE && + handle()->type() != Handle::HandleType::X86)) { + megdnn_throw( + "Dump with Image2DPack4TensorFormat is not available on CUDA compnode, " + "try export CUDA_VISIBLE_DEVICES=\'\'"); } #undef CHECK_SRC } -void RelayoutFormat::check_layout_fwd(const TensorLayout& src, - const TensorLayout& dst) { +void RelayoutFormat::check_layout_fwd( + const TensorLayout& src, const TensorLayout& dst) { TensorLayout dst_expected; dst_expected.dtype = dst.dtype; deduce_layout_fwd(src, dst_expected); megdnn_assert_eq_layout(dst_expected, dst); } -void RelayoutFormat::check_exec(const TensorLayout& src, - const TensorLayout& dst, - size_t workspace_in_bytes) { +void RelayoutFormat::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, - const TensorLayout& dst, - TensorLayout& exec_workspace, - TensorLayout& exec_src, - TensorLayout& exec_dst) { +void RelayoutFormat::deduce_exec_layout( + const TensorLayout& src, const TensorLayout& dst, TensorLayout& exec_workspace, + TensorLayout& exec_src, TensorLayout& exec_dst) { check_layout_fwd(src, dst); using Param = param::RelayoutFormat; switch (param().mode) { @@ -439,11 +437,12 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, // nchw to nchw8c { exec_workspace = TensorLayout( - {src[0], round_up(src[1], 8_z), src[2], src[3]}, - src.dtype, src.format); + {src[0], round_up(src[1], 8_z), src[2], src[3]}, src.dtype, + src.format); exec_src = exec_workspace - .reshape({src[0], div_ceil(src[1], 8_z), 8, - src[2], src[3]}) + .reshape( + {src[0], div_ceil(src[1], 8_z), 8, src[2], + src[3]}) .dimshuffle({0, 1, 3, 4, 2}); exec_dst = dst; } @@ -456,11 +455,11 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, exec_workspace = TensorLayout( {src[0], group * round_up(icpg, 4_z), src[2], src[3]}, src.dtype, src.format); - exec_src = - exec_workspace - .reshape({src[0], group * div_ceil(icpg, 4_z), - 4, src[2], src[3]}) - .dimshuffle({0, 1, 3, 4, 2}); + exec_src = exec_workspace + .reshape( + {src[0], group * div_ceil(icpg, 4_z), 4, + src[2], src[3]}) + .dimshuffle({0, 1, 3, 4, 2}); exec_dst = dst; } break; @@ -469,25 +468,27 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, { if (src.ndim == 4) { exec_workspace = TensorLayout( - {round_up(src[0], 4_z), round_up(src[1], 4_z), - src[2], src[3]}, + {round_up(src[0], 4_z), round_up(src[1], 4_z), src[2], + src[3]}, src.dtype, src.format); - exec_src = exec_workspace - .reshape({round_up(src[0], 4_z), - div_ceil(src[1], 4_z), 4, - src[2], src[3]}) - .dimshuffle({0, 1, 3, 4, 2}); + exec_src = + exec_workspace + .reshape( + {round_up(src[0], 4_z), + div_ceil(src[1], 4_z), 4, src[2], src[3]}) + .dimshuffle({0, 1, 3, 4, 2}); exec_dst = dst; } else if (src.ndim == 5) { exec_workspace = TensorLayout( - {src[0], round_up(src[1], 4_z), - round_up(src[2], 4_z), src[3], src[4]}, + {src[0], round_up(src[1], 4_z), round_up(src[2], 4_z), + src[3], src[4]}, src.dtype, src.format); - exec_src = exec_workspace - .reshape({src[0], round_up(src[1], 4_z), - div_ceil(src[2], 4_z), 4, - src[3], src[4]}) - .dimshuffle({0, 1, 2, 4, 5, 3}); + exec_src = + exec_workspace + .reshape( + {src[0], round_up(src[1], 4_z), + div_ceil(src[2], 4_z), 4, src[3], src[4]}) + .dimshuffle({0, 1, 2, 4, 5, 3}); exec_dst = dst; } } @@ -496,9 +497,8 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, // nchw to nchw4 { megdnn_assert(src.format == dst.format); - exec_workspace = - TensorLayout({src[0], src[1] * 4, src[2], src[3]}, - dst.dtype, dst.format); + exec_workspace = TensorLayout( + {src[0], src[1] * 4, src[2], src[3]}, dst.dtype, dst.format); exec_src = src.dimshuffle({0, 1, 4, 2, 3}); exec_dst = dst; } @@ -515,13 +515,13 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, megdnn_assert(src.ndim == 4); megdnn_assert(src[0] % 8 == 0); exec_workspace = TensorLayout( - {src[0], round_up(src[1], 8_z), src[2], src[3]}, - src.dtype, src.format); - exec_src = - exec_workspace - .reshape({src[0] / 8, 8, div_ceil(src[1], 8_z), - 8, src[2], src[3]}) - .dimshuffle({0, 2, 4, 5, 3, 1}); + {src[0], round_up(src[1], 8_z), src[2], src[3]}, src.dtype, + src.format); + exec_src = exec_workspace + .reshape( + {src[0] / 8, 8, div_ceil(src[1], 8_z), 8, + src[2], src[3]}) + .dimshuffle({0, 2, 4, 5, 3, 1}); exec_dst = dst; } break; @@ -533,8 +533,9 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, {round_up(src[0], 8_z), src[1], src[2], src[3], src[4]}, src.dtype, src.format); exec_src = exec_workspace - .reshape({div_ceil(src[0], 8_z), 8, src[1], - src[2], src[3], src[4]}) + .reshape( + {div_ceil(src[0], 8_z), 8, src[1], src[2], + src[3], src[4]}) .dimshuffle({0, 2, 3, 4, 5, 1}); exec_dst = dst; } @@ -548,9 +549,9 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, {src[0], src[1], round_up(src[2], 8_z), src[3], src[4]}, src.dtype, src.format); exec_src = exec_workspace - .reshape({src[0], src[1] / 8, 8, - div_ceil(src[2], 8_z), 8, src[3], - src[4]}) + .reshape( + {src[0], src[1] / 8, 8, + div_ceil(src[2], 8_z), 8, src[3], src[4]}) .dimshuffle({0, 1, 3, 5, 6, 4, 2}); exec_dst = dst; } @@ -561,11 +562,12 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, // nchw to nchw4c or oihw to oihw4i { exec_workspace = TensorLayout( - {src[0], round_up(src[1], 4_z), src[2], src[3]}, - src.dtype, src.format); + {src[0], round_up(src[1], 4_z), src[2], src[3]}, src.dtype, + src.format); exec_src = exec_workspace - .reshape({src[0], div_ceil(src[1], 4_z), 4, - src[2], src[3]}) + .reshape( + {src[0], div_ceil(src[1], 4_z), 4, src[2], + src[3]}) .dimshuffle({0, 1, 3, 4, 2}); exec_dst = dst; } @@ -620,9 +622,8 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, // group conv filter // src is {G, ocpg, icpg, fh, fw} // dst is {G, ocpgb, fh, fw, icpg, 4} - exec_src = - src.reshape({src[0], src[1] / 4, 4, src[2], src[3], src[4]}) - .dimshuffle({0, 1, 4, 5, 3, 2}); + exec_src = src.reshape({src[0], src[1] / 4, 4, src[2], src[3], src[4]}) + .dimshuffle({0, 1, 4, 5, 3, 2}); exec_dst = dst; break; case Param::Mode::INTER_WEIGHT_CHAN: @@ -649,9 +650,9 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, case Param::Mode::INTER_WEIGHT_GROUPI_DOT: // src is {G, ocpg, icpg, fh, fw} // dst is {G, ocpg/4, fh, fw, icpg/4, 4, 4} - exec_src = src.reshape({src[0], src[1] / 4, 4, src[2] / 4, 4, - src[3], src[4]}) - .dimshuffle({0, 1, 5, 6, 3, 2, 4}); + exec_src = + src.reshape({src[0], src[1] / 4, 4, src[2] / 4, 4, src[3], src[4]}) + .dimshuffle({0, 1, 5, 6, 3, 2, 4}); exec_dst = dst; break; case Param::Mode::NCHW4_CHWN4: @@ -670,19 +671,19 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, // src is {N, C, H, W} // dst is {N, C/64, H, W, 64} exec_workspace = TensorLayout( - {src[0], round_up(src[1], 64_z), src[2], src[3]}, - src.dtype); + {src[0], round_up(src[1], 64_z), src[2], src[3]}, src.dtype); exec_src = exec_workspace - .reshape({src[0], div_ceil(src[1], 64_z), 64, - src[2], src[3]}) + .reshape( + {src[0], div_ceil(src[1], 64_z), 64, src[2], + src[3]}) .dimshuffle({0, 1, 3, 4, 2}); exec_dst = dst; break; case Param::Mode::NCHW64_NCHW: // src is {N, C/64, H, W, 64} // dst is {N, C, H, W} - exec_workspace = TensorLayout({src[0], src[1] * 64, src[2], src[3]}, - dst.dtype); + exec_workspace = + TensorLayout({src[0], src[1] * 64, src[2], src[3]}, dst.dtype); exec_src = src.dimshuffle({0, 1, 4, 2, 3}); exec_dst = dst; break; diff --git a/dnn/src/common/relayout_helper.h b/dnn/src/common/relayout_helper.h index 241ab7b3..16ff10ce 100644 --- a/dnn/src/common/relayout_helper.h +++ b/dnn/src/common/relayout_helper.h @@ -36,8 +36,7 @@ struct TransposeParam { * Note that \p src and \p dst should have been processed by * RelayoutForward::check_layout_and_canonize */ -bool is_transpose(const TensorLayout& src, const TensorLayout& dst, - TransposeParam& p); +bool is_transpose(const TensorLayout& src, const TensorLayout& dst, TransposeParam& p); namespace transpose_fallback { @@ -62,9 +61,9 @@ struct transpose_traits { }; template -void transpose_block_fallback(const T* src, T* dst, const size_t src_stride, - const size_t dst_stride, size_t block_h, - size_t block_w) { +void transpose_block_fallback( + const T* src, T* dst, const size_t src_stride, const size_t dst_stride, + size_t block_h, size_t block_w) { constexpr size_t block_size = transpose_traits::block_size; T block[block_size][block_size]; @@ -83,10 +82,10 @@ void transpose_block_fallback(const T* src, T* dst, const size_t src_stride, } template -void transpose_block(const T* src, T* dst, const size_t src_stride, - const size_t dst_stride, size_t block_h, size_t block_w) { - transpose_block_fallback(src, dst, src_stride, dst_stride, block_h, - block_w); +void transpose_block( + const T* src, T* dst, const size_t src_stride, const size_t dst_stride, + size_t block_h, size_t block_w) { + transpose_block_fallback(src, dst, src_stride, dst_stride, block_h, block_w); } /*! @@ -96,11 +95,10 @@ void transpose_block(const T* src, T* dst, const size_t src_stride, * block transpose */ template -void transpose_block(const T* src, T* dst, const size_t src_stride, - const size_t dst_stride) { +void transpose_block( + const T* src, T* dst, const size_t src_stride, const size_t dst_stride) { constexpr size_t block_size = transpose_traits::block_size; - transpose_block_fallback(src, dst, src_stride, dst_stride, block_size, - block_size); + transpose_block_fallback(src, dst, src_stride, dst_stride, block_size, block_size); } /*! diff --git a/dnn/src/common/remap.cpp b/dnn/src/common/remap.cpp index 218f420c..d77a5fa0 100644 --- a/dnn/src/common/remap.cpp +++ b/dnn/src/common/remap.cpp @@ -16,9 +16,8 @@ namespace megdnn { -void RemapBase::deduce_layout_fwd(const TensorLayout& src, - const TensorLayout& map_xy, - TensorLayout& dst) { +void RemapBase::deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst) { dst.dtype = src.dtype; dst.ndim = src.ndim; dst.shape[0] = src.shape[0]; @@ -36,16 +35,14 @@ void RemapBase::deduce_layout_fwd(const TensorLayout& src, dst.shape[channel_index] = src.shape[channel_index]; } -void RemapBase::check_layout_fwd(const TensorLayout& src, - const TensorLayout& map_xy, - const TensorLayout& dst) { +void RemapBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& map_xy, const TensorLayout& dst) { auto errmsg = [&]() { - return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(map_xy) + - ", " + megdnn_layout_msg(dst); + return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(map_xy) + ", " + + megdnn_layout_msg(dst); }; MEGDNN_MARK_USED_VAR(errmsg); - megdnn_assert(src.ndim == map_xy.ndim && src.ndim == dst.ndim && - src.ndim == 4); + megdnn_assert(src.ndim == map_xy.ndim && src.ndim == dst.ndim && src.ndim == 4); megdnn_assert(dst.dtype == src.dtype); megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str()); megdnn_assert(map_xy.shape[3] == 2); @@ -59,14 +56,14 @@ void RemapBase::check_layout_fwd(const TensorLayout& src, // In remap opr, H, W is same as H W in map_xy. if (param().format == param::Remap::Format::NHWC) { megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str()); - megdnn_assert(dst.shape[2] == map_xy.shape[2] && - dst.shape[1] == map_xy.shape[1], - "%s", errmsg().c_str()); + megdnn_assert( + dst.shape[2] == map_xy.shape[2] && dst.shape[1] == map_xy.shape[1], + "%s", errmsg().c_str()); } else if (param().format == param::Remap::Format::NCHW) { megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); - megdnn_assert(dst.shape[2] == map_xy.shape[1] && - dst.shape[3] == map_xy.shape[2], - "%s", errmsg().c_str()); + megdnn_assert( + dst.shape[2] == map_xy.shape[1] && dst.shape[3] == map_xy.shape[2], + "%s", errmsg().c_str()); } else { megdnn_throw( "currently do not support other param.format except NHWC and " @@ -74,43 +71,41 @@ void RemapBase::check_layout_fwd(const TensorLayout& src, } } -void Remap::deduce_layout(const TensorLayout& src, const TensorLayout& map_xy, - TensorLayout& dst) { +void Remap::deduce_layout( + const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst) { deduce_layout_fwd(src, map_xy, dst); } -void Remap::check_exec(const TensorLayout& src, const TensorLayout& map_xy, - const TensorLayout& dst, size_t workspace_in_bytes) { +void Remap::check_exec( + const TensorLayout& src, const TensorLayout& map_xy, const TensorLayout& dst, + size_t workspace_in_bytes) { check_layout_fwd(src, map_xy, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, map_xy, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void RemapBackwardData::check_exec(const TensorLayout& map_xy, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes) { +void RemapBackwardData::check_exec( + const TensorLayout& map_xy, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_in_bytes) { check_layout_fwd(grad, map_xy, diff); - megdnn_assert(grad.dtype == dtype::Float32() DNN_INC_FLOAT16( - || grad.dtype == dtype::BFloat16()), - "Backward Remap only supports Float32/BFloat16."); - auto required_workspace_in_bytes = - get_workspace_in_bytes(map_xy, diff, grad); + megdnn_assert( + grad.dtype == dtype::Float32() + DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), + "Backward Remap only supports Float32/BFloat16."); + auto required_workspace_in_bytes = get_workspace_in_bytes(map_xy, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void RemapBackwardMat::check_exec(const TensorLayout& src, - const TensorLayout& map_xy, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes) { +void RemapBackwardMat::check_exec( + const TensorLayout& src, const TensorLayout& map_xy, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(src, map_xy, diff); megdnn_assert_eq_layout(map_xy, grad); - megdnn_assert(grad.dtype == dtype::Float32() DNN_INC_FLOAT16( - || grad.dtype == dtype::BFloat16()), - "Backward Remap only supports Float32/BFloat16."); - auto required_workspace_in_bytes = - get_workspace_in_bytes(src, map_xy, diff, grad); + megdnn_assert( + grad.dtype == dtype::Float32() + DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), + "Backward Remap only supports Float32/BFloat16."); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, map_xy, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } diff --git a/dnn/src/common/resize.cpp b/dnn/src/common/resize.cpp index d7821c22..fafde92d 100644 --- a/dnn/src/common/resize.cpp +++ b/dnn/src/common/resize.cpp @@ -18,21 +18,22 @@ namespace megdnn { -void ResizeBase::check_layout_fwd(const TensorLayout& src, - const TensorLayout& dst) { +void ResizeBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { auto errmsg = [&]() { return megdnn_layout_msg(src) + ", " + ", " + megdnn_layout_msg(dst); }; MEGDNN_MARK_USED_VAR(errmsg); - megdnn_assert(dst.dtype == src.dtype && dst.shape[0] == src.shape[0], "%s", - errmsg().c_str()); + megdnn_assert( + dst.dtype == src.dtype && dst.shape[0] == src.shape[0], "%s", + errmsg().c_str()); if (param().format == Param::Format::NCHW) { megdnn_assert(dst.shape[1] == src.shape[1], "%s", errmsg().c_str()); auto imode = param().imode; using IMode = param::Resize::InterpolationMode; - megdnn_assert(imode == IMode::INTER_LINEAR || imode == IMode::NEAREST || - imode == IMode::INTER_CUBIC); + megdnn_assert( + imode == IMode::INTER_LINEAR || imode == IMode::NEAREST || + imode == IMode::INTER_CUBIC); } else if (param().format == Param::Format::NHWC) { megdnn_assert(dst.shape[3] == src.shape[3], "%s", errmsg().c_str()); } else if (param().format == Param::Format::NCHW4) { @@ -44,45 +45,42 @@ void ResizeBase::check_layout_fwd(const TensorLayout& src, megdnn_assert(src.ndim == 5); megdnn_assert(src.shape[4] == 4); megdnn_assert(dst.shape[4] == 4); - megdnn_assert(param().imode == - param::Resize::InterpolationMode::INTER_LINEAR || - param().imode == - param::Resize::InterpolationMode::INTER_NEAREST); + megdnn_assert( + param().imode == param::Resize::InterpolationMode::INTER_LINEAR || + param().imode == param::Resize::InterpolationMode::INTER_NEAREST); } else if (param().format == Param::Format::NCHW88) { megdnn_assert(src.ndim == 5); megdnn_assert(src.shape[4] == 8); megdnn_assert(dst.shape[4] == 8); - megdnn_assert(param().imode == - param::Resize::InterpolationMode::INTER_LINEAR || - param().imode == - param::Resize::InterpolationMode::INTER_NEAREST); + megdnn_assert( + param().imode == param::Resize::InterpolationMode::INTER_LINEAR || + param().imode == param::Resize::InterpolationMode::INTER_NEAREST); } else { - megdnn_assert(param().format == Param::Format::NHWCD4, - "invalid resize tensor format"); - megdnn_assert(param().imode == - param::Resize::InterpolationMode::INTER_LINEAR || - param().imode == - param::Resize::InterpolationMode::INTER_NEAREST); + megdnn_assert( + param().format == Param::Format::NHWCD4, + "invalid resize tensor format"); + megdnn_assert( + param().imode == param::Resize::InterpolationMode::INTER_LINEAR || + param().imode == param::Resize::InterpolationMode::INTER_NEAREST); megdnn_assert(dst.shape[2] == src.shape[2], "%s", errmsg().c_str()); } } -void Resize::check_exec(const TensorLayout& src, const TensorLayout& dst, - size_t workspace_in_bytes) { +void Resize::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void ResizeBackward::check_exec(const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes) { +void ResizeBackward::check_exec( + const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(grad, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); - megdnn_assert(param().format == Param::Format::NCHW && - grad.dtype == dtype::Float32(), - "Backward resize only supports Float32 and NCHW."); + megdnn_assert( + param().format == Param::Format::NCHW && grad.dtype == dtype::Float32(), + "Backward resize only supports Float32 and NCHW."); } std::pair ResizeBase::get_cubic_coord(float scale, int idx) { diff --git a/dnn/src/common/rng.cpp b/dnn/src/common/rng.cpp index 68acbf5f..0a6192c3 100644 --- a/dnn/src/common/rng.cpp +++ b/dnn/src/common/rng.cpp @@ -15,16 +15,14 @@ namespace megdnn { -void ShuffleRNGForward::deduce_layout(const TensorLayout& src, - TensorLayout& dst, - TensorLayout& indices) { +void ShuffleRNGForward::deduce_layout( + const TensorLayout& src, TensorLayout& dst, TensorLayout& indices) { dst = src; indices = TensorLayout(TensorShape({src.shape[0]}), dtype::Int32()); } -void ShuffleRNGForward::check_exec(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& indices, - size_t workspace_in_bytes) { +void ShuffleRNGForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& indices, + size_t workspace_in_bytes) { TensorLayout dst_expected, indices_expected; megdnn_assert_contiguous(src); deduce_layout(src, dst_expected, indices_expected); @@ -35,84 +33,79 @@ void ShuffleRNGForward::check_exec(const TensorLayout& src, megdnn_assert(src.dtype == dst.dtype); megdnn_assert(indices.dtype == dtype::Int32()); - auto required_workspace_in_bytes = - get_workspace_in_bytes(src, dst, indices); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst, indices); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void ShuffleRNGBackward::check_exec(const TensorLayout& diff, - const TensorLayout& indices, - const TensorLayout& grad, - size_t workspace_in_bytes) { +void ShuffleRNGBackward::check_exec( + const TensorLayout& diff, const TensorLayout& indices, const TensorLayout& grad, + size_t workspace_in_bytes) { megdnn_assert( diff.shape[0] == indices.shape[0] && diff.dtype == grad.dtype && indices.dtype == dtype::Int32{} && diff.is_contiguous() && indices.is_contiguous() && grad.is_contiguous(), - "invalid layouts: diff=%s indices=%s grad=%s", - diff.to_string().c_str(), indices.to_string().c_str(), - grad.to_string().c_str()); - auto required_workspace_in_bytes = - get_workspace_in_bytes(diff, indices, grad); + "invalid layouts: diff=%s indices=%s grad=%s", diff.to_string().c_str(), + indices.to_string().c_str(), grad.to_string().c_str()); + auto required_workspace_in_bytes = get_workspace_in_bytes(diff, indices, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void PermutationRNG::check_exec( - const TensorLayout &dst, size_t workspace_in_bytes) { - megdnn_assert((dst.dtype == dtype::Float32() || - dst.dtype == dtype::Int32() || - dst.dtype == dtype::Int16() ) && - dst.dtype.enumv() == param().dtype && - dst.is_contiguous()); +void PermutationRNG::check_exec(const TensorLayout& dst, size_t workspace_in_bytes) { + megdnn_assert( + (dst.dtype == dtype::Float32() || dst.dtype == dtype::Int32() || + dst.dtype == dtype::Int16()) && + dst.dtype.enumv() == param().dtype && dst.is_contiguous()); megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); } -void PoissonRNG::check_exec(const TensorLayout &lam, const TensorLayout &dst, - size_t workspace_in_bytes){ - megdnn_assert( dst.dtype.category() == DTypeCategory::FLOAT && - lam.dtype == dst.dtype); +void PoissonRNG::check_exec( + const TensorLayout& lam, const TensorLayout& dst, size_t workspace_in_bytes) { + megdnn_assert( + dst.dtype.category() == DTypeCategory::FLOAT && lam.dtype == dst.dtype); megdnn_assert(dst.is_contiguous() && lam.is_contiguous()); megdnn_assert(lam.total_nr_elems() == dst.total_nr_elems()); megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(lam, dst)); } -void GammaRNG::check_exec(const TensorLayout &shape,const TensorLayout &scale, - const TensorLayout &dst, size_t workspace_in_bytes){ - megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && - shape.dtype == dst.dtype && - scale.dtype == dst.dtype); - megdnn_assert(shape.is_contiguous() && scale.is_contiguous() - && dst.is_contiguous()); - megdnn_assert(shape.total_nr_elems() == dst.total_nr_elems() && - scale.total_nr_elems() == dst.total_nr_elems()); - megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(shape,scale,dst)); +void GammaRNG::check_exec( + const TensorLayout& shape, const TensorLayout& scale, const TensorLayout& dst, + size_t workspace_in_bytes) { + megdnn_assert( + dst.dtype.category() == DTypeCategory::FLOAT && shape.dtype == dst.dtype && + scale.dtype == dst.dtype); + megdnn_assert( + shape.is_contiguous() && scale.is_contiguous() && dst.is_contiguous()); + megdnn_assert( + shape.total_nr_elems() == dst.total_nr_elems() && + scale.total_nr_elems() == dst.total_nr_elems()); + megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(shape, scale, dst)); } -void BetaRNG::check_exec(const TensorLayout &alpha,const TensorLayout &beta, - const TensorLayout &dst, size_t workspace_in_bytes){ - megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && - alpha.dtype == dst.dtype && - beta.dtype == dst.dtype); - megdnn_assert(alpha.is_contiguous() && beta.is_contiguous() - && dst.is_contiguous()); - megdnn_assert(alpha.total_nr_elems() == dst.total_nr_elems() && - beta.total_nr_elems() == dst.total_nr_elems()); - megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(alpha,beta, dst)); +void BetaRNG::check_exec( + const TensorLayout& alpha, const TensorLayout& beta, const TensorLayout& dst, + size_t workspace_in_bytes) { + megdnn_assert( + dst.dtype.category() == DTypeCategory::FLOAT && alpha.dtype == dst.dtype && + beta.dtype == dst.dtype); + megdnn_assert(alpha.is_contiguous() && beta.is_contiguous() && dst.is_contiguous()); + megdnn_assert( + alpha.total_nr_elems() == dst.total_nr_elems() && + beta.total_nr_elems() == dst.total_nr_elems()); + megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(alpha, beta, dst)); } -#define INST_CHECK_EXEC(RNG_NAME) \ - void RNG_NAME::check_exec( \ - const TensorLayout &dst, size_t workspace_in_bytes) { \ - megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && \ - dst.dtype.enumv() == param().dtype && \ - dst.is_contiguous()); \ - megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); \ +#define INST_CHECK_EXEC(RNG_NAME) \ + void RNG_NAME::check_exec(const TensorLayout& dst, size_t workspace_in_bytes) { \ + megdnn_assert( \ + dst.dtype.category() == DTypeCategory::FLOAT && \ + dst.dtype.enumv() == param().dtype && dst.is_contiguous()); \ + megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); \ } INST_CHECK_EXEC(UniformRNG) INST_CHECK_EXEC(GaussianRNG) #undef INST_CHECK_EXEC -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/common/roi_align.cpp b/dnn/src/common/roi_align.cpp index 3f0bbd49..867172a6 100644 --- a/dnn/src/common/roi_align.cpp +++ b/dnn/src/common/roi_align.cpp @@ -14,9 +14,9 @@ namespace megdnn { -void ROIAlignBase::deduce_layout_fwd(const TensorLayout& src, - const TensorLayout& rois, - TensorLayout& dst, TensorLayout& index) { +void ROIAlignBase::deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst, + TensorLayout& index) { megdnn_assert_contiguous(src); megdnn_assert_contiguous(rois); megdnn_assert_contiguous(dst); @@ -29,8 +29,8 @@ void ROIAlignBase::deduce_layout_fwd(const TensorLayout& src, using Format = ROIAlignBase::Param::Format; megdnn_assert(param().format == Format::NCHW); auto src_dtype = src.dtype, rois_dtype = rois.dtype; - megdnn_assert(src_dtype == rois_dtype && - src_dtype.category() == DTypeCategory::FLOAT); + megdnn_assert( + src_dtype == rois_dtype && src_dtype.category() == DTypeCategory::FLOAT); megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); size_t channels = src.shape[1]; megdnn_assert(rois.ndim == 2_z, "%s", errmsg().c_str()); @@ -44,10 +44,9 @@ void ROIAlignBase::deduce_layout_fwd(const TensorLayout& src, index.dtype = dtype::Int32(); } -void ROIAlignBase::check_layout_fwd(const TensorLayout& src, - const TensorLayout& rois, - const TensorLayout& dst, - const TensorLayout& index) { +void ROIAlignBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, + const TensorLayout& index) { TensorLayout dst_expected, index_expected; megdnn_assert_eq_dtype(src, dst); deduce_layout_fwd(src, rois, dst_expected, index_expected); @@ -56,31 +55,25 @@ void ROIAlignBase::check_layout_fwd(const TensorLayout& src, megdnn_assert(index.dtype == dtype::Int32()); } -void ROIAlignForward::deduce_layout(const TensorLayout& src, - const TensorLayout& rois, TensorLayout& dst, - TensorLayout& index) { +void ROIAlignForward::deduce_layout( + const TensorLayout& src, const TensorLayout& rois, TensorLayout& dst, + TensorLayout& index) { deduce_layout_fwd(src, rois, dst, index); } -void ROIAlignForward::check_exec(const TensorLayout& src, - const TensorLayout& rois, - const TensorLayout& dst, - const TensorLayout& index, - size_t workspace_in_bytes) { +void ROIAlignForward::check_exec( + const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, + const TensorLayout& index, size_t workspace_in_bytes) { check_layout_fwd(src, rois, dst, index); - auto required_workspace_in_bytes = - get_workspace_in_bytes(src, rois, dst, index); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, rois, dst, index); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void ROIAlignBackward::check_exec(const TensorLayout& diff, - const TensorLayout& rois, - const TensorLayout& index, - const TensorLayout& grad, - size_t workspace_in_bytes) { +void ROIAlignBackward::check_exec( + const TensorLayout& diff, const TensorLayout& rois, const TensorLayout& index, + const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(grad, rois, diff, index); - auto required_workspace_in_bytes = - get_workspace_in_bytes(diff, rois, index, grad); + auto required_workspace_in_bytes = get_workspace_in_bytes(diff, rois, index, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } diff --git a/dnn/src/common/roi_align_helper.h b/dnn/src/common/roi_align_helper.h index 200601e5..d60274e7 100644 --- a/dnn/src/common/roi_align_helper.h +++ b/dnn/src/common/roi_align_helper.h @@ -19,34 +19,32 @@ namespace megdnn { namespace roi_align { template -MEGDNN_HOST MEGDNN_DEVICE T bilinear_interp(const T* data, const float h, - const float w, const int height, - const int width) { +MEGDNN_HOST MEGDNN_DEVICE T bilinear_interp( + const T* data, const float h, const float w, const int height, + const int width) { int h0 = floorf(h), w0 = floorf(w), h1 = h0 + 1, w1 = w0 + 1; T top_left = (h0 >= 0 && h0 < height && w0 >= 0 && w0 < width) - ? data[h0 * width + w0] - : T(0.f); + ? data[h0 * width + w0] + : T(0.f); T top_right = (h0 >= 0 && h0 < height && w1 >= 0 && w1 < width) - ? data[h0 * width + w1] - : T(0.f); + ? data[h0 * width + w1] + : T(0.f); T bottom_left = (h1 >= 0 && h1 < height && w0 >= 0 && w0 < width) - ? data[h1 * width + w0] - : T(0.f); + ? data[h1 * width + w0] + : T(0.f); T bottom_right = (h1 >= 0 && h1 < height && w1 >= 0 && w1 < width) - ? data[h1 * width + w1] - : T(0.f); + ? data[h1 * width + w1] + : T(0.f); T top = top_left + (top_right - top_left) * static_cast(w - w0); - T bottom = - bottom_left + (bottom_right - bottom_left) * static_cast(w - w0); + T bottom = bottom_left + (bottom_right - bottom_left) * static_cast(w - w0); T res = top + (bottom - top) * static_cast(h - h0); return res; } template -MEGDNN_HOST MEGDNN_DEVICE void distribute_diff(T* diff, const T top_diff, - const float h, const float w, - const int height, - const int width) { +MEGDNN_HOST MEGDNN_DEVICE void distribute_diff( + T* diff, const T top_diff, const float h, const float w, const int height, + const int width) { #if MEGDNN_CC_CUDA using namespace ::megdnn::cuda; #endif @@ -103,9 +101,7 @@ struct MaxPooler { maxidx = idx; } } - MEGDNN_HOST MEGDNN_DEVICE void writeback_val(T& val) { - val = cnt > 0 ? maxval : 0; - } + MEGDNN_HOST MEGDNN_DEVICE void writeback_val(T& val) { val = cnt > 0 ? maxval : 0; } MEGDNN_HOST MEGDNN_DEVICE void writeback_idx(int& idx) { idx = maxidx; } }; @@ -131,10 +127,9 @@ struct BwdPooler { int height, width; float roi_start_h, roi_start_w, bin_size_h, bin_size_w; float sample_h_rate, sample_w_rate; - MEGDNN_HOST MEGDNN_DEVICE BwdPooler(int ph, int pw, int sample_height, - int sample_width, int height, int width, - float roi_start_h, float roi_start_w, - float bin_size_h, float bin_size_w) + MEGDNN_HOST MEGDNN_DEVICE BwdPooler( + int ph, int pw, int sample_height, int sample_width, int height, int width, + float roi_start_h, float roi_start_w, float bin_size_h, float bin_size_w) : ph{ph}, pw{pw}, sample_height{sample_height}, @@ -153,57 +148,52 @@ struct BwdPooler { template struct BwdMaxPooler : public BwdPooler { using Super = BwdPooler; - MEGDNN_HOST MEGDNN_DEVICE BwdMaxPooler(int ph, int pw, int sample_height, - int sample_width, int height, - int width, float roi_start_h, - float roi_start_w, float bin_size_h, - float bin_size_w) + MEGDNN_HOST MEGDNN_DEVICE BwdMaxPooler( + int ph, int pw, int sample_height, int sample_width, int height, int width, + float roi_start_h, float roi_start_w, float bin_size_h, float bin_size_w) : BwdPooler{ph, pw, sample_height, sample_width, height, width, roi_start_h, roi_start_w, bin_size_h, bin_size_w} {} - MEGDNN_HOST MEGDNN_DEVICE void update(int index, const T* diff, - const int* argmax, T* grad) { + MEGDNN_HOST MEGDNN_DEVICE void update( + int index, const T* diff, const int* argmax, T* grad) { int h_iter = argmax[index] / Super::sample_width; int w_iter = argmax[index] - Super::sample_width * h_iter; - float hcenter = - Super::roi_start_h + - Super::bin_size_h * - (Super::ph + Super::sample_h_rate * (h_iter + 0.5f)); - float wcenter = - Super::roi_start_w + - Super::bin_size_w * - (Super::pw + Super::sample_w_rate * (w_iter + 0.5f)); - distribute_diff(grad, diff[index], hcenter, wcenter, Super::height, - Super::width); + float hcenter = Super::roi_start_h + + Super::bin_size_h * + (Super::ph + Super::sample_h_rate * (h_iter + 0.5f)); + float wcenter = Super::roi_start_w + + Super::bin_size_w * + (Super::pw + Super::sample_w_rate * (w_iter + 0.5f)); + distribute_diff( + grad, diff[index], hcenter, wcenter, Super::height, Super::width); } }; template struct BwdAveragePooler : public BwdPooler { using Super = BwdPooler; - MEGDNN_HOST MEGDNN_DEVICE - BwdAveragePooler(int ph, int pw, int sample_height, int sample_width, - int height, int width, float roi_start_h, - float roi_start_w, float bin_size_h, float bin_size_w) + MEGDNN_HOST MEGDNN_DEVICE BwdAveragePooler( + int ph, int pw, int sample_height, int sample_width, int height, int width, + float roi_start_h, float roi_start_w, float bin_size_h, float bin_size_w) : BwdPooler{ph, pw, sample_height, sample_width, height, width, roi_start_h, roi_start_w, bin_size_h, bin_size_w} {} - MEGDNN_HOST MEGDNN_DEVICE void update(int index, const T* diff, - const int* /* argmax */, T* grad) { + MEGDNN_HOST MEGDNN_DEVICE void update( + int index, const T* diff, const int* /* argmax */, T* grad) { int cnt = Super::sample_height * Super::sample_width; for (int h_iter = 0; h_iter < Super::sample_height; ++h_iter) { for (int w_iter = 0; w_iter < Super::sample_width; ++w_iter) { - float hcenter = Super::roi_start_h + - Super::bin_size_h * - (Super::ph + Super::sample_h_rate * - (h_iter + 0.5f)); - float wcenter = Super::roi_start_w + - Super::bin_size_w * - (Super::pw + Super::sample_w_rate * - (w_iter + 0.5f)); + float hcenter = + Super::roi_start_h + + Super::bin_size_h * + (Super::ph + Super::sample_h_rate * (h_iter + 0.5f)); + float wcenter = + Super::roi_start_w + + Super::bin_size_w * + (Super::pw + Super::sample_w_rate * (w_iter + 0.5f)); T val = diff[index] / static_cast(cnt); - distribute_diff(grad, val, hcenter, wcenter, Super::height, - Super::width); + distribute_diff( + grad, val, hcenter, wcenter, Super::height, Super::width); } } } diff --git a/dnn/src/common/roi_copy.cpp b/dnn/src/common/roi_copy.cpp index 6cbcee9e..ab4d162e 100644 --- a/dnn/src/common/roi_copy.cpp +++ b/dnn/src/common/roi_copy.cpp @@ -14,8 +14,7 @@ namespace megdnn { -void ROICopyBase::deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst) -{ +void ROICopyBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { size_t in = src.shape[0]; size_t ih = src.shape[1]; size_t iw = src.shape[2]; @@ -30,28 +29,24 @@ void ROICopyBase::deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst) dst = TensorLayout(TensorShape({in, oh, ow, ic}), src.dtype); } -void ROICopyBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &dst) -{ +void ROICopyBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { TensorLayout dst_expected; megdnn_assert_eq_dtype(src, dst); deduce_layout_fwd(src, dst_expected); megdnn_assert_eq_shape(dst_expected, dst); } -void ROICopy::deduce_layout(const TensorLayout &src, TensorLayout &dst) -{ +void ROICopy::deduce_layout(const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } -void ROICopy::check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void ROICopy::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/roi_pooling.cpp b/dnn/src/common/roi_pooling.cpp index f116824b..797597e1 100644 --- a/dnn/src/common/roi_pooling.cpp +++ b/dnn/src/common/roi_pooling.cpp @@ -14,21 +14,17 @@ namespace megdnn { -void ROIPoolingBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &rois, - const TensorLayout &dst, - const TensorLayout &index) -{ +void ROIPoolingBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, + const TensorLayout& index) { // all should be contiguous megdnn_assert_contiguous(src); megdnn_assert_contiguous(rois); megdnn_assert_contiguous(dst); megdnn_assert_contiguous(index); auto errmsg = [&]() { - return megdnn_layout_msg(src) + ", " - + megdnn_layout_msg(rois) + ", " - + megdnn_layout_msg(dst) + ", " - + megdnn_layout_msg(index); + return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(rois) + ", " + + megdnn_layout_msg(dst) + ", " + megdnn_layout_msg(index); }; MEGDNN_MARK_USED_VAR(errmsg); // src @@ -50,32 +46,25 @@ void ROIPoolingBase::check_layout_fwd(const TensorLayout &src, megdnn_assert(index.dtype == dtype::Int32()); } -void ROIPoolingForward::check_exec(const TensorLayout &src, - const TensorLayout &rois, - const TensorLayout &dst, - const TensorLayout &index, - size_t workspace_in_bytes) -{ +void ROIPoolingForward::check_exec( + const TensorLayout& src, const TensorLayout& rois, const TensorLayout& dst, + const TensorLayout& index, size_t workspace_in_bytes) { check_layout_fwd(src, rois, dst, index); - auto required_workspace_in_bytes = get_workspace_in_bytes(src, - rois, dst, index); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, rois, dst, index); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void ROIPoolingBackward::check_exec(const TensorLayout &diff, - const TensorLayout &src, - const TensorLayout &rois, - const TensorLayout &index, - const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void ROIPoolingBackward::check_exec( + const TensorLayout& diff, const TensorLayout& src, const TensorLayout& rois, + const TensorLayout& index, const TensorLayout& grad, + size_t workspace_in_bytes) { check_layout_fwd(src, rois, diff, index); megdnn_assert_eq_layout(src, grad); - auto required_workspace_in_bytes = get_workspace_in_bytes(diff, - src, rois, index, grad); + auto required_workspace_in_bytes = + get_workspace_in_bytes(diff, src, rois, index, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/roi_pooling_helper.h b/dnn/src/common/roi_pooling_helper.h index 967c6dd8..fb644d4d 100644 --- a/dnn/src/common/roi_pooling_helper.h +++ b/dnn/src/common/roi_pooling_helper.h @@ -14,93 +14,67 @@ namespace megdnn { namespace roi_pooling { -template struct MaxPooler { +template +struct MaxPooler { T maxval; int maxidx; size_t cnt; - MEGDNN_HOST MEGDNN_DEVICE MaxPooler(): - maxval(DTypeTrait::min()), - maxidx(-1), - cnt(0) - {} - MEGDNN_HOST MEGDNN_DEVICE void feed(T val, int idx) - { + MEGDNN_HOST MEGDNN_DEVICE MaxPooler() + : maxval(DTypeTrait::min()), maxidx(-1), cnt(0) {} + MEGDNN_HOST MEGDNN_DEVICE void feed(T val, int idx) { ++cnt; if (val > maxval) { maxval = val; maxidx = idx; } } - MEGDNN_HOST MEGDNN_DEVICE void writeback_val(T &val) - { - val = cnt > 0 ? maxval : 0; - } - MEGDNN_HOST MEGDNN_DEVICE void writeback_idx(int &idx) - { - idx = maxidx; - } + MEGDNN_HOST MEGDNN_DEVICE void writeback_val(T& val) { val = cnt > 0 ? maxval : 0; } + MEGDNN_HOST MEGDNN_DEVICE void writeback_idx(int& idx) { idx = maxidx; } }; -template struct AveragePooler { +template +struct AveragePooler { T sum; size_t cnt; - MEGDNN_HOST MEGDNN_DEVICE AveragePooler(): - sum(T(0)), cnt(0) - {} - MEGDNN_HOST MEGDNN_DEVICE void feed(T val, int) - { + MEGDNN_HOST MEGDNN_DEVICE AveragePooler() : sum(T(0)), cnt(0) {} + MEGDNN_HOST MEGDNN_DEVICE void feed(T val, int) { sum += val; ++cnt; } - MEGDNN_HOST MEGDNN_DEVICE void writeback_val(T &val) - { + MEGDNN_HOST MEGDNN_DEVICE void writeback_val(T& val) { val = cnt > 0 ? sum / T(cnt) : 0; } - MEGDNN_HOST MEGDNN_DEVICE void writeback_idx(int &) - { - } + MEGDNN_HOST MEGDNN_DEVICE void writeback_idx(int&) {} }; -template struct BwdMaxPooler { +template +struct BwdMaxPooler { MEGDNN_HOST MEGDNN_DEVICE void update( - int ph, int pw, int h, int w, - float /* bin_size_h */, float /* bin_size_w */, - int /* roi_start_h */, int /* roi_start_w */, - size_t /* pooled_height */, size_t pooled_width, - size_t /* height */, size_t width, - const T *offset_src_diff, - const int *offset_fp_idx, - T &gradient) - { - if (offset_fp_idx[ph * pooled_width + pw] == - (int)(h * width + w)) { - gradient += offset_src_diff[ph * pooled_width + pw]; + int ph, int pw, int h, int w, float /* bin_size_h */, + float /* bin_size_w */, int /* roi_start_h */, int /* roi_start_w */, + size_t /* pooled_height */, size_t pooled_width, size_t /* height */, + size_t width, const T* offset_src_diff, const int* offset_fp_idx, + T& gradient) { + if (offset_fp_idx[ph * pooled_width + pw] == (int)(h * width + w)) { + gradient += offset_src_diff[ph * pooled_width + pw]; } } }; -template struct BwdAveragePooler -{ +template +struct BwdAveragePooler { MEGDNN_HOST MEGDNN_DEVICE void update( int ph, int pw, int h, int w, float bin_size_h, float bin_size_w, - int roi_start_h, int roi_start_w, - size_t /* pooled_height */, size_t pooled_width, - size_t height, size_t width, - const T *offset_src_diff, - const int * /* offset_fp_idx */, - T &gradient) - { + int roi_start_h, int roi_start_w, size_t /* pooled_height */, + size_t pooled_width, size_t height, size_t width, const T* offset_src_diff, + const int* /* offset_fp_idx */, T& gradient) { #if MEGDNN_CC_HOST - using std::min; using std::max; + using std::min; #endif - int hstart = static_cast(floor(static_cast(ph) - * bin_size_h)); - int wstart = static_cast(floor(static_cast(pw) - * bin_size_w)); - int hend = static_cast(ceil(static_cast(ph + 1) - * bin_size_h)); - int wend = static_cast(ceil(static_cast(pw + 1) - * bin_size_w)); + int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); + int wstart = static_cast(floor(static_cast(pw) * bin_size_w)); + int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h)); + int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w)); // Add roi offsets and clip to input boundaries hstart = min(max(hstart + roi_start_h, 0), (int)height); hend = min(max(hend + roi_start_h, 0), (int)height); @@ -109,12 +83,12 @@ template struct BwdAveragePooler int size = (hend - hstart) * (wend - wstart); float inv_size = 1.0f / size; if (h >= hstart && h < hend && w >= wstart && w < wend) { - gradient += offset_src_diff[ph * pooled_width + pw] * inv_size; + gradient += offset_src_diff[ph * pooled_width + pw] * inv_size; } } }; -} // namespace roi_pooling -} // namespace megdnn +} // namespace roi_pooling +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/rotate.cpp b/dnn/src/common/rotate.cpp index 02528742..ca3ccaa1 100644 --- a/dnn/src/common/rotate.cpp +++ b/dnn/src/common/rotate.cpp @@ -15,13 +15,13 @@ namespace megdnn { -void RotateBase::deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst) -{ +void RotateBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { auto errmsg = [&]() { return megdnn_layout_msg(src); }; MEGDNN_MARK_USED_VAR(errmsg); - megdnn_assert(src.ndim == 4_z && (src.shape[3] == 1_z || - src.shape[3] == 3_z), "%s", errmsg().c_str()); + megdnn_assert( + src.ndim == 4_z && (src.shape[3] == 1_z || src.shape[3] == 3_z), "%s", + errmsg().c_str()); size_t in = src.shape[0]; size_t ih = src.shape[1]; @@ -31,28 +31,24 @@ void RotateBase::deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst) dst = TensorLayout(TensorShape({in, iw, ih, ic}), src.dtype); } -void RotateBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &dst) -{ +void RotateBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { TensorLayout dst_expected; megdnn_assert_eq_dtype(src, dst); deduce_layout_fwd(src, dst_expected); megdnn_assert_eq_shape(dst_expected, dst); } -void Rotate::deduce_layout(const TensorLayout &src, TensorLayout &dst) -{ +void Rotate::deduce_layout(const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } -void Rotate::check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void Rotate::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/rounding_converter.cuh b/dnn/src/common/rounding_converter.cuh index 336f8309..46519eb7 100644 --- a/dnn/src/common/rounding_converter.cuh +++ b/dnn/src/common/rounding_converter.cuh @@ -20,8 +20,7 @@ struct RoundingConverter; template <> struct RoundingConverter { - MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE float operator()( - float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE float operator()(float x) const { return x; } }; @@ -38,8 +37,8 @@ struct RoundingConverter { template <> struct RoundingConverter { - MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_bfloat16::bfloat16 - operator()(float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE half_bfloat16::bfloat16 operator()( + float x) const { return static_cast(x); } }; @@ -48,8 +47,7 @@ struct RoundingConverter { template <> struct RoundingConverter { - MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE int8_t - operator()(float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE int8_t operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif @@ -59,8 +57,7 @@ struct RoundingConverter { template <> struct RoundingConverter { - MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE uint8_t - operator()(float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE uint8_t operator()(float x) const { #if MEGDNN_CC_HOST using std::max; using std::min; @@ -73,8 +70,7 @@ struct RoundingConverter { template <> struct RoundingConverter { - MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_qint4 - operator()(float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_qint4 operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif @@ -84,8 +80,7 @@ struct RoundingConverter { template <> struct RoundingConverter { - MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_quint4 - operator()(float x) const { + MEGDNN_HOST MEGDNN_DEVICE MEGDNN_FORCE_INLINE dt_quint4 operator()(float x) const { #if MEGDNN_CC_HOST using std::round; #endif diff --git a/dnn/src/common/separableConv.cpp b/dnn/src/common/separableConv.cpp index f22832c6..521e3068 100644 --- a/dnn/src/common/separableConv.cpp +++ b/dnn/src/common/separableConv.cpp @@ -14,17 +14,14 @@ namespace megdnn { -void SeparableConvBase::deduce_layout_fwd(const TensorLayout &src, - const TensorLayout &filter_x, - const TensorLayout &filter_y, - TensorLayout &dst) -{ +void SeparableConvBase::deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, TensorLayout& dst) { auto errmsg = [&]() { - return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter_x) + - ", " + megdnn_layout_msg(dst) + ", " + - "is_xcorr=" + "borderMode=" + - std::to_string((param().mode == Mode::CROSS_CORRELATION)) + - ", " + std::to_string((int)(param().borderMode)) + ", " + + return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter_x) + ", " + + megdnn_layout_msg(dst) + ", " + "is_xcorr=" + "borderMode=" + + std::to_string((param().mode == Mode::CROSS_CORRELATION)) + ", " + + std::to_string((int)(param().borderMode)) + ", " + "pad_h=" + std::to_string(param().pad_h) + ", " + "pad_w=" + std::to_string(param().pad_w) + ", " + "stride_h=" + std::to_string(param().stride_h) + ", " + @@ -53,11 +50,9 @@ void SeparableConvBase::deduce_layout_fwd(const TensorLayout &src, dst = TensorLayout(TensorShape({n, oc, oh, ow}), src.dtype); } -void SeparableConvBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &filter_x, - const TensorLayout &filter_y, - const TensorLayout &dst) -{ +void SeparableConvBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, const TensorLayout& dst) { TensorLayout dst_expected; megdnn_assert_eq_dtype(src, filter_x); megdnn_assert_eq_dtype(src, filter_y); @@ -67,25 +62,22 @@ void SeparableConvBase::check_layout_fwd(const TensorLayout &src, megdnn_assert_eq_layout(dst_expected, dst); } -void SeparableConvForward::deduce_layout(const TensorLayout &src, - const TensorLayout &filter_x, - const TensorLayout &filter_y, - TensorLayout &dst) -{ +void SeparableConvForward::deduce_layout( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, TensorLayout& dst) { deduce_layout_fwd(src, filter_x, filter_y, dst); } -void SeparableConvForward::check_exec(const TensorLayout &src, - const TensorLayout &filter_x, - const TensorLayout &filter_y, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void SeparableConvForward::check_exec( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, const TensorLayout& dst, + size_t workspace_in_bytes) { check_layout_fwd(src, filter_x, filter_y, dst); - auto required_workspace_in_bytes = get_workspace_in_bytes(src, filter_x, filter_y, dst); + auto required_workspace_in_bytes = + get_workspace_in_bytes(src, filter_x, filter_y, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/separableFilter.cpp b/dnn/src/common/separableFilter.cpp index d2c1adc9..3f4df850 100644 --- a/dnn/src/common/separableFilter.cpp +++ b/dnn/src/common/separableFilter.cpp @@ -14,15 +14,14 @@ namespace megdnn { -void SeparableFilterBase::deduce_layout_fwd(const TensorLayout& src, - const TensorLayout& filter_x, - const TensorLayout& filter_y, - TensorLayout& dst) { +void SeparableFilterBase::deduce_layout_fwd( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, TensorLayout& dst) { auto errmsg = [&]() { - return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter_x) + - ", " + megdnn_layout_msg(dst) + ", " + - "borderMode=" + std::to_string((int)(param().borderMode)) + - ", " + "ksize_h=" + std::to_string(param().ksize_h) + ", " + + return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(filter_x) + ", " + + megdnn_layout_msg(dst) + ", " + + "borderMode=" + std::to_string((int)(param().borderMode)) + ", " + + "ksize_h=" + std::to_string(param().ksize_h) + ", " + "ksize_w=" + std::to_string(param().ksize_w) + ", " + "anchor_h=" + std::to_string(param().anchor_h) + ", " + "anchor_w=" + std::to_string(param().anchor_w); @@ -32,8 +31,8 @@ void SeparableFilterBase::deduce_layout_fwd(const TensorLayout& src, megdnn_assert_contiguous(filter_x); megdnn_assert_contiguous(filter_y); megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); - megdnn_assert(param().format == Param::Format::NHWC, - "Only NHWC was supported by now"); + megdnn_assert( + param().format == Param::Format::NHWC, "Only NHWC was supported by now"); size_t n = src[0]; size_t ih = src[1]; size_t iw = src[2]; @@ -41,28 +40,25 @@ void SeparableFilterBase::deduce_layout_fwd(const TensorLayout& src, dst = TensorLayout(TensorShape({n, ih, iw, ic}), src.dtype); } -void SeparableFilterBase::check_layout_fwd(const TensorLayout& src, - const TensorLayout& filter_x, - const TensorLayout& filter_y, - const TensorLayout& dst) { +void SeparableFilterBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, const TensorLayout& dst) { TensorLayout dst_expected; megdnn_assert_eq_layout(src, dst); deduce_layout_fwd(src, filter_x, filter_y, dst_expected); megdnn_assert_eq_layout(dst_expected, dst); } -void SeparableFilterForward::deduce_layout(const TensorLayout& src, - const TensorLayout& filter_x, - const TensorLayout& filter_y, - TensorLayout& dst) { +void SeparableFilterForward::deduce_layout( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, TensorLayout& dst) { deduce_layout_fwd(src, filter_x, filter_y, dst); } -void SeparableFilterForward::check_exec(const TensorLayout& src, - const TensorLayout& filter_x, - const TensorLayout& filter_y, - const TensorLayout& dst, - size_t workspace_in_bytes) { +void SeparableFilterForward::check_exec( + const TensorLayout& src, const TensorLayout& filter_x, + const TensorLayout& filter_y, const TensorLayout& dst, + size_t workspace_in_bytes) { megdnn_assert(param().ksize_h > 0 && (param().ksize_h & 1)); megdnn_assert(param().ksize_w > 0 && (param().ksize_w & 1)); check_layout_fwd(src, filter_x, filter_y, dst); @@ -71,6 +67,6 @@ void SeparableFilterForward::check_exec(const TensorLayout& src, megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/sliding_window_transpose.cpp b/dnn/src/common/sliding_window_transpose.cpp index 6801df09..08f377f0 100644 --- a/dnn/src/common/sliding_window_transpose.cpp +++ b/dnn/src/common/sliding_window_transpose.cpp @@ -14,9 +14,8 @@ namespace megdnn { -void SlidingWindowTransposeBase::deduce_layout_fwd(const TensorLayout &src, - TensorLayout &dst) -{ +void SlidingWindowTransposeBase::deduce_layout_fwd( + const TensorLayout& src, TensorLayout& dst) { auto errmsg = [&]() { return megdnn_layout_msg(src) + ", " + "out_h=" + std::to_string(param().out_h) + ", " + @@ -38,38 +37,32 @@ void SlidingWindowTransposeBase::deduce_layout_fwd(const TensorLayout &src, dst = TensorLayout(TensorShape({n, ic, oh, ow}), src.dtype); } -void SlidingWindowTransposeBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &dst) -{ +void SlidingWindowTransposeBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& dst) { TensorLayout dst_expected; deduce_layout_fwd(src, dst_expected); megdnn_assert_eq_layout(dst_expected, dst); } -void SlidingWindowTransposeForward::deduce_layout(const TensorLayout &src, - TensorLayout &dst) -{ +void SlidingWindowTransposeForward::deduce_layout( + const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } -void SlidingWindowTransposeForward::check_exec(const TensorLayout &src, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void SlidingWindowTransposeForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void SlidingWindowTransposeBackward::check_exec(const TensorLayout &diff, - const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void SlidingWindowTransposeBackward::check_exec( + const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(grad, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(grad, diff); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/small_vector.cpp b/dnn/src/common/small_vector.cpp index 3377294c..5eab20b4 100644 --- a/dnn/src/common/small_vector.cpp +++ b/dnn/src/common/small_vector.cpp @@ -20,8 +20,8 @@ void SmallVectorBase::on_invalid_at(size_t idx, size_t size) { MEGDNN_MARK_USED_VAR(size); } -void SmallVectorBase::grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, - size_t type_size) { +void SmallVectorBase::grow_pod( + void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size) { size_t cur_sz_in_bytes = size_in_bytes(); size_t new_capacity_in_bytes = 2 * capacity_in_bytes() + type_size; if (new_capacity_in_bytes < min_sz_in_bytes) { diff --git a/dnn/src/common/svd.cpp b/dnn/src/common/svd.cpp index bc6f86dc..8f536abf 100644 --- a/dnn/src/common/svd.cpp +++ b/dnn/src/common/svd.cpp @@ -14,8 +14,8 @@ using namespace megdnn; -void SVD::deduce_layout(const TensorLayout& src, TensorLayout& u, - TensorLayout& s, TensorLayout& vt) { +void SVD::deduce_layout( + const TensorLayout& src, TensorLayout& u, TensorLayout& s, TensorLayout& vt) { Param p = param(); size_t m, n; canonize_params(src, nullptr, &m, &n); @@ -51,9 +51,9 @@ void SVD::deduce_layout(const TensorLayout& src, TensorLayout& u, vt = {shape_vt, src.dtype}; } -size_t SVD::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& u, const TensorLayout& s, - const TensorLayout& vt) { +size_t SVD::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, + const TensorLayout& vt) { MEGDNN_MARK_USED_VAR(u); MEGDNN_MARK_USED_VAR(s); MEGDNN_MARK_USED_VAR(vt); @@ -63,10 +63,11 @@ size_t SVD::get_workspace_in_bytes(const TensorLayout& src, return get_workspace_in_bytes(block_cnt, m, n, src.dtype.size()); } -void SVD::canonize_params(const TensorLayout& layout, size_t* block_cnt, - size_t* m, size_t* n) { - megdnn_assert(layout.is_contiguous() && layout.ndim >= 2, - "invalid SVD layout: %s", layout.to_string().c_str()); +void SVD::canonize_params( + const TensorLayout& layout, size_t* block_cnt, size_t* m, size_t* n) { + megdnn_assert( + layout.is_contiguous() && layout.ndim >= 2, "invalid SVD layout: %s", + layout.to_string().c_str()); megdnn_assert(layout.dtype == dtype::Float32(), "SVD only supports f32"); if (block_cnt) { *block_cnt = 1; @@ -82,9 +83,9 @@ void SVD::canonize_params(const TensorLayout& layout, size_t* block_cnt, } } -void SVD::check_exec(const TensorLayout& src, const TensorLayout& u, - const TensorLayout& s, const TensorLayout& vt, - size_t workspace_in_bytes) { +void SVD::check_exec( + const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, + const TensorLayout& vt, size_t workspace_in_bytes) { size_t m, n; canonize_params(src, nullptr, &m, &n); // get_workspace_in_bytes runs the canonize_params, thus runs the check diff --git a/dnn/src/common/tensor_format.cpp b/dnn/src/common/tensor_format.cpp index 1b700e94..c639a4d4 100644 --- a/dnn/src/common/tensor_format.cpp +++ b/dnn/src/common/tensor_format.cpp @@ -24,14 +24,13 @@ DefaultTensorFormat* default_tensor_format_obj; /* ===================== TensorFormat ===================== */ -TensorFormat TensorFormat::deserialize(const std::string& bin, - const Handle* handle) { +TensorFormat TensorFormat::deserialize(const std::string& bin, const Handle* handle) { using Type = TensorFormat::Type; auto type = reinterpret_cast(bin.data()); switch (*type) { case Type::DEFAULT: - return DefaultTensorFormat::deserialize(handle, type + 1, - bin.size() - sizeof(Type)); + return DefaultTensorFormat::deserialize( + handle, type + 1, bin.size() - sizeof(Type)); case Type::IMAGE2D_PACK4: return Image2DPack4TensorFormat::deserialize( handle, type + 1, bin.size() - sizeof(Type)); @@ -46,13 +45,13 @@ TensorFormat TensorFormat::deserialize(const std::string& bin, TensorFormat::Format() : m_impl{DefaultTensorFormat::make().m_impl} {} TensorFormat::Format(DType dtype) { - if (dtype.valid() && - dtype.is_quantized_lowbit()) { // quantized lowbit, by default - // aligned to bytes + if (dtype.valid() && dtype.is_quantized_lowbit()) { // quantized lowbit, by default + // aligned to bytes size_t size_nbits = dtype.low_bit(); - megdnn_assert(size_nbits == 1 || size_nbits == 2 || size_nbits == 4, - "unsupported lowbits data type(%s, size in bits: %zu)", - dtype.name(), size_nbits); + megdnn_assert( + size_nbits == 1 || size_nbits == 2 || size_nbits == 4, + "unsupported lowbits data type(%s, size in bits: %zu)", dtype.name(), + size_nbits); m_impl = LowbitsAlignedToBytesTensorFormat::make(size_nbits).m_impl; } else { // non parameterized lowbit, default format m_impl = DefaultTensorFormat::make().m_impl; @@ -74,9 +73,9 @@ std::string TensorFormat::serialize() const { void TensorFormat::on_bad_cvt(Type dst_type) const { MEGDNN_MARK_USED_VAR(dst_type); - megdnn_throw(ssprintf("can not convert tensor format %s to %d", - impl()->to_string().c_str(), - static_cast(dst_type))); + megdnn_throw(ssprintf( + "can not convert tensor format %s to %d", impl()->to_string().c_str(), + static_cast(dst_type))); } bool TensorFormat::is_default() const { @@ -155,8 +154,7 @@ TensorLayout DefaultTensorFormat::collapse_contiguous_spec( return res; } -TensorLayout::Span DefaultTensorFormat::span_spec( - const TensorLayout& layout) const { +TensorLayout::Span DefaultTensorFormat::span_spec(const TensorLayout& layout) const { assert_valid(layout); if (layout.ndim == 0) return {0, 0, 0, 0}; @@ -192,8 +190,8 @@ std::string DefaultTensorFormat::to_string() const { void DefaultTensorFormat::serialize_append(std::string&) const {} -TensorFormat DefaultTensorFormat::deserialize(const Handle* handle, - const void* buf, size_t size) { +TensorFormat DefaultTensorFormat::deserialize( + const Handle* handle, const void* buf, size_t size) { MEGDNN_MARK_USED_VAR(handle); MEGDNN_MARK_USED_VAR(buf); megdnn_assert(!size); @@ -203,8 +201,8 @@ TensorFormat DefaultTensorFormat::deserialize(const Handle* handle, TensorFormat DefaultTensorFormat::make() { // use static storage so the object is accessible in global destructing // phase - static std::aligned_storage_t + static std::aligned_storage_t< + sizeof(DefaultTensorFormat), alignof(DefaultTensorFormat)> storage; static DefaultTensorFormat* obj = default_tensor_format_obj = new (&storage) DefaultTensorFormat{}; @@ -213,8 +211,8 @@ TensorFormat DefaultTensorFormat::make() { /* ===================== Image2DTensorFormatBase ===================== */ -Image2DTensorFormatBase::Image2DTensorFormatBase(Type type, size_t align_axis, - size_t align_size_in_elements) +Image2DTensorFormatBase::Image2DTensorFormatBase( + Type type, size_t align_axis, size_t align_size_in_elements) : ImplBase(type), m_align_axis(align_axis) { megdnn_assert(align_size_in_elements && align_axis); m_align_size_in_elements_log2 = __builtin_ctz(align_size_in_elements); @@ -242,8 +240,7 @@ size_t Image2DTensorFormatBase::image_height(const TensorLayout& layout) const { return accum; } -size_t Image2DTensorFormatBase::image_width_elems( - const TensorLayout& layout) const { +size_t Image2DTensorFormatBase::image_width_elems(const TensorLayout& layout) const { size_t high_elem = 0; for (size_t i = m_align_axis; i < layout.ndim; ++i) { high_elem += (layout.shape[i] - 1) * layout.stride[i]; @@ -252,8 +249,7 @@ size_t Image2DTensorFormatBase::image_width_elems( } std::string Image2DTensorFormatBase::to_string() const { - return ssprintf("I2D{%zu,%d}", m_align_axis, - 1 << m_align_size_in_elements_log2); + return ssprintf("I2D{%zu,%d}", m_align_axis, 1 << m_align_size_in_elements_log2); } /* ===================== Image2DPackedTensorFormatBase ===================== */ @@ -270,10 +266,12 @@ template void Image2DPackedTensorFormatBase::assert_valid( const TensorLayout& layout) const { auto m_align_axis = align_axis(); - megdnn_assert(!(layout.shape[layout.ndim - 1] % PIXEL_SIZE), - "bad shape: %zu", layout.shape[layout.ndim - 1]); - megdnn_assert(layout.dtype.valid() && !layout.dtype.is_quantized_lowbit() && - layout.ndim > m_align_axis); + megdnn_assert( + !(layout.shape[layout.ndim - 1] % PIXEL_SIZE), "bad shape: %zu", + layout.shape[layout.ndim - 1]); + megdnn_assert( + layout.dtype.valid() && !layout.dtype.is_quantized_lowbit() && + layout.ndim > m_align_axis); ptrdiff_t first_non_zero_stride = 0; for (int i = layout.ndim - 1; i >= 0; --i) { megdnn_assert(layout.shape[i] && layout.stride[i] >= 0); @@ -281,14 +279,13 @@ void Image2DPackedTensorFormatBase::assert_valid( first_non_zero_stride = layout.stride[i]; } } - size_t mask = - image_pitch_alignment_in_bytes( - align_size_in_elements(layout.dtype.size_log()), layout) - - 1; + size_t mask = image_pitch_alignment_in_bytes( + align_size_in_elements(layout.dtype.size_log()), layout) - + 1; - megdnn_assert(!(first_non_zero_stride & mask), - "first stride is %d, but alignment is %zu", - static_cast(first_non_zero_stride), mask + 1); + megdnn_assert( + !(first_non_zero_stride & mask), "first stride is %d, but alignment is %zu", + static_cast(first_non_zero_stride), mask + 1); } template @@ -303,22 +300,19 @@ size_t Image2DPackedTensorFormatBase::image_row_pitch( // use width for all broadcasted case size_t alignment_in_bytes_log2 = align_size_in_elements_log2(); if (m_vendor_type == Handle::HandleVendorType::MALI) { - alignment_in_bytes_log2 += - __builtin_ctz(layout.dtype.size() * PIXEL_SIZE); + alignment_in_bytes_log2 += __builtin_ctz(layout.dtype.size() * PIXEL_SIZE); } return get_aligned_power2( - layout.dtype.size(image_width_elems(layout)), - 1 << alignment_in_bytes_log2); + layout.dtype.size(image_width_elems(layout)), 1 << alignment_in_bytes_log2); } template -size_t -Image2DPackedTensorFormatBase::image_pitch_alignment_in_bytes( +size_t Image2DPackedTensorFormatBase::image_pitch_alignment_in_bytes( size_t align_size_in_elements, const TensorLayout& layout) const { return m_vendor_type == Handle::HandleVendorType::MALI - ? (align_size_in_elements * layout.dtype.size() * PIXEL_SIZE) - : align_size_in_elements; + ? (align_size_in_elements * layout.dtype.size() * PIXEL_SIZE) + : align_size_in_elements; } template @@ -337,9 +331,10 @@ size_t Image2DPackedTensorFormatBase::init_contiguous_stride( auto m_align_axis = align_axis(); if (!layout.ndim) return 0; - megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis, - "dtype=%s ndim=%zu align=%zu", layout.dtype.name(), - layout.ndim, m_align_axis); + megdnn_assert( + layout.dtype.valid() && layout.ndim > m_align_axis, + "dtype=%s ndim=%zu align=%zu", layout.dtype.name(), layout.ndim, + m_align_axis); size_t align_size = image_pitch_alignment_in_bytes( align_size_in_elements(layout.dtype.size_log()), layout); @@ -379,15 +374,15 @@ bool Image2DPackedTensorFormatBase::is_contiguous_spec( return false; } - size_t mask = - image_pitch_alignment_in_bytes( - align_size_in_elements(layout.dtype.size_log()), - layout) - - 1; + size_t mask = image_pitch_alignment_in_bytes( + align_size_in_elements(layout.dtype.size_log()), + layout) - + 1; - megdnn_assert(s > expected && !(s & mask), - "invalid row pitch: %d; layout: %s", - static_cast(s), layout.to_string().c_str()); + megdnn_assert( + s > expected && !(s & mask), + "invalid row pitch: %d; layout: %s", static_cast(s), + layout.to_string().c_str()); expected = s; } else { return false; @@ -452,7 +447,6 @@ TensorLayout Image2DPackedTensorFormatBase::collapse_contiguous_spec return res; } - namespace megdnn { namespace detail { template class Image2DPackedTensorFormatBase<4>; @@ -465,9 +459,10 @@ LowbitsAlignedTensorFormatBase::LowbitsAlignedTensorFormatBase( : ImplBase(type), m_size_nbits(size_nbits), m_align_size_in_bits(align_size_in_bits) { - megdnn_assert(!(m_align_size_in_bits % m_size_nbits), - "align size(%zu) must be a multiple of element size(%zu)", - m_align_size_in_bits, m_size_nbits); + megdnn_assert( + !(m_align_size_in_bits % m_size_nbits), + "align size(%zu) must be a multiple of element size(%zu)", + m_align_size_in_bits, m_size_nbits); m_align_size_in_elements = m_align_size_in_bits / m_size_nbits; } @@ -475,10 +470,10 @@ std::string LowbitsAlignedTensorFormatBase::to_string() const { return ssprintf("LOWBITS{%zu,%zu}", m_size_nbits, m_align_size_in_bits); } -void LowbitsAlignedTensorFormatBase::assert_valid( - const TensorLayout& layout) const { - megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() && - layout.dtype.low_bit() == m_size_nbits); +void LowbitsAlignedTensorFormatBase::assert_valid(const TensorLayout& layout) const { + megdnn_assert( + layout.dtype.valid() && layout.dtype.is_low_bit() && + layout.dtype.low_bit() == m_size_nbits); bool has_dim_unity_stride = false; bool has_dim_aligned_stride = false; for (int i = layout.ndim - 1; i >= 0; --i) { @@ -500,13 +495,11 @@ void LowbitsAlignedTensorFormatBase::assert_valid( "innermost dim not contiguous"); } -void LowbitsAlignedTensorFormatBase::serialize_append( - std::string& result) const { +void LowbitsAlignedTensorFormatBase::serialize_append(std::string& result) const { SerializePack pack; pack.size_nbits = m_size_nbits; pack.align_size_in_bits = m_align_size_in_bits; - megdnn_assert(pack.align_size_in_bits == - m_align_size_in_bits); // detect overflow; + megdnn_assert(pack.align_size_in_bits == m_align_size_in_bits); // detect overflow; result.append(reinterpret_cast(&pack), sizeof(pack)); } @@ -523,8 +516,8 @@ TensorLayout::Span LowbitsAlignedTensorFormatBase::span_spec( return {0, 0, 0, 0}; } auto stride_val = layout.stride[i]; - megdnn_assert(stride_val >= 0, - "lowbit tensors shouldn't have negative strides"); + megdnn_assert( + stride_val >= 0, "lowbit tensors shouldn't have negative strides"); high_elem += (shape_val - 1) * stride_val; } ++high_elem; @@ -558,8 +551,7 @@ bool LowbitsAlignedTensorFormatBase::is_contiguous_spec( bool is_valid_stride = (layout.stride[i] == expected) || (expected == 1 && - (int)layout.stride[i] == - round_up(1, (int)m_align_size_in_elements)); + (int)layout.stride[i] == round_up(1, (int)m_align_size_in_elements)); if (layout.shape[i] != 1 && !is_valid_stride) return false; auto multiplier = layout.shape[i]; @@ -605,14 +597,14 @@ TensorFormat Image2DPack4TensorFormat::make_raw( size_t align_axis, size_t align_size_in_elements, Handle::HandleVendorType vendor_type) { static DNN_MUTEX mtx; - static std::unordered_map> + static std::unordered_map> cache; - megdnn_assert(std::max(align_axis, align_size_in_elements) <= - std::numeric_limits::max()); + megdnn_assert( + std::max(align_axis, align_size_in_elements) <= + std::numeric_limits::max()); MEGDNN_LOCK_GUARD(mtx); - auto&& ptr = cache[(static_cast(align_axis) << 32) | - align_size_in_elements]; + auto&& ptr = + cache[(static_cast(align_axis) << 32) | align_size_in_elements]; if (!ptr) { ptr.reset(new Image2DPack4TensorFormat{ align_axis, align_size_in_elements, vendor_type}); @@ -620,15 +612,13 @@ TensorFormat Image2DPack4TensorFormat::make_raw( return impl_to_tensor_format(ptr.get()); } -TensorFormat Image2DPack4TensorFormat::make(size_t align_axis, - const Handle* handle) { - return make_raw(align_axis, handle->image2d_pitch_alignment(), - handle->vendor_type()); +TensorFormat Image2DPack4TensorFormat::make(size_t align_axis, const Handle* handle) { + return make_raw( + align_axis, handle->image2d_pitch_alignment(), handle->vendor_type()); } -TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle, - const void* buf, - size_t size) { +TensorFormat Image2DPack4TensorFormat::deserialize( + const Handle* handle, const void* buf, size_t size) { megdnn_assert(size == sizeof(SerializePack)); auto pack = *static_cast(buf); return make(pack.align_axis, handle); @@ -654,9 +644,8 @@ TensorFormat LowbitsAlignedToBytesTensorFormat::make(size_t size_nbits) { return impl_to_tensor_format(ptr.get()); } -TensorFormat LowbitsAlignedToBytesTensorFormat::deserialize(const Handle*, - const void* buf, - size_t size) { +TensorFormat LowbitsAlignedToBytesTensorFormat::deserialize( + const Handle*, const void* buf, size_t size) { megdnn_assert(size == sizeof(SerializePack)); auto pack = *static_cast(buf); return make(pack.size_nbits); diff --git a/dnn/src/common/tensor_iter.cpp b/dnn/src/common/tensor_iter.cpp index 0fa27db3..b6a27ea5 100644 --- a/dnn/src/common/tensor_iter.cpp +++ b/dnn/src/common/tensor_iter.cpp @@ -19,15 +19,12 @@ TypeRef::TypeRef(dt_quint4* _ptr, size_t _offset) { ptr = reinterpret_cast(_ptr); offset = _offset; uint8_t cur = ptr[offset >> 1]; - val = convert(cur, dt_quint4(cur), offset & 0x1) - .as_uint8(); - + val = convert(cur, dt_quint4(cur), offset & 0x1).as_uint8(); } void TypeRef::operator=(const uint8_t _) { uint8_t cur = ptr[offset >> 1]; - ptr[offset >> 1] = - convert(dt_quint4(_), cur, offset & 0x1); + ptr[offset >> 1] = convert(dt_quint4(_), cur, offset & 0x1); } TypeRef::TypeRef(dt_qint4* _ptr, size_t _offset) { @@ -39,16 +36,14 @@ TypeRef::TypeRef(dt_qint4* _ptr, size_t _offset) { void TypeRef::operator=(const int8_t _) { int8_t cur = ptr[offset >> 1]; - ptr[offset >> 1] = - convert(dt_qint4(_), cur, offset & 0x1); + ptr[offset >> 1] = convert(dt_qint4(_), cur, offset & 0x1); } ////////////////////// TensorIter ///////////////////// -template -typename TensorIter::Iter -TensorIter::Iter::make( - ctype *ptr, const TensorLayout &layout, size_t offset) { +template +typename TensorIter::Iter TensorIter::Iter::make( + ctype* ptr, const TensorLayout& layout, size_t offset) { megdnn_assert(layout.ndim); Iter rst; rst.m_ptr = ptr; @@ -60,7 +55,7 @@ TensorIter::Iter::make( rst.m_tot_nr_elems = rst.m_layout.total_nr_elems(); rst.m_offset = 0; megdnn_assert(offset <= rst.m_tot_nr_elems); - for (int i = rst.m_layout.ndim - 1; i >= 0; i --) { + for (int i = rst.m_layout.ndim - 1; i >= 0; i--) { auto shp = rst.m_layout.shape[i]; auto stride = rst.m_layout.stride[i]; if (!shp) { @@ -75,19 +70,19 @@ TensorIter::Iter::make( return rst; } -template +template void TensorIter::Iter::on_access_idx_valonly_true() const { megdnn_throw("can not access idx of TensorIter if valonly is true"); } namespace megdnn { -#define cb(_dt) \ +#define cb(_dt) \ template class TensorIter::ctype, false>; \ template class TensorIter::ctype, true>; - MEGDNN_FOREACH_DTYPE_NAME(cb) - MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) +MEGDNN_FOREACH_DTYPE_NAME(cb) +MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) #undef cb -} +} // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/common/tensor_remap.cpp b/dnn/src/common/tensor_remap.cpp index 06cc6c64..ba09268a 100644 --- a/dnn/src/common/tensor_remap.cpp +++ b/dnn/src/common/tensor_remap.cpp @@ -14,16 +14,13 @@ namespace megdnn { -void IndexingRemapBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &map, - const TensorLayout &dst) -{ +void IndexingRemapBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& map, const TensorLayout& dst) { megdnn_assert_non_overlapping_strong(src); megdnn_assert_contiguous(map); megdnn_assert_non_overlapping_strong(dst); - auto errmsg = megdnn_layout_msg(src) + ", " - + megdnn_layout_msg(map) + ", " - + megdnn_layout_msg(dst); + auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(map) + ", " + + megdnn_layout_msg(dst); auto errmsg_c = errmsg.c_str(); MEGDNN_MARK_USED_VAR(errmsg_c); megdnn_assert(map.ndim == dst.ndim + 1, "%s", errmsg_c); @@ -33,41 +30,35 @@ void IndexingRemapBase::check_layout_fwd(const TensorLayout &src, megdnn_assert(map.shape[dst.ndim] == src.ndim, "%s", errmsg_c); megdnn_assert(dst.dtype == src.dtype); - megdnn_assert(src.dtype == dtype::Float32() || src.dtype == dtype::Int32(), - "indexing remap only support float32/int32, got %s", - src.dtype.name()); + megdnn_assert( + src.dtype == dtype::Float32() || src.dtype == dtype::Int32(), + "indexing remap only support float32/int32, got %s", src.dtype.name()); megdnn_assert(map.dtype == dtype::Int32()); } -void IndexingRemapForward::deduce_layout(const TensorLayout &src, - const TensorLayout &map, - TensorLayout &dst) -{ +void IndexingRemapForward::deduce_layout( + const TensorLayout& src, const TensorLayout& map, TensorLayout& dst) { dst = map; dst.dtype = src.dtype; --dst.ndim; dst.init_contiguous_stride(); } -void IndexingRemapForward::check_exec(const TensorLayout &src, - const TensorLayout &map, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void IndexingRemapForward::check_exec( + const TensorLayout& src, const TensorLayout& map, const TensorLayout& dst, + size_t workspace_in_bytes) { check_layout_fwd(src, map, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, map, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void IndexingRemapBackward::check_exec(const TensorLayout &diff, - const TensorLayout &map, - const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void IndexingRemapBackward::check_exec( + const TensorLayout& diff, const TensorLayout& map, const TensorLayout& grad, + size_t workspace_in_bytes) { check_layout_fwd(grad, map, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(diff, map, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/tile_repeat.cpp b/dnn/src/common/tile_repeat.cpp index 120b3d91..02bb93ad 100644 --- a/dnn/src/common/tile_repeat.cpp +++ b/dnn/src/common/tile_repeat.cpp @@ -16,11 +16,10 @@ namespace megdnn { -void TileRepeatBase::check_layout_fwd(const TensorLayout &src, - const TensorLayout &dst) -{ - auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) - + ", " + "times=" + param().times.to_string(); +void TileRepeatBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& dst) { + auto errmsg = megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " + + "times=" + param().times.to_string(); auto errmsg_c = errmsg.c_str(); MEGDNN_MARK_USED_VAR(errmsg_c); megdnn_assert_contiguous(src); @@ -29,30 +28,23 @@ void TileRepeatBase::check_layout_fwd(const TensorLayout &src, megdnn_assert(expected_ndim == src.ndim, "%s", errmsg_c); megdnn_assert(expected_ndim == dst.ndim, "%s", errmsg_c); rep(i, expected_ndim) { - megdnn_assert(dst.shape[i] == param().times[i] * src.shape[i], - "%s", errmsg_c); + megdnn_assert(dst.shape[i] == param().times[i] * src.shape[i], "%s", errmsg_c); } megdnn_assert(src.dtype == dst.dtype); } -void TileRepeatBase::deduce_layout_fwd(const TensorLayout &src, - TensorLayout &dst) -{ +void TileRepeatBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { dst.ndim = src.ndim; - rep(i, src.ndim) { - dst.shape[i] = src.shape[i] * param().times[i]; - } + rep(i, src.ndim) { dst.shape[i] = src.shape[i] * param().times[i]; } dst.dtype = src.dtype; dst.init_contiguous_stride(); check_layout_fwd(src, dst); } -size_t TileRepeatBase::get_workspace_in_bytes_fwd(const TensorShape & /* src */, - const TensorShape &dst, - const TensorShape ×, - DType dtype) -{ +size_t TileRepeatBase::get_workspace_in_bytes_fwd( + const TensorShape& /* src */, const TensorShape& dst, const TensorShape& times, + DType dtype) { size_t nr_workspace = 0; auto nr_reduces = count_not_ones_in_shape(times); if (nr_reduces == 0) { @@ -78,18 +70,14 @@ size_t TileRepeatBase::get_workspace_in_bytes_fwd(const TensorShape & /* src */, } } -void TileBase::simplify_shape(const TensorShape &src, - const TensorShape &dst, - const TensorShape ×, - TensorShape &src2, - TensorShape &dst2, - TensorShape ×2) -{ +void TileBase::simplify_shape( + const TensorShape& src, const TensorShape& dst, const TensorShape& times, + TensorShape& src2, TensorShape& dst2, TensorShape& times2) { size_t n = 0; for (size_t i = 0; i < src.ndim; ++i) { if (times.shape[i] == 1 && n > 0) { - src2.shape[n-1] *= src.shape[i]; - dst2.shape[n-1] *= dst.shape[i]; + src2.shape[n - 1] *= src.shape[i]; + dst2.shape[n - 1] *= dst.shape[i]; } else { src2.shape[n] = src.shape[i]; dst2.shape[n] = dst.shape[i]; @@ -100,54 +88,46 @@ void TileBase::simplify_shape(const TensorShape &src, src2.ndim = dst2.ndim = times2.ndim = n; } -size_t TileBase::get_workspace_in_bytes_fwd(const TensorLayout &src_, - const TensorLayout &dst_) -{ +size_t TileBase::get_workspace_in_bytes_fwd( + const TensorLayout& src_, const TensorLayout& dst_) { TensorShape src, dst, times; simplify_shape(src_, dst_, param().times, src, dst, times); - return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times, - src_.dtype); + return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times, src_.dtype); } -void TileForward::deduce_layout(const TensorLayout &src, - TensorLayout &dst) -{ +void TileForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } -void TileForward::check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void TileForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void TileBackward::check_exec(const TensorLayout &diff, const TensorLayout &grad, - size_t workspace_in_bytes) -{ +void TileBackward::check_exec( + const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(grad, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void RepeatBase::simplify_shape(const TensorShape &src, - const TensorShape & /* dst */, - const TensorShape ×, - TensorShape &src2, - TensorShape &dst2, - TensorShape ×2) -{ +void RepeatBase::simplify_shape( + const TensorShape& src, const TensorShape& /* dst */, const TensorShape& times, + TensorShape& src2, TensorShape& dst2, TensorShape& times2) { auto n = 0u; size_t i = 0; while (i < times.ndim) { size_t j = i; - while (j < times.ndim && times.shape[j] == 1) ++j; + while (j < times.ndim && times.shape[j] == 1) + ++j; // Here: j is times.ndim, or times.shape[j] != 1 - if (j < times.ndim) ++j; - src2.shape[n] = std::accumulate(src.shape + i, src.shape + j, - 1_z, SafeMultiplies()); - times2.shape[n] = times.shape[j-1]; + if (j < times.ndim) + ++j; + src2.shape[n] = std::accumulate( + src.shape + i, src.shape + j, 1_z, SafeMultiplies()); + times2.shape[n] = times.shape[j - 1]; dst2.shape[n] = src2.shape[n] * times2.shape[n]; ++n; i = j; @@ -155,37 +135,31 @@ void RepeatBase::simplify_shape(const TensorShape &src, src2.ndim = dst2.ndim = times2.ndim = n; } -size_t RepeatBase::get_workspace_in_bytes_fwd(const TensorLayout &src_, - const TensorLayout &dst_) -{ +size_t RepeatBase::get_workspace_in_bytes_fwd( + const TensorLayout& src_, const TensorLayout& dst_) { TensorShape src, dst, times; simplify_shape(src_, dst_, param().times, src, dst, times); - return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times, - src_.dtype); + return TileRepeatBase::get_workspace_in_bytes_fwd(src, dst, times, src_.dtype); } -void RepeatForward::deduce_layout(const TensorLayout &src, - TensorLayout &dst) -{ +void RepeatForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { deduce_layout_fwd(src, dst); } -void RepeatForward::check_exec(const TensorLayout &src, const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void RepeatForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void RepeatBackward::check_exec(const TensorLayout &diff, - const TensorLayout &grad, size_t workspace_in_bytes) -{ +void RepeatBackward::check_exec( + const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(grad, diff); auto required_workspace_in_bytes = get_workspace_in_bytes(diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/tile_repeat_helper.cpp b/dnn/src/common/tile_repeat_helper.cpp index 27bca9fe..4fc3105c 100644 --- a/dnn/src/common/tile_repeat_helper.cpp +++ b/dnn/src/common/tile_repeat_helper.cpp @@ -10,40 +10,37 @@ */ #include "src/common/tile_repeat_helper.h" -#include "src/common/utils.h" #include +#include "src/common/utils.h" namespace megdnn { // Tile (m, n) to (m, n*times) or Repeat (m, n) to (m*times, n) template -void tile_or_repeat_single_axis(const T * __restrict src, - T * __restrict dst, - const size_t m, const size_t n, const size_t times) -{ +void tile_or_repeat_single_axis( + const T* __restrict src, T* __restrict dst, const size_t m, const size_t n, + const size_t times) { rep(i, m) { // copy Ts of length n to dst std::memcpy(dst, src, sizeof(T) * n); size_t k = 1u; - while (k*2 <= times) { - std::memcpy(dst + k*n, dst, sizeof(T) * (k*n)); + while (k * 2 <= times) { + std::memcpy(dst + k * n, dst, sizeof(T) * (k * n)); k *= 2; } if (k < times) { - std::memcpy(dst + k*n, dst, sizeof(T) * (times-k) * n); + std::memcpy(dst + k * n, dst, sizeof(T) * (times - k) * n); } src += n; - dst += n*times; + dst += n * times; } } template -void init_tile_repeat_state(const T *src, T *dst, - T *workspace0, T * /* workspace1 */, - T *¤t, T *&next, size_t &state, - size_t nr_reduces) -{ - current = const_cast(src); +void init_tile_repeat_state( + const T* src, T* dst, T* workspace0, T* /* workspace1 */, T*& current, T*& next, + size_t& state, size_t nr_reduces) { + current = const_cast(src); if (nr_reduces == 1) { next = dst; } else { @@ -53,11 +50,9 @@ void init_tile_repeat_state(const T *src, T *dst, } template -void update_tile_repeat_state(const T * /* src */, T *dst, - T *workspace0, T *workspace1, - T *¤t, T *&next, size_t &state, - size_t nr_reduces) -{ +void update_tile_repeat_state( + const T* /* src */, T* dst, T* workspace0, T* workspace1, T*& current, T*& next, + size_t& state, size_t nr_reduces) { current = next; if (nr_reduces == 1) { next = nullptr; @@ -75,7 +70,8 @@ void update_tile_repeat_state(const T * /* src */, T *dst, } else if (state + 2 == nr_reduces) { next = dst; } else { - megdnn_assert(current == workspace0 || current == workspace1, + megdnn_assert( + current == workspace0 || current == workspace1, "Impossible happened; internal bug."); next = (current == workspace0 ? workspace1 : workspace0); } @@ -83,19 +79,18 @@ void update_tile_repeat_state(const T * /* src */, T *dst, ++state; } -#define INST(T) \ -template void tile_or_repeat_single_axis(const T *, T *, \ - const size_t, const size_t, const size_t); \ -template void init_tile_repeat_state(const T *, T *, T *, T *, T *&, T *&, \ - size_t &, size_t); \ -template void update_tile_repeat_state(const T *, T *, T *, T *, T *&, T *&, \ - size_t &, size_t); +#define INST(T) \ + template void tile_or_repeat_single_axis( \ + const T*, T*, const size_t, const size_t, const size_t); \ + template void init_tile_repeat_state( \ + const T*, T*, T*, T*, T*&, T*&, size_t&, size_t); \ + template void update_tile_repeat_state( \ + const T*, T*, T*, T*, T*&, T*&, size_t&, size_t); #define INST_DT(d) INST(DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(INST_DT) -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/common/tile_repeat_helper.h b/dnn/src/common/tile_repeat_helper.h index 10396955..6ff34ab5 100644 --- a/dnn/src/common/tile_repeat_helper.h +++ b/dnn/src/common/tile_repeat_helper.h @@ -15,22 +15,19 @@ namespace megdnn { // Tile (m, n) to (m, n*times) or Repeat (m, n) to (m*times, n) template -void tile_or_repeat_single_axis(const T * __restrict src, - T * __restrict dst, - const size_t m, const size_t n, const size_t times); +void tile_or_repeat_single_axis( + const T* __restrict src, T* __restrict dst, const size_t m, const size_t n, + const size_t times); // forward and backward can share the same init/update functions. template -void init_tile_repeat_state(const T *src, T *dst, - T *workspace0, T *workspace1, - T *¤t, T *&next, size_t &state, - size_t nr_reduces); +void init_tile_repeat_state( + const T* src, T* dst, T* workspace0, T* workspace1, T*& current, T*& next, + size_t& state, size_t nr_reduces); template -void update_tile_repeat_state(const T *src, T *dst, - T *workspace0, T *workspace1, - T *¤t, T *&next, size_t &state, - size_t nr_reduces); +void update_tile_repeat_state( + const T* src, T* dst, T* workspace0, T* workspace1, T*& current, T*& next, + size_t& state, size_t nr_reduces); -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/common/topk.cpp b/dnn/src/common/topk.cpp index ecf0a7ae..df2d665d 100644 --- a/dnn/src/common/topk.cpp +++ b/dnn/src/common/topk.cpp @@ -17,10 +17,11 @@ using namespace megdnn; -void TopK::deduce_layout(int k, const TensorLayout& data, TensorLayout& values, - TensorLayout& indices) { - megdnn_assert(k && data.ndim == 2 && data.stride[1] == 1, - "invalid k=%d data=%s", k, data.to_string().c_str()); +void TopK::deduce_layout( + int k, const TensorLayout& data, TensorLayout& values, TensorLayout& indices) { + megdnn_assert( + k && data.ndim == 2 && data.stride[1] == 1, "invalid k=%d data=%s", k, + data.to_string().c_str()); values.dtype = data.dtype; indices.dtype = dtype::Int32{}; switch (param().mode) { @@ -39,8 +40,9 @@ void TopK::deduce_layout(int k, const TensorLayout& data, TensorLayout& values, } } -void TopK::exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values, - _megdnn_tensor_out indices, _megdnn_workspace workspace) { +void TopK::exec( + int k, _megdnn_tensor_in data, _megdnn_tensor_out values, + _megdnn_tensor_out indices, _megdnn_workspace workspace) { TensorLayout oval, oidx; deduce_layout(k, data.layout, oval, oidx); megdnn_assert_eq_layout(oval, values.layout); @@ -51,9 +53,9 @@ void TopK::exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values, iptr = indices.ptr(); megdnn_assert_eq_layout(oidx, indices.layout); } - megdnn_assert(workspace.size >= get_workspace_in_bytes(k, data.layout, - values.layout, - indices.layout)); + megdnn_assert( + workspace.size >= + get_workspace_in_bytes(k, data.layout, values.layout, indices.layout)); if (static_cast(std::abs(k)) > data.layout[1]) { if (k > 0) { k = data.layout[1]; @@ -65,4 +67,3 @@ void TopK::exec(int k, _megdnn_tensor_in data, _megdnn_tensor_out values, } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/common/tqt.cpp b/dnn/src/common/tqt.cpp index 757bfa24..95a094f5 100644 --- a/dnn/src/common/tqt.cpp +++ b/dnn/src/common/tqt.cpp @@ -15,14 +15,13 @@ namespace megdnn { -void TQTBase::deduce_layout_fwd(const TensorLayout& input, - TensorLayout& output) { +void TQTBase::deduce_layout_fwd(const TensorLayout& input, TensorLayout& output) { output = TensorLayout(input, input.dtype); } -void TQTBase::check_layout_fwd(const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& output) { +void TQTBase::check_layout_fwd( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& output) { megdnn_assert(input.dtype == dtype::Float32()); megdnn_assert(scale.dtype == dtype::Float32()); TensorLayout expected; @@ -30,28 +29,24 @@ void TQTBase::check_layout_fwd(const TensorLayout& input, megdnn_assert_eq_layout(expected, output); } -void TQTForward::deduce_layout(const TensorLayout& input, - const TensorLayout& /* scale */, - TensorLayout& output) { +void TQTForward::deduce_layout( + const TensorLayout& input, const TensorLayout& /* scale */, + TensorLayout& output) { deduce_layout_fwd(input, output); } -void TQTForward::check_exec(const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& output, - size_t workspace_in_bytes) { +void TQTForward::check_exec( + const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& output, size_t workspace_in_bytes) { check_layout_fwd(input, scale, output); - auto required_workspace_space = - get_workspace_in_bytes(input, scale, output); + auto required_workspace_space = get_workspace_in_bytes(input, scale, output); megdnn_assert(workspace_in_bytes >= required_workspace_space); } -void TQTBackward::check_exec(const TensorLayout& diff, - const TensorLayout& input, - const TensorLayout& scale, - const TensorLayout& grad_x, - const TensorLayout& grad_s, - size_t workspace_in_bytes) { +void TQTBackward::check_exec( + const TensorLayout& diff, const TensorLayout& input, const TensorLayout& scale, + const TensorLayout& grad_x, const TensorLayout& grad_s, + size_t workspace_in_bytes) { megdnn_assert_eq_shape(diff, input); megdnn_assert_eq_shape(grad_x, input); auto required_worspace_space = diff --git a/dnn/src/common/transpose.cpp b/dnn/src/common/transpose.cpp index e9d5abc7..d30f559b 100644 --- a/dnn/src/common/transpose.cpp +++ b/dnn/src/common/transpose.cpp @@ -14,18 +14,15 @@ namespace megdnn { -void TransposeForward::deduce_layout(const TensorLayout &src, TensorLayout &dst) -{ +void TransposeForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { dst = src; dst.dtype = src.dtype; std::swap(dst.shape[0], dst.shape[1]); dst.init_contiguous_stride(); } -void TransposeForward::check_exec(const TensorLayout &src, - const TensorLayout &dst, - size_t workspace_in_bytes) -{ +void TransposeForward::check_exec( + const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { // dtype must collide megdnn_assert(src.dtype == dst.dtype); // ndim must be 2 @@ -47,5 +44,5 @@ void TransposeForward::check_exec(const TensorLayout &src, megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/type_cvt.cpp b/dnn/src/common/type_cvt.cpp index b287063d..af1e773e 100644 --- a/dnn/src/common/type_cvt.cpp +++ b/dnn/src/common/type_cvt.cpp @@ -14,19 +14,19 @@ namespace megdnn { -void TypeCvt::check_exec(const TensorLayout &src, const TensorLayout &dst) { +void TypeCvt::check_exec(const TensorLayout& src, const TensorLayout& dst) { megdnn_assert_contiguous(dst); megdnn_assert_eq_shape(src, dst); auto cat = src.dtype.category(); - megdnn_assert(cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT || - cat == DTypeCategory::QUANTIZED || - cat == DTypeCategory::BOOL); + megdnn_assert( + cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT || + cat == DTypeCategory::QUANTIZED || cat == DTypeCategory::BOOL); cat = dst.dtype.category(); - megdnn_assert(cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT || - cat == DTypeCategory::QUANTIZED || - cat == DTypeCategory::BOOL); + megdnn_assert( + cat == DTypeCategory::FLOAT || cat == DTypeCategory::INT || + cat == DTypeCategory::QUANTIZED || cat == DTypeCategory::BOOL); } -} // namespace megdnn +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/utils.cpp b/dnn/src/common/utils.cpp index edf5ddd5..28501839 100644 --- a/dnn/src/common/utils.cpp +++ b/dnn/src/common/utils.cpp @@ -10,8 +10,8 @@ */ #include "src/common/utils.h" -#include "megdnn/oprs/utils.h" #include "megdnn/handle.h" +#include "megdnn/oprs/utils.h" #include #include @@ -63,8 +63,9 @@ std::string megdnn::ssprintf(const char* fmt, ...) { return rst; } -void megdnn::__assert_fail__(const char* file, int line, const char* func, - const char* expr, const char* msg_fmt, ...) { +void megdnn::__assert_fail__( + const char* file, int line, const char* func, const char* expr, + const char* msg_fmt, ...) { std::string msg; if (msg_fmt) { va_list ap; @@ -73,13 +74,13 @@ void megdnn::__assert_fail__(const char* file, int line, const char* func, msg.append(svsprintf(msg_fmt, ap)); va_end(ap); } - msg = ssprintf("assertion `%s' failed at %s:%d: %s%s", expr, file, line, - func, msg.c_str()); + msg = ssprintf( + "assertion `%s' failed at %s:%d: %s%s", expr, file, line, func, + msg.c_str()); megdnn_throw(msg.c_str()); } -bool megdnn::get_next_addr(size_t* idx, const size_t* shp, size_t n, - size_t stride) { +bool megdnn::get_next_addr(size_t* idx, const size_t* shp, size_t n, size_t stride) { auto errmsg = [&]() { std::string res; res.append("idx={"); @@ -137,25 +138,25 @@ size_t megdnn::get_linear_addr(size_t* index, const size_t* shape, size_t n) { return ans; } -size_t megdnn::infer_conv_shape(size_t inp, size_t flt, size_t stride, - size_t pad, bool is_floor) { - megdnn_assert(inp + 2 * pad >= flt, "input=%zu padding=%zu filter=%zu", inp, - pad, flt); +size_t megdnn::infer_conv_shape( + size_t inp, size_t flt, size_t stride, size_t pad, bool is_floor) { + megdnn_assert( + inp + 2 * pad >= flt, "input=%zu padding=%zu filter=%zu", inp, pad, flt); if (is_floor) { return (inp + 2 * pad - flt) / stride + 1; } return (inp + 2 * pad - flt + stride - 1) / stride + 1; } -void megdnn::infer_conv_shape2d(size_t ih, size_t iw, size_t fh, size_t fw, - size_t sh, size_t sw, size_t ph, size_t pw, - size_t& oh, size_t& ow, bool is_floor) { +void megdnn::infer_conv_shape2d( + size_t ih, size_t iw, size_t fh, size_t fw, size_t sh, size_t sw, size_t ph, + size_t pw, size_t& oh, size_t& ow, bool is_floor) { oh = infer_conv_shape(ih, fh, sh, ph, is_floor); ow = infer_conv_shape(iw, fw, sw, pw, is_floor); } -WorkspaceBundle::WorkspaceBundle(void* ptr, SmallVector sizes_in_bytes, - size_t align_in_bytes) +WorkspaceBundle::WorkspaceBundle( + void* ptr, SmallVector sizes_in_bytes, size_t align_in_bytes) : m_ptr(ptr), m_sizes(std::move(sizes_in_bytes)), m_align_in_bytes(align_in_bytes) { @@ -197,9 +198,8 @@ void WorkspaceBundle::set(void* ptr) { size_t WorkspaceBundle::total_size_in_bytes() const { //! return 0 if the WorkspaceBundle is empty - size_t size = - std::accumulate(m_aligned_sizes.begin(), m_aligned_sizes.end(), - static_cast(0)); + size_t size = std::accumulate( + m_aligned_sizes.begin(), m_aligned_sizes.end(), static_cast(0)); return size ? size + m_align_in_bytes : size; } @@ -211,8 +211,7 @@ size_t megdnn::count_not_ones_in_shape(const TensorShape& shape) { } bool megdnn::is_nhwc_contig_wc(const TensorLayout& layout) { - return layout.ndim == 4 && - (layout.stride[3] == 1 || layout.shape[3] == 1) && + return layout.ndim == 4 && (layout.stride[3] == 1 || layout.shape[3] == 1) && (layout.stride[2] == static_cast(layout.shape[3]) || layout.shape[2] == 1); } @@ -222,8 +221,7 @@ megcoreDeviceHandle_t megdnn::get_device_handle(Handle* handle) { megcoreDeviceHandle_t dev_handle; megcoreComputingHandle_t comp_handle = handle->megcore_computing_handle(); status = megcoreGetDeviceHandle(comp_handle, &dev_handle); - megdnn_throw_if(status != megcoreSuccess, megdnn_error, - "get device handle error!"); + megdnn_throw_if(status != megcoreSuccess, megdnn_error, "get device handle error!"); return dev_handle; } @@ -279,10 +277,8 @@ bool megdnn::dtype_almost_equal(DType lhs, DType rhs) { } template <> -uint8_t megdnn::convert(dt_quint4 src, uint8_t dst, - size_t offset) { - uint8_t _src = - std::min(src.as_uint8(), DTypeTrait::max()); +uint8_t megdnn::convert(dt_quint4 src, uint8_t dst, size_t offset) { + uint8_t _src = std::min(src.as_uint8(), DTypeTrait::max()); if (offset == 0) { _src &= 0xF; dst &= 0xF0; @@ -296,8 +292,8 @@ uint8_t megdnn::convert(dt_quint4 src, uint8_t dst, } template <> -dt_quint4 megdnn::convert(uint8_t src, dt_quint4 dst, - size_t offset) { +dt_quint4 megdnn::convert( + uint8_t src, dt_quint4 dst, size_t offset) { src >>= (offset << 2); src &= 0xF; dst = dt_quint4(src); @@ -305,8 +301,7 @@ dt_quint4 megdnn::convert(uint8_t src, dt_quint4 dst, } template <> -int8_t megdnn::convert(dt_qint4 src, int8_t dst, - size_t offset) { +int8_t megdnn::convert(dt_qint4 src, int8_t dst, size_t offset) { int8_t _src = std::max( std::min(src.as_int8(), DTypeTrait::max()), DTypeTrait::min()); @@ -323,8 +318,7 @@ int8_t megdnn::convert(dt_qint4 src, int8_t dst, } template <> -dt_qint4 megdnn::convert(int8_t src, dt_qint4 dst, - size_t offset) { +dt_qint4 megdnn::convert(int8_t src, dt_qint4 dst, size_t offset) { src <<= (4 - (offset << 2)); src >>= 4; dst = dt_qint4(src); @@ -341,37 +335,39 @@ std::string CpuNDRange::to_string() const { } size_t& CpuNDRange::operator[](size_t idx) { - megdnn_assert(idx < m_dimension, "invalid index: %zu expected < %zu", idx, - m_dimension); + megdnn_assert( + idx < m_dimension, "invalid index: %zu expected < %zu", idx, m_dimension); return m_dim[idx]; } -bool megdnn::check_bias_share_in_channel(const TensorLayout& bias, - const param::ConvBias::Format format) { +bool megdnn::check_bias_share_in_channel( + const TensorLayout& bias, const param::ConvBias::Format format) { bool share_in_channel = false; if (format == param::ConvBias::Format::NCHW || format == param::ConvBias::Format::NCHW4_NCHW) { - share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && - bias[3] == 1); - } else if (format == param::ConvBias::Format::NHWC || - format == param::ConvBias::Format::NCHW4_NHWC) { - share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 && - bias[2] == 1); - } else if (format == param::ConvBias::Format::NCHW4 || - format == param::ConvBias::Format::NCHW8 || - format == param::ConvBias::Format::NCHW32 || - format == param::ConvBias::Format::NCHW64 || - format == param::ConvBias::Format::NCHW4_NCHW32 || - format == param::ConvBias::Format::NCHW32_NCHW4) { - share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && - bias[3] == 1); + share_in_channel = + (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && bias[3] == 1); + } else if ( + format == param::ConvBias::Format::NHWC || + format == param::ConvBias::Format::NCHW4_NHWC) { + share_in_channel = + (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 && bias[2] == 1); + } else if ( + format == param::ConvBias::Format::NCHW4 || + format == param::ConvBias::Format::NCHW8 || + format == param::ConvBias::Format::NCHW32 || + format == param::ConvBias::Format::NCHW64 || + format == param::ConvBias::Format::NCHW4_NCHW32 || + format == param::ConvBias::Format::NCHW32_NCHW4) { + share_in_channel = + (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && bias[3] == 1); } else if (format == param::ConvBias::Format::NHWCD4) { - share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[1] == 1 && - bias[3] == 1); + share_in_channel = + (bias.ndim == 5 && bias[0] == 1 && bias[1] == 1 && bias[3] == 1); } else { megdnn_assert(format == param::ConvBias::Format::CHWN4); - share_in_channel = (bias.ndim == 5 && bias[1] == 1 && bias[2] == 1 && - bias[3] == 1); + share_in_channel = + (bias.ndim == 5 && bias[1] == 1 && bias[2] == 1 && bias[3] == 1); } return share_in_channel; } diff --git a/dnn/src/common/utils.cuh b/dnn/src/common/utils.cuh index 12deba34..89e99944 100644 --- a/dnn/src/common/utils.cuh +++ b/dnn/src/common/utils.cuh @@ -13,7 +13,7 @@ #include "megdnn/arch.h" //! a comma to be used in macro for template params -#define MEGDNN_COMMA , +#define MEGDNN_COMMA , #define MEGDNN_MARK_USED_VAR(v) static_cast(v) #if MEGDNN_ENABLE_LOGGING @@ -34,21 +34,21 @@ //! megdnn_assert #if MEGDNN_ENABLE_LOGGING #if MEGDNN_ENABLE_MANGLING -#define megdnn_assert(expr, ...) \ - do { \ - if (megdnn_unlikely(!(expr))) { \ - ::megdnn::__assert_fail__( \ - "about location info, please build with debug", __LINE__, \ - NULL, #expr, ##__VA_ARGS__); \ - } \ +#define megdnn_assert(expr, ...) \ + do { \ + if (megdnn_unlikely(!(expr))) { \ + ::megdnn::__assert_fail__( \ + "about location info, please build with debug", __LINE__, NULL, \ + #expr, ##__VA_ARGS__); \ + } \ } while (0) #else -#define megdnn_assert(expr, ...) \ - do { \ - if (megdnn_unlikely(!(expr))) { \ - ::megdnn::__assert_fail__(__FILE__, __LINE__, __PRETTY_FUNCTION__, \ - #expr, ##__VA_ARGS__); \ - } \ +#define megdnn_assert(expr, ...) \ + do { \ + if (megdnn_unlikely(!(expr))) { \ + ::megdnn::__assert_fail__( \ + __FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##__VA_ARGS__); \ + } \ } while (0) #endif // MEGDNN_ENABLE_MANGLING #else @@ -69,18 +69,19 @@ namespace megdnn { -void __assert_fail__(const char *file, int line, const char *func, - const char *expr, const char *msg_fmt = nullptr, ...) +void __assert_fail__( + const char* file, int line, const char* func, const char* expr, + const char* msg_fmt = nullptr, ...) #if defined(__GNUC__) || defined(__clang__) - __attribute__((format(printf, 5, 6), noreturn)) + __attribute__((format(printf, 5, 6), noreturn)) #endif - ; + ; -void __dummy_printf__(const char *msg_fmt, ...) +void __dummy_printf__(const char* msg_fmt, ...) #ifdef __GNUC__ - __attribute__((format(printf, 1, 2))) + __attribute__((format(printf, 1, 2))) #endif -; + ; //! typetrait, just the same as std::is_same in c++11 template @@ -93,6 +94,6 @@ struct is_same { static const bool value = true; }; -} // namespace megdnn +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/common/utils.h b/dnn/src/common/utils.h index 452477d9..b47f3df0 100644 --- a/dnn/src/common/utils.h +++ b/dnn/src/common/utils.h @@ -41,8 +41,8 @@ #include "megdnn/basic_types.h" #include "megdnn/dtype.h" #include "megdnn/handle.h" -#include "megdnn/thin/small_vector.h" #include "megdnn/oprs/general.h" +#include "megdnn/thin/small_vector.h" #include "src/common/hash_ct.h" #include "src/common/utils.cuh" @@ -60,7 +60,6 @@ #include #endif - #if MEGDNN_AARCH64 || MEGDNN_ARMV7 #if MGB_ENABLE_CPUINFO #include "cpuinfo.h" @@ -68,66 +67,68 @@ #endif #if __cplusplus >= 201703L || __clang_major__ >= 4 - #define MEGDNN_FALLTHRU [[fallthrough]]; +#define MEGDNN_FALLTHRU [[fallthrough]]; #elif __GNUC__ >= 7 - #define MEGDNN_FALLTHRU __attribute__ ((fallthrough)); +#define MEGDNN_FALLTHRU __attribute__((fallthrough)); #else - #define MEGDNN_FALLTHRU +#define MEGDNN_FALLTHRU #endif -#define rep(i, n) for (auto i = decltype(n){0}; i < (n); ++i) +#define rep(i, n) for (auto i = decltype(n){0}; i < (n); ++i) #define rep_step(i, n, step) for (auto i = decltype(n){0}; i < (n); i += (step)) -#define megdnn_assert_contiguous(layout) \ - do { \ - megdnn_assert((layout).is_contiguous(), "%s is %s.", #layout, \ - (layout).to_string().c_str()); \ +#define megdnn_assert_contiguous(layout) \ + do { \ + megdnn_assert( \ + (layout).is_contiguous(), "%s is %s.", #layout, \ + (layout).to_string().c_str()); \ } while (0) -#define megdnn_assert_non_overlapping_strong(layout) \ - do { \ - megdnn_assert((layout).is_non_overlapping_strong(), "%s is %s.", \ - #layout, (layout).to_string().c_str()); \ +#define megdnn_assert_non_overlapping_strong(layout) \ + do { \ + megdnn_assert( \ + (layout).is_non_overlapping_strong(), "%s is %s.", #layout, \ + (layout).to_string().c_str()); \ } while (0) -#define megdnn_assert_eq_size_t(lhs_, rhs_) \ - do { \ - size_t lhs = lhs_, rhs = rhs_; \ - megdnn_assert(lhs == rhs, "%s is %zu, %s is %zu.", #lhs_, lhs, #rhs_, \ - rhs); \ +#define megdnn_assert_eq_size_t(lhs_, rhs_) \ + do { \ + size_t lhs = lhs_, rhs = rhs_; \ + megdnn_assert(lhs == rhs, "%s is %zu, %s is %zu.", #lhs_, lhs, #rhs_, rhs); \ } while (0) -#define megdnn_assert_eq_layout(lhs, rhs) \ - do { \ - megdnn_assert(lhs.eq_layout(rhs), "%s is %s, %s is %s.", #lhs, \ - lhs.to_string().c_str(), #rhs, rhs.to_string().c_str()); \ +#define megdnn_assert_eq_layout(lhs, rhs) \ + do { \ + megdnn_assert( \ + lhs.eq_layout(rhs), "%s is %s, %s is %s.", #lhs, \ + lhs.to_string().c_str(), #rhs, rhs.to_string().c_str()); \ } while (0) -#define megdnn_assert_eq_shape(lhs, rhs) \ - do { \ - megdnn_assert(lhs.eq_shape(rhs), "%s is %s, %s is %s.", #lhs, \ - lhs.to_string().c_str(), #rhs, rhs.to_string().c_str()); \ +#define megdnn_assert_eq_shape(lhs, rhs) \ + do { \ + megdnn_assert( \ + lhs.eq_shape(rhs), "%s is %s, %s is %s.", #lhs, \ + lhs.to_string().c_str(), #rhs, rhs.to_string().c_str()); \ } while (0) -#define megdnn_assert_eq_dtype(lhs, rhs) \ - do { \ - megdnn_assert(lhs.dtype == rhs.dtype, "%s is %s, %s is %s.", #lhs, \ - lhs.dtype.name(), #rhs, rhs.dtype.name()); \ +#define megdnn_assert_eq_dtype(lhs, rhs) \ + do { \ + megdnn_assert( \ + lhs.dtype == rhs.dtype, "%s is %s, %s is %s.", #lhs, lhs.dtype.name(), \ + #rhs, rhs.dtype.name()); \ } while (0) -#define megdnn_layout_msg(layout) \ - std::string(#layout "=" + (layout).to_string()) +#define megdnn_layout_msg(layout) std::string(#layout "=" + (layout).to_string()) #if __DEPLOY_ON_XP_SP2__ -#define DNN_MUTEX size_t +#define DNN_MUTEX size_t #define MEGDNN_LOCK_GUARD(var) MEGDNN_MARK_USED_VAR(var) #else -#define DNN_MUTEX std::mutex -#define DNN_TOKENPASTE(x, y) x##y -#define DNN_TOKENPASTE2(x, y) DNN_TOKENPASTE(x, y) +#define DNN_MUTEX std::mutex +#define DNN_TOKENPASTE(x, y) x##y +#define DNN_TOKENPASTE2(x, y) DNN_TOKENPASTE(x, y) #define DNN_LOCK_GUARD_CTOR(mtx) DNN_TOKENPASTE2(__lock_guard_, __LINE__)(mtx) -#define MEGDNN_LOCK_GUARD(mtx) \ - std::lock_guard DNN_LOCK_GUARD_CTOR(mtx) +#define MEGDNN_LOCK_GUARD(mtx) std::lock_guard DNN_LOCK_GUARD_CTOR(mtx) #endif namespace megdnn { @@ -154,8 +155,9 @@ namespace megdnn { #endif #if MEGDNN_ENABLE_LOGGING -void __log__(LogLevel level, const char* file, const char* func, int line, - const char* fmt, ...) __attribute__((format(printf, 5, 6))); +void __log__( + LogLevel level, const char* file, const char* func, int line, const char* fmt, + ...) __attribute__((format(printf, 5, 6))); #define _megdnn_do_log ::megdnn::__log__ #else @@ -177,19 +179,17 @@ constexpr int32_t cast_int(T data) { * \return true if index is updated successfully, false otherwise (index is * already the last one, next index does not exist) */ -bool get_next_addr(size_t* index, const size_t* shape, size_t n, - size_t stride = 1); +bool get_next_addr(size_t* index, const size_t* shape, size_t n, size_t stride = 1); size_t get_linear_addr(size_t* index, const size_t* shape, size_t n); int get_linear_addr_noncont(size_t* index, const TensorLayout& layout); -size_t infer_conv_shape(size_t inp, size_t flt, size_t stride, size_t pad, - bool is_floor = true); -void infer_conv_shape2d(size_t ih, size_t iw, size_t fh, size_t fw, size_t sh, - size_t sw, size_t ph, size_t pw, size_t& oh, size_t& ow, - bool is_floor = true); +size_t infer_conv_shape( + size_t inp, size_t flt, size_t stride, size_t pad, bool is_floor = true); +void infer_conv_shape2d( + size_t ih, size_t iw, size_t fh, size_t fw, size_t sh, size_t sw, size_t ph, + size_t pw, size_t& oh, size_t& ow, bool is_floor = true); template SmallVector apply_vector(Func&& func, const SmallVector& vec); -std::string ssprintf(const char* fmt, ...) - __attribute__((format(printf, 1, 2))); +std::string ssprintf(const char* fmt, ...) __attribute__((format(printf, 1, 2))); /*! * \brief transpose (m*n) matrix to (n*m) matrix @@ -201,16 +201,17 @@ std::string ssprintf(const char* fmt, ...) * */ template -void transpose(const dtype* src, dtype* dst, size_t m, size_t n, - ptrdiff_t lds = -1, ptrdiff_t ldd = -1); +void transpose( + const dtype* src, dtype* dst, size_t m, size_t n, ptrdiff_t lds = -1, + ptrdiff_t ldd = -1); /*! * transpose src with contiguous layout (k, n, c) into dst with shape * (n, c, k), with given stride (\p n_stride) on first dimension */ template -void transpose_knc2nsck(const dtype* src, dtype* dst, size_t k, size_t n, - size_t c, size_t n_stride); +void transpose_knc2nsck( + const dtype* src, dtype* dst, size_t k, size_t n, size_t c, size_t n_stride); /*! * \brief divide get result ceiled to int; both dividend and divisor shoud be @@ -249,16 +250,16 @@ std::unique_ptr make_unique(Args&&... args) { /*! * \brief check whether the source enum contain the target data type enum */ -bool inline contain_data_type(detail::AlgoDataType source, - detail::AlgoDataType target) { - return static_cast(static_cast(source) & - static_cast(target)); +bool inline contain_data_type( + detail::AlgoDataType source, detail::AlgoDataType target) { + return static_cast( + static_cast(source) & static_cast(target)); } /*! * \brief get the source enum contain the data type number */ -template +template size_t nr_type_contain(T index) { uint32_t sr_index = static_cast(index); size_t nr_type = 0; @@ -276,8 +277,8 @@ size_t nr_type_contain(T index) { */ class WorkspaceBundle { public: - WorkspaceBundle(void* ptr, SmallVector sizes_in_bytes, - size_t align_in_bytes = 512); + WorkspaceBundle( + void* ptr, SmallVector sizes_in_bytes, size_t align_in_bytes = 512); /** * \returns raw workspace ptr. * @@ -349,13 +350,13 @@ size_t count_not_ones_in_shape(const TensorShape& shape); */ bool is_nhwc_contig_wc(const TensorLayout& layout); -static inline void copy_plane_in_bytes(void* dst, const void* src, - size_t height, size_t width, - size_t stride_dst, size_t stride_src) { +static inline void copy_plane_in_bytes( + void* dst, const void* src, size_t height, size_t width, size_t stride_dst, + size_t stride_src) { for (size_t h = 0; h < height; ++h) { - std::memcpy(static_cast(dst) + h * stride_dst, - static_cast(src) + h * stride_src, - width); + std::memcpy( + static_cast(dst) + h * stride_dst, + static_cast(src) + h * stride_src, width); } } @@ -483,8 +484,9 @@ struct _SafeMultipliesImplUnsigned : public std::binary_function { t += yodd ? x : 0; overflow |= yodd & (t < x); - megdnn_assert(!overflow, "multiply overflow: %s %s", - std::to_string(x).c_str(), std::to_string(y).c_str()); + megdnn_assert( + !overflow, "multiply overflow: %s %s", std::to_string(x).c_str(), + std::to_string(y).c_str()); return t; } @@ -534,11 +536,10 @@ dt_qint4 convert(int8_t src, dt_qint4 dst, size_t offset); * \brief check float equal within given ULP(unit in the last place) */ template -static inline - typename std::enable_if::is_integer, bool>::type - almost_equal(T x, T y, int unit_last_place = 1) { - return std::abs(x - y) < (std::numeric_limits::epsilon() * - std::abs(x + y) * unit_last_place) || +static inline typename std::enable_if::is_integer, bool>::type +almost_equal(T x, T y, int unit_last_place = 1) { + return std::abs(x - y) < (std::numeric_limits::epsilon() * std::abs(x + y) * + unit_last_place) || std::abs(x - y) < std::numeric_limits::min(); } @@ -556,8 +557,9 @@ private: public: //! \brief Constructs seven-dimensional range. - CpuNDRange(size_t size0, size_t size1, size_t size2, size_t size3, - size_t size4, size_t size5, size_t size6) + CpuNDRange( + size_t size0, size_t size1, size_t size2, size_t size3, size_t size4, + size_t size5, size_t size6) : m_dimension(7) { m_dim[0] = size0; m_dim[1] = size1; @@ -571,13 +573,10 @@ public: CpuNDRange() : CpuNDRange(1, 1, 1, 1, 1, 1, 1) { m_dimension = 0; } //! \brief Constructs one-dimensional range. - CpuNDRange(size_t size0) : CpuNDRange(size0, 1, 1, 1, 1, 1, 1) { - m_dimension = 1; - } + CpuNDRange(size_t size0) : CpuNDRange(size0, 1, 1, 1, 1, 1, 1) { m_dimension = 1; } //! \brief Constructs two-dimensional range. - CpuNDRange(size_t size0, size_t size1) - : CpuNDRange(size0, size1, 1, 1, 1, 1, 1) { + CpuNDRange(size_t size0, size_t size1) : CpuNDRange(size0, size1, 1, 1, 1, 1, 1) { m_dimension = 2; } @@ -594,15 +593,15 @@ public: } //! \brief Constructs five-dimensional range. - CpuNDRange(size_t size0, size_t size1, size_t size2, size_t size3, - size_t size4) + CpuNDRange(size_t size0, size_t size1, size_t size2, size_t size3, size_t size4) : CpuNDRange(size0, size1, size2, size3, size4, 1, 1) { m_dimension = 5; } //! \brief Constructs six-dimensional range. - CpuNDRange(size_t size0, size_t size1, size_t size2, size_t size3, - size_t size4, size_t size5) + CpuNDRange( + size_t size0, size_t size1, size_t size2, size_t size3, size_t size4, + size_t size5) : CpuNDRange(size0, size1, size2, size3, size4, size5, 1) { m_dimension = 6; } @@ -693,25 +692,22 @@ struct CompTypeCvter { return *this; } - Workspace workspace() { - return m_workspace_bundle->get_workspace(m_workspace_idx); - } + Workspace workspace() { return m_workspace_bundle->get_workspace(m_workspace_idx); } }; /*! * \brief get TensorND raw_ptr+low_byte pointer. */ inline dt_byte* get_low_ptr(const TensorND* tensor) { - return static_cast(tensor->raw_ptr) + - tensor->layout.span().low_byte; + return static_cast(tensor->raw_ptr) + tensor->layout.span().low_byte; } /*! * \brief get the zero element pointer of TensorND. */ inline void* get_origin_ptr(const TensorND* tensor, void* ptr) { - return static_cast(static_cast(ptr) - - tensor->layout.span().low_byte); + return static_cast( + static_cast(ptr) - tensor->layout.span().low_byte); } } // namespace megdnn diff --git a/dnn/src/common/warp_affine.cpp b/dnn/src/common/warp_affine.cpp index 48bd648d..00e0b0d3 100644 --- a/dnn/src/common/warp_affine.cpp +++ b/dnn/src/common/warp_affine.cpp @@ -14,9 +14,8 @@ namespace megdnn { -void WarpAffineBase::check_layout_fwd(const TensorLayout& src, - const TensorLayout& mat, - const TensorLayout& dst) { +void WarpAffineBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { megdnn_assert_contiguous(mat); auto errmsg = [&]() { return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(mat) + ", " + @@ -33,21 +32,20 @@ void WarpAffineBase::check_layout_fwd(const TensorLayout& src, if (param().format == Param::Format::NCHW) { megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str()); - megdnn_assert(src.dtype.enumv() == DTypeEnum::Float32 || - DNN_FLOAT16_SELECT( - src.dtype.enumv() == DTypeEnum::Float16, - false) || - src.dtype.enumv() == DTypeEnum::Int8 || - src.dtype.enumv() == DTypeEnum::Uint8 || - (src.dtype.enumv() == DTypeEnum::QuantizedS8 || - src.dtype.enumv() == DTypeEnum::Quantized8Asymm), - "WarpAffine NCHW input dtype should be " - "Float32/Int8/Uint8/QInt8/QUint8" DNN_FLOAT16_SELECT( - "/Float16", "") "."); + megdnn_assert( + src.dtype.enumv() == DTypeEnum::Float32 || + DNN_FLOAT16_SELECT( + src.dtype.enumv() == DTypeEnum::Float16, false) || + src.dtype.enumv() == DTypeEnum::Int8 || + src.dtype.enumv() == DTypeEnum::Uint8 || + (src.dtype.enumv() == DTypeEnum::QuantizedS8 || + src.dtype.enumv() == DTypeEnum::Quantized8Asymm), + "WarpAffine NCHW input dtype should be " + "Float32/Int8/Uint8/QInt8/QUint8" DNN_FLOAT16_SELECT( + "/Float16", "") "."); megdnn_assert( (src.dtype.category() == DTypeCategory::FLOAT && - (src.dtype == mat.dtype || - mat.dtype.enumv() == DTypeEnum::Float32)) || + (src.dtype == mat.dtype || mat.dtype.enumv() == DTypeEnum::Float32)) || ((src.dtype.category() == DTypeCategory::INT || src.dtype.category() == DTypeCategory::QUANTIZED) && mat.dtype.enumv() == DTypeEnum::Float32), @@ -58,36 +56,35 @@ void WarpAffineBase::check_layout_fwd(const TensorLayout& src, mat.dtype.name()); megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); - megdnn_assert(param().imode == - param::WarpPerspective::InterpolationMode::LINEAR); - megdnn_assert(param().border_mode != - param::WarpPerspective::BorderMode::TRANSPARENT); - megdnn_assert(param().border_mode != - param::WarpPerspective::BorderMode::ISOLATED); + megdnn_assert( + param().imode == param::WarpPerspective::InterpolationMode::LINEAR); + megdnn_assert( + param().border_mode != param::WarpPerspective::BorderMode::TRANSPARENT); + megdnn_assert( + param().border_mode != param::WarpPerspective::BorderMode::ISOLATED); } else if (param().format == Param::Format::NHWC) { megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str()); megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str()); - megdnn_assert(param().imode != - param::WarpPerspective::InterpolationMode::AREA); + megdnn_assert(param().imode != param::WarpPerspective::InterpolationMode::AREA); } else { megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str()); megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str()); megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str()); megdnn_assert(param().format == Param::Format::NHWCD4); - megdnn_assert(param().imode == - param::WarpPerspective::InterpolationMode::LINEAR); - megdnn_assert(param().border_mode != - param::WarpPerspective::BorderMode::TRANSPARENT); - megdnn_assert(param().border_mode != - param::WarpPerspective::BorderMode::ISOLATED); + megdnn_assert( + param().imode == param::WarpPerspective::InterpolationMode::LINEAR); + megdnn_assert( + param().border_mode != param::WarpPerspective::BorderMode::TRANSPARENT); + megdnn_assert( + param().border_mode != param::WarpPerspective::BorderMode::ISOLATED); } } -void WarpAffine::check_exec(const TensorLayout& src, const TensorLayout& mat, - const TensorLayout& dst, - size_t workspace_in_bytes) { +void WarpAffine::check_exec( + const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst, + size_t workspace_in_bytes) { check_layout_fwd(src, mat, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, mat, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); diff --git a/dnn/src/common/warp_common.cpp b/dnn/src/common/warp_common.cpp index 40ff6db1..1ecfa05a 100644 --- a/dnn/src/common/warp_common.cpp +++ b/dnn/src/common/warp_common.cpp @@ -13,10 +13,10 @@ using namespace megdnn; -bool warp::is_cv_available(const TensorLayout& src, const TensorLayout& /*mat*/, - const TensorLayout& /*dst*/, - param::WarpAffine::InterpolationMode imode, - param::WarpAffine::Format format) { +bool warp::is_cv_available( + const TensorLayout& src, const TensorLayout& /*mat*/, + const TensorLayout& /*dst*/, param::WarpAffine::InterpolationMode imode, + param::WarpAffine::Format format) { return format == param::WarpAffine::Format::NHWC && (src[3] == 1 || src[3] == 2 || src[3] == 3) && (src.dtype == dtype::Float32() || src.dtype == dtype::Uint8()) && @@ -26,11 +26,10 @@ bool warp::is_cv_available(const TensorLayout& src, const TensorLayout& /*mat*/, imode == param::WarpAffine::InterpolationMode::LANCZOS4); } -bool warp::is_dnn_available(const TensorLayout& /*src*/, - const TensorLayout& /*mat*/, - const TensorLayout& /*dst*/, - param::WarpAffine::InterpolationMode imode, - param::WarpAffine::Format /*format*/) { +bool warp::is_dnn_available( + const TensorLayout& /*src*/, const TensorLayout& /*mat*/, + const TensorLayout& /*dst*/, param::WarpAffine::InterpolationMode imode, + param::WarpAffine::Format /*format*/) { return imode == param::WarpAffine::InterpolationMode::LINEAR; } diff --git a/dnn/src/common/warp_common.h b/dnn/src/common/warp_common.h index 02e0298c..948536ac 100644 --- a/dnn/src/common/warp_common.h +++ b/dnn/src/common/warp_common.h @@ -82,15 +82,13 @@ MIDOUT_DECL(remapBilinear_ch) namespace megdnn { namespace warp { -bool is_cv_available(const TensorLayout& src, const TensorLayout& mat, - const TensorLayout& dst, - param::WarpAffine::InterpolationMode imode, - param::WarpAffine::Format format); +bool is_cv_available( + const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst, + param::WarpAffine::InterpolationMode imode, param::WarpAffine::Format format); -bool is_dnn_available(const TensorLayout&, const TensorLayout&, - const TensorLayout&, - param::WarpAffine::InterpolationMode imode, - param::WarpAffine::Format format); +bool is_dnn_available( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + param::WarpAffine::InterpolationMode imode, param::WarpAffine::Format format); using namespace megcv; using IMode = InterpolationMode; @@ -104,8 +102,8 @@ constexpr int INTER_REMAP_COEF_SCALE = InterpTable::INTER_REMAP_COEF_SCALE; template struct RemapVec { - int operator()(const Mat&, void*, const short*, const ushort*, - const void*, int) const { + int operator()( + const Mat&, void*, const short*, const ushort*, const void*, int) const { return 0; } }; @@ -114,16 +112,17 @@ struct RemapVec { template struct RemapVec { - int operator()(const Mat8u& _src, void* _dst, const short* XY, - const ushort* FXY, const void* _wtab, int width) const { + int operator()( + const Mat8u& _src, void* _dst, const short* XY, const ushort* FXY, + const void* _wtab, int width) const { int x = 0, sstep = (int)_src.step(); if ((CH != 1 && CH != 3) || sstep > 0x8000) return 0; const uchar *S0 = _src.ptr(), *S1 = _src.ptr(1); - const short* wtab = CH == 1 ? (const short*)_wtab - : InterpTable::get_linear_ic4_table(); + const short* wtab = + CH == 1 ? (const short*)_wtab : InterpTable::get_linear_ic4_table(); uchar* D = (uchar*)_dst; __m128i delta = _mm_set1_epi32(INTER_REMAP_COEF_SCALE / 2); __m128i xy2ofs = _mm_set1_epi32(CH + (sstep << 16)); @@ -143,18 +142,12 @@ struct RemapVec { _mm_store_si128((__m128i*)iofs0, xy0); _mm_store_si128((__m128i*)iofs1, xy1); - i0 = *(ushort*)(S0 + iofs0[0]) + - (*(ushort*)(S0 + iofs0[1]) << 16); - i1 = *(ushort*)(S0 + iofs0[2]) + - (*(ushort*)(S0 + iofs0[3]) << 16); - v0 = _mm_unpacklo_epi32(_mm_cvtsi32_si128(i0), - _mm_cvtsi32_si128(i1)); - i0 = *(ushort*)(S1 + iofs0[0]) + - (*(ushort*)(S1 + iofs0[1]) << 16); - i1 = *(ushort*)(S1 + iofs0[2]) + - (*(ushort*)(S1 + iofs0[3]) << 16); - v1 = _mm_unpacklo_epi32(_mm_cvtsi32_si128(i0), - _mm_cvtsi32_si128(i1)); + i0 = *(ushort*)(S0 + iofs0[0]) + (*(ushort*)(S0 + iofs0[1]) << 16); + i1 = *(ushort*)(S0 + iofs0[2]) + (*(ushort*)(S0 + iofs0[3]) << 16); + v0 = _mm_unpacklo_epi32(_mm_cvtsi32_si128(i0), _mm_cvtsi32_si128(i1)); + i0 = *(ushort*)(S1 + iofs0[0]) + (*(ushort*)(S1 + iofs0[1]) << 16); + i1 = *(ushort*)(S1 + iofs0[2]) + (*(ushort*)(S1 + iofs0[3]) << 16); + v1 = _mm_unpacklo_epi32(_mm_cvtsi32_si128(i0), _mm_cvtsi32_si128(i1)); v0 = _mm_unpacklo_epi8(v0, z); v1 = _mm_unpacklo_epi8(v1, z); @@ -170,18 +163,12 @@ struct RemapVec { v1 = _mm_madd_epi16(v1, b1); v0 = _mm_add_epi32(_mm_add_epi32(v0, v1), delta); - i0 = *(ushort*)(S0 + iofs1[0]) + - (*(ushort*)(S0 + iofs1[1]) << 16); - i1 = *(ushort*)(S0 + iofs1[2]) + - (*(ushort*)(S0 + iofs1[3]) << 16); - v2 = _mm_unpacklo_epi32(_mm_cvtsi32_si128(i0), - _mm_cvtsi32_si128(i1)); - i0 = *(ushort*)(S1 + iofs1[0]) + - (*(ushort*)(S1 + iofs1[1]) << 16); - i1 = *(ushort*)(S1 + iofs1[2]) + - (*(ushort*)(S1 + iofs1[3]) << 16); - v3 = _mm_unpacklo_epi32(_mm_cvtsi32_si128(i0), - _mm_cvtsi32_si128(i1)); + i0 = *(ushort*)(S0 + iofs1[0]) + (*(ushort*)(S0 + iofs1[1]) << 16); + i1 = *(ushort*)(S0 + iofs1[2]) + (*(ushort*)(S0 + iofs1[3]) << 16); + v2 = _mm_unpacklo_epi32(_mm_cvtsi32_si128(i0), _mm_cvtsi32_si128(i1)); + i0 = *(ushort*)(S1 + iofs1[0]) + (*(ushort*)(S1 + iofs1[1]) << 16); + i1 = *(ushort*)(S1 + iofs1[2]) + (*(ushort*)(S1 + iofs1[3]) << 16); + v3 = _mm_unpacklo_epi32(_mm_cvtsi32_si128(i0), _mm_cvtsi32_si128(i1)); v2 = _mm_unpacklo_epi8(v2, z); v3 = _mm_unpacklo_epi8(v3, z); @@ -229,14 +216,12 @@ struct RemapVec { v0 = _mm_unpacklo_epi8(v0, z); u1 = _mm_unpacklo_epi8(u1, z); v1 = _mm_unpacklo_epi8(v1, z); - u0 = _mm_add_epi32(_mm_madd_epi16(u0, w0[0]), - _mm_madd_epi16(v0, w0[1])); - u1 = _mm_add_epi32(_mm_madd_epi16(u1, w1[0]), - _mm_madd_epi16(v1, w1[1])); - u0 = _mm_srai_epi32(_mm_add_epi32(u0, delta), - INTER_REMAP_COEF_BITS); - u1 = _mm_srai_epi32(_mm_add_epi32(u1, delta), - INTER_REMAP_COEF_BITS); + u0 = _mm_add_epi32( + _mm_madd_epi16(u0, w0[0]), _mm_madd_epi16(v0, w0[1])); + u1 = _mm_add_epi32( + _mm_madd_epi16(u1, w1[0]), _mm_madd_epi16(v1, w1[1])); + u0 = _mm_srai_epi32(_mm_add_epi32(u0, delta), INTER_REMAP_COEF_BITS); + u1 = _mm_srai_epi32(_mm_add_epi32(u1, delta), INTER_REMAP_COEF_BITS); u0 = _mm_slli_si128(u0, 4); u0 = _mm_packs_epi32(u0, u1); u0 = _mm_packus_epi16(u0, u0); @@ -261,14 +246,12 @@ struct RemapVec { v0 = _mm_unpacklo_epi8(v0, z); u1 = _mm_unpacklo_epi8(u1, z); v1 = _mm_unpacklo_epi8(v1, z); - u0 = _mm_add_epi32(_mm_madd_epi16(u0, w0[0]), - _mm_madd_epi16(v0, w0[1])); - u1 = _mm_add_epi32(_mm_madd_epi16(u1, w1[0]), - _mm_madd_epi16(v1, w1[1])); - u0 = _mm_srai_epi32(_mm_add_epi32(u0, delta), - INTER_REMAP_COEF_BITS); - u1 = _mm_srai_epi32(_mm_add_epi32(u1, delta), - INTER_REMAP_COEF_BITS); + u0 = _mm_add_epi32( + _mm_madd_epi16(u0, w0[0]), _mm_madd_epi16(v0, w0[1])); + u1 = _mm_add_epi32( + _mm_madd_epi16(u1, w1[0]), _mm_madd_epi16(v1, w1[1])); + u0 = _mm_srai_epi32(_mm_add_epi32(u0, delta), INTER_REMAP_COEF_BITS); + u1 = _mm_srai_epi32(_mm_add_epi32(u1, delta), INTER_REMAP_COEF_BITS); u0 = _mm_slli_si128(u0, 4); u0 = _mm_packs_epi32(u0, u1); u0 = _mm_packus_epi16(u0, u0); @@ -282,16 +265,16 @@ struct RemapVec { #endif template -using RemapNNFunc = void (*)(const Mat& _src, Mat& _dst, - const Mat& _xy, const T* bvalue); +using RemapNNFunc = void (*)( + const Mat& _src, Mat& _dst, const Mat& _xy, const T* bvalue); template -using RemapFunc = void (*)(const Mat& _src, Mat& _dst, - const Mat& _xy, const Mat& _fxy, - const void* _wtab, const T* bvalue); +using RemapFunc = void (*)( + const Mat& _src, Mat& _dst, const Mat& _xy, + const Mat& _fxy, const void* _wtab, const T* bvalue); template -static void remapNearest(const Mat& _src, Mat& _dst, - const Mat& _xy, const T* bvalue) { +static void remapNearest( + const Mat& _src, Mat& _dst, const Mat& _xy, const T* bvalue) { const T* S0 = _src.ptr(); size_t sstep = _src.step(); int dx, dy; @@ -356,11 +339,10 @@ static void remapNearest(const Mat& _src, Mat& _dst, } } -template -static void remapBicubic(const Mat& _src, Mat& _dst, - const Mat& _xy, const Mat& _fxy, - const void* _wtab, const T* bvalue) { +template +static void remapBicubic( + const Mat& _src, Mat& _dst, const Mat& _xy, + const Mat& _fxy, const void* _wtab, const T* bvalue) { typedef typename CastOp::type1 WT; const AT* wtab = (const AT*)_wtab; const T* S0 = _src.ptr(); @@ -369,8 +351,7 @@ static void remapBicubic(const Mat& _src, Mat& _dst, CastOp castOp; int swidth = _src.width(), sheight = _src.height(); int dwidth = _dst.width(), dheight = _dst.height(); - unsigned width1 = std::max(swidth - 3, 0), - height1 = std::max(sheight - 3, 0); + unsigned width1 = std::max(swidth - 3, 0), height1 = std::max(sheight - 3, 0); if (_dst.is_continuous() && _xy.is_continuous() && _fxy.is_continuous()) { dwidth *= dheight; dheight = 1; @@ -407,8 +388,7 @@ static void remapBicubic(const Mat& _src, Mat& _dst, (unsigned)(sy + 1) >= (unsigned)sheight)) continue; if (bmode == BMode::BORDER_CONSTANT && - (sx >= swidth || sx + 4 <= 0 || sy >= sheight || - sy + 4 <= 0)) { + (sx >= swidth || sx + 4 <= 0 || sy >= sheight || sy + 4 <= 0)) { for (size_t i = 0; i < CH; i++) { D[i] = bvalue[i]; } @@ -442,198 +422,198 @@ static void remapBicubic(const Mat& _src, Mat& _dst, } } -template -static void remapBilinear(const Mat& _src, Mat& _dst, - const Mat& _xy, const Mat& _fxy, - const void* _wtab, const T* bvalue) { +template < + class CastOp, class VecOp, typename AT, typename T, BorderMode bmode, size_t CH> +static void remapBilinear( + const Mat& _src, Mat& _dst, const Mat& _xy, + const Mat& _fxy, const void* _wtab, const T* bvalue) { MIDOUT_BEGIN(remapBilinear_bmode, midout_iv(bmode)) { - typedef typename CastOp::type1 WT; - const AT* wtab = (const AT*)_wtab; - const T* S0 = _src.ptr(); - size_t sstep = _src.step(); - int dx, dy; - CastOp castOp; - VecOp vecOp; - int swidth = _src.width(), sheight = _src.height(); - int dwidth = _dst.width(), dheight = _dst.height(); - unsigned width1 = std::max(swidth - 1, 0), - height1 = std::max(sheight - 1, 0); - for (dy = 0; dy < dheight; dy++) { - T* D = _dst.ptr(dy); - const short* XY = _xy.ptr(dy); - const ushort* FXY = _fxy.ptr(dy); - int X0 = 0; - bool prevInlier = false; - - for (dx = 0; dx <= dwidth; dx++) { - bool curInlier = - dx < dwidth ? (unsigned)XY[dx * 2] < width1 && - (unsigned)XY[dx * 2 + 1] < height1 - : !prevInlier; - if (curInlier == prevInlier) - continue; - - int X1 = dx; - dx = X0; - X0 = X1; - prevInlier = curInlier; - - if (!curInlier) { - int len = vecOp(_src, D, XY + dx * 2, FXY + dx, wtab, X1 - dx); - D += len * CH; - dx += len; - - if (CH == 1) { - MIDOUT_BEGIN(remapBilinear_bmode, 0, 1) { - for (; dx < X1; dx++, D++) { - int sx = XY[dx * 2], sy = XY[dx * 2 + 1]; - const AT* w = wtab + FXY[dx] * 4; - const T* S = S0 + sy * sstep + sx; - *D = castOp(WT(S[0] * w[0] + S[1] * w[1] + - S[sstep] * w[2] + S[sstep + 1] * w[3])); + typedef typename CastOp::type1 WT; + const AT* wtab = (const AT*)_wtab; + const T* S0 = _src.ptr(); + size_t sstep = _src.step(); + int dx, dy; + CastOp castOp; + VecOp vecOp; + int swidth = _src.width(), sheight = _src.height(); + int dwidth = _dst.width(), dheight = _dst.height(); + unsigned width1 = std::max(swidth - 1, 0), height1 = std::max(sheight - 1, 0); + for (dy = 0; dy < dheight; dy++) { + T* D = _dst.ptr(dy); + const short* XY = _xy.ptr(dy); + const ushort* FXY = _fxy.ptr(dy); + int X0 = 0; + bool prevInlier = false; + + for (dx = 0; dx <= dwidth; dx++) { + bool curInlier = dx < dwidth + ? (unsigned)XY[dx * 2] < width1 && + (unsigned)XY[dx * 2 + 1] < height1 + : !prevInlier; + if (curInlier == prevInlier) + continue; + + int X1 = dx; + dx = X0; + X0 = X1; + prevInlier = curInlier; + + if (!curInlier) { + int len = vecOp(_src, D, XY + dx * 2, FXY + dx, wtab, X1 - dx); + D += len * CH; + dx += len; + + if (CH == 1) { + MIDOUT_BEGIN(remapBilinear_bmode, 0, 1) { + for (; dx < X1; dx++, D++) { + int sx = XY[dx * 2], sy = XY[dx * 2 + 1]; + const AT* w = wtab + FXY[dx] * 4; + const T* S = S0 + sy * sstep + sx; + *D = castOp( + WT(S[0] * w[0] + S[1] * w[1] + S[sstep] * w[2] + + S[sstep + 1] * w[3])); + } } - } - MIDOUT_END(); - } else if (CH == 2) { - MIDOUT_BEGIN(remapBilinear_bmode, 0, 2) { - for (; dx < X1; dx++, D += 2) { - int sx = XY[dx * 2], sy = XY[dx * 2 + 1]; - const AT* w = wtab + FXY[dx] * 4; - const T* S = S0 + sy * sstep + sx * 2; - WT t0 = S[0] * w[0] + S[2] * w[1] + S[sstep] * w[2] + - S[sstep + 2] * w[3]; - WT t1 = S[1] * w[0] + S[3] * w[1] + - S[sstep + 1] * w[2] + S[sstep + 3] * w[3]; - D[0] = castOp(t0); - D[1] = castOp(t1); + MIDOUT_END(); + } else if (CH == 2) { + MIDOUT_BEGIN(remapBilinear_bmode, 0, 2) { + for (; dx < X1; dx++, D += 2) { + int sx = XY[dx * 2], sy = XY[dx * 2 + 1]; + const AT* w = wtab + FXY[dx] * 4; + const T* S = S0 + sy * sstep + sx * 2; + WT t0 = S[0] * w[0] + S[2] * w[1] + S[sstep] * w[2] + + S[sstep + 2] * w[3]; + WT t1 = S[1] * w[0] + S[3] * w[1] + + S[sstep + 1] * w[2] + S[sstep + 3] * w[3]; + D[0] = castOp(t0); + D[1] = castOp(t1); + } } - } - MIDOUT_END(); - } else if (CH == 3) - MIDOUT_BEGIN(remapBilinear_bmode, 0, 3) { - for (; dx < X1; dx++, D += 3) { - int sx = XY[dx * 2], sy = XY[dx * 2 + 1]; - const AT* w = wtab + FXY[dx] * 4; - const T* S = S0 + sy * sstep + sx * 3; - WT t0 = S[0] * w[0] + S[3] * w[1] + S[sstep] * w[2] + - S[sstep + 3] * w[3]; - WT t1 = S[1] * w[0] + S[4] * w[1] + - S[sstep + 1] * w[2] + S[sstep + 4] * w[3]; - WT t2 = S[2] * w[0] + S[5] * w[1] + - S[sstep + 2] * w[2] + S[sstep + 5] * w[3]; - D[0] = castOp(t0); - D[1] = castOp(t1); - D[2] = castOp(t2); + MIDOUT_END(); + } else if (CH == 3) + MIDOUT_BEGIN(remapBilinear_bmode, 0, 3) { + for (; dx < X1; dx++, D += 3) { + int sx = XY[dx * 2], sy = XY[dx * 2 + 1]; + const AT* w = wtab + FXY[dx] * 4; + const T* S = S0 + sy * sstep + sx * 3; + WT t0 = S[0] * w[0] + S[3] * w[1] + S[sstep] * w[2] + + S[sstep + 3] * w[3]; + WT t1 = S[1] * w[0] + S[4] * w[1] + + S[sstep + 1] * w[2] + S[sstep + 4] * w[3]; + WT t2 = S[2] * w[0] + S[5] * w[1] + + S[sstep + 2] * w[2] + S[sstep + 5] * w[3]; + D[0] = castOp(t0); + D[1] = castOp(t1); + D[2] = castOp(t2); + } } - } MIDOUT_END(); - else - megdnn_throw("nr. of channels must be 1/2/3."); + else megdnn_throw("nr. of channels must be 1/2/3."); - } else { - if (bmode == BMode::BORDER_TRANSPARENT && CH != 3) { - megdnn_throw( - "unsupported Linear InterpolationMode" - " with BORDER_TRANSPARENT and channel size 1"); - continue; - } - if (CH == 1) { - MIDOUT_BEGIN(remapBilinear_bmode, 1, 1) { - for (; dx < X1; dx++, D++) { + } else { + if (bmode == BMode::BORDER_TRANSPARENT && CH != 3) { + megdnn_throw( + "unsupported Linear InterpolationMode" + " with BORDER_TRANSPARENT and channel size 1"); + continue; + } + if (CH == 1) { + MIDOUT_BEGIN(remapBilinear_bmode, 1, 1) { + for (; dx < X1; dx++, D++) { + int sx = XY[dx * 2], sy = XY[dx * 2 + 1]; + if (bmode == BMode::BORDER_CONSTANT && + (sx >= swidth || sx + 1 < 0 || sy >= sheight || + sy + 1 < 0)) { + D[0] = bvalue[0]; + } else { + int sx0, sx1, sy0, sy1; + T v0, v1, v2, v3; + const AT* w = wtab + FXY[dx] * 4; + if (bmode == BMode::BORDER_REPLICATE) { + sx0 = saturate(sx, 0, swidth); + sx1 = saturate(sx + 1, 0, swidth); + sy0 = saturate(sy, 0, sheight); + sy1 = saturate(sy + 1, 0, sheight); + v0 = S0[sy0 * sstep + sx0]; + v1 = S0[sy0 * sstep + sx1]; + v2 = S0[sy1 * sstep + sx0]; + v3 = S0[sy1 * sstep + sx1]; + } else { + sx0 = border_interpolate(sx, swidth); + sx1 = border_interpolate(sx + 1, swidth); + sy0 = border_interpolate(sy, sheight); + sy1 = border_interpolate( + sy + 1, sheight); + v0 = sx0 >= 0 && sy0 >= 0 + ? S0[sy0 * sstep + sx0] + : bvalue[0]; + v1 = sx1 >= 0 && sy0 >= 0 + ? S0[sy0 * sstep + sx1] + : bvalue[0]; + v2 = sx0 >= 0 && sy1 >= 0 + ? S0[sy1 * sstep + sx0] + : bvalue[0]; + v3 = sx1 >= 0 && sy1 >= 0 + ? S0[sy1 * sstep + sx1] + : bvalue[0]; + } + D[0] = castOp( + WT(v0 * w[0] + v1 * w[1] + v2 * w[2] + + v3 * w[3])); + } + } + } + MIDOUT_END(); + } else { + for (; dx < X1; dx++, D += CH) { int sx = XY[dx * 2], sy = XY[dx * 2 + 1]; if (bmode == BMode::BORDER_CONSTANT && (sx >= swidth || sx + 1 < 0 || sy >= sheight || sy + 1 < 0)) { - D[0] = bvalue[0]; + for (size_t k = 0; k < CH; k++) + D[k] = bvalue[k]; } else { int sx0, sx1, sy0, sy1; - T v0, v1, v2, v3; + const T *v0, *v1, *v2, *v3; const AT* w = wtab + FXY[dx] * 4; if (bmode == BMode::BORDER_REPLICATE) { sx0 = saturate(sx, 0, swidth); sx1 = saturate(sx + 1, 0, swidth); sy0 = saturate(sy, 0, sheight); sy1 = saturate(sy + 1, 0, sheight); - v0 = S0[sy0 * sstep + sx0]; - v1 = S0[sy0 * sstep + sx1]; - v2 = S0[sy1 * sstep + sx0]; - v3 = S0[sy1 * sstep + sx1]; - } else { + v0 = S0 + sy0 * sstep + sx0 * CH; + v1 = S0 + sy0 * sstep + sx1 * CH; + v2 = S0 + sy1 * sstep + sx0 * CH; + v3 = S0 + sy1 * sstep + sx1 * CH; + } else if ( + bmode == BMode::BORDER_TRANSPARENT && + ((unsigned)sx >= (unsigned)(swidth - 1) || + (unsigned)sy >= (unsigned)(sheight - 1))) + continue; + else { sx0 = border_interpolate(sx, swidth); sx1 = border_interpolate(sx + 1, swidth); sy0 = border_interpolate(sy, sheight); - sy1 = border_interpolate(sy + 1, - sheight); + sy1 = border_interpolate(sy + 1, sheight); v0 = sx0 >= 0 && sy0 >= 0 - ? S0[sy0 * sstep + sx0] - : bvalue[0]; + ? S0 + sy0 * sstep + sx0 * CH + : &bvalue[0]; v1 = sx1 >= 0 && sy0 >= 0 - ? S0[sy0 * sstep + sx1] - : bvalue[0]; + ? S0 + sy0 * sstep + sx1 * CH + : &bvalue[0]; v2 = sx0 >= 0 && sy1 >= 0 - ? S0[sy1 * sstep + sx0] - : bvalue[0]; + ? S0 + sy1 * sstep + sx0 * CH + : &bvalue[0]; v3 = sx1 >= 0 && sy1 >= 0 - ? S0[sy1 * sstep + sx1] - : bvalue[0]; + ? S0 + sy1 * sstep + sx1 * CH + : &bvalue[0]; } - D[0] = castOp(WT(v0 * w[0] + v1 * w[1] + v2 * w[2] + - v3 * w[3])); - } - } - } - MIDOUT_END(); - } else { - for (; dx < X1; dx++, D += CH) { - int sx = XY[dx * 2], sy = XY[dx * 2 + 1]; - if (bmode == BMode::BORDER_CONSTANT && - (sx >= swidth || sx + 1 < 0 || sy >= sheight || - sy + 1 < 0)) { - for (size_t k = 0; k < CH; k++) - D[k] = bvalue[k]; - } else { - int sx0, sx1, sy0, sy1; - const T *v0, *v1, *v2, *v3; - const AT* w = wtab + FXY[dx] * 4; - if (bmode == BMode::BORDER_REPLICATE) { - sx0 = saturate(sx, 0, swidth); - sx1 = saturate(sx + 1, 0, swidth); - sy0 = saturate(sy, 0, sheight); - sy1 = saturate(sy + 1, 0, sheight); - v0 = S0 + sy0 * sstep + sx0 * CH; - v1 = S0 + sy0 * sstep + sx1 * CH; - v2 = S0 + sy1 * sstep + sx0 * CH; - v3 = S0 + sy1 * sstep + sx1 * CH; - } else if (bmode == BMode::BORDER_TRANSPARENT && - ((unsigned)sx >= - (unsigned)(swidth - 1) || - (unsigned)sy >= - (unsigned)(sheight - 1))) - continue; - else { - sx0 = border_interpolate(sx, swidth); - sx1 = border_interpolate(sx + 1, swidth); - sy0 = border_interpolate(sy, sheight); - sy1 = border_interpolate(sy + 1, - sheight); - v0 = sx0 >= 0 && sy0 >= 0 - ? S0 + sy0 * sstep + sx0 * CH - : &bvalue[0]; - v1 = sx1 >= 0 && sy0 >= 0 - ? S0 + sy0 * sstep + sx1 * CH - : &bvalue[0]; - v2 = sx0 >= 0 && sy1 >= 0 - ? S0 + sy1 * sstep + sx0 * CH - : &bvalue[0]; - v3 = sx1 >= 0 && sy1 >= 0 - ? S0 + sy1 * sstep + sx1 * CH - : &bvalue[0]; - } - for (size_t k = 0; k < CH; k++) { - D[k] = castOp(WT(v0[k] * w[0] + v1[k] * w[1] + - v2[k] * w[2] + v3[k] * w[3])); + for (size_t k = 0; k < CH; k++) { + D[k] = castOp( + WT(v0[k] * w[0] + v1[k] * w[1] + + v2[k] * w[2] + v3[k] * w[3])); + } } } } @@ -641,15 +621,13 @@ static void remapBilinear(const Mat& _src, Mat& _dst, } } } - } MIDOUT_END(); } -template -static void remapLanczos4(const Mat& _src, Mat& _dst, - const Mat& _xy, const Mat& _fxy, - const void* _wtab, const T* bvalue) { +template +static void remapLanczos4( + const Mat& _src, Mat& _dst, const Mat& _xy, + const Mat& _fxy, const void* _wtab, const T* bvalue) { typedef typename CastOp::type1 WT; const AT* wtab = (const AT*)_wtab; const T* S0 = _src.ptr(); @@ -658,8 +636,7 @@ static void remapLanczos4(const Mat& _src, Mat& _dst, CastOp castOp; int swidth = _src.width(), sheight = _src.height(); int dwidth = _dst.width(), dheight = _dst.height(); - unsigned width1 = std::max(swidth - 7, 0), - height1 = std::max(sheight - 7, 0); + unsigned width1 = std::max(swidth - 7, 0), height1 = std::max(sheight - 7, 0); if (_dst.is_continuous() && _xy.is_continuous() && _fxy.is_continuous()) { dwidth *= dheight; dheight = 1; @@ -678,9 +655,8 @@ static void remapLanczos4(const Mat& _src, Mat& _dst, WT sum = 0; for (int r = 0; r < 8; r++, S += sstep, w += 8) sum += S[0] * w[0] + S[CH] * w[1] + S[CH * 2] * w[2] + - S[CH * 3] * w[3] + S[CH * 4] * w[4] + - S[CH * 5] * w[5] + S[CH * 6] * w[6] + - S[CH * 7] * w[7]; + S[CH * 3] * w[3] + S[CH * 4] * w[4] + S[CH * 5] * w[5] + + S[CH * 6] * w[6] + S[CH * 7] * w[7]; w -= 64; S -= sstep * 8 - 1; D[k] = castOp(sum); @@ -692,8 +668,7 @@ static void remapLanczos4(const Mat& _src, Mat& _dst, (unsigned)(sy + 3) >= (unsigned)sheight)) continue; if (bmode == BMode::BORDER_CONSTANT && - (sx >= swidth || sx + 8 <= 0 || sy >= sheight || - sy + 8 <= 0)) { + (sx >= swidth || sx + 8 <= 0 || sy >= sheight || sy + 8 <= 0)) { for (size_t i = 0; i < CH; i++) { D[i] = bvalue[i]; } @@ -735,15 +710,15 @@ static void remapLanczos4(const Mat& _src, Mat& _dst, } } -template +template < + typename T, InterpolationMode imode, BorderMode bmode, size_t CH, + typename RemapVec> struct RemapFuncHolder; -template +template struct RemapFuncHolder { - static void get_funcs(RemapNNFunc& nnfunc, - RemapFunc& ifunc) { + static void get_funcs( + RemapNNFunc& nnfunc, RemapFunc& ifunc) { switch (imode) { case IMode::INTER_NEAREST: MIDOUT_BEGIN(megdnn_warp, midout_iv(0)) { @@ -754,24 +729,24 @@ struct RemapFuncHolder { case IMode::INTER_LINEAR: MIDOUT_BEGIN(megdnn_warp, midout_iv(1)) { ifunc = remapBilinear< - FixedPtCast, - RemapVec, short, uchar, bmode, CH>; + FixedPtCast, RemapVec, + short, uchar, bmode, CH>; } MIDOUT_END(); break; case IMode::INTER_CUBIC: MIDOUT_BEGIN(megdnn_warp, midout_iv(2)) { ifunc = remapBicubic< - FixedPtCast, - short, INTER_REMAP_COEF_SCALE, uchar, bmode, CH>; + FixedPtCast, short, + INTER_REMAP_COEF_SCALE, uchar, bmode, CH>; } MIDOUT_END(); break; case IMode::INTER_LANCZOS4: MIDOUT_BEGIN(megdnn_warp, midout_iv(3)) { ifunc = remapLanczos4< - FixedPtCast, - short, INTER_REMAP_COEF_SCALE, uchar, bmode, CH>; + FixedPtCast, short, + INTER_REMAP_COEF_SCALE, uchar, bmode, CH>; } MIDOUT_END(); break; @@ -781,11 +756,10 @@ struct RemapFuncHolder { } }; -template +template struct RemapFuncHolder { - static void get_funcs(RemapNNFunc& nnfunc, - RemapFunc& ifunc) { + static void get_funcs( + RemapNNFunc& nnfunc, RemapFunc& ifunc) { switch (imode) { case IMode::INTER_NEAREST: MIDOUT_BEGIN(megdnn_warp, midout_iv(0)) { @@ -795,22 +769,22 @@ struct RemapFuncHolder { break; case IMode::INTER_LINEAR: MIDOUT_BEGIN(megdnn_warp, midout_iv(1)) { - ifunc = remapBilinear, RemapVec, float, - float, bmode, CH>; + ifunc = remapBilinear< + Cast, RemapVec, float, float, bmode, CH>; } MIDOUT_END(); break; case IMode::INTER_CUBIC: MIDOUT_BEGIN(megdnn_warp, midout_iv(2)) { - ifunc = remapBicubic, float, 1, float, - bmode, CH>; + ifunc = remapBicubic< + Cast, float, 1, float, bmode, CH>; } MIDOUT_END(); break; case IMode::INTER_LANCZOS4: MIDOUT_BEGIN(megdnn_warp, midout_iv(3)) { - ifunc = remapLanczos4, float, 1, float, - bmode, CH>; + ifunc = remapLanczos4< + Cast, float, 1, float, bmode, CH>; } MIDOUT_END(); break; @@ -820,13 +794,15 @@ struct RemapFuncHolder { } }; -template +template < + typename T, InterpolationMode imode, BorderMode bmode, size_t CH, + typename RemapVec> #if MEGDNN_X86 MEGDNN_ATTRIBUTE_TARGET("sse3") #endif -void remap(const Mat& src, Mat& dst, Mat& map1, Mat& map2, - const T* bvalue) { +void remap( + const Mat& src, Mat& dst, Mat& map1, Mat& map2, + const T* bvalue) { RemapNNFunc nnfunc = 0; RemapFunc ifunc = 0; bool fixpt = std::is_same::value; @@ -864,8 +840,7 @@ void remap(const Mat& src, Mat& dst, Mat& map1, Mat& map2, x1 = 0; #if MEGDNN_X86 __m128i sA_data, d_data; - __m128i v_INTER_TAB_SIZE2 = - _mm_set1_epi16(INTER_TAB_SIZE2 - 1); + __m128i v_INTER_TAB_SIZE2 = _mm_set1_epi16(INTER_TAB_SIZE2 - 1); for (; x1 <= bcols - 8; x1 += 8) { __m128i const* src = (__m128i const*)(sA + x1); @@ -878,8 +853,7 @@ void remap(const Mat& src, Mat& dst, Mat& map1, Mat& map2, #elif MEGDNN_AARCH64 || MEGDNN_ARMV7 uint16x8_t v_scale = vdupq_n_u16(INTER_TAB_SIZE2 - 1); for (; x1 <= bcols - 8; x1 += 8) - vst1q_u16(A + x1, - vandq_u16(vld1q_u16(sA + x1), v_scale)); + vst1q_u16(A + x1, vandq_u16(vld1q_u16(sA + x1), v_scale)); #endif for (; x1 < bcols; ++x1) A[x1] = (ushort)(sA[x1] & (INTER_TAB_SIZE2 - 1)); @@ -889,24 +863,23 @@ void remap(const Mat& src, Mat& dst, Mat& map1, Mat& map2, } } -#define DISPATCH_CHANNEL(_imode, _bmode, _ch, _cb) \ - switch (_ch) { \ - case 1: { \ - _cb(_imode, _bmode, 1); \ - break; \ - } \ - case 2: { \ - _cb(_imode, _bmode, 2); \ - break; \ - } \ - case 3: { \ - _cb(_imode, _bmode, 3); \ - break; \ - } \ - default: { \ - megdnn_assert(0, "unsupport channels: %zu, only supprt 1/2/3", \ - _ch); \ - } \ +#define DISPATCH_CHANNEL(_imode, _bmode, _ch, _cb) \ + switch (_ch) { \ + case 1: { \ + _cb(_imode, _bmode, 1); \ + break; \ + } \ + case 2: { \ + _cb(_imode, _bmode, 2); \ + break; \ + } \ + case 3: { \ + _cb(_imode, _bmode, 3); \ + break; \ + } \ + default: { \ + megdnn_assert(0, "unsupport channels: %zu, only supprt 1/2/3", _ch); \ + } \ } #define DISPATCH_BMODE(_imode, _bmode, _ch, _cb) \ @@ -931,32 +904,36 @@ void remap(const Mat& src, Mat& dst, Mat& map1, Mat& map2, DISPATCH_CHANNEL(_imode, BorderMode::CONSTANT, _ch, _cb); \ break; \ } \ - default: { megdnn_assert(0, "unsupport border mode for cv"); } \ + default: { \ + megdnn_assert(0, "unsupport border mode for cv"); \ + } \ } -#define DISPATCH_IMODE(_imode, _bmode, _ch, _cb) \ - switch (_imode) { \ - case InterpolationMode::NEAREST: { \ - DISPATCH_BMODE(InterpolationMode::NEAREST, _bmode, _ch, _cb); \ - break; \ - } \ - case InterpolationMode::LINEAR: { \ - DISPATCH_BMODE(InterpolationMode::LINEAR, _bmode, _ch, _cb); \ - break; \ - } \ - case InterpolationMode::AREA: { \ - DISPATCH_BMODE(InterpolationMode::AREA, _bmode, _ch, _cb); \ - break; \ - } \ - case InterpolationMode::CUBIC: { \ - DISPATCH_BMODE(InterpolationMode::CUBIC, _bmode, _ch, _cb); \ - break; \ - } \ - case InterpolationMode::LANCZOS4: { \ - DISPATCH_BMODE(InterpolationMode::LANCZOS4, _bmode, _ch, _cb); \ - break; \ - } \ - default: { megdnn_assert(0, "unsupport interpolation mode for cv"); } \ +#define DISPATCH_IMODE(_imode, _bmode, _ch, _cb) \ + switch (_imode) { \ + case InterpolationMode::NEAREST: { \ + DISPATCH_BMODE(InterpolationMode::NEAREST, _bmode, _ch, _cb); \ + break; \ + } \ + case InterpolationMode::LINEAR: { \ + DISPATCH_BMODE(InterpolationMode::LINEAR, _bmode, _ch, _cb); \ + break; \ + } \ + case InterpolationMode::AREA: { \ + DISPATCH_BMODE(InterpolationMode::AREA, _bmode, _ch, _cb); \ + break; \ + } \ + case InterpolationMode::CUBIC: { \ + DISPATCH_BMODE(InterpolationMode::CUBIC, _bmode, _ch, _cb); \ + break; \ + } \ + case InterpolationMode::LANCZOS4: { \ + DISPATCH_BMODE(InterpolationMode::LANCZOS4, _bmode, _ch, _cb); \ + break; \ + } \ + default: { \ + megdnn_assert(0, "unsupport interpolation mode for cv"); \ + } \ } } // namespace warp diff --git a/dnn/src/common/warp_perspective.cpp b/dnn/src/common/warp_perspective.cpp index b247c943..9ebb9ed9 100644 --- a/dnn/src/common/warp_perspective.cpp +++ b/dnn/src/common/warp_perspective.cpp @@ -15,17 +15,16 @@ namespace megdnn { -void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src, - const TensorLayout& mat, - const TensorLayout& mat_idx, - const TensorLayout& dst) { +void WarpPerspectiveBase::check_layout_fwd( + const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx, + const TensorLayout& dst) { megdnn_assert_contiguous(mat); megdnn_assert_contiguous(src); megdnn_assert_contiguous(dst); auto errmsg = [&]() { return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(mat) + ", " + - megdnn_layout_msg(mat_idx) + ", " + megdnn_layout_msg(dst) + - ", " + param_msg(); + megdnn_layout_msg(mat_idx) + ", " + megdnn_layout_msg(dst) + ", " + + param_msg(); }; MEGDNN_MARK_USED_VAR(errmsg); if (param().format == param::WarpPerspective::Format::NHWCD4 || @@ -34,25 +33,25 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src, megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str()); megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str()); - } else if (param().format == - param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL || - param().format == - param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) { + } else if ( + param().format == param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL || + param().format == param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) { megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str()); } else { - megdnn_assert(param().format == param::WarpPerspective::Format::NHWC || - param().format == param::WarpPerspective::Format::NCHW || - param().format == - param::WarpPerspective::Format::NHWC_NCHW); + megdnn_assert( + param().format == param::WarpPerspective::Format::NHWC || + param().format == param::WarpPerspective::Format::NCHW || + param().format == param::WarpPerspective::Format::NHWC_NCHW); megdnn_assert(src.ndim == 4_z, "%s", errmsg().c_str()); megdnn_assert(dst.ndim == 4_z, "%s", errmsg().c_str()); } megdnn_assert(mat.ndim == 3_z, "%s", errmsg().c_str()); megdnn_assert(dst.shape[0] == mat.shape[0], "%s", errmsg().c_str()); if (mat_idx.ndim) { - megdnn_assert(mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1, - "%s", errmsg().c_str()); + megdnn_assert( + mat_idx.dtype == dtype::Int32() && mat_idx.ndim == 1, "%s", + errmsg().c_str()); megdnn_assert(mat.shape[0] == mat_idx.shape[0], "%s", errmsg().c_str()); megdnn_assert_contiguous(mat_idx); } else { @@ -83,8 +82,7 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src, (src.dtype == mat.dtype || mat.dtype.enumv() == DTypeEnum::Float32)) || ((src.dtype.category() == DTypeCategory::INT || - src.dtype.category() == - DTypeCategory::QUANTIZED) && + src.dtype.category() == DTypeCategory::QUANTIZED) && mat.dtype.enumv() == DTypeEnum::Float32), "The input to WarpPerspective is in NCHW format, in this " "case, if the input dtype is floating point, the " @@ -94,61 +92,61 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src, megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); - megdnn_assert(param().imode == - param::WarpPerspective::InterpolationMode::LINEAR); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::TRANSPARENT); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::ISOLATED); + megdnn_assert( + param().imode == param::WarpPerspective::InterpolationMode::LINEAR); + megdnn_assert( + param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT); + megdnn_assert( + param().bmode != param::WarpPerspective::BorderMode::ISOLATED); } else if (param().format == param::WarpPerspective::Format::NHWC) { megdnn_assert(src.shape[3] == dst.shape[3], "%s", errmsg().c_str()); } else if (param().format == param::WarpPerspective::Format::NCHW4) { - megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8, - "src expected QuantizedS8, but got %s", - src.dtype.name()); - megdnn_assert(mat.dtype == dtype::Float32(), - "matrix dtype expected float, got %s", - mat.dtype.name()); + megdnn_assert( + src.dtype.enumv() == DTypeEnum::QuantizedS8, + "src expected QuantizedS8, but got %s", src.dtype.name()); + megdnn_assert( + mat.dtype == dtype::Float32(), + "matrix dtype expected float, got %s", mat.dtype.name()); megdnn_assert(src.shape[4] == 4 && dst.shape[4] == 4); megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); - megdnn_assert(param().imode == - param::WarpPerspective::InterpolationMode::LINEAR); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::TRANSPARENT); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::ISOLATED); + megdnn_assert( + param().imode == param::WarpPerspective::InterpolationMode::LINEAR); + megdnn_assert( + param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT); + megdnn_assert( + param().bmode != param::WarpPerspective::BorderMode::ISOLATED); } else if (param().format == param::WarpPerspective::Format::NCHW64) { - megdnn_assert((src.dtype.enumv() == DTypeEnum::QuantizedS4 || - src.dtype.enumv() == DTypeEnum::Quantized4Asymm), - "src expected QuantizedS4/Quantized4Asymm, but got %s", - src.dtype.name()); - megdnn_assert(mat.dtype == dtype::Float32(), - "matrix dtype expected float, got %s", - mat.dtype.name()); + megdnn_assert( + (src.dtype.enumv() == DTypeEnum::QuantizedS4 || + src.dtype.enumv() == DTypeEnum::Quantized4Asymm), + "src expected QuantizedS4/Quantized4Asymm, but got %s", + src.dtype.name()); + megdnn_assert( + mat.dtype == dtype::Float32(), + "matrix dtype expected float, got %s", mat.dtype.name()); megdnn_assert(src.shape[4] == 64 && dst.shape[4] == 64); megdnn_assert(src.shape[1] == dst.shape[1], "%s", errmsg().c_str()); - megdnn_assert(param().imode == - param::WarpPerspective::InterpolationMode::LINEAR); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::TRANSPARENT); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::ISOLATED); + megdnn_assert( + param().imode == param::WarpPerspective::InterpolationMode::LINEAR); + megdnn_assert( + param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT); + megdnn_assert( + param().bmode != param::WarpPerspective::BorderMode::ISOLATED); } else { - megdnn_assert(param().format == - param::WarpPerspective::Format::NHWCD4); + megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4); megdnn_assert( src.dtype == dtype::Float32() || - DNN_FLOAT16_SELECT((src.dtype == dtype::Float16() || - src.dtype == dtype::BFloat16()), - false) || + DNN_FLOAT16_SELECT( + (src.dtype == dtype::Float16() || + src.dtype == dtype::BFloat16()), + false) || src.dtype.enumv() == DTypeEnum::QuantizedS8 || src.dtype.enumv() == DTypeEnum::Quantized8Asymm, "WarpPerspective NHWCD4 input dtype should be " "Float32" DNN_FLOAT16_SELECT( - "/Float16/BFloat16", - "") ",QunatizedS8, Quantized8Asymm."); + "/Float16/BFloat16", "") ",QunatizedS8, Quantized8Asymm."); megdnn_assert( (src.dtype == mat.dtype || mat.dtype == dtype::Float32()), "The input to WarpPerspective is in NHWCD4 format, in this " @@ -158,51 +156,49 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout& src, mat.dtype.name()); //! number of channels is same megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str()); - megdnn_assert(param().imode == - param::WarpPerspective::InterpolationMode::LINEAR); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::TRANSPARENT); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::ISOLATED); + megdnn_assert( + param().imode == param::WarpPerspective::InterpolationMode::LINEAR); + megdnn_assert( + param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT); + megdnn_assert( + param().bmode != param::WarpPerspective::BorderMode::ISOLATED); } - } else if (param().format == - param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL || - param().format == - param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) { - megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm || - src.dtype.enumv() == DTypeEnum::Uint8), - "src expected Quantized8Asymm or Uint8, but got %s", - src.dtype.name()); - megdnn_assert(mat.dtype == dtype::Float32(), - "matrix dtype expected float, got %s", mat.dtype.name()); + } else if ( + param().format == param::WarpPerspective::Format::NHWC_NCHW4_IC_SMALL || + param().format == param::WarpPerspective::Format::NCHW_NCHW4_IC_SMALL) { + megdnn_assert( + (src.dtype.enumv() == DTypeEnum::Quantized8Asymm || + src.dtype.enumv() == DTypeEnum::Uint8), + "src expected Quantized8Asymm or Uint8, but got %s", src.dtype.name()); + megdnn_assert( + mat.dtype == dtype::Float32(), "matrix dtype expected float, got %s", + mat.dtype.name()); megdnn_assert(dst.shape[4] == 4); - megdnn_assert(param().imode == - param::WarpPerspective::InterpolationMode::LINEAR); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::TRANSPARENT); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::ISOLATED); + megdnn_assert( + param().imode == param::WarpPerspective::InterpolationMode::LINEAR); + megdnn_assert(param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT); + megdnn_assert(param().bmode != param::WarpPerspective::BorderMode::ISOLATED); } else if (param().format == param::WarpPerspective::Format::NHWC_NCHW) { - megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm || - src.dtype.enumv() == DTypeEnum::Uint8), - "src expected Quantized8Asymm or Uint8, but got %s", - src.dtype.name()); - megdnn_assert(mat.dtype == dtype::Float32(), - "matrix dtype expected float, got %s", mat.dtype.name()); + megdnn_assert( + (src.dtype.enumv() == DTypeEnum::Quantized8Asymm || + src.dtype.enumv() == DTypeEnum::Uint8), + "src expected Quantized8Asymm or Uint8, but got %s", src.dtype.name()); + megdnn_assert( + mat.dtype == dtype::Float32(), "matrix dtype expected float, got %s", + mat.dtype.name()); megdnn_assert(src.shape[3] == dst.shape[1], "%s", errmsg().c_str()); - megdnn_assert(param().imode == - param::WarpPerspective::InterpolationMode::LINEAR); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::TRANSPARENT); - megdnn_assert(param().bmode != - param::WarpPerspective::BorderMode::ISOLATED); + megdnn_assert( + param().imode == param::WarpPerspective::InterpolationMode::LINEAR); + megdnn_assert(param().bmode != param::WarpPerspective::BorderMode::TRANSPARENT); + megdnn_assert(param().bmode != param::WarpPerspective::BorderMode::ISOLATED); } else { megdnn_assert(param().format == param::WarpPerspective::Format::NCHW); - megdnn_assert((src.dtype.enumv() == DTypeEnum::Quantized8Asymm || - src.dtype.enumv() == DTypeEnum::Uint8) && - dst.dtype.enumv() == DTypeEnum::Float32); + megdnn_assert( + (src.dtype.enumv() == DTypeEnum::Quantized8Asymm || + src.dtype.enumv() == DTypeEnum::Uint8) && + dst.dtype.enumv() == DTypeEnum::Float32); } } @@ -287,21 +283,17 @@ int WarpPerspectiveBase::get_real_coord(int p, int len) { return p; } -void WarpPerspectiveForward::check_exec(const TensorLayout& src, - const TensorLayout& mat, - const TensorLayout& mat_idx, - const TensorLayout& dst, - size_t workspace_in_bytes) { +void WarpPerspectiveForward::check_exec( + const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx, + const TensorLayout& dst, size_t workspace_in_bytes) { check_exec_allow_nhwc_mat_idx(src, mat, mat_idx, dst, workspace_in_bytes); } void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( - const TensorLayout& src, const TensorLayout& mat, - const TensorLayout& mat_idx, const TensorLayout& dst, - size_t workspace_in_bytes) { + const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx, + const TensorLayout& dst, size_t workspace_in_bytes) { check_layout_fwd(src, mat, mat_idx, dst); - auto required_workspace_in_bytes = - get_workspace_in_bytes(src, mat, mat_idx, dst); + auto required_workspace_in_bytes = get_workspace_in_bytes(src, mat, mat_idx, dst); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); if (param().format != Param::Format::NHWC && param().format != Param::Format::NCHW && @@ -310,36 +302,31 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( param().format != Param::Format::NHWC_NCHW4_IC_SMALL && param().format != Param::Format::NCHW_NCHW4_IC_SMALL && param().format != Param::Format::NCHW64) { - megdnn_assert(!mat_idx.ndim, - "mat_idx not supported for current format"); + megdnn_assert(!mat_idx.ndim, "mat_idx not supported for current format"); } } -void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat, - const TensorLayout& mat_idx, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes) { +void WarpPerspectiveBackwardData::check_exec( + const TensorLayout& mat, const TensorLayout& mat_idx, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(grad, mat, mat_idx, diff); - megdnn_assert(grad.dtype == dtype::Float32() DNN_INC_FLOAT16( - || grad.dtype == dtype::BFloat16()), - "Backward WarpPerspective only supports Float32/BFloat16."); - auto required_workspace_in_bytes = - get_workspace_in_bytes(mat, mat_idx, diff, grad); + megdnn_assert( + grad.dtype == dtype::Float32() + DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), + "Backward WarpPerspective only supports Float32/BFloat16."); + auto required_workspace_in_bytes = get_workspace_in_bytes(mat, mat_idx, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src, - const TensorLayout& mat, - const TensorLayout& mat_idx, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_in_bytes) { +void WarpPerspectiveBackwardMat::check_exec( + const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx, + const TensorLayout& diff, const TensorLayout& grad, size_t workspace_in_bytes) { check_layout_fwd(src, mat, mat_idx, diff); megdnn_assert_eq_layout(mat, grad); - megdnn_assert(grad.dtype == dtype::Float32() DNN_INC_FLOAT16( - || grad.dtype == dtype::BFloat16()), - "Backward WarpPerspective only supports Float32/BFloat16."); + megdnn_assert( + grad.dtype == dtype::Float32() + DNN_INC_FLOAT16(|| grad.dtype == dtype::BFloat16()), + "Backward WarpPerspective only supports Float32/BFloat16."); auto required_workspace_in_bytes = get_workspace_in_bytes(src, mat, mat_idx, diff, grad); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); diff --git a/dnn/src/common/warp_perspective_helper.cpp b/dnn/src/common/warp_perspective_helper.cpp index 658ac2d6..6b1b379a 100644 --- a/dnn/src/common/warp_perspective_helper.cpp +++ b/dnn/src/common/warp_perspective_helper.cpp @@ -12,13 +12,11 @@ #include "./warp_perspective_helper.h" using namespace megdnn; -bool warp_perspective::is_cv_available(const TensorLayout& src, - const TensorLayout& /*mat*/, - const TensorLayout& mat_idx, - const TensorLayout& /*dst*/, - Param param) { - return param.format == Param::Format::NHWC && - (src[3] == 1 || src[3] == 3) && !mat_idx.ndim && +bool warp_perspective::is_cv_available( + const TensorLayout& src, const TensorLayout& /*mat*/, + const TensorLayout& mat_idx, const TensorLayout& /*dst*/, Param param) { + return param.format == Param::Format::NHWC && (src[3] == 1 || src[3] == 3) && + !mat_idx.ndim && (src.dtype == dtype::Float32() || src.dtype == dtype::Uint8()) && (param.imode == Param::InterpolationMode::NEAREST || param.imode == Param::InterpolationMode::LINEAR || @@ -26,11 +24,9 @@ bool warp_perspective::is_cv_available(const TensorLayout& src, param.imode == Param::InterpolationMode::LANCZOS4); } -bool warp_perspective::is_dnn_available(const TensorLayout& /*src*/, - const TensorLayout& /*mat*/, - const TensorLayout& /*mat_idx*/, - const TensorLayout& /*dst*/, - Param param) { +bool warp_perspective::is_dnn_available( + const TensorLayout& /*src*/, const TensorLayout& /*mat*/, + const TensorLayout& /*mat_idx*/, const TensorLayout& /*dst*/, Param param) { return param.imode == Param::InterpolationMode::LINEAR; } diff --git a/dnn/src/common/warp_perspective_helper.h b/dnn/src/common/warp_perspective_helper.h index 47678be7..4cb75856 100644 --- a/dnn/src/common/warp_perspective_helper.h +++ b/dnn/src/common/warp_perspective_helper.h @@ -15,11 +15,12 @@ namespace megdnn { namespace warp_perspective { using Param = param::WarpPerspective; -bool is_cv_available(const TensorLayout& src, const TensorLayout& mat, - const TensorLayout& mat_idx, const TensorLayout& dst, - Param param); -bool is_dnn_available(const TensorLayout&, const TensorLayout&, - const TensorLayout&, const TensorLayout&, Param param); +bool is_cv_available( + const TensorLayout& src, const TensorLayout& mat, const TensorLayout& mat_idx, + const TensorLayout& dst, Param param); +bool is_dnn_available( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, Param param); } // namespace warp_perspective } // namespace megdnn diff --git a/dnn/src/common/winograd/winograd_generator.cpp b/dnn/src/common/winograd/winograd_generator.cpp index 3db76e2d..136462cc 100644 --- a/dnn/src/common/winograd/winograd_generator.cpp +++ b/dnn/src/common/winograd/winograd_generator.cpp @@ -32,10 +32,10 @@ */ #include "src/common/winograd/winograd_generator.h" -#include "src/common/utils.h" #include #include #include +#include "src/common/utils.h" using namespace megdnn; using namespace winograd; @@ -120,8 +120,9 @@ WinogradGenerator::Matrix computeB(const std::vector& a, int alpha) { WinogradGenerator::Matrix B(alpha, alpha); for (int y = 0; y < alpha - 1; ++y) { - std::memcpy(B.data() + B.cols() * y, BT.data() + BT.cols() * y, - alpha * sizeof(float)); + std::memcpy( + B.data() + B.cols() * y, BT.data() + BT.cols() * y, + alpha * sizeof(float)); } for (int x = 0; x < alpha - 1; ++x) { B.at(alpha - 1, x) = 0; @@ -131,8 +132,7 @@ WinogradGenerator::Matrix computeB(const std::vector& a, int alpha) { return B; } -WinogradGenerator::Matrix computeFPlusOne(const std::vector& a, - int alpha) { +WinogradGenerator::Matrix computeFPlusOne(const std::vector& a, int alpha) { auto fdiag = computeF(a, alpha - 1); WinogradGenerator::Matrix res(1, alpha); for (int i = 0; i < alpha - 1; i++) { @@ -190,8 +190,7 @@ WinogradGenerator::Matrix WinogradGenerator::Matrix::mul(const Matrix& rhs) { return res; } -WinogradGenerator::Matrix WinogradGenerator::Matrix::poly_multi( - const Matrix& B) { +WinogradGenerator::Matrix WinogradGenerator::Matrix::poly_multi(const Matrix& B) { megdnn_assert(rows() == 1 && B.rows() == 1); auto aw = cols(); auto bw = B.cols(); @@ -211,8 +210,7 @@ WinogradGenerator::Matrix WinogradGenerator::Matrix::poly_multi( return res; } -void WinogradGenerator::Matrix::div_per_line( - const WinogradGenerator::Matrix& line) { +void WinogradGenerator::Matrix::div_per_line(const WinogradGenerator::Matrix& line) { megdnn_assert(line.rows() == 1 && line.cols() >= m_rows); for (size_t y = 0; y < m_rows; ++y) { @@ -222,8 +220,7 @@ void WinogradGenerator::Matrix::div_per_line( } } -void WinogradGenerator::Matrix::mul_per_row( - const WinogradGenerator::Matrix& line) { +void WinogradGenerator::Matrix::mul_per_row(const WinogradGenerator::Matrix& line) { megdnn_assert(line.rows() == 1 && line.cols() >= m_cols); for (size_t y = 0; y < m_rows; ++y) { for (size_t x = 0; x < m_cols; ++x) { @@ -232,8 +229,6 @@ void WinogradGenerator::Matrix::mul_per_row( } } - - WinogradGenerator::WinogradGenerator(size_t m, size_t r, float interp) { size_t alpha = m + r - 1; @@ -249,17 +244,18 @@ WinogradGenerator::WinogradGenerator(size_t m, size_t r, float interp) { generate(m, r, a); } -WinogradGenerator::WinogradGenerator(size_t m, size_t r, - const std::vector& interp_points) { - megdnn_assert(interp_points.size() == m + r - 2, - "interp_points should be %zu, but got: %zu", m + r - 2, - interp_points.size()); +WinogradGenerator::WinogradGenerator( + size_t m, size_t r, const std::vector& interp_points) { + megdnn_assert( + interp_points.size() == m + r - 2, + "interp_points should be %zu, but got: %zu", m + r - 2, + interp_points.size()); generate(m, r, interp_points); } -void WinogradGenerator::generate(size_t m, size_t r, - const std::vector& interp_points) { +void WinogradGenerator::generate( + size_t m, size_t r, const std::vector& interp_points) { size_t alpha = m + r - 1; m_A = computeA(interp_points, alpha, m); m_A.transpose(); diff --git a/dnn/src/common/winograd/winograd_generator.h b/dnn/src/common/winograd/winograd_generator.h index c6486b4c..17c5cdb7 100644 --- a/dnn/src/common/winograd/winograd_generator.h +++ b/dnn/src/common/winograd/winograd_generator.h @@ -32,9 +32,9 @@ */ #pragma once -#include #include #include +#include #include "src/common/utils.h" namespace megdnn { @@ -46,8 +46,7 @@ namespace winograd { class WinogradGenerator { public: WinogradGenerator(size_t m, size_t r, float interp = 0.5f); - WinogradGenerator(size_t m, size_t r, - const std::vector& interp_points); + WinogradGenerator(size_t m, size_t r, const std::vector& interp_points); ~WinogradGenerator() = default; class Matrix { @@ -120,8 +119,7 @@ template class WinogradCoeff { std::unique_ptr m_generator; - std::vector generate(float rescale, - const WinogradGenerator::Matrix& m) { + std::vector generate(float rescale, const WinogradGenerator::Matrix& m) { std::vector ret; for (size_t r = 0; r < m.rows(); r++) { for (size_t c = 0; c < m.cols(); c++) { @@ -146,17 +144,11 @@ public: m_generator = std::make_unique(m, r, interp_points); } - std::vector A(float rescale) { - return generate(rescale, m_generator->A()); - } + std::vector A(float rescale) { return generate(rescale, m_generator->A()); } - std::vector B(float rescale) { - return generate(rescale, m_generator->B()); - } + std::vector B(float rescale) { return generate(rescale, m_generator->B()); } - std::vector G(float rescale) { - return generate(rescale, m_generator->G()); - } + std::vector G(float rescale) { return generate(rescale, m_generator->G()); } }; } // namespace winograd diff --git a/dnn/src/common/winograd/winograd_helper.cpp b/dnn/src/common/winograd/winograd_helper.cpp index 9d96d373..43734400 100644 --- a/dnn/src/common/winograd/winograd_helper.cpp +++ b/dnn/src/common/winograd/winograd_helper.cpp @@ -22,12 +22,10 @@ struct Getter { }; template -struct Getter::value>> { +struct Getter< + ctype, otype, typename std::enable_if_t::value>> { otype zp; - Getter(DType dtype) { - zp = dtype.param().zero_point; - } + Getter(DType dtype) { zp = dtype.param().zero_point; } otype operator()(ctype item) { return static_cast(item) - zp; } }; @@ -39,8 +37,7 @@ struct OutputGetter { template struct OutputGetter< - ctype, otype, - typename std::enable_if_t::value>> { + ctype, otype, typename std::enable_if_t::value>> { DType dtype; OutputGetter(DType dtype) : dtype{dtype} {} otype operator()(float item) { @@ -50,8 +47,7 @@ struct OutputGetter< template struct OutputGetter< - ctype, otype, - typename std::enable_if_t::value>> { + ctype, otype, typename std::enable_if_t::value>> { DType dtype; OutputGetter(DType dtype) : dtype{dtype} {} otype operator()(float item) { @@ -110,11 +106,10 @@ struct FilterVisitor { size_t OCB = OC / matmul_pack_size; size_t ICB = IC / matmul_pack_size; - return (h * alpha + w) * OCB * ICB * matmul_pack_size * - matmul_pack_size + + return (h * alpha + w) * OCB * ICB * matmul_pack_size * matmul_pack_size + ocb * ICB * matmul_pack_size * matmul_pack_size + - icb * matmul_pack_size * matmul_pack_size + - ic_pack * matmul_pack_size + oc_pack; + icb * matmul_pack_size * matmul_pack_size + ic_pack * matmul_pack_size + + oc_pack; } }; @@ -123,18 +118,18 @@ struct InputVisitor { size_t IC; InputVisitor(size_t IC) : IC(IC) {} - size_t get(size_t /*alpha*/, size_t ic, size_t IH, size_t IW, size_t ih, - size_t iw) { + size_t get( + size_t /*alpha*/, size_t ic, size_t IH, size_t IW, size_t ih, size_t iw) { constexpr size_t input_pack_size = layout_pack_size(layout); size_t icb_layout = ic / input_pack_size; size_t ic_layout = ic % input_pack_size; - return (icb_layout * IH * IW + ih * IW + iw) * input_pack_size + - ic_layout; + return (icb_layout * IH * IW + ih * IW + iw) * input_pack_size + ic_layout; } - size_t put(size_t alpha, size_t ic, size_t nr_units_in_tile, - size_t unit_idx, size_t h, size_t w) { + size_t put( + size_t alpha, size_t ic, size_t nr_units_in_tile, size_t unit_idx, size_t h, + size_t w) { if (format == param::MatrixMul::Format::DEFAULT) { return (h * alpha + w) * nr_units_in_tile * IC + unit_idx * IC + ic; } @@ -144,8 +139,8 @@ struct InputVisitor { size_t ICB = IC / matmul_pack_size; return (h * alpha + w) * ICB * nr_units_in_tile * matmul_pack_size + - icb * nr_units_in_tile * matmul_pack_size + - unit_idx * matmul_pack_size + ic_pack; + icb * nr_units_in_tile * matmul_pack_size + unit_idx * matmul_pack_size + + ic_pack; } }; @@ -154,11 +149,11 @@ struct OutputVisitor { size_t OC; OutputVisitor(size_t OC) : OC(OC) {} - size_t get(size_t alpha, size_t oc_index, size_t oc, - size_t nr_units_in_tile, size_t unit_idx, size_t h, size_t w) { + size_t get( + size_t alpha, size_t oc_index, size_t oc, size_t nr_units_in_tile, + size_t unit_idx, size_t h, size_t w) { if (format == param::MatrixMul::Format::DEFAULT) { - return (h * alpha + w) * nr_units_in_tile * OC + unit_idx * OC + - oc_index; + return (h * alpha + w) * nr_units_in_tile * OC + unit_idx * OC + oc_index; } size_t matmul_pack_size = MatrixMulForward::pack_size(format); size_t ocb = oc_index / matmul_pack_size; @@ -166,34 +161,32 @@ struct OutputVisitor { size_t OCB = OC / matmul_pack_size; return (h * alpha + w) * OCB * nr_units_in_tile * matmul_pack_size + - ocb * nr_units_in_tile * matmul_pack_size + - unit_idx * matmul_pack_size + oc_pack; + ocb * nr_units_in_tile * matmul_pack_size + unit_idx * matmul_pack_size + + oc_pack; } size_t put(size_t oc, size_t OH, size_t OW, size_t oh, size_t ow) { constexpr size_t input_pack_size = layout_pack_size(layout); size_t oc_layout = oc % input_pack_size; - return (oc / input_pack_size * OH * OW + oh * OW + ow) * - input_pack_size + + return (oc / input_pack_size * OH * OW + oh * OW + ow) * input_pack_size + oc_layout; } }; -template +template < + typename ctype, typename dst_type, typename input_filter_compute_type, + typename output_compute_type, param::ConvBias::Format layout, + param::MatrixMul::Format format> void StrategyHelper< ctype, dst_type, input_filter_compute_type, output_compute_type, layout, - format>::filter(const ctype* filter, - input_filter_compute_type* filter_transform_buf, - input_filter_compute_type* transform_mid_buf, size_t OC, - size_t IC, size_t oc_start, size_t oc_end, size_t m, - size_t r, const std::vector& interp_points, - DType dtype, float rescale) { + format>:: + filter(const ctype* filter, input_filter_compute_type* filter_transform_buf, + input_filter_compute_type* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end, size_t m, size_t r, + const std::vector& interp_points, DType dtype, float rescale) { size_t alpha = m + r - 1; - WinogradCoeff winograd_coeff(m, r, - interp_points); + WinogradCoeff winograd_coeff(m, r, interp_points); input_filter_compute_type* mid_buf1 = transform_mid_buf; input_filter_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; Getter getter(dtype); @@ -207,17 +200,15 @@ void StrategyHelper< } /* tmp = Matmul(G, src) */ - megdnn::naive::run_matrix_mul_tpl( - winograd_coeff.G(rescale).data(), mid_buf1, mid_buf2, alpha, - r, r, r, r, r, dtype, dtype); + megdnn::naive::run_matrix_mul_tpl< + input_filter_compute_type, input_filter_compute_type, false, false>( + winograd_coeff.G(rescale).data(), mid_buf1, mid_buf2, alpha, r, r, + r, r, r, dtype, dtype); /* dst = Matmul(tmp, G^T) */ - megdnn::naive::run_matrix_mul_tpl( - mid_buf2, winograd_coeff.G(rescale).data(), mid_buf1, alpha, - alpha, r, r, r, alpha, dtype, dtype); + megdnn::naive::run_matrix_mul_tpl< + input_filter_compute_type, input_filter_compute_type, false, true>( + mid_buf2, winograd_coeff.G(rescale).data(), mid_buf1, alpha, alpha, + r, r, r, alpha, dtype, dtype); rep(i, alpha) rep(j, alpha) { filter_transform_buf[filter_visitor.put(alpha, oc, ic, i, j)] = @@ -227,22 +218,20 @@ void StrategyHelper< } } -template +template < + typename ctype, typename dst_type, typename input_filter_compute_type, + typename output_compute_type, param::ConvBias::Format layout, + param::MatrixMul::Format format> void StrategyHelper< ctype, dst_type, input_filter_compute_type, output_compute_type, layout, - format>::input(const ctype* input, - input_filter_compute_type* input_transform_buf, - input_filter_compute_type* transform_mid_buf, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t IC, size_t ic, size_t unit_idx, size_t nr_units_in_tile, - size_t m, size_t r, - const std::vector& interp_points, DType dtype, - float rescale) { + format>:: + input(const ctype* input, input_filter_compute_type* input_transform_buf, + input_filter_compute_type* transform_mid_buf, int ih_start, int iw_start, + size_t IH, size_t IW, size_t IC, size_t ic, size_t unit_idx, + size_t nr_units_in_tile, size_t m, size_t r, + const std::vector& interp_points, DType dtype, float rescale) { size_t alpha = m + r - 1; - WinogradCoeff winograd_coeff(m, r, - interp_points); + WinogradCoeff winograd_coeff(m, r, interp_points); input_filter_compute_type* mid_buf1 = transform_mid_buf; input_filter_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; Getter getter(dtype); @@ -253,47 +242,43 @@ void StrategyHelper< int ih = ih_start + i; int iw = iw_start + j; if (ih >= 0 && ih < (int)IH && iw >= 0 && iw < (int)IW) { - mid_buf1[i * alpha + j] = getter( - input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]); + mid_buf1[i * alpha + j] = + getter(input[intput_visitor.get(alpha, ic, IH, IW, ih, iw)]); } } - megdnn::naive::run_matrix_mul_tpl( - winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha, - alpha, alpha, alpha, alpha, alpha, dtype, dtype); - megdnn::naive::run_matrix_mul_tpl( - mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha, - alpha, alpha, alpha, alpha, alpha, dtype, dtype); + megdnn::naive::run_matrix_mul_tpl< + input_filter_compute_type, input_filter_compute_type, true, false>( + winograd_coeff.B(rescale).data(), mid_buf1, mid_buf2, alpha, alpha, alpha, + alpha, alpha, alpha, dtype, dtype); + megdnn::naive::run_matrix_mul_tpl< + input_filter_compute_type, input_filter_compute_type, false, false>( + mid_buf2, winograd_coeff.B(rescale).data(), mid_buf1, alpha, alpha, alpha, + alpha, alpha, alpha, dtype, dtype); rep(i, alpha) rep(j, alpha) { - input_transform_buf[intput_visitor.put(alpha, ic, nr_units_in_tile, - unit_idx, i, j)] = - mid_buf1[i * alpha + j]; + input_transform_buf[intput_visitor.put( + alpha, ic, nr_units_in_tile, unit_idx, i, j)] = mid_buf1[i * alpha + j]; } } -template +template < + typename ctype, typename dst_type, typename input_filter_compute_type, + typename output_compute_type, param::ConvBias::Format layout, + param::MatrixMul::Format format> void StrategyHelper< ctype, dst_type, input_filter_compute_type, output_compute_type, layout, - format>::output(const output_compute_type* output_transform_buf, - const output_compute_type* bias, dst_type* output, - output_compute_type* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t oh_start, - size_t ow_start, size_t OH, size_t OW, size_t OC, size_t oc_start, - size_t oc_index, size_t unit_idx, size_t nr_units_in_tile, - size_t m, size_t r, - const std::vector& interp_points, DType dtype, - float input_filter_scale, float input_filter_rescale, - float rescale) { + format>:: + output(const output_compute_type* output_transform_buf, + const output_compute_type* bias, dst_type* output, + output_compute_type* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t OC, size_t oc_start, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, size_t m, size_t r, + const std::vector& interp_points, DType dtype, + float input_filter_scale, float input_filter_rescale, float rescale) { size_t alpha = m + r - 1; - winograd::WinogradCoeff winograd_coeff(m, r, - interp_points); + winograd::WinogradCoeff winograd_coeff(m, r, interp_points); output_compute_type* mid_buf1 = transform_mid_buf; output_compute_type* mid_buf2 = transform_mid_buf + alpha * alpha; OutputGetter getter(dtype); @@ -304,18 +289,17 @@ void StrategyHelper< /* gather */ rep(i, alpha) rep(j, alpha) { mid_buf1[i * alpha + j] = output_transform_buf[output_visitor.get( - alpha, oc_index, oc, nr_units_in_tile, unit_idx, i, - j)]; + alpha, oc_index, oc, nr_units_in_tile, unit_idx, i, j)]; } /* A[alpha*m] M[alpha*alpha] */ - megdnn::naive::run_matrix_mul_tpl( - winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha, - alpha, m, alpha, alpha, dtype, dtype); - megdnn::naive::run_matrix_mul_tpl( - mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m, - alpha, alpha, m, m, dtype, dtype); + megdnn::naive::run_matrix_mul_tpl< + output_compute_type, output_compute_type, true, false>( + winograd_coeff.A(rescale).data(), mid_buf1, mid_buf2, m, alpha, alpha, m, + alpha, alpha, dtype, dtype); + megdnn::naive::run_matrix_mul_tpl< + output_compute_type, output_compute_type, false, false>( + mid_buf2, winograd_coeff.A(rescale).data(), mid_buf1, m, m, alpha, alpha, m, + m, dtype, dtype); rep(i, m) rep(j, m) { auto oh = oh_start + i; @@ -323,15 +307,13 @@ void StrategyHelper< if (oh < OH && ow < OW) { float val = mid_buf1[i * m + j]; if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { - val += bias[oc] * input_filter_rescale * - input_filter_rescale; + val += bias[oc] * input_filter_rescale * input_filter_rescale; } else if (bmode == BiasMode::BIAS) { val += bias[output_visitor.put(oc, OH, OW, oh, ow)] * - input_filter_rescale * input_filter_rescale; + input_filter_rescale * input_filter_rescale; } val = val * input_filter_scale / - (input_filter_rescale * input_filter_rescale * rescale * - rescale); + (input_filter_rescale * input_filter_rescale * rescale * rescale); if (nonline_mode == NonlineMode::RELU) { val = val > 0 ? val : 0; } else if (nonline_mode == NonlineMode::SIGMOID) { @@ -346,11 +328,9 @@ void StrategyHelper< } }; -#define INST(_ctype, _dst_type, _input_filter_compute_type, \ - _output_compute_type) \ - template class StrategyHelper<_ctype, _dst_type, \ - _input_filter_compute_type, \ - _output_compute_type>; +#define INST(_ctype, _dst_type, _input_filter_compute_type, _output_compute_type) \ + template class StrategyHelper< \ + _ctype, _dst_type, _input_filter_compute_type, _output_compute_type>; INST(float, float, float, float) DNN_INC_FLOAT16(INST(dt_float16, dt_float16, dt_float16, dt_float16)) @@ -358,26 +338,26 @@ INST(int8_t, int8_t, int16_t, int) INST(uint8_t, uint8_t, int16_t, int) #undef INST -#define INST(_ctype, _dst_type, _input_filter_compute_type, \ - _output_compute_type, layout) \ - template class StrategyHelper< \ - _ctype, _dst_type, _input_filter_compute_type, \ - _output_compute_type, layout, param::MatrixMul::Format::MK4>; +#define INST( \ + _ctype, _dst_type, _input_filter_compute_type, _output_compute_type, layout) \ + template class StrategyHelper< \ + _ctype, _dst_type, _input_filter_compute_type, _output_compute_type, \ + layout, param::MatrixMul::Format::MK4>; INST(float, float, float, float, param::ConvBias::Format::NCHW) INST(float, float, float, float, param::ConvBias::Format::NCHW44) INST(int8_t, int8_t, float, float, param::ConvBias::Format::NCHW44) #undef INST -#define INST(_ctype, _dst_type, _input_filter_compute_type, \ - _output_compute_type, layout) \ - template class StrategyHelper< \ - _ctype, _dst_type, _input_filter_compute_type, \ - _output_compute_type, layout, param::MatrixMul::Format::MK8>; +#define INST( \ + _ctype, _dst_type, _input_filter_compute_type, _output_compute_type, layout) \ + template class StrategyHelper< \ + _ctype, _dst_type, _input_filter_compute_type, _output_compute_type, \ + layout, param::MatrixMul::Format::MK8>; INST(int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW) INST(int8_t, int8_t, int16_t, int, param::ConvBias::Format::NCHW44) INST(float, float, float, float, param::ConvBias::Format::NCHW88) -DNN_INC_FLOAT16(INST(dt_float16, dt_float16, dt_float16, dt_float16, - param::ConvBias::Format::NCHW)) +DNN_INC_FLOAT16(INST( + dt_float16, dt_float16, dt_float16, dt_float16, param::ConvBias::Format::NCHW)) #undef INST } // namespace winograd } // namespace megdnn diff --git a/dnn/src/common/winograd/winograd_helper.h b/dnn/src/common/winograd/winograd_helper.h index 8e111017..13c741c7 100644 --- a/dnn/src/common/winograd/winograd_helper.h +++ b/dnn/src/common/winograd/winograd_helper.h @@ -26,39 +26,37 @@ using BiasMode = ConvBiasForward::BiasMode; * * \warning The layout should be NCHW */ -template +template < + typename ctype, typename dst_type, typename input_filter_compute_type, + typename output_compute_type, + param::ConvBias::Format layout = param::ConvBias::Format::NCHW, + param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT> class StrategyHelper { public: - static void filter(const ctype* filter, - input_filter_compute_type* filter_transform_buf, - input_filter_compute_type* transform_mid_buf, size_t OC, - size_t IC, size_t oc_start, size_t oc_end, size_t m, - size_t r, const std::vector& interp_points, - DType dtype, float rescale = 1.0f); + static void filter( + const ctype* filter, input_filter_compute_type* filter_transform_buf, + input_filter_compute_type* transform_mid_buf, size_t OC, size_t IC, + size_t oc_start, size_t oc_end, size_t m, size_t r, + const std::vector& interp_points, DType dtype, float rescale = 1.0f); - static void input(const ctype* input, - input_filter_compute_type* input_transform_buf, - input_filter_compute_type* transform_mid_buf, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t IC, size_t ic, size_t unit_idx, - size_t nr_units_in_tile, size_t m, size_t r, - const std::vector& interp_points, DType dtype, - float rescale = 1.0f); + static void input( + const ctype* input, input_filter_compute_type* input_transform_buf, + input_filter_compute_type* transform_mid_buf, int ih_start, int iw_start, + size_t IH, size_t IW, size_t IC, size_t ic, size_t unit_idx, + size_t nr_units_in_tile, size_t m, size_t r, + const std::vector& interp_points, DType dtype, float rescale = 1.0f); - static void - output(const output_compute_type* output_transform_buf, - const output_compute_type* bias, dst_type* output, - output_compute_type* transform_mid_buf, BiasMode bmode, - NonlineMode nonline_mode, size_t oh_start, size_t ow_start, - size_t OH, size_t OW, size_t OC, size_t oc_start, size_t oc_index, - size_t unit_idx, size_t nr_units_in_tile, size_t m, size_t r, - const std::vector& interp_points, DType dtype, - float input_filter_scale = 1.0f, // input_scale * filter_scale - float input_filter_rescale = 1.0f, // input_rescale * filter_rescale - float rescale = 1.0f); + static void output( + const output_compute_type* output_transform_buf, + const output_compute_type* bias, dst_type* output, + output_compute_type* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t OC, size_t oc_start, size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, size_t m, size_t r, + const std::vector& interp_points, DType dtype, + float input_filter_scale = 1.0f, // input_scale * filter_scale + float input_filter_rescale = 1.0f, // input_rescale * filter_rescale + float rescale = 1.0f); }; } // namespace winograd diff --git a/dnn/src/cuda/adaptive_pooling/opr_impl.cpp b/dnn/src/cuda/adaptive_pooling/opr_impl.cpp index 1afc7e83..3bf9bfbb 100644 --- a/dnn/src/cuda/adaptive_pooling/opr_impl.cpp +++ b/dnn/src/cuda/adaptive_pooling/opr_impl.cpp @@ -15,9 +15,8 @@ namespace megdnn { namespace cuda { -void AdaptivePoolingForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) { +void AdaptivePoolingForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { auto opr = handle()->create_operator(); opr->param() = deduce_pooling_param(src.layout, dst.layout); opr->exec(src, dst, workspace); @@ -30,19 +29,17 @@ size_t AdaptivePoolingForwardImpl::get_workspace_in_bytes( return opr->get_workspace_in_bytes(src, dst); } -void AdaptivePoolingBackwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in dst, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) { +void AdaptivePoolingBackwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) { auto opr = handle()->create_operator(); opr->param() = deduce_pooling_param(src.layout, dst.layout); opr->exec(src, dst, diff, grad, workspace); } size_t AdaptivePoolingBackwardImpl::get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& dst, - const TensorLayout& diff, const TensorLayout& grad) { + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad) { auto opr = handle()->create_operator(); opr->param() = deduce_pooling_param(src, dst); return opr->get_workspace_in_bytes(src, dst, diff, grad); diff --git a/dnn/src/cuda/adaptive_pooling/opr_impl.h b/dnn/src/cuda/adaptive_pooling/opr_impl.h index e68b1236..6a47fa2c 100644 --- a/dnn/src/cuda/adaptive_pooling/opr_impl.h +++ b/dnn/src/cuda/adaptive_pooling/opr_impl.h @@ -21,22 +21,22 @@ namespace cuda { class AdaptivePoolingForwardImpl final : public AdaptivePoolingForward { public: using AdaptivePoolingForward::AdaptivePoolingForward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) override; }; class AdaptivePoolingBackwardImpl final : public AdaptivePoolingBackward { public: using AdaptivePoolingBackward::AdaptivePoolingBackward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad) override; }; } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/add_update/kern.cu b/dnn/src/cuda/add_update/kern.cu index a8183285..4e52d0af 100644 --- a/dnn/src/cuda/add_update/kern.cu +++ b/dnn/src/cuda/add_update/kern.cu @@ -14,18 +14,16 @@ namespace megdnn { namespace cuda { -#define cb(_dtype) \ - INST_RUN_ELEMWISE( \ - AddUpdateKernOp::ctype>, \ - DTypeTrait<_dtype>::ctype, 1); \ - INST_RUN_ELEMWISE( \ - AddUpdateKernOpNonContig::ctype>, \ - DTypeTrait<_dtype>::ctype, 2); +#define cb(_dtype) \ + INST_RUN_ELEMWISE( \ + AddUpdateKernOp::ctype>, DTypeTrait<_dtype>::ctype, 1); \ + INST_RUN_ELEMWISE( \ + AddUpdateKernOpNonContig::ctype>, \ + DTypeTrait<_dtype>::ctype, 2); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) -} // namespace megdnn -} // namespace cuda +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen - diff --git a/dnn/src/cuda/add_update/kern.cuh b/dnn/src/cuda/add_update/kern.cuh index 44ee1a71..a8b5983f 100644 --- a/dnn/src/cuda/add_update/kern.cuh +++ b/dnn/src/cuda/add_update/kern.cuh @@ -11,103 +11,98 @@ #pragma once -#include "src/cuda/utils.cuh" #include "src/cuda/elemwise_helper.cuh" +#include "src/cuda/utils.cuh" #if MEGDNN_CC_HOST #include "megdnn/oprs.h" #endif -namespace megdnn{ +namespace megdnn { namespace cuda { - template - struct AddUpdateKernOp { - ctype *dst; - ctype alpha, beta, bias; +template +struct AddUpdateKernOp { + ctype* dst; + ctype alpha, beta, bias; - __device__ void operator() (uint32_t idx, ctype delta) { - dst[idx] = dst[idx] * alpha + delta * beta + bias; - } + __device__ void operator()(uint32_t idx, ctype delta) { + dst[idx] = dst[idx] * alpha + delta * beta + bias; + } #if MEGDNN_CC_HOST - AddUpdateKernOp(const TensorND &dest, const AddUpdate::Param ¶m): - dst{dest.ptr()}, - alpha(param.alpha), beta(param.beta), bias(param.bias) - { - } + AddUpdateKernOp(const TensorND& dest, const AddUpdate::Param& param) + : dst{dest.ptr()}, + alpha(param.alpha), + beta(param.beta), + bias(param.bias) {} #endif - }; +}; - template - struct AddUpdateKernOp< - ctype, typename std::enable_if< - std::is_same::value || - std::is_same::value>::type> { - typedef typename elemwise_intl::VectTypeTrait::vect_type - vect_type; - ctype* dst; - ctype alpha, beta, bias; - __device__ void operator()(uint32_t idx, ctype delta) { - dst[idx] = dst[idx] * alpha + delta * beta + bias; - } - __device__ void operator()(uint32_t idx, vect_type delta) { - vect_type& x = *(vect_type*)(&dst[idx]); - x.x = x.x * alpha + delta.x * beta + bias; - x.y = x.y * alpha + delta.y * beta + bias; - x.z = x.z * alpha + delta.z * beta + bias; - x.w = x.w * alpha + delta.w * beta + bias; - } +template +struct AddUpdateKernOp< + ctype, typename std::enable_if< + std::is_same::value || + std::is_same::value>::type> { + typedef typename elemwise_intl::VectTypeTrait::vect_type vect_type; + ctype* dst; + ctype alpha, beta, bias; + __device__ void operator()(uint32_t idx, ctype delta) { + dst[idx] = dst[idx] * alpha + delta * beta + bias; + } + __device__ void operator()(uint32_t idx, vect_type delta) { + vect_type& x = *(vect_type*)(&dst[idx]); + x.x = x.x * alpha + delta.x * beta + bias; + x.y = x.y * alpha + delta.y * beta + bias; + x.z = x.z * alpha + delta.z * beta + bias; + x.w = x.w * alpha + delta.w * beta + bias; + } #if MEGDNN_CC_HOST - AddUpdateKernOp(const TensorND& dest, const AddUpdate::Param& param) - : dst{dest.ptr()}, - alpha(param.alpha), - beta(param.beta), - bias(param.bias){}; + AddUpdateKernOp(const TensorND& dest, const AddUpdate::Param& param) + : dst{dest.ptr()}, + alpha(param.alpha), + beta(param.beta), + bias(param.bias){}; #endif - }; +}; - template - struct AddUpdateKernOpNonContig { - ctype alpha, beta, bias; +template +struct AddUpdateKernOpNonContig { + ctype alpha, beta, bias; - __device__ void operator() (uint32_t /*idx*/, ctype &dst, ctype delta) { - dst = dst * alpha + delta * beta + bias; - } + __device__ void operator()(uint32_t /*idx*/, ctype& dst, ctype delta) { + dst = dst * alpha + delta * beta + bias; + } #if MEGDNN_CC_HOST - AddUpdateKernOpNonContig(const AddUpdate::Param ¶m): - alpha(param.alpha), beta(param.beta), bias(param.bias) - { - } + AddUpdateKernOpNonContig(const AddUpdate::Param& param) + : alpha(param.alpha), beta(param.beta), bias(param.bias) {} #endif - }; +}; - template - struct AddUpdateKernOpNonContig< - ctype, typename std::enable_if< - std::is_same::value || - std::is_same::value>::type> { - typedef typename elemwise_intl::VectTypeTrait::vect_type - vect_type; - ctype alpha, beta, bias; - __device__ void operator()(uint32_t, ctype& dst, ctype delta) { - dst = dst * alpha + delta * beta + bias; - } - __device__ void operator()(uint32_t, vect_type& dst, vect_type delta) { - dst.x = dst.x * alpha + delta.x * beta + bias; - dst.y = dst.y * alpha + delta.y * beta + bias; - dst.z = dst.z * alpha + delta.z * beta + bias; - dst.w = dst.w * alpha + delta.w * beta + bias; - } +template +struct AddUpdateKernOpNonContig< + ctype, typename std::enable_if< + std::is_same::value || + std::is_same::value>::type> { + typedef typename elemwise_intl::VectTypeTrait::vect_type vect_type; + ctype alpha, beta, bias; + __device__ void operator()(uint32_t, ctype& dst, ctype delta) { + dst = dst * alpha + delta * beta + bias; + } + __device__ void operator()(uint32_t, vect_type& dst, vect_type delta) { + dst.x = dst.x * alpha + delta.x * beta + bias; + dst.y = dst.y * alpha + delta.y * beta + bias; + dst.z = dst.z * alpha + delta.z * beta + bias; + dst.w = dst.w * alpha + delta.w * beta + bias; + } #if MEGDNN_CC_HOST - AddUpdateKernOpNonContig(const AddUpdate::Param& param) - : alpha(param.alpha), beta(param.beta), bias(param.bias) {} + AddUpdateKernOpNonContig(const AddUpdate::Param& param) + : alpha(param.alpha), beta(param.beta), bias(param.bias) {} #endif - }; +}; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen - diff --git a/dnn/src/cuda/add_update/opr_impl.cpp b/dnn/src/cuda/add_update/opr_impl.cpp index 5341d20b..5ec2df07 100644 --- a/dnn/src/cuda/add_update/opr_impl.cpp +++ b/dnn/src/cuda/add_update/opr_impl.cpp @@ -9,16 +9,15 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kern.cuh" #include "./opr_impl.h" +#include "./kern.cuh" #include "src/common/utils.h" using namespace megdnn; using namespace cuda; -void AddUpdateForwardImpl::exec( - _megdnn_tensor_inout dest, _megdnn_tensor_in delta) { +void AddUpdateForwardImpl::exec(_megdnn_tensor_inout dest, _megdnn_tensor_in delta) { check_exec(dest.layout, delta.layout); if (!dest.layout.is_contiguous()) { return exec_noncontig(dest, delta); @@ -29,11 +28,11 @@ void AddUpdateForwardImpl::exec( param.init_from_given_tensor(); auto stream = cuda_stream(handle()); switch (dest.layout.dtype.enumv()) { - -#define cb(_dt) case DTypeTrait<_dt>::enumv: { \ - using ctype = DTypeTrait<_dt>::ctype; \ +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ return run_elemwise, ctype, 1>( \ - param, stream, {dest, m_param}); \ + param, stream, {dest, m_param}); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb @@ -45,15 +44,14 @@ void AddUpdateForwardImpl::exec( void AddUpdateForwardImpl::exec_noncontig( _megdnn_tensor_inout dest, _megdnn_tensor_in delta) { - ElemwiseOpParamN<2> param = make_param(dest, delta); auto stream = cuda_stream(handle()); switch (dest.layout.dtype.enumv()) { - -#define cb(_dt) case DTypeTrait<_dt>::enumv: { \ - using ctype = DTypeTrait<_dt>::ctype; \ +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ return run_elemwise, ctype, 2>( \ - param, stream, {m_param}); \ + param, stream, {m_param}); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb @@ -64,4 +62,3 @@ void AddUpdateForwardImpl::exec_noncontig( } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/argmxx/argmxx.cu b/dnn/src/cuda/argmxx/argmxx.cu index 112ae525..b507fdb2 100644 --- a/dnn/src/cuda/argmxx/argmxx.cu +++ b/dnn/src/cuda/argmxx/argmxx.cu @@ -10,17 +10,17 @@ */ #include "src/common/argmxx_helper.h" -#include "src/cuda/reduce_helper.cuh" #include "megdnn/dtype.h" +#include "src/cuda/reduce_helper.cuh" namespace megdnn { namespace cuda { -#define INST(_dt) \ +#define INST(_dt) \ INST_REDUCE(argmxx::ArgmxxOp::ctype MEGDNN_COMMA false>, false); \ - INST_REDUCE(argmxx::ArgmxxOp::ctype MEGDNN_COMMA true>, false); \ + INST_REDUCE(argmxx::ArgmxxOp::ctype MEGDNN_COMMA true>, false); - MEGDNN_FOREACH_COMPUTING_DTYPE(INST) +MEGDNN_FOREACH_COMPUTING_DTYPE(INST) -} // namespace argmxx -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/argmxx/opr_impl.cpp b/dnn/src/cuda/argmxx/opr_impl.cpp index 01d6795c..78337028 100644 --- a/dnn/src/cuda/argmxx/opr_impl.cpp +++ b/dnn/src/cuda/argmxx/opr_impl.cpp @@ -10,10 +10,10 @@ */ #include "src/cuda/argmxx/opr_impl.h" -#include "src/cuda/utils.h" -#include "src/common/reduce_helper.h" #include "src/common/argmxx_helper.h" +#include "src/common/reduce_helper.h" #include "src/cuda/reduce_helper.cuh" +#include "src/cuda/utils.h" namespace { @@ -22,40 +22,34 @@ using namespace cuda; using namespace argmxx; template -size_t get_workspace_in_bytes_impl(const TensorLayout &src, - const TensorLayout & /* dst */, - size_t axis) -{ +size_t get_workspace_in_bytes_impl( + const TensorLayout& src, const TensorLayout& /* dst */, size_t axis) { size_t A, B, C; reduce::get_ABC(src, A, B, C, axis); - return get_reduce_workspace_in_bytes>( - A, B, C); + return get_reduce_workspace_in_bytes>(A, B, C); } template -void exec_impl(const T *src, int *dst, void *workspace, - size_t A, size_t B, size_t C, - cudaStream_t stream) -{ - argmxx::ArgmxxOp opr(const_cast(src), dst, A, B, C); +void exec_impl( + const T* src, int* dst, void* workspace, size_t A, size_t B, size_t C, + cudaStream_t stream) { + argmxx::ArgmxxOp opr(const_cast(src), dst, A, B, C); run_reduce, false>( - (typename argmxx::ArgmxxOp::wtype *)workspace, - A, B, C, - stream, opr); + (typename argmxx::ArgmxxOp::wtype*)workspace, A, B, C, stream, + opr); after_kernel_launch(); } -} // anonymous namespace +} // anonymous namespace namespace megdnn { namespace cuda { -size_t ArgmaxForwardImpl::get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) -{ -#define cb(DType) \ - if (src.dtype == DType()) { \ - using ctype = typename DTypeTrait::ctype; \ +size_t ArgmaxForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { +#define cb(DType) \ + if (src.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ return get_workspace_in_bytes_impl(src, dst, param().axis); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) @@ -63,33 +57,29 @@ size_t ArgmaxForwardImpl::get_workspace_in_bytes(const TensorLayout &src, megdnn_assert_internal(false); } -void ArgmaxForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) -{ +void ArgmaxForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); size_t A, B, C; reduce::get_ABC(src.layout, A, B, C, param().axis); auto stream = cuda_stream(handle()); -#define cb(DType) \ - if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ - using ctype = typename DTypeTrait::ctype; \ - exec_impl(src.ptr(), \ - dst.ptr(), \ - workspace.raw_ptr, \ - A, B, C, stream); \ +#define cb(DType) \ + if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + exec_impl( \ + src.ptr(), dst.ptr(), workspace.raw_ptr, A, B, C, \ + stream); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb } -size_t ArgminForwardImpl::get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) -{ -#define cb(DType) \ - if (src.dtype == DType()) { \ - using ctype = typename DTypeTrait::ctype; \ +size_t ArgminForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { +#define cb(DType) \ + if (src.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ return get_workspace_in_bytes_impl(src, dst, param().axis); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) @@ -97,28 +87,25 @@ size_t ArgminForwardImpl::get_workspace_in_bytes(const TensorLayout &src, megdnn_assert_internal(false); } -void ArgminForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) -{ +void ArgminForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); size_t A, B, C; reduce::get_ABC(src.layout, A, B, C, param().axis); auto stream = cuda_stream(handle()); -#define cb(DType) \ - if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ - using ctype = typename DTypeTrait::ctype; \ - exec_impl(src.ptr(), \ - dst.ptr(), \ - workspace.raw_ptr, \ - A, B, C, stream); \ +#define cb(DType) \ + if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + exec_impl( \ + src.ptr(), dst.ptr(), workspace.raw_ptr, A, B, C, \ + stream); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) - + #undef cb } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/argmxx/opr_impl.h b/dnn/src/cuda/argmxx/opr_impl.h index b97c5aba..207f91d2 100644 --- a/dnn/src/cuda/argmxx/opr_impl.h +++ b/dnn/src/cuda/argmxx/opr_impl.h @@ -14,28 +14,26 @@ namespace megdnn { namespace cuda { -class ArgmaxForwardImpl final: public ArgmaxForward { - public: - using ArgmaxForward::ArgmaxForward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) override; +class ArgmaxForwardImpl final : public ArgmaxForward { +public: + using ArgmaxForward::ArgmaxForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) override; }; -class ArgminForwardImpl: public ArgminForward { - public: - using ArgminForward::ArgminForward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) override; +class ArgminForwardImpl : public ArgminForward { +public: + using ArgminForward::ArgminForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) override; }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/argsort/argsort.cu b/dnn/src/cuda/argsort/argsort.cu index 666cf62f..25236c9f 100644 --- a/dnn/src/cuda/argsort/argsort.cu +++ b/dnn/src/cuda/argsort/argsort.cu @@ -24,12 +24,9 @@ namespace { struct StridedOffsetIterator { int bias, stride; - StridedOffsetIterator(int bias_, int stride_) - : bias(bias_), stride(stride_) {} + StridedOffsetIterator(int bias_, int stride_) : bias(bias_), stride(stride_) {} - __device__ __forceinline__ int operator[](int i) const { - return stride * i + bias; - } + __device__ __forceinline__ int operator[](int i) const { return stride * i + bias; } }; bool use_bitonic(uint32_t /*M*/, uint32_t N) { @@ -61,8 +58,9 @@ size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) { if (use_bitonic(M, N)) { return 0; } - return argsort::cub_sort_pairs(is_ascending, NULL, 0, NULL, NULL, NULL, NULL, - M, N, 0, sizeof(float)*8, NULL); + return argsort::cub_sort_pairs( + is_ascending, NULL, 0, NULL, NULL, NULL, NULL, M, N, 0, sizeof(float) * 8, + NULL); } } // anonymous namespace @@ -70,29 +68,30 @@ template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs( bool is_ascending, void* workspace, size_t workspace_size, const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, - ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,cudaStream_t stream){ + ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit, + cudaStream_t stream) { cudaError_t err; if (use_segmented(M, N)) { if (is_ascending) { err = cub::DeviceSegmentedRadixSort::SortPairs( - workspace, workspace_size, keys_in, keys_out, values_in, - values_out, N * M, M, StridedOffsetIterator(0, N), - StridedOffsetIterator(N, N), begin_bit, end_bit, stream); + workspace, workspace_size, keys_in, keys_out, values_in, values_out, + N * M, M, StridedOffsetIterator(0, N), StridedOffsetIterator(N, N), + begin_bit, end_bit, stream); cuda_check(err); } else { err = cub::DeviceSegmentedRadixSort::SortPairsDescending( - workspace, workspace_size, keys_in, keys_out, values_in, - values_out, N * M, M, StridedOffsetIterator(0, N), - StridedOffsetIterator(N, N), begin_bit, end_bit, stream); + workspace, workspace_size, keys_in, keys_out, values_in, values_out, + N * M, M, StridedOffsetIterator(0, N), StridedOffsetIterator(N, N), + begin_bit, end_bit, stream); cuda_check(err); } } else { if (is_ascending) { for (size_t i = 0; i < M; ++i) { err = cub::DeviceRadixSort::SortPairs( - workspace, workspace_size, keys_in + N * i, - keys_out + N * i, values_in + N * i, values_out + N * i, - N, begin_bit, end_bit, stream); + workspace, workspace_size, keys_in + N * i, keys_out + N * i, + values_in + N * i, values_out + N * i, N, begin_bit, end_bit, + stream); cuda_check(err); if (!keys_in) { return workspace_size; @@ -101,9 +100,9 @@ MEGDNN_NOINLINE size_t argsort::cub_sort_pairs( } else { for (size_t i = 0; i < M; ++i) { err = cub::DeviceRadixSort::SortPairsDescending( - workspace, workspace_size, keys_in + N * i, - keys_out + N * i, values_in + N * i, values_out + N * i, - N, begin_bit, end_bit, stream); + workspace, workspace_size, keys_in + N * i, keys_out + N * i, + values_in + N * i, values_out + N * i, N, begin_bit, end_bit, + stream); cuda_check(err); if (!keys_in) { return workspace_size; @@ -114,9 +113,8 @@ MEGDNN_NOINLINE size_t argsort::cub_sort_pairs( return workspace_size; } -size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, - bool is_ascending, - bool iptr_src_given) { +size_t argsort::get_fwd_workspace_in_bytes( + uint32_t M, uint32_t N, DType dtype, bool is_ascending, bool iptr_src_given) { size_t size = 0; switch (dtype.enumv().ev) { #define cb(ctype) \ @@ -135,48 +133,46 @@ size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, } template -void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr, - void* workspace, uint32_t M, uint32_t N, - bool is_ascending, cudaStream_t stream, - const int* iptr_src) { +void argsort::forward( + const dtype* sptr, dtype* dptr, int* iptr, void* workspace, uint32_t M, + uint32_t N, bool is_ascending, cudaStream_t stream, const int* iptr_src) { size_t wk_size = get_sort_workspace(M, N, is_ascending); if (!iptr_src) { - int* ptr = reinterpret_cast(static_cast(workspace) + - DIVUP(wk_size, sizeof(float)) * - sizeof(float)); + int* ptr = reinterpret_cast( + static_cast(workspace) + + DIVUP(wk_size, sizeof(float)) * sizeof(float)); kern_arange<<>>(ptr, M * N, N); iptr_src = ptr; } if (use_bitonic(M, N)) { - cuda_check(bitonic_sort(M, N, sptr, iptr_src, dptr, iptr, is_ascending, - stream)); + cuda_check( + bitonic_sort(M, N, sptr, iptr_src, dptr, iptr, is_ascending, stream)); } else { - cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src, - iptr, M, N, 0, sizeof(float)*8, stream); + cub_sort_pairs( + is_ascending, workspace, wk_size, sptr, dptr, iptr_src, iptr, M, N, 0, + sizeof(float) * 8, stream); } } namespace megdnn { namespace cuda { -#define INST_CUB_SORT(dtype) \ -template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs(bool, \ - void*, size_t, const dtype*, dtype*, \ - const dtype*, dtype*, uint32_t, uint32_t,\ - int, int, cudaStream_t); - -#define INST_FORWARD(dtype) \ -template void argsort::forward(const dtype*, dtype*, int*, void*, \ - uint32_t, uint32_t, bool, cudaStream_t, \ - const int*); - +#define INST_CUB_SORT(dtype) \ + template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs( \ + bool, void*, size_t, const dtype*, dtype*, const dtype*, dtype*, uint32_t, \ + uint32_t, int, int, cudaStream_t); + +#define INST_FORWARD(dtype) \ + template void argsort::forward( \ + const dtype*, dtype*, int*, void*, uint32_t, uint32_t, bool, cudaStream_t, \ + const int*); + ARGSORT_FOREACH_CTYPE(INST_FORWARD) INST_CUB_SORT(uint32_t) INST_CUB_SORT(uint64_t) #undef INST_CUB_SORT #undef INST_FORWARD -} +} // namespace cuda } // namespace megdnn // vim: ft=cuda syntax=cuda.doxygen - diff --git a/dnn/src/cuda/argsort/argsort.cuh b/dnn/src/cuda/argsort/argsort.cuh index 77c02078..a4531984 100644 --- a/dnn/src/cuda/argsort/argsort.cuh +++ b/dnn/src/cuda/argsort/argsort.cuh @@ -20,27 +20,27 @@ namespace megdnn { namespace cuda { namespace argsort { -size_t get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, - bool is_ascending, - bool iptr_src_given = false); +size_t get_fwd_workspace_in_bytes( + uint32_t M, uint32_t N, DType dtype, bool is_ascending, + bool iptr_src_given = false); template size_t cub_sort_pairs( bool is_ascending, void* workspace, size_t workspace_size, const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, - ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,cudaStream_t stream); + ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit, + cudaStream_t stream); /*! * \param iptr_src pointer to indices; a range would be generated if it is null */ template -void forward(const dtype* sptr, dtype* dptr, int* iptr, void* workspace, - uint32_t M, uint32_t N, bool is_ascending, cudaStream_t stream, - const int* iptr_src = NULL); +void forward( + const dtype* sptr, dtype* dptr, int* iptr, void* workspace, uint32_t M, + uint32_t N, bool is_ascending, cudaStream_t stream, const int* iptr_src = NULL); //! iterate over all supported data types -#define ARGSORT_FOREACH_CTYPE(cb) \ - cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16)) +#define ARGSORT_FOREACH_CTYPE(cb) cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16)) } // namespace argsort } // namespace cuda diff --git a/dnn/src/cuda/argsort/backward.cu b/dnn/src/cuda/argsort/backward.cu index 66cda244..b46b1966 100644 --- a/dnn/src/cuda/argsort/backward.cu +++ b/dnn/src/cuda/argsort/backward.cu @@ -21,9 +21,9 @@ using namespace argsort; namespace { template -__global__ void backward_kernel(uint32_t dst_w, uint32_t src_w, - uint32_t src_size, T* dst, const T* src_data, - const int* src_idx) { +__global__ void backward_kernel( + uint32_t dst_w, uint32_t src_w, uint32_t src_size, T* dst, const T* src_data, + const int* src_idx) { uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < src_size) { uint32_t r = idx / src_w; @@ -34,9 +34,9 @@ __global__ void backward_kernel(uint32_t dst_w, uint32_t src_w, } // namespace template -void argsort::backward_proxy(uint32_t dst_h, uint32_t dst_w, uint32_t src_w, - T* dst, const T* src_data, const int* src_idx, - cudaStream_t stream) { +void argsort::backward_proxy( + uint32_t dst_h, uint32_t dst_w, uint32_t src_w, T* dst, const T* src_data, + const int* src_idx, cudaStream_t stream) { if (dst_w != src_w) { cudaMemsetAsync(dst, 0, dst_h * dst_w * sizeof(T), stream); } @@ -51,10 +51,10 @@ namespace megdnn { namespace cuda { namespace argsort { -#define INST(T) \ - template void backward_proxy(uint32_t dst_h, uint32_t dst_w, \ - uint32_t src_w, T* dst, const T* src_data, \ - const int* src_idx, cudaStream_t stream); +#define INST(T) \ + template void backward_proxy( \ + uint32_t dst_h, uint32_t dst_w, uint32_t src_w, T* dst, const T* src_data, \ + const int* src_idx, cudaStream_t stream); ARGSORT_FOREACH_CTYPE(INST) #undef INST diff --git a/dnn/src/cuda/argsort/backward.cuh b/dnn/src/cuda/argsort/backward.cuh index 3b4daa0a..edcde300 100644 --- a/dnn/src/cuda/argsort/backward.cuh +++ b/dnn/src/cuda/argsort/backward.cuh @@ -18,12 +18,12 @@ namespace cuda { namespace argsort { template -void backward_proxy(uint32_t dst_h, uint32_t dst_w, uint32_t src_w, T* dst, - const T* src_data, const int* src_idx, cudaStream_t stream); +void backward_proxy( + uint32_t dst_h, uint32_t dst_w, uint32_t src_w, T* dst, const T* src_data, + const int* src_idx, cudaStream_t stream); } // namespace argsort } // namespace cuda } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/argsort/bitonic_sort.cu b/dnn/src/cuda/argsort/bitonic_sort.cu index 386e4a57..57d921f4 100644 --- a/dnn/src/cuda/argsort/bitonic_sort.cu +++ b/dnn/src/cuda/argsort/bitonic_sort.cu @@ -10,8 +10,8 @@ */ #include "./bitonic_sort.cuh" -#include "src/cuda/query_blocksize.cuh" #include "megdnn/dtype.h" +#include "src/cuda/query_blocksize.cuh" #if __CUDACC_VER_MAJOR__ < 9 #pragma message "warp sync disabled due to insufficient cuda version" @@ -28,16 +28,16 @@ namespace bitonic_sort_impl { //! load keys and init idx template -__device__ __forceinline__ void safe_load0(T* dst, uint16_t* idx, const T* src, - uint32_t id, uint32_t size) { +__device__ __forceinline__ void safe_load0( + T* dst, uint16_t* idx, const T* src, uint32_t id, uint32_t size) { dst[id] = id < size ? src[id] : CompareLess::template max(); idx[id] = id; } //! load values template -__device__ __forceinline__ void safe_load1(T* dst, const T* src, uint32_t id, - uint32_t size) { +__device__ __forceinline__ void safe_load1( + T* dst, const T* src, uint32_t id, uint32_t size) { // broadcast last value to avoid out-of-bound values (for example, when // input contains NaN) dst[id] = src[min(id, size - 1)]; @@ -45,8 +45,8 @@ __device__ __forceinline__ void safe_load1(T* dst, const T* src, uint32_t id, //! write keys template -__device__ __forceinline__ void safe_write0(T* dst, const T* src, uint32_t id, - uint32_t size) { +__device__ __forceinline__ void safe_write0( + T* dst, const T* src, uint32_t id, uint32_t size) { if (id < size) { dst[id] = src[id]; } @@ -54,9 +54,8 @@ __device__ __forceinline__ void safe_write0(T* dst, const T* src, uint32_t id, //! write values template -__device__ __forceinline__ void safe_write1(T* dst, const T* src, - const uint16_t* remap, uint32_t id, - uint32_t size) { +__device__ __forceinline__ void safe_write1( + T* dst, const T* src, const uint16_t* remap, uint32_t id, uint32_t size) { if (id < size) { dst[id] = src[remap[id]]; } @@ -97,8 +96,7 @@ struct NumTrait { struct LessThan { template - static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, - Value v1) { + static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, Value v1) { return k0 < k1 | ((k0 == k1) & (v0 < v1)); } @@ -110,8 +108,7 @@ struct LessThan { struct GreaterThan { template - static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, - Value v1) { + static __device__ __forceinline__ bool cmp(Key k0, Value v0, Key k1, Value v1) { return k0 > k1 | ((k0 == k1) & (v0 < v1)); } @@ -141,11 +138,12 @@ static int get_shmem(int block_size, void* = NULL) { * * where N / 4 == 1 << nr_th_log2 */ -template -static __global__ void kern(uint32_t batch, uint32_t length, const Key* key_inp, - const Value* value_inp, Key* key_out, - Value* value_out) { +template < + class Sync, typename Key, typename Value, class CompareLess, + uint32_t nr_th_log2> +static __global__ void kern( + uint32_t batch, uint32_t length, const Key* key_inp, const Value* value_inp, + Key* key_out, Value* value_out) { const uint32_t nr_th = 1 << nr_th_log2; // 24KiB shared memory for 4-byte keys for 1024 threads @@ -168,10 +166,8 @@ static __global__ void kern(uint32_t batch, uint32_t length, const Key* key_inp, cur_length = cur_batch < batch ? length : 0; safe_load0(keys, values, key_inp, tid0, cur_length); safe_load0(keys, values, key_inp, tid0 + nr_th, cur_length); - safe_load0(keys, values, key_inp, tid0 + nr_th * 2, - cur_length); - safe_load0(keys, values, key_inp, tid0 + nr_th * 3, - cur_length); + safe_load0(keys, values, key_inp, tid0 + nr_th * 2, cur_length); + safe_load0(keys, values, key_inp, tid0 + nr_th * 3, cur_length); Sync::s(); @@ -192,12 +188,10 @@ static __global__ void kern(uint32_t batch, uint32_t length, const Key* key_inp, for (uint32_t slen_log = 0; slen_log <= (nr_th_log2 + 1); ++slen_log) { // log2 of half of current bitonic sequence (i.e. length of its // monotonic part) - uint32_t asc0 = !((tid0 >> slen_log) & 1), - asc1 = !((tid1 >> slen_log) & 1); + uint32_t asc0 = !((tid0 >> slen_log) & 1), asc1 = !((tid1 >> slen_log) & 1); #pragma unroll for (uint32_t j = 0; j <= slen_log; ++j) { - uint32_t step = 1 << (slen_log - j), xmask = step - 1, - ymask = ~xmask; + uint32_t step = 1 << (slen_log - j), xmask = step - 1, ymask = ~xmask; WORK((tid0 & xmask) + ((tid0 & ymask) << 1), asc0); WORK((tid1 & xmask) + ((tid1 & ymask) << 1), asc1); Sync::s(); @@ -230,25 +224,25 @@ static __global__ void kern(uint32_t batch, uint32_t length, const Key* key_inp, } // namespace bitonic_sort_impl template -cudaError_t cuda::bitonic_sort(uint32_t batch, uint32_t length, - const Key* key_inp, const Value* value_inp, - Key* key_out, Value* value_out, bool ascending, - cudaStream_t stream) { +cudaError_t cuda::bitonic_sort( + uint32_t batch, uint32_t length, const Key* key_inp, const Value* value_inp, + Key* key_out, Value* value_out, bool ascending, cudaStream_t stream) { using namespace bitonic_sort_impl; if (length == 1) { if (key_inp != key_out) { - cudaMemcpyAsync(key_out, key_inp, sizeof(Key) * batch, - cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync( + key_out, key_inp, sizeof(Key) * batch, cudaMemcpyDeviceToDevice, + stream); } if (value_inp != value_out) { - cudaMemcpyAsync(value_out, value_inp, sizeof(Value) * batch, - cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync( + value_out, value_inp, sizeof(Value) * batch, + cudaMemcpyDeviceToDevice, stream); } return cudaGetLastError(); } - void (*kptr)(uint32_t, uint32_t, const Key*, const Value*, Key*, Value*) = - NULL; + void (*kptr)(uint32_t, uint32_t, const Key*, const Value*, Key*, Value*) = NULL; uint32_t l4 = (length + 3) / 4; dim3 block; @@ -288,8 +282,8 @@ cudaError_t cuda::bitonic_sort(uint32_t batch, uint32_t length, } int suggested_block_size = - query_launch_config_for_kernel(reinterpret_cast(kptr), - get_shmem) + query_launch_config_for_kernel( + reinterpret_cast(kptr), get_shmem) .block_size; block.y = std::max(suggested_block_size / block.x, 1); int shmem = get_shmem(block.y * block.x); @@ -301,18 +295,16 @@ cudaError_t cuda::bitonic_sort(uint32_t batch, uint32_t length, namespace megdnn { namespace cuda { -#define INST(k, v) \ - template cudaError_t bitonic_sort(uint32_t, uint32_t, const k*, \ - const v*, k*, v*, bool, \ - cudaStream_t) +#define INST(k, v) \ + template cudaError_t bitonic_sort( \ + uint32_t, uint32_t, const k*, const v*, k*, v*, bool, cudaStream_t) INST(float, int); INST(int32_t, int); DNN_INC_FLOAT16(INST(dt_float16, int)); #undef INST -} // namespace megdnn +} // namespace cuda } // namespace megdnn // vim: ft=cuda syntax=cuda.doxygen - diff --git a/dnn/src/cuda/argsort/bitonic_sort.cuh b/dnn/src/cuda/argsort/bitonic_sort.cuh index 9f5f3171..6950369e 100644 --- a/dnn/src/cuda/argsort/bitonic_sort.cuh +++ b/dnn/src/cuda/argsort/bitonic_sort.cuh @@ -27,12 +27,11 @@ const uint32_t BITONIC_SORT_MAX_LENGTH = 2048; * and \p key_out can be identical, and so are \p value_inp and \p value_out. */ template -cudaError_t bitonic_sort(uint32_t batch, uint32_t length, const Key* key_inp, - const Value* value_inp, Key* key_out, Value* value_out, - bool ascending, cudaStream_t stream); +cudaError_t bitonic_sort( + uint32_t batch, uint32_t length, const Key* key_inp, const Value* value_inp, + Key* key_out, Value* value_out, bool ascending, cudaStream_t stream); } // namespace cuda } // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen - diff --git a/dnn/src/cuda/argsort/opr_impl.cpp b/dnn/src/cuda/argsort/opr_impl.cpp index b18e5a4a..29b15f47 100644 --- a/dnn/src/cuda/argsort/opr_impl.cpp +++ b/dnn/src/cuda/argsort/opr_impl.cpp @@ -19,9 +19,9 @@ using namespace megdnn; using namespace cuda; -void ArgsortForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_tensor_out indices, - _megdnn_workspace workspace) { +void ArgsortForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices, + _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, indices.layout, workspace.size); auto M = src.layout.shape[0], N = src.layout.shape[1]; auto iptr = indices.ptr(); @@ -29,50 +29,47 @@ void ArgsortForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, bool is_ascending = (param().order == Order::ASCENDING); auto stream = cuda_stream(this->handle()); switch (src.layout.dtype.enumv()) { -#define cb(t) \ - case DTypeTrait::enumv: \ - argsort::forward(src.ptr(), dst.ptr(), iptr, wptr, M, N, \ - is_ascending, stream); \ +#define cb(t) \ + case DTypeTrait::enumv: \ + argsort::forward( \ + src.ptr(), dst.ptr(), iptr, wptr, M, N, is_ascending, stream); \ break; ARGSORT_FOREACH_CTYPE(cb); #undef cb default: - megdnn_throw(ssprintf("unsupported argsort dtype on cuda: %s", - src.layout.dtype.name())); + megdnn_throw(ssprintf( + "unsupported argsort dtype on cuda: %s", src.layout.dtype.name())); } } -size_t ArgsortForwardImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout&, - const TensorLayout&) { - megdnn_assert(src.ndim == 2, "invalid src layout: %s", - src.to_string().c_str()); +size_t ArgsortForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout&, const TensorLayout&) { + megdnn_assert(src.ndim == 2, "invalid src layout: %s", src.to_string().c_str()); auto M = src.shape[0], N = src.shape[1]; auto&& dtype = src.dtype; - megdnn_assert(std::max(M, N) <= - static_cast(std::numeric_limits::max())); + megdnn_assert( + std::max(M, N) <= static_cast(std::numeric_limits::max())); return argsort::get_fwd_workspace_in_bytes( M, N, dtype, param().order == Param::Order::ASCENDING); } -void ArgsortBackwardImpl::exec(_megdnn_tensor_in diff, - _megdnn_tensor_in indices, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) { +void ArgsortBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { check_exec(diff.layout, indices.layout, grad.layout, workspace.size); auto stream = cuda_stream(this->handle()); switch (diff.layout.dtype.enumv()) { -#define cb(t) \ - case DTypeTrait::enumv: \ - argsort::backward_proxy(grad.layout[0], grad.layout[1], \ - diff.layout[1], grad.ptr(), diff.ptr(), \ - indices.ptr(), stream); \ +#define cb(t) \ + case DTypeTrait::enumv: \ + argsort::backward_proxy( \ + grad.layout[0], grad.layout[1], diff.layout[1], grad.ptr(), \ + diff.ptr(), indices.ptr(), stream); \ break; ARGSORT_FOREACH_CTYPE(cb); #undef cb default: - megdnn_throw(ssprintf("unsupported argsort dtype on cuda: %s", - diff.layout.dtype.name())); + megdnn_throw(ssprintf( + "unsupported argsort dtype on cuda: %s", diff.layout.dtype.name())); } } diff --git a/dnn/src/cuda/argsort/opr_impl.h b/dnn/src/cuda/argsort/opr_impl.h index 82d5a41f..b2abe158 100644 --- a/dnn/src/cuda/argsort/opr_impl.h +++ b/dnn/src/cuda/argsort/opr_impl.h @@ -14,34 +14,30 @@ namespace megdnn { namespace cuda { -class ArgsortForwardImpl final: public ArgsortForward { - public: - using ArgsortForward::ArgsortForward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_tensor_out indices, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst, - const TensorLayout &indices) override; +class ArgsortForwardImpl final : public ArgsortForward { +public: + using ArgsortForward::ArgsortForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& indices) override; }; -class ArgsortBackwardImpl final: public ArgsortBackward { - public: - using ArgsortBackward::ArgsortBackward; - void exec(_megdnn_tensor_in diff, - _megdnn_tensor_in indices, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &, - const TensorLayout &) override { - return 0; - } +class ArgsortBackwardImpl final : public ArgsortBackward { +public: + using ArgsortBackward::ArgsortBackward; + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/atomic_add.cuh b/dnn/src/cuda/atomic_add.cuh index adab0ee6..4d7b4c6d 100644 --- a/dnn/src/cuda/atomic_add.cuh +++ b/dnn/src/cuda/atomic_add.cuh @@ -35,22 +35,19 @@ template <> MEGDNN_DEVICE void atomic_add(dt_float16* address, dt_float16 val) { #if (__CUDA_ARCH__ < 700 || __CUDACC_VER_MAJOR__ <= 9) unsigned int* address_as_ui = reinterpret_cast( - reinterpret_cast(address) - - (reinterpret_cast(address) & 2)); + reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *address_as_ui; unsigned int assumed; do { assumed = old; - unsigned short data = reinterpret_cast(address) & 2 - ? (old >> 16) - : (old & 0xffff); + unsigned short data = + reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); dt_float16 hsum = *reinterpret_cast(&data); hsum += val; data = *reinterpret_cast(&hsum); - old = reinterpret_cast(address) & 2 - ? (old & 0xffff) | (data << 16) - : (old & 0xffff0000) | data; + old = reinterpret_cast(address) & 2 ? (old & 0xffff) | (data << 16) + : (old & 0xffff0000) | data; old = ::atomicCAS(address_as_ui, assumed, old); } while (assumed != old); #else @@ -61,22 +58,19 @@ MEGDNN_DEVICE void atomic_add(dt_float16* address, dt_float16 val) { template <> MEGDNN_DEVICE void atomic_add(dt_bfloat16* address, dt_bfloat16 val) { unsigned int* address_as_ui = reinterpret_cast( - reinterpret_cast(address) - - (reinterpret_cast(address) & 2)); + reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *address_as_ui; unsigned int assumed; do { assumed = old; - unsigned short data = reinterpret_cast(address) & 2 - ? (old >> 16) - : (old & 0xffff); + unsigned short data = + reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); dt_bfloat16 hsum = *reinterpret_cast(&data); hsum += val; data = *reinterpret_cast(&hsum); - old = reinterpret_cast(address) & 2 - ? (old & 0xffff) | (data << 16) - : (old & 0xffff0000) | data; + old = reinterpret_cast(address) & 2 ? (old & 0xffff) | (data << 16) + : (old & 0xffff0000) | data; old = ::atomicCAS(address_as_ui, assumed, old); } while (assumed != old); } @@ -99,8 +93,8 @@ struct AtomicAddIntegerImpl { old_byte = (old >> shift) & 0xff; // preserve size in initial cast. Casting directly to uint32_t pads // negative signed values with 1's (e.g. signed -1 = unsigned ~0). - newval = static_cast(static_cast(val) + - static_cast(old_byte)); + newval = static_cast( + static_cast(val) + static_cast(old_byte)); // newval = static_cast(THCNumerics::add(val, // old_byte)); newval = (old & ~(0x000000ff << shift)) | (newval << shift); @@ -124,8 +118,8 @@ struct AtomicAddIntegerImpl { old_bytes = is_32_align ? old >> 16 : old & 0xffff; // preserve size in initial cast. Casting directly to uint32_t pads // negative signed values with 1's (e.g. signed -1 = unsigned ~0). - newval = static_cast(static_cast(val) + - static_cast(old_bytes)); + newval = static_cast( + static_cast(val) + static_cast(old_bytes)); // newval = static_cast(THCNumerics::add(val, // old_bytes)); newval = is_32_align ? (old & 0xffff) | (newval << 16) diff --git a/dnn/src/cuda/batch_conv_bias/algo.cpp b/dnn/src/cuda/batch_conv_bias/algo.cpp index 91425147..ba359611 100644 --- a/dnn/src/cuda/batch_conv_bias/algo.cpp +++ b/dnn/src/cuda/batch_conv_bias/algo.cpp @@ -30,8 +30,8 @@ BatchConvBiasForwardImpl::AlgoPack BatchConvBiasForwardImpl::sm_algo_pack; BatchConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( BatchConvBiasForwardImpl* o, const TensorLayout& src, - const TensorLayout& filter, const TensorLayout& bias, - const TensorLayout& z, const TensorLayout& dst) + const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) : opr{o}, src_layout{src}, filter_layout{filter}, @@ -40,11 +40,10 @@ BatchConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( dst_layout{dst} {} BatchConvBiasForwardImpl::AlgoBase::ExecArgs::ExecArgs( - BatchConvBiasForwardImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_in filter, _megdnn_tensor_in bias, _megdnn_tensor_in z, - _megdnn_tensor_out dst, _megdnn_workspace workspace) - : SizeArgs(opr, src.layout, filter.layout, bias.layout, z.layout, - dst.layout), + BatchConvBiasForwardImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_in filter, + _megdnn_tensor_in bias, _megdnn_tensor_in z, _megdnn_tensor_out dst, + _megdnn_workspace workspace) + : SizeArgs(opr, src.layout, filter.layout, bias.layout, z.layout, dst.layout), src_tensor{&src}, filter_tensor{&filter}, bias_tensor{&bias}, @@ -61,11 +60,11 @@ std::string BatchConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const { "dtype=(%s(src),%s(flt),%s(bias),%s(z))->(%s(dst))", src_layout.to_string().c_str(), filter_layout.to_string().c_str(), bias_layout.to_string().c_str(), z_layout.to_string().c_str(), - dst_layout.to_string().c_str(), param.pad_h, param.pad_w, - param.stride_h, param.stride_w, param.dilate_h, param.dilate_w, + dst_layout.to_string().c_str(), param.pad_h, param.pad_w, param.stride_h, + param.stride_w, param.dilate_h, param.dilate_w, static_cast(param.mode), src_layout.dtype.name(), - filter_layout.dtype.name(), bias_layout.dtype.name(), - z_layout.dtype.name(), dst_layout.dtype.name()); + filter_layout.dtype.name(), bias_layout.dtype.name(), z_layout.dtype.name(), + dst_layout.dtype.name()); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/batch_conv_bias/algo.h b/dnn/src/cuda/batch_conv_bias/algo.h index be5cf602..359dada7 100644 --- a/dnn/src/cuda/batch_conv_bias/algo.h +++ b/dnn/src/cuda/batch_conv_bias/algo.h @@ -39,23 +39,23 @@ public: AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } struct SizeArgs { BatchConvBiasForwardImpl* opr; - TensorLayout src_layout, filter_layout, bias_layout, z_layout, - dst_layout; + TensorLayout src_layout, filter_layout, bias_layout, z_layout, dst_layout; std::string to_string() const; - SizeArgs(BatchConvBiasForwardImpl* opr, const TensorLayout& src, - const TensorLayout& filter, const TensorLayout& bias, - const TensorLayout& z, const TensorLayout& dst); + SizeArgs( + BatchConvBiasForwardImpl* opr, const TensorLayout& src, + const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst); }; struct ExecArgs : public SizeArgs { const TensorND *src_tensor, *filter_tensor, *bias_tensor, *z_tensor, *dst_tensor; Workspace workspace; - ExecArgs(BatchConvBiasForwardImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_in filter, _megdnn_tensor_in bias, - _megdnn_tensor_in z, _megdnn_tensor_out dst, - _megdnn_workspace workspace); + ExecArgs( + BatchConvBiasForwardImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in filter, _megdnn_tensor_in bias, _megdnn_tensor_in z, + _megdnn_tensor_out dst, _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -71,31 +71,27 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "batch conv bias fwd algo %s: required workspace %zu " - "bytes, got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "batch conv bias fwd algo %s: required workspace %zu " + "bytes, got %zu", + name(), req, workspace.size); return *this; } }; -class BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdGemm final - : public AlgoBase { +class BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdGemm final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD"; @@ -110,9 +106,7 @@ public: size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD"; @@ -120,8 +114,7 @@ public: MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_PRECOMP_NCHW4_DOTPROD_INT8) private: - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; }; class BatchConvBiasForwardImpl::AlgoPack : NonCopyableObj { diff --git a/dnn/src/cuda/batch_conv_bias/batch_conv_bias.cuh b/dnn/src/cuda/batch_conv_bias/batch_conv_bias.cuh index c434704f..44c8a099 100644 --- a/dnn/src/cuda/batch_conv_bias/batch_conv_bias.cuh +++ b/dnn/src/cuda/batch_conv_bias/batch_conv_bias.cuh @@ -34,33 +34,30 @@ struct LaunchConfig { }; template -void do_batch_conv_bias_int8_gemm_ncdiv4hw4(const int8_t* d_src, - const int8_t* d_filter, - BiasVisitor bias, Epilogue epilogue, - const convolution::ConvParam& param, - float alpha, float beta, - cudaStream_t stream); +void do_batch_conv_bias_int8_gemm_ncdiv4hw4( + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4( - const int8_t* d_src, const int8_t* d_filter, int* workspace, - BiasVisitor bias, Epilogue epilogue, - const convolution::ConvParam& param, float alpha, float beta, + const int8_t* d_src, const int8_t* d_filter, int* workspace, BiasVisitor bias, + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, cudaStream_t stream); } // namespace batch_conv_bias } // namespace cuda } // namespace megdnn -#define MARK_USED_VAR \ - MEGDNN_MARK_USED_VAR(n + ci + hi + wi + co + fh + fw + ho + wo + ph + pw + \ - sh + sw + dh + dw); +#define MARK_USED_VAR \ + MEGDNN_MARK_USED_VAR( \ + n + ci + hi + wi + co + fh + fw + ho + wo + ph + pw + sh + sw + dh + dw); #define UNPACK_BATCH_CONV_PARAMETER(_param) \ size_t ph = _param.pad_h, pw = _param.pad_w; \ diff --git a/dnn/src/cuda/batch_conv_bias/gemm_int8_nchw4_dp4a.cpp b/dnn/src/cuda/batch_conv_bias/gemm_int8_nchw4_dp4a.cpp index 433c2c23..348ebb98 100644 --- a/dnn/src/cuda/batch_conv_bias/gemm_int8_nchw4_dp4a.cpp +++ b/dnn/src/cuda/batch_conv_bias/gemm_int8_nchw4_dp4a.cpp @@ -25,52 +25,50 @@ using namespace cuda; using namespace convolution; namespace { template -void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, Epilogue epilogue, - const ConvParam& param, float alpha, float beta, - cudaStream_t stream) { - void (*kern_wrapper)(const int8_t*, const int8_t*, BiasVisitor, Epilogue, - const ConvParam&, float, float, cudaStream_t); +void dispatch_kernel( + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + Epilogue epilogue, const ConvParam& param, float alpha, float beta, + cudaStream_t stream) { + void (*kern_wrapper)( + const int8_t*, const int8_t*, BiasVisitor, Epilogue, const ConvParam&, + float, float, cudaStream_t); using namespace batch_conv_bias; int img_pixels = param.ho * param.wo; if (img_pixels % 4 == 0) { kern_wrapper = - do_batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128; + do_batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128; } else { - kern_wrapper = - do_batch_conv_bias_int8_gemm_ncdiv4hw4; + kern_wrapper = do_batch_conv_bias_int8_gemm_ncdiv4hw4; } megdnn_assert(kern_wrapper != nullptr); - return kern_wrapper(d_src, d_filter, bias_visitor, epilogue, param, alpha, - beta, stream); + return kern_wrapper( + d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, stream); } template -void dispatch_nonlinear_mode(const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, const int8_t* d_z, - int8_t* d_dst, const ConvParam& param, float alpha, - float beta, float gamma, float scale, - cudaStream_t stream, - param::BatchConvBias::NonlineMode nonlinear_mode) { +void dispatch_nonlinear_mode( + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + const int8_t* d_z, int8_t* d_dst, const ConvParam& param, float alpha, + float beta, float gamma, float scale, cudaStream_t stream, + param::BatchConvBias::NonlineMode nonlinear_mode) { using NonlineMode = megdnn::param_enumv::BatchConvBias::NonlineMode; Layout layout; layout.init(param.n, param.co, param.ho, param.wo); using namespace batch_conv_bias; -#define DISPATCH_CONV_INT8_EPILOGUE(_act_op) \ - do { \ - IConvEpilogue<_act_op> epilogue{d_dst, \ - d_z, \ - layout.batch_stride, \ - layout.channel_stride / 4, \ - layout.height_stride, \ - layout.width_stride, \ - gamma, \ - _act_op{scale, 1.f / scale}}; \ - dispatch_kernel>( \ - d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, \ - stream); \ - return; \ +#define DISPATCH_CONV_INT8_EPILOGUE(_act_op) \ + do { \ + IConvEpilogue<_act_op> epilogue{ \ + d_dst, \ + d_z, \ + layout.batch_stride, \ + layout.channel_stride / 4, \ + layout.height_stride, \ + layout.width_stride, \ + gamma, \ + _act_op{scale, 1.f / scale}}; \ + dispatch_kernel>( \ + d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, stream); \ + return; \ } while (0) #define cb(_nonline_mode) \ if (static_cast(nonlinear_mode) == NonlineMode::_nonline_mode) { \ @@ -82,12 +80,11 @@ void dispatch_nonlinear_mode(const int8_t* d_src, const int8_t* d_filter, #undef DISPATCH_CONV_INT8_EPILOGUE } -#define INST(_visitor) \ - template void dispatch_nonlinear_mode<_visitor>( \ - const int8_t* d_src, const int8_t* d_filter, \ - _visitor bias_visitor, const int8_t* d_z, int8_t* d_dst, \ - const ConvParam& param, float alpha, float beta, float gamma, \ - float scale, cudaStream_t stream, \ +#define INST(_visitor) \ + template void dispatch_nonlinear_mode<_visitor>( \ + const int8_t* d_src, const int8_t* d_filter, _visitor bias_visitor, \ + const int8_t* d_z, int8_t* d_dst, const ConvParam& param, float alpha, \ + float beta, float gamma, float scale, cudaStream_t stream, \ param::BatchConvBias::NonlineMode nonlinear_mode); INST(PerChannelBiasVisitor); @@ -110,33 +107,31 @@ bool BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdGemm::is_available( return false; if (param.format != Format::NCHW4) return false; - UNPACK_BATCH_CONV_BIAS_NCHW4_PARAM(args.src_layout, args.filter_layout, - args.dst_layout, param); + UNPACK_BATCH_CONV_BIAS_NCHW4_PARAM( + args.src_layout, args.filter_layout, args.dst_layout, param); // TODO support group conv available &= param.sparse == Sparse::DENSE; // mode must be cross correlation available &= param.mode == Mode::CROSS_CORRELATION; // check data type - auto src_dtype = args.src_layout.dtype, - filter_dtype = args.filter_layout.dtype, + auto src_dtype = args.src_layout.dtype, filter_dtype = args.filter_layout.dtype, bias_dtype = args.bias_layout.dtype, dst_dtype = args.dst_layout.dtype; - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8 && - bias_dtype.enumv() == DTypeEnum::QuantizedS32 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8); + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + bias_dtype.enumv() == DTypeEnum::QuantizedS32 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); // TODO: support dialtion available &= dh == 1 && dw == 1; // can be treat as gemm - available &= - (fh == 1 && sh == 1 && fw == 1 && sw == 1 && ph == 0 && pw == 0); + available &= (fh == 1 && sh == 1 && fw == 1 && sw == 1 && ph == 0 && pw == 0); // only support sm_61 or later, platform should have fast native int8 // support available &= is_compute_capability_required(6, 1); return available; } -size_t -BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdGemm::get_workspace_in_bytes( +size_t BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdGemm::get_workspace_in_bytes( const SizeArgs& /* args */) const { return 0; } @@ -145,25 +140,21 @@ void BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdGemm::exec( const ExecArgs& args) const { using Format = Param::Format; auto&& param = args.opr->param(); - UNPACK_BATCH_CONV_BIAS_NCHW4_PARAM(args.src_layout, args.filter_layout, - args.dst_layout, param); + UNPACK_BATCH_CONV_BIAS_NCHW4_PARAM( + args.src_layout, args.filter_layout, args.dst_layout, param); auto&& stream = cuda_stream(args.opr->handle()); ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, + kern_param.pw = pw, kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, kern_param.fw = fw; float src_scale = args.src_layout.dtype.param().scale, - filter_scale = - args.filter_layout.dtype.param().scale, - bias_scale = - args.bias_layout.dtype.param().scale, + filter_scale = args.filter_layout.dtype.param().scale, + bias_scale = args.bias_layout.dtype.param().scale, dst_scale = args.dst_layout.dtype.param().scale; - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale; int8_t* z_dev_ptr = nullptr; float gamma = 1.f; if (args.z_layout.ndim > 0) { @@ -175,9 +166,9 @@ void BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdGemm::exec( bias_visitor.bias = args.bias_tensor->compatible_ptr(); dispatch_nonlinear_mode( args.src_tensor->compatible_ptr(), - args.filter_tensor->compatible_ptr(), bias_visitor, - z_dev_ptr, args.dst_tensor->compatible_ptr(), kern_param, - alpha, beta, gamma, dst_scale, stream, param.nonlineMode); + args.filter_tensor->compatible_ptr(), bias_visitor, z_dev_ptr, + args.dst_tensor->compatible_ptr(), kern_param, alpha, beta, gamma, + dst_scale, stream, param.nonlineMode); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/batch_conv_bias/helper.cu b/dnn/src/cuda/batch_conv_bias/helper.cu index 849fe1e1..f8b0f087 100644 --- a/dnn/src/cuda/batch_conv_bias/helper.cu +++ b/dnn/src/cuda/batch_conv_bias/helper.cu @@ -17,8 +17,8 @@ using namespace cuda; using namespace batch_conv_bias; namespace { -__global__ void kern_compute_offset(int* __restrict__ offset, - const convolution::ConvParam param) { +__global__ void kern_compute_offset( + int* __restrict__ offset, const convolution::ConvParam param) { const int tid = threadIdx.x + blockDim.x * blockIdx.x; const int img_pixels = param.ho * param.wo; const int img_pixels_ru128 = DIVUP(img_pixels, 128) * 128; @@ -33,8 +33,7 @@ __global__ void kern_compute_offset(int* __restrict__ offset, const int kw = filter_idx - param.fw * kh; const int ih = param.sh * oh - param.ph + kh; const int iw = param.sw * ow - param.pw + kw; - if (img_idx < img_pixels && ih >= 0 && ih < param.hi && iw >= 0 && - iw < param.wi) { + if (img_idx < img_pixels && ih >= 0 && ih < param.hi && iw >= 0 && iw < param.wi) { offset[tid] = ih * param.wi + iw; } else { offset[tid] = -1; diff --git a/dnn/src/cuda/batch_conv_bias/helper.cuh b/dnn/src/cuda/batch_conv_bias/helper.cuh index 4434a675..920c026c 100644 --- a/dnn/src/cuda/batch_conv_bias/helper.cuh +++ b/dnn/src/cuda/batch_conv_bias/helper.cuh @@ -15,9 +15,9 @@ namespace megdnn { namespace cuda { namespace batch_conv_bias { -void compute_offset(int* offset, const convolution::ConvParam& param, - cudaStream_t stream); -} // namespace batched_conv2d +void compute_offset( + int* offset, const convolution::ConvParam& param, cudaStream_t stream); +} // namespace batch_conv_bias } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/batch_conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp b/dnn/src/cuda/batch_conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp index ba32390d..beff2991 100644 --- a/dnn/src/cuda/batch_conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp +++ b/dnn/src/cuda/batch_conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp @@ -26,31 +26,31 @@ using namespace cuda; using namespace convolution; namespace { template -void dispatch_nonlinear_mode(const int8_t* d_src, const int8_t* d_filter, - int* d_workspace, BiasVisitor bias_visitor, - const int8_t* d_z, int8_t* d_dst, - const ConvParam& param, float alpha, float beta, - float gamma, float scale, cudaStream_t stream, - param::BatchConvBias::NonlineMode nonlinear_mode) { +void dispatch_nonlinear_mode( + const int8_t* d_src, const int8_t* d_filter, int* d_workspace, + BiasVisitor bias_visitor, const int8_t* d_z, int8_t* d_dst, + const ConvParam& param, float alpha, float beta, float gamma, float scale, + cudaStream_t stream, param::BatchConvBias::NonlineMode nonlinear_mode) { using NonlineMode = megdnn::param_enumv::BatchConvBias::NonlineMode; Layout layout; layout.init(param.n, param.co, param.ho, param.wo); using namespace batch_conv_bias; -#define DISPATCH_CONV_INT8_EPILOGUE(_act_op) \ - do { \ - IConvEpilogue<_act_op> epilogue{d_dst, \ - d_z, \ - layout.batch_stride, \ - layout.channel_stride / 4, \ - layout.height_stride, \ - layout.width_stride, \ - gamma, \ - _act_op{scale, 1.f / scale}}; \ - do_batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4< \ - BiasVisitor, IConvEpilogue<_act_op>>( \ - d_src, d_filter, d_workspace, bias_visitor, epilogue, param, \ - alpha, beta, stream); \ - return; \ +#define DISPATCH_CONV_INT8_EPILOGUE(_act_op) \ + do { \ + IConvEpilogue<_act_op> epilogue{ \ + d_dst, \ + d_z, \ + layout.batch_stride, \ + layout.channel_stride / 4, \ + layout.height_stride, \ + layout.width_stride, \ + gamma, \ + _act_op{scale, 1.f / scale}}; \ + do_batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4< \ + BiasVisitor, IConvEpilogue<_act_op>>( \ + d_src, d_filter, d_workspace, bias_visitor, epilogue, param, alpha, \ + beta, stream); \ + return; \ } while (0) #define cb(_nonline_mode) \ if (static_cast(nonlinear_mode) == NonlineMode::_nonline_mode) { \ @@ -62,21 +62,20 @@ void dispatch_nonlinear_mode(const int8_t* d_src, const int8_t* d_filter, #undef DISPATCH_CONV_INT8_EPILOGUE } -#define INST(_visitor) \ - template void dispatch_nonlinear_mode<_visitor>( \ - const int8_t* d_src, const int8_t* d_filter, int* workspace, \ - _visitor bias_visitor, const int8_t* d_z, int8_t* d_dst, \ - const ConvParam& param, float alpha, float beta, float gamma, \ - float scale, cudaStream_t stream, \ - param::BatchConvBias::NonlineMode nonlinear_mode); +#define INST(_visitor) \ + template void dispatch_nonlinear_mode<_visitor>( \ + const int8_t* d_src, const int8_t* d_filter, int* workspace, \ + _visitor bias_visitor, const int8_t* d_z, int8_t* d_dst, \ + const ConvParam& param, float alpha, float beta, float gamma, float scale, \ + cudaStream_t stream, param::BatchConvBias::NonlineMode nonlinear_mode); INST(PerChannelBiasVisitor); #undef INST } // namespace -bool BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp:: - is_available(const SizeArgs& args) const { +bool BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp::is_available( + const SizeArgs& args) const { if (args.bias_layout.ndim <= 0) return false; @@ -90,20 +89,20 @@ bool BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp:: return false; if (param.format != Format::NCHW4) return false; - UNPACK_BATCH_CONV_BIAS_NCHW4_PARAM(args.src_layout, args.filter_layout, - args.dst_layout, param); + UNPACK_BATCH_CONV_BIAS_NCHW4_PARAM( + args.src_layout, args.filter_layout, args.dst_layout, param); // TODO support group conv available &= param.sparse == Sparse::DENSE; // mode must be cross correlation available &= param.mode == Mode::CROSS_CORRELATION; // check data type - auto src_dtype = args.src_layout.dtype, - filter_dtype = args.filter_layout.dtype, + auto src_dtype = args.src_layout.dtype, filter_dtype = args.filter_layout.dtype, bias_dtype = args.bias_layout.dtype, dst_dtype = args.dst_layout.dtype; - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8 && - bias_dtype.enumv() == DTypeEnum::QuantizedS32 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8); + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + bias_dtype.enumv() == DTypeEnum::QuantizedS32 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); // TODO: support dialtion available &= dh == 1 && dw == 1; // TODO: support fh fw != 1 @@ -117,8 +116,8 @@ bool BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp:: size_t BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp:: get_workspace_in_bytes(const SizeArgs& args) const { auto&& param = args.opr->param(); - UNPACK_BATCH_CONV_BIAS_NCHW4_PARAM(args.src_layout, args.filter_layout, - args.dst_layout, param); + UNPACK_BATCH_CONV_BIAS_NCHW4_PARAM( + args.src_layout, args.filter_layout, args.dst_layout, param); size_t img_pixels = ho * wo; size_t img_pixels_ru128 = round_up(img_pixels, 128_z); size_t filter_pixels = fh * fw; @@ -129,25 +128,21 @@ void BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp::exec( const ExecArgs& args) const { using Format = Param::Format; auto&& param = args.opr->param(); - UNPACK_BATCH_CONV_BIAS_NCHW4_PARAM(args.src_layout, args.filter_layout, - args.dst_layout, param); + UNPACK_BATCH_CONV_BIAS_NCHW4_PARAM( + args.src_layout, args.filter_layout, args.dst_layout, param); auto&& stream = cuda_stream(args.opr->handle()); ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, + kern_param.pw = pw, kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, kern_param.fw = fw; float src_scale = args.src_layout.dtype.param().scale, - filter_scale = - args.filter_layout.dtype.param().scale, - bias_scale = - args.bias_layout.dtype.param().scale, + filter_scale = args.filter_layout.dtype.param().scale, + bias_scale = args.bias_layout.dtype.param().scale, dst_scale = args.dst_layout.dtype.param().scale; - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale; int8_t* z_dev_ptr = nullptr; float gamma = 1.f; if (args.z_layout.ndim > 0) { @@ -160,9 +155,9 @@ void BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp::exec( dispatch_nonlinear_mode( args.src_tensor->compatible_ptr(), args.filter_tensor->compatible_ptr(), - reinterpret_cast(args.workspace.raw_ptr), bias_visitor, - z_dev_ptr, args.dst_tensor->compatible_ptr(), kern_param, - alpha, beta, gamma, dst_scale, stream, param.nonlineMode); + reinterpret_cast(args.workspace.raw_ptr), bias_visitor, z_dev_ptr, + args.dst_tensor->compatible_ptr(), kern_param, alpha, beta, gamma, + dst_scale, stream, param.nonlineMode); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_hswish.cu b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_hswish.cu index 5e3ba1ac..cacf8146 100644 --- a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_hswish.cu +++ b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_batch_cuda_conv_bias_kern_impls.py #include "../batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128.cuinl" -template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::batch_conv_bias:: + do_batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_id.cu b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_id.cu index 5413afa3..46e00e3e 100644 --- a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_id.cu +++ b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_batch_cuda_conv_bias_kern_impls.py #include "../batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128.cuinl" -template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::batch_conv_bias:: + do_batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_relu.cu b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_relu.cu index bfd64cad..92da9af1 100644 --- a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_relu.cu +++ b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_batch_cuda_conv_bias_kern_impls.py #include "../batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128.cuinl" -template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::batch_conv_bias:: + do_batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_hswish.cu b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_hswish.cu index 9b4e34f8..b60811dd 100644 --- a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_hswish.cu +++ b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_hswish.cu @@ -1,13 +1,12 @@ // generated by gen_batch_cuda_conv_bias_kern_impls.py #include "../batch_conv_bias_int8_gemm_ncdiv4hw4.cuinl" -template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_gemm_ncdiv4hw4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_gemm_ncdiv4hw4< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_id.cu b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_id.cu index c14e0e96..46391d43 100644 --- a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_id.cu +++ b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_id.cu @@ -1,13 +1,12 @@ // generated by gen_batch_cuda_conv_bias_kern_impls.py #include "../batch_conv_bias_int8_gemm_ncdiv4hw4.cuinl" -template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_gemm_ncdiv4hw4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_gemm_ncdiv4hw4< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_relu.cu b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_relu.cu index 43caacc9..49ee0efb 100644 --- a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_relu.cu +++ b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_gemm_ncdiv4hw4_per_chan_relu.cu @@ -1,13 +1,11 @@ // generated by gen_batch_cuda_conv_bias_kern_impls.py #include "../batch_conv_bias_int8_gemm_ncdiv4hw4.cuinl" -template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_gemm_ncdiv4hw4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_gemm_ncdiv4hw4< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_hswish.cu b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_hswish.cu index 74b7dd7c..0b42b023 100644 --- a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_hswish.cu +++ b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_hswish.cu @@ -1,14 +1,14 @@ // generated by gen_batch_cuda_conv_bias_kern_impls.py #include "../batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4.cuinl" -template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4>>( - const int8_t* d_src, - const int8_t* d_filter, -int* d_workspace, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::batch_conv_bias:: + do_batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, int* d_workspace, + PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_id.cu b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_id.cu index 9cedc650..d401fcc6 100644 --- a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_id.cu +++ b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_id.cu @@ -1,14 +1,14 @@ // generated by gen_batch_cuda_conv_bias_kern_impls.py #include "../batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4.cuinl" -template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4>>( - const int8_t* d_src, - const int8_t* d_filter, -int* d_workspace, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::batch_conv_bias:: + do_batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, int* d_workspace, + PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_relu.cu b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_relu.cu index 9de1a406..6106cd22 100644 --- a/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_relu.cu +++ b/dnn/src/cuda/batch_conv_bias/int8/kimpl/batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4_per_chan_relu.cu @@ -1,14 +1,14 @@ // generated by gen_batch_cuda_conv_bias_kern_impls.py #include "../batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4.cuinl" -template void megdnn::cuda::batch_conv_bias::do_batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4>>( - const int8_t* d_src, - const int8_t* d_filter, -int* d_workspace, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::batch_conv_bias:: + do_batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, int* d_workspace, + PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/batch_conv_bias/opr_impl.cpp b/dnn/src/cuda/batch_conv_bias/opr_impl.cpp index 5e5ff3af..f3bb821d 100644 --- a/dnn/src/cuda/batch_conv_bias/opr_impl.cpp +++ b/dnn/src/cuda/batch_conv_bias/opr_impl.cpp @@ -17,13 +17,10 @@ using namespace megdnn; using namespace cuda; /* ============== BatchConvBiasForwardImpl ============== */ -BatchConvBiasForwardImpl::Algorithm* -BatchConvBiasForwardImpl::get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { +BatchConvBiasForwardImpl::Algorithm* BatchConvBiasForwardImpl::get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, src, filter, bias, z, dst); if (sm_algo_pack.int8_nchw4_gemm_dotprod.is_available_attribute( args, positive_attr, negative_attr, workspace_limit_in_bytes)) { @@ -33,51 +30,47 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return &sm_algo_pack.int8_nchw4_implicit_gemm_dotprod; } - megdnn_throw( - ssprintf("no batch conv bias algorithm without attribute(%s) with " - "attribute(%s) args(%s) and " - "workspace limit (%zu bytes)", - Algorithm::attribute_str(negative_attr).c_str(), - Algorithm::attribute_str(positive_attr).c_str(), - args.to_string().c_str(), workspace_limit_in_bytes)); + megdnn_throw(ssprintf( + "no batch conv bias algorithm without attribute(%s) with " + "attribute(%s) args(%s) and " + "workspace limit (%zu bytes)", + Algorithm::attribute_str(negative_attr).c_str(), + Algorithm::attribute_str(positive_attr).c_str(), args.to_string().c_str(), + workspace_limit_in_bytes)); } -std::vector -BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, - const TensorLayout& z, - const TensorLayout& dst) { +std::vector BatchConvBiasForwardImpl:: + get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) { AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; return megdnn::get_all_algorithms(args); } -std::vector -BatchConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, - const TensorLayout& z, - const TensorLayout& dst) { +std::vector BatchConvBiasForwardImpl:: + get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) { AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; return megdnn::get_all_algorithms_safe(args); } size_t BatchConvBiasForwardImpl::get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst) { + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst) { return get_dnn_workspace(this, src, filter, bias, z, dst); } -void BatchConvBiasForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_in bias, _megdnn_tensor_in z, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) { - check_exec(src.layout, filter.layout, bias.layout, z.layout, dst.layout, - workspace.size); +void BatchConvBiasForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias, + _megdnn_tensor_in z, _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec( + src.layout, filter.layout, bias.layout, z.layout, dst.layout, + workspace.size); AlgoBase::ExecArgs args(this, src, filter, bias, z, dst, workspace); - auto algo = get_algorithm(this, src.layout, filter.layout, bias.layout, - z.layout, dst.layout); + auto algo = get_algorithm( + this, src.layout, filter.layout, bias.layout, z.layout, dst.layout); algo->exec(args); } diff --git a/dnn/src/cuda/batch_conv_bias/opr_impl.h b/dnn/src/cuda/batch_conv_bias/opr_impl.h index 29ab7a30..9c731ae8 100644 --- a/dnn/src/cuda/batch_conv_bias/opr_impl.h +++ b/dnn/src/cuda/batch_conv_bias/opr_impl.h @@ -18,14 +18,14 @@ namespace cuda { class BatchConvBiasForwardImpl : public BatchConvBiasForward { public: using BatchConvBiasForward::BatchConvBiasForward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_in bias, _megdnn_tensor_in z, - _megdnn_tensor_out dst, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, - const TensorLayout& z, - const TensorLayout& dst) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias, + _megdnn_tensor_in z, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst) override; const char* get_algorithm_set_name() const override; class AlgoBase; @@ -48,9 +48,8 @@ protected: const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, + const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) override; private: diff --git a/dnn/src/cuda/batch_normalization/opr_impl.cpp b/dnn/src/cuda/batch_normalization/opr_impl.cpp index c0606a8d..999b9896 100644 --- a/dnn/src/cuda/batch_normalization/opr_impl.cpp +++ b/dnn/src/cuda/batch_normalization/opr_impl.cpp @@ -17,9 +17,8 @@ namespace cuda { namespace batch_normalization { -BNTensorDescHolder::BNTensorDescHolder(const TensorLayout& x, - const ParamDim& param_dim, - const FwdMode& fwd_mode) { +BNTensorDescHolder::BNTensorDescHolder( + const TensorLayout& x, const ParamDim& param_dim, const FwdMode& fwd_mode) { TensorShape xy_shape(x); Format xy_format = Format::NCHW; @@ -52,8 +51,8 @@ BNTensorDescHolder::BNTensorDescHolder(const TensorLayout& x, param_desc.set(xy_desc.desc, bn_mode); } -size_t get_reserve_size(const cudnnHandle_t& handle, - const BNTensorDescHolder& tensor_desc) { +size_t get_reserve_size( + const cudnnHandle_t& handle, const BNTensorDescHolder& tensor_desc) { #if CUDNN_VERSION >= 7410 size_t reserve_size; cudnn_check(cudnnGetBatchNormalizationTrainingExReserveSpaceSize( @@ -95,59 +94,55 @@ size_t BNForwardImpl::get_workspace_in_bytes( size_t BNForwardImpl::get_reserve_in_bytes(const TensorLayout& src) { BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode); - return batch_normalization::get_reserve_size(cudnn_handle(this->handle()), - tensor_desc); + return batch_normalization::get_reserve_size( + cudnn_handle(this->handle()), tensor_desc); } -void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, - _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, - _megdnn_tensor_out variance, - _megdnn_tensor_out batch_mean, - _megdnn_tensor_out batch_inv_variance, - _megdnn_tensor_out reserve, _megdnn_tensor_out dst, - _megdnn_workspace workspace) { - check_exec(src.layout, bn_scale.layout, bn_bias.layout, mean.layout, - variance.layout, batch_mean.layout, batch_inv_variance.layout, - dst.layout, workspace.size, reserve.layout.access_bytes()); +void BNForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in bn_scale, _megdnn_tensor_in bn_bias, + _megdnn_tensor_out mean, _megdnn_tensor_out variance, + _megdnn_tensor_out batch_mean, _megdnn_tensor_out batch_inv_variance, + _megdnn_tensor_out reserve, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec( + src.layout, bn_scale.layout, bn_bias.layout, mean.layout, variance.layout, + batch_mean.layout, batch_inv_variance.layout, dst.layout, workspace.size, + reserve.layout.access_bytes()); auto handle = cudnn_handle(this->handle()); - BNTensorDescHolder tensor_desc(src.layout, m_param.param_dim, - m_param.fwd_mode); + BNTensorDescHolder tensor_desc(src.layout, m_param.param_dim, m_param.fwd_mode); float alpha = 1.0f, beta = 0.0f; switch (m_param.fwd_mode) { case param::BN::FwdMode::TRAINING: #if CUDNN_VERSION >= 7410 cudnn_check(cudnnBatchNormalizationForwardTrainingEx( - handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, - &alpha, &beta, // one & zero + handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, + &beta, // one & zero tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x nullptr, nullptr, // zDesc & z tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y - tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc - bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, - mean.raw_ptr, variance.raw_ptr, m_param.epsilon, - batch_mean.raw_ptr, batch_inv_variance.raw_ptr, nullptr, - workspace.raw_ptr, workspace.size, reserve.raw_ptr, - reserve.layout.access_bytes())); + tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc + bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, mean.raw_ptr, + variance.raw_ptr, m_param.epsilon, batch_mean.raw_ptr, + batch_inv_variance.raw_ptr, nullptr, workspace.raw_ptr, + workspace.size, reserve.raw_ptr, reserve.layout.access_bytes())); #else cudnn_check(cudnnBatchNormalizationForwardTraining( handle, tensor_desc.bn_mode, &alpha, &beta, tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y - tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc - bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, - mean.raw_ptr, variance.raw_ptr, m_param.epsilon, - batch_mean.raw_ptr, batch_inv_variance.raw_ptr)); + tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc + bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, mean.raw_ptr, + variance.raw_ptr, m_param.epsilon, batch_mean.raw_ptr, + batch_inv_variance.raw_ptr)); #endif // CUDNN_VERSION >= 7410 break; case param::BN::FwdMode::INFERENCE: cudnn_check(cudnnBatchNormalizationForwardInference( handle, tensor_desc.bn_mode, &alpha, &beta, - tensor_desc.xy_desc.desc, src.raw_ptr, - tensor_desc.xy_desc.desc, dst.raw_ptr, - tensor_desc.param_desc.desc, bn_scale.raw_ptr, - bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, - m_param.epsilon)); + tensor_desc.xy_desc.desc, src.raw_ptr, tensor_desc.xy_desc.desc, + dst.raw_ptr, tensor_desc.param_desc.desc, bn_scale.raw_ptr, + bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, m_param.epsilon)); break; default: megdnn_throw("Unknown forward mode type of batch normalization."); @@ -181,30 +176,28 @@ size_t BNBackwardImpl::get_workspace_in_bytes( size_t BNBackwardImpl::get_reserve_in_bytes(const TensorLayout& src) { BNTensorDescHolder tensor_desc(src, m_param.param_dim, m_param.fwd_mode); - return batch_normalization::get_reserve_size(cudnn_handle(this->handle()), - tensor_desc); + return batch_normalization::get_reserve_size( + cudnn_handle(this->handle()), tensor_desc); } -void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, - _megdnn_tensor_in saved_batch_mean, - _megdnn_tensor_in saved_batch_inv_variance, - _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, - _megdnn_tensor_out d_bn_scale, - _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, - _megdnn_workspace workspace) { - check_exec(x.layout, dy.layout, saved_batch_mean.layout, - saved_batch_inv_variance.layout, bn_scale.layout, - d_bn_scale.layout, d_bn_bias.layout, dx.layout, workspace.size, - reserve.layout.access_bytes()); +void BNBackwardImpl::exec( + _megdnn_tensor_in x, _megdnn_tensor_in dy, _megdnn_tensor_in saved_batch_mean, + _megdnn_tensor_in saved_batch_inv_variance, _megdnn_tensor_in bn_scale, + _megdnn_tensor_in reserve, _megdnn_tensor_out d_bn_scale, + _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, + _megdnn_workspace workspace) { + check_exec( + x.layout, dy.layout, saved_batch_mean.layout, + saved_batch_inv_variance.layout, bn_scale.layout, d_bn_scale.layout, + d_bn_bias.layout, dx.layout, workspace.size, reserve.layout.access_bytes()); auto handle = cudnn_handle(this->handle()); - BNTensorDescHolder tensor_desc(x.layout, m_param.param_dim, - m_param.fwd_mode); + BNTensorDescHolder tensor_desc(x.layout, m_param.param_dim, m_param.fwd_mode); float alpha = 1.0, beta = 0.0; #if CUDNN_VERSION >= 7410 cudnn_check(cudnnBatchNormalizationBackwardEx( - handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta, - &alpha, &beta, tensor_desc.xy_desc.desc, + handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta, &alpha, + &beta, tensor_desc.xy_desc.desc, x.raw_ptr, // xDesc & x nullptr, nullptr, // yDesc & y tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy @@ -213,9 +206,9 @@ void BNBackwardImpl::exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale nullptr, // bnBias d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias - m_param.epsilon, saved_batch_mean.raw_ptr, - saved_batch_inv_variance.raw_ptr, nullptr, workspace.raw_ptr, - workspace.size, reserve.raw_ptr, reserve.layout.access_bytes())); + m_param.epsilon, saved_batch_mean.raw_ptr, saved_batch_inv_variance.raw_ptr, + nullptr, workspace.raw_ptr, workspace.size, reserve.raw_ptr, + reserve.layout.access_bytes())); #else cudnn_check(cudnnBatchNormalizationBackward( handle, tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta, diff --git a/dnn/src/cuda/batch_normalization/opr_impl.h b/dnn/src/cuda/batch_normalization/opr_impl.h index acfb29a8..83b3ed90 100644 --- a/dnn/src/cuda/batch_normalization/opr_impl.h +++ b/dnn/src/cuda/batch_normalization/opr_impl.h @@ -27,47 +27,47 @@ struct BNTensorDescHolder { BNParamDesc param_desc; cudnnBatchNormMode_t bn_mode; - BNTensorDescHolder(const TensorLayout& x, const ParamDim& param_dim, - const FwdMode& fwd_mode); + BNTensorDescHolder( + const TensorLayout& x, const ParamDim& param_dim, const FwdMode& fwd_mode); }; -size_t get_reserve_size(const cudnnHandle_t& handle, - const BNTensorDescHolder& tensor_desc); +size_t get_reserve_size( + const cudnnHandle_t& handle, const BNTensorDescHolder& tensor_desc); } // namespace batch_normalization class BNForwardImpl final : public BNForward { public: using BNForward::BNForward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale, - _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, - _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, - _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, - _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in bn_scale, + _megdnn_tensor_in bn_bias, _megdnn_tensor_out mean, + _megdnn_tensor_out variance, _megdnn_tensor_out batch_mean, + _megdnn_tensor_out batch_inv_variance, _megdnn_tensor_out reserve, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout&, - const TensorLayout&, const TensorLayout&, - const TensorLayout&, const TensorLayout&, - const TensorLayout&, const TensorLayout&, - const TensorLayout&) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&) override; size_t get_reserve_in_bytes(const TensorLayout& src) override; }; class BNBackwardImpl final : public BNBackward { public: using BNBackward::BNBackward; - void exec(_megdnn_tensor_in x, _megdnn_tensor_in dy, - _megdnn_tensor_in saved_batch_mean, - _megdnn_tensor_in saved_batch_inv_variance, - _megdnn_tensor_in bn_scale, _megdnn_tensor_in reserve, - _megdnn_tensor_out d_bn_scale, _megdnn_tensor_out d_bn_bias, - _megdnn_tensor_out dx, _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in x, _megdnn_tensor_in dy, + _megdnn_tensor_in saved_batch_mean, + _megdnn_tensor_in saved_batch_inv_variance, _megdnn_tensor_in bn_scale, + _megdnn_tensor_in reserve, _megdnn_tensor_out d_bn_scale, + _megdnn_tensor_out d_bn_bias, _megdnn_tensor_out dx, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& x, const TensorLayout&, - const TensorLayout&, const TensorLayout&, - const TensorLayout&, const TensorLayout&, - const TensorLayout&, const TensorLayout&, - const TensorLayout&) override; + size_t get_workspace_in_bytes( + const TensorLayout& x, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&) override; size_t get_reserve_in_bytes(const TensorLayout& src) override; }; diff --git a/dnn/src/cuda/batched_matrix_mul/algo.cpp b/dnn/src/cuda/batched_matrix_mul/algo.cpp index 3f3994de..04cd2634 100644 --- a/dnn/src/cuda/batched_matrix_mul/algo.cpp +++ b/dnn/src/cuda/batched_matrix_mul/algo.cpp @@ -30,18 +30,18 @@ std::string BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { return ssprintf( "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " "B=%d,ldA=%zu,ldB=%zu,ldC=%zu", - m, k, k, n, m, n, param.transposeA, param.transposeB, - layout_a.stride[0], layout_b.stride[0], layout_c.stride[0]); + m, k, k, n, m, n, param.transposeA, param.transposeB, layout_a.stride[0], + layout_b.stride[0], layout_c.stride[0]); } BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs( - BatchedMatrixMulForwardImpl* o, const TensorLayout& A, - const TensorLayout& B, const TensorLayout& C) + BatchedMatrixMulForwardImpl* o, const TensorLayout& A, const TensorLayout& B, + const TensorLayout& C) : opr(o), layout_a(A), layout_b(B), layout_c(C){}; BatchedMatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs( - BatchedMatrixMulForwardImpl* o, _megdnn_tensor_in A, - _megdnn_tensor_in B, _megdnn_tensor_in C, _megdnn_workspace workspace) + BatchedMatrixMulForwardImpl* o, _megdnn_tensor_in A, _megdnn_tensor_in B, + _megdnn_tensor_in C, _megdnn_workspace workspace) : SizeArgs(o, A.layout, B.layout, C.layout), tensor_a{A}, tensor_b{B}, diff --git a/dnn/src/cuda/batched_matrix_mul/algo.h b/dnn/src/cuda/batched_matrix_mul/algo.h index 637d3da7..90c85ada 100644 --- a/dnn/src/cuda/batched_matrix_mul/algo.h +++ b/dnn/src/cuda/batched_matrix_mul/algo.h @@ -13,10 +13,10 @@ #include #include "megdnn/dtype.h" #include "megdnn/oprs.h" +#include "src/common/metahelper.h" #include "src/common/utils.h" #include "src/cuda/batched_matrix_mul/opr_impl.h" #include "src/cuda/matrix_mul/cublasLt_wrapper.h" -#include "src/common/metahelper.h" #if CUDA_VERSION >= 10010 #include @@ -43,8 +43,9 @@ public: BatchedMatrixMulForwardImpl* opr; TensorLayout layout_a, layout_b, layout_c; std::string to_string() const; - SizeArgs(BatchedMatrixMulForwardImpl* o, const TensorLayout& A, - const TensorLayout& B, const TensorLayout& C); + SizeArgs( + BatchedMatrixMulForwardImpl* o, const TensorLayout& A, + const TensorLayout& B, const TensorLayout& C); bool can_be_treated_as_int8x8x32() const { return layout_a.dtype.enumv() == layout_b.dtype.enumv() && (layout_a.dtype.enumv() == DTypeEnum::Int8 || @@ -57,9 +58,9 @@ public: struct ExecArgs : public SizeArgs { TensorND tensor_a, tensor_b, tensor_c; Workspace workspace; - ExecArgs(BatchedMatrixMulForwardImpl* o, _megdnn_tensor_in A, - _megdnn_tensor_in B, _megdnn_tensor_in C, - _megdnn_workspace workspace); + ExecArgs( + BatchedMatrixMulForwardImpl* o, _megdnn_tensor_in A, + _megdnn_tensor_in B, _megdnn_tensor_in C, _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -74,22 +75,22 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "batched matrix mul fwd algo %s: required workspace %zu " - "bytes, got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "batched matrix mul fwd algo %s: required workspace %zu " + "bytes, got %zu", + name(), req, workspace.size); return *this; } }; class BatchedMatrixMulForwardImpl::AlgoBruteForce final : public BatchedMatrixMulForwardImpl::AlgoBase { using Param = MatrixMulForward::Param; + private: WorkspaceBundle get_workspace_bundle(); @@ -97,15 +98,12 @@ public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; void exec(const ExecArgs& args) const final; - AlgoAttribute attribute()const override{ - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "BRUTE_FORCE"; } MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE) std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; }; class BatchedMatrixMulForwardImpl::AlgoCublas final : public BatchedMatrixMulForwardImpl::AlgoBase { @@ -115,8 +113,7 @@ public: size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; void exec(const ExecArgs& args) const final; AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; } const char* name() const override { return "CUBLAS"; } MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) @@ -129,8 +126,7 @@ public: size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; void exec(const ExecArgs& args) const final; AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; } const char* name() const override { return "CUBLAS_LT"; } MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) @@ -143,9 +139,7 @@ public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; void exec(const ExecArgs& args) const final; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "INT8x8x32"; } MEGDNN_DECL_ALGO_TYPE(CUDA_INT8X8X32) }; diff --git a/dnn/src/cuda/batched_matrix_mul/brute_force.cpp b/dnn/src/cuda/batched_matrix_mul/brute_force.cpp index 96f03c20..5f6e163f 100644 --- a/dnn/src/cuda/batched_matrix_mul/brute_force.cpp +++ b/dnn/src/cuda/batched_matrix_mul/brute_force.cpp @@ -12,8 +12,8 @@ #include #include "./algo.h" #include "megdnn/opr_param_defs.h" -#include "src/common/algo_chooser.h" #include "src/common/algo_base.h" +#include "src/common/algo_chooser.h" #include "src/cuda/handle.h" #include "src/cuda/utils.h" @@ -37,8 +37,8 @@ std::pair> prepare_sub_opr( set_execution_policy( args.opr, matmul_opr.get()); - auto&& config = sub_opr_config(args.layout_a, args.layout_b, args.layout_c, - args.opr); + auto&& config = + sub_opr_config(args.layout_a, args.layout_b, args.layout_c, args.opr); matmul_opr->param() = config.second; return {config.first, std::move(matmul_opr)}; @@ -46,9 +46,9 @@ std::pair> prepare_sub_opr( } // namespace -std::vector -BatchedMatrixMulForwardImpl::AlgoBruteForce::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { +std::vector BatchedMatrixMulForwardImpl::AlgoBruteForce:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { const BatchedMatrixMulForwardImpl* bmm_opr = static_cast(opr); auto&& config = sub_opr_config(layouts[0], layouts[1], layouts[2], bmm_opr); @@ -63,8 +63,8 @@ bool BatchedMatrixMulForwardImpl::AlgoBruteForce::is_available( auto config = prepare_sub_opr(args); return get_algorithm( - static_cast(config.second.get()), - config.first[0], config.first[1], config.first[2]); + static_cast(config.second.get()), config.first[0], + config.first[1], config.first[2]); } size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes( const SizeArgs& args) const { @@ -73,17 +73,16 @@ size_t BatchedMatrixMulForwardImpl::AlgoBruteForce::get_workspace_in_bytes( return config.second->get_workspace_in_bytes( config.first[0], config.first[1], config.first[2]); } -void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec( - const ExecArgs& args) const { +void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec(const ExecArgs& args) const { auto N = args.layout_a.shape[0]; auto config = prepare_sub_opr(args); rep(n, N) { TensorND A_, B_, C_; auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { - out.raw_ptr = static_cast(static_cast(in.raw_ptr) + - n * in.layout.stride[0] * - in.layout.dtype.size()); + out.raw_ptr = static_cast( + static_cast(in.raw_ptr) + + n * in.layout.stride[0] * in.layout.dtype.size()); out.layout = in.layout.remove_axis(0); }; tensor_n_from_batch(args.tensor_a, A_); diff --git a/dnn/src/cuda/batched_matrix_mul/cublas.cpp b/dnn/src/cuda/batched_matrix_mul/cublas.cpp index 81db2372..b2261fec 100644 --- a/dnn/src/cuda/batched_matrix_mul/cublas.cpp +++ b/dnn/src/cuda/batched_matrix_mul/cublas.cpp @@ -18,8 +18,7 @@ using namespace megdnn; using namespace cuda; using namespace batched_matrix_mul; -bool BatchedMatrixMulForwardImpl::AlgoCublas::is_available( - const SizeArgs& args) const { +bool BatchedMatrixMulForwardImpl::AlgoCublas::is_available(const SizeArgs& args) const { auto dtype = args.layout_a.dtype; auto&& param = args.opr->param(); auto&& handle = concrete_handle(args.opr->handle()); @@ -61,19 +60,22 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { auto k = args.layout_a.shape[param.transposeA ? 1 : 2]; auto workspace = args.workspace; - uintptr_t* As = static_cast(static_cast( - workspace.raw_ptr + 0 * batch * sizeof(uintptr_t))); - uintptr_t* Bs = static_cast(static_cast( - workspace.raw_ptr + 1 * batch * sizeof(uintptr_t))); - uintptr_t* Cs = static_cast(static_cast( - workspace.raw_ptr + 2 * batch * sizeof(uintptr_t))); + uintptr_t* As = static_cast( + static_cast(workspace.raw_ptr + 0 * batch * sizeof(uintptr_t))); + uintptr_t* Bs = static_cast( + static_cast(workspace.raw_ptr + 1 * batch * sizeof(uintptr_t))); + uintptr_t* Cs = static_cast( + static_cast(workspace.raw_ptr + 2 * batch * sizeof(uintptr_t))); - arange(As, reinterpret_cast(args.tensor_a.raw_ptr), - args.layout_a.stride[0] * dtype.size(), batch, stream); - arange(Bs, reinterpret_cast(args.tensor_b.raw_ptr), - args.layout_b.stride[0] * dtype.size(), batch, stream); - arange(Cs, reinterpret_cast(args.tensor_c.raw_ptr), - args.layout_c.stride[0] * dtype.size(), batch, stream); + arange( + As, reinterpret_cast(args.tensor_a.raw_ptr), + args.layout_a.stride[0] * dtype.size(), batch, stream); + arange( + Bs, reinterpret_cast(args.tensor_b.raw_ptr), + args.layout_b.stride[0] * dtype.size(), batch, stream); + arange( + Cs, reinterpret_cast(args.tensor_c.raw_ptr), + args.layout_c.stride[0] * dtype.size(), batch, stream); auto io32_c32 = [&]() { auto zero = handle->zero_device(); @@ -81,12 +83,9 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { cublas_check(cublasSgemmBatched( cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, - reinterpret_cast(Bs), - args.layout_b.stride[1], - reinterpret_cast(As), - args.layout_a.stride[1], zero, - reinterpret_cast(Cs), args.layout_c.stride[1], - batch)); + reinterpret_cast(Bs), args.layout_b.stride[1], + reinterpret_cast(As), args.layout_a.stride[1], zero, + reinterpret_cast(Cs), args.layout_c.stride[1], batch)); }; #if CUDART_VERSION >= 9010 @@ -97,12 +96,10 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { cublas_check(cublasGemmBatchedEx( cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, - reinterpret_cast(Bs), CUDA_R_16F, - args.layout_b.stride[1], reinterpret_cast(As), - CUDA_R_16F, args.layout_a.stride[1], zero, - reinterpret_cast(Cs), CUDA_R_16F, - args.layout_c.stride[1], batch, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT)); + reinterpret_cast(Bs), CUDA_R_16F, args.layout_b.stride[1], + reinterpret_cast(As), CUDA_R_16F, args.layout_a.stride[1], + zero, reinterpret_cast(Cs), CUDA_R_16F, args.layout_c.stride[1], + batch, CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH)); }; #endif @@ -116,9 +113,8 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, reinterpret_cast(Bs), args.layout_b.stride[1], - reinterpret_cast(As), args.layout_a.stride[1], - zero, reinterpret_cast<__half**>(Cs), args.layout_c.stride[1], - batch)); + reinterpret_cast(As), args.layout_a.stride[1], zero, + reinterpret_cast<__half**>(Cs), args.layout_c.stride[1], batch)); cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH)); }; #endif diff --git a/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp b/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp index 627ff031..b7a56da1 100644 --- a/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp +++ b/dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp @@ -10,8 +10,8 @@ */ #include "./algo.h" #include "src/cuda/handle.h" -#include "src/cuda/utils.h" #include "src/cuda/matrix_mul/cublasLt_wrapper.h" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; @@ -23,8 +23,7 @@ static inline CUBLASLTMatmulDesc::SizeArgs from_local_size_args( auto&& handle = concrete_handle(args.opr->handle()); bool transA = param.transposeA; bool transB = param.transposeB; - return {handle, transA, transB, - args.layout_a, args.layout_b, args.layout_c}; + return {handle, transA, transB, args.layout_a, args.layout_b, args.layout_c}; } bool BatchedMatrixMulForwardImpl::AlgoCublasLt::is_available( @@ -33,7 +32,7 @@ bool BatchedMatrixMulForwardImpl::AlgoCublasLt::is_available( auto&& dev_prop = current_device_prop(); bool is_dev_support = dev_prop.major >= 7; bool res = is_dev_support && CUBLASLTMatmulDesc(cublasLt_args, true) - .is_available(cublasLt_args, INT_MAX); + .is_available(cublasLt_args, INT_MAX); return res; } @@ -46,8 +45,7 @@ size_t BatchedMatrixMulForwardImpl::AlgoCublasLt::get_workspace_in_bytes( return desc.get_workspace_bundle(cublasLt_args, algo).total_size_in_bytes(); } -void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec( - const ExecArgs& args) const { +void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const { auto cublasLt_args = from_local_size_args(args); cublasLtMatmulAlgo_t algo; CUBLASLTMatmulDesc desc(cublasLt_args, true); @@ -59,40 +57,36 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec( auto batched_hgemm = [&]() { auto zero_half = handle->zero_device_h(); auto one_half = handle->one_device_h(); - megdnn_assert(ws_bundle.nr_workspace() == 1, - "workspace bundle size should be 1(ws_algo)"); + megdnn_assert( + ws_bundle.nr_workspace() == 1, + "workspace bundle size should be 1(ws_algo)"); cublas_check(cublasLtMatmul( cublasLt_handle, desc.matmul_desc, one_half, - static_cast(args.tensor_b.raw_ptr), - desc.layout_b, - static_cast(args.tensor_a.raw_ptr), - desc.layout_a, zero_half, - static_cast(args.tensor_c.raw_ptr), + static_cast(args.tensor_b.raw_ptr), desc.layout_b, + static_cast(args.tensor_a.raw_ptr), desc.layout_a, + zero_half, static_cast(args.tensor_c.raw_ptr), desc.layout_c, static_cast<__half*>(args.tensor_c.raw_ptr), - desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), - stream)); + desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), stream)); }; auto batched_sgemm = [&]() { auto zero = handle->zero_device(); auto one = handle->one_device(); - auto dev_b = - (desc.dt_b == CUDA_R_16F) - ? static_cast(args.tensor_b.ptr()) - : static_cast(args.tensor_b.ptr()); - auto dev_a = - (desc.dt_a == CUDA_R_16F) - ? static_cast(args.tensor_a.ptr()) - : static_cast(args.tensor_a.ptr()); + auto dev_b = (desc.dt_b == CUDA_R_16F) + ? static_cast(args.tensor_b.ptr()) + : static_cast(args.tensor_b.ptr()); + auto dev_a = (desc.dt_a == CUDA_R_16F) + ? static_cast(args.tensor_a.ptr()) + : static_cast(args.tensor_a.ptr()); auto dev_c = static_cast(args.tensor_c.raw_ptr); - megdnn_assert(ws_bundle.nr_workspace() == 1, - "workspace bundle size should be 1(ws_algo)"); - cublas_check(cublasLtMatmul(cublasLt_handle, desc.matmul_desc, one, - dev_b, desc.layout_b, dev_a, desc.layout_a, - zero, dev_c, desc.layout_c, dev_c, - desc.layout_c, &algo, ws_bundle.get(0), - ws_bundle.get_size(0), stream)); + megdnn_assert( + ws_bundle.nr_workspace() == 1, + "workspace bundle size should be 1(ws_algo)"); + cublas_check(cublasLtMatmul( + cublasLt_handle, desc.matmul_desc, one, dev_b, desc.layout_b, dev_a, + desc.layout_a, zero, dev_c, desc.layout_c, dev_c, desc.layout_c, &algo, + ws_bundle.get(0), ws_bundle.get_size(0), stream)); }; - + auto batched_igemm = [&]() { auto zero = handle->zero_device(); auto one = handle->one_device(); @@ -105,34 +99,32 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec( int32_t pm = CUBLAS_POINTER_MODE_DEVICE; cublasOperation_t trans_a = CUBLAS_OP_T, trans_c = CUBLAS_OP_N; cublasLtMatrixTransformDesc_t transform_desc = nullptr; - cublas_check( - cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F)); + cublas_check(cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F)); cublas_check(cublasLtMatrixTransformDescSetAttribute( - transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, - &pm, sizeof(pm))); + transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, &pm, + sizeof(pm))); cublas_check(cublasLtMatrixTransform( cublasLt_handle, transform_desc, one, args.tensor_b.raw_ptr, - desc.layout_b, zero, nullptr, nullptr, ws_b, - desc.layout_trans_b, stream)); + desc.layout_b, zero, nullptr, nullptr, ws_b, desc.layout_trans_b, + stream)); cublas_check(cublasLtMatrixTransformDescSetAttribute( transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_a, sizeof(trans_a))); cublas_check(cublasLtMatrixTransform( cublasLt_handle, transform_desc, one, args.tensor_a.raw_ptr, - desc.layout_a, zero, nullptr, nullptr, ws_a, - desc.layout_trans_a, stream)); + desc.layout_a, zero, nullptr, nullptr, ws_a, desc.layout_trans_a, + stream)); cublas_check(cublasLtMatmul( - cublasLt_handle, desc.matmul_desc, one, ws_b, - desc.layout_trans_b, ws_a, desc.layout_trans_a, zero, ws_c, - desc.layout_trans_c, ws_c, desc.layout_trans_c, &algo, - ws_bundle.get(0), ws_bundle.get_size(0), stream)); + cublasLt_handle, desc.matmul_desc, one, ws_b, desc.layout_trans_b, ws_a, + desc.layout_trans_a, zero, ws_c, desc.layout_trans_c, ws_c, + desc.layout_trans_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), + stream)); cublas_check(cublasLtMatrixTransformDescSetAttribute( transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_c, sizeof(trans_c))); cublas_check(cublasLtMatrixTransform( - cublasLt_handle, transform_desc, one, ws_c, desc.layout_trans_c, - zero, nullptr, nullptr, args.tensor_c.raw_ptr, desc.layout_c, - stream)); + cublasLt_handle, transform_desc, one, ws_c, desc.layout_trans_c, zero, + nullptr, nullptr, args.tensor_c.raw_ptr, desc.layout_c, stream)); cublas_check(cublasLtMatrixTransformDescDestroy(transform_desc)); }; diff --git a/dnn/src/cuda/batched_matrix_mul/helper.cu b/dnn/src/cuda/batched_matrix_mul/helper.cu index 3d5e385c..de29522b 100644 --- a/dnn/src/cuda/batched_matrix_mul/helper.cu +++ b/dnn/src/cuda/batched_matrix_mul/helper.cu @@ -13,35 +13,32 @@ namespace { template -__global__ void kernel(T *Xs, T start, uint32_t step, uint32_t n) -{ +__global__ void kernel(T* Xs, T start, uint32_t step, uint32_t n) { uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { - Xs[i] = start + i*step; + Xs[i] = start + i * step; } } -} // anonymous namespace +} // anonymous namespace namespace megdnn { namespace cuda { namespace batched_matrix_mul { template -void arange(T *Xs, T start, uint32_t step, uint32_t n, cudaStream_t stream) -{ +void arange(T* Xs, T start, uint32_t step, uint32_t n, cudaStream_t stream) { uint32_t threads = NR_THREADS; uint32_t blocks = DIVUP(n, threads); kernel<<>>(Xs, start, step, n); after_kernel_launch(); } -template void arange(uintptr_t *, uintptr_t, - uint32_t, uint32_t, cudaStream_t); +template void arange( + uintptr_t*, uintptr_t, uint32_t, uint32_t, cudaStream_t); -} // namespace batched_matrix_mul -} // namespace cuda -} // namespace megdnn +} // namespace batched_matrix_mul +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/batched_matrix_mul/int8x8x32.cpp b/dnn/src/cuda/batched_matrix_mul/int8x8x32.cpp index c9b57e7a..27db5020 100644 --- a/dnn/src/cuda/batched_matrix_mul/int8x8x32.cpp +++ b/dnn/src/cuda/batched_matrix_mul/int8x8x32.cpp @@ -25,18 +25,15 @@ bool BatchedMatrixMulForwardImpl::AlgoInt8x8x32::is_available( return args.can_be_treated_as_int8x8x32(); } -void BatchedMatrixMulForwardImpl::AlgoInt8x8x32::exec( - const ExecArgs& args) const { +void BatchedMatrixMulForwardImpl::AlgoInt8x8x32::exec(const ExecArgs& args) const { auto&& param = args.opr->param(); auto batch_count = args.layout_a.shape[0]; auto m = args.tensor_c.layout.shape[1], n = args.tensor_c.layout.shape[2], k = args.tensor_a.layout.shape[param.transposeA ? 1 : 2]; - auto LDA = args.tensor_a.layout.stride[0], - LDB = args.tensor_b.layout.stride[0], + auto LDA = args.tensor_a.layout.stride[0], LDB = args.tensor_b.layout.stride[0], LDC = args.tensor_c.layout.stride[0]; - auto STA = args.tensor_a.layout.stride[1], - STB = args.tensor_b.layout.stride[1], + auto STA = args.tensor_a.layout.stride[1], STB = args.tensor_b.layout.stride[1], STC = args.tensor_c.layout.stride[1]; int8_t* A = args.tensor_a.compatible_ptr(); @@ -44,9 +41,9 @@ void BatchedMatrixMulForwardImpl::AlgoInt8x8x32::exec( int32_t* C = args.tensor_c.compatible_ptr(); auto&& handle = concrete_handle(args.opr->handle()); - exec_igemm_8x8x32(A, B, C, batch_count, m, n, k, LDA, LDB, LDC, STA, STB, - STC, param.transposeA, param.transposeB, - cuda_stream(handle)); + exec_igemm_8x8x32( + A, B, C, batch_count, m, n, k, LDA, LDB, LDC, STA, STB, STC, + param.transposeA, param.transposeB, cuda_stream(handle)); } size_t BatchedMatrixMulForwardImpl::AlgoInt8x8x32::get_workspace_in_bytes( @@ -55,4 +52,3 @@ size_t BatchedMatrixMulForwardImpl::AlgoInt8x8x32::get_workspace_in_bytes( } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/batched_matrix_mul/int8x8x32.cu b/dnn/src/cuda/batched_matrix_mul/int8x8x32.cu index 37592871..a3ab3128 100644 --- a/dnn/src/cuda/batched_matrix_mul/int8x8x32.cu +++ b/dnn/src/cuda/batched_matrix_mul/int8x8x32.cu @@ -51,20 +51,14 @@ __device__ __forceinline__ void Global2SharedMem::gmem2reg_cpy() { int32_t val = 0; if (row < check_bound_row && col * 4 < check_bound_col) val = (int32_t)0xff & g_ptr[row * ld_src + col * 4]; - if (row < check_bound_row && - (col * 4 + 1) < check_bound_col) - val |= (((int32_t)0xff & - g_ptr[row * ld_src + col * 4 + 1]) + if (row < check_bound_row && (col * 4 + 1) < check_bound_col) + val |= (((int32_t)0xff & g_ptr[row * ld_src + col * 4 + 1]) << 8); - if (row < check_bound_row && - (col * 4 + 2) < check_bound_col) - val |= (((int32_t)0xff & - g_ptr[row * ld_src + col * 4 + 2]) + if (row < check_bound_row && (col * 4 + 2) < check_bound_col) + val |= (((int32_t)0xff & g_ptr[row * ld_src + col * 4 + 2]) << 16); - if (row < check_bound_row && - (col * 4 + 3) < check_bound_col) - val |= (((int32_t)0xff & - g_ptr[row * ld_src + col * 4 + 3]) + if (row < check_bound_row && (col * 4 + 3) < check_bound_col) + val |= (((int32_t)0xff & g_ptr[row * ld_src + col * 4 + 3]) << 24); cpy_reg[row][col] = val; } @@ -74,24 +68,19 @@ __device__ __forceinline__ void Global2SharedMem::gmem2reg_cpy() { for (int col = 0; col < SmemConfig::smem_col / 4; ++col) { #pragma unroll for (int row = 0; row < SmemConfig::smem_row / 4; ++row) { - int32_t src0 = cpy_reg[row * 4][col], - src1 = cpy_reg[row * 4 + 1][col], + int32_t src0 = cpy_reg[row * 4][col], src1 = cpy_reg[row * 4 + 1][col], src2 = cpy_reg[row * 4 + 2][col], src3 = cpy_reg[row * 4 + 3][col]; - reg[col * 4 + 3][row] = ((src3 >> 24 & 0xff) << 24) | - ((src2 >> 24 & 0xff) << 16) | - ((src1 >> 24 & 0xff) << 8) | - (src0 >> 24 & 0xff); - reg[col * 4 + 2][row] = ((src3 >> 16 & 0xff) << 24) | - ((src2 >> 16 & 0xff) << 16) | - ((src1 >> 16 & 0xff) << 8) | - (src0 >> 16 & 0xff); + reg[col * 4 + 3][row] = + ((src3 >> 24 & 0xff) << 24) | ((src2 >> 24 & 0xff) << 16) | + ((src1 >> 24 & 0xff) << 8) | (src0 >> 24 & 0xff); + reg[col * 4 + 2][row] = + ((src3 >> 16 & 0xff) << 24) | ((src2 >> 16 & 0xff) << 16) | + ((src1 >> 16 & 0xff) << 8) | (src0 >> 16 & 0xff); reg[col * 4 + 1][row] = ((src3 >> 8 & 0xff) << 24) | ((src2 >> 8 & 0xff) << 16) | - ((src1 >> 8 & 0xff) << 8) | - (src0 >> 8 & 0xff); - reg[col * 4][row] = ((src3 & 0xff) << 24) | - ((src2 & 0xff) << 16) | + ((src1 >> 8 & 0xff) << 8) | (src0 >> 8 & 0xff); + reg[col * 4][row] = ((src3 & 0xff) << 24) | ((src2 & 0xff) << 16) | ((src1 & 0xff) << 8) | (src0 & 0xff); } } @@ -128,20 +117,14 @@ __device__ __forceinline__ void Global2SharedMem::gmem2reg_cpy() { int32_t val = 0; if (col < check_bound_col && row * 4 < check_bound_row) val = (int32_t)0xff & g_ptr[col * ld_src + row * 4]; - if (col < check_bound_col && - (row * 4 + 1) < check_bound_row) - val |= (((int32_t)0xff & - g_ptr[col * ld_src + row * 4 + 1]) + if (col < check_bound_col && (row * 4 + 1) < check_bound_row) + val |= (((int32_t)0xff & g_ptr[col * ld_src + row * 4 + 1]) << 8); - if (col < check_bound_col && - (row * 4 + 2) < check_bound_row) - val |= (((int32_t)0xff & - g_ptr[col * ld_src + row * 4 + 2]) + if (col < check_bound_col && (row * 4 + 2) < check_bound_row) + val |= (((int32_t)0xff & g_ptr[col * ld_src + row * 4 + 2]) << 16); - if (col < check_bound_col && - (row * 4 + 3) < check_bound_row) - val |= (((int32_t)0xff & - g_ptr[col * ld_src + row * 4 + 3]) + if (col < check_bound_col && (row * 4 + 3) < check_bound_row) + val |= (((int32_t)0xff & g_ptr[col * ld_src + row * 4 + 3]) << 24); reg[col][row] = val; } @@ -168,20 +151,17 @@ __device__ __forceinline__ void Global2SharedMem::iter_forward() { } template -__global__ void batched_8x8x32_kern(const int8_t* a, int lda, int sta, bool tra, - const int8_t* b, int ldb, int stb, bool trb, - int32_t* c, int ldc, int stc, int m, int n, - int k) { +__global__ void batched_8x8x32_kern( + const int8_t* a, int lda, int sta, bool tra, const int8_t* b, int ldb, int stb, + bool trb, int32_t* c, int ldc, int stc, int m, int n, int k) { typedef UnrollConfig_ UnrollConfig; typedef ThreadConfig_ ThreadConfig; int off_batch = blockIdx.z, off_m = blockIdx.x, off_n = blockIdx.y, off_w = threadIdx.x, off_h = threadIdx.y, tid_x = off_m * ThreadConfig::thread_x + off_w, tid_y = off_n * ThreadConfig::thread_y + off_h; - static int const unroll = UnrollConfig::unroll, - thread_k = UnrollConfig::thread_k, - load_m = UnrollConfig::load_m, - load_n = UnrollConfig::load_n; + static int const unroll = UnrollConfig::unroll, thread_k = UnrollConfig::thread_k, + load_m = UnrollConfig::load_m, load_n = UnrollConfig::load_n; typedef SmemConfig SmemA; typedef SmemConfig SmemB; @@ -203,12 +183,10 @@ __global__ void batched_8x8x32_kern(const int8_t* a, int lda, int sta, bool tra, int32_t* smem_b = reinterpret_cast( &smem_a[(UnrollConfig::unroll_k / 4) * UnrollConfig::block_m]); - int off_smem_a = - (off_w * UnrollConfig::unroll_m + (off_h / thread_k) * load_m) * - UnrollConfig::unroll_k / 4, - off_smem_b = - (off_h * UnrollConfig::unroll_n + (off_w / thread_k) * load_n) * - UnrollConfig::unroll_k / 4; + int off_smem_a = (off_w * UnrollConfig::unroll_m + (off_h / thread_k) * load_m) * + UnrollConfig::unroll_k / 4, + off_smem_b = (off_h * UnrollConfig::unroll_n + (off_w / thread_k) * load_n) * + UnrollConfig::unroll_k / 4; int a_col = load_m; if (a_col > m - idx_m) a_col = m - idx_m; @@ -247,14 +225,12 @@ __global__ void batched_8x8x32_kern(const int8_t* a, int lda, int sta, bool tra, bool al_a = tra ? (m % 4 == 0) : (k % 4 == 0), al_b = trb ? (k % 4 == 0) : (n % 4 == 0); - gl2sh_type_a gl2sh_a(&smem_a[off_smem_a], idx_k_a * unroll / 4, - UnrollConfig::unroll_k / 4, sta, - UnrollConfig::unroll_k / 4, a_row, a_col, step_a, tra, - al_a); - gl2sh_type_b gl2sh_b(&smem_b[off_smem_b], idx_k_b * unroll / 4, - UnrollConfig::unroll_k / 4, stb, - UnrollConfig::unroll_k / 4, b_row, b_col, step_b, !trb, - al_b); + gl2sh_type_a gl2sh_a( + &smem_a[off_smem_a], idx_k_a * unroll / 4, UnrollConfig::unroll_k / 4, sta, + UnrollConfig::unroll_k / 4, a_row, a_col, step_a, tra, al_a); + gl2sh_type_b gl2sh_b( + &smem_b[off_smem_b], idx_k_b * unroll / 4, UnrollConfig::unroll_k / 4, stb, + UnrollConfig::unroll_k / 4, b_row, b_col, step_b, !trb, al_b); gl2sh_a.g_ptr = &a[off_a]; gl2sh_b.g_ptr = &b[off_b]; @@ -294,29 +270,29 @@ __global__ void batched_8x8x32_kern(const int8_t* a, int lda, int sta, bool tra, } __syncthreads(); if (off_c != -1) { - int32_t reg_a[UnrollConfig::unroll_m], - reg_b[UnrollConfig::unroll_n]; + int32_t reg_a[UnrollConfig::unroll_m], reg_b[UnrollConfig::unroll_n]; #pragma unroll - for (int k_in = 0; - k_in < UnrollConfig::unroll_k / 4 && k_in * 4 < k_out; + for (int k_in = 0; k_in < UnrollConfig::unroll_k / 4 && k_in * 4 < k_out; ++k_in) { #pragma unroll for (int i = 0; i < UnrollConfig::unroll_m; ++i) - reg_a[i] = smem_a[(off_w * UnrollConfig::unroll_m + i) * - UnrollConfig::unroll_k / 4 + - k_in]; + reg_a[i] = + smem_a[(off_w * UnrollConfig::unroll_m + i) * + UnrollConfig::unroll_k / 4 + + k_in]; #pragma unroll for (int j = 0; j < UnrollConfig::unroll_n; ++j) - reg_b[j] = smem_b[(off_h * UnrollConfig::unroll_n + j) * - UnrollConfig::unroll_k / 4 + - k_in]; + reg_b[j] = + smem_b[(off_h * UnrollConfig::unroll_n + j) * + UnrollConfig::unroll_k / 4 + + k_in]; #pragma unroll for (int i = 0; i < UnrollConfig::unroll_m; ++i) #pragma unroll for (int j = 0; j < UnrollConfig::unroll_n; ++j) { - dot_prod(reg_a[i], reg_b[j], - sum[i * UnrollConfig::unroll_n + j], - sum[i * UnrollConfig::unroll_n + j]); + dot_prod( + reg_a[i], reg_b[j], sum[i * UnrollConfig::unroll_n + j], + sum[i * UnrollConfig::unroll_n + j]); } } } @@ -329,15 +305,14 @@ __global__ void batched_8x8x32_kern(const int8_t* a, int lda, int sta, bool tra, for (int j = 0; j < UnrollConfig::unroll_n; ++j) if (tid_x * UnrollConfig::unroll_m + i < m && tid_y * UnrollConfig::unroll_n + j < n) - *(ptr_c + i * stc + j) = - sum[i * UnrollConfig::unroll_n + j]; + *(ptr_c + i * stc + j) = sum[i * UnrollConfig::unroll_n + j]; } } -void exec_igemm_8x8x32(const int8_t* A, const int8_t* B, int32_t* C, - const int batch_count, const int m, const int n, - const int k, int ldA, int ldB, int ldC, int stA, int stB, - int stC, bool transA, bool transB, cudaStream_t stream) { +void exec_igemm_8x8x32( + const int8_t* A, const int8_t* B, int32_t* C, const int batch_count, + const int m, const int n, const int k, int ldA, int ldB, int ldC, int stA, + int stB, int stC, bool transA, bool transB, cudaStream_t stream) { static int const unroll_m = 8, unroll_n = 8, unroll_k = 32, unroll = 4; typedef ThreadConfig<8, 8> Thread; typedef UnrollConfig Unroll; @@ -346,12 +321,13 @@ void exec_igemm_8x8x32(const int8_t* A, const int8_t* B, int32_t* C, grid.x = (m + Unroll::block_m - 1) / Unroll::block_m; grid.y = (n + Unroll::block_n - 1) / Unroll::block_n; grid.z = batch_count; - static uint32_t shared_storage = (Unroll::block_m + Unroll::block_n) * - Unroll::unroll_k * sizeof(int8_t); + static uint32_t shared_storage = + (Unroll::block_m + Unroll::block_n) * Unroll::unroll_k * sizeof(int8_t); - void (*kern)(const int8_t* a, int lda, int sta, bool tra, const int8_t* b, - int ldb, int stb, bool trb, int32_t* c, int ldc, int stc, - int m, int n, int k) = batched_8x8x32_kern; + void (*kern)( + const int8_t* a, int lda, int sta, bool tra, const int8_t* b, int ldb, + int stb, bool trb, int32_t* c, int ldc, int stc, int m, int n, int k) = + batched_8x8x32_kern; kern<<>>( A, ldA, stA, transA, B, ldB, stB, transB, C, ldC, stC, m, n, k); after_kernel_launch(); diff --git a/dnn/src/cuda/batched_matrix_mul/int8x8x32.cuh b/dnn/src/cuda/batched_matrix_mul/int8x8x32.cuh index 8db89bfe..5247fe11 100644 --- a/dnn/src/cuda/batched_matrix_mul/int8x8x32.cuh +++ b/dnn/src/cuda/batched_matrix_mul/int8x8x32.cuh @@ -24,10 +24,8 @@ struct UnrollConfig { static int const unroll_k = k_tot; static int const unroll = k_; static int const thread_k = k_tot / k_; - static int const load_m = - (m_ / 4) / (ThreadConfig::thread_y / thread_k) * 4; - static int const load_n = - (n_ / 4) / (ThreadConfig::thread_x / thread_k) * 4; + static int const load_m = (m_ / 4) / (ThreadConfig::thread_y / thread_k) * 4; + static int const load_n = (n_ / 4) / (ThreadConfig::thread_x / thread_k) * 4; }; template @@ -58,10 +56,9 @@ struct Global2SharedMem { bool tr; bool aligned; - __device__ __forceinline__ Global2SharedMem(int32_t* smem_, int s_off, - int s_bound, int ld_src_, - int ld_dst_, int b_r_, int b_c_, - int step_, bool tr_, bool al_) + __device__ __forceinline__ Global2SharedMem( + int32_t* smem_, int s_off, int s_bound, int ld_src_, int ld_dst_, int b_r_, + int b_c_, int step_, bool tr_, bool al_) : smem(smem_), smem_off(s_off), smem_bound(s_bound), @@ -79,15 +76,14 @@ struct Global2SharedMem { }; template -__global__ void batched_8x8x32_kern(const int8_t* a, int lda, int sta, bool tra, - const int8_t* b, int ldb, int stb, bool trb, - int32_t* c, int ldc, int stc, int m, int n, - int k); +__global__ void batched_8x8x32_kern( + const int8_t* a, int lda, int sta, bool tra, const int8_t* b, int ldb, int stb, + bool trb, int32_t* c, int ldc, int stc, int m, int n, int k); -void exec_igemm_8x8x32(const int8_t* A, const int8_t* B, int32_t* C, - const int batch_count, const int m, const int n, - const int k, int ldA, int ldB, int ldC, int stA, int stB, - int stC, bool transA, bool transB, cudaStream_t stream); +void exec_igemm_8x8x32( + const int8_t* A, const int8_t* B, int32_t* C, const int batch_count, + const int m, const int n, const int k, int ldA, int ldB, int ldC, int stA, + int stB, int stC, bool transA, bool transB, cudaStream_t stream); } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp b/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp index 01e902f3..32d3fb85 100644 --- a/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp +++ b/dnn/src/cuda/batched_matrix_mul/opr_impl.cpp @@ -22,9 +22,9 @@ using namespace cuda; using Algorithm = BatchedMatrixMulForwardImpl::Algorithm; -void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, - _megdnn_tensor_out C, - _megdnn_workspace workspace) { +void BatchedMatrixMulForwardImpl::exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) { using namespace batched_matrix_mul; //! //! \Note (int8, int8) => int32 is supported @@ -53,7 +53,7 @@ std::vector BatchedMatrixMulForwardImpl::get_all_algorithms( } std::vector BatchedMatrixMulForwardImpl::get_all_algorithms_safe( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { - auto ret_safe = get_all_algorithms(A,B,C); + auto ret_safe = get_all_algorithms(A, B, C); megdnn_assert(!ret_safe.empty(), "no usable batchedmatrixmulForward fwd algorithm"); return ret_safe; } @@ -64,22 +64,22 @@ Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( const AlgoAttribute& negative_attr) { MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); AlgoBase::SizeArgs args(this, A, B, C); - if (sm_algo_pack.cublas.is_available_attribute(args, positive_attr, - negative_attr)) { + if (sm_algo_pack.cublas.is_available_attribute( + args, positive_attr, negative_attr)) { return &sm_algo_pack.cublas; } #if CUDA_VERSION >= 10010 - else if (sm_algo_pack.cublasLt.is_available_attribute(args, positive_attr, - negative_attr)) { + else if (sm_algo_pack.cublasLt.is_available_attribute( + args, positive_attr, negative_attr)) { return &sm_algo_pack.cublasLt; } #endif - else if (sm_algo_pack.int8x8x32.is_available_attribute(args, positive_attr, - negative_attr)) { + else if (sm_algo_pack.int8x8x32.is_available_attribute( + args, positive_attr, negative_attr)) { return &sm_algo_pack.int8x8x32; } else { - if (sm_algo_pack.brute_force.is_available_attribute(args, positive_attr, - negative_attr)) { + if (sm_algo_pack.brute_force.is_available_attribute( + args, positive_attr, negative_attr)) { return &sm_algo_pack.brute_force; } } @@ -89,8 +89,8 @@ Algorithm* BatchedMatrixMulForwardImpl::get_algorithm_heuristic( "attribute(%s) args(%s) and " "workspace limit (%zu bytes)", Algorithm::attribute_str(negative_attr).c_str(), - Algorithm::attribute_str(positive_attr).c_str(), - args.to_string().c_str(), workspace_limit_in_bytes)); + Algorithm::attribute_str(positive_attr).c_str(), args.to_string().c_str(), + workspace_limit_in_bytes)); return nullptr; }; diff --git a/dnn/src/cuda/batched_matrix_mul/opr_impl.h b/dnn/src/cuda/batched_matrix_mul/opr_impl.h index cef99e8f..cc124175 100644 --- a/dnn/src/cuda/batched_matrix_mul/opr_impl.h +++ b/dnn/src/cuda/batched_matrix_mul/opr_impl.h @@ -28,26 +28,26 @@ public: class AlgoInt8x8x32; class AlgoPack; - void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& A, const TensorLayout& B, - const TensorLayout& C) override; + void exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& A, const TensorLayout& B, + const TensorLayout& C) override; - const char* get_algorithm_set_name() const override { - return "BATCHED_MATMUL"; - } + const char* get_algorithm_set_name() const override { return "BATCHED_MATMUL"; } bool is_thread_safe() const override { return true; } static const AlgoPack& algo_pack() { return sm_algo_pack; } Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: - std::vector get_all_algorithms(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) override; - std::vector get_all_algorithms_safe(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) override; + std::vector get_all_algorithms( + const TensorLayout& A, const TensorLayout& B, + const TensorLayout& C) override; + std::vector get_all_algorithms_safe( + const TensorLayout& A, const TensorLayout& B, + const TensorLayout& C) override; Algorithm* get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, diff --git a/dnn/src/cuda/check_non_finite/kern.cu b/dnn/src/cuda/check_non_finite/kern.cu index f688d61f..6692e08d 100644 --- a/dnn/src/cuda/check_non_finite/kern.cu +++ b/dnn/src/cuda/check_non_finite/kern.cu @@ -21,7 +21,7 @@ namespace cuda { INST_REDUCE(reduce::CheckNonFiniteOp, false); #undef COMMA -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/cuda/check_non_finite/opr_impl.cpp b/dnn/src/cuda/check_non_finite/opr_impl.cpp index 3b548f7d..47214983 100644 --- a/dnn/src/cuda/check_non_finite/opr_impl.cpp +++ b/dnn/src/cuda/check_non_finite/opr_impl.cpp @@ -22,14 +22,14 @@ namespace cuda { using reduce::CheckNonFiniteOp; -size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) { +size_t CheckNonFiniteImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { typedef CheckNonFiniteOp Op; return get_reduce_workspace_in_bytes(1, src.total_nr_elems(), 1); } -void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) { +void CheckNonFiniteImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); typedef CheckNonFiniteOp Op; auto stream = cuda_stream(this->handle()); diff --git a/dnn/src/cuda/check_non_finite/opr_impl.h b/dnn/src/cuda/check_non_finite/opr_impl.h index 5ab9d635..8c89b61a 100644 --- a/dnn/src/cuda/check_non_finite/opr_impl.h +++ b/dnn/src/cuda/check_non_finite/opr_impl.h @@ -21,13 +21,14 @@ class CheckNonFiniteImpl final : public CheckNonFinite { public: using CheckNonFinite::CheckNonFinite; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) override; bool is_thread_safe() const override { return true; } - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; }; } // namespace cuda diff --git a/dnn/src/cuda/checksum/kern.cu b/dnn/src/cuda/checksum/kern.cu index ca622a97..8edc8a4d 100644 --- a/dnn/src/cuda/checksum/kern.cu +++ b/dnn/src/cuda/checksum/kern.cu @@ -10,68 +10,57 @@ */ #include "./kern.cuh" -#include "src/cuda/utils.cuh" #include "src/cuda/reduce_helper.cuh" +#include "src/cuda/utils.cuh" namespace { - struct ChecksumOp { - typedef uint32_t wtype; - const uint32_t *src; - uint32_t *dst; +struct ChecksumOp { + typedef uint32_t wtype; + const uint32_t* src; + uint32_t* dst; - static const uint32_t INIT = 0; + static const uint32_t INIT = 0; - __host__ __device__ void write(uint32_t idx, uint32_t val) { - dst[idx] = val; - } + __host__ __device__ void write(uint32_t idx, uint32_t val) { dst[idx] = val; } - __host__ __device__ static uint32_t apply(uint32_t a, uint32_t b) { - return a + b; - } - }; + __host__ __device__ static uint32_t apply(uint32_t a, uint32_t b) { return a + b; } +}; - struct NonFourAlignedChecksumOp : ChecksumOp { - __host__ __device__ uint32_t read(uint32_t idx) { - uint8_t* data = (uint8_t*) (src + idx); - return (data[0] | ((uint32_t) data[1] << 8) | - ((uint32_t) data[2] << 16) | ((uint32_t) data[3] << 24)) * - (idx + 1); - } - }; - - struct FourAlignedChecksumOp : ChecksumOp { - __host__ __device__ uint32_t read(uint32_t idx) { - return src[idx] * (idx + 1); - } - }; +struct NonFourAlignedChecksumOp : ChecksumOp { + __host__ __device__ uint32_t read(uint32_t idx) { + uint8_t* data = (uint8_t*)(src + idx); + return (data[0] | ((uint32_t)data[1] << 8) | ((uint32_t)data[2] << 16) | + ((uint32_t)data[3] << 24)) * + (idx + 1); + } +}; +struct FourAlignedChecksumOp : ChecksumOp { + __host__ __device__ uint32_t read(uint32_t idx) { return src[idx] * (idx + 1); } +}; -} // anonymous namespace +} // anonymous namespace void megdnn::cuda::checksum::calc( - uint32_t *dest, - const uint32_t *buf, - uint32_t *workspace, - size_t nr_elem, cudaStream_t stream) { + uint32_t* dest, const uint32_t* buf, uint32_t* workspace, size_t nr_elem, + cudaStream_t stream) { if (!nr_elem) return; if (reinterpret_cast(buf) & 0b11) { NonFourAlignedChecksumOp op; op.src = buf; op.dst = dest; - run_reduce(workspace, - 1, nr_elem, 1, stream, op); + run_reduce( + workspace, 1, nr_elem, 1, stream, op); } else { FourAlignedChecksumOp op; op.src = buf; op.dst = dest; - run_reduce(workspace, - 1, nr_elem, 1, stream, op); + run_reduce(workspace, 1, nr_elem, 1, stream, op); } } -size_t megdnn::cuda::checksum::get_workspace_in_bytes(size_t nr_elem) -{ +size_t megdnn::cuda::checksum::get_workspace_in_bytes(size_t nr_elem) { return get_reduce_workspace_in_bytes(1, nr_elem, 1); } // vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/cuda/checksum/kern.cuh b/dnn/src/cuda/checksum/kern.cuh index 8000adca..5e059c6e 100644 --- a/dnn/src/cuda/checksum/kern.cuh +++ b/dnn/src/cuda/checksum/kern.cuh @@ -13,20 +13,18 @@ #include "src/cuda/utils.cuh" -namespace megdnn{ +namespace megdnn { namespace cuda { namespace checksum { void calc( - uint32_t *dest, const uint32_t *buf, uint32_t *workspace, - size_t nr_elem, + uint32_t* dest, const uint32_t* buf, uint32_t* workspace, size_t nr_elem, cudaStream_t stream); size_t get_workspace_in_bytes(size_t nr_elem); -} -} -} +} // namespace checksum +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen - diff --git a/dnn/src/cuda/checksum/opr_impl.cpp b/dnn/src/cuda/checksum/opr_impl.cpp index db1042c9..63926bea 100644 --- a/dnn/src/cuda/checksum/opr_impl.cpp +++ b/dnn/src/cuda/checksum/opr_impl.cpp @@ -9,11 +9,11 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kern.cuh" #include "./opr_impl.h" +#include "./kern.cuh" -#include "src/cuda/reduce_helper.cuh" #include "src/common/utils.h" +#include "src/cuda/reduce_helper.cuh" #include @@ -22,23 +22,20 @@ using namespace cuda; namespace { -WorkspaceBundle get_wbundle(const TensorLayout &data) -{ - size_t size_all = data.shape[0], - size_ints = size_all / sizeof(uint32_t); +WorkspaceBundle get_wbundle(const TensorLayout& data) { + size_t size_all = data.shape[0], size_ints = size_all / sizeof(uint32_t); size_t part1 = checksum::get_workspace_in_bytes(size_ints); size_t part2 = sizeof(ChecksumForward::Result::checksum); return {nullptr, {part1, part2}}; } -} // anonymous namespace +} // anonymous namespace -size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout &data) { +size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout& data) { auto wbundle = get_wbundle(data); return wbundle.total_size_in_bytes(); } - ChecksumForward::Result ChecksumForwardImpl::exec( _megdnn_tensor_in data, _megdnn_workspace workspace) { auto wbundle = get_wbundle(data.layout); @@ -49,19 +46,19 @@ ChecksumForward::Result ChecksumForwardImpl::exec( auto stream = cuda_stream(handle()); auto ptr = static_cast(data.raw_ptr); - size_t size_all = data.layout.shape[0], - size_ints = size_all / sizeof(uint32_t); + size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t); auto last_val_size = std::min(size_all, 4); cuda_check(cudaMemcpyAsync( - &result.last_val, ptr + size_all - last_val_size, last_val_size, - cudaMemcpyDeviceToHost, stream)); + &result.last_val, ptr + size_all - last_val_size, last_val_size, + cudaMemcpyDeviceToHost, stream)); if (size_ints) { - checksum::calc(static_cast(wbundle.get(1)), - static_cast(data.raw_ptr), - static_cast(wbundle.get(0)), - size_ints, stream); - cuda_check(cudaMemcpyAsync(&result.checksum, wbundle.get(1), - sizeof(result.checksum), cudaMemcpyDeviceToHost, stream)); + checksum::calc( + static_cast(wbundle.get(1)), + static_cast(data.raw_ptr), + static_cast(wbundle.get(0)), size_ints, stream); + cuda_check(cudaMemcpyAsync( + &result.checksum, wbundle.get(1), sizeof(result.checksum), + cudaMemcpyDeviceToHost, stream)); } cuda_check(cudaStreamSynchronize(stream)); return result; diff --git a/dnn/src/cuda/checksum/opr_impl.h b/dnn/src/cuda/checksum/opr_impl.h index 293e2406..4e385116 100644 --- a/dnn/src/cuda/checksum/opr_impl.h +++ b/dnn/src/cuda/checksum/opr_impl.h @@ -17,23 +17,18 @@ namespace megdnn { namespace cuda { -class ChecksumForwardImpl final: public ChecksumForward { - public: - using ChecksumForward::ChecksumForward; +class ChecksumForwardImpl final : public ChecksumForward { +public: + using ChecksumForward::ChecksumForward; - size_t get_workspace_in_bytes(const TensorLayout &) override; + size_t get_workspace_in_bytes(const TensorLayout&) override; - bool is_thread_safe() const override { - return true; - } + bool is_thread_safe() const override { return true; } - Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) - override; + Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) override; }; -} -} +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - - diff --git a/dnn/src/cuda/concat/concat.cu b/dnn/src/cuda/concat/concat.cu index d2150f2f..27246348 100644 --- a/dnn/src/cuda/concat/concat.cu +++ b/dnn/src/cuda/concat/concat.cu @@ -10,58 +10,46 @@ */ #include "src/cuda/concat/concat.cuh" -#include "src/cuda/utils.cuh" #include "megdnn/dtype.h" +#include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { namespace concat { template -__global__ void forward_kernel(const T **srcs, T *dst, - size_t nr_srcs, - size_t A, size_t B, size_t C, - const size_t *Bv, - const size_t *table_outer, - const size_t *table_inner) -{ +__global__ void forward_kernel( + const T** srcs, T* dst, size_t nr_srcs, size_t A, size_t B, size_t C, + const size_t* Bv, const size_t* table_outer, const size_t* table_inner) { size_t addr = threadIdx.x + blockIdx.x * blockDim.x; - if (addr < A*B*C) { + if (addr < A * B * C) { size_t c = addr % C; size_t b = addr / C % B; - size_t a = addr / (B*C); + size_t a = addr / (B * C); size_t i = table_outer[b]; size_t B_src = Bv[i]; size_t b_src = table_inner[b]; - size_t addr_src = (a*B_src + b_src)*C + c; + size_t addr_src = (a * B_src + b_src) * C + c; dst[addr] = srcs[i][addr_src]; } } template -void forward_proxy(const T **srcs, - T *dst, - size_t nr_srcs, - size_t A, size_t B, size_t C, - const size_t *Bv, - const size_t *table_outer, - const size_t *table_inner, - cudaStream_t stream) -{ +void forward_proxy( + const T** srcs, T* dst, size_t nr_srcs, size_t A, size_t B, size_t C, + const size_t* Bv, const size_t* table_outer, const size_t* table_inner, + cudaStream_t stream) { size_t total_nr_elem = A * B * C; size_t NR_BLOCKS = DIVUP(total_nr_elem, NR_THREADS); - forward_kernel<<>>(srcs, dst, - nr_srcs, - A, B, C, - Bv, - table_outer, - table_inner); + forward_kernel<<>>( + srcs, dst, nr_srcs, A, B, C, Bv, table_outer, table_inner); after_kernel_launch(); } -#define INST(T) \ -template void forward_proxy(const T**, T *, size_t, size_t, size_t, size_t, \ - const size_t *, const size_t *, const size_t *, cudaStream_t); +#define INST(T) \ + template void forward_proxy( \ + const T**, T*, size_t, size_t, size_t, size_t, const size_t*, \ + const size_t*, const size_t*, cudaStream_t); #define cb(DType) INST(typename DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb) @@ -69,9 +57,8 @@ MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb #undef INST -} // namespace concat -} // namespace cuda -} // namespace megdnn +} // namespace concat +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/concat/concat.cuh b/dnn/src/cuda/concat/concat.cuh index e0d50baf..09cf6336 100644 --- a/dnn/src/cuda/concat/concat.cuh +++ b/dnn/src/cuda/concat/concat.cuh @@ -17,18 +17,13 @@ namespace cuda { namespace concat { template -void forward_proxy(const T **srcs, - T *dst, - size_t nr_srcs, - size_t A, size_t B, size_t C, - const size_t *Bv, - const size_t *table_outer, - const size_t *table_inner, +void forward_proxy( + const T** srcs, T* dst, size_t nr_srcs, size_t A, size_t B, size_t C, + const size_t* Bv, const size_t* table_outer, const size_t* table_inner, cudaStream_t stream); -} // namespace concat -} // namespace cuda -} // namespace megdnn +} // namespace concat +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/concat/opr_impl.cpp b/dnn/src/cuda/concat/opr_impl.cpp index 776ac7e6..bc334fed 100644 --- a/dnn/src/cuda/concat/opr_impl.cpp +++ b/dnn/src/cuda/concat/opr_impl.cpp @@ -9,32 +9,26 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/cuda/concat/opr_impl.h" -#include "src/cuda/utils.h" #include "src/cuda/concat/concat.cuh" +#include "src/cuda/utils.h" namespace megdnn { namespace cuda { size_t ConcatForwardImpl::get_workspace_in_bytes( - const TensorLayoutArray &srcs, - const TensorLayout &dst) -{ + const TensorLayoutArray& srcs, const TensorLayout& dst) { auto B = dst.shape[param().axis]; // Please refer to the comment in ConcatForwardImpl::exec for detail. - WorkspaceBundle bundle(nullptr, { - sizeof(uintptr_t) * srcs.size(), - sizeof(size_t) * srcs.size(), - sizeof(size_t) * B, - sizeof(size_t) * B}); + WorkspaceBundle bundle( + nullptr, {sizeof(uintptr_t) * srcs.size(), sizeof(size_t) * srcs.size(), + sizeof(size_t) * B, sizeof(size_t) * B}); return bundle.total_size_in_bytes(); } template void ConcatForwardImpl::exec_internal( - _megdnn_in const TensorNDArray &srcs, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) -{ + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { auto srcs_layout = apply_vector(m_get_layout, srcs); auto srcs_shape = apply_vector(m_get_shape, srcs_layout); check_exec(srcs_layout, dst.layout, workspace.size); @@ -46,10 +40,10 @@ void ConcatForwardImpl::exec_internal( // workspace_cpu will be freed by cuda callback. SmallVector workspace_sizes{ - sizeof(const T *) * srcs.size(), - sizeof(size_t) * srcs.size(), - sizeof(size_t) * B, - sizeof(size_t) * B, + sizeof(const T*) * srcs.size(), + sizeof(size_t) * srcs.size(), + sizeof(size_t) * B, + sizeof(size_t) * B, }; // What do we need: @@ -67,30 +61,30 @@ void ConcatForwardImpl::exec_internal( // These temporary spaces reside in the device side. // The actually work is delegated to concat::forward_proxy. WorkspaceBundle workspace_cpu(nullptr, workspace_sizes), - workspace_gpu(nullptr, workspace_sizes); + workspace_gpu(nullptr, workspace_sizes); auto total_workspace_size = workspace_cpu.total_size_in_bytes(); - void *workspace_cpu_raw = malloc(total_workspace_size); + void* workspace_cpu_raw = malloc(total_workspace_size); megdnn_assert_internal(workspace_cpu_raw); - void *workspace_gpu_raw = workspace.raw_ptr; + void* workspace_gpu_raw = workspace.raw_ptr; workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes); workspace_gpu = WorkspaceBundle(workspace_gpu_raw, workspace_sizes); // srcs - auto srcs_cpu = static_cast(workspace_cpu.get(0)); - auto srcs_gpu = static_cast(workspace_gpu.get(0)); + auto srcs_cpu = static_cast(workspace_cpu.get(0)); + auto srcs_gpu = static_cast(workspace_gpu.get(0)); for (size_t i = 0; i < srcs.size(); ++i) { srcs_cpu[i] = srcs[i].ptr(); } // Bv - auto Bv_cpu = static_cast(workspace_cpu.get(1)); - auto Bv_gpu = static_cast(workspace_gpu.get(1)); + auto Bv_cpu = static_cast(workspace_cpu.get(1)); + auto Bv_gpu = static_cast(workspace_gpu.get(1)); get_ABC(srcs_shape, A, Bv_cpu, C); // table_outer - auto table_outer_cpu = static_cast(workspace_cpu.get(2)); - auto table_outer_gpu = static_cast(workspace_gpu.get(2)); - auto table_inner_cpu = static_cast(workspace_cpu.get(3)); - auto table_inner_gpu = static_cast(workspace_gpu.get(3)); + auto table_outer_cpu = static_cast(workspace_cpu.get(2)); + auto table_outer_gpu = static_cast(workspace_gpu.get(2)); + auto table_inner_cpu = static_cast(workspace_cpu.get(3)); + auto table_inner_gpu = static_cast(workspace_gpu.get(3)); { size_t outer_idx = 0, inner_idx = 0; @@ -105,11 +99,9 @@ void ConcatForwardImpl::exec_internal( } } for (size_t i = 0; i < workspace_cpu.nr_workspace(); ++i) { - cuda_check(cudaMemcpyAsync(workspace_gpu.get(i), - workspace_cpu.get(i), - workspace_cpu.get_size(i), - cudaMemcpyHostToDevice, - stream)); + cuda_check(cudaMemcpyAsync( + workspace_gpu.get(i), workspace_cpu.get(i), workspace_cpu.get_size(i), + cudaMemcpyHostToDevice, stream)); } /* CUDA_CK(cudaMemcpyAsync(workspace_gpu_raw, workspace_cpu_raw, @@ -117,30 +109,26 @@ void ConcatForwardImpl::exec_internal( cudaMemcpyHostToDevice, stream)); */ - cuda_check(cudaStreamAddCallback(stream, callback_free, - static_cast(workspace_cpu_raw), 0)); - concat::forward_proxy(srcs_gpu, dst.ptr(), srcs.size(), - A, B, C, - Bv_gpu, - table_outer_gpu, - table_inner_gpu, - stream); + cuda_check(cudaStreamAddCallback( + stream, callback_free, static_cast(workspace_cpu_raw), 0)); + concat::forward_proxy( + srcs_gpu, dst.ptr(), srcs.size(), A, B, C, Bv_gpu, table_outer_gpu, + table_inner_gpu, stream); } -void ConcatForwardImpl::exec(_megdnn_in const TensorNDArray &srcs, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) -{ -#define cb(DType) \ +void ConcatForwardImpl::exec( + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { +#define cb(DType) \ if (dst.layout.dtype.enumv() == DTypeTrait::enumv) { \ - using ctype = typename DTypeTrait::ctype; \ - exec_internal(srcs, dst, workspace); \ + using ctype = typename DTypeTrait::ctype; \ + exec_internal(srcs, dst, workspace); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/concat/opr_impl.h b/dnn/src/cuda/concat/opr_impl.h index 40ef491d..df1838e0 100644 --- a/dnn/src/cuda/concat/opr_impl.h +++ b/dnn/src/cuda/concat/opr_impl.h @@ -14,23 +14,23 @@ namespace megdnn { namespace cuda { -class ConcatForwardImpl: public ConcatForward { - public: - using ConcatForward::ConcatForward; - void exec(_megdnn_in const TensorNDArray &srcs, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayoutArray &, - const TensorLayout &) override; - private: - template - void exec_internal(_megdnn_in const TensorNDArray &srcs, - _megdnn_tensor_out dst, - _megdnn_workspace workspace); +class ConcatForwardImpl : public ConcatForward { +public: + using ConcatForward::ConcatForward; + void exec( + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayoutArray&, const TensorLayout&) override; + +private: + template + void exec_internal( + _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, + _megdnn_workspace workspace); }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/cond_take/kern.cu b/dnn/src/cuda/cond_take/kern.cu index f099f1e4..e0bffbb1 100644 --- a/dnn/src/cuda/cond_take/kern.cu +++ b/dnn/src/cuda/cond_take/kern.cu @@ -9,11 +9,11 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include #include "./kern.cuh" +#include "src/common/cond_take/predicate.cuh" #include "src/cuda/cumsum/kern_impl.cuinl" #include "src/cuda/query_blocksize.cuh" -#include "src/common/cond_take/predicate.cuh" -#include using namespace megdnn; using namespace megdnn::cond_take; diff --git a/dnn/src/cuda/cond_take/kern.cuh b/dnn/src/cuda/cond_take/kern.cuh index 4fbcc00a..ec739794 100644 --- a/dnn/src/cuda/cond_take/kern.cuh +++ b/dnn/src/cuda/cond_take/kern.cuh @@ -10,9 +10,9 @@ */ #pragma once +#include #include "megdnn/dtype.h" #include "src/common/cond_take/predicate.cuh" -#include namespace megdnn { namespace cuda { @@ -26,11 +26,10 @@ typedef dt_int32 IdxType; * \param size number of elements in mask * \return output size; i.e. number of elements taken */ -template +template size_t gen_idx( - void *workspace, size_t workspace_size, - IdxType *dest_idx, const T *mask, size_t size, - uint32_t mode, const megdnn::cond_take::KParam &kparam, + void* workspace, size_t workspace_size, IdxType* dest_idx, const T* mask, + size_t size, uint32_t mode, const megdnn::cond_take::KParam& kparam, cudaStream_t stream); //! get workspace size in bytes for gen_idx() @@ -44,13 +43,13 @@ size_t gen_idx_get_workspace_size(size_t size); * \param src_idx index input, must have been filled by gen_idx() * \param size size of original mask */ -template -void copy_output(T *dest_data, IdxType *dest_idx, - const T *src_data, IdxType *src_idx, uint32_t size, - cudaStream_t stream); - -} // namespace cond_take -} // namespace cuda -} // namespace megdnn +template +void copy_output( + T* dest_data, IdxType* dest_idx, const T* src_data, IdxType* src_idx, + uint32_t size, cudaStream_t stream); + +} // namespace cond_take +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/cuda/cond_take/kern.inl b/dnn/src/cuda/cond_take/kern.inl index 03d324c7..9a7622b8 100644 --- a/dnn/src/cuda/cond_take/kern.inl +++ b/dnn/src/cuda/cond_take/kern.inl @@ -9,11 +9,11 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include #include "./kern.cuh" +#include "src/common/cond_take/predicate.cuh" #include "src/cuda/cumsum/kern_impl.cuinl" #include "src/cuda/query_blocksize.cuh" -#include "src/common/cond_take/predicate.cuh" -#include using namespace megdnn; using namespace megdnn::cond_take; @@ -21,69 +21,57 @@ using namespace megdnn::cuda::cond_take; namespace { - //! cumsum opr to get output index - template - struct IdxGetter { - typedef ::megdnn::cuda::cumsum::SumOp ContigOp; - - const T * data; - Pred pred; - - IdxGetter(const T *d, const ::megdnn::cond_take::KParam &p): - data(d), pred(p) - {} - - __host__ __device__ static IdxType init() { - return 0; - } - - __device__ static IdxType apply(IdxType lhs, IdxType rhs) { - return lhs + rhs; - } - - __device__ IdxType visit(uint32_t idx) const { - return pred(data[idx]); - } - - static ContigOp make_contig(const IdxType *data) { - return ContigOp(data); - } - }; - - template - __global__ void copy_kern( - T *dest_data, IdxType *dest_idx, - const T *src_data, const IdxType *src_idx, uint32_t size) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < size && src_idx[tid] > src_idx[tid - 1]) { - uint32_t v = src_idx[tid] - 1; - dest_data[v] = src_data[tid]; - dest_idx[v] = tid; - } - } +//! cumsum opr to get output index +template +struct IdxGetter { + typedef ::megdnn::cuda::cumsum::SumOp ContigOp; + + const T* data; + Pred pred; + + IdxGetter(const T* d, const ::megdnn::cond_take::KParam& p) : data(d), pred(p) {} + + __host__ __device__ static IdxType init() { return 0; } + + __device__ static IdxType apply(IdxType lhs, IdxType rhs) { return lhs + rhs; } - // set zero for the first element - __global__ void set_zero(IdxType *dest) { - dest[0] = 0; + __device__ IdxType visit(uint32_t idx) const { return pred(data[idx]); } + + static ContigOp make_contig(const IdxType* data) { return ContigOp(data); } +}; + +template +__global__ void copy_kern( + T* dest_data, IdxType* dest_idx, const T* src_data, const IdxType* src_idx, + uint32_t size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < size && src_idx[tid] > src_idx[tid - 1]) { + uint32_t v = src_idx[tid] - 1; + dest_data[v] = src_data[tid]; + dest_idx[v] = tid; } +} + +// set zero for the first element +__global__ void set_zero(IdxType* dest) { + dest[0] = 0; +} -} // anonymous namespace +} // anonymous namespace -template +template size_t cuda::cond_take::gen_idx( - void *workspace, size_t workspace_size, - IdxType *dest_idx, const T *mask, size_t size, - uint32_t mode, const KParam &kparam, cudaStream_t stream) { - + void* workspace, size_t workspace_size, IdxType* dest_idx, const T* mask, + size_t size, uint32_t mode, const KParam& kparam, cudaStream_t stream) { switch (mode) { -#define cb(_m) case PEnum::_m: \ - { \ - typedef IdxGetter Op; \ - cuda::cumsum::run_kern( \ - dest_idx + 1, workspace, workspace_size, \ - 1, size, 1, Op(mask, kparam), stream); \ - break; \ - } +#define cb(_m) \ + case PEnum::_m: { \ + typedef IdxGetter Op; \ + cuda::cumsum::run_kern( \ + dest_idx + 1, workspace, workspace_size, 1, size, 1, Op(mask, kparam), \ + stream); \ + break; \ + } MEGDNN_FOREACH_COND_TAKE_MODE(cb) #undef cb default: @@ -91,20 +79,21 @@ size_t cuda::cond_take::gen_idx( } IdxType host_sum_size; - cuda_check(cudaMemcpyAsync(&host_sum_size, dest_idx + size, sizeof(IdxType), - cudaMemcpyDeviceToHost, stream)); + cuda_check(cudaMemcpyAsync( + &host_sum_size, dest_idx + size, sizeof(IdxType), cudaMemcpyDeviceToHost, + stream)); cuda_check(cudaStreamSynchronize(stream)); return host_sum_size; } -template -void cuda::cond_take::copy_output(T *dest_data, IdxType *dest_idx, - const T *src_data, IdxType *src_idx, uint32_t size, - cudaStream_t stream) { +template +void cuda::cond_take::copy_output( + T* dest_data, IdxType* dest_idx, const T* src_data, IdxType* src_idx, + uint32_t size, cudaStream_t stream) { int nr_thread = query_blocksize_for_kernel(copy_kern); int nr_block = DIVUP(size, nr_thread); - set_zero <<< 1, 1, 0, stream >>> (src_idx); - copy_kern <<< nr_block, nr_thread, 0, stream >>> ( + set_zero<<<1, 1, 0, stream>>>(src_idx); + copy_kern<<>>( dest_data, dest_idx, src_data, src_idx + 1, size); after_kernel_launch(); } @@ -113,19 +102,18 @@ namespace megdnn { namespace cuda { namespace cond_take { -#define inst_genidx(dt) \ - template size_t gen_idx( \ - void*, size_t, IdxType*, const DTypeTrait
::ctype*, \ - size_t, uint32_t, const KParam &, cudaStream_t); +#define inst_genidx(dt) \ + template size_t gen_idx( \ + void*, size_t, IdxType*, const DTypeTrait
::ctype*, size_t, uint32_t, \ + const KParam&, cudaStream_t); -#define inst_copy_(ct) \ - template void copy_output(ct*, IdxType*, const ct*, \ - IdxType*, uint32_t, cudaStream_t); +#define inst_copy_(ct) \ + template void copy_output( \ + ct*, IdxType*, const ct*, IdxType*, uint32_t, cudaStream_t); #define inst_copy(dt) inst_copy_(DTypeTrait
::ctype) -} // namespace cond_take -} // namespace cuda -} // namespace megdnn - +} // namespace cond_take +} // namespace cuda +} // namespace megdnn // vim: ft=cuda syntax=cuda.doxygen diff --git a/dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu b/dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu index ccaff9fe..e4a3fd2e 100644 --- a/dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu +++ b/dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu @@ -9,11 +9,11 @@ namespace cond_take { inst_genidx(::megdnn::dtype::BFloat16) #undef inst_genidx -inst_copy(::megdnn::dtype::BFloat16) + inst_copy(::megdnn::dtype::BFloat16) #undef inst_copy #undef inst_copy_ -} // cond_take -} // cuda -} // megdnn +} // namespace cond_take +} // namespace cuda +} // namespace megdnn #endif diff --git a/dnn/src/cuda/cond_take/kimpl/dt_bool.cu b/dnn/src/cuda/cond_take/kimpl/dt_bool.cu index 7afa17b7..bad94744 100644 --- a/dnn/src/cuda/cond_take/kimpl/dt_bool.cu +++ b/dnn/src/cuda/cond_take/kimpl/dt_bool.cu @@ -8,10 +8,10 @@ namespace cond_take { inst_genidx(::megdnn::dtype::Bool) #undef inst_genidx -inst_copy(::megdnn::dtype::Bool) + inst_copy(::megdnn::dtype::Bool) #undef inst_copy #undef inst_copy_ -} // cond_take -} // cuda -} // megdnn +} // namespace cond_take +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/cond_take/kimpl/dt_float16.cu b/dnn/src/cuda/cond_take/kimpl/dt_float16.cu index d9dddc97..5e5f8cdc 100644 --- a/dnn/src/cuda/cond_take/kimpl/dt_float16.cu +++ b/dnn/src/cuda/cond_take/kimpl/dt_float16.cu @@ -9,11 +9,11 @@ namespace cond_take { inst_genidx(::megdnn::dtype::Float16) #undef inst_genidx -inst_copy(::megdnn::dtype::Float16) + inst_copy(::megdnn::dtype::Float16) #undef inst_copy #undef inst_copy_ -} // cond_take -} // cuda -} // megdnn +} // namespace cond_take +} // namespace cuda +} // namespace megdnn #endif diff --git a/dnn/src/cuda/cond_take/kimpl/dt_float32.cu b/dnn/src/cuda/cond_take/kimpl/dt_float32.cu index d72c1e3c..3f424d66 100644 --- a/dnn/src/cuda/cond_take/kimpl/dt_float32.cu +++ b/dnn/src/cuda/cond_take/kimpl/dt_float32.cu @@ -8,10 +8,10 @@ namespace cond_take { inst_genidx(::megdnn::dtype::Float32) #undef inst_genidx -inst_copy(::megdnn::dtype::Float32) + inst_copy(::megdnn::dtype::Float32) #undef inst_copy #undef inst_copy_ -} // cond_take -} // cuda -} // megdnn +} // namespace cond_take +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/cond_take/kimpl/dt_int16.cu b/dnn/src/cuda/cond_take/kimpl/dt_int16.cu index f06bc7e7..a44bf1ae 100644 --- a/dnn/src/cuda/cond_take/kimpl/dt_int16.cu +++ b/dnn/src/cuda/cond_take/kimpl/dt_int16.cu @@ -8,10 +8,10 @@ namespace cond_take { inst_genidx(::megdnn::dtype::Int16) #undef inst_genidx -inst_copy(::megdnn::dtype::Int16) + inst_copy(::megdnn::dtype::Int16) #undef inst_copy #undef inst_copy_ -} // cond_take -} // cuda -} // megdnn +} // namespace cond_take +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/cond_take/kimpl/dt_int32.cu b/dnn/src/cuda/cond_take/kimpl/dt_int32.cu index c86de346..357dfd42 100644 --- a/dnn/src/cuda/cond_take/kimpl/dt_int32.cu +++ b/dnn/src/cuda/cond_take/kimpl/dt_int32.cu @@ -8,10 +8,10 @@ namespace cond_take { inst_genidx(::megdnn::dtype::Int32) #undef inst_genidx -inst_copy(::megdnn::dtype::Int32) + inst_copy(::megdnn::dtype::Int32) #undef inst_copy #undef inst_copy_ -} // cond_take -} // cuda -} // megdnn +} // namespace cond_take +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/cond_take/kimpl/dt_int8.cu b/dnn/src/cuda/cond_take/kimpl/dt_int8.cu index f78a1f37..7ccbb3b7 100644 --- a/dnn/src/cuda/cond_take/kimpl/dt_int8.cu +++ b/dnn/src/cuda/cond_take/kimpl/dt_int8.cu @@ -8,10 +8,10 @@ namespace cond_take { inst_genidx(::megdnn::dtype::Int8) #undef inst_genidx -inst_copy(::megdnn::dtype::Int8) + inst_copy(::megdnn::dtype::Int8) #undef inst_copy #undef inst_copy_ -} // cond_take -} // cuda -} // megdnn +} // namespace cond_take +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/cond_take/kimpl/dt_uint8.cu b/dnn/src/cuda/cond_take/kimpl/dt_uint8.cu index 4bdd6e7b..76047f9a 100644 --- a/dnn/src/cuda/cond_take/kimpl/dt_uint8.cu +++ b/dnn/src/cuda/cond_take/kimpl/dt_uint8.cu @@ -8,10 +8,10 @@ namespace cond_take { inst_genidx(::megdnn::dtype::Uint8) #undef inst_genidx -inst_copy(::megdnn::dtype::Uint8) + inst_copy(::megdnn::dtype::Uint8) #undef inst_copy #undef inst_copy_ -} // cond_take -} // cuda -} // megdnn +} // namespace cond_take +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/cond_take/opr_impl.cpp b/dnn/src/cuda/cond_take/opr_impl.cpp index 81c23da1..bdedc920 100644 --- a/dnn/src/cuda/cond_take/opr_impl.cpp +++ b/dnn/src/cuda/cond_take/opr_impl.cpp @@ -11,8 +11,8 @@ #include "./opr_impl.h" #include "./kern.cuh" -#include "src/common/utils.h" #include "src/common/cond_take/predicate.cuh" +#include "src/common/utils.h" #include "src/cuda/handle.h" #include "src/cuda/utils.h" @@ -36,8 +36,7 @@ size_t CondTakeImpl::get_workspace_in_bytes(const TensorLayout& data) { } CondTakeImpl::Output CondTakeImpl::exec( - _megdnn_tensor_in data, _megdnn_tensor_in mask, - _megdnn_workspace workspace, + _megdnn_tensor_in data, _megdnn_tensor_in mask, _megdnn_workspace workspace, DynOutMallocPolicyCall malloc_policy) { size_t size = check_exec_get_size(data.layout, mask.layout, workspace.size); auto wk_bundle = make_bundle(size); @@ -49,43 +48,38 @@ CondTakeImpl::Output CondTakeImpl::exec( auto stream = cuda_stream(handle()); size_t out_size; switch (mask.layout.dtype.enumv()) { -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: { \ - using ctype = DTypeTrait<_dt>::ctype; \ - out_size = gen_idx(wk_bundle.get(1), wk_bundle.get_size(1), \ - idx_tmp, mask.ptr(), \ - size, static_cast(param().mode), kparam, \ - stream); \ - break; \ - } +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + out_size = gen_idx( \ + wk_bundle.get(1), wk_bundle.get_size(1), idx_tmp, mask.ptr(), \ + size, static_cast(param().mode), kparam, stream); \ + break; \ + } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) cb(::megdnn::dtype::Bool) #undef cb - default: - megdnn_throw("bad mask dtype"); + default : megdnn_throw("bad mask dtype"); } - auto out_data = malloc_policy.alloc_output(0, - data.layout.dtype, {out_size}); + auto out_data = malloc_policy.alloc_output(0, data.layout.dtype, {out_size}); auto out_idx = malloc_policy.alloc_output(1, dtype::Int32(), {out_size}); auto out_idx_ptr = out_idx.ptr(); switch (data.layout.dtype.enumv()) { -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: { \ - using ctype = DTypeTrait<_dt>::ctype; \ - auto out_data_ptr = out_data.ptr(); \ - auto data_ptr = data.ptr(); \ - copy_output( \ - out_data_ptr, out_idx_ptr, data_ptr, idx_tmp, size, \ - stream); \ - break; \ - } +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + auto out_data_ptr = out_data.ptr(); \ + auto data_ptr = data.ptr(); \ + copy_output( \ + out_data_ptr, out_idx_ptr, data_ptr, idx_tmp, size, stream); \ + break; \ + } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) cb(::megdnn::dtype::Bool) #undef cb - default: - megdnn_throw("bad data dtype"); + default : megdnn_throw("bad data dtype"); } return {{out_data, out_idx}}; diff --git a/dnn/src/cuda/cond_take/opr_impl.h b/dnn/src/cuda/cond_take/opr_impl.h index 9f5c077e..c4ec6572 100644 --- a/dnn/src/cuda/cond_take/opr_impl.h +++ b/dnn/src/cuda/cond_take/opr_impl.h @@ -16,20 +16,19 @@ namespace megdnn { namespace cuda { -class CondTakeImpl final: public CondTake { +class CondTakeImpl final : public CondTake { WorkspaceBundle make_bundle(size_t nr_item); - public: - using CondTake::CondTake; - Output exec( - _megdnn_tensor_in data, _megdnn_tensor_in mask, - _megdnn_workspace workspace, - DynOutMallocPolicyCall malloc_policy) override; +public: + using CondTake::CondTake; + Output exec( + _megdnn_tensor_in data, _megdnn_tensor_in mask, _megdnn_workspace workspace, + DynOutMallocPolicyCall malloc_policy) override; - size_t get_workspace_in_bytes(const TensorLayout& data) override; + size_t get_workspace_in_bytes(const TensorLayout& data) override; }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index 4475c071..6aa2578e 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -109,18 +109,17 @@ MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvBiasForwardImpl) ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( const ConvBiasForwardImpl* o, const TensorLayout& src, - const TensorLayout& filter, const TensorLayout& bias, - const TensorLayout& z, const TensorLayout& dst, - const PreprocessedFilter* preprocessed_filter) - : SizeArgs(o, src, filter, - o->make_canonized_filter_meta(src.ndim, filter), bias, z, - dst, preprocessed_filter) {} + const TensorLayout& filter, const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) + : SizeArgs( + o, src, filter, o->make_canonized_filter_meta(src.ndim, filter), bias, + z, dst, preprocessed_filter) {} ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( const ConvBiasForwardImpl* o, const TensorLayout& src, const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) + const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst, + const PreprocessedFilter* preprocessed_filter) : BiasForwardSizeArgs{concrete_handle(o->handle()), &src, &filter, @@ -133,12 +132,12 @@ ConvBiasForwardImpl::AlgoBase::SizeArgs::SizeArgs( preprocessed_filter{preprocessed_filter} {} ConvBiasForwardImpl::AlgoBase::ExecArgs::ExecArgs( - ConvBiasForwardImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_in filter, _megdnn_tensor_in bias, _megdnn_tensor_in z, - _megdnn_tensor_out dst, _megdnn_workspace workspace, - const PreprocessedFilter* preprocessed_filter) - : SizeArgs(opr, src.layout, filter.layout, bias.layout, z.layout, - dst.layout, preprocessed_filter), + ConvBiasForwardImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_in filter, + _megdnn_tensor_in bias, _megdnn_tensor_in z, _megdnn_tensor_out dst, + _megdnn_workspace workspace, const PreprocessedFilter* preprocessed_filter) + : SizeArgs( + opr, src.layout, filter.layout, bias.layout, z.layout, dst.layout, + preprocessed_filter), src_tensor{&src}, filter_tensor{&filter}, bias_tensor{&bias}, @@ -172,9 +171,9 @@ std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const { "nonlinear_mode=%s", src_layout->to_string().c_str(), filter_layout->to_string().c_str(), bias_layout->to_string().c_str(), z_layout->to_string().c_str(), - dst_layout->to_string().c_str(), fm.padding[0], fm.padding[1], - fm.stride[0], fm.stride[1], fm.dilation[0], fm.dilation[1], - !fm.should_flip, src_layout->dtype.name(), dst_layout->dtype.name(), + dst_layout->to_string().c_str(), fm.padding[0], fm.padding[1], fm.stride[0], + fm.stride[1], fm.dilation[0], fm.dilation[1], !fm.should_flip, + src_layout->dtype.name(), dst_layout->dtype.name(), nonlinear_mode_str.c_str()); } @@ -200,59 +199,39 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { int8_nchw4_imma.push_back( {AlgoInt8NCHW4IMMAImplicitGemm::MMATileSize::IMMA8x32x16}); int8_chwn4_imma_reorder_filter.push_back( - {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize:: - IMMA16x16x16}); + {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::IMMA16x16x16}); int8_chwn4_imma_reorder_filter.push_back( - {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize:: - IMMA32x8x16}); + {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::IMMA32x8x16}); int8_chwn4_imma_reorder_filter.push_back( - {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize:: - IMMA8x32x16}); + {AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::MMATileSize::IMMA8x32x16}); int8_chwn4_imma_unroll_width.push_back( - {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize:: - IMMA16x16x16}); + {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::IMMA16x16x16}); int8_chwn4_imma_unroll_width.push_back( - {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize:: - IMMA32x8x16}); + {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::IMMA32x8x16}); int8_chwn4_imma_unroll_width.push_back( - {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize:: - IMMA8x32x16}); + {AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::MMATileSize::IMMA8x32x16}); #if CUDA_VERSION >= 10020 { using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; - int8_nchw32_imma.emplace_back( - AlgoParam{128, 256, 64, 64, 64, 64, 8, 8, 16, 2}); - int8_nchw32_imma.emplace_back( - AlgoParam{256, 128, 64, 64, 64, 64, 8, 8, 16, 2}); - int8_nchw32_imma.emplace_back( - AlgoParam{128, 128, 64, 64, 64, 64, 8, 8, 16, 2}); - int8_nchw32_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 32, 64, 8, 8, 16, 2}); - int8_nchw32_imma.emplace_back( - AlgoParam{64, 128, 64, 32, 64, 64, 8, 8, 16, 2}); - int8_nchw32_imma.emplace_back( - AlgoParam{128, 64, 32, 64, 32, 32, 8, 8, 16, 1}); - int8_nchw32_imma.emplace_back( - AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1}); - int8_nchw32_imma.emplace_back( - AlgoParam{64, 128, 32, 32, 64, 32, 8, 8, 16, 1}); - int8_nchw32_imma.emplace_back( - AlgoParam{32, 128, 32, 32, 64, 32, 8, 8, 16, 1}); + int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 8, 8, 16, 2}); + int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 8, 8, 16, 2}); + int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 8, 8, 16, 2}); + int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 8, 8, 16, 2}); + int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 8, 8, 16, 2}); + int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 8, 8, 16, 1}); + int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1}); + int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 8, 8, 16, 1}); + int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 8, 8, 16, 1}); } { using AlgoParam = AlgoInt8NHWCIMMAImplicitGemm::AlgoParam; - int8_nhwc_imma.emplace_back( - AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 16}); - int8_nhwc_imma.emplace_back( - AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 8}); - int8_nhwc_imma.emplace_back( - AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 4}); + int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 16}); + int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 8}); + int8_nhwc_imma.emplace_back(AlgoParam{64, 16, 32, 64, 16, 32, 8, 8, 16, 2, 4}); int8_nhwc_imma.emplace_back( AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 16}); - int8_nhwc_imma.emplace_back( - AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 8}); - int8_nhwc_imma.emplace_back( - AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 4}); + int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 8}); + int8_nhwc_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 8, 8, 16, 1, 4}); } { using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; @@ -324,46 +303,35 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; - int8_nchw4_dotprod.emplace_back( - AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2}); - int8_nchw4_dotprod.emplace_back( - AlgoParam{128, 64, 32, 64, 32, 32, 1, 1, 4, 2}); - int8_nchw4_dotprod.emplace_back( - AlgoParam{64, 128, 32, 64, 32, 32, 1, 1, 4, 2}); - int8_nchw4_dotprod.emplace_back( - AlgoParam{32, 128, 32, 32, 64, 32, 1, 1, 4, 2}); - int8_nchw4_dotprod.emplace_back( - AlgoParam{128, 32, 32, 64, 32, 32, 1, 1, 4, 2}); - int8_nchw4_dotprod.emplace_back( - AlgoParam{32, 64, 32, 32, 64, 32, 1, 1, 4, 2}); - int8_nchw4_dotprod.emplace_back( - AlgoParam{64, 32, 32, 64, 32, 32, 1, 1, 4, 2}); - int8_nchw4_dotprod.emplace_back( - AlgoParam{16, 128, 16, 16, 128, 16, 1, 1, 4, 1}); - int8_nchw4_dotprod.emplace_back( - AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 1, 1, 4, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1, 1, 4, 1}); + int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 1, 1, 4, 2}); } -ConvBiasForwardImpl::AlgoBase* -ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum( +ConvBiasForwardImpl::AlgoBase* ConvBiasForwardImpl::AlgoPack::cudnn_conv_from_enum( cudnnConvolutionFwdAlgo_t algo) { for (auto&& i : cudnn_convs) { if (i.cudnn_enum() == algo) return &i; } - megdnn_throw(ssprintf("can not find cudnn conv fwd algorithm %d", - static_cast(algo))); + megdnn_throw(ssprintf( + "can not find cudnn conv fwd algorithm %d", static_cast(algo))); } -ConvBiasForwardImpl::AlgoBase* -ConvBiasForwardImpl::AlgoPack::cudnn_conv_bias_act_from_enum( - cudnnConvolutionFwdAlgo_t algo) { +ConvBiasForwardImpl::AlgoBase* ConvBiasForwardImpl::AlgoPack:: + cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo) { for (auto&& i : cudnn_conv_bias_activations) { if (i.cudnn_enum() == algo) return &i; } - megdnn_throw(ssprintf("can not find cudnn conv bias act algorithm %d", - static_cast(algo))); + megdnn_throw(ssprintf( + "can not find cudnn conv bias act algorithm %d", static_cast(algo))); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index 90e0dc7b..682357de 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -94,20 +94,22 @@ public: const PreprocessedFilter* preprocessed_filter; std::string to_string() const; - SizeArgs(const ConvBiasForwardImpl* opr, const TensorLayout& src, - const TensorLayout& filter, const TensorLayout& bias, - const TensorLayout& z, const TensorLayout& dst, - const PreprocessedFilter* preprocessed_filter = nullptr); - SizeArgs(const ConvBiasForwardImpl* opr, const TensorLayout& src, - const TensorLayout& filter, - const CanonizedFilterMeta& filter_meta, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, - const PreprocessedFilter* preprocessed_filter = nullptr); + SizeArgs( + const ConvBiasForwardImpl* opr, const TensorLayout& src, + const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst, + const PreprocessedFilter* preprocessed_filter = nullptr); + SizeArgs( + const ConvBiasForwardImpl* opr, const TensorLayout& src, + const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, + const TensorLayout& bias, const TensorLayout& z, + const TensorLayout& dst, + const PreprocessedFilter* preprocessed_filter = nullptr); void init_conv_bias_desc(conv_bias::CUDNNForwardDescs& desc) const { - desc.set_conv_bias(*src_layout, filter_meta, *dst_layout, - *bias_layout, *z_layout, opr->param()); + desc.set_conv_bias( + *src_layout, filter_meta, *dst_layout, *bias_layout, *z_layout, + opr->param()); } void init_conv_desc(conv_bias::CUDNNForwardDescs& desc) const { @@ -119,17 +121,16 @@ public: *dst_tensor; Workspace workspace; - ExecArgs(ConvBiasForwardImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_in filter, _megdnn_tensor_in bias, - _megdnn_tensor_in z, _megdnn_tensor_out dst, - _megdnn_workspace workspace, - const PreprocessedFilter* preprocessed_filter = nullptr); + ExecArgs( + ConvBiasForwardImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in filter, _megdnn_tensor_in bias, _megdnn_tensor_in z, + _megdnn_tensor_out dst, _megdnn_workspace workspace, + const PreprocessedFilter* preprocessed_filter = nullptr); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; virtual void exec(const ExecArgs& args) const = 0; - virtual size_t get_preprocess_workspace_in_bytes( - const SizeArgs& args) const { + virtual size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const { MEGDNN_MARK_USED_VAR(args); return 0; } @@ -152,17 +153,15 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); megdnn_assert( req <= workspace.size, - "conv bias fwd algo %s: required workspace %zu bytes, got %zu", - name(), req, workspace.size); + "conv bias fwd algo %s: required workspace %zu bytes, got %zu", name(), + req, workspace.size); return *this; } @@ -173,8 +172,9 @@ class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase { public: AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { - megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != - CudnnAlgoPack::conv_fwd_algos().end()); + megdnn_assert( + CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != + CudnnAlgoPack::conv_fwd_algos().end()); m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum); m_name = ConvBiasForward::algo_name( "CUDNN:ConvBiasActivation:" + m_attr.name, {}); @@ -224,14 +224,11 @@ public: const char* name() const override { if (m_name.empty()) { - m_name = - ConvBiasForward::algo_name("CHANNEL_WISE", {}); + m_name = ConvBiasForward::algo_name("CHANNEL_WISE", {}); } return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) @@ -247,15 +244,12 @@ public: const char* name() const override { if (m_name.empty()) { - m_name = ConvBiasForward::algo_name( - "CHANNEL_WISE_SMALL", {}); + m_name = ConvBiasForward::algo_name("CHANNEL_WISE_SMALL", {}); } return m_name.c_str(); } MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: mutable std::string m_name; @@ -268,15 +262,12 @@ public: void exec(const ExecArgs& args) const override; const char* name() const override { if (m_name.empty()) { - m_name = ConvBiasForward::algo_name( - "CHANNEL_WISE_8X8X32", {}); + m_name = ConvBiasForward::algo_name("CHANNEL_WISE_8X8X32", {}); } return m_name.c_str(); } MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32) - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: mutable std::string m_name; @@ -284,10 +275,10 @@ private: class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase { public: - AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum) - : m_cudnn_enum(cudnn_enum) { - megdnn_assert(CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != - CudnnAlgoPack::conv_fwd_algos().end()); + AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { + megdnn_assert( + CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != + CudnnAlgoPack::conv_fwd_algos().end()); m_attr = CudnnAlgoPack::conv_fwd_algos().at(cudnn_enum); m_name = ConvBiasForward::algo_name( "CUDNN:Convolution:" + m_attr.name, {}); @@ -345,9 +336,7 @@ public: return m_name.c_str(); } MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: mutable std::string m_name; @@ -356,8 +345,7 @@ private: //! im2col and matmul, with dilation class ConvBiasForwardImpl::AlgoMatmul final : public AlgoBase { template - static void exec_internal(const ExecArgs& args, - const WorkspaceBundle& bundle); + static void exec_internal(const ExecArgs& args, const WorkspaceBundle& bundle); public: bool is_available(const SizeArgs& args) const override; @@ -366,19 +354,16 @@ public: const char* name() const override { if (m_name.empty()) { - m_name = ConvBiasForward::algo_name("MATMUL", - {}); + m_name = ConvBiasForward::algo_name("MATMUL", {}); } return m_name.c_str(); } std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; } private: @@ -399,9 +384,7 @@ public: return m_name.c_str(); } MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32) - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: bool need_src_unroll(const SizeArgs& args) const; @@ -428,12 +411,10 @@ public: } std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; } MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) @@ -451,20 +432,16 @@ public: void exec(const ExecArgs& args) const override; std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; const char* name() const override { if (m_name.empty()) { - m_name = ConvBiasForward::algo_name("CUDA:GROUP_CONV", - {}); + m_name = ConvBiasForward::algo_name("CUDA:GROUP_CONV", {}); } return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) @@ -481,38 +458,29 @@ public: size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; const char* name() const override { return "QUINT4x4x32_WMMA"; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; bool use_kernel_fhxfw(const SizeArgs& args) const; size_t get_workspace_in_bytes_do_conv(const SizeArgs& args) const; MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) }; #endif -class ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm final - : public AlgoBase { +class ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm final : public AlgoBase { public: AlgoInt8CHWN4DotProdImplicitGemm() = default; bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - const char* name() const override { - return "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM"; - } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + const char* name() const override { return "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM"; } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } template static void dispatch_nonlinear_mode( - const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, const int8_t* d_z, int8_t* d_dst, - const convolution::ConvParam& param, float alpha, float beta, - float gamma, float scale, cudaStream_t stream, + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + const int8_t* d_z, int8_t* d_dst, const convolution::ConvParam& param, + float alpha, float beta, float gamma, float scale, cudaStream_t stream, param::ConvBias::NonlineMode nonlinear_mode); MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_DOTPROD_INT8) }; @@ -548,12 +516,7 @@ public: // corresponds to cutlass::conv::ConvType. we hope that algo.h does not // depend on cutlass headers - enum class ConvType { - kConvolution, - kBatchConvolution, - kLocal, - kLocalShare - }; + enum class ConvType { kConvolution, kBatchConvolution, kLocal, kLocalShare }; // common parameters for operation selection struct AlgoParam { @@ -569,16 +532,15 @@ public: int stage; int access_size; - AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_, - int warp_m_, int warp_n_, int warp_k_, int instruction_m_, - int instruction_n_, int instruction_k_, int stage_, - int access_size_ = 0); + AlgoParam( + int threadblock_m_, int threadblock_n_, int threadblock_k_, int warp_m_, + int warp_n_, int warp_k_, int instruction_m_, int instruction_n_, + int instruction_k_, int stage_, int access_size_ = 0); std::string to_string() const; }; - AlgoCutlassConvolutionBase(AlgoParam algo_param) - : m_algo_param{algo_param} {} + AlgoCutlassConvolutionBase(AlgoParam algo_param) : m_algo_param{algo_param} {} // generate a cutlass::library::ConvolutionKey and find the corresponding // operation (cutlass kernel) from the global OperationTable @@ -589,18 +551,14 @@ public: // execute the cutlass kernel found by get_cutlass_conv_op. we give // subclasses full freedom to decide where and how these arguments are // extracted - void execute_cutlass_conv_op(const cutlass::library::Operation* op, - const void* src, const void* filter, - const void* bias, const void* z, void* dst, - void* workspace, size_t n, size_t hi, - size_t wi, size_t ci, size_t co, size_t fh, - size_t fw, size_t ho, size_t wo, size_t ph, - size_t pw, size_t sh, size_t sw, size_t dh, - size_t dw, const void* alpha, const void* beta, - const void* gamma, const void* delta, - const void* theta, const void* threshold, - const void* dst_scale, cudaStream_t stream, - const void* extra_param = nullptr) const; + void execute_cutlass_conv_op( + const cutlass::library::Operation* op, const void* src, const void* filter, + const void* bias, const void* z, void* dst, void* workspace, size_t n, + size_t hi, size_t wi, size_t ci, size_t co, size_t fh, size_t fw, size_t ho, + size_t wo, size_t ph, size_t pw, size_t sh, size_t sw, size_t dh, size_t dw, + const void* alpha, const void* beta, const void* gamma, const void* delta, + const void* theta, const void* threshold, const void* dst_scale, + cudaStream_t stream, const void* extra_param = nullptr) const; protected: AlgoParam m_algo_param; @@ -611,17 +569,15 @@ class ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm final public: AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) : AlgoCutlassConvolutionBase(algo_param), - m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s", - m_algo_param.to_string().c_str())} {} + m_name{ssprintf( + "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s", + m_algo_param.to_string().c_str())} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; const char* name() const override { return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } - size_t get_preprocess_workspace_in_bytes( - const SizeArgs& args) const override; + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override; SmallVector deduce_preprocessed_filter_layout( const SizeArgs& args) const override; void exec_preprocess(const ExecArgs& args) const override; @@ -634,8 +590,7 @@ public: } private: - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; std::string m_name; }; @@ -644,49 +599,34 @@ public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - const char* name() const override { - return "FALLBACK_CONV_NCHW_QS8"; - } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + const char* name() const override { return "FALLBACK_CONV_NCHW_QS8"; } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8) std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; - + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; + private: WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; }; #if CUDA_VERSION >= 10000 -class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm final - : public AlgoBase { +class ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm final : public AlgoBase { public: - enum class MMATileSize : uint32_t { - IMMA16x16x16, - IMMA32x8x16, - IMMA8x32x16 - }; + enum class MMATileSize : uint32_t { IMMA16x16x16, IMMA32x8x16, IMMA8x32x16 }; AlgoInt8CHWN4IMMAImplicitGemm(MMATileSize mma_tile_size) : m_mma_tile_size{mma_tile_size}, - m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_" + - to_string(m_mma_tile_size)} {} + m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_" + to_string(m_mma_tile_size)} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; const char* name() const override { return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } template static void dispatch_nonlinear_mode( - const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, int8_t* d_z, int8_t* d_dst, - const convolution::ConvParam& param, float alpha, float beta, - float gamma, float scale, cudaStream_t stream, - param::ConvBias::NonlineMode nonlinear_mode, - MMATileSize mma_tile_size); + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + int8_t* d_z, int8_t* d_dst, const convolution::ConvParam& param, + float alpha, float beta, float gamma, float scale, cudaStream_t stream, + param::ConvBias::NonlineMode nonlinear_mode, MMATileSize mma_tile_size); static std::string to_string(MMATileSize mma_tile_size); MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_CHWN4_IMMA_INT8) @@ -702,15 +642,13 @@ private: std::string m_name; }; -class ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm final - : public AlgoBase { +class ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm final : public AlgoBase { public: using MMATileSize = AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize; AlgoInt8NCHW4IMMAImplicitGemm(MMATileSize mma_tile_size) : m_mma_tile_size{mma_tile_size}, m_name{"INT8_NCHW4_IMMA_IMPLICIT_GEMM_" + - AlgoInt8CHWN4IMMAImplicitGemm::to_string( - m_mma_tile_size)} {} + AlgoInt8CHWN4IMMAImplicitGemm::to_string(m_mma_tile_size)} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; @@ -722,13 +660,10 @@ public: serialize_write_pod(m_mma_tile_size, ret); return ret; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; MMATileSize m_mma_tile_size; std::string m_name; }; @@ -740,8 +675,7 @@ public: AlgoInt8CHWN4IMMAImplicitGemmReorderFilter(MMATileSize mma_tile_size) : m_mma_tile_size{mma_tile_size}, m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_REORDER_FILTER_" + - AlgoInt8CHWN4IMMAImplicitGemm::to_string( - m_mma_tile_size)} {} + AlgoInt8CHWN4IMMAImplicitGemm::to_string(m_mma_tile_size)} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; @@ -753,9 +687,7 @@ public: serialize_write_pod(m_mma_tile_size, ret); return ret; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: MMATileSize m_mma_tile_size; @@ -769,8 +701,7 @@ public: AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth(MMATileSize mma_tile_size) : m_mma_tile_size{mma_tile_size}, m_name{"INT8_CHWN4_IMMA_IMPLICIT_GEMM_UNROLL_WIDTH_" + - AlgoInt8CHWN4IMMAImplicitGemm::to_string( - m_mma_tile_size)} {} + AlgoInt8CHWN4IMMAImplicitGemm::to_string(m_mma_tile_size)} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; @@ -782,9 +713,7 @@ public: serialize_write_pod(m_mma_tile_size, ret); return ret; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: MMATileSize m_mma_tile_size; @@ -799,20 +728,18 @@ public: AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) : AlgoCutlassConvolutionBase(algo_param) { m_name = ConvBias::algo_name( - ssprintf("INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s", - to_string(m_algo_param).c_str()), + ssprintf( + "INT8_NCHW32_IMMA_IMPLICIT_GEMM_%s", + to_string(m_algo_param).c_str()), ConvBias::DirectParam{}); } bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; const char* name() const override { return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } static std::string to_string(AlgoParam algo_param); - size_t get_preprocess_workspace_in_bytes( - const SizeArgs& args) const override; + size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override; SmallVector deduce_preprocessed_filter_layout( const SizeArgs& args) const override; void exec_preprocess(const ExecArgs& args) const override; @@ -825,8 +752,7 @@ public: } private: - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; std::string m_name; }; @@ -837,20 +763,18 @@ public: AlgoInt8NHWCIMMAImplicitGemm(AlgoParam algo_param) : AlgoCutlassConvolutionBase(algo_param) { m_name = ConvBias::algo_name( - ssprintf("INT8_NHWC_IMMA_IMPLICIT_GEMM_%s", - to_string(m_algo_param).c_str()), + ssprintf( + "INT8_NHWC_IMMA_IMPLICIT_GEMM_%s", + to_string(m_algo_param).c_str()), ConvBias::DirectParam{}); } bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; const char* name() const override { return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } static std::string to_string(AlgoParam algo_param); - size_t get_preprocess_workspace_in_bytes( - const SizeArgs& args) const override; + size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override; SmallVector deduce_preprocessed_filter_layout( const SizeArgs& args) const override; void exec_preprocess(const ExecArgs& args) const override; @@ -866,8 +790,8 @@ private: std::tuple get_constants( const ExecArgs& args) const; - void reorder_filter(const ExecArgs& args, int interleaved, - void* reordered_filter) const; + void reorder_filter( + const ExecArgs& args, int interleaved, void* reordered_filter) const; std::string m_name; }; @@ -878,9 +802,7 @@ public: AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) : AlgoCutlassConvolutionBase(algo_param) {} - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return m_name.c_str(); } std::string param() const override; @@ -911,17 +833,16 @@ public: using Base = AlgoInt4NCHW64IMMAImplicitGemmBase; using AlgoParam = Base::AlgoParam; - AlgoInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param) - : Base{algo_param} { + AlgoInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} { m_name = ConvBias::algo_name( - ssprintf("INT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s", - to_string(m_algo_param).c_str()), + ssprintf( + "INT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s", + to_string(m_algo_param).c_str()), ConvBias::DirectParam{}); } size_t get_workspace_in_bytes(const SizeArgs& args) const override; - size_t get_preprocess_workspace_in_bytes( - const SizeArgs& args) const override; + size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override; SmallVector deduce_preprocessed_filter_layout( const SizeArgs& args) const override; void exec_preprocess(const ExecArgs& args) const override; @@ -931,8 +852,7 @@ public: private: DTypeEnum src_dtype() const override { return DTypeEnum::QuantizedS4; } - std::tuple prepare_filter_bias( - const ExecArgs& args) const override; + std::tuple prepare_filter_bias(const ExecArgs& args) const override; std::tuple get_constants( const ExecArgs& args) const override; @@ -944,17 +864,16 @@ public: using Base = AlgoInt4NCHW64IMMAImplicitGemmBase; using AlgoParam = Base::AlgoParam; - AlgoUInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param) - : Base{algo_param} { + AlgoUInt4Int4NCHW64IMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} { m_name = ConvBias::algo_name( - ssprintf("UINT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s", - to_string(m_algo_param).c_str()), + ssprintf( + "UINT4_INT4_NCHW64_IMMA_IMPLICIT_GEMM_%s", + to_string(m_algo_param).c_str()), ConvBias::DirectParam{}); } size_t get_workspace_in_bytes(const SizeArgs& args) const override; - size_t get_preprocess_workspace_in_bytes( - const SizeArgs& args) const override; + size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override; SmallVector deduce_preprocessed_filter_layout( const SizeArgs& args) const override; void exec_preprocess(const ExecArgs& args) const override; @@ -964,14 +883,14 @@ public: private: DTypeEnum src_dtype() const override { return DTypeEnum::Quantized4Asymm; } - std::tuple prepare_filter_bias( - const ExecArgs& args) const override; + std::tuple prepare_filter_bias(const ExecArgs& args) const override; std::tuple get_constants( const ExecArgs& args) const override; - void update_bias(const ExecArgs& args, void* updated_bias, - void* reduce_filter_ptr, void* reduce_workspace) const; + void update_bias( + const ExecArgs& args, void* updated_bias, void* reduce_filter_ptr, + void* reduce_workspace) const; }; class ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase @@ -980,9 +899,7 @@ public: AlgoInt4NHWCIMMAImplicitGemmBase(AlgoParam algo_param) : AlgoCutlassConvolutionBase(algo_param) {} - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return m_name.c_str(); } std::string param() const override; @@ -1002,8 +919,8 @@ protected: virtual std::tuple get_constants( const ExecArgs& args) const = 0; - void reorder_filter(const ExecArgs& args, int interleaved, - void* reordered_filter) const; + void reorder_filter( + const ExecArgs& args, int interleaved, void* reordered_filter) const; std::string m_name; }; @@ -1016,14 +933,14 @@ public: AlgoInt4Int4NHWCIMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} { m_name = ConvBias::algo_name( - ssprintf("INT4_INT4_NHWC_IMMA_IMPLICIT_GEMM_%s", - to_string(m_algo_param).c_str()), + ssprintf( + "INT4_INT4_NHWC_IMMA_IMPLICIT_GEMM_%s", + to_string(m_algo_param).c_str()), ConvBias::DirectParam{}); } size_t get_workspace_in_bytes(const SizeArgs& args) const override; - size_t get_preprocess_workspace_in_bytes( - const SizeArgs& args) const override; + size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override; SmallVector deduce_preprocessed_filter_layout( const SizeArgs& args) const override; void exec_preprocess(const ExecArgs& args) const override; @@ -1033,8 +950,7 @@ public: private: DTypeEnum src_dtype() const override { return DTypeEnum::QuantizedS4; } - std::tuple prepare_filter_bias( - const ExecArgs& args) const override; + std::tuple prepare_filter_bias(const ExecArgs& args) const override; std::tuple get_constants( const ExecArgs& args) const override; @@ -1048,14 +964,14 @@ public: AlgoUInt4Int4NHWCIMMAImplicitGemm(AlgoParam algo_param) : Base{algo_param} { m_name = ConvBias::algo_name( - ssprintf("UINT4_INT4_NHWC_IMMA_IMPLICIT_GEMM_%s", - to_string(m_algo_param).c_str()), + ssprintf( + "UINT4_INT4_NHWC_IMMA_IMPLICIT_GEMM_%s", + to_string(m_algo_param).c_str()), ConvBias::DirectParam{}); } size_t get_workspace_in_bytes(const SizeArgs& args) const override; - size_t get_preprocess_workspace_in_bytes( - const SizeArgs& args) const override; + size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override; SmallVector deduce_preprocessed_filter_layout( const SizeArgs& args) const override; void exec_preprocess(const ExecArgs& args) const override; @@ -1065,14 +981,14 @@ public: private: DTypeEnum src_dtype() const override { return DTypeEnum::Quantized4Asymm; } - std::tuple prepare_filter_bias( - const ExecArgs& args) const override; + std::tuple prepare_filter_bias(const ExecArgs& args) const override; std::tuple get_constants( const ExecArgs& args) const override; - void update_bias(const ExecArgs& args, void* updated_bias, - void* reduce_filter_ptr, void* reduce_workspace) const; + void update_bias( + const ExecArgs& args, void* updated_bias, void* reduce_filter_ptr, + void* reduce_workspace) const; }; #endif @@ -1083,14 +999,11 @@ public: void exec(const ExecArgs& args) const override; std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; const char* name() const override { return "CONVBIAS_BFLOAT16"; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) private: @@ -1125,8 +1038,7 @@ public: std::vector int8_nchw4_imma; std::vector int8_chwn4_imma_reorder_filter; - std::vector - int8_chwn4_imma_unroll_width; + std::vector int8_chwn4_imma_unroll_width; #endif #if CUDA_VERSION >= 10020 std::vector int8_nchw32_imma; diff --git a/dnn/src/cuda/conv_bias/batched_matmul.cpp b/dnn/src/cuda/conv_bias/batched_matmul.cpp index 398a5f62..5ca12154 100644 --- a/dnn/src/cuda/conv_bias/batched_matmul.cpp +++ b/dnn/src/cuda/conv_bias/batched_matmul.cpp @@ -10,8 +10,8 @@ * implied. */ -#include "src/common/algo_chooser.h" #include "src/common/algo_base.h" +#include "src/common/algo_chooser.h" #include "src/common/conv_bias.h" #include "src/cuda/batched_matrix_mul/algo.h" #include "src/cuda/conv_bias/algo.h" @@ -42,8 +42,7 @@ std::pair sub_opr_config( B.stride[1] = src_layout.stride[1]; B.stride[0] = src_layout.stride[0]; B.dtype = src_layout.dtype; - C = {{dst_layout.shape[0], dst_layout.shape[1], B.shape[2]}, - dst_layout.dtype}; + C = {{dst_layout.shape[0], dst_layout.shape[1], B.shape[2]}, dst_layout.dtype}; C.stride[2] = 1; C.stride[1] = dst_layout.stride[1]; C.stride[0] = dst_layout.stride[0]; @@ -56,44 +55,43 @@ std::pair sub_opr_config( return {{A, B, C}, param}; } -std::pair> -prepare_sub_opr(const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) { +std::pair> prepare_sub_opr( + const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) { auto bmatmul_opr = args.handle->create_operator(); set_execution_policy( args.opr, bmatmul_opr.get()); - auto&& config = - sub_opr_config(args.filter_meta, *args.src_layout, - *args.filter_layout, *args.dst_layout, args.opr); + auto&& config = sub_opr_config( + args.filter_meta, *args.src_layout, *args.filter_layout, *args.dst_layout, + args.opr); bmatmul_opr->param() = config.second; return {config.first, std::move(bmatmul_opr)}; } } // namespace -std::vector -ConvBiasForwardImpl::AlgoBatchedMatmul::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { +std::vector ConvBiasForwardImpl::AlgoBatchedMatmul:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { const ConvBiasForwardImpl* conv_bias_opr = static_cast(opr); - CanonizedFilterMeta fm = conv_bias_opr->make_canonized_filter_meta( - layouts[0].ndim, layouts[1]); - auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[4], - conv_bias_opr); + CanonizedFilterMeta fm = + conv_bias_opr->make_canonized_filter_meta(layouts[0].ndim, layouts[1]); + auto&& config = + sub_opr_config(fm, layouts[0], layouts[1], layouts[4], conv_bias_opr); std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str, - config.first}}; + return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str, config.first}}; } -bool ConvBiasForwardImpl::AlgoBatchedMatmul::is_available( - const SizeArgs& args) const { +bool ConvBiasForwardImpl::AlgoBatchedMatmul::is_available(const SizeArgs& args) const { if (args.z_layout->ndim > 0) return false; auto config = prepare_sub_opr(args); //! The dst of batched matmul should be contiguous - if (!config.first[2].is_contiguous()) return false; + if (!config.first[2].is_contiguous()) + return false; auto&& fm = args.filter_meta; return fm.format == Param::Format::NCHW && @@ -103,9 +101,9 @@ bool ConvBiasForwardImpl::AlgoBatchedMatmul::is_available( fm.dilation[1] == 1 && fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.padding[0] == 0 && fm.padding[1] == 0 && fm.stride[0] == 1 && fm.stride[1] == 1 && - get_algorithm(static_cast( - config.second.get()), - config.first[0], config.first[1], config.first[2]); + get_algorithm( + static_cast(config.second.get()), + config.first[0], config.first[1], config.first[2]); } WorkspaceBundle ConvBiasForwardImpl::AlgoBatchedMatmul::get_workspace_bundle( @@ -114,9 +112,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoBatchedMatmul::get_workspace_bundle( SmallVector sizes; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); sizes.push_back(dst_layout.span().dist_byte()); } @@ -125,9 +122,9 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoBatchedMatmul::get_workspace_bundle( auto config = prepare_sub_opr(args); - sizes.insert(sizes.begin(), - config.second->get_workspace_in_bytes( - config.first[0], config.first[1], config.first[2])); + sizes.insert( + sizes.begin(), config.second->get_workspace_in_bytes( + config.first[0], config.first[1], config.first[2])); return {ptr, std::move(sizes)}; } @@ -142,9 +139,9 @@ void ConvBiasForwardImpl::AlgoBatchedMatmul::exec(const ExecArgs& args) const { if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor.raw_ptr = bundle.get(1); conv_dst_tensor.layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - conv_dst_tensor.layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, + conv_dst_tensor.layout.dtype); } ExecArgs conv_args = args; @@ -158,9 +155,9 @@ void ConvBiasForwardImpl::AlgoBatchedMatmul::exec(const ExecArgs& args) const { C{args.dst_tensor->raw_ptr, config.first[2]}; config.second->exec(A, B, C, bundle.get_workspace(0)); } - handle_bias_and_nonlinear(args.handle, args.nonlinear_mode, - &conv_dst_tensor, args.dst_tensor, - args.bias_tensor); + handle_bias_and_nonlinear( + args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, + args.bias_tensor); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/bfloat16.cpp b/dnn/src/cuda/conv_bias/bfloat16.cpp index 60567edb..66d53f49 100644 --- a/dnn/src/cuda/conv_bias/bfloat16.cpp +++ b/dnn/src/cuda/conv_bias/bfloat16.cpp @@ -10,11 +10,11 @@ * implied. */ +#include "src/common/algo_base.h" #include "src/cuda/conv_bias/algo.h" #include "src/cuda/handle.h" #include "src/cuda/utils.cuh" #include "src/cuda/utils.h" -#include "src/common/algo_base.h" using namespace megdnn; using namespace cuda; @@ -46,8 +46,8 @@ std::pair> prepare_sub_opr( const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) { auto convbias_opr = args.handle->create_operator(); auto&& config = sub_opr_config( - {*args.src_layout, *args.filter_layout, *args.bias_layout, - *args.z_layout, *args.dst_layout}, + {*args.src_layout, *args.filter_layout, *args.bias_layout, *args.z_layout, + *args.dst_layout}, args.opr); convbias_opr->param() = config.second; @@ -55,26 +55,25 @@ std::pair> prepare_sub_opr( } } // namespace -std::vector -ConvBiasForwardImpl::AlgoBFloat16::get_subopr_list( +std::vector ConvBiasForwardImpl::AlgoBFloat16::get_subopr_list( const TensorLayoutArray& layouts, const OperatorBase* opr) const { - auto&& config = sub_opr_config( - layouts, static_cast(opr)); + auto&& config = + sub_opr_config(layouts, static_cast(opr)); std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); return {{Algorithm::OprType::CONVBIAS_FORWARD, param_str, config.first}}; } -bool ConvBiasForwardImpl::AlgoBFloat16::is_available( - const SizeArgs& args) const { +bool ConvBiasForwardImpl::AlgoBFloat16::is_available(const SizeArgs& args) const { auto config = prepare_sub_opr(args); return args.src_layout->dtype == args.filter_layout->dtype && args.src_layout->dtype == dtype::BFloat16() && - get_algorithm(static_cast(config.second.get()), - config.first[0], config.first[1], config.first[2], - config.first[3], config.first[4]); + get_algorithm( + static_cast(config.second.get()), + config.first[0], config.first[1], config.first[2], config.first[3], + config.first[4]); } WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle( @@ -82,8 +81,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle( auto config = prepare_sub_opr(args); SmallVector sizes; - auto get_workspace = [&sizes](const TensorLayout& src, - const TensorLayout& dst) { + auto get_workspace = [&sizes](const TensorLayout& src, const TensorLayout& dst) { if (src.dtype != dst.dtype) { sizes.push_back(dst.span().dist_byte()); } @@ -123,8 +121,9 @@ void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { { auto config = prepare_sub_opr(args); - config.second->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, - fz_tensor, fdst_tensor, nullptr, cvter.workspace()); + config.second->exec( + fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, fdst_tensor, + nullptr, cvter.workspace()); } { cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); } } diff --git a/dnn/src/cuda/conv_bias/chanwise.cpp b/dnn/src/cuda/conv_bias/chanwise.cpp index 77ce4590..a2ee6d5c 100644 --- a/dnn/src/cuda/conv_bias/chanwise.cpp +++ b/dnn/src/cuda/conv_bias/chanwise.cpp @@ -18,10 +18,8 @@ using namespace megdnn; using namespace cuda; using namespace conv_bias; -bool ConvBiasForwardImpl::AlgoChanwise::is_available( - const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { +bool ConvBiasForwardImpl::AlgoChanwise::is_available(const SizeArgs& args) const { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } if (args.src_layout->dtype == args.filter_layout->dtype && @@ -33,10 +31,10 @@ bool ConvBiasForwardImpl::AlgoChanwise::is_available( auto&& fm = args.filter_meta; bool flag = args.filter_meta.format == Param::Format::NCHW && - args.src_layout->dtype.category() == DTypeCategory::FLOAT && - args.opr->param().compute_mode == Param::ComputeMode::DEFAULT && - fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && !fm.should_flip; + args.src_layout->dtype.category() == DTypeCategory::FLOAT && + args.opr->param().compute_mode == Param::ComputeMode::DEFAULT && + fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && !fm.should_flip; return flag; } @@ -45,24 +43,22 @@ size_t ConvBiasForwardImpl::AlgoChanwise::get_workspace_in_bytes( auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); return dst_layout.span().dist_byte(); } return 0; } void ConvBiasForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { - WorkspaceBundle bundle{args.workspace.raw_ptr, - {get_workspace_in_bytes(args)}}; + WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; auto conv_dst_tensor = *args.dst_tensor; if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor.raw_ptr = bundle.get(0); conv_dst_tensor.layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - conv_dst_tensor.layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, + conv_dst_tensor.layout.dtype); } { @@ -70,10 +66,9 @@ void ConvBiasForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { auto stream = cuda_stream(args.handle); switch (args.src_layout->dtype.enumv()) { case DTypeEnum::Float32: - chanwise::run_fwd(conv_dst_tensor.ptr(), - args.src_tensor->ptr(), - args.filter_tensor->ptr(), kparam, - stream); + chanwise::run_fwd( + conv_dst_tensor.ptr(), args.src_tensor->ptr(), + args.filter_tensor->ptr(), kparam, stream); break; case DTypeEnum::Float16: #if CUDA_VERSION >= 9000 @@ -81,19 +76,19 @@ void ConvBiasForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { chanwise::run_fwd( static_cast(conv_dst_tensor.raw_ptr), static_cast(args.src_tensor->raw_ptr), - static_cast(args.filter_tensor->raw_ptr), - kparam, stream); + static_cast(args.filter_tensor->raw_ptr), kparam, + stream); } else { - chanwise::run_fwd(conv_dst_tensor.ptr(), - args.src_tensor->ptr(), - args.filter_tensor->ptr(), - kparam, stream); + chanwise::run_fwd( + conv_dst_tensor.ptr(), + args.src_tensor->ptr(), + args.filter_tensor->ptr(), kparam, stream); } #else - chanwise::run_fwd(conv_dst_tensor.ptr(), - args.src_tensor->ptr(), - args.filter_tensor->ptr(), kparam, - stream); + chanwise::run_fwd( + conv_dst_tensor.ptr(), + args.src_tensor->ptr(), + args.filter_tensor->ptr(), kparam, stream); #endif break; default: @@ -101,9 +96,9 @@ void ConvBiasForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { } } - handle_bias_and_nonlinear(args.handle, args.nonlinear_mode, - &conv_dst_tensor, args.dst_tensor, - args.bias_tensor); + handle_bias_and_nonlinear( + args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, + args.bias_tensor); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/chanwise/fwd.cu b/dnn/src/cuda/conv_bias/chanwise/fwd.cu index 0f46bf2a..dd0a36c7 100644 --- a/dnn/src/cuda/conv_bias/chanwise/fwd.cu +++ b/dnn/src/cuda/conv_bias/chanwise/fwd.cu @@ -26,8 +26,7 @@ namespace { // each y-slice of a block works on an (N, CHL_MUL, OH, OW) spatial image at // given inp_chl template -__global__ void kern_fwd_float(T* dst, const T* src, const T* flt_tot, - Param param) { +__global__ void kern_fwd_float(T* dst, const T* src, const T* flt_tot, Param param) { extern __shared__ uint8_t flt_storage[]; T* const flt = reinterpret_cast(flt_storage); @@ -69,15 +68,13 @@ __global__ void kern_fwd_float(T* dst, const T* src, const T* flt_tot, #pragma unroll for (uint32_t fw = 0; fw < FW; ++fw) { if (static_cast(fw + iw) < IW) { - sum += flt_base[fh * FW + fw] * - src_base[fh * IW + fw]; + sum += flt_base[fh * FW + fw] * src_base[fh * IW + fw]; } } } } } else { - int fhmax = min(int(FH), int(IH - ih)), - fwmax = min(int(FW), int(IW - iw)); + int fhmax = min(int(FH), int(IH - ih)), fwmax = min(int(FW), int(IW - iw)); for (int fh = max(0, -ih); fh < fhmax; ++fh) { for (int fw = max(0, -iw); fw < fwmax; ++fw) { sum += flt_base[fh * FW + fw] * src_base[fh * IW + fw]; @@ -90,8 +87,8 @@ __global__ void kern_fwd_float(T* dst, const T* src, const T* flt_tot, #if CUDA_VERSION >= 9000 template -__global__ void kern_fwd_half(__half* dst, const __half* src, - const __half* flt_tot, Param param) { +__global__ void kern_fwd_half( + __half* dst, const __half* src, const __half* flt_tot, Param param) { extern __shared__ uint8_t flt_storage[]; __half* const flt = reinterpret_cast<__half*>(flt_storage); @@ -122,8 +119,7 @@ __global__ void kern_fwd_half(__half* dst, const __half* src, int ih = int(oh * SH) - int(PH), iw = int(ow * SW) - int(PW); const __half* flt_base = flt + chl_mul * FSIZE; - const __half* src_base = - src + int(((n * IC + ic) * IH + ih) * IW + iw); + const __half* src_base = src + int(((n * IC + ic) * IH + ih) * IW + iw); __half2 sum{0.0, 0.0}; @@ -134,10 +130,8 @@ __global__ void kern_fwd_half(__half* dst, const __half* src, if (static_cast(fh + ih) < IH) { if (FH_SET == 3 && FW_SET == 3 && SW_SET == 1) { __half2 fil0 = {flt_base[fh * FW], flt_base[fh * FW]}; - __half2 fil1 = {flt_base[fh * FW + 1], - flt_base[fh * FW + 1]}; - __half2 fil2 = {flt_base[fh * FW + 2], - flt_base[fh * FW + 2]}; + __half2 fil1 = {flt_base[fh * FW + 1], flt_base[fh * FW + 1]}; + __half2 fil2 = {flt_base[fh * FW + 2], flt_base[fh * FW + 2]}; __half2 src0 = {0.0, 0.0}; if (static_cast(iw) < IW) @@ -157,14 +151,10 @@ __global__ void kern_fwd_half(__half* dst, const __half* src, sum = fma2(src1, fil1, sum); } else if (FH_SET == 5 && FW_SET == 5 && SW_SET == 1) { __half2 fil0 = {flt_base[fh * FW], flt_base[fh * FW]}; - __half2 fil1 = {flt_base[fh * FW + 1], - flt_base[fh * FW + 1]}; - __half2 fil2 = {flt_base[fh * FW + 2], - flt_base[fh * FW + 2]}; - __half2 fil3 = {flt_base[fh * FW + 3], - flt_base[fh * FW + 3]}; - __half2 fil4 = {flt_base[fh * FW + 4], - flt_base[fh * FW + 4]}; + __half2 fil1 = {flt_base[fh * FW + 1], flt_base[fh * FW + 1]}; + __half2 fil2 = {flt_base[fh * FW + 2], flt_base[fh * FW + 2]}; + __half2 fil3 = {flt_base[fh * FW + 3], flt_base[fh * FW + 3]}; + __half2 fil4 = {flt_base[fh * FW + 4], flt_base[fh * FW + 4]}; __half2 src0 = {0.0, 0.0}; if (static_cast(iw) < IW) @@ -196,14 +186,13 @@ __global__ void kern_fwd_half(__half* dst, const __half* src, } else { #pragma unroll for (uint32_t fw = 0; fw < FW; ++fw) { - __half2 fil = {flt_base[fh * FW + fw], - flt_base[fh * FW + fw]}; + __half2 fil = { + flt_base[fh * FW + fw], flt_base[fh * FW + fw]}; __half2 src = {0.0, 0.0}; - if (static_cast(static_cast(fw) + - iw) < IW) + if (static_cast(static_cast(fw) + iw) < IW) src.x = src_base[fh * IW + fw]; - if (static_cast(static_cast(fw) + - iw + SW) < IW) + if (static_cast(static_cast(fw) + iw + SW) < + IW) src.y = src_base[fh * IW + fw + SW]; sum = fma2(src, fil, sum); } @@ -211,10 +200,8 @@ __global__ void kern_fwd_half(__half* dst, const __half* src, } } - dst[(((n * IC + ic) * CHL_MUL + chl_mul) * OH + oh) * OW + ow] = - sum.x; - dst[(((n * IC + ic) * CHL_MUL + chl_mul) * OH + oh) * OW + ow + 1] = - sum.y; + dst[(((n * IC + ic) * CHL_MUL + chl_mul) * OH + oh) * OW + ow] = sum.x; + dst[(((n * IC + ic) * CHL_MUL + chl_mul) * OH + oh) * OW + ow + 1] = sum.y; continue; } @@ -232,8 +219,7 @@ __global__ void kern_fwd_half(__half* dst, const __half* src, int ih = int(oh * SH) - int(PH), iw = int(ow * SW) - int(PW); const __half* flt_base = flt + chl_mul * FSIZE; - const __half* src_base = - src + int(((n * IC + ic) * IH + ih) * IW + iw); + const __half* src_base = src + int(((n * IC + ic) * IH + ih) * IW + iw); __half sum(0); @@ -246,8 +232,9 @@ __global__ void kern_fwd_half(__half* dst, const __half* src, #pragma unroll for (uint32_t fw = 0; fw < FW; ++fw) { if (static_cast(fw + iw) < IW) { - sum = fma(flt_base[fh * FW + fw], - src_base[fh * IW + fw], sum); + sum = + fma(flt_base[fh * FW + fw], + src_base[fh * IW + fw], sum); } } } @@ -257,16 +244,13 @@ __global__ void kern_fwd_half(__half* dst, const __half* src, fwmax = min(int(FW), int(IW - iw)); for (int fh = max(0, -ih); fh < fhmax; ++fh) { for (int fw = max(0, -iw); fw < fwmax; ++fw) { - sum = fma(flt_base[fh * FW + fw], - src_base[fh * IW + fw], sum); + sum = fma(flt_base[fh * FW + fw], src_base[fh * IW + fw], sum); } } } - dst[(((n * IC + ic) * CHL_MUL + chl_mul) * OH + oh) * OW + ow] = - sum; + dst[(((n * IC + ic) * CHL_MUL + chl_mul) * OH + oh) * OW + ow] = sum; - if (n == N - 1 && chl_mul == CHL_MUL - 1 && ow == OW - 1 && - oh == OH - 1) + if (n == N - 1 && chl_mul == CHL_MUL - 1 && ow == OW - 1 && oh == OH - 1) break; } } @@ -335,30 +319,28 @@ namespace conv_bias { namespace chanwise { template -void run_fwd(T* dst, const T* src, const T* flt, const Param& param, - cudaStream_t stream) { +void run_fwd( + T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream) { void (*kern)(T*, const T*, const T*, Param); kern = get_kern(param).f; int nr_thread = query_blocksize_for_kernel(kern), nr_out_dimx = param.out_h * param.out_w * param.batch * param.chl_mul; - dim3 nr_block(param.src_chl, - std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); + dim3 nr_block(param.src_chl, std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T); kern<<>>(dst, src, flt, param); after_kernel_launch(); } -template void run_fwd(float*, const float*, const float*, const Param&, - cudaStream_t); +template void run_fwd(float*, const float*, const float*, const Param&, cudaStream_t); #if CUDA_VERSION >= 9000 -template void run_fwd(__half*, const __half*, const __half*, const Param&, - cudaStream_t); +template void run_fwd( + __half*, const __half*, const __half*, const Param&, cudaStream_t); #endif -template void run_fwd(dt_float16*, const dt_float16*, const dt_float16*, - const Param&, cudaStream_t); +template void run_fwd( + dt_float16*, const dt_float16*, const dt_float16*, const Param&, cudaStream_t); } // namespace chanwise } // namespace conv_bias diff --git a/dnn/src/cuda/conv_bias/chanwise/fwd_8x8x32.cu b/dnn/src/cuda/conv_bias/chanwise/fwd_8x8x32.cu index 15b2a471..7a09064f 100644 --- a/dnn/src/cuda/conv_bias/chanwise/fwd_8x8x32.cu +++ b/dnn/src/cuda/conv_bias/chanwise/fwd_8x8x32.cu @@ -20,11 +20,9 @@ using namespace chanwise; namespace { -__host__ __device__ void get_receptive_field_size(uint32_t OH, uint32_t OW, - uint32_t FH, uint32_t FW, - uint32_t SH, uint32_t SW, - uint32_t DH, uint32_t DW, - uint32_t* RH, uint32_t* RW) { +__host__ __device__ void get_receptive_field_size( + uint32_t OH, uint32_t OW, uint32_t FH, uint32_t FW, uint32_t SH, uint32_t SW, + uint32_t DH, uint32_t DW, uint32_t* RH, uint32_t* RW) { // DFH = dilationd FH, DFW = dilationd FW // RH = receptive field height, RW = receptive field width uint32_t DFH = (FH - 1) * DH + 1, DFW = (FW - 1) * DW + 1; @@ -37,11 +35,10 @@ __host__ __device__ void get_receptive_field_size(uint32_t OH, uint32_t OW, // F == 0: FH/FW should be retrieved from param // F != 0: FH/FW should use F template -__global__ void kern(int32_t* dst, const int8_t* src, const int8_t* flt, - Param param) { +__global__ void kern(int32_t* dst, const int8_t* src, const int8_t* flt, Param param) { // each block would process 128 channels at every 4x4 spatial area. - uint32_t C = param.src_chl, IH = param.src_h, IW = param.src_w, - OH = param.out_h, OW = param.out_w, FH = F == 0 ? param.flt_h : F, + uint32_t C = param.src_chl, IH = param.src_h, IW = param.src_w, OH = param.out_h, + OW = param.out_w, FH = F == 0 ? param.flt_h : F, FW = F == 0 ? param.flt_w : F, PH = param.pad_h, PW = param.pad_w, SH = param.stride_h, SW = param.stride_w, DH = param.dilation_h, DW = param.dilation_w; @@ -52,9 +49,8 @@ __global__ void kern(int32_t* dst, const int8_t* src, const int8_t* flt, uint32_t c_beg = blockIdx.x * 128, c_end = min((blockIdx.x + 1) * 128, C), c_cur = c_beg + threadIdx.x * 4; uint32_t tidx = threadIdx.x, tidy = threadIdx.y, tidz = threadIdx.z, - tid = (tidx << 0) | (tidy << 5) | (tidz << 7), - tid_stride = 32 * 4 * 4, tidyz = (tidy << 0) | (tidz << 2), - tidyz_stride = 4 * 4; + tid = (tidx << 0) | (tidy << 5) | (tidz << 7), tid_stride = 32 * 4 * 4, + tidyz = (tidy << 0) | (tidz << 2), tidyz_stride = 4 * 4; uint32_t oh = bidz * 4 + tidz, ow = bidy * 4 + tidy; uint32_t C_32 = C >> 2; // calculate receptive field of 4x4 output pixels @@ -70,14 +66,13 @@ __global__ void kern(int32_t* dst, const int8_t* src, const int8_t* flt, static_cast(shared + 128 * FH * FW * sizeof(int8_t))); uint32_t* flt_shared_32 = reinterpret_cast(flt_shared); - int8_t* src_shared = static_cast( - static_cast(shared + 128 * FH * FW * sizeof(int8_t) + - 128 * FH * FW * sizeof(int8_t))); + int8_t* src_shared = static_cast(static_cast( + shared + 128 * FH * FW * sizeof(int8_t) + 128 * FH * FW * sizeof(int8_t))); uint32_t* src_shared_32 = reinterpret_cast(src_shared); int32_t* dst_shared = static_cast(static_cast( - shared + 128 * FH * FW * sizeof(int8_t) + - 128 * FH * FW * sizeof(int8_t) + 128 * RH * RW * sizeof(int8_t))); + shared + 128 * FH * FW * sizeof(int8_t) + 128 * FH * FW * sizeof(int8_t) + + 128 * RH * RW * sizeof(int8_t))); // read original filter to shared memory // *_int8 vars must be multiples of 4 here. @@ -152,8 +147,7 @@ __global__ void kern(int32_t* dst, const int8_t* src, const int8_t* flt, for (uint32_t k = 0; k < 4; ++k) { uint32_t c = c_beg + tidx + k * 32; if (c < c_end) { - dst[(oh * OW + ow) * C + c] = - dst_shared[tidyz * 129 + tidx + k * 32]; + dst[(oh * OW + ow) * C + c] = dst_shared[tidyz * 129 + tidx + k * 32]; } } } @@ -161,15 +155,13 @@ __global__ void kern(int32_t* dst, const int8_t* src, const int8_t* flt, } // anonymous namespace -void megdnn::cuda::conv_bias::chanwise::run_fwd_8x8x32(int32_t* dst, - const int8_t* src, - const int8_t* flt, - const Param& param, - cudaStream_t stream) { - uint32_t N = param.batch, C = param.src_chl, IH = param.src_h, - IW = param.src_w, OH = param.out_h, OW = param.out_w, - FH = param.flt_h, FW = param.flt_w, SH = param.stride_h, - SW = param.stride_w, DH = param.dilation_h, DW = param.dilation_w; +void megdnn::cuda::conv_bias::chanwise::run_fwd_8x8x32( + int32_t* dst, const int8_t* src, const int8_t* flt, const Param& param, + cudaStream_t stream) { + uint32_t N = param.batch, C = param.src_chl, IH = param.src_h, IW = param.src_w, + OH = param.out_h, OW = param.out_w, FH = param.flt_h, FW = param.flt_w, + SH = param.stride_h, SW = param.stride_w, DH = param.dilation_h, + DW = param.dilation_w; dim3 threads(32, 4, 4); dim3 blocks(DIVUP(C, 128), DIVUP(OW, 4), DIVUP(OH, 4)); @@ -185,8 +177,8 @@ void megdnn::cuda::conv_bias::chanwise::run_fwd_8x8x32(int32_t* dst, // use 129 instead of 128 to avoid shared memory bank conflict uint32_t dst_shared_mem_size = 129 * 4 * 4 * sizeof(int32_t); - uint32_t shared_mem_size = 2 * filter_shared_mem_size + - src_shared_mem_size + dst_shared_mem_size; + uint32_t shared_mem_size = + 2 * filter_shared_mem_size + src_shared_mem_size + dst_shared_mem_size; void (*kptr)(int32_t*, const int8_t*, const int8_t*, Param) = kern<0>; if (FH == 1 && FW == 1) @@ -200,8 +192,7 @@ void megdnn::cuda::conv_bias::chanwise::run_fwd_8x8x32(int32_t* dst, int32_t* dptr = dst + n * C * OH * OW; const int8_t* sptr = src + n * C * IH * IW; const int8_t* fptr = flt; - kptr<<>>(dptr, sptr, fptr, - param); + kptr<<>>(dptr, sptr, fptr, param); } after_kernel_launch(); } diff --git a/dnn/src/cuda/conv_bias/chanwise/fwd_small.cu b/dnn/src/cuda/conv_bias/chanwise/fwd_small.cu index f0b6616a..43489a18 100644 --- a/dnn/src/cuda/conv_bias/chanwise/fwd_small.cu +++ b/dnn/src/cuda/conv_bias/chanwise/fwd_small.cu @@ -34,17 +34,18 @@ enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD }; // one each in the lower and upper half of a tile. // Backprop input direction is the same as forward direction with the filter // rotated by 180°. -template +template < + typename T, typename T2, DepthwiseConv2dDirection kDirection, + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, + bool kKnownEvenHeight> __global__ void #if __CUDA_ARCH__ >= 750 __launch_bounds__(1024, 1) #else __launch_bounds__(1024, 2) #endif - DepthwiseConv2dGPUKernelNCHWSmall(const Param param, const T* input, - const T* filter, T* output) { + DepthwiseConv2dGPUKernelNCHWSmall( + const Param param, const T* input, const T* filter, T* output) { // Holds block plus halo and filter data for blockDim.z depths. extern __shared__ __align__(8) unsigned char shared_memory[]; static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); @@ -54,12 +55,10 @@ __launch_bounds__(1024, 2) const int in_height = static_cast(param.src_h); const int in_width = static_cast(param.src_w); const int in_depth = static_cast(param.src_chl); - const int filter_height = kKnownFilterHeight < 0 - ? static_cast(param.flt_h) - : kKnownFilterHeight; - const int filter_width = kKnownFilterWidth < 0 - ? static_cast(param.flt_w) - : kKnownFilterWidth; + const int filter_height = + kKnownFilterHeight < 0 ? static_cast(param.flt_h) : kKnownFilterHeight; + const int filter_width = + kKnownFilterWidth < 0 ? static_cast(param.flt_w) : kKnownFilterWidth; const int pad_height = static_cast(param.pad_h); const int pad_width = static_cast(param.pad_w); @@ -139,8 +138,7 @@ __launch_bounds__(1024, 2) if (filter_write_offset != 0) { const int filter_offset = - (channel + filter_channel) % in_depth * filter_pixels + - filter_pix; + (channel + filter_channel) % in_depth * filter_pixels + filter_pix; shared_data[filter_write_offset] = *(filter_offset + filter); } @@ -181,23 +179,23 @@ __launch_bounds__(1024, 2) } } -template -void LaunchDepthwiseConv2dGPUSmall(const Param& param, const T* input, - const T* filter, T* output, - cudaStream_t stream) { +template < + typename T, typename T2, DepthwiseConv2dDirection kDirection, + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, + bool kKnownEvenHeight> +void LaunchDepthwiseConv2dGPUSmall( + const Param& param, const T* input, const T* filter, T* output, + cudaStream_t stream) { const int block_height = (param.src_h + 1) / 2; dim3 block_dim; int block_count; void (*kernel)(const Param, const T*, const T*, T*); block_dim = dim3(param.src_w, block_height, kBlockDepth); - block_count = - DIVUP(param.batch * param.src_chl * param.chl_mul, kBlockDepth) * - kBlockDepth; + block_count = DIVUP(param.batch * param.src_chl * param.chl_mul, kBlockDepth) * + kBlockDepth; kernel = DepthwiseConv2dGPUKernelNCHWSmall< - T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, - kBlockDepth, kKnownEvenHeight>; + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, + kKnownEvenHeight>; const int tile_width = param.src_w + param.flt_w - 1; const int tile_height = block_height * 2 + param.flt_h - 1; const int tile_pixels = tile_height * tile_width; @@ -206,48 +204,51 @@ void LaunchDepthwiseConv2dGPUSmall(const Param& param, const T* input, kBlockDepth * (tile_pixels + filter_pixels) * sizeof(T); const int num_outputs = param.out_h * param.out_w * block_count; - block_count = GetFixedBlockSize(num_outputs, kernel, shared_memory_size, - block_dim.x * block_dim.y * block_dim.z); + block_count = GetFixedBlockSize( + num_outputs, kernel, shared_memory_size, + block_dim.x * block_dim.y * block_dim.z); kernel<<>>( param, input, filter, output); after_kernel_launch(); } -template -void LaunchDepthwiseConv2dGPUSmall(const Param& param, const T* input, - const T* filter, T* output, - cudaStream_t stream) { +template < + typename T, typename T2, DepthwiseConv2dDirection kDirection, + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth> +void LaunchDepthwiseConv2dGPUSmall( + const Param& param, const T* input, const T* filter, T* output, + cudaStream_t stream) { if (param.src_h & 1) { return LaunchDepthwiseConv2dGPUSmall< - T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, - kBlockDepth, false>(param, input, filter, output, stream); + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, + false>(param, input, filter, output, stream); } else { return LaunchDepthwiseConv2dGPUSmall< - T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, - kBlockDepth, true>(param, input, filter, output, stream); + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, + true>(param, input, filter, output, stream); } } -template -void LaunchDepthwiseConv2dGPUSmall(const Param& param, const T* input, - const T* filter, T* output, - cudaStream_t stream) { +template < + typename T, typename T2, DepthwiseConv2dDirection kDirection, + int kKnownFilterWidth, int kKnownFilterHeight> +void LaunchDepthwiseConv2dGPUSmall( + const Param& param, const T* input, const T* filter, T* output, + cudaStream_t stream) { // Maximize (power of two) kBlockDepth while keeping a block within 1024 // threads (2 pixels per thread). const int block_pixels = (param.src_h + 1) / 2 * param.src_w; if (block_pixels > 256) { - LaunchDepthwiseConv2dGPUSmall( + LaunchDepthwiseConv2dGPUSmall< + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, 2>( param, input, filter, output, stream); } else if (block_pixels > 128) { - LaunchDepthwiseConv2dGPUSmall( + LaunchDepthwiseConv2dGPUSmall< + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, 4>( param, input, filter, output, stream); } else { - LaunchDepthwiseConv2dGPUSmall( + LaunchDepthwiseConv2dGPUSmall< + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, 8>( param, input, filter, output, stream); } } @@ -260,27 +261,29 @@ namespace conv_bias { namespace chanwise { // =====================================fwd===================================== -#define LAUNCH(type, type2) \ - if (param.flt_h == 3 && param.flt_w == 3) { \ - LaunchDepthwiseConv2dGPUSmall< \ - type, type2, DepthwiseConv2dDirection::DIRECTION_FORWARD, 3, \ - 3>(param, src, flt, dst, stream); \ - } else { \ - LaunchDepthwiseConv2dGPUSmall< \ - type, type2, DepthwiseConv2dDirection::DIRECTION_FORWARD, -1, \ - -1>(param, src, flt, dst, stream); \ +#define LAUNCH(type, type2) \ + if (param.flt_h == 3 && param.flt_w == 3) { \ + LaunchDepthwiseConv2dGPUSmall< \ + type, type2, DepthwiseConv2dDirection::DIRECTION_FORWARD, 3, 3>( \ + param, src, flt, dst, stream); \ + } else { \ + LaunchDepthwiseConv2dGPUSmall< \ + type, type2, DepthwiseConv2dDirection::DIRECTION_FORWARD, -1, -1>( \ + param, src, flt, dst, stream); \ } template <> -void run_fwd_small(float* dst, const float* src, const float* flt, - const Param& param, cudaStream_t stream) { +void run_fwd_small( + float* dst, const float* src, const float* flt, const Param& param, + cudaStream_t stream) { LAUNCH(float, float2); } #if CUDA_VERSION >= 9000 template <> -void run_fwd_small(__half* dst, const __half* src, const __half* flt, - const Param& param, cudaStream_t stream) { +void run_fwd_small( + __half* dst, const __half* src, const __half* flt, const Param& param, + cudaStream_t stream) { LAUNCH(__half, __half2); } #endif diff --git a/dnn/src/cuda/conv_bias/chanwise/kern.cuh b/dnn/src/cuda/conv_bias/chanwise/kern.cuh index 256a9e73..4b5a60d0 100644 --- a/dnn/src/cuda/conv_bias/chanwise/kern.cuh +++ b/dnn/src/cuda/conv_bias/chanwise/kern.cuh @@ -25,8 +25,8 @@ namespace conv_bias { namespace chanwise { struct Param { - uint32_t batch, src_chl, src_h, src_w, chl_mul, flt_h, flt_w, out_h, out_w, - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w; + uint32_t batch, src_chl, src_h, src_w, chl_mul, flt_h, flt_w, out_h, out_w, pad_h, + pad_w, stride_h, stride_w, dilation_h, dilation_w; #if MEGDNN_CC_HOST static Param from_fwd_args(const BiasForwardSizeArgs& args) { #define U(v) static_cast(v) @@ -54,16 +54,17 @@ struct Param { }; template -void run_fwd(T* dst, const T* src, const T* flt, const Param& param, - cudaStream_t stream); +void run_fwd( + T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream); template -void run_fwd_small(T* dst, const T* src, const T* flt, const Param& param, - cudaStream_t stream); +void run_fwd_small( + T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream); // implemented in fwd_8x8x32.cu -void run_fwd_8x8x32(int32_t* dst, const int8_t* src, const int8_t* flt, - const Param& param, cudaStream_t stream); +void run_fwd_8x8x32( + int32_t* dst, const int8_t* src, const int8_t* flt, const Param& param, + cudaStream_t stream); } // namespace chanwise } // namespace conv_bias diff --git a/dnn/src/cuda/conv_bias/chanwise/kern_helper.cuh b/dnn/src/cuda/conv_bias/chanwise/kern_helper.cuh index 53765e76..ab270364 100644 --- a/dnn/src/cuda/conv_bias/chanwise/kern_helper.cuh +++ b/dnn/src/cuda/conv_bias/chanwise/kern_helper.cuh @@ -26,8 +26,7 @@ namespace chanwise { /*! * \brief return a / b and set mod to a % b */ -__device__ __forceinline__ uint32_t div_mod(uint32_t a, uint32_t b, - uint32_t& mod) { +__device__ __forceinline__ uint32_t div_mod(uint32_t a, uint32_t b, uint32_t& mod) { uint32_t ret = a / b; mod = a - ret * b; return ret; @@ -38,8 +37,7 @@ __device__ __forceinline__ uint32_t div_mod(uint32_t a, uint32_t b, * \param rs row stride */ template -__device__ __forceinline__ void block_memcpy(T* dst, const T* src, - uint32_t size) { +__device__ __forceinline__ void block_memcpy(T* dst, const T* src, uint32_t size) { for (uint32_t i = threadIdx.x; i < size; i += blockDim.x) { dst[i] = src[i]; } diff --git a/dnn/src/cuda/conv_bias/chanwise/launch_config.cpp b/dnn/src/cuda/conv_bias/chanwise/launch_config.cpp index 164a85f2..c1186f40 100644 --- a/dnn/src/cuda/conv_bias/chanwise/launch_config.cpp +++ b/dnn/src/cuda/conv_bias/chanwise/launch_config.cpp @@ -16,9 +16,9 @@ using namespace megdnn; using namespace cuda; using namespace conv_bias; -int chanwise::GetFixedBlockSize1(int work_element_count, const void* func, - int dynamic_shared_memory_size, - int fixed_block_size) { +int chanwise::GetFixedBlockSize1( + int work_element_count, const void* func, int dynamic_shared_memory_size, + int fixed_block_size) { int block_count = 0; cuda_check(cudaOccupancyMaxActiveBlocksPerMultiprocessor( diff --git a/dnn/src/cuda/conv_bias/chanwise/launch_config.cuh b/dnn/src/cuda/conv_bias/chanwise/launch_config.cuh index d5000853..548d55e1 100644 --- a/dnn/src/cuda/conv_bias/chanwise/launch_config.cuh +++ b/dnn/src/cuda/conv_bias/chanwise/launch_config.cuh @@ -16,15 +16,17 @@ namespace cuda { namespace conv_bias { namespace chanwise { -int GetFixedBlockSize1(int work_element_count, const void* func, - int dynamic_shared_memory_size, int fixed_block_size); +int GetFixedBlockSize1( + int work_element_count, const void* func, int dynamic_shared_memory_size, + int fixed_block_size); template -int GetFixedBlockSize(int work_element_count, DeviceFunc func, - int dynamic_shared_memory_size, int fixed_block_size) { - return GetFixedBlockSize1(work_element_count, - reinterpret_cast(func), - dynamic_shared_memory_size, fixed_block_size); +int GetFixedBlockSize( + int work_element_count, DeviceFunc func, int dynamic_shared_memory_size, + int fixed_block_size) { + return GetFixedBlockSize1( + work_element_count, reinterpret_cast(func), + dynamic_shared_memory_size, fixed_block_size); } } // namespace chanwise diff --git a/dnn/src/cuda/conv_bias/chanwise_8x8x32.cpp b/dnn/src/cuda/conv_bias/chanwise_8x8x32.cpp index 178e78e7..5e5560b7 100644 --- a/dnn/src/cuda/conv_bias/chanwise_8x8x32.cpp +++ b/dnn/src/cuda/conv_bias/chanwise_8x8x32.cpp @@ -10,19 +10,17 @@ */ #include "./algo.h" -#include "src/cuda/utils.h" -#include "src/cuda/conv_bias/chanwise/kern.cuh" #include "src/common/conv_bias.h" #include "src/common/elemwise/kern_defs.cuh" +#include "src/cuda/conv_bias/chanwise/kern.cuh" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace conv_bias; -bool ConvBiasForwardImpl::AlgoChanwise8x8x32::is_available( - const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { +bool ConvBiasForwardImpl::AlgoChanwise8x8x32::is_available(const SizeArgs& args) const { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } if (args.z_layout->ndim > 0) @@ -43,36 +41,33 @@ size_t ConvBiasForwardImpl::AlgoChanwise8x8x32::get_workspace_in_bytes( auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); return dst_layout.span().dist_byte(); } return 0; } void ConvBiasForwardImpl::AlgoChanwise8x8x32::exec(const ExecArgs& args) const { - WorkspaceBundle bundle{args.workspace.raw_ptr, - {get_workspace_in_bytes(args)}}; + WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; auto conv_dst_tensor = *args.dst_tensor; if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor.raw_ptr = bundle.get(0); conv_dst_tensor.layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - conv_dst_tensor.layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, + conv_dst_tensor.layout.dtype); } { auto kparam = chanwise::Param::from_fwd_args(args); auto stream = cuda_stream(args.handle); - chanwise::run_fwd_8x8x32(conv_dst_tensor.ptr(), - args.src_tensor->ptr(), - args.filter_tensor->ptr(), kparam, - stream); + chanwise::run_fwd_8x8x32( + conv_dst_tensor.ptr(), args.src_tensor->ptr(), + args.filter_tensor->ptr(), kparam, stream); } - handle_bias_and_nonlinear(args.handle, args.nonlinear_mode, - &conv_dst_tensor, args.dst_tensor, - args.bias_tensor); + handle_bias_and_nonlinear( + args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, + args.bias_tensor); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/chanwise_small.cpp b/dnn/src/cuda/conv_bias/chanwise_small.cpp index a87aca06..b2097ca9 100644 --- a/dnn/src/cuda/conv_bias/chanwise_small.cpp +++ b/dnn/src/cuda/conv_bias/chanwise_small.cpp @@ -21,17 +21,15 @@ using namespace conv_bias; namespace { inline bool is_available_small(const chanwise::Param& param) { return param.chl_mul == 1 && param.stride_h == 1 && param.stride_w == 1 && - param.src_h <= 32 && param.src_w <= 32 && - param.src_h == param.out_h && param.src_w == param.out_w && - param.pad_h < param.flt_h && param.pad_w < param.flt_w && + param.src_h <= 32 && param.src_w <= 32 && param.src_h == param.out_h && + param.src_w == param.out_w && param.pad_h < param.flt_h && + param.pad_w < param.flt_w && param.flt_h * param.flt_w <= (param.src_h + 1) / 2 * param.src_w; } } // anonymous namespace -bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available( - const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { +bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available(const SizeArgs& args) const { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } if (args.src_layout->dtype == args.filter_layout->dtype && @@ -58,34 +56,31 @@ size_t ConvBiasForwardImpl::AlgoChanwiseSmall::get_workspace_in_bytes( auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); return dst_layout.span().dist_byte(); } return 0; } void ConvBiasForwardImpl::AlgoChanwiseSmall::exec(const ExecArgs& args) const { - WorkspaceBundle bundle{args.workspace.raw_ptr, - {get_workspace_in_bytes(args)}}; + WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; auto conv_dst_tensor = *args.dst_tensor; if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor.raw_ptr = bundle.get(0); conv_dst_tensor.layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - conv_dst_tensor.layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, + conv_dst_tensor.layout.dtype); } { auto kparam = chanwise::Param::from_fwd_args(args); auto stream = cuda_stream(args.handle); switch (args.src_layout->dtype.enumv()) { case DTypeEnum::Float32: - chanwise::run_fwd_small(conv_dst_tensor.ptr(), - args.src_tensor->ptr(), - args.filter_tensor->ptr(), - kparam, stream); + chanwise::run_fwd_small( + conv_dst_tensor.ptr(), args.src_tensor->ptr(), + args.filter_tensor->ptr(), kparam, stream); break; #if CUDA_VERSION >= 9000 case DTypeEnum::Float16: @@ -100,9 +95,9 @@ void ConvBiasForwardImpl::AlgoChanwiseSmall::exec(const ExecArgs& args) const { megdnn_assert_internal(0); } } - handle_bias_and_nonlinear(args.handle, args.nonlinear_mode, - &conv_dst_tensor, args.dst_tensor, - args.bias_tensor); + handle_bias_and_nonlinear( + args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, + args.bias_tensor); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/conv_bias_int8.cuh b/dnn/src/cuda/conv_bias/conv_bias_int8.cuh index 8585a710..ad25acfa 100644 --- a/dnn/src/cuda/conv_bias/conv_bias_int8.cuh +++ b/dnn/src/cuda/conv_bias/conv_bias_int8.cuh @@ -37,88 +37,88 @@ struct LaunchConfig { template void do_conv_bias_int8_implicit_gemm_cdiv4hwn4( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width( const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias, - Epilogue epilogue, const convolution::ConvParam& param, float alpha, - float beta, cudaStream_t stream); + Epilogue epilogue, const convolution::ConvParam& param, float alpha, float beta, + cudaStream_t stream); } // namespace conv_bias_int8 } // namespace cuda } // namespace megdnn -#define MARK_USED_VAR \ - MEGDNN_MARK_USED_VAR(n + ci + hi + wi + co + fh + fw + ho + wo + ph + pw + \ - sh + sw + dh + dw); +#define MARK_USED_VAR \ + MEGDNN_MARK_USED_VAR( \ + n + ci + hi + wi + co + fh + fw + ho + wo + ph + pw + sh + sw + dh + dw); #define UNPACK_CONV_PARAMETER(_filter_meta, _param) \ size_t ph = _param.pad_h, pw = _param.pad_w; \ diff --git a/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp index 7336d1b8..778781fb 100644 --- a/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp +++ b/dnn/src/cuda/conv_bias/conv_nchwqs8.cpp @@ -22,11 +22,10 @@ using namespace cuda; using namespace conv_bias; namespace { -inline void deduce_reformat_layout(std::unique_ptr& relayout, - const TensorLayout& src_layout, - TensorLayout& dst_layout, - RelayoutFormat::Param::Mode mode, - const int oc = 0, const int group = 1) { +inline void deduce_reformat_layout( + std::unique_ptr& relayout, const TensorLayout& src_layout, + TensorLayout& dst_layout, RelayoutFormat::Param::Mode mode, const int oc = 0, + const int group = 1) { if (src_layout.ndim > 0) { RelayoutFormat::Param trans_param; trans_param.mode = mode; @@ -48,28 +47,27 @@ std::pair sub_opr_config( TensorLayout inner_dst_layout; auto relayout_src = args.handle->create_operator(); - deduce_reformat_layout(relayout_src, *args.src_layout, inner_src_layout, - RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, - args.filter_meta.group); - deduce_reformat_layout(relayout_src, *args.filter_layout, - inner_filter_layout, - RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT); + deduce_reformat_layout( + relayout_src, *args.src_layout, inner_src_layout, + RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, args.filter_meta.group); + deduce_reformat_layout( + relayout_src, *args.filter_layout, inner_filter_layout, + RelayoutFormat::Param::Mode::NCHW_NCHW4_WEIGHT); bool dst_float = args.dst_layout->dtype.enumv() == DTypeEnum::Float32; if (dst_float) { inner_dst_layout = *args.dst_layout; inner_bias_layout = *args.bias_layout; inner_z_layout = *args.z_layout; } else { - deduce_reformat_layout(relayout_src, *args.dst_layout, inner_dst_layout, - RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, - args.filter_meta.group); - deduce_reformat_layout(relayout_src, *args.bias_layout, - inner_bias_layout, - RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, - args.filter_meta.group); - deduce_reformat_layout(relayout_src, *args.z_layout, inner_z_layout, - RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, - args.filter_meta.group); + deduce_reformat_layout( + relayout_src, *args.dst_layout, inner_dst_layout, + RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, args.filter_meta.group); + deduce_reformat_layout( + relayout_src, *args.bias_layout, inner_bias_layout, + RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, args.filter_meta.group); + deduce_reformat_layout( + relayout_src, *args.z_layout, inner_z_layout, + RelayoutFormat::Param::Mode::NCHW_NCHW4, 0, args.filter_meta.group); } megdnn::param::ConvBias inner_conv_param = args.opr->param(); @@ -79,8 +77,9 @@ std::pair sub_opr_config( inner_conv_param.format = megdnn::param::ConvBias::Format::NCHW4; } std::pair ret; - ret.first = {inner_src_layout, inner_filter_layout, inner_bias_layout, - inner_z_layout, inner_dst_layout}; + ret.first = { + inner_src_layout, inner_filter_layout, inner_bias_layout, inner_z_layout, + inner_dst_layout}; ret.second = inner_conv_param; return ret; @@ -89,8 +88,8 @@ std::pair sub_opr_config( std::pair> prepare_sub_opr( const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) { auto convbias_opr = args.handle->create_operator(); - set_execution_policy(args.opr, - convbias_opr.get()); + set_execution_policy( + args.opr, convbias_opr.get()); auto&& config = sub_opr_config(args); convbias_opr->param() = config.second; @@ -98,12 +97,13 @@ std::pair> prepare_sub_opr( } } // namespace -std::vector -ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { +std::vector ConvBiasForwardImpl::AlgoFallbackNCHWQS8:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { const ConvBiasForwardImpl* o = static_cast(opr); - SizeArgs args(const_cast(o), layouts[0], layouts[1], - layouts[2], layouts[3], layouts[4], nullptr); + SizeArgs args( + const_cast(o), layouts[0], layouts[1], layouts[2], + layouts[3], layouts[4], nullptr); auto&& config = sub_opr_config(args); @@ -114,8 +114,7 @@ ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_subopr_list( bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } auto&& param = args.opr->param(); @@ -128,8 +127,7 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( bool is_bias_ok = args.bias_layout->ndim == 0 || (args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 && - args.bias_layout->shape[2] == 1 && - args.bias_layout->shape[3] == 1); + args.bias_layout->shape[2] == 1 && args.bias_layout->shape[3] == 1); bool is_ok = is_format_ok && is_version_ok && is_dtype_ok && is_bias_ok; if (!is_ok) { return false; @@ -140,15 +138,15 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available( bool is_relayout_ok = true; if (args.dst_layout->dtype.enumv() != DTypeEnum::Float32) { is_relayout_ok = relayout_format::RelayoutFormatFast::usable( - config.first[4], *args.dst_layout, - RelayoutFormat::Param::Mode::NCHW4_NCHW); + config.first[4], *args.dst_layout, + RelayoutFormat::Param::Mode::NCHW4_NCHW); } return is_relayout_ok && has_available_algo( static_cast(config.second.get()), - config.first[0], config.first[1], config.first[2], - config.first[3], config.first[4]); + config.first[0], config.first[1], config.first[2], config.first[3], + config.first[4]); } WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle( @@ -165,9 +163,9 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_bundle( config.first[0], config.first[1], config.first[2], config.first[3], config.first[4], nullptr); - return WorkspaceBundle(ptr, {config.first[0].span().dist_byte(), - config.first[1].span().dist_byte(), ws_bias, - ws_z, ws_dst, inner_ws}); + return WorkspaceBundle( + ptr, {config.first[0].span().dist_byte(), + config.first[1].span().dist_byte(), ws_bias, ws_z, ws_dst, inner_ws}); } size_t ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_in_bytes( @@ -176,8 +174,7 @@ size_t ConvBiasForwardImpl::AlgoFallbackNCHWQS8::get_workspace_in_bytes( return trans_bundle.total_size_in_bytes(); } -void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec( - const ExecArgs& args) const { +void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec(const ExecArgs& args) const { auto relayout_nchw_nchw4 = args.handle->create_operator(); RelayoutFormat::Param in_trans; in_trans.mode = RelayoutFormat::Param::Mode::NCHW_NCHW4; @@ -223,8 +220,7 @@ void ConvBiasForwardImpl::AlgoFallbackNCHWQS8::exec( relayout_nchw_nchw4->exec(*args.z_tensor, inner_z, {}); } config.second->exec( - inner_src, inner_weight, inner_bias, inner_z, inner_dst, - nullptr, + inner_src, inner_weight, inner_bias, inner_z, inner_dst, nullptr, Workspace((dt_byte*)bundle.get(5), bundle.get_size(5))); relayout_nchw4_nchw->exec(inner_dst, *args.dst_tensor, {}); } diff --git a/dnn/src/cuda/conv_bias/cudnn_conv.cpp b/dnn/src/cuda/conv_bias/cudnn_conv.cpp index 9d4fa298..ae0d3450 100644 --- a/dnn/src/cuda/conv_bias/cudnn_conv.cpp +++ b/dnn/src/cuda/conv_bias/cudnn_conv.cpp @@ -9,24 +9,22 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "src/common/conv_bias.h" #include "src/cuda/conv_bias/algo.h" #include "src/cuda/cudnn_wrapper.h" #include "src/cuda/utils.h" -#include "src/common/conv_bias.h" using namespace megdnn; using namespace cuda; using namespace conv_bias; -bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( - const SizeArgs& args) const { +bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(const SizeArgs& args) const { if (args.z_layout->ndim > 0) return false; if (args.filter_meta.format != Param::Format::NCHW && args.filter_meta.format != Param::Format::NHWC) { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } } @@ -48,9 +46,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); } SizeArgs conv_args = args; conv_args.dst_layout = &dst_layout; @@ -62,9 +59,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( size_t workspace_size; auto status = cudnnGetConvolutionForwardWorkspaceSize( - conv_args.handle->cudnn_handle(), D.src_desc.desc, - D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, - m_cudnn_enum, &workspace_size); + conv_args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, + D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &workspace_size); return status == CUDNN_STATUS_SUCCESS; } @@ -74,9 +70,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle( SmallVector sizes; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); sizes.push_back(dst_layout.span().dist_byte()); } @@ -88,12 +83,12 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle( size_t conv_workspace_size; auto status = cudnnGetConvolutionForwardWorkspaceSize( - conv_args.handle->cudnn_handle(), D.src_desc.desc, - D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, - m_cudnn_enum, &conv_workspace_size); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, - "conv fwd get workspace failed: %s; info: %s", - cudnnGetErrorString(status), args.to_string().c_str()); + conv_args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, + D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &conv_workspace_size); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, + "conv fwd get workspace failed: %s; info: %s", cudnnGetErrorString(status), + args.to_string().c_str()); sizes.insert(sizes.begin(), conv_workspace_size); return {ptr, std::move(sizes)}; } @@ -109,9 +104,9 @@ void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const { if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor.raw_ptr = bundle.get(1); conv_dst_tensor.layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - conv_dst_tensor.layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, + conv_dst_tensor.layout.dtype); } ExecArgs conv_args = args; @@ -126,17 +121,17 @@ void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const { auto status = cudnnConvolutionForward( conv_args.handle->cudnn_handle(), &alpha, D.src_desc.desc, conv_args.src_tensor->raw_ptr, D.filter_desc.desc, - conv_args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, - m_cudnn_enum, conv_workspace.raw_ptr, conv_workspace.size, - &beta, D.dst_desc.desc, conv_args.dst_tensor->raw_ptr); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, - "conv fwd failed: %s; info: %s", cudnnGetErrorString(status), - conv_args.to_string().c_str()); + conv_args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, m_cudnn_enum, + conv_workspace.raw_ptr, conv_workspace.size, &beta, D.dst_desc.desc, + conv_args.dst_tensor->raw_ptr); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", + cudnnGetErrorString(status), conv_args.to_string().c_str()); } - handle_bias_and_nonlinear(args.handle, args.nonlinear_mode, - &conv_dst_tensor, args.dst_tensor, - args.bias_tensor); + handle_bias_and_nonlinear( + args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, + args.bias_tensor); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp index 0980f8cb..6d86a49e 100644 --- a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp +++ b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp @@ -13,10 +13,10 @@ #include "./algo.h" +#include "src/common/conv_bias.h" #include "src/cuda/conv_bias/helper.h" #include "src/cuda/cudnn_wrapper.h" #include "src/cuda/utils.h" -#include "src/common/conv_bias.h" using namespace megdnn; using namespace cuda; @@ -26,8 +26,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( const SizeArgs& args) const { if (args.filter_meta.format != Param::Format::NCHW && args.filter_meta.format != Param::Format::NHWC) { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } } @@ -44,8 +43,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( } if (args.bias_layout->ndim == 0 || - !check_bias_share_in_channel(*(args.bias_layout), - args.opr->param().format)) { + !check_bias_share_in_channel(*(args.bias_layout), args.opr->param().format)) { return false; } auto&& param = args.opr->param(); @@ -111,8 +109,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( return false; // sm version auto&& device_prop = current_device_prop(); - if (device_prop.major < 7 || - (device_prop.major == 7 && device_prop.minor < 5)) + if (device_prop.major < 7 || (device_prop.major == 7 && device_prop.minor < 5)) return false; } @@ -130,11 +127,10 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) return false; MEGDNN_FALLTHRU // XXX: why? - case param::ConvBias::NonlineMode::IDENTITY: - if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) - break; - if (m_cudnn_enum != - CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) { + case param::ConvBias::NonlineMode::IDENTITY + : if (args.src_layout->dtype.category() == + DTypeCategory::QUANTIZED) break; + if (m_cudnn_enum != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) { // cudnn require algo to // CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM // when activation if IDENTITY @@ -151,8 +147,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( size_t workspace_size; auto status = cudnnGetConvolutionForwardWorkspaceSize( args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, - D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, - &workspace_size); + D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &workspace_size); return status == CUDNN_STATUS_SUCCESS; } @@ -164,11 +159,11 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes( size_t workspace_size; auto status = cudnnGetConvolutionForwardWorkspaceSize( args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, - D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, - &workspace_size); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, - "conv fwd get workspace failed: %s; info: %s", - cudnnGetErrorString(status), args.to_string().c_str()); + D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &workspace_size); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, + "conv fwd get workspace failed: %s; info: %s", cudnnGetErrorString(status), + args.to_string().c_str()); if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() && args.src_layout->dtype.category() != DTypeCategory::FLOAT) { // cudnn require bias to be float when executing CONFIG_INT @@ -203,8 +198,7 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( } }; - auto src_dtype = args.src_layout->dtype, - filter_dtype = args.filter_layout->dtype, + auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype, dst_dtype = args.dst_layout->dtype; megdnn_assert( (src_dtype.category() == dst_dtype.category()) || @@ -220,12 +214,12 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( alpha /= get_scale(args.dst_layout->dtype); if (args.z_layout->ndim > 0 && args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { - beta = get_scale(args.z_layout->dtype) / - get_scale(args.dst_layout->dtype); + beta = get_scale(args.z_layout->dtype) / get_scale(args.dst_layout->dtype); } if (args.bias_layout->dtype.category() == DTypeCategory::QUANTIZED) { - megdnn_assert(fabs(expected_bias_scale - - get_scale(args.bias_layout->dtype)) < 1e-4); + megdnn_assert( + fabs(expected_bias_scale - get_scale(args.bias_layout->dtype)) < + 1e-4); } } @@ -241,8 +235,9 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( float_bias_layout.dtype = dtype::Float32(); auto bias_size_in_bytes = float_bias_layout.span().dist_byte(); megdnn_assert(args.workspace.size >= bias_size_in_bytes); - cvt->exec({args.bias_tensor->raw_ptr, converted_bias_layout}, - TensorND{workspace_ptr, float_bias_layout}); + cvt->exec( + {args.bias_tensor->raw_ptr, converted_bias_layout}, + TensorND{workspace_ptr, float_bias_layout}); bias_ptr = workspace_ptr; workspace_ptr += bias_size_in_bytes; @@ -254,33 +249,30 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( status = cudnnConvolutionBiasActivationForward( args.handle->cudnn_handle(), &alpha, D.src_desc.desc, args.src_tensor->raw_ptr, D.filter_desc.desc, - args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, - m_cudnn_enum, workspace_ptr, workspace_size, &beta, - D.dst_desc.desc, args.dst_tensor->raw_ptr, D.bias_desc.desc, - bias_ptr, D.conv_desc.act_desc, D.dst_desc.desc, - args.dst_tensor->raw_ptr); + args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, m_cudnn_enum, + workspace_ptr, workspace_size, &beta, D.dst_desc.desc, + args.dst_tensor->raw_ptr, D.bias_desc.desc, bias_ptr, + D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr); } else { status = cudnnConvolutionBiasActivationForward( args.handle->cudnn_handle(), &alpha, D.src_desc.desc, args.src_tensor->raw_ptr, D.filter_desc.desc, - args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, - m_cudnn_enum, workspace_ptr, workspace_size, &beta, - D.z_desc.desc, args.z_tensor->raw_ptr, D.bias_desc.desc, - bias_ptr, D.conv_desc.act_desc, D.dst_desc.desc, - args.dst_tensor->raw_ptr); + args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, m_cudnn_enum, + workspace_ptr, workspace_size, &beta, D.z_desc.desc, + args.z_tensor->raw_ptr, D.bias_desc.desc, bias_ptr, + D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr); } - megdnn_assert(status == CUDNN_STATUS_SUCCESS, - "conv fwd failed: %s; info: %s, algo %s", - cudnnGetErrorString(status), args.to_string().c_str(), - name()); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s, algo %s", + cudnnGetErrorString(status), args.to_string().c_str(), name()); // Noline switch (args.nonlinear_mode) { case param::ConvBias::NonlineMode::RELU: break; case param::ConvBias::NonlineMode::SIGMOID: { - megdnn_assert(args.dst_layout->dtype.category() != - DTypeCategory::QUANTIZED); + megdnn_assert( + args.dst_layout->dtype.category() != DTypeCategory::QUANTIZED); auto&& elem_opr = args.handle->create_operator(); elem_opr->param().mode = Elemwise::Param::Mode::SIGMOID; elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); @@ -289,21 +281,16 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( case param::ConvBias::NonlineMode::IDENTITY: break; case param::ConvBias::NonlineMode::H_SWISH: { - megdnn_assert(args.dst_layout->dtype.category() == - DTypeCategory::QUANTIZED || - (args.dst_layout->dtype.category() == - DTypeCategory::FLOAT && - args.opr->param().format == - param::ConvBias::Format::NCHW4_NCHW)); + megdnn_assert( + args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED || + (args.dst_layout->dtype.category() == DTypeCategory::FLOAT && + args.opr->param().format == param::ConvBias::Format::NCHW4_NCHW)); if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) { - auto&& elem_opr = - args.handle->create_operator(); - elem_opr->param().mode = - ElemwiseMultiType::Param::Mode::QH_SWISH; + auto&& elem_opr = args.handle->create_operator(); + elem_opr->param().mode = ElemwiseMultiType::Param::Mode::QH_SWISH; elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); } else { - auto&& elem_opr = - args.handle->create_operator(); + auto&& elem_opr = args.handle->create_operator(); elem_opr->param().mode = ElemwiseForward::Param::Mode::H_SWISH; elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); } diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp b/dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp index 096686d1..d6dcf2de 100644 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_base.cpp @@ -35,15 +35,16 @@ ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::AlgoParam( stage(stage_), access_size(access_size_) {} -std::string -ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::to_string() const { +std::string ConvBiasForwardImpl::AlgoCutlassConvolutionBase::AlgoParam::to_string() + const { /// default algorithm if (threadblock_m == 128 && threadblock_n == 128 && threadblock_k == 32 && warp_m == 32 && warp_n == 64 && warp_k == 32 && stage == 2) { return ""; } - return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, - threadblock_k, warp_m, warp_n, warp_k, stage); + return ssprintf( + "_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, threadblock_k, + warp_m, warp_n, warp_k, stage); } namespace { @@ -106,8 +107,7 @@ struct LayoutPack { LayoutTypeID bias; }; -LayoutPack get_layout_pack(const param::ConvBias::Format format, - int access_type) { +LayoutPack get_layout_pack(const param::ConvBias::Format format, int access_type) { using Format = param::ConvBias::Format; switch (format) { @@ -122,44 +122,30 @@ LayoutPack get_layout_pack(const param::ConvBias::Format format, LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC}; case Format::NCHW4_NCHW32: return {LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorC4RSK4, - LayoutTypeID::kTensorNC32HW32, - LayoutTypeID::kTensorNC32HW32}; + LayoutTypeID::kTensorNC32HW32, LayoutTypeID::kTensorNC32HW32}; case Format::NCHW32: - return {LayoutTypeID::kTensorNC32HW32, - LayoutTypeID::kTensorC32RSK32, - LayoutTypeID::kTensorNC32HW32, - LayoutTypeID::kTensorNC32HW32}; + return {LayoutTypeID::kTensorNC32HW32, LayoutTypeID::kTensorC32RSK32, + LayoutTypeID::kTensorNC32HW32, LayoutTypeID::kTensorNC32HW32}; case Format::NCHW32_NCHW4: - return {LayoutTypeID::kTensorNC32HW32, - LayoutTypeID::kTensorC32RSK32, LayoutTypeID::kTensorNC4HW4, - LayoutTypeID::kTensorNC4HW4}; + return {LayoutTypeID::kTensorNC32HW32, LayoutTypeID::kTensorC32RSK32, + LayoutTypeID::kTensorNC4HW4, LayoutTypeID::kTensorNC4HW4}; case Format::NCHW64: - return {LayoutTypeID::kTensorNC64HW64, - LayoutTypeID::kTensorC64RSK64, - LayoutTypeID::kTensorNC64HW64, - LayoutTypeID::kTensorNC64HW64}; + return {LayoutTypeID::kTensorNC64HW64, LayoutTypeID::kTensorC64RSK64, + LayoutTypeID::kTensorNC64HW64, LayoutTypeID::kTensorNC64HW64}; case Format::NHWC: switch (access_type) { case 4: - return {LayoutTypeID::kTensorNHWC, - LayoutTypeID::kTensorNC4HW4, - LayoutTypeID::kTensorNHWC, - LayoutTypeID::kTensorNHWC}; + return {LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNC4HW4, + LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC}; case 8: - return {LayoutTypeID::kTensorNHWC, - LayoutTypeID::kTensorNC8HW8, - LayoutTypeID::kTensorNHWC, - LayoutTypeID::kTensorNHWC}; + return {LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNC8HW8, + LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC}; case 16: - return {LayoutTypeID::kTensorNHWC, - LayoutTypeID::kTensorNC16HW16, - LayoutTypeID::kTensorNHWC, - LayoutTypeID::kTensorNHWC}; + return {LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNC16HW16, + LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC}; case 32: - return {LayoutTypeID::kTensorNHWC, - LayoutTypeID::kTensorNC32HW32, - LayoutTypeID::kTensorNHWC, - LayoutTypeID::kTensorNHWC}; + return {LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNC32HW32, + LayoutTypeID::kTensorNHWC, LayoutTypeID::kTensorNHWC}; default: megdnn_assert(0, "invalid access_type"); } @@ -168,8 +154,7 @@ LayoutPack get_layout_pack(const param::ConvBias::Format format, } } -EpilogueType get_epilogue_type(const param::ConvBias::NonlineMode mode, - bool clamp) { +EpilogueType get_epilogue_type(const param::ConvBias::NonlineMode mode, bool clamp) { using NonlineMode = param::ConvBias::NonlineMode; if (clamp) { @@ -194,55 +179,53 @@ EpilogueType get_epilogue_type(const param::ConvBias::NonlineMode mode, } // namespace -const Operation* -ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_conv_op( +const Operation* ConvBiasForwardImpl::AlgoCutlassConvolutionBase::get_cutlass_conv_op( const SizeArgs& args, ConvOperator conv_op, ConvType conv_type, bool use_conv_filter_unity_opt, bool without_shared_load) const { auto&& param = args.opr->param(); auto layouts = get_layout_pack(param.format, m_algo_param.access_size); auto epilogue_type = get_epilogue_type( - param.nonlineMode, - args.dst_layout->dtype.enumv() != DTypeEnum::Float32); + param.nonlineMode, args.dst_layout->dtype.enumv() != DTypeEnum::Float32); cutlass::conv::SpecialOptimizeDesc special_optimization = (use_conv_filter_unity_opt) ? cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY : cutlass::conv::SpecialOptimizeDesc::NONE; - ConvolutionKey key{convert_conv_op(conv_op), - convert_dtype(args.src_layout->dtype.enumv()), - layouts.src, - convert_dtype(args.filter_layout->dtype.enumv()), - layouts.filter, - convert_dtype(args.dst_layout->dtype.enumv()), - layouts.dst, - convert_dtype(args.bias_layout->dtype.enumv()), - layouts.bias, - convert_conv_type(conv_type), - m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k, - m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k, - m_algo_param.instruction_m, - m_algo_param.instruction_n, - m_algo_param.instruction_k, - epilogue_type, - m_algo_param.stage, - special_optimization, - without_shared_load}; + ConvolutionKey key{ + convert_conv_op(conv_op), + convert_dtype(args.src_layout->dtype.enumv()), + layouts.src, + convert_dtype(args.filter_layout->dtype.enumv()), + layouts.filter, + convert_dtype(args.dst_layout->dtype.enumv()), + layouts.dst, + convert_dtype(args.bias_layout->dtype.enumv()), + layouts.bias, + convert_conv_type(conv_type), + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + m_algo_param.instruction_m, + m_algo_param.instruction_n, + m_algo_param.instruction_k, + epilogue_type, + m_algo_param.stage, + special_optimization, + without_shared_load}; return Singleton::get().operation_table.find_op(key); } void ConvBiasForwardImpl::AlgoCutlassConvolutionBase::execute_cutlass_conv_op( - const Operation* op, const void* src, const void* filter, - const void* bias, const void* z, void* dst, void* workspace, size_t n, - size_t hi, size_t wi, size_t ci, size_t co, size_t fh, size_t fw, - size_t ho, size_t wo, size_t ph, size_t pw, size_t sh, size_t sw, - size_t dh, size_t dw, const void* alpha, const void* beta, - const void* gamma, const void* delta, const void* theta, + const Operation* op, const void* src, const void* filter, const void* bias, + const void* z, void* dst, void* workspace, size_t n, size_t hi, size_t wi, + size_t ci, size_t co, size_t fh, size_t fw, size_t ho, size_t wo, size_t ph, + size_t pw, size_t sh, size_t sw, size_t dh, size_t dw, const void* alpha, + const void* beta, const void* gamma, const void* delta, const void* theta, const void* threshold, const void* dst_scale, cudaStream_t stream, const void* extra_param) const { // gcc prints warnings when size_t values are implicitly narrowed to int @@ -253,9 +236,8 @@ void ConvBiasForwardImpl::AlgoCutlassConvolutionBase::execute_cutlass_conv_op( int(sw), int(dh), int(dw), cutlass::conv::Mode::kCrossCorrelation}; ConvolutionArguments conv_args{ - problem_size, src, filter, bias, z, - dst, alpha, beta, gamma, delta, - theta, threshold, dst_scale, extra_param}; + problem_size, src, filter, bias, z, dst, alpha, + beta, gamma, delta, theta, threshold, dst_scale, extra_param}; cutlass_check(op->run(&conv_args, workspace, stream)); } diff --git a/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu b/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu index c5ca513c..47875acc 100644 --- a/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu +++ b/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu @@ -11,8 +11,8 @@ */ #include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" -#include "src/cuda/query_blocksize.cuh" #include "src/cuda/integer_subbyte_utils.cuh" +#include "src/cuda/query_blocksize.cuh" using namespace megdnn; using namespace cuda; @@ -24,8 +24,7 @@ __device__ __forceinline__ void reorder_ncxhwx_imma_filter_func( int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, uint32_t lane, bool trans_oc) { static constexpr uint32_t elements_per_lane = 128 / size_bits; - static constexpr uint32_t threads_per_interleaved = - interleaved / elements_per_lane; + static constexpr uint32_t threads_per_interleaved = interleaved / elements_per_lane; static constexpr uint32_t instruction_shape_col = 8; // 4 threads per Quad static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; @@ -44,14 +43,12 @@ __device__ __forceinline__ void reorder_ncxhwx_imma_filter_func( ((row % reordered_elements_per_thread) / elements_per_thread) * instruction_shape_col + - ((row % interleaved) / - reordered_elements_per_thread) * + ((row % interleaved) / reordered_elements_per_thread) * elements_per_thread + (row % elements_per_thread) : row; - uint32_t dst_offset = - (col * OC + row) * interleaved + residue * elements_per_lane; + uint32_t dst_offset = (col * OC + row) * interleaved + residue * elements_per_lane; *(reinterpret_cast(dst + dst_offset * size_bits / 8)) = *(reinterpret_cast(src + src_offset * size_bits / 8)); @@ -89,28 +86,24 @@ __device__ __forceinline__ void reorder_nhwc_imma_filter_func( uint32_t src_offset = lane * elements_per_access; // reorder k k = (trans_oc) - ? (k / interleaved) * interleaved + - ((k % reordered_elements_per_thread) / - elements_per_thread) * - instruction_shape_col + - ((k % interleaved) / reordered_elements_per_thread) * - elements_per_thread + - (k % elements_per_thread) - : k; - uint32_t dst_offset = - (k * ICx * FH * FW + cx * FH * FW + rs) * elements_per_access; + ? (k / interleaved) * interleaved + + ((k % reordered_elements_per_thread) / elements_per_thread) * + instruction_shape_col + + ((k % interleaved) / reordered_elements_per_thread) * + elements_per_thread + + (k % elements_per_thread) + : k; + uint32_t dst_offset = (k * ICx * FH * FW + cx * FH * FW + rs) * elements_per_access; if (alignbits == 32) { - *(reinterpret_cast(dst + dst_offset * size_bits / 8)) = *( - reinterpret_cast(src + src_offset * size_bits / 8)); + *(reinterpret_cast(dst + dst_offset * size_bits / 8)) = + *(reinterpret_cast(src + src_offset * size_bits / 8)); } else if (alignbits == 64) { *(reinterpret_cast(dst + dst_offset * size_bits / 8)) = - *(reinterpret_cast(src + - src_offset * size_bits / 8)); + *(reinterpret_cast(src + src_offset * size_bits / 8)); } else { *(reinterpret_cast(dst + dst_offset * size_bits / 8)) = - *(reinterpret_cast(src + - src_offset * size_bits / 8)); + *(reinterpret_cast(src + src_offset * size_bits / 8)); } } @@ -133,15 +126,14 @@ void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter( int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc, cudaStream_t stream) { static constexpr uint32_t elements_per_lane = 128 / size_bits; - uint32_t nr_threads = - query_blocksize_for_kernel(reinterpret_cast( - reorder_ncxhwx_imma_filter_kernel)); + uint32_t nr_threads = query_blocksize_for_kernel(reinterpret_cast( + reorder_ncxhwx_imma_filter_kernel)); uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane); nr_threads = std::min(nr_threads, vthreads); uint32_t nr_blocks = DIVUP(vthreads, nr_threads); reorder_ncxhwx_imma_filter_kernel - <<>>(dst_filter, src_filter, OC, - IC, FH, FW, trans_oc); + <<>>( + dst_filter, src_filter, OC, IC, FH, FW, trans_oc); after_kernel_launch(); } template @@ -151,17 +143,16 @@ void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter( uint32_t interleaved, cudaStream_t stream) { const uint32_t elements_per_access = alignbits / size_bits; - void (*kern)(int8_t* __restrict__, const int8_t* __restrict__, uint32_t, - uint32_t, uint32_t, uint32_t, bool); + void (*kern)( + int8_t* __restrict__, const int8_t* __restrict__, uint32_t, uint32_t, + uint32_t, uint32_t, bool); kern = nullptr; - auto get_kern = [&kern](const uint32_t alignbits, - const uint32_t interleaved) { -#define DISPATCH_KERNEL(alignbits_, interleaved_) \ - if (alignbits == alignbits_ && interleaved == interleaved_) { \ - kern = reorder_nhwc_imma_filter_kernel; \ - return; \ + auto get_kern = [&kern](const uint32_t alignbits, const uint32_t interleaved) { +#define DISPATCH_KERNEL(alignbits_, interleaved_) \ + if (alignbits == alignbits_ && interleaved == interleaved_) { \ + kern = reorder_nhwc_imma_filter_kernel; \ + return; \ } DISPATCH_KERNEL(128, 16); DISPATCH_KERNEL(64, 16); @@ -183,29 +174,27 @@ void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter( nr_threads = std::min(nr_threads, vthreads); uint32_t nr_blocks = DIVUP(vthreads, nr_threads); - kern<<>>(dst_filter, src_filter, OC, IC, - FH, FW, trans_oc); + kern<<>>( + dst_filter, src_filter, OC, IC, FH, FW, trans_oc); after_kernel_launch(); } -#define INST(_size_bits, _interleaved) \ - template void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter< \ - _size_bits, _interleaved>(int8_t * dst_filter, \ - const int8_t* src_filter, uint32_t OC, \ - uint32_t IC, uint32_t FH, uint32_t FW, \ - bool trans_oc, cudaStream_t stream); +#define INST(_size_bits, _interleaved) \ + template void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter< \ + _size_bits, _interleaved>( \ + int8_t * dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, \ + uint32_t FH, uint32_t FW, bool trans_oc, cudaStream_t stream); INST(8, 32) INST(4, 64) #undef INST -#define INST(_size_bits) \ - template void \ - megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter<_size_bits>( \ - int8_t * dst_filter, const int8_t* src_filter, uint32_t OC, \ - uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc, \ - uint32_t alignbits, uint32_t interleaved, cudaStream_t stream); +#define INST(_size_bits) \ + template void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter<_size_bits>( \ + int8_t * dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, \ + uint32_t FH, uint32_t FW, bool trans_oc, uint32_t alignbits, \ + uint32_t interleaved, cudaStream_t stream); INST(4) INST(8) #undef INST diff --git a/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh b/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh index 03f74e90..69500284 100644 --- a/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh +++ b/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh @@ -18,16 +18,15 @@ namespace cuda { namespace cutlass_wrapper { template -void reorder_ncxhwx_imma_filter(int8_t* dst_filter, const int8_t* src_filter, - uint32_t OC, uint32_t IC, uint32_t FH, - uint32_t FW, bool trans_oc, - cudaStream_t stream); +void reorder_ncxhwx_imma_filter( + int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, + uint32_t FH, uint32_t FW, bool trans_oc, cudaStream_t stream); template -void reorder_nhwc_imma_filter(int8_t* dst_filter, const int8_t* src_filter, - uint32_t OC, uint32_t IC, uint32_t FH, - uint32_t FW, bool trans_oc, uint32_t alignbits, - uint32_t interleaved, cudaStream_t stream); +void reorder_nhwc_imma_filter( + int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, + uint32_t FH, uint32_t FW, bool trans_oc, uint32_t alignbits, + uint32_t interleaved, cudaStream_t stream); } // namespace cutlass_wrapper } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/conv_bias/group_conv.cpp b/dnn/src/cuda/conv_bias/group_conv.cpp index 51c78a74..1366c64f 100644 --- a/dnn/src/cuda/conv_bias/group_conv.cpp +++ b/dnn/src/cuda/conv_bias/group_conv.cpp @@ -32,9 +32,9 @@ std::pair sub_opr_config( args.filter_meta.format == megdnn::param::ConvBias::Format::NCHW4) { c_pos = 1; } else { - megdnn_assert(args.filter_meta.format == - megdnn::param::ConvBias::Format::NHWC, - "invalid conv format"); + megdnn_assert( + args.filter_meta.format == megdnn::param::ConvBias::Format::NHWC, + "invalid conv format"); c_pos = 3; } @@ -45,8 +45,7 @@ std::pair sub_opr_config( megdnn::param::ConvBias param = args.opr->param(); param.sparse = megdnn::param::ConvBias::Sparse::DENSE; - param.nonlineMode = - megdnn::param::ConvBias::NonlineMode::IDENTITY; + param.nonlineMode = megdnn::param::ConvBias::NonlineMode::IDENTITY; std::pair ret; ret.first = {src_pg, filter_pg, bias_pg, z_pg, dst_pg}; ret.second = param; @@ -66,15 +65,16 @@ std::pair> prepare_sub_opr( } } // namespace -std::vector -ConvBiasForwardImpl::AlgoGroupConvGeneral::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { - AlgoBase::SizeArgs args{static_cast(opr), - layouts[0], - layouts[1], - layouts[2], - layouts[3], - layouts[4]}; +std::vector ConvBiasForwardImpl::AlgoGroupConvGeneral:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { + AlgoBase::SizeArgs args{ + static_cast(opr), + layouts[0], + layouts[1], + layouts[2], + layouts[3], + layouts[4]}; auto&& config = sub_opr_config(args); std::string param_str; @@ -99,9 +99,8 @@ bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available( auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); } auto conv_args = args; @@ -109,9 +108,8 @@ bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available( auto config = prepare_sub_opr(conv_args); bool ret = has_available_algo( - static_cast(config.second.get()), - config.first[0], config.first[1], config.first[2], config.first[3], - config.first[4]); + static_cast(config.second.get()), config.first[0], + config.first[1], config.first[2], config.first[3], config.first[4]); return ret; } @@ -121,9 +119,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoGroupConvGeneral::get_workspace_bundle( SmallVector sizes; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); sizes.push_back(dst_layout.span().dist_byte()); } @@ -131,8 +128,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoGroupConvGeneral::get_workspace_bundle( conv_args.dst_layout = &dst_layout; auto config = prepare_sub_opr(conv_args); size_t mm_ws = config.second->get_workspace_in_bytes( - config.first[0], config.first[1], config.first[2], - config.first[3], config.first[4], nullptr); + config.first[0], config.first[1], config.first[2], config.first[3], + config.first[4], nullptr); sizes.insert(sizes.begin(), mm_ws); return {ptr, std::move(sizes)}; @@ -143,16 +140,15 @@ size_t ConvBiasForwardImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } -void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec( - const ExecArgs& args) const { +void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) const { auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); auto conv_dst_tensor = *args.dst_tensor; if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); conv_dst_tensor.layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - conv_dst_tensor.layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, + conv_dst_tensor.layout.dtype); } { auto sub_args = args; @@ -171,18 +167,17 @@ void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec( args.filter_meta.format == Param::Format::NCHW4) { c_pos = 1; } else { - megdnn_assert(args.filter_meta.format == Param::Format::NHWC, - "invalid conv format"); + megdnn_assert( + args.filter_meta.format == Param::Format::NHWC, + "invalid conv format"); c_pos = 3; } auto grp = args.filter_meta.group; auto&& fm = args.filter_meta; - auto strd_src = tsrc.layout.stride[c_pos] * fm.icpg * - tsrc.layout.dtype.size(), - strd_dst = tdst.layout.stride[c_pos] * fm.ocpg * - tdst.layout.dtype.size(), + auto strd_src = tsrc.layout.stride[c_pos] * fm.icpg * tsrc.layout.dtype.size(), + strd_dst = tdst.layout.stride[c_pos] * fm.ocpg * tdst.layout.dtype.size(), strd_flt = fm.icpg * fm.ocpg * fm.spatial[0] * fm.spatial[1] * tfilter.layout.dtype.size(); if (args.filter_meta.format == Param::Format::NCHW4) { @@ -190,16 +185,16 @@ void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec( strd_dst >>= 2; } for (uint32_t g = 0; g < grp; ++g) { - config.second->exec(tsrc, tfilter, tbias, - tz, tdst, nullptr, bundle.get_workspace(0)); + config.second->exec( + tsrc, tfilter, tbias, tz, tdst, nullptr, bundle.get_workspace(0)); incr_voidp(tsrc.raw_ptr, strd_src); incr_voidp(tdst.raw_ptr, strd_dst); incr_voidp(tfilter.raw_ptr, strd_flt); } } - handle_bias_and_nonlinear(args.handle, args.nonlinear_mode, - &conv_dst_tensor, args.dst_tensor, - args.bias_tensor); + handle_bias_and_nonlinear( + args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, + args.bias_tensor); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/helper.cpp b/dnn/src/cuda/conv_bias/helper.cpp index d4cd291b..0c884116 100644 --- a/dnn/src/cuda/conv_bias/helper.cpp +++ b/dnn/src/cuda/conv_bias/helper.cpp @@ -28,11 +28,10 @@ ConvBiasDesc::~ConvBiasDesc() { cudnn_check(cudnnDestroyActivationDescriptor(act_desc)); } -void ConvBiasDesc::set_conv_bias(DType data_type, const param::ConvBias& param, - size_t nr_group) { +void ConvBiasDesc::set_conv_bias( + DType data_type, const param::ConvBias& param, size_t nr_group) { #if CUDNN_VERSION < 7100 - megdnn_throw( - "ConvBias(CUDNN_ACTIVATION_IDENTITY) require cudnn 7.1 or higher"); + megdnn_throw("ConvBias(CUDNN_ACTIVATION_IDENTITY) require cudnn 7.1 or higher"); #else cudnnConvolutionMode_t mode; using Param = param::ConvBias; @@ -72,13 +71,11 @@ void ConvBiasDesc::set_conv_bias(DType data_type, const param::ConvBias& param, case Param::NonlineMode::SIGMOID: case Param::NonlineMode::H_SWISH: cudnn_check(cudnnSetActivationDescriptor( - act_desc, CUDNN_ACTIVATION_IDENTITY, - CUDNN_NOT_PROPAGATE_NAN, 0)); + act_desc, CUDNN_ACTIVATION_IDENTITY, CUDNN_NOT_PROPAGATE_NAN, 0)); break; case Param::NonlineMode::RELU: cudnn_check(cudnnSetActivationDescriptor( - act_desc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, - 0)); + act_desc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0)); break; default: megdnn_throw("unsupported non linear mode"); @@ -86,8 +83,8 @@ void ConvBiasDesc::set_conv_bias(DType data_type, const param::ConvBias& param, #endif } -void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param, - const size_t nr_group) { +void ConvBiasDesc::set_conv( + DType data_type, const param::ConvBias& param, const size_t nr_group) { using Param = param::ConvBias; cudnnConvolutionMode_t mode; switch (param.mode) { @@ -109,8 +106,9 @@ void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param, auto comp_mode = param.compute_mode; compute_type = get_compute_type_fp16(comp_mode); #if CUDNN_MAJOR >= 7 - } else if (data_type.category() == DTypeCategory::INT || - data_type.category() == DTypeCategory::QUANTIZED) { + } else if ( + data_type.category() == DTypeCategory::INT || + data_type.category() == DTypeCategory::QUANTIZED) { compute_type = CUDNN_DATA_INT32; #endif } else { @@ -157,8 +155,9 @@ bool is_cudnn_supported(const BiasForwardSizeArgs& args) { args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS8) { return false; } - } else if (args.filter_meta.format != param::Convolution::Format::NCHW && - args.filter_meta.format != param::Convolution::Format::NHWC) { + } else if ( + args.filter_meta.format != param::Convolution::Format::NCHW && + args.filter_meta.format != param::Convolution::Format::NHWC) { return false; } auto& fm = args.filter_meta; @@ -173,24 +172,24 @@ bool is_cudnn_supported(const BiasForwardSizeArgs& args) { return supported; } -SmallVector matmul_get_workspace_bundle( - const BiasForwardSizeArgs& args) { +SmallVector matmul_get_workspace_bundle(const BiasForwardSizeArgs& args) { auto dtype = args.src_layout->dtype; auto&& fm = args.filter_meta; megdnn_assert(fm.group == 1); auto N = args.src_layout->shape[0]; auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; auto OH = args.dst_layout->shape[2], OW = args.dst_layout->shape[3]; - SmallVector sizes{dtype.size() * args.dst_layout->total_nr_elems(), - dtype.size() * IC * FH * FW * OH * OW * N}; + SmallVector sizes{ + dtype.size() * args.dst_layout->total_nr_elems(), + dtype.size() * IC * FH * FW * OH * OW * N}; if (args.filter_meta.should_flip) { sizes.push_back(dtype.size() * OC * IC * FH * FW); } return sizes; } -void flip_filter(const BiasForwardSizeArgs& args, const Workspace& workspace, - void*& raw_ptr) { +void flip_filter( + const BiasForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr) { auto&& fm = args.filter_meta; megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2); auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; @@ -205,9 +204,9 @@ void flip_filter(const BiasForwardSizeArgs& args, const Workspace& workspace, raw_ptr = workspace.raw_ptr; } -} // conv_bias +} // namespace conv_bias -} // cuda -} // megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/helper.h b/dnn/src/cuda/conv_bias/helper.h index 0d3687d4..9c141acf 100644 --- a/dnn/src/cuda/conv_bias/helper.h +++ b/dnn/src/cuda/conv_bias/helper.h @@ -11,10 +11,10 @@ #pragma once #include "./opr_impl.h" -#include "src/cuda/handle.h" -#include "src/cuda/cudnn_wrapper.h" -#include "src/common/utils.h" #include "src/common/algo_chooser.h" +#include "src/common/utils.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/handle.h" namespace megdnn { namespace cuda { @@ -22,112 +22,109 @@ namespace cuda { class ConvBiasDesc { public: ConvBiasDesc(); - void set_conv_bias(DType data_type, const param::ConvBias& param, - const size_t nr_group); - void set_conv(DType data_type, const param::ConvBias& param, - const size_t nr_group); + void set_conv_bias( + DType data_type, const param::ConvBias& param, const size_t nr_group); + void set_conv(DType data_type, const param::ConvBias& param, const size_t nr_group); ~ConvBiasDesc(); cudnnConvolutionDescriptor_t conv_desc; cudnnActivationDescriptor_t act_desc; }; namespace conv_bias { - using CanonizedFilterMeta = ConvBiasForward::CanonizedFilterMeta; +using CanonizedFilterMeta = ConvBiasForward::CanonizedFilterMeta; - //! conv size descriptor in the forward view - struct BiasForwardSizeArgs { - HandleImpl *handle; - const TensorLayout *src_layout; - const TensorLayout *filter_layout; - const TensorLayout *bias_layout; - const TensorLayout *z_layout; - CanonizedFilterMeta filter_meta; - const TensorLayout *dst_layout; - param::ConvBias::NonlineMode nonlinear_mode; - }; +//! conv size descriptor in the forward view +struct BiasForwardSizeArgs { + HandleImpl* handle; + const TensorLayout* src_layout; + const TensorLayout* filter_layout; + const TensorLayout* bias_layout; + const TensorLayout* z_layout; + CanonizedFilterMeta filter_meta; + const TensorLayout* dst_layout; + param::ConvBias::NonlineMode nonlinear_mode; +}; - //! whether cudnn is supported for a filter meta - bool is_cudnn_supported(const BiasForwardSizeArgs& args); +//! whether cudnn is supported for a filter meta +bool is_cudnn_supported(const BiasForwardSizeArgs& args); - //! get workspace bundle for matmul algo - SmallVector matmul_get_workspace_bundle( - const BiasForwardSizeArgs& args); +//! get workspace bundle for matmul algo +SmallVector matmul_get_workspace_bundle(const BiasForwardSizeArgs& args); - /*! - * \brief flip conv filter - * - * Flip conv filter pointed by \p raw_ptr, store result in workspace, and - * change \p raw_ptr to workspace. - */ - void flip_filter(const BiasForwardSizeArgs& args, - const Workspace& workspace, void*& raw_ptr); +/*! + * \brief flip conv filter + * + * Flip conv filter pointed by \p raw_ptr, store result in workspace, and + * change \p raw_ptr to workspace. + */ +void flip_filter( + const BiasForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr); - struct CUDNNForwardDescs { - TensorDesc src_desc, dst_desc, bias_desc, z_desc; - FilterDesc filter_desc; - ConvBiasDesc conv_desc; +struct CUDNNForwardDescs { + TensorDesc src_desc, dst_desc, bias_desc, z_desc; + FilterDesc filter_desc; + ConvBiasDesc conv_desc; - void set_conv_bias(const TensorLayout& src, - const CanonizedFilterMeta& filter, - const TensorLayout& dst, const TensorLayout& bias, - const TensorLayout& z, - const param::ConvBias& param) { - using Format = param::ConvBias::Format; - Format src_format, dst_format; - src_format = dst_format = param.format; - if (param.format == Format::NCHW4_NCHW) { - src_format = Format::NCHW4; - dst_format = Format::NCHW; - } - src_desc.set(src, src_format); - filter_desc.set(filter); - if (z.ndim > 0) { - z_desc.set(z, dst_format); - } - dst_desc.set(dst, dst_format); - conv_desc.set_conv_bias(src.dtype, param, filter.group); + void set_conv_bias( + const TensorLayout& src, const CanonizedFilterMeta& filter, + const TensorLayout& dst, const TensorLayout& bias, const TensorLayout& z, + const param::ConvBias& param) { + using Format = param::ConvBias::Format; + Format src_format, dst_format; + src_format = dst_format = param.format; + if (param.format == Format::NCHW4_NCHW) { + src_format = Format::NCHW4; + dst_format = Format::NCHW; + } + src_desc.set(src, src_format); + filter_desc.set(filter); + if (z.ndim > 0) { + z_desc.set(z, dst_format); + } + dst_desc.set(dst, dst_format); + conv_desc.set_conv_bias(src.dtype, param, filter.group); - // cudnn requires the bias to be float tensor. - auto float_bias_layout = bias; - float_bias_layout.dtype = dtype::Float32(); - if (param.format == param::ConvBias::Format::NCHW4 || - param.format == param::ConvBias::Format::NCHW32) { - // cudnn require bias to be NCHW, not NCHW4. - float_bias_layout = float_bias_layout.reshape( - {float_bias_layout[0], - float_bias_layout[1] * float_bias_layout[4], - float_bias_layout[2], float_bias_layout[3]}); - bias_desc.set(float_bias_layout); - } else if (param.format == param::ConvBias::Format::NCHW4_NCHW) { - megdnn_assert(float_bias_layout.ndim == 4, - "NCHW4_NCHW format assumes bias tensor is stored " - "in NCHW layout, ndim(expected:4,got:%zu)", - float_bias_layout.ndim); - bias_desc.set(float_bias_layout); - } else { - bias_desc.set(float_bias_layout, param.format); - } + // cudnn requires the bias to be float tensor. + auto float_bias_layout = bias; + float_bias_layout.dtype = dtype::Float32(); + if (param.format == param::ConvBias::Format::NCHW4 || + param.format == param::ConvBias::Format::NCHW32) { + // cudnn require bias to be NCHW, not NCHW4. + float_bias_layout = float_bias_layout.reshape( + {float_bias_layout[0], float_bias_layout[1] * float_bias_layout[4], + float_bias_layout[2], float_bias_layout[3]}); + bias_desc.set(float_bias_layout); + } else if (param.format == param::ConvBias::Format::NCHW4_NCHW) { + megdnn_assert( + float_bias_layout.ndim == 4, + "NCHW4_NCHW format assumes bias tensor is stored " + "in NCHW layout, ndim(expected:4,got:%zu)", + float_bias_layout.ndim); + bias_desc.set(float_bias_layout); + } else { + bias_desc.set(float_bias_layout, param.format); } + } - void set_conv(const TensorLayout& src, - const CanonizedFilterMeta& filter, - const TensorLayout& dst, const param::ConvBias& param) { - using Format = param::ConvBias::Format; - Format src_format, dst_format; - src_format = dst_format = param.format; - if (param.format == Format::NCHW4_NCHW) { - src_format = Format::NCHW4; - dst_format = Format::NCHW; - } - src_desc.set(src, src_format); - filter_desc.set(filter); - dst_desc.set(dst, dst_format); - conv_desc.set_conv(src.dtype, param, filter.group); + void set_conv( + const TensorLayout& src, const CanonizedFilterMeta& filter, + const TensorLayout& dst, const param::ConvBias& param) { + using Format = param::ConvBias::Format; + Format src_format, dst_format; + src_format = dst_format = param.format; + if (param.format == Format::NCHW4_NCHW) { + src_format = Format::NCHW4; + dst_format = Format::NCHW; } - }; + src_desc.set(src, src_format); + filter_desc.set(filter); + dst_desc.set(dst, dst_format); + conv_desc.set_conv(src.dtype, param, filter.group); + } +}; } // namespace conv_bias -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp index bcd8c655..24c2e404 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp @@ -17,8 +17,7 @@ using namespace cuda; using namespace convolution; #if CUDA_VERSION >= 10020 -size_t -ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_workspace_in_bytes( +size_t ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_workspace_in_bytes( const SizeArgs& args) const { if (args.preprocessed_filter) { return 0; @@ -32,9 +31,8 @@ size_t ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm:: return 0; } -SmallVector ConvBiasForwardImpl:: - AlgoInt4Int4NCHW64IMMAImplicitGemm::deduce_preprocessed_filter_layout( - const SizeArgs& args) const { +SmallVector ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm:: + deduce_preprocessed_filter_layout(const SizeArgs& args) const { return {args.filter_layout->collapse_contiguous()}; } @@ -45,9 +43,8 @@ void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( reorder_filter(args, filter_ptr); } -std::tuple -ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::prepare_filter_bias( - const ExecArgs& args) const { +std::tuple ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm:: + prepare_filter_bias(const ExecArgs& args) const { void* filter_ptr = nullptr; if (args.preprocessed_filter) { megdnn_assert(args.preprocessed_filter->tensors.size() == 1); @@ -60,18 +57,15 @@ ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::prepare_filter_bias( return {filter_ptr, bias_ptr}; } -std::tuple -ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::get_constants( - const ExecArgs& args) const { +std::tuple ConvBiasForwardImpl:: + AlgoInt4Int4NCHW64IMMAImplicitGemm::get_constants(const ExecArgs& args) const { float src_scale = args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - bias_scale = - args.bias_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + bias_scale = args.bias_layout->dtype.param().scale, dst_scale = args.dst_layout->dtype.param().scale; - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale, gamma = 0.f, delta = 0.f, theta = 0.f; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale, + gamma = 0.f, delta = 0.f, theta = 0.f; if (args.z_layout->ndim > 0) { float z_scale = args.z_layout->dtype.param().scale; diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp index 8b5f1cda..c4db640d 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp @@ -17,8 +17,7 @@ using namespace cuda; using namespace convolution; #if CUDA_VERSION >= 10020 -size_t -ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::get_workspace_in_bytes( +size_t ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::get_workspace_in_bytes( const SizeArgs& args) const { if (args.preprocessed_filter) { return 0; @@ -32,9 +31,8 @@ size_t ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm:: return 0; } -SmallVector ConvBiasForwardImpl:: - AlgoInt4Int4NHWCIMMAImplicitGemm::deduce_preprocessed_filter_layout( - const SizeArgs& args) const { +SmallVector ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm:: + deduce_preprocessed_filter_layout(const SizeArgs& args) const { return {args.filter_layout->collapse_contiguous()}; } @@ -45,9 +43,8 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::exec_preprocess( reorder_filter(args, m_algo_param.access_size, filter_ptr); } -std::tuple -ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::prepare_filter_bias( - const ExecArgs& args) const { +std::tuple ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm:: + prepare_filter_bias(const ExecArgs& args) const { void* filter_ptr = nullptr; if (args.preprocessed_filter) { megdnn_assert(args.preprocessed_filter->tensors.size() == 1); @@ -60,14 +57,11 @@ ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::prepare_filter_bias( return {filter_ptr, bias_ptr}; } -std::tuple -ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::get_constants( - const ExecArgs& args) const { +std::tuple ConvBiasForwardImpl:: + AlgoInt4Int4NHWCIMMAImplicitGemm::get_constants(const ExecArgs& args) const { float src_scale = args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - bias_scale = - args.bias_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + bias_scale = args.bias_layout->dtype.param().scale, dst_scale; if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { @@ -77,16 +71,15 @@ ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::get_constants( dst_scale = args.dst_layout->dtype.param().scale; } - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale, gamma = 0.f, delta = 0.f, theta = 0.f; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale, + gamma = 0.f, delta = 0.f, theta = 0.f; if (args.z_layout->ndim > 0) { float z_scale; if (args.z_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { z_scale = args.z_layout->dtype.param().scale; } else { // DTypeEnum::QuantizedS8 - megdnn_assert(args.z_layout->dtype.enumv() == - DTypeEnum::QuantizedS8); + megdnn_assert(args.z_layout->dtype.enumv() == DTypeEnum::QuantizedS8); z_scale = args.z_layout->dtype.param().scale; } gamma = z_scale / dst_scale; diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp index 40bb0e82..227339a3 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp @@ -22,8 +22,7 @@ using namespace cuda; using namespace convolution; #if CUDA_VERSION >= 10020 -std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::param() - const { +std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::param() const { std::string ret; serialize_write_pod(m_algo_param, ret); return ret; @@ -91,13 +90,10 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( const ExecArgs& args) const { auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - size_t n = args.src_layout->operator[](0), - ci = args.src_layout->operator[](1) * 64, - hi = args.src_layout->operator[](2), - wi = args.src_layout->operator[](3); + size_t n = args.src_layout->operator[](0), ci = args.src_layout->operator[](1) * 64, + hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3); size_t co = args.dst_layout->operator[](1) * 64, - ho = args.dst_layout->operator[](2), - wo = args.dst_layout->operator[](3); + ho = args.dst_layout->operator[](2), wo = args.dst_layout->operator[](3); UNPACK_CONV_PARAMETER(fm, param); MARK_USED_VAR @@ -122,35 +118,33 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( bool without_shared_load = true; if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { - dst_scale = - args.dst_layout->dtype.param().scale; - src_zero = args.src_layout->dtype.param() - .zero_point; + dst_scale = args.dst_layout->dtype.param().scale; + src_zero = args.src_layout->dtype.param().zero_point; } else { // DTypeEnum::QuantizedS4 dst_scale = args.dst_layout->dtype.param().scale; } cudaStream_t stream = cuda_stream(args.opr->handle()); - const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, - ConvType::kConvolution, - use_conv_filter_unity_opt, without_shared_load); + const auto* op = get_cutlass_conv_op( + args, ConvOperator::kFprop, ConvType::kConvolution, + use_conv_filter_unity_opt, without_shared_load); - execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, - z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, - ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw, - &alpha, &beta, &gamma, &delta, &theta, &threshold, - &dst_scale, stream, &src_zero); + execute_cutlass_conv_op( + op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, z_ptr, + args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, + pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, + &dst_scale, stream, &src_zero); after_kernel_launch(); } std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( AlgoParam algo_param) { - return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, - algo_param.threadblock_n, algo_param.threadblock_k, - algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, - algo_param.stage); + return ssprintf( + "%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, algo_param.threadblock_n, + algo_param.threadblock_k, algo_param.warp_m, algo_param.warp_n, + algo_param.warp_k, algo_param.stage); } void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter( @@ -165,8 +159,8 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter( // filter: KCRS64 => CRSK64 and reorder oc cutlass_wrapper::reorder_ncxhwx_imma_filter<4, 64>( reinterpret_cast(reordered_filter), - reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, fh, - fw, true, stream); + reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, fh, fw, + true, stream); } #endif diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp index cf53fd39..b25d2690 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp @@ -22,8 +22,7 @@ using namespace cuda; using namespace convolution; #if CUDA_VERSION >= 10020 -std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::param() - const { +std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::param() const { std::string ret; serialize_write_pod(m_algo_param, ret); return ret; @@ -87,9 +86,9 @@ bool ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::is_available( return false; bool use_conv_filter_unity_opt = (fh == 1 && fw == 1); - bool without_shared_load = ((co % m_algo_param.threadblock_n == 0) && - (m_algo_param.threadblock_n == 32 || - m_algo_param.threadblock_n == 64)); + bool without_shared_load = + ((co % m_algo_param.threadblock_n == 0) && + (m_algo_param.threadblock_n == 32 || m_algo_param.threadblock_n == 64)); const auto* op = get_cutlass_conv_op( args, ConvOperator::kFprop, ConvType::kConvolution, use_conv_filter_unity_opt, without_shared_load); @@ -103,12 +102,9 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( const ExecArgs& args) const { auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - size_t n = args.src_layout->operator[](0), - ci = args.src_layout->operator[](3), - hi = args.src_layout->operator[](1), - wi = args.src_layout->operator[](2); - size_t co = args.dst_layout->operator[](3), - ho = args.dst_layout->operator[](1), + size_t n = args.src_layout->operator[](0), ci = args.src_layout->operator[](3), + hi = args.src_layout->operator[](1), wi = args.src_layout->operator[](2); + size_t co = args.dst_layout->operator[](3), ho = args.dst_layout->operator[](1), wo = args.dst_layout->operator[](2); UNPACK_CONV_PARAMETER(fm, param); MARK_USED_VAR @@ -132,18 +128,16 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( uint8_t src_zero = 0; bool use_conv_filter_unity_opt = (fh == 1 && fw == 1); - bool without_shared_load = ((co % m_algo_param.threadblock_n == 0) && - (m_algo_param.threadblock_n == 32 || - m_algo_param.threadblock_n == 64)); + bool without_shared_load = + ((co % m_algo_param.threadblock_n == 0) && + (m_algo_param.threadblock_n == 32 || m_algo_param.threadblock_n == 64)); if (args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { - src_zero = args.src_layout->dtype.param() - .zero_point; + src_zero = args.src_layout->dtype.param().zero_point; } if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { - dst_scale = - args.dst_layout->dtype.param().scale; + dst_scale = args.dst_layout->dtype.param().scale; } else if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { dst_scale = args.dst_layout->dtype.param().scale; } else { // DTypeEnum::QuantizedS8 @@ -152,30 +146,30 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( cudaStream_t stream = cuda_stream(args.opr->handle()); - const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, - ConvType::kConvolution, - use_conv_filter_unity_opt, without_shared_load); + const auto* op = get_cutlass_conv_op( + args, ConvOperator::kFprop, ConvType::kConvolution, + use_conv_filter_unity_opt, without_shared_load); - execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, - z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, - ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw, - &alpha, &beta, &gamma, &delta, &theta, &threshold, - &dst_scale, stream, &src_zero); + execute_cutlass_conv_op( + op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, z_ptr, + args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, + pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, + &dst_scale, stream, &src_zero); after_kernel_launch(); } std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( AlgoParam algo_param) { - return ssprintf("%dX%dX%d_%dX%dX%d_%d_%d", algo_param.threadblock_m, - algo_param.threadblock_n, algo_param.threadblock_k, - algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, - algo_param.stage, algo_param.access_size); + return ssprintf( + "%dX%dX%d_%dX%dX%d_%d_%d", algo_param.threadblock_m, + algo_param.threadblock_n, algo_param.threadblock_k, algo_param.warp_m, + algo_param.warp_n, algo_param.warp_k, algo_param.stage, + algo_param.access_size); } void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( - const ExecArgs& args, const int iterleaved, - void* reordered_filter) const { + const ExecArgs& args, const int iterleaved, void* reordered_filter) const { size_t co = args.filter_layout->operator[](0), ci = args.filter_layout->operator[](3), fh = args.filter_layout->operator[](1), @@ -185,17 +179,17 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( // reformat filter from nhwc to ncxhwx and reorder oc // use trans_oc threadblock_n must be 32 or 64 and src dtype == dest dtype - bool trans_oc = ((co % m_algo_param.threadblock_n == 0) && - (m_algo_param.threadblock_n == 32 || - m_algo_param.threadblock_n == 64)); + bool trans_oc = + ((co % m_algo_param.threadblock_n == 0) && + (m_algo_param.threadblock_n == 32 || m_algo_param.threadblock_n == 64)); uint32_t oc_iterleaved = (m_algo_param.threadblock_n == 64) ? 64 : 32; uint32_t alignbits = iterleaved * 4; cutlass_wrapper::reorder_nhwc_imma_filter<4>( reinterpret_cast(reordered_filter), - reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, fh, - fw, trans_oc, alignbits, oc_iterleaved, stream); + reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, fh, fw, + trans_oc, alignbits, oc_iterleaved, stream); } #endif diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_dp4a.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_dp4a.cpp index d0c332a4..52f2e47b 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_dp4a.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_dp4a.cpp @@ -10,12 +10,12 @@ */ #include "./algo.h" +#include "src/common/conv_bias.h" #include "src/cuda/convolution_helper/bias_visitor.cuh" #include "src/cuda/convolution_helper/epilogue.cuh" #include "src/cuda/convolution_helper/layout.cuh" #include "src/cuda/convolution_helper/parameter.cuh" #include "src/cuda/utils.h" -#include "src/common/conv_bias.h" using namespace megdnn; using namespace cuda; @@ -23,12 +23,13 @@ using namespace convolution; namespace { template -void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, Epilogue epilogue, - const ConvParam& param, float alpha, float beta, - cudaStream_t stream) { - void (*kern_wrapper)(const int8_t*, const int8_t*, BiasVisitor, Epilogue, - const ConvParam&, float, float, cudaStream_t); +void dispatch_kernel( + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + Epilogue epilogue, const ConvParam& param, float alpha, float beta, + cudaStream_t stream) { + void (*kern_wrapper)( + const int8_t*, const int8_t*, BiasVisitor, Epilogue, const ConvParam&, + float, float, cudaStream_t); using namespace conv_bias_int8; // for turing if (is_compute_capability_required(7, 5)) { @@ -41,19 +42,16 @@ void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter, do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width< BiasVisitor, Epilogue>; } else { - kern_wrapper = - do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit< - BiasVisitor, Epilogue>; + kern_wrapper = do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit< + BiasVisitor, Epilogue>; } } else { if (use_unroll_width) { - kern_wrapper = - do_conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width< - BiasVisitor, Epilogue>; + kern_wrapper = do_conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width< + BiasVisitor, Epilogue>; } else { - kern_wrapper = - do_conv_bias_int8_implicit_gemm_cdiv4hwn4; + kern_wrapper = do_conv_bias_int8_implicit_gemm_cdiv4hwn4< + BiasVisitor, Epilogue>; } } } else { // volta or lower @@ -62,20 +60,18 @@ void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter, BiasVisitor, Epilogue>; } else { kern_wrapper = - do_conv_bias_int8_implicit_gemm_cdiv4hwn4; + do_conv_bias_int8_implicit_gemm_cdiv4hwn4; } } megdnn_assert(kern_wrapper != nullptr); - return kern_wrapper(d_src, d_filter, bias_visitor, epilogue, param, alpha, - beta, stream); + return kern_wrapper( + d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, stream); } } // namespace bool ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::is_available( const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } if (args.bias_layout->ndim <= 0) @@ -88,26 +84,23 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::is_available( bool available = true; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - if (!check_bias_share_in_channel(*(args.bias_layout), - param.format)) + if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) return false; if (param.format != Format::CHWN4) return false; - UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); // TODO support group conv available &= param.sparse == Sparse::DENSE; // mode must be cross correlation available &= param.mode == Mode::CROSS_CORRELATION; // check data type - auto src_dtype = args.src_layout->dtype, - filter_dtype = args.filter_layout->dtype, - bias_dtype = args.bias_layout->dtype, - dst_dtype = args.dst_layout->dtype; - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8 && - bias_dtype.enumv() == DTypeEnum::QuantizedS32 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8); + auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype, + bias_dtype = args.bias_layout->dtype, dst_dtype = args.dst_layout->dtype; + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + bias_dtype.enumv() == DTypeEnum::QuantizedS32 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); // TODO: support dialtion available &= dh == 1 && dw == 1; // only support sm_61 or later, platform should have fast native int8 @@ -116,8 +109,7 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::is_available( return available; } -size_t -ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::get_workspace_in_bytes( +size_t ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::get_workspace_in_bytes( const SizeArgs& /* args */) const { return 0; } @@ -127,25 +119,20 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::exec( using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); auto&& stream = cuda_stream(args.opr->handle()); ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, + kern_param.pw = pw, kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, kern_param.fw = fw; float src_scale = args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - bias_scale = - args.bias_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + bias_scale = args.bias_layout->dtype.param().scale, dst_scale = args.dst_layout->dtype.param().scale; - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale; int8_t* z_dev_ptr = nullptr; float gamma = 1.f; if (args.z_layout->ndim > 0) { @@ -157,36 +144,34 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::exec( bias_visitor.bias = args.bias_tensor->compatible_ptr(); dispatch_nonlinear_mode( args.src_tensor->compatible_ptr(), - args.filter_tensor->compatible_ptr(), bias_visitor, - z_dev_ptr, args.dst_tensor->compatible_ptr(), kern_param, - alpha, beta, gamma, dst_scale, stream, param.nonlineMode); + args.filter_tensor->compatible_ptr(), bias_visitor, z_dev_ptr, + args.dst_tensor->compatible_ptr(), kern_param, alpha, beta, gamma, + dst_scale, stream, param.nonlineMode); } template -void ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm:: - dispatch_nonlinear_mode(const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, const int8_t* d_z, - int8_t* d_dst, const ConvParam& param, - float alpha, float beta, float gamma, - float scale, cudaStream_t stream, - param::ConvBias::NonlineMode nonlinear_mode) { +void ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::dispatch_nonlinear_mode( + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + const int8_t* d_z, int8_t* d_dst, const ConvParam& param, float alpha, + float beta, float gamma, float scale, cudaStream_t stream, + param::ConvBias::NonlineMode nonlinear_mode) { using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; Layout layout; layout.init(param.n, param.co, param.ho, param.wo); -#define DISPATCH_CONV_INT8_EPILOGUE(_act_op) \ - do { \ - IConvEpilogue<_act_op> epilogue{d_dst, \ - d_z, \ - layout.batch_stride, \ - layout.channel_stride / 4, \ - layout.height_stride, \ - layout.width_stride, \ - gamma, \ - _act_op{scale, 1.f / scale}}; \ - dispatch_kernel>( \ - d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, \ - stream); \ - return; \ +#define DISPATCH_CONV_INT8_EPILOGUE(_act_op) \ + do { \ + IConvEpilogue<_act_op> epilogue{ \ + d_dst, \ + d_z, \ + layout.batch_stride, \ + layout.channel_stride / 4, \ + layout.height_stride, \ + layout.width_stride, \ + gamma, \ + _act_op{scale, 1.f / scale}}; \ + dispatch_kernel>( \ + d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, stream); \ + return; \ } while (0) #define cb(_nonline_mode) \ if (static_cast(nonlinear_mode) == NonlineMode::_nonline_mode) { \ @@ -198,13 +183,13 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm:: #undef DISPATCH_CONV_INT8_EPILOGUE } -#define INST(_visitor) \ - template void ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm:: \ - dispatch_nonlinear_mode<_visitor>( \ - const int8_t* d_src, const int8_t* d_filter, \ - _visitor bias_visitor, const int8_t* d_z, int8_t* d_dst, \ - const ConvParam& param, float alpha, float beta, \ - float gamma, float scale, cudaStream_t stream, \ +#define INST(_visitor) \ + template void ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm:: \ + dispatch_nonlinear_mode<_visitor>( \ + const int8_t* d_src, const int8_t* d_filter, \ + _visitor bias_visitor, const int8_t* d_z, int8_t* d_dst, \ + const ConvParam& param, float alpha, float beta, float gamma, \ + float scale, cudaStream_t stream, \ param::ConvBias::NonlineMode nonlinear_mode); INST(PerChannelBiasVisitor); diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma.cpp index 97abc14a..98598ebb 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma.cpp @@ -10,12 +10,12 @@ */ #include "./algo.h" +#include "src/common/conv_bias.h" #include "src/cuda/convolution_helper/bias_visitor.cuh" #include "src/cuda/convolution_helper/epilogue.cuh" #include "src/cuda/convolution_helper/layout.cuh" #include "src/cuda/convolution_helper/parameter.cuh" #include "src/cuda/utils.h" -#include "src/common/conv_bias.h" using namespace megdnn; using namespace cuda; @@ -23,47 +23,42 @@ using namespace convolution; #if CUDA_VERSION >= 10000 namespace { -using MMATileSize = - ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize; +using MMATileSize = ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize; template -void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, Epilogue epilogue, - const ConvParam& param, float alpha, float beta, - cudaStream_t stream, MMATileSize mma_tile_size) { - void (*kern_wrapper)(const int8_t*, const int8_t*, BiasVisitor, Epilogue, - const ConvParam& param, float alpha, float beta, - cudaStream_t stream); +void dispatch_kernel( + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + Epilogue epilogue, const ConvParam& param, float alpha, float beta, + cudaStream_t stream, MMATileSize mma_tile_size) { + void (*kern_wrapper)( + const int8_t*, const int8_t*, BiasVisitor, Epilogue, const ConvParam& param, + float alpha, float beta, cudaStream_t stream); using namespace conv_bias_int8; // for turing switch (mma_tile_size) { case MMATileSize::IMMA8x32x16: - kern_wrapper = - do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4< - BiasVisitor, Epilogue>; + kern_wrapper = do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4< + BiasVisitor, Epilogue>; break; case MMATileSize::IMMA32x8x16: - kern_wrapper = - do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4< - BiasVisitor, Epilogue>; + kern_wrapper = do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4< + BiasVisitor, Epilogue>; break; case MMATileSize::IMMA16x16x16: - kern_wrapper = - do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4< - BiasVisitor, Epilogue>; + kern_wrapper = do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4< + BiasVisitor, Epilogue>; break; default: megdnn_assert(false, "invalid mma tile size"); } - return kern_wrapper(d_src, d_filter, bias_visitor, epilogue, param, alpha, - beta, stream); + return kern_wrapper( + d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, stream); } }; // namespace bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::is_available( const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } if (args.bias_layout->ndim <= 0) @@ -76,26 +71,23 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::is_available( bool available = true; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - if (!check_bias_share_in_channel(*(args.bias_layout), - param.format)) + if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) return false; if (param.format != Format::CHWN4) return false; - UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); // TODO support group conv available &= param.sparse == Sparse::DENSE; // mode must be cross correlation available &= param.mode == Mode::CROSS_CORRELATION; // check data type - auto src_dtype = args.src_layout->dtype, - filter_dtype = args.filter_layout->dtype, - bias_dtype = args.bias_layout->dtype, - dst_dtype = args.dst_layout->dtype; - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8 && - bias_dtype.enumv() == DTypeEnum::QuantizedS32 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8); + auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype, + bias_dtype = args.bias_layout->dtype, dst_dtype = args.dst_layout->dtype; + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + bias_dtype.enumv() == DTypeEnum::QuantizedS32 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); // check layout available &= (ci % 16 == 0); // TODO: support dialtion @@ -106,8 +98,7 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::is_available( return available; } -size_t -ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::get_workspace_in_bytes( +size_t ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::get_workspace_in_bytes( const SizeArgs& /* args */) const { return 0; } @@ -117,25 +108,20 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::exec( using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); auto&& stream = cuda_stream(args.opr->handle()); ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, + kern_param.pw = pw, kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, kern_param.fw = fw; float src_scale = args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - bias_scale = - args.bias_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + bias_scale = args.bias_layout->dtype.param().scale, dst_scale = args.dst_layout->dtype.param().scale; - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale; int8_t* z_dev_ptr = nullptr; float gamma = 1.f; if (args.z_layout->ndim > 0) { @@ -147,38 +133,35 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::exec( bias_visitor.bias = args.bias_tensor->compatible_ptr(); dispatch_nonlinear_mode( args.src_tensor->compatible_ptr(), - args.filter_tensor->compatible_ptr(), bias_visitor, - z_dev_ptr, args.dst_tensor->compatible_ptr(), kern_param, - alpha, beta, gamma, dst_scale, stream, param.nonlineMode, - m_mma_tile_size); + args.filter_tensor->compatible_ptr(), bias_visitor, z_dev_ptr, + args.dst_tensor->compatible_ptr(), kern_param, alpha, beta, gamma, + dst_scale, stream, param.nonlineMode, m_mma_tile_size); } template -void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm:: - dispatch_nonlinear_mode(const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, int8_t* d_z, - int8_t* d_dst, const ConvParam& param, - float alpha, float beta, float gamma, - float scale, cudaStream_t stream, - param::ConvBias::NonlineMode nonlinear_mode, - MMATileSize mma_tile_size) { +void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::dispatch_nonlinear_mode( + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + int8_t* d_z, int8_t* d_dst, const ConvParam& param, float alpha, float beta, + float gamma, float scale, cudaStream_t stream, + param::ConvBias::NonlineMode nonlinear_mode, MMATileSize mma_tile_size) { using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; Layout layout; layout.init(param.n, param.co, param.ho, param.wo); -#define DISPATCH_CONV_IMMA_EPILOGUE(_act_op) \ - do { \ - IConvEpilogue<_act_op> epilogue{d_dst, \ - d_z, \ - layout.batch_stride, \ - layout.channel_stride / 4, \ - layout.height_stride, \ - layout.width_stride, \ - gamma, \ - _act_op{scale, 1.f / scale}}; \ - dispatch_kernel>( \ - d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, \ - stream, mma_tile_size); \ - return; \ +#define DISPATCH_CONV_IMMA_EPILOGUE(_act_op) \ + do { \ + IConvEpilogue<_act_op> epilogue{ \ + d_dst, \ + d_z, \ + layout.batch_stride, \ + layout.channel_stride / 4, \ + layout.height_stride, \ + layout.width_stride, \ + gamma, \ + _act_op{scale, 1.f / scale}}; \ + dispatch_kernel>( \ + d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, stream, \ + mma_tile_size); \ + return; \ } while (0) #define cb(_nonline_mode) \ if (static_cast(nonlinear_mode) == NonlineMode::_nonline_mode) { \ @@ -190,15 +173,14 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm:: #undef DISPATCH_CONV_IMMA_EPILOGUE } -#define INST(_visitor) \ - template void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm:: \ - dispatch_nonlinear_mode<_visitor>( \ - const int8_t* d_src, const int8_t* d_filter, \ - _visitor bias_visitor, int8_t* d_z, int8_t* d_dst, \ - const ConvParam& param, float alpha, float beta, \ - float gamma, float scale, cudaStream_t stream, \ - param::ConvBias::NonlineMode nonlinear_mode, \ - MMATileSize mma_tile_size); +#define INST(_visitor) \ + template void \ + ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::dispatch_nonlinear_mode< \ + _visitor>( \ + const int8_t* d_src, const int8_t* d_filter, _visitor bias_visitor, \ + int8_t* d_z, int8_t* d_dst, const ConvParam& param, float alpha, \ + float beta, float gamma, float scale, cudaStream_t stream, \ + param::ConvBias::NonlineMode nonlinear_mode, MMATileSize mma_tile_size); INST(PerChannelBiasVisitor); diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma_reorder_filter.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma_reorder_filter.cpp index 687dbbd9..a925efb9 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma_reorder_filter.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma_reorder_filter.cpp @@ -10,12 +10,12 @@ */ #include "./algo.h" +#include "src/common/conv_bias.h" #include "src/cuda/convolution_helper/bias_visitor.cuh" #include "src/cuda/convolution_helper/epilogue.cuh" #include "src/cuda/convolution_helper/layout.cuh" #include "src/cuda/convolution_helper/parameter.cuh" #include "src/cuda/utils.h" -#include "src/common/conv_bias.h" using namespace megdnn; using namespace cuda; @@ -23,17 +23,16 @@ using namespace convolution; #if CUDA_VERSION >= 10000 namespace { -using MMATileSize = - ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize; +using MMATileSize = ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize; template -void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, Epilogue epilogue, - const ConvParam& param, float alpha, float beta, - cudaStream_t stream, MMATileSize mma_tile_size) { - void (*kern_wrapper)(const int8_t*, const int8_t*, BiasVisitor, Epilogue, - const ConvParam& param, float alpha, float beta, - cudaStream_t stream); +void dispatch_kernel( + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + Epilogue epilogue, const ConvParam& param, float alpha, float beta, + cudaStream_t stream, MMATileSize mma_tile_size) { + void (*kern_wrapper)( + const int8_t*, const int8_t*, BiasVisitor, Epilogue, const ConvParam& param, + float alpha, float beta, cudaStream_t stream); using namespace conv_bias_int8; switch (mma_tile_size) { case MMATileSize::IMMA8x32x16: @@ -54,35 +53,34 @@ void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter, default: megdnn_assert(false, "invalid mma tile size"); } - return kern_wrapper(d_src, d_filter, bias_visitor, epilogue, param, alpha, - beta, stream); + return kern_wrapper( + d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, stream); } template -void dispatch_nonlinear_mode(const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, int8_t* d_z, - int8_t* d_dst, const ConvParam& param, float alpha, - float beta, float gamma, float scale, - cudaStream_t stream, - param::ConvBias::NonlineMode nonlinear_mode, - MMATileSize mma_tile_size) { +void dispatch_nonlinear_mode( + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + int8_t* d_z, int8_t* d_dst, const ConvParam& param, float alpha, float beta, + float gamma, float scale, cudaStream_t stream, + param::ConvBias::NonlineMode nonlinear_mode, MMATileSize mma_tile_size) { using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; Layout layout; layout.init(param.n, param.co, param.ho, param.wo); -#define DISPATCH_CONV_IMMA_EPILOGUE(_act_op) \ - do { \ - IConvEpilogue<_act_op> epilogue{d_dst, \ - d_z, \ - layout.batch_stride, \ - layout.channel_stride / 4, \ - layout.height_stride, \ - layout.width_stride, \ - gamma, \ - _act_op{scale, 1.f / scale}}; \ - dispatch_kernel>( \ - d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, \ - stream, mma_tile_size); \ - return; \ +#define DISPATCH_CONV_IMMA_EPILOGUE(_act_op) \ + do { \ + IConvEpilogue<_act_op> epilogue{ \ + d_dst, \ + d_z, \ + layout.batch_stride, \ + layout.channel_stride / 4, \ + layout.height_stride, \ + layout.width_stride, \ + gamma, \ + _act_op{scale, 1.f / scale}}; \ + dispatch_kernel>( \ + d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, stream, \ + mma_tile_size); \ + return; \ } while (0) #define cb(_nonline_mode) \ if (static_cast(nonlinear_mode) == NonlineMode::_nonline_mode) { \ @@ -94,23 +92,20 @@ void dispatch_nonlinear_mode(const int8_t* d_src, const int8_t* d_filter, #undef DISPATCH_CONV_IMMA_EPILOGUE } -#define INST(_visitor) \ - template void dispatch_nonlinear_mode<_visitor>( \ - const int8_t* d_src, const int8_t* d_filter, \ - _visitor bias_visitor, int8_t* d_z, int8_t* d_dst, \ - const ConvParam& param, float alpha, float beta, float gamma, \ - float scale, cudaStream_t stream, \ - param::ConvBias::NonlineMode nonlinear_mode, \ - MMATileSize mma_tile_size); +#define INST(_visitor) \ + template void dispatch_nonlinear_mode<_visitor>( \ + const int8_t* d_src, const int8_t* d_filter, _visitor bias_visitor, \ + int8_t* d_z, int8_t* d_dst, const ConvParam& param, float alpha, \ + float beta, float gamma, float scale, cudaStream_t stream, \ + param::ConvBias::NonlineMode nonlinear_mode, MMATileSize mma_tile_size); INST(PerChannelBiasVisitor); }; // namespace -bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter:: - is_available(const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { +bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::is_available( + const SizeArgs& args) const { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } if (args.bias_layout->ndim <= 0) @@ -123,26 +118,23 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter:: bool available = true; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - if (!check_bias_share_in_channel(*(args.bias_layout), - param.format)) + if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) return false; if (param.format != Format::CHWN4) return false; - UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); // TODO support group conv available &= param.sparse == Sparse::DENSE; // mode must be cross correlation available &= param.mode == Mode::CROSS_CORRELATION; // check data type - auto src_dtype = args.src_layout->dtype, - filter_dtype = args.filter_layout->dtype, - bias_dtype = args.bias_layout->dtype, - dst_dtype = args.dst_layout->dtype; - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8 && - bias_dtype.enumv() == DTypeEnum::QuantizedS32 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8); + auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype, + bias_dtype = args.bias_layout->dtype, dst_dtype = args.dst_layout->dtype; + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + bias_dtype.enumv() == DTypeEnum::QuantizedS32 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); // check layout available &= (ci % 16 == 0); // TODO: support dialtion @@ -163,8 +155,7 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::exec( using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); // reorder filter { TensorLayout in = *(args.filter_layout); @@ -179,27 +170,22 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::exec( ts_in.layout = in, ts_out.layout = out; ts_in.raw_ptr = args.filter_tensor->raw_ptr, ts_out.raw_ptr = args.workspace.raw_ptr; - args.opr->handle()->create_operator()->exec(ts_in, - ts_out); + args.opr->handle()->create_operator()->exec(ts_in, ts_out); } auto&& stream = cuda_stream(args.opr->handle()); ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, + kern_param.pw = pw, kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, kern_param.fw = fw; float src_scale = args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - bias_scale = - args.bias_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + bias_scale = args.bias_layout->dtype.param().scale, dst_scale = args.dst_layout->dtype.param().scale; - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale; int8_t* z_dev_ptr = nullptr; float gamma = 1.f; if (args.z_layout->ndim > 0) { @@ -211,10 +197,9 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::exec( bias_visitor.bias = args.bias_tensor->compatible_ptr(); dispatch_nonlinear_mode( args.src_tensor->compatible_ptr(), - reinterpret_cast(args.workspace.raw_ptr), bias_visitor, - z_dev_ptr, args.dst_tensor->compatible_ptr(), kern_param, - alpha, beta, gamma, dst_scale, stream, param.nonlineMode, - m_mma_tile_size); + reinterpret_cast(args.workspace.raw_ptr), bias_visitor, z_dev_ptr, + args.dst_tensor->compatible_ptr(), kern_param, alpha, beta, gamma, + dst_scale, stream, param.nonlineMode, m_mma_tile_size); } #undef INST diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma_unroll_width.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma_unroll_width.cpp index a61e89cb..ac4666c7 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma_unroll_width.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_chwn4_imma_unroll_width.cpp @@ -9,13 +9,13 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "src/common/conv_bias.h" #include "src/cuda/conv_bias/algo.h" #include "src/cuda/convolution_helper/bias_visitor.cuh" #include "src/cuda/convolution_helper/epilogue.cuh" #include "src/cuda/convolution_helper/layout.cuh" #include "src/cuda/convolution_helper/parameter.cuh" #include "src/cuda/utils.h" -#include "src/common/conv_bias.h" using namespace megdnn; using namespace cuda; @@ -23,17 +23,16 @@ using namespace convolution; #if CUDA_VERSION >= 10000 namespace { -using MMATileSize = - ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize; +using MMATileSize = ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::MMATileSize; template -void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, Epilogue epilogue, - const ConvParam& param, float alpha, float beta, - cudaStream_t stream, MMATileSize mma_tile_size) { - void (*kern_wrapper)(const int8_t*, const int8_t*, BiasVisitor, Epilogue, - const ConvParam& param, float alpha, float beta, - cudaStream_t stream); +void dispatch_kernel( + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + Epilogue epilogue, const ConvParam& param, float alpha, float beta, + cudaStream_t stream, MMATileSize mma_tile_size) { + void (*kern_wrapper)( + const int8_t*, const int8_t*, BiasVisitor, Epilogue, const ConvParam& param, + float alpha, float beta, cudaStream_t stream); using namespace conv_bias_int8; switch (mma_tile_size) { case MMATileSize::IMMA8x32x16: @@ -54,35 +53,34 @@ void dispatch_kernel(const int8_t* d_src, const int8_t* d_filter, default: megdnn_assert(false, "invalid mma tile size"); } - return kern_wrapper(d_src, d_filter, bias_visitor, epilogue, param, alpha, - beta, stream); + return kern_wrapper( + d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, stream); } template -void dispatch_nonlinear_mode(const int8_t* d_src, const int8_t* d_filter, - BiasVisitor bias_visitor, int8_t* d_z, - int8_t* d_dst, const ConvParam& param, float alpha, - float beta, float gamma, float scale, - cudaStream_t stream, - param::ConvBias::NonlineMode nonlinear_mode, - MMATileSize mma_tile_size) { +void dispatch_nonlinear_mode( + const int8_t* d_src, const int8_t* d_filter, BiasVisitor bias_visitor, + int8_t* d_z, int8_t* d_dst, const ConvParam& param, float alpha, float beta, + float gamma, float scale, cudaStream_t stream, + param::ConvBias::NonlineMode nonlinear_mode, MMATileSize mma_tile_size) { using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; Layout layout; layout.init(param.n, param.co, param.ho, param.wo); -#define DISPATCH_CONV_IMMA_EPILOGUE(_act_op) \ - do { \ - IConvEpilogue<_act_op> epilogue{d_dst, \ - d_z, \ - layout.batch_stride, \ - layout.channel_stride / 4, \ - layout.height_stride, \ - layout.width_stride, \ - gamma, \ - _act_op{scale, 1.f / scale}}; \ - dispatch_kernel>( \ - d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, \ - stream, mma_tile_size); \ - return; \ +#define DISPATCH_CONV_IMMA_EPILOGUE(_act_op) \ + do { \ + IConvEpilogue<_act_op> epilogue{ \ + d_dst, \ + d_z, \ + layout.batch_stride, \ + layout.channel_stride / 4, \ + layout.height_stride, \ + layout.width_stride, \ + gamma, \ + _act_op{scale, 1.f / scale}}; \ + dispatch_kernel>( \ + d_src, d_filter, bias_visitor, epilogue, param, alpha, beta, stream, \ + mma_tile_size); \ + return; \ } while (0) #define cb(_nonline_mode) \ if (static_cast(nonlinear_mode) == NonlineMode::_nonline_mode) { \ @@ -94,23 +92,20 @@ void dispatch_nonlinear_mode(const int8_t* d_src, const int8_t* d_filter, #undef DISPATCH_CONV_IMMA_EPILOGUE } -#define INST(_visitor) \ - template void dispatch_nonlinear_mode<_visitor>( \ - const int8_t* d_src, const int8_t* d_filter, \ - _visitor bias_visitor, int8_t* d_z, int8_t* d_dst, \ - const ConvParam& param, float alpha, float beta, float gamma, \ - float scale, cudaStream_t stream, \ - param::ConvBias::NonlineMode nonlinear_mode, \ - MMATileSize mma_tile_size); +#define INST(_visitor) \ + template void dispatch_nonlinear_mode<_visitor>( \ + const int8_t* d_src, const int8_t* d_filter, _visitor bias_visitor, \ + int8_t* d_z, int8_t* d_dst, const ConvParam& param, float alpha, \ + float beta, float gamma, float scale, cudaStream_t stream, \ + param::ConvBias::NonlineMode nonlinear_mode, MMATileSize mma_tile_size); INST(PerChannelBiasVisitor); }; // namespace -bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth:: - is_available(const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { +bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::is_available( + const SizeArgs& args) const { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } if (args.bias_layout->ndim <= 0) @@ -123,26 +118,23 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth:: bool available = true; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - if (!check_bias_share_in_channel(*(args.bias_layout), - param.format)) + if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) return false; if (param.format != Format::CHWN4) return false; - UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); // TODO support group conv available &= param.sparse == Sparse::DENSE; // mode must be cross correlation available &= param.mode == Mode::CROSS_CORRELATION; // check data type - auto src_dtype = args.src_layout->dtype, - filter_dtype = args.filter_layout->dtype, - bias_dtype = args.bias_layout->dtype, - dst_dtype = args.dst_layout->dtype; - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8 && - bias_dtype.enumv() == DTypeEnum::QuantizedS32 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8); + auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype, + bias_dtype = args.bias_layout->dtype, dst_dtype = args.dst_layout->dtype; + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + bias_dtype.enumv() == DTypeEnum::QuantizedS32 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); // check batch size available &= (n % 4 == 0); // check layout @@ -165,8 +157,7 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::exec( using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); // reorder filter { TensorLayout in = *(args.filter_layout); @@ -181,27 +172,22 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::exec( ts_in.layout = in, ts_out.layout = out; ts_in.raw_ptr = args.filter_tensor->raw_ptr, ts_out.raw_ptr = args.workspace.raw_ptr; - args.opr->handle()->create_operator()->exec(ts_in, - ts_out); + args.opr->handle()->create_operator()->exec(ts_in, ts_out); } auto&& stream = cuda_stream(args.opr->handle()); ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, + kern_param.pw = pw, kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, kern_param.fw = fw; float src_scale = args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - bias_scale = - args.bias_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + bias_scale = args.bias_layout->dtype.param().scale, dst_scale = args.dst_layout->dtype.param().scale; - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale; int8_t* z_dev_ptr = nullptr; float gamma = 1.f; if (args.z_layout->ndim > 0) { @@ -213,10 +199,9 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::exec( bias_visitor.bias = args.bias_tensor->compatible_ptr(); dispatch_nonlinear_mode( args.src_tensor->compatible_ptr(), - reinterpret_cast(args.workspace.raw_ptr), bias_visitor, - z_dev_ptr, args.dst_tensor->compatible_ptr(), kern_param, - alpha, beta, gamma, dst_scale, stream, param.nonlineMode, - m_mma_tile_size); + reinterpret_cast(args.workspace.raw_ptr), bias_visitor, z_dev_ptr, + args.dst_tensor->compatible_ptr(), kern_param, alpha, beta, gamma, + dst_scale, stream, param.nonlineMode, m_mma_tile_size); } #undef INST diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp index f182f00b..2bfdb0ea 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp @@ -23,8 +23,7 @@ using namespace convolution; #if CUDA_VERSION >= 10020 bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } if (args.bias_layout->ndim <= 0) @@ -41,12 +40,9 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( return false; if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) return false; - size_t n = args.src_layout->operator[](0), - ci = args.src_layout->operator[](1) * 32, - hi = args.src_layout->operator[](2), - wi = args.src_layout->operator[](3); - size_t ho = args.dst_layout->operator[](2), - wo = args.dst_layout->operator[](3); + size_t n = args.src_layout->operator[](0), ci = args.src_layout->operator[](1) * 32, + hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3); + size_t ho = args.dst_layout->operator[](2), wo = args.dst_layout->operator[](3); size_t co; if (param.format == Format::NCHW32) { co = args.dst_layout->operator[](1) * 32; @@ -61,14 +57,13 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( // mode must be cross correlation available &= param.mode == Mode::CROSS_CORRELATION; // check data type - auto src_dtype = args.src_layout->dtype, - filter_dtype = args.filter_layout->dtype, - bias_dtype = args.bias_layout->dtype, - dst_dtype = args.dst_layout->dtype; - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8 && - bias_dtype.enumv() == DTypeEnum::QuantizedS32 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8); + auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype, + bias_dtype = args.bias_layout->dtype, dst_dtype = args.dst_layout->dtype; + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + bias_dtype.enumv() == DTypeEnum::QuantizedS32 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); // TODO: support dialtion available &= dh == 1 && dw == 1; // only support sm_75 or later, platform should have tensorcore int8 @@ -88,9 +83,8 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( return available; } -WorkspaceBundle -ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::get_workspace_bundle( - dt_byte* raw_ptr, const SizeArgs& args) const { +WorkspaceBundle ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: + get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const { if (args.preprocessed_filter) { return WorkspaceBundle{raw_ptr, {}}; } else { @@ -99,8 +93,7 @@ ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::get_workspace_bundle( } } -size_t -ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::get_workspace_in_bytes( +size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::get_workspace_in_bytes( const SizeArgs& args) const { return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } @@ -110,12 +103,9 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - size_t n = args.src_layout->operator[](0), - ci = args.src_layout->operator[](1) * 32, - hi = args.src_layout->operator[](2), - wi = args.src_layout->operator[](3); - size_t ho = args.dst_layout->operator[](2), - wo = args.dst_layout->operator[](3); + size_t n = args.src_layout->operator[](0), ci = args.src_layout->operator[](1) * 32, + hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3); + size_t ho = args.dst_layout->operator[](2), wo = args.dst_layout->operator[](3); size_t co; bool trans_oc; if (param.format == Format::NCHW32) { @@ -135,26 +125,22 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( filter_ptr = reinterpret_cast(args.workspace.raw_ptr); // filter: KCRS32 => CRSK32 and reorder oc cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( - filter_ptr, - reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, - fh, fw, trans_oc, stream); + filter_ptr, reinterpret_cast(args.filter_tensor->raw_ptr), co, + ci, fh, fw, trans_oc, stream); } else { - filter_ptr = reinterpret_cast( - args.preprocessed_filter->tensors[0].raw_ptr); + filter_ptr = + reinterpret_cast(args.preprocessed_filter->tensors[0].raw_ptr); } float src_scale = args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - bias_scale = - args.bias_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + bias_scale = args.bias_layout->dtype.param().scale, dst_scale = args.dst_layout->dtype.param().scale; // \note these constants of cutlass epilogue will be passed to method // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, // a different dtype here results in undefined epilogue behaviors - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale; int8_t* z_dev_ptr = nullptr; float gamma = 0.0; if (args.z_layout->ndim > 0) { @@ -166,25 +152,25 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( bool use_conv_filter_unity_opt = (fh == 1 && fw == 1); bool without_shared_load = (param.format == Format::NCHW32); - const auto* op = get_cutlass_conv_op(args, ConvOperator::kFprop, - ConvType::kConvolution, - use_conv_filter_unity_opt, without_shared_load); + const auto* op = get_cutlass_conv_op( + args, ConvOperator::kFprop, ConvType::kConvolution, + use_conv_filter_unity_opt, without_shared_load); execute_cutlass_conv_op( op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, - z_dev_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, - fw, ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, - &theta, &threshold, &dst_scale, stream); + z_dev_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, + wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, + &threshold, &dst_scale, stream); after_kernel_launch(); } std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string( AlgoParam algo_param) { - return ssprintf("%uX%uX%u_%uX%uX%u_%u", algo_param.threadblock_m, - algo_param.threadblock_n, algo_param.threadblock_k, - algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, - algo_param.stage); + return ssprintf( + "%uX%uX%u_%uX%uX%u_%u", algo_param.threadblock_m, algo_param.threadblock_n, + algo_param.threadblock_k, algo_param.warp_m, algo_param.warp_n, + algo_param.warp_k, algo_param.stage); } size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: @@ -218,10 +204,9 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess( cudaStream_t stream = cuda_stream(args.opr->handle()); // filter: KCRS32 => CRSK32 and reorder oc cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( - reinterpret_cast( - args.preprocessed_filter->tensors[0].raw_ptr), - reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, fh, - fw, trans_oc, stream); + reinterpret_cast(args.preprocessed_filter->tensors[0].raw_ptr), + reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, fh, fw, + trans_oc, stream); } #endif diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp index 83a0c341..e9ee4822 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp @@ -20,8 +20,7 @@ using namespace cuda; bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } if (args.bias_layout->ndim <= 0) @@ -41,18 +40,15 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( valid_format |= param.format == Format::NCHW4_NCHW && args.bias_layout->dtype.enumv() == DTypeEnum::Float32 && args.dst_layout->dtype.enumv() == DTypeEnum::Float32; - valid_format |= - param.format == Format::NCHW4_NHWC && - args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32 && - (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || - args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm); + valid_format |= param.format == Format::NCHW4_NHWC && + args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32 && + (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || + args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm); valid_format |= param.format == Format::NCHW4; if (!valid_format) return false; - size_t n = args.src_layout->operator[](0), - ci = args.src_layout->operator[](1) * 4, - hi = args.src_layout->operator[](2), - wi = args.src_layout->operator[](3); + size_t n = args.src_layout->operator[](0), ci = args.src_layout->operator[](1) * 4, + hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3); size_t co; size_t dst_spatial_pos; if (param.format == Format::NCHW4) { @@ -78,12 +74,11 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( // mode must be cross correlation available &= param.mode == Mode::CROSS_CORRELATION; // check data type - auto src_dtype = args.src_layout->dtype, - filter_dtype = args.filter_layout->dtype, - bias_dtype = args.bias_layout->dtype, - dst_dtype = args.dst_layout->dtype; - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8); + auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype, + bias_dtype = args.bias_layout->dtype, dst_dtype = args.dst_layout->dtype; + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8); available &= (bias_dtype.enumv() == DTypeEnum::QuantizedS32 && (dst_dtype.enumv() == DTypeEnum::QuantizedS8 || dst_dtype.enumv() == DTypeEnum::QuantizedS4 || @@ -109,9 +104,8 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( return available; } -WorkspaceBundle -ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::get_workspace_bundle( - dt_byte* raw_ptr, const SizeArgs& args) const { +WorkspaceBundle ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm:: + get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const { if (args.preprocessed_filter) { return WorkspaceBundle{raw_ptr, {}}; } else { @@ -120,8 +114,7 @@ ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::get_workspace_bundle( } } -size_t -ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::get_workspace_in_bytes( +size_t ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::get_workspace_in_bytes( const SizeArgs& args) const { return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } @@ -131,10 +124,8 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - size_t n = args.src_layout->operator[](0), - ci = args.src_layout->operator[](1) * 4, - hi = args.src_layout->operator[](2), - wi = args.src_layout->operator[](3); + size_t n = args.src_layout->operator[](0), ci = args.src_layout->operator[](1) * 4, + hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3); size_t co, dst_spatial_pos; if (param.format == Format::NCHW4) { co = args.dst_layout->operator[](1) * 4; @@ -169,17 +160,15 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( ts_src.layout = src; ts_dst.raw_ptr = args.workspace.raw_ptr; ts_dst.layout = dst; - auto&& transpose = - args.opr->handle()->create_operator(); + auto&& transpose = args.opr->handle()->create_operator(); transpose->exec(ts_src, ts_dst); } else { - filter_ptr = reinterpret_cast( - args.preprocessed_filter->tensors[0].raw_ptr); + filter_ptr = + reinterpret_cast(args.preprocessed_filter->tensors[0].raw_ptr); } float src_scale = args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale; + filter_scale = args.filter_layout->dtype.param().scale; // \note these constants of cutlass epilogue will be passed to method // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, @@ -190,14 +179,11 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( float gamma = 0.f; float theta = 0.f; if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { - theta = args.dst_layout->dtype.param() - .zero_point; + theta = args.dst_layout->dtype.param().zero_point; } if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { - megdnn_assert(args.dst_layout->dtype.category() == - DTypeCategory::QUANTIZED); - float bias_scale = - args.bias_layout->dtype.param().scale; + megdnn_assert(args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED); + float bias_scale = args.bias_layout->dtype.param().scale; dst_scale = get_scale(args.dst_layout->dtype); alpha /= dst_scale, beta = bias_scale / dst_scale; } @@ -207,15 +193,14 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( z_ptr = args.z_tensor->raw_ptr; gamma = 1.f; if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { - megdnn_assert(args.dst_layout->dtype.category() == - DTypeCategory::QUANTIZED); + megdnn_assert( + args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED); float z_scale = get_scale(args.z_layout->dtype); gamma = z_scale / dst_scale; } if (args.z_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { uint8_t z_zero = - args.z_layout->dtype.param() - .zero_point; + args.z_layout->dtype.param().zero_point; delta = -z_zero * gamma; } } @@ -228,10 +213,10 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( use_conv_filter_unity_opt, without_shared_load); execute_cutlass_conv_op( - op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, - z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, - ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, - &theta, &threshold, &dst_scale, stream); + op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, z_ptr, + args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, + pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, + &dst_scale, stream); after_kernel_launch(); } @@ -241,9 +226,8 @@ size_t ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm:: return 0_z; } -SmallVector ConvBiasForwardImpl:: - AlgoInt8NCHW4DotProdImplicitGemm::deduce_preprocessed_filter_layout( - const SizeArgs& args) const { +SmallVector ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm:: + deduce_preprocessed_filter_layout(const SizeArgs& args) const { return {args.filter_layout->collapse_contiguous()}; } @@ -252,10 +236,8 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec_preprocess( using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - size_t n = args.src_layout->operator[](0), - ci = args.src_layout->operator[](1) * 4, - hi = args.src_layout->operator[](2), - wi = args.src_layout->operator[](3); + size_t n = args.src_layout->operator[](0), ci = args.src_layout->operator[](1) * 4, + hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3); size_t co, dst_spatial_pos; if (param.format == Format::NCHW4) { co = args.dst_layout->operator[](1) * 4; diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_imma.cpp index f1aaa2fb..f6a89f81 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_imma.cpp @@ -10,9 +10,9 @@ */ #include "./algo.h" -#include "src/cuda/utils.h" -#include "src/cuda/convolution_helper/bias_visitor.cuh" #include "src/common/conv_bias.h" +#include "src/cuda/convolution_helper/bias_visitor.cuh" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; @@ -20,8 +20,7 @@ using namespace cuda; #if CUDA_VERSION >= 10000 bool ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::is_available( const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } if (args.bias_layout->ndim <= 0) @@ -34,26 +33,23 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::is_available( bool available = true; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - if (!check_bias_share_in_channel(*(args.bias_layout), - param.format)) + if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) return false; if (param.format != Format::NCHW4) return false; - UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); // TODO support group conv available &= param.sparse == Sparse::DENSE; // mode must be cross correlation available &= param.mode == Mode::CROSS_CORRELATION; // check data type - auto src_dtype = args.src_layout->dtype, - filter_dtype = args.filter_layout->dtype, - bias_dtype = args.bias_layout->dtype, - dst_dtype = args.dst_layout->dtype; - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8 && - bias_dtype.enumv() == DTypeEnum::QuantizedS32 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8); + auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype, + bias_dtype = args.bias_layout->dtype, dst_dtype = args.dst_layout->dtype; + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + bias_dtype.enumv() == DTypeEnum::QuantizedS32 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); // check layout available &= (ci % 16 == 0); // TODO: support dialtion @@ -64,9 +60,8 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::is_available( return available; } -WorkspaceBundle -ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::get_workspace_bundle( - dt_byte* raw_ptr, const SizeArgs& args) const { +WorkspaceBundle ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm:: + get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const { size_t ws_size_src = args.src_layout->span().dist_byte(); size_t ws_size_filter = args.filter_layout->span().dist_byte(); size_t ws_size_dst = args.dst_layout->span().dist_byte(); @@ -78,8 +73,7 @@ ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::get_workspace_bundle( return WorkspaceBundle{raw_ptr, {ws_size_src, ws_size_filter, ws_size_dst}}; } -size_t -ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::get_workspace_in_bytes( +size_t ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::get_workspace_in_bytes( const SizeArgs& args) const { return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } @@ -89,8 +83,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + UNPACK_CONV_BIAS_NCHW4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); auto ws = get_workspace_bundle(args.workspace.raw_ptr, args); auto ws_src = ws.get(0); auto ws_filter = ws.get(1); @@ -108,11 +101,10 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( ts_src.layout = src; ts_dst.raw_ptr = ws_src; ts_dst.layout = dst; - auto&& transpose = - args.opr->handle()->create_operator(); + auto&& transpose = args.opr->handle()->create_operator(); transpose->exec(ts_src, ts_dst); } - + // reformat filter from nchw4 to chwn4 { TensorLayout src{{co, ci / 4 * fh * fw}, dtype::Int32()}; @@ -124,26 +116,21 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( ts_src.layout = src; ts_dst.raw_ptr = ws_filter; ts_dst.layout = dst; - auto&& transpose = - args.opr->handle()->create_operator(); + auto&& transpose = args.opr->handle()->create_operator(); transpose->exec(ts_src, ts_dst); } convolution::ConvParam kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ho = ho, - kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, - kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, + kern_param.pw = pw, kern_param.sh = sh, kern_param.sw = sw, kern_param.fh = fh, kern_param.fw = fw; float src_scale = args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - bias_scale = - args.bias_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + bias_scale = args.bias_layout->dtype.param().scale, dst_scale = args.dst_layout->dtype.param().scale; - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale; // process z int8_t* z_dev_ptr = nullptr; @@ -160,8 +147,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( ts_src.layout = src; ts_dst.raw_ptr = ws_z; ts_dst.layout = dst; - auto&& transpose = - args.opr->handle()->create_operator(); + auto&& transpose = args.opr->handle()->create_operator(); transpose->exec(ts_src, ts_dst); z_dev_ptr = reinterpret_cast(ws_z); float z_scale = args.z_layout->dtype.param().scale; @@ -172,10 +158,9 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( bias_visitor.bias = args.bias_tensor->compatible_ptr(); ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::dispatch_nonlinear_mode< convolution::PerChannelBiasVisitor>( - reinterpret_cast(ws_src), - reinterpret_cast(ws_filter), bias_visitor, z_dev_ptr, - reinterpret_cast(ws_dst), kern_param, alpha, beta, gamma, - dst_scale, stream, param.nonlineMode, m_mma_tile_size); + reinterpret_cast(ws_src), reinterpret_cast(ws_filter), + bias_visitor, z_dev_ptr, reinterpret_cast(ws_dst), kern_param, + alpha, beta, gamma, dst_scale, stream, param.nonlineMode, m_mma_tile_size); // reformat chwn4 to nchw4 { @@ -188,8 +173,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( ts_src.layout = src; ts_dst.raw_ptr = args.dst_tensor->raw_ptr; ts_dst.layout = dst; - auto&& transpose = - args.opr->handle()->create_operator(); + auto&& transpose = args.opr->handle()->create_operator(); transpose->exec(ts_src, ts_dst); } } diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nhwc_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nhwc_imma.cpp index 8e35d1e9..3efa1c05 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nhwc_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nhwc_imma.cpp @@ -52,10 +52,8 @@ bool ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::is_available( auto dst_dtype = args.dst_layout->dtype.enumv(); - if (!(dst_dtype == DTypeEnum::QuantizedS8 || - dst_dtype == DTypeEnum::QuantizedS4 || - dst_dtype == DTypeEnum::Quantized4Asymm || - dst_dtype == DTypeEnum::Float32)) + if (!(dst_dtype == DTypeEnum::QuantizedS8 || dst_dtype == DTypeEnum::QuantizedS4 || + dst_dtype == DTypeEnum::Quantized4Asymm || dst_dtype == DTypeEnum::Float32)) return false; if (!(args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32 || @@ -82,10 +80,10 @@ bool ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::is_available( return false; bool use_conv_filter_unity_opt = (fh == 1 && fw == 1); - bool without_shared_load = ((co % m_algo_param.threadblock_n == 0) && - (m_algo_param.threadblock_n == 16 || - (m_algo_param.threadblock_n == 32 && - dst_dtype != DTypeEnum::Float32))); + bool without_shared_load = + ((co % m_algo_param.threadblock_n == 0) && + (m_algo_param.threadblock_n == 16 || + (m_algo_param.threadblock_n == 32 && dst_dtype != DTypeEnum::Float32))); const auto* op = get_cutlass_conv_op( args, ConvOperator::kFprop, ConvType::kConvolution, use_conv_filter_unity_opt, without_shared_load); @@ -95,8 +93,7 @@ bool ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::is_available( return true; } -size_t -ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::get_workspace_in_bytes( +size_t ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::get_workspace_in_bytes( const SizeArgs& args) const { if (args.preprocessed_filter) { return 0; @@ -121,12 +118,10 @@ void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec_preprocess( reorder_filter(args, m_algo_param.access_size, filter_ptr); } -std::tuple -ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::get_constants( - const ExecArgs& args) const { +std::tuple ConvBiasForwardImpl:: + AlgoInt8NHWCIMMAImplicitGemm::get_constants(const ExecArgs& args) const { float src_scale = args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, bias_scale = 1.f, dst_scale; if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { @@ -140,18 +135,15 @@ ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::get_constants( } else if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) { dst_scale = args.dst_layout->dtype.param().scale; } else if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { - dst_scale = - args.dst_layout->dtype.param().scale; - dst_zero = args.dst_layout->dtype.param() - .zero_point; + dst_scale = args.dst_layout->dtype.param().scale; + dst_zero = args.dst_layout->dtype.param().zero_point; } else { // DTypeEnum::Float32 megdnn_assert(args.dst_layout->dtype.enumv() == DTypeEnum::Float32); dst_scale = 1.f; } - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale, gamma = 0.f, delta = 0.f, - theta = dst_zero; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale, + gamma = 0.f, delta = 0.f, theta = dst_zero; if (args.z_layout->ndim > 0) { float z_scale; @@ -162,11 +154,9 @@ ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::get_constants( z_scale = args.z_layout->dtype.param().scale; gamma = z_scale / dst_scale; } else if (args.z_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { - z_scale = - args.z_layout->dtype.param().scale; + z_scale = args.z_layout->dtype.param().scale; uint8_t z_zero = - args.z_layout->dtype.param() - .zero_point; + args.z_layout->dtype.param().zero_point; gamma = z_scale / dst_scale; delta = -z_zero * gamma; } else { // DTypeEnum::Float32 @@ -175,8 +165,7 @@ ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::get_constants( } } - if (args.opr->param().nonlineMode == - param::ConvBias::NonlineMode::IDENTITY) { + if (args.opr->param().nonlineMode == param::ConvBias::NonlineMode::IDENTITY) { delta += theta; theta = 0.f; } @@ -188,12 +177,9 @@ void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( const ExecArgs& args) const { auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - size_t n = args.src_layout->operator[](0), - ci = args.src_layout->operator[](3), - hi = args.src_layout->operator[](1), - wi = args.src_layout->operator[](2); - size_t co = args.dst_layout->operator[](3), - ho = args.dst_layout->operator[](1), + size_t n = args.src_layout->operator[](0), ci = args.src_layout->operator[](3), + hi = args.src_layout->operator[](1), wi = args.src_layout->operator[](2); + size_t co = args.dst_layout->operator[](3), ho = args.dst_layout->operator[](1), wo = args.dst_layout->operator[](2); UNPACK_CONV_PARAMETER(fm, param); MARK_USED_VAR @@ -225,18 +211,17 @@ void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( auto dst_dtype = args.dst_layout->dtype.enumv(); - bool without_shared_load = ((co % m_algo_param.threadblock_n == 0) && - (m_algo_param.threadblock_n == 16 || - (m_algo_param.threadblock_n == 32 && - dst_dtype != DTypeEnum::Float32))); + bool without_shared_load = + ((co % m_algo_param.threadblock_n == 0) && + (m_algo_param.threadblock_n == 16 || + (m_algo_param.threadblock_n == 32 && dst_dtype != DTypeEnum::Float32))); if (dst_dtype == DTypeEnum::QuantizedS8) { // DTypeEnum::QuantizedS8 dst_scale = args.dst_layout->dtype.param().scale; } else if (dst_dtype == DTypeEnum::QuantizedS4) { dst_scale = args.dst_layout->dtype.param().scale; } else if (dst_dtype == DTypeEnum::Quantized4Asymm) { - dst_scale = - args.dst_layout->dtype.param().scale; + dst_scale = args.dst_layout->dtype.param().scale; } else { // DTypeEnum::Float32 dst_scale = 1.f; } @@ -247,26 +232,26 @@ void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( args, ConvOperator::kFprop, ConvType::kConvolution, use_conv_filter_unity_opt, without_shared_load); - execute_cutlass_conv_op(op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, - z_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, - ci, co, fh, fw, ho, wo, ph, pw, sh, sw, dh, dw, - &alpha, &beta, &gamma, &delta, &theta, &threshold, - &dst_scale, stream); + execute_cutlass_conv_op( + op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, z_ptr, + args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, + pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, + &dst_scale, stream); after_kernel_launch(); } std::string ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::to_string( AlgoParam algo_param) { - return ssprintf("%dX%dX%d_%dX%dX%d_%d_%d", algo_param.threadblock_m, - algo_param.threadblock_n, algo_param.threadblock_k, - algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, - algo_param.stage, algo_param.access_size); + return ssprintf( + "%dX%dX%d_%dX%dX%d_%d_%d", algo_param.threadblock_m, + algo_param.threadblock_n, algo_param.threadblock_k, algo_param.warp_m, + algo_param.warp_n, algo_param.warp_k, algo_param.stage, + algo_param.access_size); } void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::reorder_filter( - const ExecArgs& args, const int iterleaved, - void* reordered_filter) const { + const ExecArgs& args, const int iterleaved, void* reordered_filter) const { size_t co = args.filter_layout->operator[](0), ci = args.filter_layout->operator[](3), fh = args.filter_layout->operator[](1), @@ -276,18 +261,19 @@ void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::reorder_filter( // reformat filter from nhwc to ncxhwx and reorder oc // use trans_oc threadblock_n must be 16 or 32 and src dtype == dest dtype - bool trans_oc = ((co % m_algo_param.threadblock_n == 0) && - (m_algo_param.threadblock_n == 16 || - (m_algo_param.threadblock_n == 32 && - args.dst_layout->dtype.enumv() != DTypeEnum::Float32))); + bool trans_oc = + ((co % m_algo_param.threadblock_n == 0) && + (m_algo_param.threadblock_n == 16 || + (m_algo_param.threadblock_n == 32 && + args.dst_layout->dtype.enumv() != DTypeEnum::Float32))); uint32_t oc_iterleaved = (m_algo_param.threadblock_n == 32) ? 32 : 16; uint32_t alignbits = iterleaved * 8; cutlass_wrapper::reorder_nhwc_imma_filter<8>( reinterpret_cast(reordered_filter), - reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, fh, - fw, trans_oc, alignbits, oc_iterleaved, stream); + reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, fh, fw, + trans_oc, alignbits, oc_iterleaved, stream); } #endif diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp index 1be3ea4e..a8c68d44 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp @@ -19,8 +19,8 @@ using namespace cuda; using namespace convolution; #if CUDA_VERSION >= 10020 -size_t ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm:: - get_workspace_in_bytes(const SizeArgs& args) const { +size_t ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::get_workspace_in_bytes( + const SizeArgs& args) const { if (args.preprocessed_filter) { return 0; } else { @@ -43,9 +43,8 @@ size_t ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm:: return ws_size_reduce_filter; } -SmallVector ConvBiasForwardImpl:: - AlgoUInt4Int4NCHW64IMMAImplicitGemm::deduce_preprocessed_filter_layout( - const SizeArgs& args) const { +SmallVector ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm:: + deduce_preprocessed_filter_layout(const SizeArgs& args) const { return {args.filter_layout->collapse_contiguous(), args.bias_layout->collapse_contiguous()}; } @@ -62,9 +61,8 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( update_bias(args, bias_ptr, reduce_filter_ptr, reduce_workspace); } -std::tuple -ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::prepare_filter_bias( - const ExecArgs& args) const { +std::tuple ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm:: + prepare_filter_bias(const ExecArgs& args) const { void* filter_ptr = nullptr; void* bias_ptr = nullptr; if (args.preprocessed_filter) { @@ -74,45 +72,35 @@ ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::prepare_filter_bias( return {filter_ptr, bias_ptr}; } else { filter_ptr = reinterpret_cast(args.workspace.raw_ptr); - bias_ptr = - reinterpret_cast(args.workspace.raw_ptr + - args.filter_layout->span().dist_byte()); - void* reduce_filter_ptr = - reinterpret_cast(args.workspace.raw_ptr + - args.filter_layout->span().dist_byte() + - args.bias_layout->span().dist_byte()); - void* reduce_workspace = - reinterpret_cast(args.workspace.raw_ptr + - args.filter_layout->span().dist_byte() + - args.bias_layout->span().dist_byte() + - args.bias_layout->span().dist_byte()); + bias_ptr = reinterpret_cast( + args.workspace.raw_ptr + args.filter_layout->span().dist_byte()); + void* reduce_filter_ptr = reinterpret_cast( + args.workspace.raw_ptr + args.filter_layout->span().dist_byte() + + args.bias_layout->span().dist_byte()); + void* reduce_workspace = reinterpret_cast( + args.workspace.raw_ptr + args.filter_layout->span().dist_byte() + + args.bias_layout->span().dist_byte() + + args.bias_layout->span().dist_byte()); reorder_filter(args, filter_ptr); update_bias(args, bias_ptr, reduce_filter_ptr, reduce_workspace); } return {filter_ptr, bias_ptr}; } -std::tuple -ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::get_constants( - const ExecArgs& args) const { - float src_scale = - args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - bias_scale = - args.bias_layout->dtype.param().scale, - dst_scale = - args.dst_layout->dtype.param().scale; +std::tuple ConvBiasForwardImpl:: + AlgoUInt4Int4NCHW64IMMAImplicitGemm::get_constants(const ExecArgs& args) const { + float src_scale = args.src_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + bias_scale = args.bias_layout->dtype.param().scale, + dst_scale = args.dst_layout->dtype.param().scale; uint8_t dst_zero = args.dst_layout->dtype.param().zero_point; - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale, gamma = 0.f, delta = 0.f, - theta = dst_zero; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale, + gamma = 0.f, delta = 0.f, theta = dst_zero; if (args.z_layout->ndim > 0) { - float z_scale = - args.z_layout->dtype.param().scale; + float z_scale = args.z_layout->dtype.param().scale; gamma = z_scale / dst_scale; uint8_t z_zero = args.z_layout->dtype.param().zero_point; @@ -121,8 +109,7 @@ ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::get_constants( // identity epilogue has no theta: // alpha * accumulator + beta * bias + gamma * source + delta - if (args.opr->param().nonlineMode == - param::ConvBias::NonlineMode::IDENTITY) { + if (args.opr->param().nonlineMode == param::ConvBias::NonlineMode::IDENTITY) { delta += theta; theta = 0.f; } @@ -141,14 +128,12 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( auto&& stream = cuda_stream(args.opr->handle()); int src_zero_point = - args.src_tensor->layout.dtype.param() - .zero_point; + args.src_tensor->layout.dtype.param().zero_point; do_dispatch_reduce_filter_and_update_bias_4bit( reinterpret_cast(args.filter_tensor->raw_ptr), args.bias_tensor->compatible_ptr(), co, ci * fh * fw / 8, reinterpret_cast(updated_bias), - reinterpret_cast(reduce_workspace), src_zero_point, - stream); + reinterpret_cast(reduce_workspace), src_zero_point, stream); } #endif diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp index 73bb4b17..f6276d80 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp @@ -19,8 +19,7 @@ using namespace cuda; using namespace convolution; #if CUDA_VERSION >= 10020 -size_t -ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::get_workspace_in_bytes( +size_t ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::get_workspace_in_bytes( const SizeArgs& args) const { if (args.preprocessed_filter) { return 0; @@ -44,9 +43,8 @@ size_t ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm:: return ws_size_reduce_filter; } -SmallVector ConvBiasForwardImpl:: - AlgoUInt4Int4NHWCIMMAImplicitGemm::deduce_preprocessed_filter_layout( - const SizeArgs& args) const { +SmallVector ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm:: + deduce_preprocessed_filter_layout(const SizeArgs& args) const { return {args.filter_layout->collapse_contiguous(), args.bias_layout->collapse_contiguous()}; } @@ -63,9 +61,8 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::exec_preprocess( update_bias(args, bias_ptr, reduce_filter_ptr, reduce_workspace); } -std::tuple -ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::prepare_filter_bias( - const ExecArgs& args) const { +std::tuple ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm:: + prepare_filter_bias(const ExecArgs& args) const { void* filter_ptr = nullptr; void* bias_ptr = nullptr; if (args.preprocessed_filter) { @@ -75,65 +72,52 @@ ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::prepare_filter_bias( return {filter_ptr, bias_ptr}; } else { filter_ptr = reinterpret_cast(args.workspace.raw_ptr); - bias_ptr = - reinterpret_cast(args.workspace.raw_ptr + - args.filter_layout->span().dist_byte()); - void* reduce_filter_ptr = - reinterpret_cast(args.workspace.raw_ptr + - args.filter_layout->span().dist_byte() + - args.bias_layout->span().dist_byte()); - void* reduce_workspace = - reinterpret_cast(args.workspace.raw_ptr + - args.filter_layout->span().dist_byte() + - args.bias_layout->span().dist_byte() + - args.bias_layout->span().dist_byte()); + bias_ptr = reinterpret_cast( + args.workspace.raw_ptr + args.filter_layout->span().dist_byte()); + void* reduce_filter_ptr = reinterpret_cast( + args.workspace.raw_ptr + args.filter_layout->span().dist_byte() + + args.bias_layout->span().dist_byte()); + void* reduce_workspace = reinterpret_cast( + args.workspace.raw_ptr + args.filter_layout->span().dist_byte() + + args.bias_layout->span().dist_byte() + + args.bias_layout->span().dist_byte()); reorder_filter(args, m_algo_param.access_size, filter_ptr); update_bias(args, bias_ptr, reduce_filter_ptr, reduce_workspace); } return {filter_ptr, bias_ptr}; } -std::tuple -ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::get_constants( - const ExecArgs& args) const { - float src_scale = - args.src_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - bias_scale = - args.bias_layout->dtype.param().scale, +std::tuple ConvBiasForwardImpl:: + AlgoUInt4Int4NHWCIMMAImplicitGemm::get_constants(const ExecArgs& args) const { + float src_scale = args.src_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + bias_scale = args.bias_layout->dtype.param().scale, dst_scale; uint8_t dst_zero = 0; if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { - dst_scale = - args.dst_layout->dtype.param().scale; + dst_scale = args.dst_layout->dtype.param().scale; - dst_zero = args.dst_layout->dtype.param() - .zero_point; + dst_zero = args.dst_layout->dtype.param().zero_point; } else { // DTypeEnum::QuantizedS8 megdnn_assert(args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS8); dst_scale = args.dst_layout->dtype.param().scale; } - float alpha = src_scale * filter_scale / dst_scale, - beta = bias_scale / dst_scale, gamma = 0.f, delta = 0.f, - theta = dst_zero; + float alpha = src_scale * filter_scale / dst_scale, beta = bias_scale / dst_scale, + gamma = 0.f, delta = 0.f, theta = dst_zero; if (args.z_layout->ndim > 0) { float z_scale; if (args.z_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { - z_scale = - args.z_layout->dtype.param().scale; + z_scale = args.z_layout->dtype.param().scale; uint8_t z_zero = - args.z_layout->dtype.param() - .zero_point; + args.z_layout->dtype.param().zero_point; gamma = z_scale / dst_scale; delta = -z_zero * gamma; } else { // DTypeEnum::QuantizedS8 - megdnn_assert(args.z_layout->dtype.enumv() == - DTypeEnum::QuantizedS8); + megdnn_assert(args.z_layout->dtype.enumv() == DTypeEnum::QuantizedS8); z_scale = args.z_layout->dtype.param().scale; gamma = z_scale / dst_scale; } @@ -141,8 +125,7 @@ ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::get_constants( // identity epilogue has no theta: // alpha * accumulator + beta * bias + gamma * source + delta - if (args.opr->param().nonlineMode == - param::ConvBias::NonlineMode::IDENTITY) { + if (args.opr->param().nonlineMode == param::ConvBias::NonlineMode::IDENTITY) { delta += theta; theta = 0.f; } @@ -161,14 +144,12 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::update_bias( auto&& stream = cuda_stream(args.opr->handle()); int src_zero_point = - args.src_tensor->layout.dtype.param() - .zero_point; + args.src_tensor->layout.dtype.param().zero_point; do_dispatch_reduce_filter_and_update_bias_4bit( reinterpret_cast(args.filter_tensor->raw_ptr), args.bias_tensor->compatible_ptr(), co, ci * fh * fw / 8, reinterpret_cast(updated_bias), - reinterpret_cast(reduce_workspace), src_zero_point, - stream); + reinterpret_cast(reduce_workspace), src_zero_point, stream); } #endif diff --git a/dnn/src/cuda/conv_bias/inplace_matmul.cpp b/dnn/src/cuda/conv_bias/inplace_matmul.cpp index 5eb70683..daf1d835 100644 --- a/dnn/src/cuda/conv_bias/inplace_matmul.cpp +++ b/dnn/src/cuda/conv_bias/inplace_matmul.cpp @@ -16,8 +16,7 @@ using namespace megdnn; using namespace cuda; -bool ConvBiasForwardImpl::AlgoInplaceMatmul::is_available( - const SizeArgs& args) const { +bool ConvBiasForwardImpl::AlgoInplaceMatmul::is_available(const SizeArgs& args) const { if (args.z_layout->ndim > 0) return false; @@ -32,24 +31,22 @@ size_t ConvBiasForwardImpl::AlgoInplaceMatmul::get_workspace_in_bytes( auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); return dst_layout.span().dist_byte(); } return 0; } void ConvBiasForwardImpl::AlgoInplaceMatmul::exec(const ExecArgs& args) const { - WorkspaceBundle bundle{args.workspace.raw_ptr, - {get_workspace_in_bytes(args)}}; + WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; auto conv_dst_tensor = *args.dst_tensor; if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor.raw_ptr = bundle.get(0); conv_dst_tensor.layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - conv_dst_tensor.layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, + conv_dst_tensor.layout.dtype); } { @@ -63,14 +60,14 @@ void ConvBiasForwardImpl::AlgoInplaceMatmul::exec(const ExecArgs& args) const { conv_bias::exec_inplace_matmul_fwd( args.src_tensor->ptr(), args.filter_tensor->ptr(), - conv_dst_tensor.ptr(), N, - args.src_layout->stride[0], conv_dst_tensor.layout.stride[0], - IC, IH, IW, OC, OH, OW, FH, FW, fm.padding[0], fm.padding[1], - fm.stride[0], fm.stride[1], !fm.should_flip, stream); + conv_dst_tensor.ptr(), N, args.src_layout->stride[0], + conv_dst_tensor.layout.stride[0], IC, IH, IW, OC, OH, OW, FH, FW, + fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], + !fm.should_flip, stream); } - handle_bias_and_nonlinear(args.handle, args.nonlinear_mode, - &conv_dst_tensor, args.dst_tensor, - args.bias_tensor); + handle_bias_and_nonlinear( + args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, + args.bias_tensor); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_hswish.cu index ede57e2a..c1831e11 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_id.cu index af52cbda..413297a8 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_relu.cu index b9a31403..22ef1299 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_hswish.cu index 12cc85e9..adf6303a 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_id.cu index c833f90a..5cedfb3b 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_relu.cu index 4bffef5a..d9a21e22 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_cdiv4hwn4_ld_64bit_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_hswish.cu index 5bf4f141..495af1f9 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_hswish.cu @@ -1,13 +1,10 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_id.cu index a606532c..79317029 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_id.cu @@ -1,13 +1,11 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_relu.cu index 8b77a855..6fefc2a6 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_per_chan_relu.cu @@ -1,13 +1,10 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_hswish.cu index 3bc09c2f..995a49a2 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_id.cu index 37c49d07..534bc713 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_relu.cu index e27ac734..6396bb2e 100644 --- a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_hswish.cu index 3120bc1f..6adefd69 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_id.cu index 83e57e3e..1250c4fe 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_relu.cu index 6a9f1bb9..928cda8a 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu index 6b82068a..1e480de6 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_id.cu index 245638d4..03740ee6 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu index 19cc9b99..c2c5692c 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_reorder_filter< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu index af021e23..5050e4f4 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_id.cu index b65517ff..9487ef27 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_relu.cu index 316cb9e8..14878d0d 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma16x16x16_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_hswish.cu index 8e5e3305..03dc727f 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_id.cu index 30629aa7..abc3163d 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_relu.cu index e556a537..1fd9518d 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu index f6f2f92b..033b275d 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_id.cu index 76ca26cb..84d69432 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu index bfb3f746..94e902b4 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_reorder_filter< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu index d1534545..57f6aff9 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_id.cu index 0d5b2961..b7d7c77d 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_relu.cu index d6346834..eb2c16f5 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma32x8x16_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_hswish.cu index 74adeec8..1919c75a 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_id.cu index 394b986f..39c6be57 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_relu.cu index c57eb845..fef20fb6 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu index a6ff09c2..8c5777fe 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_id.cu index b86b7065..b5d7f31e 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu index 742b3291..5bce7318 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_reorder_filter< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu index 7c9500fe..097976fa 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_hswish.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_id.cu index c31552e1..9ecd9e2c 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_id.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_id.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_relu.cu index 6fdaad98..31760c3d 100644 --- a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_relu.cu +++ b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width_per_chan_relu.cu @@ -1,13 +1,13 @@ // generated by gen_cuda_conv_bias_kern_impls.py #include "../conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width.cuinl" -template void megdnn::cuda::conv_bias_int8::do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width>>( - const int8_t* d_src, - const int8_t* d_filter, - PerChannelBiasVisitor bias, - IConvEpilogue> epilogue, - const ConvParam& param, - float alpha, - float beta, - cudaStream_t stream); +template void megdnn::cuda::conv_bias_int8:: + do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width< + PerChannelBiasVisitor, + IConvEpilogue< + Activation>>( + const int8_t* d_src, const int8_t* d_filter, PerChannelBiasVisitor bias, + IConvEpilogue< + Activation> + epilogue, + const ConvParam& param, float alpha, float beta, cudaStream_t stream); diff --git a/dnn/src/cuda/conv_bias/matmul.cpp b/dnn/src/cuda/conv_bias/matmul.cpp index a28ab6fd..7b96a917 100644 --- a/dnn/src/cuda/conv_bias/matmul.cpp +++ b/dnn/src/cuda/conv_bias/matmul.cpp @@ -10,12 +10,12 @@ * implied. */ +#include "src/common/algo_base.h" #include "src/common/conv_bias.h" #include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/helper.h" #include "src/cuda/conv_bias/matmul/im2col.cuh" #include "src/cuda/utils.h" -#include "src/common/algo_base.h" using namespace megdnn; using namespace cuda; @@ -27,8 +27,8 @@ std::pair sub_opr_config( const TensorLayout& src_layout, const TensorLayout& filter_layout, const TensorLayout& dst_layout, const ConvBiasForwardImpl* opr) { size_t N = src_layout.shape[0], IC = fm.icpg, OC = fm.ocpg, - OH = dst_layout.shape[2], OW = dst_layout.shape[3], - FH = fm.spatial[0], FW = fm.spatial[1]; + OH = dst_layout.shape[2], OW = dst_layout.shape[3], FH = fm.spatial[0], + FW = fm.spatial[1]; megdnn_assert(src_layout.dtype.category() == DTypeCategory::FLOAT); TensorLayout Al({OC, IC * FH * FW}, filter_layout.dtype), @@ -45,26 +45,25 @@ std::pair sub_opr_config( std::pair> prepare_sub_opr( const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) { auto matmul_opr = args.handle->create_operator(); - set_execution_policy(args.opr, - matmul_opr.get()); - auto&& config = - sub_opr_config(args.filter_meta, *args.src_layout, - *args.filter_layout, *args.dst_layout, args.opr); + set_execution_policy( + args.opr, matmul_opr.get()); + auto&& config = sub_opr_config( + args.filter_meta, *args.src_layout, *args.filter_layout, *args.dst_layout, + args.opr); matmul_opr->param() = config.second; return {config.first, std::move(matmul_opr)}; } } // namespace -std::vector -ConvBiasForwardImpl::AlgoMatmul::get_subopr_list( +std::vector ConvBiasForwardImpl::AlgoMatmul::get_subopr_list( const TensorLayoutArray& layouts, const OperatorBase* opr) const { const ConvBiasForwardImpl* conv_bias_opr = static_cast(opr); - CanonizedFilterMeta fm = conv_bias_opr->make_canonized_filter_meta( - layouts[0].ndim, layouts[1]); - auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[4], - conv_bias_opr); + CanonizedFilterMeta fm = + conv_bias_opr->make_canonized_filter_meta(layouts[0].ndim, layouts[1]); + auto&& config = + sub_opr_config(fm, layouts[0], layouts[1], layouts[4], conv_bias_opr); std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); @@ -81,8 +80,8 @@ bool ConvBiasForwardImpl::AlgoMatmul::is_available(const SizeArgs& args) const { auto&& fm = args.filter_meta; return args.filter_meta.format == Param::Format::NCHW && - args.src_layout->dtype.category() == DTypeCategory::FLOAT && - fm.group == 1 && fm.spatial_ndim == 2; + args.src_layout->dtype.category() == DTypeCategory::FLOAT && fm.group == 1 && + fm.spatial_ndim == 2; } WorkspaceBundle ConvBiasForwardImpl::AlgoMatmul::get_workspace_bundle( @@ -91,9 +90,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoMatmul::get_workspace_bundle( SmallVector sizes; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); sizes.push_back(dst_layout.span().dist_byte()); } @@ -121,9 +119,9 @@ void ConvBiasForwardImpl::AlgoMatmul::exec(const ExecArgs& args) const { if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); conv_dst_tensor.layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - conv_dst_tensor.layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, + conv_dst_tensor.layout.dtype); } ExecArgs conv_args = args; @@ -142,33 +140,32 @@ void ConvBiasForwardImpl::AlgoMatmul::exec(const ExecArgs& args) const { megdnn_assert_internal(0); } } - handle_bias_and_nonlinear(args.handle, args.nonlinear_mode, - &conv_dst_tensor, args.dst_tensor, - args.bias_tensor); + handle_bias_and_nonlinear( + args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, + args.bias_tensor); } template void ConvBiasForwardImpl::AlgoMatmul::exec_internal( const ExecArgs& args, const WorkspaceBundle& bundle) { auto&& fm = args.filter_meta; - size_t N = args.src_layout->shape[0], IC = fm.icpg, - IH = args.src_layout->shape[2], IW = args.src_layout->shape[3], - OC = fm.ocpg, OH = args.dst_tensor->layout.shape[2], - OW = args.dst_tensor->layout.shape[3], FH = fm.spatial[0], - FW = fm.spatial[1], PH = fm.padding[0], PW = fm.padding[1], - SH = fm.stride[0], SW = fm.stride[1], DH = fm.dilation[0], - DW = fm.dilation[1]; + size_t N = args.src_layout->shape[0], IC = fm.icpg, IH = args.src_layout->shape[2], + IW = args.src_layout->shape[3], OC = fm.ocpg, + OH = args.dst_tensor->layout.shape[2], OW = args.dst_tensor->layout.shape[3], + FH = fm.spatial[0], FW = fm.spatial[1], PH = fm.padding[0], + PW = fm.padding[1], SH = fm.stride[0], SW = fm.stride[1], + DH = fm.dilation[0], DW = fm.dilation[1]; auto stream = cuda_stream(args.handle); T* dst_t = static_cast(bundle.get(0)); T* col = static_cast(bundle.get(1)); - conv_bias::im2col(args.src_tensor->ptr(), col, N, - args.src_layout->stride[0], IC, IH, IW, FH, FW, OH, OW, - PH, PW, SH, SW, DH, DW, stream); + conv_bias::im2col( + args.src_tensor->ptr(), col, N, args.src_layout->stride[0], IC, IH, IW, + FH, FW, OH, OW, PH, PW, SH, SW, DH, DW, stream); auto config = prepare_sub_opr(args); - TensorND A(args.filter_tensor->ptr(), config.first[0]), - B(col, config.first[1]), C(dst_t, config.first[2]); + TensorND A(args.filter_tensor->ptr(), config.first[0]), B(col, config.first[1]), + C(dst_t, config.first[2]); size_t matmul_ws_idx = 2; if (fm.should_flip) { conv_bias::flip_filter(args, bundle.get_workspace(2), A.raw_ptr); @@ -177,8 +174,7 @@ void ConvBiasForwardImpl::AlgoMatmul::exec_internal( config.second->exec(A, B, C, bundle.get_workspace(matmul_ws_idx)); - TensorLayout C2l({OC * OH * OW, N}, typename DTypeTrait::dtype()), - C3l = C2l; + TensorLayout C2l({OC * OH * OW, N}, typename DTypeTrait::dtype()), C3l = C2l; C3l.stride[0] = 1; C3l.stride[1] = args.dst_tensor->layout.stride[0]; TensorND C2(dst_t, C2l); diff --git a/dnn/src/cuda/conv_bias/matmul/im2col.cu b/dnn/src/cuda/conv_bias/matmul/im2col.cu index 2a5be3b4..57cff64f 100644 --- a/dnn/src/cuda/conv_bias/matmul/im2col.cu +++ b/dnn/src/cuda/conv_bias/matmul/im2col.cu @@ -18,12 +18,10 @@ using namespace cuda; namespace { template -__global__ void im2col_kernel(const T* im, T* col, uint32_t N, uint32_t INP_BS, - uint32_t IC, uint32_t IH, uint32_t IW, - uint32_t FH, uint32_t FW, uint32_t OH, - uint32_t OW, uint32_t PH, uint32_t PW, - uint32_t SH, uint32_t SW, uint32_t DH, - uint32_t DW) { +__global__ void im2col_kernel( + const T* im, T* col, uint32_t N, uint32_t INP_BS, uint32_t IC, uint32_t IH, + uint32_t IW, uint32_t FH, uint32_t FW, uint32_t OH, uint32_t OW, uint32_t PH, + uint32_t PW, uint32_t SH, uint32_t SW, uint32_t DH, uint32_t DW) { uint32_t n = threadIdx.x + blockIdx.y * blockDim.x; uint32_t ow = threadIdx.y + blockIdx.z * blockDim.y; uint32_t oh = blockIdx.x % OH; @@ -34,19 +32,17 @@ __global__ void im2col_kernel(const T* im, T* col, uint32_t N, uint32_t INP_BS, uint32_t didx = blockIdx.x * OW * N + ow * N + n; uint32_t ih = -PH + oh * SH + fh * DH; uint32_t iw = -PW + ow * SW + fw * DW; - col[didx] = (ih < IH && iw < IW - ? im[n * INP_BS + ic * IH * IW + ih * IW + iw] - : T(0.0f)); + col[didx] = + (ih < IH && iw < IW ? im[n * INP_BS + ic * IH * IW + ih * IW + iw] + : T(0.0f)); } } template -__global__ void col2im_kernel(const T* col, T* im, uint32_t N, uint32_t INP_BS, - uint32_t IC, uint32_t IH, uint32_t IW, - uint32_t FH, uint32_t FW, uint32_t OH, - uint32_t OW, uint32_t PH, uint32_t PW, - uint32_t SH, uint32_t SW, uint32_t DH, - uint32_t DW) { +__global__ void col2im_kernel( + const T* col, T* im, uint32_t N, uint32_t INP_BS, uint32_t IC, uint32_t IH, + uint32_t IW, uint32_t FH, uint32_t FW, uint32_t OH, uint32_t OW, uint32_t PH, + uint32_t PW, uint32_t SH, uint32_t SW, uint32_t DH, uint32_t DW) { uint32_t iw = threadIdx.x + blockIdx.y * blockDim.x; uint32_t ih = threadIdx.y + blockIdx.z * blockDim.y; uint32_t ic = blockIdx.x % IC; @@ -63,9 +59,9 @@ __global__ void col2im_kernel(const T* col, T* im, uint32_t N, uint32_t INP_BS, uint32_t anchorw = iw + PW - fw * DW; if (anchorw < OW * SW && anchorw % SW == 0) { uint32_t ow = anchorw / SW; - res += col[ic * FH * FW * OH * OW * N + - fh * FW * OH * OW * N + fw * OH * OW * N + - oh * OW * N + ow * N + n]; + res += + col[ic * FH * FW * OH * OW * N + fh * FW * OH * OW * N + + fw * OH * OW * N + oh * OW * N + ow * N + n]; } } } @@ -77,35 +73,32 @@ __global__ void col2im_kernel(const T* col, T* im, uint32_t N, uint32_t INP_BS, } // anonymous namespace template -void conv_bias::im2col(const T* im, T* col, size_t N, size_t INP_BS, size_t IC, - size_t IH, size_t IW, size_t FH, size_t FW, size_t OH, - size_t OW, size_t PH, size_t PW, size_t SH, size_t SW, - size_t DH, size_t DW, cudaStream_t stream) { +void conv_bias::im2col( + const T* im, T* col, size_t N, size_t INP_BS, size_t IC, size_t IH, size_t IW, + size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t DH, size_t DW, cudaStream_t stream) { dim3 threads(NR_THREADS_X, NR_THREADS_Y); // dim3 blocks(DIVUP(N, NR_THREADS_X), DIVUP(OW, NR_THREADS_Y), // IC*FH*FW*OH); IC*FH*FW*OH can be larger than 65536; shuffling blocks // dimensions to put IC*FH*FW*OH to the first dimension. - dim3 blocks(IC * FH * FW * OH, DIVUP(N, NR_THREADS_X), - DIVUP(OW, NR_THREADS_Y)); - im2col_kernel<<>>(im, col, N, INP_BS, IC, IH, - IW, FH, FW, OH, OW, PH, PW, - SH, SW, DH, DW); + dim3 blocks(IC * FH * FW * OH, DIVUP(N, NR_THREADS_X), DIVUP(OW, NR_THREADS_Y)); + im2col_kernel<<>>( + im, col, N, INP_BS, IC, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW, DH, DW); after_kernel_launch(); } template -void conv_bias::col2im(const T* col, T* im, size_t N, size_t INP_BS, size_t IC, - size_t IH, size_t IW, size_t FH, size_t FW, size_t OH, - size_t OW, size_t PH, size_t PW, size_t SH, size_t SW, - size_t DH, size_t DW, cudaStream_t stream) { +void conv_bias::col2im( + const T* col, T* im, size_t N, size_t INP_BS, size_t IC, size_t IH, size_t IW, + size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t DH, size_t DW, cudaStream_t stream) { dim3 threads(NR_THREADS_X, NR_THREADS_Y); // (x, y, z) is shuffled to (y, z, x) to bypass CUDA launch shape // limitation. dim3 blocks(DIVUP(IW, NR_THREADS_X), DIVUP(IH, NR_THREADS_Y), // N*IC); dim3 blocks(N * IC, DIVUP(IW, NR_THREADS_X), DIVUP(IH, NR_THREADS_Y)); - col2im_kernel<<>>(col, im, N, INP_BS, IC, IH, - IW, FH, FW, OH, OW, PH, PW, - SH, SW, DH, DW); + col2im_kernel<<>>( + col, im, N, INP_BS, IC, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW, DH, DW); after_kernel_launch(); } @@ -113,17 +106,17 @@ namespace megdnn { namespace cuda { namespace conv_bias { -#define DO_INST(T) \ - template void im2col(const T* im, T* col, size_t N, size_t INP_BS, \ - size_t IC, size_t IH, size_t IW, size_t FH, \ - size_t FW, size_t OH, size_t OW, size_t PH, \ - size_t PW, size_t SH, size_t SW, size_t DH, \ - size_t DW, cudaStream_t stream); \ - template void col2im(const T* col, T* im, size_t N, size_t INP_BS, \ - size_t IC, size_t IH, size_t IW, size_t FH, \ - size_t FW, size_t OH, size_t OW, size_t PH, \ - size_t PW, size_t SH, size_t SW, size_t DH, \ - size_t DW, cudaStream_t stream); +#define DO_INST(T) \ + template void im2col( \ + const T* im, T* col, size_t N, size_t INP_BS, size_t IC, size_t IH, \ + size_t IW, size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, \ + size_t PW, size_t SH, size_t SW, size_t DH, size_t DW, \ + cudaStream_t stream); \ + template void col2im( \ + const T* col, T* im, size_t N, size_t INP_BS, size_t IC, size_t IH, \ + size_t IW, size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, \ + size_t PW, size_t SH, size_t SW, size_t DH, size_t DW, \ + cudaStream_t stream); #define INST(_dt) DO_INST(DTypeTrait<_dt>::ctype) diff --git a/dnn/src/cuda/conv_bias/matmul/im2col.cuh b/dnn/src/cuda/conv_bias/matmul/im2col.cuh index 1ecbd5ea..88123f46 100644 --- a/dnn/src/cuda/conv_bias/matmul/im2col.cuh +++ b/dnn/src/cuda/conv_bias/matmul/im2col.cuh @@ -19,16 +19,18 @@ namespace conv_bias { //! col is of shape (ic*fh*fw, oh*ow*n) template -void im2col(const T* im, T* col, size_t N, size_t INP_BS, size_t IC, size_t IH, - size_t IW, size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, - size_t PW, size_t SH, size_t SW, size_t DH, size_t DW, // dilation - cudaStream_t stream); +void im2col( + const T* im, T* col, size_t N, size_t INP_BS, size_t IC, size_t IH, size_t IW, + size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t DH, size_t DW, // dilation + cudaStream_t stream); template -void col2im(const T* col, T* im, size_t N, size_t INP_BS, size_t IC, size_t IH, - size_t IW, size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, - size_t PW, size_t SH, size_t SW, size_t DH, size_t DW, // dilation - cudaStream_t stream); +void col2im( + const T* col, T* im, size_t N, size_t INP_BS, size_t IC, size_t IH, size_t IW, + size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t DH, size_t DW, // dilation + cudaStream_t stream); } // namespace conv_bias } // namespace cuda diff --git a/dnn/src/cuda/conv_bias/matmul/im2col_nhwc_int8.cu b/dnn/src/cuda/conv_bias/matmul/im2col_nhwc_int8.cu index ec4da730..4fe99fea 100644 --- a/dnn/src/cuda/conv_bias/matmul/im2col_nhwc_int8.cu +++ b/dnn/src/cuda/conv_bias/matmul/im2col_nhwc_int8.cu @@ -15,13 +15,11 @@ namespace { template -__global__ void im2col_kern(const int8_t* __restrict src, - int8_t* __restrict unrolled, uint32_t N, - uint32_t IH, uint32_t IW, uint32_t IC, uint32_t IWS, - uint32_t OH, uint32_t OW, uint32_t OC, uint32_t OWS, - uint32_t FH, uint32_t FW, uint32_t PH, uint32_t PW, - uint32_t SH, uint32_t SW, uint32_t DH, uint32_t DW, - uint32_t LD) { +__global__ void im2col_kern( + const int8_t* __restrict src, int8_t* __restrict unrolled, uint32_t N, + uint32_t IH, uint32_t IW, uint32_t IC, uint32_t IWS, uint32_t OH, uint32_t OW, + uint32_t OC, uint32_t OWS, uint32_t FH, uint32_t FW, uint32_t PH, uint32_t PW, + uint32_t SH, uint32_t SW, uint32_t DH, uint32_t DW, uint32_t LD) { uint32_t ic = blockIdx.x * 32 + threadIdx.x; uint32_t ow = blockIdx.y * 4 + threadIdx.y; uint32_t oh = blockIdx.z * 4 + threadIdx.z; @@ -44,22 +42,19 @@ __global__ void im2col_kern(const int8_t* __restrict src, } // anonymous namespace -void megdnn::cuda::im2col_nhwc_int8(const int8_t* src, int8_t* unrolled, - uint32_t N, uint32_t IH, uint32_t IW, - uint32_t IC, uint32_t IWS, uint32_t OH, - uint32_t OW, uint32_t OC, uint32_t OWS, - uint32_t FH, uint32_t FW, uint32_t PH, - uint32_t PW, uint32_t SH, uint32_t SW, - uint32_t DH, uint32_t DW, uint32_t LD, - bool flip, cudaStream_t stream) { +void megdnn::cuda::im2col_nhwc_int8( + const int8_t* src, int8_t* unrolled, uint32_t N, uint32_t IH, uint32_t IW, + uint32_t IC, uint32_t IWS, uint32_t OH, uint32_t OW, uint32_t OC, uint32_t OWS, + uint32_t FH, uint32_t FW, uint32_t PH, uint32_t PW, uint32_t SH, uint32_t SW, + uint32_t DH, uint32_t DW, uint32_t LD, bool flip, cudaStream_t stream) { dim3 nthreads = dim3(32, 4, 4); dim3 nblocks = dim3(DIVUP(IC, 32), DIVUP(OW, 4), DIVUP(OH, 4)); - void (*kern_ptr)(const int8_t* __restrict src, int8_t* __restrict unrolled, - uint32_t N, uint32_t IH, uint32_t IW, uint32_t IC, - uint32_t IWS, uint32_t OH, uint32_t OW, uint32_t OC, - uint32_t OWS, uint32_t FH, uint32_t FW, uint32_t PH, - uint32_t PW, uint32_t SH, uint32_t SW, uint32_t DH, - uint32_t DW, uint32_t LD); + void (*kern_ptr)( + const int8_t* __restrict src, int8_t* __restrict unrolled, uint32_t N, + uint32_t IH, uint32_t IW, uint32_t IC, uint32_t IWS, uint32_t OH, + uint32_t OW, uint32_t OC, uint32_t OWS, uint32_t FH, uint32_t FW, + uint32_t PH, uint32_t PW, uint32_t SH, uint32_t SW, uint32_t DH, + uint32_t DW, uint32_t LD); if (flip) { kern_ptr = im2col_kern; } else { @@ -67,8 +62,8 @@ void megdnn::cuda::im2col_nhwc_int8(const int8_t* src, int8_t* unrolled, } for (size_t n = 0; n < N; ++n) { kern_ptr<<>>( - src + n * IH * IW * IWS, unrolled + n * OH * OW * LD, N, IH, IW, - IC, IWS, OH, OW, OC, OWS, FH, FW, PH, PW, SH, SW, DH, DW, LD); + src + n * IH * IW * IWS, unrolled + n * OH * OW * LD, N, IH, IW, IC, + IWS, OH, OW, OC, OWS, FH, FW, PH, PW, SH, SW, DH, DW, LD); } after_kernel_launch(); } diff --git a/dnn/src/cuda/conv_bias/matmul/im2col_nhwc_int8.cuh b/dnn/src/cuda/conv_bias/matmul/im2col_nhwc_int8.cuh index fbfb91a4..e6f3b9cf 100644 --- a/dnn/src/cuda/conv_bias/matmul/im2col_nhwc_int8.cuh +++ b/dnn/src/cuda/conv_bias/matmul/im2col_nhwc_int8.cuh @@ -15,12 +15,11 @@ namespace megdnn { namespace cuda { -void im2col_nhwc_int8(const int8_t* src, int8_t* unrolled, uint32_t N, - uint32_t IH, uint32_t IW, uint32_t IC, uint32_t IWS, - uint32_t OH, uint32_t OW, uint32_t OC, uint32_t OWS, - uint32_t FH, uint32_t FW, uint32_t PH, uint32_t PW, - uint32_t SH, uint32_t SW, uint32_t DH, uint32_t DW, - uint32_t LD, bool flip, cudaStream_t stream); +void im2col_nhwc_int8( + const int8_t* src, int8_t* unrolled, uint32_t N, uint32_t IH, uint32_t IW, + uint32_t IC, uint32_t IWS, uint32_t OH, uint32_t OW, uint32_t OC, uint32_t OWS, + uint32_t FH, uint32_t FW, uint32_t PH, uint32_t PW, uint32_t SH, uint32_t SW, + uint32_t DH, uint32_t DW, uint32_t LD, bool flip, cudaStream_t stream); } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/conv_bias/matmul/inplace_matmul_impl.cu b/dnn/src/cuda/conv_bias/matmul/inplace_matmul_impl.cu index 750295b1..af6c1860 100644 --- a/dnn/src/cuda/conv_bias/matmul/inplace_matmul_impl.cu +++ b/dnn/src/cuda/conv_bias/matmul/inplace_matmul_impl.cu @@ -27,9 +27,7 @@ struct BufferFetcherTexture { struct BufferFetcherRaw { const float* ptr; - __device__ __forceinline__ float get(uint32_t offset) { - return ptr[offset]; - } + __device__ __forceinline__ float get(uint32_t offset) { return ptr[offset]; } }; struct BufferFetcherTextureHost { @@ -61,8 +59,7 @@ BufferFetcherTextureHost::BufferFetcherTextureHost(float* p, const size_t n) { cudaCreateChannelDesc(32, 0, 0, 0, cudaChannelFormatKindFloat); cudaTextureDesc tex_desc; memset(&tex_desc, 0, sizeof(cudaTextureDesc)); - if (cudaCreateTextureObject(&tex_obj, &res_desc, &tex_desc, NULL) == - cudaSuccess) { + if (cudaCreateTextureObject(&tex_obj, &res_desc, &tex_desc, NULL) == cudaSuccess) { val.tex = tex_obj; init_succ = true; } else { @@ -72,10 +69,10 @@ BufferFetcherTextureHost::BufferFetcherTextureHost(float* p, const size_t n) { template struct KernelPtr { - typedef void (*type)(BufferFetcher, BufferFetcher, float*, uint32_t, - uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t); + typedef void (*type)( + BufferFetcher, BufferFetcher, float*, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t); }; //! 1 -> 0xffffffff, 0 -> 0x00000000 @@ -90,9 +87,8 @@ union FloatAndU32 { //! \p mask must be either all 1 or 0 bits template -__device__ __forceinline__ float visit_with_mask(BufferFetcher buf, - uint32_t offset, - uint32_t mask) { +__device__ __forceinline__ float visit_with_mask( + BufferFetcher buf, uint32_t offset, uint32_t mask) { FloatAndU32 f; f.f = buf.get(offset & mask); f.u &= mask; @@ -100,14 +96,12 @@ __device__ __forceinline__ float visit_with_mask(BufferFetcher buf, } template -__global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, float* dst, - const uint32_t INP_BS, const uint32_t OUT_BS, - const uint32_t IC, const uint32_t IH, - const uint32_t IW, const uint32_t OC, - const uint32_t OH, const uint32_t OW, - const uint32_t FH, const uint32_t FW, - const uint32_t SH, const uint32_t SW, - const uint32_t PH, const uint32_t PW) { +__global__ void conv_kernel( + BufferFetcher src, BufferFetcher filter, float* dst, const uint32_t INP_BS, + const uint32_t OUT_BS, const uint32_t IC, const uint32_t IH, const uint32_t IW, + const uint32_t OC, const uint32_t OH, const uint32_t OW, const uint32_t FH, + const uint32_t FW, const uint32_t SH, const uint32_t SW, const uint32_t PH, + const uint32_t PW) { const uint32_t BM = BY < BX ? BY : BX; // BY*BX == 256 // (OC) * (IC*FH*FW) * (OH*OW) @@ -185,14 +179,18 @@ __global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, float* dst, if (tidy < BM) { uint32_t tmp = offsetB + (ic * IH + (fh2)) * IW + (fw2), ok = bool_as_mask(tidy + i < heightB), - p0 = bool_as_mask(fh2 + oh0 >= PH && fh2 + oh0 < IH + PH && - fw2 + ow0 >= PW && fw2 + ow0 < IW + PW), - p1 = bool_as_mask(fh2 + oh1 >= PH && fh2 + oh1 < IH + PH && - fw2 + ow1 >= PW && fw2 + ow1 < IW + PW), - p2 = bool_as_mask(fh2 + oh2 >= PH && fh2 + oh2 < IH + PH && - fw2 + ow2 >= PW && fw2 + ow2 < IW + PW), - p3 = bool_as_mask(fh2 + oh3 >= PH && fh2 + oh3 < IH + PH && - fw2 + ow3 >= PW && fw2 + ow3 < IW + PW); + p0 = bool_as_mask( + fh2 + oh0 >= PH && fh2 + oh0 < IH + PH && + fw2 + ow0 >= PW && fw2 + ow0 < IW + PW), + p1 = bool_as_mask( + fh2 + oh1 >= PH && fh2 + oh1 < IH + PH && + fw2 + ow1 >= PW && fw2 + ow1 < IW + PW), + p2 = bool_as_mask( + fh2 + oh2 >= PH && fh2 + oh2 < IH + PH && + fw2 + ow2 >= PW && fw2 + ow2 < IW + PW), + p3 = bool_as_mask( + fh2 + oh3 >= PH && fh2 + oh3 < IH + PH && + fw2 + ow3 >= PW && fw2 + ow3 < IW + PW); localB[tidy][tidx].x = visit_with_mask(src, tmp + op0, ok & p0); localB[tidy][tidx].y = visit_with_mask(src, tmp + op1, ok & p1); localB[tidy][tidx].z = visit_with_mask(src, tmp + op2, ok & p2); @@ -288,10 +286,10 @@ __global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, float* dst, } // anonymous namespace void conv_bias::exec_inplace_matmul_fwd( - const float* src, const float* filter, float* dst, size_t N, - size_t INP_BS, size_t OUT_BS, size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH, - size_t PW, size_t SH, size_t SW, bool is_xcorr, cudaStream_t stream) { + const float* src, const float* filter, float* dst, size_t N, size_t INP_BS, + size_t OUT_BS, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, + size_t FH, size_t FW, size_t PH, size_t PW, size_t SH, size_t SW, bool is_xcorr, + cudaStream_t stream) { BufferFetcherTextureHost src_tex(const_cast(src), N * INP_BS), filter_tex(const_cast(filter), OC * IC * FH * FW); @@ -317,32 +315,31 @@ void conv_bias::exec_inplace_matmul_fwd( } else { BX = BY = 16; } - dim3 blocks((OH * OW + BX * 4 - 1) / (BX * 4), (OC + BY * 4 - 1) / (BY * 4), - N); + dim3 blocks((OH * OW + BX * 4 - 1) / (BX * 4), (OC + BY * 4 - 1) / (BY * 4), N); dim3 threads(BX, BY); -#define DISPATCH_BX_BY(BX, BY) \ - do { \ - if (src_tex.init_succ) { \ - KernelPtr::type kptr; \ - if (is_xcorr) { \ - kptr = conv_kernel; \ - } else { \ - kptr = conv_kernel; \ - } \ - kptr<<>>( \ - src_tex.val, filter_tex.val, dst, INP_BS, OUT_BS, IC, IH, \ - IW, OC, OH, OW, FH, FW, SH, SW, PH, PW); \ - } else { \ - KernelPtr::type kptr; \ - if (is_xcorr) { \ - kptr = conv_kernel; \ - } else { \ - kptr = conv_kernel; \ - } \ - kptr<<>>( \ - src_buf, filter_buf, dst, INP_BS, OUT_BS, IC, IH, IW, OC, \ - OH, OW, FH, FW, SH, SW, PH, PW); \ - } \ +#define DISPATCH_BX_BY(BX, BY) \ + do { \ + if (src_tex.init_succ) { \ + KernelPtr::type kptr; \ + if (is_xcorr) { \ + kptr = conv_kernel; \ + } else { \ + kptr = conv_kernel; \ + } \ + kptr<<>>( \ + src_tex.val, filter_tex.val, dst, INP_BS, OUT_BS, IC, IH, IW, OC, \ + OH, OW, FH, FW, SH, SW, PH, PW); \ + } else { \ + KernelPtr::type kptr; \ + if (is_xcorr) { \ + kptr = conv_kernel; \ + } else { \ + kptr = conv_kernel; \ + } \ + kptr<<>>( \ + src_buf, filter_buf, dst, INP_BS, OUT_BS, IC, IH, IW, OC, OH, OW, \ + FH, FW, SH, SW, PH, PW); \ + } \ } while (0) #define DISPATCH_BX(BX) \ do { \ diff --git a/dnn/src/cuda/conv_bias/matmul/inplace_matmul_impl.cuh b/dnn/src/cuda/conv_bias/matmul/inplace_matmul_impl.cuh index 121bc29e..8c0aab1a 100644 --- a/dnn/src/cuda/conv_bias/matmul/inplace_matmul_impl.cuh +++ b/dnn/src/cuda/conv_bias/matmul/inplace_matmul_impl.cuh @@ -18,12 +18,11 @@ namespace megdnn { namespace cuda { namespace conv_bias { -void exec_inplace_matmul_fwd(const float* src, const float* filter, float* dst, - size_t N, size_t INP_BS, size_t OUT_BS, size_t IC, - size_t IH, size_t IW, size_t OC, size_t OH, - size_t OW, size_t FH, size_t FW, size_t PH, - size_t PW, size_t SH, size_t SW, bool is_xcorr, - cudaStream_t stream); +void exec_inplace_matmul_fwd( + const float* src, const float* filter, float* dst, size_t N, size_t INP_BS, + size_t OUT_BS, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, + size_t FH, size_t FW, size_t PH, size_t PW, size_t SH, size_t SW, bool is_xcorr, + cudaStream_t stream); } // namespace conv_bias } // namespace cuda diff --git a/dnn/src/cuda/conv_bias/matmul_8x8x32.cpp b/dnn/src/cuda/conv_bias/matmul_8x8x32.cpp index 5e5a5ad1..950a3829 100644 --- a/dnn/src/cuda/conv_bias/matmul_8x8x32.cpp +++ b/dnn/src/cuda/conv_bias/matmul_8x8x32.cpp @@ -9,16 +9,15 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/common/conv_bias.h" -#include "src/cuda/utils.h" -#include "src/cuda/utils.cuh" #include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/matmul/im2col_nhwc_int8.cuh" +#include "src/cuda/utils.cuh" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; -bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available( - const SizeArgs& args) const { +bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available(const SizeArgs& args) const { if (args.z_layout->ndim > 0) return false; if (!is_compute_capability_required(6, 1)) @@ -32,9 +31,8 @@ bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available( auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); } using NonlineMode = param::ConvBias::NonlineMode; @@ -48,8 +46,7 @@ bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available( (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && dst_layout.dtype.enumv() == DTypeEnum::QuantizedS32)) && fm.group == 1 && fm.spatial_ndim == 2 && - (fm.format == Param::Format::NHWC || - fm.format == Param::Format::NCHW4); + (fm.format == Param::Format::NHWC || fm.format == Param::Format::NCHW4); return available; }; @@ -57,8 +54,7 @@ template WorkspaceBundle ConvBiasForwardImpl::AlgoMatmul8x8x32::get_bundle( const SizeArgs& args) const { size_t src_unroll_part, filter_reshape_part; - size_t relayout_src_part = 0, relayout_filter_part = 0, - relayout_dst_part = 0; + size_t relayout_src_part = 0, relayout_filter_part = 0, relayout_dst_part = 0; auto&& fm = args.filter_meta; size_t n, ih, iw, oh, ow, fh, fw, ic, oc; n = args.dst_layout->shape[0]; @@ -95,16 +91,15 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoMatmul8x8x32::get_bundle( filter_reshape_part = 0; } - SmallVector sizes = {src_unroll_part, filter_reshape_part, - relayout_src_part, relayout_filter_part, - relayout_dst_part}; + SmallVector sizes = { + src_unroll_part, filter_reshape_part, relayout_src_part, + relayout_filter_part, relayout_dst_part}; auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); sizes.push_back(dst_layout.span().dist_byte()); } @@ -124,8 +119,7 @@ size_t ConvBiasForwardImpl::AlgoMatmul8x8x32::get_workspace_in_bytes( } template -void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal( - const ExecArgs& args) const { +void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) const { auto stream = args.handle->stream(); auto cublas_handle = args.handle->cublas_handle(); auto alpha = args.handle->one_device_i32(); @@ -141,11 +135,9 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal( filter_tensor = *args.filter_tensor; } else { // NCHW4 - auto to_nhwc = [](const TensorLayout& layout, - void* raw_ptr) -> TensorND { + auto to_nhwc = [](const TensorLayout& layout, void* raw_ptr) -> TensorND { return {raw_ptr, - {{layout[0], layout[2], layout[3], layout[1] * 4}, - layout.dtype}}; + {{layout[0], layout[2], layout[3], layout[1] * 4}, layout.dtype}}; }; src_tensor = to_nhwc(*args.src_layout, bundle.get(2)); filter_tensor = to_nhwc(args.filter_tensor->layout, bundle.get(3)); @@ -156,17 +148,13 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal( W = src.layout[3]; args.handle->relayout_opr()->exec( {src.raw_ptr, - TensorLayout{{N, H, W, C / 4, 4}, - { - src.layout.stride[0], - src.layout.stride[2], - src.layout.stride[3], - src.layout.stride[1], - src.layout.stride[4] - }, - src.layout.dtype}}, - {dst_ptr, - TensorLayout{{N, H, W, C / 4, 4}, src.layout.dtype}}); + TensorLayout{ + {N, H, W, C / 4, 4}, + {src.layout.stride[0], src.layout.stride[2], + src.layout.stride[3], src.layout.stride[1], + src.layout.stride[4]}, + src.layout.dtype}}, + {dst_ptr, TensorLayout{{N, H, W, C / 4, 4}, src.layout.dtype}}); }; relayout(*args.src_tensor, src_tensor.raw_ptr); relayout(*args.filter_tensor, filter_tensor.raw_ptr); @@ -194,9 +182,9 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal( if (need_src_unroll(args)) { inp0 = static_cast(bundle.get(0)); inp0_stride = LD; - im2col_nhwc_int8(src_tensor.compatible_ptr(), inp0, N, IH, IW, - IC, IWS, OH, OW, OC, OWS, FH, FW, PH, PW, SH, SW, DH, - DW, LD, fm.should_flip, stream); + im2col_nhwc_int8( + src_tensor.compatible_ptr(), inp0, N, IH, IW, IC, IWS, OH, OW, + OC, OWS, FH, FW, PH, PW, SH, SW, DH, DW, LD, fm.should_flip, stream); } else { inp0 = src_tensor.compatible_ptr(); inp0_stride = IWS; @@ -206,27 +194,28 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal( inp1 = static_cast(bundle.get(1)); cuda_check(cudaMemcpy2DAsync( inp1, LD * sizeof(int8_t), filter_tensor.raw_ptr, - FH * FW * IC * sizeof(int8_t), FH * FW * IC * sizeof(int8_t), - OC, cudaMemcpyDeviceToDevice, stream)); + FH * FW * IC * sizeof(int8_t), FH * FW * IC * sizeof(int8_t), OC, + cudaMemcpyDeviceToDevice, stream)); inp1_stride = LD; } else { inp1 = filter_tensor.compatible_ptr(); inp1_stride = FH * FW * IC; } - cublas_check(cublasGemmEx(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, - N * OH * OW, FH * FW * IC, alpha, inp1, CUDA_R_8I, - inp1_stride, inp0, CUDA_R_8I, inp0_stride, beta, - dst_tensor.compatible_ptr(), CUDA_R_32I, - OWS, CUDA_R_32I, CUBLAS_GEMM_DFALT)); + cublas_check(cublasGemmEx( + cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, N * OH * OW, FH * FW * IC, + alpha, inp1, CUDA_R_8I, inp1_stride, inp0, CUDA_R_8I, inp0_stride, beta, + dst_tensor.compatible_ptr(), CUDA_R_32I, OWS, CUDA_R_32I, + CUBLAS_GEMM_DFALT)); if (format == Param::Format::NCHW4) { args.handle->relayout_opr()->exec( {dst_tensor.compatible_ptr(), - TensorLayout{{N, OC / 4, OH, OW, 4}, - {static_cast(OH * OW * OC), 4, - static_cast(OC * OW), - static_cast(OC), 1}, - dst_tensor.layout.dtype}}, + TensorLayout{ + {N, OC / 4, OH, OW, 4}, + {static_cast(OH * OW * OC), 4, + static_cast(OC * OW), static_cast(OC), + 1}, + dst_tensor.layout.dtype}}, *args.dst_tensor); } } @@ -240,9 +229,9 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const { if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); conv_dst_tensor.layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - conv_dst_tensor.layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, + conv_dst_tensor.layout.dtype); } conv_args.dst_tensor = &conv_dst_tensor; conv_args.dst_layout = &conv_dst_tensor.layout; @@ -252,9 +241,9 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const { if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); conv_dst_tensor.layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - conv_dst_tensor.layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, + conv_dst_tensor.layout.dtype); } conv_args.dst_tensor = &conv_dst_tensor; conv_args.dst_layout = &conv_dst_tensor.layout; @@ -266,9 +255,9 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const { // NCHW4 exec_internal(conv_args); } - handle_bias_and_nonlinear(args.handle, args.nonlinear_mode, - &conv_dst_tensor, args.dst_tensor, - args.bias_tensor); + handle_bias_and_nonlinear( + args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, + args.bias_tensor); } bool ConvBiasForwardImpl::AlgoMatmul8x8x32::need_filter_reshape( @@ -299,8 +288,9 @@ bool ConvBiasForwardImpl::AlgoMatmul8x8x32::need_src_unroll( } auto&& fm = args.filter_meta; - return !(fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.stride[0] == 1 && - fm.stride[1] == 1 && fm.padding[0] == 0 && fm.padding[1] == 0 && - stride % 4 == 0); + return !( + fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.stride[0] == 1 && + fm.stride[1] == 1 && fm.padding[0] == 0 && fm.padding[1] == 0 && + stride % 4 == 0); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/opr_impl.cpp b/dnn/src/cuda/conv_bias/opr_impl.cpp index 84fcd37c..bf374278 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.cpp +++ b/dnn/src/cuda/conv_bias/opr_impl.cpp @@ -9,70 +9,60 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ +#include "src/cuda/conv_bias/opr_impl.h" #include "megdnn/dtype.h" #include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/helper.h" -#include "src/cuda/conv_bias/opr_impl.h" #include "src/cuda/handle.h" #include "src/cuda/utils.h" -#include "src/common/conv_bias.h" #include "src/common/algo_chooser.h" +#include "src/common/conv_bias.h" #include "src/cuda/cudnn_with_check.h" namespace megdnn { namespace cuda { -void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_in bias, _megdnn_tensor_in z, - _megdnn_tensor_out dst, - const PreprocessedFilter* preprocessed_filter, - _megdnn_workspace workspace) { - check_exec_allow_noncontiguous(src.layout, filter.layout, bias.layout, - z.layout, dst.layout, workspace.size, - preprocessed_filter); - AlgoBase::ExecArgs args(this, src, filter, bias, z, dst, workspace, - preprocessed_filter); - auto algo = get_algorithm(this, src.layout, filter.layout, bias.layout, - z.layout, dst.layout); +void ConvBiasForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias, + _megdnn_tensor_in z, _megdnn_tensor_out dst, + const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) { + check_exec_allow_noncontiguous( + src.layout, filter.layout, bias.layout, z.layout, dst.layout, + workspace.size, preprocessed_filter); + AlgoBase::ExecArgs args( + this, src, filter, bias, z, dst, workspace, preprocessed_filter); + auto algo = get_algorithm( + this, src.layout, filter.layout, bias.layout, z.layout, dst.layout); algo->exec(args); }; -std::vector -ConvBiasForwardImpl::get_all_algorithms(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, - const TensorLayout& z, - const TensorLayout& dst) { +std::vector ConvBiasForwardImpl::get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst) { return megdnn::get_all_algorithms( {this, src, filter, bias, z, dst}); } -std::vector -ConvBiasForwardImpl::get_all_algorithms_safe(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& bias, - const TensorLayout& z, - const TensorLayout& dst) { +std::vector ConvBiasForwardImpl::get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst) { return megdnn::get_all_algorithms_safe( {this, src, filter, bias, z, dst}); } ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { using namespace conv_bias; AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; auto dst_layout = *args.dst_layout; if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { dst_layout.dtype = DType(); - args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype, - args.filter_layout->dtype, - dst_layout.dtype); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); } auto conv_args = args; @@ -89,26 +79,25 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( auto get_cudnn_algo = [this, &conv_args, &args, workspace_limit_in_bytes, positive_attr, negative_attr]( - const thin_function& - cb) -> AlgoBase* { + const thin_function& cb) + -> AlgoBase* { auto cudnn_handle = cuda::cudnn_handle(this->handle()); CUDNNForwardDescs desc; conv_args.init_conv_desc(desc); #if CUDNN_MAJOR >= 7 int max_count = 0; - cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, - &max_count)); + cudnn_check( + cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &max_count)); SmallVector algo_perf(max_count); int ret_count = 0; cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, - desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, - &ret_count, algo_perf.data())); + desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, &ret_count, + algo_perf.data())); for (int i = 0; i < ret_count; ++i) { auto conv_bias_algo = cb(algo_perf[i].algo); if (conv_bias_algo->is_available_attribute( - args, positive_attr, negative_attr, - workspace_limit_in_bytes)) { + args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return conv_bias_algo; } } @@ -117,13 +106,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( cudnn_check(cudnnGetConvolutionForwardAlgorithm( cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, desc.conv_desc.conv_desc, desc.dst_desc.desc, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_limit_in_bytes, &algo)); + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_limit_in_bytes, + &algo)); auto conv_bias_algo = cb(algo); - if (conv_bias_algo->is_available_attribute(args, positive_attr, - negative_attr, - workspace_limit_in_bytes)) + if (conv_bias_algo->is_available_attribute( + args, positive_attr, negative_attr, workspace_limit_in_bytes)) return conv_bias_algo; #endif return nullptr; @@ -133,20 +121,18 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( negative_attr](const AlgoBase::SizeArgs& size_arg) -> ConvBiasForwardImpl::AlgoBase* { if (sm_algo_pack.batched_matmul.is_available_attribute( - size_arg, positive_attr, negative_attr, - workspace_limit_in_bytes)) { + size_arg, positive_attr, negative_attr, workspace_limit_in_bytes)) { return &sm_algo_pack.batched_matmul; } return nullptr; }; - const bool is_chanwise = - (args.filter_meta.format == Param::Format::NCHW && - args.filter_meta.group == src[1]) || - (args.filter_meta.format == Param::Format::NCHW4 && - args.filter_meta.group == src[1] * 4) || - (args.filter_meta.format == Param::Format::NCHW32 && - args.filter_meta.group == src[1] * 32); + const bool is_chanwise = (args.filter_meta.format == Param::Format::NCHW && + args.filter_meta.group == src[1]) || + (args.filter_meta.format == Param::Format::NCHW4 && + args.filter_meta.group == src[1] * 4) || + (args.filter_meta.format == Param::Format::NCHW32 && + args.filter_meta.group == src[1] * 32); // prefer special chanwise impl since as the group conv of cudnn // whose version is lower than v7.5.0 is still slower than our // implementation in many channel-wise cases @@ -156,19 +142,17 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( const int hw_size = src[2] * src[3]; //! choose dnn when stride != 1, may need calibrate for different cudnn //! version - const bool prefer_dnn_chanwise = - slow_cudnn_chanwise_impl || args.filter_meta.stride[0] != 1 || - args.filter_meta.stride[1] != 1 || hw_size < 512; + const bool prefer_dnn_chanwise = slow_cudnn_chanwise_impl || + args.filter_meta.stride[0] != 1 || + args.filter_meta.stride[1] != 1 || hw_size < 512; //! avoid bad case in cudnn, check dnn chanwise impl first if (is_chanwise) { if (prefer_dnn_chanwise) { if (sm_algo_pack.chanwise.is_available_attribute( - args, positive_attr, negative_attr, - workspace_limit_in_bytes)) + args, positive_attr, negative_attr, workspace_limit_in_bytes)) return &sm_algo_pack.chanwise; if (sm_algo_pack.chanwise8x8x32.is_available_attribute( - args, positive_attr, negative_attr, - workspace_limit_in_bytes)) + args, positive_attr, negative_attr, workspace_limit_in_bytes)) return &sm_algo_pack.chanwise8x8x32; } else { conv_args.dst_layout = &dst_layout; @@ -183,8 +167,8 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( //! Prefer CUDNN CONVBIAS. bool cudnn_conv_bias_act_supported = false; for (auto&& algo : sm_algo_pack.cudnn_conv_bias_activations) { - if (algo.is_available_attribute(args, positive_attr, negative_attr, - workspace_limit_in_bytes)) { + if (algo.is_available_attribute( + args, positive_attr, negative_attr, workspace_limit_in_bytes)) { cudnn_conv_bias_act_supported = true; break; } @@ -233,39 +217,33 @@ const char* ConvBiasForwardImpl::get_algorithm_set_name() const { } size_t ConvBiasForwardImpl::get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) { TensorLayoutArray layouts{src, filter, bias, z, dst}; HeuristicCache::Key key{this->handle(), this->get_opr_type(), - layouts.data(), layouts.size(), &this->param(), - sizeof(this->param())}; + layouts.data(), layouts.size(), + &this->param(), sizeof(this->param())}; auto rst = HeuristicCache::instance().get(key); if (rst.policy.algo.valid()) { return rst.workspace; } - AlgoBase::SizeArgs args{ - this, src, filter, bias, z, dst, preprocessed_filter}; - return get_algorithm(this, src, filter, bias, z, dst) - ->get_workspace_in_bytes(args); + AlgoBase::SizeArgs args{this, src, filter, bias, z, dst, preprocessed_filter}; + return get_algorithm(this, src, filter, bias, z, dst)->get_workspace_in_bytes(args); }; size_t ConvBiasForwardImpl::get_preprocess_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst) { + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst) { AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; return get_algorithm(this, src, filter, bias, z, dst) ->get_preprocess_workspace_in_bytes(args); } -SmallVector -ConvBiasForwardImpl::deduce_preprocessed_filter_layout( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst) { +SmallVector ConvBiasForwardImpl::deduce_preprocessed_filter_layout( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + const TensorLayout& z, const TensorLayout& dst) { AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; return get_algorithm(this, src, filter, bias, z, dst) ->deduce_preprocessed_filter_layout(args); @@ -276,12 +254,11 @@ void ConvBiasForwardImpl::exec_preprocess( _megdnn_tensor_in bias, const TensorLayout& z_layout, const TensorLayout& dst_layout, PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) { - TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout}, - z{nullptr, z_layout}; - AlgoBase::ExecArgs args(this, src, filter, bias, z, dst, workspace, - preprocessed_filter); - auto algo = get_algorithm(this, src.layout, filter.layout, bias.layout, - z.layout, dst.layout); + TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout}, z{nullptr, z_layout}; + AlgoBase::ExecArgs args( + this, src, filter, bias, z, dst, workspace, preprocessed_filter); + auto algo = get_algorithm( + this, src.layout, filter.layout, bias.layout, z.layout, dst.layout); return algo->exec_preprocess(args); } diff --git a/dnn/src/cuda/conv_bias/opr_impl.h b/dnn/src/cuda/conv_bias/opr_impl.h index 06ee9a95..3bb2b6fb 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.h +++ b/dnn/src/cuda/conv_bias/opr_impl.h @@ -19,28 +19,26 @@ namespace cuda { class ConvBiasForwardImpl : public ConvBiasForward { public: using ConvBiasForward::ConvBiasForward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_in bias, _megdnn_tensor_in z, - _megdnn_tensor_out dst, - const PreprocessedFilter* preprocessed_filter, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&, const TensorLayout&, - const TensorLayout&, - const PreprocessedFilter*) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias, + _megdnn_tensor_in z, _megdnn_tensor_out dst, + const PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, + const PreprocessedFilter*) override; - size_t get_preprocess_workspace_in_bytes(const TensorLayout&, - const TensorLayout&, - const TensorLayout&, - const TensorLayout&, - const TensorLayout&) override; + size_t get_preprocess_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&) override; SmallVector deduce_preprocessed_filter_layout( const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&, const TensorLayout&) override; - void exec_preprocess(const TensorLayout&, _megdnn_tensor_in, - _megdnn_tensor_in, const TensorLayout&, - const TensorLayout&, PreprocessedFilter*, - _megdnn_workspace) override; + void exec_preprocess( + const TensorLayout&, _megdnn_tensor_in, _megdnn_tensor_in, + const TensorLayout&, const TensorLayout&, PreprocessedFilter*, + _megdnn_workspace) override; const char* get_algorithm_set_name() const override; class AlgoBase; @@ -90,13 +88,11 @@ public: const TensorLayout& dst) override; Algorithm* get_algorithm_heuristic( const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& bias, const TensorLayout& z, - const TensorLayout& dst, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, + const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) override; private: - static AlgoPack sm_algo_pack; }; diff --git a/dnn/src/cuda/conv_bias/quint4x4x32_wmma.cpp b/dnn/src/cuda/conv_bias/quint4x4x32_wmma.cpp index 5f70adb2..d7de5259 100644 --- a/dnn/src/cuda/conv_bias/quint4x4x32_wmma.cpp +++ b/dnn/src/cuda/conv_bias/quint4x4x32_wmma.cpp @@ -15,8 +15,8 @@ #include "./quint4x4x32_wmma/activation_u4.cuh" #include "./quint4x4x32_wmma/reduce_with_scale_data.cuh" -#include "./reduce_filter.cuh" #include "./quint4x4x32_wmma/wmma_conv_integer_u4.cuh" +#include "./reduce_filter.cuh" using namespace megdnn; using namespace cuda; @@ -25,8 +25,7 @@ using namespace activation_u4; #if CUDA_VERSION >= 10000 bool ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::is_available( const SizeArgs& args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } if (args.z_layout->ndim > 0) @@ -54,8 +53,9 @@ bool ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::is_available( available &= (param.format == Param::Format::NCHW8); // device support sm_75 auto&& device_prop = current_device_prop(); - available &= (device_prop.major > 7 || - (device_prop.major == 7 && device_prop.minor >= 5)); + available &= + (device_prop.major > 7 || + (device_prop.major == 7 && device_prop.minor >= 5)); // nonlinmode should be RELU or Identity available &= param.nonlineMode == Param::NonlineMode::RELU || param.nonlineMode == Param::NonlineMode::IDENTITY; @@ -72,8 +72,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::get_workspace_bundle( IC = args.filter_layout->operator[](1) * 8, FH = args.filter_layout->operator[](2), FW = args.filter_layout->operator[](3); - size_t OH = args.dst_layout->operator[](2), - OW = args.dst_layout->operator[](3); + size_t OH = args.dst_layout->operator[](2), OW = args.dst_layout->operator[](3); size_t ws_size_zp_filter = OC * sizeof(int32_t); // for reduce filter @@ -85,8 +84,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::get_workspace_bundle( size_t ws_size_relayout_filter = get_workspace_in_bytes_do_conv(args); if (ws_size_relayout_filter > 0) { WorkspaceBundle ws{ - raw_ptr, - {ws_size_zp_filter, ws_size_zp_data, ws_size_relayout_filter}}; + raw_ptr, {ws_size_zp_filter, ws_size_zp_data, ws_size_relayout_filter}}; return ws; } WorkspaceBundle ws{raw_ptr, {ws_size_zp_filter, ws_size_zp_data}}; @@ -100,8 +98,7 @@ size_t ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::get_workspace_in_bytes( bool ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::use_kernel_fhxfw( const SizeArgs& args) const { - return (args.filter_meta.spatial[0] == 3 && - args.filter_meta.spatial[1] == 3); + return (args.filter_meta.spatial[0] == 3 && args.filter_meta.spatial[1] == 3); } size_t ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::get_workspace_in_bytes_do_conv( @@ -115,27 +112,21 @@ size_t ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::get_workspace_in_bytes_do_conv( return OC * IC * FH * FW / 2; } -void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec( - const ExecArgs& args) const { +void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec(const ExecArgs& args) const { auto&& handle = concrete_handle(args.opr->handle()); auto&& ws_bundle = get_workspace_bundle(args.workspace.raw_ptr, args); auto&& ws_zp_filter = ws_bundle.get_workspace(0); auto&& ws_zp_data = ws_bundle.get_workspace(1); - size_t N = args.src_layout->operator[](0), - IC = args.src_layout->operator[](1) * 8, - IH = args.src_layout->operator[](2), - IW = args.src_layout->operator[](3), - OC = args.filter_layout->operator[](0), - FH = args.filter_meta.spatial[0], FW = args.filter_meta.spatial[1], - OH = args.dst_layout->operator[](2), - OW = args.dst_layout->operator[](3), - PH = args.filter_meta.padding[0], PW = args.filter_meta.padding[1], - SH = args.filter_meta.stride[0], SW = args.filter_meta.stride[1]; - int32_t zp_data = - args.src_layout->dtype.param().zero_point; + size_t N = args.src_layout->operator[](0), IC = args.src_layout->operator[](1) * 8, + IH = args.src_layout->operator[](2), IW = args.src_layout->operator[](3), + OC = args.filter_layout->operator[](0), FH = args.filter_meta.spatial[0], + FW = args.filter_meta.spatial[1], OH = args.dst_layout->operator[](2), + OW = args.dst_layout->operator[](3), PH = args.filter_meta.padding[0], + PW = args.filter_meta.padding[1], SH = args.filter_meta.stride[0], + SW = args.filter_meta.stride[1]; + int32_t zp_data = args.src_layout->dtype.param().zero_point; int32_t zp_filter = - args.filter_layout->dtype.param() - .zero_point; + args.filter_layout->dtype.param().zero_point; int32_t zp_data_filter = zp_data * zp_filter * FH * FW * IC; auto&& stream = cuda_stream(handle); // zp filter @@ -144,9 +135,8 @@ void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec( FH * FW * IC / 8, ws_zp_filter.ptr(), stream); // zp data do_dispatch_reduce_with_scale_data_u4( - ws_zp_data.ptr(), - static_cast(args.src_tensor->raw_ptr), N, IH, IW, OH, OW, - PH, PW, FH, FW, SH, SW, IC, -zp_filter, + ws_zp_data.ptr(), static_cast(args.src_tensor->raw_ptr), + N, IH, IW, OH, OW, PH, PW, FH, FW, SH, SW, IC, -zp_filter, static_cast(zp_data), stream); // do conv @@ -154,17 +144,16 @@ void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec( wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_fhxfw( static_cast(args.src_tensor->raw_ptr), static_cast(args.filter_tensor->raw_ptr), - args.dst_tensor->compatible_ptr(), N, IH, IW, OH, OW, - PH, PW, IC, OC, FH, FW, SH, SW, static_cast(zp_data), - stream); + args.dst_tensor->compatible_ptr(), N, IH, IW, OH, OW, PH, PW, + IC, OC, FH, FW, SH, SW, static_cast(zp_data), stream); } else { auto&& ws_relayout_filter = ws_bundle.get_workspace(2); wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_1xfw( static_cast(args.src_tensor->raw_ptr), static_cast(args.filter_tensor->raw_ptr), args.dst_tensor->compatible_ptr(), - ws_relayout_filter.ptr(), N, IH, IW, OH, OW, PH, PW, - IC, OC, FH, FW, SH, SW, static_cast(zp_data), stream); + ws_relayout_filter.ptr(), N, IH, IW, OH, OW, PH, PW, IC, OC, + FH, FW, SH, SW, static_cast(zp_data), stream); } // do activation int s0 = args.bias_layout->stride[0], s1 = args.bias_layout->stride[1], @@ -179,13 +168,13 @@ void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec( if (param.nonlineMode == Param::NonlineMode::RELU) { do_dispatch_activation_u4( args.dst_tensor->compatible_ptr(), visitor, - ws_zp_data.ptr(), ws_zp_filter.ptr(), - zp_data_filter, N, OC, OH, OW, stream); + ws_zp_data.ptr(), ws_zp_filter.ptr(), zp_data_filter, + N, OC, OH, OW, stream); } else if (param.nonlineMode == Param::NonlineMode::IDENTITY) { do_dispatch_activation_u4( args.dst_tensor->compatible_ptr(), visitor, - ws_zp_data.ptr(), ws_zp_filter.ptr(), - zp_data_filter, N, OC, OH, OW, stream); + ws_zp_data.ptr(), ws_zp_filter.ptr(), zp_data_filter, + N, OC, OH, OW, stream); } } #endif diff --git a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cpp b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cpp index 4886406d..41004eb1 100644 --- a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cpp +++ b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cpp @@ -24,8 +24,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "src/cuda/utils.h" #include "src/cuda/query_blocksize.cuh" +#include "src/cuda/utils.h" namespace megdnn { namespace cuda { @@ -35,10 +35,9 @@ namespace activation_u4 { * Cuda 3D launch config to ensure maximize occupancy we should use for a kernel * launch. */ -void get_launch_config(const void* kern, int dimx, int dimy, int dimz, - dim3& blocks, dim3& grids) { - auto config = - query_launch_config_for_kernel(reinterpret_cast(kern)); +void get_launch_config( + const void* kern, int dimx, int dimy, int dimz, dim3& blocks, dim3& grids) { + auto config = query_launch_config_for_kernel(reinterpret_cast(kern)); int block_size = config.block_size; int grid_size = config.grid_size; auto&& device_prop = current_device_prop(); @@ -50,16 +49,15 @@ void get_launch_config(const void* kern, int dimx, int dimy, int dimz, int z_grid_limit = device_prop.maxGridSize[2]; #define MIN3(a, b, c) std::min({(a), (b), (c)}) uint32_t blkx = MIN3(dimx, block_size, x_thread_limit); - uint32_t blky = - MIN3(dimy, std::max(block_size / (int)(blkx), 1), y_thread_limit); - uint32_t blkz = - MIN3(dimz, std::max(block_size / ((int)blkx * (int)blky), 1), - z_thread_limit); + uint32_t blky = MIN3(dimy, std::max(block_size / (int)(blkx), 1), y_thread_limit); + uint32_t blkz = MIN3( + dimz, std::max(block_size / ((int)blkx * (int)blky), 1), z_thread_limit); uint32_t gridx = MIN3(grid_size, DIVUP((int)dimx, (int)blkx), x_grid_limit); - uint32_t gridy = MIN3(DIVUP(grid_size, (int)gridx), DIVUP(dimy, (int)blky), - y_grid_limit); - uint32_t gridz = MIN3(DIVUP(grid_size, (int)(gridx * gridy)), - DIVUP(dimz, (int)blkz), z_grid_limit); + uint32_t gridy = + MIN3(DIVUP(grid_size, (int)gridx), DIVUP(dimy, (int)blky), y_grid_limit); + uint32_t gridz = + MIN3(DIVUP(grid_size, (int)(gridx * gridy)), DIVUP(dimz, (int)blkz), + z_grid_limit); #undef MIN3 grids = dim3{gridx, gridy, gridz}; diff --git a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cu b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cu index 83fc87e3..708a98e7 100644 --- a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cu +++ b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cu @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -43,16 +44,15 @@ using namespace activation_u4; namespace { __host__ __device__ __forceinline__ int4 operator+(int4 lval, int4 rval) { - return make_int4(lval.x + rval.x, lval.y + rval.y, lval.z + rval.z, - lval.w + rval.w); + return make_int4( + lval.x + rval.x, lval.y + rval.y, lval.z + rval.z, lval.w + rval.w); } template -__global__ void kern_activation_u4(int32_t* dst, const int32_t* zp_data, - const int32_t* zp_filter, - int32_t zp_data_filter, int batch_size, - int OC, int OH, int OW, - BiasVisitor visitor) { +__global__ void kern_activation_u4( + int32_t* dst, const int32_t* zp_data, const int32_t* zp_filter, + int32_t zp_data_filter, int batch_size, int OC, int OH, int OW, + BiasVisitor visitor) { const int ow = blockIdx.x * blockDim.x + threadIdx.x; const int oh = blockIdx.y * blockDim.y + threadIdx.y; const int bc = blockIdx.z * blockDim.z + threadIdx.z; @@ -62,8 +62,7 @@ __global__ void kern_activation_u4(int32_t* dst, const int32_t* zp_data, const int batch = bc / oc_blks; const int oc_blk = bc % oc_blks; - int32_t* dptr = dst + batch * OC * OH * OW + - oc_blk * OH * OW * subbytes_per_pixel + + int32_t* dptr = dst + batch * OC * OH * OW + oc_blk * OH * OW * subbytes_per_pixel + oh * OW * subbytes_per_pixel + ow * subbytes_per_pixel; if (batch >= batch_size || oh >= OH || ow >= OW) return; @@ -87,27 +86,28 @@ __global__ void kern_activation_u4(int32_t* dst, const int32_t* zp_data, } // namespace template -void do_dispatch_activation_u4(int32_t* dst, BiasVisitor visitor, - const int32_t* zp_data, const int32_t* zp_filter, - int32_t zp_data_filter, int batch_size, int co, - int ho, int wo, cudaStream_t stream) { - void (*fptr)(int32_t*, const int32_t*, const int32_t*, int32_t, int, int OC, - int, int, BiasVisitor) = kern_activation_u4; +void do_dispatch_activation_u4( + int32_t* dst, BiasVisitor visitor, const int32_t* zp_data, + const int32_t* zp_filter, int32_t zp_data_filter, int batch_size, int co, + int ho, int wo, cudaStream_t stream) { + void (*fptr)( + int32_t*, const int32_t*, const int32_t*, int32_t, int, int OC, int, int, + BiasVisitor) = kern_activation_u4; dim3 grids{0, 0, 0}; dim3 blocks{0, 0, 0}; - get_launch_config(reinterpret_cast(fptr), wo, ho, - batch_size * co / 8, blocks, grids); + get_launch_config( + reinterpret_cast(fptr), wo, ho, batch_size * co / 8, blocks, + grids); kern_activation_u4<<>>( - dst, zp_data, zp_filter, zp_data_filter, batch_size, co, ho, wo, - visitor); + dst, zp_data, zp_filter, zp_data_filter, batch_size, co, ho, wo, visitor); after_kernel_launch(); } -#define INST(_op) \ - template void do_dispatch_activation_u4<_op>( \ - int32_t * dst, BiasVisitor visitor, const int32_t* zp_data, \ - const int32_t* zp_filter, int32_t zp_data_filter, int batch_size, \ - int co, int ho, int wo, cudaStream_t stream); +#define INST(_op) \ + template void do_dispatch_activation_u4<_op>( \ + int32_t * dst, BiasVisitor visitor, const int32_t* zp_data, \ + const int32_t* zp_filter, int32_t zp_data_filter, int batch_size, int co, \ + int ho, int wo, cudaStream_t stream); INST(ActivationRELU); INST(ActivationIdentity); diff --git a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cuh b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cuh index 162f031f..c46bf70a 100644 --- a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cuh +++ b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/activation_u4.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -39,8 +40,8 @@ namespace megdnn { namespace cuda { namespace activation_u4 { -void get_launch_config(const void* kern, int dimx, int dimy, int dimz, - dim3& blocks, dim3& grids); +void get_launch_config( + const void* kern, int dimx, int dimy, int dimz, dim3& blocks, dim3& grids); struct BiasVisitor { const int32_t* bias_ptr; @@ -49,10 +50,8 @@ struct BiasVisitor { int height_stride; int width_stride; #ifdef MEGDNN_CC_CUDA - __host__ __device__ __forceinline__ const int32_t* ptr(int batch, - int oc_blk, int oh, - int ow, - int oc_remain) { + __host__ __device__ __forceinline__ const int32_t* ptr( + int batch, int oc_blk, int oh, int ow, int oc_remain) { return bias_ptr + batch * batch_stride + oc_blk * channel_stride + oh * height_stride + ow * width_stride + oc_remain; } @@ -74,18 +73,16 @@ struct ActivationRELU { struct ActivationIdentity { #ifdef MEGDNN_CC_CUDA - __host__ __device__ __forceinline__ static int4 apply(int4 in) { - return in; - } + __host__ __device__ __forceinline__ static int4 apply(int4 in) { return in; } #endif }; } // namespace activation_u4 template -void do_dispatch_activation_u4(int32_t* dst, activation_u4::BiasVisitor visitor, - const int32_t* zp_data, const int32_t* zp_filter, - int32_t zp_data_filter, int batch_size, int co, - int ho, int wo, cudaStream_t stream); +void do_dispatch_activation_u4( + int32_t* dst, activation_u4::BiasVisitor visitor, const int32_t* zp_data, + const int32_t* zp_filter, int32_t zp_data_filter, int batch_size, int co, + int ho, int wo, cudaStream_t stream); } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cu b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cu index 99853cf2..6041f267 100644 --- a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cu +++ b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cu @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -43,35 +44,36 @@ using namespace wmma_conv_integer_subbyte; namespace { -template +template < + typename ConvConfig, size_t thread_blk_x, size_t thread_blk_y, + size_t pixels_per_thread_x, size_t pixels_per_thread_y> struct TileCounter { - MEGDNN_STATIC_ASSERT(thread_blk_x % WARP_SIZE == 0, - "thread block size in dim x not divided by warpSize"); + MEGDNN_STATIC_ASSERT( + thread_blk_x % WARP_SIZE == 0, + "thread block size in dim x not divided by warpSize"); static const size_t spatial_tile_x = thread_blk_x * pixels_per_thread_x; static const size_t spatial_tile_y = thread_blk_y * pixels_per_thread_y; static const size_t global_load_tile_x = (spatial_tile_x - 1) * ConvConfig::SW + ConvConfig::FW; static const size_t global_load_tile_y = (spatial_tile_y - 1) * ConvConfig::SH + ConvConfig::FH; - static const size_t reg_cache_x = - (global_load_tile_x + WARP_SIZE - 1) / WARP_SIZE; - static const size_t warps_per_block = - (thread_blk_x * thread_blk_y) / WARP_SIZE; + static const size_t reg_cache_x = (global_load_tile_x + WARP_SIZE - 1) / WARP_SIZE; + static const size_t warps_per_block = (thread_blk_x * thread_blk_y) / WARP_SIZE; static const size_t reg_cache_y = (global_load_tile_y + warps_per_block - 1) / warps_per_block; static const size_t smem_stride = global_load_tile_x + (global_load_tile_x % 2 == 0); }; -template +template < + typename ConvConfig_, size_t thread_blk_x, size_t thread_blk_y, + size_t pixels_per_thread_x, size_t pixels_per_thread_y> __global__ void reduce_in_spatial_block_and_along_input_channel_with_scale_u4( - int32_t* __restrict__ dst, const uint8_t* __restrict__ src, int IC, - int IH, int IW, int OH, int OW, int PH, int PW, int32_t scale, - int32_t zero) { - typedef TileCounter + int32_t* __restrict__ dst, const uint8_t* __restrict__ src, int IC, int IH, + int IW, int OH, int OW, int PH, int PW, int32_t scale, int32_t zero) { + typedef TileCounter< + ConvConfig_, thread_blk_x, thread_blk_y, pixels_per_thread_x, + pixels_per_thread_y> TileCounter_; const int bidx = blockIdx.x; @@ -87,12 +89,11 @@ __global__ void reduce_in_spatial_block_and_along_input_channel_with_scale_u4( const uint8_t* __restrict__ sptr = src + bidz * IC * IH * IW / 2 + (ih_base * IW + iw_base) * 4; - __shared__ uint8_t smem[TileCounter_::global_load_tile_y] - [TileCounter_::smem_stride * 4]; + __shared__ uint8_t + smem[TileCounter_::global_load_tile_y][TileCounter_::smem_stride * 4]; uint32_t reg_cache[TileCounter_::reg_cache_y][TileCounter_::reg_cache_x]; int32_t acc[pixels_per_thread_y][pixels_per_thread_x]; - int32_t* __restrict__ dptr = - dst + bidz * OH * OW + ow_start + oh_start * OW; + int32_t* __restrict__ dptr = dst + bidz * OH * OW + ow_start + oh_start * OW; const int tid = tidy * thread_blk_x + tidx; const int idx_in_warp = tid % WARP_SIZE; @@ -147,10 +148,9 @@ __global__ void reduce_in_spatial_block_and_along_input_channel_with_scale_u4( for (int j = 0; j < TileCounter_::reg_cache_x; ++j) { int iw = idx_in_warp + j * WARP_SIZE; int ih = warp_id + i * TileCounter_::warps_per_block; - if (ih_base + ih >= 0 && ih_base + ih < IH && - iw_base + iw >= 0 && iw_base + iw < IW) { - reg_cache[i][j] = - *(const uint32_t*)(&sptr[(ih * IW + iw) * 4]); + if (ih_base + ih >= 0 && ih_base + ih < IH && iw_base + iw >= 0 && + iw_base + iw < IW) { + reg_cache[i][j] = *(const uint32_t*)(&sptr[(ih * IW + iw) * 4]); } else { reg_cache[i][j] = zero; } @@ -168,8 +168,7 @@ __global__ void reduce_in_spatial_block_and_along_input_channel_with_scale_u4( for (int fh = 0; fh < ConvConfig_::FH; ++fh) { #pragma unroll for (int fw = 0; fw < ConvConfig_::FW; ++fw) { - uint32_t sdata = - *(uint32_t*)(&smem[y + fh][(x + fw) * 4]); + uint32_t sdata = *(uint32_t*)(&smem[y + fh][(x + fw) * 4]); #pragma unroll for (int r = 0; r < 8; r++) { uint8_t val = (sdata & 0xF); @@ -212,8 +211,9 @@ __global__ void reduce_in_spatial_block_and_along_input_channel_with_scale_u4( } } -template +template < + typename ConvConfig, size_t thread_blk_x, size_t thread_blk_y, + size_t pixels_per_thread_x, size_t pixels_per_thread_y> struct LargeChannelTileCounter { static const size_t spatial_tile_x = thread_blk_x * pixels_per_thread_x; static const size_t spatial_tile_y = pixels_per_thread_y; @@ -221,13 +221,10 @@ struct LargeChannelTileCounter { (spatial_tile_x - 1) * ConvConfig::SW + ConvConfig::FW; static const size_t global_load_tile_y = (spatial_tile_y - 1) * ConvConfig::SH + ConvConfig::FH; - static const size_t reg_cache_x = - (global_load_tile_x + WARP_SIZE - 1) / WARP_SIZE; - static const size_t warps_per_block = - (thread_blk_x * thread_blk_y) / WARP_SIZE; + static const size_t reg_cache_x = (global_load_tile_x + WARP_SIZE - 1) / WARP_SIZE; + static const size_t warps_per_block = (thread_blk_x * thread_blk_y) / WARP_SIZE; static const size_t reg_cache_y = - (global_load_tile_y * thread_blk_y + warps_per_block - 1) / - warps_per_block; + (global_load_tile_y * thread_blk_y + warps_per_block - 1) / warps_per_block; static const size_t smem_stride = global_load_tile_x + (global_load_tile_x % 2 == 0); static const size_t reduce_dim_0 = thread_blk_y; @@ -235,15 +232,16 @@ struct LargeChannelTileCounter { static const size_t reduce_dim_2 = thread_blk_x * pixels_per_thread_x; }; -template +template < + typename ConvConfig_, size_t thread_blk_x, size_t thread_blk_y, + size_t pixels_per_thread_x, size_t pixels_per_thread_y> __global__ void reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels( - int32_t* __restrict__ dst, const uint8_t* __restrict__ src, int IC, - int IH, int IW, int OH, int OW, int PH, int PW, int32_t scale, - int32_t zero) { - typedef LargeChannelTileCounter + int32_t* __restrict__ dst, const uint8_t* __restrict__ src, int IC, int IH, + int IW, int OH, int OW, int PH, int PW, int32_t scale, int32_t zero) { + typedef LargeChannelTileCounter< + ConvConfig_, thread_blk_x, thread_blk_y, pixels_per_thread_x, + pixels_per_thread_y> TileCounter_; const int bidx = blockIdx.x; @@ -251,8 +249,8 @@ reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels( const int tidx = threadIdx.x; const int tidy = threadIdx.y; - const int blocks_per_row = (OW + TileCounter_::spatial_tile_x - 1) / - TileCounter_::spatial_tile_x; + const int blocks_per_row = + (OW + TileCounter_::spatial_tile_x - 1) / TileCounter_::spatial_tile_x; const int bidw = bidx % blocks_per_row; const int bidh = bidx / blocks_per_row; @@ -265,14 +263,12 @@ reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels( __shared__ uint8_t smem[thread_blk_y][TileCounter_::global_load_tile_y] [TileCounter_::smem_stride * 4]; - __shared__ int32_t - s_reduce[TileCounter_::reduce_dim_0][TileCounter_::reduce_dim_1] - [TileCounter_::reduce_dim_2 + 1]; + __shared__ int32_t s_reduce[TileCounter_::reduce_dim_0][TileCounter_::reduce_dim_1] + [TileCounter_::reduce_dim_2 + 1]; uint32_t reg_cache[TileCounter_::reg_cache_y][TileCounter_::reg_cache_x]; int32_t acc[pixels_per_thread_y][pixels_per_thread_x]; - int32_t* __restrict__ dptr = - dst + bidz * OH * OW + ow_start + oh_start * OW; + int32_t* __restrict__ dptr = dst + bidz * OH * OW + ow_start + oh_start * OW; const int tid = tidy * thread_blk_x + tidx; const int idx_in_warp = tid % WARP_SIZE; @@ -299,10 +295,10 @@ reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels( iw_base + iw < IW) { reg_cache[i][j] = 0; if (ic_blk < ic_blks) - reg_cache[i][j] = - *(const uint32_t*)(&sptr[(ic_blk * IH * IW + - ih * IW + iw) * - 4]); + reg_cache[i][j] = *(const uint32_t*)(&sptr + [(ic_blk * IH * IW + + ih * IW + iw) * + 4]); } else { reg_cache[i][j] = (ic_blk < ic_blks) ? zero : 0; } @@ -339,14 +335,15 @@ reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels( int ih = hc % TileCounter_::global_load_tile_y; int ic_blk = hc / TileCounter_::global_load_tile_y; int g_ic_blk = ic_blk + c * thread_blk_y; - if (ih_base + ih >= 0 && ih_base + ih < IH && - iw_base + iw >= 0 && iw_base + iw < IW) { + if (ih_base + ih >= 0 && ih_base + ih < IH && iw_base + iw >= 0 && + iw_base + iw < IW) { reg_cache[i][j] = 0; if (g_ic_blk < ic_blks) reg_cache[i][j] = - *(const uint32_t*)(&sptr[(ic_blk * IH * IW + - ih * IW + iw) * - 4]); + *(const uint32_t*)(&sptr + [(ic_blk * IH * IW + + ih * IW + iw) * + 4]); } else { reg_cache[i][j] = (g_ic_blk < ic_blks) ? zero : 0; } @@ -387,10 +384,8 @@ reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels( int hc = warp_id + i * TileCounter_::warps_per_block; int ih = hc % TileCounter_::global_load_tile_y; int ic_blk = hc / TileCounter_::global_load_tile_y; - if (ic_blk < thread_blk_y && - x < TileCounter_::global_load_tile_x) { - *(uint32_t*)(&smem[ic_blk][ih][x * 4]) = - reg_cache[i][j]; + if (ic_blk < thread_blk_y && x < TileCounter_::global_load_tile_x) { + *(uint32_t*)(&smem[ic_blk][ih][x * 4]) = reg_cache[i][j]; } } } @@ -434,8 +429,7 @@ reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels( int x = j * thread_blk_x + tidx; int y = i; if (oh_start + y < OH && ow_start + x < OW) { - dptr[y * OW + x] = - s_reduce[0][i][tidx + j * thread_blk_x] * scale; + dptr[y * OW + x] = s_reduce[0][i][tidx + j * thread_blk_x] * scale; } } } @@ -445,9 +439,9 @@ reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels( } // namespace void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( - int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw, - int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic, - int32_t scale, uint8_t zp_data, cudaStream_t stream) { + int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw, int oh, + int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic, int32_t scale, + uint8_t zp_data, cudaStream_t stream) { zp_data = (zp_data << 4) | zp_data; int32_t zero = (zp_data << 24) | (zp_data << 16) | (zp_data << 8) | zp_data; if (fh == 3 && fw == 3 && sh == 1 && sw == 1) { @@ -458,8 +452,9 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( constexpr size_t pixels_per_thread_x_ = 4; constexpr size_t pixels_per_thread_y_ = 2; - typedef TileCounter + typedef TileCounter< + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_> TileCounter_; dim3 gridDim; @@ -475,10 +470,9 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( gridDim.z = batch_size; reduce_in_spatial_block_and_along_input_channel_with_scale_u4< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> - <<>>(dst, src, ic, ih, iw, oh, - ow, ph, pw, scale, zero); + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_><<>>( + dst, src, ic, ih, iw, oh, ow, ph, pw, scale, zero); } else { if (iw <= 32) { constexpr size_t thread_blk_x_ = WARP_SIZE / 2; @@ -487,8 +481,8 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( constexpr size_t pixels_per_thread_y_ = 4; typedef LargeChannelTileCounter< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_> TileCounter_; dim3 gridDim; @@ -504,11 +498,9 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( gridDim.z = batch_size; reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> - <<>>(dst, src, ic, ih, iw, - oh, ow, ph, pw, - scale, zero); + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_><<>>( + dst, src, ic, ih, iw, oh, ow, ph, pw, scale, zero); } else { constexpr size_t thread_blk_x_ = WARP_SIZE / 2; constexpr size_t thread_blk_y_ = 4; @@ -516,8 +508,8 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( constexpr size_t pixels_per_thread_y_ = 4; typedef LargeChannelTileCounter< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_> TileCounter_; dim3 gridDim; @@ -533,11 +525,9 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( gridDim.z = batch_size; reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> - <<>>(dst, src, ic, ih, iw, - oh, ow, ph, pw, - scale, zero); + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_><<>>( + dst, src, ic, ih, iw, oh, ow, ph, pw, scale, zero); } } } else if (fh == 5 && fw == 5 && sh == 1 && sw == 1) { @@ -548,8 +538,9 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( constexpr size_t pixels_per_thread_x_ = 4; constexpr size_t pixels_per_thread_y_ = 2; - typedef TileCounter + typedef TileCounter< + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_> TileCounter_; dim3 gridDim; @@ -565,10 +556,9 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( gridDim.z = batch_size; reduce_in_spatial_block_and_along_input_channel_with_scale_u4< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> - <<>>(dst, src, ic, ih, iw, oh, - ow, ph, pw, scale, zero); + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_><<>>( + dst, src, ic, ih, iw, oh, ow, ph, pw, scale, zero); } else { if (iw <= 32) { constexpr size_t thread_blk_x_ = WARP_SIZE / 2; @@ -577,8 +567,8 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( constexpr size_t pixels_per_thread_y_ = 4; typedef LargeChannelTileCounter< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_> TileCounter_; dim3 gridDim; @@ -594,11 +584,9 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( gridDim.z = batch_size; reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> - <<>>(dst, src, ic, ih, iw, - oh, ow, ph, pw, - scale, zero); + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_><<>>( + dst, src, ic, ih, iw, oh, ow, ph, pw, scale, zero); } else { constexpr size_t thread_blk_x_ = WARP_SIZE / 2; @@ -607,8 +595,8 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( constexpr size_t pixels_per_thread_y_ = 4; typedef LargeChannelTileCounter< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_> TileCounter_; dim3 gridDim; @@ -624,11 +612,9 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( gridDim.z = batch_size; reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> - <<>>(dst, src, ic, ih, iw, - oh, ow, ph, pw, - scale, zero); + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_><<>>( + dst, src, ic, ih, iw, oh, ow, ph, pw, scale, zero); } } } else if (fh == 7 && fw == 7 && sh == 1 && sw == 1) { @@ -639,8 +625,9 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( constexpr size_t pixels_per_thread_x_ = 4; constexpr size_t pixels_per_thread_y_ = 2; - typedef TileCounter + typedef TileCounter< + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_> TileCounter_; dim3 gridDim; @@ -656,19 +643,18 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( gridDim.z = batch_size; reduce_in_spatial_block_and_along_input_channel_with_scale_u4< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> - <<>>(dst, src, ic, ih, iw, oh, - ow, ph, pw, scale, zero); + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_><<>>( + dst, src, ic, ih, iw, oh, ow, ph, pw, scale, zero); } else { constexpr size_t thread_blk_x_ = WARP_SIZE / 2; constexpr size_t thread_blk_y_ = 8; constexpr size_t pixels_per_thread_x_ = 1; constexpr size_t pixels_per_thread_y_ = 4; - typedef LargeChannelTileCounter + typedef LargeChannelTileCounter< + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_> TileCounter_; dim3 gridDim; @@ -684,10 +670,9 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_data_u4( gridDim.z = batch_size; reduce_in_spatial_block_and_along_input_channel_with_scale_u4_large_channels< - ConvConfig_, thread_blk_x_, thread_blk_y_, - pixels_per_thread_x_, pixels_per_thread_y_> - <<>>(dst, src, ic, ih, iw, oh, - ow, ph, pw, scale, zero); + ConvConfig_, thread_blk_x_, thread_blk_y_, pixels_per_thread_x_, + pixels_per_thread_y_><<>>( + dst, src, ic, ih, iw, oh, ow, ph, pw, scale, zero); } } after_kernel_launch(); diff --git a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cuh b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cuh index 80b00563..28d90f09 100644 --- a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cuh +++ b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/reduce_with_scale_data.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -38,9 +39,9 @@ namespace megdnn { namespace cuda { void do_dispatch_reduce_with_scale_data_u4( - int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw, - int oh, int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic, - int32_t scale, uint8_t zp_data, cudaStream_t stream); + int32_t* dst, const uint8_t* src, int batch_size, int ih, int iw, int oh, + int ow, int ph, int pw, int fh, int fw, int sh, int sw, int ic, int32_t scale, + uint8_t zp_data, cudaStream_t stream); } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4.cuh b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4.cuh index 78b54b0a..c0d4983a 100644 --- a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4.cuh +++ b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -58,21 +59,16 @@ struct ConvConfig { static int const SW = SW_; }; -void _do_wmma_conv_integer_subbyte_1xfw(const uint8_t* d_data, - const uint8_t* d_filter, int32_t* d_out, - uint8_t* workspace, int batch_size, - int hi, int wi, int ho, int wo, int ph, - int pw, int ci, int co, int fh, int fw, - int sh, int sw, uint8_t zp_data, - cudaStream_t stream); +void _do_wmma_conv_integer_subbyte_1xfw( + const uint8_t* d_data, const uint8_t* d_filter, int32_t* d_out, + uint8_t* workspace, int batch_size, int hi, int wi, int ho, int wo, int ph, + int pw, int ci, int co, int fh, int fw, int sh, int sw, uint8_t zp_data, + cudaStream_t stream); -void _do_wmma_conv_integer_subbyte_fhxfw(const uint8_t* d_data, - const uint8_t* d_filter, - int32_t* d_out, int batch_size, int hi, - int wi, int ho, int wo, int ph, int pw, - int ci, int co, int fh, int fw, int sh, - int sw, uint8_t zp_data, - cudaStream_t stream); +void _do_wmma_conv_integer_subbyte_fhxfw( + const uint8_t* d_data, const uint8_t* d_filter, int32_t* d_out, int batch_size, + int hi, int wi, int ho, int wo, int ph, int pw, int ci, int co, int fh, int fw, + int sh, int sw, uint8_t zp_data, cudaStream_t stream); } // namespace wmma_conv_integer_subbyte } // namespace cuda diff --git a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4_1xfw.cu b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4_1xfw.cu index bea3036a..9040add0 100644 --- a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4_1xfw.cu +++ b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4_1xfw.cu @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -47,8 +48,9 @@ using namespace wmma_conv_integer_subbyte; namespace wmma_conv_integer_subbyte_1xfw { -template +template < + int WARPS_W_, int WARPS_OC_, int OUT_CHANNELS_PER_WARP_, int OH_PER_WARP_, + int IC_UNROLL_SIZE_> struct BlockConfig { static int const WARPS_W = WARPS_W_; static int const WARPS_OC = WARPS_OC_; @@ -111,10 +113,9 @@ struct ConvDataGlobal2ShareMemVisitor { copy_t reg_cache[DataCount::LANES_PER_WARP]; - __device__ ConvDataGlobal2ShareMemVisitor(uint8_t* smem, - const uint8_t* g_ptr, int IH, - int IW, int b_ih, int b_iw, - copy_t zero) + __device__ ConvDataGlobal2ShareMemVisitor( + uint8_t* smem, const uint8_t* g_ptr, int IH, int IW, int b_ih, int b_iw, + copy_t zero) : smem{smem}, g_ptr{g_ptr}, IH{IH}, @@ -137,10 +138,8 @@ struct ConvDataGlobal2ShareMemVisitor { int row = i * BlockConfig_::WARPS_PER_BLOCK + warp_id; int ci_idx = row / DataCount_::LANES_PER_SLICE; int hi_idx = row - ci_idx * DataCount_::LANES_PER_SLICE; - if (idx % ConvConfig_::FH != 0 && - hi_idx < BlockConfig_::OH_PER_WARP - 1) { - int y = (hi_idx + - 1) * DataCount::LANE_SIZE + + if (idx % ConvConfig_::FH != 0 && hi_idx < BlockConfig_::OH_PER_WARP - 1) { + int y = (hi_idx + 1) * DataCount::LANE_SIZE + tid_in_warp; int x = ci_idx * 8; if (tid_in_warp < DataCount_::LANE_SIZE) @@ -148,12 +147,12 @@ struct ConvDataGlobal2ShareMemVisitor { } else { bool cond = ((b_iw + tid_in_warp) >= 0) && ((b_iw + tid_in_warp) < IW) && - ((b_ih_base + hi_idx) >= 0) && - ((b_ih_base + hi_idx) < IH); + ((b_ih_base + hi_idx) >= 0) && ((b_ih_base + hi_idx) < IH); if (cond) { - copy_t val = *(copy_t*)(&g_ptr[(ci_idx * ci_stride + - hi_idx * hi_stride + col) / - 2]); + copy_t val = *(copy_t*)(&g_ptr + [(ci_idx * ci_stride + + hi_idx * hi_stride + col) / + 2]); reg_cache[i] = val; } else { reg_cache[i] = zero; @@ -164,18 +163,15 @@ struct ConvDataGlobal2ShareMemVisitor { __device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; - i < DataCount::LANES_PER_WARP; ++i) { + for (int i = 0; i < DataCount::LANES_PER_WARP; ++i) { if (tid_in_warp < DataCount::LANE_SIZE) { int row = i * BlockConfig_::WARPS_PER_BLOCK + warp_id; int ci_idx = - row / - DataCount::LANES_PER_SLICE; + row / DataCount::LANES_PER_SLICE; int hi_idx = - row - ci_idx * DataCount::LANES_PER_SLICE; - int y = hi_idx * DataCount::LANE_SIZE + + row - + ci_idx * DataCount::LANES_PER_SLICE; + int y = hi_idx * DataCount::LANE_SIZE + tid_in_warp; int x = ci_idx * 8; *(copy_t*)(get_smem_ptr(y, x)) = reg_cache[i]; @@ -184,19 +180,16 @@ struct ConvDataGlobal2ShareMemVisitor { } __device__ __forceinline__ uint8_t* get_smem_ptr(int y, int x) { - return &smem[(y * DataCount::SMEM_DATA_STRIDE + - x) / - 2]; + return &smem + [(y * DataCount::SMEM_DATA_STRIDE + x) / 2]; } - __device__ __forceinline__ void inc_stage() { + __device__ __forceinline__ void inc_stage() { idx++; - g_ptr += idx % ConvConfig_::FH == 0 - ? (BlockConfig_::IC_BLKS * ci_stride - - (ConvConfig_::FH - 1) * hi_stride) / - 2 - : hi_stride / 2; + g_ptr += idx % ConvConfig_::FH == 0 ? (BlockConfig_::IC_BLKS * ci_stride - + (ConvConfig_::FH - 1) * hi_stride) / + 2 + : hi_stride / 2; } }; @@ -215,28 +208,23 @@ struct ConvFilterGlobal2ShareMemVisitor { copy_t reg_cache[FilterCount::REG_FILTER_ROW] [FilterCount::REG_FILTER_COL]; - __device__ ConvFilterGlobal2ShareMemVisitor(uint8_t* smem, - const uint8_t* g_ptr, - int co_stride, int co_remain) - : smem{smem}, - g_ptr{g_ptr}, - co_stride{co_stride}, - co_remain{co_remain} {} + __device__ ConvFilterGlobal2ShareMemVisitor( + uint8_t* smem, const uint8_t* g_ptr, int co_stride, int co_remain) + : smem{smem}, g_ptr{g_ptr}, co_stride{co_stride}, co_remain{co_remain} {} __device__ __forceinline__ void copy() { #pragma unroll - for (int i = 0; - i < FilterCount::REG_FILTER_ROW; ++i) { + for (int i = 0; i < FilterCount::REG_FILTER_ROW; + ++i) { #pragma unroll - for (int j = 0; - j < FilterCount::REG_FILTER_COL; + for (int j = 0; j < FilterCount::REG_FILTER_COL; ++j) { int y = BlockConfig_::WARPS_PER_BLOCK * i + warp_id; int x = WARP_SIZE * j + tid_in_warp; bool valid = (y < - FilterCount::OUT_CHANNELS_PER_BLOCK) && + FilterCount< + ConvConfig_, BlockConfig_>::OUT_CHANNELS_PER_BLOCK) && (x < BlockConfig_::IC_BLKS * ConvConfig_::FW) && (y < co_remain); if (valid) { @@ -251,18 +239,17 @@ struct ConvFilterGlobal2ShareMemVisitor { __device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; - i < FilterCount::REG_FILTER_ROW; ++i) { + for (int i = 0; i < FilterCount::REG_FILTER_ROW; + ++i) { #pragma unroll - for (int j = 0; - j < FilterCount::REG_FILTER_COL; + for (int j = 0; j < FilterCount::REG_FILTER_COL; ++j) { int y = BlockConfig_::WARPS_PER_BLOCK * i + warp_id; int x = WARP_SIZE * j + tid_in_warp; bool bounds = (y < - FilterCount::OUT_CHANNELS_PER_BLOCK) && + FilterCount< + ConvConfig_, BlockConfig_>::OUT_CHANNELS_PER_BLOCK) && (x < BlockConfig_::IC_BLKS * ConvConfig_::FW); copy_t val = reg_cache[i][j]; if (bounds) @@ -272,10 +259,9 @@ struct ConvFilterGlobal2ShareMemVisitor { } __device__ __forceinline__ uint8_t* get_smem_ptr(int y, int x) { - return &smem[(y * FilterCount::SMEM_FILTER_STRIDE + - x) / - 2]; + return &smem + [(y * FilterCount::SMEM_FILTER_STRIDE + x) / + 2]; } __device__ __forceinline__ void inc_stage() { @@ -284,19 +270,19 @@ struct ConvFilterGlobal2ShareMemVisitor { }; template -__device__ inline void -calc(wmma::fragment - data_frag[OH_PER_WARP], - wmma::fragment - filter_frag[OUT_CHANNELS_PER_WARP], - wmma::fragment - acc_frag[OUT_CHANNELS_PER_WARP][OH_PER_WARP]) { +__device__ inline void calc( + wmma::fragment + data_frag[OH_PER_WARP], + wmma::fragment + filter_frag[OUT_CHANNELS_PER_WARP], + wmma::fragment + acc_frag[OUT_CHANNELS_PER_WARP][OH_PER_WARP]) { #pragma unroll for (int i = 0; i < OUT_CHANNELS_PER_WARP; ++i) { #pragma unroll for (int j = 0; j < OH_PER_WARP; ++j) { - wmma::mma_sync(acc_frag[i][j], filter_frag[i], data_frag[j], - acc_frag[i][j]); + wmma::mma_sync( + acc_frag[i][j], filter_frag[i], data_frag[j], acc_frag[i][j]); } } } @@ -307,11 +293,9 @@ struct enable_kernel_partial_spec; template struct enable_kernel_partial_spec { static __device__ inline void load_share_mem( - wmma::fragment + wmma::fragment data_frag[BlockConfig_::OH_PER_WARP], - wmma::fragment + wmma::fragment filter_frag[BlockConfig_::OUT_CHANNELS_PER_WARP], ConvDataGlobal2ShareMemVisitor& gbl2smem_data_visitor, @@ -319,36 +303,30 @@ struct enable_kernel_partial_spec { gbl2smem_filter_visitor, int data_spatial_idx, int fw, int ic_blk) { const int warp_y = threadIdx.y; - uint8_t* __restrict__ s_ptr_data = gbl2smem_data_visitor.get_smem_ptr( - data_spatial_idx, ic_blk * WMMA_K); - uint8_t* __restrict__ s_ptr_filter = - gbl2smem_filter_visitor.get_smem_ptr( - warp_y * WMMA_M, - fw * WMMA_K * BlockConfig_::IC_UNROLL_SIZE + - ic_blk * WMMA_K); + uint8_t* __restrict__ s_ptr_data = + gbl2smem_data_visitor.get_smem_ptr(data_spatial_idx, ic_blk * WMMA_K); + uint8_t* __restrict__ s_ptr_filter = gbl2smem_filter_visitor.get_smem_ptr( + warp_y * WMMA_M, + fw * WMMA_K * BlockConfig_::IC_UNROLL_SIZE + ic_blk * WMMA_K); #pragma unroll for (int i = 0; i < BlockConfig_::OH_PER_WARP; ++i) { wmma::load_matrix_sync( data_frag[i], - s_ptr_data + - i * - DataCount::LANE_SIZE * - DataCount::SMEM_DATA_STRIDE / - 2, + s_ptr_data + i * DataCount::LANE_SIZE * + DataCount:: + SMEM_DATA_STRIDE / + 2, DataCount::SMEM_DATA_STRIDE); } #pragma unroll for (int j = 0; j < BlockConfig_::OUT_CHANNELS_PER_WARP; ++j) { wmma::load_matrix_sync( filter_frag[j], - s_ptr_filter + - j * WMMA_M * BlockConfig_::WARPS_OC * - FilterCount:: - SMEM_FILTER_STRIDE / - 2, + s_ptr_filter + j * WMMA_M * BlockConfig_::WARPS_OC * + FilterCount:: + SMEM_FILTER_STRIDE / + 2, FilterCount::SMEM_FILTER_STRIDE); } } @@ -359,11 +337,9 @@ struct enable_kernel_partial_spec { gbl2smem_data_visitor, ConvFilterGlobal2ShareMemVisitor& gbl2smem_filter_visitor, - wmma::fragment + wmma::fragment data_frag[2][BlockConfig_::OH_PER_WARP], - wmma::fragment + wmma::fragment filter_frag[2][BlockConfig_::OUT_CHANNELS_PER_WARP], wmma::fragment acc_frag[BlockConfig_::OUT_CHANNELS_PER_WARP] @@ -380,32 +356,29 @@ struct enable_kernel_partial_spec { #pragma unroll for (; loop_count < BlockConfig_::IC_UNROLL_SIZE * ConvConfig_::FW - 1; loop_count++) { - calc(data_frag[loop_count % 2], - filter_frag[loop_count % 2], - acc_frag); + calc( + data_frag[loop_count % 2], filter_frag[loop_count % 2], acc_frag); int fw = (loop_count + 1) / BlockConfig_::IC_UNROLL_SIZE; int ic_blk = (loop_count + 1) % BlockConfig_::IC_UNROLL_SIZE; int data_spatial_idx = data_spatial_idx_base + fw; - load_share_mem(data_frag[(loop_count + 1) % 2], - filter_frag[(loop_count + 1) % 2], - gbl2smem_data_visitor, gbl2smem_filter_visitor, - data_spatial_idx, fw, ic_blk); + load_share_mem( + data_frag[(loop_count + 1) % 2], filter_frag[(loop_count + 1) % 2], + gbl2smem_data_visitor, gbl2smem_filter_visitor, data_spatial_idx, + fw, ic_blk); } calc( - data_frag[(loop_count % 2)], filter_frag[(loop_count % 2)], - acc_frag); + data_frag[(loop_count % 2)], filter_frag[(loop_count % 2)], acc_frag); if (!last_slice) { __syncthreads(); gbl2smem_data_visitor.commit(); gbl2smem_filter_visitor.commit(); __syncthreads(); - load_share_mem(data_frag[0], filter_frag[0], gbl2smem_data_visitor, - gbl2smem_filter_visitor, data_spatial_idx_base, 0, - 0); + load_share_mem( + data_frag[0], filter_frag[0], gbl2smem_data_visitor, + gbl2smem_filter_visitor, data_spatial_idx_base, 0, 0); } } }; @@ -413,8 +386,8 @@ struct enable_kernel_partial_spec { template __global__ void convolution_template_device_u4( const uint8_t* __restrict__ data, const uint8_t* __restrict__ filter, - int32_t* __restrict__ out, int N, int IH, int IW, int OH, int OW, - int PH, int PW, int IC, int OC, int32_t zero) { + int32_t* __restrict__ out, int N, int IH, int IW, int OH, int OW, int PH, + int PW, int IC, int OC, int32_t zero) { typedef enable_kernel_partial_spec caller; constexpr size_t IC_BLKS = BlockConfig_::IC_BLKS; constexpr size_t OUT_CHANNELS_PER_BLOCK = @@ -438,31 +411,26 @@ __global__ void convolution_template_device_u4( const uint8_t* __restrict__ g_ptr_data = data + bidz * IC * IH * IW / 2 + (b_ih * IW + b_iw) * 4; const uint8_t* __restrict__ g_ptr_filter = - filter + bidy * OUT_CHANNELS_PER_BLOCK * ConvConfig_::FH * - ConvConfig_::FW * IC / 2; + filter + + bidy * OUT_CHANNELS_PER_BLOCK * ConvConfig_::FH * ConvConfig_::FW * IC / 2; const int co_remain = OC - bidy * OUT_CHANNELS_PER_BLOCK; - int32_t* __restrict__ g_ptr_out = out + bidz * OC * OH * OW + - oc_start * OH * OW + + int32_t* __restrict__ g_ptr_out = out + bidz * OC * OH * OW + oc_start * OH * OW + (b_oh * OW + ow_start) * WMMA_M; + __shared__ uint8_t smem_data[DataCount::SMEM_DATA_ROW] + [DataCount::SMEM_DATA_COL]; __shared__ uint8_t - smem_data[DataCount::SMEM_DATA_ROW] - [DataCount::SMEM_DATA_COL]; - __shared__ uint8_t smem_filter - [FilterCount::SMEM_FILTER_ROW] - [FilterCount::SMEM_FILTER_COL]; - - ConvDataGlobal2ShareMemVisitor - gbl2smem_data_visitor{smem_data[0], g_ptr_data, IH, IW, - b_ih, b_iw, zero}; - ConvFilterGlobal2ShareMemVisitor - gbl2smem_filter_visitor{smem_filter[0], g_ptr_filter, - IC / 2 * ConvConfig_::FH * ConvConfig_::FW, - co_remain}; + smem_filter[FilterCount::SMEM_FILTER_ROW] + [FilterCount::SMEM_FILTER_COL]; + + ConvDataGlobal2ShareMemVisitor gbl2smem_data_visitor{ + smem_data[0], g_ptr_data, IH, IW, b_ih, b_iw, zero}; + ConvFilterGlobal2ShareMemVisitor gbl2smem_filter_visitor{ + smem_filter[0], g_ptr_filter, IC / 2 * ConvConfig_::FH * ConvConfig_::FW, + co_remain}; wmma::fragment - acc_frag[BlockConfig_::OUT_CHANNELS_PER_WARP] - [BlockConfig_::OH_PER_WARP]; + acc_frag[BlockConfig_::OUT_CHANNELS_PER_WARP][BlockConfig_::OH_PER_WARP]; wmma::fragment data_frag[2][BlockConfig_::OH_PER_WARP]; wmma::fragment @@ -483,32 +451,33 @@ __global__ void convolution_template_device_u4( __syncthreads(); - caller::load_share_mem(data_frag[0], filter_frag[0], gbl2smem_data_visitor, - gbl2smem_filter_visitor, warp_x * WMMA_N, 0, 0); + caller::load_share_mem( + data_frag[0], filter_frag[0], gbl2smem_data_visitor, + gbl2smem_filter_visitor, warp_x * WMMA_N, 0, 0); int ic_blocks = (IC / 8 + IC_BLKS - 1) / IC_BLKS * ConvConfig_::FH - 1; #pragma unroll for (int ci_blk = 0; ci_blk < ic_blocks; ci_blk++) { - caller::consume_slice(gbl2smem_data_visitor, - gbl2smem_filter_visitor, data_frag, - filter_frag, acc_frag); + caller::consume_slice( + gbl2smem_data_visitor, gbl2smem_filter_visitor, data_frag, filter_frag, + acc_frag); } - caller::consume_slice(gbl2smem_data_visitor, gbl2smem_filter_visitor, - data_frag, filter_frag, acc_frag); + caller::consume_slice( + gbl2smem_data_visitor, gbl2smem_filter_visitor, data_frag, filter_frag, + acc_frag); // store #pragma unroll for (int i = 0; i < BlockConfig_::OUT_CHANNELS_PER_WARP; ++i) { #pragma unroll for (int j = 0; j < BlockConfig_::OH_PER_WARP; ++j) { - if (b_oh + j < OH && - oc_start + i * BlockConfig_::WARPS_OC * WMMA_M < OC && + if (b_oh + j < OH && oc_start + i * BlockConfig_::WARPS_OC * WMMA_M < OC && ow_start < OW) { - wmma::store_matrix_sync(&g_ptr_out[i * BlockConfig_::WARPS_OC * - WMMA_M * OH * OW + - j * OW * WMMA_M], - acc_frag[i][j], WMMA_M, - wmma::mem_col_major); + wmma::store_matrix_sync( + &g_ptr_out + [i * BlockConfig_::WARPS_OC * WMMA_M * OH * OW + + j * OW * WMMA_M], + acc_frag[i][j], WMMA_M, wmma::mem_col_major); } } } @@ -517,15 +486,14 @@ __global__ void convolution_template_device_u4( template __global__ void convolution_template_device_u4( const uint8_t* __restrict__ /* data */, - const uint8_t* __restrict__ /* filter */, - int32_t* __restrict__ /* out */, int /* N */, int /* IH */, - int /* IW */, int /* OH */, int /* OW */, int /* PH */, int /* PW */, - int /* IC */, int /* OC */, int32_t /* zero */) {} + const uint8_t* __restrict__ /* filter */, int32_t* __restrict__ /* out */, + int /* N */, int /* IH */, int /* IW */, int /* OH */, int /* OW */, + int /* PH */, int /* PW */, int /* IC */, int /* OC */, int32_t /* zero */) {} #endif -__global__ void reorder_kernel(const uint32_t* __restrict__ src, - uint32_t* __restrict__ dst, int rows, int cols, - int fh, int fw, int ic_blks) { +__global__ void reorder_kernel( + const uint32_t* __restrict__ src, uint32_t* __restrict__ dst, int rows, + int cols, int fh, int fw, int ic_blks) { const int tidx = blockIdx.x * blockDim.x + threadIdx.x; const int tidy = blockIdx.y * blockDim.y + threadIdx.y; const uint32_t* __restrict__ sptr = src + tidy * cols + tidx; @@ -546,12 +514,11 @@ __global__ void reorder_kernel(const uint32_t* __restrict__ src, using namespace wmma_conv_integer_subbyte_1xfw; -void megdnn::cuda::wmma_conv_integer_subbyte:: - _do_wmma_conv_integer_subbyte_1xfw( - const uint8_t* d_data, const uint8_t* d_filter, int32_t* d_out, - uint8_t* workspace, int batch_size, int hi, int wi, int ho, - int wo, int ph, int pw, int ci, int co, int fh, int fw, int sh, - int sw, uint8_t zp_data, cudaStream_t stream) { +void megdnn::cuda::wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_1xfw( + const uint8_t* d_data, const uint8_t* d_filter, int32_t* d_out, + uint8_t* workspace, int batch_size, int hi, int wi, int ho, int wo, int ph, + int pw, int ci, int co, int fh, int fw, int sh, int sw, uint8_t zp_data, + cudaStream_t stream) { cuda_check(cudaDeviceSetCacheConfig(cudaFuncCachePreferShared)); cuda_check(cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte)); zp_data = (zp_data << 4) | zp_data; @@ -563,8 +530,8 @@ void megdnn::cuda::wmma_conv_integer_subbyte:: int by = (co + ty - 1) / ty; reorder_kernel<<>>( reinterpret_cast(d_filter), - reinterpret_cast(workspace), co, ci * fh * fw / 8, - fh, fw, ic_blks); + reinterpret_cast(workspace), co, ci * fh * fw / 8, fh, fw, + ic_blks); }; if (fh == 3 && fw == 3 && sh == 1 && sw == 1) { @@ -590,17 +557,17 @@ void megdnn::cuda::wmma_conv_integer_subbyte:: gridDim.y = blocks_per_out_channel; gridDim.z = batch_size; - typedef BlockConfig + typedef BlockConfig< + warps_w, warps_oc, out_channels_per_warp, oh_per_warp, ic_unroll_size> BlockConfig_; _do_dispatch_reorder_kernel(BlockConfig_::IC_BLKS); convolution_template_device_u4< ConvConfig<3, 3, 1, 1>, - BlockConfig> - <<>>(d_data, workspace, d_out, - batch_size, hi, wi, ho, wo, - ph, pw, ci, co, zero); + BlockConfig< + warps_w, warps_oc, out_channels_per_warp, oh_per_warp, + ic_unroll_size>><<>>( + d_data, workspace, d_out, batch_size, hi, wi, ho, wo, ph, pw, ci, co, + zero); } else if (fh == 5 && fw == 5 && sh == 1 && sw == 1) { constexpr size_t warps_w = 2; @@ -625,17 +592,17 @@ void megdnn::cuda::wmma_conv_integer_subbyte:: gridDim.y = blocks_per_out_channel; gridDim.z = batch_size; - typedef BlockConfig + typedef BlockConfig< + warps_w, warps_oc, out_channels_per_warp, oh_per_warp, ic_unroll_size> BlockConfig_; _do_dispatch_reorder_kernel(BlockConfig_::IC_BLKS); convolution_template_device_u4< ConvConfig<5, 5, 1, 1>, - BlockConfig> - <<>>(d_data, workspace, d_out, - batch_size, hi, wi, ho, wo, - ph, pw, ci, co, zero); + BlockConfig< + warps_w, warps_oc, out_channels_per_warp, oh_per_warp, + ic_unroll_size>><<>>( + d_data, workspace, d_out, batch_size, hi, wi, ho, wo, ph, pw, ci, co, + zero); } else if (fh == 7 && fw == 7 && sh == 1 && sw == 1) { constexpr size_t warps_w = 2; constexpr size_t warps_oc = 4; @@ -659,17 +626,17 @@ void megdnn::cuda::wmma_conv_integer_subbyte:: gridDim.y = blocks_per_out_channel; gridDim.z = batch_size; - typedef BlockConfig + typedef BlockConfig< + warps_w, warps_oc, out_channels_per_warp, oh_per_warp, ic_unroll_size> BlockConfig_; _do_dispatch_reorder_kernel(BlockConfig_::IC_BLKS); convolution_template_device_u4< ConvConfig<7, 7, 1, 1>, - BlockConfig> - <<>>(d_data, workspace, d_out, - batch_size, hi, wi, ho, wo, - ph, pw, ci, co, zero); + BlockConfig< + warps_w, warps_oc, out_channels_per_warp, oh_per_warp, + ic_unroll_size>><<>>( + d_data, workspace, d_out, batch_size, hi, wi, ho, wo, ph, pw, ci, co, + zero); } after_kernel_launch(); } diff --git a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4_fhxfw.cu b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4_fhxfw.cu index d63d20e0..8b0d56e7 100644 --- a/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4_fhxfw.cu +++ b/dnn/src/cuda/conv_bias/quint4x4x32_wmma/wmma_conv_integer_u4_fhxfw.cu @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -48,8 +49,9 @@ using namespace wmma_conv_integer_subbyte; namespace wmma_conv_integer_subbyte_fhxfw { -template +template < + int WARPS_W_, int WARPS_OC_, int OUT_CHANNELS_PER_WARP_, int OH_PER_WARP_, + int IC_UNROLL_SIZE_> struct BlockConfig { static int const WARPS_W = WARPS_W_; static int const WARPS_OC = WARPS_OC_; @@ -84,21 +86,17 @@ struct FilterCount { WMMA_M * BlockConfig::WARPS_OC * BlockConfig::OUT_CHANNELS_PER_WARP; static int const SMEM_FILTER_ROW = OUT_CHANNELS_PER_BLOCK; static int const SMEM_SKEW = - ((ConvConfig::FH * ConvConfig::FW * BlockConfig::IC_UNROLL_SIZE) % - 2 == - 0) * + ((ConvConfig::FH * ConvConfig::FW * BlockConfig::IC_UNROLL_SIZE) % 2 == 0) * SKEW; static int const SMEM_FILTER_COL = - (BlockConfig::IC_BLKS * ConvConfig::FH * ConvConfig::FW * 8 + - SMEM_SKEW) / + (BlockConfig::IC_BLKS * ConvConfig::FH * ConvConfig::FW * 8 + SMEM_SKEW) / 2; static int const SMEM_FILTER_STRIDE = SMEM_FILTER_COL * 2; static int const REG_FILTER_ROW = (SMEM_FILTER_ROW + BlockConfig::WARPS_PER_BLOCK - 1) / BlockConfig::WARPS_PER_BLOCK; static int const REG_FILTER_COL = - (BlockConfig::IC_BLKS * ConvConfig::FH * ConvConfig::FW + - WARP_SIZE - 1) / + (BlockConfig::IC_BLKS * ConvConfig::FH * ConvConfig::FW + WARP_SIZE - 1) / WARP_SIZE; }; @@ -120,10 +118,9 @@ struct ConvDataGlobal2ShareMemVisitor { copy_t reg_cache[DataCount::LANES_PER_WARP]; - __device__ ConvDataGlobal2ShareMemVisitor(uint8_t* smem, - const uint8_t* g_ptr, int IH, - int IW, int b_ih, int b_iw, - copy_t zero) + __device__ ConvDataGlobal2ShareMemVisitor( + uint8_t* smem, const uint8_t* g_ptr, int IH, int IW, int b_ih, int b_iw, + copy_t zero) : smem{smem}, g_ptr{g_ptr}, b_ih{b_ih}, @@ -140,21 +137,18 @@ struct ConvDataGlobal2ShareMemVisitor { int col = (tid_in_warp << 3); // read input from global memory without boundary check #pragma unroll - for (int i = 0; - i < DataCount::LANES_PER_WARP; ++i) { + for (int i = 0; i < DataCount::LANES_PER_WARP; ++i) { int row = i * BlockConfig_::WARPS_PER_BLOCK + warp_id; - int ci_idx = - row / DataCount::LANES_PER_SLICE; - int hi_idx = - row - ci_idx * DataCount::LANES_PER_SLICE; - bool bounds = ((b_iw + tid_in_warp) >= 0) && - ((b_iw + tid_in_warp) < IW) && + int ci_idx = row / DataCount::LANES_PER_SLICE; + int hi_idx = row - + ci_idx * DataCount::LANES_PER_SLICE; + bool bounds = ((b_iw + tid_in_warp) >= 0) && ((b_iw + tid_in_warp) < IW) && ((b_ih + hi_idx) >= 0) && ((b_ih + hi_idx) < IH); if (bounds) { - copy_t val = *(copy_t*)(&g_ptr[(ci_idx * ci_stride + - hi_idx * hi_stride + col) / - 2]); + copy_t val = *(copy_t*)(&g_ptr + [(ci_idx * ci_stride + + hi_idx * hi_stride + col) / + 2]); reg_cache[i] = val; } else { reg_cache[i] = zero; @@ -164,18 +158,15 @@ struct ConvDataGlobal2ShareMemVisitor { __device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; - i < DataCount::LANES_PER_WARP; ++i) { + for (int i = 0; i < DataCount::LANES_PER_WARP; ++i) { if (tid_in_warp < DataCount::LANE_SIZE) { int row = i * BlockConfig_::WARPS_PER_BLOCK + warp_id; int ci_idx = - row / - DataCount::LANES_PER_SLICE; + row / DataCount::LANES_PER_SLICE; int hi_idx = - row - ci_idx * DataCount::LANES_PER_SLICE; - int y = hi_idx * DataCount::LANE_SIZE + + row - + ci_idx * DataCount::LANES_PER_SLICE; + int y = hi_idx * DataCount::LANE_SIZE + tid_in_warp; int x = ci_idx * 8; *(copy_t*)(get_smem_ptr(y, x)) = reg_cache[i]; @@ -184,10 +175,8 @@ struct ConvDataGlobal2ShareMemVisitor { } __device__ __forceinline__ uint8_t* get_smem_ptr(int y, int x) { - return &smem[(y * DataCount::SMEM_DATA_STRIDE + - x) / - 2]; + return &smem + [(y * DataCount::SMEM_DATA_STRIDE + x) / 2]; } __device__ __forceinline__ void inc_stage() { @@ -211,10 +200,8 @@ struct ConvFilterGlobal2ShareMemVisitor { copy_t reg_cache[FilterCount::REG_FILTER_ROW] [FilterCount::REG_FILTER_COL]; - __device__ ConvFilterGlobal2ShareMemVisitor(uint8_t* smem, - const uint8_t* g_ptr, - int co_stride, int co_remain, - int idx) + __device__ ConvFilterGlobal2ShareMemVisitor( + uint8_t* smem, const uint8_t* g_ptr, int co_stride, int co_remain, int idx) : smem{smem}, g_ptr{g_ptr}, co_stride{co_stride}, @@ -222,22 +209,20 @@ struct ConvFilterGlobal2ShareMemVisitor { idx{idx} {} __device__ __forceinline__ void copy() { - int ci_remain = - idx < BlockConfig_::IC_BLKS ? idx : BlockConfig_::IC_BLKS; + int ci_remain = idx < BlockConfig_::IC_BLKS ? idx : BlockConfig_::IC_BLKS; #pragma unroll - for (int i = 0; - i < FilterCount::REG_FILTER_ROW; ++i) { + for (int i = 0; i < FilterCount::REG_FILTER_ROW; + ++i) { #pragma unroll - for (int j = 0; - j < FilterCount::REG_FILTER_COL; + for (int j = 0; j < FilterCount::REG_FILTER_COL; ++j) { int y = BlockConfig_::WARPS_PER_BLOCK * i + warp_id; int x = WARP_SIZE * j + tid_in_warp; bool valid = (x < ci_remain * ConvConfig_::FH * ConvConfig_::FW) && (y < - FilterCount::OUT_CHANNELS_PER_BLOCK) && + FilterCount< + ConvConfig_, BlockConfig_>::OUT_CHANNELS_PER_BLOCK) && (y < co_remain); if (valid) { copy_t val = *(copy_t*)(&g_ptr[y * co_stride + x * 4]); @@ -251,11 +236,10 @@ struct ConvFilterGlobal2ShareMemVisitor { __device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; - i < FilterCount::REG_FILTER_ROW; ++i) { + for (int i = 0; i < FilterCount::REG_FILTER_ROW; + ++i) { #pragma unroll - for (int j = 0; - j < FilterCount::REG_FILTER_COL; + for (int j = 0; j < FilterCount::REG_FILTER_COL; ++j) { int y = BlockConfig_::WARPS_PER_BLOCK * i + warp_id; int x = WARP_SIZE * j + tid_in_warp; @@ -263,15 +247,13 @@ struct ConvFilterGlobal2ShareMemVisitor { int ci_blk = x / (ConvConfig_::FH * ConvConfig_::FW); int ci_inner_blk = (ci_blk & 0x3); int ci_outer_blk = (ci_blk >> 2); - int s_x = ci_outer_blk * IC_BLK * ConvConfig_::FH * - ConvConfig_::FW + + int s_x = ci_outer_blk * IC_BLK * ConvConfig_::FH * ConvConfig_::FW + spatial_idx * IC_BLK + ci_inner_blk; bool bounds = (y < - FilterCount::OUT_CHANNELS_PER_BLOCK) && - (x < BlockConfig_::IC_BLKS * ConvConfig_::FH * - ConvConfig_::FW); + FilterCount< + ConvConfig_, BlockConfig_>::OUT_CHANNELS_PER_BLOCK) && + (x < BlockConfig_::IC_BLKS * ConvConfig_::FH * ConvConfig_::FW); if (bounds) *(copy_t*)get_smem_ptr(y, s_x * 8) = reg_cache[i][j]; } @@ -279,10 +261,9 @@ struct ConvFilterGlobal2ShareMemVisitor { } __device__ __forceinline__ uint8_t* get_smem_ptr(int y, int x) { - return &smem[(y * FilterCount::SMEM_FILTER_STRIDE + - x) / - 2]; + return &smem + [(y * FilterCount::SMEM_FILTER_STRIDE + x) / + 2]; } __device__ __forceinline__ void inc_stage() { @@ -293,11 +274,9 @@ struct ConvFilterGlobal2ShareMemVisitor { template __device__ inline void load_share_mem( - wmma::fragment + wmma::fragment data_frag[BlockConfig_::OH_PER_WARP], - wmma::fragment + wmma::fragment filter_frag[BlockConfig_::OUT_CHANNELS_PER_WARP], ConvDataGlobal2ShareMemVisitor& gbl2smem_data_visitor, @@ -305,12 +284,11 @@ __device__ inline void load_share_mem( gbl2smem_filter_visitor, int data_spatial_idx, int filter_spatial_idx, int ic_blk) { const int warp_y = threadIdx.y; - uint8_t* __restrict__ s_ptr_data = gbl2smem_data_visitor.get_smem_ptr( - data_spatial_idx, ic_blk * WMMA_K); + uint8_t* __restrict__ s_ptr_data = + gbl2smem_data_visitor.get_smem_ptr(data_spatial_idx, ic_blk * WMMA_K); uint8_t* __restrict__ s_ptr_filter = gbl2smem_filter_visitor.get_smem_ptr( - warp_y * WMMA_M, - ic_blk * WMMA_K * ConvConfig_::FH * ConvConfig_::FW + - filter_spatial_idx * WMMA_K); + warp_y * WMMA_M, ic_blk * WMMA_K * ConvConfig_::FH * ConvConfig_::FW + + filter_spatial_idx * WMMA_K); #pragma unroll for (int i = 0; i < BlockConfig_::OH_PER_WARP; ++i) { @@ -318,8 +296,7 @@ __device__ inline void load_share_mem( data_frag[i], s_ptr_data + i * DataCount::LANE_SIZE * - DataCount::SMEM_DATA_STRIDE / + DataCount::SMEM_DATA_STRIDE / 2, DataCount::SMEM_DATA_STRIDE); } @@ -329,27 +306,27 @@ __device__ inline void load_share_mem( filter_frag[j], s_ptr_filter + j * WMMA_M * BlockConfig_::WARPS_OC * - FilterCount::SMEM_FILTER_STRIDE / + FilterCount< + ConvConfig_, BlockConfig_>::SMEM_FILTER_STRIDE / 2, FilterCount::SMEM_FILTER_STRIDE); } } template -__device__ inline void -calc(wmma::fragment - data_frag[OH_PER_WARP], - wmma::fragment - filter_frag[OUT_CHANNELS_PER_WARP], - wmma::fragment - acc_frag[OUT_CHANNELS_PER_WARP][OH_PER_WARP]) { +__device__ inline void calc( + wmma::fragment + data_frag[OH_PER_WARP], + wmma::fragment + filter_frag[OUT_CHANNELS_PER_WARP], + wmma::fragment + acc_frag[OUT_CHANNELS_PER_WARP][OH_PER_WARP]) { #pragma unroll for (int i = 0; i < OUT_CHANNELS_PER_WARP; ++i) { #pragma unroll for (int j = 0; j < OH_PER_WARP; ++j) { - wmma::mma_sync(acc_frag[i][j], filter_frag[i], data_frag[j], - acc_frag[i][j]); + wmma::mma_sync( + acc_frag[i][j], filter_frag[i], data_frag[j], acc_frag[i][j]); } } } @@ -360,11 +337,9 @@ __device__ void consume_slice( gbl2smem_data_visitor, ConvFilterGlobal2ShareMemVisitor& gbl2smem_filter_visitor, - wmma::fragment + wmma::fragment data_frag[2][BlockConfig_::OH_PER_WARP], - wmma::fragment + wmma::fragment filter_frag[2][BlockConfig_::OUT_CHANNELS_PER_WARP], wmma::fragment acc_frag[BlockConfig_::OUT_CHANNELS_PER_WARP] @@ -383,27 +358,23 @@ __device__ void consume_slice( BlockConfig_::IC_UNROLL_SIZE * ConvConfig_::FH * ConvConfig_::FW - 1; loop_count++) { calc( - data_frag[loop_count % 2], filter_frag[loop_count % 2], - acc_frag); + data_frag[loop_count % 2], filter_frag[loop_count % 2], acc_frag); - int filter_spatial_idx = - (loop_count + 1) % (ConvConfig_::FH * ConvConfig_::FW); + int filter_spatial_idx = (loop_count + 1) % (ConvConfig_::FH * ConvConfig_::FW); int ic_blk = (loop_count + 1) / (ConvConfig_::FH * ConvConfig_::FW); int fh = filter_spatial_idx / ConvConfig_::FW; int fw = filter_spatial_idx % ConvConfig_::FW; - int data_spatial_idx = - data_spatial_idx_base + - fh * DataCount::LANE_SIZE + fw; + int data_spatial_idx = data_spatial_idx_base + + fh * DataCount::LANE_SIZE + + fw; load_share_mem( - data_frag[(loop_count + 1) % 2], - filter_frag[(loop_count + 1) % 2], gbl2smem_data_visitor, - gbl2smem_filter_visitor, data_spatial_idx, filter_spatial_idx, - ic_blk); + data_frag[(loop_count + 1) % 2], filter_frag[(loop_count + 1) % 2], + gbl2smem_data_visitor, gbl2smem_filter_visitor, data_spatial_idx, + filter_spatial_idx, ic_blk); } calc( - data_frag[(loop_count % 2)], filter_frag[(loop_count % 2)], - acc_frag); + data_frag[(loop_count % 2)], filter_frag[(loop_count % 2)], acc_frag); if (!last_slice) { __syncthreads(); gbl2smem_data_visitor.commit(); @@ -473,8 +444,8 @@ __device__ void consume_slice_no_reg_cache( template __global__ void convolution_template_device_u4( const uint8_t* __restrict__ data, const uint8_t* __restrict__ filter, - int32_t* __restrict__ out, int N, int IH, int IW, int OH, int OW, - int PH, int PW, int IC, int OC, int32_t zero) { + int32_t* __restrict__ out, int N, int IH, int IW, int OH, int OW, int PH, + int PW, int IC, int OC, int32_t zero) { constexpr size_t IC_BLKS = BlockConfig_::IC_BLKS; constexpr size_t OUT_CHANNELS_PER_BLOCK = FilterCount::OUT_CHANNELS_PER_BLOCK; @@ -497,36 +468,31 @@ __global__ void convolution_template_device_u4( const uint8_t* __restrict__ g_ptr_data = data + bidz * IC * IH * IW / 2 + (b_ih * IW + b_iw) * 8 / 2; const uint8_t* __restrict__ g_ptr_filter = - filter + bidy * OUT_CHANNELS_PER_BLOCK * ConvConfig_::FH * - ConvConfig_::FW * IC / 2; + filter + + bidy * OUT_CHANNELS_PER_BLOCK * ConvConfig_::FH * ConvConfig_::FW * IC / 2; const int co_remain = OC - bidy * OUT_CHANNELS_PER_BLOCK; - int32_t* __restrict__ g_ptr_out = out + bidz * OC * OH * OW + - oc_start * OH * OW + + int32_t* __restrict__ g_ptr_out = out + bidz * OC * OH * OW + oc_start * OH * OW + (b_oh * OW + ow_start) * WMMA_M; const int icb = IC / 8; + __shared__ uint8_t smem_data[DataCount::SMEM_DATA_ROW] + [DataCount::SMEM_DATA_COL]; __shared__ uint8_t - smem_data[DataCount::SMEM_DATA_ROW] - [DataCount::SMEM_DATA_COL]; - __shared__ uint8_t smem_filter - [FilterCount::SMEM_FILTER_ROW] - [FilterCount::SMEM_FILTER_COL]; + smem_filter[FilterCount::SMEM_FILTER_ROW] + [FilterCount::SMEM_FILTER_COL]; wmma::fragment - acc_frag[BlockConfig_::OUT_CHANNELS_PER_WARP] - [BlockConfig_::OH_PER_WARP]; + acc_frag[BlockConfig_::OUT_CHANNELS_PER_WARP][BlockConfig_::OH_PER_WARP]; wmma::fragment data_frag[2][BlockConfig_::OH_PER_WARP]; wmma::fragment filter_frag[2][BlockConfig_::OUT_CHANNELS_PER_WARP]; - ConvDataGlobal2ShareMemVisitor - gbl2smem_data_visitor{smem_data[0], g_ptr_data, IH, IW, - b_ih, b_iw, zero}; - ConvFilterGlobal2ShareMemVisitor - gbl2smem_filter_visitor{smem_filter[0], g_ptr_filter, - IC / 2 * ConvConfig_::FH * ConvConfig_::FW, - co_remain, icb}; + ConvDataGlobal2ShareMemVisitor gbl2smem_data_visitor{ + smem_data[0], g_ptr_data, IH, IW, b_ih, b_iw, zero}; + ConvFilterGlobal2ShareMemVisitor gbl2smem_filter_visitor{ + smem_filter[0], g_ptr_filter, IC / 2 * ConvConfig_::FH * ConvConfig_::FW, + co_remain, icb}; #pragma unroll for (int i = 0; i < BlockConfig_::OUT_CHANNELS_PER_WARP; ++i) { @@ -550,26 +516,25 @@ __global__ void convolution_template_device_u4( #pragma unroll for (int ci_blk = 0; ci_blk < ic_blocks; ci_blk++) { consume_slice( - gbl2smem_data_visitor, gbl2smem_filter_visitor, data_frag, - filter_frag, acc_frag); + gbl2smem_data_visitor, gbl2smem_filter_visitor, data_frag, filter_frag, + acc_frag); } consume_slice( - gbl2smem_data_visitor, gbl2smem_filter_visitor, data_frag, - filter_frag, acc_frag); + gbl2smem_data_visitor, gbl2smem_filter_visitor, data_frag, filter_frag, + acc_frag); // store #pragma unroll for (int i = 0; i < BlockConfig_::OUT_CHANNELS_PER_WARP; ++i) { #pragma unroll for (int j = 0; j < BlockConfig_::OH_PER_WARP; ++j) { - if (b_oh + j < OH && - oc_start + i * BlockConfig_::WARPS_OC * WMMA_M < OC && + if (b_oh + j < OH && oc_start + i * BlockConfig_::WARPS_OC * WMMA_M < OC && ow_start < OW) { - wmma::store_matrix_sync(&g_ptr_out[i * BlockConfig_::WARPS_OC * - WMMA_M * OH * OW + - j * OW * WMMA_M], - acc_frag[i][j], WMMA_M, - wmma::mem_col_major); + wmma::store_matrix_sync( + &g_ptr_out + [i * BlockConfig_::WARPS_OC * WMMA_M * OH * OW + + j * OW * WMMA_M], + acc_frag[i][j], WMMA_M, wmma::mem_col_major); } } } @@ -578,21 +543,18 @@ __global__ void convolution_template_device_u4( template __global__ void convolution_template_device_u4( const uint8_t* __restrict__ /* data */, - const uint8_t* __restrict__ /* filter */, - int32_t* __restrict__ /* out */, int /* N */, int /* IH */, - int /* IW */, int /* OH */, int /* OW */, int /* PH */, int /* PW */, - int /* IC */, int /* OC */, int32_t /* zero */) {} + const uint8_t* __restrict__ /* filter */, int32_t* __restrict__ /* out */, + int /* N */, int /* IH */, int /* IW */, int /* OH */, int /* OW */, + int /* PH */, int /* PW */, int /* IC */, int /* OC */, int32_t /* zero */) {} #endif } // namespace wmma_conv_integer_subbyte_fhxfw using namespace wmma_conv_integer_subbyte_fhxfw; -void megdnn::cuda::wmma_conv_integer_subbyte:: - _do_wmma_conv_integer_subbyte_fhxfw( - const uint8_t* d_data, const uint8_t* d_filter, int32_t* d_out, - int batch_size, int hi, int wi, int ho, int wo, int ph, int pw, - int ci, int co, int fh, int fw, int sh, int sw, uint8_t zp_data, - cudaStream_t stream) { +void megdnn::cuda::wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_fhxfw( + const uint8_t* d_data, const uint8_t* d_filter, int32_t* d_out, int batch_size, + int hi, int wi, int ho, int wo, int ph, int pw, int ci, int co, int fh, int fw, + int sh, int sw, uint8_t zp_data, cudaStream_t stream) { cuda_check(cudaDeviceSetCacheConfig(cudaFuncCachePreferShared)); cuda_check(cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte)); zp_data = (zp_data << 4) | zp_data; @@ -622,11 +584,11 @@ void megdnn::cuda::wmma_conv_integer_subbyte:: convolution_template_device_u4< ConvConfig<3, 3, 1, 1>, - BlockConfig> - <<>>(d_data, d_filter, d_out, - batch_size, hi, wi, ho, wo, - ph, pw, ci, co, zero); + BlockConfig< + warps_w, warps_oc, out_channels_per_warp, oh_per_warp, + ic_unroll_size>><<>>( + d_data, d_filter, d_out, batch_size, hi, wi, ho, wo, ph, pw, ci, co, + zero); } else if (fh == 5 && fw == 5 && sh == 1 && sw == 1) { constexpr size_t warps_w = 2; constexpr size_t warps_oc = 4; @@ -652,11 +614,11 @@ void megdnn::cuda::wmma_conv_integer_subbyte:: convolution_template_device_u4< ConvConfig<5, 5, 1, 1>, - BlockConfig> - <<>>(d_data, d_filter, d_out, - batch_size, hi, wi, ho, wo, - ph, pw, ci, co, zero); + BlockConfig< + warps_w, warps_oc, out_channels_per_warp, oh_per_warp, + ic_unroll_size>><<>>( + d_data, d_filter, d_out, batch_size, hi, wi, ho, wo, ph, pw, ci, co, + zero); } else if (fh == 7 && fw == 7 && sh == 1 && sw == 1) { constexpr size_t warps_w = 2; constexpr size_t warps_oc = 2; @@ -682,11 +644,11 @@ void megdnn::cuda::wmma_conv_integer_subbyte:: convolution_template_device_u4< ConvConfig<7, 7, 1, 1>, - BlockConfig> - <<>>(d_data, d_filter, d_out, - batch_size, hi, wi, ho, wo, - ph, pw, ci, co, zero); + BlockConfig< + warps_w, warps_oc, out_channels_per_warp, oh_per_warp, + ic_unroll_size>><<>>( + d_data, d_filter, d_out, batch_size, hi, wi, ho, wo, ph, pw, ci, co, + zero); } after_kernel_launch(); } diff --git a/dnn/src/cuda/conv_bias/reduce_filter.cu b/dnn/src/cuda/conv_bias/reduce_filter.cu index f0359430..779c3184 100644 --- a/dnn/src/cuda/conv_bias/reduce_filter.cu +++ b/dnn/src/cuda/conv_bias/reduce_filter.cu @@ -56,22 +56,18 @@ struct ReduceWithScaleInt4Op { static const wtype INIT = 0; #if MEGDNN_CC_CUDA - __host__ __device__ void write(uint32_t idx, wtype val) { - dst[idx] = val * scale; - } + __host__ __device__ void write(uint32_t idx, wtype val) { dst[idx] = val * scale; } __host__ __device__ static wtype apply(wtype a, wtype b) { return a + b; } __device__ wtype read(uint32_t idx) { constexpr uint32_t subbytes_per_pixel = 8; - const uint32_t* sptr = - (const uint32_t*)(src + subbytes_per_pixel * idx / 2); + const uint32_t* sptr = (const uint32_t*)(src + subbytes_per_pixel * idx / 2); uint32_t val = *sptr; int32_t ret = 0; #pragma unroll for (int j = 0; j < 8; j++) { - ret += integer_subbyte::unpack_integer_4bits(val, - (j << 2)); + ret += integer_subbyte::unpack_integer_4bits(val, (j << 2)); } return ret; } @@ -96,14 +92,12 @@ struct ReduceUpdateBiasInt4Op { __device__ wtype read(uint32_t idx) { constexpr uint32_t subbytes_per_pixel = 8; - const uint32_t* fptr = - (const uint32_t*)(filter + subbytes_per_pixel * idx / 2); + const uint32_t* fptr = (const uint32_t*)(filter + subbytes_per_pixel * idx / 2); uint32_t val = *fptr; int32_t ret = 0; #pragma unroll for (int j = 0; j < 8; j++) { - ret += integer_subbyte::unpack_integer_4bits(val, - (j << 2)); + ret += integer_subbyte::unpack_integer_4bits(val, (j << 2)); } return ret; } @@ -114,8 +108,8 @@ struct ReduceUpdateBiasInt4Op { template void megdnn::cuda::do_dispatch_reduce_with_scale_filter_4bit( - const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols, - int32_t* dst, cudaStream_t stream) { + const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols, int32_t* dst, + cudaStream_t stream) { // rows = OC // cols is measured in pixels, i.e. IC * FH * FW / 8, a pixel consists of 8 // subbyte data, @@ -127,14 +121,13 @@ void megdnn::cuda::do_dispatch_reduce_with_scale_filter_4bit( static_cast(stream); static_cast(rows); static_cast(cols); - run_reduce, false>(dst + rows, rows, cols, - 1, stream, op); + run_reduce, false>( + dst + rows, rows, cols, 1, stream, op); } -#define INST(signedness) \ - template void \ - megdnn::cuda::do_dispatch_reduce_with_scale_filter_4bit( \ - const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols, \ +#define INST(signedness) \ + template void megdnn::cuda::do_dispatch_reduce_with_scale_filter_4bit( \ + const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols, \ int32_t* dst, cudaStream_t stream) INST(false); INST(true); @@ -142,31 +135,31 @@ INST(true); template void megdnn::cuda::do_dispatch_reduce_filter_and_update_bias_4bit( - const uint8_t* filter, const int32_t* src_bias, uint32_t rows, - uint32_t cols, int32_t* dst_bias, int32_t* workspace, - int32_t zero_point, cudaStream_t stream) { + const uint8_t* filter, const int32_t* src_bias, uint32_t rows, uint32_t cols, + int32_t* dst_bias, int32_t* workspace, int32_t zero_point, + cudaStream_t stream) { ReduceUpdateBiasInt4Op op; op.filter = filter; op.src_bias = src_bias; op.dst_bias = dst_bias; op.zero_point = zero_point; - run_reduce, false>(workspace, rows, cols, - 1, stream, op); + run_reduce, false>( + workspace, rows, cols, 1, stream, op); } -#define INST(signedness) \ - template void \ - megdnn::cuda::do_dispatch_reduce_filter_and_update_bias_4bit( \ - const uint8_t* filter, const int32_t* src_bias, uint32_t rows, \ - uint32_t cols, int32_t* dst_bias, int32_t* workspace, \ - int32_t zero_point, cudaStream_t stream) +#define INST(signedness) \ + template void \ + megdnn::cuda::do_dispatch_reduce_filter_and_update_bias_4bit( \ + const uint8_t* filter, const int32_t* src_bias, uint32_t rows, \ + uint32_t cols, int32_t* dst_bias, int32_t* workspace, int32_t zero_point, \ + cudaStream_t stream) INST(false); INST(true); #undef INST -size_t megdnn::cuda::do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B, - size_t C) { +size_t megdnn::cuda::do_dispatch_reduce_workspace_in_bytes( + size_t A, size_t B, size_t C) { return get_reduce_workspace_in_bytes>(A, B, C); } diff --git a/dnn/src/cuda/conv_bias/reduce_filter.cuh b/dnn/src/cuda/conv_bias/reduce_filter.cuh index cd42ba6b..a857ff25 100644 --- a/dnn/src/cuda/conv_bias/reduce_filter.cuh +++ b/dnn/src/cuda/conv_bias/reduce_filter.cuh @@ -45,16 +45,14 @@ namespace megdnn { namespace cuda { template -void do_dispatch_reduce_with_scale_filter_4bit(const uint8_t* src, - int32_t scale, uint32_t rows, - uint32_t cols, int32_t* dst, - cudaStream_t stream); +void do_dispatch_reduce_with_scale_filter_4bit( + const uint8_t* src, int32_t scale, uint32_t rows, uint32_t cols, int32_t* dst, + cudaStream_t stream); template void do_dispatch_reduce_filter_and_update_bias_4bit( - const uint8_t* filter, const int32_t* src_bias, uint32_t rows, - uint32_t cols, int32_t* dst_bias, int32_t* workspace, int zero_point, - cudaStream_t stream); + const uint8_t* filter, const int32_t* src_bias, uint32_t rows, uint32_t cols, + int32_t* dst_bias, int32_t* workspace, int zero_point, cudaStream_t stream); size_t do_dispatch_reduce_workspace_in_bytes(size_t A, size_t B, size_t C); diff --git a/dnn/src/cuda/convolution/backward_data/algo.cpp b/dnn/src/cuda/convolution/backward_data/algo.cpp index 2bde8791..1196ceb2 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.cpp +++ b/dnn/src/cuda/convolution/backward_data/algo.cpp @@ -56,15 +56,14 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl) -ConvolutionBackwardDataImpl::AlgoCUDNN* -ConvolutionBackwardDataImpl::AlgoPack::cudnn_from_enum( - cudnnConvolutionBwdDataAlgo_t algo) { +ConvolutionBackwardDataImpl::AlgoCUDNN* ConvolutionBackwardDataImpl::AlgoPack:: + cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo) { for (auto&& i : cudnn) { if (i.cudnn_enum() == algo) return &i; } - megdnn_throw(ssprintf("can not find cudnn bwd_data algorithm %d", - static_cast(algo))); + megdnn_throw(ssprintf( + "can not find cudnn bwd_data algorithm %d", static_cast(algo))); } ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; @@ -72,8 +71,9 @@ ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( const ConvolutionBackwardDataImpl* o, const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) - : SizeArgs(o, filter, o->make_canonized_filter_meta(grad.ndim, filter), - diff, grad) {} + : SizeArgs( + o, filter, o->make_canonized_filter_meta(grad.ndim, filter), diff, + grad) {} ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( const ConvolutionBackwardDataImpl* o, const TensorLayout& filter, @@ -88,8 +88,7 @@ ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( ConvolutionBackwardDataImpl::AlgoBase::ExecArgs::ExecArgs( const ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter, - _megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace) + _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) : SizeArgs(opr, filter.layout, diff.layout, grad.layout), filter_tensor{&filter}, diff_tensor{&diff}, @@ -104,9 +103,9 @@ std::string ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::to_string() const { "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s", fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], diff_layout->to_string().c_str(), grad_layout->to_string().c_str(), - fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], - fm.dilation[0], fm.dilation[1], !fm.should_flip, - diff_layout->dtype.name(), grad_layout->dtype.name()); + fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], fm.dilation[0], + fm.dilation[1], !fm.should_flip, diff_layout->dtype.name(), + grad_layout->dtype.name()); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/backward_data/algo.h b/dnn/src/cuda/convolution/backward_data/algo.h index 821ae5f2..d46a5652 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.h +++ b/dnn/src/cuda/convolution/backward_data/algo.h @@ -56,26 +56,26 @@ public: void init_desc(convolution::CUDNNBwdDataDescs& desc) const { desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); } - SizeArgs(const ConvolutionBackwardDataImpl* opr, - const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad); - SizeArgs(const ConvolutionBackwardDataImpl* opr, - const TensorLayout& filter, - const CanonizedFilterMeta& filter_meta, - const TensorLayout& diff, const TensorLayout& grad); + SizeArgs( + const ConvolutionBackwardDataImpl* opr, const TensorLayout& filter, + const TensorLayout& diff, const TensorLayout& grad); + SizeArgs( + const ConvolutionBackwardDataImpl* opr, const TensorLayout& filter, + const CanonizedFilterMeta& filter_meta, const TensorLayout& diff, + const TensorLayout& grad); convolution::ForwardSizeArgs as_fwd_args() const { - return {handle, grad_layout, filter_layout, filter_meta, - diff_layout}; + return {handle, grad_layout, filter_layout, filter_meta, diff_layout}; } }; struct ExecArgs : public SizeArgs { const TensorND *filter_tensor, *diff_tensor, *grad_tensor; Workspace workspace; - ExecArgs(const ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter, - _megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace); + ExecArgs( + const ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -91,17 +91,16 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "conv bwd data algo %s: " - "required workspace %zu bytes, got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "conv bwd data algo %s: " + "required workspace %zu bytes, got %zu", + name(), req, workspace.size); return *this; } @@ -113,10 +112,10 @@ class ConvolutionBackwardDataImpl::AlgoCUDNN final : public AlgoBase { CudnnAlgoPack::Attr m_attr; public: - AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) - : m_cudnn_enum(cudnn_enum) { - megdnn_assert(CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) != - CudnnAlgoPack::conv_bwd_data_algos().end()); + AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { + megdnn_assert( + CudnnAlgoPack::conv_bwd_data_algos().find(cudnn_enum) != + CudnnAlgoPack::conv_bwd_data_algos().end()); m_attr = CudnnAlgoPack::conv_bwd_data_algos().at(cudnn_enum); } @@ -158,14 +157,12 @@ public: void exec(const ExecArgs& args) const override; std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; const char* name() const override { return "MATMUL"; } MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; } }; @@ -177,9 +174,7 @@ public: const char* name() const override { return "CHANNEL_WISE"; } MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } }; class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase { @@ -191,8 +186,7 @@ public: const char* name() const override { return "CHANNEL_WISE_SMALL"; } MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } }; @@ -203,16 +197,11 @@ public: void exec(const ExecArgs& args) const override; std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; - const char* name() const override { - return "CONVOLUTION_BACKWARD_DATD_BFLOAT16"; - } + const char* name() const override { return "CONVOLUTION_BACKWARD_DATD_BFLOAT16"; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; @@ -220,25 +209,19 @@ private: }; //! implement group conv by another algo -class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final - : public AlgoBase { +class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; - const char* name() const override { - return "CUDA:GROUP_CONV_BACKWARD_DATA"; - } + const char* name() const override { return "CUDA:GROUP_CONV_BACKWARD_DATA"; } MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; @@ -256,26 +239,24 @@ public: int warp_k; int stage; std::string to_string() { - return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, - threadblock_n, threadblock_k, warp_m, warp_n, - warp_k, stage); + return ssprintf( + "_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, + threadblock_k, warp_m, warp_n, warp_k, stage); } }; AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) : m_algo_param{algo_param}, - m_name{ssprintf("INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s", - m_algo_param.to_string().c_str())} {} + m_name{ssprintf( + "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM%s", + m_algo_param.to_string().c_str())} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; const char* name() const override { return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8) private: - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; const void* get_available_op(const SizeArgs& args) const; AlgoParam m_algo_param; std::string m_name; @@ -287,16 +268,12 @@ public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - const char* name() const override { - return "INT8_NCHW_DOTPROD_IMPLICIT_GEMM"; - } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + const char* name() const override { return "INT8_NCHW_DOTPROD_IMPLICIT_GEMM"; } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW_DOTPROD_INT8); + private: - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; const void* get_available_op(const SizeArgs& args) const; }; @@ -313,29 +290,27 @@ public: int stage; int access_size; std::string to_string() { - return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage_%d", threadblock_m, - threadblock_n, threadblock_k, warp_m, warp_n, - warp_k, stage, access_size); + return ssprintf( + "_%dX%dX%d_%dX%dX%d_%dstage_%d", threadblock_m, threadblock_n, + threadblock_k, warp_m, warp_n, warp_k, stage, access_size); } }; AlgoInt8NHWCIMMAImplicitGemm(AlgoParam algo_param) : m_algo_param{algo_param}, - m_name{ssprintf("INT8_NHWC_IMMA_IMPLICIT_GEMM%s", - m_algo_param.to_string().c_str())} {} + m_name{ssprintf( + "INT8_NHWC_IMMA_IMPLICIT_GEMM%s", + m_algo_param.to_string().c_str())} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; const char* name() const override { return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NHWC_IMMA_INT8) private: - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; const void* get_available_op(const SizeArgs& args) const; - void reorder_filter(const ExecArgs& args, const int iterleaved, - int8_t* reordered_filter) const; + void reorder_filter( + const ExecArgs& args, const int iterleaved, int8_t* reordered_filter) const; AlgoParam m_algo_param; std::string m_name; }; diff --git a/dnn/src/cuda/convolution/backward_data/bfloat16.cpp b/dnn/src/cuda/convolution/backward_data/bfloat16.cpp index 53830dbd..c4798a40 100644 --- a/dnn/src/cuda/convolution/backward_data/bfloat16.cpp +++ b/dnn/src/cuda/convolution/backward_data/bfloat16.cpp @@ -20,8 +20,7 @@ using namespace convolution; namespace { std::pair sub_opr_config( - const TensorLayoutArray& layouts, - const ConvolutionBackwardDataImpl* opr) { + const TensorLayoutArray& layouts, const ConvolutionBackwardDataImpl* opr) { megdnn_assert(layouts.size() >= 3); std::pair ret; ret.first = layouts; @@ -35,34 +34,30 @@ std::pair sub_opr_config( change_dtype(ret.first[2]); ret.second = opr->param(); - ret.second.compute_mode = - ConvolutionBackwardData::Param::ComputeMode::DEFAULT; + ret.second.compute_mode = ConvolutionBackwardData::Param::ComputeMode::DEFAULT; return ret; } -std::pair> -prepare_sub_opr(const ConvolutionBackwardDataImpl::AlgoBase::SizeArgs& args) { - auto conv_back_data_opr = - args.handle->create_operator(); +std::pair> prepare_sub_opr( + const ConvolutionBackwardDataImpl::AlgoBase::SizeArgs& args) { + auto conv_back_data_opr = args.handle->create_operator(); auto&& config = sub_opr_config( - {*args.filter_layout, *args.diff_layout, *args.grad_layout}, - args.opr); + {*args.filter_layout, *args.diff_layout, *args.grad_layout}, args.opr); conv_back_data_opr->param() = config.second; return {config.first, std::move(conv_back_data_opr)}; } } // namespace -std::vector -ConvolutionBackwardDataImpl::AlgoBFloat16::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { +std::vector ConvolutionBackwardDataImpl::AlgoBFloat16:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { auto&& config = sub_opr_config( layouts, static_cast(opr)); std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::CONVOLUTION_BACKWARD_DATA, param_str, - config.first}}; + return {{Algorithm::OprType::CONVOLUTION_BACKWARD_DATA, param_str, config.first}}; } bool ConvolutionBackwardDataImpl::AlgoBFloat16::is_available( @@ -70,17 +65,16 @@ bool ConvolutionBackwardDataImpl::AlgoBFloat16::is_available( auto config = prepare_sub_opr(args); return args.diff_layout->dtype == args.filter_layout->dtype && args.diff_layout->dtype == dtype::BFloat16() && - get_algorithm(static_cast( - config.second.get()), - config.first[0], config.first[1], config.first[2]); + get_algorithm( + static_cast(config.second.get()), + config.first[0], config.first[1], config.first[2]); } WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle( void* ptr, const SizeArgs& args) const { auto config = prepare_sub_opr(args); SmallVector sizes; - auto get_workspace = [&sizes](const TensorLayout& src, - const TensorLayout& dst) { + auto get_workspace = [&sizes](const TensorLayout& src, const TensorLayout& dst) { if (src.dtype != dst.dtype) { sizes.push_back(dst.span().dist_byte()); } @@ -99,8 +93,7 @@ size_t ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_in_bytes( return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } -void ConvolutionBackwardDataImpl::AlgoBFloat16::exec( - const ExecArgs& args) const { +void ConvolutionBackwardDataImpl::AlgoBFloat16::exec(const ExecArgs& args) const { TensorND ffilter_tensor = *args.filter_tensor; TensorND fdiff_tensor = *args.diff_tensor; TensorND fgrad_tensor = *args.grad_tensor; @@ -113,8 +106,8 @@ void ConvolutionBackwardDataImpl::AlgoBFloat16::exec( } { auto config = prepare_sub_opr(args); - config.second->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor, - cvter.workspace()); + config.second->exec( + ffilter_tensor, fdiff_tensor, fgrad_tensor, cvter.workspace()); } { cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } } diff --git a/dnn/src/cuda/convolution/backward_data/chanwise.cpp b/dnn/src/cuda/convolution/backward_data/chanwise.cpp index a5a7f510..d9bbc1a7 100644 --- a/dnn/src/cuda/convolution/backward_data/chanwise.cpp +++ b/dnn/src/cuda/convolution/backward_data/chanwise.cpp @@ -20,8 +20,7 @@ using namespace convolution; bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available( const SizeArgs& args) const { - if (!args.grad_layout->is_contiguous() || - !args.diff_layout->is_contiguous()) { + if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { return false; } if ((args.diff_layout->dtype == args.filter_layout->dtype && @@ -42,16 +41,14 @@ size_t ConvolutionBackwardDataImpl::AlgoChanwise::get_workspace_in_bytes( return 0; } -void ConvolutionBackwardDataImpl::AlgoChanwise::exec( - const ExecArgs& args) const { +void ConvolutionBackwardDataImpl::AlgoChanwise::exec(const ExecArgs& args) const { auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args()); auto stream = cuda_stream(args.handle); switch (args.diff_layout->dtype.enumv()) { case DTypeEnum::Float32: - return chanwise::run_bwd_data(args.grad_tensor->ptr(), - args.diff_tensor->ptr(), - args.filter_tensor->ptr(), - kparam, stream); + return chanwise::run_bwd_data( + args.grad_tensor->ptr(), args.diff_tensor->ptr(), + args.filter_tensor->ptr(), kparam, stream); case DTypeEnum::Float16: #if CUDA_VERSION >= 9000 @@ -59,8 +56,8 @@ void ConvolutionBackwardDataImpl::AlgoChanwise::exec( return chanwise::run_bwd_data( static_cast<__half*>(args.grad_tensor->raw_ptr), static_cast<__half*>(args.diff_tensor->raw_ptr), - static_cast<__half*>(args.filter_tensor->raw_ptr), - kparam, stream); + static_cast<__half*>(args.filter_tensor->raw_ptr), kparam, + stream); } else { return chanwise::run_bwd_data( args.grad_tensor->ptr(), @@ -68,10 +65,10 @@ void ConvolutionBackwardDataImpl::AlgoChanwise::exec( args.filter_tensor->ptr(), kparam, stream); } #else - return chanwise::run_bwd_data(args.grad_tensor->ptr(), - args.diff_tensor->ptr(), - args.filter_tensor->ptr(), - kparam, stream); + return chanwise::run_bwd_data( + args.grad_tensor->ptr(), + args.diff_tensor->ptr(), + args.filter_tensor->ptr(), kparam, stream); #endif default: diff --git a/dnn/src/cuda/convolution/backward_data/chanwise_small.cpp b/dnn/src/cuda/convolution/backward_data/chanwise_small.cpp index 755f5359..37b2f9eb 100644 --- a/dnn/src/cuda/convolution/backward_data/chanwise_small.cpp +++ b/dnn/src/cuda/convolution/backward_data/chanwise_small.cpp @@ -11,8 +11,8 @@ */ #include "src/cuda/convolution/backward_data/algo.h" -#include "src/cuda/utils.h" #include "src/cuda/convolution/chanwise/kern.cuh" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; @@ -21,17 +21,16 @@ using namespace convolution; namespace { inline bool is_available_small(const chanwise::Param& param) { return param.chl_mul == 1 && param.stride_h == 1 && param.stride_w == 1 && - param.src_h <= 32 && param.src_w <= 32 && - param.src_h == param.out_h && param.src_w == param.out_w && - param.pad_h < param.flt_h && param.pad_w < param.flt_w && + param.src_h <= 32 && param.src_w <= 32 && param.src_h == param.out_h && + param.src_w == param.out_w && param.pad_h < param.flt_h && + param.pad_w < param.flt_w && param.flt_h * param.flt_w <= (param.src_h + 1) / 2 * param.src_w; } } // anonymous namespace bool ConvolutionBackwardDataImpl::AlgoChanwiseSmall::is_available( const SizeArgs& args) const { - if (!args.grad_layout->is_contiguous() || - !args.diff_layout->is_contiguous()) { + if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { return false; } if ((args.diff_layout->dtype == args.filter_layout->dtype && @@ -58,23 +57,20 @@ size_t ConvolutionBackwardDataImpl::AlgoChanwiseSmall::get_workspace_in_bytes( return 0; } -void ConvolutionBackwardDataImpl::AlgoChanwiseSmall::exec( - const ExecArgs& args) const { +void ConvolutionBackwardDataImpl::AlgoChanwiseSmall::exec(const ExecArgs& args) const { auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args()); auto stream = cuda_stream(args.handle); switch (args.grad_layout->dtype.enumv()) { case DTypeEnum::Float32: return chanwise::run_bwd_data_small( - args.grad_tensor->ptr(), - args.diff_tensor->ptr(), + args.grad_tensor->ptr(), args.diff_tensor->ptr(), args.filter_tensor->ptr(), kparam, stream); #if CUDA_VERSION >= 9000 case DTypeEnum::Float16: return chanwise::run_bwd_data_small( static_cast(args.grad_tensor->raw_ptr), static_cast(args.diff_tensor->raw_ptr), - static_cast(args.filter_tensor->raw_ptr), kparam, - stream); + static_cast(args.filter_tensor->raw_ptr), kparam, stream); #endif default: break; diff --git a/dnn/src/cuda/convolution/backward_data/cudnn.cpp b/dnn/src/cuda/convolution/backward_data/cudnn.cpp index 5c51cf66..48099d04 100644 --- a/dnn/src/cuda/convolution/backward_data/cudnn.cpp +++ b/dnn/src/cuda/convolution/backward_data/cudnn.cpp @@ -11,21 +11,19 @@ #include "./algo.h" -#include "src/cuda/utils.h" -#include "src/cuda/cudnn_wrapper.h" -#include "src/cuda/convolution/helper.h" #include "src/cuda/conv_bias/helper.h" +#include "src/cuda/convolution/helper.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace convolution; -bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( - const SizeArgs &args) const { +bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available(const SizeArgs& args) const { if (args.filter_meta.format != Param::Format::NCHW && args.filter_meta.format != Param::Format::NHWC) { - if (!args.grad_layout->is_contiguous() || - !args.diff_layout->is_contiguous()) { + if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { return false; } } @@ -35,9 +33,11 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( TensorLayout bias_layout, z_layout; conv_bias::CanonizedFilterMeta meta; meta.copy_from(args.filter_meta); - conv_bias::BiasForwardSizeArgs bias_args{args.handle, - args.grad_layout, args.filter_layout, &bias_layout, - &z_layout, meta, args.diff_layout, param::ConvBias::NonlineMode::IDENTITY, + conv_bias::BiasForwardSizeArgs bias_args{ + args.handle, args.grad_layout, + args.filter_layout, &bias_layout, + &z_layout, meta, + args.diff_layout, param::ConvBias::NonlineMode::IDENTITY, }; if (!conv_bias::is_cudnn_supported(bias_args)) return false; @@ -45,53 +45,37 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( args.init_desc(D); size_t workspace_size; auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( - args.handle->cudnn_handle(), - D.filter_desc.desc, - D.diff_desc.desc, - D.conv_desc.desc, - D.grad_desc.desc, - m_cudnn_enum, - &workspace_size); + args.handle->cudnn_handle(), D.filter_desc.desc, D.diff_desc.desc, + D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); return status == CUDNN_STATUS_SUCCESS; } size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( - const SizeArgs &args) const { + const SizeArgs& args) const { CUDNNBwdDataDescs D; args.init_desc(D); size_t workspace_size; auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( - args.handle->cudnn_handle(), - D.filter_desc.desc, - D.diff_desc.desc, - D.conv_desc.desc, - D.grad_desc.desc, - m_cudnn_enum, - &workspace_size); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, + args.handle->cudnn_handle(), D.filter_desc.desc, D.diff_desc.desc, + D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, "conv bwd_data get workspace failed: %s; info: %s", cudnnGetErrorString(status), args.to_string().c_str()); return workspace_size; } -void ConvolutionBackwardDataImpl::AlgoCUDNN::exec( - const ExecArgs &args) const { +void ConvolutionBackwardDataImpl::AlgoCUDNN::exec(const ExecArgs& args) const { CUDNNBwdDataDescs D; args.init_desc(D); float alpha = 1.0f, beta = 0.0f; - auto status = cudnnConvolutionBackwardData(args.handle->cudnn_handle(), - &alpha, - D.filter_desc.desc, args.filter_tensor->raw_ptr, - D.diff_desc.desc, args.diff_tensor->raw_ptr, - D.conv_desc.desc, - m_cudnn_enum, - args.workspace.raw_ptr, - args.workspace.size, - &beta, - D.grad_desc.desc, - args.grad_tensor->raw_ptr); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, - "conv bwd_data failed: %s; info: %s", + auto status = cudnnConvolutionBackwardData( + args.handle->cudnn_handle(), &alpha, D.filter_desc.desc, + args.filter_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, + D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, + &beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", cudnnGetErrorString(status), args.to_string().c_str()); } diff --git a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu index 5d26536e..d333d474 100644 --- a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu +++ b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cu @@ -68,8 +68,7 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel( #pragma unroll for (int i = 0; i < interleaved; i++) { src_value[i] = *reinterpret_cast( - src + (ocb * interleaved + i) * FHFW * IC + fhfw * IC + - icb * 4); + src + (ocb * interleaved + i) * FHFW * IC + fhfw * IC + icb * 4); } auto trans = transpose_int8_interleavedx4(); @@ -77,9 +76,9 @@ __global__ void reorder_filter_nhwc_to_cnxhwx_kernel( #pragma unroll for (int i = 0; i < 4; i++) { - *reinterpret_cast(dst + (icb * 4 + i) * FHFW * OC + - (ocb * FHFW + fhfw) * interleaved) = - dst_value[i]; + *reinterpret_cast( + dst + (icb * 4 + i) * FHFW * OC + + (ocb * FHFW + fhfw) * interleaved) = dst_value[i]; } } } @@ -90,8 +89,7 @@ void megdnn::cuda::deconv::reorder_filter_nc4hw4_to_n4hwc4( int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, cudaStream_t stream) { dim3 threads(BLOCKSIZE_X, BLOCKSIZE_Y, 1); - dim3 blocks(DIVUP(FH * FW, BLOCKSIZE_X), DIVUP(IC / 4, BLOCKSIZE_Y), - OC / 4); + dim3 blocks(DIVUP(FH * FW, BLOCKSIZE_X), DIVUP(IC / 4, BLOCKSIZE_Y), OC / 4); reorder_filter_nc4hw4_to_n4hwc4_kernel<<>>( dst, src, OC, IC, FH * FW); @@ -107,16 +105,13 @@ void megdnn::cuda::deconv::reorder_filter_nhwc_to_cnxhwx( if (interleaved == 4) { reorder_filter_nhwc_to_cnxhwx_kernel<4, int> - <<>>(dst, src, OC, IC, - FH * FW); + <<>>(dst, src, OC, IC, FH * FW); } else if (interleaved == 8) { reorder_filter_nhwc_to_cnxhwx_kernel<8, int2> - <<>>(dst, src, OC, IC, - FH * FW); + <<>>(dst, src, OC, IC, FH * FW); } else { reorder_filter_nhwc_to_cnxhwx_kernel<16, int4> - <<>>(dst, src, OC, IC, - FH * FW); + <<>>(dst, src, OC, IC, FH * FW); } after_kernel_launch(); } diff --git a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cuh b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cuh index ea7baaee..3e5f2d11 100644 --- a/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cuh +++ b/dnn/src/cuda/convolution/backward_data/deconv_int8_helper.cuh @@ -16,13 +16,13 @@ namespace megdnn { namespace cuda { namespace deconv { -void reorder_filter_nc4hw4_to_n4hwc4(int8_t* dst, const int8_t* src, - uint32_t OC, uint32_t IC, uint32_t FH, - uint32_t FW, cudaStream_t stream); +void reorder_filter_nc4hw4_to_n4hwc4( + int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, + uint32_t FW, cudaStream_t stream); -void reorder_filter_nhwc_to_cnxhwx(int8_t* dst, const int8_t* src, uint32_t OC, - uint32_t IC, uint32_t FH, uint32_t FW, - uint32_t interleaved, cudaStream_t stream); +void reorder_filter_nhwc_to_cnxhwx( + int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, + uint32_t FW, uint32_t interleaved, cudaStream_t stream); } // namespace deconv } // namespace cuda diff --git a/dnn/src/cuda/convolution/backward_data/group_conv.cpp b/dnn/src/cuda/convolution/backward_data/group_conv.cpp index a87b01ec..b57b0de1 100644 --- a/dnn/src/cuda/convolution/backward_data/group_conv.cpp +++ b/dnn/src/cuda/convolution/backward_data/group_conv.cpp @@ -38,10 +38,9 @@ std::pair sub_opr_config( return ret; } -std::pair> -prepare_sub_opr(const ConvolutionBackwardDataImpl::AlgoBase::SizeArgs& args) { - auto conv_bwd_data_opr = - args.handle->create_operator(); +std::pair> prepare_sub_opr( + const ConvolutionBackwardDataImpl::AlgoBase::SizeArgs& args) { + auto conv_bwd_data_opr = args.handle->create_operator(); set_execution_policy( args.opr, conv_bwd_data_opr.get()); auto&& config = sub_opr_config(args); @@ -51,9 +50,9 @@ prepare_sub_opr(const ConvolutionBackwardDataImpl::AlgoBase::SizeArgs& args) { } } // namespace -std::vector -ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { +std::vector ConvolutionBackwardDataImpl::AlgoGroupConvGeneral:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { AlgoBase::SizeArgs args{ static_cast(opr), layouts[0], layouts[1], layouts[2]}; @@ -61,8 +60,7 @@ ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::get_subopr_list( std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::CONVOLUTION_BACKWARD_DATA, param_str, - config.first}}; + return {{Algorithm::OprType::CONVOLUTION_BACKWARD_DATA, param_str, config.first}}; } bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available( @@ -76,8 +74,7 @@ bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available( if (args.filter_meta.group <= 1) return false; - if (args.filter_meta.format != - megdnn::param::Convolution::Format::NCHW) { + if (args.filter_meta.format != megdnn::param::Convolution::Format::NCHW) { return false; } @@ -88,8 +85,7 @@ bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available( config.first[0], config.first[1], config.first[2]); } -WorkspaceBundle -ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::get_workspace_bundle( +WorkspaceBundle ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::get_workspace_bundle( void* ptr, const SizeArgs& args) const { auto config = prepare_sub_opr(args); size_t sizes = config.second->get_workspace_in_bytes( @@ -97,8 +93,7 @@ ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::get_workspace_bundle( return {ptr, {sizes}}; } -size_t -ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( +size_t ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( const SizeArgs& args) const { return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } @@ -118,10 +113,10 @@ void ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::exec( auto strd_flt = fm.icpg * fm.ocpg * fm.spatial[0] * fm.spatial[1] * tfilter.layout.dtype.size(), - strd_diff = tdiff.layout.stride[c_pos] * fm.ocpg * - tdiff.layout.dtype.size(), - strd_grad = (tgrad.layout.stride[c_pos] * fm.icpg * - tgrad.layout.dtype.size()); + strd_diff = + tdiff.layout.stride[c_pos] * fm.ocpg * tdiff.layout.dtype.size(), + strd_grad = + (tgrad.layout.stride[c_pos] * fm.icpg * tgrad.layout.dtype.size()); auto grp = args.filter_meta.group; for (uint32_t g = 0; g < grp; ++g) { diff --git a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp index fc4107db..af063c85 100644 --- a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp +++ b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw4_dp4a.cpp @@ -20,16 +20,15 @@ using namespace megdnn; using namespace cuda; -const void* -ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::get_available_op( - const SizeArgs& args) const { +const void* ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: + get_available_op(const SizeArgs& args) const { using namespace cutlass::library; auto&& fm = args.filter_meta; size_t sh = fm.stride[0], sw = fm.stride[1]; cutlass::conv::SpecialOptimizeDesc special_optimization = - (sh == 2 && sw == 2) ? cutlass::conv::SpecialOptimizeDesc:: - DECONV_DOUBLE_UPSAMPLING - : cutlass::conv::SpecialOptimizeDesc::NONE; + (sh == 2 && sw == 2) + ? cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING + : cutlass::conv::SpecialOptimizeDesc::NONE; ConvolutionKey key{ cutlass::conv::Operator::kDgrad, NumericTypeID::kS8, @@ -57,26 +56,25 @@ ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::get_available_op( return (void*)Singleton::get().operation_table.find_op(key); } -bool ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: - is_available(const SizeArgs& args) const { +bool ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( + const SizeArgs& args) const { auto&& fm = args.filter_meta; if (fm.format != Param::Format::NCHW4) return false; - if (!args.grad_layout->is_contiguous() || - !args.diff_layout->is_contiguous()) { + if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { return false; } bool available = true; - auto src_dtype = args.diff_layout->dtype, - filter_dtype = args.filter_layout->dtype, + auto src_dtype = args.diff_layout->dtype, filter_dtype = args.filter_layout->dtype, dst_dtype = args.grad_layout->dtype; - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8); + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); // TODO support group deconv int8 available &= (fm.group == 1); // mode must be cross correlation @@ -113,11 +111,9 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( auto&& fm = args.filter_meta; size_t n = args.diff_layout->operator[](0), co = args.diff_layout->operator[](1) * 4, - ho = args.diff_layout->operator[](2), - wo = args.diff_layout->operator[](3); + ho = args.diff_layout->operator[](2), wo = args.diff_layout->operator[](3); size_t ci = args.grad_layout->operator[](1) * 4, - hi = args.grad_layout->operator[](2), - wi = args.grad_layout->operator[](3); + hi = args.grad_layout->operator[](2), wi = args.grad_layout->operator[](3); size_t fh = fm.spatial[0], fw = fm.spatial[1]; size_t sh = fm.stride[0], sw = fm.stride[1]; size_t ph = fm.padding[0], pw = fm.padding[1]; @@ -131,22 +127,19 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( filter_ptr = reinterpret_cast(args.workspace.raw_ptr); // reformat filter from nc4hw4 to n4hwc4 megdnn::cuda::deconv::reorder_filter_nc4hw4_to_n4hwc4( - filter_ptr, args.filter_tensor->compatible_ptr(), co, - ci, fh, fw, stream); + filter_ptr, args.filter_tensor->compatible_ptr(), co, ci, fh, + fw, stream); } - float diff_scale = - args.diff_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - grad_scale = - args.grad_layout->dtype.param().scale; + float diff_scale = args.diff_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + grad_scale = args.grad_layout->dtype.param().scale; // \note these constants of cutlass epilogue will be passed to struct // `ConvolutionArguments` by pointer and interpreted as ElementCompute*, // a different dtype here results in undefined epilogue behaviors - float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, - gamma = 0.f, delta = 0.f; + float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, gamma = 0.f, + delta = 0.f; using namespace cutlass::library; diff --git a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp index 50ebde32..02ad9b15 100644 --- a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp +++ b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nchw_dp4a.cpp @@ -19,16 +19,15 @@ using namespace megdnn; using namespace cuda; -const void* -ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::get_available_op( - const SizeArgs& args) const { +const void* ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: + get_available_op(const SizeArgs& args) const { using namespace cutlass::library; auto&& fm = args.filter_meta; size_t sh = fm.stride[0], sw = fm.stride[1]; cutlass::conv::SpecialOptimizeDesc special_optimization = - (sh == 2 && sw == 2) ? cutlass::conv::SpecialOptimizeDesc:: - DECONV_DOUBLE_UPSAMPLING - : cutlass::conv::SpecialOptimizeDesc::NONE; + (sh == 2 && sw == 2) + ? cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING + : cutlass::conv::SpecialOptimizeDesc::NONE; // only use 16x64x8_16x64x8_2stages impl ConvolutionKey key{ cutlass::conv::Operator::kDgrad, @@ -63,25 +62,23 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::is_available( if (fm.format != Param::Format::NCHW) return false; - if (!args.grad_layout->is_contiguous() || - !args.diff_layout->is_contiguous()) { + if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { return false; } bool available = true; - auto src_dtype = args.diff_layout->dtype, - filter_dtype = args.filter_layout->dtype, + auto src_dtype = args.diff_layout->dtype, filter_dtype = args.filter_layout->dtype, dst_dtype = args.grad_layout->dtype; - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8); + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); // TODO support group deconv int8 available &= (fm.group == 1); // ic and oc must be multiples of 4 - available &= - ((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0); + available &= ((fm.group * fm.icpg) % 4 == 0 && (fm.group * fm.ocpg) % 4 == 0); // mode must be cross correlation available &= !fm.should_flip; // mode must be 2D @@ -117,12 +114,9 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( const ExecArgs& args) const { auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - size_t n = args.diff_layout->operator[](0), - co = args.diff_layout->operator[](1), - ho = args.diff_layout->operator[](2), - wo = args.diff_layout->operator[](3); - size_t ci = args.grad_layout->operator[](1), - hi = args.grad_layout->operator[](2), + size_t n = args.diff_layout->operator[](0), co = args.diff_layout->operator[](1), + ho = args.diff_layout->operator[](2), wo = args.diff_layout->operator[](3); + size_t ci = args.grad_layout->operator[](1), hi = args.grad_layout->operator[](2), wi = args.grad_layout->operator[](3); size_t fh = fm.spatial[0], fw = fm.spatial[1]; size_t sh = fm.stride[0], sw = fm.stride[1]; @@ -144,10 +138,9 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( exec_src = exec_src.dimshuffle({0, 3, 4, 2, 1}); - auto&& relayout = - args.opr->handle()->create_operator(); - relayout->exec({args.filter_tensor->raw_ptr, exec_src}, - {inner_filter_ptr, exec_dst}); + auto&& relayout = args.opr->handle()->create_operator(); + relayout->exec( + {args.filter_tensor->raw_ptr, exec_src}, {inner_filter_ptr, exec_dst}); } { inner_diff_ptr = reinterpret_cast(bundle.get(1)); @@ -157,25 +150,21 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( exec_src = exec_src.dimshuffle({0, 1, 3, 4, 2}); - auto&& relayout = - args.opr->handle()->create_operator(); - relayout->exec({args.diff_tensor->raw_ptr, exec_src}, - {inner_diff_ptr, exec_dst}); + auto&& relayout = args.opr->handle()->create_operator(); + relayout->exec( + {args.diff_tensor->raw_ptr, exec_src}, {inner_diff_ptr, exec_dst}); } int8_t* inner_grad_ptr = reinterpret_cast(bundle.get(2)); - float diff_scale = - args.diff_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - grad_scale = - args.grad_layout->dtype.param().scale; + float diff_scale = args.diff_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + grad_scale = args.grad_layout->dtype.param().scale; // \note these constants of cutlass epilogue will be passed to struct // `ConvolutionArguments` by pointer and interpreted as ElementCompute*, a // different dtype here results in undefined epilogue behaviors - float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, - gamma = 0.f, delta = 0.f; + float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, gamma = 0.f, + delta = 0.f; using namespace cutlass::library; @@ -205,10 +194,9 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( exec_src = exec_src.dimshuffle({0, 1, 4, 2, 3}); - auto&& relayout = - args.opr->handle()->create_operator(); - relayout->exec({inner_grad_ptr, exec_src}, - {args.grad_tensor->raw_ptr, exec_dst}); + auto&& relayout = args.opr->handle()->create_operator(); + relayout->exec( + {inner_grad_ptr, exec_src}, {args.grad_tensor->raw_ptr, exec_dst}); } } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nhwc_imma.cpp b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nhwc_imma.cpp index f2e7903f..04606fb2 100644 --- a/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nhwc_imma.cpp +++ b/dnn/src/cuda/convolution/backward_data/implicit_gemm_int8_nhwc_imma.cpp @@ -20,24 +20,24 @@ using namespace megdnn; using namespace cuda; -const void* -ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_available_op( +const void* ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_available_op( const SizeArgs& args) const { using namespace cutlass::library; auto&& fm = args.filter_meta; size_t sh = fm.stride[0], sw = fm.stride[1]; cutlass::conv::SpecialOptimizeDesc special_optimization = - (sh == 2 && sw == 2) ? cutlass::conv::SpecialOptimizeDesc:: - DECONV_DOUBLE_UPSAMPLING - : cutlass::conv::SpecialOptimizeDesc::NONE; + (sh == 2 && sw == 2) + ? cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING + : cutlass::conv::SpecialOptimizeDesc::NONE; LayoutTypeID filter_layout; if (m_algo_param.access_size == 16) { filter_layout = LayoutTypeID::kTensorCK16RS16; } else if (m_algo_param.access_size == 8) { filter_layout = LayoutTypeID::kTensorCK8RS8; } else { - megdnn_assert(m_algo_param.access_size == 4, "invalid access_size: %d", - m_algo_param.access_size); + megdnn_assert( + m_algo_param.access_size == 4, "invalid access_size: %d", + m_algo_param.access_size); filter_layout = LayoutTypeID::kTensorCK4RS4; } ConvolutionKey key{ @@ -73,22 +73,21 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::is_available( if (fm.format != Param::Format::NHWC) return false; - if (!args.grad_layout->is_contiguous() || - !args.diff_layout->is_contiguous()) { + if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { return false; } bool available = true; - auto src_dtype = args.diff_layout->dtype, - filter_dtype = args.filter_layout->dtype, + auto src_dtype = args.diff_layout->dtype, filter_dtype = args.filter_layout->dtype, dst_dtype = args.grad_layout->dtype; size_t co = args.diff_layout->operator[](3); size_t ci = args.grad_layout->operator[](3); - available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && - filter_dtype.enumv() == DTypeEnum::QuantizedS8 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8); + available &= + (src_dtype.enumv() == DTypeEnum::QuantizedS8 && + filter_dtype.enumv() == DTypeEnum::QuantizedS8 && + dst_dtype.enumv() == DTypeEnum::QuantizedS8); // TODO support group deconv int8 available &= (fm.group == 1); // mode must be cross correlation @@ -111,9 +110,8 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::is_available( return available; } -WorkspaceBundle -ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::get_workspace_bundle( - dt_byte* raw_ptr, const SizeArgs& args) const { +WorkspaceBundle ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm:: + get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const { size_t ws_filter = args.filter_layout->span().dist_byte(); return WorkspaceBundle{raw_ptr, {ws_filter}}; } @@ -127,12 +125,9 @@ void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( const ExecArgs& args) const { auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - size_t n = args.diff_layout->operator[](0), - co = args.diff_layout->operator[](3), - ho = args.diff_layout->operator[](1), - wo = args.diff_layout->operator[](2); - size_t ci = args.grad_layout->operator[](3), - hi = args.grad_layout->operator[](1), + size_t n = args.diff_layout->operator[](0), co = args.diff_layout->operator[](3), + ho = args.diff_layout->operator[](1), wo = args.diff_layout->operator[](2); + size_t ci = args.grad_layout->operator[](3), hi = args.grad_layout->operator[](1), wi = args.grad_layout->operator[](2); size_t fh = fm.spatial[0], fw = fm.spatial[1]; size_t sh = fm.stride[0], sw = fm.stride[1]; @@ -149,18 +144,15 @@ void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( reorder_filter(args, m_algo_param.access_size, filter_ptr); } - float diff_scale = - args.diff_layout->dtype.param().scale, - filter_scale = - args.filter_layout->dtype.param().scale, - grad_scale = - args.grad_layout->dtype.param().scale; + float diff_scale = args.diff_layout->dtype.param().scale, + filter_scale = args.filter_layout->dtype.param().scale, + grad_scale = args.grad_layout->dtype.param().scale; // \note these constants of cutlass epilogue will be passed to struct // `ConvolutionArguments` by pointer and interpreted as ElementCompute*, // a different dtype here results in undefined epilogue behaviors - float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, - gamma = 0.f, delta = 0.f; + float alpha = diff_scale * filter_scale / grad_scale, beta = 0.f, gamma = 0.f, + delta = 0.f; using namespace cutlass::library; @@ -188,8 +180,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( } void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::reorder_filter( - const ExecArgs& args, const int interleaved, - int8_t* reordered_filter) const { + const ExecArgs& args, const int interleaved, int8_t* reordered_filter) const { auto&& fm = args.filter_meta; size_t co = args.diff_layout->operator[](3); size_t ci = args.grad_layout->operator[](3); @@ -197,8 +188,8 @@ void ConvolutionBackwardDataImpl::AlgoInt8NHWCIMMAImplicitGemm::reorder_filter( auto&& stream = cuda_stream(args.opr->handle()); megdnn::cuda::deconv::reorder_filter_nhwc_to_cnxhwx( - reordered_filter, args.filter_tensor->compatible_ptr(), co, - ci, fh, fw, interleaved, stream); + reordered_filter, args.filter_tensor->compatible_ptr(), co, ci, fh, + fw, interleaved, stream); } void ConvolutionBackwardDataImpl::AlgoPack::fill_int8_imma_algos() { diff --git a/dnn/src/cuda/convolution/backward_data/matmul.cpp b/dnn/src/cuda/convolution/backward_data/matmul.cpp index 165e60f4..a328ca0b 100644 --- a/dnn/src/cuda/convolution/backward_data/matmul.cpp +++ b/dnn/src/cuda/convolution/backward_data/matmul.cpp @@ -24,11 +24,9 @@ namespace { std::pair sub_opr_config( const ConvolutionBackwardDataImpl::CanonizedFilterMeta& fm, const TensorLayout& filter_layout, const TensorLayout& diff_layout, - const TensorLayout& grad_layout, - const ConvolutionBackwardDataImpl* opr) { - size_t N = grad_layout.shape[0], IC = fm.icpg, - OC = fm.ocpg, OH = diff_layout.shape[2], - OW = diff_layout.shape[3], FH = fm.spatial[0], + const TensorLayout& grad_layout, const ConvolutionBackwardDataImpl* opr) { + size_t N = grad_layout.shape[0], IC = fm.icpg, OC = fm.ocpg, + OH = diff_layout.shape[2], OW = diff_layout.shape[3], FH = fm.spatial[0], FW = fm.spatial[1]; megdnn_assert(filter_layout.dtype.enumv() == diff_layout.dtype.enumv()); @@ -36,8 +34,7 @@ std::pair sub_opr_config( Bl({IC * FH * FW, OH * OW * N}, filter_layout.dtype), Cl({OC, OH * OW * N}, filter_layout.dtype); MatrixMulForward::Param param; - if (opr->param().compute_mode == - param::Convolution::ComputeMode::FLOAT32) { + if (opr->param().compute_mode == param::Convolution::ComputeMode::FLOAT32) { param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32; } @@ -50,32 +47,31 @@ std::pair> prepare_sub_opr( auto matmul_opr = args.handle->create_operator(); set_execution_policy( args.opr, matmul_opr.get()); - auto&& config = - sub_opr_config(args.filter_meta, *args.filter_layout, - *args.diff_layout, *args.grad_layout, args.opr); + auto&& config = sub_opr_config( + args.filter_meta, *args.filter_layout, *args.diff_layout, *args.grad_layout, + args.opr); matmul_opr->param() = config.second; return {config.first, std::move(matmul_opr)}; } } // namespace -std::vector -ConvolutionBackwardDataImpl::AlgoMatmul::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { +std::vector ConvolutionBackwardDataImpl::AlgoMatmul:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { const ConvolutionBackwardDataImpl* conv_backward_data_opr = static_cast(opr); CanonizedFilterMeta fm = conv_backward_data_opr->make_canonized_filter_meta( layouts[2].ndim, layouts[0]); - auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[2], - conv_backward_data_opr); + auto&& config = sub_opr_config( + fm, layouts[0], layouts[1], layouts[2], conv_backward_data_opr); std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, config.first}}; } -bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available( - const SizeArgs& args) const { +bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available(const SizeArgs& args) const { if (args.diff_layout->dtype == args.filter_layout->dtype && args.diff_layout->dtype == dtype::BFloat16()) { return false; @@ -110,16 +106,14 @@ void ConvolutionBackwardDataImpl::AlgoMatmul::exec(const ExecArgs& args) const { } template -void ConvolutionBackwardDataImpl::AlgoMatmul::exec_internal( - const ExecArgs& args) { +void ConvolutionBackwardDataImpl::AlgoMatmul::exec_internal(const ExecArgs& args) { auto&& fm = args.filter_meta; size_t N = args.grad_layout->shape[0], IC = fm.icpg, IH = args.grad_layout->shape[2], IW = args.grad_layout->shape[3], OC = fm.ocpg, OH = args.diff_layout->shape[2], - OW = args.diff_layout->shape[3], FH = fm.spatial[0], - FW = fm.spatial[1], PH = fm.padding[0], PW = fm.padding[1], - SH = fm.stride[0], SW = fm.stride[1], DH = fm.dilation[0], - DW = fm.dilation[1]; + OW = args.diff_layout->shape[3], FH = fm.spatial[0], FW = fm.spatial[1], + PH = fm.padding[0], PW = fm.padding[1], SH = fm.stride[0], SW = fm.stride[1], + DH = fm.dilation[0], DW = fm.dilation[1]; auto stream = cuda_stream(args.handle); auto config = prepare_sub_opr(args); @@ -144,13 +138,12 @@ void ConvolutionBackwardDataImpl::AlgoMatmul::exec_internal( { // take gemm grad TensorLayout Al({OC, IC * FH * FW}, typename DTypeTrait::dtype()), - Bl({IC * FH * FW, OH * OW * N}, - typename DTypeTrait::dtype()), + Bl({IC * FH * FW, OH * OW * N}, typename DTypeTrait::dtype()), Cl({OC, OH * OW * N}, typename DTypeTrait::dtype()); TensorND A(args.filter_tensor->ptr(), Al), B(col, Bl), C(diff_t, Cl); if (fm.should_flip) { - convolution::flip_filter(args.as_fwd_args(), - wbundle.get_workspace(2), A.raw_ptr); + convolution::flip_filter( + args.as_fwd_args(), wbundle.get_workspace(2), A.raw_ptr); config.second->exec(A, C, B, wbundle.get_workspace(3)); } else { config.second->exec(A, C, B, wbundle.get_workspace(2)); @@ -158,9 +151,9 @@ void ConvolutionBackwardDataImpl::AlgoMatmul::exec_internal( } { // col2im - convolution::col2im(col, args.grad_tensor->ptr(), N, - args.grad_layout->stride[0], IC, IH, IW, FH, FW, - OH, OW, PH, PW, SH, SW, DH, DW, stream); + convolution::col2im( + col, args.grad_tensor->ptr(), N, args.grad_layout->stride[0], IC, IH, + IW, FH, FW, OH, OW, PH, PW, SH, SW, DH, DW, stream); } } diff --git a/dnn/src/cuda/convolution/backward_filter/algo.cpp b/dnn/src/cuda/convolution/backward_filter/algo.cpp index 9bf5c4ed..ba4c886c 100644 --- a/dnn/src/cuda/convolution/backward_filter/algo.cpp +++ b/dnn/src/cuda/convolution/backward_filter/algo.cpp @@ -19,10 +19,10 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { non_cudnn_algos.push_back(&chanwise); non_cudnn_algos.push_back(&matmul); - all_algos.push_back(&chanwise); // prefer chanwise + all_algos.push_back(&chanwise); // prefer chanwise fill_cudnn_algos(); - for (auto &&i: cudnn) { + for (auto&& i : cudnn) { all_algos.push_back(&i); } all_algos.push_back(&matmul); @@ -38,27 +38,22 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionBackwardFilterImpl) -ConvolutionBackwardFilterImpl::AlgoCUDNN* -ConvolutionBackwardFilterImpl::AlgoPack::cudnn_from_enum( - cudnnConvolutionBwdFilterAlgo_t algo) { - for (auto &&i: cudnn) { +ConvolutionBackwardFilterImpl::AlgoCUDNN* ConvolutionBackwardFilterImpl::AlgoPack:: + cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo) { + for (auto&& i : cudnn) { if (i.cudnn_enum() == algo) return &i; } - megdnn_throw(ssprintf("can not find cudnn bwd_filter algorithm %d", - static_cast(algo))); + megdnn_throw(ssprintf( + "can not find cudnn bwd_filter algorithm %d", static_cast(algo))); } -ConvolutionBackwardFilterImpl::AlgoPack -ConvolutionBackwardFilterImpl::sm_algo_pack; +ConvolutionBackwardFilterImpl::AlgoPack ConvolutionBackwardFilterImpl::sm_algo_pack; ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( - const ConvolutionBackwardFilterImpl *o, - const TensorLayout &src, const TensorLayout &diff, - const TensorLayout &grad): - SizeArgs(o, src, diff, grad, o->make_canonized_filter_meta(src.ndim, grad)) -{ -} + const ConvolutionBackwardFilterImpl* o, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad) + : SizeArgs(o, src, diff, grad, o->make_canonized_filter_meta(src.ndim, grad)) {} ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( const ConvolutionBackwardFilterImpl* o, const TensorLayout& src, @@ -72,29 +67,24 @@ ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( opr{o} {} ConvolutionBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs( - const ConvolutionBackwardFilterImpl *opr, - _megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace): - SizeArgs(opr, src.layout, diff.layout, grad.layout), - src_tensor{&src}, diff_tensor{&diff}, grad_tensor{&grad}, - workspace{workspace} -{ -} + const ConvolutionBackwardFilterImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) + : SizeArgs(opr, src.layout, diff.layout, grad.layout), + src_tensor{&src}, + diff_tensor{&diff}, + grad_tensor{&grad}, + workspace{workspace} {} -std::string -ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::to_string() const { - auto &&fm = grad_filter_meta; +std::string ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::to_string() const { + auto&& fm = grad_filter_meta; MEGDNN_MARK_USED_VAR(fm); return ssprintf( "src=%s diff=%s grad_filter=%u{%u,%u,%u,%u}, " "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s", - src_layout->to_string().c_str(), diff_layout->to_string().c_str(), - fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], - fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], - fm.dilation[0], fm.dilation[1], !fm.should_flip, - src_layout->dtype.name(), diff_layout->dtype.name()); + src_layout->to_string().c_str(), diff_layout->to_string().c_str(), fm.group, + fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.padding[0], + fm.padding[1], fm.stride[0], fm.stride[1], fm.dilation[0], fm.dilation[1], + !fm.should_flip, src_layout->dtype.name(), diff_layout->dtype.name()); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/backward_filter/algo.h b/dnn/src/cuda/convolution/backward_filter/algo.h index bb83c4c5..9f8cd758 100644 --- a/dnn/src/cuda/convolution/backward_filter/algo.h +++ b/dnn/src/cuda/convolution/backward_filter/algo.h @@ -51,26 +51,26 @@ public: void init_desc(convolution::CUDNNBwdFilterDescs& desc) const { desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param()); } - SizeArgs(const ConvolutionBackwardFilterImpl* opr, - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad); - SizeArgs(const ConvolutionBackwardFilterImpl* opr, - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, - const CanonizedFilterMeta& grad_meta); + SizeArgs( + const ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad); + SizeArgs( + const ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad, + const CanonizedFilterMeta& grad_meta); convolution::ForwardSizeArgs as_fwd_args() const { - return {handle, src_layout, grad_layout, grad_filter_meta, - diff_layout}; + return {handle, src_layout, grad_layout, grad_filter_meta, diff_layout}; } }; struct ExecArgs : public SizeArgs { const TensorND *src_tensor, *diff_tensor, *grad_tensor; Workspace workspace; - ExecArgs(const ConvolutionBackwardFilterImpl* opr, - _megdnn_tensor_in src, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace); + ExecArgs( + const ConvolutionBackwardFilterImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -86,17 +86,16 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "conv bwd filter algo %s: " - "required workspace %zu bytes, got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "conv bwd filter algo %s: " + "required workspace %zu bytes, got %zu", + name(), req, workspace.size); return *this; } @@ -108,10 +107,10 @@ class ConvolutionBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { CudnnAlgoPack::Attr m_attr; public: - AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) - : m_cudnn_enum(cudnn_enum) { - megdnn_assert(CudnnAlgoPack::conv_bwd_flt_algos().find(cudnn_enum) != - CudnnAlgoPack::conv_bwd_flt_algos().end()); + AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { + megdnn_assert( + CudnnAlgoPack::conv_bwd_flt_algos().find(cudnn_enum) != + CudnnAlgoPack::conv_bwd_flt_algos().end()); m_attr = CudnnAlgoPack::conv_bwd_flt_algos().at(cudnn_enum); } @@ -154,14 +153,12 @@ public: void exec(const ExecArgs& args) const override; std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; const char* name() const override { return "MATMUL"; } MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; } }; @@ -173,9 +170,7 @@ public: const char* name() const override { return "CHANNEL_WISE"; } MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } }; class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { @@ -185,16 +180,11 @@ public: void exec(const ExecArgs& args) const override; std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; - const char* name() const override { - return "CONVOLUTION_BACKWARD_FILTER_BFLOAT16"; - } + const char* name() const override { return "CONVOLUTION_BACKWARD_FILTER_BFLOAT16"; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) @@ -203,24 +193,18 @@ private: }; //! implement group conv by another algo -class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final - : public AlgoBase { +class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; - const char* name() const override { - return "CUDA:GROUP_CONV_BACKWARD_FILTER"; - } + const char* name() const override { return "CUDA:GROUP_CONV_BACKWARD_FILTER"; } MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; diff --git a/dnn/src/cuda/convolution/backward_filter/bfloat16.cpp b/dnn/src/cuda/convolution/backward_filter/bfloat16.cpp index bbd4ac1f..58ae0413 100644 --- a/dnn/src/cuda/convolution/backward_filter/bfloat16.cpp +++ b/dnn/src/cuda/convolution/backward_filter/bfloat16.cpp @@ -20,9 +20,8 @@ using namespace cuda; using namespace convolution; namespace { -std::pair -sub_opr_config(const TensorLayoutArray& layouts, - const ConvolutionBackwardFilterImpl* opr) { +std::pair sub_opr_config( + const TensorLayoutArray& layouts, const ConvolutionBackwardFilterImpl* opr) { megdnn_assert(layouts.size() >= 3); std::pair ret; ret.first = layouts; @@ -36,13 +35,12 @@ sub_opr_config(const TensorLayoutArray& layouts, change_dtype(ret.first[2]); ret.second = opr->param(); - ret.second.compute_mode = - ConvolutionBackwardFilter::Param::ComputeMode::DEFAULT; + ret.second.compute_mode = ConvolutionBackwardFilter::Param::ComputeMode::DEFAULT; return ret; } -std::pair> -prepare_sub_opr(const ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs& args) { +std::pair> prepare_sub_opr( + const ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs& args) { auto conv_back_filter_opr = args.handle->create_operator(); @@ -54,16 +52,15 @@ prepare_sub_opr(const ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs& args) { } } // namespace -std::vector -ConvolutionBackwardFilterImpl::AlgoBFloat16::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { +std::vector ConvolutionBackwardFilterImpl::AlgoBFloat16:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { auto&& config = sub_opr_config( layouts, static_cast(opr)); std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::CONVOLUTION_BACKWARD_FILTER, param_str, - config.first}}; + return {{Algorithm::OprType::CONVOLUTION_BACKWARD_FILTER, param_str, config.first}}; } bool ConvolutionBackwardFilterImpl::AlgoBFloat16::is_available( @@ -71,18 +68,16 @@ bool ConvolutionBackwardFilterImpl::AlgoBFloat16::is_available( auto config = prepare_sub_opr(args); return args.src_layout->dtype == args.diff_layout->dtype && args.src_layout->dtype == dtype::BFloat16() && - get_algorithm(static_cast( - config.second.get()), - config.first[0], config.first[1], config.first[2]); + get_algorithm( + static_cast(config.second.get()), + config.first[0], config.first[1], config.first[2]); } -WorkspaceBundle -ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle( +WorkspaceBundle ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle( void* ptr, const SizeArgs& args) const { auto config = prepare_sub_opr(args); SmallVector sizes; - auto get_workspace = [&sizes](const TensorLayout& src, - const TensorLayout& dst) { + auto get_workspace = [&sizes](const TensorLayout& src, const TensorLayout& dst) { if (src.dtype != dst.dtype) { sizes.push_back(dst.span().dist_byte()); } @@ -102,8 +97,7 @@ size_t ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_in_bytes( return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } -void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec( - const ExecArgs& args) const { +void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec(const ExecArgs& args) const { TensorND fsrc_tensor = *args.src_tensor; TensorND fdiff_tensor = *args.diff_tensor; TensorND fgrad_tensor = *args.grad_tensor; @@ -116,8 +110,7 @@ void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec( } { auto config = prepare_sub_opr(args); - config.second->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, - cvter.workspace()); + config.second->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, cvter.workspace()); } { cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } } diff --git a/dnn/src/cuda/convolution/backward_filter/chanwise.cpp b/dnn/src/cuda/convolution/backward_filter/chanwise.cpp index e6d6893d..b18e44d5 100644 --- a/dnn/src/cuda/convolution/backward_filter/chanwise.cpp +++ b/dnn/src/cuda/convolution/backward_filter/chanwise.cpp @@ -10,54 +10,50 @@ */ #include "./algo.h" -#include "src/cuda/utils.h" #include "src/cuda/convolution/chanwise/kern.cuh" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace convolution; bool ConvolutionBackwardFilterImpl::AlgoChanwise::is_available( - const SizeArgs &args) const { - if (!args.src_layout->is_contiguous() || - !args.diff_layout->is_contiguous()) { + const SizeArgs& args) const { + if (!args.src_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { return false; } if (args.src_layout->dtype == args.src_layout->dtype && args.diff_layout->dtype == dtype::BFloat16()) { return false; } - auto &&fm = args.grad_filter_meta; + auto&& fm = args.grad_filter_meta; return fm.format == Param::Format::NCHW && - args.diff_layout->dtype.category() == DTypeCategory::FLOAT && - fm.spatial_ndim == 2 && fm.icpg == 1 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - !fm.should_flip; + args.diff_layout->dtype.category() == DTypeCategory::FLOAT && + fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && !fm.should_flip; } size_t ConvolutionBackwardFilterImpl::AlgoChanwise::get_workspace_in_bytes( - const SizeArgs &) const { + const SizeArgs&) const { return 0; } -void ConvolutionBackwardFilterImpl::AlgoChanwise::exec( - const ExecArgs &args) const { +void ConvolutionBackwardFilterImpl::AlgoChanwise::exec(const ExecArgs& args) const { auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args()); auto stream = cuda_stream(args.handle); switch (args.diff_layout->dtype.enumv()) { - case DTypeEnum::Float32: - return chanwise::run_bwd_filter(args.grad_tensor->ptr(), - args.src_tensor->ptr(), - args.diff_tensor->ptr(), - kparam, stream); - case DTypeEnum::Float16: + case DTypeEnum::Float32: + return chanwise::run_bwd_filter( + args.grad_tensor->ptr(), args.src_tensor->ptr(), + args.diff_tensor->ptr(), kparam, stream); + case DTypeEnum::Float16: #if CUDA_VERSION >= 9000 if (is_compute_capability_required(5, 3)) { - return chanwise::run_bwd_filter( - static_cast<__half*>(args.grad_tensor->raw_ptr), - static_cast<__half*>(args.src_tensor->raw_ptr), - static_cast<__half*>(args.diff_tensor->raw_ptr), - kparam, stream); + return chanwise::run_bwd_filter( + static_cast<__half*>(args.grad_tensor->raw_ptr), + static_cast<__half*>(args.src_tensor->raw_ptr), + static_cast<__half*>(args.diff_tensor->raw_ptr), kparam, + stream); } else { return chanwise::run_bwd_filter( args.grad_tensor->ptr(), @@ -65,10 +61,10 @@ void ConvolutionBackwardFilterImpl::AlgoChanwise::exec( args.diff_tensor->ptr(), kparam, stream); } #else - return chanwise::run_bwd_filter(args.grad_tensor->ptr(), - args.src_tensor->ptr(), - args.diff_tensor->ptr(), - kparam, stream); + return chanwise::run_bwd_filter( + args.grad_tensor->ptr(), + args.src_tensor->ptr(), + args.diff_tensor->ptr(), kparam, stream); #endif default: @@ -78,4 +74,3 @@ void ConvolutionBackwardFilterImpl::AlgoChanwise::exec( } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution/backward_filter/cudnn.cpp b/dnn/src/cuda/convolution/backward_filter/cudnn.cpp index 75ff3874..14731fce 100644 --- a/dnn/src/cuda/convolution/backward_filter/cudnn.cpp +++ b/dnn/src/cuda/convolution/backward_filter/cudnn.cpp @@ -11,21 +11,20 @@ #include "./algo.h" -#include "src/cuda/utils.h" -#include "src/cuda/cudnn_wrapper.h" -#include "src/cuda/convolution/helper.h" #include "src/cuda/conv_bias/helper.h" +#include "src/cuda/convolution/helper.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace convolution; bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( - const SizeArgs &args) const { + const SizeArgs& args) const { if (args.grad_filter_meta.format != Param::Format::NCHW && args.grad_filter_meta.format != Param::Format::NHWC) { - if (!args.grad_layout->is_contiguous() || - !args.diff_layout->is_contiguous()) { + if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { return false; } } @@ -34,9 +33,11 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( TensorLayout bias_layout, z_layout; conv_bias::CanonizedFilterMeta meta; meta.copy_from(args.grad_filter_meta); - conv_bias::BiasForwardSizeArgs bias_args{args.handle, - args.src_layout, args.grad_layout, &bias_layout, - &z_layout, meta, args.diff_layout, param::ConvBias::NonlineMode::IDENTITY, + conv_bias::BiasForwardSizeArgs bias_args{ + args.handle, args.src_layout, + args.grad_layout, &bias_layout, + &z_layout, meta, + args.diff_layout, param::ConvBias::NonlineMode::IDENTITY, }; if (!conv_bias::is_cudnn_supported(bias_args)) return false; @@ -44,58 +45,42 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( args.init_desc(D); size_t workspace_size; auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( - args.handle->cudnn_handle(), - D.src_desc.desc, - D.diff_desc.desc, - D.conv_desc.desc, - D.grad_desc.desc, - m_cudnn_enum, - &workspace_size); + args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, + D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); return status == CUDNN_STATUS_SUCCESS; } size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( - const SizeArgs &args) const { + const SizeArgs& args) const { CUDNNBwdFilterDescs D; args.init_desc(D); size_t workspace_size; auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( - args.handle->cudnn_handle(), - D.src_desc.desc, - D.diff_desc.desc, - D.conv_desc.desc, - D.grad_desc.desc, - m_cudnn_enum, - &workspace_size); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, + args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, + D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, "conv bwd_filter get workspace failed: %s; info: %s", cudnnGetErrorString(status), args.to_string().c_str()); return workspace_size; } -void ConvolutionBackwardFilterImpl::AlgoCUDNN::exec( - const ExecArgs &args) const { +void ConvolutionBackwardFilterImpl::AlgoCUDNN::exec(const ExecArgs& args) const { CUDNNBwdFilterDescs D; args.init_desc(D); float alpha = 1.0f, beta = 0.0f; - auto status = cudnnConvolutionBackwardFilter(args.handle->cudnn_handle(), - &alpha, - D.src_desc.desc, args.src_tensor->raw_ptr, - D.diff_desc.desc, args.diff_tensor->raw_ptr, - D.conv_desc.desc, - m_cudnn_enum, - args.workspace.raw_ptr, - args.workspace.size, - &beta, - D.grad_desc.desc, - args.grad_tensor->raw_ptr); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, - "conv bwd_data failed: %s; info: %s", + auto status = cudnnConvolutionBackwardFilter( + args.handle->cudnn_handle(), &alpha, D.src_desc.desc, + args.src_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, + D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, + &beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", cudnnGetErrorString(status), args.to_string().c_str()); } void ConvolutionBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { - for(auto&& algo : CudnnAlgoPack::conv_bwd_flt_algos()) { + for (auto&& algo : CudnnAlgoPack::conv_bwd_flt_algos()) { cudnn.push_back(algo.first); } } diff --git a/dnn/src/cuda/convolution/backward_filter/group_conv.cpp b/dnn/src/cuda/convolution/backward_filter/group_conv.cpp index 548e54b4..4b3ef295 100644 --- a/dnn/src/cuda/convolution/backward_filter/group_conv.cpp +++ b/dnn/src/cuda/convolution/backward_filter/group_conv.cpp @@ -37,8 +37,8 @@ std::pair sub_opr_config( return ret; } -std::pair> -prepare_sub_opr(const ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs& args) { +std::pair> prepare_sub_opr( + const ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs& args) { auto conv_bwd_filter_opr = args.handle->create_operator(); set_execution_policy( @@ -50,9 +50,9 @@ prepare_sub_opr(const ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs& args) { } } // namespace -std::vector -ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { +std::vector ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { AlgoBase::SizeArgs args{ static_cast(opr), layouts[0], layouts[1], layouts[2]}; @@ -60,20 +60,18 @@ ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::get_subopr_list( std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::CONVOLUTION_BACKWARD_FILTER, param_str, - config.first}}; + return {{Algorithm::OprType::CONVOLUTION_BACKWARD_FILTER, param_str, config.first}}; } bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available( - const SizeArgs &args) const { + const SizeArgs& args) const { if (args.src_layout->dtype == args.src_layout->dtype && args.diff_layout->dtype == dtype::BFloat16()) { return false; } if (args.grad_filter_meta.group <= 1) return false; - if (args.grad_filter_meta.format != - megdnn::param::Convolution::Format::NCHW) { + if (args.grad_filter_meta.format != megdnn::param::Convolution::Format::NCHW) { return false; } @@ -84,17 +82,15 @@ bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available( config.first[0], config.first[1], config.first[2]); } -WorkspaceBundle -ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::get_workspace_bundle( - void* ptr, const SizeArgs& args) const { +WorkspaceBundle ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral:: + get_workspace_bundle(void* ptr, const SizeArgs& args) const { auto config = prepare_sub_opr(args); size_t sizes = config.second->get_workspace_in_bytes( config.first[0], config.first[1], config.first[2]); return {ptr, {sizes}}; } -size_t -ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( +size_t ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( const SizeArgs& args) const { return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } @@ -113,10 +109,9 @@ void ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::exec( auto&& fm = args.grad_filter_meta; - auto strd_src = tsrc.layout.stride[c_pos] * fm.icpg * - tsrc.layout.dtype.size(), - strd_diff = tdiff.layout.stride[c_pos] * fm.ocpg * - tdiff.layout.dtype.size(), + auto strd_src = tsrc.layout.stride[c_pos] * fm.icpg * tsrc.layout.dtype.size(), + strd_diff = + tdiff.layout.stride[c_pos] * fm.ocpg * tdiff.layout.dtype.size(), strd_grad = fm.icpg * fm.ocpg * fm.spatial[0] * fm.spatial[1] * tgrad.layout.dtype.size(); @@ -131,4 +126,3 @@ void ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::exec( } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution/backward_filter/matmul.cpp b/dnn/src/cuda/convolution/backward_filter/matmul.cpp index 4c4a2253..529d61e5 100644 --- a/dnn/src/cuda/convolution/backward_filter/matmul.cpp +++ b/dnn/src/cuda/convolution/backward_filter/matmul.cpp @@ -23,11 +23,9 @@ namespace { std::pair sub_opr_config( const ConvolutionBackwardFilterImpl::CanonizedFilterMeta& fm, const TensorLayout& src_layout, const TensorLayout& diff_layout, - const TensorLayout& grad_layout, - const ConvolutionBackwardFilterImpl* opr) { - size_t N = src_layout.shape[0], IC = fm.icpg, - OC = fm.ocpg, OH = diff_layout.shape[2], - OW = diff_layout.shape[3], FH = fm.spatial[0], + const TensorLayout& grad_layout, const ConvolutionBackwardFilterImpl* opr) { + size_t N = src_layout.shape[0], IC = fm.icpg, OC = fm.ocpg, + OH = diff_layout.shape[2], OW = diff_layout.shape[3], FH = fm.spatial[0], FW = fm.spatial[1]; megdnn_assert(src_layout.dtype.enumv() == diff_layout.dtype.enumv()); @@ -35,8 +33,7 @@ std::pair sub_opr_config( Bl({IC * FH * FW, OH * OW * N}, src_layout.dtype), Cl({OC, OH * OW * N}, src_layout.dtype); MatrixMulForward::Param param; - if (opr->param().compute_mode == - param::Convolution::ComputeMode::FLOAT32) { + if (opr->param().compute_mode == param::Convolution::ComputeMode::FLOAT32) { param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32; } @@ -50,25 +47,24 @@ std::pair> prepare_sub_opr( set_execution_policy( args.opr, matmul_opr.get()); - auto&& config = - sub_opr_config(args.grad_filter_meta, *args.src_layout, - *args.diff_layout, *args.grad_layout, args.opr); + auto&& config = sub_opr_config( + args.grad_filter_meta, *args.src_layout, *args.diff_layout, + *args.grad_layout, args.opr); matmul_opr->param() = config.second; return {config.first, std::move(matmul_opr)}; } } // namespace -std::vector -ConvolutionBackwardFilterImpl::AlgoMatmul::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { +std::vector ConvolutionBackwardFilterImpl::AlgoMatmul:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { const ConvolutionBackwardFilterImpl* conv_backward_filter_opr = static_cast(opr); - CanonizedFilterMeta fm = - conv_backward_filter_opr->make_canonized_filter_meta( - layouts[0].ndim, layouts[2]); - auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[2], - conv_backward_filter_opr); + CanonizedFilterMeta fm = conv_backward_filter_opr->make_canonized_filter_meta( + layouts[0].ndim, layouts[2]); + auto&& config = sub_opr_config( + fm, layouts[0], layouts[1], layouts[2], conv_backward_filter_opr); std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); @@ -97,8 +93,7 @@ size_t ConvolutionBackwardFilterImpl::AlgoMatmul::get_workspace_in_bytes( return WorkspaceBundle(nullptr, sizes).total_size_in_bytes(); } -void ConvolutionBackwardFilterImpl::AlgoMatmul::exec( - const ExecArgs& args) const { +void ConvolutionBackwardFilterImpl::AlgoMatmul::exec(const ExecArgs& args) const { #define cb(DType) \ if (args.diff_layout->dtype == DType()) { \ using ctype = typename DTypeTrait::ctype; \ @@ -112,16 +107,14 @@ void ConvolutionBackwardFilterImpl::AlgoMatmul::exec( } template -void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal( - const ExecArgs& args) { +void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal(const ExecArgs& args) { auto&& fm = args.grad_filter_meta; - size_t N = args.src_layout->shape[0], IC = fm.icpg, - IH = args.src_layout->shape[2], IW = args.src_layout->shape[3], - OC = fm.ocpg, OH = args.diff_layout->shape[2], - OW = args.diff_layout->shape[3], FH = fm.spatial[0], - FW = fm.spatial[1], PH = fm.padding[0], PW = fm.padding[1], - SH = fm.stride[0], SW = fm.stride[1], DH = fm.dilation[0], - DW = fm.dilation[1]; + size_t N = args.src_layout->shape[0], IC = fm.icpg, IH = args.src_layout->shape[2], + IW = args.src_layout->shape[3], OC = fm.ocpg, + OH = args.diff_layout->shape[2], OW = args.diff_layout->shape[3], + FH = fm.spatial[0], FW = fm.spatial[1], PH = fm.padding[0], + PW = fm.padding[1], SH = fm.stride[0], SW = fm.stride[1], + DH = fm.dilation[0], DW = fm.dilation[1]; auto stream = cuda_stream(args.handle); auto config = prepare_sub_opr(args); @@ -145,15 +138,14 @@ void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal( } { // im2col - convolution::im2col(args.src_tensor->ptr(), col, N, - args.src_tensor->layout.stride[0], IC, IH, IW, - FH, FW, OH, OW, PH, PW, SH, SW, DH, DW, stream); + convolution::im2col( + args.src_tensor->ptr(), col, N, args.src_tensor->layout.stride[0], + IC, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW, DH, DW, stream); } { // take gemm grad TensorLayout Al({OC, IC * FH * FW}, typename DTypeTrait::dtype()), - Bl({IC * FH * FW, OH * OW * N}, - typename DTypeTrait::dtype()), + Bl({IC * FH * FW, OH * OW * N}, typename DTypeTrait::dtype()), Cl({OC, OH * OW * N}, typename DTypeTrait::dtype()); TensorND A(args.grad_tensor->ptr(), Al), B(col, Bl), C(diff_t, Cl); if (fm.should_flip) { diff --git a/dnn/src/cuda/convolution/chanwise/bwd_data.cu b/dnn/src/cuda/convolution/chanwise/bwd_data.cu index d2156218..ea58a1cf 100644 --- a/dnn/src/cuda/convolution/chanwise/bwd_data.cu +++ b/dnn/src/cuda/convolution/chanwise/bwd_data.cu @@ -24,10 +24,9 @@ namespace { // grid idx is (inp_chl, worker_index) // each y-slice of a block works on an (N, IH, IW) spatial image at given // inp_chl -template -__global__ void kern_bwd_data_float(T* src_grad, const T* dst_grad, - const T* flt_tot, Param param) { +template +__global__ void kern_bwd_data_float( + T* src_grad, const T* dst_grad, const T* flt_tot, Param param) { // extern __shared__ of dt_float16 does not work extern __shared__ uint8_t flt_storage[]; @@ -78,8 +77,7 @@ __global__ void kern_bwd_data_float(T* src_grad, const T* dst_grad, const T* pd = dst_grad_base + oh * OW + ow; const T* pf = flt + fh * FW + fw; #pragma unroll - for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; - ++chl_mul) { + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) { sum += *pd * *pf; pd += OH * OW; pf += FSIZE; @@ -110,10 +108,9 @@ __global__ void kern_bwd_data_float(T* src_grad, const T* dst_grad, } #if CUDA_VERSION >= 9000 -template -__global__ void kern_bwd_data_hf(__half* src_grad, const __half* dst_grad, - const __half* flt_tot, Param param) { +template +__global__ void kern_bwd_data_hf( + __half* src_grad, const __half* dst_grad, const __half* flt_tot, Param param) { extern __shared__ uint8_t flt_storage[]; __half* const flt = reinterpret_cast<__half*>(flt_storage); @@ -141,8 +138,7 @@ __global__ void kern_bwd_data_hf(__half* src_grad, const __half* dst_grad, out_idx = div_mod(out_idx, IH, ih); n = out_idx; - const __half* dst_grad_base = - dst_grad + n * (IC * CHL_MUL * OH * OW); + const __half* dst_grad_base = dst_grad + n * (IC * CHL_MUL * OH * OW); __half2 sum{0.0, 0.0}; __half2 pd2{0.0, 0.0}; @@ -167,10 +163,8 @@ __global__ void kern_bwd_data_hf(__half* src_grad, const __half* dst_grad, if (FW == 3) { #pragma unroll - for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; - ++chl_mul) { - __half2 flt0 = {0.0, *(pf)}, - flt1 = {*(pf), *(pf + 1)}, + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) { + __half2 flt0 = {0.0, *(pf)}, flt1 = {*(pf), *(pf + 1)}, flt2 = {*(pf + 1), *(pf + 2)}, flt3 = {*(pf + 2), 0.0}; uint32_t ow = owmin; @@ -209,10 +203,8 @@ __global__ void kern_bwd_data_hf(__half* src_grad, const __half* dst_grad, } } else if (FW == 5) { #pragma unroll - for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; - ++chl_mul) { - __half2 flt0 = {0.0, *(pf)}, - flt1 = {*(pf), *(pf + 1)}, + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) { + __half2 flt0 = {0.0, *(pf)}, flt1 = {*(pf), *(pf + 1)}, flt2 = {*(pf + 1), *(pf + 2)}, flt3 = {*(pf + 2), *(pf + 3)}, flt4 = {*(pf + 3), *(pf + 4)}, @@ -272,8 +264,7 @@ __global__ void kern_bwd_data_hf(__half* src_grad, const __half* dst_grad, } } else { #pragma unroll - for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; - ++chl_mul) { + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) { #pragma unroll for (uint32_t dow = 0; dow <= FW; ++dow) { uint32_t ow = owmin + dow; @@ -309,8 +300,7 @@ __global__ void kern_bwd_data_hf(__half* src_grad, const __half* dst_grad, const __half* pd = dst_grad_base + oh * OW + owmin_x; const __half* pf = flt + fh * FW + fw; #pragma unroll - for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; - ++chl_mul) { + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) { pd2.x = *pd; pd2.y = 0.0; pf2.x = *pf; @@ -326,8 +316,7 @@ __global__ void kern_bwd_data_hf(__half* src_grad, const __half* dst_grad, const __half* pd = dst_grad_base + oh * OW + owmax_y; const __half* pf = flt + fh * FW + fw; #pragma unroll - for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; - ++chl_mul) { + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) { pd2.x = 0.0; pd2.y = *pd; pf2.x = 0.0; @@ -346,8 +335,7 @@ __global__ void kern_bwd_data_hf(__half* src_grad, const __half* dst_grad, const __half* pd = dst_grad_base + oh * OW + ow; const __half* pf = flt + fh * FW + fw; #pragma unroll - for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; - ++chl_mul) { + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) { pd2.x = *pd; pd2.y = *pd; pf2.x = *pf; @@ -371,8 +359,7 @@ __global__ void kern_bwd_data_hf(__half* src_grad, const __half* dst_grad, out_idx = div_mod(out_idx, IH, ih); n = out_idx; - const __half* dst_grad_base = - dst_grad + n * (IC * CHL_MUL * OH * OW); + const __half* dst_grad_base = dst_grad + n * (IC * CHL_MUL * OH * OW); __half sum(0); @@ -391,12 +378,11 @@ __global__ void kern_bwd_data_hf(__half* src_grad, const __half* dst_grad, uint32_t ow = owmin + dow; if (ow <= owmax) { uint32_t fw = iw - ow * SW + PW; - const __half* pd = - dst_grad_base + oh * OW + ow; + const __half* pd = dst_grad_base + oh * OW + ow; const __half* pf = flt + fh * FW + fw; #pragma unroll - for (uint32_t chl_mul = 0; - chl_mul < CHL_MUL; ++chl_mul) { + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; + ++chl_mul) { sum = fma(*pd, *pf, sum); pd += OH * OW; pf += FSIZE; @@ -415,8 +401,7 @@ __global__ void kern_bwd_data_hf(__half* src_grad, const __half* dst_grad, const __half* pd = dst_grad_base + oh * OW + ow; const __half* pf = flt + fh * FW + fw; #pragma unroll - for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; - ++chl_mul) { + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) { sum = fma(*pd, *pf, sum); pd += OH * OW; pf += FSIZE; @@ -500,31 +485,30 @@ namespace convolution { namespace chanwise { template -void run_bwd_data(T* src_grad, const T* dst_grad, const T* flt, - const Param& param, cudaStream_t stream) { +void run_bwd_data( + T* src_grad, const T* dst_grad, const T* flt, const Param& param, + cudaStream_t stream) { void (*kern)(T*, const T*, const T*, Param); kern = get_kern(param).f; int nr_thread = query_blocksize_for_kernel(kern), nr_out_dimx = param.src_h * param.src_w * param.batch; - dim3 nr_block(param.src_chl, - std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); + dim3 nr_block(param.src_chl, std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T); - kern<<>>(src_grad, dst_grad, flt, - param); + kern<<>>(src_grad, dst_grad, flt, param); after_kernel_launch(); } -template void run_bwd_data(float*, const float*, const float*, const Param&, - cudaStream_t); +template void run_bwd_data( + float*, const float*, const float*, const Param&, cudaStream_t); #if CUDA_VERSION >= 9000 -template void run_bwd_data(__half*, const __half*, const __half*, const Param&, - cudaStream_t); +template void run_bwd_data( + __half*, const __half*, const __half*, const Param&, cudaStream_t); #endif -template void run_bwd_data(dt_float16*, const dt_float16*, const dt_float16*, - const Param&, cudaStream_t); +template void run_bwd_data( + dt_float16*, const dt_float16*, const dt_float16*, const Param&, cudaStream_t); } // namespace chanwise } // namespace convolution @@ -532,4 +516,3 @@ template void run_bwd_data(dt_float16*, const dt_float16*, const dt_float16*, } // namespace megdnn // vim: syntax=cuda.doxygen - diff --git a/dnn/src/cuda/convolution/chanwise/bwd_filter.cu b/dnn/src/cuda/convolution/chanwise/bwd_filter.cu index 9a36443e..911806fe 100644 --- a/dnn/src/cuda/convolution/chanwise/bwd_filter.cu +++ b/dnn/src/cuda/convolution/chanwise/bwd_filter.cu @@ -11,8 +11,8 @@ #include "./kern.cuh" #include "./kern_helper.cuh" -#include "src/cuda/cub/util_ptx.cuh" #include "cuda_fp16.h" +#include "src/cuda/cub/util_ptx.cuh" #include "src/cuda/fp16_help.cuh" const uint32_t WARP_SIZE = 32, BATCH_UNROLL = 4; @@ -32,22 +32,17 @@ namespace { * \tparam nr_thpf number of threads for one element in the filter; must be * power of 2; */ -template +template __global__ void kern_bwd_filter_float( T* flt_grad, const T* src, const T* dst_grad, Param param) { - - const uint32_t - N = param.batch, IC = param.src_chl, IH = param.src_h, IW = param.src_w, - CHL_MUL = param.chl_mul, - FH = param.flt_h, FW = param.flt_w, - PH = param.pad_h, PW = param.pad_w, - SH = param.stride_h, SW = param.stride_w, - OH = param.out_h, OW = param.out_w, - SRC_BATCH_STRIDE = IC * IH * IW, - DST_BATCH_STRIDE = IC * CHL_MUL * OH * OW, - BLKDIM_X = blockDim.x / nr_thpf, - THREADID_X = threadIdx.x / nr_thpf, - OUT_IDX = blockIdx.x * BLKDIM_X + THREADID_X; + const uint32_t N = param.batch, IC = param.src_chl, IH = param.src_h, + IW = param.src_w, CHL_MUL = param.chl_mul, FH = param.flt_h, + FW = param.flt_w, PH = param.pad_h, PW = param.pad_w, + SH = param.stride_h, SW = param.stride_w, OH = param.out_h, + OW = param.out_w, SRC_BATCH_STRIDE = IC * IH * IW, + DST_BATCH_STRIDE = IC * CHL_MUL * OH * OW, + BLKDIM_X = blockDim.x / nr_thpf, THREADID_X = threadIdx.x / nr_thpf, + OUT_IDX = blockIdx.x * BLKDIM_X + THREADID_X; uint32_t ic, chl_mul, fh, fw; { @@ -63,18 +58,15 @@ __global__ void kern_bwd_filter_float( src += ic * IH * IW; dst_grad += (ic * CHL_MUL + chl_mul) * OH * OW; - const uint32_t - oh_lo = max(int32_t(PH - fh + SH - 1), 0) / SH, - oh_hi = min((IH - 1 + PH - fh) / SH + 1, OH), - ow_lo = max(int32_t(PW - fw + SW - 1), 0) / SW, - ow_hi = min((IW - 1 + PW - fw) / SW + 1, OW), - oblk_h = oh_hi - oh_lo, - oblk_w = ow_hi - ow_lo, - oblk_tot = oblk_h * oblk_w * ((N + BATCH_UNROLL - 1) / BATCH_UNROLL), - tid = threadIdx.x % nr_thpf; - - if (IH + PH < fh + 1 || oh_lo >= oh_hi || - IW + PW < fw + 1 || ow_lo >= ow_hi) { + const uint32_t oh_lo = max(int32_t(PH - fh + SH - 1), 0) / SH, + oh_hi = min((IH - 1 + PH - fh) / SH + 1, OH), + ow_lo = max(int32_t(PW - fw + SW - 1), 0) / SW, + ow_hi = min((IW - 1 + PW - fw) / SW + 1, OW), oblk_h = oh_hi - oh_lo, + oblk_w = ow_hi - ow_lo, + oblk_tot = oblk_h * oblk_w * ((N + BATCH_UNROLL - 1) / BATCH_UNROLL), + tid = threadIdx.x % nr_thpf; + + if (IH + PH < fh + 1 || oh_lo >= oh_hi || IW + PW < fw + 1 || ow_lo >= ow_hi) { if (!tid) flt_grad[OUT_IDX] = 0; return; @@ -86,12 +78,11 @@ __global__ void kern_bwd_filter_float( n = div_mod(div_mod(oblk_idx, oblk_w, ow), oblk_h, oh) * BATCH_UNROLL; oh += oh_lo; ow += ow_lo; - uint32_t ih = oh * SH - PH + fh, - iw = ow * SW - PW + fw, + uint32_t ih = oh * SH - PH + fh, iw = ow * SW - PW + fw, soff = ih * IW + iw + n * SRC_BATCH_STRIDE, doff = oh * OW + ow + n * DST_BATCH_STRIDE; #pragma unroll - for (uint32_t i = 0; i < BATCH_UNROLL; ++ i) { + for (uint32_t i = 0; i < BATCH_UNROLL; ++i) { if (!i || n + i < N) { sum += src[soff] * dst_grad[doff]; } @@ -129,50 +120,44 @@ __global__ void kern_bwd_filter_float( } #if CUDA_VERSION >= 9000 -template +template __global__ void kern_bwd_filter_hf( - __half* flt_grad, const __half* src, const __half* dst_grad, Param param) { - const uint32_t - N = param.batch, IC = param.src_chl, IH = param.src_h, IW = param.src_w, - CHL_MUL = param.chl_mul, - FH = param.flt_h, FW = param.flt_w, - PH = param.pad_h, PW = param.pad_w, - SH = param.stride_h, SW = param.stride_w, - OH = param.out_h, OW = param.out_w, - SRC_BATCH_STRIDE = IC * IH * IW, - DST_BATCH_STRIDE = IC * CHL_MUL * OH * OW, - BLKDIM_X = (blockDim.x / nr_thpf) * 2, - THREADID_X = (threadIdx.x / nr_thpf) * 2, - OUT_IDX = blockIdx.x * BLKDIM_X + THREADID_X, - LAST_IDX = FH * FW * CHL_MUL * IC, - tid = threadIdx.x % nr_thpf; + __half* flt_grad, const __half* src, const __half* dst_grad, Param param) { + const uint32_t N = param.batch, IC = param.src_chl, IH = param.src_h, + IW = param.src_w, CHL_MUL = param.chl_mul, FH = param.flt_h, + FW = param.flt_w, PH = param.pad_h, PW = param.pad_w, + SH = param.stride_h, SW = param.stride_w, OH = param.out_h, + OW = param.out_w, SRC_BATCH_STRIDE = IC * IH * IW, + DST_BATCH_STRIDE = IC * CHL_MUL * OH * OW, + BLKDIM_X = (blockDim.x / nr_thpf) * 2, + THREADID_X = (threadIdx.x / nr_thpf) * 2, + OUT_IDX = blockIdx.x * BLKDIM_X + THREADID_X, + LAST_IDX = FH * FW * CHL_MUL * IC, tid = threadIdx.x % nr_thpf; __half2 sum2{0.0, 0.0}; - if (OUT_IDX % FW != FW - 1) { - uint32_t ic, chl_mul, fh, fw; - { - uint32_t i = OUT_IDX; - i = div_mod(i, FW, fw); - i = div_mod(i, FH, fh); - i = div_mod(i, CHL_MUL, chl_mul); - ic = i; - } - if (ic >= IC) { - return; - } - src += ic * IH * IW; - dst_grad += (ic * CHL_MUL + chl_mul) * OH * OW; - - const uint32_t - oh_lo = max(int32_t(PH - fh + SH - 1), 0) / SH, - oh_hi = min((IH - 1 + PH - fh) / SH + 1, OH), - ow_lox = max(int32_t(PW - fw + SW - 1), 0) / SW, - ow_loy = max(int32_t(PW - fw + SW - 2), 0) / SW, - ow_hix = min((IW - 1 + PW - fw) / SW + 1, OW), - ow_hiy = min((IW - 2 + PW - fw) / SW + 1, OW), - oblk_h = oh_hi - oh_lo, - oblk_wx = ow_hix - ow_lox, - oblk_wy = ow_hiy - ow_loy; + if (OUT_IDX % FW != FW - 1) { + uint32_t ic, chl_mul, fh, fw; + { + uint32_t i = OUT_IDX; + i = div_mod(i, FW, fw); + i = div_mod(i, FH, fh); + i = div_mod(i, CHL_MUL, chl_mul); + ic = i; + } + if (ic >= IC) { + return; + } + src += ic * IH * IW; + dst_grad += (ic * CHL_MUL + chl_mul) * OH * OW; + + const uint32_t oh_lo = max(int32_t(PH - fh + SH - 1), 0) / SH, + oh_hi = min((IH - 1 + PH - fh) / SH + 1, OH), + ow_lox = max(int32_t(PW - fw + SW - 1), 0) / SW, + ow_loy = max(int32_t(PW - fw + SW - 2), 0) / SW, + ow_hix = min((IW - 1 + PW - fw) / SW + 1, OW), + ow_hiy = min((IW - 2 + PW - fw) / SW + 1, OW), + oblk_h = oh_hi - oh_lo, oblk_wx = ow_hix - ow_lox, + oblk_wy = ow_hiy - ow_loy; if (IH + PH < fh + 1 || oh_lo >= oh_hi || IW + PW < fw + 1) { if (!tid) { flt_grad[OUT_IDX] = 0; @@ -180,123 +165,121 @@ __global__ void kern_bwd_filter_hf( } return; } - - if (ow_lox >= ow_hix) { - if (!tid) - flt_grad[OUT_IDX] = 0; - } - - if (IW + PW < fw + 2 || ow_loy >= ow_hiy) { - if (!tid) - flt_grad[OUT_IDX + 1] = 0; + + if (ow_lox >= ow_hix) { + if (!tid) + flt_grad[OUT_IDX] = 0; + } + + if (IW + PW < fw + 2 || ow_loy >= ow_hiy) { + if (!tid) + flt_grad[OUT_IDX + 1] = 0; if (ow_lox >= ow_hix) return; - } - - sum2.x = 0.0; - sum2.y = 0.0; - __half2 src2{0.0, 0.0}; - __half2 dst2{0.0, 0.0}; - - const uint32_t - oblk_w = max(ow_hix, ow_hiy) - min(ow_lox, ow_loy), - oblk_tot = oblk_h * oblk_w * ((N + BATCH_UNROLL - 1) / BATCH_UNROLL); - - for (uint32_t oblk_idx = tid; oblk_idx < oblk_tot; oblk_idx += nr_thpf) { - uint32_t n_x, n_y, oh, ow_x, ow_y; - n_x = div_mod(div_mod(oblk_idx, oblk_wx, ow_x), oblk_h, oh) * BATCH_UNROLL; - n_y = div_mod(div_mod(oblk_idx, oblk_wy, ow_y), oblk_h, oh) * BATCH_UNROLL; - oh += oh_lo; - ow_x += ow_lox; - ow_y += ow_loy; - uint32_t ih = oh * SH - PH + fh, - iw_x = ow_x * SW - PW + fw, - iw_y = ow_y * SW - PW + fw + 1, - soff_x = ih * IW + iw_x + n_x * SRC_BATCH_STRIDE, - soff_y = ih * IW + iw_y + n_y * SRC_BATCH_STRIDE, - doff_x = oh * OW + ow_x + n_x * DST_BATCH_STRIDE, - doff_y = oh * OW + ow_y + n_y * DST_BATCH_STRIDE; + } + + sum2.x = 0.0; + sum2.y = 0.0; + __half2 src2{0.0, 0.0}; + __half2 dst2{0.0, 0.0}; + + const uint32_t oblk_w = max(ow_hix, ow_hiy) - min(ow_lox, ow_loy), + oblk_tot = oblk_h * oblk_w * + ((N + BATCH_UNROLL - 1) / BATCH_UNROLL); + + for (uint32_t oblk_idx = tid; oblk_idx < oblk_tot; oblk_idx += nr_thpf) { + uint32_t n_x, n_y, oh, ow_x, ow_y; + n_x = div_mod(div_mod(oblk_idx, oblk_wx, ow_x), oblk_h, oh) * BATCH_UNROLL; + n_y = div_mod(div_mod(oblk_idx, oblk_wy, ow_y), oblk_h, oh) * BATCH_UNROLL; + oh += oh_lo; + ow_x += ow_lox; + ow_y += ow_loy; + uint32_t ih = oh * SH - PH + fh, iw_x = ow_x * SW - PW + fw, + iw_y = ow_y * SW - PW + fw + 1, + soff_x = ih * IW + iw_x + n_x * SRC_BATCH_STRIDE, + soff_y = ih * IW + iw_y + n_y * SRC_BATCH_STRIDE, + doff_x = oh * OW + ow_x + n_x * DST_BATCH_STRIDE, + doff_y = oh * OW + ow_y + n_y * DST_BATCH_STRIDE; #pragma unroll - for (uint32_t i = 0; i < BATCH_UNROLL; ++ i) { - if (!i || n_x + i < N || n_y + i < N) { - src2.x = 0.0; - src2.y = 0.0; - dst2.x = 0.0; - dst2.y = 0.0; - if (n_x + i < N && ow_x < ow_hix) { - src2.x = src[soff_x]; - dst2.x = dst_grad[doff_x]; - } - if (n_y + i < N && ow_y < ow_hiy) { - src2.y = src[soff_y]; - dst2.y = dst_grad[doff_y]; - } - sum2 = fma2(src2, dst2, sum2); - } - soff_x += SRC_BATCH_STRIDE; - soff_y += SRC_BATCH_STRIDE; - doff_x += DST_BATCH_STRIDE; - doff_y += DST_BATCH_STRIDE; - } - } - } else { - for (size_t offset = 0; offset < 2; ++ offset) { - uint32_t ic, chl_mul, fh, fw; - { - uint32_t i = OUT_IDX + offset; - i = div_mod(i, FW, fw); - i = div_mod(i, FH, fh); - i = div_mod(i, CHL_MUL, chl_mul); - ic = i; - } - if (ic >= IC) { - if (offset == 0) + for (uint32_t i = 0; i < BATCH_UNROLL; ++i) { + if (!i || n_x + i < N || n_y + i < N) { + src2.x = 0.0; + src2.y = 0.0; + dst2.x = 0.0; + dst2.y = 0.0; + if (n_x + i < N && ow_x < ow_hix) { + src2.x = src[soff_x]; + dst2.x = dst_grad[doff_x]; + } + if (n_y + i < N && ow_y < ow_hiy) { + src2.y = src[soff_y]; + dst2.y = dst_grad[doff_y]; + } + sum2 = fma2(src2, dst2, sum2); + } + soff_x += SRC_BATCH_STRIDE; + soff_y += SRC_BATCH_STRIDE; + doff_x += DST_BATCH_STRIDE; + doff_y += DST_BATCH_STRIDE; + } + } + } else { + for (size_t offset = 0; offset < 2; ++offset) { + uint32_t ic, chl_mul, fh, fw; + { + uint32_t i = OUT_IDX + offset; + i = div_mod(i, FW, fw); + i = div_mod(i, FH, fh); + i = div_mod(i, CHL_MUL, chl_mul); + ic = i; + } + if (ic >= IC) { + if (offset == 0) return; else break; - } - const uint32_t - oh_lo = max(int32_t(PH - fh + SH - 1), 0) / SH, - oh_hi = min((IH - 1 + PH - fh) / SH + 1, OH), - ow_lo = max(int32_t(PW - fw + SW - 1), 0) / SW, - ow_hi = min((IW - 1 + PW - fw) / SW + 1, OW), - oblk_h = oh_hi - oh_lo, - oblk_w = ow_hi - ow_lo, - oblk_tot = oblk_h * oblk_w * ((N + BATCH_UNROLL - 1) / BATCH_UNROLL); - - if (IH + PH < fh + 1 || oh_lo >= oh_hi || - IW + PW < fw + 1 || ow_lo >= ow_hi) { - if (!tid) - flt_grad[OUT_IDX + offset] = 0; - continue; - } - - __half sum(0.0); - - for (uint32_t oblk_idx = tid; oblk_idx < oblk_tot; oblk_idx += nr_thpf) { - uint32_t n, oh, ow; - n = div_mod(div_mod(oblk_idx, oblk_w, ow), oblk_h, oh) * BATCH_UNROLL; - oh += oh_lo; - ow += ow_lo; - uint32_t ih = oh * SH - PH + fh, - iw = ow * SW - PW + fw, - soff = ic * IH * IW + ih * IW + iw + n * SRC_BATCH_STRIDE, - doff = (ic * CHL_MUL + chl_mul) * OH * OW + oh * OW + ow + n * DST_BATCH_STRIDE; + } + const uint32_t oh_lo = max(int32_t(PH - fh + SH - 1), 0) / SH, + oh_hi = min((IH - 1 + PH - fh) / SH + 1, OH), + ow_lo = max(int32_t(PW - fw + SW - 1), 0) / SW, + ow_hi = min((IW - 1 + PW - fw) / SW + 1, OW), + oblk_h = oh_hi - oh_lo, oblk_w = ow_hi - ow_lo, + oblk_tot = oblk_h * oblk_w * + ((N + BATCH_UNROLL - 1) / BATCH_UNROLL); + + if (IH + PH < fh + 1 || oh_lo >= oh_hi || IW + PW < fw + 1 || + ow_lo >= ow_hi) { + if (!tid) + flt_grad[OUT_IDX + offset] = 0; + continue; + } + + __half sum(0.0); + + for (uint32_t oblk_idx = tid; oblk_idx < oblk_tot; oblk_idx += nr_thpf) { + uint32_t n, oh, ow; + n = div_mod(div_mod(oblk_idx, oblk_w, ow), oblk_h, oh) * BATCH_UNROLL; + oh += oh_lo; + ow += ow_lo; + uint32_t ih = oh * SH - PH + fh, iw = ow * SW - PW + fw, + soff = ic * IH * IW + ih * IW + iw + n * SRC_BATCH_STRIDE, + doff = (ic * CHL_MUL + chl_mul) * OH * OW + oh * OW + ow + + n * DST_BATCH_STRIDE; #pragma unroll - for (uint32_t i = 0; i < BATCH_UNROLL; ++ i) { - if (!i || n + i < N) { - sum = fma(src[soff], dst_grad[doff], sum); - } - soff += SRC_BATCH_STRIDE; - doff += DST_BATCH_STRIDE; - } - } + for (uint32_t i = 0; i < BATCH_UNROLL; ++i) { + if (!i || n + i < N) { + sum = fma(src[soff], dst_grad[doff], sum); + } + soff += SRC_BATCH_STRIDE; + doff += DST_BATCH_STRIDE; + } + } if (!offset) sum2.x = sum; if (offset) sum2.y = sum; - } - } + } + } if (nr_thpf == 1) { flt_grad[OUT_IDX] = sum2.x; @@ -405,72 +388,67 @@ namespace cuda { namespace convolution { namespace chanwise { template -void run_bwd_filter(T *filter_grad, const T *src, const T *dst_grad, - const Param ¶m, cudaStream_t stream) { - void (*kern)(T*, const T*, const T*, Param) = NULL; - uint32_t - nr_thread = query_blocksize_for_kernel(get_kern(1024).f), - nr_thpf = std::min(nr_thread, - std::max( - 1, - param.out_h * param.out_w * param.batch / - (BATCH_UNROLL * 16))); - // find nearest power-of-2 of nr_thpf - do { -#define CK(_n) \ - if (nr_thpf >= _n) { \ - kern = get_kern(_n).f; \ - nr_thpf = _n; \ - break; \ - } - CK(1<<10); - CK(1<<9); - CK(1<<8); - CK(1<<7); - CK(1<<6); - CK(1<<5); - CK(1<<4); - CK(1<<3); - CK(1<<2); - CK(1<<1); - CK(1<<0); +void run_bwd_filter( + T* filter_grad, const T* src, const T* dst_grad, const Param& param, + cudaStream_t stream) { + void (*kern)(T*, const T*, const T*, Param) = NULL; + uint32_t nr_thread = query_blocksize_for_kernel(get_kern(1024).f), + nr_thpf = std::min( + nr_thread, std::max( + 1, param.out_h * param.out_w * param.batch / + (BATCH_UNROLL * 16))); + // find nearest power-of-2 of nr_thpf + do { +#define CK(_n) \ + if (nr_thpf >= _n) { \ + kern = get_kern(_n).f; \ + nr_thpf = _n; \ + break; \ + } + CK(1 << 10); + CK(1 << 9); + CK(1 << 8); + CK(1 << 7); + CK(1 << 6); + CK(1 << 5); + CK(1 << 4); + CK(1 << 3); + CK(1 << 2); + CK(1 << 1); + CK(1 << 0); #undef CK - } while(0); - - megdnn_assert(kern); - nr_thread = query_blocksize_for_kernel(kern); - - uint32_t nr_flt_per_blk = nr_thread / nr_thpf; - while (nr_flt_per_blk * nr_thpf % WARP_SIZE) - --nr_flt_per_blk; - megdnn_assert(nr_flt_per_blk); - - int nr_block = DIVUP( - param.flt_h * param.flt_w * param.src_chl * param.chl_mul, - nr_flt_per_blk); - nr_thread = nr_flt_per_blk * nr_thpf; - uint32_t shared = nr_thread * 2 * sizeof(T); - kern <<< nr_block, nr_thread, shared, stream >>> ( - filter_grad, src, dst_grad, param); - after_kernel_launch(); + } while (0); + + megdnn_assert(kern); + nr_thread = query_blocksize_for_kernel(kern); + + uint32_t nr_flt_per_blk = nr_thread / nr_thpf; + while (nr_flt_per_blk * nr_thpf % WARP_SIZE) + --nr_flt_per_blk; + megdnn_assert(nr_flt_per_blk); + + int nr_block = DIVUP( + param.flt_h * param.flt_w * param.src_chl * param.chl_mul, nr_flt_per_blk); + nr_thread = nr_flt_per_blk * nr_thpf; + uint32_t shared = nr_thread * 2 * sizeof(T); + kern<<>>(filter_grad, src, dst_grad, param); + after_kernel_launch(); } -template void run_bwd_filter(float*, const float*, const float*, const Param&, - cudaStream_t); +template void run_bwd_filter( + float*, const float*, const float*, const Param&, cudaStream_t); #if CUDA_VERSION >= 9000 -template void run_bwd_filter(__half*, const __half*, const __half*, const Param&, - cudaStream_t); +template void run_bwd_filter( + __half*, const __half*, const __half*, const Param&, cudaStream_t); #endif -template void run_bwd_filter(dt_float16*, const dt_float16*, const dt_float16*, - const Param&, cudaStream_t); - -} // namespace chanwise -} // namespace convolution -} // namespace cuda -} // namespace megdnn +template void run_bwd_filter( + dt_float16*, const dt_float16*, const dt_float16*, const Param&, cudaStream_t); +} // namespace chanwise +} // namespace convolution +} // namespace cuda +} // namespace megdnn // vim: syntax=cuda.doxygen - diff --git a/dnn/src/cuda/convolution/chanwise/bwd_small.cu b/dnn/src/cuda/convolution/chanwise/bwd_small.cu index 17ce9447..c4be0323 100644 --- a/dnn/src/cuda/convolution/chanwise/bwd_small.cu +++ b/dnn/src/cuda/convolution/chanwise/bwd_small.cu @@ -55,17 +55,18 @@ enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD }; // one each in the lower and upper half of a tile. // Backprop input direction is the same as forward direction with the filter // rotated by 180°. -template +template < + typename T, typename T2, DepthwiseConv2dDirection kDirection, + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, + bool kKnownEvenHeight> __global__ void #if __CUDA_ARCH__ >= 750 __launch_bounds__(1024, 1) #else __launch_bounds__(1024, 2) #endif - DepthwiseConv2dGPUKernelNCHWSmall(const Param param, const T* input, - const T* filter, T* output) { + DepthwiseConv2dGPUKernelNCHWSmall( + const Param param, const T* input, const T* filter, T* output) { // Holds block plus halo and filter data for blockDim.z depths. extern __shared__ __align__(8) unsigned char shared_memory[]; static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); @@ -75,12 +76,10 @@ __launch_bounds__(1024, 2) const int in_height = static_cast(param.src_h); const int in_width = static_cast(param.src_w); const int in_depth = static_cast(param.src_chl); - const int filter_height = kKnownFilterHeight < 0 - ? static_cast(param.flt_h) - : kKnownFilterHeight; - const int filter_width = kKnownFilterWidth < 0 - ? static_cast(param.flt_w) - : kKnownFilterWidth; + const int filter_height = + kKnownFilterHeight < 0 ? static_cast(param.flt_h) : kKnownFilterHeight; + const int filter_width = + kKnownFilterWidth < 0 ? static_cast(param.flt_w) : kKnownFilterWidth; const int pad_height = static_cast(param.pad_h); const int pad_width = static_cast(param.pad_w); @@ -160,8 +159,7 @@ __launch_bounds__(1024, 2) if (filter_write_offset != 0) { const int filter_offset = - (channel + filter_channel) % in_depth * filter_pixels + - filter_pix; + (channel + filter_channel) % in_depth * filter_pixels + filter_pix; shared_data[filter_write_offset] = *(filter_offset + filter); } @@ -202,23 +200,23 @@ __launch_bounds__(1024, 2) } } -template -void LaunchDepthwiseConv2dGPUSmall(const Param& param, const T* input, - const T* filter, T* output, - cudaStream_t stream) { +template < + typename T, typename T2, DepthwiseConv2dDirection kDirection, + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, + bool kKnownEvenHeight> +void LaunchDepthwiseConv2dGPUSmall( + const Param& param, const T* input, const T* filter, T* output, + cudaStream_t stream) { const int block_height = (param.src_h + 1) / 2; dim3 block_dim; int block_count; void (*kernel)(const Param, const T*, const T*, T*); block_dim = dim3(param.src_w, block_height, kBlockDepth); - block_count = - DIVUP(param.batch * param.src_chl * param.chl_mul, kBlockDepth) * - kBlockDepth; + block_count = DIVUP(param.batch * param.src_chl * param.chl_mul, kBlockDepth) * + kBlockDepth; kernel = DepthwiseConv2dGPUKernelNCHWSmall< - T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, - kBlockDepth, kKnownEvenHeight>; + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, + kKnownEvenHeight>; const int tile_width = param.src_w + param.flt_w - 1; const int tile_height = block_height * 2 + param.flt_h - 1; const int tile_pixels = tile_height * tile_width; @@ -227,48 +225,51 @@ void LaunchDepthwiseConv2dGPUSmall(const Param& param, const T* input, kBlockDepth * (tile_pixels + filter_pixels) * sizeof(T); const int num_outputs = param.out_h * param.out_w * block_count; - block_count = GetFixedBlockSize(num_outputs, kernel, shared_memory_size, - block_dim.x * block_dim.y * block_dim.z); + block_count = GetFixedBlockSize( + num_outputs, kernel, shared_memory_size, + block_dim.x * block_dim.y * block_dim.z); kernel<<>>( param, input, filter, output); after_kernel_launch(); } -template -void LaunchDepthwiseConv2dGPUSmall(const Param& param, const T* input, - const T* filter, T* output, - cudaStream_t stream) { +template < + typename T, typename T2, DepthwiseConv2dDirection kDirection, + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth> +void LaunchDepthwiseConv2dGPUSmall( + const Param& param, const T* input, const T* filter, T* output, + cudaStream_t stream) { if (param.src_h & 1) { return LaunchDepthwiseConv2dGPUSmall< - T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, - kBlockDepth, false>(param, input, filter, output, stream); + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, + false>(param, input, filter, output, stream); } else { return LaunchDepthwiseConv2dGPUSmall< - T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, - kBlockDepth, true>(param, input, filter, output, stream); + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, + true>(param, input, filter, output, stream); } } -template -void LaunchDepthwiseConv2dGPUSmall(const Param& param, const T* input, - const T* filter, T* output, - cudaStream_t stream) { +template < + typename T, typename T2, DepthwiseConv2dDirection kDirection, + int kKnownFilterWidth, int kKnownFilterHeight> +void LaunchDepthwiseConv2dGPUSmall( + const Param& param, const T* input, const T* filter, T* output, + cudaStream_t stream) { // Maximize (power of two) kBlockDepth while keeping a block within 1024 // threads (2 pixels per thread). const int block_pixels = (param.src_h + 1) / 2 * param.src_w; if (block_pixels > 256) { - LaunchDepthwiseConv2dGPUSmall( + LaunchDepthwiseConv2dGPUSmall< + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, 2>( param, input, filter, output, stream); } else if (block_pixels > 128) { - LaunchDepthwiseConv2dGPUSmall( + LaunchDepthwiseConv2dGPUSmall< + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, 4>( param, input, filter, output, stream); } else { - LaunchDepthwiseConv2dGPUSmall( + LaunchDepthwiseConv2dGPUSmall< + T, T2, kDirection, kKnownFilterWidth, kKnownFilterHeight, 8>( param, input, filter, output, stream); } } @@ -281,35 +282,34 @@ namespace convolution { namespace chanwise { // ===================================bwd data================================== -#define LAUNCH(type, type2) \ - if (param.flt_h == 3 && param.flt_w == 3) { \ - LaunchDepthwiseConv2dGPUSmall< \ - type, type2, DepthwiseConv2dDirection::DIRECTION_BACKWARD, 3, \ - 3>(param, dst_grad, flt, src_grad, stream); \ - } else { \ - LaunchDepthwiseConv2dGPUSmall< \ - type, type2, DepthwiseConv2dDirection::DIRECTION_BACKWARD, -1, \ - -1>(param, dst_grad, flt, src_grad, stream); \ +#define LAUNCH(type, type2) \ + if (param.flt_h == 3 && param.flt_w == 3) { \ + LaunchDepthwiseConv2dGPUSmall< \ + type, type2, DepthwiseConv2dDirection::DIRECTION_BACKWARD, 3, 3>( \ + param, dst_grad, flt, src_grad, stream); \ + } else { \ + LaunchDepthwiseConv2dGPUSmall< \ + type, type2, DepthwiseConv2dDirection::DIRECTION_BACKWARD, -1, -1>( \ + param, dst_grad, flt, src_grad, stream); \ } template <> -void run_bwd_data_small(float* src_grad, const float* dst_grad, - const float* flt, const Param& param, - cudaStream_t stream) { +void run_bwd_data_small( + float* src_grad, const float* dst_grad, const float* flt, const Param& param, + cudaStream_t stream) { LAUNCH(float, float2); } #if CUDA_VERSION >= 9000 template <> -void run_bwd_data_small(__half* src_grad, const __half* dst_grad, - const __half* flt, const Param& param, - cudaStream_t stream) { +void run_bwd_data_small( + __half* src_grad, const __half* dst_grad, const __half* flt, const Param& param, + cudaStream_t stream) { LAUNCH(__half, __half2); } #endif #undef LAUNCH - } // namespace chanwise } // namespace convolution } // namespace cuda diff --git a/dnn/src/cuda/convolution/chanwise/kern.cuh b/dnn/src/cuda/convolution/chanwise/kern.cuh index 7bba345d..4c794944 100644 --- a/dnn/src/cuda/convolution/chanwise/kern.cuh +++ b/dnn/src/cuda/convolution/chanwise/kern.cuh @@ -12,8 +12,8 @@ #include "src/cuda/utils.cuh" -#include #include +#include #if MEGDNN_CC_HOST #include "src/cuda/convolution/helper.h" @@ -24,54 +24,53 @@ namespace cuda { namespace convolution { namespace chanwise { - struct Param { - uint32_t batch, src_chl, src_h, src_w, - chl_mul, flt_h, flt_w, - out_h, out_w, - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w; +struct Param { + uint32_t batch, src_chl, src_h, src_w, chl_mul, flt_h, flt_w, out_h, out_w, pad_h, + pad_w, stride_h, stride_w, dilation_h, dilation_w; #if MEGDNN_CC_HOST - static Param from_fwd_args(const ForwardSizeArgs &args) { + static Param from_fwd_args(const ForwardSizeArgs& args) { #define U(v) static_cast(v) - auto &&src = args.src_layout->shape; - auto &&dst = args.dst_layout->shape; - auto &&fm = args.filter_meta; - size_t c_pos, hw_pos; - if (fm.format == param::Convolution::Format::NCHW) { - c_pos = 1; - hw_pos = 2; - } else { - c_pos = 3; - hw_pos = 1; - } - return { - U(src[0]), U(src[c_pos]), U(src[hw_pos]), U(src[hw_pos+1]), - U(fm.ocpg), U(fm.spatial[0]), U(fm.spatial[1]), - U(dst[hw_pos]), U(dst[hw_pos+1]), - U(fm.padding[0]), U(fm.padding[1]), - U(fm.stride[0]), U(fm.stride[1]), - U(fm.dilation[0]), U(fm.dilation[1]), - }; -#undef U + auto&& src = args.src_layout->shape; + auto&& dst = args.dst_layout->shape; + auto&& fm = args.filter_meta; + size_t c_pos, hw_pos; + if (fm.format == param::Convolution::Format::NCHW) { + c_pos = 1; + hw_pos = 2; + } else { + c_pos = 3; + hw_pos = 1; } + return { + U(src[0]), U(src[c_pos]), U(src[hw_pos]), + U(src[hw_pos + 1]), U(fm.ocpg), U(fm.spatial[0]), + U(fm.spatial[1]), U(dst[hw_pos]), U(dst[hw_pos + 1]), + U(fm.padding[0]), U(fm.padding[1]), U(fm.stride[0]), + U(fm.stride[1]), U(fm.dilation[0]), U(fm.dilation[1]), + }; +#undef U + } #endif - }; +}; - template - void run_bwd_data_small(T *src_grad, const T *dst_grad, const T *flt, - const Param ¶m, cudaStream_t stream); +template +void run_bwd_data_small( + T* src_grad, const T* dst_grad, const T* flt, const Param& param, + cudaStream_t stream); - template - void run_bwd_data(T *src_grad, const T *dst_grad, const T *flt, - const Param ¶m, cudaStream_t stream); +template +void run_bwd_data( + T* src_grad, const T* dst_grad, const T* flt, const Param& param, + cudaStream_t stream); - template - void run_bwd_filter(T *filter_grad, const T *src, const T *dst_grad, - const Param ¶m, cudaStream_t stream); +template +void run_bwd_filter( + T* filter_grad, const T* src, const T* dst_grad, const Param& param, + cudaStream_t stream); -} // namespace chanwise -} // namespace convolution -} // namespace cuda -} // namespace megdnn +} // namespace chanwise +} // namespace convolution +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution/chanwise/kern_helper.cuh b/dnn/src/cuda/convolution/chanwise/kern_helper.cuh index 84516fee..e58c4ce4 100644 --- a/dnn/src/cuda/convolution/chanwise/kern_helper.cuh +++ b/dnn/src/cuda/convolution/chanwise/kern_helper.cuh @@ -10,9 +10,9 @@ */ #pragma once +#include "megdnn/dtype.h" #include "src/cuda/query_blocksize.cuh" #include "src/cuda/utils.cuh" -#include "megdnn/dtype.h" #include #include @@ -23,33 +23,30 @@ namespace cuda { namespace convolution { namespace chanwise { - /*! - * \brief return a / b and set mod to a % b - */ - __device__ __forceinline__ uint32_t div_mod( - uint32_t a, uint32_t b, uint32_t &mod) { - uint32_t ret = a / b; - mod = a - ret * b; - return ret; - } - - /*! - * \brief copy a 2D matrix by all threads in a block - * \param rs row stride - */ - template - __device__ __forceinline__ void block_memcpy( - T *dst, const T *src, uint32_t size) { - for (uint32_t i = threadIdx.x; i < size; i += blockDim.x) { - dst[i] = src[i]; - } - __syncthreads(); +/*! + * \brief return a / b and set mod to a % b + */ +__device__ __forceinline__ uint32_t div_mod(uint32_t a, uint32_t b, uint32_t& mod) { + uint32_t ret = a / b; + mod = a - ret * b; + return ret; +} + +/*! + * \brief copy a 2D matrix by all threads in a block + * \param rs row stride + */ +template +__device__ __forceinline__ void block_memcpy(T* dst, const T* src, uint32_t size) { + for (uint32_t i = threadIdx.x; i < size; i += blockDim.x) { + dst[i] = src[i]; } + __syncthreads(); +} -} // namespace chanwise -} // namespace convolution -} // namespace cuda -} // namespace megdnn +} // namespace chanwise +} // namespace convolution +} // namespace cuda +} // namespace megdnn // vim: syntax=cuda.doxygen - diff --git a/dnn/src/cuda/convolution/chanwise/launch_config.cpp b/dnn/src/cuda/convolution/chanwise/launch_config.cpp index 5479fc13..2ad0cf83 100644 --- a/dnn/src/cuda/convolution/chanwise/launch_config.cpp +++ b/dnn/src/cuda/convolution/chanwise/launch_config.cpp @@ -16,9 +16,9 @@ using namespace megdnn; using namespace cuda; using namespace convolution; -int chanwise::GetFixedBlockSize1(int work_element_count, const void* func, - int dynamic_shared_memory_size, - int fixed_block_size) { +int chanwise::GetFixedBlockSize1( + int work_element_count, const void* func, int dynamic_shared_memory_size, + int fixed_block_size) { int block_count = 0; cuda_check(cudaOccupancyMaxActiveBlocksPerMultiprocessor( diff --git a/dnn/src/cuda/convolution/chanwise/launch_config.cuh b/dnn/src/cuda/convolution/chanwise/launch_config.cuh index 997b438a..889a6100 100644 --- a/dnn/src/cuda/convolution/chanwise/launch_config.cuh +++ b/dnn/src/cuda/convolution/chanwise/launch_config.cuh @@ -16,15 +16,17 @@ namespace cuda { namespace convolution { namespace chanwise { -int GetFixedBlockSize1(int work_element_count, const void* func, - int dynamic_shared_memory_size, int fixed_block_size); +int GetFixedBlockSize1( + int work_element_count, const void* func, int dynamic_shared_memory_size, + int fixed_block_size); template -int GetFixedBlockSize(int work_element_count, DeviceFunc func, - int dynamic_shared_memory_size, int fixed_block_size) { - return GetFixedBlockSize1(work_element_count, - reinterpret_cast(func), - dynamic_shared_memory_size, fixed_block_size); +int GetFixedBlockSize( + int work_element_count, DeviceFunc func, int dynamic_shared_memory_size, + int fixed_block_size) { + return GetFixedBlockSize1( + work_element_count, reinterpret_cast(func), + dynamic_shared_memory_size, fixed_block_size); } } // namespace chanwise diff --git a/dnn/src/cuda/convolution/forward/algos.cpp b/dnn/src/cuda/convolution/forward/algos.cpp index 039f071e..52a7c347 100644 --- a/dnn/src/cuda/convolution/forward/algos.cpp +++ b/dnn/src/cuda/convolution/forward/algos.cpp @@ -11,18 +11,18 @@ */ #include "src/cuda/convolution/forward/algos.h" -#include "src/cuda/conv_bias/opr_impl.h" -#include "src/cuda/conv_bias/algo.h" #include "src/common/algo_base.h" #include "src/common/algo_chooser.h" +#include "src/cuda/conv_bias/algo.h" +#include "src/cuda/conv_bias/opr_impl.h" using namespace megdnn; using namespace cuda; namespace { std::pair sub_opr_config( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, const ConvolutionForwardImpl* opr) { + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, + const ConvolutionForwardImpl* opr) { auto conv_param = opr->param(); DType bias_type; if (src.dtype.enumv() == DTypeEnum::QuantizedS8) { @@ -35,8 +35,9 @@ std::pair sub_opr_config( src.dtype.param().scale * filter.dtype.param().scale); - } else if (src.dtype.enumv() == DTypeEnum::Uint8 || - src.dtype.enumv() == DTypeEnum::Int8) { + } else if ( + src.dtype.enumv() == DTypeEnum::Uint8 || + src.dtype.enumv() == DTypeEnum::Int8) { bias_type = dtype::Int32{}; } else if (src.dtype.enumv() == DTypeEnum::Quantized4Asymm) { bias_type = dtype::QuantizedS32( @@ -49,17 +50,18 @@ std::pair sub_opr_config( } std::pair ret; - ret.second = {param::ConvBias::NonlineMode::IDENTITY, - conv_param.mode, - conv_param.sparse, - conv_param.format, - conv_param.pad_h, - conv_param.pad_w, - conv_param.stride_h, - conv_param.stride_w, - conv_param.dilate_h, - conv_param.dilate_w, - conv_param.compute_mode}; + ret.second = { + param::ConvBias::NonlineMode::IDENTITY, + conv_param.mode, + conv_param.sparse, + conv_param.format, + conv_param.pad_h, + conv_param.pad_w, + conv_param.stride_h, + conv_param.stride_w, + conv_param.dilate_h, + conv_param.dilate_w, + conv_param.compute_mode}; ret.first.push_back(TensorLayout({}, bias_type)); ret.first.push_back(TensorLayout({}, dst.dtype)); return ret; @@ -72,8 +74,7 @@ std::pair> prepare_sub_opr( args.opr, conv_bias_opr.get()); auto&& config = sub_opr_config( - *args.layout_src, *args.layout_filter, *args.layout_dst, - args.opr); + *args.layout_src, *args.layout_filter, *args.layout_dst, args.opr); conv_bias_opr->param() = config.second; return {config.first, std::move(conv_bias_opr)}; @@ -93,16 +94,14 @@ ConvolutionForwardImpl::AlgoPack ConvolutionForwardImpl::sm_algo_pack; MEGDNN_DEF_GET_ALGO_FROM_DESC(ConvolutionForwardImpl) -ConvolutionForwardImpl::AlgoBase::SizeArgs::SizeArgs(ConvolutionForwardImpl* o, - const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) +ConvolutionForwardImpl::AlgoBase::SizeArgs::SizeArgs( + ConvolutionForwardImpl* o, const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) : opr{o}, layout_src{&src}, layout_filter{&filter}, layout_dst{&dst} {} ConvolutionForwardImpl::AlgoBase::ExecArgs::ExecArgs( - ConvolutionForwardImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_in filter, _megdnn_tensor_out dst, - _megdnn_workspace workspace) + ConvolutionForwardImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_in filter, + _megdnn_tensor_out dst, _megdnn_workspace workspace) : SizeArgs(opr, src.layout, filter.layout, dst.layout), tensor_src{src}, tensor_filter{filter}, @@ -110,51 +109,45 @@ ConvolutionForwardImpl::AlgoBase::ExecArgs::ExecArgs( workspace{workspace} {} std::string ConvolutionForwardImpl::AlgoBase::SizeArgs::to_string() const { - return ssprintf("src=%s, filter=%s, dst=%s", - layout_src->to_string().c_str(), - layout_filter->to_string().c_str(), - layout_dst->to_string().c_str()); + return ssprintf( + "src=%s, filter=%s, dst=%s", layout_src->to_string().c_str(), + layout_filter->to_string().c_str(), layout_dst->to_string().c_str()); } /* ===================== default algo ===================== */ -std::vector -ConvolutionForwardImpl::AlgoDefault::get_subopr_list( +std::vector ConvolutionForwardImpl::AlgoDefault::get_subopr_list( const TensorLayoutArray& layouts, const OperatorBase* opr) const { - auto&& config = - sub_opr_config(layouts[0], layouts[1], layouts[2], - static_cast(opr)); + auto&& config = sub_opr_config( + layouts[0], layouts[1], layouts[2], + static_cast(opr)); - TensorLayoutArray conv_bias_layouts = {layouts[0], layouts[1], - config.first[0], config.first[1], - layouts[2]}; + TensorLayoutArray conv_bias_layouts = { + layouts[0], layouts[1], config.first[0], config.first[1], layouts[2]}; std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::CONVBIAS_FORWARD, param_str, - conv_bias_layouts}}; + return {{Algorithm::OprType::CONVBIAS_FORWARD, param_str, conv_bias_layouts}}; } -bool ConvolutionForwardImpl::AlgoDefault::is_available( - const SizeArgs& args) const { +bool ConvolutionForwardImpl::AlgoDefault::is_available(const SizeArgs& args) const { auto config = prepare_sub_opr(args); - return get_algorithm(static_cast(config.second.get()), - *args.layout_src, *args.layout_filter, config.first[0], - config.first[1], *args.layout_dst); + return get_algorithm( + static_cast(config.second.get()), *args.layout_src, + *args.layout_filter, config.first[0], config.first[1], *args.layout_dst); } - size_t ConvolutionForwardImpl::AlgoDefault::get_workspace_in_bytes( const SizeArgs& args) const { auto config = prepare_sub_opr(args); return config.second->get_workspace_in_bytes( - *args.layout_src, *args.layout_filter, config.first[0], - config.first[1], *args.layout_dst, nullptr); + *args.layout_src, *args.layout_filter, config.first[0], config.first[1], + *args.layout_dst, nullptr); } void ConvolutionForwardImpl::AlgoDefault::exec(const ExecArgs& args) const { auto config = prepare_sub_opr(args); - config.second->exec(args.tensor_src, args.tensor_filter, - {nullptr, config.first[0]}, {nullptr, config.first[1]}, - args.tensor_dst, nullptr, args.workspace); + config.second->exec( + args.tensor_src, args.tensor_filter, {nullptr, config.first[0]}, + {nullptr, config.first[1]}, args.tensor_dst, nullptr, args.workspace); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/forward/algos.h b/dnn/src/cuda/convolution/forward/algos.h index c9984b1e..7825905d 100644 --- a/dnn/src/cuda/convolution/forward/algos.h +++ b/dnn/src/cuda/convolution/forward/algos.h @@ -31,7 +31,6 @@ protected: ~AlgoBase() = default; public: - enum class AlgoType : uint32_t { CUDA_DEFAULT, }; @@ -44,16 +43,18 @@ public: const TensorLayout *layout_src, *layout_filter, *layout_dst; std::string to_string() const; - SizeArgs(ConvolutionForwardImpl* opr, const TensorLayout& src, - const TensorLayout& filter, const TensorLayout& dst); + SizeArgs( + ConvolutionForwardImpl* opr, const TensorLayout& src, + const TensorLayout& filter, const TensorLayout& dst); }; struct ExecArgs : public SizeArgs { TensorND tensor_src, tensor_filter, tensor_dst; Workspace workspace; - ExecArgs(ConvolutionForwardImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_in filter, _megdnn_tensor_out dst, - _megdnn_workspace workspace); + ExecArgs( + ConvolutionForwardImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; @@ -69,17 +70,16 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "convolution fwd algo %s: required workspace %zu bytes, " - "got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "convolution fwd algo %s: required workspace %zu bytes, " + "got %zu", + name(), req, workspace.size); return *this; } }; @@ -91,13 +91,10 @@ public: size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override; const char* name() const override { return "DEFAULT"; } void exec(const ExecArgs&) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; MEGDNN_DECL_ALGO_TYPE(CUDA_DEFAULT) }; diff --git a/dnn/src/cuda/convolution/helper.cpp b/dnn/src/cuda/convolution/helper.cpp index 747e90b3..20e2ffc3 100644 --- a/dnn/src/cuda/convolution/helper.cpp +++ b/dnn/src/cuda/convolution/helper.cpp @@ -15,7 +15,7 @@ using namespace megdnn; using namespace cuda; using namespace convolution; -bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) { +bool convolution::is_cudnn_supported(const ForwardSizeArgs& args) { if (args.src_layout->dtype == args.filter_layout->dtype && args.src_layout->dtype == dtype::BFloat16()) { return false; @@ -30,11 +30,12 @@ bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) { // for NHWC as well. if (args.filter_meta.format == param::Convolution::Format::NCHW4) { if (args.dst_layout->dtype.enumv() != DTypeEnum::Int8 && - args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS8) { + args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS8) { return false; } - } else if (args.filter_meta.format != param::Convolution::Format::NCHW && - args.filter_meta.format != param::Convolution::Format::NHWC) { + } else if ( + args.filter_meta.format != param::Convolution::Format::NCHW && + args.filter_meta.format != param::Convolution::Format::NHWC) { return false; } auto& fm = args.filter_meta; @@ -50,37 +51,32 @@ bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) { } SmallVector convolution::matmul_get_workspace_bundle( - const ForwardSizeArgs &args) { + const ForwardSizeArgs& args) { auto dtype = args.src_layout->dtype; - auto &&fm = args.filter_meta; + auto&& fm = args.filter_meta; megdnn_assert(fm.group == 1); auto N = args.src_layout->shape[0]; - auto OC = fm.ocpg, - IC = fm.icpg, - FH = fm.spatial[0], - FW = fm.spatial[1]; - auto OH = args.dst_layout->shape[2], - OW = args.dst_layout->shape[3]; + auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; + auto OH = args.dst_layout->shape[2], OW = args.dst_layout->shape[3]; SmallVector sizes{ dtype.size() * args.dst_layout->total_nr_elems(), - dtype.size() * IC*FH*FW*OH*OW*N - }; + dtype.size() * IC * FH * FW * OH * OW * N}; if (args.filter_meta.should_flip) { sizes.push_back(dtype.size() * OC * IC * FH * FW); } return sizes; } -void convolution::flip_filter(const ForwardSizeArgs &args, - const Workspace &workspace, void *&raw_ptr) { - auto &&fm = args.filter_meta; +void convolution::flip_filter( + const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr) { + auto&& fm = args.filter_meta; megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2); auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; auto dtype = fm.dtype; megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW); TensorND src{raw_ptr, {{OC, IC, FH, FW}, dtype}}, - dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout}; + dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout}; dst.layout.stride[2] = -dst.layout.stride[2]; dst.layout.stride[3] = -dst.layout.stride[3]; args.handle->relayout_opr()->exec(src, dst); diff --git a/dnn/src/cuda/convolution/helper.h b/dnn/src/cuda/convolution/helper.h index a8bc57c8..2036f38a 100644 --- a/dnn/src/cuda/convolution/helper.h +++ b/dnn/src/cuda/convolution/helper.h @@ -11,91 +11,84 @@ #pragma once #include "./opr_impl.h" +#include "src/common/algo_chooser.h" +#include "src/common/utils.h" #include "src/cuda/cudnn_wrapper.h" #include "src/cuda/handle.h" -#include "src/common/utils.h" -#include "src/common/algo_chooser.h" namespace megdnn { namespace cuda { namespace convolution { - using CanonizedFilterMeta = ConvolutionForward::CanonizedFilterMeta; +using CanonizedFilterMeta = ConvolutionForward::CanonizedFilterMeta; - //! conv size descriptor in the forward view - struct ForwardSizeArgs { - HandleImpl *handle; - const TensorLayout *src_layout; - const TensorLayout *filter_layout; - CanonizedFilterMeta filter_meta; - const TensorLayout *dst_layout; - }; +//! conv size descriptor in the forward view +struct ForwardSizeArgs { + HandleImpl* handle; + const TensorLayout* src_layout; + const TensorLayout* filter_layout; + CanonizedFilterMeta filter_meta; + const TensorLayout* dst_layout; +}; - //! whether cudnn is supported for a filter meta - bool is_cudnn_supported(const ForwardSizeArgs &args); +//! whether cudnn is supported for a filter meta +bool is_cudnn_supported(const ForwardSizeArgs& args); - //! get workspace bundle for matmul algo - SmallVector matmul_get_workspace_bundle( - const ForwardSizeArgs& args); +//! get workspace bundle for matmul algo +SmallVector matmul_get_workspace_bundle(const ForwardSizeArgs& args); - struct CUDNNForwardDescs { - TensorDesc src_desc, dst_desc; - FilterDesc filter_desc; - ConvDesc conv_desc; - void set(const TensorLayout &src, - const CanonizedFilterMeta &filter, - const TensorLayout &dst, - const param::Convolution ¶m) - { - src_desc.set(src, param.format); - filter_desc.set(filter); - dst_desc.set(dst, param.format); - conv_desc.set(src.dtype, param, filter.group); - } - }; +struct CUDNNForwardDescs { + TensorDesc src_desc, dst_desc; + FilterDesc filter_desc; + ConvDesc conv_desc; + void set( + const TensorLayout& src, const CanonizedFilterMeta& filter, + const TensorLayout& dst, const param::Convolution& param) { + src_desc.set(src, param.format); + filter_desc.set(filter); + dst_desc.set(dst, param.format); + conv_desc.set(src.dtype, param, filter.group); + } +}; - struct CUDNNBwdDataDescs { - TensorDesc diff_desc, grad_desc; - FilterDesc filter_desc; - ConvDesc conv_desc; - void set(const CanonizedFilterMeta &filter, - const TensorLayout &diff, - const TensorLayout &grad, - const param::Convolution ¶m) - { - filter_desc.set(filter); - diff_desc.set(diff, param.format); - grad_desc.set(grad, param.format); - conv_desc.set(filter.dtype, param, filter.group); - } - }; +struct CUDNNBwdDataDescs { + TensorDesc diff_desc, grad_desc; + FilterDesc filter_desc; + ConvDesc conv_desc; + void set( + const CanonizedFilterMeta& filter, const TensorLayout& diff, + const TensorLayout& grad, const param::Convolution& param) { + filter_desc.set(filter); + diff_desc.set(diff, param.format); + grad_desc.set(grad, param.format); + conv_desc.set(filter.dtype, param, filter.group); + } +}; - struct CUDNNBwdFilterDescs { - TensorDesc diff_desc, src_desc; - FilterDesc grad_desc; - ConvDesc conv_desc; - void set(const TensorLayout &src, - const TensorLayout &diff, - const CanonizedFilterMeta &grad, - const param::Convolution ¶m) - { - src_desc.set(src, param.format); - diff_desc.set(diff, param.format); - grad_desc.set(grad); - conv_desc.set(src.dtype, param, grad.group); - } - }; +struct CUDNNBwdFilterDescs { + TensorDesc diff_desc, src_desc; + FilterDesc grad_desc; + ConvDesc conv_desc; + void set( + const TensorLayout& src, const TensorLayout& diff, + const CanonizedFilterMeta& grad, const param::Convolution& param) { + src_desc.set(src, param.format); + diff_desc.set(diff, param.format); + grad_desc.set(grad); + conv_desc.set(src.dtype, param, grad.group); + } +}; - /*! - * \brief flip conv filter - * - * Flip conv filter pointed by \p raw_ptr, store result in workspace, and - * change \p raw_ptr to workspace. - */ - void flip_filter(const ForwardSizeArgs &args, - const Workspace &workspace, void *&raw_ptr); +/*! + * \brief flip conv filter + * + * Flip conv filter pointed by \p raw_ptr, store result in workspace, and + * change \p raw_ptr to workspace. + */ +void flip_filter( + const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr); -} // namespace convolution -} // namespace cuda -} // namespace megdnn +} // namespace convolution +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/im2col.cu b/dnn/src/cuda/convolution/im2col.cu index c2b7ec7f..e1369490 100644 --- a/dnn/src/cuda/convolution/im2col.cu +++ b/dnn/src/cuda/convolution/im2col.cu @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./im2col.cuh" -#include "src/cuda/utils.cuh" #include "megdnn/dtype.h" +#include "src/cuda/utils.cuh" using namespace megdnn; using namespace cuda; @@ -18,15 +18,10 @@ using namespace cuda; namespace { template -__global__ void im2col_kernel(const T *im, T *col, - uint32_t N, uint32_t INP_BS, - uint32_t IC, uint32_t IH, uint32_t IW, - uint32_t FH, uint32_t FW, - uint32_t OH, uint32_t OW, - uint32_t PH, uint32_t PW, - uint32_t SH, uint32_t SW, - uint32_t DH, uint32_t DW) -{ +__global__ void im2col_kernel( + const T* im, T* col, uint32_t N, uint32_t INP_BS, uint32_t IC, uint32_t IH, + uint32_t IW, uint32_t FH, uint32_t FW, uint32_t OH, uint32_t OW, uint32_t PH, + uint32_t PW, uint32_t SH, uint32_t SW, uint32_t DH, uint32_t DW) { uint32_t n = threadIdx.x + blockIdx.y * blockDim.x; uint32_t ow = threadIdx.y + blockIdx.z * blockDim.y; uint32_t oh = blockIdx.x % OH; @@ -34,24 +29,20 @@ __global__ void im2col_kernel(const T *im, T *col, uint32_t fh = blockIdx.x / OH / FW % FH; uint32_t ic = blockIdx.x / OH / FW / FH; if (n < N && ow < OW) { - uint32_t didx = blockIdx.x * OW*N + ow*N + n; - uint32_t ih = -PH + oh*SH + fh*DH; - uint32_t iw = -PW + ow*SW + fw*DW; - col[didx] = (ih < IH && iw < IW ? - im[n*INP_BS + ic*IH*IW + ih*IW + iw] : T(0.0f)); + uint32_t didx = blockIdx.x * OW * N + ow * N + n; + uint32_t ih = -PH + oh * SH + fh * DH; + uint32_t iw = -PW + ow * SW + fw * DW; + col[didx] = + (ih < IH && iw < IW ? im[n * INP_BS + ic * IH * IW + ih * IW + iw] + : T(0.0f)); } } template -__global__ void col2im_kernel(const T *col, T *im, - uint32_t N, uint32_t INP_BS, - uint32_t IC, uint32_t IH, uint32_t IW, - uint32_t FH, uint32_t FW, - uint32_t OH, uint32_t OW, - uint32_t PH, uint32_t PW, - uint32_t SH, uint32_t SW, - uint32_t DH, uint32_t DW) -{ +__global__ void col2im_kernel( + const T* col, T* im, uint32_t N, uint32_t INP_BS, uint32_t IC, uint32_t IH, + uint32_t IW, uint32_t FH, uint32_t FW, uint32_t OH, uint32_t OW, uint32_t PH, + uint32_t PW, uint32_t SH, uint32_t SW, uint32_t DH, uint32_t DW) { uint32_t iw = threadIdx.x + blockIdx.y * blockDim.x; uint32_t ih = threadIdx.y + blockIdx.z * blockDim.y; uint32_t ic = blockIdx.x % IC; @@ -61,98 +52,70 @@ __global__ void col2im_kernel(const T *col, T *im, // ih = -ph + oh*sh + fh*dh // ih + ph - fh*dh == oh*sh for (uint32_t fh = 0; fh < FH; ++fh) { - uint32_t anchorh = ih + PH - fh*DH; - if (anchorh < OH*SH && anchorh % SH == 0) { + uint32_t anchorh = ih + PH - fh * DH; + if (anchorh < OH * SH && anchorh % SH == 0) { uint32_t oh = anchorh / SH; for (uint32_t fw = 0; fw < FW; ++fw) { - uint32_t anchorw = iw + PW - fw*DW; - if (anchorw < OW*SW && anchorw % SW == 0) { + uint32_t anchorw = iw + PW - fw * DW; + if (anchorw < OW * SW && anchorw % SW == 0) { uint32_t ow = anchorw / SW; - res += col[ic*FH*FW*OH*OW*N + - fh*FW*OH*OW*N + - fw*OH*OW*N + - oh*OW*N + - ow*N + - n]; + res += + col[ic * FH * FW * OH * OW * N + fh * FW * OH * OW * N + + fw * OH * OW * N + oh * OW * N + ow * N + n]; } } } } - im[n*INP_BS + ic*IH*IW + ih*IW + iw] = res; + im[n * INP_BS + ic * IH * IW + ih * IW + iw] = res; } } -} // anonymous namespace +} // anonymous namespace template -void convolution::im2col(const T *im, T *col, - size_t N, size_t INP_BS, - size_t IC, size_t IH, size_t IW, - size_t FH, size_t FW, - size_t OH, size_t OW, - size_t PH, size_t PW, - size_t SH, size_t SW, - size_t DH, size_t DW, - cudaStream_t stream) -{ +void convolution::im2col( + const T* im, T* col, size_t N, size_t INP_BS, size_t IC, size_t IH, size_t IW, + size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t DH, size_t DW, cudaStream_t stream) { dim3 threads(NR_THREADS_X, NR_THREADS_Y); // dim3 blocks(DIVUP(N, NR_THREADS_X), DIVUP(OW, NR_THREADS_Y), IC*FH*FW*OH); // IC*FH*FW*OH can be larger than 65536; shuffling blocks dimensions to // put IC*FH*FW*OH to the first dimension. - dim3 blocks(IC*FH*FW*OH, DIVUP(N, NR_THREADS_X), DIVUP(OW, NR_THREADS_Y)); - im2col_kernel<<>>(im, col, - N, INP_BS, - IC, IH, IW, FH, FW, OH, OW, - PH, PW, SH, SW, DH, DW); + dim3 blocks(IC * FH * FW * OH, DIVUP(N, NR_THREADS_X), DIVUP(OW, NR_THREADS_Y)); + im2col_kernel<<>>( + im, col, N, INP_BS, IC, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW, DH, DW); after_kernel_launch(); } template -void convolution::col2im(const T *col, T *im, - size_t N, size_t INP_BS, - size_t IC, size_t IH, size_t IW, - size_t FH, size_t FW, - size_t OH, size_t OW, - size_t PH, size_t PW, - size_t SH, size_t SW, - size_t DH, size_t DW, - cudaStream_t stream) -{ +void convolution::col2im( + const T* col, T* im, size_t N, size_t INP_BS, size_t IC, size_t IH, size_t IW, + size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t DH, size_t DW, cudaStream_t stream) { dim3 threads(NR_THREADS_X, NR_THREADS_Y); // (x, y, z) is shuffled to (y, z, x) to bypass CUDA launch shape limitation. // dim3 blocks(DIVUP(IW, NR_THREADS_X), DIVUP(IH, NR_THREADS_Y), N*IC); - dim3 blocks(N*IC, DIVUP(IW, NR_THREADS_X), DIVUP(IH, NR_THREADS_Y)); - col2im_kernel<<>>(col, im, - N, INP_BS, - IC, IH, IW, FH, FW, OH, OW, - PH, PW, SH, SW, DH, DW); + dim3 blocks(N * IC, DIVUP(IW, NR_THREADS_X), DIVUP(IH, NR_THREADS_Y)); + col2im_kernel<<>>( + col, im, N, INP_BS, IC, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW, DH, DW); after_kernel_launch(); } - namespace megdnn { namespace cuda { namespace convolution { -#define DO_INST(T) \ -template void im2col(const T *im, T *col, \ - size_t N, size_t INP_BS, \ - size_t IC, size_t IH, size_t IW, \ - size_t FH, size_t FW, \ - size_t OH, size_t OW, \ - size_t PH, size_t PW, \ - size_t SH, size_t SW, \ - size_t DH, size_t DW, \ - cudaStream_t stream); \ -template void col2im(const T *col, T *im, \ - size_t N, size_t INP_BS, \ - size_t IC, size_t IH, size_t IW, \ - size_t FH, size_t FW, \ - size_t OH, size_t OW, \ - size_t PH, size_t PW, \ - size_t SH, size_t SW, \ - size_t DH, size_t DW, \ - cudaStream_t stream); +#define DO_INST(T) \ + template void im2col( \ + const T* im, T* col, size_t N, size_t INP_BS, size_t IC, size_t IH, \ + size_t IW, size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, \ + size_t PW, size_t SH, size_t SW, size_t DH, size_t DW, \ + cudaStream_t stream); \ + template void col2im( \ + const T* col, T* im, size_t N, size_t INP_BS, size_t IC, size_t IH, \ + size_t IW, size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, \ + size_t PW, size_t SH, size_t SW, size_t DH, size_t DW, \ + cudaStream_t stream); #define INST(_dt) DO_INST(DTypeTrait<_dt>::ctype) @@ -161,8 +124,8 @@ MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST); #undef DO_INST #undef INST -} // namespace convolution -} // namespace cuda -} // namespace megdnn +} // namespace convolution +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/im2col.cuh b/dnn/src/cuda/convolution/im2col.cuh index 9a830fbd..1dfef03d 100644 --- a/dnn/src/cuda/convolution/im2col.cuh +++ b/dnn/src/cuda/convolution/im2col.cuh @@ -10,8 +10,8 @@ */ #pragma once -#include #include +#include namespace megdnn { namespace cuda { @@ -19,29 +19,21 @@ namespace convolution { //! col is of shape (ic*fh*fw, oh*ow*n) template -void im2col(const T *im, T *col, - size_t N, size_t INP_BS, - size_t IC, size_t IH, size_t IW, - size_t FH, size_t FW, - size_t OH, size_t OW, - size_t PH, size_t PW, - size_t SH, size_t SW, - size_t DH, size_t DW, // dilation +void im2col( + const T* im, T* col, size_t N, size_t INP_BS, size_t IC, size_t IH, size_t IW, + size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t DH, size_t DW, // dilation cudaStream_t stream); template -void col2im(const T *col, T *im, - size_t N, size_t INP_BS, - size_t IC, size_t IH, size_t IW, - size_t FH, size_t FW, - size_t OH, size_t OW, - size_t PH, size_t PW, - size_t SH, size_t SW, - size_t DH, size_t DW, // dilation +void col2im( + const T* col, T* im, size_t N, size_t INP_BS, size_t IC, size_t IH, size_t IW, + size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t DH, size_t DW, // dilation cudaStream_t stream); -} // namespace dilated_convolution -} // namespace cuda -} // namespace megdnn +} // namespace convolution +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution/opr_impl.cpp b/dnn/src/cuda/convolution/opr_impl.cpp index 3ae1d210..c39563ca 100644 --- a/dnn/src/cuda/convolution/opr_impl.cpp +++ b/dnn/src/cuda/convolution/opr_impl.cpp @@ -13,11 +13,11 @@ #include "src/cuda/convolution/opr_impl.h" #include "megdnn/dtype.h" #include "src/common/algo_chooser.h" -#include "src/cuda/convolution/helper.h" -#include "src/cuda/convolution/forward/algos.h" +#include "src/cuda/conv_bias/opr_impl.h" #include "src/cuda/convolution/backward_data/algo.h" #include "src/cuda/convolution/backward_filter/algo.h" -#include "src/cuda/conv_bias/opr_impl.h" +#include "src/cuda/convolution/forward/algos.h" +#include "src/cuda/convolution/helper.h" #include "src/cuda/utils.h" @@ -26,17 +26,15 @@ using namespace cuda; using namespace convolution; #define TO_STRING2(v) #v -#define TO_STRING(v) TO_STRING2(v) +#define TO_STRING(v) TO_STRING2(v) #define CUDNN_VERSION_STR \ TO_STRING(CUDNN_MAJOR) \ "." TO_STRING(CUDNN_MINOR) "." TO_STRING(CUDNN_PATCHLEVEL) /* ============== ConvolutionForwardImpl ============== */ -ConvolutionForwardImpl::Algorithm* -ConvolutionForwardImpl::get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, +ConvolutionForwardImpl::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args{this, src, filter, dst}; MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); @@ -45,37 +43,34 @@ ConvolutionForwardImpl::get_algorithm_heuristic( return &sm_algo_pack.algo_default; } -std::vector -ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) { +std::vector ConvolutionForwardImpl:: + get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) { AlgoBase::SizeArgs args{this, src, filter, dst}; return megdnn::get_all_algorithms(args); } -std::vector -ConvolutionForwardImpl::get_all_algorithms_safe(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) { +std::vector ConvolutionForwardImpl:: + get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) { AlgoBase::SizeArgs args{this, src, filter, dst}; return megdnn::get_all_algorithms_safe(args); } size_t ConvolutionForwardImpl::get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, const PreprocessedFilter* preprocessed_filter) { MEGDNN_MARK_USED_VAR(preprocessed_filter); return get_dnn_workspace(this, src, filter, dst); } -void ConvolutionForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - const PreprocessedFilter* preprocessed_filter, - _megdnn_workspace workspace) { - check_exec(src.layout, filter.layout, dst.layout, workspace.size, - preprocessed_filter); +void ConvolutionForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) { + check_exec( + src.layout, filter.layout, dst.layout, workspace.size, preprocessed_filter); AlgoBase::ExecArgs args(this, src, filter, dst, workspace); auto&& algo = get_algorithm(this, src.layout, filter.layout, dst.layout); algo->exec(args); @@ -87,38 +82,37 @@ const char* ConvolutionForwardImpl::get_algorithm_set_name() const { /* ============== ConvolutionBackwardDataImpl ============== */ -void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) { +void ConvolutionBackwardDataImpl::exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { check_exec(filter.layout, diff.layout, grad.layout, workspace.size); AlgoBase::ExecArgs args(this, filter, diff, grad, workspace); auto algo = get_algorithm(this, filter.layout, diff.layout, grad.layout); algo->exec(args); } -std::vector -ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) { +std::vector ConvolutionBackwardDataImpl:: + get_all_algorithms( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) { return megdnn::get_all_algorithms( {this, filter, diff, grad}); } -std::vector -ConvolutionBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) { +std::vector ConvolutionBackwardDataImpl:: + get_all_algorithms_safe( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) { return megdnn::get_all_algorithms_safe( {this, filter, diff, grad}); } -ConvolutionBackwardDataImpl::Algorithm* -ConvolutionBackwardDataImpl::get_algorithm_heuristic( - const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { +ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl:: + get_algorithm_heuristic( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, filter, diff, grad); if (args.filter_meta.group > 1 && @@ -128,16 +122,14 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( return &sm_algo_pack.chanwise; } - if (args.filter_layout->dtype.enumv() == - DTypeTrait::enumv) { + if (args.filter_layout->dtype.enumv() == DTypeTrait::enumv) { return megdnn::get_algo_match_attribute( sm_algo_pack.int8_algos, args, workspace_limit_in_bytes, "cuda conv bwd_data", positive_attr, negative_attr); } - auto get_cudnn_algo = - [this, &args, workspace_limit_in_bytes, positive_attr, - negative_attr]() -> ConvolutionBackwardDataImpl::AlgoBase* { + auto get_cudnn_algo = [this, &args, workspace_limit_in_bytes, positive_attr, + negative_attr]() -> ConvolutionBackwardDataImpl::AlgoBase* { auto cudnn_handle = cuda::cudnn_handle(this->handle()); CUDNNBwdDataDescs desc; args.init_desc(desc); @@ -161,10 +153,9 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( continue; } AlgoBase* conv_bd_data_algo = reinterpret_cast( - sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); + sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); if (conv_bd_data_algo->is_available_attribute( - args, positive_attr, negative_attr, - workspace_limit_in_bytes)) { + args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return conv_bd_data_algo; } } @@ -195,8 +186,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( return &sm_algo_pack.group; } - if (args.filter_layout->dtype.enumv() != - DTypeTrait::enumv) { + if (args.filter_layout->dtype.enumv() != DTypeTrait::enumv) { return megdnn::get_algo_match_attribute( sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, "cuda conv bwd_data", positive_attr, negative_attr); @@ -219,38 +209,37 @@ const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const { /* ============== ConvolutionBackwardFilterImpl ============== */ -void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) { +void ConvolutionBackwardFilterImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { check_exec(src.layout, diff.layout, grad.layout, workspace.size); AlgoBase::ExecArgs args(this, src, diff, grad, workspace); auto algo = get_algorithm(this, src.layout, diff.layout, grad.layout); algo->exec(args); } -std::vector -ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) { +std::vector ConvolutionBackwardFilterImpl:: + get_all_algorithms( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) { return megdnn::get_all_algorithms( {this, src, diff, grad}); } -std::vector -ConvolutionBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) { +std::vector ConvolutionBackwardFilterImpl:: + get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) { return megdnn::get_all_algorithms_safe( {this, src, diff, grad}); } -ConvolutionBackwardFilterImpl::Algorithm* -ConvolutionBackwardFilterImpl::get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { +ConvolutionBackwardFilterImpl::Algorithm* ConvolutionBackwardFilterImpl:: + get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, src, diff, grad); if (args.grad_filter_meta.group > 1 && @@ -298,8 +287,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( AlgoBase* conv_bd_filter_algo = reinterpret_cast( sm_algo_pack.cudnn_from_enum(algo_perf[i].algo)); if (conv_bd_filter_algo->is_available_attribute( - args, positive_attr, negative_attr, - workspace_limit_in_bytes)) { + args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return conv_bd_filter_algo; } } @@ -342,8 +330,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( } size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad) { + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { return get_dnn_workspace(this, src, diff, grad); } diff --git a/dnn/src/cuda/convolution/opr_impl.h b/dnn/src/cuda/convolution/opr_impl.h index 8579b498..4d22ae36 100644 --- a/dnn/src/cuda/convolution/opr_impl.h +++ b/dnn/src/cuda/convolution/opr_impl.h @@ -20,10 +20,10 @@ namespace cuda { class ConvolutionForwardImpl : public ConvolutionForward { public: using ConvolutionForward::ConvolutionForward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - const PreprocessedFilter* preprocessed_filter, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + const PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) override; size_t get_workspace_in_bytes( const TensorLayout& src, const TensorLayout& filter, @@ -32,18 +32,16 @@ public: const char* get_algorithm_set_name() const override; SmallVector deduce_preprocessed_filter_layout( - const TensorLayout&, const TensorLayout&, - const TensorLayout&) override { + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { return {}; } - size_t get_preprocess_workspace_in_bytes(const TensorLayout&, - const TensorLayout&, - const TensorLayout&) override { + size_t get_preprocess_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { return 0; } - void exec_preprocess(const TensorLayout&, _megdnn_tensor_in, - const TensorLayout&, PreprocessedFilter*, - _megdnn_workspace) override { + void exec_preprocess( + const TensorLayout&, _megdnn_tensor_in, const TensorLayout&, + PreprocessedFilter*, _megdnn_workspace) override { megdnn_throw("cuda exec_preprocess has not implemeted yet"); } @@ -76,22 +74,22 @@ private: class ConvolutionBackwardDataImpl : public ConvolutionBackwardData { public: using ConvolutionBackwardData::ConvolutionBackwardData; - void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; AlgorithmInfo get_algorithm_info_heuristic( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { - return get_algorithm_heuristic(filter, diff, grad, - workspace_limit_in_bytes, positive_attr, - negative_attr) + const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { + return get_algorithm_heuristic( + filter, diff, grad, workspace_limit_in_bytes, positive_attr, + negative_attr) ->info(); } - size_t get_workspace_in_bytes(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) override; + size_t get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; const char* get_algorithm_set_name() const override; class AlgoBase; @@ -132,19 +130,19 @@ private: class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter { public: using ConvolutionBackwardFilter::ConvolutionBackwardFilter; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; AlgorithmInfo get_algorithm_info_heuristic( const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { - return get_algorithm_heuristic(filter, diff, grad, - workspace_limit_in_bytes, positive_attr, - negative_attr) + const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { + return get_algorithm_heuristic( + filter, diff, grad, workspace_limit_in_bytes, positive_attr, + negative_attr) ->info(); } @@ -172,9 +170,8 @@ protected: const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) override; private: diff --git a/dnn/src/cuda/convolution3d/backward_data/algo.cpp b/dnn/src/cuda/convolution3d/backward_data/algo.cpp index 62107f0f..b448c622 100644 --- a/dnn/src/cuda/convolution3d/backward_data/algo.cpp +++ b/dnn/src/cuda/convolution3d/backward_data/algo.cpp @@ -18,10 +18,10 @@ using namespace cuda; Convolution3DBackwardDataImpl::AlgoPack::AlgoPack() { non_cudnn_algos.push_back(&chanwise); - all_algos.push_back(&chanwise); // prefer chanwise + all_algos.push_back(&chanwise); // prefer chanwise fill_cudnn_algos(); - for (auto &&i: cudnn) { + for (auto&& i : cudnn) { all_algos.push_back(&i); } all_algos.push_back(&group); @@ -33,15 +33,14 @@ Convolution3DBackwardDataImpl::AlgoPack::AlgoPack() { MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardDataImpl) -Convolution3DBackwardDataImpl::AlgoCUDNN* -Convolution3DBackwardDataImpl::AlgoPack::cudnn_from_enum( - cudnnConvolutionBwdDataAlgo_t algo) { - for (auto &&i: cudnn) { +Convolution3DBackwardDataImpl::AlgoCUDNN* Convolution3DBackwardDataImpl::AlgoPack:: + cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo) { + for (auto&& i : cudnn) { if (i.cudnn_enum() == algo) return &i; } - megdnn_throw(ssprintf("can not find cudnn bwd_data algorithm %d", - static_cast(algo))); + megdnn_throw(ssprintf( + "can not find cudnn bwd_data algorithm %d", static_cast(algo))); } Convolution3DBackwardDataImpl::AlgoPack Convolution3DBackwardDataImpl::sm_algo_pack; @@ -49,47 +48,42 @@ Convolution3DBackwardDataImpl::AlgoPack Convolution3DBackwardDataImpl::sm_algo_p Convolution3DBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( const Convolution3DBackwardDataImpl* o, const TensorLayout& filter, const TensorLayout& diff, const TensorLayout& grad) - : SizeArgs(o, filter, o->make_canonized_filter_meta(grad.ndim, filter), - diff, grad) {} + : SizeArgs( + o, filter, o->make_canonized_filter_meta(grad.ndim, filter), diff, + grad) {} Convolution3DBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( - const Convolution3DBackwardDataImpl *o, const TensorLayout& filter, - const CanonizedFilterMeta &filter_meta, const TensorLayout &diff, - const TensorLayout &grad): - handle{concrete_handle(o->handle())}, - filter_meta{filter_meta}, - diff_layout{&diff}, - grad_layout{&grad}, - filter_layout{&filter}, - opr{o} -{ -} + const Convolution3DBackwardDataImpl* o, const TensorLayout& filter, + const CanonizedFilterMeta& filter_meta, const TensorLayout& diff, + const TensorLayout& grad) + : handle{concrete_handle(o->handle())}, + filter_meta{filter_meta}, + diff_layout{&diff}, + grad_layout{&grad}, + filter_layout{&filter}, + opr{o} {} Convolution3DBackwardDataImpl::AlgoBase::ExecArgs::ExecArgs( - const Convolution3DBackwardDataImpl *opr, - _megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace): - SizeArgs(opr, filter.layout, diff.layout, grad.layout), - filter_tensor{&filter}, diff_tensor{&diff}, grad_tensor{&grad}, - workspace{workspace} -{ -} + const Convolution3DBackwardDataImpl* opr, _megdnn_tensor_in filter, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) + : SizeArgs(opr, filter.layout, diff.layout, grad.layout), + filter_tensor{&filter}, + diff_tensor{&diff}, + grad_tensor{&grad}, + workspace{workspace} {} std::string Convolution3DBackwardDataImpl::AlgoBase::SizeArgs::to_string() const { - auto &&fm = filter_meta; + auto&& fm = filter_meta; MEGDNN_MARK_USED_VAR(fm); return ssprintf( "filter=%u{%u,%u,%u,%u,%u}, diff=%s, grad=%s, " "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, " "dtype=%s,%s", - fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], - fm.spatial[2], diff_layout->to_string().c_str(), - grad_layout->to_string().c_str(), fm.padding[0], fm.padding[1], - fm.padding[2], fm.stride[0], fm.stride[1], fm.stride[2], - fm.dilation[0], fm.dilation[1], fm.dilation[2], !fm.should_flip, - diff_layout->dtype.name(), grad_layout->dtype.name()); + fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2], + diff_layout->to_string().c_str(), grad_layout->to_string().c_str(), + fm.padding[0], fm.padding[1], fm.padding[2], fm.stride[0], fm.stride[1], + fm.stride[2], fm.dilation[0], fm.dilation[1], fm.dilation[2], + !fm.should_flip, diff_layout->dtype.name(), grad_layout->dtype.name()); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/backward_data/algo.h b/dnn/src/cuda/convolution3d/backward_data/algo.h index 41eb05da..82fc6f8f 100644 --- a/dnn/src/cuda/convolution3d/backward_data/algo.h +++ b/dnn/src/cuda/convolution3d/backward_data/algo.h @@ -13,9 +13,9 @@ #pragma once #include -#include "src/cuda/convolution3d/helper.h" #include "src/common/algo_base.h" #include "src/common/metahelper.h" +#include "src/cuda/convolution3d/helper.h" namespace megdnn { namespace cuda { @@ -49,13 +49,13 @@ public: void init_desc(convolution3d::CUDNNBwdDataDescs& desc) const { desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); } - SizeArgs(const Convolution3DBackwardDataImpl* opr, - const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad); - SizeArgs(const Convolution3DBackwardDataImpl* opr, - const TensorLayout& filter, - const CanonizedFilterMeta& filter_meta, - const TensorLayout& diff, const TensorLayout& grad); + SizeArgs( + const Convolution3DBackwardDataImpl* opr, const TensorLayout& filter, + const TensorLayout& diff, const TensorLayout& grad); + SizeArgs( + const Convolution3DBackwardDataImpl* opr, const TensorLayout& filter, + const CanonizedFilterMeta& filter_meta, const TensorLayout& diff, + const TensorLayout& grad); convolution3d::ForwardSizeArgs as_fwd_args() const { return {handle, grad_layout, filter_layout, @@ -66,9 +66,10 @@ public: const TensorND *filter_tensor, *diff_tensor, *grad_tensor; Workspace workspace; - ExecArgs(const Convolution3DBackwardDataImpl* opr, - _megdnn_tensor_in filter, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace); + ExecArgs( + const Convolution3DBackwardDataImpl* opr, _megdnn_tensor_in filter, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -83,16 +84,15 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "conv bwd data algo %s: " - "required workspace %zu bytes, got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "conv bwd data algo %s: " + "required workspace %zu bytes, got %zu", + name(), req, workspace.size); return *this; } @@ -104,10 +104,10 @@ class Convolution3DBackwardDataImpl::AlgoCUDNN final : public AlgoBase { CudnnAlgoPack::Attr m_attr; public: - AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) - : m_cudnn_enum(cudnn_enum) { - megdnn_assert(CudnnAlgoPack::conv3d_bwd_data_algos().find(cudnn_enum) != - CudnnAlgoPack::conv3d_bwd_data_algos().end()); + AlgoCUDNN(cudnnConvolutionBwdDataAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { + megdnn_assert( + CudnnAlgoPack::conv3d_bwd_data_algos().find(cudnn_enum) != + CudnnAlgoPack::conv3d_bwd_data_algos().end()); m_attr = CudnnAlgoPack::conv3d_bwd_data_algos().at(cudnn_enum); } @@ -148,29 +148,21 @@ public: const char* name() const override { return "CHANNEL_WISE"; } MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } }; //! implement group conv by another algo -class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final - : public AlgoBase { +class Convolution3DBackwardDataImpl::AlgoGroupConvGeneral final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; - const char* name() const override { - return "CUDA:GROUP_CONV3D_BACKWARD_DATA"; - } + const char* name() const override { return "CUDA:GROUP_CONV3D_BACKWARD_DATA"; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) private: diff --git a/dnn/src/cuda/convolution3d/backward_data/chanwise.cpp b/dnn/src/cuda/convolution3d/backward_data/chanwise.cpp index c4d0c584..601403a1 100644 --- a/dnn/src/cuda/convolution3d/backward_data/chanwise.cpp +++ b/dnn/src/cuda/convolution3d/backward_data/chanwise.cpp @@ -10,49 +10,42 @@ */ #include "./algo.h" -#include "src/cuda/utils.h" #include "src/cuda/convolution3d/chanwise/kern.cuh" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace convolution3d; bool Convolution3DBackwardDataImpl::AlgoChanwise::is_available( - const SizeArgs &args) const { - if (!args.grad_layout->is_contiguous() || - !args.diff_layout->is_contiguous()) { + const SizeArgs& args) const { + if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { return false; } - auto &&fm = args.filter_meta; + auto&& fm = args.filter_meta; return args.filter_meta.format == Param::Format::NCDHW && - args.diff_layout->dtype.category() == DTypeCategory::FLOAT && - fm.spatial_ndim == 3 && fm.icpg == 1 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.dilation[2] == 1 && - !fm.should_flip; + args.diff_layout->dtype.category() == DTypeCategory::FLOAT && + fm.spatial_ndim == 3 && fm.icpg == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.dilation[2] == 1 && !fm.should_flip; } size_t Convolution3DBackwardDataImpl::AlgoChanwise::get_workspace_in_bytes( - const SizeArgs &) const { + const SizeArgs&) const { return 0; } -void Convolution3DBackwardDataImpl::AlgoChanwise::exec( - const ExecArgs &args) const { +void Convolution3DBackwardDataImpl::AlgoChanwise::exec(const ExecArgs& args) const { auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args()); auto stream = cuda_stream(args.handle); switch (args.diff_layout->dtype.enumv()) { -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: \ - { \ - using ctype = DTypeTrait<_dt>::ctype; \ - return chanwise::run_bwd_data( \ - args.grad_tensor->ptr(), \ - args.diff_tensor->ptr(), \ - args.filter_tensor->ptr(), \ - kparam, stream); \ - } - MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + return chanwise::run_bwd_data( \ + args.grad_tensor->ptr(), args.diff_tensor->ptr(), \ + args.filter_tensor->ptr(), kparam, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) #undef cb default: break; @@ -60,4 +53,3 @@ void Convolution3DBackwardDataImpl::AlgoChanwise::exec( megdnn_assert_internal(0); } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp b/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp index c7fcafdf..e7a8813e 100644 --- a/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp +++ b/dnn/src/cuda/convolution3d/backward_data/cudnn.cpp @@ -11,16 +11,16 @@ #include "./algo.h" -#include "src/cuda/utils.h" -#include "src/cuda/cudnn_wrapper.h" #include "src/cuda/convolution3d/helper.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace convolution3d; bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available( - const SizeArgs &args) const { + const SizeArgs& args) const { CUDNNBwdDataDescs D; if (!is_cudnn_supported(args.as_fwd_args())) @@ -29,53 +29,37 @@ bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available( args.init_desc(D); size_t workspace_size; auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( - args.handle->cudnn_handle(), - D.filter_desc.desc, - D.diff_desc.desc, - D.conv_desc.desc, - D.grad_desc.desc, - m_cudnn_enum, - &workspace_size); + args.handle->cudnn_handle(), D.filter_desc.desc, D.diff_desc.desc, + D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); return status == CUDNN_STATUS_SUCCESS; } size_t Convolution3DBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( - const SizeArgs &args) const { + const SizeArgs& args) const { CUDNNBwdDataDescs D; args.init_desc(D); size_t workspace_size; auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( - args.handle->cudnn_handle(), - D.filter_desc.desc, - D.diff_desc.desc, - D.conv_desc.desc, - D.grad_desc.desc, - m_cudnn_enum, - &workspace_size); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, + args.handle->cudnn_handle(), D.filter_desc.desc, D.diff_desc.desc, + D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, "conv bwd_data get workspace failed: %s; info: %s", cudnnGetErrorString(status), args.to_string().c_str()); return workspace_size; } -void Convolution3DBackwardDataImpl::AlgoCUDNN::exec( - const ExecArgs &args) const { +void Convolution3DBackwardDataImpl::AlgoCUDNN::exec(const ExecArgs& args) const { CUDNNBwdDataDescs D; args.init_desc(D); float alpha = 1.0f, beta = 0.0f; - auto status = cudnnConvolutionBackwardData(args.handle->cudnn_handle(), - &alpha, - D.filter_desc.desc, args.filter_tensor->raw_ptr, - D.diff_desc.desc, args.diff_tensor->raw_ptr, - D.conv_desc.desc, - m_cudnn_enum, - args.workspace.raw_ptr, - args.workspace.size, - &beta, - D.grad_desc.desc, - args.grad_tensor->raw_ptr); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, - "conv bwd_data failed: %s; info: %s", + auto status = cudnnConvolutionBackwardData( + args.handle->cudnn_handle(), &alpha, D.filter_desc.desc, + args.filter_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, + D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, + &beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", cudnnGetErrorString(status), args.to_string().c_str()); } diff --git a/dnn/src/cuda/convolution3d/backward_data/group_conv.cpp b/dnn/src/cuda/convolution3d/backward_data/group_conv.cpp index d76da0f0..ebaaf590 100644 --- a/dnn/src/cuda/convolution3d/backward_data/group_conv.cpp +++ b/dnn/src/cuda/convolution3d/backward_data/group_conv.cpp @@ -16,8 +16,8 @@ using namespace cuda; using namespace convolution3d; namespace { -std::pair -sub_opr_config(const Convolution3DBackwardDataImpl::AlgoBase::SizeArgs& args) { +std::pair sub_opr_config( + const Convolution3DBackwardDataImpl::AlgoBase::SizeArgs& args) { TensorLayout filter_pg = *args.filter_layout; TensorLayout diff_pg = *args.diff_layout; TensorLayout grad_pg = *args.grad_layout; @@ -37,8 +37,8 @@ sub_opr_config(const Convolution3DBackwardDataImpl::AlgoBase::SizeArgs& args) { return ret; } -std::pair> -prepare_sub_opr(const Convolution3DBackwardDataImpl::AlgoBase::SizeArgs& args) { +std::pair> prepare_sub_opr( + const Convolution3DBackwardDataImpl::AlgoBase::SizeArgs& args) { auto conv3d_backdata_opr = args.handle->create_operator(); set_execution_policy( @@ -50,9 +50,9 @@ prepare_sub_opr(const Convolution3DBackwardDataImpl::AlgoBase::SizeArgs& args) { } } // namespace -std::vector -Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { +std::vector Convolution3DBackwardDataImpl::AlgoGroupConvGeneral:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { AlgoBase::SizeArgs args{ static_cast(opr), layouts[0], layouts[1], layouts[2]}; @@ -60,12 +60,11 @@ Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::get_subopr_list( std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::CONVOLUTION3D_BACKWARD_DATA, param_str, - config.first}}; + return {{Algorithm::OprType::CONVOLUTION3D_BACKWARD_DATA, param_str, config.first}}; } bool Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::is_available( - const SizeArgs &args) const { + const SizeArgs& args) const { if (args.filter_meta.group <= 1) return false; if (args.filter_meta.format != Param::Format::NCDHW) { @@ -79,17 +78,15 @@ bool Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::is_available( config.first[0], config.first[1], config.first[2]); } -WorkspaceBundle -Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::get_workspace_bundle( - void* ptr, const SizeArgs& args) const { +WorkspaceBundle Convolution3DBackwardDataImpl::AlgoGroupConvGeneral:: + get_workspace_bundle(void* ptr, const SizeArgs& args) const { auto config = prepare_sub_opr(args); size_t sizes = config.second->get_workspace_in_bytes( config.first[0], config.first[1], config.first[2]); return {ptr, {sizes}}; } -size_t -Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( +size_t Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( const SizeArgs& args) const { return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } @@ -107,12 +104,13 @@ void Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::exec( auto grp = args.filter_meta.group; auto&& fm = args.filter_meta; - auto strd_flt = (fm.icpg * fm.ocpg * fm.spatial[0] * fm.spatial[1] * - fm.spatial[2] * tfilter.layout.dtype.size()), - strd_diff = (tdiff.layout.stride[c_pos] * fm.ocpg * - tdiff.layout.dtype.size()), - strd_grad = (tgrad.layout.stride[c_pos] * fm.icpg * - tgrad.layout.dtype.size()); + auto strd_flt = + (fm.icpg * fm.ocpg * fm.spatial[0] * fm.spatial[1] * + fm.spatial[2] * tfilter.layout.dtype.size()), + strd_diff = + (tdiff.layout.stride[c_pos] * fm.ocpg * tdiff.layout.dtype.size()), + strd_grad = + (tgrad.layout.stride[c_pos] * fm.icpg * tgrad.layout.dtype.size()); for (uint32_t g = 0; g < grp; ++g) { config.second->exec(tfilter, tdiff, tgrad, bundle.get_workspace(0)); @@ -124,4 +122,3 @@ void Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::exec( } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution3d/backward_filter/algo.cpp b/dnn/src/cuda/convolution3d/backward_filter/algo.cpp index 9c07655a..e97c4f0e 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/algo.cpp +++ b/dnn/src/cuda/convolution3d/backward_filter/algo.cpp @@ -18,10 +18,10 @@ using namespace cuda; Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { non_cudnn_algos.push_back(&chanwise); non_cudnn_algos.push_back(&inplace_matmul); - all_algos.push_back(&chanwise); // prefer chanwise + all_algos.push_back(&chanwise); // prefer chanwise fill_cudnn_algos(); - for (auto &&i: cudnn) { + for (auto&& i : cudnn) { all_algos.push_back(&i); } @@ -35,27 +35,22 @@ Convolution3DBackwardFilterImpl::AlgoPack::AlgoPack() { MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DBackwardFilterImpl) -Convolution3DBackwardFilterImpl::AlgoCUDNN* -Convolution3DBackwardFilterImpl::AlgoPack::cudnn_from_enum( - cudnnConvolutionBwdFilterAlgo_t algo) { - for (auto &&i: cudnn) { +Convolution3DBackwardFilterImpl::AlgoCUDNN* Convolution3DBackwardFilterImpl::AlgoPack:: + cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo) { + for (auto&& i : cudnn) { if (i.cudnn_enum() == algo) return &i; } - megdnn_throw(ssprintf("can not find cudnn bwd_filter algorithm %d", - static_cast(algo))); + megdnn_throw(ssprintf( + "can not find cudnn bwd_filter algorithm %d", static_cast(algo))); } -Convolution3DBackwardFilterImpl::AlgoPack -Convolution3DBackwardFilterImpl::sm_algo_pack; +Convolution3DBackwardFilterImpl::AlgoPack Convolution3DBackwardFilterImpl::sm_algo_pack; Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( - const Convolution3DBackwardFilterImpl *o, - const TensorLayout &src, const TensorLayout &diff, - const TensorLayout &grad): - SizeArgs(o, src, diff, grad, o->make_canonized_filter_meta(src.ndim, grad)) -{ -} + const Convolution3DBackwardFilterImpl* o, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad) + : SizeArgs(o, src, diff, grad, o->make_canonized_filter_meta(src.ndim, grad)) {} Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( const Convolution3DBackwardFilterImpl* o, const TensorLayout& src, @@ -69,31 +64,26 @@ Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( opr{o} {} Convolution3DBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs( - const Convolution3DBackwardFilterImpl *opr, - _megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace): - SizeArgs(opr, src.layout, diff.layout, grad.layout), - src_tensor{&src}, diff_tensor{&diff}, grad_tensor{&grad}, - workspace{workspace} -{ -} + const Convolution3DBackwardFilterImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) + : SizeArgs(opr, src.layout, diff.layout, grad.layout), + src_tensor{&src}, + diff_tensor{&diff}, + grad_tensor{&grad}, + workspace{workspace} {} -std::string -Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs::to_string() const { - auto &&fm = grad_filter_meta; +std::string Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs::to_string() const { + auto&& fm = grad_filter_meta; MEGDNN_MARK_USED_VAR(fm); return ssprintf( "src=%s diff=%s grad_filter=%u{%u,%u,%u,%u,%u}, " "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, " "dtype=%s,%s", - src_layout->to_string().c_str(), diff_layout->to_string().c_str(), - fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], - fm.spatial[2], fm.padding[0], fm.padding[1], fm.padding[2], - fm.stride[0], fm.stride[1], fm.stride[2], fm.dilation[0], - fm.dilation[1], fm.dilation[2], !fm.should_flip, - src_layout->dtype.name(), diff_layout->dtype.name()); + src_layout->to_string().c_str(), diff_layout->to_string().c_str(), fm.group, + fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], fm.spatial[2], + fm.padding[0], fm.padding[1], fm.padding[2], fm.stride[0], fm.stride[1], + fm.stride[2], fm.dilation[0], fm.dilation[1], fm.dilation[2], + !fm.should_flip, src_layout->dtype.name(), diff_layout->dtype.name()); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/backward_filter/algo.h b/dnn/src/cuda/convolution3d/backward_filter/algo.h index e1f5b3ba..89f4de6b 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/algo.h +++ b/dnn/src/cuda/convolution3d/backward_filter/algo.h @@ -13,9 +13,9 @@ #pragma once #include -#include "src/cuda/convolution3d/helper.h" #include "src/common/algo_base.h" #include "src/common/metahelper.h" +#include "src/cuda/convolution3d/helper.h" namespace megdnn { namespace cuda { @@ -44,13 +44,13 @@ public: void init_desc(convolution3d::CUDNNBwdFilterDescs& desc) const { desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param()); } - SizeArgs(const Convolution3DBackwardFilterImpl* opr, - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad); - SizeArgs(const Convolution3DBackwardFilterImpl* opr, - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, - const CanonizedFilterMeta& grad_meta); + SizeArgs( + const Convolution3DBackwardFilterImpl* opr, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad); + SizeArgs( + const Convolution3DBackwardFilterImpl* opr, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad, + const CanonizedFilterMeta& grad_meta); convolution3d::ForwardSizeArgs as_fwd_args() const { return {handle, src_layout, grad_layout, @@ -61,9 +61,10 @@ public: const TensorND *src_tensor, *diff_tensor, *grad_tensor; Workspace workspace; - ExecArgs(const Convolution3DBackwardFilterImpl* opr, - _megdnn_tensor_in src, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace); + ExecArgs( + const Convolution3DBackwardFilterImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -78,16 +79,15 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "conv bwd filter algo %s: " - "required workspace %zu bytes, got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "conv bwd filter algo %s: " + "required workspace %zu bytes, got %zu", + name(), req, workspace.size); return *this; } @@ -99,10 +99,10 @@ class Convolution3DBackwardFilterImpl::AlgoCUDNN final : public AlgoBase { CudnnAlgoPack::Attr m_attr; public: - AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) - : m_cudnn_enum(cudnn_enum) { - megdnn_assert(CudnnAlgoPack::conv3d_bwd_flt_algos().find(cudnn_enum) != - CudnnAlgoPack::conv3d_bwd_flt_algos().end()); + AlgoCUDNN(cudnnConvolutionBwdFilterAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { + megdnn_assert( + CudnnAlgoPack::conv3d_bwd_flt_algos().find(cudnn_enum) != + CudnnAlgoPack::conv3d_bwd_flt_algos().end()); m_attr = CudnnAlgoPack::conv3d_bwd_flt_algos().at(cudnn_enum); } @@ -135,17 +135,14 @@ public: } }; -class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final - : public AlgoBase { +class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; const char* name() const override { return "INPLACE_MATMUL"; } - AlgoAttribute attribute() const override { - return static_cast(0); - } + AlgoAttribute attribute() const override { return static_cast(0); } MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) }; @@ -156,30 +153,22 @@ public: void exec(const ExecArgs& args) const override; const char* name() const override { return "CHANNEL_WISE"; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) }; //! implement group conv by another algo -class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final - : public AlgoBase { +class Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; - const char* name() const override { - return "CUDA:GROUP_CONV3D_BACKWARD_FILTER"; - } + const char* name() const override { return "CUDA:GROUP_CONV3D_BACKWARD_FILTER"; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) private: diff --git a/dnn/src/cuda/convolution3d/backward_filter/chanwise.cpp b/dnn/src/cuda/convolution3d/backward_filter/chanwise.cpp index 06831344..993ab7f1 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/chanwise.cpp +++ b/dnn/src/cuda/convolution3d/backward_filter/chanwise.cpp @@ -10,49 +10,42 @@ */ #include "./algo.h" -#include "src/cuda/utils.h" #include "src/cuda/convolution3d/chanwise/kern.cuh" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace convolution3d; bool Convolution3DBackwardFilterImpl::AlgoChanwise::is_available( - const SizeArgs &args) const { - if (!args.src_layout->is_contiguous() || - !args.diff_layout->is_contiguous()) { + const SizeArgs& args) const { + if (!args.src_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { return false; } - auto &&fm = args.grad_filter_meta; + auto&& fm = args.grad_filter_meta; return fm.format == Param::Format::NCDHW && - args.diff_layout->dtype.category() == DTypeCategory::FLOAT && - fm.spatial_ndim == 3 && fm.icpg == 1 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.dilation[2] == 1 && - !fm.should_flip; + args.diff_layout->dtype.category() == DTypeCategory::FLOAT && + fm.spatial_ndim == 3 && fm.icpg == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.dilation[2] == 1 && !fm.should_flip; } size_t Convolution3DBackwardFilterImpl::AlgoChanwise::get_workspace_in_bytes( - const SizeArgs &) const { + const SizeArgs&) const { return 0; } -void Convolution3DBackwardFilterImpl::AlgoChanwise::exec( - const ExecArgs &args) const { +void Convolution3DBackwardFilterImpl::AlgoChanwise::exec(const ExecArgs& args) const { auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args()); auto stream = cuda_stream(args.handle); switch (args.diff_layout->dtype.enumv()) { -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: \ - { \ - using ctype = DTypeTrait<_dt>::ctype; \ - return chanwise::run_bwd_filter( \ - args.grad_tensor->ptr(), \ - args.src_tensor->ptr(), \ - args.diff_tensor->ptr(), \ - kparam, stream); \ - } - MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + return chanwise::run_bwd_filter( \ + args.grad_tensor->ptr(), args.src_tensor->ptr(), \ + args.diff_tensor->ptr(), kparam, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) #undef cb default: break; @@ -61,4 +54,3 @@ void Convolution3DBackwardFilterImpl::AlgoChanwise::exec( } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp b/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp index a0afe7c2..8fa38c96 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp +++ b/dnn/src/cuda/convolution3d/backward_filter/cudnn.cpp @@ -43,26 +43,25 @@ size_t Convolution3DBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, - "conv bwd_filter get workspace failed: %s; info: %s", - cudnnGetErrorString(status), args.to_string().c_str()); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, + "conv bwd_filter get workspace failed: %s; info: %s", + cudnnGetErrorString(status), args.to_string().c_str()); return workspace_size; } -void Convolution3DBackwardFilterImpl::AlgoCUDNN::exec( - const ExecArgs& args) const { +void Convolution3DBackwardFilterImpl::AlgoCUDNN::exec(const ExecArgs& args) const { CUDNNBwdFilterDescs D; args.init_desc(D); float alpha = 1.0f, beta = 0.0f; auto status = cudnnConvolutionBackwardFilter( args.handle->cudnn_handle(), &alpha, D.src_desc.desc, - args.src_tensor->raw_ptr, D.diff_desc.desc, - args.diff_tensor->raw_ptr, D.conv_desc.desc, m_cudnn_enum, - args.workspace.raw_ptr, args.workspace.size, &beta, - D.grad_desc.desc, args.grad_tensor->raw_ptr); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, - "conv bwd_data failed: %s; info: %s", - cudnnGetErrorString(status), args.to_string().c_str()); + args.src_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, + D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, + &beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", + cudnnGetErrorString(status), args.to_string().c_str()); } void Convolution3DBackwardFilterImpl::AlgoPack::fill_cudnn_algos() { diff --git a/dnn/src/cuda/convolution3d/backward_filter/group_conv.cpp b/dnn/src/cuda/convolution3d/backward_filter/group_conv.cpp index 3875d922..b67b5533 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/group_conv.cpp +++ b/dnn/src/cuda/convolution3d/backward_filter/group_conv.cpp @@ -16,10 +16,8 @@ using namespace cuda; using namespace convolution3d; namespace { -std::pair -sub_opr_config( +std::pair sub_opr_config( const Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs& args) { - TensorLayout grad_pg = *args.grad_layout; TensorLayout src_pg = *args.src_layout; TensorLayout diff_pg = *args.diff_layout; @@ -40,12 +38,10 @@ sub_opr_config( } std::pair> -prepare_sub_opr( - const Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs& args) { +prepare_sub_opr(const Convolution3DBackwardFilterImpl::AlgoBase::SizeArgs& args) { auto conv3d_backfilter_opr = args.handle->create_operator(); - set_execution_policy( + set_execution_policy( args.opr, conv3d_backfilter_opr.get()); auto&& config = sub_opr_config(args); conv3d_backfilter_opr->param() = config.second; @@ -54,17 +50,18 @@ prepare_sub_opr( } } // namespace -std::vector -Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { +std::vector Convolution3DBackwardFilterImpl:: + AlgoGroupConvGeneral::get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { AlgoBase::SizeArgs args{ - static_cast(opr), - layouts[0], layouts[1], layouts[2]}; + static_cast(opr), layouts[0], + layouts[1], layouts[2]}; auto&& config = sub_opr_config(args); std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::CONVOLUTION3D_BACKWARD_FILTER, param_str, + return { + {Algorithm::OprType::CONVOLUTION3D_BACKWARD_FILTER, param_str, config.first}}; } @@ -83,17 +80,15 @@ bool Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::is_available( config.first[0], config.first[1], config.first[2]); } -WorkspaceBundle -Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::get_workspace_bundle( - void* ptr, const SizeArgs& args) const { +WorkspaceBundle Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral:: + get_workspace_bundle(void* ptr, const SizeArgs& args) const { auto config = prepare_sub_opr(args); size_t sizes = config.second->get_workspace_in_bytes( config.first[0], config.first[1], config.first[2]); return {ptr, {sizes}}; } -size_t -Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( +size_t Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( const SizeArgs& args) const { return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } @@ -111,12 +106,13 @@ void Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::exec( auto grp = args.grad_filter_meta.group; auto&& fm = args.grad_filter_meta; - auto strd_src = (tsrc.layout.stride[c_pos] * fm.icpg * - tsrc.layout.dtype.size()), - strd_diff = (tdiff.layout.stride[c_pos] * fm.ocpg * - tdiff.layout.dtype.size()), - strd_grad = (fm.icpg * fm.ocpg * fm.spatial[0] * fm.spatial[1] * - fm.spatial[2] * tgrad.layout.dtype.size()); + auto strd_src = + (tsrc.layout.stride[c_pos] * fm.icpg * tsrc.layout.dtype.size()), + strd_diff = + (tdiff.layout.stride[c_pos] * fm.ocpg * tdiff.layout.dtype.size()), + strd_grad = + (fm.icpg * fm.ocpg * fm.spatial[0] * fm.spatial[1] * + fm.spatial[2] * tgrad.layout.dtype.size()); for (uint32_t g = 0; g < grp; ++g) { config.second->exec(tsrc, tdiff, tgrad, bundle.get_workspace(0)); @@ -128,4 +124,3 @@ void Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::exec( } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul.cpp b/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul.cpp index 63add1d7..583d3832 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul.cpp +++ b/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul.cpp @@ -16,57 +16,38 @@ using namespace megdnn; using namespace cuda; bool Convolution3DBackwardFilterImpl::AlgoInplaceMatmul::is_available( - const SizeArgs &args) const { - if (!args.src_layout->is_contiguous() || - !args.diff_layout->is_contiguous()) { + const SizeArgs& args) const { + if (!args.src_layout->is_contiguous() || !args.diff_layout->is_contiguous()) { return false; } - auto &&fm = args.grad_filter_meta; + auto&& fm = args.grad_filter_meta; return args.grad_filter_meta.format == Param::Format::NCDHW && - args.src_layout->dtype == dtype::Float32() && - fm.group == 1 && fm.spatial_ndim == 3; + args.src_layout->dtype == dtype::Float32() && fm.group == 1 && + fm.spatial_ndim == 3; } size_t Convolution3DBackwardFilterImpl::AlgoInplaceMatmul::get_workspace_in_bytes( - const SizeArgs &) const { + const SizeArgs&) const { return 0; } void Convolution3DBackwardFilterImpl::AlgoInplaceMatmul::exec( - const ExecArgs &args) const { - auto &&fm = args.grad_filter_meta; - size_t N = args.src_layout->shape[0], - IC = fm.icpg, - ID = args.src_layout->shape[2], - IH = args.src_layout->shape[3], - IW = args.src_layout->shape[4], - OC = fm.ocpg, - OD = args.diff_layout->shape[2], - OH = args.diff_layout->shape[3], - OW = args.diff_layout->shape[4], - FD = fm.spatial[0], - FH = fm.spatial[1], - FW = fm.spatial[2], - DD = fm.dilation[0], - DH = fm.dilation[1], - DW = fm.dilation[2]; + const ExecArgs& args) const { + auto&& fm = args.grad_filter_meta; + size_t N = args.src_layout->shape[0], IC = fm.icpg, ID = args.src_layout->shape[2], + IH = args.src_layout->shape[3], IW = args.src_layout->shape[4], OC = fm.ocpg, + OD = args.diff_layout->shape[2], OH = args.diff_layout->shape[3], + OW = args.diff_layout->shape[4], FD = fm.spatial[0], FH = fm.spatial[1], + FW = fm.spatial[2], DD = fm.dilation[0], DH = fm.dilation[1], + DW = fm.dilation[2]; auto stream = args.handle->stream(); convolution3d::exec_inplace_matmul_bwd_filter( - args.diff_tensor->ptr(), - args.src_tensor->ptr(), - args.grad_tensor->ptr(), - N, - args.src_layout->stride[0], - args.diff_layout->stride[0], - IC, ID, IH, IW, - OC, OD, OH, OW, - FD, FH, FW, - fm.padding[0], fm.padding[1], fm.padding[2], - fm.stride[0], fm.stride[1], fm.stride[2], - DD, DH, DW, - !fm.should_flip, stream); + args.diff_tensor->ptr(), args.src_tensor->ptr(), + args.grad_tensor->ptr(), N, args.src_layout->stride[0], + args.diff_layout->stride[0], IC, ID, IH, IW, OC, OD, OH, OW, FD, FH, FW, + fm.padding[0], fm.padding[1], fm.padding[2], fm.stride[0], fm.stride[1], + fm.stride[2], DD, DH, DW, !fm.should_flip, stream); } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul_impl.cu b/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul_impl.cu index 709153d6..9188ff7b 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul_impl.cu +++ b/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul_impl.cu @@ -8,10 +8,10 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include +#include #include "./inplace_matmul_impl.cuh" #include "src/cuda/utils.cuh" -#include -#include using namespace megdnn; using namespace cuda; @@ -26,22 +26,18 @@ struct BufferFetcherTexture { }; struct BufferFetcherRaw { - const float *ptr; + const float* ptr; - __device__ __forceinline__ float get(uint32_t offset) { - return ptr[offset]; - } + __device__ __forceinline__ float get(uint32_t offset) { return ptr[offset]; } }; struct BufferFetcherTextureHost { bool init_succ; BufferFetcherTexture val; - BufferFetcherTextureHost(float *p, const size_t n); + BufferFetcherTextureHost(float* p, const size_t n); - ~BufferFetcherTextureHost() { - reset(); - } + ~BufferFetcherTextureHost() { reset(); } void reset() { if (init_succ) { @@ -51,36 +47,34 @@ struct BufferFetcherTextureHost { } }; -BufferFetcherTextureHost::BufferFetcherTextureHost(float *p, const size_t n) { +BufferFetcherTextureHost::BufferFetcherTextureHost(float* p, const size_t n) { init_succ = false; cudaTextureObject_t tex_obj; cudaResourceDesc res_desc; memset(&res_desc, 0, sizeof(cudaResourceDesc)); res_desc.resType = cudaResourceTypeLinear; - res_desc.res.linear.devPtr = static_cast(p); - res_desc.res.linear.sizeInBytes = n*sizeof(float); - res_desc.res.linear.desc = cudaCreateChannelDesc(32, 0, 0, 0, cudaChannelFormatKindFloat); - cudaTextureDesc tex_desc; + res_desc.res.linear.devPtr = static_cast(p); + res_desc.res.linear.sizeInBytes = n * sizeof(float); + res_desc.res.linear.desc = + cudaCreateChannelDesc(32, 0, 0, 0, cudaChannelFormatKindFloat); + cudaTextureDesc tex_desc; memset(&tex_desc, 0, sizeof(cudaTextureDesc)); if (cudaCreateTextureObject(&tex_obj, &res_desc, &tex_desc, NULL) == cudaSuccess) { val.tex = tex_obj; init_succ = true; } else { - cudaGetLastError(); // reset error + cudaGetLastError(); // reset error } } -template +template struct KernelPtr { - typedef void(*type)(BufferFetcher, BufferFetcher, float*, - uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t); + typedef void (*type)( + BufferFetcher, BufferFetcher, float*, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t); }; //! 1 -> 0xffffffff, 0 -> 0x00000000 @@ -94,7 +88,7 @@ union FloatAndU32 { }; //! \p mask must be either all 1 or 0 bits -template +template __device__ __forceinline__ float visit_with_mask( BufferFetcher buf, uint32_t offset, uint32_t mask) { FloatAndU32 f; @@ -103,22 +97,20 @@ __device__ __forceinline__ float visit_with_mask( return f.f; } -__device__ __forceinline__ uint32_t with_dilation( - const uint32_t origin, const uint32_t D) { +__device__ __forceinline__ uint32_t +with_dilation(const uint32_t origin, const uint32_t D) { return origin * D; } template -__global__ void conv_kernel(BufferFetcher diff, BufferFetcher src, - float *grad, - const uint32_t N, const uint32_t INP_BS, const uint32_t OUT_BS, - const uint32_t IC, const uint32_t ID, const uint32_t IH, const uint32_t IW, - const uint32_t OC, const uint32_t OD, const uint32_t OH, const uint32_t OW, - const uint32_t FD, const uint32_t FH, const uint32_t FW, - const uint32_t SD, const uint32_t SH, const uint32_t SW, - const uint32_t PD, const uint32_t PH, const uint32_t PW, - const uint32_t DD, const uint32_t DH, const uint32_t DW) -{ +__global__ void conv_kernel( + BufferFetcher diff, BufferFetcher src, float* grad, const uint32_t N, + const uint32_t INP_BS, const uint32_t OUT_BS, const uint32_t IC, + const uint32_t ID, const uint32_t IH, const uint32_t IW, const uint32_t OC, + const uint32_t OD, const uint32_t OH, const uint32_t OW, const uint32_t FD, + const uint32_t FH, const uint32_t FW, const uint32_t SD, const uint32_t SH, + const uint32_t SW, const uint32_t PD, const uint32_t PH, const uint32_t PW, + const uint32_t DD, const uint32_t DH, const uint32_t DW) { const uint32_t BM = BY < BX ? BY : BX; uint32_t n = blockIdx.z; @@ -127,33 +119,33 @@ __global__ void conv_kernel(BufferFetcher diff, BufferFetcher src, const uint32_t tidy = threadIdx.y; const uint32_t posx = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t posy = blockIdx.y * blockDim.y + threadIdx.y; - const uint32_t posx2 = posx<<2; - const uint32_t posy2 = posy<<2; - + const uint32_t posx2 = posx << 2; + const uint32_t posy2 = posy << 2; + const uint32_t heightA = OC; - const uint32_t widthA = OD*OH*OW; + const uint32_t widthA = OD * OH * OW; const uint32_t heightB = widthA; - const uint32_t widthB = IC*FD*FH*FW; - - uint32_t ic0 = (posx2+0) / FW / FH / FD; - uint32_t fd0 = (posx2+0) / FW / FH % FD; - uint32_t fh0 = (posx2+0) / FW % FH; - uint32_t fw0 = (posx2+0) % FW; - - uint32_t ic1 = (posx2+1) / FW / FH / FD; - uint32_t fd1 = (posx2+1) / FW / FH % FD; - uint32_t fh1 = (posx2+1) / FW % FH; - uint32_t fw1 = (posx2+1) % FW; - - uint32_t ic2 = (posx2+2) / FW / FH / FD; - uint32_t fd2 = (posx2+2) / FW / FH % FD; - uint32_t fh2 = (posx2+2) / FW % FH; - uint32_t fw2 = (posx2+2) % FW; - - uint32_t ic3 = (posx2+3) / FW / FH / FD; - uint32_t fd3 = (posx2+3) / FW / FH % FD; - uint32_t fh3 = (posx2+3) / FW % FH; - uint32_t fw3 = (posx2+3) % FW; + const uint32_t widthB = IC * FD * FH * FW; + + uint32_t ic0 = (posx2 + 0) / FW / FH / FD; + uint32_t fd0 = (posx2 + 0) / FW / FH % FD; + uint32_t fh0 = (posx2 + 0) / FW % FH; + uint32_t fw0 = (posx2 + 0) % FW; + + uint32_t ic1 = (posx2 + 1) / FW / FH / FD; + uint32_t fd1 = (posx2 + 1) / FW / FH % FD; + uint32_t fh1 = (posx2 + 1) / FW % FH; + uint32_t fw1 = (posx2 + 1) % FW; + + uint32_t ic2 = (posx2 + 2) / FW / FH / FD; + uint32_t fd2 = (posx2 + 2) / FW / FH % FD; + uint32_t fh2 = (posx2 + 2) / FW % FH; + uint32_t fw2 = (posx2 + 2) % FW; + + uint32_t ic3 = (posx2 + 3) / FW / FH / FD; + uint32_t fd3 = (posx2 + 3) / FW / FH % FD; + uint32_t fh3 = (posx2 + 3) / FW % FH; + uint32_t fw3 = (posx2 + 3) % FW; if (!is_xcorr) { fd0 = FD - fd0 - 1; @@ -174,7 +166,7 @@ __global__ void conv_kernel(BufferFetcher diff, BufferFetcher src, const uint32_t fd1d = with_dilation(fd1, DD); const uint32_t fd2d = with_dilation(fd2, DD); const uint32_t fd3d = with_dilation(fd3, DD); - + const uint32_t fh0d = with_dilation(fh0, DH); const uint32_t fh1d = with_dilation(fh1, DH); const uint32_t fh2d = with_dilation(fh2, DH); @@ -185,69 +177,67 @@ __global__ void conv_kernel(BufferFetcher diff, BufferFetcher src, const uint32_t fw2d = with_dilation(fw2, DW); const uint32_t fw3d = with_dilation(fw3, DW); - const uint32_t fp0 = ic0 * ID*IH*IW + fd0d * IH*IW + fh0d * IW + fw0d; - const uint32_t fp1 = ic1 * ID*IH*IW + fd1d * IH*IW + fh1d * IW + fw1d; - const uint32_t fp2 = ic2 * ID*IH*IW + fd2d * IH*IW + fh2d * IW + fw2d; - const uint32_t fp3 = ic3 * ID*IH*IW + fd3d * IH*IW + fh3d * IW + fw3d; + const uint32_t fp0 = ic0 * ID * IH * IW + fd0d * IH * IW + fh0d * IW + fw0d; + const uint32_t fp1 = ic1 * ID * IH * IW + fd1d * IH * IW + fh1d * IW + fw1d; + const uint32_t fp2 = ic2 * ID * IH * IW + fd2d * IH * IW + fh2d * IW + fw2d; + const uint32_t fp3 = ic3 * ID * IH * IW + fd3d * IH * IW + fh3d * IW + fw3d; - const uint32_t OP = OH*OW; + const uint32_t OP = OH * OW; __shared__ float4 localA[BY][BM]; __shared__ float4 localB[BM][BX]; uint32_t i = 0u; uint32_t offsetA = n * OUT_BS + posy2 * widthA + tidx; - uint32_t offsetB = n * INP_BS - PD*IH*IW - PH*IW - PW; - - float4 sum0 = {0.0f, 0.0f, 0.0f, 0.0f}, - sum1 = {0.0f, 0.0f, 0.0f, 0.0f}, - sum2 = {0.0f, 0.0f, 0.0f, 0.0f}, - sum3 = {0.0f, 0.0f, 0.0f, 0.0f}; - - uint32_t od = tidy / (OW*OH); + uint32_t offsetB = n * INP_BS - PD * IH * IW - PH * IW - PW; + + float4 sum0 = {0.0f, 0.0f, 0.0f, 0.0f}, sum1 = {0.0f, 0.0f, 0.0f, 0.0f}, + sum2 = {0.0f, 0.0f, 0.0f, 0.0f}, sum3 = {0.0f, 0.0f, 0.0f, 0.0f}; + + uint32_t od = tidy / (OW * OH); uint32_t oh = tidy / (OW) % OH; uint32_t ow = tidy % OW; - uint32_t odm = tidy % (OW*OH); + uint32_t odm = tidy % (OW * OH); - const uint32_t ods = BM / (OW*OH); + const uint32_t ods = BM / (OW * OH); const uint32_t ohs = BM / (OW) % OH; const uint32_t ows = BM % OW; - const uint32_t odms = BM % (OW*OH); + const uint32_t odms = BM % (OW * OH); for (; i < widthA; i += BM, offsetA += BM) { // load localA if (tidx < BM) { - localA[tidy][tidx].x = diff.get(offsetA + 0*widthA); - localA[tidy][tidx].y = diff.get(offsetA + 1*widthA); - localA[tidy][tidx].z = diff.get(offsetA + 2*widthA); - localA[tidy][tidx].w = diff.get(offsetA + 3*widthA); + localA[tidy][tidx].x = diff.get(offsetA + 0 * widthA); + localA[tidy][tidx].y = diff.get(offsetA + 1 * widthA); + localA[tidy][tidx].z = diff.get(offsetA + 2 * widthA); + localA[tidy][tidx].w = diff.get(offsetA + 3 * widthA); } if (tidy < BM) { - uint32_t tmp = offsetB + od*SD*IH*IW + oh*SH*IW + ow*SW, - ok = bool_as_mask(tidy+i < heightB), + uint32_t tmp = offsetB + od * SD * IH * IW + oh * SH * IW + ow * SW, + ok = bool_as_mask(tidy + i < heightB), p0 = bool_as_mask( - fd0d+od*SD >= PD && fd0d+od*SD < ID+PD && - fh0d+oh*SH >= PH && fh0d+oh*SH < IH+PH && - fw0d+ow*SW >= PW && fw0d+ow*SW < IW+PW), + fd0d + od * SD >= PD && fd0d + od * SD < ID + PD && + fh0d + oh * SH >= PH && fh0d + oh * SH < IH + PH && + fw0d + ow * SW >= PW && fw0d + ow * SW < IW + PW), p1 = bool_as_mask( - fd1d+od*SD >= PD && fd1d+od*SD < ID+PD && - fh1d+oh*SH >= PH && fh1d+oh*SH < IH+PH && - fw1d+ow*SW >= PW && fw1d+ow*SW < IW+PW), + fd1d + od * SD >= PD && fd1d + od * SD < ID + PD && + fh1d + oh * SH >= PH && fh1d + oh * SH < IH + PH && + fw1d + ow * SW >= PW && fw1d + ow * SW < IW + PW), p2 = bool_as_mask( - fd2d+od*SD >= PD && fd2d+od*SD < ID+PD && - fh2d+oh*SH >= PH && fh2d+oh*SH < IH+PH && - fw2d+ow*SW >= PW && fw2d+ow*SW < IW+PW), + fd2d + od * SD >= PD && fd2d + od * SD < ID + PD && + fh2d + oh * SH >= PH && fh2d + oh * SH < IH + PH && + fw2d + ow * SW >= PW && fw2d + ow * SW < IW + PW), p3 = bool_as_mask( - fd3d+od*SD >= PD && fd3d+od*SD < ID+PD && - fh3d+oh*SH >= PH && fh3d+oh*SH < IH+PH && - fw3d+ow*SW >= PW && fw3d+ow*SW < IW+PW); - - localB[tidy][tidx].x = visit_with_mask(src, tmp+fp0, ok & p0); - localB[tidy][tidx].y = visit_with_mask(src, tmp+fp1, ok & p1); - localB[tidy][tidx].z = visit_with_mask(src, tmp+fp2, ok & p2); - localB[tidy][tidx].w = visit_with_mask(src, tmp+fp3, ok & p3); + fd3d + od * SD >= PD && fd3d + od * SD < ID + PD && + fh3d + oh * SH >= PH && fh3d + oh * SH < IH + PH && + fw3d + ow * SW >= PW && fw3d + ow * SW < IW + PW); + + localB[tidy][tidx].x = visit_with_mask(src, tmp + fp0, ok & p0); + localB[tidy][tidx].y = visit_with_mask(src, tmp + fp1, ok & p1); + localB[tidy][tidx].z = visit_with_mask(src, tmp + fp2, ok & p2); + localB[tidy][tidx].w = visit_with_mask(src, tmp + fp3, ok & p3); } - __syncthreads(); + __syncthreads(); for (uint32_t j = 0u; j < BM; ++j) { float4 tmpA = localA[tidy][j]; float4 tmpB = localB[j][tidx]; @@ -267,7 +257,6 @@ __global__ void conv_kernel(BufferFetcher diff, BufferFetcher src, sum3.y += tmpA.w * tmpB.y; sum3.z += tmpA.w * tmpB.z; sum3.w += tmpA.w * tmpB.w; - } oh += ohs; ow += ows; @@ -281,58 +270,69 @@ __global__ void conv_kernel(BufferFetcher diff, BufferFetcher src, odm -= (odm >= OP) * OP; __syncthreads(); } - + // widthB == IC*FD*FH*FW, heightA == OC const uint32_t grad_idx = posy2 * widthB + posx2; - bool y0 = (posy2+0 < heightA); - bool y1 = (posy2+1 < heightA); - bool y2 = (posy2+2 < heightA); - bool y3 = (posy2+3 < heightA); - bool x0 = (posx2+0 < widthB); - bool x1 = (posx2+1 < widthB); - bool x2 = (posx2+2 < widthB); - bool x3 = (posx2+3 < widthB); + bool y0 = (posy2 + 0 < heightA); + bool y1 = (posy2 + 1 < heightA); + bool y2 = (posy2 + 2 < heightA); + bool y3 = (posy2 + 3 < heightA); + bool x0 = (posx2 + 0 < widthB); + bool x1 = (posx2 + 1 < widthB); + bool x2 = (posx2 + 2 < widthB); + bool x3 = (posx2 + 3 < widthB); if (y0) { - if (x0) atomicAdd(&grad[grad_idx + 0*widthB + 0], sum0.x); - if (x1) atomicAdd(&grad[grad_idx + 0*widthB + 1], sum0.y); - if (x2) atomicAdd(&grad[grad_idx + 0*widthB + 2], sum0.z); - if (x3) atomicAdd(&grad[grad_idx + 0*widthB + 3], sum0.w); + if (x0) + atomicAdd(&grad[grad_idx + 0 * widthB + 0], sum0.x); + if (x1) + atomicAdd(&grad[grad_idx + 0 * widthB + 1], sum0.y); + if (x2) + atomicAdd(&grad[grad_idx + 0 * widthB + 2], sum0.z); + if (x3) + atomicAdd(&grad[grad_idx + 0 * widthB + 3], sum0.w); } if (y1) { - if (x0) atomicAdd(&grad[grad_idx + 1*widthB + 0], sum1.x); - if (x1) atomicAdd(&grad[grad_idx + 1*widthB + 1], sum1.y); - if (x2) atomicAdd(&grad[grad_idx + 1*widthB + 2], sum1.z); - if (x3) atomicAdd(&grad[grad_idx + 1*widthB + 3], sum1.w); + if (x0) + atomicAdd(&grad[grad_idx + 1 * widthB + 0], sum1.x); + if (x1) + atomicAdd(&grad[grad_idx + 1 * widthB + 1], sum1.y); + if (x2) + atomicAdd(&grad[grad_idx + 1 * widthB + 2], sum1.z); + if (x3) + atomicAdd(&grad[grad_idx + 1 * widthB + 3], sum1.w); } if (y2) { - if (x0) atomicAdd(&grad[grad_idx + 2*widthB + 0], sum2.x); - if (x1) atomicAdd(&grad[grad_idx + 2*widthB + 1], sum2.y); - if (x2) atomicAdd(&grad[grad_idx + 2*widthB + 2], sum2.z); - if (x3) atomicAdd(&grad[grad_idx + 2*widthB + 3], sum2.w); - } + if (x0) + atomicAdd(&grad[grad_idx + 2 * widthB + 0], sum2.x); + if (x1) + atomicAdd(&grad[grad_idx + 2 * widthB + 1], sum2.y); + if (x2) + atomicAdd(&grad[grad_idx + 2 * widthB + 2], sum2.z); + if (x3) + atomicAdd(&grad[grad_idx + 2 * widthB + 3], sum2.w); + } if (y3) { - if (x0) atomicAdd(&grad[grad_idx + 3*widthB + 0], sum3.x); - if (x1) atomicAdd(&grad[grad_idx + 3*widthB + 1], sum3.y); - if (x2) atomicAdd(&grad[grad_idx + 3*widthB + 2], sum3.z); - if (x3) atomicAdd(&grad[grad_idx + 3*widthB + 3], sum3.w); + if (x0) + atomicAdd(&grad[grad_idx + 3 * widthB + 0], sum3.x); + if (x1) + atomicAdd(&grad[grad_idx + 3 * widthB + 1], sum3.y); + if (x2) + atomicAdd(&grad[grad_idx + 3 * widthB + 2], sum3.z); + if (x3) + atomicAdd(&grad[grad_idx + 3 * widthB + 3], sum3.w); } } -} // anonymous namespace +} // anonymous namespace void convolution3d::exec_inplace_matmul_bwd_filter( - const float *diff, const float *src, float *grad, - size_t N, size_t INP_BS, size_t OUT_BS, - size_t IC, size_t ID, size_t IH, size_t IW, - size_t OC, size_t OD, size_t OH, size_t OW, - size_t FD, size_t FH, size_t FW, - size_t PD, size_t PH, size_t PW, - size_t SD, size_t SH, size_t SW, - size_t DD, size_t DH, size_t DW, - bool is_xcorr, - cudaStream_t stream) { - BufferFetcherTextureHost diff_tex(const_cast(diff), OC*OD*OH*OW*N), - src_tex(const_cast(src), N * INP_BS); + const float* diff, const float* src, float* grad, size_t N, size_t INP_BS, + size_t OUT_BS, size_t IC, size_t ID, size_t IH, size_t IW, size_t OC, size_t OD, + size_t OH, size_t OW, size_t FD, size_t FH, size_t FW, size_t PD, size_t PH, + size_t PW, size_t SD, size_t SH, size_t SW, size_t DD, size_t DH, size_t DW, + bool is_xcorr, cudaStream_t stream) { + BufferFetcherTextureHost diff_tex(const_cast(diff), OC * OD * OH * OW * N), + src_tex(const_cast(src), N * INP_BS); BufferFetcherRaw diff_buf, src_buf; src_buf.ptr = src; diff_buf.ptr = diff; @@ -341,74 +341,86 @@ void convolution3d::exec_inplace_matmul_bwd_filter( diff_tex.reset(); } int m = OC; - int n = IC*FD*FH*FW; + int n = IC * FD * FH * FW; int BY = 1; int BX = 1; if (m <= 64) { - while (BY < 16 && (BY<<2) < m) BY <<= 1; + while (BY < 16 && (BY << 2) < m) + BY <<= 1; BX = 256 / BY; } else if (n <= 64) { - while (BX < 16 && (BX<<2) < n) BX <<= 1; + while (BX < 16 && (BX << 2) < n) + BX <<= 1; BY = 256 / BX; } else { BX = BY = 16; } cudaMemset(grad, 0, OC * IC * FD * FH * FW * sizeof(float)); - dim3 blocks(DIVUP(n, 4*BX), DIVUP(m, 4*BY), N); + dim3 blocks(DIVUP(n, 4 * BX), DIVUP(m, 4 * BY), N); dim3 threads(BX, BY); -#define DISPATCH_BX_BY(BX, BY) do { \ - if (diff_tex.init_succ) { \ - KernelPtr::type kptr; \ - if (is_xcorr) { \ - kptr = conv_kernel; \ - } else { \ - kptr = conv_kernel; \ - } \ - kptr<<>>( \ - diff_tex.val, src_tex.val, grad, \ - N, INP_BS, OUT_BS, \ - IC, ID, IH, IW, \ - OC, OD, OH, OW, \ - FD, FH, FW, \ - SD, SH, SW, \ - PD, PH, PW, \ - DD, DH, DW); \ - } else { \ - KernelPtr::type kptr; \ - if (is_xcorr) { \ - kptr = conv_kernel; \ - } else { \ - kptr = conv_kernel; \ - } \ - kptr<<>>( \ - diff_buf, src_buf, grad, \ - N, INP_BS, OUT_BS, \ - IC, ID, IH, IW, \ - OC, OD, OH, OW, \ - FD, FH, FW, \ - SD, SH, SW, \ - PD, PH, PW, \ - DD, DH, DW); \ - } \ -} while (0) -#define DISPATCH_BX(BX) do { \ - DISPATCH_BX_BY(BX, 256/BX); \ -} while (0) -#define DISPATCH() do { \ - switch (BX) { \ - case 1: DISPATCH_BX(1); break; \ - case 2: DISPATCH_BX(2); break; \ - case 4: DISPATCH_BX(4); break; \ - case 8: DISPATCH_BX(8); break; \ - case 16: DISPATCH_BX(16); break; \ - case 32: DISPATCH_BX(32); break; \ - case 64: DISPATCH_BX(64); break; \ - case 128: DISPATCH_BX(128); break; \ - case 256: DISPATCH_BX(256); break; \ - default: \ - report_error("no usable kernel"); \ - } \ -} while (0) +#define DISPATCH_BX_BY(BX, BY) \ + do { \ + if (diff_tex.init_succ) { \ + KernelPtr::type kptr; \ + if (is_xcorr) { \ + kptr = conv_kernel; \ + } else { \ + kptr = conv_kernel; \ + } \ + kptr<<>>( \ + diff_tex.val, src_tex.val, grad, N, INP_BS, OUT_BS, IC, ID, IH, \ + IW, OC, OD, OH, OW, FD, FH, FW, SD, SH, SW, PD, PH, PW, DD, DH, \ + DW); \ + } else { \ + KernelPtr::type kptr; \ + if (is_xcorr) { \ + kptr = conv_kernel; \ + } else { \ + kptr = conv_kernel; \ + } \ + kptr<<>>( \ + diff_buf, src_buf, grad, N, INP_BS, OUT_BS, IC, ID, IH, IW, OC, \ + OD, OH, OW, FD, FH, FW, SD, SH, SW, PD, PH, PW, DD, DH, DW); \ + } \ + } while (0) +#define DISPATCH_BX(BX) \ + do { \ + DISPATCH_BX_BY(BX, 256 / BX); \ + } while (0) +#define DISPATCH() \ + do { \ + switch (BX) { \ + case 1: \ + DISPATCH_BX(1); \ + break; \ + case 2: \ + DISPATCH_BX(2); \ + break; \ + case 4: \ + DISPATCH_BX(4); \ + break; \ + case 8: \ + DISPATCH_BX(8); \ + break; \ + case 16: \ + DISPATCH_BX(16); \ + break; \ + case 32: \ + DISPATCH_BX(32); \ + break; \ + case 64: \ + DISPATCH_BX(64); \ + break; \ + case 128: \ + DISPATCH_BX(128); \ + break; \ + case 256: \ + DISPATCH_BX(256); \ + break; \ + default: \ + report_error("no usable kernel"); \ + } \ + } while (0) DISPATCH(); #undef DISPATCH #undef DISPATCH_BX @@ -417,4 +429,3 @@ void convolution3d::exec_inplace_matmul_bwd_filter( } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul_impl.cuh b/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul_impl.cuh index 97038206..9bdccdd1 100644 --- a/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul_impl.cuh +++ b/dnn/src/cuda/convolution3d/backward_filter/inplace_matmul_impl.cuh @@ -11,27 +11,22 @@ #pragma once #include -#include #include +#include namespace megdnn { namespace cuda { namespace convolution3d { void exec_inplace_matmul_bwd_filter( - const float *diff, const float *src, float *grad, - size_t N, size_t INP_BS, size_t OUT_BS, - size_t IC, size_t ID, size_t IH, size_t IW, - size_t OC, size_t OD, size_t OH, size_t OW, - size_t FD, size_t FH, size_t FW, - size_t PD, size_t PH, size_t PW, - size_t SD, size_t SH, size_t SW, - size_t DD, size_t DH, size_t DW, - bool is_xcorr, - cudaStream_t stream); + const float* diff, const float* src, float* grad, size_t N, size_t INP_BS, + size_t OUT_BS, size_t IC, size_t ID, size_t IH, size_t IW, size_t OC, size_t OD, + size_t OH, size_t OW, size_t FD, size_t FH, size_t FW, size_t PD, size_t PH, + size_t PW, size_t SD, size_t SH, size_t SW, size_t DD, size_t DH, size_t DW, + bool is_xcorr, cudaStream_t stream); -} // namespace convolution -} // namespace cuda -} // namespace megdnn +} // namespace convolution3d +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/chanwise/bwd_data.cu b/dnn/src/cuda/convolution3d/chanwise/bwd_data.cu index f5f6d37b..7fcbf3b6 100644 --- a/dnn/src/cuda/convolution3d/chanwise/bwd_data.cu +++ b/dnn/src/cuda/convolution3d/chanwise/bwd_data.cu @@ -19,34 +19,26 @@ using namespace chanwise; namespace { -template +template < + typename T, int CHL_MUL_SET, int FD_SET, int FH_SET, int FW_SET, int SD_SET, + int SH_SET, int SW_SET> __global__ void kern_bwd_data( - T *src_grad, const T *dst_grad, const T *flt_tot, Param param) { - + T* src_grad, const T* dst_grad, const T* flt_tot, Param param) { extern __shared__ uint8_t flt_storage[]; - T * const flt = reinterpret_cast(flt_storage); - - const uint32_t - N = param.batch, IC = param.src_chl, ic = blockIdx.x, - ID = param.src_d, IH = param.src_h, IW = param.src_w, - CHL_MUL = CHL_MUL_SET ? CHL_MUL_SET : param.chl_mul, - FD = FD_SET ? FD_SET : param.flt_d, - FH = FH_SET ? FH_SET : param.flt_h, - FW = FW_SET ? FW_SET : param.flt_w, - FSIZE = FD * FH * FW, - PD = param.pad_d, - PH = param.pad_h, - PW = param.pad_w, - SD = SD_SET ? SD_SET : param.stride_d, - SH = SH_SET ? SH_SET : param.stride_h, - SW = SW_SET ? SW_SET : param.stride_w, - OD = param.out_d, - OH = param.out_h, - OW = param.out_w, - TOT_OUT = N * ID * IH * IW; + T* const flt = reinterpret_cast(flt_storage); + + const uint32_t N = param.batch, IC = param.src_chl, ic = blockIdx.x, + ID = param.src_d, IH = param.src_h, IW = param.src_w, + CHL_MUL = CHL_MUL_SET ? CHL_MUL_SET : param.chl_mul, + FD = FD_SET ? FD_SET : param.flt_d, + FH = FH_SET ? FH_SET : param.flt_h, + FW = FW_SET ? FW_SET : param.flt_w, FSIZE = FD * FH * FW, + PD = param.pad_d, PH = param.pad_h, PW = param.pad_w, + SD = SD_SET ? SD_SET : param.stride_d, + SH = SH_SET ? SH_SET : param.stride_h, + SW = SW_SET ? SW_SET : param.stride_w, OD = param.out_d, + OH = param.out_h, OW = param.out_w, TOT_OUT = N * ID * IH * IW; block_memcpy(flt, flt_tot + ic * FSIZE * CHL_MUL, FSIZE * CHL_MUL); dst_grad += ic * CHL_MUL * OD * OH * OW; @@ -60,7 +52,7 @@ __global__ void kern_bwd_data( out_idx = div_mod(out_idx, IH, ih); out_idx = div_mod(out_idx, ID, id); n = out_idx; - const T *dst_grad_base = dst_grad + n * (IC * CHL_MUL * OD * OH * OW); + const T* dst_grad_base = dst_grad + n * (IC * CHL_MUL * OD * OH * OW); T sum(0); @@ -70,30 +62,28 @@ __global__ void kern_bwd_data( odmax = min((id + PD) / SD, OD - 1), ohmax = min((ih + PH) / SH, OH - 1), owmax = min((iw + PW) / SW, OW - 1); - if (SD_SET == 1 && SH_SET == 1 && SW_SET == 1 && - FD_SET && FH_SET && FW_SET) { + if (SD_SET == 1 && SH_SET == 1 && SW_SET == 1 && FD_SET && FH_SET && FW_SET) { #pragma unroll - for (uint32_t dod = 0; dod < FD; ++ dod) { + for (uint32_t dod = 0; dod < FD; ++dod) { uint32_t od = odmin + dod; if (od <= odmax) { uint32_t fd = id - od * SD + PD; #pragma unroll - for (uint32_t doh = 0; doh < FH; ++ doh) { + for (uint32_t doh = 0; doh < FH; ++doh) { uint32_t oh = ohmin + doh; if (oh <= ohmax) { uint32_t fh = ih - oh * SH + PH; #pragma unroll - for (uint32_t dow = 0; dow < FW; ++ dow) { + for (uint32_t dow = 0; dow < FW; ++dow) { uint32_t ow = owmin + dow; if (ow <= owmax) { uint32_t fw = iw - ow * SW + PW; - const T *pd = dst_grad_base + - od * OH * OW + oh * OW + ow; - const T *pf = flt + - fd * FH * FW + fh * FW + fw; + const T* pd = + dst_grad_base + od * OH * OW + oh * OW + ow; + const T* pf = flt + fd * FH * FW + fh * FW + fw; #pragma unroll for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; - ++ chl_mul) { + ++chl_mul) { sum += *pd * *pf; pd += OD * OH * OW; pf += FSIZE; @@ -102,21 +92,19 @@ __global__ void kern_bwd_data( } } } - } + } } } else { - for (uint32_t od = odmin; od <= odmax; ++ od) { + for (uint32_t od = odmin; od <= odmax; ++od) { uint32_t fd = id - od * SD + PD; - for (uint32_t oh = ohmin; oh <= ohmax; ++ oh) { + for (uint32_t oh = ohmin; oh <= ohmax; ++oh) { uint32_t fh = ih - oh * SH + PH; - for (uint32_t ow = owmin; ow <= owmax; ++ ow) { + for (uint32_t ow = owmin; ow <= owmax; ++ow) { uint32_t fw = iw - ow * SW + PW; - const T *pd = dst_grad_base + - od * OH * OW + oh * OW + ow; - const T *pf = flt + - fd * FH * FW + fh * FW + fw; + const T* pd = dst_grad_base + od * OH * OW + oh * OW + ow; + const T* pf = flt + fd * FH * FW + fh * FW + fw; #pragma unroll - for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++ chl_mul) { + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) { sum += *pd * *pf; pd += OD * OH * OW; pf += FSIZE; @@ -125,70 +113,64 @@ __global__ void kern_bwd_data( } } } - src_grad[n * IC * ID * IH * IW + - id * IH * IW + ih * IW + iw] = sum; + src_grad[n * IC * ID * IH * IW + id * IH * IW + ih * IW + iw] = sum; } } -template +template class KernDispatch { - public: - typedef void (*kern_ptr_t)(T*, const T*, const T*, Param); - - static kern_ptr_t dispatch( - int chl_mul, - int fd, int fh, int fw, - int sd, int sh, int sw) { - if (chl_mul == 1) { - if (fd == 2 && fh == 2 && fw == 2) - return d1<1, 2, 2, 2>(sd, sh, sw); - if (fd == 3 && fh == 3 && fw == 3) - return d1<1, 3, 3, 3>(sd, sh, sw); - } - return d1<0, 0, 0, 0>(sd, sh, sw); +public: + typedef void (*kern_ptr_t)(T*, const T*, const T*, Param); + + static kern_ptr_t dispatch( + int chl_mul, int fd, int fh, int fw, int sd, int sh, int sw) { + if (chl_mul == 1) { + if (fd == 2 && fh == 2 && fw == 2) + return d1<1, 2, 2, 2>(sd, sh, sw); + if (fd == 3 && fh == 3 && fw == 3) + return d1<1, 3, 3, 3>(sd, sh, sw); } + return d1<0, 0, 0, 0>(sd, sh, sw); + } - private: - template - static kern_ptr_t d1(int sd, int sh, int sw) { - if (sd == 1 && sh == 1 && sw == 1) - return kern_bwd_data; - if (sd == 1 && sh == 1 && sw == 2) - return kern_bwd_data; - if (sd == 1 && sh == 2 && sw == 1) - return kern_bwd_data; - if (sd == 1 && sh == 2 && sw == 2) - return kern_bwd_data; - if (sd == 2 && sh == 1 && sw == 1) - return kern_bwd_data; - if (sd == 2 && sh == 1 && sw == 2) - return kern_bwd_data; - if (sd == 2 && sh == 2 && sw == 1) - return kern_bwd_data; - if (sd == 2 && sh == 2 && sw == 2) - return kern_bwd_data; - return kern_bwd_data; - } +private: + template + static kern_ptr_t d1(int sd, int sh, int sw) { + if (sd == 1 && sh == 1 && sw == 1) + return kern_bwd_data; + if (sd == 1 && sh == 1 && sw == 2) + return kern_bwd_data; + if (sd == 1 && sh == 2 && sw == 1) + return kern_bwd_data; + if (sd == 1 && sh == 2 && sw == 2) + return kern_bwd_data; + if (sd == 2 && sh == 1 && sw == 1) + return kern_bwd_data; + if (sd == 2 && sh == 1 && sw == 2) + return kern_bwd_data; + if (sd == 2 && sh == 2 && sw == 1) + return kern_bwd_data; + if (sd == 2 && sh == 2 && sw == 2) + return kern_bwd_data; + return kern_bwd_data; + } }; -} // anonymous namespace +} // anonymous namespace -template -void chanwise::run_bwd_data(T *src_grad, const T *dst_grad, const T *flt, - const Param ¶m, cudaStream_t stream) { +template +void chanwise::run_bwd_data( + T* src_grad, const T* dst_grad, const T* flt, const Param& param, + cudaStream_t stream) { typename KernDispatch::kern_ptr_t kern = KernDispatch::dispatch( - param.chl_mul, - param.flt_d, param.flt_h, param.flt_w, - param.stride_d, param.stride_h, param.stride_w); + param.chl_mul, param.flt_d, param.flt_h, param.flt_w, param.stride_d, + param.stride_h, param.stride_w); int nr_thread = query_blocksize_for_kernel(kern), nr_out_dimx = param.src_d * param.src_h * param.src_w * param.batch; - dim3 nr_block( - param.src_chl, - std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); - uint32_t shared = param.chl_mul * param.flt_d * - param.flt_h * param.flt_w * sizeof(T); - kern <<< nr_block, nr_thread, shared, stream >>> ( - src_grad, dst_grad, flt, param); + dim3 nr_block(param.src_chl, std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); + uint32_t shared = + param.chl_mul * param.flt_d * param.flt_h * param.flt_w * sizeof(T); + kern<<>>(src_grad, dst_grad, flt, param); after_kernel_launch(); } @@ -197,8 +179,9 @@ namespace cuda { namespace convolution3d { namespace chanwise { -#define DO_INST(_ct) template void run_bwd_data( \ - _ct*, const _ct*, const _ct*, const Param&, cudaStream_t); +#define DO_INST(_ct) \ + template void run_bwd_data( \ + _ct*, const _ct*, const _ct*, const Param&, cudaStream_t); #define INST(_dt) DO_INST(DTypeTrait<_dt>::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST) @@ -206,10 +189,9 @@ MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST) #undef INST #undef DO_INST -} // namespace chanwise -} // namespace convolution3d -} // namespace cuda -} // namespace megdnn +} // namespace chanwise +} // namespace convolution3d +} // namespace cuda +} // namespace megdnn // vim: syntax=cuda.doxygen - diff --git a/dnn/src/cuda/convolution3d/chanwise/bwd_filter.cu b/dnn/src/cuda/convolution3d/chanwise/bwd_filter.cu index be9e5ab5..16918082 100644 --- a/dnn/src/cuda/convolution3d/chanwise/bwd_filter.cu +++ b/dnn/src/cuda/convolution3d/chanwise/bwd_filter.cu @@ -21,23 +21,19 @@ using namespace chanwise; namespace { -template +template __global__ void kern_bwd_filter( - T *flt_grad, const T *src, const T *dst_grad, Param param) { - - const uint32_t - N = param.batch, IC = param.src_chl, - ID = param.src_d, IH = param.src_h, IW = param.src_w, - CHL_MUL = param.chl_mul, - FD = param.flt_d, FH = param.flt_h, FW = param.flt_w, - PD = param.pad_d, PH = param.pad_h, PW = param.pad_w, - SD = param.stride_d, SH = param.stride_h, SW = param.stride_w, - OD = param.out_d, OH = param.out_h, OW = param.out_w, - SRC_BATCH_STRIDE = IC * ID * IH * IW, - DST_BATCH_STRIDE = IC * CHL_MUL * OD * OH * OW, - BLKDIM_X = blockDim.x / nr_thpf, - THREADID_X = threadIdx.x / nr_thpf, - OUT_IDX = blockIdx.x * BLKDIM_X + THREADID_X; + T* flt_grad, const T* src, const T* dst_grad, Param param) { + const uint32_t N = param.batch, IC = param.src_chl, ID = param.src_d, + IH = param.src_h, IW = param.src_w, CHL_MUL = param.chl_mul, + FD = param.flt_d, FH = param.flt_h, FW = param.flt_w, + PD = param.pad_d, PH = param.pad_h, PW = param.pad_w, + SD = param.stride_d, SH = param.stride_h, SW = param.stride_w, + OD = param.out_d, OH = param.out_h, OW = param.out_w, + SRC_BATCH_STRIDE = IC * ID * IH * IW, + DST_BATCH_STRIDE = IC * CHL_MUL * OD * OH * OW, + BLKDIM_X = blockDim.x / nr_thpf, THREADID_X = threadIdx.x / nr_thpf, + OUT_IDX = blockIdx.x * BLKDIM_X + THREADID_X; uint32_t ic, chl_mul, fd, fh, fw; { @@ -54,21 +50,18 @@ __global__ void kern_bwd_filter( src += ic * ID * IH * IW; dst_grad += (ic * CHL_MUL + chl_mul) * OD * OH * OW; - const uint32_t - od_lo = max(int32_t(PD - fd + SD - 1), 0) / SD, - od_hi = min((ID - 1 + PD - fd) / SD + 1, OD), - oh_lo = max(int32_t(PH - fh + SH - 1), 0) / SH, - oh_hi = min((IH - 1 + PH - fh) / SH + 1, OH), - ow_lo = max(int32_t(PW - fw + SW - 1), 0) / SW, - ow_hi = min((IW - 1 + PW - fw) / SW + 1, OW), - oblk_d = od_hi - od_lo, - oblk_h = oh_hi - oh_lo, - oblk_w = ow_hi - ow_lo, - oblk_tot = oblk_d * oblk_h * oblk_w * ((N + BATCH_UNROLL - 1) / BATCH_UNROLL), - tid = threadIdx.x % nr_thpf; - - if (ID + PD < fd + 1 || od_lo >= od_hi || - IH + PH < fh + 1 || oh_lo >= oh_hi || + const uint32_t od_lo = max(int32_t(PD - fd + SD - 1), 0) / SD, + od_hi = min((ID - 1 + PD - fd) / SD + 1, OD), + oh_lo = max(int32_t(PH - fh + SH - 1), 0) / SH, + oh_hi = min((IH - 1 + PH - fh) / SH + 1, OH), + ow_lo = max(int32_t(PW - fw + SW - 1), 0) / SW, + ow_hi = min((IW - 1 + PW - fw) / SW + 1, OW), oblk_d = od_hi - od_lo, + oblk_h = oh_hi - oh_lo, oblk_w = ow_hi - ow_lo, + oblk_tot = oblk_d * oblk_h * oblk_w * + ((N + BATCH_UNROLL - 1) / BATCH_UNROLL), + tid = threadIdx.x % nr_thpf; + + if (ID + PD < fd + 1 || od_lo >= od_hi || IH + PH < fh + 1 || oh_lo >= oh_hi || IW + PW < fw + 1 || ow_lo >= ow_hi) { if (!tid) flt_grad[OUT_IDX] = 0; @@ -78,17 +71,16 @@ __global__ void kern_bwd_filter( T sum(0); for (uint32_t oblk_idx = tid; oblk_idx < oblk_tot; oblk_idx += nr_thpf) { uint32_t n, oh, ow, od; - n = div_mod(div_mod(div_mod(oblk_idx, oblk_w, ow), oblk_h, oh), oblk_d, od) * BATCH_UNROLL; + n = div_mod(div_mod(div_mod(oblk_idx, oblk_w, ow), oblk_h, oh), oblk_d, od) * + BATCH_UNROLL; od += od_lo; oh += oh_lo; ow += ow_lo; - uint32_t id = od * SD - PD + fd, - ih = oh * SH - PH + fh, - iw = ow * SW - PW + fw, + uint32_t id = od * SD - PD + fd, ih = oh * SH - PH + fh, iw = ow * SW - PW + fw, soff = id * IH * IW + ih * IW + iw + n * SRC_BATCH_STRIDE, doff = od * OH * OW + oh * OW + ow + n * DST_BATCH_STRIDE; #pragma unroll - for (uint32_t i = 0; i < BATCH_UNROLL; ++ i) { + for (uint32_t i = 0; i < BATCH_UNROLL; ++i) { if (!i || n + i < N) { sum += src[soff] * dst_grad[doff]; } @@ -102,7 +94,7 @@ __global__ void kern_bwd_filter( } else { // reduce all sums in a block extern __shared__ uint8_t shared_storage[]; - volatile T *thread_sum = reinterpret_cast(shared_storage); + volatile T* thread_sum = reinterpret_cast(shared_storage); thread_sum += THREADID_X * nr_thpf; thread_sum[tid] = sum; #pragma unroll @@ -111,8 +103,7 @@ __global__ void kern_bwd_filter( if (i >= WARP_SIZE) { __syncthreads(); } - T v0 = thread_sum[tid], - v1 = v0 + thread_sum[tid + i]; + T v0 = thread_sum[tid], v1 = v0 + thread_sum[tid + i]; thread_sum[tid] = cond ? v1 : v0; } @@ -121,59 +112,55 @@ __global__ void kern_bwd_filter( } } -} // anonymous namespace +} // anonymous namespace -template +template void convolution3d::chanwise::run_bwd_filter( - T *filter_grad, const T *src, const T *dst_grad, - const Param ¶m, cudaStream_t stream) { + T* filter_grad, const T* src, const T* dst_grad, const Param& param, + cudaStream_t stream) { void (*kern)(T*, const T*, const T*, Param) = NULL; - uint32_t - nr_thread = query_blocksize_for_kernel(kern_bwd_filter), - nr_thpf = std::min(nr_thread, - std::max( - 1, - param.out_d * param.out_h * param.out_w * param.batch / - (BATCH_UNROLL * 16))); + uint32_t nr_thread = query_blocksize_for_kernel(kern_bwd_filter), + nr_thpf = std::min( + nr_thread, std::max( + 1, param.out_d * param.out_h * param.out_w * + param.batch / (BATCH_UNROLL * 16))); // find nearest power-of-2 of nr_thpf do { -#define CK(_n) \ - if(nr_thpf >= _n) { \ - kern = kern_bwd_filter; \ - nr_thpf = _n; \ - break; \ - } - CK(1<<10); - CK(1<<9); - CK(1<<8); - CK(1<<7); - CK(1<<6); - CK(1<<5); - CK(1<<4); - CK(1<<3); - CK(1<<2); - CK(1<<1); - CK(1<<0); +#define CK(_n) \ + if (nr_thpf >= _n) { \ + kern = kern_bwd_filter; \ + nr_thpf = _n; \ + break; \ + } + CK(1 << 10); + CK(1 << 9); + CK(1 << 8); + CK(1 << 7); + CK(1 << 6); + CK(1 << 5); + CK(1 << 4); + CK(1 << 3); + CK(1 << 2); + CK(1 << 1); + CK(1 << 0); #undef CK - } while(0); + } while (0); megdnn_assert(kern); nr_thread = query_blocksize_for_kernel(kern); uint32_t nr_flt_per_blk = nr_thread / nr_thpf; while (nr_flt_per_blk * nr_thpf % WARP_SIZE) - -- nr_flt_per_blk; + --nr_flt_per_blk; megdnn_assert(nr_flt_per_blk); int nr_block = DIVUP( - param.flt_d * param.flt_h * param.flt_w * - param.src_chl * param.chl_mul, + param.flt_d * param.flt_h * param.flt_w * param.src_chl * param.chl_mul, nr_flt_per_blk); nr_thread = nr_flt_per_blk * nr_thpf; uint32_t shared = nr_thread * 2 * sizeof(T); - kern <<< nr_block, nr_thread, shared, stream >>> ( - filter_grad, src, dst_grad, param); + kern<<>>(filter_grad, src, dst_grad, param); after_kernel_launch(); } @@ -182,8 +169,9 @@ namespace cuda { namespace convolution3d { namespace chanwise { -#define DO_INST(_ct) template void run_bwd_filter( \ - _ct*, const _ct*, const _ct*, const Param&, cudaStream_t); +#define DO_INST(_ct) \ + template void run_bwd_filter( \ + _ct*, const _ct*, const _ct*, const Param&, cudaStream_t); #define INST(_dt) DO_INST(DTypeTrait<_dt>::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST) @@ -191,11 +179,9 @@ MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST) #undef INST #undef DO_INST -} // namespace chanwise -} // namespace convolution3d -} // namespace cuda -} // namespace megdnn - +} // namespace chanwise +} // namespace convolution3d +} // namespace cuda +} // namespace megdnn // vim: syntax=cuda.doxygen - diff --git a/dnn/src/cuda/convolution3d/chanwise/fwd.cu b/dnn/src/cuda/convolution3d/chanwise/fwd.cu index 17b71a92..c61d7b8c 100644 --- a/dnn/src/cuda/convolution3d/chanwise/fwd.cu +++ b/dnn/src/cuda/convolution3d/chanwise/fwd.cu @@ -19,27 +19,23 @@ using namespace chanwise; namespace { -template -__global__ void kern_fwd( - T *dst, const T *src, const T *flt_tot, Param param) { - +template +__global__ void kern_fwd(T* dst, const T* src, const T* flt_tot, Param param) { // extern __shared__ of dt_float16 does not work extern __shared__ uint8_t flt_storage[]; - T * const flt = reinterpret_cast(flt_storage); - - const uint32_t - N = param.batch, IC = param.src_chl, ic = blockIdx.x, - ID = param.src_d, IH = param.src_h, IW = param.src_w, - CHL_MUL = CHL_MUL_SET ? CHL_MUL_SET : param.chl_mul, - FD = FD_SET ? FD_SET : param.flt_d, - FH = FH_SET ? FH_SET : param.flt_h, - FW = FW_SET ? FW_SET : param.flt_w, - FSIZE = FD * FH * FW, - PD = param.pad_d, PH = param.pad_h, PW = param.pad_w, - SD = param.stride_d, SH = param.stride_h, SW = param.stride_w, - OD = param.out_d, OH = param.out_h, OW = param.out_w, - TOT_OUT = N * CHL_MUL * OD * OH * OW; + T* const flt = reinterpret_cast(flt_storage); + + const uint32_t N = param.batch, IC = param.src_chl, ic = blockIdx.x, + ID = param.src_d, IH = param.src_h, IW = param.src_w, + CHL_MUL = CHL_MUL_SET ? CHL_MUL_SET : param.chl_mul, + FD = FD_SET ? FD_SET : param.flt_d, + FH = FH_SET ? FH_SET : param.flt_h, + FW = FW_SET ? FW_SET : param.flt_w, FSIZE = FD * FH * FW, + PD = param.pad_d, PH = param.pad_h, PW = param.pad_w, + SD = param.stride_d, SH = param.stride_h, SW = param.stride_w, + OD = param.out_d, OH = param.out_h, OW = param.out_w, + TOT_OUT = N * CHL_MUL * OD * OH * OW; block_memcpy(flt, flt_tot + ic * FSIZE * CHL_MUL, FSIZE * CHL_MUL); @@ -58,8 +54,7 @@ __global__ void kern_fwd( n = div_mod(out_idx, CHL_MUL, chl_mul); } - int id = int(od * SD) - int(PD), - ih = int(oh * SH) - int(PH), + int id = int(od * SD) - int(PD), ih = int(oh * SH) - int(PH), iw = int(ow * SW) - int(PW); const T* flt_base = flt + chl_mul * FSIZE; @@ -69,14 +64,14 @@ __global__ void kern_fwd( if (FD_SET && FH_SET && FW_SET) { #pragma unroll - for (uint32_t fd = 0; fd < FD; ++ fd) { + for (uint32_t fd = 0; fd < FD; ++fd) { // fh + ih < 0 would overflow, so we do not need to check it if (static_cast(fd + id) < ID) { #pragma unroll - for (uint32_t fh = 0; fh < FH; ++ fh) { + for (uint32_t fh = 0; fh < FH; ++fh) { if (static_cast(fh + ih) < IH) { #pragma unroll - for(uint32_t fw = 0; fw < FW; ++ fw) { + for (uint32_t fw = 0; fw < FW; ++fw) { if (static_cast(fw + iw) < IW) { sum += flt_base[fd * FH * FW + fh * FW + fw] * src_base[fd * IH * IW + fh * IW + fw]; @@ -87,29 +82,27 @@ __global__ void kern_fwd( } } } else { - int fdmax = min(int(FD), int(ID - id)), - fhmax = min(int(FH), int(IH - ih)), + int fdmax = min(int(FD), int(ID - id)), fhmax = min(int(FH), int(IH - ih)), fwmax = min(int(FW), int(IW - iw)); - for (int fd = max(0, -id); fd < fdmax; ++ fd) { - for (int fh = max(0, -ih); fh < fhmax; ++ fh) { - for (int fw = max(0, -iw); fw < fwmax; ++ fw) { - sum += flt_base[fd * FH * FW + fh * FW + fw] * + for (int fd = max(0, -id); fd < fdmax; ++fd) { + for (int fh = max(0, -ih); fh < fhmax; ++fh) { + for (int fw = max(0, -iw); fw < fwmax; ++fw) { + sum += flt_base[fd * FH * FW + fh * FW + fw] * src_base[fd * IH * IW + fh * IW + fw]; } } } } dst[((((n * IC + ic) * CHL_MUL + chl_mul) * OD + od) * OH + oh) * OW + ow] = - sum; + sum; } } -} // anonymous namespace +} // anonymous namespace -template +template void chanwise::run_fwd( - T *dst, const T *src, const T *flt, const Param ¶m, - cudaStream_t stream) { + T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream) { void (*kern)(T*, const T*, const T*, Param); if (param.chl_mul == 1) { if (param.flt_d == 2 && param.flt_h == 2 && param.flt_w == 2) { @@ -122,15 +115,14 @@ void chanwise::run_fwd( } else { kern = kern_fwd; } - + int nr_thread = query_blocksize_for_kernel(kern), nr_out_dimx = - param.out_d * param.out_h * param.out_w * param.batch * param.chl_mul; - dim3 nr_block( - param.src_chl, - std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); - uint32_t shared = param.chl_mul * param.flt_d * param.flt_h * param.flt_w * sizeof(T); - kern <<< nr_block, nr_thread, shared, stream >>> (dst, src, flt, param); + param.out_d * param.out_h * param.out_w * param.batch * param.chl_mul; + dim3 nr_block(param.src_chl, std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); + uint32_t shared = + param.chl_mul * param.flt_d * param.flt_h * param.flt_w * sizeof(T); + kern<<>>(dst, src, flt, param); after_kernel_launch(); } @@ -139,8 +131,8 @@ namespace cuda { namespace convolution3d { namespace chanwise { -#define DO_INST(_ct) template void run_fwd( \ - _ct*, const _ct*, const _ct*, const Param&, cudaStream_t); +#define DO_INST(_ct) \ + template void run_fwd(_ct*, const _ct*, const _ct*, const Param&, cudaStream_t); #define INST(_dt) DO_INST(DTypeTrait<_dt>::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST) @@ -148,10 +140,9 @@ MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST) #undef INST #undef DO_INST -} // namespace chanwise -} // namespace convolution3d -} // namespace cuda -} // namespace megdnn +} // namespace chanwise +} // namespace convolution3d +} // namespace cuda +} // namespace megdnn // vim: syntax=cuda.doxygen - diff --git a/dnn/src/cuda/convolution3d/chanwise/kern.cuh b/dnn/src/cuda/convolution3d/chanwise/kern.cuh index 1e186135..d0e56547 100644 --- a/dnn/src/cuda/convolution3d/chanwise/kern.cuh +++ b/dnn/src/cuda/convolution3d/chanwise/kern.cuh @@ -12,8 +12,8 @@ #include "src/cuda/utils.cuh" -#include #include +#include #if MEGDNN_CC_HOST #include "src/cuda/convolution3d/helper.h" @@ -24,60 +24,55 @@ namespace cuda { namespace convolution3d { namespace chanwise { - struct Param { - uint32_t batch, src_chl, - src_d, src_h, src_w, - chl_mul, - flt_d, flt_h, flt_w, - out_d, out_h, out_w, - pad_d, pad_h, pad_w, - stride_d, stride_h, stride_w, - dilation_d, dilation_h, dilation_w; +struct Param { + uint32_t batch, src_chl, src_d, src_h, src_w, chl_mul, flt_d, flt_h, flt_w, out_d, + out_h, out_w, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, dilation_d, + dilation_h, dilation_w; #if MEGDNN_CC_HOST - static Param from_fwd_args(const ForwardSizeArgs &args) { + static Param from_fwd_args(const ForwardSizeArgs& args) { #define U(v) static_cast(v) - auto &&src = args.src_layout->shape; - auto &&dst = args.dst_layout->shape; - auto &&fm = args.filter_meta; - size_t c_pos, hw_pos; - if (fm.format == param::Convolution3D::Format::NCDHW) { - c_pos = 1; - hw_pos = 2; - } else { //NDHWC - c_pos = 4; - hw_pos = 1; - } - return { - U(src[0]), U(src[c_pos]), - U(src[hw_pos]), U(src[hw_pos+1]), U(src[hw_pos+2]), - U(fm.ocpg), - U(fm.spatial[0]), U(fm.spatial[1]), U(fm.spatial[2]), - U(dst[hw_pos]), U(dst[hw_pos+1]), U(dst[hw_pos+2]), - U(fm.padding[0]), U(fm.padding[1]), U(fm.padding[2]), - U(fm.stride[0]), U(fm.stride[1]), U(fm.stride[2]), - U(fm.dilation[0]), U(fm.dilation[1]), U(fm.dilation[2]), - }; -#undef U + auto&& src = args.src_layout->shape; + auto&& dst = args.dst_layout->shape; + auto&& fm = args.filter_meta; + size_t c_pos, hw_pos; + if (fm.format == param::Convolution3D::Format::NCDHW) { + c_pos = 1; + hw_pos = 2; + } else { // NDHWC + c_pos = 4; + hw_pos = 1; } + return { + U(src[0]), U(src[c_pos]), U(src[hw_pos]), + U(src[hw_pos + 1]), U(src[hw_pos + 2]), U(fm.ocpg), + U(fm.spatial[0]), U(fm.spatial[1]), U(fm.spatial[2]), + U(dst[hw_pos]), U(dst[hw_pos + 1]), U(dst[hw_pos + 2]), + U(fm.padding[0]), U(fm.padding[1]), U(fm.padding[2]), + U(fm.stride[0]), U(fm.stride[1]), U(fm.stride[2]), + U(fm.dilation[0]), U(fm.dilation[1]), U(fm.dilation[2]), + }; +#undef U + } #endif - }; +}; - template - void run_fwd(T *dst, const T *src, const T *flt, const Param ¶m, - cudaStream_t stream); - - template - void run_bwd_data(T *src_grad, const T *dst_grad, const T *flt, - const Param ¶m, cudaStream_t stream); +template +void run_fwd( + T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream); - template - void run_bwd_filter(T *filter_grad, const T *src, const T *dst_grad, - const Param ¶m, cudaStream_t stream); +template +void run_bwd_data( + T* src_grad, const T* dst_grad, const T* flt, const Param& param, + cudaStream_t stream); -} // namespace chanwise -} // namespace convolution -} // namespace cuda -} // namespace megdnn +template +void run_bwd_filter( + T* filter_grad, const T* src, const T* dst_grad, const Param& param, + cudaStream_t stream); -// vim: ft=cpp syntax=cpp.doxygen +} // namespace chanwise +} // namespace convolution3d +} // namespace cuda +} // namespace megdnn +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/chanwise/kern_helper.cuh b/dnn/src/cuda/convolution3d/chanwise/kern_helper.cuh index e44d1cd0..15ff7fd4 100644 --- a/dnn/src/cuda/convolution3d/chanwise/kern_helper.cuh +++ b/dnn/src/cuda/convolution3d/chanwise/kern_helper.cuh @@ -10,9 +10,9 @@ */ #pragma once +#include "megdnn/dtype.h" #include "src/cuda/query_blocksize.cuh" #include "src/cuda/utils.cuh" -#include "megdnn/dtype.h" #include #include @@ -23,33 +23,30 @@ namespace cuda { namespace convolution3d { namespace chanwise { - /*! - * \brief return a / b and set mod to a % b - */ - __device__ __forceinline__ uint32_t div_mod( - uint32_t a, uint32_t b, uint32_t &mod) { - uint32_t ret = a / b; - mod = a - ret * b; - return ret; - } - - /*! - * \brief copy a 2D matrix by all threads in a block - * \param rs row stride - */ - template - __device__ __forceinline__ void block_memcpy( - T *dst, const T *src, uint32_t size) { - for (uint32_t i = threadIdx.x; i < size; i += blockDim.x) { - dst[i] = src[i]; - } - __syncthreads(); +/*! + * \brief return a / b and set mod to a % b + */ +__device__ __forceinline__ uint32_t div_mod(uint32_t a, uint32_t b, uint32_t& mod) { + uint32_t ret = a / b; + mod = a - ret * b; + return ret; +} + +/*! + * \brief copy a 2D matrix by all threads in a block + * \param rs row stride + */ +template +__device__ __forceinline__ void block_memcpy(T* dst, const T* src, uint32_t size) { + for (uint32_t i = threadIdx.x; i < size; i += blockDim.x) { + dst[i] = src[i]; } + __syncthreads(); +} -} // namespace chanwise -} // namespace convolution3d -} // namespace cuda -} // namespace megdnn +} // namespace chanwise +} // namespace convolution3d +} // namespace cuda +} // namespace megdnn // vim: syntax=cuda.doxygen - diff --git a/dnn/src/cuda/convolution3d/forward/1x1x1.cpp b/dnn/src/cuda/convolution3d/forward/1x1x1.cpp index a7a12c3d..a5cb327b 100644 --- a/dnn/src/cuda/convolution3d/forward/1x1x1.cpp +++ b/dnn/src/cuda/convolution3d/forward/1x1x1.cpp @@ -16,47 +16,42 @@ using namespace megdnn; using namespace cuda; using namespace convolution3d; -bool Convolution3DForwardImpl::Algo1x1x1::is_available( - const SizeArgs &args) const { - auto &&fm = args.filter_meta; - const size_t MAX_WORKSPACE_SIZE = 2147483648; // 2 * 1024^3 +bool Convolution3DForwardImpl::Algo1x1x1::is_available(const SizeArgs& args) const { + auto&& fm = args.filter_meta; + const size_t MAX_WORKSPACE_SIZE = 2147483648; // 2 * 1024^3 if (get_workspace_in_bytes(args) > MAX_WORKSPACE_SIZE) { return false; } return fm.format == Param::Format::NCDHW && - (fm.dtype_enum == DTypeEnum::Float32 || - fm.dtype_enum == DTypeEnum::Float16) && - fm.spatial_ndim == 3 && fm.group == 1 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.dilation[2] == 1 && - fm.spatial[0] == 1 && fm.spatial[1] == 1 && - fm.spatial[2] == 1 && - fm.padding[0] == 0 && fm.padding[1] == 0 && - fm.padding[2] == 0 && - fm.stride[0] == 1 && fm.stride[1] == 1 && - fm.stride[2] == 1; + (fm.dtype_enum == DTypeEnum::Float32 || + fm.dtype_enum == DTypeEnum::Float16) && + fm.spatial_ndim == 3 && fm.group == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.dilation[2] == 1 && fm.spatial[0] == 1 && + fm.spatial[1] == 1 && fm.spatial[2] == 1 && fm.padding[0] == 0 && + fm.padding[1] == 0 && fm.padding[2] == 0 && fm.stride[0] == 1 && + fm.stride[1] == 1 && fm.stride[2] == 1; } void Convolution3DForwardImpl::Algo1x1x1::extract_matmul_layouts( - const SizeArgs &args, - TensorLayout &A, TensorLayout &B, TensorLayout &C) { - auto &&fm = args.filter_meta; + const SizeArgs& args, TensorLayout& A, TensorLayout& B, TensorLayout& C) { + auto&& fm = args.filter_meta; A = {{fm.ocpg, fm.icpg}, DType::from_enum(fm.dtype_enum)}; B.ndim = 2; B.shape[0] = args.src_layout->shape[1]; - B.shape[1] = args.src_layout->shape[2] * args.src_layout->shape[3] * args.src_layout->shape[4]; + B.shape[1] = args.src_layout->shape[2] * args.src_layout->shape[3] * + args.src_layout->shape[4]; B.stride[0] = args.src_layout->stride[1]; B.stride[1] = 1; B.dtype = args.src_layout->dtype; C = {{args.dst_layout->shape[1], B.shape[1]}, args.dst_layout->dtype}; } size_t Convolution3DForwardImpl::Algo1x1x1::get_workspace_in_bytes( - const SizeArgs &args) const { + const SizeArgs& args) const { TensorLayout A, B, C; extract_matmul_layouts(args, A, B, C); return args.handle->matmul_opr()->get_workspace_in_bytes(A, B, C); } -void Convolution3DForwardImpl::Algo1x1x1::exec(const ExecArgs &args) const { +void Convolution3DForwardImpl::Algo1x1x1::exec(const ExecArgs& args) const { TensorND A, B, C; extract_matmul_layouts(args, A.layout, B.layout, C.layout); A.raw_ptr = args.filter_tensor->raw_ptr; @@ -66,11 +61,10 @@ void Convolution3DForwardImpl::Algo1x1x1::exec(const ExecArgs &args) const { auto mm = args.handle->matmul_opr(); auto strd_B = args.src_layout->stride[0] * args.src_layout->dtype.size(), strd_C = args.dst_layout->stride[0] * args.dst_layout->dtype.size(); - for (size_t i = 0; i < batch; ++ i) { + for (size_t i = 0; i < batch; ++i) { mm->exec(A, B, C, args.workspace); incr_voidp(B.raw_ptr, strd_B); incr_voidp(C.raw_ptr, strd_C); } } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution3d/forward/algo.cpp b/dnn/src/cuda/convolution3d/forward/algo.cpp index 3c3f7498..503fd0e9 100644 --- a/dnn/src/cuda/convolution3d/forward/algo.cpp +++ b/dnn/src/cuda/convolution3d/forward/algo.cpp @@ -23,8 +23,8 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() { all_algos.push_back(&chanwise); fill_cudnn_algos(); - for (auto &&i: cudnn) { - all_algos.push_back(&i); + for (auto&& i : cudnn) { + all_algos.push_back(&i); } all_algos.push_back(&inplace_matmul); all_algos.push_back(&a1x1x1); @@ -37,15 +37,14 @@ Convolution3DForwardImpl::AlgoPack::AlgoPack() { MEGDNN_DEF_GET_ALGO_FROM_DESC(Convolution3DForwardImpl) -Convolution3DForwardImpl::AlgoCUDNN* -Convolution3DForwardImpl::AlgoPack::cudnn_from_enum( - cudnnConvolutionFwdAlgo_t algo) { - for (auto &&i: cudnn) { +Convolution3DForwardImpl::AlgoCUDNN* Convolution3DForwardImpl::AlgoPack:: + cudnn_from_enum(cudnnConvolutionFwdAlgo_t algo) { + for (auto&& i : cudnn) { if (i.cudnn_enum() == algo) return &i; } - megdnn_throw(ssprintf("can not find cudnn fwd algorithm %d", - static_cast(algo))); + megdnn_throw( + ssprintf("can not find cudnn fwd algorithm %d", static_cast(algo))); } Convolution3DForwardImpl::AlgoPack Convolution3DForwardImpl::sm_algo_pack; @@ -53,8 +52,9 @@ Convolution3DForwardImpl::AlgoPack Convolution3DForwardImpl::sm_algo_pack; Convolution3DForwardImpl::AlgoBase::SizeArgs::SizeArgs( const Convolution3DForwardImpl* o, const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) - : SizeArgs(o, src, filter, - o->make_canonized_filter_meta(src.ndim, filter), dst) {} + : SizeArgs( + o, src, filter, o->make_canonized_filter_meta(src.ndim, filter), + dst) {} Convolution3DForwardImpl::AlgoBase::SizeArgs::SizeArgs( const Convolution3DForwardImpl* o, const TensorLayout& src, @@ -69,30 +69,26 @@ Convolution3DForwardImpl::AlgoBase::SizeArgs::SizeArgs( opr{o} {} Convolution3DForwardImpl::AlgoBase::ExecArgs::ExecArgs( - const Convolution3DForwardImpl *opr, - _megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace): - SizeArgs(opr, src.layout, filter.layout, dst.layout), - src_tensor{&src}, filter_tensor{&filter}, dst_tensor{&dst}, - workspace{workspace} -{ -} + const Convolution3DForwardImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in filter, _megdnn_tensor_out dst, _megdnn_workspace workspace) + : SizeArgs(opr, src.layout, filter.layout, dst.layout), + src_tensor{&src}, + filter_tensor{&filter}, + dst_tensor{&dst}, + workspace{workspace} {} std::string Convolution3DForwardImpl::AlgoBase::SizeArgs::to_string() const { - auto &&fm = filter_meta; + auto&& fm = filter_meta; MEGDNN_MARK_USED_VAR(fm); return ssprintf( "src=%s, filter=%u{%u,%u,%u,%u,%u}, dst=%s, " "pad=%ux%ux%u, stride=%ux%ux%u, dilate=%ux%ux%u, xcorr=%d, " "dtype=%s,%s", - src_layout->to_string().c_str(), fm.group, fm.ocpg, fm.icpg, - fm.spatial[0], fm.spatial[1], fm.spatial[2], - dst_layout->to_string().c_str(), fm.padding[0], fm.padding[1], - fm.padding[2], fm.stride[0], fm.stride[1], fm.stride[2], - fm.dilation[0], fm.dilation[1], fm.dilation[2], !fm.should_flip, - src_layout->dtype.name(), dst_layout->dtype.name()); + src_layout->to_string().c_str(), fm.group, fm.ocpg, fm.icpg, fm.spatial[0], + fm.spatial[1], fm.spatial[2], dst_layout->to_string().c_str(), + fm.padding[0], fm.padding[1], fm.padding[2], fm.stride[0], fm.stride[1], + fm.stride[2], fm.dilation[0], fm.dilation[1], fm.dilation[2], + !fm.should_flip, src_layout->dtype.name(), dst_layout->dtype.name()); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/forward/algo.h b/dnn/src/cuda/convolution3d/forward/algo.h index 2fdfcbc0..f13f6532 100644 --- a/dnn/src/cuda/convolution3d/forward/algo.h +++ b/dnn/src/cuda/convolution3d/forward/algo.h @@ -14,12 +14,12 @@ #include "megdnn/oprs.h" +#include "src/common/algo_base.h" +#include "src/common/metahelper.h" #include "src/common/utils.h" #include "src/cuda/convolution3d/helper.h" #include "src/cuda/convolution3d/opr_impl.h" #include "src/cuda/handle.h" -#include "src/common/algo_base.h" -#include "src/common/metahelper.h" #include @@ -54,20 +54,22 @@ public: void init_desc(convolution3d::CUDNNForwardDescs& desc) const { desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); } - SizeArgs(const Convolution3DForwardImpl* opr, const TensorLayout& src, - const TensorLayout& filter, const TensorLayout& dst); - SizeArgs(const Convolution3DForwardImpl* opr, const TensorLayout& src, - const TensorLayout& filter, - const CanonizedFilterMeta& filter_meta, - const TensorLayout& dst); + SizeArgs( + const Convolution3DForwardImpl* opr, const TensorLayout& src, + const TensorLayout& filter, const TensorLayout& dst); + SizeArgs( + const Convolution3DForwardImpl* opr, const TensorLayout& src, + const TensorLayout& filter, const CanonizedFilterMeta& filter_meta, + const TensorLayout& dst); }; struct ExecArgs : public SizeArgs { const TensorND *src_tensor, *filter_tensor, *dst_tensor; Workspace workspace; - ExecArgs(const Convolution3DForwardImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_in filter, _megdnn_tensor_out dst, - _megdnn_workspace workspace); + ExecArgs( + const Convolution3DForwardImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -82,24 +84,22 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); megdnn_assert( req <= workspace.size, - "conv3d fwd algo %s: required workspace %zu bytes, got %zu", - name(), req, workspace.size); + "conv3d fwd algo %s: required workspace %zu bytes, got %zu", name(), + req, workspace.size); return *this; } virtual bool is_cudnn() const { return false; } }; class Convolution3DForwardImpl::Algo1x1x1 final : public AlgoBase { - static void extract_matmul_layouts(const SizeArgs& args, TensorLayout& A, - TensorLayout& B, TensorLayout& C); + static void extract_matmul_layouts( + const SizeArgs& args, TensorLayout& A, TensorLayout& B, TensorLayout& C); public: bool is_available(const SizeArgs& args) const override; @@ -108,8 +108,7 @@ public: const char* name() const override { return "1x1x1"; } AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; } MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1) }; @@ -121,14 +120,11 @@ public: size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; const char* name() const override { return "CUDA:GROUP_CONV3D_FORWARD"; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) private: WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; @@ -140,8 +136,9 @@ class Convolution3DForwardImpl::AlgoCUDNN final : public AlgoBase { public: AlgoCUDNN(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { - megdnn_assert(CudnnAlgoPack::conv3d_fwd_algos().find(cudnn_enum) != - CudnnAlgoPack::conv3d_fwd_algos().end()); + megdnn_assert( + CudnnAlgoPack::conv3d_fwd_algos().find(cudnn_enum) != + CudnnAlgoPack::conv3d_fwd_algos().end()); m_attr = CudnnAlgoPack::conv3d_fwd_algos().at(cudnn_enum); } @@ -173,7 +170,6 @@ public: serialize_write_pod(m_cudnn_enum, ret); return ret; } - }; class Convolution3DForwardImpl::AlgoInplaceMatmul final : public AlgoBase { @@ -183,9 +179,7 @@ public: void exec(const ExecArgs& args) const override; const char* name() const override { return "INPLACE_MATMUL"; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) }; @@ -196,9 +190,7 @@ public: void exec(const ExecArgs& args) const override; const char* name() const override { return "CHANNEL_WISE"; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) }; @@ -232,4 +224,3 @@ public: } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution3d/forward/chanwise.cpp b/dnn/src/cuda/convolution3d/forward/chanwise.cpp index c3a28044..c6b79706 100644 --- a/dnn/src/cuda/convolution3d/forward/chanwise.cpp +++ b/dnn/src/cuda/convolution3d/forward/chanwise.cpp @@ -10,47 +10,41 @@ */ #include "./algo.h" -#include "src/cuda/utils.h" #include "src/cuda/convolution3d/chanwise/kern.cuh" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace convolution3d; -bool Convolution3DForwardImpl::AlgoChanwise::is_available( - const SizeArgs &args) const { - if (!args.src_layout->is_contiguous() || - !args.dst_layout->is_contiguous()) { +bool Convolution3DForwardImpl::AlgoChanwise::is_available(const SizeArgs& args) const { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { return false; } - auto &&fm = args.filter_meta; + auto&& fm = args.filter_meta; return args.filter_meta.format == Param::Format::NCDHW && - args.src_layout->dtype.category() == DTypeCategory::FLOAT && - fm.spatial_ndim == 3 && fm.icpg == 1 && - fm.dilation[0] == 1 && fm.dilation[1] == 1 && - fm.dilation[2] == 1 && !fm.should_flip; + args.src_layout->dtype.category() == DTypeCategory::FLOAT && + fm.spatial_ndim == 3 && fm.icpg == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.dilation[2] == 1 && !fm.should_flip; } size_t Convolution3DForwardImpl::AlgoChanwise::get_workspace_in_bytes( - const SizeArgs &) const { + const SizeArgs&) const { return 0; } -void Convolution3DForwardImpl::AlgoChanwise::exec(const ExecArgs &args) const { +void Convolution3DForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { auto kparam = chanwise::Param::from_fwd_args(args); auto stream = cuda_stream(args.handle); switch (args.src_layout->dtype.enumv()) { -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: \ - { \ - using ctype = DTypeTrait<_dt>::ctype; \ - return chanwise::run_fwd( \ - args.dst_tensor->ptr(), \ - args.src_tensor->ptr(), \ - args.filter_tensor->ptr(), \ - kparam, stream); \ - } - MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + return chanwise::run_fwd( \ + args.dst_tensor->ptr(), args.src_tensor->ptr(), \ + args.filter_tensor->ptr(), kparam, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) #undef cb default: break; diff --git a/dnn/src/cuda/convolution3d/forward/cudnn.cpp b/dnn/src/cuda/convolution3d/forward/cudnn.cpp index da801e31..a3ba334a 100644 --- a/dnn/src/cuda/convolution3d/forward/cudnn.cpp +++ b/dnn/src/cuda/convolution3d/forward/cudnn.cpp @@ -10,16 +10,15 @@ */ #include "./algo.h" -#include "src/cuda/utils.h" -#include "src/cuda/cudnn_wrapper.h" #include "src/cuda/convolution3d/helper.h" +#include "src/cuda/cudnn_wrapper.h" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace convolution3d; -bool Convolution3DForwardImpl::AlgoCUDNN::is_available( - const SizeArgs &args) const { +bool Convolution3DForwardImpl::AlgoCUDNN::is_available(const SizeArgs& args) const { CUDNNForwardDescs D; if (!is_cudnn_supported(args)) @@ -28,53 +27,37 @@ bool Convolution3DForwardImpl::AlgoCUDNN::is_available( args.init_desc(D); size_t workspace_size; auto status = cudnnGetConvolutionForwardWorkspaceSize( - args.handle->cudnn_handle(), - D.src_desc.desc, - D.filter_desc.desc, - D.conv_desc.desc, - D.dst_desc.desc, - m_cudnn_enum, - &workspace_size); + args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, + D.conv_desc.desc, D.dst_desc.desc, m_cudnn_enum, &workspace_size); return status == CUDNN_STATUS_SUCCESS; } size_t Convolution3DForwardImpl::AlgoCUDNN::get_workspace_in_bytes( - const SizeArgs &args) const { + const SizeArgs& args) const { CUDNNForwardDescs D; args.init_desc(D); size_t workspace_size; auto status = cudnnGetConvolutionForwardWorkspaceSize( - args.handle->cudnn_handle(), - D.src_desc.desc, - D.filter_desc.desc, - D.conv_desc.desc, - D.dst_desc.desc, - m_cudnn_enum, - &workspace_size); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, - "conv fwd get workspace failed: %s; info: %s", - cudnnGetErrorString(status), args.to_string().c_str()); + args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, + D.conv_desc.desc, D.dst_desc.desc, m_cudnn_enum, &workspace_size); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, + "conv fwd get workspace failed: %s; info: %s", cudnnGetErrorString(status), + args.to_string().c_str()); return workspace_size; } -void Convolution3DForwardImpl::AlgoCUDNN::exec( - const ExecArgs &args) const { +void Convolution3DForwardImpl::AlgoCUDNN::exec(const ExecArgs& args) const { CUDNNForwardDescs D; args.init_desc(D); float alpha = 1.0f, beta = 0.0f; - auto status = cudnnConvolutionForward(args.handle->cudnn_handle(), - &alpha, - D.src_desc.desc, args.src_tensor->raw_ptr, - D.filter_desc.desc, args.filter_tensor->raw_ptr, - D.conv_desc.desc, - m_cudnn_enum, - args.workspace.raw_ptr, - args.workspace.size, - &beta, - D.dst_desc.desc, - args.dst_tensor->raw_ptr); - megdnn_assert(status == CUDNN_STATUS_SUCCESS, - "conv fwd failed: %s; info: %s", + auto status = cudnnConvolutionForward( + args.handle->cudnn_handle(), &alpha, D.src_desc.desc, + args.src_tensor->raw_ptr, D.filter_desc.desc, args.filter_tensor->raw_ptr, + D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, + &beta, D.dst_desc.desc, args.dst_tensor->raw_ptr); + megdnn_assert( + status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", cudnnGetErrorString(status), args.to_string().c_str()); } diff --git a/dnn/src/cuda/convolution3d/forward/group_conv.cpp b/dnn/src/cuda/convolution3d/forward/group_conv.cpp index c8d65f9e..6a85e326 100644 --- a/dnn/src/cuda/convolution3d/forward/group_conv.cpp +++ b/dnn/src/cuda/convolution3d/forward/group_conv.cpp @@ -45,8 +45,8 @@ std::pair sub_opr_config( return ret; } -std::pair> -prepare_sub_opr(const Convolution3DForwardImpl::AlgoBase::SizeArgs& args) { +std::pair> prepare_sub_opr( + const Convolution3DForwardImpl::AlgoBase::SizeArgs& args) { auto conv3d_opr = args.handle->create_operator(); set_execution_policy( args.opr, conv3d_opr.get()); @@ -57,21 +57,21 @@ prepare_sub_opr(const Convolution3DForwardImpl::AlgoBase::SizeArgs& args) { } } // namespace -std::vector -Convolution3DForwardImpl::AlgoGroupConvGeneral::get_subopr_list( - const TensorLayoutArray& layouts, const OperatorBase* opr) const { - AlgoBase::SizeArgs args{static_cast(opr), - layouts[0], layouts[1], layouts[2]}; +std::vector Convolution3DForwardImpl::AlgoGroupConvGeneral:: + get_subopr_list( + const TensorLayoutArray& layouts, const OperatorBase* opr) const { + AlgoBase::SizeArgs args{ + static_cast(opr), layouts[0], layouts[1], + layouts[2]}; auto&& config = sub_opr_config(args); std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::CONVOLUTION3D_FORWARD, param_str, - config.first}}; + return {{Algorithm::OprType::CONVOLUTION3D_FORWARD, param_str, config.first}}; } bool Convolution3DForwardImpl::AlgoGroupConvGeneral::is_available( - const SizeArgs &args) const { + const SizeArgs& args) const { if (args.filter_meta.group <= 1) return false; if (args.filter_meta.format != Param::Format::NCDHW && @@ -86,8 +86,7 @@ bool Convolution3DForwardImpl::AlgoGroupConvGeneral::is_available( config.first[0], config.first[1], config.first[2]); } -WorkspaceBundle -Convolution3DForwardImpl::AlgoGroupConvGeneral::get_workspace_bundle( +WorkspaceBundle Convolution3DForwardImpl::AlgoGroupConvGeneral::get_workspace_bundle( void* ptr, const SizeArgs& args) const { auto config = prepare_sub_opr(args); size_t sizes = config.second->get_workspace_in_bytes( @@ -100,8 +99,7 @@ size_t Convolution3DForwardImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } -void Convolution3DForwardImpl::AlgoGroupConvGeneral::exec( - const ExecArgs& args) const { +void Convolution3DForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) const { auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); { auto config = prepare_sub_opr(args); @@ -113,18 +111,17 @@ void Convolution3DForwardImpl::AlgoGroupConvGeneral::exec( if (args.filter_meta.format == Param::Format::NCDHW) { c_pos = 1; } else { - megdnn_assert(args.filter_meta.format == Param::Format::NDHWC, - "invalid conv format"); + megdnn_assert( + args.filter_meta.format == Param::Format::NDHWC, + "invalid conv format"); c_pos = 4; } auto grp = args.filter_meta.group; auto&& fm = args.filter_meta; - auto strd_src = tsrc.layout.stride[c_pos] * fm.icpg * - tsrc.layout.dtype.size(), - strd_dst = tdst.layout.stride[c_pos] * fm.ocpg * - tdst.layout.dtype.size(), + auto strd_src = tsrc.layout.stride[c_pos] * fm.icpg * tsrc.layout.dtype.size(), + strd_dst = tdst.layout.stride[c_pos] * fm.ocpg * tdst.layout.dtype.size(), strd_flt = fm.icpg * fm.ocpg * fm.spatial[0] * fm.spatial[1] * fm.spatial[2] * tfilter.layout.dtype.size(); @@ -138,4 +135,3 @@ void Convolution3DForwardImpl::AlgoGroupConvGeneral::exec( } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution3d/forward/inplace_matmul.cpp b/dnn/src/cuda/convolution3d/forward/inplace_matmul.cpp index 672ec5f7..a3ccbd62 100644 --- a/dnn/src/cuda/convolution3d/forward/inplace_matmul.cpp +++ b/dnn/src/cuda/convolution3d/forward/inplace_matmul.cpp @@ -16,50 +16,33 @@ using namespace megdnn; using namespace cuda; bool Convolution3DForwardImpl::AlgoInplaceMatmul::is_available( - const SizeArgs &args) const { - auto &&fm = args.filter_meta; + const SizeArgs& args) const { + auto&& fm = args.filter_meta; return args.filter_meta.format == Param::Format::NCDHW && - args.src_layout->dtype == dtype::Float32() && - fm.group == 1 && fm.spatial_ndim == 3; + args.src_layout->dtype == dtype::Float32() && fm.group == 1 && + fm.spatial_ndim == 3; } size_t Convolution3DForwardImpl::AlgoInplaceMatmul::get_workspace_in_bytes( - const SizeArgs &) const { + const SizeArgs&) const { return 0; } -void Convolution3DForwardImpl::AlgoInplaceMatmul::exec( - const ExecArgs &args) const { - auto &&fm = args.filter_meta; - size_t N = args.src_layout->shape[0], - IC = fm.icpg, - ID = args.src_layout->shape[2], - IH = args.src_layout->shape[3], - IW = args.src_layout->shape[4], - OC = fm.ocpg, - OD = args.dst_layout->shape[2], - OH = args.dst_layout->shape[3], - OW = args.dst_layout->shape[4], - FD = fm.spatial[0], - FH = fm.spatial[1], - FW = fm.spatial[2], - DD = fm.dilation[0], - DH = fm.dilation[1], - DW = fm.dilation[2]; +void Convolution3DForwardImpl::AlgoInplaceMatmul::exec(const ExecArgs& args) const { + auto&& fm = args.filter_meta; + size_t N = args.src_layout->shape[0], IC = fm.icpg, ID = args.src_layout->shape[2], + IH = args.src_layout->shape[3], IW = args.src_layout->shape[4], OC = fm.ocpg, + OD = args.dst_layout->shape[2], OH = args.dst_layout->shape[3], + OW = args.dst_layout->shape[4], FD = fm.spatial[0], FH = fm.spatial[1], + FW = fm.spatial[2], DD = fm.dilation[0], DH = fm.dilation[1], + DW = fm.dilation[2]; auto stream = args.handle->stream(); convolution3d::exec_inplace_matmul_fwd( - args.src_tensor->ptr(), - args.filter_tensor->ptr(), - args.dst_tensor->ptr(), - N, args.src_layout->stride[0], args.dst_layout->stride[0], - IC, ID, IH, IW, - OC, OD, OH, OW, - FD, FH, FW, - fm.padding[0], fm.padding[1], fm.padding[2], - fm.stride[0], fm.stride[1], fm.stride[2], - DD, DH, DW, - !fm.should_flip, stream); + args.src_tensor->ptr(), args.filter_tensor->ptr(), + args.dst_tensor->ptr(), N, args.src_layout->stride[0], + args.dst_layout->stride[0], IC, ID, IH, IW, OC, OD, OH, OW, FD, FH, FW, + fm.padding[0], fm.padding[1], fm.padding[2], fm.stride[0], fm.stride[1], + fm.stride[2], DD, DH, DW, !fm.should_flip, stream); } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution3d/forward/inplace_matmul_impl.cu b/dnn/src/cuda/convolution3d/forward/inplace_matmul_impl.cu index ed215f91..4fc44747 100644 --- a/dnn/src/cuda/convolution3d/forward/inplace_matmul_impl.cu +++ b/dnn/src/cuda/convolution3d/forward/inplace_matmul_impl.cu @@ -8,10 +8,10 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include +#include #include "./inplace_matmul_impl.cuh" #include "src/cuda/utils.cuh" -#include -#include using namespace megdnn; using namespace cuda; @@ -26,22 +26,18 @@ struct BufferFetcherTexture { }; struct BufferFetcherRaw { - const float *ptr; + const float* ptr; - __device__ __forceinline__ float get(uint32_t offset) { - return ptr[offset]; - } + __device__ __forceinline__ float get(uint32_t offset) { return ptr[offset]; } }; struct BufferFetcherTextureHost { bool init_succ; BufferFetcherTexture val; - BufferFetcherTextureHost(float *p, const size_t n); + BufferFetcherTextureHost(float* p, const size_t n); - ~BufferFetcherTextureHost() { - reset(); - } + ~BufferFetcherTextureHost() { reset(); } void reset() { if (init_succ) { @@ -51,36 +47,34 @@ struct BufferFetcherTextureHost { } }; -BufferFetcherTextureHost::BufferFetcherTextureHost(float *p, const size_t n) { +BufferFetcherTextureHost::BufferFetcherTextureHost(float* p, const size_t n) { init_succ = false; cudaTextureObject_t tex_obj; cudaResourceDesc res_desc; memset(&res_desc, 0, sizeof(cudaResourceDesc)); res_desc.resType = cudaResourceTypeLinear; - res_desc.res.linear.devPtr = static_cast(p); - res_desc.res.linear.sizeInBytes = n*sizeof(float); - res_desc.res.linear.desc = cudaCreateChannelDesc(32, 0, 0, 0, cudaChannelFormatKindFloat); - cudaTextureDesc tex_desc; + res_desc.res.linear.devPtr = static_cast(p); + res_desc.res.linear.sizeInBytes = n * sizeof(float); + res_desc.res.linear.desc = + cudaCreateChannelDesc(32, 0, 0, 0, cudaChannelFormatKindFloat); + cudaTextureDesc tex_desc; memset(&tex_desc, 0, sizeof(cudaTextureDesc)); if (cudaCreateTextureObject(&tex_obj, &res_desc, &tex_desc, NULL) == cudaSuccess) { val.tex = tex_obj; init_succ = true; } else { - cudaGetLastError(); // reset error + cudaGetLastError(); // reset error } } -template +template struct KernelPtr { - typedef void(*type)(BufferFetcher, BufferFetcher, float*, - uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t, - uint32_t, uint32_t, uint32_t); + typedef void (*type)( + BufferFetcher, BufferFetcher, float*, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t); }; //! 1 -> 0xffffffff, 0 -> 0x00000000 @@ -94,7 +88,7 @@ union FloatAndU32 { }; //! \p mask must be either all 1 or 0 bits -template +template __device__ __forceinline__ float visit_with_mask( BufferFetcher buf, uint32_t offset, uint32_t mask) { FloatAndU32 f; @@ -104,16 +98,14 @@ __device__ __forceinline__ float visit_with_mask( } template -__global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, - float *dst, - const uint32_t INP_BS, const uint32_t OUT_BS, - const uint32_t IC, const uint32_t ID, const uint32_t IH, const uint32_t IW, - const uint32_t OC, const uint32_t OD, const uint32_t OH, const uint32_t OW, - const uint32_t FD, const uint32_t FH, const uint32_t FW, - const uint32_t SD, const uint32_t SH, const uint32_t SW, - const uint32_t PD, const uint32_t PH, const uint32_t PW, - const uint32_t DD, const uint32_t DH, const uint32_t DW) -{ +__global__ void conv_kernel( + BufferFetcher src, BufferFetcher filter, float* dst, const uint32_t INP_BS, + const uint32_t OUT_BS, const uint32_t IC, const uint32_t ID, const uint32_t IH, + const uint32_t IW, const uint32_t OC, const uint32_t OD, const uint32_t OH, + const uint32_t OW, const uint32_t FD, const uint32_t FH, const uint32_t FW, + const uint32_t SD, const uint32_t SH, const uint32_t SW, const uint32_t PD, + const uint32_t PH, const uint32_t PW, const uint32_t DD, const uint32_t DH, + const uint32_t DW) { const uint32_t BM = BY < BX ? BY : BX; // BY*BX == 256 // (OC) * (IC*FD*FH*FW) * (OD*OH*OW) @@ -122,32 +114,32 @@ __global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, const uint32_t tidy = threadIdx.y; const uint32_t posx = blockIdx.x * blockDim.x + threadIdx.x; const uint32_t posy = blockIdx.y * blockDim.y + threadIdx.y; - const uint32_t posx2 = posx<<2; - const uint32_t posy2 = posy<<2; + const uint32_t posx2 = posx << 2; + const uint32_t posy2 = posy << 2; const uint32_t heightA = OC; - const uint32_t widthA = IC*FD*FH*FW; + const uint32_t widthA = IC * FD * FH * FW; const uint32_t heightB = widthA; - const uint32_t widthB = OD*OH*OW; - const uint32_t od0 = (posx2+0) / OW / OH * SD; - const uint32_t oh0 = (posx2+0) / OW % OH * SH; - const uint32_t ow0 = (posx2+0) % OW * SW; + const uint32_t widthB = OD * OH * OW; + const uint32_t od0 = (posx2 + 0) / OW / OH * SD; + const uint32_t oh0 = (posx2 + 0) / OW % OH * SH; + const uint32_t ow0 = (posx2 + 0) % OW * SW; const uint32_t op0 = od0 * IH * IW + oh0 * IW + ow0; - const uint32_t od1 = (posx2+1) / OW / OH * SD; - const uint32_t oh1 = (posx2+1) / OW % OH * SH; - const uint32_t ow1 = (posx2+1) % OW * SW; + const uint32_t od1 = (posx2 + 1) / OW / OH * SD; + const uint32_t oh1 = (posx2 + 1) / OW % OH * SH; + const uint32_t ow1 = (posx2 + 1) % OW * SW; const uint32_t op1 = od1 * IH * IW + oh1 * IW + ow1; - const uint32_t od2 = (posx2+2) / OW / OH * SD; - const uint32_t oh2 = (posx2+2) / OW % OH * SH; - const uint32_t ow2 = (posx2+2) % OW * SW; + const uint32_t od2 = (posx2 + 2) / OW / OH * SD; + const uint32_t oh2 = (posx2 + 2) / OW % OH * SH; + const uint32_t ow2 = (posx2 + 2) % OW * SW; const uint32_t op2 = od2 * IH * IW + oh2 * IW + ow2; - const uint32_t od3 = (posx2+3) / OW / OH * SD; - const uint32_t oh3 = (posx2+3) / OW % OH * SH; - const uint32_t ow3 = (posx2+3) % OW * SW; + const uint32_t od3 = (posx2 + 3) / OW / OH * SD; + const uint32_t oh3 = (posx2 + 3) / OW % OH * SH; + const uint32_t ow3 = (posx2 + 3) % OW * SW; const uint32_t op3 = od3 * IH * IW + oh3 * IW + ow3; - const uint32_t FP = FD*FH*FW; + const uint32_t FP = FD * FH * FW; // OC % (BLOCK*4) == 0 // IC*FD*FH*FW % BLOCK == 0 // OD*OH*OW % (BLOCK*4) == 0 @@ -155,30 +147,28 @@ __global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, __shared__ float4 localB[BM][BX]; uint32_t i = 0u; uint32_t offsetA = posy2 * widthA + tidx; - uint32_t offsetB = n*INP_BS - PD*IH*IW - PH*IW - PW; - float4 sum0 = {0.0f, 0.0f, 0.0f, 0.0f}, - sum1 = {0.0f, 0.0f, 0.0f, 0.0f}, - sum2 = {0.0f, 0.0f, 0.0f, 0.0f}, - sum3 = {0.0f, 0.0f, 0.0f, 0.0f}; + uint32_t offsetB = n * INP_BS - PD * IH * IW - PH * IW - PW; + float4 sum0 = {0.0f, 0.0f, 0.0f, 0.0f}, sum1 = {0.0f, 0.0f, 0.0f, 0.0f}, + sum2 = {0.0f, 0.0f, 0.0f, 0.0f}, sum3 = {0.0f, 0.0f, 0.0f, 0.0f}; uint32_t fd = tidy / FW / FH % FD; uint32_t fh = tidy / FW % FH; uint32_t fw = tidy % FW; - uint32_t ic = tidy / (FD*FH*FW); - uint32_t icm = tidy % (FD*FH*FW); + uint32_t ic = tidy / (FD * FH * FW); + uint32_t icm = tidy % (FD * FH * FW); const uint32_t fds = BM / FW / FH % FD; const uint32_t fhs = BM / FW % FH; const uint32_t fws = BM % FW; - const uint32_t ics = BM / (FD*FH*FW); - const uint32_t icms = BM % (FD*FH*FW); + const uint32_t ics = BM / (FD * FH * FW); + const uint32_t icms = BM % (FD * FH * FW); for (; i < widthA; i += BM, offsetA += BM) { // load localA if (tidx < BM) { - localA[tidy][tidx].x = filter.get(offsetA + 0*widthA); - localA[tidy][tidx].y = filter.get(offsetA + 1*widthA); - localA[tidy][tidx].z = filter.get(offsetA + 2*widthA); - localA[tidy][tidx].w = filter.get(offsetA + 3*widthA); + localA[tidy][tidx].x = filter.get(offsetA + 0 * widthA); + localA[tidy][tidx].y = filter.get(offsetA + 1 * widthA); + localA[tidy][tidx].z = filter.get(offsetA + 2 * widthA); + localA[tidy][tidx].w = filter.get(offsetA + 3 * widthA); } // load localB @@ -188,39 +178,38 @@ __global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, fh2 = fh; fw2 = fw; } else { - fd2 = FD-fd-1; - fh2 = FH-fh-1; - fw2 = FW-fw-1; + fd2 = FD - fd - 1; + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; } if (tidy < BM) { - uint32_t fd2d = fd2 * DD, - fh2d = fh2 * DH, - fw2d = fw2 * DW; - uint32_t tmp = offsetB+ic*ID*IH*IW+fd2d*IH*IW+fh2d*IW+fw2d, - ok = bool_as_mask(tidy+i < heightB), + uint32_t fd2d = fd2 * DD, fh2d = fh2 * DH, fw2d = fw2 * DW; + uint32_t tmp = offsetB + ic * ID * IH * IW + fd2d * IH * IW + fh2d * IW + + fw2d, + ok = bool_as_mask(tidy + i < heightB), p0 = bool_as_mask( - fd2d+od0 >= PD && fd2d+od0 < ID+PD && - fh2d+oh0 >= PH && fh2d+oh0 < IH+PH && - fw2d+ow0 >= PW && fw2d+ow0 < IW+PW), + fd2d + od0 >= PD && fd2d + od0 < ID + PD && + fh2d + oh0 >= PH && fh2d + oh0 < IH + PH && + fw2d + ow0 >= PW && fw2d + ow0 < IW + PW), p1 = bool_as_mask( - fd2d+od1 >= PD && fd2d+od1 < ID+PD && - fh2d+oh1 >= PH && fh2d+oh1 < IH+PH && - fw2d+ow1 >= PW && fw2d+ow1 < IW+PW), + fd2d + od1 >= PD && fd2d + od1 < ID + PD && + fh2d + oh1 >= PH && fh2d + oh1 < IH + PH && + fw2d + ow1 >= PW && fw2d + ow1 < IW + PW), p2 = bool_as_mask( - fd2d+od2 >= PD && fd2d+od2 < ID+PD && - fh2d+oh2 >= PH && fh2d+oh2 < IH+PH && - fw2d+ow2 >= PW && fw2d+ow2 < IW+PW), + fd2d + od2 >= PD && fd2d + od2 < ID + PD && + fh2d + oh2 >= PH && fh2d + oh2 < IH + PH && + fw2d + ow2 >= PW && fw2d + ow2 < IW + PW), p3 = bool_as_mask( - fd2d+od3 >= PD && fd2d+od3 < ID+PD && - fh2d+oh3 >= PH && fh2d+oh3 < IH+PH && - fw2d+ow3 >= PW && fw2d+ow3 < IW+PW); - localB[tidy][tidx].x = visit_with_mask(src, tmp+op0, ok & p0); - localB[tidy][tidx].y = visit_with_mask(src, tmp+op1, ok & p1); - localB[tidy][tidx].z = visit_with_mask(src, tmp+op2, ok & p2); - localB[tidy][tidx].w = visit_with_mask(src, tmp+op3, ok & p3); + fd2d + od3 >= PD && fd2d + od3 < ID + PD && + fh2d + oh3 >= PH && fh2d + oh3 < IH + PH && + fw2d + ow3 >= PW && fw2d + ow3 < IW + PW); + localB[tidy][tidx].x = visit_with_mask(src, tmp + op0, ok & p0); + localB[tidy][tidx].y = visit_with_mask(src, tmp + op1, ok & p1); + localB[tidy][tidx].z = visit_with_mask(src, tmp + op2, ok & p2); + localB[tidy][tidx].w = visit_with_mask(src, tmp + op3, ok & p3); } - __syncthreads(); // die without this sync().. + __syncthreads(); // die without this sync().. for (uint32_t j = 0u; j < BM; ++j) { float4 tmpA = localA[tidy][j]; float4 tmpB = localB[j][tidx]; @@ -258,57 +247,67 @@ __global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, __syncthreads(); } - const uint32_t dst_idx = n*OUT_BS + posy2*widthB + posx2; - bool y0 = (posy2+0 < heightA); - bool y1 = (posy2+1 < heightA); - bool y2 = (posy2+2 < heightA); - bool y3 = (posy2+3 < heightA); - bool x0 = (posx2+0 < widthB); - bool x1 = (posx2+1 < widthB); - bool x2 = (posx2+2 < widthB); - bool x3 = (posx2+3 < widthB); - if (y0) { - if (x0) dst[dst_idx + 0*widthB + 0] = sum0.x; - if (x1) dst[dst_idx + 0*widthB + 1] = sum0.y; - if (x2) dst[dst_idx + 0*widthB + 2] = sum0.z; - if (x3) dst[dst_idx + 0*widthB + 3] = sum0.w; + const uint32_t dst_idx = n * OUT_BS + posy2 * widthB + posx2; + bool y0 = (posy2 + 0 < heightA); + bool y1 = (posy2 + 1 < heightA); + bool y2 = (posy2 + 2 < heightA); + bool y3 = (posy2 + 3 < heightA); + bool x0 = (posx2 + 0 < widthB); + bool x1 = (posx2 + 1 < widthB); + bool x2 = (posx2 + 2 < widthB); + bool x3 = (posx2 + 3 < widthB); + if (y0) { + if (x0) + dst[dst_idx + 0 * widthB + 0] = sum0.x; + if (x1) + dst[dst_idx + 0 * widthB + 1] = sum0.y; + if (x2) + dst[dst_idx + 0 * widthB + 2] = sum0.z; + if (x3) + dst[dst_idx + 0 * widthB + 3] = sum0.w; } if (y1) { - if (x0) dst[dst_idx + 1*widthB + 0] = sum1.x; - if (x1) dst[dst_idx + 1*widthB + 1] = sum1.y; - if (x2) dst[dst_idx + 1*widthB + 2] = sum1.z; - if (x3) dst[dst_idx + 1*widthB + 3] = sum1.w; + if (x0) + dst[dst_idx + 1 * widthB + 0] = sum1.x; + if (x1) + dst[dst_idx + 1 * widthB + 1] = sum1.y; + if (x2) + dst[dst_idx + 1 * widthB + 2] = sum1.z; + if (x3) + dst[dst_idx + 1 * widthB + 3] = sum1.w; } if (y2) { - if (x0) dst[dst_idx + 2*widthB + 0] = sum2.x; - if (x1) dst[dst_idx + 2*widthB + 1] = sum2.y; - if (x2) dst[dst_idx + 2*widthB + 2] = sum2.z; - if (x3) dst[dst_idx + 2*widthB + 3] = sum2.w; + if (x0) + dst[dst_idx + 2 * widthB + 0] = sum2.x; + if (x1) + dst[dst_idx + 2 * widthB + 1] = sum2.y; + if (x2) + dst[dst_idx + 2 * widthB + 2] = sum2.z; + if (x3) + dst[dst_idx + 2 * widthB + 3] = sum2.w; } if (y3) { - if (x0) dst[dst_idx + 3*widthB + 0] = sum3.x; - if (x1) dst[dst_idx + 3*widthB + 1] = sum3.y; - if (x2) dst[dst_idx + 3*widthB + 2] = sum3.z; - if (x3) dst[dst_idx + 3*widthB + 3] = sum3.w; + if (x0) + dst[dst_idx + 3 * widthB + 0] = sum3.x; + if (x1) + dst[dst_idx + 3 * widthB + 1] = sum3.y; + if (x2) + dst[dst_idx + 3 * widthB + 2] = sum3.z; + if (x3) + dst[dst_idx + 3 * widthB + 3] = sum3.w; } } -} // anonymous namespace +} // anonymous namespace void convolution3d::exec_inplace_matmul_fwd( - const float *src, const float *filter, float *dst, - size_t N, size_t INP_BS, size_t OUT_BS, - size_t IC, size_t ID, size_t IH, size_t IW, - size_t OC, size_t OD, size_t OH, size_t OW, - size_t FD, size_t FH, size_t FW, - size_t PD, size_t PH, size_t PW, - size_t SD, size_t SH, size_t SW, - size_t DD, size_t DH, size_t DW, - bool is_xcorr, - cudaStream_t stream) -{ - BufferFetcherTextureHost src_tex(const_cast(src), N * INP_BS), - filter_tex(const_cast(filter), OC*IC*FD*FH*FW); + const float* src, const float* filter, float* dst, size_t N, size_t INP_BS, + size_t OUT_BS, size_t IC, size_t ID, size_t IH, size_t IW, size_t OC, size_t OD, + size_t OH, size_t OW, size_t FD, size_t FH, size_t FW, size_t PD, size_t PH, + size_t PW, size_t SD, size_t SH, size_t SW, size_t DD, size_t DH, size_t DW, + bool is_xcorr, cudaStream_t stream) { + BufferFetcherTextureHost src_tex(const_cast(src), N * INP_BS), + filter_tex(const_cast(filter), OC * IC * FD * FH * FW); BufferFetcherRaw src_buf, filter_buf; src_buf.ptr = src; filter_buf.ptr = filter; @@ -317,73 +316,84 @@ void convolution3d::exec_inplace_matmul_fwd( filter_tex.reset(); } int m = OC; - int n = OD*OH*OW; + int n = OD * OH * OW; int BY = 1; int BX = 1; if (m <= 64) { - while (BY < 16 && (BY<<2) < m) BY <<= 1; + while (BY < 16 && (BY << 2) < m) + BY <<= 1; BX = 256 / BY; } else if (n <= 64) { - while (BX < 16 && (BX<<2) < n) BX <<= 1; + while (BX < 16 && (BX << 2) < n) + BX <<= 1; BY = 256 / BX; } else { BX = BY = 16; } - dim3 blocks(DIVUP(OD*OH*OW, 4*BX), DIVUP(OC, 4*BY), N); + dim3 blocks(DIVUP(OD * OH * OW, 4 * BX), DIVUP(OC, 4 * BY), N); dim3 threads(BX, BY); -#define DISPATCH_BX_BY(BX, BY) do { \ - if (src_tex.init_succ) { \ - KernelPtr::type kptr; \ - if (is_xcorr) { \ - kptr = conv_kernel; \ - } else { \ - kptr = conv_kernel; \ - } \ - kptr<<>>( \ - src_tex.val, filter_tex.val, dst, \ - INP_BS, OUT_BS, \ - IC, ID, IH, IW, \ - OC, OD, OH, OW, \ - FD, FH, FW, \ - SD, SH, SW, \ - PD, PH, PW, \ - DD, DH, DW); \ - } else { \ - KernelPtr::type kptr; \ - if (is_xcorr) { \ - kptr = conv_kernel; \ - } else { \ - kptr = conv_kernel; \ - } \ - kptr<<>>( \ - src_buf, filter_buf, dst, \ - INP_BS, OUT_BS, \ - IC, ID, IH, IW, \ - OC, OD, OH, OW, \ - FD, FH, FW, \ - SD, SH, SW, \ - PD, PH, PW, \ - DD, DH, DW); \ - } \ -} while (0) -#define DISPATCH_BX(BX) do { \ - DISPATCH_BX_BY(BX, 256/BX); \ -} while (0) -#define DISPATCH() do { \ - switch (BX) { \ - case 1: DISPATCH_BX(1); break; \ - case 2: DISPATCH_BX(2); break; \ - case 4: DISPATCH_BX(4); break; \ - case 8: DISPATCH_BX(8); break; \ - case 16: DISPATCH_BX(16); break; \ - case 32: DISPATCH_BX(32); break; \ - case 64: DISPATCH_BX(64); break; \ - case 128: DISPATCH_BX(128); break; \ - case 256: DISPATCH_BX(256); break; \ - default: \ - report_error("no usable kernel"); \ - } \ -} while (0) +#define DISPATCH_BX_BY(BX, BY) \ + do { \ + if (src_tex.init_succ) { \ + KernelPtr::type kptr; \ + if (is_xcorr) { \ + kptr = conv_kernel; \ + } else { \ + kptr = conv_kernel; \ + } \ + kptr<<>>( \ + src_tex.val, filter_tex.val, dst, INP_BS, OUT_BS, IC, ID, IH, IW, \ + OC, OD, OH, OW, FD, FH, FW, SD, SH, SW, PD, PH, PW, DD, DH, DW); \ + } else { \ + KernelPtr::type kptr; \ + if (is_xcorr) { \ + kptr = conv_kernel; \ + } else { \ + kptr = conv_kernel; \ + } \ + kptr<<>>( \ + src_buf, filter_buf, dst, INP_BS, OUT_BS, IC, ID, IH, IW, OC, OD, \ + OH, OW, FD, FH, FW, SD, SH, SW, PD, PH, PW, DD, DH, DW); \ + } \ + } while (0) +#define DISPATCH_BX(BX) \ + do { \ + DISPATCH_BX_BY(BX, 256 / BX); \ + } while (0) +#define DISPATCH() \ + do { \ + switch (BX) { \ + case 1: \ + DISPATCH_BX(1); \ + break; \ + case 2: \ + DISPATCH_BX(2); \ + break; \ + case 4: \ + DISPATCH_BX(4); \ + break; \ + case 8: \ + DISPATCH_BX(8); \ + break; \ + case 16: \ + DISPATCH_BX(16); \ + break; \ + case 32: \ + DISPATCH_BX(32); \ + break; \ + case 64: \ + DISPATCH_BX(64); \ + break; \ + case 128: \ + DISPATCH_BX(128); \ + break; \ + case 256: \ + DISPATCH_BX(256); \ + break; \ + default: \ + report_error("no usable kernel"); \ + } \ + } while (0) DISPATCH(); #undef DISPATCH #undef DISPATCH_BX @@ -392,4 +402,3 @@ void convolution3d::exec_inplace_matmul_fwd( } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/convolution3d/forward/inplace_matmul_impl.cuh b/dnn/src/cuda/convolution3d/forward/inplace_matmul_impl.cuh index 4578252b..7eee3451 100644 --- a/dnn/src/cuda/convolution3d/forward/inplace_matmul_impl.cuh +++ b/dnn/src/cuda/convolution3d/forward/inplace_matmul_impl.cuh @@ -11,26 +11,22 @@ #pragma once #include -#include #include +#include namespace megdnn { namespace cuda { namespace convolution3d { -void exec_inplace_matmul_fwd(const float *src, const float *filter, float *dst, - size_t N, size_t INP_BS, size_t OUT_BS, - size_t IC, size_t ID, size_t IH, size_t IW, - size_t OC, size_t OD, size_t OH, size_t OW, - size_t FD, size_t FH, size_t FW, - size_t PD, size_t PH, size_t PW, - size_t SD, size_t SH, size_t SW, - size_t DD, size_t DH, size_t DW, - bool is_xcorr, - cudaStream_t stream); +void exec_inplace_matmul_fwd( + const float* src, const float* filter, float* dst, size_t N, size_t INP_BS, + size_t OUT_BS, size_t IC, size_t ID, size_t IH, size_t IW, size_t OC, size_t OD, + size_t OH, size_t OW, size_t FD, size_t FH, size_t FW, size_t PD, size_t PH, + size_t PW, size_t SD, size_t SH, size_t SW, size_t DD, size_t DH, size_t DW, + bool is_xcorr, cudaStream_t stream); -} // namespace convolution -} // namespace cuda -} // namespace megdnn +} // namespace convolution3d +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/helper.cpp b/dnn/src/cuda/convolution3d/helper.cpp index 591786ae..93e2d4fe 100644 --- a/dnn/src/cuda/convolution3d/helper.cpp +++ b/dnn/src/cuda/convolution3d/helper.cpp @@ -15,10 +15,10 @@ using namespace megdnn; using namespace cuda; using namespace convolution3d; -bool convolution3d::is_cudnn_supported(const ForwardSizeArgs &args) { +bool convolution3d::is_cudnn_supported(const ForwardSizeArgs& args) { if (args.handle->is_tegra_k1()) return false; - + if (args.src_layout->dtype.category() != DTypeCategory::FLOAT) return false; @@ -31,18 +31,19 @@ bool convolution3d::is_cudnn_supported(const ForwardSizeArgs &args) { #else fm.group == 1 #endif - && fm.spatial_ndim == 3; + && fm.spatial_ndim == 3; } -void convolution3d::flip_filter(const ForwardSizeArgs &args, - const Workspace &workspace, void *&raw_ptr) { - auto &&fm = args.filter_meta; +void convolution3d::flip_filter( + const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr) { + auto&& fm = args.filter_meta; megdnn_assert(fm.group == 1 && fm.spatial_ndim == 3); - auto OC = fm.ocpg, IC = fm.icpg, FD = fm.spatial[0], FH = fm.spatial[1], FW = fm.spatial[2]; + auto OC = fm.ocpg, IC = fm.icpg, FD = fm.spatial[0], FH = fm.spatial[1], + FW = fm.spatial[2]; auto dtype = DType::from_enum(fm.dtype_enum); megdnn_assert(workspace.size >= dtype.size() * OC * IC * FD * FH * FW); TensorND src{raw_ptr, {{OC, IC, FD, FH, FW}, dtype}}, - dst{workspace.raw_ptr + (FD * FH * FW - 1) * dtype.size(), src.layout}; + dst{workspace.raw_ptr + (FD * FH * FW - 1) * dtype.size(), src.layout}; dst.layout.stride[2] = -dst.layout.stride[2]; dst.layout.stride[3] = -dst.layout.stride[3]; dst.layout.stride[4] = -dst.layout.stride[4]; diff --git a/dnn/src/cuda/convolution3d/helper.h b/dnn/src/cuda/convolution3d/helper.h index bef84205..4f85f161 100644 --- a/dnn/src/cuda/convolution3d/helper.h +++ b/dnn/src/cuda/convolution3d/helper.h @@ -11,241 +11,227 @@ #pragma once #include "./opr_impl.h" +#include "src/common/algo_chooser.h" +#include "src/common/utils.h" #include "src/cuda/cudnn_wrapper.h" #include "src/cuda/handle.h" -#include "src/common/utils.h" -#include "src/common/algo_chooser.h" #include "src/cuda/utils.h" namespace megdnn { namespace cuda { namespace convolution3d { - using CanonizedFilterMeta = Convolution3DForward::CanonizedFilterMeta; +using CanonizedFilterMeta = Convolution3DForward::CanonizedFilterMeta; - //! conv size descriptor in the forward view - struct ForwardSizeArgs { - HandleImpl *handle; - const TensorLayout *src_layout; - const TensorLayout *filter_layout; - CanonizedFilterMeta filter_meta; - const TensorLayout *dst_layout; - param::Convolution3D::DataType data_type; - }; +//! conv size descriptor in the forward view +struct ForwardSizeArgs { + HandleImpl* handle; + const TensorLayout* src_layout; + const TensorLayout* filter_layout; + CanonizedFilterMeta filter_meta; + const TensorLayout* dst_layout; + param::Convolution3D::DataType data_type; +}; - //! whether cudnn is supported for a filter meta - bool is_cudnn_supported(const ForwardSizeArgs &args); +//! whether cudnn is supported for a filter meta +bool is_cudnn_supported(const ForwardSizeArgs& args); - struct CUDNNForwardDescs { - Tensor3DDesc src_desc, dst_desc; - Filter3DDesc filter_desc; - Conv3DDesc conv_desc; - void set(const TensorLayout &src, - const CanonizedFilterMeta &filter, - const TensorLayout &dst, - const param::Convolution3D ¶m) - { - src_desc.set(src); - filter_desc.set(filter); - dst_desc.set(dst); - conv_desc.set(param, filter.group); - } - }; +struct CUDNNForwardDescs { + Tensor3DDesc src_desc, dst_desc; + Filter3DDesc filter_desc; + Conv3DDesc conv_desc; + void set( + const TensorLayout& src, const CanonizedFilterMeta& filter, + const TensorLayout& dst, const param::Convolution3D& param) { + src_desc.set(src); + filter_desc.set(filter); + dst_desc.set(dst); + conv_desc.set(param, filter.group); + } +}; - struct CUDNNBwdDataDescs { - Tensor3DDesc diff_desc, grad_desc; - Filter3DDesc filter_desc; - Conv3DDesc conv_desc; - void set(const CanonizedFilterMeta &filter, - const TensorLayout &diff, - const TensorLayout &grad, - const param::Convolution3D ¶m) - { - filter_desc.set(filter); - diff_desc.set(diff); - grad_desc.set(grad); - conv_desc.set(param, filter.group); - } - }; +struct CUDNNBwdDataDescs { + Tensor3DDesc diff_desc, grad_desc; + Filter3DDesc filter_desc; + Conv3DDesc conv_desc; + void set( + const CanonizedFilterMeta& filter, const TensorLayout& diff, + const TensorLayout& grad, const param::Convolution3D& param) { + filter_desc.set(filter); + diff_desc.set(diff); + grad_desc.set(grad); + conv_desc.set(param, filter.group); + } +}; - struct CUDNNBwdFilterDescs { - Tensor3DDesc diff_desc, src_desc; - Filter3DDesc grad_desc; - Conv3DDesc conv_desc; - void set(const TensorLayout &src, - const TensorLayout &diff, - const CanonizedFilterMeta &grad, - const param::Convolution3D ¶m) - { - src_desc.set(src); - diff_desc.set(diff); - grad_desc.set(grad); - conv_desc.set(param, grad.group); - } - }; +struct CUDNNBwdFilterDescs { + Tensor3DDesc diff_desc, src_desc; + Filter3DDesc grad_desc; + Conv3DDesc conv_desc; + void set( + const TensorLayout& src, const TensorLayout& diff, + const CanonizedFilterMeta& grad, const param::Convolution3D& param) { + src_desc.set(src); + diff_desc.set(diff); + grad_desc.set(grad); + conv_desc.set(param, grad.group); + } +}; - /*! - * \brief flip conv filter - * - * Flip conv filter pointed by \p raw_ptr, store result in workspace, and - * change \p raw_ptr to workspace. - */ - void flip_filter(const ForwardSizeArgs &args, - const Workspace &workspace, void *&raw_ptr); +/*! + * \brief flip conv filter + * + * Flip conv filter pointed by \p raw_ptr, store result in workspace, and + * change \p raw_ptr to workspace. + */ +void flip_filter( + const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr); - inline bool cudnn_get_convolution_fwd_algo_helper( - cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, - const cudnnConvolutionDescriptor_t conv_desc, - const cudnnTensorDescriptor_t y_desc, - size_t workspace_limit_in_bytes, cudnnConvolutionFwdAlgo_t* algo, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { - MEGDNN_MARK_USED_VAR(positive_attr); - MEGDNN_MARK_USED_VAR(negative_attr); +inline bool cudnn_get_convolution_fwd_algo_helper( + cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, + const cudnnFilterDescriptor_t w_desc, + const cudnnConvolutionDescriptor_t conv_desc, + const cudnnTensorDescriptor_t y_desc, size_t workspace_limit_in_bytes, + cudnnConvolutionFwdAlgo_t* algo, const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { + MEGDNN_MARK_USED_VAR(positive_attr); + MEGDNN_MARK_USED_VAR(negative_attr); #if CUDNN_MAJOR >= 7 - int algo_max_count = 0; - cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( - cudnn_handle, &algo_max_count)); - SmallVector algo_perf(algo_max_count); - int algo_count = 0; - cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( - cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_max_count, - &algo_count, algo_perf.data())); - for (int i = 0; i < algo_count; ++i) { - if (algo_perf[i].algo == - cudnnConvolutionFwdAlgo_t:: - CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) - continue; - size_t workspace_size = 0; - cudnn_check(cudnnGetConvolutionForwardWorkspaceSize( - cudnn_handle, x_desc, w_desc, conv_desc, y_desc, - algo_perf[i].algo, &workspace_size)); - if (workspace_size > workspace_limit_in_bytes) continue; - if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { + int algo_max_count = 0; + cudnn_check( + cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, &algo_max_count)); + SmallVector algo_perf(algo_max_count); + int algo_count = 0; + cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( + cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_max_count, + &algo_count, algo_perf.data())); + for (int i = 0; i < algo_count; ++i) { + if (algo_perf[i].algo == + cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) + continue; + size_t workspace_size = 0; + cudnn_check(cudnnGetConvolutionForwardWorkspaceSize( + cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_perf[i].algo, + &workspace_size)); + if (workspace_size > workspace_limit_in_bytes) + continue; + if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { + *algo = algo_perf[i].algo; + return true; + } else { + if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { *algo = algo_perf[i].algo; return true; - } else { - if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { - *algo = algo_perf[i].algo; - return true; - } } } - return false; + } + return false; #else - cudnn_check(cudnnGetConvolutionForwardAlgorithm( - cudnn_handle, x_desc, w_desc, conv_desc, y_desc, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_limit_in_bytes, algo)); - return true; + cudnn_check(cudnnGetConvolutionForwardAlgorithm( + cudnn_handle, x_desc, w_desc, conv_desc, y_desc, + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_limit_in_bytes, + algo)); + return true; #endif - } +} - inline bool cudnn_get_convolution_bwd_data_algo_helper( - cudnnHandle_t cudnn_handle, const cudnnFilterDescriptor_t w_desc, - const cudnnTensorDescriptor_t dy_desc, - const cudnnConvolutionDescriptor_t conv_desc, - const cudnnTensorDescriptor_t dx_desc, - size_t workspace_limit_in_bytes, - cudnnConvolutionBwdDataAlgo_t* algo, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { - MEGDNN_MARK_USED_VAR(positive_attr); - MEGDNN_MARK_USED_VAR(negative_attr); +inline bool cudnn_get_convolution_bwd_data_algo_helper( + cudnnHandle_t cudnn_handle, const cudnnFilterDescriptor_t w_desc, + const cudnnTensorDescriptor_t dy_desc, + const cudnnConvolutionDescriptor_t conv_desc, + const cudnnTensorDescriptor_t dx_desc, size_t workspace_limit_in_bytes, + cudnnConvolutionBwdDataAlgo_t* algo, const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { + MEGDNN_MARK_USED_VAR(positive_attr); + MEGDNN_MARK_USED_VAR(negative_attr); #if CUDNN_MAJOR >= 7 - int algo_max_count = 0; - cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( - cudnn_handle, &algo_max_count)); - SmallVector algo_perf( - algo_max_count); - int algo_count = 0; - cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7( - cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, - algo_max_count, &algo_count, algo_perf.data())); - for (int i = 0; i < algo_count; ++i) { - if (algo_perf[i].algo == - cudnnConvolutionBwdDataAlgo_t:: - CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING) - continue; - size_t workspace_size = 0; - cudnn_check(cudnnGetConvolutionBackwardDataWorkspaceSize( - cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, - algo_perf[i].algo, &workspace_size)); - if (workspace_size > workspace_limit_in_bytes) continue; - if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { + int algo_max_count = 0; + cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( + cudnn_handle, &algo_max_count)); + SmallVector algo_perf(algo_max_count); + int algo_count = 0; + cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7( + cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, algo_max_count, + &algo_count, algo_perf.data())); + for (int i = 0; i < algo_count; ++i) { + if (algo_perf[i].algo == + cudnnConvolutionBwdDataAlgo_t::CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING) + continue; + size_t workspace_size = 0; + cudnn_check(cudnnGetConvolutionBackwardDataWorkspaceSize( + cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, algo_perf[i].algo, + &workspace_size)); + if (workspace_size > workspace_limit_in_bytes) + continue; + if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { + *algo = algo_perf[i].algo; + return true; + } else { + if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { *algo = algo_perf[i].algo; return true; - } else { - if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { - *algo = algo_perf[i].algo; - return true; - } } } - return false; + } + return false; #else - cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle, - w_desc, dy_desc, conv_desc, dx_desc, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_limit_in_bytes, - algo)); - return true; + cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm( + cudnn_handle, w_desc, dy_desc, conv_desc, dx_desc, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_limit_in_bytes, algo)); + return true; #endif - } +} - inline bool cudnn_get_convolution_bwd_filter_algo_helper( - cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, - const cudnnTensorDescriptor_t dy_desc, - const cudnnConvolutionDescriptor_t conv_desc, - const cudnnFilterDescriptor_t dw_desc, - size_t workspace_limit_in_bytes, - cudnnConvolutionBwdFilterAlgo_t* algo, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { - MEGDNN_MARK_USED_VAR(positive_attr); - MEGDNN_MARK_USED_VAR(negative_attr); +inline bool cudnn_get_convolution_bwd_filter_algo_helper( + cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, + const cudnnTensorDescriptor_t dy_desc, + const cudnnConvolutionDescriptor_t conv_desc, + const cudnnFilterDescriptor_t dw_desc, size_t workspace_limit_in_bytes, + cudnnConvolutionBwdFilterAlgo_t* algo, const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { + MEGDNN_MARK_USED_VAR(positive_attr); + MEGDNN_MARK_USED_VAR(negative_attr); #if CUDNN_MAJOR >= 7 - int algo_max_count = 0; - cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( - cudnn_handle, &algo_max_count)); - SmallVector algo_perf( - algo_max_count); - int algo_count = 0; - cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm_v7( - cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, - algo_max_count, &algo_count, algo_perf.data())); - for (int i = 0; i < algo_count; ++i) { - if (algo_perf[i].algo == - cudnnConvolutionBwdFilterAlgo_t:: - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) - continue; - size_t workspace_size = 0; - cudnn_check(cudnnGetConvolutionBackwardFilterWorkspaceSize( - cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, - algo_perf[i].algo, &workspace_size)); - if (workspace_size > workspace_limit_in_bytes) continue; - if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { + int algo_max_count = 0; + cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( + cudnn_handle, &algo_max_count)); + SmallVector algo_perf(algo_max_count); + int algo_count = 0; + cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm_v7( + cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, algo_max_count, + &algo_count, algo_perf.data())); + for (int i = 0; i < algo_count; ++i) { + if (algo_perf[i].algo == cudnnConvolutionBwdFilterAlgo_t:: + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING) + continue; + size_t workspace_size = 0; + cudnn_check(cudnnGetConvolutionBackwardFilterWorkspaceSize( + cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, algo_perf[i].algo, + &workspace_size)); + if (workspace_size > workspace_limit_in_bytes) + continue; + if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { + *algo = algo_perf[i].algo; + return true; + } else { + if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { *algo = algo_perf[i].algo; return true; - } else { - if (algo_perf[i].determinism == CUDNN_DETERMINISTIC) { - *algo = algo_perf[i].algo; - return true; - } } } - return false; + } + return false; #else - cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm( - cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_limit_in_bytes, algo)); - return true; + cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm( + cudnn_handle, x_desc, dy_desc, conv_desc, dw_desc, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_limit_in_bytes, algo)); + return true; #endif - } +} -} // namespace convolution3d -} // namespace cuda -} // namespace megdnn +} // namespace convolution3d +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convolution3d/opr_impl.cpp b/dnn/src/cuda/convolution3d/opr_impl.cpp index 1a5abf37..b64a0454 100644 --- a/dnn/src/cuda/convolution3d/opr_impl.cpp +++ b/dnn/src/cuda/convolution3d/opr_impl.cpp @@ -23,17 +23,15 @@ using namespace cuda; using namespace convolution3d; #define TO_STRING2(v) #v -#define TO_STRING(v) TO_STRING2(v) +#define TO_STRING(v) TO_STRING2(v) #define CUDNN_VERSION_STR \ TO_STRING(CUDNN_MAJOR) \ "." TO_STRING(CUDNN_MINOR) "." TO_STRING(CUDNN_PATCHLEVEL) /* ============== Convolution3DForwardImpl ============== */ -Convolution3DForwardImpl::Algorithm* -Convolution3DForwardImpl::get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, +Convolution3DForwardImpl::Algorithm* Convolution3DForwardImpl::get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, src, filter, dst); @@ -43,8 +41,7 @@ Convolution3DForwardImpl::get_algorithm_heuristic( // version is lower than v7.5.0 is still slower than our implementation // in many channel-wise cases if (sm_algo_pack.chanwise.is_available_attribute( - args, positive_attr, negative_attr, - workspace_limit_in_bytes)) { + args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return &sm_algo_pack.chanwise; } } @@ -61,9 +58,8 @@ Convolution3DForwardImpl::get_algorithm_heuristic( args, positive_attr, negative_attr, workspace_limit_in_bytes); }; - auto get_cudnn_algo = - [this, &args, workspace_limit_in_bytes, positive_attr, - negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { + auto get_cudnn_algo = [this, &args, workspace_limit_in_bytes, positive_attr, + negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { auto cudnn_handle = cuda::cudnn_handle(this->handle()); cudnnConvolutionFwdAlgo_t algo; CUDNNForwardDescs desc; @@ -71,8 +67,8 @@ Convolution3DForwardImpl::get_algorithm_heuristic( bool got = cudnn_get_convolution_fwd_algo_helper( cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, - desc.conv_desc.desc, desc.dst_desc.desc, - workspace_limit_in_bytes, &algo, positive_attr, negative_attr); + desc.conv_desc.desc, desc.dst_desc.desc, workspace_limit_in_bytes, + &algo, positive_attr, negative_attr); if (got) { return static_cast( megdnn::get_algo_match_attribute( @@ -101,32 +97,30 @@ Convolution3DForwardImpl::get_algorithm_heuristic( "cuda conv3d fwd", positive_attr, negative_attr); } -std::vector -Convolution3DForwardImpl::get_all_algorithms(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) { +std::vector Convolution3DForwardImpl:: + get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) { return megdnn::get_all_algorithms( {this, src, filter, dst}); } -std::vector -Convolution3DForwardImpl::get_all_algorithms_safe(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) { +std::vector Convolution3DForwardImpl:: + get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) { return megdnn::get_all_algorithms_safe( {this, src, filter, dst}); } size_t Convolution3DForwardImpl::get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst) { + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { return get_dnn_workspace(this, src, filter, dst); } -void Convolution3DForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) { +void Convolution3DForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { check_exec(src.layout, filter.layout, dst.layout, workspace.size); AlgoBase::ExecArgs args(this, src, filter, dst, workspace); auto algo = get_algorithm(this, src.layout, filter.layout, dst.layout); @@ -137,38 +131,37 @@ const char* Convolution3DForwardImpl::get_algorithm_set_name() const { return "CUDACONV0+CUDNN" CUDNN_VERSION_STR; } -void Convolution3DBackwardDataImpl::exec(_megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) { +void Convolution3DBackwardDataImpl::exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { check_exec(filter.layout, diff.layout, grad.layout, workspace.size); AlgoBase::ExecArgs args(this, filter, diff, grad, workspace); auto algo = get_algorithm(this, filter.layout, diff.layout, grad.layout); algo->exec(args); } -std::vector -Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) { +std::vector Convolution3DBackwardDataImpl:: + get_all_algorithms( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) { return megdnn::get_all_algorithms( {this, filter, diff, grad}); } -std::vector -Convolution3DBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) { +std::vector Convolution3DBackwardDataImpl:: + get_all_algorithms_safe( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) { return megdnn::get_all_algorithms_safe( {this, filter, diff, grad}); } -Convolution3DBackwardDataImpl::Algorithm* -Convolution3DBackwardDataImpl::get_algorithm_heuristic( - const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { +Convolution3DBackwardDataImpl::Algorithm* Convolution3DBackwardDataImpl:: + get_algorithm_heuristic( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, filter, diff, grad); if (args.filter_meta.group > 1 && @@ -186,13 +179,13 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( args.init_desc(desc); bool got = cudnn_get_convolution_bwd_data_algo_helper( cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, - desc.conv_desc.desc, desc.grad_desc.desc, - workspace_limit_in_bytes, &algo, positive_attr, negative_attr); + desc.conv_desc.desc, desc.grad_desc.desc, workspace_limit_in_bytes, + &algo, positive_attr, negative_attr); if (got) { - return static_cast(megdnn::get_algo_match_attribute< - Convolution3DBackwardDataImpl>( - sm_algo_pack.cudnn_from_enum(algo), positive_attr, - negative_attr)); + return static_cast( + megdnn::get_algo_match_attribute( + sm_algo_pack.cudnn_from_enum(algo), positive_attr, + negative_attr)); } else { return nullptr; } @@ -224,39 +217,35 @@ const char* Convolution3DBackwardDataImpl::get_algorithm_set_name() const { return "CUDACONV0+CUDNN" CUDNN_VERSION_STR; } -void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) { +void Convolution3DBackwardFilterImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { check_exec(src.layout, diff.layout, grad.layout, workspace.size); AlgoBase::ExecArgs args(this, src, diff, grad, workspace); - auto algo = - get_algorithm(this, src.layout, diff.layout, grad.layout); + auto algo = get_algorithm(this, src.layout, diff.layout, grad.layout); algo->exec(args); } std::vector -Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) { +Convolution3DBackwardFilterImpl::get_all_algorithms( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { return megdnn::get_all_algorithms( {this, src, diff, grad}); } std::vector -Convolution3DBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) { +Convolution3DBackwardFilterImpl::get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { return megdnn::get_all_algorithms_safe( {this, src, diff, grad}); } -Convolution3DBackwardFilterImpl::Algorithm* -Convolution3DBackwardFilterImpl::get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { +Convolution3DBackwardFilterImpl::Algorithm* Convolution3DBackwardFilterImpl:: + get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, src, diff, grad); if (args.grad_filter_meta.group > 1 && @@ -274,13 +263,13 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( args.init_desc(desc); bool got = cudnn_get_convolution_bwd_filter_algo_helper( cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, - desc.conv_desc.desc, desc.grad_desc.desc, - workspace_limit_in_bytes, &algo, positive_attr, negative_attr); + desc.conv_desc.desc, desc.grad_desc.desc, workspace_limit_in_bytes, + &algo, positive_attr, negative_attr); if (got) { - return static_cast(megdnn::get_algo_match_attribute< - Convolution3DBackwardFilterImpl>( - sm_algo_pack.cudnn_from_enum(algo), positive_attr, - negative_attr)); + return static_cast( + megdnn::get_algo_match_attribute( + sm_algo_pack.cudnn_from_enum(algo), positive_attr, + negative_attr)); } else { return nullptr; } @@ -303,9 +292,8 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( } size_t Convolution3DBackwardFilterImpl::get_workspace_in_bytes( - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad) { - return get_dnn_workspace(this, src, diff , grad); + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { + return get_dnn_workspace(this, src, diff, grad); } const char* Convolution3DBackwardFilterImpl::get_algorithm_set_name() const { diff --git a/dnn/src/cuda/convolution3d/opr_impl.h b/dnn/src/cuda/convolution3d/opr_impl.h index 5b208b4b..530cade8 100644 --- a/dnn/src/cuda/convolution3d/opr_impl.h +++ b/dnn/src/cuda/convolution3d/opr_impl.h @@ -19,11 +19,12 @@ namespace cuda { class Convolution3DForwardImpl : public Convolution3DForward { public: using Convolution3DForward::Convolution3DForward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_out dst, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; const char* get_algorithm_set_name() const override; class AlgoBase; class AlgoCUDNN; @@ -55,11 +56,12 @@ private: class Convolution3DBackwardDataImpl : public Convolution3DBackwardData { public: using Convolution3DBackwardData::Convolution3DBackwardData; - void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) override; + void exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; const char* get_algorithm_set_name() const override; class AlgoBase; @@ -93,11 +95,12 @@ private: class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter { public: using Convolution3DBackwardFilter::Convolution3DBackwardFilter; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; const char* get_algorithm_set_name() const override; class AlgoBase; @@ -119,9 +122,8 @@ protected: const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) override; private: diff --git a/dnn/src/cuda/convolution_helper/activation.cuh b/dnn/src/cuda/convolution_helper/activation.cuh index 17bb8ef1..51c84035 100644 --- a/dnn/src/cuda/convolution_helper/activation.cuh +++ b/dnn/src/cuda/convolution_helper/activation.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -42,10 +43,9 @@ namespace convolution { template struct Activation; -#define DEF_APPLY_AND_TRANSFORM(_act) \ - __device__ __forceinline__ int apply_and_transform(float4 in) { \ - return transform_float4_to_int8x4( \ - quantize(_act::apply(dequantize(in)))); \ +#define DEF_APPLY_AND_TRANSFORM(_act) \ + __device__ __forceinline__ int apply_and_transform(float4 in) { \ + return transform_float4_to_int8x4(quantize(_act::apply(dequantize(in)))); \ } template <> @@ -57,12 +57,8 @@ struct Activation { #if MEGDNN_CC_CUDA DEF_APPLY_AND_TRANSFORM( Activation); - __device__ __forceinline__ float4 dequantize(float4 in) { - return scale * in; - } - __device__ __forceinline__ float4 quantize(float4 in) { - return inv_scale * in; - } + __device__ __forceinline__ float4 dequantize(float4 in) { return scale * in; } + __device__ __forceinline__ float4 quantize(float4 in) { return inv_scale * in; } __device__ __forceinline__ static float4 apply(float4 in) { float x = in.x * fminf(fmaxf(in.x + 3.f, 0.f), 6.f) * (1.f / 6.f); float y = in.y * fminf(fmaxf(in.y + 3.f, 0.f), 6.f) * (1.f / 6.f); @@ -75,8 +71,7 @@ struct Activation { template <> struct Activation { - MEGDNN_HOST MEGDNN_DEVICE Activation(float /* scale */, - float /* inv_scale */) {} + MEGDNN_HOST MEGDNN_DEVICE Activation(float /* scale */, float /* inv_scale */) {} #if MEGDNN_CC_CUDA DEF_APPLY_AND_TRANSFORM( Activation); @@ -94,8 +89,7 @@ struct Activation { template <> struct Activation { - MEGDNN_HOST MEGDNN_DEVICE Activation(float /* scale */, - float /* inv_scale */) {} + MEGDNN_HOST MEGDNN_DEVICE Activation(float /* scale */, float /* inv_scale */) {} #if MEGDNN_CC_CUDA DEF_APPLY_AND_TRANSFORM( Activation); diff --git a/dnn/src/cuda/convolution_helper/bias_visitor.cuh b/dnn/src/cuda/convolution_helper/bias_visitor.cuh index 4b7956bb..9272ef33 100644 --- a/dnn/src/cuda/convolution_helper/bias_visitor.cuh +++ b/dnn/src/cuda/convolution_helper/bias_visitor.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -42,24 +43,24 @@ namespace convolution { struct PerChannelBiasVisitor { const int32_t* __restrict__ bias; #if MEGDNN_CC_CUDA - __host__ __device__ __forceinline__ void move(int, int ch, int, int) { - bias += ch; - } + __host__ __device__ __forceinline__ void move(int, int ch, int, int) { bias += ch; } __host__ __device__ __forceinline__ float4 at(int, int ch, int, int) { int ix = *(bias + ch); int iy = *(bias + ch + 1); int iz = *(bias + ch + 2); int iw = *(bias + ch + 3); - return ::make_float4(static_cast(ix), static_cast(iy), - static_cast(iz), static_cast(iw)); + return ::make_float4( + static_cast(ix), static_cast(iy), static_cast(iz), + static_cast(iw)); } __host__ __device__ __forceinline__ float4 at(int, int ch, int) { int ix = *(bias + ch); int iy = *(bias + ch + 1); int iz = *(bias + ch + 2); int iw = *(bias + ch + 3); - return ::make_float4(static_cast(ix), static_cast(iy), - static_cast(iz), static_cast(iw)); + return ::make_float4( + static_cast(ix), static_cast(iy), static_cast(iz), + static_cast(iw)); } #endif }; diff --git a/dnn/src/cuda/convolution_helper/block_tile_consumer/block_consumer.cuh b/dnn/src/cuda/convolution_helper/block_tile_consumer/block_consumer.cuh index 778b81b5..0a9d1d07 100644 --- a/dnn/src/cuda/convolution_helper/block_tile_consumer/block_consumer.cuh +++ b/dnn/src/cuda/convolution_helper/block_tile_consumer/block_consumer.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** diff --git a/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer.cuh b/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer.cuh index e67f4ac3..9aea6a96 100644 --- a/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer.cuh +++ b/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -60,8 +61,8 @@ struct IConvBlockConsumer { } } - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void consume_block( DataGlobal2ShareMemVisitor data_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor) { @@ -92,35 +93,32 @@ struct IConvBlockConsumer { #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m_packed; ++j) { smem_storage_dtype* ker_sh_ptr = filter_gl2sh_visitor.sh_ptr( - 0, tidy * RegBlockConfig::pack_size + - j * ThreadConfig::nr_thread_y * - RegBlockConfig::pack_size); + 0, + tidy * RegBlockConfig::pack_size + + j * ThreadConfig::nr_thread_y * RegBlockConfig::pack_size); #pragma unroll for (int pack = 0; pack < RegBlockConfig::pack_size; ++pack) { - reg_filter[j * RegBlockConfig::pack_size + pack][0] = - *(ker_sh_ptr++); + reg_filter[j * RegBlockConfig::pack_size + pack][0] = *(ker_sh_ptr++); } } #pragma unroll - for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; - ++ci_inner) { + for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; ++ci_inner) { const int comp_idx = (ci_inner & 0x1); const int load_idx = 1 - comp_idx; #pragma unroll for (int i = 0; i < RegBlockConfig::reg_n; ++i) { #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m; ++j) { - dot_prod(reg_src[i][comp_idx], reg_filter[j][comp_idx], - reg_acc[i][j], reg_acc[i][j]); + dot_prod( + reg_src[i][comp_idx], reg_filter[j][comp_idx], + reg_acc[i][j], reg_acc[i][j]); } } if (ci_inner < RegBlockConfig::reg_k_packed - 1) { - int32_t* data_sh_ptr = - data_gl2sh_visitor.sh_ptr(ci_inner + 1, 0); - int32_t* ker_sh_ptr = - filter_gl2sh_visitor.sh_ptr(ci_inner + 1, 0); + int32_t* data_sh_ptr = data_gl2sh_visitor.sh_ptr(ci_inner + 1, 0); + int32_t* ker_sh_ptr = filter_gl2sh_visitor.sh_ptr(ci_inner + 1, 0); if (use_wide_store) { #pragma unroll @@ -128,31 +126,28 @@ struct IConvBlockConsumer { int i2 = (i << 1); int tidx2 = (tidx << 1); reg_src[i2][load_idx] = - data_sh_ptr[tidx2 + - i2 * ThreadConfig::nr_thread_x]; + data_sh_ptr[tidx2 + i2 * ThreadConfig::nr_thread_x]; reg_src[i2 + 1][load_idx] = - data_sh_ptr[tidx2 + - i2 * ThreadConfig::nr_thread_x + 1]; + data_sh_ptr[tidx2 + i2 * ThreadConfig::nr_thread_x + 1]; } } else { #pragma unroll for (int i = 0; i < RegBlockConfig::reg_n; ++i) { reg_src[i][load_idx] = - data_sh_ptr[tidx + - i * ThreadConfig::nr_thread_x]; + data_sh_ptr[tidx + i * ThreadConfig::nr_thread_x]; } } #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m_packed; ++j) { smem_storage_dtype* ker_sh_ptr_packed = - &ker_sh_ptr[(tidy + j * ThreadConfig::nr_thread_y) * - RegBlockConfig::pack_size]; + &ker_sh_ptr + [(tidy + j * ThreadConfig::nr_thread_y) * + RegBlockConfig::pack_size]; #pragma unroll - for (int pack = 0; pack < RegBlockConfig::pack_size; - ++pack) { - reg_filter[j * RegBlockConfig::pack_size + pack] - [load_idx] = *(ker_sh_ptr_packed++); + for (int pack = 0; pack < RegBlockConfig::pack_size; ++pack) { + reg_filter[j * RegBlockConfig::pack_size + pack][load_idx] = + *(ker_sh_ptr_packed++); } } } @@ -179,8 +174,8 @@ struct IConvBlockConsumer { } } - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void consume_block( DataGlobal2ShareMemVisitor data_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor) { @@ -191,35 +186,30 @@ struct IConvBlockConsumer { static bool constexpr use_wide_store = !(RegBlockConfig::reg_n & 0x1); #pragma unroll - for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; - ++ci_inner) { - smem_storage_dtype* data_sh_ptr = - data_gl2sh_visitor.sh_ptr(ci_inner, 0); - smem_storage_dtype* ker_sh_ptr = - filter_gl2sh_visitor.sh_ptr(ci_inner, 0); + for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; ++ci_inner) { + smem_storage_dtype* data_sh_ptr = data_gl2sh_visitor.sh_ptr(ci_inner, 0); + smem_storage_dtype* ker_sh_ptr = filter_gl2sh_visitor.sh_ptr(ci_inner, 0); if (use_wide_store) { #pragma unroll for (int i = 0; i < (RegBlockConfig::reg_n >> 1); ++i) { int i2 = (i << 1); int tidx2 = (tidx << 1); - reg_src[i2] = - data_sh_ptr[tidx2 + i2 * ThreadConfig::nr_thread_x]; + reg_src[i2] = data_sh_ptr[tidx2 + i2 * ThreadConfig::nr_thread_x]; reg_src[i2 + 1] = - data_sh_ptr[tidx2 + i2 * ThreadConfig::nr_thread_x + - 1]; + data_sh_ptr[tidx2 + i2 * ThreadConfig::nr_thread_x + 1]; } } else { #pragma unroll for (int i = 0; i < RegBlockConfig::reg_n; ++i) { - reg_src[i] = - data_sh_ptr[tidx + i * ThreadConfig::nr_thread_x]; + reg_src[i] = data_sh_ptr[tidx + i * ThreadConfig::nr_thread_x]; } } #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m_packed; ++j) { smem_storage_dtype* ker_sh_ptr_packed = - &ker_sh_ptr[(tidy + j * ThreadConfig::nr_thread_y) * - RegBlockConfig::pack_size]; + &ker_sh_ptr + [(tidy + j * ThreadConfig::nr_thread_y) * + RegBlockConfig::pack_size]; #pragma unroll for (int pack = 0; pack < RegBlockConfig::pack_size; ++pack) { reg_filter[j * RegBlockConfig::pack_size + pack] = @@ -230,8 +220,7 @@ struct IConvBlockConsumer { for (int i = 0; i < RegBlockConfig::reg_n; ++i) { #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m; ++j) { - dot_prod(reg_src[i], reg_filter[j], reg_acc[i][j], - reg_acc[i][j]); + dot_prod(reg_src[i], reg_filter[j], reg_acc[i][j], reg_acc[i][j]); } } } diff --git a/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_coxhw.cuh b/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_coxhw.cuh index e8f80fb6..5f753bbb 100644 --- a/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_coxhw.cuh +++ b/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_coxhw.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_coxhw.cuh + * \file + * dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_coxhw.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -60,8 +62,8 @@ struct IConvBlockConsumer_COxHW { } } - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void consume_block( DataGlobal2ShareMemVisitor data_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor) { @@ -90,8 +92,9 @@ struct IConvBlockConsumer_COxHW { } #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m_packed; ++j) { - int out_channel = ((tidy + j * ThreadConfig::nr_thread_y) - << RegBlockConfig::pack_size_bit); + int out_channel = + ((tidy + j * ThreadConfig::nr_thread_y) + << RegBlockConfig::pack_size_bit); #pragma unroll for (int packed = 0; packed < RegBlockConfig::pack_size; ++packed) { reg_filter[j * RegBlockConfig::pack_size + packed][0] = @@ -100,20 +103,17 @@ struct IConvBlockConsumer_COxHW { } #pragma unroll - for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; - ++ci_inner) { + for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; ++ci_inner) { const int comp_idx = (ci_inner & 0x1); const int load_idx = 1 - comp_idx; if (ci_inner < RegBlockConfig::reg_k_packed - 1) { - if (use_wide_store) { #pragma unroll for (int i = 0; i < (RegBlockConfig::reg_width >> 1); ++i) { int i2 = (i << 1); int tidx2 = (tidx << 1); reg_src[i2][load_idx] = *(data_gl2sh_visitor.sh_ptr( - ci_inner + 1, - tidx2 + i2 * ThreadConfig::nr_thread_x)); + ci_inner + 1, tidx2 + i2 * ThreadConfig::nr_thread_x)); reg_src[i2 + 1][load_idx] = *(data_gl2sh_visitor.sh_ptr( ci_inner + 1, tidx2 + i2 * ThreadConfig::nr_thread_x + 1)); @@ -122,20 +122,19 @@ struct IConvBlockConsumer_COxHW { #pragma unroll for (int i = 0; i < RegBlockConfig::reg_width; ++i) { reg_src[i][load_idx] = *(data_gl2sh_visitor.sh_ptr( - ci_inner + 1, - tidx + i * ThreadConfig::nr_thread_x)); + ci_inner + 1, tidx + i * ThreadConfig::nr_thread_x)); } } #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m_packed; ++j) { - int out_channel = ((tidy + j * ThreadConfig::nr_thread_y) - << RegBlockConfig::pack_size_bit); + int out_channel = + ((tidy + j * ThreadConfig::nr_thread_y) + << RegBlockConfig::pack_size_bit); #pragma unroll - for (int packed = 0; packed < RegBlockConfig::pack_size; - ++packed) { - reg_filter[j * RegBlockConfig::pack_size + packed] - [load_idx] = *(filter_gl2sh_visitor.sh_ptr( - out_channel + packed, ci_inner + 1)); + for (int packed = 0; packed < RegBlockConfig::pack_size; ++packed) { + reg_filter[j * RegBlockConfig::pack_size + packed][load_idx] = + *(filter_gl2sh_visitor.sh_ptr( + out_channel + packed, ci_inner + 1)); } } } @@ -172,8 +171,9 @@ struct IConvBlockConsumer_COxHW { // %d, %d\n", x, y, z, w); // } // } - dot_prod(reg_src[i][comp_idx], reg_filter[j][comp_idx], - reg_acc[i][j], reg_acc[i][j]); + dot_prod( + reg_src[i][comp_idx], reg_filter[j][comp_idx], + reg_acc[i][j], reg_acc[i][j]); } } } @@ -199,8 +199,8 @@ struct IConvBlockConsumer_COxHW { } } - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void consume_block( DataGlobal2ShareMemVisitor data_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor) { @@ -211,8 +211,7 @@ struct IConvBlockConsumer_COxHW { static bool const use_wide_store = !(RegBlockConfig::reg_width & 0x1); #pragma unroll - for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; - ++ci_inner) { + for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; ++ci_inner) { if (use_wide_store) { #pragma unroll for (int i = 0; i < (RegBlockConfig::reg_width >> 1); ++i) { @@ -221,8 +220,7 @@ struct IConvBlockConsumer_COxHW { reg_src[i2] = *(data_gl2sh_visitor.sh_ptr( ci_inner, tidx2 + i2 * ThreadConfig::nr_thread_x)); reg_src[i2 + 1] = *(data_gl2sh_visitor.sh_ptr( - ci_inner, - tidx2 + i2 * ThreadConfig::nr_thread_x + 1)); + ci_inner, tidx2 + i2 * ThreadConfig::nr_thread_x + 1)); } } else { #pragma unroll @@ -233,23 +231,21 @@ struct IConvBlockConsumer_COxHW { } #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m_packed; ++j) { - int out_channel = ((tidy + j * ThreadConfig::nr_thread_y) - << RegBlockConfig::pack_size_bit); + int out_channel = + ((tidy + j * ThreadConfig::nr_thread_y) + << RegBlockConfig::pack_size_bit); #pragma unroll - for (int packed = 0; packed < RegBlockConfig::pack_size; - ++packed) { + for (int packed = 0; packed < RegBlockConfig::pack_size; ++packed) { reg_filter[j * RegBlockConfig::pack_size + packed] = - *(filter_gl2sh_visitor.sh_ptr(out_channel + - packed, - ci_inner)); + *(filter_gl2sh_visitor.sh_ptr( + out_channel + packed, ci_inner)); } } #pragma unroll for (int i = 0; i < RegBlockConfig::reg_width; ++i) { #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m; ++j) { - dot_prod(reg_src[i], reg_filter[j], reg_acc[i][j], - reg_acc[i][j]); + dot_prod(reg_src[i], reg_filter[j], reg_acc[i][j], reg_acc[i][j]); } } } diff --git a/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_unroll_width.cuh b/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_unroll_width.cuh index f0f154b3..1c799d47 100644 --- a/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_unroll_width.cuh +++ b/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_unroll_width.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_unroll_width.cuh + * \file + * dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_block_consumer_unroll_width.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -64,8 +66,8 @@ struct IConvBlockConsumerUnrollWidth { } } - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void consume_block( DataGlobal2ShareMemVisitor data_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor) { @@ -101,24 +103,21 @@ struct IConvBlockConsumerUnrollWidth { #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m_packed; ++j) { int32_t* ker_sh_ptr = filter_gl2sh_visitor.sh_ptr( - 0, tidy * RegBlockConfig::pack_size + - j * ThreadConfig::nr_thread_y * - RegBlockConfig::pack_size); + 0, + tidy * RegBlockConfig::pack_size + + j * ThreadConfig::nr_thread_y * RegBlockConfig::pack_size); #pragma unroll for (int packed = 0; packed < RegBlockConfig::pack_size; ++packed) { - reg_filter[j * RegBlockConfig::pack_size + packed][0] = - *(ker_sh_ptr++); + reg_filter[j * RegBlockConfig::pack_size + packed][0] = *(ker_sh_ptr++); } } #pragma unroll - for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; - ++ci_inner) { + for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; ++ci_inner) { const int comp_idx = (ci_inner & 0x1); const int load_idx = 1 - comp_idx; if (ci_inner < RegBlockConfig::reg_k_packed - 1) { - int32_t* ker_sh_ptr = - filter_gl2sh_visitor.sh_ptr(ci_inner + 1, 0); + int32_t* ker_sh_ptr = filter_gl2sh_visitor.sh_ptr(ci_inner + 1, 0); if (use_wide_store) { #pragma unroll @@ -127,18 +126,12 @@ struct IConvBlockConsumerUnrollWidth { for (int j = 0; j < RegBlockConfig::reg_width; ++j) { int i2 = (i << 1); int tidx2 = (tidx << 1); - reg_src[i2][j] - [load_idx] = *(data_gl2sh_visitor.sh_ptr( - ci_inner + 1, j, - tidx2 + i2 * ThreadConfig:: - nr_thread_x)); - reg_src[i2 + 1][j] - [load_idx] = *(data_gl2sh_visitor.sh_ptr( - ci_inner + 1, j, - tidx2 + - i2 * ThreadConfig:: - nr_thread_x + - 1)); + reg_src[i2][j][load_idx] = *(data_gl2sh_visitor.sh_ptr( + ci_inner + 1, j, + tidx2 + i2 * ThreadConfig::nr_thread_x)); + reg_src[i2 + 1][j][load_idx] = *(data_gl2sh_visitor.sh_ptr( + ci_inner + 1, j, + tidx2 + i2 * ThreadConfig::nr_thread_x + 1)); } } } else { @@ -146,24 +139,22 @@ struct IConvBlockConsumerUnrollWidth { for (int i = 0; i < RegBlockConfig::reg_n; ++i) { #pragma unroll for (int j = 0; j < RegBlockConfig::reg_width; ++j) { - reg_src[i][j] - [load_idx] = *(data_gl2sh_visitor.sh_ptr( - ci_inner + 1, j, - tidx + i * ThreadConfig:: - nr_thread_x)); + reg_src[i][j][load_idx] = *(data_gl2sh_visitor.sh_ptr( + ci_inner + 1, j, + tidx + i * ThreadConfig::nr_thread_x)); } } } #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m_packed; ++j) { int32_t* ker_sh_ptr_packed = - &ker_sh_ptr[(tidy + j * ThreadConfig::nr_thread_y) * - RegBlockConfig::pack_size]; + &ker_sh_ptr + [(tidy + j * ThreadConfig::nr_thread_y) * + RegBlockConfig::pack_size]; #pragma unroll - for (int packed = 0; packed < RegBlockConfig::pack_size; - ++packed) { - reg_filter[j * RegBlockConfig::pack_size + packed] - [load_idx] = *(ker_sh_ptr_packed++); + for (int packed = 0; packed < RegBlockConfig::pack_size; ++packed) { + reg_filter[j * RegBlockConfig::pack_size + packed][load_idx] = + *(ker_sh_ptr_packed++); } } } @@ -173,9 +164,9 @@ struct IConvBlockConsumerUnrollWidth { for (int j = 0; j < RegBlockConfig::reg_width; ++j) { #pragma unroll for (int k = 0; k < RegBlockConfig::reg_m; ++k) { - dot_prod(reg_src[i][j][comp_idx], - reg_filter[k][comp_idx], reg_acc[i][j][k], - reg_acc[i][j][k]); + dot_prod( + reg_src[i][j][comp_idx], reg_filter[k][comp_idx], + reg_acc[i][j][k], reg_acc[i][j][k]); } } } @@ -206,8 +197,8 @@ struct IConvBlockConsumerUnrollWidth { } } - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void consume_block( DataGlobal2ShareMemVisitor data_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor) { @@ -218,8 +209,7 @@ struct IConvBlockConsumerUnrollWidth { static bool const use_wide_store = !(RegBlockConfig::reg_n & 0x1); #pragma unroll - for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; - ++ci_inner) { + for (int ci_inner = 0; ci_inner < RegBlockConfig::reg_k_packed; ++ci_inner) { int32_t* ker_sh_ptr = filter_gl2sh_visitor.sh_ptr(ci_inner, 0); if (use_wide_store) { @@ -230,8 +220,7 @@ struct IConvBlockConsumerUnrollWidth { int i2 = (i << 1); int tidx2 = (tidx << 1); reg_src[i2][j] = *(data_gl2sh_visitor.sh_ptr( - ci_inner, j, - tidx2 + i2 * ThreadConfig::nr_thread_x)); + ci_inner, j, tidx2 + i2 * ThreadConfig::nr_thread_x)); reg_src[i2 + 1][j] = *(data_gl2sh_visitor.sh_ptr( ci_inner, j, tidx2 + i2 * ThreadConfig::nr_thread_x + 1)); @@ -243,19 +232,18 @@ struct IConvBlockConsumerUnrollWidth { #pragma unroll for (int j = 0; j < RegBlockConfig::reg_width; ++j) { reg_src[i][j] = *(data_gl2sh_visitor.sh_ptr( - ci_inner, j, - tidx + i * ThreadConfig::nr_thread_x)); + ci_inner, j, tidx + i * ThreadConfig::nr_thread_x)); } } } #pragma unroll for (int j = 0; j < RegBlockConfig::reg_m_packed; ++j) { int32_t* ker_sh_ptr_packed = - &ker_sh_ptr[(tidy + j * ThreadConfig::nr_thread_y) * - RegBlockConfig::pack_size]; + &ker_sh_ptr + [(tidy + j * ThreadConfig::nr_thread_y) * + RegBlockConfig::pack_size]; #pragma unroll - for (int packed = 0; packed < RegBlockConfig::pack_size; - ++packed) { + for (int packed = 0; packed < RegBlockConfig::pack_size; ++packed) { reg_filter[j * RegBlockConfig::pack_size + packed] = *(ker_sh_ptr_packed++); } @@ -266,8 +254,9 @@ struct IConvBlockConsumerUnrollWidth { for (int j = 0; j < RegBlockConfig::reg_width; ++j) { #pragma unroll for (int k = 0; k < RegBlockConfig::reg_m; ++k) { - dot_prod(reg_src[i][j], reg_filter[k], reg_acc[i][j][k], - reg_acc[i][j][k]); + dot_prod( + reg_src[i][j], reg_filter[k], reg_acc[i][j][k], + reg_acc[i][j][k]); } } } diff --git a/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer.cuh b/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer.cuh index acb54f28..129b383e 100644 --- a/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer.cuh +++ b/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer.cuh @@ -1,36 +1,39 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer.cuh + * \file + * dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "src/cuda/utils.cuh" @@ -38,14 +41,13 @@ namespace megdnn { namespace cuda { namespace convolution { -template +template < + typename IMMAConfig_, typename WarpTileConfig_, typename ThreadConfig_, + bool pipelined> struct IConvIMMABlockConsumer; -template -struct IConvIMMABlockConsumer { +template +struct IConvIMMABlockConsumer { using IMMAConfig = IMMAConfig_; using WarpTileConfig = WarpTileConfig_; using ThreadConfig = ThreadConfig_; @@ -69,8 +71,8 @@ struct IConvIMMABlockConsumer + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void consume_block( DataGlobal2ShareMemVisitor data_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor) { @@ -89,13 +91,13 @@ struct IConvIMMABlockConsumer(data_sh_ptr), - IMMAConfig::wmma_k); + wmma::load_matrix_sync( + frag_src[i2][0], reinterpret_cast(data_sh_ptr), + IMMAConfig::wmma_k); wmma::load_matrix_sync( frag_src[i2 + 1][0], - reinterpret_cast(data_sh_ptr + - IMMAConfig::tile_b_sizes_int), + reinterpret_cast( + data_sh_ptr + IMMAConfig::tile_b_sizes_int), IMMAConfig::wmma_k); } } else { @@ -104,9 +106,9 @@ struct IConvIMMABlockConsumer(data_sh_ptr), - IMMAConfig::wmma_k); + wmma::load_matrix_sync( + frag_src[i][0], reinterpret_cast(data_sh_ptr), + IMMAConfig::wmma_k); } } #pragma unroll @@ -114,27 +116,24 @@ struct IConvIMMABlockConsumer(ker_sh_ptr), - IMMAConfig::wmma_k); + wmma::load_matrix_sync( + frag_filter[j][0], reinterpret_cast(ker_sh_ptr), + IMMAConfig::wmma_k); } #pragma unroll - for (int ci_inner = 0; ci_inner < WarpTileConfig::warp_tile_k; - ++ci_inner) { + for (int ci_inner = 0; ci_inner < WarpTileConfig::warp_tile_k; ++ci_inner) { const int comp_idx = (ci_inner & 0x1); const int load_idx = 1 - comp_idx; if (ci_inner < WarpTileConfig::warp_tile_k - 1) { if (use_wide_store) { #pragma unroll - for (int i = 0; i < (WarpTileConfig::warp_tile_n >> 1); - ++i) { + for (int i = 0; i < (WarpTileConfig::warp_tile_n >> 1); ++i) { int i2 = (i << 1); int warpx2 = (warpx << 1); int32_t* data_sh_ptr = data_gl2sh_visitor.sh_ptr( - ci_inner + 1, - (warpx2 + i2 * ThreadConfig::nr_warp_x) * - IMMAConfig::tile_b_sizes_int); + ci_inner + 1, (warpx2 + i2 * ThreadConfig::nr_warp_x) * + IMMAConfig::tile_b_sizes_int); wmma::load_matrix_sync( frag_src[i2][load_idx], reinterpret_cast(data_sh_ptr), @@ -142,17 +141,15 @@ struct IConvIMMABlockConsumer( - data_sh_ptr + - IMMAConfig::tile_b_sizes_int), + data_sh_ptr + IMMAConfig::tile_b_sizes_int), IMMAConfig::wmma_k); } } else { #pragma unroll for (int i = 0; i < WarpTileConfig::warp_tile_n; ++i) { int32_t* data_sh_ptr = data_gl2sh_visitor.sh_ptr( - ci_inner + 1, - (warpx + i * ThreadConfig::nr_warp_x) * - IMMAConfig::tile_b_sizes_int); + ci_inner + 1, (warpx + i * ThreadConfig::nr_warp_x) * + IMMAConfig::tile_b_sizes_int); wmma::load_matrix_sync( frag_src[i][load_idx], reinterpret_cast(data_sh_ptr), @@ -162,21 +159,20 @@ struct IConvIMMABlockConsumer(ker_sh_ptr), - IMMAConfig::wmma_k); + reinterpret_cast(ker_sh_ptr), IMMAConfig::wmma_k); } } // end if use_wide_store #pragma unroll for (int i = 0; i < WarpTileConfig::warp_tile_m; ++i) { #pragma unroll for (int j = 0; j < WarpTileConfig::warp_tile_n; ++j) { - wmma::mma_sync(frag_acc[i][j], frag_filter[i][comp_idx], - frag_src[j][comp_idx], frag_acc[i][j]); + wmma::mma_sync( + frag_acc[i][j], frag_filter[i][comp_idx], + frag_src[j][comp_idx], frag_acc[i][j]); } } } // end ci_inner @@ -184,10 +180,8 @@ struct IConvIMMABlockConsumer -struct IConvIMMABlockConsumer { +template +struct IConvIMMABlockConsumer { using IMMAConfig = IMMAConfig_; using WarpTileConfig = WarpTileConfig_; using ThreadConfig = ThreadConfig_; @@ -211,8 +205,8 @@ struct IConvIMMABlockConsumer + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void consume_block( DataGlobal2ShareMemVisitor data_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor) { @@ -224,8 +218,7 @@ struct IConvIMMABlockConsumer> 1); ++i) { @@ -235,8 +228,7 @@ struct IConvIMMABlockConsumer(data_sh_ptr), + frag_src[i2], reinterpret_cast(data_sh_ptr), IMMAConfig::wmma_k); wmma::load_matrix_sync( frag_src[i2 + 1], @@ -260,16 +252,17 @@ struct IConvIMMABlockConsumer(ker_sh_ptr), - IMMAConfig::wmma_k); + wmma::load_matrix_sync( + frag_filter[j], reinterpret_cast(ker_sh_ptr), + IMMAConfig::wmma_k); } #pragma unroll for (int i = 0; i < WarpTileConfig::warp_tile_m; ++i) { #pragma unroll for (int j = 0; j < WarpTileConfig::warp_tile_n; ++j) { - wmma::mma_sync(frag_acc[i][j], frag_filter[i], frag_src[j], - frag_acc[i][j]); + wmma::mma_sync( + frag_acc[i][j], frag_filter[i], frag_src[j], + frag_acc[i][j]); } } } // end for ci_inner diff --git a/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer_unroll_width.cuh b/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer_unroll_width.cuh index eb523637..48364213 100644 --- a/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer_unroll_width.cuh +++ b/dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer_unroll_width.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer_unroll_width.cuh + * \file + * dnn/src/cuda/convolution_helper/block_tile_consumer/iconv_imma_block_consumer_unroll_width.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -38,8 +40,9 @@ namespace megdnn { namespace cuda { namespace convolution { -template +template < + typename Conv1dConfig_, typename IMMAConfig_, typename WarpTileConfig_, + typename ThreadConfig_> struct IConvIMMABlockConsumerUnrollWidth { using Conv1dConfig = Conv1dConfig_; using IMMAConfig = IMMAConfig_; @@ -66,8 +69,8 @@ struct IConvIMMABlockConsumerUnrollWidth { } #if __CUDA_ARCH__ >= 730 - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void consume_block( DataGlobal2ShareMemVisitor data_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor) { @@ -76,38 +79,32 @@ struct IConvIMMABlockConsumerUnrollWidth { const int warpx = tidx / ThreadConfig::warp_size; const int warpy = tidy; - static bool const consecutive_width_tile = - !(WarpTileConfig::warp_tile_n & 0x1); + static bool const consecutive_width_tile = !(WarpTileConfig::warp_tile_n & 0x1); if (consecutive_width_tile) { #pragma unroll for (int i = 0; i < (WarpTileConfig::warp_tile_n >> 1); ++i) { int i2 = (i << 1); int warpx2 = (warpx << 1); int32_t* data_sh_ptr = data_gl2sh_visitor.sh_ptr( - (warpx2 + i2 * ThreadConfig::nr_warp_x) * - Conv1dConfig::sw, - 0); - wmma::load_matrix_sync(frag_src[i2][0], - reinterpret_cast(data_sh_ptr), - IMMAConfig::wmma_k); + (warpx2 + i2 * ThreadConfig::nr_warp_x) * Conv1dConfig::sw, 0); + wmma::load_matrix_sync( + frag_src[i2][0], reinterpret_cast(data_sh_ptr), + IMMAConfig::wmma_k); wmma::load_matrix_sync( frag_src[i2 + 1][0], reinterpret_cast( data_sh_ptr + - Conv1dConfig::sw * - IMMAConfig::tile_b_sizes_int), + Conv1dConfig::sw * IMMAConfig::tile_b_sizes_int), IMMAConfig::wmma_k); } } else { #pragma unroll for (int i = 0; i < WarpTileConfig::warp_tile_n; ++i) { int32_t* data_sh_ptr = data_gl2sh_visitor.sh_ptr( - (warpx + i * ThreadConfig::nr_warp_x) * - Conv1dConfig::sw, - 0); - wmma::load_matrix_sync(frag_src[i][0], - reinterpret_cast(data_sh_ptr), - IMMAConfig::wmma_k); + (warpx + i * ThreadConfig::nr_warp_x) * Conv1dConfig::sw, 0); + wmma::load_matrix_sync( + frag_src[i][0], reinterpret_cast(data_sh_ptr), + IMMAConfig::wmma_k); } } #pragma unroll @@ -115,9 +112,9 @@ struct IConvIMMABlockConsumerUnrollWidth { int32_t* ker_sh_ptr = filter_gl2sh_visitor.sh_ptr( 0, (warpy + j * ThreadConfig::nr_warp_y) * IMMAConfig::tile_a_sizes_int); - wmma::load_matrix_sync(frag_filter[j][0], - reinterpret_cast(ker_sh_ptr), - IMMAConfig::wmma_k); + wmma::load_matrix_sync( + frag_filter[j][0], reinterpret_cast(ker_sh_ptr), + IMMAConfig::wmma_k); } #pragma unroll @@ -127,8 +124,7 @@ struct IConvIMMABlockConsumerUnrollWidth { if (kw != Conv1dConfig::fw - 1) { if (consecutive_width_tile) { #pragma unroll - for (int i = 0; i < (WarpTileConfig::warp_tile_n >> 1); - ++i) { + for (int i = 0; i < (WarpTileConfig::warp_tile_n >> 1); ++i) { int i2 = (i << 1); int warpx2 = (warpx << 1); int32_t* data_sh_ptr = data_gl2sh_visitor.sh_ptr( @@ -169,23 +165,23 @@ struct IConvIMMABlockConsumerUnrollWidth { IMMAConfig::tile_a_sizes_int); wmma::load_matrix_sync( frag_filter[j][load_idx], - reinterpret_cast(ker_sh_ptr), - IMMAConfig::wmma_k); + reinterpret_cast(ker_sh_ptr), IMMAConfig::wmma_k); } } // end if ci_inner #pragma unroll for (int i = 0; i < WarpTileConfig::warp_tile_m; ++i) { #pragma unroll for (int j = 0; j < WarpTileConfig::warp_tile_n; ++j) { - wmma::mma_sync(frag_acc[i][j], frag_filter[i][comp_idx], - frag_src[j][comp_idx], frag_acc[i][j]); + wmma::mma_sync( + frag_acc[i][j], frag_filter[i][comp_idx], + frag_src[j][comp_idx], frag_acc[i][j]); } } } // end for kw } #else - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void consume_block( DataGlobal2ShareMemVisitor /* data_gl2sh_visitor */, FilterGlobal2ShareMemVisitor /* filter_gl2sh_visitor */) {} diff --git a/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator.cuh b/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator.cuh index dd11a519..4813ede5 100644 --- a/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator.cuh +++ b/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** diff --git a/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_basic.cuh b/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_basic.cuh index 953e6ab6..d046d25f 100644 --- a/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_basic.cuh +++ b/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_basic.cuh @@ -1,36 +1,39 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_basic.cuh + * \file + * dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_basic.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -63,8 +66,8 @@ struct BlockTileIteratorBasic { block_out_channel_remain = param.co - block_out_channel; } - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void set_remain( DataGlobal2ShareMemVisitor& src_gl2sh_visitor, FilterGlobal2ShareMemVisitor& filter_gl2sh_visitor) { @@ -76,18 +79,16 @@ struct BlockTileIteratorBasic { __device__ __forceinline__ void set_remain( GlobalMemoryWriter& global_memory_writer) { global_memory_writer.block_batch_remain = block_batch_remain; - global_memory_writer.block_out_channel_remain = - block_out_channel_remain; + global_memory_writer.block_out_channel_remain = block_out_channel_remain; } - template + template < + typename InputLayout, typename KernLayout, typename src_dtype, + typename filter_dtype, typename Param, typename DataGlobal2ShareMemVisitor, + typename FilterGlobal2ShareMemVisitor, typename BlockConsumer> __device__ __forceinline__ void iterate_with_param( - const src_dtype* __restrict__ src, - const filter_dtype* __restrict__ filter, const Param& param, - DataGlobal2ShareMemVisitor src_gl2sh_visitor, + const src_dtype* __restrict__ src, const filter_dtype* __restrict__ filter, + const Param& param, DataGlobal2ShareMemVisitor src_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor, BlockConsumer& consumer) { InputLayout src_layout; @@ -109,18 +110,17 @@ struct BlockTileIteratorBasic { int w_end = w_base + param.fw - 1; h_end = h_end < param.hi ? h_end : param.hi - 1; w_end = w_end < param.wi ? w_end : param.wi - 1; - const int ci_blks = - (param.ci + DataTileCount::block_tile_in_channel - 1) / - DataTileCount::block_tile_in_channel; + const int ci_blks = (param.ci + DataTileCount::block_tile_in_channel - 1) / + DataTileCount::block_tile_in_channel; int kh = h_start - h_base; int kw = w_start - w_base; - src_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename DataGlobal2ShareMemVisitor::copy_t*>( - g_src_ptr + src_layout.offset(0, 0, h_start, w_start)); - filter_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename FilterGlobal2ShareMemVisitor::copy_t*>( - g_filter_ptr + filter_layout.offset(0, 0, kh, kw)); + src_gl2sh_visitor.g_ptr = + reinterpret_cast( + g_src_ptr + src_layout.offset(0, 0, h_start, w_start)); + filter_gl2sh_visitor.g_ptr = + reinterpret_cast( + g_filter_ptr + filter_layout.offset(0, 0, kh, kw)); src_gl2sh_visitor.first_copy(); filter_gl2sh_visitor.first_copy(); @@ -136,15 +136,13 @@ struct BlockTileIteratorBasic { int kh = h_next - h_base; int kw = w_next - w_base; src_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename DataGlobal2ShareMemVisitor:: - copy_t*>( + const typename DataGlobal2ShareMemVisitor::copy_t*>( g_src_ptr + src_layout.offset(0, 0, h_next, w_next)); filter_gl2sh_visitor.g_ptr = reinterpret_cast< const typename FilterGlobal2ShareMemVisitor:: copy_t*>( - g_filter_ptr + - filter_layout.offset(0, 0, kh, kw)); + g_filter_ptr + filter_layout.offset(0, 0, kh, kw)); src_gl2sh_visitor.copy(); filter_gl2sh_visitor.copy(); } @@ -155,11 +153,9 @@ struct BlockTileIteratorBasic { filter_gl2sh_visitor.copy(); } - consumer.consume_block(src_gl2sh_visitor, - filter_gl2sh_visitor); + consumer.consume_block(src_gl2sh_visitor, filter_gl2sh_visitor); - if (!(ci_outer == ci_blks - 1 && h == h_end && - w == w_end)) { + if (!(ci_outer == ci_blks - 1 && h == h_end && w == w_end)) { __syncthreads(); src_gl2sh_visitor.commit(); filter_gl2sh_visitor.commit(); diff --git a/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_coxhw.cuh b/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_coxhw.cuh index e2c0a3d7..ad3d10b7 100644 --- a/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_coxhw.cuh +++ b/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_coxhw.cuh @@ -1,36 +1,39 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_coxhw.cuh + * \file + * dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_coxhw.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "src/cuda/convolution_helper/prologue.cuh" @@ -38,8 +41,9 @@ namespace megdnn { namespace cuda { namespace convolution { -template +template < + typename DataTileCount_, typename FilterTileCount_, + typename Prologue = ConvPrologue> struct BlockTileIterator_COxHW { using DataTileCount = DataTileCount_; using FilterTileCount = FilterTileCount_; @@ -59,18 +63,16 @@ struct BlockTileIterator_COxHW { template __device__ __forceinline__ void init_with_param(const Param& param) { block_batch = bidz; - block_out_height_width = - bidx * DataTileCount::block_tile_out_height_width; + block_out_height_width = bidx * DataTileCount::block_tile_out_height_width; block_out_channel = bidy * FilterTileCount::block_tile_out_channel; block_out_height = block_out_height_width / param.wo; block_out_width = block_out_height_width - block_out_height * param.wo; block_out_channel_remain = param.co - block_out_channel; - block_out_height_width_remain = - param.ho * param.wo - block_out_height_width; + block_out_height_width_remain = param.ho * param.wo - block_out_height_width; } - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void set_remain( DataGlobal2ShareMemVisitor& src_gl2sh_visitor, FilterGlobal2ShareMemVisitor& filter_gl2sh_visitor) { @@ -83,25 +85,23 @@ struct BlockTileIterator_COxHW { template __device__ __forceinline__ void set_remain( GlobalMemoryWriter& global_memory_writer) { - global_memory_writer.block_out_channel_remain = - block_out_channel_remain; + global_memory_writer.block_out_channel_remain = block_out_channel_remain; global_memory_writer.block_out_height_width_remain = block_out_height_width_remain; } - template + template < + typename InputLayout, typename KernLayout, typename src_dtype, + typename filter_dtype, typename Param, typename DataGlobal2ShareMemVisitor, + typename FilterGlobal2ShareMemVisitor, typename BlockConsumer> __device__ __forceinline__ void iterate_with_param( - const src_dtype* __restrict__ src, - const filter_dtype* __restrict__ filter, const Param& param, - DataGlobal2ShareMemVisitor src_gl2sh_visitor, + const src_dtype* __restrict__ src, const filter_dtype* __restrict__ filter, + const Param& param, DataGlobal2ShareMemVisitor src_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor, BlockConsumer& consumer) { - Prologue::template prologue(src, filter, param, block_batch, - block_out_channel, block_out_height, - block_out_width); + Prologue::template prologue( + src, filter, param, block_batch, block_out_channel, block_out_height, + block_out_width); static constexpr bool precomp_offset = DataGlobal2ShareMemVisitor::precomp_offset; InputLayout src_layout; @@ -113,8 +113,8 @@ struct BlockTileIterator_COxHW { g_src_ptr = src + src_layout.offset(block_batch, 0, 0, 0); } else { g_src_ptr = - src + src_layout.offset(block_batch, 0, block_out_height, - block_out_width); + src + src_layout.offset( + block_batch, 0, block_out_height, block_out_width); } const filter_dtype* __restrict__ g_filter_ptr = filter + filter_layout.offset(block_out_channel, 0, 0, 0); @@ -122,18 +122,18 @@ struct BlockTileIterator_COxHW { src_gl2sh_visitor.init_stride(src_layout); filter_gl2sh_visitor.init_stride(filter_layout); - const int ci_blks = - (param.ci + DataTileCount::block_tile_in_channel - 1) / - DataTileCount::block_tile_in_channel; + const int ci_blks = (param.ci + DataTileCount::block_tile_in_channel - 1) / + DataTileCount::block_tile_in_channel; if (precomp_offset) { src_gl2sh_visitor.offset += block_out_height_width; } - src_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename DataGlobal2ShareMemVisitor::copy_t*>(g_src_ptr); - filter_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename FilterGlobal2ShareMemVisitor::copy_t*>( - g_filter_ptr); + src_gl2sh_visitor.g_ptr = + reinterpret_cast( + g_src_ptr); + filter_gl2sh_visitor.g_ptr = + reinterpret_cast( + g_filter_ptr); src_gl2sh_visitor.first_copy(); filter_gl2sh_visitor.first_copy(); @@ -152,15 +152,13 @@ struct BlockTileIterator_COxHW { // rewind if (precomp_offset) { src_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename DataGlobal2ShareMemVisitor:: - copy_t*>(g_src_ptr); + const typename DataGlobal2ShareMemVisitor::copy_t*>( + g_src_ptr); src_gl2sh_visitor.offset += img_pixels; } filter_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename FilterGlobal2ShareMemVisitor:: - copy_t*>( - g_filter_ptr + - filter_layout.offset(0, 0, kh, kw)); + const typename FilterGlobal2ShareMemVisitor::copy_t*>( + g_filter_ptr + filter_layout.offset(0, 0, kh, kw)); src_gl2sh_visitor.copy(); filter_gl2sh_visitor.copy(); } @@ -171,8 +169,7 @@ struct BlockTileIterator_COxHW { filter_gl2sh_visitor.copy(); } - consumer.consume_block(src_gl2sh_visitor, - filter_gl2sh_visitor); + consumer.consume_block(src_gl2sh_visitor, filter_gl2sh_visitor); if (!(ci_outer == ci_blks - 1 && f == filter_pixels - 1)) { __syncthreads(); diff --git a/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width.cuh b/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width.cuh index 8ab17aab..53c7597f 100644 --- a/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width.cuh +++ b/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width.cuh + * \file + * dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -68,8 +70,8 @@ struct BlockTileIteratorUnrollWidth { block_out_channel_remain = param.co - block_out_channel; } - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void set_remain( DataGlobal2ShareMemVisitor& src_gl2sh_visitor, FilterGlobal2ShareMemVisitor& filter_gl2sh_visitor) { @@ -81,18 +83,16 @@ struct BlockTileIteratorUnrollWidth { __device__ __forceinline__ void set_remain( GlobalMemoryWriter& global_memory_writer) { global_memory_writer.block_batch_remain = block_batch_remain; - global_memory_writer.block_out_channel_remain = - block_out_channel_remain; + global_memory_writer.block_out_channel_remain = block_out_channel_remain; } - template + template < + typename InputLayout, typename KernLayout, typename src_dtype, + typename filter_dtype, typename Param, typename DataGlobal2ShareMemVisitor, + typename FilterGlobal2ShareMemVisitor, typename BlockConsumer> __device__ __forceinline__ void iterate_with_param( - const src_dtype* __restrict__ src, - const filter_dtype* __restrict__ filter, const Param& param, - DataGlobal2ShareMemVisitor src_gl2sh_visitor, + const src_dtype* __restrict__ src, const filter_dtype* __restrict__ filter, + const Param& param, DataGlobal2ShareMemVisitor src_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor, BlockConsumer& consumer) { InputLayout src_layout; @@ -113,18 +113,17 @@ struct BlockTileIteratorUnrollWidth { h_end = h_end < param.hi ? h_end : param.hi - 1; int w_start = w_base; int w_end = w_start + param.fw - 1; - const int ci_blks = - (param.ci + DataTileCount::block_tile_in_channel - 1) / - DataTileCount::block_tile_in_channel; + const int ci_blks = (param.ci + DataTileCount::block_tile_in_channel - 1) / + DataTileCount::block_tile_in_channel; int kh = h_start - h_base; src_gl2sh_visitor.sw = param.sw; - src_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename DataGlobal2ShareMemVisitor::copy_t*>( - g_src_ptr + src_layout.offset(0, 0, h_start, w_start)); - filter_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename FilterGlobal2ShareMemVisitor::copy_t*>( - g_filter_ptr + filter_layout.offset(0, 0, kh, 0)); + src_gl2sh_visitor.g_ptr = + reinterpret_cast( + g_src_ptr + src_layout.offset(0, 0, h_start, w_start)); + filter_gl2sh_visitor.g_ptr = + reinterpret_cast( + g_filter_ptr + filter_layout.offset(0, 0, kh, 0)); src_gl2sh_visitor.set_range(-w_start, param.wi - w_start); src_gl2sh_visitor.first_copy(); filter_gl2sh_visitor.first_copy(); @@ -141,17 +140,14 @@ struct BlockTileIteratorUnrollWidth { int kh = h_next - h_base; int kw = w_next - w_base; src_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename DataGlobal2ShareMemVisitor:: - copy_t*>( + const typename DataGlobal2ShareMemVisitor::copy_t*>( g_src_ptr + src_layout.offset(0, 0, h_next, w_next)); filter_gl2sh_visitor.g_ptr = reinterpret_cast< const typename FilterGlobal2ShareMemVisitor:: copy_t*>( - g_filter_ptr + - filter_layout.offset(0, 0, kh, kw)); - src_gl2sh_visitor.set_range(-w_next, - param.wi - w_next); + g_filter_ptr + filter_layout.offset(0, 0, kh, kw)); + src_gl2sh_visitor.set_range(-w_next, param.wi - w_next); src_gl2sh_visitor.copy(); filter_gl2sh_visitor.copy(); } @@ -162,11 +158,9 @@ struct BlockTileIteratorUnrollWidth { filter_gl2sh_visitor.copy(); } - consumer.consume_block(src_gl2sh_visitor, - filter_gl2sh_visitor); + consumer.consume_block(src_gl2sh_visitor, filter_gl2sh_visitor); - if (!(ci_outer == ci_blks - 1 && h == h_end && - w == w_end)) { + if (!(ci_outer == ci_blks - 1 && h == h_end && w == w_end)) { __syncthreads(); src_gl2sh_visitor.commit(); filter_gl2sh_visitor.commit(); diff --git a/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width_v2.cuh b/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width_v2.cuh index 1904de87..d71ac497 100644 --- a/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width_v2.cuh +++ b/dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width_v2.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width_v2.cuh + * \file + * dnn/src/cuda/convolution_helper/block_tile_iterator/block_tile_iterator_unroll_width_v2.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -69,8 +71,8 @@ struct BlockTileIteratorUnrollWidthV2 { block_out_channel_remain = param.co - block_out_channel; } - template + template < + typename DataGlobal2ShareMemVisitor, typename FilterGlobal2ShareMemVisitor> __device__ __forceinline__ void set_remain( DataGlobal2ShareMemVisitor& src_gl2sh_visitor, FilterGlobal2ShareMemVisitor& filter_gl2sh_visitor) { @@ -82,18 +84,16 @@ struct BlockTileIteratorUnrollWidthV2 { __device__ __forceinline__ void set_remain( GlobalMemoryWriter& global_memory_writer) { global_memory_writer.block_batch_remain = block_batch_remain; - global_memory_writer.block_out_channel_remain = - block_out_channel_remain; + global_memory_writer.block_out_channel_remain = block_out_channel_remain; } - template + template < + typename InputLayout, typename KernLayout, typename src_dtype, + typename filter_dtype, typename Param, typename DataGlobal2ShareMemVisitor, + typename FilterGlobal2ShareMemVisitor, typename BlockConsumer> __device__ __forceinline__ void iterate_with_param( - const src_dtype* __restrict__ src, - const filter_dtype* __restrict__ filter, const Param& param, - DataGlobal2ShareMemVisitor src_gl2sh_visitor, + const src_dtype* __restrict__ src, const filter_dtype* __restrict__ filter, + const Param& param, DataGlobal2ShareMemVisitor src_gl2sh_visitor, FilterGlobal2ShareMemVisitor filter_gl2sh_visitor, BlockConsumer& consumer) { InputLayout src_layout; @@ -112,17 +112,16 @@ struct BlockTileIteratorUnrollWidthV2 { int h_end = h_base + param.fh - 1; h_end = h_end < param.hi ? h_end : param.hi - 1; - const int ci_blks = - (param.ci + DataTileCount::block_tile_in_channel - 1) / - DataTileCount::block_tile_in_channel; + const int ci_blks = (param.ci + DataTileCount::block_tile_in_channel - 1) / + DataTileCount::block_tile_in_channel; int kh = h_start - h_base; - src_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename DataGlobal2ShareMemVisitor::copy_t*>( - g_src_ptr + src_layout.offset(0, 0, h_start, 0)); - filter_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename FilterGlobal2ShareMemVisitor::copy_t*>( - g_filter_ptr + filter_layout.offset(0, 0, kh, 0)); + src_gl2sh_visitor.g_ptr = + reinterpret_cast( + g_src_ptr + src_layout.offset(0, 0, h_start, 0)); + filter_gl2sh_visitor.g_ptr = + reinterpret_cast( + g_filter_ptr + filter_layout.offset(0, 0, kh, 0)); src_gl2sh_visitor.set_range(-block_in_width, param.wi - block_in_width); src_gl2sh_visitor.first_copy(); filter_gl2sh_visitor.first_copy(); @@ -136,14 +135,11 @@ struct BlockTileIteratorUnrollWidthV2 { int h_next = h + 1; int kh = h_next - h_base; src_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename DataGlobal2ShareMemVisitor:: - copy_t*>( + const typename DataGlobal2ShareMemVisitor::copy_t*>( g_src_ptr + src_layout.offset(0, 0, h_next, 0)); filter_gl2sh_visitor.g_ptr = reinterpret_cast< - const typename FilterGlobal2ShareMemVisitor:: - copy_t*>( - g_filter_ptr + - filter_layout.offset(0, 0, kh, 0)); + const typename FilterGlobal2ShareMemVisitor::copy_t*>( + g_filter_ptr + filter_layout.offset(0, 0, kh, 0)); src_gl2sh_visitor.copy(); filter_gl2sh_visitor.copy(); } @@ -154,8 +150,7 @@ struct BlockTileIteratorUnrollWidthV2 { filter_gl2sh_visitor.copy(); } - consumer.consume_block(src_gl2sh_visitor, - filter_gl2sh_visitor); + consumer.consume_block(src_gl2sh_visitor, filter_gl2sh_visitor); if (!(ci_outer == ci_blks - 1 && h == h_end)) { __syncthreads(); diff --git a/dnn/src/cuda/convolution_helper/config.cuh b/dnn/src/cuda/convolution_helper/config.cuh index ea71ba7c..35218ba1 100644 --- a/dnn/src/cuda/convolution_helper/config.cuh +++ b/dnn/src/cuda/convolution_helper/config.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -52,10 +53,10 @@ struct RegBlockConfig { static int constexpr reg_m = reg_m_; static int constexpr reg_n = reg_n_; static int constexpr reg_k = reg_k_; - MEGDNN_STATIC_ASSERT(reg_m % pack_size == 0, - "reg_m must be a multiple of pack_size"); - MEGDNN_STATIC_ASSERT(reg_k % pack_size == 0, - "reg_k must be a multiple of pack_size"); + MEGDNN_STATIC_ASSERT( + reg_m % pack_size == 0, "reg_m must be a multiple of pack_size"); + MEGDNN_STATIC_ASSERT( + reg_k % pack_size == 0, "reg_k must be a multiple of pack_size"); static int constexpr reg_k_packed = reg_k / pack_size; static int constexpr reg_m_packed = reg_m / pack_size; static int constexpr reg_width = reg_width_; @@ -67,8 +68,7 @@ struct ThreadConfig { static int constexpr nr_thread_x = thread_x; static int constexpr nr_thread_y = thread_y; static int constexpr nr_threads = nr_thread_x * nr_thread_y; - static int constexpr nr_warp_x = - !(nr_thread_x & 0x1f) ? (nr_thread_x >> 5) : 0; + static int constexpr nr_warp_x = !(nr_thread_x & 0x1f) ? (nr_thread_x >> 5) : 0; static int constexpr nr_warp_y = !(nr_thread_x & 0x1f) ? nr_thread_y : 0; }; static int constexpr WARP_SIZE = ThreadConfig<1, 1>::warp_size; @@ -92,10 +92,10 @@ struct IMMAConfig { static int constexpr wmma_n_bit = wmma_n == 8 ? 3 : (wmma_n == 16 ? 4 : 5); static int constexpr wmma_m_bit = wmma_m == 8 ? 3 : (wmma_m == 16 ? 4 : 5); #if __CUDA_ARCH__ >= 730 - using fragment_a = wmma::fragment; - using fragment_b = wmma::fragment; + using fragment_a = wmma::fragment< + wmma::matrix_a, wmma_m, wmma_n, wmma_k, int8_t, wmma::row_major>; + using fragment_b = wmma::fragment< + wmma::matrix_b, wmma_m, wmma_n, wmma_k, int8_t, wmma::col_major>; using fragment_c = wmma::fragment; #endif diff --git a/dnn/src/cuda/convolution_helper/conv_trait/conv_trait.cuh b/dnn/src/cuda/convolution_helper/conv_trait/conv_trait.cuh index f4869ff4..bebce701 100644 --- a/dnn/src/cuda/convolution_helper/conv_trait/conv_trait.cuh +++ b/dnn/src/cuda/convolution_helper/conv_trait/conv_trait.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** diff --git a/dnn/src/cuda/convolution_helper/conv_trait/ibatch_conv_trait.cuh b/dnn/src/cuda/convolution_helper/conv_trait/ibatch_conv_trait.cuh index f21197e1..c67bb20d 100644 --- a/dnn/src/cuda/convolution_helper/conv_trait/ibatch_conv_trait.cuh +++ b/dnn/src/cuda/convolution_helper/conv_trait/ibatch_conv_trait.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -46,27 +47,26 @@ namespace megdnn { namespace cuda { namespace convolution { -#define COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( \ - _src_dtype, _filter_dtype, _smem_storage_dtype, _input_layout, \ - _kern_layout, _output_layout, _conv_param) \ - using src_dtype = _src_dtype; \ - using filter_dtype = _filter_dtype; \ - using smem_storage_dtype = _smem_storage_dtype; \ - using InputLayout = _input_layout; \ - using KernLayout = _kern_layout; \ - using OutputLayout = _output_layout; \ - using Param = _conv_param; \ +#define COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( \ + _src_dtype, _filter_dtype, _smem_storage_dtype, _input_layout, _kern_layout, \ + _output_layout, _conv_param) \ + using src_dtype = _src_dtype; \ + using filter_dtype = _filter_dtype; \ + using smem_storage_dtype = _smem_storage_dtype; \ + using InputLayout = _input_layout; \ + using KernLayout = _kern_layout; \ + using OutputLayout = _output_layout; \ + using Param = _conv_param; \ static constexpr bool check_bounds = check_bounds_; #define MEGDNN_COMMA , -template +template < + bool check_bounds_, typename src_ldg_dtype, typename filter_ldg_dtype, + typename RegBlockConfig_, typename ThreadConfig_> struct IBatchConvTrait_f1x1s1x1 { - COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(int8_t, int8_t, int32_t, - Layout, - Layout, - Layout, - ConvParam); + COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( + int8_t, int8_t, int32_t, Layout, Layout, + Layout, ConvParam); using RegBlockConfig = RegBlockConfig_; using ThreadConfig = ThreadConfig_; struct DataTileCount { @@ -74,10 +74,8 @@ struct IBatchConvTrait_f1x1s1x1 { using ThreadConfig = ThreadConfig; using copy_t = src_ldg_dtype; using smem_storage_dtype = smem_storage_dtype; - static int constexpr load_width = - sizeof(copy_t) / sizeof(smem_storage_dtype); - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(src_dtype); + static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(src_dtype); static int constexpr skew = load_width; static int constexpr block_tile_batch = RegBlockConfig::reg_n; MEGDNN_STATIC_ASSERT( @@ -87,16 +85,13 @@ struct IBatchConvTrait_f1x1s1x1 { RegBlockConfig::reg_width * ThreadConfig::nr_thread_x; static int constexpr block_tile_in_channel = RegBlockConfig::reg_k; - static int constexpr smem_load_x = - block_tile_out_height_width / load_width; - static int constexpr load_x = - smem_load_x > WARP_SIZE ? WARP_SIZE : smem_load_x; + static int constexpr smem_load_x = block_tile_out_height_width / load_width; + static int constexpr load_x = smem_load_x > WARP_SIZE ? WARP_SIZE : smem_load_x; static int constexpr load_y = ThreadConfig::nr_threads / load_x; static int constexpr smem_h = RegBlockConfig::reg_k_packed; static int constexpr smem_w = block_tile_out_height_width; - static int constexpr smem_stride = - smem_w % 2 == 0 ? smem_w + skew : smem_w; + static int constexpr smem_stride = smem_w % 2 == 0 ? smem_w + skew : smem_w; static int constexpr smem_tot = smem_h * smem_stride; static int constexpr reg_h = (smem_h + load_y - 1) / load_y; @@ -111,19 +106,15 @@ struct IBatchConvTrait_f1x1s1x1 { using ThreadConfig = ThreadConfig; using copy_t = filter_ldg_dtype; using smem_storage_dtype = smem_storage_dtype; - static int constexpr load_width = - sizeof(copy_t) / sizeof(smem_storage_dtype); - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(filter_dtype); + static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(filter_dtype); static int constexpr skew = load_width; static int constexpr block_tile_out_channel = RegBlockConfig::reg_m * ThreadConfig::nr_thread_y; static int constexpr block_tile_in_channel = RegBlockConfig::reg_k; - static int constexpr smem_load_x = - RegBlockConfig::reg_k_packed / load_width; - static int constexpr load_x = - smem_load_x > WARP_SIZE ? WARP_SIZE : smem_load_x; + static int constexpr smem_load_x = RegBlockConfig::reg_k_packed / load_width; + static int constexpr load_x = smem_load_x > WARP_SIZE ? WARP_SIZE : smem_load_x; static int constexpr load_y = ThreadConfig::nr_threads / load_x; static int constexpr smem_h = block_tile_out_channel; @@ -139,14 +130,11 @@ struct IBatchConvTrait_f1x1s1x1 { }; using BlockTileIterator = - BlockTileIterator_COxHW; - using DataGlobal2ShareMemVisitor = - Global2ShareMemVisitor_CIxHW; + BlockTileIterator_COxHW; + using DataGlobal2ShareMemVisitor = Global2ShareMemVisitor_CIxHW< + check_bounds, false, DataTileCount, InputLayout>; using FilterGlobal2ShareMemVisitor = - Global2ShareMemVisitor_COxCI; + Global2ShareMemVisitor_COxCI; static bool constexpr pipelined = RegBlockConfig::reg_k_packed > 1; using BlockConsumer = IConvBlockConsumer_COxHW; @@ -154,14 +142,13 @@ struct IBatchConvTrait_f1x1s1x1 { IConvGlobalMemoryWriter_COxHW; }; -template +template < + bool check_bounds_, typename filter_ldg_dtype, typename RegBlockConfig_, + typename ThreadConfig_> struct IBatchConvTrait { - COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(int8_t, int8_t, int32_t, - Layout, - Layout, - Layout, - ConvParam); + COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( + int8_t, int8_t, int32_t, Layout, Layout, + Layout, ConvParam); using RegBlockConfig = RegBlockConfig_; using ThreadConfig = ThreadConfig_; struct DataTileCount { @@ -170,8 +157,7 @@ struct IBatchConvTrait { using copy_t = int32_t; using smem_storage_dtype = smem_storage_dtype; static int constexpr load_width = 4; - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(src_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(src_dtype); static int constexpr skew = load_width; static int constexpr block_tile_batch = RegBlockConfig::reg_n; MEGDNN_STATIC_ASSERT( @@ -183,14 +169,12 @@ struct IBatchConvTrait { static int constexpr smem_load_x = DIVUP(block_tile_out_height_width, load_width); - static int constexpr load_x = - smem_load_x > WARP_SIZE ? WARP_SIZE : smem_load_x; + static int constexpr load_x = smem_load_x > WARP_SIZE ? WARP_SIZE : smem_load_x; static int constexpr load_y = ThreadConfig::nr_threads / load_x; static int constexpr smem_h = RegBlockConfig::reg_k_packed; static int constexpr smem_w = smem_load_x * load_width; - static int constexpr smem_stride = - smem_w % 2 == 0 ? smem_w + skew : smem_w; + static int constexpr smem_stride = smem_w % 2 == 0 ? smem_w + skew : smem_w; static int constexpr smem_tot = smem_h * smem_stride; static int constexpr reg_h = (smem_h + load_y - 1) / load_y; @@ -200,20 +184,16 @@ struct IBatchConvTrait { static bool constexpr check_bounds_h = smem_h % load_y != 0; static bool constexpr check_bounds_w = smem_load_x % load_x != 0; }; - using FilterTileCount = - typename IBatchConvTrait_f1x1s1x1::FilterTileCount; + using FilterTileCount = typename IBatchConvTrait_f1x1s1x1< + check_bounds, int, filter_ldg_dtype, RegBlockConfig, + ThreadConfig>::FilterTileCount; using BlockTileIterator = - BlockTileIterator_COxHW; - using DataGlobal2ShareMemVisitor = - Global2ShareMemVisitor_CIxHW; + BlockTileIterator_COxHW; + using DataGlobal2ShareMemVisitor = Global2ShareMemVisitor_CIxHW< + check_bounds, true, DataTileCount, InputLayout>; using FilterGlobal2ShareMemVisitor = - Global2ShareMemVisitor_COxCI; + Global2ShareMemVisitor_COxCI; static bool constexpr pipelined = RegBlockConfig::reg_k_packed > 1; using BlockConsumer = IConvBlockConsumer_COxHW; diff --git a/dnn/src/cuda/convolution_helper/conv_trait/iconv_imma_trait.cuh b/dnn/src/cuda/convolution_helper/conv_trait/iconv_imma_trait.cuh index 3fb685eb..0481e739 100644 --- a/dnn/src/cuda/convolution_helper/conv_trait/iconv_imma_trait.cuh +++ b/dnn/src/cuda/convolution_helper/conv_trait/iconv_imma_trait.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -43,27 +44,26 @@ namespace megdnn { namespace cuda { namespace convolution { -#define COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( \ - _src_dtype, _filter_dtype, _smem_storage_dtype, _input_layout, \ - _kern_layout, _output_layout, _conv_param) \ - using src_dtype = _src_dtype; \ - using filter_dtype = _filter_dtype; \ - using smem_storage_dtype = _smem_storage_dtype; \ - using InputLayout = _input_layout; \ - using KernLayout = _kern_layout; \ - using OutputLayout = _output_layout; \ - using Param = _conv_param; \ +#define COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( \ + _src_dtype, _filter_dtype, _smem_storage_dtype, _input_layout, _kern_layout, \ + _output_layout, _conv_param) \ + using src_dtype = _src_dtype; \ + using filter_dtype = _filter_dtype; \ + using smem_storage_dtype = _smem_storage_dtype; \ + using InputLayout = _input_layout; \ + using KernLayout = _kern_layout; \ + using OutputLayout = _output_layout; \ + using Param = _conv_param; \ static constexpr bool check_bounds = check_bounds_; #define MEGDNN_COMMA , -template +template < + bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, + typename ThreadConfig_> struct IConvIMMATrait { - COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(int8_t, int8_t, int32_t, - Layout, - Layout, - Layout, - ConvParam); + COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( + int8_t, int8_t, int32_t, Layout, Layout, + Layout, ConvParam); using IMMAConfig = IMMAConfig_; using WarpTileConfig = WarpTileConfig_; using ThreadConfig = ThreadConfig_; @@ -73,10 +73,8 @@ struct IConvIMMATrait { using ThreadConfig = ThreadConfig; using copy_t = int32_t; using smem_storage_dtype = smem_storage_dtype; - static int constexpr load_width = - sizeof(copy_t) / sizeof(smem_storage_dtype); - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(src_dtype); + static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(src_dtype); static int constexpr block_tile_batch = WarpTileConfig::warp_tile_n * IMMAConfig::wmma_n * ThreadConfig::nr_warp_x; @@ -97,8 +95,7 @@ struct IConvIMMATrait { static int constexpr reg_h = (smem_h + load_y - 1) / load_y; static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x; - static int constexpr reg_d = - IMMAConfig::wmma_k / WarpTileConfig::pack_size; + static int constexpr reg_d = IMMAConfig::wmma_k / WarpTileConfig::pack_size; static bool constexpr check_bounds_h = smem_h % load_y != 0; static bool constexpr check_bounds_w = smem_load_x % load_x != 0; @@ -110,13 +107,11 @@ struct IConvIMMATrait { using ThreadConfig = ThreadConfig; using copy_t = int32_t; using smem_storage_dtype = smem_storage_dtype; - static int constexpr load_width = - sizeof(copy_t) / sizeof(smem_storage_dtype); - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(filter_dtype); - static int constexpr block_tile_out_channel = - WarpTileConfig::warp_tile_m * IMMAConfig::wmma_m * - ThreadConfig::nr_warp_y; + static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(filter_dtype); + static int constexpr block_tile_out_channel = WarpTileConfig::warp_tile_m * + IMMAConfig::wmma_m * + ThreadConfig::nr_warp_y; static int constexpr block_tile_in_channel = WarpTileConfig::warp_tile_k * IMMAConfig::wmma_k; @@ -134,8 +129,7 @@ struct IConvIMMATrait { static int constexpr reg_h = (smem_h + load_y - 1) / load_y; static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x; - static int constexpr reg_d = - IMMAConfig::wmma_k / WarpTileConfig::pack_size; + static int constexpr reg_d = IMMAConfig::wmma_k / WarpTileConfig::pack_size; static bool constexpr check_bounds_h = smem_h % load_y != 0; static bool constexpr check_bounds_w = smem_load_x % load_x != 0; @@ -157,62 +151,52 @@ struct IConvIMMATrait { static int constexpr smem_stride = smem_w; static int constexpr smem_tot = smem_h * smem_stride; - static int constexpr store_x = - (WarpTileConfig::warp_tile_n & 0x1) - ? IMMAConfig::wmma_n / store_width - : 2 * IMMAConfig::wmma_n / store_width; + static int constexpr store_x = (WarpTileConfig::warp_tile_n & 0x1) + ? IMMAConfig::wmma_n / store_width + : 2 * IMMAConfig::wmma_n / store_width; static int constexpr store_y = ThreadConfig::warp_size / store_x; }; - using BlockTileIterator = - BlockTileIteratorBasic; + using BlockTileIterator = BlockTileIteratorBasic; using DataGlobal2ShareMemVisitor = - Global2ShareMemVisitorIMMA_CIxN; + Global2ShareMemVisitorIMMA_CIxN; using FilterGlobal2ShareMemVisitor = - Global2ShareMemVisitorIMMA_CIxN; + Global2ShareMemVisitorIMMA_CIxN; static bool constexpr pipelined = WarpTileConfig::warp_tile_k > 1; - using BlockConsumer = IConvIMMABlockConsumer; - using GlobalMemoryWriter = - IConvIMMAGlobalMemoryWriter; + using BlockConsumer = + IConvIMMABlockConsumer; + using GlobalMemoryWriter = IConvIMMAGlobalMemoryWriter; }; -template +template < + bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, + typename ThreadConfig_> struct IConvIMMATraitReorderFilter { - COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(int8_t, int8_t, int32_t, - Layout, - Layout, - Layout, - ConvParam); + COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( + int8_t, int8_t, int32_t, Layout, Layout, + Layout, ConvParam); using IMMAConfig = IMMAConfig_; using WarpTileConfig = WarpTileConfig_; using ThreadConfig = ThreadConfig_; MEGDNN_STATIC_ASSERT( std::is_same:: - src_dtype MEGDNN_COMMA src_dtype>::value == - true, + WarpTileConfig MEGDNN_COMMA ThreadConfig>::src_dtype + MEGDNN_COMMA src_dtype>::value == true, "data type of input tensor should be int8_t"); - using DataTileCount = - typename IConvIMMATrait::DataTileCount; + using DataTileCount = typename IConvIMMATrait< + check_bounds, IMMAConfig, WarpTileConfig, ThreadConfig>::DataTileCount; struct FilterTileCount { using IMMAConfig = IMMAConfig; using WarpTileConfig = WarpTileConfig; using ThreadConfig = ThreadConfig; using copy_t = int4; using smem_storage_dtype = smem_storage_dtype; - static int constexpr load_width = - sizeof(copy_t) / sizeof(smem_storage_dtype); - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(filter_dtype); - static int constexpr block_tile_out_channel = - WarpTileConfig::warp_tile_m * IMMAConfig::wmma_m * - ThreadConfig::nr_warp_y; + static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(filter_dtype); + static int constexpr block_tile_out_channel = WarpTileConfig::warp_tile_m * + IMMAConfig::wmma_m * + ThreadConfig::nr_warp_y; static int constexpr block_tile_in_channel = WarpTileConfig::warp_tile_k * IMMAConfig::wmma_k; @@ -235,32 +219,27 @@ struct IConvIMMATraitReorderFilter { static bool constexpr check_bounds_w = smem_load_x % load_x != 0; }; - using BlockTileIterator = - BlockTileIteratorBasic; + using BlockTileIterator = BlockTileIteratorBasic; using DataGlobal2ShareMemVisitor = - Global2ShareMemVisitorIMMA_CIxN; + Global2ShareMemVisitorIMMA_CIxN; using FilterGlobal2ShareMemVisitor = - Global2ShareMemVisitorIMMA_CIxN; + Global2ShareMemVisitorIMMA_CIxN; static bool constexpr pipelined = WarpTileConfig::warp_tile_k > 1; - using BlockConsumer = IConvIMMABlockConsumer; - using GlobalMemoryStoreCount = - typename IConvIMMATrait::GlobalMemoryStoreCount; - using GlobalMemoryWriter = - IConvIMMAGlobalMemoryWriter; + using BlockConsumer = + IConvIMMABlockConsumer; + using GlobalMemoryStoreCount = typename IConvIMMATrait< + check_bounds, IMMAConfig, WarpTileConfig, + ThreadConfig>::GlobalMemoryStoreCount; + using GlobalMemoryWriter = IConvIMMAGlobalMemoryWriter; }; -template +template < + bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, + typename ThreadConfig_> struct IConvIMMATraitUnrollWidth { - COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(int8_t, int8_t, int32_t, - Layout, - Layout, - Layout, - ConvParam); + COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( + int8_t, int8_t, int32_t, Layout, Layout, + Layout, ConvParam); using IMMAConfig = IMMAConfig_; using WarpTileConfig = WarpTileConfig_; using ThreadConfig = ThreadConfig_; @@ -271,10 +250,8 @@ struct IConvIMMATraitUnrollWidth { using ThreadConfig = ThreadConfig; using copy_t = int4; using smem_storage_dtype = smem_storage_dtype; - static int constexpr load_width = - sizeof(copy_t) / sizeof(smem_storage_dtype); - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(src_dtype); + static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(src_dtype); static int constexpr block_tile_batch = IMMAConfig::wmma_n; static int constexpr block_tile_out_width = @@ -296,8 +273,7 @@ struct IConvIMMATraitUnrollWidth { static int constexpr reg_h = (smem_h + load_y - 1) / load_y; static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x; - static int constexpr reg_d = - IMMAConfig::wmma_k / WarpTileConfig::pack_size; + static int constexpr reg_d = IMMAConfig::wmma_k / WarpTileConfig::pack_size; static bool constexpr check_bounds_h = smem_h % load_y != 0; static bool constexpr check_bounds_w = smem_load_x % load_x != 0; @@ -306,25 +282,20 @@ struct IConvIMMATraitUnrollWidth { MEGDNN_STATIC_ASSERT( std::is_same::filter_dtype - MEGDNN_COMMA filter_dtype>::value == true, + WarpTileConfig MEGDNN_COMMA ThreadConfig>:: + filter_dtype MEGDNN_COMMA filter_dtype>::value == true, "data type of filter tensor should be int8_t"); - using FilterTileCount = - typename IConvIMMATraitReorderFilter::FilterTileCount; + using FilterTileCount = typename IConvIMMATraitReorderFilter< + check_bounds, IMMAConfig, WarpTileConfig, ThreadConfig>::FilterTileCount; using BlockTileIterator = BlockTileIteratorUnrollWidth; - using DataGlobal2ShareMemVisitor = - Global2ShareMemVisitorIMMA_CIxWOxN; + using DataGlobal2ShareMemVisitor = Global2ShareMemVisitorIMMA_CIxWOxN< + check_bounds, DataTileCount, InputLayout>; using FilterGlobal2ShareMemVisitor = - Global2ShareMemVisitorIMMA_CIxN; + Global2ShareMemVisitorIMMA_CIxN; static bool constexpr pipelined = WarpTileConfig::warp_tile_k > 1; - using BlockConsumer = IConvIMMABlockConsumer; + using BlockConsumer = + IConvIMMABlockConsumer; struct GlobalMemoryStoreCount { using IMMAConfig = IMMAConfig; @@ -335,33 +306,31 @@ struct IConvIMMATraitUnrollWidth { static int constexpr consecutive_width_tile = !(WarpTileConfig::warp_tile_n & 0x1); static int constexpr smem_w = - consecutive_width_tile - ? 2 * ThreadConfig::nr_warp_x * IMMAConfig::wmma_m * - IMMAConfig::wmma_n - : ThreadConfig::nr_warp_x * IMMAConfig::wmma_m * - IMMAConfig::wmma_n; + consecutive_width_tile ? 2 * ThreadConfig::nr_warp_x * + IMMAConfig::wmma_m * IMMAConfig::wmma_n + : ThreadConfig::nr_warp_x * IMMAConfig::wmma_m * + IMMAConfig::wmma_n; static int constexpr store_width = sizeof(copy_t) / sizeof(int32_t); static int constexpr smem_stride = smem_w; static int constexpr smem_tot = smem_h * smem_stride; - static int constexpr store_x = - consecutive_width_tile ? 2 * IMMAConfig::wmma_n / store_width - : IMMAConfig::wmma_n / store_width; + static int constexpr store_x = consecutive_width_tile + ? 2 * IMMAConfig::wmma_n / store_width + : IMMAConfig::wmma_n / store_width; static int constexpr store_y = ThreadConfig::warp_size / store_x; }; using GlobalMemoryWriter = IConvIMMAGlobalMemoryWriterUnrollWidth; }; -template +template < + bool check_bounds_, typename Conv1dConfig_, typename IMMAConfig_, + typename WarpTileConfig_, typename ThreadConfig_> struct IConvIMMATraitUnrollWidthV2 { - COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(int8_t, int8_t, int32_t, - Layout, - Layout, - Layout, - ConvParam); + COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( + int8_t, int8_t, int32_t, Layout, Layout, + Layout, ConvParam); using Conv1dConfig = Conv1dConfig_; using IMMAConfig = IMMAConfig_; using WarpTileConfig = WarpTileConfig_; @@ -373,15 +342,14 @@ struct IConvIMMATraitUnrollWidthV2 { using ThreadConfig = ThreadConfig; using Conv1dConfig = Conv1dConfig; - MEGDNN_STATIC_ASSERT(WarpTileConfig::warp_tile_k == 1, - "kernel unrolling along width axis assumes tile k " - "in warp-level must be 1"); + MEGDNN_STATIC_ASSERT( + WarpTileConfig::warp_tile_k == 1, + "kernel unrolling along width axis assumes tile k " + "in warp-level must be 1"); using copy_t = int4; using smem_storage_dtype = smem_storage_dtype; - static int constexpr load_width = - sizeof(copy_t) / sizeof(smem_storage_dtype); - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(src_dtype); + static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(src_dtype); static int constexpr block_tile_out_width = WarpTileConfig::warp_tile_n * ThreadConfig::nr_warp_x; @@ -398,16 +366,14 @@ struct IConvIMMATraitUnrollWidthV2 { static int constexpr load_y = ThreadConfig::nr_threads / load_x; // smem col major - static int constexpr smem_h = - WarpTileConfig::warp_tile_k * block_tile_in_width; + static int constexpr smem_h = WarpTileConfig::warp_tile_k * block_tile_in_width; static int constexpr smem_w = IMMAConfig::tile_b_sizes_int; static int constexpr smem_stride = smem_w; static int constexpr smem_tot = smem_h * smem_stride; static int constexpr reg_h = (smem_h + load_y - 1) / load_y; static int constexpr reg_w = (smem_load_x + load_x - 1) / load_x; - static int constexpr reg_d = - IMMAConfig::wmma_k / WarpTileConfig::pack_size; + static int constexpr reg_d = IMMAConfig::wmma_k / WarpTileConfig::pack_size; static bool constexpr check_bounds_h = smem_h % load_y != 0; static bool constexpr check_bounds_w = smem_load_x % load_x != 0; @@ -419,18 +385,17 @@ struct IConvIMMATraitUnrollWidthV2 { using ThreadConfig = ThreadConfig; using Conv1dConfig = Conv1dConfig; - MEGDNN_STATIC_ASSERT(WarpTileConfig::warp_tile_k == 1, - "kernel unrolling along width axis assumes tile k " - "in warp-level must be 1"); + MEGDNN_STATIC_ASSERT( + WarpTileConfig::warp_tile_k == 1, + "kernel unrolling along width axis assumes tile k " + "in warp-level must be 1"); using copy_t = int4; using smem_storage_dtype = smem_storage_dtype; - static int constexpr load_width = - sizeof(copy_t) / sizeof(smem_storage_dtype); - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(filter_dtype); - static int constexpr block_tile_out_channel = - WarpTileConfig::warp_tile_m * IMMAConfig::wmma_m * - ThreadConfig::nr_warp_y; + static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(filter_dtype); + static int constexpr block_tile_out_channel = WarpTileConfig::warp_tile_m * + IMMAConfig::wmma_m * + ThreadConfig::nr_warp_y; static int constexpr block_tile_in_channel = WarpTileConfig::warp_tile_k * IMMAConfig::wmma_k; @@ -455,15 +420,12 @@ struct IConvIMMATraitUnrollWidthV2 { using BlockTileIterator = BlockTileIteratorUnrollWidthV2; - using DataGlobal2ShareMemVisitor = - Global2ShareMemVisitorIMMA_CIxWIxN; + using DataGlobal2ShareMemVisitor = Global2ShareMemVisitorIMMA_CIxWIxN< + check_bounds, DataTileCount, InputLayout>; using FilterGlobal2ShareMemVisitor = - Global2ShareMemVisitorIMMA_FWxCO; - using BlockConsumer = - IConvIMMABlockConsumerUnrollWidth; + Global2ShareMemVisitorIMMA_FWxCO; + using BlockConsumer = IConvIMMABlockConsumerUnrollWidth< + Conv1dConfig, IMMAConfig, WarpTileConfig, ThreadConfig>; using GlobalMemoryStoreCount = typename IConvIMMATraitUnrollWidth< check_bounds, IMMAConfig, WarpTileConfig, ThreadConfig>::GlobalMemoryStoreCount; diff --git a/dnn/src/cuda/convolution_helper/conv_trait/iconv_trait.cuh b/dnn/src/cuda/convolution_helper/conv_trait/iconv_trait.cuh index f43f820d..d4c52724 100644 --- a/dnn/src/cuda/convolution_helper/conv_trait/iconv_trait.cuh +++ b/dnn/src/cuda/convolution_helper/conv_trait/iconv_trait.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -43,27 +44,26 @@ namespace megdnn { namespace cuda { namespace convolution { -#define COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( \ - _src_dtype, _filter_dtype, _smem_storage_dtype, _input_layout, \ - _kern_layout, _output_layout, _conv_param) \ - using src_dtype = _src_dtype; \ - using filter_dtype = _filter_dtype; \ - using smem_storage_dtype = _smem_storage_dtype; \ - using InputLayout = _input_layout; \ - using KernLayout = _kern_layout; \ - using OutputLayout = _output_layout; \ - using Param = _conv_param; \ +#define COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( \ + _src_dtype, _filter_dtype, _smem_storage_dtype, _input_layout, _kern_layout, \ + _output_layout, _conv_param) \ + using src_dtype = _src_dtype; \ + using filter_dtype = _filter_dtype; \ + using smem_storage_dtype = _smem_storage_dtype; \ + using InputLayout = _input_layout; \ + using KernLayout = _kern_layout; \ + using OutputLayout = _output_layout; \ + using Param = _conv_param; \ static constexpr bool check_bounds = check_bounds_; #define MEGDNN_COMMA , -template +template < + bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_, + typename ThreadConfig_> struct IConvTrait { - COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(int8_t, int8_t, int32_t, - Layout, - Layout, - Layout, - ConvParam); + COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( + int8_t, int8_t, int32_t, Layout, Layout, + Layout, ConvParam); using RegBlockConfig = RegBlockConfig_; using ThreadConfig = ThreadConfig_; struct DataTileCount { @@ -71,10 +71,8 @@ struct IConvTrait { using ThreadConfig = ThreadConfig; using copy_t = ldg_dtype; using smem_storage_dtype = smem_storage_dtype; - static int constexpr load_width = - sizeof(copy_t) / sizeof(smem_storage_dtype); - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(src_dtype); + static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(src_dtype); static int constexpr skew = load_width; static int constexpr block_tile_batch = RegBlockConfig::reg_n * ThreadConfig::nr_thread_x; @@ -86,8 +84,7 @@ struct IConvTrait { static int constexpr smem_h = RegBlockConfig::reg_k_packed; static int constexpr smem_w = block_tile_batch; - static int constexpr smem_stride = - smem_w % 2 == 0 ? smem_w + skew : smem_w; + static int constexpr smem_stride = smem_w % 2 == 0 ? smem_w + skew : smem_w; static int constexpr smem_tot = smem_h * smem_stride; static int constexpr reg_h = (smem_h + load_y - 1) / load_y; @@ -102,10 +99,8 @@ struct IConvTrait { using ThreadConfig = ThreadConfig; using copy_t = ldg_dtype; using smem_storage_dtype = smem_storage_dtype; - static int constexpr load_width = - sizeof(copy_t) / sizeof(smem_storage_dtype); - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(filter_dtype); + static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(filter_dtype); static int constexpr skew = load_width; static int constexpr block_tile_out_channel = RegBlockConfig::reg_m * ThreadConfig::nr_thread_y; @@ -117,8 +112,7 @@ struct IConvTrait { static int constexpr smem_h = RegBlockConfig::reg_k_packed; static int constexpr smem_w = block_tile_out_channel; - static int constexpr smem_stride = - smem_w % 2 == 0 ? smem_w + skew : smem_w; + static int constexpr smem_stride = smem_w % 2 == 0 ? smem_w + skew : smem_w; static int constexpr smem_tot = smem_h * smem_stride; static int constexpr reg_h = (smem_h + load_y - 1) / load_y; @@ -128,27 +122,23 @@ struct IConvTrait { static bool constexpr check_bounds_w = smem_load_x % load_x != 0; }; - using BlockTileIterator = - BlockTileIteratorBasic; + using BlockTileIterator = BlockTileIteratorBasic; using DataGlobal2ShareMemVisitor = Global2ShareMemVisitor_CIxN; using FilterGlobal2ShareMemVisitor = Global2ShareMemVisitor_CIxN; static bool constexpr pipelined = RegBlockConfig::reg_k_packed > 1; - using BlockConsumer = - IConvBlockConsumer; - using GlobalMemoryWriter = - IConvGlobalMemoryWriter; + using BlockConsumer = IConvBlockConsumer; + using GlobalMemoryWriter = IConvGlobalMemoryWriter; }; -template +template < + bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_, + typename ThreadConfig_> struct IConvTraitUnrollWidth { - COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM(int8_t, int8_t, int32_t, - Layout, - Layout, - Layout, - ConvParam); + COMMON_DEFS_WITH_DATA_TYPE_LAYOUT_AND_PARAM( + int8_t, int8_t, int32_t, Layout, Layout, + Layout, ConvParam); using RegBlockConfig = RegBlockConfig_; using ThreadConfig = ThreadConfig_; struct DataTileCount { @@ -156,10 +146,8 @@ struct IConvTraitUnrollWidth { using ThreadConfig = ThreadConfig; using copy_t = ldg_dtype; using smem_storage_dtype = smem_storage_dtype; - static int constexpr load_width = - sizeof(copy_t) / sizeof(smem_storage_dtype); - static int constexpr ldg_load_width = - sizeof(copy_t) / sizeof(src_dtype); + static int constexpr load_width = sizeof(copy_t) / sizeof(smem_storage_dtype); + static int constexpr ldg_load_width = sizeof(copy_t) / sizeof(src_dtype); static int constexpr skew = load_width; static int constexpr block_tile_batch = RegBlockConfig::reg_n * ThreadConfig::nr_thread_x; @@ -173,8 +161,7 @@ struct IConvTraitUnrollWidth { static int constexpr smem_h = RegBlockConfig::reg_k_packed; static int constexpr smem_w = block_tile_batch; static int constexpr img_cache = RegBlockConfig::reg_width; - static int constexpr smem_stride = - smem_w % 2 == 0 ? smem_w + skew : smem_w; + static int constexpr smem_stride = smem_w % 2 == 0 ? smem_w + skew : smem_w; static int constexpr smem_tot = smem_h * img_cache * smem_stride; static int constexpr reg_h = (smem_h + load_y - 1) / load_y; @@ -186,25 +173,20 @@ struct IConvTraitUnrollWidth { MEGDNN_STATIC_ASSERT( std::is_same::filter_dtype - MEGDNN_COMMA filter_dtype>::value == true, + RegBlockConfig MEGDNN_COMMA ThreadConfig>:: + filter_dtype MEGDNN_COMMA filter_dtype>::value == true, "data type of filter tensor should be int8_t"); - using FilterTileCount = - typename IConvTrait::FilterTileCount; + using FilterTileCount = typename IConvTrait< + check_bounds, ldg_dtype, RegBlockConfig, ThreadConfig>::FilterTileCount; using BlockTileIterator = BlockTileIteratorUnrollWidth; using DataGlobal2ShareMemVisitor = - Global2ShareMemVisitor_CIxWOxN; + Global2ShareMemVisitor_CIxWOxN; using FilterGlobal2ShareMemVisitor = - Global2ShareMemVisitor_CIxN; + Global2ShareMemVisitor_CIxN; static bool constexpr pipelined = RegBlockConfig::reg_k_packed > 1; using BlockConsumer = - IConvBlockConsumerUnrollWidth; + IConvBlockConsumerUnrollWidth; using GlobalMemoryWriter = IConvGlobalMemoryWriterUnrollWidth; }; diff --git a/dnn/src/cuda/convolution_helper/epilogue.cuh b/dnn/src/cuda/convolution_helper/epilogue.cuh index 929457cf..ef63bae4 100644 --- a/dnn/src/cuda/convolution_helper/epilogue.cuh +++ b/dnn/src/cuda/convolution_helper/epilogue.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -50,12 +51,10 @@ struct IConvEpilogue { int width_stride; float gamma; ActivationOp act; - MEGDNN_HOST MEGDNN_DEVICE IConvEpilogue(int8_t* __restrict__ dst, - const int8_t* __restrict__ z, - int batch_stride, - int channel_stride, - int height_stride, int width_stride, - float gamma, ActivationOp act) + MEGDNN_HOST MEGDNN_DEVICE IConvEpilogue( + int8_t* __restrict__ dst, const int8_t* __restrict__ z, int batch_stride, + int channel_stride, int height_stride, int width_stride, float gamma, + ActivationOp act) : dst{dst}, z{z}, batch_stride{batch_stride}, @@ -65,18 +64,17 @@ struct IConvEpilogue { gamma{gamma}, act{act} {} #if MEGDNN_CC_CUDA - __device__ __forceinline__ void move(const int b_idx, const int ch_idx, - const int h_idx, const int w_idx) { + __device__ __forceinline__ void move( + const int b_idx, const int ch_idx, const int h_idx, const int w_idx) { size_t offset = b_idx * batch_stride + ch_idx * channel_stride + h_idx * height_stride + w_idx * width_stride; dst += offset; if (z != nullptr) z += offset; } - __device__ __forceinline__ void apply(float alpha, float4 f_conv, - float beta, float4 f_bias, - const int b_idx, const int ch_idx, - const int h_idx, const int w_idx) { + __device__ __forceinline__ void apply( + float alpha, float4 f_conv, float beta, float4 f_bias, const int b_idx, + const int ch_idx, const int h_idx, const int w_idx) { size_t idx = b_idx * batch_stride + ch_idx * channel_stride + h_idx * height_stride + w_idx * width_stride; float4 f_res = alpha * f_conv + beta * f_bias; @@ -85,29 +83,25 @@ struct IConvEpilogue { float4 f_z = transform_int8x4_to_float4(i_z); f_res = f_res + gamma * f_z; } - *(reinterpret_cast(&dst[idx])) = - act.apply_and_transform(f_res); + *(reinterpret_cast(&dst[idx])) = act.apply_and_transform(f_res); } - __device__ __forceinline__ void apply(float alpha, float4 f_conv, - float beta, float4 f_bias, - const int b_idx, const int ch_idx, - const int hw_idx) { - size_t idx = b_idx * batch_stride + ch_idx * channel_stride + - hw_idx * width_stride; + __device__ __forceinline__ void apply( + float alpha, float4 f_conv, float beta, float4 f_bias, const int b_idx, + const int ch_idx, const int hw_idx) { + size_t idx = + b_idx * batch_stride + ch_idx * channel_stride + hw_idx * width_stride; float4 f_res = alpha * f_conv + beta * f_bias; if (z != nullptr) { int i_z = __ldg(reinterpret_cast(&z[idx])); float4 f_z = transform_int8x4_to_float4(i_z); f_res = f_res + gamma * f_z; } - *(reinterpret_cast(&dst[idx])) = - act.apply_and_transform(f_res); + *(reinterpret_cast(&dst[idx])) = act.apply_and_transform(f_res); } - __device__ __forceinline__ void apply(float alpha, float4 f_conv_x, - float4 f_conv_y, float beta, - float4 f_bias_x, float4 f_bias_y, - const int b_idx, const int ch_idx, - const int h_idx, const int w_idx) { + __device__ __forceinline__ void apply( + float alpha, float4 f_conv_x, float4 f_conv_y, float beta, float4 f_bias_x, + float4 f_bias_y, const int b_idx, const int ch_idx, const int h_idx, + const int w_idx) { size_t idx = b_idx * batch_stride + ch_idx * channel_stride + h_idx * height_stride + w_idx * width_stride; float4 f_res_x = alpha * f_conv_x + beta * f_bias_x; @@ -123,13 +117,11 @@ struct IConvEpilogue { int iy = act.apply_and_transform(f_res_y); *(reinterpret_cast(&dst[idx])) = ::make_int2(ix, iy); } - __device__ __forceinline__ void apply(float alpha, float4 f_conv_x, - float4 f_conv_y, float beta, - float4 f_bias_x, float4 f_bias_y, - const int b_idx, const int ch_idx, - const int hw_idx) { - size_t idx = b_idx * batch_stride + ch_idx * channel_stride + - hw_idx * width_stride; + __device__ __forceinline__ void apply( + float alpha, float4 f_conv_x, float4 f_conv_y, float beta, float4 f_bias_x, + float4 f_bias_y, const int b_idx, const int ch_idx, const int hw_idx) { + size_t idx = + b_idx * batch_stride + ch_idx * channel_stride + hw_idx * width_stride; float4 f_res_x = alpha * f_conv_x + beta * f_bias_x; float4 f_res_y = alpha * f_conv_y + beta * f_bias_y; if (z != nullptr) { @@ -144,13 +136,11 @@ struct IConvEpilogue { *(reinterpret_cast(&dst[idx])) = ::make_int2(ix, iy); } - __device__ __forceinline__ void apply(float alpha, float4 f_conv_x, - float4 f_conv_y, float4 f_conv_z, - float4 f_conv_w, float beta, - float4 f_bias_x, float4 f_bias_y, - float4 f_bias_z, float4 f_bias_w, - const int b_idx, const int ch_idx, - const int h_idx, const int w_idx) { + __device__ __forceinline__ void apply( + float alpha, float4 f_conv_x, float4 f_conv_y, float4 f_conv_z, + float4 f_conv_w, float beta, float4 f_bias_x, float4 f_bias_y, + float4 f_bias_z, float4 f_bias_w, const int b_idx, const int ch_idx, + const int h_idx, const int w_idx) { size_t idx = b_idx * batch_stride + ch_idx * channel_stride + h_idx * height_stride + w_idx * width_stride; float4 f_res_x = alpha * f_conv_x + beta * f_bias_x; @@ -176,15 +166,13 @@ struct IConvEpilogue { int iw = act.apply_and_transform(f_res_w); *(reinterpret_cast(&dst[idx])) = ::make_int4(ix, iy, iz, iw); } - __device__ __forceinline__ void apply(float alpha, float4 f_conv_x, - float4 f_conv_y, float4 f_conv_z, - float4 f_conv_w, float beta, - float4 f_bias_x, float4 f_bias_y, - float4 f_bias_z, float4 f_bias_w, - const int b_idx, const int ch_idx, - const int hw_idx) { - size_t idx = b_idx * batch_stride + ch_idx * channel_stride + - hw_idx * width_stride; + __device__ __forceinline__ void apply( + float alpha, float4 f_conv_x, float4 f_conv_y, float4 f_conv_z, + float4 f_conv_w, float beta, float4 f_bias_x, float4 f_bias_y, + float4 f_bias_z, float4 f_bias_w, const int b_idx, const int ch_idx, + const int hw_idx) { + size_t idx = + b_idx * batch_stride + ch_idx * channel_stride + hw_idx * width_stride; float4 f_res_x = alpha * f_conv_x + beta * f_bias_x; float4 f_res_y = alpha * f_conv_y + beta * f_bias_y; float4 f_res_z = alpha * f_conv_z + beta * f_bias_z; diff --git a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor.cuh b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor.cuh index 8212d282..e0c902ba 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -36,10 +37,12 @@ #include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixn.cuh" #include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixwoxn.cuh" #include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixn.cuh" -#include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_fwxco.cuh" -#include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwoxn.cuh" #include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwixn.cuh" -//#include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_small_channel.cuh" -//#include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_with_img_cache.cuh" +#include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwoxn.cuh" +#include "src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_fwxco.cuh" +//#include +//"src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_small_channel.cuh" +//#include +//"src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_with_img_cache.cuh" // vim: ft=cpp syntax=cuda.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixhw.cuh b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixhw.cuh index 1ebbd777..554202fe 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixhw.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixhw.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixhw.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixhw.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -73,14 +75,13 @@ struct Global2ShareMemVisitorBase_CIxHW { } }; -template +template struct Global2ShareMemVisitor_CIxHW; #define DEF(_precomp_offset, _Layout) \ template \ - struct Global2ShareMemVisitor_CIxHW \ + struct Global2ShareMemVisitor_CIxHW< \ + check_bounds, _precomp_offset, TileCount_, _Layout> \ : public Global2ShareMemVisitorBase_CIxHW { \ using Base = Global2ShareMemVisitorBase_CIxHW; \ using TileCount = typename Base::TileCount; \ @@ -104,191 +105,180 @@ struct Global2ShareMemVisitor_CIxHW; const int* __restrict__ offset; \ int remain; -DEF(true, Layout) +DEF(true, Layout) - copy_t reg[TileCount::reg_h][TileCount::reg_w][TileCount::reg_d]; - MEGDNN_STATIC_ASSERT(load_width == 4, - "load four element from src tensor per time"); +copy_t reg[TileCount::reg_h][TileCount::reg_w][TileCount::reg_d]; +MEGDNN_STATIC_ASSERT(load_width == 4, "load four element from src tensor per time"); - __device__ Global2ShareMemVisitor_CIxHW(smem_storage_dtype* smem_, - const int* __restrict__ offset_) - : Base{smem_}, offset{offset_} {} +__device__ Global2ShareMemVisitor_CIxHW( + smem_storage_dtype* smem_, const int* __restrict__ offset_) + : Base{smem_}, offset{offset_} {} - __device__ __forceinline__ void first_copy() { +__device__ __forceinline__ void first_copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - int out_offset = w_idx * load_width; - int4 in_offset = - *reinterpret_cast(&offset[out_offset]); - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - copy_t ix = make_zero(); - copy_t iy = ix; - copy_t iz = ix; - copy_t iw = ix; - if (in_offset.x >= 0) { - ix = g_ptr[h_idx * stride + in_offset.x]; - } - if (in_offset.y >= 0) { - iy = g_ptr[h_idx * stride + in_offset.y]; - } - if (in_offset.z >= 0) { - iz = g_ptr[h_idx * stride + in_offset.z]; - } - if (in_offset.w >= 0) { - iw = g_ptr[h_idx * stride + in_offset.w]; - } - *(sh_ptr_as_copy_t(h_idx, out_offset + 0)) = ix; - *(sh_ptr_as_copy_t(h_idx, out_offset + 1)) = iy; - *(sh_ptr_as_copy_t(h_idx, out_offset + 2)) = iz; - *(sh_ptr_as_copy_t(h_idx, out_offset + 3)) = iw; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + int out_offset = w_idx * load_width; + int4 in_offset = *reinterpret_cast(&offset[out_offset]); + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + copy_t ix = make_zero(); + copy_t iy = ix; + copy_t iz = ix; + copy_t iw = ix; + if (in_offset.x >= 0) { + ix = g_ptr[h_idx * stride + in_offset.x]; + } + if (in_offset.y >= 0) { + iy = g_ptr[h_idx * stride + in_offset.y]; + } + if (in_offset.z >= 0) { + iz = g_ptr[h_idx * stride + in_offset.z]; } + if (in_offset.w >= 0) { + iw = g_ptr[h_idx * stride + in_offset.w]; + } + *(sh_ptr_as_copy_t(h_idx, out_offset + 0)) = ix; + *(sh_ptr_as_copy_t(h_idx, out_offset + 1)) = iy; + *(sh_ptr_as_copy_t(h_idx, out_offset + 2)) = iz; + *(sh_ptr_as_copy_t(h_idx, out_offset + 3)) = iw; } } +} - __device__ __forceinline__ void copy() { +__device__ __forceinline__ void copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - int out_offset = w_idx * load_width; - int4 in_offset = - *reinterpret_cast(&offset[out_offset]); - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - copy_t ix = make_zero(); - copy_t iy = ix; - copy_t iz = ix; - copy_t iw = ix; - if (in_offset.x >= 0) { - ix = g_ptr[h_idx * stride + in_offset.x]; - } - if (in_offset.y >= 0) { - iy = g_ptr[h_idx * stride + in_offset.y]; - } - if (in_offset.z >= 0) { - iz = g_ptr[h_idx * stride + in_offset.z]; - } - if (in_offset.w >= 0) { - iw = g_ptr[h_idx * stride + in_offset.w]; - } - reg[i][j][0] = ix; - reg[i][j][1] = iy; - reg[i][j][2] = iz; - reg[i][j][3] = iw; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + int out_offset = w_idx * load_width; + int4 in_offset = *reinterpret_cast(&offset[out_offset]); + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + copy_t ix = make_zero(); + copy_t iy = ix; + copy_t iz = ix; + copy_t iw = ix; + if (in_offset.x >= 0) { + ix = g_ptr[h_idx * stride + in_offset.x]; } + if (in_offset.y >= 0) { + iy = g_ptr[h_idx * stride + in_offset.y]; + } + if (in_offset.z >= 0) { + iz = g_ptr[h_idx * stride + in_offset.z]; + } + if (in_offset.w >= 0) { + iw = g_ptr[h_idx * stride + in_offset.w]; + } + reg[i][j][0] = ix; + reg[i][j][1] = iy; + reg[i][j][2] = iz; + reg[i][j][3] = iw; } } +} - __device__ __forceinline__ void commit() { +__device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - int out_offset = w_idx * load_width; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - *(sh_ptr_as_copy_t(h_idx, out_offset + 0)) = reg[i][j][0]; - *(sh_ptr_as_copy_t(h_idx, out_offset + 1)) = reg[i][j][1]; - *(sh_ptr_as_copy_t(h_idx, out_offset + 2)) = reg[i][j][2]; - *(sh_ptr_as_copy_t(h_idx, out_offset + 3)) = reg[i][j][3]; - } + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + int out_offset = w_idx * load_width; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + *(sh_ptr_as_copy_t(h_idx, out_offset + 0)) = reg[i][j][0]; + *(sh_ptr_as_copy_t(h_idx, out_offset + 1)) = reg[i][j][1]; + *(sh_ptr_as_copy_t(h_idx, out_offset + 2)) = reg[i][j][2]; + *(sh_ptr_as_copy_t(h_idx, out_offset + 3)) = reg[i][j][3]; } } +} }; - + DEF(false, Layout) - copy_t reg[TileCount::reg_h][TileCount::reg_w]; - __device__ Global2ShareMemVisitor_CIxHW(smem_storage_dtype* smem_) - : Base{smem_} {} +copy_t reg[TileCount::reg_h][TileCount::reg_w]; +__device__ Global2ShareMemVisitor_CIxHW(smem_storage_dtype* smem_) : Base{smem_} {} - __device__ __forceinline__ void first_copy() { +__device__ __forceinline__ void first_copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - int spatial = w_idx * load_width; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (spatial < remain) { - val = g_ptr[h_idx * stride + w_idx]; - } - *(sh_ptr_as_copy_t(h_idx, spatial)) = val; - } else { - *(sh_ptr_as_copy_t(h_idx, spatial)) = - g_ptr[h_idx * stride + w_idx]; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + int spatial = w_idx * load_width; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (spatial < remain) { + val = g_ptr[h_idx * stride + w_idx]; } + *(sh_ptr_as_copy_t(h_idx, spatial)) = val; + } else { + *(sh_ptr_as_copy_t(h_idx, spatial)) = g_ptr[h_idx * stride + w_idx]; } } } +} - __device__ __forceinline__ void copy() { +__device__ __forceinline__ void copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - int spatial = w_idx * load_width; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (spatial < remain) { - val = g_ptr[h_idx * stride + w_idx]; - } - reg[i][j] = val; - } else { - reg[i][j] = g_ptr[h_idx * stride + w_idx]; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + int spatial = w_idx * load_width; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (spatial < remain) { + val = g_ptr[h_idx * stride + w_idx]; } + reg[i][j] = val; + } else { + reg[i][j] = g_ptr[h_idx * stride + w_idx]; } } } +} - __device__ __forceinline__ void commit() { +__device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = reg[i][j]; - } + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = reg[i][j]; } } +} }; #undef DEF diff --git a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixn.cuh b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixn.cuh index fbb08684..63f0ff2c 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixn.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixn.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixn.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixn.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -44,104 +46,100 @@ template struct Global2ShareMemVisitor_CIxN; DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitor_CIxN, Layout) - using RegBlockConfig = typename TileCount::RegBlockConfig; - using ThreadConfig = typename TileCount::ThreadConfig; - int stride; - int remain; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int tid = tidy * ThreadConfig::nr_thread_x + tidx; - const int gl_load_y = tid / TileCount::load_x; - const int gl_load_x = tid - gl_load_y * TileCount::load_x; - - copy_t reg[TileCount::reg_h][TileCount::reg_w]; +using RegBlockConfig = typename TileCount::RegBlockConfig; +using ThreadConfig = typename TileCount::ThreadConfig; +int stride; +int remain; - __device__ __forceinline__ void init_stride(Layout layout) { - stride = layout.channel_stride / TileCount::ldg_load_width; - } +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; +const int tid = tidy * ThreadConfig::nr_thread_x + tidx; +const int gl_load_y = tid / TileCount::load_x; +const int gl_load_x = tid - gl_load_y * TileCount::load_x; + +copy_t reg[TileCount::reg_h][TileCount::reg_w]; - __device__ __forceinline__ void first_copy() { +__device__ __forceinline__ void init_stride(Layout layout) { + stride = layout.channel_stride / TileCount::ldg_load_width; +} + +__device__ __forceinline__ void first_copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - int batch = w_idx * load_width; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (batch < remain) { - val = g_ptr[h_idx * stride + w_idx]; - } - *(sh_ptr_as_copy_t(h_idx, batch)) = val; - } else { - *(sh_ptr_as_copy_t(h_idx, batch)) = - g_ptr[h_idx * stride + w_idx]; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + int batch = w_idx * load_width; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (batch < remain) { + val = g_ptr[h_idx * stride + w_idx]; } + *(sh_ptr_as_copy_t(h_idx, batch)) = val; + } else { + *(sh_ptr_as_copy_t(h_idx, batch)) = g_ptr[h_idx * stride + w_idx]; } } } - - __device__ __forceinline__ void copy() { +} + +__device__ __forceinline__ void copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - int batch = w_idx * load_width; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (batch < remain) { - val = g_ptr[h_idx * stride + w_idx]; - } - reg[i][j] = val; - } else { - reg[i][j] = g_ptr[h_idx * stride + w_idx]; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + int batch = w_idx * load_width; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (batch < remain) { + val = g_ptr[h_idx * stride + w_idx]; } + reg[i][j] = val; + } else { + reg[i][j] = g_ptr[h_idx * stride + w_idx]; } } } - - __device__ __forceinline__ void commit() { +} + +__device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = reg[i][j]; - } + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = reg[i][j]; } } +} - __device__ __forceinline__ int32_t* sh_ptr(int y, int x) { - return &smem[y * TileCount::smem_stride + x]; - } - - __device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { - return reinterpret_cast(sh_ptr(y, x)); - } +__device__ __forceinline__ int32_t* sh_ptr(int y, int x) { + return &smem[y * TileCount::smem_stride + x]; +} - __device__ __forceinline__ void move_forward() { - g_ptr += RegBlockConfig::reg_k_packed * stride; - } +__device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { + return reinterpret_cast(sh_ptr(y, x)); +} + +__device__ __forceinline__ void move_forward() { + g_ptr += RegBlockConfig::reg_k_packed * stride; +} }; } // namespace cuda diff --git a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixwoxn.cuh b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixwoxn.cuh index 829799f9..208f76e8 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixwoxn.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixwoxn.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixwoxn.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_cixwoxn.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -43,141 +45,129 @@ namespace convolution { template struct Global2ShareMemVisitor_CIxWOxN; -DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitor_CIxWOxN, - Layout) - using RegBlockConfig = typename TileCount::RegBlockConfig; - using ThreadConfig = typename TileCount::ThreadConfig; - int sw; - int stride; - int remain; - int img_stride; - int img_start; - int img_end; +DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitor_CIxWOxN, Layout) +using RegBlockConfig = typename TileCount::RegBlockConfig; +using ThreadConfig = typename TileCount::ThreadConfig; +int sw; +int stride; +int remain; +int img_stride; +int img_start; +int img_end; - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int tid = tidy * ThreadConfig::nr_thread_x + tidx; - const int gl_load_y = tid / TileCount::load_x; - const int gl_load_x = tid - gl_load_y * TileCount::load_x; +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; +const int tid = tidy * ThreadConfig::nr_thread_x + tidx; +const int gl_load_y = tid / TileCount::load_x; +const int gl_load_x = tid - gl_load_y * TileCount::load_x; - copy_t reg[TileCount::reg_h][TileCount::img_cache][TileCount::reg_w]; +copy_t reg[TileCount::reg_h][TileCount::img_cache][TileCount::reg_w]; - __device__ __forceinline__ void init_stride(Layout layout) { - stride = layout.channel_stride / TileCount::ldg_load_width; - img_stride = layout.width_stride / TileCount::ldg_load_width; - } +__device__ __forceinline__ void init_stride(Layout layout) { + stride = layout.channel_stride / TileCount::ldg_load_width; + img_stride = layout.width_stride / TileCount::ldg_load_width; +} - __device__ __forceinline__ void first_copy() { +__device__ __forceinline__ void first_copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::img_cache; ++j) { - int jstride = j * sw; + for (int j = 0; j < TileCount::img_cache; ++j) { + int jstride = j * sw; #pragma unroll - for (int k = 0; k < TileCount::reg_w; ++k) { - int w_idx = gl_load_x + k * TileCount::load_x; - int batch = w_idx * load_width; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (jstride >= img_start && jstride < img_end && - batch < remain) { - val = g_ptr[h_idx * stride + jstride * img_stride + - w_idx]; - } - *(sh_ptr_as_copy_t(h_idx, j, batch)) = val; - } else { - copy_t val = make_zero(); - if (jstride >= img_start && jstride < img_end) { - val = g_ptr[h_idx * stride + jstride * img_stride + - w_idx]; - } - *(sh_ptr_as_copy_t(h_idx, j, batch)) = val; + for (int k = 0; k < TileCount::reg_w; ++k) { + int w_idx = gl_load_x + k * TileCount::load_x; + int batch = w_idx * load_width; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (jstride >= img_start && jstride < img_end && batch < remain) { + val = g_ptr[h_idx * stride + jstride * img_stride + w_idx]; + } + *(sh_ptr_as_copy_t(h_idx, j, batch)) = val; + } else { + copy_t val = make_zero(); + if (jstride >= img_start && jstride < img_end) { + val = g_ptr[h_idx * stride + jstride * img_stride + w_idx]; } + *(sh_ptr_as_copy_t(h_idx, j, batch)) = val; } } } } +} - __device__ __forceinline__ void copy() { +__device__ __forceinline__ void copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::img_cache; ++j) { - int jstride = j * sw; + for (int j = 0; j < TileCount::img_cache; ++j) { + int jstride = j * sw; #pragma unroll - for (int k = 0; k < TileCount::reg_w; ++k) { - int w_idx = gl_load_x + k * TileCount::load_x; - int batch = w_idx * load_width; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (jstride >= img_start && jstride < img_end && - batch < remain) { - val = g_ptr[h_idx * stride + jstride * img_stride + - w_idx]; - } - reg[i][j][k] = val; - } else { - copy_t val = make_zero(); - if (jstride >= img_start && jstride < img_end) { - val = g_ptr[h_idx * stride + jstride * img_stride + - w_idx]; - } - reg[i][j][k] = val; + for (int k = 0; k < TileCount::reg_w; ++k) { + int w_idx = gl_load_x + k * TileCount::load_x; + int batch = w_idx * load_width; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (jstride >= img_start && jstride < img_end && batch < remain) { + val = g_ptr[h_idx * stride + jstride * img_stride + w_idx]; } + reg[i][j][k] = val; + } else { + copy_t val = make_zero(); + if (jstride >= img_start && jstride < img_end) { + val = g_ptr[h_idx * stride + jstride * img_stride + w_idx]; + } + reg[i][j][k] = val; } } } } +} - __device__ __forceinline__ void commit() { +__device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::img_cache; ++j) { + for (int j = 0; j < TileCount::img_cache; ++j) { #pragma unroll - for (int k = 0; k < TileCount::reg_w; ++k) { - int w_idx = gl_load_x + k * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - *(sh_ptr_as_copy_t(h_idx, j, w_idx * load_width)) = - reg[i][j][k]; - } + for (int k = 0; k < TileCount::reg_w; ++k) { + int w_idx = gl_load_x + k * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + *(sh_ptr_as_copy_t(h_idx, j, w_idx * load_width)) = reg[i][j][k]; } } } +} - __device__ __forceinline__ int32_t* sh_ptr(int z, int y, int x) { - return &smem[(z * TileCount::img_cache + y) * TileCount::smem_stride + - x]; - } +__device__ __forceinline__ int32_t* sh_ptr(int z, int y, int x) { + return &smem[(z * TileCount::img_cache + y) * TileCount::smem_stride + x]; +} - __device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int z, int y, int x) { - return reinterpret_cast(sh_ptr(z, y, x)); - } +__device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int z, int y, int x) { + return reinterpret_cast(sh_ptr(z, y, x)); +} - __device__ __forceinline__ void move_forward() { - g_ptr += RegBlockConfig::reg_k_packed * stride; - } +__device__ __forceinline__ void move_forward() { + g_ptr += RegBlockConfig::reg_k_packed * stride; +} - __device__ __forceinline__ void set_range(const int start, const int end) { - img_start = start, img_end = end; - } +__device__ __forceinline__ void set_range(const int start, const int end) { + img_start = start, img_end = end; +} }; } // namespace convolution diff --git a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_common.cuh b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_common.cuh index 62ecf9cb..5aba8898 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_common.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_common.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_common.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_common.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. diff --git a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_coxci.cuh b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_coxci.cuh index b5c73175..d2f0bd3a 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_coxci.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_coxci.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_coxci.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_coxci.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -44,102 +46,99 @@ template struct Global2ShareMemVisitor_COxCI; DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitor_COxCI, Layout) - using RegBlockConfig = typename TileCount::RegBlockConfig; - using ThreadConfig = typename TileCount::ThreadConfig; - int stride; - int remain; +using RegBlockConfig = typename TileCount::RegBlockConfig; +using ThreadConfig = typename TileCount::ThreadConfig; +int stride; +int remain; - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int tid = tidy * ThreadConfig::nr_thread_x + tidx; - const int gl_load_y = tid / TileCount::load_x; - const int gl_load_x = tid - gl_load_y * TileCount::load_x; - - copy_t reg[TileCount::reg_h][TileCount::reg_w]; +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; +const int tid = tidy * ThreadConfig::nr_thread_x + tidx; +const int gl_load_y = tid / TileCount::load_x; +const int gl_load_x = tid - gl_load_y * TileCount::load_x; - __device__ __forceinline__ void init_stride(Layout layout) { - stride = layout.batch_stride / TileCount::ldg_load_width; - } +copy_t reg[TileCount::reg_h][TileCount::reg_w]; + +__device__ __forceinline__ void init_stride(Layout layout) { + stride = layout.batch_stride / TileCount::ldg_load_width; +} - __device__ __forceinline__ void first_copy() { +__device__ __forceinline__ void first_copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (h_idx < remain) { - val = g_ptr[h_idx * stride + w_idx]; - } - *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = val; - } else { - *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = - g_ptr[h_idx * stride + w_idx]; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (h_idx < remain) { + val = g_ptr[h_idx * stride + w_idx]; } + *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = val; + } else { + *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = + g_ptr[h_idx * stride + w_idx]; } } } - - __device__ __forceinline__ void copy() { +} + +__device__ __forceinline__ void copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (h_idx < remain) { - val = g_ptr[h_idx * stride + w_idx]; - } - reg[i][j] = val; - } else { - reg[i][j] = g_ptr[h_idx * stride + w_idx]; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (h_idx < remain) { + val = g_ptr[h_idx * stride + w_idx]; } + reg[i][j] = val; + } else { + reg[i][j] = g_ptr[h_idx * stride + w_idx]; } } } - - __device__ __forceinline__ void commit() { +} + +__device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = reg[i][j]; - } + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = reg[i][j]; } } +} - __device__ __forceinline__ int32_t* sh_ptr(int y, int x) { - return &smem[y * TileCount::smem_stride + x]; - } - - __device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { - return reinterpret_cast(sh_ptr(y, x)); - } +__device__ __forceinline__ int32_t* sh_ptr(int y, int x) { + return &smem[y * TileCount::smem_stride + x]; +} - __device__ __forceinline__ void move_forward() { - g_ptr += RegBlockConfig::reg_k_packed / load_width; - } +__device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { + return reinterpret_cast(sh_ptr(y, x)); +} + +__device__ __forceinline__ void move_forward() { + g_ptr += RegBlockConfig::reg_k_packed / load_width; +} }; } // namespace cuda diff --git a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixn.cuh b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixn.cuh index b8e0ba64..38ff7d6a 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixn.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixn.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixn.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixn.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -44,219 +46,211 @@ namespace convolution { template struct Global2ShareMemVisitorIMMA_CIxN; -DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitorIMMA_CIxN, - Layout) - using IMMAConfig = typename TileCount::IMMAConfig; - using WarpTileConfig = typename TileCount::WarpTileConfig; - using ThreadConfig = typename TileCount::ThreadConfig; - int stride; - int remain; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int tid = tidy * ThreadConfig::nr_thread_x + tidx; - const int gl_load_y = tid / TileCount::load_x; - const int gl_load_x = tid - gl_load_y * TileCount::load_x; - - copy_t reg[TileCount::reg_h][TileCount::reg_w][TileCount::reg_d]; +DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitorIMMA_CIxN, Layout) +using IMMAConfig = typename TileCount::IMMAConfig; +using WarpTileConfig = typename TileCount::WarpTileConfig; +using ThreadConfig = typename TileCount::ThreadConfig; +int stride; +int remain; - __device__ __forceinline__ void init_stride(Layout layout) { - stride = layout.channel_stride / TileCount::ldg_load_width; - } +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; +const int tid = tidy * ThreadConfig::nr_thread_x + tidx; +const int gl_load_y = tid / TileCount::load_x; +const int gl_load_x = tid - gl_load_y * TileCount::load_x; + +copy_t reg[TileCount::reg_h][TileCount::reg_w][TileCount::reg_d]; + +__device__ __forceinline__ void init_stride(Layout layout) { + stride = layout.channel_stride / TileCount::ldg_load_width; +} - __device__ __forceinline__ void first_copy() { +__device__ __forceinline__ void first_copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - int batch = w_idx * load_width; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + int batch = w_idx * load_width; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; #pragma unroll - for (int k = 0; k < TileCount::reg_d; ++k) { - int channel = ((h_idx * TileCount::reg_d + k)); - if (check_bounds) { - copy_t val = make_zero(); - if (batch < remain) { - val = g_ptr[channel * stride + w_idx]; - } - *(sh_ptr(h_idx, batch * TileCount::reg_d + k)) = val; - } else { - *(sh_ptr(h_idx, batch * TileCount::reg_d + k)) = - g_ptr[channel * stride + w_idx]; + for (int k = 0; k < TileCount::reg_d; ++k) { + int channel = ((h_idx * TileCount::reg_d + k)); + if (check_bounds) { + copy_t val = make_zero(); + if (batch < remain) { + val = g_ptr[channel * stride + w_idx]; } + *(sh_ptr(h_idx, batch * TileCount::reg_d + k)) = val; + } else { + *(sh_ptr(h_idx, batch * TileCount::reg_d + k)) = + g_ptr[channel * stride + w_idx]; } } } } +} - __device__ __forceinline__ void copy() { +__device__ __forceinline__ void copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - int batch = w_idx * load_width; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + int batch = w_idx * load_width; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; #pragma unroll - for (int k = 0; k < TileCount::reg_d; ++k) { - int channel = (h_idx * TileCount::reg_d + k); - if (check_bounds) { - copy_t val = make_zero(); - if (batch < remain) { - val = g_ptr[channel * stride + w_idx]; - } - reg[i][j][k] = val; - } else { - reg[i][j][k] = g_ptr[channel * stride + w_idx]; + for (int k = 0; k < TileCount::reg_d; ++k) { + int channel = (h_idx * TileCount::reg_d + k); + if (check_bounds) { + copy_t val = make_zero(); + if (batch < remain) { + val = g_ptr[channel * stride + w_idx]; } + reg[i][j][k] = val; + } else { + reg[i][j][k] = g_ptr[channel * stride + w_idx]; } } } } - - __device__ __forceinline__ void commit() { +} + +__device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; #pragma unroll - for (int k = 0; k < TileCount::reg_d; ++k) { - *(sh_ptr(h_idx, w_idx * load_width * TileCount::reg_d + - k)) = reg[i][j][k]; - } + for (int k = 0; k < TileCount::reg_d; ++k) { + *(sh_ptr(h_idx, w_idx * load_width * TileCount::reg_d + k)) = + reg[i][j][k]; } } } - - __device__ __forceinline__ int32_t* sh_ptr(int y, int x) { - return &smem[y * TileCount::smem_stride + x]; - } - - __device__ __forceinline__ void move_forward() { - g_ptr += WarpTileConfig::warp_tile_k * IMMAConfig::wmma_k / 4 * stride; - } -}; +} -DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitorIMMA_CIxN, - Layout) - using IMMAConfig = typename TileCount::IMMAConfig; - using WarpTileConfig = typename TileCount::WarpTileConfig; - using ThreadConfig = typename TileCount::ThreadConfig; - int stride; - int remain; +__device__ __forceinline__ int32_t* sh_ptr(int y, int x) { + return &smem[y * TileCount::smem_stride + x]; +} - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int tid = tidy * ThreadConfig::nr_thread_x + tidx; - const int gl_load_y = tid / TileCount::load_x; - const int gl_load_x = tid - gl_load_y * TileCount::load_x; +__device__ __forceinline__ void move_forward() { + g_ptr += WarpTileConfig::warp_tile_k * IMMAConfig::wmma_k / 4 * stride; +} +}; - copy_t reg[TileCount::reg_h][TileCount::reg_w]; - MEGDNN_STATIC_ASSERT(std::is_same::value == true, - "ldg data type must be int4 for this memory visitor"); +DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitorIMMA_CIxN, Layout) +using IMMAConfig = typename TileCount::IMMAConfig; +using WarpTileConfig = typename TileCount::WarpTileConfig; +using ThreadConfig = typename TileCount::ThreadConfig; +int stride; +int remain; +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; +const int tid = tidy * ThreadConfig::nr_thread_x + tidx; +const int gl_load_y = tid / TileCount::load_x; +const int gl_load_x = tid - gl_load_y * TileCount::load_x; - __device__ __forceinline__ void init_stride(Layout layout) { - stride = layout.channel_stride / TileCount::ldg_load_width; - } +copy_t reg[TileCount::reg_h][TileCount::reg_w]; +MEGDNN_STATIC_ASSERT( + std::is_same::value == true, + "ldg data type must be int4 for this memory visitor"); + +__device__ __forceinline__ void init_stride(Layout layout) { + stride = layout.channel_stride / TileCount::ldg_load_width; +} - __device__ __forceinline__ void first_copy() { +__device__ __forceinline__ void first_copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (w_idx < remain) { - val = g_ptr[h_idx * stride + w_idx]; - } - *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = val; - } else { - *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = - g_ptr[h_idx * stride + w_idx]; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (w_idx < remain) { + val = g_ptr[h_idx * stride + w_idx]; } + *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = val; + } else { + *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = + g_ptr[h_idx * stride + w_idx]; } } } +} - __device__ __forceinline__ void copy() { +__device__ __forceinline__ void copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (w_idx < remain) { - val = g_ptr[h_idx * stride + w_idx]; - } - reg[i][j] = val; - } else { - reg[i][j] = g_ptr[h_idx * stride + w_idx]; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (w_idx < remain) { + val = g_ptr[h_idx * stride + w_idx]; } + reg[i][j] = val; + } else { + reg[i][j] = g_ptr[h_idx * stride + w_idx]; } } } +} - __device__ __forceinline__ void commit() { +__device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = reg[i][j]; - } + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = reg[i][j]; } } +} - __device__ __forceinline__ int32_t* sh_ptr(int y, int x) { - return &smem[y * TileCount::smem_stride + x]; - } +__device__ __forceinline__ int32_t* sh_ptr(int y, int x) { + return &smem[y * TileCount::smem_stride + x]; +} - __device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { - return reinterpret_cast(sh_ptr(y, x)); - } +__device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { + return reinterpret_cast(sh_ptr(y, x)); +} - __device__ __forceinline__ void move_forward() { - g_ptr += WarpTileConfig::warp_tile_k * stride; - } +__device__ __forceinline__ void move_forward() { + g_ptr += WarpTileConfig::warp_tile_k * stride; +} }; #undef MEGDNN_COMMA diff --git a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwixn.cuh b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwixn.cuh index 844a7e7d..0c44f5d5 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwixn.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwixn.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwixn.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwixn.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -44,173 +46,167 @@ namespace convolution { template struct Global2ShareMemVisitorIMMA_CIxWIxN; -DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitorIMMA_CIxWIxN, - Layout) - using IMMAConfig = typename TileCount::IMMAConfig; - using WarpTileConfig = typename TileCount::WarpTileConfig; - using ThreadConfig = typename TileCount::ThreadConfig; - int stride; - int remain; - int width_stride; - int width_start; - int width_end; +DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitorIMMA_CIxWIxN, Layout) +using IMMAConfig = typename TileCount::IMMAConfig; +using WarpTileConfig = typename TileCount::WarpTileConfig; +using ThreadConfig = typename TileCount::ThreadConfig; +int stride; +int remain; +int width_stride; +int width_start; +int width_end; - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int tid = tidy * ThreadConfig::nr_thread_x + tidx; - const int gl_load_y = tid / TileCount::load_x; - const int gl_load_x = tid - gl_load_y * TileCount::load_x; +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; +const int tid = tidy * ThreadConfig::nr_thread_x + tidx; +const int gl_load_y = tid / TileCount::load_x; +const int gl_load_x = tid - gl_load_y * TileCount::load_x; - copy_t reg[TileCount::reg_h][TileCount::reg_w][TileCount::reg_d]; - MEGDNN_STATIC_ASSERT(std::is_same::value == true, - "ldg data type must be int4 for this memory visitor"); - - __device__ __forceinline__ void init_stride(Layout layout) { - stride = layout.channel_stride / TileCount::ldg_load_width; - width_stride = layout.width_stride / TileCount::ldg_load_width; - } +copy_t reg[TileCount::reg_h][TileCount::reg_w][TileCount::reg_d]; +MEGDNN_STATIC_ASSERT( + std::is_same::value == true, + "ldg data type must be int4 for this memory visitor"); +__device__ __forceinline__ void init_stride(Layout layout) { + stride = layout.channel_stride / TileCount::ldg_load_width; + width_stride = layout.width_stride / TileCount::ldg_load_width; +} - __device__ __forceinline__ void first_copy() { +__device__ __forceinline__ void first_copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - int batch = (w_idx << 2); - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t c0 = make_zero(); - copy_t c1 = make_zero(); - copy_t c2 = make_zero(); - copy_t c3 = make_zero(); - if (h_idx >= width_start && h_idx < width_end && - batch < remain) { - c0 = g_ptr[0 * stride + h_idx * width_stride + w_idx]; - c1 = g_ptr[1 * stride + h_idx * width_stride + w_idx]; - c2 = g_ptr[2 * stride + h_idx * width_stride + w_idx]; - c3 = g_ptr[3 * stride + h_idx * width_stride + w_idx]; - } - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = - make_int4(c0.x, c1.x, c2.x, c3.x); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = - make_int4(c0.y, c1.y, c2.y, c3.y); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = - make_int4(c0.z, c1.z, c2.z, c3.z); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = - make_int4(c0.w, c1.w, c2.w, c3.w); - } else { - copy_t c0 = make_zero(); - copy_t c1 = make_zero(); - copy_t c2 = make_zero(); - copy_t c3 = make_zero(); - if (h_idx >= width_start && h_idx < width_end) { - c0 = g_ptr[0 * stride + h_idx * width_stride + w_idx]; - c1 = g_ptr[1 * stride + h_idx * width_stride + w_idx]; - c2 = g_ptr[2 * stride + h_idx * width_stride + w_idx]; - c3 = g_ptr[3 * stride + h_idx * width_stride + w_idx]; - } - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = - make_int4(c0.x, c1.x, c2.x, c3.x); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = - make_int4(c0.y, c1.y, c2.y, c3.y); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = - make_int4(c0.z, c1.z, c2.z, c3.z); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = - make_int4(c0.w, c1.w, c2.w, c3.w); + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + int batch = (w_idx << 2); + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t c0 = make_zero(); + copy_t c1 = make_zero(); + copy_t c2 = make_zero(); + copy_t c3 = make_zero(); + if (h_idx >= width_start && h_idx < width_end && batch < remain) { + c0 = g_ptr[0 * stride + h_idx * width_stride + w_idx]; + c1 = g_ptr[1 * stride + h_idx * width_stride + w_idx]; + c2 = g_ptr[2 * stride + h_idx * width_stride + w_idx]; + c3 = g_ptr[3 * stride + h_idx * width_stride + w_idx]; } + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = + make_int4(c0.x, c1.x, c2.x, c3.x); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = + make_int4(c0.y, c1.y, c2.y, c3.y); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = + make_int4(c0.z, c1.z, c2.z, c3.z); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = + make_int4(c0.w, c1.w, c2.w, c3.w); + } else { + copy_t c0 = make_zero(); + copy_t c1 = make_zero(); + copy_t c2 = make_zero(); + copy_t c3 = make_zero(); + if (h_idx >= width_start && h_idx < width_end) { + c0 = g_ptr[0 * stride + h_idx * width_stride + w_idx]; + c1 = g_ptr[1 * stride + h_idx * width_stride + w_idx]; + c2 = g_ptr[2 * stride + h_idx * width_stride + w_idx]; + c3 = g_ptr[3 * stride + h_idx * width_stride + w_idx]; + } + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = + make_int4(c0.x, c1.x, c2.x, c3.x); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = + make_int4(c0.y, c1.y, c2.y, c3.y); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = + make_int4(c0.z, c1.z, c2.z, c3.z); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = + make_int4(c0.w, c1.w, c2.w, c3.w); } } } +} - __device__ __forceinline__ void copy() { +__device__ __forceinline__ void copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - int batch = (w_idx << 2); - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t c0 = make_zero(); - copy_t c1 = make_zero(); - copy_t c2 = make_zero(); - copy_t c3 = make_zero(); - if (h_idx >= width_start && h_idx < width_end && - batch < remain) { - c0 = g_ptr[0 * stride + h_idx * width_stride + w_idx]; - c1 = g_ptr[1 * stride + h_idx * width_stride + w_idx]; - c2 = g_ptr[2 * stride + h_idx * width_stride + w_idx]; - c3 = g_ptr[3 * stride + h_idx * width_stride + w_idx]; - } - reg[i][j][0] = make_int4(c0.x, c1.x, c2.x, c3.x); - reg[i][j][1] = make_int4(c0.y, c1.y, c2.y, c3.y); - reg[i][j][2] = make_int4(c0.z, c1.z, c2.z, c3.z); - reg[i][j][3] = make_int4(c0.w, c1.w, c2.w, c3.w); - } else { - copy_t c0 = make_zero(); - copy_t c1 = make_zero(); - copy_t c2 = make_zero(); - copy_t c3 = make_zero(); - if (h_idx >= width_start && h_idx < width_end) { - c0 = g_ptr[0 * stride + h_idx * width_stride + w_idx]; - c1 = g_ptr[1 * stride + h_idx * width_stride + w_idx]; - c2 = g_ptr[2 * stride + h_idx * width_stride + w_idx]; - c3 = g_ptr[3 * stride + h_idx * width_stride + w_idx]; - } - reg[i][j][0] = make_int4(c0.x, c1.x, c2.x, c3.x); - reg[i][j][1] = make_int4(c0.y, c1.y, c2.y, c3.y); - reg[i][j][2] = make_int4(c0.z, c1.z, c2.z, c3.z); - reg[i][j][3] = make_int4(c0.w, c1.w, c2.w, c3.w); + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + int batch = (w_idx << 2); + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t c0 = make_zero(); + copy_t c1 = make_zero(); + copy_t c2 = make_zero(); + copy_t c3 = make_zero(); + if (h_idx >= width_start && h_idx < width_end && batch < remain) { + c0 = g_ptr[0 * stride + h_idx * width_stride + w_idx]; + c1 = g_ptr[1 * stride + h_idx * width_stride + w_idx]; + c2 = g_ptr[2 * stride + h_idx * width_stride + w_idx]; + c3 = g_ptr[3 * stride + h_idx * width_stride + w_idx]; + } + reg[i][j][0] = make_int4(c0.x, c1.x, c2.x, c3.x); + reg[i][j][1] = make_int4(c0.y, c1.y, c2.y, c3.y); + reg[i][j][2] = make_int4(c0.z, c1.z, c2.z, c3.z); + reg[i][j][3] = make_int4(c0.w, c1.w, c2.w, c3.w); + } else { + copy_t c0 = make_zero(); + copy_t c1 = make_zero(); + copy_t c2 = make_zero(); + copy_t c3 = make_zero(); + if (h_idx >= width_start && h_idx < width_end) { + c0 = g_ptr[0 * stride + h_idx * width_stride + w_idx]; + c1 = g_ptr[1 * stride + h_idx * width_stride + w_idx]; + c2 = g_ptr[2 * stride + h_idx * width_stride + w_idx]; + c3 = g_ptr[3 * stride + h_idx * width_stride + w_idx]; } + reg[i][j][0] = make_int4(c0.x, c1.x, c2.x, c3.x); + reg[i][j][1] = make_int4(c0.y, c1.y, c2.y, c3.y); + reg[i][j][2] = make_int4(c0.z, c1.z, c2.z, c3.z); + reg[i][j][3] = make_int4(c0.w, c1.w, c2.w, c3.w); } } } +} - __device__ __forceinline__ void commit() { +__device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = reg[i][j][0]; - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = reg[i][j][1]; - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = reg[i][j][2]; - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = reg[i][j][3]; - } + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = reg[i][j][0]; + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = reg[i][j][1]; + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = reg[i][j][2]; + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = reg[i][j][3]; } } +} - __device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { - return reinterpret_cast(sh_ptr(y, x)); - } - __device__ __forceinline__ int32_t* sh_ptr(int y, int x) { - return &smem[y * TileCount::smem_stride + x]; - } +__device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { + return reinterpret_cast(sh_ptr(y, x)); +} +__device__ __forceinline__ int32_t* sh_ptr(int y, int x) { + return &smem[y * TileCount::smem_stride + x]; +} - __device__ __forceinline__ void move_forward() { - g_ptr += WarpTileConfig::warp_tile_k * IMMAConfig::wmma_k / 4 * stride; - } +__device__ __forceinline__ void move_forward() { + g_ptr += WarpTileConfig::warp_tile_k * IMMAConfig::wmma_k / 4 * stride; +} - __device__ __forceinline__ void set_range(const int start, const int end) { - width_start = start, width_end = end; - } +__device__ __forceinline__ void set_range(const int start, const int end) { + width_start = start, width_end = end; +} }; #undef MEGDNN_COMMA diff --git a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwoxn.cuh b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwoxn.cuh index 8833e84f..cb14ea42 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwoxn.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwoxn.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwoxn.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_cixwoxn.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -44,197 +46,210 @@ namespace convolution { template struct Global2ShareMemVisitorIMMA_CIxWOxN; -DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitorIMMA_CIxWOxN, - Layout) - using IMMAConfig = typename TileCount::IMMAConfig; - using WarpTileConfig = typename TileCount::WarpTileConfig; - using ThreadConfig = typename TileCount::ThreadConfig; - int stride; - int remain; - int sw; - int width_stride; - int width_start; - int width_end; - - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int tid = tidy * ThreadConfig::nr_thread_x + tidx; - const int gl_load_y = tid / TileCount::load_x; - const int gl_load_x = tid - gl_load_y * TileCount::load_x; - - copy_t reg[TileCount::reg_h][TileCount::reg_w][TileCount::reg_d]; - MEGDNN_STATIC_ASSERT(std::is_same::value == true, - "ldg data type must be int4 for this memory visitor"); - - __device__ __forceinline__ void init_stride(Layout layout) { - stride = layout.channel_stride / TileCount::ldg_load_width; - width_stride = layout.width_stride / TileCount::ldg_load_width; - } +DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitorIMMA_CIxWOxN, Layout) +using IMMAConfig = typename TileCount::IMMAConfig; +using WarpTileConfig = typename TileCount::WarpTileConfig; +using ThreadConfig = typename TileCount::ThreadConfig; +int stride; +int remain; +int sw; +int width_stride; +int width_start; +int width_end; + +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; +const int tid = tidy * ThreadConfig::nr_thread_x + tidx; +const int gl_load_y = tid / TileCount::load_x; +const int gl_load_x = tid - gl_load_y * TileCount::load_x; + +copy_t reg[TileCount::reg_h][TileCount::reg_w][TileCount::reg_d]; +MEGDNN_STATIC_ASSERT( + std::is_same::value == true, + "ldg data type must be int4 for this memory visitor"); - __device__ __forceinline__ void first_copy() { +__device__ __forceinline__ void init_stride(Layout layout) { + stride = layout.channel_stride / TileCount::ldg_load_width; + width_stride = layout.width_stride / TileCount::ldg_load_width; +} + +__device__ __forceinline__ void first_copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - int width = (w_idx >> (IMMAConfig::wmma_n_bit - 2)) * sw; - int batch = (w_idx & ((IMMAConfig::wmma_n >> 2) - 1)); - if (check_bounds) { - copy_t c0 = make_zero(); - copy_t c1 = make_zero(); - copy_t c2 = make_zero(); - copy_t c3 = make_zero(); - if (width >= width_start && width < width_end && - (batch << 2) < remain) { - c0 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 0) * stride]; - c1 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 1) * stride]; - c2 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 2) * stride]; - c3 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 3) * stride]; - } - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = - make_int4(c0.x, c1.x, c2.x, c3.x); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = - make_int4(c0.y, c1.y, c2.y, c3.y); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = - make_int4(c0.z, c1.z, c2.z, c3.z); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = - make_int4(c0.w, c1.w, c2.w, c3.w); - } else { - copy_t c0 = make_zero(); - copy_t c1 = make_zero(); - copy_t c2 = make_zero(); - copy_t c3 = make_zero(); - if (width >= width_start && width < width_end) { - c0 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 0) * stride]; - c1 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 1) * stride]; - c2 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 2) * stride]; - c3 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 3) * stride]; - } - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = - make_int4(c0.x, c1.x, c2.x, c3.x); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = - make_int4(c0.y, c1.y, c2.y, c3.y); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = - make_int4(c0.z, c1.z, c2.z, c3.z); - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = - make_int4(c0.w, c1.w, c2.w, c3.w); + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + int width = (w_idx >> (IMMAConfig::wmma_n_bit - 2)) * sw; + int batch = (w_idx & ((IMMAConfig::wmma_n >> 2) - 1)); + if (check_bounds) { + copy_t c0 = make_zero(); + copy_t c1 = make_zero(); + copy_t c2 = make_zero(); + copy_t c3 = make_zero(); + if (width >= width_start && width < width_end && + (batch << 2) < remain) { + c0 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 0) * stride]; + c1 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 1) * stride]; + c2 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 2) * stride]; + c3 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 3) * stride]; + } + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = + make_int4(c0.x, c1.x, c2.x, c3.x); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = + make_int4(c0.y, c1.y, c2.y, c3.y); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = + make_int4(c0.z, c1.z, c2.z, c3.z); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = + make_int4(c0.w, c1.w, c2.w, c3.w); + } else { + copy_t c0 = make_zero(); + copy_t c1 = make_zero(); + copy_t c2 = make_zero(); + copy_t c3 = make_zero(); + if (width >= width_start && width < width_end) { + c0 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 0) * stride]; + c1 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 1) * stride]; + c2 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 2) * stride]; + c3 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 3) * stride]; } + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = + make_int4(c0.x, c1.x, c2.x, c3.x); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = + make_int4(c0.y, c1.y, c2.y, c3.y); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = + make_int4(c0.z, c1.z, c2.z, c3.z); + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = + make_int4(c0.w, c1.w, c2.w, c3.w); } } } +} - __device__ __forceinline__ void copy() { +__device__ __forceinline__ void copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - int width = (w_idx >> (IMMAConfig::wmma_n_bit - 2)) * sw; - int batch = (w_idx & ((IMMAConfig::wmma_n >> 2) - 1)); - if (check_bounds) { - copy_t c0 = make_zero(); - copy_t c1 = make_zero(); - copy_t c2 = make_zero(); - copy_t c3 = make_zero(); - if (width >= width_start && width < width_end && - (batch << 2) < remain) { - c0 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 0) * stride]; - c1 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 1) * stride]; - c2 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 2) * stride]; - c3 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 3) * stride]; - } - reg[i][j][0] = make_int4(c0.x, c1.x, c2.x, c3.x); - reg[i][j][1] = make_int4(c0.y, c1.y, c2.y, c3.y); - reg[i][j][2] = make_int4(c0.z, c1.z, c2.z, c3.z); - reg[i][j][3] = make_int4(c0.w, c1.w, c2.w, c3.w); - } else { - copy_t c0 = make_zero(); - copy_t c1 = make_zero(); - copy_t c2 = make_zero(); - copy_t c3 = make_zero(); - if (width >= width_start && width < width_end) { - c0 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 0) * stride]; - c1 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 1) * stride]; - c2 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 2) * stride]; - c3 = g_ptr[width * width_stride + batch + - ((h_idx << 2) + 3) * stride]; - } - reg[i][j][0] = make_int4(c0.x, c1.x, c2.x, c3.x); - reg[i][j][1] = make_int4(c0.y, c1.y, c2.y, c3.y); - reg[i][j][2] = make_int4(c0.z, c1.z, c2.z, c3.z); - reg[i][j][3] = make_int4(c0.w, c1.w, c2.w, c3.w); + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + int width = (w_idx >> (IMMAConfig::wmma_n_bit - 2)) * sw; + int batch = (w_idx & ((IMMAConfig::wmma_n >> 2) - 1)); + if (check_bounds) { + copy_t c0 = make_zero(); + copy_t c1 = make_zero(); + copy_t c2 = make_zero(); + copy_t c3 = make_zero(); + if (width >= width_start && width < width_end && + (batch << 2) < remain) { + c0 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 0) * stride]; + c1 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 1) * stride]; + c2 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 2) * stride]; + c3 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 3) * stride]; } + reg[i][j][0] = make_int4(c0.x, c1.x, c2.x, c3.x); + reg[i][j][1] = make_int4(c0.y, c1.y, c2.y, c3.y); + reg[i][j][2] = make_int4(c0.z, c1.z, c2.z, c3.z); + reg[i][j][3] = make_int4(c0.w, c1.w, c2.w, c3.w); + } else { + copy_t c0 = make_zero(); + copy_t c1 = make_zero(); + copy_t c2 = make_zero(); + copy_t c3 = make_zero(); + if (width >= width_start && width < width_end) { + c0 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 0) * stride]; + c1 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 1) * stride]; + c2 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 2) * stride]; + c3 = + g_ptr[width * width_stride + batch + + ((h_idx << 2) + 3) * stride]; + } + reg[i][j][0] = make_int4(c0.x, c1.x, c2.x, c3.x); + reg[i][j][1] = make_int4(c0.y, c1.y, c2.y, c3.y); + reg[i][j][2] = make_int4(c0.z, c1.z, c2.z, c3.z); + reg[i][j][3] = make_int4(c0.w, c1.w, c2.w, c3.w); } } } +} - __device__ __forceinline__ void commit() { +__device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = reg[i][j][0]; - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = reg[i][j][1]; - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = reg[i][j][2]; - *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = reg[i][j][3]; - } + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4))) = reg[i][j][0]; + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 4)) = reg[i][j][1]; + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 8)) = reg[i][j][2]; + *(sh_ptr_as_copy_t(h_idx, (w_idx << 4) + 12)) = reg[i][j][3]; } } +} - template - __device__ __forceinline__ T* sh_ptr_as(int y, int x) { - return reinterpret_cast(sh_ptr(y, x)); - } +template +__device__ __forceinline__ T* sh_ptr_as(int y, int x) { + return reinterpret_cast(sh_ptr(y, x)); +} - __device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { - return reinterpret_cast(sh_ptr(y, x)); - } +__device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { + return reinterpret_cast(sh_ptr(y, x)); +} - __device__ __forceinline__ int32_t* sh_ptr(int y, int x) { - return &smem[y * TileCount::smem_stride + x]; - } +__device__ __forceinline__ int32_t* sh_ptr(int y, int x) { + return &smem[y * TileCount::smem_stride + x]; +} - __device__ __forceinline__ void move_forward() { - g_ptr += WarpTileConfig::warp_tile_k * IMMAConfig::wmma_k / 4 * stride; - } +__device__ __forceinline__ void move_forward() { + g_ptr += WarpTileConfig::warp_tile_k * IMMAConfig::wmma_k / 4 * stride; +} - __device__ __forceinline__ void set_range(const int start, const int end) { - width_start = start, width_end = end; - } +__device__ __forceinline__ void set_range(const int start, const int end) { + width_start = start, width_end = end; +} }; #undef MEGDNN_COMMA diff --git a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_fwxco.cuh b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_fwxco.cuh index 2669905b..6b576a34 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_fwxco.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_fwxco.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_fwxco.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_visitor/global_memory_visitor_imma_fwxco.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -44,109 +46,106 @@ namespace convolution { template struct Global2ShareMemVisitorIMMA_FWxCO; -DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitorIMMA_FWxCO, - Layout) - using IMMAConfig = typename TileCount::IMMAConfig; - using WarpTileConfig = typename TileCount::WarpTileConfig; - using ThreadConfig = typename TileCount::ThreadConfig; - int stride; - int remain; - int ch_stride; +DEF_GLOBAL_MEMORY_VISITOR(Global2ShareMemVisitorIMMA_FWxCO, Layout) +using IMMAConfig = typename TileCount::IMMAConfig; +using WarpTileConfig = typename TileCount::WarpTileConfig; +using ThreadConfig = typename TileCount::ThreadConfig; +int stride; +int remain; +int ch_stride; - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int tid = tidy * ThreadConfig::nr_thread_x + tidx; - const int gl_load_y = tid / TileCount::load_x; - const int gl_load_x = tid - gl_load_y * TileCount::load_x; +const int tidx = threadIdx.x; +const int tidy = threadIdx.y; +const int tid = tidy * ThreadConfig::nr_thread_x + tidx; +const int gl_load_y = tid / TileCount::load_x; +const int gl_load_x = tid - gl_load_y * TileCount::load_x; - copy_t reg[TileCount::reg_h][TileCount::reg_w]; - MEGDNN_STATIC_ASSERT(std::is_same::value == true, - "ldg data type must be int4 for this memory visitor"); +copy_t reg[TileCount::reg_h][TileCount::reg_w]; +MEGDNN_STATIC_ASSERT( + std::is_same::value == true, + "ldg data type must be int4 for this memory visitor"); - __device__ __forceinline__ void init_stride(Layout layout) { - stride = layout.width_stride / TileCount::ldg_load_width; - ch_stride = layout.channel_stride / TileCount::ldg_load_width; - } +__device__ __forceinline__ void init_stride(Layout layout) { + stride = layout.width_stride / TileCount::ldg_load_width; + ch_stride = layout.channel_stride / TileCount::ldg_load_width; +} - __device__ __forceinline__ void first_copy() { +__device__ __forceinline__ void first_copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (w_idx < remain) { - val = g_ptr[h_idx * stride + w_idx]; - } - *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = val; - } else { - *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = - g_ptr[h_idx * stride + w_idx]; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (w_idx < remain) { + val = g_ptr[h_idx * stride + w_idx]; } + *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = val; + } else { + *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = + g_ptr[h_idx * stride + w_idx]; } } } +} - __device__ __forceinline__ void copy() { +__device__ __forceinline__ void copy() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - if (check_bounds) { - copy_t val = make_zero(); - if (w_idx < remain) { - val = g_ptr[h_idx * stride + w_idx]; - } - reg[i][j] = val; - } else { - reg[i][j] = g_ptr[h_idx * stride + w_idx]; + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + if (check_bounds) { + copy_t val = make_zero(); + if (w_idx < remain) { + val = g_ptr[h_idx * stride + w_idx]; } + reg[i][j] = val; + } else { + reg[i][j] = g_ptr[h_idx * stride + w_idx]; } } } +} - __device__ __forceinline__ void commit() { +__device__ __forceinline__ void commit() { #pragma unroll - for (int i = 0; i < TileCount::reg_h; ++i) { - int h_idx = gl_load_y + i * TileCount::load_y; - if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) - continue; + for (int i = 0; i < TileCount::reg_h; ++i) { + int h_idx = gl_load_y + i * TileCount::load_y; + if (TileCount::check_bounds_h && h_idx >= TileCount::smem_h) + continue; #pragma unroll - for (int j = 0; j < TileCount::reg_w; ++j) { - int w_idx = gl_load_x + j * TileCount::load_x; - if (TileCount::check_bounds_w && - w_idx >= TileCount::smem_load_x) - continue; - *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = reg[i][j]; - } + for (int j = 0; j < TileCount::reg_w; ++j) { + int w_idx = gl_load_x + j * TileCount::load_x; + if (TileCount::check_bounds_w && w_idx >= TileCount::smem_load_x) + continue; + *(sh_ptr_as_copy_t(h_idx, w_idx * load_width)) = reg[i][j]; } } +} - __device__ __forceinline__ int32_t* sh_ptr(int y, int x) { - return &smem[y * TileCount::smem_stride + x]; - } +__device__ __forceinline__ int32_t* sh_ptr(int y, int x) { + return &smem[y * TileCount::smem_stride + x]; +} - __device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { - return reinterpret_cast(sh_ptr(y, x)); - } +__device__ __forceinline__ copy_t* sh_ptr_as_copy_t(int y, int x) { + return reinterpret_cast(sh_ptr(y, x)); +} - __device__ __forceinline__ void move_forward() { - g_ptr += WarpTileConfig::warp_tile_k * ch_stride; - } +__device__ __forceinline__ void move_forward() { + g_ptr += WarpTileConfig::warp_tile_k * ch_stride; +} }; #undef MEGDNN_COMMA diff --git a/dnn/src/cuda/convolution_helper/global_memory_writer/global_memory_writer.cuh b/dnn/src/cuda/convolution_helper/global_memory_writer/global_memory_writer.cuh index 84d01f91..36c80a4c 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_writer/global_memory_writer.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_writer/global_memory_writer.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** diff --git a/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer.cuh b/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer.cuh index dfb6003c..b495764d 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -47,16 +49,16 @@ struct IConvGlobalMemoryWriter { int block_batch_remain; int block_out_channel_remain; - __device__ __forceinline__ void init(int32_t* /* smem */, - const float alpha_, - const float beta_) { + __device__ __forceinline__ void init( + int32_t* /* smem */, const float alpha_, const float beta_) { alpha = alpha_, beta = beta_; } - template - __device__ __forceinline__ void write(BiasVisitor bias, Epilogue epilogue, - BlockConsumer block_consumer) { + template < + bool check_bounds, typename BiasVisitor, typename Epilogue, + typename BlockConsumer> + __device__ __forceinline__ void write( + BiasVisitor bias, Epilogue epilogue, BlockConsumer block_consumer) { static constexpr bool use_wide_store = !(RegBlockConfig::reg_n & 0x1); static constexpr int pack_size_bit = RegBlockConfig::pack_size_bit; @@ -69,15 +71,15 @@ struct IConvGlobalMemoryWriter { #pragma unroll for (int j = 0; j < (RegBlockConfig::reg_n >> 1); ++j) { int j2 = (j << 1); - int out_channel = ((tidy + i * ThreadConfig::nr_thread_y) - << pack_size_bit); + int out_channel = + ((tidy + i * ThreadConfig::nr_thread_y) << pack_size_bit); int batch = (tidx << 1) + j2 * ThreadConfig::nr_thread_x; int ipack = (i << pack_size_bit); - float4 f_conv0 = - make_float4(block_consumer.reg_acc[j2][ipack], - block_consumer.reg_acc[j2][ipack + 1], - block_consumer.reg_acc[j2][ipack + 2], - block_consumer.reg_acc[j2][ipack + 3]); + float4 f_conv0 = make_float4( + block_consumer.reg_acc[j2][ipack], + block_consumer.reg_acc[j2][ipack + 1], + block_consumer.reg_acc[j2][ipack + 2], + block_consumer.reg_acc[j2][ipack + 3]); float4 f_conv1 = make_float4( block_consumer.reg_acc[j2 + 1][ipack], block_consumer.reg_acc[j2 + 1][ipack + 1], @@ -86,23 +88,23 @@ struct IConvGlobalMemoryWriter { if (!check_bounds) { float4 f_bias0 = bias.at(batch, out_channel, 0, 0); float4 f_bias1 = bias.at(batch + 1, out_channel, 0, 0); - epilogue.apply(alpha, f_conv0, f_conv1, beta, f_bias0, - f_bias1, batch, out_channel, 0, 0); + epilogue.apply( + alpha, f_conv0, f_conv1, beta, f_bias0, f_bias1, batch, + out_channel, 0, 0); } else if (out_channel < block_out_channel_remain) { if (((block_batch_remain & 0x1) == 0) && batch + 2 <= block_batch_remain) { float4 f_bias0 = bias.at(batch, out_channel, 0, 0); - float4 f_bias1 = - bias.at(batch + 1, out_channel, 0, 0); - epilogue.apply(alpha, f_conv0, f_conv1, beta, - f_bias0, f_bias1, batch, out_channel, - 0, 0); + float4 f_bias1 = bias.at(batch + 1, out_channel, 0, 0); + epilogue.apply( + alpha, f_conv0, f_conv1, beta, f_bias0, f_bias1, + batch, out_channel, 0, 0); } else { -#define store(_i) \ - if (batch + (_i) < block_batch_remain) { \ - float4 f_bias##_i = bias.at(batch + (_i), out_channel, 0, 0); \ - epilogue.apply(alpha, f_conv##_i, beta, f_bias##_i, batch + (_i), \ - out_channel, 0, 0); \ +#define store(_i) \ + if (batch + (_i) < block_batch_remain) { \ + float4 f_bias##_i = bias.at(batch + (_i), out_channel, 0, 0); \ + epilogue.apply( \ + alpha, f_conv##_i, beta, f_bias##_i, batch + (_i), out_channel, 0, 0); \ } store(0); store(1); @@ -116,13 +118,12 @@ struct IConvGlobalMemoryWriter { for (int i = 0; i < RegBlockConfig::reg_m_packed; ++i) { #pragma unroll for (int j = 0; j < RegBlockConfig::reg_n; ++j) { - int out_channel = ((tidy + i * ThreadConfig::nr_thread_y) - << pack_size_bit); + int out_channel = + ((tidy + i * ThreadConfig::nr_thread_y) << pack_size_bit); int batch = tidx + j * ThreadConfig::nr_thread_x; int ipack = (i << pack_size_bit); - if (check_bounds && - (out_channel >= block_out_channel_remain || - batch >= block_batch_remain)) { + if (check_bounds && (out_channel >= block_out_channel_remain || + batch >= block_batch_remain)) { } else { float4 f_conv = make_float4( block_consumer.reg_acc[j][ipack], @@ -130,8 +131,8 @@ struct IConvGlobalMemoryWriter { block_consumer.reg_acc[j][ipack + 2], block_consumer.reg_acc[j][ipack + 3]); float4 f_bias = bias.at(batch, out_channel, 0, 0); - epilogue.apply(alpha, f_conv, beta, f_bias, batch, - out_channel, 0, 0); + epilogue.apply( + alpha, f_conv, beta, f_bias, batch, out_channel, 0, 0); } } } diff --git a/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_coxhw.cuh b/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_coxhw.cuh index 2d364139..8cedaf10 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_coxhw.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_coxhw.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_coxhw.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_coxhw.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -47,18 +49,17 @@ struct IConvGlobalMemoryWriter_COxHW { int block_out_height_width_remain; int block_out_channel_remain; - __device__ __forceinline__ void init(int32_t* /* smem */, - const float alpha_, - const float beta_) { + __device__ __forceinline__ void init( + int32_t* /* smem */, const float alpha_, const float beta_) { alpha = alpha_, beta = beta_; } - template - __device__ __forceinline__ void write(BiasVisitor bias, Epilogue epilogue, - BlockConsumer block_consumer) { - static constexpr bool use_wide_store = - !(RegBlockConfig::reg_width & 0x1); + template < + bool check_bounds, typename BiasVisitor, typename Epilogue, + typename BlockConsumer> + __device__ __forceinline__ void write( + BiasVisitor bias, Epilogue epilogue, BlockConsumer block_consumer) { + static constexpr bool use_wide_store = !(RegBlockConfig::reg_width & 0x1); static constexpr int pack_size_bit = RegBlockConfig::pack_size_bit; const int tidx = threadIdx.x; @@ -70,50 +71,50 @@ struct IConvGlobalMemoryWriter_COxHW { #pragma unroll for (int j = 0; j < (RegBlockConfig::reg_width >> 1); ++j) { int j2 = (j << 1); - int out_channel = ((tidy + i * ThreadConfig::nr_thread_y) - << pack_size_bit); - int out_height_width = - (tidx << 1) + j2 * ThreadConfig::nr_thread_x; + int out_channel = + ((tidy + i * ThreadConfig::nr_thread_y) << pack_size_bit); + int out_height_width = (tidx << 1) + j2 * ThreadConfig::nr_thread_x; int ipack = (i << pack_size_bit); - float4 f_conv0 = - make_float4(block_consumer.reg_acc[j2][ipack], - block_consumer.reg_acc[j2][ipack + 1], - block_consumer.reg_acc[j2][ipack + 2], - block_consumer.reg_acc[j2][ipack + 3]); + float4 f_conv0 = make_float4( + block_consumer.reg_acc[j2][ipack], + block_consumer.reg_acc[j2][ipack + 1], + block_consumer.reg_acc[j2][ipack + 2], + block_consumer.reg_acc[j2][ipack + 3]); float4 f_conv1 = make_float4( block_consumer.reg_acc[j2 + 1][ipack], block_consumer.reg_acc[j2 + 1][ipack + 1], block_consumer.reg_acc[j2 + 1][ipack + 2], block_consumer.reg_acc[j2 + 1][ipack + 3]); -// if (threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && i == 0 && j == 0) { -// printf("acc = %f, %f, %f, %f\n", f_conv0.x, f_conv0.y, f_conv0.z, f_conv0.w); -// } - + // if (threadIdx.x == 0 && threadIdx.y == 0 && + // blockIdx.x == 0 && blockIdx.y == 0 && + // blockIdx.z == 0 && i == 0 && j == 0) { + // printf("acc = %f, %f, %f, %f\n", + // f_conv0.x, f_conv0.y, f_conv0.z, + // f_conv0.w); + // } + if (!check_bounds) { - float4 f_bias0 = - bias.at(0, out_channel, out_height_width); - float4 f_bias1 = - bias.at(0, out_channel, out_height_width + 1); - epilogue.apply(alpha, f_conv0, f_conv1, beta, f_bias0, - f_bias1, 0, out_channel, - out_height_width); + float4 f_bias0 = bias.at(0, out_channel, out_height_width); + float4 f_bias1 = bias.at(0, out_channel, out_height_width + 1); + epilogue.apply( + alpha, f_conv0, f_conv1, beta, f_bias0, f_bias1, 0, + out_channel, out_height_width); } else if (out_channel < block_out_channel_remain) { if (((block_out_height_width_remain & 0x1) == 0) && - out_height_width + 2 <= - block_out_height_width_remain) { - float4 f_bias0 = - bias.at(0, out_channel, out_height_width); - float4 f_bias1 = bias.at(0, out_channel, - out_height_width + 1); - epilogue.apply(alpha, f_conv0, f_conv1, beta, - f_bias0, f_bias1, 0, out_channel, - out_height_width); + out_height_width + 2 <= block_out_height_width_remain) { + float4 f_bias0 = bias.at(0, out_channel, out_height_width); + float4 f_bias1 = + bias.at(0, out_channel, out_height_width + 1); + epilogue.apply( + alpha, f_conv0, f_conv1, beta, f_bias0, f_bias1, 0, + out_channel, out_height_width); } else { -#define store(_i) \ - if (out_height_width + (_i) < block_out_height_width_remain) { \ - float4 f_bias##_i = bias.at(0, out_channel, out_height_width); \ - epilogue.apply(alpha, f_conv##_i, beta, f_bias##_i, 0, out_channel, \ - out_height_width + (_i)); \ +#define store(_i) \ + if (out_height_width + (_i) < block_out_height_width_remain) { \ + float4 f_bias##_i = bias.at(0, out_channel, out_height_width); \ + epilogue.apply( \ + alpha, f_conv##_i, beta, f_bias##_i, 0, out_channel, \ + out_height_width + (_i)); \ } store(0); store(1); @@ -127,8 +128,8 @@ struct IConvGlobalMemoryWriter_COxHW { for (int i = 0; i < RegBlockConfig::reg_m_packed; ++i) { #pragma unroll for (int j = 0; j < RegBlockConfig::reg_width; ++j) { - int out_channel = ((tidy + i * ThreadConfig::nr_thread_y) - << pack_size_bit); + int out_channel = + ((tidy + i * ThreadConfig::nr_thread_y) << pack_size_bit); int out_height_width = tidx + j * ThreadConfig::nr_thread_x; int ipack = (i << pack_size_bit); if (check_bounds && @@ -140,10 +141,10 @@ struct IConvGlobalMemoryWriter_COxHW { block_consumer.reg_acc[j][ipack + 1], block_consumer.reg_acc[j][ipack + 2], block_consumer.reg_acc[j][ipack + 3]); - float4 f_bias = - bias.at(0, out_channel, out_height_width); - epilogue.apply(alpha, f_conv, beta, f_bias, 0, - out_channel, out_height_width); + float4 f_bias = bias.at(0, out_channel, out_height_width); + epilogue.apply( + alpha, f_conv, beta, f_bias, 0, out_channel, + out_height_width); } } } diff --git a/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_unroll_width.cuh b/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_unroll_width.cuh index 15fdfb2a..b6892518 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_unroll_width.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_unroll_width.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_unroll_width.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_writer/iconv_global_memory_writer_unroll_width.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -47,16 +49,16 @@ struct IConvGlobalMemoryWriterUnrollWidth { int block_batch_remain; int block_out_channel_remain; - __device__ __forceinline__ void init(int32_t* /* smem */, - const float alpha_, - const float beta_) { + __device__ __forceinline__ void init( + int32_t* /* smem */, const float alpha_, const float beta_) { alpha = alpha_, beta = beta_; } - template - __device__ __forceinline__ void write(BiasVisitor bias, Epilogue epilogue, - BlockConsumer block_consumer) { + template < + bool check_bounds, typename BiasVisitor, typename Epilogue, + typename BlockConsumer> + __device__ __forceinline__ void write( + BiasVisitor bias, Epilogue epilogue, BlockConsumer block_consumer) { static constexpr bool use_wide_store = !(RegBlockConfig::reg_n & 0x1); static constexpr int pack_size_bit = RegBlockConfig::pack_size_bit; @@ -74,8 +76,7 @@ struct IConvGlobalMemoryWriterUnrollWidth { int out_channel = ((tidy + i * ThreadConfig::nr_thread_y) << pack_size_bit); - int batch = - (tidx << 1) + k2 * ThreadConfig::nr_thread_x; + int batch = (tidx << 1) + k2 * ThreadConfig::nr_thread_x; int ipack = (i << pack_size_bit); float4 f_conv0 = make_float4( block_consumer.reg_acc[k2][j][ipack], @@ -89,27 +90,24 @@ struct IConvGlobalMemoryWriterUnrollWidth { block_consumer.reg_acc[k2 + 1][j][ipack + 3]); if (!check_bounds) { float4 f_bias0 = bias.at(batch, out_channel, 0, j); - float4 f_bias1 = - bias.at(batch + 1, out_channel, 0, j); - epilogue.apply(alpha, f_conv0, f_conv1, beta, - f_bias0, f_bias1, batch, out_channel, - 0, j); + float4 f_bias1 = bias.at(batch + 1, out_channel, 0, j); + epilogue.apply( + alpha, f_conv0, f_conv1, beta, f_bias0, f_bias1, + batch, out_channel, 0, j); } else if (out_channel < block_out_channel_remain) { if (((block_batch_remain & 0x1) == 0) && batch + 2 <= block_batch_remain) { - float4 f_bias0 = - bias.at(batch, out_channel, 0, j); - float4 f_bias1 = - bias.at(batch + 1, out_channel, 0, j); - epilogue.apply(alpha, f_conv0, f_conv1, beta, - f_bias0, f_bias1, batch, - out_channel, 0, j); + float4 f_bias0 = bias.at(batch, out_channel, 0, j); + float4 f_bias1 = bias.at(batch + 1, out_channel, 0, j); + epilogue.apply( + alpha, f_conv0, f_conv1, beta, f_bias0, f_bias1, + batch, out_channel, 0, j); } else { -#define store(_i) \ - if (batch + (_i) < block_batch_remain) { \ - float4 f_bias##_i = bias.at(batch + (_i), out_channel, 0, j); \ - epilogue.apply(alpha, f_conv##_i, beta, f_bias##_i, batch + (_i), \ - out_channel, 0, j); \ +#define store(_i) \ + if (batch + (_i) < block_batch_remain) { \ + float4 f_bias##_i = bias.at(batch + (_i), out_channel, 0, j); \ + epilogue.apply( \ + alpha, f_conv##_i, beta, f_bias##_i, batch + (_i), out_channel, 0, j); \ } store(0); store(1); @@ -131,9 +129,8 @@ struct IConvGlobalMemoryWriterUnrollWidth { << pack_size_bit); int batch = tidx + k * ThreadConfig::nr_thread_x; int ipack = (i << pack_size_bit); - if (check_bounds && - (out_channel >= block_out_channel_remain || - batch >= block_batch_remain)) { + if (check_bounds && (out_channel >= block_out_channel_remain || + batch >= block_batch_remain)) { } else { float4 f_conv = make_float4( block_consumer.reg_acc[k][j][ipack], @@ -141,8 +138,9 @@ struct IConvGlobalMemoryWriterUnrollWidth { block_consumer.reg_acc[k][j][ipack + 2], block_consumer.reg_acc[k][j][ipack + 3]); float4 f_bias = bias.at(batch, out_channel, 0, j); - epilogue.apply(alpha, f_conv, beta, f_bias, batch, - out_channel, 0, j); + epilogue.apply( + alpha, f_conv, beta, f_bias, batch, out_channel, 0, + j); } } } diff --git a/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer.cuh b/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer.cuh index 0987459d..f1a82afe 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -60,16 +62,17 @@ struct IConvIMMAGlobalMemoryWriter { int block_batch_remain; int block_out_channel_remain; - __device__ __forceinline__ void init(int32_t* smem_, const float alpha_, - const float beta_) { + __device__ __forceinline__ void init( + int32_t* smem_, const float alpha_, const float beta_) { smem = smem_; alpha = alpha_, beta = beta_; } - template - __device__ __forceinline__ void write(BiasVisitor bias, Epilogue epilogue, - BlockConsumer block_consumer) { + template < + bool check_bounds, typename BiasVisitor, typename Epilogue, + typename BlockConsumer> + __device__ __forceinline__ void write( + BiasVisitor bias, Epilogue epilogue, BlockConsumer block_consumer) { #if __CUDA_ARCH__ >= 730 const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -82,27 +85,26 @@ struct IConvIMMAGlobalMemoryWriter { if (use_wide_store) { const int warpx2 = (warpx << 1); int32_t* st_sh_frag_ptr = - smem + - (warpy * ThreadConfig::nr_warp_x + warpx) * - (IMMAConfig::wmma_m * IMMAConfig::wmma_n << 1); + smem + (warpy * ThreadConfig::nr_warp_x + warpx) * + (IMMAConfig::wmma_m * IMMAConfig::wmma_n << 1); #pragma unroll for (int i = 0; i < WarpTileConfig::warp_tile_m; ++i) { #pragma urnoll for (int j = 0; j < (WarpTileConfig::warp_tile_n >> 1); ++j) { int j2 = (j << 1); static int const wmma_n2 = (IMMAConfig::wmma_n << 1); - wmma::store_matrix_sync(st_sh_frag_ptr, - block_consumer.frag_acc[i][j2], - wmma_n2, wmma::mem_row_major); - wmma::store_matrix_sync(st_sh_frag_ptr + IMMAConfig::wmma_n, - block_consumer.frag_acc[i][j2 + 1], - wmma_n2, wmma::mem_row_major); + wmma::store_matrix_sync( + st_sh_frag_ptr, block_consumer.frag_acc[i][j2], wmma_n2, + wmma::mem_row_major); + wmma::store_matrix_sync( + st_sh_frag_ptr + IMMAConfig::wmma_n, + block_consumer.frag_acc[i][j2 + 1], wmma_n2, + wmma::mem_row_major); const int sh_st_y = idx_intra_warp / GlobalMemoryStoreCount::store_x; const int sh_st_x = - idx_intra_warp - - sh_st_y * GlobalMemoryStoreCount::store_x; + idx_intra_warp - sh_st_y * GlobalMemoryStoreCount::store_x; const int wmma_tile_h_base = (sh_st_y << 2); const int wmma_tile_w = sh_st_x * GlobalMemoryStoreCount::store_width; @@ -116,30 +118,30 @@ struct IConvIMMAGlobalMemoryWriter { int const b1 = b0 + 1, b2 = b0 + 2, b3 = b0 + 3; st_type lane0 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 0) * - wmma_n2 + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 0) * wmma_n2 + + wmma_tile_w])); st_type lane1 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 1) * - wmma_n2 + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 1) * wmma_n2 + + wmma_tile_w])); st_type lane2 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 2) * - wmma_n2 + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 2) * wmma_n2 + + wmma_tile_w])); st_type lane3 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 3) * - wmma_n2 + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 3) * wmma_n2 + + wmma_tile_w])); - float4 f_conv0 = ::make_float4(lane0.x, lane1.x, - lane2.x, lane3.x); - float4 f_conv1 = ::make_float4(lane0.y, lane1.y, - lane2.y, lane3.y); - float4 f_conv2 = ::make_float4(lane0.z, lane1.z, - lane2.z, lane3.z); - float4 f_conv3 = ::make_float4(lane0.w, lane1.w, - lane2.w, lane3.w); + float4 f_conv0 = + ::make_float4(lane0.x, lane1.x, lane2.x, lane3.x); + float4 f_conv1 = + ::make_float4(lane0.y, lane1.y, lane2.y, lane3.y); + float4 f_conv2 = + ::make_float4(lane0.z, lane1.z, lane2.z, lane3.z); + float4 f_conv3 = + ::make_float4(lane0.w, lane1.w, lane2.w, lane3.w); // store to global memory if (!check_bounds) { @@ -148,9 +150,9 @@ struct IConvIMMAGlobalMemoryWriter { float4 f_bias2 = bias.at(b2, ch, 0, 0); float4 f_bias3 = bias.at(b3, ch, 0, 0); - epilogue.apply(alpha, f_conv0, f_conv1, f_conv2, - f_conv3, beta, f_bias0, f_bias1, - f_bias2, f_bias3, b0, ch, 0, 0); + epilogue.apply( + alpha, f_conv0, f_conv1, f_conv2, f_conv3, beta, + f_bias0, f_bias1, f_bias2, f_bias3, b0, ch, 0, 0); } else if (ch < block_out_channel_remain) { if (((block_batch_remain & 0x3) == 0) && b0 + 4 <= block_batch_remain) { @@ -159,9 +161,10 @@ struct IConvIMMAGlobalMemoryWriter { float4 f_bias2 = bias.at(b2, ch, 0, 0); float4 f_bias3 = bias.at(b3, ch, 0, 0); - epilogue.apply(alpha, f_conv0, f_conv1, f_conv2, - f_conv3, beta, f_bias0, f_bias1, - f_bias2, f_bias3, b0, ch, 0, 0); + epilogue.apply( + alpha, f_conv0, f_conv1, f_conv2, f_conv3, beta, + f_bias0, f_bias1, f_bias2, f_bias3, b0, ch, 0, + 0); } else { #define store(_idx) \ if (b0 + _idx < block_batch_remain) { \ @@ -178,9 +181,9 @@ struct IConvIMMAGlobalMemoryWriter { } // end j } // end i } else { - int32_t* st_sh_frag_ptr = - smem + (warpy * ThreadConfig::nr_warp_x + warpx) * - IMMAConfig::wmma_m * IMMAConfig::wmma_n; + int32_t* st_sh_frag_ptr = smem + (warpy * ThreadConfig::nr_warp_x + warpx) * + IMMAConfig::wmma_m * + IMMAConfig::wmma_n; #pragma unroll for (int i = 0; i < WarpTileConfig::warp_tile_m; ++i) { @@ -192,8 +195,7 @@ struct IConvIMMAGlobalMemoryWriter { const int sh_st_y = idx_intra_warp / GlobalMemoryStoreCount::store_x; const int sh_st_x = - idx_intra_warp - - sh_st_y * GlobalMemoryStoreCount::store_x; + idx_intra_warp - sh_st_y * GlobalMemoryStoreCount::store_x; const int wmma_tile_h_base = (sh_st_y << 2); const int wmma_tile_w = sh_st_x * GlobalMemoryStoreCount::store_width; @@ -207,30 +209,30 @@ struct IConvIMMAGlobalMemoryWriter { int const b1 = b0 + 1, b2 = b0 + 2, b3 = b0 + 3; st_type lane0 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 0) * - IMMAConfig::wmma_n + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 0) * IMMAConfig::wmma_n + + wmma_tile_w])); st_type lane1 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 1) * - IMMAConfig::wmma_n + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 1) * IMMAConfig::wmma_n + + wmma_tile_w])); st_type lane2 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 2) * - IMMAConfig::wmma_n + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 2) * IMMAConfig::wmma_n + + wmma_tile_w])); st_type lane3 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 3) * - IMMAConfig::wmma_n + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 3) * IMMAConfig::wmma_n + + wmma_tile_w])); - float4 f_conv0 = ::make_float4(lane0.x, lane1.x, - lane2.x, lane3.x); - float4 f_conv1 = ::make_float4(lane0.y, lane1.y, - lane2.y, lane3.y); - float4 f_conv2 = ::make_float4(lane0.z, lane1.z, - lane2.z, lane3.z); - float4 f_conv3 = ::make_float4(lane0.w, lane1.w, - lane2.w, lane3.w); + float4 f_conv0 = + ::make_float4(lane0.x, lane1.x, lane2.x, lane3.x); + float4 f_conv1 = + ::make_float4(lane0.y, lane1.y, lane2.y, lane3.y); + float4 f_conv2 = + ::make_float4(lane0.z, lane1.z, lane2.z, lane3.z); + float4 f_conv3 = + ::make_float4(lane0.w, lane1.w, lane2.w, lane3.w); // store to global memory if (!check_bounds) { @@ -238,9 +240,9 @@ struct IConvIMMAGlobalMemoryWriter { float4 f_bias1 = bias.at(b1, ch, 0, 0); float4 f_bias2 = bias.at(b2, ch, 0, 0); float4 f_bias3 = bias.at(b3, ch, 0, 0); - epilogue.apply(alpha, f_conv0, f_conv1, f_conv2, - f_conv3, beta, f_bias0, f_bias1, - f_bias2, f_bias3, b0, ch, 0, 0); + epilogue.apply( + alpha, f_conv0, f_conv1, f_conv2, f_conv3, beta, + f_bias0, f_bias1, f_bias2, f_bias3, b0, ch, 0, 0); } else if (ch < block_out_channel_remain) { if ((block_batch_remain & 0x3) == 0 && b0 + 4 <= block_batch_remain) { @@ -248,9 +250,10 @@ struct IConvIMMAGlobalMemoryWriter { float4 f_bias1 = bias.at(b1, ch, 0, 0); float4 f_bias2 = bias.at(b2, ch, 0, 0); float4 f_bias3 = bias.at(b3, ch, 0, 0); - epilogue.apply(alpha, f_conv0, f_conv1, f_conv2, - f_conv3, beta, f_bias0, f_bias1, - f_bias2, f_bias3, b0, ch, 0, 0); + epilogue.apply( + alpha, f_conv0, f_conv1, f_conv2, f_conv3, beta, + f_bias0, f_bias1, f_bias2, f_bias3, b0, ch, 0, + 0); } else { store(0); store(1); diff --git a/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer_unroll_width.cuh b/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer_unroll_width.cuh index 8a50869c..8fec1ad1 100644 --- a/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer_unroll_width.cuh +++ b/dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer_unroll_width.cuh @@ -1,29 +1,31 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** - * \file dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer_unroll_width.cuh + * \file + * dnn/src/cuda/convolution_helper/global_memory_writer/iconv_imma_global_memory_writer_unroll_width.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -61,16 +63,17 @@ struct IConvIMMAGlobalMemoryWriterUnrollWidth { int block_batch_remain; int block_out_channel_remain; - __device__ __forceinline__ void init(int32_t* smem_, const float alpha_, - const float beta_) { + __device__ __forceinline__ void init( + int32_t* smem_, const float alpha_, const float beta_) { smem = smem_; alpha = alpha_, beta = beta_; } - template - __device__ __forceinline__ void write(BiasVisitor bias, Epilogue epilogue, - BlockConsumer block_consumer) { + template < + bool check_bounds, typename BiasVisitor, typename Epilogue, + typename BlockConsumer> + __device__ __forceinline__ void write( + BiasVisitor bias, Epilogue epilogue, BlockConsumer block_consumer) { #if __CUDA_ARCH__ >= 730 const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -83,65 +86,63 @@ struct IConvIMMAGlobalMemoryWriterUnrollWidth { if (consecutive_width_tile) { const int warpx2 = (warpx << 1); int32_t* st_sh_frag_ptr = - smem + - (warpy * ThreadConfig::nr_warp_x + warpx) * - (IMMAConfig::wmma_m * IMMAConfig::wmma_n << 1); + smem + (warpy * ThreadConfig::nr_warp_x + warpx) * + (IMMAConfig::wmma_m * IMMAConfig::wmma_n << 1); #pragma unroll for (int i = 0; i < WarpTileConfig::warp_tile_m; ++i) { #pragma urnoll for (int j = 0; j < (WarpTileConfig::warp_tile_n >> 1); ++j) { int j2 = (j << 1); static int const wmma_n2 = (IMMAConfig::wmma_n << 1); - wmma::store_matrix_sync(st_sh_frag_ptr, - block_consumer.frag_acc[i][j2], - wmma_n2, wmma::mem_row_major); - wmma::store_matrix_sync(st_sh_frag_ptr + IMMAConfig::wmma_n, - block_consumer.frag_acc[i][j2 + 1], - wmma_n2, wmma::mem_row_major); + wmma::store_matrix_sync( + st_sh_frag_ptr, block_consumer.frag_acc[i][j2], wmma_n2, + wmma::mem_row_major); + wmma::store_matrix_sync( + st_sh_frag_ptr + IMMAConfig::wmma_n, + block_consumer.frag_acc[i][j2 + 1], wmma_n2, + wmma::mem_row_major); const int sh_st_y = idx_intra_warp / GlobalMemoryStoreCount::store_x; const int sh_st_x = - idx_intra_warp - - sh_st_y * GlobalMemoryStoreCount::store_x; + idx_intra_warp - sh_st_y * GlobalMemoryStoreCount::store_x; const int wmma_tile_h_base = (sh_st_y << 2); const int wmma_tile_w = sh_st_x * GlobalMemoryStoreCount::store_width; if (wmma_tile_h_base + 4 <= IMMAConfig::wmma_m) { int const b0 = wmma_tile_w & (IMMAConfig::wmma_n - 1); - int const width = - (warpx2 + j2 * ThreadConfig::nr_warp_x) + - (wmma_tile_w >> IMMAConfig::wmma_n_bit); + int const width = (warpx2 + j2 * ThreadConfig::nr_warp_x) + + (wmma_tile_w >> IMMAConfig::wmma_n_bit); int const ch = (warpy + i * ThreadConfig::nr_warp_y) * IMMAConfig::wmma_m + wmma_tile_h_base; int const b1 = b0 + 1, b2 = b0 + 2, b3 = b0 + 3; st_type lane0 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 0) * - wmma_n2 + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 0) * wmma_n2 + + wmma_tile_w])); st_type lane1 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 1) * - wmma_n2 + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 1) * wmma_n2 + + wmma_tile_w])); st_type lane2 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 2) * - wmma_n2 + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 2) * wmma_n2 + + wmma_tile_w])); st_type lane3 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 3) * - wmma_n2 + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 3) * wmma_n2 + + wmma_tile_w])); - float4 f_conv0 = ::make_float4(lane0.x, lane1.x, - lane2.x, lane3.x); - float4 f_conv1 = ::make_float4(lane0.y, lane1.y, - lane2.y, lane3.y); - float4 f_conv2 = ::make_float4(lane0.z, lane1.z, - lane2.z, lane3.z); - float4 f_conv3 = ::make_float4(lane0.w, lane1.w, - lane2.w, lane3.w); + float4 f_conv0 = + ::make_float4(lane0.x, lane1.x, lane2.x, lane3.x); + float4 f_conv1 = + ::make_float4(lane0.y, lane1.y, lane2.y, lane3.y); + float4 f_conv2 = + ::make_float4(lane0.z, lane1.z, lane2.z, lane3.z); + float4 f_conv3 = + ::make_float4(lane0.w, lane1.w, lane2.w, lane3.w); // store to global memory if (!check_bounds) { @@ -150,9 +151,10 @@ struct IConvIMMAGlobalMemoryWriterUnrollWidth { float4 f_bias2 = bias.at(b2, ch, 0, width); float4 f_bias3 = bias.at(b3, ch, 0, width); - epilogue.apply(alpha, f_conv0, f_conv1, f_conv2, - f_conv3, beta, f_bias0, f_bias1, - f_bias2, f_bias3, b0, ch, 0, width); + epilogue.apply( + alpha, f_conv0, f_conv1, f_conv2, f_conv3, beta, + f_bias0, f_bias1, f_bias2, f_bias3, b0, ch, 0, + width); } else if (ch < block_out_channel_remain) { if ((block_batch_remain & 0x3) == 0 && b0 + 4 <= block_batch_remain) { @@ -161,16 +163,15 @@ struct IConvIMMAGlobalMemoryWriterUnrollWidth { float4 f_bias2 = bias.at(b2, ch, 0, width); float4 f_bias3 = bias.at(b3, ch, 0, width); - epilogue.apply(alpha, f_conv0, f_conv1, f_conv2, - f_conv3, beta, f_bias0, f_bias1, - f_bias2, f_bias3, b0, ch, 0, - width); + epilogue.apply( + alpha, f_conv0, f_conv1, f_conv2, f_conv3, beta, + f_bias0, f_bias1, f_bias2, f_bias3, b0, ch, 0, + width); } else { -#define store(_idx) \ - if (b0 + _idx < block_batch_remain) { \ - float4 f_bias = bias.at(b##_idx, ch, 0, width); \ - epilogue.apply(alpha, f_conv##_idx, beta, f_bias, b##_idx, ch, 0, \ - width); \ +#define store(_idx) \ + if (b0 + _idx < block_batch_remain) { \ + float4 f_bias = bias.at(b##_idx, ch, 0, width); \ + epilogue.apply(alpha, f_conv##_idx, beta, f_bias, b##_idx, ch, 0, width); \ } store(0); store(1); @@ -182,9 +183,9 @@ struct IConvIMMAGlobalMemoryWriterUnrollWidth { } // end j } // end i } else { - int32_t* st_sh_frag_ptr = - smem + (warpy * ThreadConfig::nr_warp_x + warpx) * - IMMAConfig::wmma_m * IMMAConfig::wmma_n; + int32_t* st_sh_frag_ptr = smem + (warpy * ThreadConfig::nr_warp_x + warpx) * + IMMAConfig::wmma_m * + IMMAConfig::wmma_n; #pragma unroll for (int i = 0; i < WarpTileConfig::warp_tile_m; ++i) { @@ -196,8 +197,7 @@ struct IConvIMMAGlobalMemoryWriterUnrollWidth { const int sh_st_y = idx_intra_warp / GlobalMemoryStoreCount::store_x; const int sh_st_x = - idx_intra_warp - - sh_st_y * GlobalMemoryStoreCount::store_x; + idx_intra_warp - sh_st_y * GlobalMemoryStoreCount::store_x; const int wmma_tile_h_base = (sh_st_y << 2); const int wmma_tile_w = sh_st_x * GlobalMemoryStoreCount::store_width; @@ -210,30 +210,30 @@ struct IConvIMMAGlobalMemoryWriterUnrollWidth { int const b1 = b0 + 1, b2 = b0 + 2, b3 = b0 + 3; st_type lane0 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 0) * - IMMAConfig::wmma_n + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 0) * IMMAConfig::wmma_n + + wmma_tile_w])); st_type lane1 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 1) * - IMMAConfig::wmma_n + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 1) * IMMAConfig::wmma_n + + wmma_tile_w])); st_type lane2 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 2) * - IMMAConfig::wmma_n + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 2) * IMMAConfig::wmma_n + + wmma_tile_w])); st_type lane3 = *(reinterpret_cast( - &st_sh_frag_ptr[(wmma_tile_h_base + 3) * - IMMAConfig::wmma_n + - wmma_tile_w])); + &st_sh_frag_ptr + [(wmma_tile_h_base + 3) * IMMAConfig::wmma_n + + wmma_tile_w])); - float4 f_conv0 = ::make_float4(lane0.x, lane1.x, - lane2.x, lane3.x); - float4 f_conv1 = ::make_float4(lane0.y, lane1.y, - lane2.y, lane3.y); - float4 f_conv2 = ::make_float4(lane0.z, lane1.z, - lane2.z, lane3.z); - float4 f_conv3 = ::make_float4(lane0.w, lane1.w, - lane2.w, lane3.w); + float4 f_conv0 = + ::make_float4(lane0.x, lane1.x, lane2.x, lane3.x); + float4 f_conv1 = + ::make_float4(lane0.y, lane1.y, lane2.y, lane3.y); + float4 f_conv2 = + ::make_float4(lane0.z, lane1.z, lane2.z, lane3.z); + float4 f_conv3 = + ::make_float4(lane0.w, lane1.w, lane2.w, lane3.w); // store to global memory if (!check_bounds) { @@ -242,9 +242,10 @@ struct IConvIMMAGlobalMemoryWriterUnrollWidth { float4 f_bias2 = bias.at(b2, ch, 0, width); float4 f_bias3 = bias.at(b3, ch, 0, width); - epilogue.apply(alpha, f_conv0, f_conv1, f_conv2, - f_conv3, beta, f_bias0, f_bias1, - f_bias2, f_bias3, b0, ch, 0, width); + epilogue.apply( + alpha, f_conv0, f_conv1, f_conv2, f_conv3, beta, + f_bias0, f_bias1, f_bias2, f_bias3, b0, ch, 0, + width); } else if (ch < block_out_channel_remain) { if ((block_batch_remain & 0x3) == 0 && b0 + 4 <= block_batch_remain) { @@ -253,10 +254,10 @@ struct IConvIMMAGlobalMemoryWriterUnrollWidth { float4 f_bias2 = bias.at(b2, ch, 0, width); float4 f_bias3 = bias.at(b3, ch, 0, width); - epilogue.apply(alpha, f_conv0, f_conv1, f_conv2, - f_conv3, beta, f_bias0, f_bias1, - f_bias2, f_bias3, b0, ch, 0, - width); + epilogue.apply( + alpha, f_conv0, f_conv1, f_conv2, f_conv3, beta, + f_bias0, f_bias1, f_bias2, f_bias3, b0, ch, 0, + width); } else { store(0); store(1); @@ -273,7 +274,7 @@ struct IConvIMMAGlobalMemoryWriterUnrollWidth { } }; -} // namespace cuda +} // namespace convolution } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/convolution_helper/kernel.cuh b/dnn/src/cuda/convolution_helper/kernel.cuh index 38af58e3..1821500a 100644 --- a/dnn/src/cuda/convolution_helper/kernel.cuh +++ b/dnn/src/cuda/convolution_helper/kernel.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -45,9 +46,8 @@ namespace convolution { template __global__ void convolution_kernel( const typename ConvTrait::src_dtype* __restrict__ src, - const typename ConvTrait::filter_dtype* __restrict__ filter, - BiasVisitor bias, Epilogue epilogue, typename ConvTrait::Param param, - float alpha, float beta) { + const typename ConvTrait::filter_dtype* __restrict__ filter, BiasVisitor bias, + Epilogue epilogue, typename ConvTrait::Param param, float alpha, float beta) { static bool constexpr check_bounds = ConvTrait::check_bounds; using BlockTileIterator = typename ConvTrait::BlockTileIterator; BlockTileIterator block_iterator; @@ -58,8 +58,7 @@ __global__ void convolution_kernel( using DataTileCount = typename ConvTrait::DataTileCount; using FilterTileCount = typename ConvTrait::FilterTileCount; - using DataGlobal2ShareMemVisitor = - typename ConvTrait::DataGlobal2ShareMemVisitor; + using DataGlobal2ShareMemVisitor = typename ConvTrait::DataGlobal2ShareMemVisitor; using FilterGlobal2ShareMemVisitor = typename ConvTrait::FilterGlobal2ShareMemVisitor; @@ -72,16 +71,15 @@ __global__ void convolution_kernel( DataGlobal2ShareMemVisitor src_gl2sh_visitor{smem_src}; FilterGlobal2ShareMemVisitor filter_gl2sh_visitor{smem_filter}; if (check_bounds) { - block_iterator.set_remain(src_gl2sh_visitor, - filter_gl2sh_visitor); + block_iterator.set_remain(src_gl2sh_visitor, filter_gl2sh_visitor); } using BlockConsumer = typename ConvTrait::BlockConsumer; BlockConsumer block_consumer; block_consumer.init_accumulator(); - block_iterator.template iterate_with_param( + block_iterator.template iterate_with_param< + typename ConvTrait::InputLayout, typename ConvTrait::KernLayout>( src, filter, param, src_gl2sh_visitor, filter_gl2sh_visitor, block_consumer); @@ -91,13 +89,13 @@ __global__ void convolution_kernel( if (check_bounds) { block_iterator.set_remain(global_memory_writer); } - bias.move(block_iterator.block_batch, block_iterator.block_out_channel, - block_iterator.block_out_height, block_iterator.block_out_width); - epilogue.move(block_iterator.block_batch, block_iterator.block_out_channel, - block_iterator.block_out_height, - block_iterator.block_out_width); - global_memory_writer.template write(bias, epilogue, - block_consumer); + bias.move( + block_iterator.block_batch, block_iterator.block_out_channel, + block_iterator.block_out_height, block_iterator.block_out_width); + epilogue.move( + block_iterator.block_batch, block_iterator.block_out_channel, + block_iterator.block_out_height, block_iterator.block_out_width); + global_memory_writer.template write(bias, epilogue, block_consumer); } template @@ -116,8 +114,7 @@ __global__ void convolution_kernel_precomp_offset( using DataTileCount = typename ConvTrait::DataTileCount; using FilterTileCount = typename ConvTrait::FilterTileCount; - using DataGlobal2ShareMemVisitor = - typename ConvTrait::DataGlobal2ShareMemVisitor; + using DataGlobal2ShareMemVisitor = typename ConvTrait::DataGlobal2ShareMemVisitor; using FilterGlobal2ShareMemVisitor = typename ConvTrait::FilterGlobal2ShareMemVisitor; @@ -130,16 +127,15 @@ __global__ void convolution_kernel_precomp_offset( DataGlobal2ShareMemVisitor src_gl2sh_visitor{smem_src, offset}; FilterGlobal2ShareMemVisitor filter_gl2sh_visitor{smem_filter}; if (check_bounds) { - block_iterator.set_remain(src_gl2sh_visitor, - filter_gl2sh_visitor); + block_iterator.set_remain(src_gl2sh_visitor, filter_gl2sh_visitor); } using BlockConsumer = typename ConvTrait::BlockConsumer; BlockConsumer block_consumer; block_consumer.init_accumulator(); - block_iterator.template iterate_with_param( + block_iterator.template iterate_with_param< + typename ConvTrait::InputLayout, typename ConvTrait::KernLayout>( src, filter, param, src_gl2sh_visitor, filter_gl2sh_visitor, block_consumer); @@ -149,13 +145,13 @@ __global__ void convolution_kernel_precomp_offset( if (check_bounds) { block_iterator.set_remain(global_memory_writer); } - bias.move(block_iterator.block_batch, block_iterator.block_out_channel, - block_iterator.block_out_height, block_iterator.block_out_width); - epilogue.move(block_iterator.block_batch, block_iterator.block_out_channel, - block_iterator.block_out_height, - block_iterator.block_out_width); - global_memory_writer.template write(bias, epilogue, - block_consumer); + bias.move( + block_iterator.block_batch, block_iterator.block_out_channel, + block_iterator.block_out_height, block_iterator.block_out_width); + epilogue.move( + block_iterator.block_batch, block_iterator.block_out_channel, + block_iterator.block_out_height, block_iterator.block_out_width); + global_memory_writer.template write(bias, epilogue, block_consumer); } } // namespace convolution diff --git a/dnn/src/cuda/convolution_helper/layout.cuh b/dnn/src/cuda/convolution_helper/layout.cuh index 930ad099..6b851f08 100644 --- a/dnn/src/cuda/convolution_helper/layout.cuh +++ b/dnn/src/cuda/convolution_helper/layout.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -52,19 +53,17 @@ struct Layout { int height_stride; int width_stride; - __host__ __device__ __forceinline__ void init(const int batch, - const int /* channel */, - const int height, - const int width) { + __host__ __device__ __forceinline__ void init( + const int batch, const int /* channel */, const int height, + const int width) { batch_stride = 4; channel_stride = height * width * batch * 4; height_stride = width * batch * 4; width_stride = batch * 4; } - __device__ __forceinline__ size_t offset(const int batch, const int channel, - const int height, - const int width) { + __device__ __forceinline__ size_t + offset(const int batch, const int channel, const int height, const int width) { return batch * batch_stride + (channel >> 2) * channel_stride + height * height_stride + width * width_stride; } @@ -78,19 +77,17 @@ struct Layout { int height_stride; int width_stride; - __host__ __device__ __forceinline__ void init(const int batch, - const int /* channel */, - const int height, - const int width) { + __host__ __device__ __forceinline__ void init( + const int batch, const int /* channel */, const int height, + const int width) { batch_stride = 16; channel_stride = height * width * batch * 16; height_stride = width * batch * 16; width_stride = batch * 16; } - __device__ __forceinline__ size_t offset(const int batch, const int channel, - const int height, - const int width) { + __device__ __forceinline__ size_t + offset(const int batch, const int channel, const int height, const int width) { return batch * batch_stride + (channel >> 4) * channel_stride + height * height_stride + width * width_stride; } @@ -104,19 +101,17 @@ struct Layout { int height_stride; int width_stride; - __host__ __device__ __forceinline__ void init(const int /* batch */, - const int channel, - const int height, - const int width) { + __host__ __device__ __forceinline__ void init( + const int /* batch */, const int channel, const int height, + const int width) { batch_stride = channel * height * width; channel_stride = height * width * 4; height_stride = width * 4; width_stride = 4; } - __device__ __forceinline__ size_t offset(const int batch, const int channel, - const int height, - const int width) { + __device__ __forceinline__ size_t + offset(const int batch, const int channel, const int height, const int width) { return batch * batch_stride + (channel >> 2) * channel_stride + height * height_stride + width * width_stride; } diff --git a/dnn/src/cuda/convolution_helper/parameter.cuh b/dnn/src/cuda/convolution_helper/parameter.cuh index 889bd84a..6736e13e 100644 --- a/dnn/src/cuda/convolution_helper/parameter.cuh +++ b/dnn/src/cuda/convolution_helper/parameter.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** diff --git a/dnn/src/cuda/convolution_helper/prologue.cuh b/dnn/src/cuda/convolution_helper/prologue.cuh index df7e325d..30bff496 100644 --- a/dnn/src/cuda/convolution_helper/prologue.cuh +++ b/dnn/src/cuda/convolution_helper/prologue.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -42,9 +43,8 @@ struct ConvPrologue { template static __device__ __forceinline__ void prologue( const src_dtype* __restrict__& /* src */, - const filter_dtype* __restrict__& /* filter */, - const Param& /* param */, const int /* batch */, - const int /* channel */, const int /* height */, + const filter_dtype* __restrict__& /* filter */, const Param& /* param */, + const int /* batch */, const int /* channel */, const int /* height */, const int /* width */) {} }; diff --git a/dnn/src/cuda/convpooling/conv_pooling.cuh b/dnn/src/cuda/convpooling/conv_pooling.cuh index f3d122b3..40c10a72 100644 --- a/dnn/src/cuda/convpooling/conv_pooling.cuh +++ b/dnn/src/cuda/convpooling/conv_pooling.cuh @@ -16,18 +16,14 @@ namespace megdnn { namespace cuda { namespace conv_pool { -template +template < + int kern_h, int kern_w, int pool_shape_h, int pool_shape_w, class Nonlin, + class Pooler, class IdxGetter> __global__ void kern_xcorr_smallkern_pool( - float *input, - const float *filter, - float *output, - const float *output_bias, - cudaTextureObject_t m_tex, - int IC, int IH, int IW, - int OH, int OW); + float* input, const float* filter, float* output, const float* output_bias, + cudaTextureObject_t m_tex, int IC, int IH, int IW, int OH, int OW); -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/conv_pooling.h b/dnn/src/cuda/convpooling/conv_pooling.h index 0192b082..413a89e1 100644 --- a/dnn/src/cuda/convpooling/conv_pooling.h +++ b/dnn/src/cuda/convpooling/conv_pooling.h @@ -17,46 +17,25 @@ namespace conv_pool { #define NR_PXL_PER_THREAD 4 #define NR_THREAD_PER_BLOCK 192 -#define MAX_SHARED_MEM_SIZE 32768 //32 * 1024 -#define MAX_TEX_OBJ_SIZE 134217728 //2^27 +#define MAX_SHARED_MEM_SIZE 32768 // 32 * 1024 +#define MAX_TEX_OBJ_SIZE 134217728 // 2^27 #define HEIGHT_EQUALS_WITH_WEIGHT -enum PoolModeCu { - AVERAGE = 0, - MAX = 1 -}; +enum PoolModeCu { AVERAGE = 0, MAX = 1 }; -enum ConvModeCu { - CROSS_CORRELATION = 0, - CONVOLUTION = 1 -}; +enum ConvModeCu { CROSS_CORRELATION = 0, CONVOLUTION = 1 }; -enum NonlineModeCu{ - IDENTITY = 0, - RELU = 1, - SIGMOID = 2 -}; +enum NonlineModeCu { IDENTITY = 0, RELU = 1, SIGMOID = 2 }; void start_gpu_xcorr_pool_with_texture_obj( - cudaStream_t stream, - float *input, - const float *kernel, - float *output, - size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t /*PH*/, size_t /*PW*/, - size_t /*SH*/, size_t /*SW*/, - size_t pool_shape_h, - size_t pool_shape_w, - PoolModeCu poolMode, - ConvModeCu convMode, - NonlineModeCu nonlineMode, - const float *bias); - -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn + cudaStream_t stream, float* input, const float* kernel, float* output, size_t N, + size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, + size_t FW, size_t /*PH*/, size_t /*PW*/, size_t /*SH*/, size_t /*SW*/, + size_t pool_shape_h, size_t pool_shape_w, PoolModeCu poolMode, + ConvModeCu convMode, NonlineModeCu nonlineMode, const float* bias); + +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/conv_pooling_tex.cu b/dnn/src/cuda/convpooling/conv_pooling_tex.cu index 76b16d35..dae5429f 100644 --- a/dnn/src/cuda/convpooling/conv_pooling_tex.cu +++ b/dnn/src/cuda/convpooling/conv_pooling_tex.cu @@ -19,14 +19,13 @@ namespace conv_pool { #define NR_PXL_PER_THREAD 4 #define NR_THREAD_PER_BLOCK 192 -#define MAX_SHARED_MEM_SIZE 32768 //32 * 1024 -#define MAX_TEX_OBJ_SIZE 134217728 //2^27 +#define MAX_SHARED_MEM_SIZE 32768 // 32 * 1024 +#define MAX_TEX_OBJ_SIZE 134217728 // 2^27 #define HEIGHT_EQUALS_WITH_WEIGHT - - __host__ void create_cuda_tex(float *input, cudaTextureObject_t& tex, - size_t N, size_t IC, size_t IH, size_t IW) { - +__host__ void create_cuda_tex( + float* input, cudaTextureObject_t& tex, size_t N, size_t IC, size_t IH, + size_t IW) { struct cudaResourceDesc res_desc; memset(&res_desc, 0, sizeof(res_desc)); res_desc.resType = cudaResourceTypeLinear; @@ -40,195 +39,247 @@ namespace conv_pool { tex_desc.addressMode[1] = cudaAddressModeClamp; tex_desc.addressMode[2] = cudaAddressModeClamp; tex_desc.readMode = cudaReadModeElementType; - CUDA_CHKERR(cudaCreateTextureObject( - &tex, &res_desc, &tex_desc, NULL)); - + CUDA_CHKERR(cudaCreateTextureObject(&tex, &res_desc, &tex_desc, NULL)); } void start_gpu_xcorr_pool_with_texture_obj( - cudaStream_t stream, - float *input, - const float *kernel, - float *output, - size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t /*PH*/, size_t /*PW*/, - size_t /*SH*/, size_t /*SW*/, - size_t pool_shape_h, - size_t pool_shape_w, - PoolModeCu poolMode, - ConvModeCu convMode, - NonlineModeCu nonlineMode, - const float *bias) { - - int nr_batch = N, nr_oc = OC, - output_area2d = OH * OW, - kern_h = FH, kern_w = FW, - nr_thread_per_block = std::min(NR_THREAD_PER_BLOCK, - align_to_warp(output_area2d)), - oplane_nr_split = std::max(1, - output_area2d / (nr_thread_per_block * NR_PXL_PER_THREAD)), + cudaStream_t stream, float* input, const float* kernel, float* output, size_t N, + size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, + size_t FW, size_t /*PH*/, size_t /*PW*/, size_t /*SH*/, size_t /*SW*/, + size_t pool_shape_h, size_t pool_shape_w, PoolModeCu poolMode, + ConvModeCu convMode, NonlineModeCu nonlineMode, const float* bias) { + int nr_batch = N, nr_oc = OC, output_area2d = OH * OW, kern_h = FH, kern_w = FW, + nr_thread_per_block = + std::min(NR_THREAD_PER_BLOCK, align_to_warp(output_area2d)), + oplane_nr_split = + std::max(1, output_area2d / (nr_thread_per_block * NR_PXL_PER_THREAD)), share_size = kern_h * kern_w * IC * sizeof(float); - megdnn_assert(share_size < MAX_SHARED_MEM_SIZE, "kernel too large: " + megdnn_assert( + share_size < MAX_SHARED_MEM_SIZE, + "kernel too large: " "total %d bytes per output channel allowed, got %d", MAX_SHARED_MEM_SIZE, share_size); - void (*f) (float *input, - const float *filter, - float *output, - const float *output_bias, - cudaTextureObject_t m_tex, - int IC, int IH, int IW, - int OH, int OW) = NULL; - -#define DISPATCH_POOLMODE(nonlin, kh, kw, ph, pw, convMode) \ - do { \ - switch (poolMode) { \ - case AVERAGE: \ - f = kern_xcorr_smallkern_pool; \ - break; \ - case MAX: \ - f = kern_xcorr_smallkern_pool; \ - break; \ - } \ - } while(0) - -#define DISPATCH_CONVMODE(nonlin, kh, kw, ph, pw) \ - do { \ - switch (convMode) { \ - case CONVOLUTION: DISPATCH_POOLMODE \ - (nonlin, kh, kw, ph, pw, IdxGetterConvolution); break; \ - case CROSS_CORRELATION: DISPATCH_POOLMODE\ - (nonlin, kh, kw, ph, pw, IdxGetterCorrRel); break; \ - } \ - } while(0) + void (*f)( + float* input, const float* filter, float* output, const float* output_bias, + cudaTextureObject_t m_tex, int IC, int IH, int IW, int OH, int OW) = NULL; + +#define DISPATCH_POOLMODE(nonlin, kh, kw, ph, pw, convMode) \ + do { \ + switch (poolMode) { \ + case AVERAGE: \ + f = kern_xcorr_smallkern_pool< \ + kh, kw, ph, pw, nonlin, MeanPooler, convMode>; \ + break; \ + case MAX: \ + f = kern_xcorr_smallkern_pool< \ + kh, kw, ph, pw, nonlin, MaxPooler, convMode>; \ + break; \ + } \ + } while (0) + +#define DISPATCH_CONVMODE(nonlin, kh, kw, ph, pw) \ + do { \ + switch (convMode) { \ + case CONVOLUTION: \ + DISPATCH_POOLMODE(nonlin, kh, kw, ph, pw, IdxGetterConvolution); \ + break; \ + case CROSS_CORRELATION: \ + DISPATCH_POOLMODE(nonlin, kh, kw, ph, pw, IdxGetterCorrRel); \ + break; \ + } \ + } while (0) #ifdef HEIGHT_EQUALS_WITH_WEIGHT -#define DISPATCH_POOLSHAPE(nonlin, kh, kw) \ - do { \ - switch (pool_shape_h) { \ - case 1: DISPATCH_CONVMODE(nonlin, kh, kw, 1, 1); break; \ - case 2: DISPATCH_CONVMODE(nonlin, kh, kw, 2, 2); break; \ - case 3: DISPATCH_CONVMODE(nonlin, kh, kw, 3, 3); break; \ - case 4: DISPATCH_CONVMODE(nonlin, kh, kw, 4, 4); break; \ - } \ - } while(0) - -#define DISPATCH_KERN_H(nonlin) \ - do { \ - switch(kern_h) { \ - case 1: DISPATCH_POOLSHAPE(nonlin, 1, 1); break;\ - case 2: DISPATCH_POOLSHAPE(nonlin, 2, 2); break;\ - case 3: DISPATCH_POOLSHAPE(nonlin, 3, 3); break;\ - case 4: DISPATCH_POOLSHAPE(nonlin, 4, 4); break;\ - case 5: DISPATCH_POOLSHAPE(nonlin, 5, 5); break;\ - case 6: DISPATCH_POOLSHAPE(nonlin, 6, 6); break;\ - case 7: DISPATCH_POOLSHAPE(nonlin, 7, 7); break;\ - } \ - } while(0) - -#else //HEIGHT_EQUALS_WITH_WEIGHT - -#define DISPATCH_POOLSHAPE_W(nonlin, kh, kw, ph) \ - do { \ - switch (pool_shape_w) { \ - case 1: DISPATCH_CONVMODE(nonlin, kh, kw, ph, 1); break; \ - case 2: DISPATCH_CONVMODE(nonlin, kh, kw, ph, 2); break; \ - case 3: DISPATCH_CONVMODE(nonlin, kh, kw, ph, 3); break; \ - case 4: DISPATCH_CONVMODE(nonlin, kh, kw, ph, 4); break; \ - } \ - } while(0) - -#define DISPATCH_POOLSHAPE_H(nonlin, kern_h, kern_w) \ - do { \ - switch (pool_shape_h) { \ - case 1: DISPATCH_POOLSHAPE_W(nonlin, kern_h, kern_w, 1); break; \ - case 2: DISPATCH_POOLSHAPE_W(nonlin, kern_h, kern_w, 2); break; \ - case 3: DISPATCH_POOLSHAPE_W(nonlin, kern_h, kern_w, 3); break; \ - case 4: DISPATCH_POOLSHAPE_W(nonlin, kern_h, kern_w, 4); break; \ - } \ - } while(0) - -#define DISPATCH_KERN_W(nonlin, kern_h) \ - do { \ - switch(kern_w) { \ - case 1: DISPATCH_POOLSHAPE_H(nonlin, kern_h, 1); break;\ - case 2: DISPATCH_POOLSHAPE_H(nonlin, kern_h, 2); break;\ - case 3: DISPATCH_POOLSHAPE_H(nonlin, kern_h, 3); break;\ - case 4: DISPATCH_POOLSHAPE_H(nonlin, kern_h, 4); break;\ - case 5: DISPATCH_POOLSHAPE_H(nonlin, kern_h, 5); break;\ - case 6: DISPATCH_POOLSHAPE_H(nonlin, kern_h, 6); break;\ - case 7: DISPATCH_POOLSHAPE_H(nonlin, kern_h, 7); break;\ - case 8: DISPATCH_POOLSHAPE_H(nonlin, kern_h, 8); break;\ - } \ - } while(0) - -#define DISPATCH_KERN_H(nonlin) \ - do { \ - switch(kern_h) { \ - case 1: DISPATCH_KERN_W(nonlin, 1); break;\ - case 2: DISPATCH_KERN_W(nonlin, 2); break;\ - case 3: DISPATCH_KERN_W(nonlin, 3); break;\ - case 4: DISPATCH_KERN_W(nonlin, 4); break;\ - case 5: DISPATCH_KERN_W(nonlin, 5); break;\ - case 6: DISPATCH_KERN_W(nonlin, 6); break;\ - case 7: DISPATCH_KERN_W(nonlin, 7); break;\ - case 8: DISPATCH_KERN_W(nonlin, 8); break;\ - } \ - } while(0) - -#endif //HEIGHT_EQUALS_WITH_WEIGHT - switch(nonlineMode) { +#define DISPATCH_POOLSHAPE(nonlin, kh, kw) \ + do { \ + switch (pool_shape_h) { \ + case 1: \ + DISPATCH_CONVMODE(nonlin, kh, kw, 1, 1); \ + break; \ + case 2: \ + DISPATCH_CONVMODE(nonlin, kh, kw, 2, 2); \ + break; \ + case 3: \ + DISPATCH_CONVMODE(nonlin, kh, kw, 3, 3); \ + break; \ + case 4: \ + DISPATCH_CONVMODE(nonlin, kh, kw, 4, 4); \ + break; \ + } \ + } while (0) + +#define DISPATCH_KERN_H(nonlin) \ + do { \ + switch (kern_h) { \ + case 1: \ + DISPATCH_POOLSHAPE(nonlin, 1, 1); \ + break; \ + case 2: \ + DISPATCH_POOLSHAPE(nonlin, 2, 2); \ + break; \ + case 3: \ + DISPATCH_POOLSHAPE(nonlin, 3, 3); \ + break; \ + case 4: \ + DISPATCH_POOLSHAPE(nonlin, 4, 4); \ + break; \ + case 5: \ + DISPATCH_POOLSHAPE(nonlin, 5, 5); \ + break; \ + case 6: \ + DISPATCH_POOLSHAPE(nonlin, 6, 6); \ + break; \ + case 7: \ + DISPATCH_POOLSHAPE(nonlin, 7, 7); \ + break; \ + } \ + } while (0) + +#else // HEIGHT_EQUALS_WITH_WEIGHT + +#define DISPATCH_POOLSHAPE_W(nonlin, kh, kw, ph) \ + do { \ + switch (pool_shape_w) { \ + case 1: \ + DISPATCH_CONVMODE(nonlin, kh, kw, ph, 1); \ + break; \ + case 2: \ + DISPATCH_CONVMODE(nonlin, kh, kw, ph, 2); \ + break; \ + case 3: \ + DISPATCH_CONVMODE(nonlin, kh, kw, ph, 3); \ + break; \ + case 4: \ + DISPATCH_CONVMODE(nonlin, kh, kw, ph, 4); \ + break; \ + } \ + } while (0) + +#define DISPATCH_POOLSHAPE_H(nonlin, kern_h, kern_w) \ + do { \ + switch (pool_shape_h) { \ + case 1: \ + DISPATCH_POOLSHAPE_W(nonlin, kern_h, kern_w, 1); \ + break; \ + case 2: \ + DISPATCH_POOLSHAPE_W(nonlin, kern_h, kern_w, 2); \ + break; \ + case 3: \ + DISPATCH_POOLSHAPE_W(nonlin, kern_h, kern_w, 3); \ + break; \ + case 4: \ + DISPATCH_POOLSHAPE_W(nonlin, kern_h, kern_w, 4); \ + break; \ + } \ + } while (0) + +#define DISPATCH_KERN_W(nonlin, kern_h) \ + do { \ + switch (kern_w) { \ + case 1: \ + DISPATCH_POOLSHAPE_H(nonlin, kern_h, 1); \ + break; \ + case 2: \ + DISPATCH_POOLSHAPE_H(nonlin, kern_h, 2); \ + break; \ + case 3: \ + DISPATCH_POOLSHAPE_H(nonlin, kern_h, 3); \ + break; \ + case 4: \ + DISPATCH_POOLSHAPE_H(nonlin, kern_h, 4); \ + break; \ + case 5: \ + DISPATCH_POOLSHAPE_H(nonlin, kern_h, 5); \ + break; \ + case 6: \ + DISPATCH_POOLSHAPE_H(nonlin, kern_h, 6); \ + break; \ + case 7: \ + DISPATCH_POOLSHAPE_H(nonlin, kern_h, 7); \ + break; \ + case 8: \ + DISPATCH_POOLSHAPE_H(nonlin, kern_h, 8); \ + break; \ + } \ + } while (0) + +#define DISPATCH_KERN_H(nonlin) \ + do { \ + switch (kern_h) { \ + case 1: \ + DISPATCH_KERN_W(nonlin, 1); \ + break; \ + case 2: \ + DISPATCH_KERN_W(nonlin, 2); \ + break; \ + case 3: \ + DISPATCH_KERN_W(nonlin, 3); \ + break; \ + case 4: \ + DISPATCH_KERN_W(nonlin, 4); \ + break; \ + case 5: \ + DISPATCH_KERN_W(nonlin, 5); \ + break; \ + case 6: \ + DISPATCH_KERN_W(nonlin, 6); \ + break; \ + case 7: \ + DISPATCH_KERN_W(nonlin, 7); \ + break; \ + case 8: \ + DISPATCH_KERN_W(nonlin, 8); \ + break; \ + } \ + } while (0) + +#endif // HEIGHT_EQUALS_WITH_WEIGHT + switch (nonlineMode) { case IDENTITY: DISPATCH_KERN_H(Identity); - break; + break; case RELU: DISPATCH_KERN_H(Relu); - break; + break; case SIGMOID: DISPATCH_KERN_H(Sigmoid); - break; + break; } - megdnn_assert(f, "Start_gpu_xcorr_pool: unsupported conv-pooling configuration. \ + megdnn_assert( + f, + "Start_gpu_xcorr_pool: unsupported conv-pooling configuration. \ pool_shape_h %zu, pool_shape_w %zu, kern_h %d, kern_w %d\n", - pool_shape_h, pool_shape_w, kern_h, kern_w); + pool_shape_h, pool_shape_w, kern_h, kern_w); cudaTextureObject_t m_tex = 0; size_t input_size = N * IC * IH * IW; // Case 1: Size of input data is less than // the limit of cudaTextureObject_t. - if(input_size < MAX_TEX_OBJ_SIZE) { - dim3 grid_dim(nr_batch, nr_oc, oplane_nr_split), - block_dim(nr_thread_per_block); - create_cuda_tex(input, m_tex, N, IC, IH, IW); + if (input_size < MAX_TEX_OBJ_SIZE) { + dim3 grid_dim(nr_batch, nr_oc, oplane_nr_split), block_dim(nr_thread_per_block); + create_cuda_tex(input, m_tex, N, IC, IH, IW); f<<>>( - input, kernel, output, bias, m_tex, - IC, IH, IW, OH, OW); + input, kernel, output, bias, m_tex, IC, IH, IW, OH, OW); } // Case 2: Size of input data reached // the limit of cudaTextureObject_t (2^27 Bytes). else { - size_t input_stride = IC * IH * IW, - output_stride = OC * OH * OW; + size_t input_stride = IC * IH * IW, output_stride = OC * OH * OW; int batch_size = MAX_TEX_OBJ_SIZE / input_stride; - float *input_base = input; - float *output_base = output; - for(; nr_batch > 0; nr_batch -= batch_size) { + float* input_base = input; + float* output_base = output; + for (; nr_batch > 0; nr_batch -= batch_size) { int cur_batch = nr_batch < batch_size ? nr_batch : batch_size; dim3 grid_dim(cur_batch, nr_oc, oplane_nr_split), - block_dim(nr_thread_per_block); - create_cuda_tex(input_base, m_tex, N, IC, IH, IW); + block_dim(nr_thread_per_block); + create_cuda_tex(input_base, m_tex, N, IC, IH, IW); f<<>>( - input_base, kernel, output_base, bias, m_tex, - IC, IH, IW, OH, OW); + input_base, kernel, output_base, bias, m_tex, IC, IH, IW, OH, OW); input_base += batch_size * input_stride; output_base += batch_size * output_stride; @@ -239,12 +290,12 @@ void start_gpu_xcorr_pool_with_texture_obj( CUDA_CHKERR(cudaDestroyTextureObject(m_tex)); m_tex = 0; - //texinput.destory(); + // texinput.destory(); } -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn #undef CUDA_CHKERR #undef CUDA_CHK_KERN_ERR #undef NR_PXL_PER_THREAD diff --git a/dnn/src/cuda/convpooling/conv_pooling_utils.cuh b/dnn/src/cuda/convpooling/conv_pooling_utils.cuh index 82955c01..b2feed1b 100644 --- a/dnn/src/cuda/convpooling/conv_pooling_utils.cuh +++ b/dnn/src/cuda/convpooling/conv_pooling_utils.cuh @@ -9,25 +9,25 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once -#include "src/cuda/utils.cuh" -#include -#include #include +#include +#include +#include "src/cuda/utils.cuh" //#include "./helper.cuh" - namespace megdnn { namespace cuda { namespace conv_pool { -#define CUDA_CHKERR(call) \ - do { \ - cudaError_t code = (call); \ - megdnn_assert(code == cudaSuccess, "cuda err %d: %s (call %s at %s:%s:%d)", \ - int(code), cudaGetErrorString(code), # call, \ - __FILE__, __func__, __LINE__); \ - } while(0) +#define CUDA_CHKERR(call) \ + do { \ + cudaError_t code = (call); \ + megdnn_assert( \ + code == cudaSuccess, "cuda err %d: %s (call %s at %s:%s:%d)", \ + int(code), cudaGetErrorString(code), #call, __FILE__, __func__, \ + __LINE__); \ + } while (0) #define CUDA_CHK_KERN_ERR CUDA_CHKERR(cudaDeviceSynchronize()); @@ -40,28 +40,24 @@ static inline int __host__ align_to_warp(int n) { // --- Nonline --- struct Relu { - static __device__ float apply(float x) { - return x > 0 ? x : 0; - } + static __device__ float apply(float x) { return x > 0 ? x : 0; } }; struct Sigmoid { static __device__ float apply(float x) { - float exp_value = exp((double) -x); + float exp_value = exp((double)-x); return 1 / (1 + exp_value); } }; struct Identity { - static __device__ float apply(float x) { - return x; - } + static __device__ float apply(float x) { return x; } }; // --- Static Reduce --- -template +template struct StaticReduce { - static __device__ float apply(const float *val) { + static __device__ float apply(const float* val) { const int half = size / 2; return Op::apply( StaticReduce::apply(val), @@ -69,124 +65,107 @@ struct StaticReduce { } }; -template +template struct StaticReduce<1, Op> { - static __device__ float apply(const float *val) { - return val[0]; - } + static __device__ float apply(const float* val) { return val[0]; } }; -template +template struct StaticReduce<2, Op> { - static __device__ float apply(const float *val) { + static __device__ float apply(const float* val) { return Op::apply(val[0], val[1]); } }; struct OpAdd { - static __device__ float apply(float a, float b) { - return a + b; - } + static __device__ float apply(float a, float b) { return a + b; } }; struct OpMax { - static __device__ float apply(float a, float b) { - return max(a, b); - } + static __device__ float apply(float a, float b) { return max(a, b); } }; struct IdxGetterConvolution { static inline __device__ int apply(int kern, int i, int p) { return kern - i - 1 + p; } - }; struct IdxGetterCorrRel { - static inline __device__ int apply(int kern, int i, int p) { - return i - p; - } + static inline __device__ int apply(int kern, int i, int p) { return i - p; } }; - // --- Pooling --- struct MeanPooler { - template - static __device__ float apply(const float *val) { + template + static __device__ float apply(const float* val) { const int size = pool_shape_h * pool_shape_w; return StaticReduce::apply(val) / size; } }; struct MaxPooler { - template - static __device__ float apply(const float *val) { + template + static __device__ float apply(const float* val) { return StaticReduce::apply(val); } }; - - // --- Reader --- +// --- Reader --- class Tex1DReader { cudaTextureObject_t m_tex; int m_base_offset, m_chl_stride, m_row_stride, m_row_offset; - //size_t batch_, chal_, height_, weight_; - - public: - // Set attributes of texture Object - /*__device__ void init(cudaTextureObject_t& tex, - size_t batch, size_t chal, size_t height, size_t weight) { - batch_ = batch; - chal_ = chal; - height_ = height; - weight_ = weight; - m_chl_stride = height * weight; - m_row_stride = weight; - } - - __device__ void set_pos(cudaTextureObject_t& tex, - // Current position - size_t n, size_t c, size_t h, size_t w) { - m_tex = tex; - m_base_offset = ((n * chal_ + c) * height_ + h) * weight_ + w; - } - */ - __device__ void set_pos(cudaTextureObject_t& tex, + // size_t batch_, chal_, height_, weight_; + +public: + // Set attributes of texture Object + /*__device__ void init(cudaTextureObject_t& tex, + size_t batch, size_t chal, size_t height, size_t weight) { + batch_ = batch; + chal_ = chal; + height_ = height; + weight_ = weight; + m_chl_stride = height * weight; + m_row_stride = weight; + } + + __device__ void set_pos(cudaTextureObject_t& tex, + // Current position + size_t n, size_t c, size_t h, size_t w) { + m_tex = tex; + m_base_offset = ((n * chal_ + c) * height_ + h) * weight_ + w; + } + */ + __device__ void set_pos( + cudaTextureObject_t& tex, // Current position int chal, int height, int weight, int n, int c, int h, int w) { - m_chl_stride = height * weight; - m_row_stride = weight; - m_tex = tex; - m_base_offset = ((n * chal + c) * height + h) * weight + w; - } - - __device__ void reset_row() { - m_row_offset = m_base_offset; - } - - __device__ void next_row() { - m_row_offset += m_row_stride; - } - - __device__ void next_channel() { - m_base_offset += m_chl_stride; - } - - __device__ float get(int /*dr*/, int dc) { - return tex1Dfetch(m_tex, dc + m_row_offset); - } - - __device__ float get(int idx) { - return tex1Dfetch(m_tex, idx + m_base_offset); - } -}; + m_chl_stride = height * weight; + m_row_stride = weight; + m_tex = tex; + m_base_offset = ((n * chal + c) * height + h) * weight + w; + } + + __device__ void reset_row() { m_row_offset = m_base_offset; } - extern __host__ void create_cuda_tex(float *input, cudaTextureObject_t& tex, - size_t N, size_t IC, size_t IH, size_t IW); + __device__ void next_row() { m_row_offset += m_row_stride; } + __device__ void next_channel() { m_base_offset += m_chl_stride; } + + __device__ float get(int /*dr*/, int dc) { + return tex1Dfetch(m_tex, dc + m_row_offset); + } + + __device__ float get(int idx) { + return tex1Dfetch(m_tex, idx + m_base_offset); + } +}; +extern __host__ void create_cuda_tex( + float* input, cudaTextureObject_t& tex, size_t N, size_t IC, size_t IH, + size_t IW); -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl.h b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl.h index 90dc5fb2..aa05fc75 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl.h +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl.h @@ -15,35 +15,30 @@ namespace megdnn { namespace cuda { namespace conv_pool { -typedef void (*kern_corr_pointer) (float *input, - const float *filter, - float *output, - const float *output_bias, - cudaTextureObject_t m_tex, - int IC, int IH, int IW, - int OH, int OW); +typedef void (*kern_corr_pointer)( + float* input, const float* filter, float* output, const float* output_bias, + cudaTextureObject_t m_tex, int IC, int IH, int IW, int OH, int OW); #include "./kern_corr_func_macro.inc" -#define DISPATCH_POOLMODE(nonlin, kern_size, pool_size, idx_getter) \ - KERN_CORR_DEFINE(nonlin, kern_size, kern_size, pool_size, pool_size, \ - idx_getter, MeanPooler) \ - KERN_CORR_DEFINE(nonlin, kern_size, kern_size, pool_size, pool_size, \ - idx_getter, MaxPooler) \ +#define DISPATCH_POOLMODE(nonlin, kern_size, pool_size, idx_getter) \ + KERN_CORR_DEFINE( \ + nonlin, kern_size, kern_size, pool_size, pool_size, idx_getter, \ + MeanPooler) \ + KERN_CORR_DEFINE( \ + nonlin, kern_size, kern_size, pool_size, pool_size, idx_getter, MaxPooler) - -#define DISPATCH_CONVMODE(nonlin, kern_size, pool_size) \ - DISPATCH_POOLMODE(nonlin, kern_size, pool_size, IdxGetterConvolution) \ - DISPATCH_POOLMODE(nonlin, kern_size, pool_size, IdxGetterCorrRel) \ +#define DISPATCH_CONVMODE(nonlin, kern_size, pool_size) \ + DISPATCH_POOLMODE(nonlin, kern_size, pool_size, IdxGetterConvolution) \ + DISPATCH_POOLMODE(nonlin, kern_size, pool_size, IdxGetterCorrRel) #define DISPATCH_POOLSHAPE(nonlin, kern_size) \ - DISPATCH_CONVMODE(nonlin, kern_size, 1) \ - DISPATCH_CONVMODE(nonlin, kern_size, 2) \ - DISPATCH_CONVMODE(nonlin, kern_size, 3) \ + DISPATCH_CONVMODE(nonlin, kern_size, 1) \ + DISPATCH_CONVMODE(nonlin, kern_size, 2) \ + DISPATCH_CONVMODE(nonlin, kern_size, 3) \ DISPATCH_CONVMODE(nonlin, kern_size, 4) - -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize1.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize1.cu index 37028210..bc6d7563 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize1.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize1.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Identity, 1) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize2.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize2.cu index 7c77397e..30a7f3f8 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize2.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize2.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Identity, 2) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize3.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize3.cu index eb72747e..cec61b33 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize3.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize3.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Identity, 3) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize4.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize4.cu index 4d9e0773..23c589e8 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize4.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize4.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Identity, 4) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize5.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize5.cu index df5b23ef..480b702d 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize5.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize5.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Identity, 5) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize6.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize6.cu index ec4fbb7d..b9f3050b 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize6.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize6.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Identity, 6) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize7.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize7.cu index 4ad72904..aaee7291 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize7.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_identity_ksize7.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Identity, 7) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize1.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize1.cu index f610b03b..4020225a 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize1.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize1.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Relu, 1) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize2.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize2.cu index 32831048..a257e9e7 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize2.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize2.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Relu, 2) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize3.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize3.cu index aaa379cc..3f9b9f4f 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize3.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize3.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Relu, 3) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize4.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize4.cu index cedc0cfb..e0f1ad96 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize4.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize4.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Relu, 4) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize5.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize5.cu index 0a6c8d4e..e70438d7 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize5.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize5.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Relu, 5) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize6.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize6.cu index f3c314de..3269909a 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize6.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize6.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Relu, 6) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize7.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize7.cu index 41418fff..5dad991b 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize7.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_relu_ksize7.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Relu, 7) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize1.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize1.cu index 637b9ba5..c3713821 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize1.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize1.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Sigmoid, 1) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize2.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize2.cu index 68160a27..25b9b585 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize2.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize2.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Sigmoid, 2) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize3.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize3.cu index 4c056d2f..f9112024 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize3.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize3.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Sigmoid, 3) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize4.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize4.cu index 20e04d17..6d61c9c4 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize4.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize4.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Sigmoid, 4) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize5.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize5.cu index 96c02bc6..b98a30ce 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize5.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize5.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Sigmoid, 5) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize6.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize6.cu index e65e7b8b..cec3f441 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize6.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize6.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Sigmoid, 6) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize7.cu b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize7.cu index 4227167c..312efc48 100644 --- a/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize7.cu +++ b/dnn/src/cuda/convpooling/kernel_impl/kernel_impl_sigmoid_ksize7.cu @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./kernel_impl.h" #include "../conv_pooling_utils.cuh" +#include "./kernel_impl.h" namespace megdnn { namespace cuda { @@ -17,7 +17,7 @@ namespace conv_pool { DISPATCH_POOLSHAPE(Sigmoid, 7) -} // namespace conv_pool -} // namespace cuda -} // namespace megdnn +} // namespace conv_pool +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/opr_impl.cpp b/dnn/src/cuda/convpooling/opr_impl.cpp index 7a62336e..1273dcad 100644 --- a/dnn/src/cuda/convpooling/opr_impl.cpp +++ b/dnn/src/cuda/convpooling/opr_impl.cpp @@ -10,60 +10,61 @@ */ #include "src/cuda/convpooling/opr_impl.h" #include "src/cuda/convpooling/conv_pooling.h" -#include "src/cuda/utils.h" #include "src/cuda/handle.h" +#include "src/cuda/utils.h" namespace megdnn { namespace cuda { using namespace conv_pool; -void get_dest_shape(size_t ih, size_t iw, size_t fh, size_t fw, - size_t sh, size_t sw, size_t ph, size_t pw, - size_t &oh, size_t &ow, bool is_floor = true) -{ - megdnn_assert(ih+2*ph >= fh, "input height=%zu, padding height=%zu, " - "filter height=%zu", ih, ph, fh); - megdnn_assert(iw+2*pw >= fw, "input width=%zu, padding width=%zu, " - "filter width=%zu", iw, pw, fw); +void get_dest_shape( + size_t ih, size_t iw, size_t fh, size_t fw, size_t sh, size_t sw, size_t ph, + size_t pw, size_t& oh, size_t& ow, bool is_floor = true) { + megdnn_assert( + ih + 2 * ph >= fh, + "input height=%zu, padding height=%zu, " + "filter height=%zu", + ih, ph, fh); + megdnn_assert( + iw + 2 * pw >= fw, + "input width=%zu, padding width=%zu, " + "filter width=%zu", + iw, pw, fw); megdnn_assert(sh && sw, "invalid stride setting: (%zu, %zu)", sh, sw); if (is_floor) { - oh = (ih+2*ph-fh)/sh + 1; - ow = (iw+2*pw-fw)/sw + 1; + oh = (ih + 2 * ph - fh) / sh + 1; + ow = (iw + 2 * pw - fw) / sw + 1; } else { - oh = (ih+2*ph-fh+sh-1)/sh + 1; - ow = (iw+2*pw-fw+sw-1)/sw + 1; + oh = (ih + 2 * ph - fh + sh - 1) / sh + 1; + ow = (iw + 2 * pw - fw + sw - 1) / sw + 1; } } -ConvPoolingForwardImpl::ConvPoolingForwardImpl(Handle *handle): - ConvPoolingForward(handle) { +ConvPoolingForwardImpl::ConvPoolingForwardImpl(Handle* handle) + : ConvPoolingForward(handle) { return; } -size_t ConvPoolingForwardImpl::get_workspace_in_bytes(const TensorLayout & /*src*/, - const TensorLayout & /*filter*/, - const TensorLayout & /*bias*/, - const TensorLayout & /*dst*/) { +size_t ConvPoolingForwardImpl::get_workspace_in_bytes( + const TensorLayout& /*src*/, const TensorLayout& /*filter*/, + const TensorLayout& /*bias*/, const TensorLayout& /*dst*/) { return 0; } void ConvPoolingForwardImpl::deduce_layout( - const TensorLayout & srcl, - const TensorLayout & filterl, - const TensorLayout & /*bias*/, - TensorLayout & dstl) { - + const TensorLayout& srcl, const TensorLayout& filterl, + const TensorLayout& /*bias*/, TensorLayout& dstl) { megdnn_assert_contiguous(srcl); megdnn_assert_contiguous(filterl); - auto &src = srcl.shape; - auto &filter = filterl.shape; - //auto &wsp = workspace.shape; - //wsp = TensorShape({0, 0, 0, 0}); - //megdnn_assert(src.ndim == 4_z, "%s", errmsg_c); - //megdnn_assert(filter.ndim == 4_z, "%s", errmsg_c); + auto& src = srcl.shape; + auto& filter = filterl.shape; + // auto &wsp = workspace.shape; + // wsp = TensorShape({0, 0, 0, 0}); + // megdnn_assert(src.ndim == 4_z, "%s", errmsg_c); + // megdnn_assert(filter.ndim == 4_z, "%s", errmsg_c); megdnn_assert(srcl.ndim == 4_z, "%s", "src.ndim != 4"); megdnn_assert(filterl.ndim == 4_z, "%s", "filter.ndim != 4"); - size_t n = src[0]; + size_t n = src[0]; size_t ic = src[1]; size_t ih = src[2]; size_t iw = src[3]; @@ -83,25 +84,21 @@ void ConvPoolingForwardImpl::deduce_layout( size_t poolw = this->param().pool_shape_w; size_t conv_oh, conv_ow, oh, ow; // Shape of the output of convoluation. - get_dest_shape(ih, iw, fh, fw, conv_sh, conv_sw, - conv_ph, conv_pw, conv_oh, conv_ow); + get_dest_shape( + ih, iw, fh, fw, conv_sh, conv_sw, conv_ph, conv_pw, conv_oh, conv_ow); // Shape of the output of pooling. - get_dest_shape(conv_oh, conv_ow, poolh, poolw, - pool_sh, pool_sw, pool_ph, pool_pw, oh, ow); + get_dest_shape( + conv_oh, conv_ow, poolh, poolw, pool_sh, pool_sw, pool_ph, pool_pw, oh, ow); dstl = TensorLayout(TensorShape{n, oc, oh, ow}, srcl.dtype); - //workspace = Workspace(NULL, 0); - //workspace.gen_default_stride(); + // workspace = Workspace(NULL, 0); + // workspace.gen_default_stride(); } -void ConvPoolingForwardImpl::check_layout ( - const TensorLayout & src, - const TensorLayout & filter, - const TensorLayout & bias, - TensorLayout & dst, - size_t /* workspace_limit_in_bytes */ - ) { - +void ConvPoolingForwardImpl::check_layout( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& bias, + TensorLayout& dst, size_t /* workspace_limit_in_bytes */ +) { TensorLayout dst_expected; deduce_layout(src, filter, bias, dst_expected); megdnn_assert_eq_layout(dst_expected, dst); @@ -110,14 +107,13 @@ void ConvPoolingForwardImpl::check_layout ( megdnn_assert(dst.shape[1] == filter.shape[0]); } -void ConvPoolingForwardImpl::exec(const _megdnn_in TensorND src, - const _megdnn_in TensorND filter, - const _megdnn_in TensorND bias, - _megdnn_out TensorND dst, - _megdnn_out Workspace workspace) { - check_layout(src.layout, filter.layout, bias.layout, dst.layout, workspace.size); - auto stream = cuda_stream(this->handle()); - size_t N = src.layout.shape[0]; +void ConvPoolingForwardImpl::exec( + const _megdnn_in TensorND src, const _megdnn_in TensorND filter, + const _megdnn_in TensorND bias, _megdnn_out TensorND dst, + _megdnn_out Workspace workspace) { + check_layout(src.layout, filter.layout, bias.layout, dst.layout, workspace.size); + auto stream = cuda_stream(this->handle()); + size_t N = src.layout.shape[0]; size_t IC = src.layout.shape[1]; size_t IH = src.layout.shape[2]; size_t IW = src.layout.shape[3]; @@ -134,83 +130,80 @@ void ConvPoolingForwardImpl::exec(const _megdnn_in TensorND src, size_t POOL_H = this->param().pool_shape_h; size_t POOL_W = this->param().pool_shape_w; - PoolModeCu poolMode; - switch(this->param().poolMode) { + PoolModeCu poolMode; + switch (this->param().poolMode) { case Param::PoolMode::AVERAGE: poolMode = AVERAGE; - break; + break; case Param::PoolMode::MAX: poolMode = MAX; - break; + break; default: poolMode = AVERAGE; } ConvModeCu convMode; - switch(this->param().convMode) { + switch (this->param().convMode) { case Param::ConvMode::CROSS_CORRELATION: convMode = CROSS_CORRELATION; - break; + break; case Param::ConvMode::CONVOLUTION: convMode = CONVOLUTION; - break; + break; default: convMode = CROSS_CORRELATION; } NonlineModeCu nonlineMode; - switch(this->param().nonlineMode) { + switch (this->param().nonlineMode) { case Param::NonlineMode::IDENTITY: nonlineMode = IDENTITY; - break; + break; case Param::NonlineMode::RELU: nonlineMode = RELU; - break; + break; case Param::NonlineMode::SIGMOID: nonlineMode = SIGMOID; - break; + break; default: nonlineMode = IDENTITY; } float *src_ptr = static_cast(src.raw_ptr), - *filter_ptr = static_cast(filter.raw_ptr), - *bias_ptr = static_cast(bias.raw_ptr), - *dst_ptr = static_cast(dst.raw_ptr); - - switch (this->param().method) { - case Param::Method::WITH_SHARED_MEM: - // This method is out-of-date. - /* + *filter_ptr = static_cast(filter.raw_ptr), + *bias_ptr = static_cast(bias.raw_ptr), + *dst_ptr = static_cast(dst.raw_ptr); + + switch (this->param().method) { + case Param::Method::WITH_SHARED_MEM: + // This method is out-of-date. + /* start_gpu_xcorr_pool_with_shared_mem(stream, src_ptr, filter_ptr, dst_ptr, - N, IC, IH, IW, OC, OH, OW, - FH, FW, CONV_PH, CONV_PW, CONV_SH, CONV_SW, - this->param().pool_shape_w, - poolMode, - this->param().relu, - bias_ptr); - - break; + N, IC, IH, IW, OC, OH, OW, + FH, FW, CONV_PH, CONV_PW, CONV_SH, CONV_SW, + this->param().pool_shape_w, + poolMode, + this->param().relu, + bias_ptr); + + break; */ - case Param::Method::WITH_TEXTURE_OBJ: - start_gpu_xcorr_pool_with_texture_obj(stream, src_ptr, filter_ptr, dst_ptr, - N, IC, IH, IW, OC, OH, OW, - FH, FW, CONV_PH, CONV_PW, CONV_SH, CONV_SW, - POOL_H, POOL_W, - poolMode, convMode, nonlineMode, bias_ptr); - break; - - default: - start_gpu_xcorr_pool_with_texture_obj(stream, src_ptr, filter_ptr, dst_ptr, - N, IC, IH, IW, OC, OH, OW, - FH, FW, CONV_PH, CONV_PW, CONV_SH, CONV_SW, - POOL_H, POOL_W, - poolMode, convMode, nonlineMode, bias_ptr); - } -} + case Param::Method::WITH_TEXTURE_OBJ: + start_gpu_xcorr_pool_with_texture_obj( + stream, src_ptr, filter_ptr, dst_ptr, N, IC, IH, IW, OC, OH, OW, FH, + FW, CONV_PH, CONV_PW, CONV_SH, CONV_SW, POOL_H, POOL_W, poolMode, + convMode, nonlineMode, bias_ptr); + break; + default: + start_gpu_xcorr_pool_with_texture_obj( + stream, src_ptr, filter_ptr, dst_ptr, N, IC, IH, IW, OC, OH, OW, FH, + FW, CONV_PH, CONV_PW, CONV_SH, CONV_SW, POOL_H, POOL_W, poolMode, + convMode, nonlineMode, bias_ptr); + } +} -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/convpooling/opr_impl.h b/dnn/src/cuda/convpooling/opr_impl.h index 2590fe55..032df7e6 100644 --- a/dnn/src/cuda/convpooling/opr_impl.h +++ b/dnn/src/cuda/convpooling/opr_impl.h @@ -21,44 +21,38 @@ void start_gpu_xcorr_pool_with_shared_mem( cudaStream_t stream, float *input, const float *kernel, - float *output, + float *output, size_t N, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH, size_t PW, size_t SH, size_t SW, - size_t pool_shape, + size_t pool_shape, PoolModeCu poolMode = AVERAGE, bool relu = true, const float *bias = NULL); */ -class ConvPoolingForwardImpl final: public ConvPoolingForward { - public: - ConvPoolingForwardImpl(Handle *handle); - void exec( const _megdnn_in TensorND src, - const _megdnn_in TensorND filter, - const _megdnn_in TensorND bias, - _megdnn_out TensorND dst, - _megdnn_out Workspace workspace) override; - void deduce_layout( - const TensorLayout & src, - const TensorLayout & filter, - const TensorLayout & bias, - TensorLayout & dst) override; - void check_layout( - const TensorLayout & src, - const TensorLayout & filter, - const TensorLayout & bias, - TensorLayout & dst, - size_t workspace_limit_in_bytes) override; - size_t get_workspace_in_bytes(const TensorLayout & src, - const TensorLayout & filter, - const TensorLayout & bias, - const TensorLayout & dst) override; +class ConvPoolingForwardImpl final : public ConvPoolingForward { +public: + ConvPoolingForwardImpl(Handle* handle); + void exec( + const _megdnn_in TensorND src, const _megdnn_in TensorND filter, + const _megdnn_in TensorND bias, _megdnn_out TensorND dst, + _megdnn_out Workspace workspace) override; + void deduce_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, TensorLayout& dst) override; + void check_layout( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, TensorLayout& dst, + size_t workspace_limit_in_bytes) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& bias, const TensorLayout& dst) override; }; -} // namespace cuda -} // namespace megdnn -// vim: syntax=cpp.doxygen \ No newline at end of file +} // namespace cuda +} // namespace megdnn + // vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/cuda/correlation/correlation_cuda.cu b/dnn/src/cuda/correlation/correlation_cuda.cu index 22f6aa34..bb0cedd7 100644 --- a/dnn/src/cuda/correlation/correlation_cuda.cu +++ b/dnn/src/cuda/correlation/correlation_cuda.cu @@ -27,14 +27,12 @@ namespace correlation { vtid += blockDim.x * gridDim.x) template -__global__ void forward_kernel(const int nthreads, const T* data1, - const T* data2, T* dst, const int bchannels, - const int bheight, const int bwidth, - const int tchannels, const int theight, - const int twidth, const int kernel_size, - const int max_displacement, const int stride1, - const int stride2, const int pad_size, - const bool is_multiply) { +__global__ void forward_kernel( + const int nthreads, const T* data1, const T* data2, T* dst, const int bchannels, + const int bheight, const int bwidth, const int tchannels, const int theight, + const int twidth, const int kernel_size, const int max_displacement, + const int stride1, const int stride2, const int pad_size, + const bool is_multiply) { CUDA_KERNEL_LOOP(idx, nthreads) { int kernel_radius = (kernel_size - 1) / 2; int neighborhood_grid_radius = max_displacement / stride2; @@ -50,10 +48,8 @@ __global__ void forward_kernel(const int nthreads, const T* data1, int y1 = y * stride1 + kernel_radius + max_displacement - pad_size; // get offset of center in image2 - int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * - stride2; - int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * - stride2; + int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * stride2; + int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * stride2; int x2 = x1 + s2o; int y2 = y1 + s2p; @@ -70,20 +66,16 @@ __global__ void forward_kernel(const int nthreads, const T* data1, for (int channel = 0; channel < bchannels; channel++) { T tmp1 = T(0.f); T tmp2 = T(0.f); - if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 && - in_y1 < bheight) { + if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 && in_y1 < bheight) { int idx1 = - ((n * bchannels + channel) * bheight + in_y1) * - bwidth + + ((n * bchannels + channel) * bheight + in_y1) * bwidth + in_x1; tmp1 = data1[idx1]; } - if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 && - in_y2 < bheight) { + if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 && in_y2 < bheight) { int idx2 = - ((n * bchannels + channel) * bheight + in_y2) * - bwidth + + ((n * bchannels + channel) * bheight + in_y2) * bwidth + in_x2; tmp2 = data2[idx2]; } @@ -104,11 +96,11 @@ __global__ void forward_kernel(const int nthreads, const T* data1, template __global__ void backward_kernel_data1( - const int nthreads, const T* diff, const T* data1, const T* data2, - T* grad1, const int bchannels, const int bheight, const int bwidth, - const int tchannels, const int theight, const int twidth, - const int kernel_size, const int max_displacement, const int stride1, - const int stride2, const int pad_size, const bool is_multiply) { + const int nthreads, const T* diff, const T* data1, const T* data2, T* grad1, + const int bchannels, const int bheight, const int bwidth, const int tchannels, + const int theight, const int twidth, const int kernel_size, + const int max_displacement, const int stride1, const int stride2, + const int pad_size, const bool is_multiply) { CUDA_KERNEL_LOOP(idx, nthreads) { int kernel_radius = (kernel_size - 1) / 2; int neighborhood_grid_radius = max_displacement / stride2; @@ -130,57 +122,51 @@ __global__ void backward_kernel_data1( // we show cal the x_min,y_min,x_max,y_max of diff for grad1(x,y) // for diff_x_min, diff_y_min, x,y at the position of right-down // ceil (l - 2*kernel_radius - max_displacement + pad_size) / stride1 - int xmin = (x + pad_size - 2 * kernel_radius - max_displacement + - round_off_s1 - 1) / - stride1 + + int xmin = (x + pad_size - 2 * kernel_radius - max_displacement + round_off_s1 - + 1) / stride1 + 1 - round_off; - int ymin = (y + pad_size - 2 * kernel_radius - max_displacement + - round_off_s1 - 1) / - stride1 + + int ymin = (y + pad_size - 2 * kernel_radius - max_displacement + round_off_s1 - + 1) / stride1 + 1 - round_off; // floor (l - max_displacement + pad_size) / stride1 - int xmax = (x + pad_size - max_displacement + round_off_s1) / stride1 - - round_off; - int ymax = (y + pad_size - max_displacement + round_off_s1) / stride1 - - round_off; + int xmax = + (x + pad_size - max_displacement + round_off_s1) / stride1 - round_off; + int ymax = + (y + pad_size - max_displacement + round_off_s1) / stride1 - round_off; T sum = T(0.f); - if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && - (ymin <= theight - 1)) { + if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && (ymin <= theight - 1)) { xmin = max(0, xmin); xmax = min(twidth - 1, xmax); ymin = max(0, ymin); ymax = min(theight - 1, ymax); - for (int p = -neighborhood_grid_radius; - p <= neighborhood_grid_radius; p++) { - for (int o = -neighborhood_grid_radius; - o <= neighborhood_grid_radius; o++) { + for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; + p++) { + for (int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; + o++) { // Get bottom1 data: int s2o = stride2 * o; int s2p = stride2 * p; int x2 = x + s2o, y2 = y + s2p; - int idx2 = - ((n * bchannels + c) * bheight + y2) * bwidth + x2; + int idx2 = ((n * bchannels + c) * bheight + y2) * bwidth + x2; T tmp2 = T(0.f); if (x2 >= 0 && x2 < bwidth && y2 >= 0 && y2 < bheight) { tmp2 = data2[idx2]; } - int op = (p + neighborhood_grid_radius) * - neighborhood_grid_width + + int op = (p + neighborhood_grid_radius) * neighborhood_grid_width + (o + neighborhood_grid_radius); int diff_channels_offset = (n * tchannels + op); for (int diff_y = ymin; diff_y <= ymax; diff_y++) { for (int diff_x = xmin; diff_x <= xmax; diff_x++) { int idxtopdiff = - (diff_channels_offset * theight + diff_y) * - twidth + + (diff_channels_offset * theight + diff_y) * twidth + diff_x; if (is_multiply) { @@ -203,11 +189,11 @@ __global__ void backward_kernel_data1( template __global__ void backward_kernel_data2( - const int nthreads, const T* diff, const T* data1, const T* data2, - T* grad2, const int bchannels, const int bheight, const int bwidth, - const int tchannels, const int theight, const int twidth, - const int kernel_size, const int max_displacement, const int stride1, - const int stride2, const int pad_size, const bool is_multiply) { + const int nthreads, const T* diff, const T* data1, const T* data2, T* grad2, + const int bchannels, const int bheight, const int bwidth, const int tchannels, + const int theight, const int twidth, const int kernel_size, + const int max_displacement, const int stride1, const int stride2, + const int pad_size, const bool is_multiply) { CUDA_KERNEL_LOOP(idx, nthreads) { int kernel_radius = (kernel_size - 1) / 2; int neighborhood_grid_radius = max_displacement / stride2; @@ -222,10 +208,9 @@ __global__ void backward_kernel_data2( T sum = T(0.f); - for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; - p++) { - for (int o = -neighborhood_grid_radius; - o <= neighborhood_grid_radius; o++) { + for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) { + for (int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; + o++) { int s2o = o * stride2; int s2p = p * stride2; @@ -235,19 +220,17 @@ __global__ void backward_kernel_data2( const int round_off = ROUND_OFF; const int round_off_s1 = stride1 * round_off; - int xmin = (x1 + pad_size - 2 * kernel_radius - - max_displacement + round_off_s1 - 1) / + int xmin = (x1 + pad_size - 2 * kernel_radius - max_displacement + + round_off_s1 - 1) / stride1 + 1 - round_off; - int ymin = (y1 + pad_size - 2 * kernel_radius - - max_displacement + round_off_s1 - 1) / + int ymin = (y1 + pad_size - 2 * kernel_radius - max_displacement + + round_off_s1 - 1) / stride1 + 1 - round_off; - int xmax = (x1 + pad_size - max_displacement + round_off_s1) / - stride1 - + int xmax = (x1 + pad_size - max_displacement + round_off_s1) / stride1 - round_off; - int ymax = (y1 + pad_size - max_displacement + round_off_s1) / - stride1 - + int ymax = (y1 + pad_size - max_displacement + round_off_s1) / stride1 - round_off; if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && @@ -258,22 +241,19 @@ __global__ void backward_kernel_data2( ymin = max(0, ymin); ymax = min(theight - 1, ymax); - int idx1 = - ((n * bchannels + c) * bheight + y1) * bwidth + x1; + int idx1 = ((n * bchannels + c) * bheight + y1) * bwidth + x1; T tmp1 = T(0.f); if (x1 >= 0 && x1 < bwidth && y1 >= 0 && y1 < bheight) { tmp1 = data1[idx1]; } - int op = (p + neighborhood_grid_radius) * - neighborhood_grid_width + + int op = (p + neighborhood_grid_radius) * neighborhood_grid_width + (o + neighborhood_grid_radius); int diff_channels_offset = (n * tchannels + op); for (int diff_y = ymin; diff_y <= ymax; diff_y++) { for (int diff_x = xmin; diff_x <= xmax; diff_x++) { int idxtopdiff = - (diff_channels_offset * theight + diff_y) * - twidth + + (diff_channels_offset * theight + diff_y) * twidth + diff_x; if (is_multiply) { @@ -295,76 +275,71 @@ __global__ void backward_kernel_data2( } template -void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst, - const int bchannels, const int bheight, const int bwidth, - const int tchannels, const int theight, const int twidth, - const int kernel_size, const int max_displacement, - const int stride1, const int stride2, const int pad_size, - const bool is_multiply, cudaStream_t stream) { +void forward_proxy( + const int nthreads, const T* data1, const T* data2, T* dst, const int bchannels, + const int bheight, const int bwidth, const int tchannels, const int theight, + const int twidth, const int kernel_size, const int max_displacement, + const int stride1, const int stride2, const int pad_size, + const bool is_multiply, cudaStream_t stream) { int threads_block = query_blocksize_for_kernel(forward_kernel); - forward_kernel - <<>>( - nthreads, data1, data2, dst, bchannels, bheight, bwidth, - tchannels, theight, twidth, kernel_size, max_displacement, - stride1, stride2, pad_size, is_multiply); + forward_kernel<<>>( + nthreads, data1, data2, dst, bchannels, bheight, bwidth, tchannels, theight, + twidth, kernel_size, max_displacement, stride1, stride2, pad_size, + is_multiply); after_kernel_launch(); } template -void backward_proxy_data1(const int nthreads, const T* diff, const T* data1, - const T* data2, T* grad1, const int bchannels, - const int bheight, const int bwidth, - const int tchannels, const int theight, - const int twidth, const int kernel_size, - const int max_displacement, const int stride1, - const int stride2, const int pad_size, - const bool is_multiply, cudaStream_t stream) { +void backward_proxy_data1( + const int nthreads, const T* diff, const T* data1, const T* data2, T* grad1, + const int bchannels, const int bheight, const int bwidth, const int tchannels, + const int theight, const int twidth, const int kernel_size, + const int max_displacement, const int stride1, const int stride2, + const int pad_size, const bool is_multiply, cudaStream_t stream) { int threads_block = query_blocksize_for_kernel(backward_kernel_data1); backward_kernel_data1 <<>>( - nthreads, diff, data1, data2, grad1, bchannels, bheight, - bwidth, tchannels, theight, twidth, kernel_size, - max_displacement, stride1, stride2, pad_size, is_multiply); + nthreads, diff, data1, data2, grad1, bchannels, bheight, bwidth, + tchannels, theight, twidth, kernel_size, max_displacement, stride1, + stride2, pad_size, is_multiply); after_kernel_launch(); } template -void backward_proxy_data2(const int nthreads, const T* diff, const T* data1, - const T* data2, T* grad2, const int bchannels, - const int bheight, const int bwidth, - const int tchannels, const int theight, - const int twidth, const int kernel_size, - const int max_displacement, const int stride1, - const int stride2, const int pad_size, - const bool is_multiply, cudaStream_t stream) { +void backward_proxy_data2( + const int nthreads, const T* diff, const T* data1, const T* data2, T* grad2, + const int bchannels, const int bheight, const int bwidth, const int tchannels, + const int theight, const int twidth, const int kernel_size, + const int max_displacement, const int stride1, const int stride2, + const int pad_size, const bool is_multiply, cudaStream_t stream) { int threads_block = query_blocksize_for_kernel(backward_kernel_data2); backward_kernel_data2 <<>>( - nthreads, diff, data1, data2, grad2, bchannels, bheight, - bwidth, tchannels, theight, twidth, kernel_size, - max_displacement, stride1, stride2, pad_size, is_multiply); + nthreads, diff, data1, data2, grad2, bchannels, bheight, bwidth, + tchannels, theight, twidth, kernel_size, max_displacement, stride1, + stride2, pad_size, is_multiply); after_kernel_launch(); } -#define INST(T) \ - template void forward_proxy( \ - const int, const T*, const T*, T* dst, const int, const int, \ - const int, const int, const int, const int, const int, const int, \ - const int, const int, const int, const bool, cudaStream_t); \ - template void backward_proxy_data1( \ - const int, const T*, const T*, const T*, T*, const int, const int, \ - const int, const int, const int, const int, const int, const int, \ - const int, const int, const int, const bool, cudaStream_t); \ - template void backward_proxy_data2( \ - const int, const T*, const T*, const T*, T*, const int, const int, \ - const int, const int, const int, const int, const int, const int, \ +#define INST(T) \ + template void forward_proxy( \ + const int, const T*, const T*, T* dst, const int, const int, const int, \ + const int, const int, const int, const int, const int, const int, \ + const int, const int, const bool, cudaStream_t); \ + template void backward_proxy_data1( \ + const int, const T*, const T*, const T*, T*, const int, const int, \ + const int, const int, const int, const int, const int, const int, \ + const int, const int, const int, const bool, cudaStream_t); \ + template void backward_proxy_data2( \ + const int, const T*, const T*, const T*, T*, const int, const int, \ + const int, const int, const int, const int, const int, const int, \ const int, const int, const int, const bool, cudaStream_t); INST(dt_float32) INST(dt_float16) INST(dt_bfloat16) #undef INST -} // namespace roi_align +} // namespace correlation } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/correlation/correlation_cuda.cuh b/dnn/src/cuda/correlation/correlation_cuda.cuh index 3562abd0..b261daa8 100644 --- a/dnn/src/cuda/correlation/correlation_cuda.cuh +++ b/dnn/src/cuda/correlation/correlation_cuda.cuh @@ -17,32 +17,28 @@ namespace cuda { namespace correlation { template -void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst, - const int bchannels, const int bheight, const int bwidth, - const int tchannels, const int theight, const int twidth, - const int kernel_size, const int max_displacement, - const int stride1, const int stride2, const int pad_size, - const bool is_multiply, cudaStream_t stream); +void forward_proxy( + const int nthreads, const T* data1, const T* data2, T* dst, const int bchannels, + const int bheight, const int bwidth, const int tchannels, const int theight, + const int twidth, const int kernel_size, const int max_displacement, + const int stride1, const int stride2, const int pad_size, + const bool is_multiply, cudaStream_t stream); template -void backward_proxy_data1(const int nthreads, const T* diff, const T* data1, - const T* data2, T* grad1, const int bchannels, - const int bheight, const int bwidth, - const int tchannels, const int theight, - const int twidth, const int kernel_size, - const int max_displacement, const int stride1, - const int stride2, const int pad_size, - const bool is_multiply, cudaStream_t stream); +void backward_proxy_data1( + const int nthreads, const T* diff, const T* data1, const T* data2, T* grad1, + const int bchannels, const int bheight, const int bwidth, const int tchannels, + const int theight, const int twidth, const int kernel_size, + const int max_displacement, const int stride1, const int stride2, + const int pad_size, const bool is_multiply, cudaStream_t stream); template -void backward_proxy_data2(const int nthreads, const T* diff, const T* data1, - const T* data2, T* grad2, const int bchannels, - const int bheight, const int bwidth, - const int tchannels, const int theight, - const int twidth, const int kernel_size, - const int max_displacement, const int stride1, - const int stride2, const int pad_size, - const bool is_multiply, cudaStream_t stream); +void backward_proxy_data2( + const int nthreads, const T* diff, const T* data1, const T* data2, T* grad2, + const int bchannels, const int bheight, const int bwidth, const int tchannels, + const int theight, const int twidth, const int kernel_size, + const int max_displacement, const int stride1, const int stride2, + const int pad_size, const bool is_multiply, cudaStream_t stream); } // namespace correlation } // namespace cuda diff --git a/dnn/src/cuda/correlation/opr_impl.cpp b/dnn/src/cuda/correlation/opr_impl.cpp index 99a23fb1..dd66cc26 100644 --- a/dnn/src/cuda/correlation/opr_impl.cpp +++ b/dnn/src/cuda/correlation/opr_impl.cpp @@ -17,10 +17,9 @@ namespace megdnn { namespace cuda { -void CorrelationForwardImpl::exec(_megdnn_tensor_in data1, - _megdnn_tensor_in data2, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) { +void CorrelationForwardImpl::exec( + _megdnn_tensor_in data1, _megdnn_tensor_in data2, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { check_exec(data1.layout, data2.layout, dst.layout, workspace.size); auto p = param(); auto stream = cuda_stream(handle()); @@ -38,25 +37,22 @@ void CorrelationForwardImpl::exec(_megdnn_tensor_in data1, int bheight = data1.layout[2], bwidth = data1.layout[3]; using namespace ::megdnn::cuda::correlation; -#define cb(DType) \ - if (data1.layout.dtype == DType()) { \ - using T = typename DTypeTrait::ctype; \ - forward_proxy(nthreads, data1.ptr(), data2.ptr(), \ - dst.ptr(), bchannels, bheight, bwidth, tchannels, \ - theight, twidth, kernel_size, max_displacement, \ - stride1, stride2, pad_size, is_multiply, stream); \ +#define cb(DType) \ + if (data1.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + forward_proxy( \ + nthreads, data1.ptr(), data2.ptr(), dst.ptr(), bchannels, \ + bheight, bwidth, tchannels, theight, twidth, kernel_size, \ + max_displacement, stride1, stride2, pad_size, is_multiply, stream); \ } MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) #undef cb } -void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff, - _megdnn_tensor_in data1, - _megdnn_tensor_in data2, - _megdnn_tensor_out grad1, - _megdnn_workspace workspace) { - check_exec(diff.layout, data1.layout, data2.layout, grad1.layout, - workspace.size); +void CorrelationBackwardData1Impl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out grad1, _megdnn_workspace workspace) { + check_exec(diff.layout, data1.layout, data2.layout, grad1.layout, workspace.size); auto stream = cuda_stream(handle()); int nthreads = grad1.layout.total_nr_elems(); @@ -74,26 +70,23 @@ void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff, using namespace ::megdnn::cuda::correlation; -#define cb(DType) \ - if (diff.layout.dtype == DType()) { \ - using T = typename DTypeTrait::ctype; \ - backward_proxy_data1(nthreads, diff.ptr(), data1.ptr(), \ - data2.ptr(), grad1.ptr(), bchannels, \ - bheight, bwidth, tchannels, theight, twidth, \ - kernel_size, max_displacement, stride1, \ - stride2, pad_size, is_multiply, stream); \ +#define cb(DType) \ + if (diff.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + backward_proxy_data1( \ + nthreads, diff.ptr(), data1.ptr(), data2.ptr(), \ + grad1.ptr(), bchannels, bheight, bwidth, tchannels, theight, \ + twidth, kernel_size, max_displacement, stride1, stride2, pad_size, \ + is_multiply, stream); \ } MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) #undef cb } -void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff, - _megdnn_tensor_in data1, - _megdnn_tensor_in data2, - _megdnn_tensor_out grad2, - _megdnn_workspace workspace) { - check_exec(diff.layout, data1.layout, data2.layout, grad2.layout, - workspace.size); +void CorrelationBackwardData2Impl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out grad2, _megdnn_workspace workspace) { + check_exec(diff.layout, data1.layout, data2.layout, grad2.layout, workspace.size); auto p = param(); auto stream = cuda_stream(handle()); int nthreads = grad2.layout.total_nr_elems(); @@ -111,14 +104,14 @@ void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff, using namespace ::megdnn::cuda::correlation; -#define cb(DType) \ - if (diff.layout.dtype == DType()) { \ - using T = typename DTypeTrait::ctype; \ - backward_proxy_data2(nthreads, diff.ptr(), data1.ptr(), \ - data2.ptr(), grad2.ptr(), bchannels, \ - bheight, bwidth, tchannels, theight, twidth, \ - kernel_size, max_displacement, stride1, \ - stride2, pad_size, is_multiply, stream); \ +#define cb(DType) \ + if (diff.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + backward_proxy_data2( \ + nthreads, diff.ptr(), data1.ptr(), data2.ptr(), \ + grad2.ptr(), bchannels, bheight, bwidth, tchannels, theight, \ + twidth, kernel_size, max_displacement, stride1, stride2, pad_size, \ + is_multiply, stream); \ } MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) #undef cb diff --git a/dnn/src/cuda/correlation/opr_impl.h b/dnn/src/cuda/correlation/opr_impl.h index 0fc31c48..d9e50d26 100644 --- a/dnn/src/cuda/correlation/opr_impl.h +++ b/dnn/src/cuda/correlation/opr_impl.h @@ -20,11 +20,12 @@ namespace cuda { class CorrelationForwardImpl final : public CorrelationForward { public: using CorrelationForward::CorrelationForward; - void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, - _megdnn_tensor_out dst, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& data1, - const TensorLayout& data2, - const TensorLayout& dst) override { + void exec( + _megdnn_tensor_in data1, _megdnn_tensor_in data2, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& data1, const TensorLayout& data2, + const TensorLayout& dst) override { return 0; } }; @@ -32,12 +33,12 @@ public: class CorrelationBackwardData1Impl final : public CorrelationBackwardData1 { public: using CorrelationBackwardData1::CorrelationBackwardData1; - void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, - _megdnn_tensor_in data2, _megdnn_tensor_out grad1, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&, - const TensorLayout&) override { + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out grad1, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { return 0; } }; @@ -45,12 +46,12 @@ public: class CorrelationBackwardData2Impl final : public CorrelationBackwardData2 { public: using CorrelationBackwardData2::CorrelationBackwardData2; - void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, - _megdnn_tensor_in data2, _megdnn_tensor_out grad2, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&, - const TensorLayout&) override { + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, + _megdnn_tensor_out grad2, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { return 0; } }; diff --git a/dnn/src/cuda/cuda_shfl_compat.cuh b/dnn/src/cuda/cuda_shfl_compat.cuh index 898cecea..e85c5148 100644 --- a/dnn/src/cuda/cuda_shfl_compat.cuh +++ b/dnn/src/cuda/cuda_shfl_compat.cuh @@ -11,10 +11,10 @@ #pragma once #if __CUDACC_VER_MAJOR__ >= 9 -#define __shfl(x, y, z) __shfl_sync(0xffffffffu, x, y, z) -#define __shfl_up(x, y, z) __shfl_up_sync(0xffffffffu, x, y, z) +#define __shfl(x, y, z) __shfl_sync(0xffffffffu, x, y, z) +#define __shfl_up(x, y, z) __shfl_up_sync(0xffffffffu, x, y, z) #define __shfl_down(x, y, z) __shfl_down_sync(0xffffffffu, x, y, z) -#define __shfl_xor(x, y, z) __shfl_xor_sync(0xffffffffu, x, y, z) +#define __shfl_xor(x, y, z) __shfl_xor_sync(0xffffffffu, x, y, z) #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/cudnn_wrapper.cpp b/dnn/src/cuda/cudnn_wrapper.cpp index 339e24fa..fced80c0 100644 --- a/dnn/src/cuda/cudnn_wrapper.cpp +++ b/dnn/src/cuda/cudnn_wrapper.cpp @@ -17,8 +17,8 @@ namespace { using namespace megdnn; -cudnnDataType_t to_cudnn_dtype(DType type, - const param::Convolution::Format format = {}) { +cudnnDataType_t to_cudnn_dtype( + DType type, const param::Convolution::Format format = {}) { switch (type.enumv()) { case DTypeEnum::Float32: return CUDNN_DATA_FLOAT; @@ -54,12 +54,11 @@ cudnnDataType_t to_cudnn_dtype(DType type, #endif default: #if CUDNN_MAJOR >= 6 - megdnn_throw("dtype must be float16/float32/int8/int32"); + megdnn_throw("dtype must be float16/float32/int8/int32"); #else - megdnn_throw("dtype must be float16/float32"); + megdnn_throw("dtype must be float16/float32"); #endif } - } cudnnTensorFormat_t to_cudnn_format(const param::Convolution::Format format) { @@ -83,8 +82,7 @@ cudnnTensorFormat_t to_cudnn_format(const param::Convolution::Format format) { namespace megdnn { namespace cuda { -cudnnDataType_t get_compute_type_fp16( - param::Convolution::ComputeMode comp_mode) { +cudnnDataType_t get_compute_type_fp16(param::Convolution::ComputeMode comp_mode) { using Param = param::Convolution; cudnnDataType_t compute_type; if (comp_mode == Param::ComputeMode::DEFAULT) { @@ -119,8 +117,8 @@ TensorDesc::~TensorDesc() { cudnn_check(cudnnDestroyTensorDescriptor(desc)); } -void TensorDesc::set(const TensorLayout& layout, - const param::Convolution::Format format) { +void TensorDesc::set( + const TensorLayout& layout, const param::Convolution::Format format) { // Layout can be not contiguous; group conv needs it. // megdnn_assert_contiguous(layout); if (format == param::Convolution::Format::NCHW4 || @@ -143,25 +141,22 @@ void TensorDesc::set(const TensorLayout& layout, if (format == param::Convolution::Format::NCHW4) { megdnn_assert(layout.is_physical_contiguous()); cudnn_check(cudnnSetTensor4dDescriptor( - desc, to_cudnn_format(format), - to_cudnn_dtype(layout.dtype, format), layout.shape[0], - layout.shape[c_pos] * 4, layout.shape[spatial_pos + 0], + desc, to_cudnn_format(format), to_cudnn_dtype(layout.dtype, format), + layout.shape[0], layout.shape[c_pos] * 4, layout.shape[spatial_pos + 0], layout.shape[spatial_pos + 1])); } else if (format == param::Convolution::Format::NCHW32) { megdnn_assert(layout.is_physical_contiguous()); cudnn_check(cudnnSetTensor4dDescriptor( - desc, to_cudnn_format(format), - to_cudnn_dtype(layout.dtype, format), layout.shape[0], - layout.shape[c_pos] * 32, layout.shape[spatial_pos + 0], - layout.shape[spatial_pos + 1])); + desc, to_cudnn_format(format), to_cudnn_dtype(layout.dtype, format), + layout.shape[0], layout.shape[c_pos] * 32, + layout.shape[spatial_pos + 0], layout.shape[spatial_pos + 1])); } else { cudnn_check(cudnnSetTensor4dDescriptorEx( desc, to_cudnn_dtype(layout.dtype), layout.shape[0], layout.shape[c_pos], layout.shape[spatial_pos + 0], - layout.shape[spatial_pos + 1], layout.stride[0], - layout.stride[c_pos], layout.stride[spatial_pos + 0], - layout.stride[spatial_pos + 1])); + layout.shape[spatial_pos + 1], layout.stride[0], layout.stride[c_pos], + layout.stride[spatial_pos + 0], layout.stride[spatial_pos + 1])); } } @@ -175,11 +170,12 @@ std::string TensorDesc::to_string() { int c_stride; int h_stride; int w_stride; - cudnn_check(cudnnGetTensor4dDescriptor(desc, &data_type, &n, &c, &h, &w, - &n_stride, &c_stride, &h_stride, - &w_stride)); - return ssprintf("", data_type, n, c, h, - w, n_stride, c_stride, h_stride, w_stride); + cudnn_check(cudnnGetTensor4dDescriptor( + desc, &data_type, &n, &c, &h, &w, &n_stride, &c_stride, &h_stride, + &w_stride)); + return ssprintf( + "", data_type, n, c, h, w, n_stride, + c_stride, h_stride, w_stride); } template @@ -200,16 +196,14 @@ std::string FilterDesc::to_string() { int c; int h; int w; - cudnn_check(cudnnGetFilter4dDescriptor(desc, &data_type, &format, &k, &c, - &h, &w)); - return ssprintf("", data_type,format, k, c, h, - w); + cudnn_check(cudnnGetFilter4dDescriptor(desc, &data_type, &format, &k, &c, &h, &w)); + return ssprintf( + "", data_type, format, k, c, h, w); } template void FilterDesc::set( - const typename ConvolutionBase::CanonizedFilterMeta& - filter_meta) { + const typename ConvolutionBase::CanonizedFilterMeta& filter_meta) { megdnn_assert(filter_meta.spatial_ndim == 2); #if CUDNN_VERSION < 7500 megdnn_assert(filter_meta.dilation[0] == 1 && filter_meta.dilation[1] == 1); @@ -247,8 +241,8 @@ ConvDesc::~ConvDesc() { cudnn_check(cudnnDestroyConvolutionDescriptor(desc)); } -void ConvDesc::set(DType data_type, const param::Convolution& param, - const size_t nr_group) { +void ConvDesc::set( + DType data_type, const param::Convolution& param, const size_t nr_group) { using Param = param::Convolution; cudnnConvolutionMode_t mode; switch (param.mode) { @@ -270,8 +264,9 @@ void ConvDesc::set(DType data_type, const param::Convolution& param, auto comp_mode = param.compute_mode; compute_type = get_compute_type_fp16(comp_mode); #if CUDNN_MAJOR >= 7 - } else if (data_type.category() == DTypeCategory::INT || - data_type.category() == DTypeCategory::QUANTIZED) { + } else if ( + data_type.category() == DTypeCategory::INT || + data_type.category() == DTypeCategory::QUANTIZED) { compute_type = CUDNN_DATA_INT32; #endif } else { @@ -304,27 +299,29 @@ LRNDesc::~LRNDesc() { void LRNDesc::set(const param::LRN& param) { megdnn_assert(param.n & 1, "n is %u", param.n); - megdnn_assert(param.n >= CUDNN_LRN_MIN_N, "n is %u, CUDNN_LRN_MIN_N is %d", - param.n, CUDNN_LRN_MIN_N); - megdnn_assert(param.n <= CUDNN_LRN_MAX_N, "n is %u, CUDNN_LRN_MAX_N is %d", - param.n, CUDNN_LRN_MAX_N); - megdnn_assert(param.k >= CUDNN_LRN_MIN_K, "k is %f, CUDNN_LRN_MIN_K is %lf", - param.k, CUDNN_LRN_MIN_K); - megdnn_assert(param.beta >= CUDNN_LRN_MIN_BETA, - "beta is %f, CUDNN_LRN_MIN_BETA is %lf", param.beta, - CUDNN_LRN_MIN_BETA); + megdnn_assert( + param.n >= CUDNN_LRN_MIN_N, "n is %u, CUDNN_LRN_MIN_N is %d", param.n, + CUDNN_LRN_MIN_N); + megdnn_assert( + param.n <= CUDNN_LRN_MAX_N, "n is %u, CUDNN_LRN_MAX_N is %d", param.n, + CUDNN_LRN_MAX_N); + megdnn_assert( + param.k >= CUDNN_LRN_MIN_K, "k is %f, CUDNN_LRN_MIN_K is %lf", param.k, + CUDNN_LRN_MIN_K); + megdnn_assert( + param.beta >= CUDNN_LRN_MIN_BETA, "beta is %f, CUDNN_LRN_MIN_BETA is %lf", + param.beta, CUDNN_LRN_MIN_BETA); // Note that alpha is divided by n in the cudnn implementation, // so we have to multiply alpha by n ahead of time. - cudnn_check(cudnnSetLRNDescriptor(desc, param.n, param.alpha * param.n, - param.beta, param.k)); + cudnn_check(cudnnSetLRNDescriptor( + desc, param.n, param.alpha * param.n, param.beta, param.k)); } BNParamDesc::BNParamDesc() { cudnn_check(cudnnCreateTensorDescriptor(&desc)); } -void BNParamDesc::set(const cudnnTensorDescriptor_t xDesc, - cudnnBatchNormMode_t mode) { +void BNParamDesc::set(const cudnnTensorDescriptor_t xDesc, cudnnBatchNormMode_t mode) { cudnn_check(cudnnDeriveBNTensorDescriptor(desc, xDesc, mode)); } @@ -353,18 +350,18 @@ void Tensor3DDesc::set(const TensorLayout& layout, bool is_ndhwc) { c_pos = 1; spatial_pos = 2; } - const int dimA[] = {sc(layout.shape[0]), sc(layout.shape[c_pos]), - sc(layout.shape[spatial_pos + 0]), - sc(layout.shape[spatial_pos + 1]), - sc(layout.shape[spatial_pos + 2])}; + const int dimA[] = { + sc(layout.shape[0]), sc(layout.shape[c_pos]), + sc(layout.shape[spatial_pos + 0]), sc(layout.shape[spatial_pos + 1]), + sc(layout.shape[spatial_pos + 2])}; - const int strideA[] = {sc(layout.stride[0]), sc(layout.stride[c_pos]), - sc(layout.stride[spatial_pos + 0]), - sc(layout.stride[spatial_pos + 1]), - sc(layout.stride[spatial_pos + 2])}; + const int strideA[] = { + sc(layout.stride[0]), sc(layout.stride[c_pos]), + sc(layout.stride[spatial_pos + 0]), sc(layout.stride[spatial_pos + 1]), + sc(layout.stride[spatial_pos + 2])}; - cudnn_check(cudnnSetTensorNdDescriptor(desc, to_cudnn_dtype(layout.dtype), - 5, dimA, strideA)); + cudnn_check(cudnnSetTensorNdDescriptor( + desc, to_cudnn_dtype(layout.dtype), 5, dimA, strideA)); } Filter3DDesc::Filter3DDesc() { @@ -375,8 +372,7 @@ Filter3DDesc::~Filter3DDesc() { cudnn_check(cudnnDestroyFilterDescriptor(desc)); } -void Filter3DDesc::set( - const Convolution3DBase::CanonizedFilterMeta& filter_meta) { +void Filter3DDesc::set(const Convolution3DBase::CanonizedFilterMeta& filter_meta) { megdnn_assert(filter_meta.spatial_ndim == 3); #if CUDNN_MAJOR <= 6 megdnn_assert(filter_meta.group == 1); @@ -385,8 +381,7 @@ void Filter3DDesc::set( // cuDNN version 6 or below filter_meta.group always is 1. // So it is compatible for all cuDNN versions. const int filterDimA[] = { - sc(filter_meta.ocpg * - filter_meta.group), // cudnn 6 group always be 1 + sc(filter_meta.ocpg * filter_meta.group), // cudnn 6 group always be 1 sc(filter_meta.icpg), sc(filter_meta.spatial[0]), sc(filter_meta.spatial[1]), sc(filter_meta.spatial[2])}; @@ -428,10 +423,10 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { #endif const int padA[] = {sc(param.pad_d), sc(param.pad_h), sc(param.pad_w)}, - filterStrideA[] = {sc(param.stride_d), sc(param.stride_h), - sc(param.stride_w)}, - dilationA[] = {sc(param.dilate_d), sc(param.dilate_h), - sc(param.dilate_w)}; + filterStrideA[] = + {sc(param.stride_d), sc(param.stride_h), sc(param.stride_w)}, + dilationA[] = { + sc(param.dilate_d), sc(param.dilate_h), sc(param.dilate_w)}; // not use true half // in CUDNN_MAJOR < 6, all elements in dilA shoule be 1 cudnn_check(cudnnSetConvolutionNdDescriptor( @@ -441,9 +436,9 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { ////////////////////////// CudnnAlgoPack ////////////////////////// #define V1(v) #v -#define V(v) V1(v) +#define V(v) V1(v) #define DEF_NAME(NAME) \ - #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) +#NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) #define DEF_ALGO(NAME, PROD1, PROD2) \ { \ NAME, { DEF_NAME(NAME), PROD1, PROD2 } \ @@ -455,88 +450,83 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { const std::unordered_map CudnnAlgoPack::conv_bwd_data_algos() { - static const std::unordered_map - algos = - { DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false, false), + static const std::unordered_map + algos = { + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false, false), #if CUDNN_VERSION == 8004 - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, true), #else - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, false), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, false), #endif - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true, true), - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true, true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true, true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true, true), #if CUDNN_MAJOR >= 5 - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true, true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true, true), #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, true, false), + DEF_ALGO( + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, true, false), #endif #endif - }; + }; return algos; } const std::unordered_map CudnnAlgoPack::conv_bwd_flt_algos() { - static const std::unordered_map - algos = { - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false, false), - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true, false), - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true, true), - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false, false), + static const std::unordered_map< + cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr> + algos = + { DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false, false), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true, false), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true, true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false, false), #if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, - true, false), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, true, false), #if CUDNN_MAJOR >= 6 - DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true, - true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true, true), #endif #endif - }; + }; return algos; } -const std::unordered_map -CudnnAlgoPack::conv_fwd_algos() { - static const std::unordered_map - algos = - { DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true, false), +const std::unordered_map CudnnAlgoPack:: + conv_fwd_algos() { + static const std::unordered_map + algos = { + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true, false), #if CUDNN_VERSION == 8004 - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, true), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, true), #else - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, false), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, false), #endif - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true, false), - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true, false), - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true, true), - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true, true), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true, false), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true, false), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true, true), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true, true), #if CUDNN_MAJOR >= 5 - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true, false), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true, false), #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true, false), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true, false), #endif #endif - }; + }; return algos; } const std::unordered_map CudnnAlgoPack::conv3d_bwd_data_algos() { - static const std::unordered_map + static const std::unordered_map algos = { DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false, false), DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, false), - DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true, - true), + DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true, true), }; return algos; @@ -546,8 +536,8 @@ const std::unordered_map CudnnAlgoPack::conv3d_bwd_flt_algos() { #pragma message \ "fp16 dilated conv with odd size filter, only algo_1 works, need focus on doc" - static const std::unordered_map + static const std::unordered_map< + cudnnConvolutionBwdFilterAlgo_t, CudnnAlgoPack::Attr> algos = { DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false, false), DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true, false), @@ -557,18 +547,15 @@ CudnnAlgoPack::conv3d_bwd_flt_algos() { return algos; } -const std::unordered_map -CudnnAlgoPack::conv3d_fwd_algos() { - static const std::unordered_map +const std::unordered_map CudnnAlgoPack:: + conv3d_fwd_algos() { + static const std::unordered_map algos = { DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true, false), #if CUDNN_VERSION == 8004 - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, - true), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, true), #else - DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, - false), + DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, false), #endif DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true, true), }; diff --git a/dnn/src/cuda/cudnn_wrapper.h b/dnn/src/cuda/cudnn_wrapper.h index 33fdb6cd..ef0ab5ab 100644 --- a/dnn/src/cuda/cudnn_wrapper.h +++ b/dnn/src/cuda/cudnn_wrapper.h @@ -21,81 +21,78 @@ namespace cuda { /*! * \brief get compute_type of convolution operations */ -cudnnDataType_t get_compute_type_fp16( - param::Convolution::ComputeMode comp_mode); +cudnnDataType_t get_compute_type_fp16(param::Convolution::ComputeMode comp_mode); class TensorDesc { - public: - TensorDesc(); - //! default layout is nchw - void set(const TensorLayout& layout, const param::Convolution::Format = - param::Convolution::Format::NCHW); - std::string to_string(); - ~TensorDesc(); - cudnnTensorDescriptor_t desc; +public: + TensorDesc(); + //! default layout is nchw + void set( + const TensorLayout& layout, + const param::Convolution::Format = param::Convolution::Format::NCHW); + std::string to_string(); + ~TensorDesc(); + cudnnTensorDescriptor_t desc; }; template class FilterDesc { - public: - FilterDesc(); - void set(const typename ConvolutionBase::CanonizedFilterMeta &meta); - std::string to_string(); - ~FilterDesc(); - cudnnFilterDescriptor_t desc; +public: + FilterDesc(); + void set(const typename ConvolutionBase::CanonizedFilterMeta& meta); + std::string to_string(); + ~FilterDesc(); + cudnnFilterDescriptor_t desc; }; class ConvDesc { - public: - ConvDesc(); - void set(DType data_type, const param::Convolution& param, - const size_t nr_group); - ~ConvDesc(); - cudnnConvolutionDescriptor_t desc; +public: + ConvDesc(); + void set(DType data_type, const param::Convolution& param, const size_t nr_group); + ~ConvDesc(); + cudnnConvolutionDescriptor_t desc; }; class LRNDesc { - public: - LRNDesc(); - void set(const param::LRN ¶m); - ~LRNDesc(); - cudnnLRNDescriptor_t desc; +public: + LRNDesc(); + void set(const param::LRN& param); + ~LRNDesc(); + cudnnLRNDescriptor_t desc; }; - class BNParamDesc { - public: - BNParamDesc(); - void set(const cudnnTensorDescriptor_t xDesc, - cudnnBatchNormMode_t mode); - ~BNParamDesc(); - cudnnTensorDescriptor_t desc; +public: + BNParamDesc(); + void set(const cudnnTensorDescriptor_t xDesc, cudnnBatchNormMode_t mode); + ~BNParamDesc(); + cudnnTensorDescriptor_t desc; }; // the classes below is used to deal with 3d situations class Tensor3DDesc { - public: - Tensor3DDesc(); - //! default layout is NCDHW - void set(const TensorLayout &layout, bool is_ndhwc = false); - ~Tensor3DDesc(); - cudnnTensorDescriptor_t desc; +public: + Tensor3DDesc(); + //! default layout is NCDHW + void set(const TensorLayout& layout, bool is_ndhwc = false); + ~Tensor3DDesc(); + cudnnTensorDescriptor_t desc; }; class Filter3DDesc { - public: - Filter3DDesc(); - void set(const Convolution3DBase::CanonizedFilterMeta &meta); - ~Filter3DDesc(); - cudnnFilterDescriptor_t desc; +public: + Filter3DDesc(); + void set(const Convolution3DBase::CanonizedFilterMeta& meta); + ~Filter3DDesc(); + cudnnFilterDescriptor_t desc; }; class Conv3DDesc { - public: - Conv3DDesc(); - void set(const param::Convolution3D ¶m, const size_t nr_group); - ~Conv3DDesc(); - cudnnConvolutionDescriptor_t desc; +public: + Conv3DDesc(); + void set(const param::Convolution3D& param, const size_t nr_group); + ~Conv3DDesc(); + cudnnConvolutionDescriptor_t desc; }; class CudnnAlgoPack { @@ -113,8 +110,7 @@ public: static const std::unordered_map conv_bwd_flt_algos(); - static const std::unordered_map - conv_fwd_algos(); + static const std::unordered_map conv_fwd_algos(); static const std::unordered_map conv3d_bwd_data_algos(); @@ -122,9 +118,7 @@ public: static const std::unordered_map conv3d_bwd_flt_algos(); - static const std::unordered_map - conv3d_fwd_algos(); - + static const std::unordered_map conv3d_fwd_algos(); }; } // namespace cuda diff --git a/dnn/src/cuda/cumsum/kern.cuh b/dnn/src/cuda/cumsum/kern.cuh index 00363378..bef3a01d 100644 --- a/dnn/src/cuda/cumsum/kern.cuh +++ b/dnn/src/cuda/cumsum/kern.cuh @@ -53,8 +53,9 @@ struct SumOp { * The buffer in *op* and *dst* should not have identical memory addresses. */ template -void run_kern(T* dst, void* workspace, uint32_t workspace_size, uint32_t A, - uint32_t B, uint32_t C, const Op& op, cudaStream_t stream); +void run_kern( + T* dst, void* workspace, uint32_t workspace_size, uint32_t A, uint32_t B, + uint32_t C, const Op& op, cudaStream_t stream); /*! * \brief get required workspace size for cumsum, in bytes @@ -63,8 +64,7 @@ void run_kern(T* dst, void* workspace, uint32_t workspace_size, uint32_t A, * Note: cuda device must be set to the computing device before calling this * function. */ -uint32_t get_workspace_in_bytes(uint32_t A, uint32_t B, uint32_t C, - uint32_t item_size); +uint32_t get_workspace_in_bytes(uint32_t A, uint32_t B, uint32_t C, uint32_t item_size); } // namespace cumsum } // namespace cuda diff --git a/dnn/src/cuda/cumsum/kern_impl.cu b/dnn/src/cuda/cumsum/kern_impl.cu index b40774e4..d16a6c28 100644 --- a/dnn/src/cuda/cumsum/kern_impl.cu +++ b/dnn/src/cuda/cumsum/kern_impl.cu @@ -33,25 +33,23 @@ uint32_t get_workspace_elems_for_cub_1d_with_dtype_reverse(uint32_t nr_item) { ScanOp scan_op; size_t wk_size0 = 0, wk_size1 = 0; - cuda_check(cub::DeviceScan::ExclusiveScan(NULL, wk_size0, inp_iter, - out_iter, scan_op, 0, nr_item)); - cuda_check(cub::DeviceScan::InclusiveScan(NULL, wk_size1, inp_iter, - out_iter, scan_op, nr_item)); + cuda_check(cub::DeviceScan::ExclusiveScan( + NULL, wk_size0, inp_iter, out_iter, scan_op, 0, nr_item)); + cuda_check(cub::DeviceScan::InclusiveScan( + NULL, wk_size1, inp_iter, out_iter, scan_op, nr_item)); return std::max(wk_size0, wk_size1); } template uint32_t get_workspace_elems_for_cub_1d_with_dtype(uint32_t nr_item) { - return std::max(get_workspace_elems_for_cub_1d_with_dtype_reverse( - nr_item), - get_workspace_elems_for_cub_1d_with_dtype_reverse( - nr_item)); + return std::max( + get_workspace_elems_for_cub_1d_with_dtype_reverse(nr_item), + get_workspace_elems_for_cub_1d_with_dtype_reverse(nr_item)); } } // namespace -uint32_t cumsum::get_workspace_bytes_for_cub_1d(uint32_t nr_item, - uint32_t item_size) { +uint32_t cumsum::get_workspace_bytes_for_cub_1d(uint32_t nr_item, uint32_t item_size) { switch (item_size) { #define CASE(size, type) \ case size: \ @@ -66,8 +64,8 @@ uint32_t cumsum::get_workspace_bytes_for_cub_1d(uint32_t nr_item, } } -uint32_t cumsum::get_workspace_in_bytes(uint32_t A, uint32_t B, uint32_t C, - uint32_t item_size) { +uint32_t cumsum::get_workspace_in_bytes( + uint32_t A, uint32_t B, uint32_t C, uint32_t item_size) { if (A == 1 && C == 1) { return get_workspace_bytes_for_cub_1d(B, item_size); } @@ -82,8 +80,8 @@ uint32_t cumsum::get_workspace_in_bytes(uint32_t A, uint32_t B, uint32_t C, return res * item_size; } -void cumsum::get_BX_BY(uint32_t /* A */, uint32_t /* B */, uint32_t C, - uint32_t& BX, uint32_t& BY) { +void cumsum::get_BX_BY( + uint32_t /* A */, uint32_t /* B */, uint32_t C, uint32_t& BX, uint32_t& BY) { BX = 1; while (BX < C && BX * 2 <= 32) BX *= 2; diff --git a/dnn/src/cuda/cumsum/opr_impl.cpp b/dnn/src/cuda/cumsum/opr_impl.cpp index bf2e0416..a98eb458 100644 --- a/dnn/src/cuda/cumsum/opr_impl.cpp +++ b/dnn/src/cuda/cumsum/opr_impl.cpp @@ -25,9 +25,9 @@ namespace { * \brief compute cumsum reduction on (A, B, C) tensor to (A, 1, C) */ template -void dispatch(T* dst, T* workspace, size_t workspace_size, size_t A, size_t B, - size_t C, bool exclusive, bool reverse, const Op& op, - cudaStream_t stream) { +void dispatch( + T* dst, T* workspace, size_t workspace_size, size_t A, size_t B, size_t C, + bool exclusive, bool reverse, const Op& op, cudaStream_t stream) { #define IF(exclusive_v, reverse_v) \ if (exclusive == exclusive_v && reverse == reverse_v) { \ run_kern( \ @@ -44,28 +44,27 @@ void dispatch(T* dst, T* workspace, size_t workspace_size, size_t A, size_t B, } // anonymous namespace -void CumsumForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_workspace workspace) { +void CumsumForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); size_t A, B, C; reduce::get_ABC(src.layout, A, B, C, param().axis); auto stream = cuda_stream(handle()); -#define cb(DType) \ - if (src.layout.dtype == DType()) { \ - using ctype = DTypeTrait::ctype; \ - dispatch>( \ - dst.ptr(), workspace.ptr(), workspace.size, A, \ - B, C, param().exclusive, param().reverse, src.ptr(), \ - stream); \ - return; \ +#define cb(DType) \ + if (src.layout.dtype == DType()) { \ + using ctype = DTypeTrait::ctype; \ + dispatch>( \ + dst.ptr(), workspace.ptr(), workspace.size, A, B, C, \ + param().exclusive, param().reverse, src.ptr(), stream); \ + return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb megdnn_assert_internal(false); } -size_t CumsumForwardImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout&) { +size_t CumsumForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout&) { size_t A, B, C; reduce::get_ABC(src, A, B, C, param().axis); cuda_check(cudaSetDevice(concrete_handle(handle())->device_id())); diff --git a/dnn/src/cuda/cumsum/opr_impl.h b/dnn/src/cuda/cumsum/opr_impl.h index 26b49fbb..c6f26bb9 100644 --- a/dnn/src/cuda/cumsum/opr_impl.h +++ b/dnn/src/cuda/cumsum/opr_impl.h @@ -14,16 +14,16 @@ namespace megdnn { namespace cuda { -class CumsumForwardImpl: public CumsumForward { - public: - using CumsumForward::CumsumForward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &dst) override; +class CumsumForwardImpl : public CumsumForward { +public: + using CumsumForward::CumsumForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) override; }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/cutlass/convolution_operation.h b/dnn/src/cuda/cutlass/convolution_operation.h index c2405b8b..b46b1cc5 100644 --- a/dnn/src/cuda/cutlass/convolution_operation.h +++ b/dnn/src/cuda/cutlass/convolution_operation.h @@ -73,15 +73,14 @@ public: m_description.tile_description.threadblock_stages = Operator::kStages; - m_description.tile_description.warp_count = - make_Coord(Operator::ConvolutionKernel::WarpCount::kM, - Operator::ConvolutionKernel::WarpCount::kN, - Operator::ConvolutionKernel::WarpCount::kK); + m_description.tile_description.warp_count = make_Coord( + Operator::ConvolutionKernel::WarpCount::kM, + Operator::ConvolutionKernel::WarpCount::kN, + Operator::ConvolutionKernel::WarpCount::kK); - m_description.tile_description.math_instruction.instruction_shape = - make_Coord(Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); + m_description.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, Operator::InstructionShape::kN, + Operator::InstructionShape::kK); m_description.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; @@ -100,13 +99,12 @@ public: ArchMap::kMax; - m_description.src = make_TensorDescription( - Operator::kAlignmentSrc); - m_description.filter = - make_TensorDescription( - Operator::kAlignmentFilter); - m_description.dst = make_TensorDescription( - Operator::kAlignmentDst); + m_description.src = + make_TensorDescription(Operator::kAlignmentSrc); + m_description.filter = make_TensorDescription( + Operator::kAlignmentFilter); + m_description.dst = + make_TensorDescription(Operator::kAlignmentDst); m_description.bias = make_TensorDescription( Operator::kAlignmentDst); @@ -116,18 +114,15 @@ public: m_description.epilogue_type = Operator::EpilogueOutputOp::kType; m_description.epilogue_count = Operator::EpilogueOutputOp::kCount; - m_description.threadblock_swizzle = ThreadblockSwizzleMap< - typename Operator::ThreadblockSwizzle>::kId; + m_description.threadblock_swizzle = + ThreadblockSwizzleMap::kId; - m_description.special_optimization = - Operator::kSpecialOpt; + m_description.special_optimization = Operator::kSpecialOpt; m_description.gemm_mode = Operator::kGemmMode; m_description.without_shared_load = Operator::kWithoutSharedLoad; } - virtual OperationDescription const& description() const { - return m_description; - } + virtual OperationDescription const& description() const { return m_description; } protected: ConvolutionDescription m_description; @@ -141,8 +136,8 @@ template struct init_epilogue_param_; template -struct init_epilogue_param_ { +struct init_epilogue_param_< + EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombination> { using ElementCompute = typename EpilogueOp::ElementCompute; typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { return {*static_cast(conv_args->alpha), @@ -180,8 +175,7 @@ struct init_epilogue_param_< template struct init_epilogue_param_< - EpilogueOp, - epilogue::EpilogueType::kBiasAddLinearCombinationReluClamp> { + EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationReluClamp> { using ElementCompute = typename EpilogueOp::ElementCompute; typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { return {*static_cast(conv_args->alpha), @@ -209,8 +203,7 @@ struct init_epilogue_param_< template struct init_epilogue_param_< - EpilogueOp, - epilogue::EpilogueType::kBiasAddLinearCombinationHSwishClamp> { + EpilogueOp, epilogue::EpilogueType::kBiasAddLinearCombinationHSwishClamp> { using ElementCompute = typename EpilogueOp::ElementCompute; typename EpilogueOp::Params get(ConvolutionArguments const* conv_args) { return {*static_cast(conv_args->alpha), @@ -250,9 +243,9 @@ public: ConvolutionOperation(char const* name = "unknown_gemm") : ConvolutionOperationBase(name) {} - virtual Status run(void const* arguments_ptr, - void* device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + virtual Status run( + void const* arguments_ptr, void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const { cutlass::conv::Operator conv_op = this->m_description.conv_op; ConvolutionArguments const* conv_args = reinterpret_cast(arguments_ptr); @@ -263,14 +256,12 @@ public: args.ref_src = { static_cast(const_cast(conv_args->src)), LayoutSrc::packed(implicit_gemm_tensor_a_extent(conv_op, ps))}; - args.ref_filter = {static_cast( - const_cast(conv_args->filter)), - LayoutFilter::packed( - implicit_gemm_tensor_b_extent(conv_op, ps))}; + args.ref_filter = { + static_cast(const_cast(conv_args->filter)), + LayoutFilter::packed(implicit_gemm_tensor_b_extent(conv_op, ps))}; args.ref_bias = { static_cast(const_cast(conv_args->bias)), - LayoutBias::packed( - implicit_gemm_tensor_bias_extent(conv_op, ps))}; + LayoutBias::packed(implicit_gemm_tensor_bias_extent(conv_op, ps))}; args.ref_z = { static_cast(const_cast(conv_args->z)), LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; @@ -278,14 +269,12 @@ public: static_cast(conv_args->dst), LayoutDst::packed(implicit_gemm_tensor_c_extent(conv_op, ps))}; - args.output_op = - init_epilogue_param().get( - conv_args); + args.output_op = init_epilogue_param().get( + conv_args); if (conv_args->extra_param) { - args.extra_param = - *reinterpret_cast( - conv_args->extra_param); + args.extra_param = *reinterpret_cast( + conv_args->extra_param); } Operator op; diff --git a/dnn/src/cuda/cutlass/gemm_operation.h b/dnn/src/cuda/cutlass/gemm_operation.h index 20c5b9a0..565d0b67 100644 --- a/dnn/src/cuda/cutlass/gemm_operation.h +++ b/dnn/src/cuda/cutlass/gemm_operation.h @@ -94,15 +94,14 @@ public: m_description.tile_description.threadblock_stages = Operator::kStages; - m_description.tile_description.warp_count = - make_Coord(Operator::GemmKernel::WarpCount::kM, - Operator::GemmKernel::WarpCount::kN, - Operator::GemmKernel::WarpCount::kK); + m_description.tile_description.warp_count = make_Coord( + Operator::GemmKernel::WarpCount::kM, + Operator::GemmKernel::WarpCount::kN, + Operator::GemmKernel::WarpCount::kK); - m_description.tile_description.math_instruction.instruction_shape = - make_Coord(Operator::InstructionShape::kM, - Operator::InstructionShape::kN, - Operator::InstructionShape::kK); + m_description.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, Operator::InstructionShape::kN, + Operator::InstructionShape::kK); m_description.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; @@ -121,12 +120,12 @@ public: ArchMap::kMax; - m_description.A = make_TensorDescription( - Operator::kAlignmentA); - m_description.B = make_TensorDescription( - Operator::kAlignmentB); - m_description.C = make_TensorDescription( - Operator::kAlignmentC); + m_description.A = + make_TensorDescription(Operator::kAlignmentA); + m_description.B = + make_TensorDescription(Operator::kAlignmentB); + m_description.C = + make_TensorDescription(Operator::kAlignmentC); m_description.stages = Operator::kStages; @@ -134,9 +133,7 @@ public: m_description.split_k_mode = mode(); } - virtual OperationDescription const& description() const { - return m_description; - } + virtual OperationDescription const& description() const { return m_description; } protected: GemmDescription m_description; @@ -162,26 +159,23 @@ public: GemmOperation(char const* name = "unknown_gemm") : GemmOperationBase(name) {} - virtual Status run(void const* arguments_ptr, - void* device_workspace = nullptr, - cudaStream_t stream = nullptr) const { + virtual Status run( + void const* arguments_ptr, void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const { GemmArguments const* gemm_args = reinterpret_cast(arguments_ptr); OperatorArguments args; args.problem_size = gemm_args->problem_size; - args.ref_A = {static_cast(gemm_args->A), - int(gemm_args->lda)}; - args.ref_B = {static_cast(gemm_args->B), - int(gemm_args->ldb)}; - args.ref_C = {static_cast(gemm_args->C), - int(gemm_args->ldc)}; - args.ref_D = {static_cast(gemm_args->D), - int(gemm_args->ldd)}; + args.ref_A = {static_cast(gemm_args->A), int(gemm_args->lda)}; + args.ref_B = {static_cast(gemm_args->B), int(gemm_args->ldb)}; + args.ref_C = {static_cast(gemm_args->C), int(gemm_args->ldc)}; + args.ref_D = {static_cast(gemm_args->D), int(gemm_args->ldd)}; args.split_k_slices = gemm_args->split_k_slices; - args.epilogue = {*static_cast(gemm_args->alpha), - *static_cast(gemm_args->beta)}; + args.epilogue = { + *static_cast(gemm_args->alpha), + *static_cast(gemm_args->beta)}; Operator op; Status status = op.initialize(args, device_workspace); diff --git a/dnn/src/cuda/cutlass/initialize_all.cu b/dnn/src/cuda/cutlass/initialize_all.cu index e314724f..137f78fa 100644 --- a/dnn/src/cuda/cutlass/initialize_all.cu +++ b/dnn/src/cuda/cutlass/initialize_all.cu @@ -49,8 +49,7 @@ namespace library { #define CUTLASS_ARCH_MMA_SM75_SUPPORTED 1 #endif -#if __CUDACC_VER_MAJOR__ > 9 || \ - (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) void initialize_all_gemm_simt_operations(Manifest& manifest); void initialize_all_conv2d_simt_operations(Manifest& manifest); diff --git a/dnn/src/cuda/cutlass/library.h b/dnn/src/cuda/cutlass/library.h index 4d7a5b05..dbc841e6 100644 --- a/dnn/src/cuda/cutlass/library.h +++ b/dnn/src/cuda/cutlass/library.h @@ -183,13 +183,7 @@ enum class ScalarPointerMode { kHost, kDevice, kInvalid }; enum class SplitKMode { kNone, kSerial, kParallel, kParallelSerial, kInvalid }; /// Indicates the classificaition of the math instruction -enum class OpcodeClassID { - kSimt, - kTensorOp, - kWmmaTensorOp, - kSparseTensorOp, - kInvalid -}; +enum class OpcodeClassID { kSimt, kTensorOp, kWmmaTensorOp, kSparseTensorOp, kInvalid }; enum class ArchTagID { kSm50, @@ -292,8 +286,7 @@ struct MathInstructionDescription { // MathInstructionDescription( - cutlass::gemm::GemmCoord instruction_shape = - cutlass::gemm::GemmCoord(), + cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), NumericTypeID element_accumulator = NumericTypeID::kInvalid, OpcodeClassID opcode_class = OpcodeClassID::kInvalid, MathOperationID math_operation = MathOperationID::kMultiplyAdd) @@ -344,14 +337,11 @@ struct TileDescription { // TileDescription( - cutlass::gemm::GemmCoord threadblock_shape = - cutlass::gemm::GemmCoord(), + cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(), int threadblock_stages = 0, cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), - MathInstructionDescription math_instruction = - MathInstructionDescription(), - int minimum_compute_capability = 0, - int maximum_compute_capability = 0) + MathInstructionDescription math_instruction = MathInstructionDescription(), + int minimum_compute_capability = 0, int maximum_compute_capability = 0) : threadblock_shape(threadblock_shape), threadblock_stages(threadblock_stages), warp_count(warp_count), @@ -365,15 +355,12 @@ struct TileDescription { (threadblock_stages == rhs.threadblock_stages) && (warp_count == rhs.warp_count) && (math_instruction == rhs.math_instruction) && - (minimum_compute_capability == - rhs.minimum_compute_capability) && + (minimum_compute_capability == rhs.minimum_compute_capability) && (maximum_compute_capability == rhs.maximum_compute_capability)); } // Inequality operator - inline bool operator!=(TileDescription const& rhs) const { - return !(*this == rhs); - } + inline bool operator!=(TileDescription const& rhs) const { return !(*this == rhs); } }; /// High-level description of an operation @@ -394,8 +381,7 @@ struct OperationDescription { // Methods // OperationDescription( - char const* name = "unknown", - OperationKind kind = OperationKind::kInvalid, + char const* name = "unknown", OperationKind kind = OperationKind::kInvalid, TileDescription const& tile_description = TileDescription()) : name(name), kind(kind), tile_description(tile_description) {} }; @@ -421,10 +407,10 @@ struct TensorDescription { // Methods // - TensorDescription(NumericTypeID element = NumericTypeID::kInvalid, - LayoutTypeID layout = LayoutTypeID::kInvalid, - int alignment = 1, int log_extent_range = 24, - int log_stride_range = 24) + TensorDescription( + NumericTypeID element = NumericTypeID::kInvalid, + LayoutTypeID layout = LayoutTypeID::kInvalid, int alignment = 1, + int log_extent_range = 24, int log_stride_range = 24) : element(element), layout(layout), alignment(alignment), @@ -533,8 +519,9 @@ public: virtual OperationDescription const& description() const = 0; - virtual Status run(void const* arguments, void* device_workspace = nullptr, - cudaStream_t stream = nullptr) const = 0; + virtual Status run( + void const* arguments, void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const = 0; }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/library_internal.h b/dnn/src/cuda/cutlass/library_internal.h index b12a0a52..bd698fae 100644 --- a/dnn/src/cuda/cutlass/library_internal.h +++ b/dnn/src/cuda/cutlass/library_internal.h @@ -196,8 +196,7 @@ struct MathOperationMap { template <> struct MathOperationMap { - static MathOperationID const kId = - MathOperationID::kMultiplyAddGaussianComplex; + static MathOperationID const kId = MathOperationID::kMultiplyAddGaussianComplex; }; template <> @@ -489,37 +488,30 @@ template struct ThreadblockSwizzleMap; template -struct ThreadblockSwizzleMap< - gemm::threadblock::GemmIdentityThreadblockSwizzle> { +struct ThreadblockSwizzleMap> { static ThreadblockSwizzleID const kId = ThreadblockSwizzleID::kGemmIdentity; }; template <> -struct ThreadblockSwizzleMap< - gemm::threadblock::GemmHorizontalThreadblockSwizzle> { - static ThreadblockSwizzleID const kId = - ThreadblockSwizzleID::kGemmHorizontal; +struct ThreadblockSwizzleMap { + static ThreadblockSwizzleID const kId = ThreadblockSwizzleID::kGemmHorizontal; }; template <> -struct ThreadblockSwizzleMap< - gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle> { - static ThreadblockSwizzleID const kId = - ThreadblockSwizzleID::kGemmBatchedIdentity; +struct ThreadblockSwizzleMap { + static ThreadblockSwizzleID const kId = ThreadblockSwizzleID::kGemmBatchedIdentity; }; template struct ThreadblockSwizzleMap< gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle> { - static ThreadblockSwizzleID const kId = - ThreadblockSwizzleID::kGemmSplitKIdentity; + static ThreadblockSwizzleID const kId = ThreadblockSwizzleID::kGemmSplitKIdentity; }; template <> struct ThreadblockSwizzleMap< gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle> { - static ThreadblockSwizzleID const kId = - ThreadblockSwizzleID::kGemmSplitKHorizontal; + static ThreadblockSwizzleID const kId = ThreadblockSwizzleID::kGemmSplitKHorizontal; }; template <> @@ -587,8 +579,7 @@ TensorDescription make_TensorDescription(int alignment = 1) { desc.element = NumericTypeMap::kId; desc.layout = LayoutMap::kId; desc.alignment = alignment; - desc.log_extent_range = - int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; + desc.log_extent_range = int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; return desc; diff --git a/dnn/src/cuda/cutlass/manifest.h b/dnn/src/cuda/cutlass/manifest.h index 396074f3..5d0b2310 100644 --- a/dnn/src/cuda/cutlass/manifest.h +++ b/dnn/src/cuda/cutlass/manifest.h @@ -75,8 +75,7 @@ private: OperationVector operations_; public: - Manifest(Provider provider = library::Provider::kCUTLASS) - : provider_(provider) {} + Manifest(Provider provider = library::Provider::kCUTLASS) : provider_(provider) {} /// Top-level initialization Status initialize(); diff --git a/dnn/src/cuda/cutlass/operation_table.cpp b/dnn/src/cuda/cutlass/operation_table.cpp index 2da08ea3..290d2aef 100644 --- a/dnn/src/cuda/cutlass/operation_table.cpp +++ b/dnn/src/cuda/cutlass/operation_table.cpp @@ -36,8 +36,8 @@ * implied. */ -#include "src/common/utils.h" #include "src/cuda/cutlass/operation_table.h" +#include "src/common/utils.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,8 +86,7 @@ GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { ///////////////////////////////////////////////////////////////////////////////////////////////// -ConvolutionKey get_convolution_key_from_desc( - const ConvolutionDescription& desc) { +ConvolutionKey get_convolution_key_from_desc(const ConvolutionDescription& desc) { ConvolutionKey key; key.conv_op = desc.conv_op; @@ -139,8 +138,8 @@ void OperationTable::append(Manifest const& manifest) { // insert all gemm operations into operation table if (desc.kind == OperationKind::kGemm) { - GemmKey key = get_gemm_key_from_desc( - static_cast(desc)); + GemmKey key = + get_gemm_key_from_desc(static_cast(desc)); gemm_operations[key].push_back(operation.get()); } @@ -158,8 +157,8 @@ void OperationTable::append(Manifest const& manifest) { Operation const* OperationTable::find_op(GemmKey const& key) const { if (gemm_operations.count(key)) { auto const& ops = gemm_operations.at(key); - megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", - ops.size()); + megdnn_assert( + ops.size() == 1, "exactly one kernel expected, got %zu", ops.size()); return ops[0]; } return nullptr; @@ -170,8 +169,8 @@ Operation const* OperationTable::find_op(GemmKey const& key) const { Operation const* OperationTable::find_op(ConvolutionKey const& key) const { if (convolution_operations.count(key) > 0) { auto const& ops = convolution_operations.at(key); - megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", - ops.size()); + megdnn_assert( + ops.size() == 1, "exactly one kernel expected, got %zu", ops.size()); return ops[0]; } return nullptr; diff --git a/dnn/src/cuda/cutlass/operation_table.h b/dnn/src/cuda/cutlass/operation_table.h index 1fef2ff7..3e38bf95 100644 --- a/dnn/src/cuda/cutlass/operation_table.h +++ b/dnn/src/cuda/cutlass/operation_table.h @@ -100,7 +100,7 @@ struct GemmKey { return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) && (element_B == rhs.element_B) && (layout_B == rhs.layout_B) && (element_C == rhs.element_C) && (layout_C == rhs.layout_C) && - (element_accumulator == rhs.element_accumulator) && + (element_accumulator == rhs.element_accumulator) && (threadblock_shape_m == rhs.threadblock_shape_m) && (threadblock_shape_n == rhs.threadblock_shape_n) && (threadblock_shape_k == rhs.threadblock_shape_k) && @@ -111,8 +111,7 @@ struct GemmKey { (instruction_shape_n == rhs.instruction_shape_n) && (instruction_shape_k == rhs.instruction_shape_k) && (stages == rhs.stages) && (alignment_A == rhs.alignment_A) && - (alignment_B == rhs.alignment_B) && - (split_k_mode == rhs.split_k_mode); + (alignment_B == rhs.alignment_B) && (split_k_mode == rhs.split_k_mode); } inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); } @@ -136,13 +135,13 @@ struct GemmKey { "\n layout_B: " + to_string(layout_B) + "\n element_C: " + to_string(element_C) + "\n layout_C: " + to_string(layout_C) + - "\n element_accumulator: " + to_string(element_accumulator) + + "\n element_accumulator: " + to_string(element_accumulator) + "\n threadblock_shape: " + threadblock_shape_str + "\n warp_shape: " + warp_shape_str + "\n instruction_shape: " + instruction_shape_str + "\n stages: " + std::to_string(stages) + - "\n alignment_A: " + std::to_string(alignment_A) + - "\n alignment_B: " + std::to_string(alignment_B) + + "\n alignment_A: " + std::to_string(alignment_A) + + "\n alignment_B: " + std::to_string(alignment_B) + "\n split_k_mode: " + to_string(split_k_mode) + "\n}"; } }; @@ -156,14 +155,10 @@ struct GemmKeyHasher { .update(&key.layout_B, sizeof(key.layout_B)) .update(&key.element_C, sizeof(key.element_C)) .update(&key.layout_C, sizeof(key.layout_C)) - .update(&key.element_accumulator, - sizeof(key.element_accumulator)) - .update(&key.threadblock_shape_m, - sizeof(key.threadblock_shape_m)) - .update(&key.threadblock_shape_n, - sizeof(key.threadblock_shape_n)) - .update(&key.threadblock_shape_k, - sizeof(key.threadblock_shape_k)) + .update(&key.element_accumulator, sizeof(key.element_accumulator)) + .update(&key.threadblock_shape_m, sizeof(key.threadblock_shape_m)) + .update(&key.threadblock_shape_n, sizeof(key.threadblock_shape_n)) + .update(&key.threadblock_shape_k, sizeof(key.threadblock_shape_k)) .update(&key.warp_shape_m, sizeof(key.warp_shape_m)) .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) @@ -176,8 +171,7 @@ struct GemmKeyHasher { }; using GemmOperationMap = - std::unordered_map, - GemmKeyHasher>; + std::unordered_map, GemmKeyHasher>; ///////////////////////////////////////////////////////////////////////////////////////////////// // Data Structures for ConvolutionOperationMap @@ -219,10 +213,8 @@ struct ConvolutionKey { (layout_src == rhs.layout_src) && (element_filter == rhs.element_filter) && (layout_filter == rhs.layout_filter) && - (element_dst == rhs.element_dst) && - (layout_dst == rhs.layout_dst) && - (element_bias == rhs.element_bias) && - (layout_bias == rhs.layout_bias) && + (element_dst == rhs.element_dst) && (layout_dst == rhs.layout_dst) && + (element_bias == rhs.element_bias) && (layout_bias == rhs.layout_bias) && (convolution_type == rhs.convolution_type) && (threadblock_shape_m == rhs.threadblock_shape_m) && (threadblock_shape_n == rhs.threadblock_shape_n) && @@ -238,9 +230,7 @@ struct ConvolutionKey { (without_shared_load == rhs.without_shared_load); } - inline bool operator!=(ConvolutionKey const& rhs) const { - return !(*this == rhs); - } + inline bool operator!=(ConvolutionKey const& rhs) const { return !(*this == rhs); } inline std::string str() const { auto tuple_to_str = [](int m, int n, int k) -> std::string { @@ -270,10 +260,8 @@ struct ConvolutionKey { "\n instruction_shape: " + instruction_shape_str + "\n epilogue_type: " + to_string(epilogue_type) + "\n stages: " + std::to_string(stages) + - "\n special_optimization: " + - to_string(special_optimization) + - "\n without_shared_load: " + to_string(without_shared_load) + - "\n}"; + "\n special_optimization: " + to_string(special_optimization) + + "\n without_shared_load: " + to_string(without_shared_load) + "\n}"; } }; @@ -291,34 +279,25 @@ struct ConvolutionKeyHasher { .update(&key.element_bias, sizeof(key.element_bias)) .update(&key.layout_bias, sizeof(key.layout_bias)) .update(&key.convolution_type, sizeof(key.convolution_type)) - .update(&key.threadblock_shape_m, - sizeof(key.threadblock_shape_m)) - .update(&key.threadblock_shape_n, - sizeof(key.threadblock_shape_n)) - .update(&key.threadblock_shape_k, - sizeof(key.threadblock_shape_k)) + .update(&key.threadblock_shape_m, sizeof(key.threadblock_shape_m)) + .update(&key.threadblock_shape_n, sizeof(key.threadblock_shape_n)) + .update(&key.threadblock_shape_k, sizeof(key.threadblock_shape_k)) .update(&key.warp_shape_m, sizeof(key.warp_shape_m)) .update(&key.warp_shape_n, sizeof(key.warp_shape_n)) .update(&key.warp_shape_k, sizeof(key.warp_shape_k)) - .update(&key.instruction_shape_m, - sizeof(key.instruction_shape_m)) - .update(&key.instruction_shape_n, - sizeof(key.instruction_shape_n)) - .update(&key.instruction_shape_k, - sizeof(key.instruction_shape_k)) + .update(&key.instruction_shape_m, sizeof(key.instruction_shape_m)) + .update(&key.instruction_shape_n, sizeof(key.instruction_shape_n)) + .update(&key.instruction_shape_k, sizeof(key.instruction_shape_k)) .update(&key.epilogue_type, sizeof(key.epilogue_type)) .update(&key.stages, sizeof(key.stages)) - .update(&key.special_optimization, - sizeof(key.special_optimization)) - .update(&key.without_shared_load, - sizeof(key.without_shared_load)) + .update(&key.special_optimization, sizeof(key.special_optimization)) + .update(&key.without_shared_load, sizeof(key.without_shared_load)) .digest(); } }; -using ConvolutionOperationMap = - std::unordered_map, - ConvolutionKeyHasher>; +using ConvolutionOperationMap = std::unordered_map< + ConvolutionKey, std::vector, ConvolutionKeyHasher>; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/dnn/src/cuda/cutlass/util.cu b/dnn/src/cuda/cutlass/util.cu index 0506826c..a309a66a 100644 --- a/dnn/src/cuda/cutlass/util.cu +++ b/dnn/src/cuda/cutlass/util.cu @@ -36,8 +36,7 @@ * implied. */ -#if __CUDACC_VER_MAJOR__ > 9 || \ - (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) #include #include @@ -95,8 +94,7 @@ char const* to_string(Provider provider, bool pretty) { template <> Provider from_string(std::string const& str) { for (auto const& possible : Provider_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { + if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } @@ -115,8 +113,7 @@ static struct { {"spgemm", "", GemmKind::kSparse}, {"universal", "", GemmKind::kUniversal}, {"planar_complex", "", GemmKind::kPlanarComplex}, - {"planar_complex_array", "", - GemmKind::kPlanarComplexArray}, + {"planar_complex_array", "", GemmKind::kPlanarComplexArray}, }; /// Converts a GemmKind enumerant to a string @@ -167,8 +164,7 @@ char const* to_string(OperationKind enumerant, bool pretty) { template <> OperationKind from_string(std::string const& str) { for (auto const& possible : OperationKind_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { + if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } @@ -186,8 +182,7 @@ static struct { {"success", "Success", Status::kSuccess}, {"misaligned_operand", "Error: misaligned operand", Status::kErrorMisalignedOperand}, - {"invalid_problem", "Error: invalid problem", - Status::kErrorInvalidProblem}, + {"invalid_problem", "Error: invalid problem", Status::kErrorInvalidProblem}, {"not_supported", "Error: not supported", Status::kErrorNotSupported}, {"internal", "Error: internal", Status::kErrorInternal}}; @@ -210,8 +205,7 @@ char const* to_string(Status status, bool pretty) { template <> Status from_string(std::string const& str) { for (auto const& possible : Status_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { + if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } @@ -284,8 +278,7 @@ char const* to_string(NumericTypeID type, bool pretty) { template <> NumericTypeID from_string(std::string const& str) { for (auto const& possible : NumericTypeID_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { + if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } @@ -497,48 +490,49 @@ bool is_float_type(NumericTypeID type) { static struct { LayoutTypeID layout; char const* alias; -} layout_aliases[] = {{LayoutTypeID::kUnknown, "unknown"}, - {LayoutTypeID::kRowMajor, "row"}, - {LayoutTypeID::kRowMajor, "t"}, - {LayoutTypeID::kColumnMajor, "column"}, - {LayoutTypeID::kColumnMajor, "col"}, - {LayoutTypeID::kColumnMajor, "n"}, - - {LayoutTypeID::kColumnMajorInterleavedK2, "nk2"}, - {LayoutTypeID::kRowMajorInterleavedK2, "tk2"}, - - {LayoutTypeID::kColumnMajorInterleavedK4, "nk4"}, - {LayoutTypeID::kRowMajorInterleavedK4, "tk4"}, - - {LayoutTypeID::kColumnMajorInterleavedK16, "nk16"}, - {LayoutTypeID::kRowMajorInterleavedK16, "tk16"}, - - {LayoutTypeID::kColumnMajorInterleavedK32, "nk32"}, - {LayoutTypeID::kRowMajorInterleavedK32, "tk32"}, - - {LayoutTypeID::kColumnMajorInterleavedK64, "nk64"}, - {LayoutTypeID::kRowMajorInterleavedK64, "tk64"}, - - {LayoutTypeID::kTensorNCHW, "nchw"}, - {LayoutTypeID::kTensorNCDHW, "ncdhw"}, - {LayoutTypeID::kTensorNHWC, "nhwc"}, - {LayoutTypeID::kTensorNDHWC, "ndhwc"}, - {LayoutTypeID::kTensorNC4HW4, "nc4hw4"}, - {LayoutTypeID::kTensorNC8HW8, "nc8hw8"}, - {LayoutTypeID::kTensorNC16HW16, "nc16hw16"}, - {LayoutTypeID::kTensorNC32HW32, "nc32hw32"}, - {LayoutTypeID::kTensorNC64HW64, "nc64hw64"}, - {LayoutTypeID::kTensorC4RSK4, "c4rsk4"}, - {LayoutTypeID::kTensorC8RSK8, "c8rsk8"}, - {LayoutTypeID::kTensorC16RSK16, "c16rsk16"}, - {LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, - {LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, - {LayoutTypeID::kTensorK4RSC4, "k4rsc4"}, - {LayoutTypeID::kTensorCK4RS4, "ck4rs4"}, - {LayoutTypeID::kTensorCK8RS8, "ck8rs8"}, - {LayoutTypeID::kTensorCK16RS16, "ck16rs16"}, - {LayoutTypeID::kUnknown, "*"}, - {LayoutTypeID::kInvalid, nullptr}}; +} layout_aliases[] = { + {LayoutTypeID::kUnknown, "unknown"}, + {LayoutTypeID::kRowMajor, "row"}, + {LayoutTypeID::kRowMajor, "t"}, + {LayoutTypeID::kColumnMajor, "column"}, + {LayoutTypeID::kColumnMajor, "col"}, + {LayoutTypeID::kColumnMajor, "n"}, + + {LayoutTypeID::kColumnMajorInterleavedK2, "nk2"}, + {LayoutTypeID::kRowMajorInterleavedK2, "tk2"}, + + {LayoutTypeID::kColumnMajorInterleavedK4, "nk4"}, + {LayoutTypeID::kRowMajorInterleavedK4, "tk4"}, + + {LayoutTypeID::kColumnMajorInterleavedK16, "nk16"}, + {LayoutTypeID::kRowMajorInterleavedK16, "tk16"}, + + {LayoutTypeID::kColumnMajorInterleavedK32, "nk32"}, + {LayoutTypeID::kRowMajorInterleavedK32, "tk32"}, + + {LayoutTypeID::kColumnMajorInterleavedK64, "nk64"}, + {LayoutTypeID::kRowMajorInterleavedK64, "tk64"}, + + {LayoutTypeID::kTensorNCHW, "nchw"}, + {LayoutTypeID::kTensorNCDHW, "ncdhw"}, + {LayoutTypeID::kTensorNHWC, "nhwc"}, + {LayoutTypeID::kTensorNDHWC, "ndhwc"}, + {LayoutTypeID::kTensorNC4HW4, "nc4hw4"}, + {LayoutTypeID::kTensorNC8HW8, "nc8hw8"}, + {LayoutTypeID::kTensorNC16HW16, "nc16hw16"}, + {LayoutTypeID::kTensorNC32HW32, "nc32hw32"}, + {LayoutTypeID::kTensorNC64HW64, "nc64hw64"}, + {LayoutTypeID::kTensorC4RSK4, "c4rsk4"}, + {LayoutTypeID::kTensorC8RSK8, "c8rsk8"}, + {LayoutTypeID::kTensorC16RSK16, "c16rsk16"}, + {LayoutTypeID::kTensorC32RSK32, "c32rsk32"}, + {LayoutTypeID::kTensorC64RSK64, "c64rsk64"}, + {LayoutTypeID::kTensorK4RSC4, "k4rsc4"}, + {LayoutTypeID::kTensorCK4RS4, "ck4rs4"}, + {LayoutTypeID::kTensorCK8RS8, "ck8rs8"}, + {LayoutTypeID::kTensorCK16RS16, "ck16rs16"}, + {LayoutTypeID::kUnknown, "*"}, + {LayoutTypeID::kInvalid, nullptr}}; /// Converts a LayoutTypeID enumerant to a string char const* to_string(LayoutTypeID layout, bool pretty) { @@ -640,8 +634,7 @@ char const* to_string(OpcodeClassID type, bool pretty) { template <> OpcodeClassID from_string(std::string const& str) { for (auto const& possible : OpcodeClassID_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { + if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } @@ -655,8 +648,9 @@ static struct { char const* text; char const* pretty; ComplexTransform enumerant; -} ComplexTransform_enumerants[] = {{"n", "none", ComplexTransform::kNone}, - {"c", "conj", ComplexTransform::kConjugate}}; +} ComplexTransform_enumerants[] = { + {"n", "none", ComplexTransform::kNone}, + {"c", "conj", ComplexTransform::kConjugate}}; /// Converts a ComplexTransform enumerant to a string char const* to_string(ComplexTransform type, bool pretty) { @@ -677,8 +671,7 @@ char const* to_string(ComplexTransform type, bool pretty) { template <> ComplexTransform from_string(std::string const& str) { for (auto const& possible : ComplexTransform_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { + if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } @@ -715,8 +708,7 @@ char const* to_string(SplitKMode type, bool pretty) { template <> SplitKMode from_string(std::string const& str) { for (auto const& possible : SplitKMode_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { + if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } @@ -754,8 +746,7 @@ char const* to_string(ConvModeID type, bool pretty) { template <> ConvModeID from_string(std::string const& str) { for (auto const& possible : ConvModeID_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { + if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } @@ -794,8 +785,7 @@ char const* to_string(IteratorAlgorithmID type, bool pretty) { template <> IteratorAlgorithmID from_string(std::string const& str) { for (auto const& possible : IteratorAlgorithmID_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { + if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } @@ -835,8 +825,7 @@ char const* to_string(ConvKind type, bool pretty) { template <> ConvKind from_string(std::string const& str) { for (auto const& possible : ConvKind_enumerants) { - if ((str.compare(possible.text) == 0) || - (str.compare(possible.pretty) == 0)) { + if ((str.compare(possible.text) == 0) || (str.compare(possible.pretty) == 0)) { return possible.enumerant; } } @@ -848,8 +837,8 @@ ConvKind from_string(std::string const& str) { /// Lexical cast a string to a byte array. Returns true if cast is successful or /// false if invalid. -bool lexical_cast(std::vector& bytes, NumericTypeID type, - std::string const& str) { +bool lexical_cast( + std::vector& bytes, NumericTypeID type, std::string const& str) { int size_bytes = sizeof_bits(type) / 8; if (!size_bytes) { return false; @@ -893,14 +882,12 @@ bool lexical_cast(std::vector& bytes, NumericTypeID type, case NumericTypeID::kBF16: { float tmp; ss >> tmp; - *reinterpret_cast(bytes.data()) = - static_cast(tmp); + *reinterpret_cast(bytes.data()) = static_cast(tmp); } break; case NumericTypeID::kTF32: { float tmp; ss >> tmp; - *reinterpret_cast(bytes.data()) = - static_cast(tmp); + *reinterpret_cast(bytes.data()) = static_cast(tmp); } break; case NumericTypeID::kF32: { ss >> *reinterpret_cast(bytes.data()); @@ -920,8 +907,7 @@ bool lexical_cast(std::vector& bytes, NumericTypeID type, std::complex tmp; ss >> tmp; cutlass::complex* x = - reinterpret_cast*>( - bytes.data()); + reinterpret_cast*>(bytes.data()); x->real() = static_cast(std::real(tmp)); x->imag() = static_cast(std::imag(tmp)); } break; @@ -932,8 +918,7 @@ bool lexical_cast(std::vector& bytes, NumericTypeID type, std::complex tmp; ss >> tmp; cutlass::complex* x = - reinterpret_cast*>( - bytes.data()); + reinterpret_cast*>(bytes.data()); x->real() = static_cast(std::real(tmp)); x->imag() = static_cast(std::imag(tmp)); } break; @@ -1015,8 +1000,7 @@ std::string lexical_cast(std::vector& bytes, NumericTypeID type) { } break; case NumericTypeID::kCF16: { cutlass::complex const* x = - reinterpret_cast const*>( - bytes.data()); + reinterpret_cast const*>(bytes.data()); ss << float(x->real()); @@ -1026,8 +1010,7 @@ std::string lexical_cast(std::vector& bytes, NumericTypeID type) { } break; case NumericTypeID::kCBF16: { cutlass::complex const* x = - reinterpret_cast const*>( - bytes.data()); + reinterpret_cast const*>(bytes.data()); ss << float(x->real()); @@ -1037,8 +1020,7 @@ std::string lexical_cast(std::vector& bytes, NumericTypeID type) { } break; case NumericTypeID::kCF32: { cutlass::complex const* x = - reinterpret_cast const*>( - bytes.data()); + reinterpret_cast const*>(bytes.data()); ss << x->real(); @@ -1048,8 +1030,7 @@ std::string lexical_cast(std::vector& bytes, NumericTypeID type) { } break; case NumericTypeID::kCTF32: { cutlass::complex const* x = - reinterpret_cast const*>( - bytes.data()); + reinterpret_cast const*>(bytes.data()); ss << float(x->real()); @@ -1059,8 +1040,7 @@ std::string lexical_cast(std::vector& bytes, NumericTypeID type) { } break; case NumericTypeID::kCF64: { cutlass::complex const* x = - reinterpret_cast const*>( - bytes.data()); + reinterpret_cast const*>(bytes.data()); ss << x->real(); @@ -1077,8 +1057,7 @@ std::string lexical_cast(std::vector& bytes, NumericTypeID type) { /// Casts from a signed int64 to the destination type. Returns true if /// successful. -bool cast_from_int64(std::vector& bytes, NumericTypeID type, - int64_t src) { +bool cast_from_int64(std::vector& bytes, NumericTypeID type, int64_t src) { int size_bytes = sizeof_bits(type) / 8; if (!size_bytes) { return false; @@ -1088,39 +1067,31 @@ bool cast_from_int64(std::vector& bytes, NumericTypeID type, switch (type) { case NumericTypeID::kU8: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU16: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU32: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU64: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS8: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS16: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS32: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS64: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kF16: { - *reinterpret_cast(bytes.data()) = - static_cast(float(src)); + *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kBF16: { *reinterpret_cast(bytes.data()) = @@ -1159,8 +1130,7 @@ bool cast_from_int64(std::vector& bytes, NumericTypeID type, /// Casts from an unsigned int64 to the destination type. Returns true if /// successful. -bool cast_from_uint64(std::vector& bytes, NumericTypeID type, - uint64_t src) { +bool cast_from_uint64(std::vector& bytes, NumericTypeID type, uint64_t src) { int size_bytes = sizeof_bits(type) / 8; if (!size_bytes) { return false; @@ -1170,39 +1140,31 @@ bool cast_from_uint64(std::vector& bytes, NumericTypeID type, switch (type) { case NumericTypeID::kU8: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU16: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU32: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU64: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS8: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS16: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS32: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS64: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kF16: { - *reinterpret_cast(bytes.data()) = - static_cast(float(src)); + *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kBF16: { *reinterpret_cast(bytes.data()) = @@ -1241,8 +1203,7 @@ bool cast_from_uint64(std::vector& bytes, NumericTypeID type, /// Lexical cast a string to a byte array. Returns true if cast is successful or /// false if invalid. -bool cast_from_double(std::vector& bytes, NumericTypeID type, - double src) { +bool cast_from_double(std::vector& bytes, NumericTypeID type, double src) { int size_bytes = sizeof_bits(type) / 8; if (!size_bytes) { return false; @@ -1252,39 +1213,31 @@ bool cast_from_double(std::vector& bytes, NumericTypeID type, switch (type) { case NumericTypeID::kU8: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU16: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU32: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kU64: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS8: { *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS16: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS32: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kS64: { - *reinterpret_cast(bytes.data()) = - static_cast(src); + *reinterpret_cast(bytes.data()) = static_cast(src); } break; case NumericTypeID::kF16: { - *reinterpret_cast(bytes.data()) = - static_cast(float(src)); + *reinterpret_cast(bytes.data()) = static_cast(float(src)); } break; case NumericTypeID::kBF16: { *reinterpret_cast(bytes.data()) = @@ -1308,8 +1261,7 @@ bool cast_from_double(std::vector& bytes, NumericTypeID type, } break; case NumericTypeID::kCBF16: { cutlass::complex* x = - reinterpret_cast*>( - bytes.data()); + reinterpret_cast*>(bytes.data()); x->real() = static_cast(bfloat16_t(src)); x->imag() = static_cast(bfloat16_t(0)); } break; @@ -1367,8 +1319,7 @@ static struct { conv::ConvType enumerant; } ConvType_enumerants[] = { {"convolution", "Convolution", conv::ConvType::kConvolution}, - {"batch_convolution", "BatchConvolution", - conv::ConvType::kBatchConvolution}, + {"batch_convolution", "BatchConvolution", conv::ConvType::kBatchConvolution}, {"local", "Local", conv::ConvType::kLocal}, {"local_share", "LocalShare", conv::ConvType::kLocalShare}, }; @@ -1395,14 +1346,10 @@ static struct { char const* pretty; ArchTagID enumerant; } ArchTagID_enumerants[] = { - {"sm_50", "Sm50", ArchTagID::kSm50}, - {"sm_60", "Sm60", ArchTagID::kSm60}, - {"sm_61", "Sm61", ArchTagID::kSm61}, - {"sm_70", "Sm70", ArchTagID::kSm70}, - {"sm_72", "Sm72", ArchTagID::kSm72}, - {"sm_75", "Sm75", ArchTagID::kSm75}, - {"sm_80", "Sm80", ArchTagID::kSm80}, - {"sm_86", "Sm86", ArchTagID::kSm86}, + {"sm_50", "Sm50", ArchTagID::kSm50}, {"sm_60", "Sm60", ArchTagID::kSm60}, + {"sm_61", "Sm61", ArchTagID::kSm61}, {"sm_70", "Sm70", ArchTagID::kSm70}, + {"sm_72", "Sm72", ArchTagID::kSm72}, {"sm_75", "Sm75", ArchTagID::kSm75}, + {"sm_80", "Sm80", ArchTagID::kSm80}, {"sm_86", "Sm86", ArchTagID::kSm86}, }; /// Converts an ArchTagID enumerant to a string @@ -1438,8 +1385,7 @@ static struct { epilogue::EpilogueType::kBiasAddLinearCombinationHSwishClamp}, {"bias_add_linear_combination_relu", "BiasAddLinearCombinationRelu", epilogue::EpilogueType::kBiasAddLinearCombinationRelu}, - {"bias_add_linear_combination_relu_clamp", - "BiasAddLinearCombinationReluClamp", + {"bias_add_linear_combination_relu_clamp", "BiasAddLinearCombinationReluClamp", epilogue::EpilogueType::kBiasAddLinearCombinationReluClamp}, {"conversion", "Conversion", epilogue::EpilogueType::kConversion}, {"linear_combination", "LinearCombination", @@ -1486,8 +1432,7 @@ static struct { ThreadblockSwizzleID::kGemmSplitKIdentity}, {"gemm_split_k_horizontal", "GemmSplitKHorizontalThreadblockSwizzle", ThreadblockSwizzleID::kGemmSplitKHorizontal}, - {"gemv_batched_strided_default", - "GemvBatchedStridedThreadblockDefaultSwizzle", + {"gemv_batched_strided_default", "GemvBatchedStridedThreadblockDefaultSwizzle", ThreadblockSwizzleID::kGemvBatchedStridedDefault}, {"gemv_batched_strided_reduction", "GemvBatchedStridedThreadblockReductionSwizzle", diff --git a/dnn/src/cuda/cutlass/util.h b/dnn/src/cuda/cutlass/util.h index da36393e..65834d69 100644 --- a/dnn/src/cuda/cutlass/util.h +++ b/dnn/src/cuda/cutlass/util.h @@ -161,8 +161,8 @@ std::string lexical_cast(int64_t int_value); /// Lexical cast a string to a byte array. Returns true if cast is successful or /// false if invalid. -bool lexical_cast(std::vector& bytes, NumericTypeID type, - std::string const& str); +bool lexical_cast( + std::vector& bytes, NumericTypeID type, std::string const& str); /// Lexical cast TO a string FROM a byte array. Returns true if cast is /// successful or false if invalid. @@ -170,18 +170,15 @@ std::string lexical_cast(std::vector& bytes, NumericTypeID type); /// Casts from a signed int64 to the destination type. Returns true if /// successful. -bool cast_from_int64(std::vector& bytes, NumericTypeID type, - int64_t src); +bool cast_from_int64(std::vector& bytes, NumericTypeID type, int64_t src); /// Casts from an unsigned int64 to the destination type. Returns true if /// successful. -bool cast_from_uint64(std::vector& bytes, NumericTypeID type, - uint64_t src); +bool cast_from_uint64(std::vector& bytes, NumericTypeID type, uint64_t src); /// Casts from a real value represented as a double to the destination type. /// Returns true if successful. -bool cast_from_double(std::vector& bytes, NumericTypeID type, - double src); +bool cast_from_double(std::vector& bytes, NumericTypeID type, double src); ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -198,8 +195,7 @@ char const* to_string(ArchTagID tag, bool pretty = false); char const* to_string(epilogue::EpilogueType type, bool pretty = false); /// Converts a ThreadblockSwizzleID enumerant to a string -char const* to_string(ThreadblockSwizzleID threadblock_swizzle, - bool pretty = false); +char const* to_string(ThreadblockSwizzleID threadblock_swizzle, bool pretty = false); /// Converts a bool value to a string char const* to_string(bool val, bool pretty = false); @@ -208,8 +204,7 @@ char const* to_string(bool val, bool pretty = false); char const* to_string(MathOperationID math_op, bool pretty = false); /// Converts a SpecialOptimizeDesc enumerant to a string -char const* to_string(conv::SpecialOptimizeDesc special_opt, - bool pretty = false); +char const* to_string(conv::SpecialOptimizeDesc special_opt, bool pretty = false); /// Converts an ImplicitGemmMode enumerant to a string char const* to_string(conv::ImplicitGemmMode mode, bool pretty = false); diff --git a/dnn/src/cuda/cv/kernel_common.cuh b/dnn/src/cuda/cv/kernel_common.cuh index 07bcf0cd..032fe4d8 100644 --- a/dnn/src/cuda/cv/kernel_common.cuh +++ b/dnn/src/cuda/cv/kernel_common.cuh @@ -181,8 +181,7 @@ __device__ inline int border_interpolate(int p, int len) { template __device__ void interpolate_coefs(float x, float* coeffs); template <> -__device__ inline void interpolate_coefs(float x, - float* coeffs) {} +__device__ inline void interpolate_coefs(float x, float* coeffs) {} template <> __device__ inline void interpolate_coefs(float x, float* coeffs) { interpolate_linear_coefs(x, coeffs); @@ -192,8 +191,7 @@ __device__ inline void interpolate_coefs(float x, float* coeffs) { megdnn::resize::interpolate_cubic(x, coeffs); } template <> -__device__ inline void interpolate_coefs(float x, - float* coeffs) { +__device__ inline void interpolate_coefs(float x, float* coeffs) { interpolate_lanczos4_coefs(x, coeffs); } diff --git a/dnn/src/cuda/cvt_color/cvt_color.cu b/dnn/src/cuda/cvt_color/cvt_color.cu index 012c13ce..25b99b9f 100644 --- a/dnn/src/cuda/cvt_color/cvt_color.cu +++ b/dnn/src/cuda/cvt_color/cvt_color.cu @@ -77,13 +77,12 @@ using namespace megcv; #define THREADS_X 256 #define THREADS_Y 1 -#define U8_PROCESS_PER_THREADS_X 4 +#define U8_PROCESS_PER_THREADS_X 4 #define F32_PROCESS_PER_THREADS_X 1 -__global__ void cvt_rgb2gray_8u_kernel(const uchar* src, uchar* dst, - const size_t rows, const size_t cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_rgb2gray_8u_kernel( + const uchar* src, uchar* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { size_t t = blockIdx.x * blockDim.x + threadIdx.x; if (t < (rows * cols) / U8_PROCESS_PER_THREADS_X) { @@ -95,17 +94,17 @@ __global__ void cvt_rgb2gray_8u_kernel(const uchar* src, uchar* dst, uchar temp_src[12]; *((uint3*)temp_src) = *((uint3*)src); - temp_des[0] = (temp_src[0] * 4899 + temp_src[1] * 9617 + - temp_src[2] * 1868 + (1 << 13)) >> + temp_des[0] = (temp_src[0] * 4899 + temp_src[1] * 9617 + temp_src[2] * 1868 + + (1 << 13)) >> 14; - temp_des[1] = (temp_src[3] * 4899 + temp_src[4] * 9617 + - temp_src[5] * 1868 + (1 << 13)) >> + temp_des[1] = (temp_src[3] * 4899 + temp_src[4] * 9617 + temp_src[5] * 1868 + + (1 << 13)) >> 14; - temp_des[2] = (temp_src[6] * 4899 + temp_src[7] * 9617 + - temp_src[8] * 1868 + (1 << 13)) >> + temp_des[2] = (temp_src[6] * 4899 + temp_src[7] * 9617 + temp_src[8] * 1868 + + (1 << 13)) >> 14; - temp_des[3] = (temp_src[9] * 4899 + temp_src[10] * 9617 + - temp_src[11] * 1868 + (1 << 13)) >> + temp_des[3] = (temp_src[9] * 4899 + temp_src[10] * 9617 + temp_src[11] * 1868 + + (1 << 13)) >> 14; *((uint32_t*)dst) = *((uint32_t*)temp_des); @@ -117,17 +116,15 @@ __global__ void cvt_rgb2gray_8u_kernel(const uchar* src, uchar* dst, dst += 1 * offset; for (int i = 0; i < rest; i++, src += 3, dst += 1) - dst[0] = (src[0] * 4899 + src[1] * 9617 + src[2] * 1868 + - (1 << 13)) >> + dst[0] = (src[0] * 4899 + src[1] * 9617 + src[2] * 1868 + (1 << 13)) >> 14; } } } -__global__ void cvt_rgb2gray_32f_kernel(const float* src, float* dst, - const size_t rows, const size_t cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_rgb2gray_32f_kernel( + const float* src, float* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { size_t t = blockIdx.x * blockDim.x + threadIdx.x; if (t < rows * cols) { @@ -138,18 +135,15 @@ __global__ void cvt_rgb2gray_32f_kernel(const float* src, float* dst, float temp_src[3], temp_dst; *((float3*)temp_src) = *((float3*)src); - temp_dst = temp_src[0] * 0.299f + temp_src[1] * 0.587f + - temp_src[2] * 0.114f; + temp_dst = temp_src[0] * 0.299f + temp_src[1] * 0.587f + temp_src[2] * 0.114f; dst[0] = temp_dst; } } - -__global__ void cvt_bgr2gray_8u_kernel(const uchar* src, uchar* dst, - const size_t rows, const size_t cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_bgr2gray_8u_kernel( + const uchar* src, uchar* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { size_t t = blockIdx.x * blockDim.x + threadIdx.x; if (t < (rows * cols) / U8_PROCESS_PER_THREADS_X) { @@ -161,17 +155,17 @@ __global__ void cvt_bgr2gray_8u_kernel(const uchar* src, uchar* dst, uchar temp_src[12]; *((uint3*)temp_src) = *((uint3*)src); - temp_des[0] = (temp_src[0] * 1868 + temp_src[1] * 9617 + - temp_src[2] * 4899 + (1 << 13)) >> + temp_des[0] = (temp_src[0] * 1868 + temp_src[1] * 9617 + temp_src[2] * 4899 + + (1 << 13)) >> 14; - temp_des[1] = (temp_src[3] * 1868 + temp_src[4] * 9617 + - temp_src[5] * 4899 + (1 << 13)) >> + temp_des[1] = (temp_src[3] * 1868 + temp_src[4] * 9617 + temp_src[5] * 4899 + + (1 << 13)) >> 14; - temp_des[2] = (temp_src[6] * 1868 + temp_src[7] * 9617 + - temp_src[8] * 4899 + (1 << 13)) >> + temp_des[2] = (temp_src[6] * 1868 + temp_src[7] * 9617 + temp_src[8] * 4899 + + (1 << 13)) >> 14; - temp_des[3] = (temp_src[9] * 1868 + temp_src[10] * 9617 + - temp_src[11] * 4899 + (1 << 13)) >> + temp_des[3] = (temp_src[9] * 1868 + temp_src[10] * 9617 + temp_src[11] * 4899 + + (1 << 13)) >> 14; *((uint32_t*)dst) = *((uint32_t*)temp_des); @@ -183,17 +177,15 @@ __global__ void cvt_bgr2gray_8u_kernel(const uchar* src, uchar* dst, dst += 1 * offset; for (int i = 0; i < rest; i++, src += 3, dst += 1) - dst[0] = (src[0] * 1868 + src[1] * 9617 + src[2] * 4899 + - (1 << 13)) >> + dst[0] = (src[0] * 1868 + src[1] * 9617 + src[2] * 4899 + (1 << 13)) >> 14; } } } -__global__ void cvt_bgr2gray_32f_kernel(const float* src, float* dst, - const size_t rows, const size_t cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_bgr2gray_32f_kernel( + const float* src, float* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { size_t t = blockIdx.x * blockDim.x + threadIdx.x; if (t < rows * cols) { @@ -204,18 +196,15 @@ __global__ void cvt_bgr2gray_32f_kernel(const float* src, float* dst, float temp_src[3], temp_dst; *((float3*)temp_src) = *((float3*)src); - temp_dst = temp_src[0] * 0.114f + temp_src[1] * 0.587f + - temp_src[2] * 0.299f; + temp_dst = temp_src[0] * 0.114f + temp_src[1] * 0.587f + temp_src[2] * 0.299f; dst[0] = temp_dst; } } - -__global__ void cvt_gray2rgb_8u_kernel(const uchar* src, uchar* dst, - const size_t rows, const size_t cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_gray2rgb_8u_kernel( + const uchar* src, uchar* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { size_t t = blockIdx.x * blockDim.x + threadIdx.x; if (t < (rows * cols) / U8_PROCESS_PER_THREADS_X) { @@ -258,10 +247,9 @@ __global__ void cvt_gray2rgb_8u_kernel(const uchar* src, uchar* dst, } } -__global__ void cvt_gray2rgb_32f_kernel(const float* src, float* dst, - const size_t rows, const size_t cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_gray2rgb_32f_kernel( + const float* src, float* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { size_t t = blockIdx.x * blockDim.x + threadIdx.x; if (t < rows * cols) { @@ -281,10 +269,9 @@ __global__ void cvt_gray2rgb_32f_kernel(const float* src, float* dst, #define descale(x, n) (((x) + (1 << ((n)-1))) >> (n)) -__global__ void cvt_rgb2yuv_8u_kernel(const uchar* src, uchar* dst, - const size_t rows, const size_t cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_rgb2yuv_8u_kernel( + const uchar* src, uchar* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { size_t t = blockIdx.x * blockDim.x + threadIdx.x; const int yuv_shift = 14; @@ -300,7 +287,8 @@ __global__ void cvt_rgb2yuv_8u_kernel(const uchar* src, uchar* dst, *((uint3*)temp_src) = *((uint3*)src); int p = 0; - int y = descale(temp_src[0 + p] * coef[0] + temp_src[1 + p] * coef[1] + + int y = + descale(temp_src[0 + p] * coef[0] + temp_src[1 + p] * coef[1] + temp_src[2 + p] * coef[2], yuv_shift); int cr = descale((temp_src[0 + p] - y) * coef[3] + delta, yuv_shift); @@ -310,9 +298,10 @@ __global__ void cvt_rgb2yuv_8u_kernel(const uchar* src, uchar* dst, temp_dst[2 + p] = saturate(cb, 0, 255); p += 3; - y = descale(temp_src[0 + p] * coef[0] + temp_src[1 + p] * coef[1] + - temp_src[2 + p] * coef[2], - yuv_shift); + y = + descale(temp_src[0 + p] * coef[0] + temp_src[1 + p] * coef[1] + + temp_src[2 + p] * coef[2], + yuv_shift); cr = descale((temp_src[0 + p] - y) * coef[3] + delta, yuv_shift); cb = descale((temp_src[2 + p] - y) * coef[4] + delta, yuv_shift); temp_dst[0 + p] = saturate(y, 0, 255); @@ -320,9 +309,10 @@ __global__ void cvt_rgb2yuv_8u_kernel(const uchar* src, uchar* dst, temp_dst[2 + p] = saturate(cb, 0, 255); p += 3; - y = descale(temp_src[0 + p] * coef[0] + temp_src[1 + p] * coef[1] + - temp_src[2 + p] * coef[2], - yuv_shift); + y = + descale(temp_src[0 + p] * coef[0] + temp_src[1 + p] * coef[1] + + temp_src[2 + p] * coef[2], + yuv_shift); cr = descale((temp_src[0 + p] - y) * coef[3] + delta, yuv_shift); cb = descale((temp_src[2 + p] - y) * coef[4] + delta, yuv_shift); temp_dst[0 + p] = saturate(y, 0, 255); @@ -330,9 +320,10 @@ __global__ void cvt_rgb2yuv_8u_kernel(const uchar* src, uchar* dst, temp_dst[2 + p] = saturate(cb, 0, 255); p += 3; - y = descale(temp_src[0 + p] * coef[0] + temp_src[1 + p] * coef[1] + - temp_src[2 + p] * coef[2], - yuv_shift); + y = + descale(temp_src[0 + p] * coef[0] + temp_src[1 + p] * coef[1] + + temp_src[2 + p] * coef[2], + yuv_shift); cr = descale((temp_src[0 + p] - y) * coef[3] + delta, yuv_shift); cb = descale((temp_src[2 + p] - y) * coef[4] + delta, yuv_shift); temp_dst[0 + p] = saturate(y, 0, 255); @@ -351,13 +342,12 @@ __global__ void cvt_rgb2yuv_8u_kernel(const uchar* src, uchar* dst, uchar temp_src[3], temp_dst[3]; *((uchar3*)temp_src) = *((uchar3*)src); - int Y = descale(temp_src[0] * coef[0] + temp_src[1] * coef[1] + + int Y = + descale(temp_src[0] * coef[0] + temp_src[1] * coef[1] + temp_src[2] * coef[2], yuv_shift); - int Cr = - descale((temp_src[0] - Y) * coef[3] + delta, yuv_shift); - int Cb = - descale((temp_src[2] - Y) * coef[4] + delta, yuv_shift); + int Cr = descale((temp_src[0] - Y) * coef[3] + delta, yuv_shift); + int Cb = descale((temp_src[2] - Y) * coef[4] + delta, yuv_shift); temp_dst[0] = saturate(Y, 0, 255); temp_dst[1] = saturate(Cr, 0, 255); @@ -369,10 +359,9 @@ __global__ void cvt_rgb2yuv_8u_kernel(const uchar* src, uchar* dst, } } -__global__ void cvt_rgb2yuv_32f_kernel(const float* src, float* dst, - const size_t rows, const size_t cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_rgb2yuv_32f_kernel( + const float* src, float* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { size_t t = blockIdx.x * blockDim.x + threadIdx.x; const float coef[] = {0.114f, 0.587f, 0.299f, 0.492f, 0.877f}; @@ -386,8 +375,7 @@ __global__ void cvt_rgb2yuv_32f_kernel(const float* src, float* dst, float temp_src[3], temp_dst[3]; *((float3*)temp_src) = *((float3*)src); - float Y = temp_src[0] * coef[0] + temp_src[1] * coef[1] + - temp_src[2] * coef[2]; + float Y = temp_src[0] * coef[0] + temp_src[1] * coef[1] + temp_src[2] * coef[2]; temp_dst[0] = Y; temp_dst[1] = (temp_src[0] - Y) * coef[3] + delta; temp_dst[2] = (temp_src[2] - Y) * coef[4] + delta; @@ -396,10 +384,9 @@ __global__ void cvt_rgb2yuv_32f_kernel(const float* src, float* dst, } } -__global__ void cvt_yuv2rgb_8u_kernel(const uchar* src, uchar* dst, - const size_t rows, const size_t cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_yuv2rgb_8u_kernel( + const uchar* src, uchar* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { size_t t = blockIdx.x * blockDim.x + threadIdx.x; const int yuv_shift = 14; @@ -417,10 +404,9 @@ __global__ void cvt_yuv2rgb_8u_kernel(const uchar* src, uchar* dst, int p = 0; int R = temp_src[0 + p] + descale((temp_src[1 + p] - delta) * coef[0], yuv_shift); - int G = temp_src[0 + p] + - descale((temp_src[2 + p] - delta) * coef[2] + - (temp_src[1 + p] - delta) * coef[1], - yuv_shift); + int G = temp_src[0 + p] + descale((temp_src[2 + p] - delta) * coef[2] + + (temp_src[1 + p] - delta) * coef[1], + yuv_shift); int B = temp_src[0 + p] + descale((temp_src[2 + p] - delta) * coef[3], yuv_shift); @@ -429,42 +415,33 @@ __global__ void cvt_yuv2rgb_8u_kernel(const uchar* src, uchar* dst, temp_dst[2 + p] = saturate(B, 0, 255); p += 3; - R = temp_src[0 + p] + - descale((temp_src[1 + p] - delta) * coef[0], yuv_shift); - G = temp_src[0 + p] + - descale((temp_src[2 + p] - delta) * coef[2] + - (temp_src[1 + p] - delta) * coef[1], - yuv_shift); - B = temp_src[0 + p] + - descale((temp_src[2 + p] - delta) * coef[3], yuv_shift); + R = temp_src[0 + p] + descale((temp_src[1 + p] - delta) * coef[0], yuv_shift); + G = temp_src[0 + p] + descale((temp_src[2 + p] - delta) * coef[2] + + (temp_src[1 + p] - delta) * coef[1], + yuv_shift); + B = temp_src[0 + p] + descale((temp_src[2 + p] - delta) * coef[3], yuv_shift); temp_dst[0 + p] = saturate(R, 0, 255); temp_dst[1 + p] = saturate(G, 0, 255); temp_dst[2 + p] = saturate(B, 0, 255); p += 3; - R = temp_src[0 + p] + - descale((temp_src[1 + p] - delta) * coef[0], yuv_shift); - G = temp_src[0 + p] + - descale((temp_src[2 + p] - delta) * coef[2] + - (temp_src[1 + p] - delta) * coef[1], - yuv_shift); - B = temp_src[0 + p] + - descale((temp_src[2 + p] - delta) * coef[3], yuv_shift); + R = temp_src[0 + p] + descale((temp_src[1 + p] - delta) * coef[0], yuv_shift); + G = temp_src[0 + p] + descale((temp_src[2 + p] - delta) * coef[2] + + (temp_src[1 + p] - delta) * coef[1], + yuv_shift); + B = temp_src[0 + p] + descale((temp_src[2 + p] - delta) * coef[3], yuv_shift); temp_dst[0 + p] = saturate(R, 0, 255); temp_dst[1 + p] = saturate(G, 0, 255); temp_dst[2 + p] = saturate(B, 0, 255); p += 3; - R = temp_src[0 + p] + - descale((temp_src[1 + p] - delta) * coef[0], yuv_shift); - G = temp_src[0 + p] + - descale((temp_src[2 + p] - delta) * coef[2] + - (temp_src[1 + p] - delta) * coef[1], - yuv_shift); - B = temp_src[0 + p] + - descale((temp_src[2 + p] - delta) * coef[3], yuv_shift); + R = temp_src[0 + p] + descale((temp_src[1 + p] - delta) * coef[0], yuv_shift); + G = temp_src[0 + p] + descale((temp_src[2 + p] - delta) * coef[2] + + (temp_src[1 + p] - delta) * coef[1], + yuv_shift); + B = temp_src[0 + p] + descale((temp_src[2 + p] - delta) * coef[3], yuv_shift); temp_dst[0 + p] = saturate(R, 0, 255); temp_dst[1 + p] = saturate(G, 0, 255); @@ -482,9 +459,8 @@ __global__ void cvt_yuv2rgb_8u_kernel(const uchar* src, uchar* dst, uchar Y = src[0], Cr = src[1], Cb = src[2]; int R = Y + descale((Cr - delta) * coef[0], yuv_shift); - int G = Y + - descale((Cb - delta) * coef[2] + (Cr - delta) * coef[1], - yuv_shift); + int G = Y + descale((Cb - delta) * coef[2] + (Cr - delta) * coef[1], + yuv_shift); int B = Y + descale((Cb - delta) * coef[3], yuv_shift); dst[0] = saturate(R, 0, 255); @@ -495,10 +471,9 @@ __global__ void cvt_yuv2rgb_8u_kernel(const uchar* src, uchar* dst, } } -__global__ void cvt_yuv2rgb_32f_kernel(const float* src, float* dst, - const size_t rows, const size_t cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_yuv2rgb_32f_kernel( + const float* src, float* dst, const size_t rows, const size_t cols, + const size_t src_step, const size_t dst_step) { size_t t = blockIdx.x * blockDim.x + threadIdx.x; const float coef[] = {2.032f, -0.395f, -0.581f, 1.140f}; @@ -524,11 +499,9 @@ __global__ void cvt_yuv2rgb_32f_kernel(const float* src, float* dst, } // convert planar or semi-planar YUV to gray. data type: uint8 -__global__ void cvt_yuv2gray_psp_8u_kernel(const uchar* src, uchar* dst, - const size_t dst_rows, - const size_t dst_cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_yuv2gray_psp_8u_kernel( + const uchar* src, uchar* dst, const size_t dst_rows, const size_t dst_cols, + const size_t src_step, const size_t dst_step) { int c = (blockIdx.x * blockDim.x + threadIdx.x) * U8_PROCESS_PER_THREADS_X; int r = blockIdx.y * blockDim.y + threadIdx.y; src += r * src_step + c; @@ -544,11 +517,9 @@ __global__ void cvt_yuv2gray_psp_8u_kernel(const uchar* src, uchar* dst, // is_rgb: convert to RGB if true, otherwise convert to BGR // is_nv12: decode src as YUV_NV12 if true, YUV_NV21 otherwise template -__global__ void cvt_yuv2rgbbgr_sp_8u_kernel(const uchar* src, uchar* dst, - const size_t dst_rows, - const size_t dst_cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_yuv2rgbbgr_sp_8u_kernel( + const uchar* src, uchar* dst, const size_t dst_rows, const size_t dst_cols, + const size_t src_step, const size_t dst_step) { int c = (blockIdx.x * blockDim.x + threadIdx.x) * 2; int r = (blockIdx.y * blockDim.y + threadIdx.y) * 2; if (c >= dst_cols || r >= dst_rows) @@ -617,11 +588,9 @@ __global__ void cvt_yuv2rgbbgr_sp_8u_kernel(const uchar* src, uchar* dst, // is_rgb: convert to RGB if true, otherwise convert to BGR // is_nv12: decode src as YUV_NV12 if true, YUV_NV21 otherwise template -__global__ void cvt_yuv2rgbbgr_p_8u_kernel(const uchar* src, uchar* dst, - const size_t dst_rows, - const size_t dst_cols, - const size_t src_step, - const size_t dst_step) { +__global__ void cvt_yuv2rgbbgr_p_8u_kernel( + const uchar* src, uchar* dst, const size_t dst_rows, const size_t dst_cols, + const size_t src_step, const size_t dst_step) { int c = (blockIdx.x * blockDim.x + threadIdx.x) * 2; int r = (blockIdx.y * blockDim.y + threadIdx.y) * 2; if (c >= dst_cols || r >= dst_rows) @@ -686,13 +655,12 @@ __global__ void cvt_yuv2rgbbgr_p_8u_kernel(const uchar* src, uchar* dst, #undef SET_COLOR } -#define CALL_CVT_OPR_8U_KERNEL(_func) \ - { \ - dim3 THREADS(THREADS_X); \ - dim3 BLOCKS(DIVUP(src_cols* src_rows, \ - THREADS_X* U8_PROCESS_PER_THREADS_X)); \ - cvt_##_func##_8u_kernel<<>>( \ - src, dst, src_rows, src_cols, src_step, dst_step); \ +#define CALL_CVT_OPR_8U_KERNEL(_func) \ + { \ + dim3 THREADS(THREADS_X); \ + dim3 BLOCKS(DIVUP(src_cols* src_rows, THREADS_X* U8_PROCESS_PER_THREADS_X)); \ + cvt_##_func##_8u_kernel<<>>( \ + src, dst, src_rows, src_cols, src_step, dst_step); \ } #define CALL_CVT_OPR_32F_KERNEL(_func) \ @@ -704,48 +672,42 @@ __global__ void cvt_yuv2rgbbgr_p_8u_kernel(const uchar* src, uchar* dst, } // convert planar or semi-planar YUV to gray, data tyoe: uint8 -#define CALL_CVT_YUV2GRAY_PSP_OPR_8U_KERNEL \ - { \ - dim3 THREADS(THREADS_X, 1); \ - dim3 BLOCKS(DIVUP(dst_cols, THREADS_X* U8_PROCESS_PER_THREADS_X), \ - dst_rows); \ - cvt_yuv2gray_psp_8u_kernel<<>>( \ - src, dst, dst_rows, dst_cols, src_step, dst_step); \ +#define CALL_CVT_YUV2GRAY_PSP_OPR_8U_KERNEL \ + { \ + dim3 THREADS(THREADS_X, 1); \ + dim3 BLOCKS(DIVUP(dst_cols, THREADS_X* U8_PROCESS_PER_THREADS_X), dst_rows); \ + cvt_yuv2gray_psp_8u_kernel<<>>( \ + src, dst, dst_rows, dst_cols, src_step, dst_step); \ } // convert semi-planar YUV to RGB or BGR. data type: uint8 // is_rgb: convert to RGB if true, otherwise convert to BGR // is_nv12: decode src as YUV_NV12 if true, YUV_NV21 otherwise -#define CALL_CVT_YUV2RGBBGR_SP_OPR_8U_KERNEL(is_rgb, is_nv12) \ - { \ - dim3 THREADS(THREADS_X, THREADS_Y); \ - dim3 BLOCKS(DIVUP(dst_cols / 2, THREADS_X), \ - DIVUP(dst_rows / 2, THREADS_Y)); \ - cvt_yuv2rgbbgr_sp_8u_kernel \ - <<>>(src, dst, dst_rows, dst_cols, \ - src_step, dst_step); \ +#define CALL_CVT_YUV2RGBBGR_SP_OPR_8U_KERNEL(is_rgb, is_nv12) \ + { \ + dim3 THREADS(THREADS_X, THREADS_Y); \ + dim3 BLOCKS(DIVUP(dst_cols / 2, THREADS_X), DIVUP(dst_rows / 2, THREADS_Y)); \ + cvt_yuv2rgbbgr_sp_8u_kernel<<>>( \ + src, dst, dst_rows, dst_cols, src_step, dst_step); \ } // convert planar YUV to RGB or BGR. data type: uint8 // is_rgb: convert to RGB if true, otherwise convert to BGR // is_yu12: decode src as YUV_YU12 if true, YUV_YV12 otherwise -#define CALL_CVT_YUV2RGBBGR_P_OPR_8U_KERNEL(is_rgb, is_yu12) \ - { \ - dim3 THREADS(THREADS_X, THREADS_Y); \ - dim3 BLOCKS(DIVUP(dst_cols / 2, THREADS_X), \ - DIVUP(dst_rows / 2, THREADS_Y)); \ - cvt_yuv2rgbbgr_p_8u_kernel \ - <<>>(src, dst, dst_rows, dst_cols, \ - src_step, dst_step); \ +#define CALL_CVT_YUV2RGBBGR_P_OPR_8U_KERNEL(is_rgb, is_yu12) \ + { \ + dim3 THREADS(THREADS_X, THREADS_Y); \ + dim3 BLOCKS(DIVUP(dst_cols / 2, THREADS_X), DIVUP(dst_rows / 2, THREADS_Y)); \ + cvt_yuv2rgbbgr_p_8u_kernel<<>>( \ + src, dst, dst_rows, dst_cols, src_step, dst_step); \ } using namespace param_enumv; -void cvt_color_8u_proxy(const uchar* src, uchar* dst, const size_t src_rows, - const size_t src_cols, const size_t src_step, - const size_t dst_rows, const size_t dst_cols, - const size_t dst_step, const uint32_t mode, - cudaStream_t stream) { +void cvt_color_8u_proxy( + const uchar* src, uchar* dst, const size_t src_rows, const size_t src_cols, + const size_t src_step, const size_t dst_rows, const size_t dst_cols, + const size_t dst_step, const uint32_t mode, cudaStream_t stream) { switch (mode) { case CvtColor::Mode::RGB2GRAY: CALL_CVT_OPR_8U_KERNEL(rgb2gray) @@ -798,11 +760,10 @@ void cvt_color_8u_proxy(const uchar* src, uchar* dst, const size_t src_rows, } } -void cvt_color_32f_proxy(const float* src, float* dst, const size_t src_rows, - const size_t src_cols, const size_t src_step, - const size_t dst_rows, const size_t dst_cols, - const size_t dst_step, const uint32_t mode, - cudaStream_t stream) { +void cvt_color_32f_proxy( + const float* src, float* dst, const size_t src_rows, const size_t src_cols, + const size_t src_step, const size_t dst_rows, const size_t dst_cols, + const size_t dst_step, const uint32_t mode, cudaStream_t stream) { MEGDNN_MARK_USED_VAR(dst_rows); MEGDNN_MARK_USED_VAR(dst_cols); switch (mode) { diff --git a/dnn/src/cuda/cvt_color/cvt_color.cuh b/dnn/src/cuda/cvt_color/cvt_color.cuh index 8b61180b..9ff4724d 100644 --- a/dnn/src/cuda/cvt_color/cvt_color.cuh +++ b/dnn/src/cuda/cvt_color/cvt_color.cuh @@ -69,17 +69,15 @@ namespace cvt_color { typedef unsigned char uchar; -void cvt_color_8u_proxy(const uchar* src, uchar* dst, const size_t src_rows, - const size_t src_cols, const size_t src_step, - const size_t dst_rows, const size_t dst_cols, - const size_t dst_step, const uint32_t mode, - cudaStream_t stream); +void cvt_color_8u_proxy( + const uchar* src, uchar* dst, const size_t src_rows, const size_t src_cols, + const size_t src_step, const size_t dst_rows, const size_t dst_cols, + const size_t dst_step, const uint32_t mode, cudaStream_t stream); -void cvt_color_32f_proxy(const float* src, float* dst, const size_t src_rows, - const size_t src_cols, const size_t src_step, - const size_t dst_rows, const size_t dst_cols, - const size_t dst_step, const uint32_t mode, - cudaStream_t stream); +void cvt_color_32f_proxy( + const float* src, float* dst, const size_t src_rows, const size_t src_cols, + const size_t src_step, const size_t dst_rows, const size_t dst_cols, + const size_t dst_step, const uint32_t mode, cudaStream_t stream); } // namespace cvt_color } // namespace cuda diff --git a/dnn/src/cuda/cvt_color/opr_impl.cpp b/dnn/src/cuda/cvt_color/opr_impl.cpp index d9df4f48..4d5f6c36 100644 --- a/dnn/src/cuda/cvt_color/opr_impl.cpp +++ b/dnn/src/cuda/cvt_color/opr_impl.cpp @@ -14,8 +14,8 @@ #include "src/cuda/utils.h" #include "src/common/cv/common.h" -#include "src/common/cv/helper.h" #include "src/common/cv/cvt_color.h" +#include "src/common/cv/helper.h" #include @@ -25,35 +25,34 @@ namespace cuda { using namespace megcv; using namespace cvt_color; - -void CvtColorImpl::cvt_color_exec_8u(_megdnn_tensor_in src_tensor, - _megdnn_tensor_in dst_tensor) { +void CvtColorImpl::cvt_color_exec_8u( + _megdnn_tensor_in src_tensor, _megdnn_tensor_in dst_tensor) { auto stream = cuda_stream(this->handle()); for (size_t i = 0; i < src_tensor.layout.shape[0]; ++i) { Mat src = TensorND2Mat(src_tensor, i); Mat dst = TensorND2Mat(dst_tensor, i); - cvt_color_8u_proxy(src.ptr(), dst.ptr(), src.rows(), src.cols(), - src.step(), dst.rows(), dst.cols(), dst.step(), - static_cast(param().mode), stream); + cvt_color_8u_proxy( + src.ptr(), dst.ptr(), src.rows(), src.cols(), src.step(), dst.rows(), + dst.cols(), dst.step(), static_cast(param().mode), stream); } } -void CvtColorImpl::cvt_color_exec_32f(_megdnn_tensor_in src_tensor, - _megdnn_tensor_in dst_tensor) { +void CvtColorImpl::cvt_color_exec_32f( + _megdnn_tensor_in src_tensor, _megdnn_tensor_in dst_tensor) { auto stream = cuda_stream(this->handle()); for (size_t i = 0; i < src_tensor.layout.shape[0]; ++i) { Mat src = TensorND2Mat(src_tensor, i); Mat dst = TensorND2Mat(dst_tensor, i); - cvt_color_32f_proxy(src.ptr(), dst.ptr(), src.rows(), src.cols(), - src.step(), dst.rows(), dst.cols(), dst.step(), - static_cast(param().mode), stream); + cvt_color_32f_proxy( + src.ptr(), dst.ptr(), src.rows(), src.cols(), src.step(), dst.rows(), + dst.cols(), dst.step(), static_cast(param().mode), stream); } } -void CvtColorImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_workspace workspace) { +void CvtColorImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { using namespace megcv; check_exec(src.layout, dst.layout, workspace.size); diff --git a/dnn/src/cuda/cvt_color/opr_impl.h b/dnn/src/cuda/cvt_color/opr_impl.h index a6f2a04f..aa529779 100644 --- a/dnn/src/cuda/cvt_color/opr_impl.h +++ b/dnn/src/cuda/cvt_color/opr_impl.h @@ -22,11 +22,11 @@ private: public: using CvtColor::CvtColor; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, - const TensorLayout&) override { + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { return 0; } }; diff --git a/dnn/src/cuda/dct/dct_channel_select.cu b/dnn/src/cuda/dct/dct_channel_select.cu index d8bd8b3e..9c2f4fc9 100644 --- a/dnn/src/cuda/dct/dct_channel_select.cu +++ b/dnn/src/cuda/dct/dct_channel_select.cu @@ -60,17 +60,16 @@ inline __device__ void load_row(float (&row_cache)[8], const uint8_t* src) { row_cache[7] = (float)(((uchar4*)&(row.y))->w); } -inline __device__ void fast_dct_1d_internel(float& src0, float& src1, - float& src2, float& src3, - float& src4, float& src5, - float& src6, float& src7) { +inline __device__ void fast_dct_1d_internel( + float& src0, float& src1, float& src2, float& src3, float& src4, float& src5, + float& src6, float& src7) { constexpr float rsqrt_8 = 0.3535533905932737f; //!< rsqrt_8 = sqrt(1 / 8) - constexpr float a = 1.387039845322148f; //!< a = sqrt2 * cos(pi * 1 / 16) - constexpr float b = 1.306562964876377f; //!< b = sqrt2 * cos(pi * 2 / 16) - constexpr float c = 1.175875602419359f; //!< c = sqrt2 * cos(pi * 3 / 16) - constexpr float d = 0.785694958387102f; //!< d = sqrt2 * cos(pi * 5 / 16) - constexpr float e = 0.541196100146197f; //!< e = sqrt2 * cos(pi * 6 / 16) - constexpr float f = 0.275899379282943f; //!< f = sqrt2 * cos(pi * 7 / 16) + constexpr float a = 1.387039845322148f; //!< a = sqrt2 * cos(pi * 1 / 16) + constexpr float b = 1.306562964876377f; //!< b = sqrt2 * cos(pi * 2 / 16) + constexpr float c = 1.175875602419359f; //!< c = sqrt2 * cos(pi * 3 / 16) + constexpr float d = 0.785694958387102f; //!< d = sqrt2 * cos(pi * 5 / 16) + constexpr float e = 0.541196100146197f; //!< e = sqrt2 * cos(pi * 6 / 16) + constexpr float f = 0.275899379282943f; //!< f = sqrt2 * cos(pi * 7 / 16) const float add_0_7 = src0 + src7; const float add_1_6 = src1 + src6; @@ -98,28 +97,25 @@ inline __device__ void fast_dct_1d_internel(float& src0, float& src1, } inline __device__ void fast_dct_1d(float (&src)[8]) { - fast_dct_1d_internel(src[0], src[1], src[2], src[3], src[4], src[5], src[6], - src[7]); + fast_dct_1d_internel( + src[0], src[1], src[2], src[3], src[4], src[5], src[6], src[7]); } inline __device__ void fast_dct_1d_col(float (&src)[8][8], const int col) { - fast_dct_1d_internel(src[0][col], src[1][col], src[2][col], src[3][col], - src[4][col], src[5][col], src[6][col], src[7][col]); + fast_dct_1d_internel( + src[0][col], src[1][col], src[2][col], src[3][col], src[4][col], + src[5][col], src[6][col], src[7][col]); } -enum class MaskType { - NO_MASK = 0, - USER_DEFINE_MASK = 1, - FIX_32_MASK = 2, - MASK_END -}; -template +enum class MaskType { NO_MASK = 0, USER_DEFINE_MASK = 1, FIX_32_MASK = 2, MASK_END }; +template < + const int dct_block, const int block_oh, const int block_ow, uint32_t format, + MaskType mask_type, typename DstDtype, typename T2> struct StoreMask; -template -struct StoreMask { +template +struct StoreMask< + dct_block, block_oh, block_ow, DctLayoutFormat::NCHW, + MaskType::USER_DEFINE_MASK, float, T2> { static inline __device__ void func( const float (&thread_cache)[dct_block][dct_block], float* dst_tid, const int oc_stride, int channel_idx, const int* mask_offset, @@ -136,13 +132,12 @@ struct StoreMask 0"); + set_async_error_info( + error_info, error_tracker, "nchw sub mask len must > 0"); } for (int store_channel_idx = 0; store_channel_idx < nr_store_channel; ++store_channel_idx) { - const int index = - mask_val[store_channel_offset + store_channel_idx]; + const int index = mask_val[store_channel_offset + store_channel_idx]; dst_tid[store_channel_idx * oc_stride] = shared[index / dct_block][index % dct_block][threadIdx.y] [threadIdx.x]; @@ -150,10 +145,10 @@ struct StoreMask -struct StoreMask { +template +struct StoreMask< + dct_block, block_oh, block_ow, DctLayoutFormat::NCHW4, + MaskType::USER_DEFINE_MASK, int8_t, T2> { static inline __device__ void func( const float (&thread_cache)[dct_block][dct_block], int8_t* dst_tid, const int oc_stride, int channel_idx, const int* mask_offset, @@ -173,47 +168,40 @@ struct StoreMask 0"); + set_async_error_info( + error_info, error_tracker, + "nchw4 sub_mask_len mod 4 should be 0 and " + "sub_mask_len must > 0"); } for (int store_channel_idx = 0; store_channel_idx < nr_store_channel; store_channel_idx += channel_block) { - const int index0 = - mask_val[store_channel_offset + store_channel_idx]; - const int index1 = - mask_val[store_channel_offset + store_channel_idx + 1]; - const int index2 = - mask_val[store_channel_offset + store_channel_idx + 2]; - const int index3 = - mask_val[store_channel_offset + store_channel_idx + 3]; + const int index0 = mask_val[store_channel_offset + store_channel_idx]; + const int index1 = mask_val[store_channel_offset + store_channel_idx + 1]; + const int index2 = mask_val[store_channel_offset + store_channel_idx + 2]; + const int index3 = mask_val[store_channel_offset + store_channel_idx + 3]; const int store_c4_idx = store_channel_idx / channel_block; *(char4*)(&dst_tid[store_c4_idx * channel_block * oc_stride]) = { - quant_param.func( - shared[index0 / dct_block][index0 % dct_block] - [threadIdx.y][threadIdx.x]), - quant_param.func( - shared[index1 / dct_block][index1 % dct_block] - [threadIdx.y][threadIdx.x]), - quant_param.func( - shared[index2 / dct_block][index2 % dct_block] - [threadIdx.y][threadIdx.x]), - quant_param.func( - shared[index3 / dct_block][index3 % dct_block] - [threadIdx.y][threadIdx.x])}; + quant_param.func(shared[index0 / dct_block][index0 % dct_block] + [threadIdx.y][threadIdx.x]), + quant_param.func(shared[index1 / dct_block][index1 % dct_block] + [threadIdx.y][threadIdx.x]), + quant_param.func(shared[index2 / dct_block][index2 % dct_block] + [threadIdx.y][threadIdx.x]), + quant_param.func(shared[index3 / dct_block][index3 % dct_block] + [threadIdx.y][threadIdx.x])}; } } }; -template -struct StoreMask { +template < + const int dct_block, const int block_oh, const int block_ow, uint32_t format, + typename DstDtype, typename T2> +struct StoreMask< + dct_block, block_oh, block_ow, format, MaskType::NO_MASK, DstDtype, T2> { static inline __device__ void func( - const float (&thread_cache)[dct_block][dct_block], - DstDtype* dst_tid, const int oc_stride, int channel_idx, - const int* mask_offset, const int* mask_val, - CudaPostProcess& quant_param, + const float (&thread_cache)[dct_block][dct_block], DstDtype* dst_tid, + const int oc_stride, int channel_idx, const int* mask_offset, + const int* mask_val, CudaPostProcess& quant_param, megcore::AsyncErrorInfo* error_info, void* error_tracker) { constexpr int channel_block = ChannelBlockHelper::channel_block; #pragma unroll @@ -229,10 +217,10 @@ struct StoreMask -struct StoreMask { +template +struct StoreMask< + dct_block, block_oh, block_ow, DctLayoutFormat::NCHW, MaskType::FIX_32_MASK, + float, T2> { static inline __device__ void func( const float (&thread_cache)[dct_block][dct_block], float* dst_tid, const int oc_stride, int channel_idx, const int* mask_offset, @@ -265,25 +253,21 @@ struct StoreMask -struct StoreMask { +template +struct StoreMask< + dct_block, block_oh, block_ow, DctLayoutFormat::NCHW4, MaskType::FIX_32_MASK, + int8_t, T2> { static inline __device__ void func( const float (&thread_cache)[dct_block][dct_block], int8_t* dst_tid, const int oc_stride, int channel_idx, const int* mask_offset, const int* mask_val, CudaPostProcess& quant_param, megcore::AsyncErrorInfo* error_info, void* error_tracker) { -#define STORE(store_index, index0, index1, index2, index3) \ - *(char4*)(&dst_tid[store_index * oc_stride]) = { \ - quant_param.func( \ - thread_cache[index0 / dct_block][index0 % dct_block]), \ - quant_param.func( \ - thread_cache[index1 / dct_block][index1 % dct_block]), \ - quant_param.func( \ - thread_cache[index2 / dct_block][index2 % dct_block]), \ - quant_param.func( \ - thread_cache[index3 / dct_block][index3 % dct_block])} +#define STORE(store_index, index0, index1, index2, index3) \ + *(char4*)(&dst_tid[store_index * oc_stride]) = { \ + quant_param.func(thread_cache[index0 / dct_block][index0 % dct_block]), \ + quant_param.func(thread_cache[index1 / dct_block][index1 % dct_block]), \ + quant_param.func(thread_cache[index2 / dct_block][index2 % dct_block]), \ + quant_param.func(thread_cache[index3 / dct_block][index3 % dct_block])} STORE(0, 0, 1, 8, 16); STORE(4, 9, 2, 3, 10); @@ -295,16 +279,14 @@ struct StoreMask -__global__ void kern_dct(const uint8_t* src, DstDtype* dst, const int n, - const int c, const int h, const int w, const int oh, - const int ow, const int oc_stride, const int oc, - const int* mask_offset, const int* mask_val, - CudaPostProcess quant_param, - megcore::AsyncErrorInfo* error_info, - void* error_tracker) { +template < + const int dct_block, MaskType mask_type, const int ker_block_h, + const int ker_block_w, uint32_t format, typename DstDtype, typename T2> +__global__ void kern_dct( + const uint8_t* src, DstDtype* dst, const int n, const int c, const int h, + const int w, const int oh, const int ow, const int oc_stride, const int oc, + const int* mask_offset, const int* mask_val, CudaPostProcess quant_param, + megcore::AsyncErrorInfo* error_info, void* error_tracker) { constexpr int block_oh = ker_block_h / dct_block; constexpr int block_ow = ker_block_w / dct_block; const int channel_stride = h * w; @@ -312,15 +294,13 @@ __global__ void kern_dct(const uint8_t* src, DstDtype* dst, const int n, const int oh_idx = blockIdx.y * block_oh + threadIdx.y; const int ow_idx = blockIdx.x * block_ow + threadIdx.x; float thread_cache[dct_block][dct_block]; - const uint8_t* src_tid = - src + blockIdx.z * channel_stride + - (blockIdx.y * ker_block_h + threadIdx.y * dct_block) * w + - (blockIdx.x * ker_block_w + threadIdx.x * dct_block); + const uint8_t* src_tid = src + blockIdx.z * channel_stride + + (blockIdx.y * ker_block_h + threadIdx.y * dct_block) * w + + (blockIdx.x * ker_block_w + threadIdx.x * dct_block); const int inner_channel_offset = (oh_idx * ow + ow_idx) * ChannelBlockHelper::channel_block; - DstDtype* dst_tid = - dst + blockIdx.z * channel_stride + inner_channel_offset; + DstDtype* dst_tid = dst + blockIdx.z * channel_stride + inner_channel_offset; if (mask_type != MaskType::NO_MASK) { const int batch_idx = blockIdx.z / c; const int batch_stride = oc_stride * oc; @@ -331,8 +311,8 @@ __global__ void kern_dct(const uint8_t* src, DstDtype* dst, const int n, } else { out_channel_offset = mask_offset[oc_idx]; } - dst_tid = dst + batch_idx * batch_stride + - out_channel_offset * oc_stride + inner_channel_offset; + dst_tid = dst + batch_idx * batch_stride + out_channel_offset * oc_stride + + inner_channel_offset; } if (oh_idx < oh && ow_idx < ow) { @@ -365,61 +345,58 @@ __global__ void kern_dct(const uint8_t* src, DstDtype* dst, const int n, fast_dct_1d_col(thread_cache, 6); fast_dct_1d_col(thread_cache, 7); - StoreMask::func(thread_cache, dst_tid, oc_stride, oc_idx, - mask_offset, mask_val, quant_param, error_info, - error_tracker); + StoreMask::func( + thread_cache, dst_tid, oc_stride, oc_idx, mask_offset, mask_val, + quant_param, error_info, error_tracker); } } } // namespace template -void call_kern_dct(const uint8_t* d_src, DstDtype* d_dst, const int n, - const int c, const int h, const int w, const int oc, - bool fix_32_mask, const int* mask_offset, - const int* mask_val, cudaStream_t stream, - megcore::AsyncErrorInfo* error_info, void* error_tracker, - float scale) { +void call_kern_dct( + const uint8_t* d_src, DstDtype* d_dst, const int n, const int c, const int h, + const int w, const int oc, bool fix_32_mask, const int* mask_offset, + const int* mask_val, cudaStream_t stream, megcore::AsyncErrorInfo* error_info, + void* error_tracker, float scale) { constexpr int ker_block_h = 32; constexpr int ker_block_w = 256; const int oh = h / dct_block; const int ow = w / dct_block; const int oc_stride = oh * ow; const dim3 block_dim(DIVUP(w, ker_block_w), DIVUP(h, ker_block_h), n * c); - const dim3 thread_dim(DIVUP(ker_block_w, dct_block), - DIVUP(ker_block_h, dct_block)); + const dim3 thread_dim(DIVUP(ker_block_w, dct_block), DIVUP(ker_block_h, dct_block)); auto cuda_dtype_param = CudaPostProcess(scale); if (fix_32_mask) { - kern_dct<<>>( - d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, mask_offset, - mask_val, cuda_dtype_param, error_info, error_tracker); + kern_dct + <<>>( + d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, mask_offset, + mask_val, cuda_dtype_param, error_info, error_tracker); } else if (mask_offset && mask_val) { - kern_dct<<>>( - d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, mask_offset, - mask_val, cuda_dtype_param, error_info, error_tracker); + kern_dct< + dct_block, MaskType::USER_DEFINE_MASK, ker_block_h, ker_block_w, format> + <<>>( + d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, mask_offset, + mask_val, cuda_dtype_param, error_info, error_tracker); } else { kern_dct <<>>( - d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, - mask_offset, mask_val, cuda_dtype_param, error_info, - error_tracker); + d_src, d_dst, n, c, h, w, oh, ow, oc_stride, oc, mask_offset, + mask_val, cuda_dtype_param, error_info, error_tracker); } } template void call_kern_dct<8, DctLayoutFormat::NCHW, float>( - const uint8_t* d_src, float* d_dst, const int n, const int c, - const int h, const int w, const int oc, bool fix_32_mask, - const int* mask_offset, const int* mask_val, cudaStream_t stream, - megcore::AsyncErrorInfo* error_info, void* error_tracker, float scale); + const uint8_t* d_src, float* d_dst, const int n, const int c, const int h, + const int w, const int oc, bool fix_32_mask, const int* mask_offset, + const int* mask_val, cudaStream_t stream, megcore::AsyncErrorInfo* error_info, + void* error_tracker, float scale); template void call_kern_dct<8, DctLayoutFormat::NCHW4, int8_t>( - const uint8_t* d_src, int8_t* d_dst, const int n, const int c, - const int h, const int w, const int oc, bool fix_32_mask, - const int* mask_offset, const int* mask_val, cudaStream_t stream, - megcore::AsyncErrorInfo* error_info, void* error_tracker, float scale); + const uint8_t* d_src, int8_t* d_dst, const int n, const int c, const int h, + const int w, const int oc, bool fix_32_mask, const int* mask_offset, + const int* mask_val, cudaStream_t stream, megcore::AsyncErrorInfo* error_info, + void* error_tracker, float scale); } // namespace dct diff --git a/dnn/src/cuda/dct/dct_channel_select.cuh b/dnn/src/cuda/dct/dct_channel_select.cuh index 451dbff0..822314d5 100644 --- a/dnn/src/cuda/dct/dct_channel_select.cuh +++ b/dnn/src/cuda/dct/dct_channel_select.cuh @@ -24,12 +24,11 @@ namespace dct { using DctLayoutFormat = megdnn::param_enumv::DctChannelSelect::Format; template -void call_kern_dct(const uint8_t* d_src, DstDtype* d_dst, const int n, - const int c, const int h, const int w, const int oc, - bool fix_32_mask, const int* mask_offset, - const int* mask_val, cudaStream_t stream, - megcore::AsyncErrorInfo* error_info, void* error_tracker, - float scale = 1.f); +void call_kern_dct( + const uint8_t* d_src, DstDtype* d_dst, const int n, const int c, const int h, + const int w, const int oc, bool fix_32_mask, const int* mask_offset, + const int* mask_val, cudaStream_t stream, megcore::AsyncErrorInfo* error_info, + void* error_tracker, float scale = 1.f); } // namespace dct } // namespace cuda diff --git a/dnn/src/cuda/dct/opr_impl.cpp b/dnn/src/cuda/dct/opr_impl.cpp index bd81953c..09e141e8 100644 --- a/dnn/src/cuda/dct/opr_impl.cpp +++ b/dnn/src/cuda/dct/opr_impl.cpp @@ -9,37 +9,35 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ +#include "src/cuda/dct/opr_impl.h" #include "src/common/utils.h" #include "src/cuda/dct/dct_channel_select.cuh" -#include "src/cuda/dct/opr_impl.h" #include "src/cuda/handle.h" #include "src/cuda/utils.h" namespace megdnn { namespace cuda { -void DctChannelSelectForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in mask_offset, - _megdnn_tensor_in mask_val, - _megdnn_tensor_out dst, - _megdnn_workspace /*workspace*/) { +void DctChannelSelectForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in mask_offset, + _megdnn_tensor_in mask_val, _megdnn_tensor_out dst, + _megdnn_workspace /*workspace*/) { auto stream = cuda_stream(this->handle()); const int in = src.layout.shape[0]; const int ic = src.layout.shape[1]; const int ih = src.layout.shape[2]; const int iw = src.layout.shape[3]; int oc = dst.layout.shape[1]; - const bool with_fix_32_mask = - param().fastImpl == Param::FastImpl::FIX_32_MASK; + const bool with_fix_32_mask = param().fastImpl == Param::FastImpl::FIX_32_MASK; if (param().format == Param::Format::NCHW4) { - megdnn_assert(dst.layout.ndim == 5 && dst.layout.shape[4] == 4, - "dst must be nchw4"); + megdnn_assert( + dst.layout.ndim == 5 && dst.layout.shape[4] == 4, "dst must be nchw4"); oc = oc * 4; } - megdnn_assert(!with_fix_32_mask || (with_fix_32_mask && oc == 32), - "only support specify mask"); + megdnn_assert( + !with_fix_32_mask || (with_fix_32_mask && oc == 32), + "only support specify mask"); megdnn_assert(param().dct_block_size == 8, "only support dct block = 8"); - auto error_info = - concrete_handle(this->handle())->megcore_context().error_info; + auto error_info = concrete_handle(this->handle())->megcore_context().error_info; constexpr int dct_block = 8; const int* mask_offset_ptr = nullptr; const int* mask_val_ptr = nullptr; @@ -48,21 +46,21 @@ void DctChannelSelectForwardImpl::exec(_megdnn_tensor_in src, mask_val_ptr = mask_val.ptr(); } if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { - megdnn_assert(param().format == Param::Format::NCHW, - "fp32 only support nchw"); + megdnn_assert(param().format == Param::Format::NCHW, "fp32 only support nchw"); dct::call_kern_dct( src.ptr(), dst.ptr(), in, ic, ih, iw, oc, - with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, - error_info, m_error_tracker); + with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, error_info, + m_error_tracker); } else { - megdnn_assert(dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8, - "only support fp32 and qs8"); - megdnn_assert(param().format == Param::Format::NCHW4, - "qint8 only support nchw4"); + megdnn_assert( + dst.layout.dtype.enumv() == DTypeEnum::QuantizedS8, + "only support fp32 and qs8"); + megdnn_assert( + param().format == Param::Format::NCHW4, "qint8 only support nchw4"); dct::call_kern_dct( src.ptr(), (int8_t*)dst.raw_ptr, in, ic, ih, iw, oc, - with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, - error_info, m_error_tracker, + with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, error_info, + m_error_tracker, dst.layout.dtype.param<::megdnn::dtype::QuantizedS8>().scale); } } diff --git a/dnn/src/cuda/dct/opr_impl.h b/dnn/src/cuda/dct/opr_impl.h index 38991a34..e41a2489 100644 --- a/dnn/src/cuda/dct/opr_impl.h +++ b/dnn/src/cuda/dct/opr_impl.h @@ -19,19 +19,17 @@ class DctChannelSelectForwardImpl : public DctChannelSelectForward { public: using DctChannelSelectForward::DctChannelSelectForward; void* m_error_tracker = nullptr; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in mask_offset, - _megdnn_tensor_in mask_val, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in mask_offset, + _megdnn_tensor_in mask_val, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& /*src*/, - const TensorLayout& /*mask_offset*/, - const TensorLayout& /*mask_val*/, - const TensorLayout& /*dst*/) override { + size_t get_workspace_in_bytes( + const TensorLayout& /*src*/, const TensorLayout& /*mask_offset*/, + const TensorLayout& /*mask_val*/, const TensorLayout& /*dst*/) override { return 0; }; - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } }; } // namespace cuda diff --git a/dnn/src/cuda/deformable_conv/bwd_data/algo.cpp b/dnn/src/cuda/deformable_conv/bwd_data/algo.cpp index 1efb0c24..0250176f 100644 --- a/dnn/src/cuda/deformable_conv/bwd_data/algo.cpp +++ b/dnn/src/cuda/deformable_conv/bwd_data/algo.cpp @@ -33,9 +33,9 @@ OprImpl::AlgoBase::SizeArgs::SizeArgs( const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& im_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad) - : SizeArgs(o, im, - o->make_canonized_filter_meta(im.ndim, filter, offset), - offset, mask, out_grad, im_grad, offset_grad, mask_grad) {} + : SizeArgs( + o, im, o->make_canonized_filter_meta(im.ndim, filter, offset), offset, + mask, out_grad, im_grad, offset_grad, mask_grad) {} OprImpl::AlgoBase::SizeArgs::SizeArgs( OprImpl* o, const TensorLayout& im, const CanonizedFilterMeta& filter, @@ -55,13 +55,13 @@ OprImpl::AlgoBase::SizeArgs::SizeArgs( OprImpl::AlgoBase::ExecArgs::ExecArgs( OprImpl* opr, _megdnn_tensor_in im, _megdnn_tensor_in filter, - _megdnn_tensor_in offset, _megdnn_tensor_in mask, - _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad, - _megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad, - _megdnn_workspace ws) - : SizeArgs(opr, im.layout, filter.layout, offset.layout, mask.layout, - out_grad.layout, im_grad.layout, offset_grad.layout, - mask_grad.layout), + _megdnn_tensor_in offset, _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, + _megdnn_tensor_out im_grad, _megdnn_tensor_out offset_grad, + _megdnn_tensor_out mask_grad, _megdnn_workspace ws) + : SizeArgs( + opr, im.layout, filter.layout, offset.layout, mask.layout, + out_grad.layout, im_grad.layout, offset_grad.layout, + mask_grad.layout), im_tensor(im), filter_tensor(filter), offset_tensor(offset), @@ -81,13 +81,11 @@ std::string OprImpl::AlgoBase::SizeArgs::to_string() const { "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, " "dtype=%s,%s", megdnn_layout_msg(im_layout).c_str(), fm.group, fm.ocpg, fm.icpg, - fm.spatial[0], fm.spatial[1], - megdnn_layout_msg(offset_layout).c_str(), + fm.spatial[0], fm.spatial[1], megdnn_layout_msg(offset_layout).c_str(), megdnn_layout_msg(mask_layout).c_str(), - megdnn_layout_msg(out_grad_layout).c_str(), fm.padding[0], - fm.padding[1], fm.stride[0], fm.stride[1], fm.dilation[0], - fm.dilation[1], !fm.should_flip, im_layout.dtype.name(), - out_grad_layout.dtype.name()); + megdnn_layout_msg(out_grad_layout).c_str(), fm.padding[0], fm.padding[1], + fm.stride[0], fm.stride[1], fm.dilation[0], fm.dilation[1], !fm.should_flip, + im_layout.dtype.name(), out_grad_layout.dtype.name()); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/deformable_conv/bwd_data/algo.h b/dnn/src/cuda/deformable_conv/bwd_data/algo.h index da3596f8..d132c612 100644 --- a/dnn/src/cuda/deformable_conv/bwd_data/algo.h +++ b/dnn/src/cuda/deformable_conv/bwd_data/algo.h @@ -49,17 +49,19 @@ public: std::string to_string() const; - SizeArgs(DeformableConvBackwardDataImpl* opr, const TensorLayout& im, - const TensorLayout& filter, const TensorLayout& offset, - const TensorLayout& mask, const TensorLayout& out_grad, - const TensorLayout& im_grad, const TensorLayout& offset_grad, - const TensorLayout& mask_grad); - - SizeArgs(DeformableConvBackwardDataImpl* opr, const TensorLayout& im, - const CanonizedFilterMeta& filter, const TensorLayout& offset, - const TensorLayout& mask, const TensorLayout& out_grad, - const TensorLayout& im_grad, const TensorLayout& offset_grad, - const TensorLayout& mask_grad); + SizeArgs( + DeformableConvBackwardDataImpl* opr, const TensorLayout& im, + const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& im_grad, const TensorLayout& offset_grad, + const TensorLayout& mask_grad); + + SizeArgs( + DeformableConvBackwardDataImpl* opr, const TensorLayout& im, + const CanonizedFilterMeta& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& im_grad, const TensorLayout& offset_grad, + const TensorLayout& mask_grad); }; struct ExecArgs : public SizeArgs { const TensorND im_tensor, filter_tensor, offset_tensor, mask_tensor, @@ -67,11 +69,12 @@ public: TensorND im_grad_tensor, offset_grad_tensor, mask_grad_tensor; Workspace workspace; - ExecArgs(DeformableConvBackwardDataImpl* opr, _megdnn_tensor_in im, - _megdnn_tensor_in filter, _megdnn_tensor_in offset, - _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, - _megdnn_tensor_out im_grad, _megdnn_tensor_out offset_grad, - _megdnn_tensor_out mask_grad, _megdnn_workspace workspace); + ExecArgs( + DeformableConvBackwardDataImpl* opr, _megdnn_tensor_in im, + _megdnn_tensor_in filter, _megdnn_tensor_in offset, + _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, + _megdnn_tensor_out im_grad, _megdnn_tensor_out offset_grad, + _megdnn_tensor_out mask_grad, _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -86,11 +89,9 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); megdnn_assert( req <= workspace.size, @@ -110,13 +111,10 @@ public: size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; const char* name() const override { return "MATMUL"; } MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) diff --git a/dnn/src/cuda/deformable_conv/bwd_data/algo_matmul.cpp b/dnn/src/cuda/deformable_conv/bwd_data/algo_matmul.cpp index 32579cd9..ddaa1961 100644 --- a/dnn/src/cuda/deformable_conv/bwd_data/algo_matmul.cpp +++ b/dnn/src/cuda/deformable_conv/bwd_data/algo_matmul.cpp @@ -11,10 +11,10 @@ #include "src/cuda/utils.h" +#include "src/common/algo_base.h" #include "src/cuda/deformable_conv/bwd_data/algo.h" #include "src/cuda/deformable_conv/kimpl/deformable_conv.cuh" #include "src/cuda/deformable_conv/opr_impl.h" -#include "src/common/algo_base.h" using namespace megdnn; using namespace cuda; @@ -23,10 +23,9 @@ using Algo = DeformableConvBackwardDataImpl::AlgoMatmul; using OprParam = DeformableConvBase::Param; namespace { -deformable_conv::Param create_param(const Algo::SizeArgs& args, - const OprParam& opr_param, - cublasHandle_t handle, - cudaStream_t stream) { +deformable_conv::Param create_param( + const Algo::SizeArgs& args, const OprParam& opr_param, cublasHandle_t handle, + cudaStream_t stream) { deformable_conv::Param p; auto&& fm = args.filter_meta; @@ -61,14 +60,12 @@ deformable_conv::Param create_param(const Algo::SizeArgs& args, std::pair sub_opr_config( const DeformableConvForwardImpl::CanonizedFilterMeta& fm, - const TensorLayout& im, - const TensorLayout& out_grad) { + const TensorLayout& im, const TensorLayout& out_grad) { auto&& dt = im.dtype; - size_t batch_sz = im[0], OH = out_grad[2], - OW = out_grad[3], FH = fm.spatial[0], FW = fm.spatial[1]; + size_t batch_sz = im[0], OH = out_grad[2], OW = out_grad[3], FH = fm.spatial[0], + FW = fm.spatial[1]; - size_t M = fm.icpg * FH * FW, K = fm.ocpg, N = batch_sz * OH * OW, - batch = fm.group; + size_t M = fm.icpg * FH * FW, K = fm.ocpg, N = batch_sz * OH * OW, batch = fm.group; TensorLayout al = {{batch, K, M}, dt}; TensorLayout bl = {{batch, K, N}, dt}; TensorLayout cl = {{batch, M, N}, dt}; @@ -80,15 +77,14 @@ std::pair sub_opr_config( return {{al, bl, cl}, param}; } -std::pair> -prepare_sub_opr( +std::pair> prepare_sub_opr( const DeformableConvBackwardDataImpl::AlgoBase::SizeArgs& args) { auto bmatmul_opr = args.handle->create_operator(); set_execution_policy( args.opr, bmatmul_opr.get()); - auto&& config = sub_opr_config(args.filter_meta, args.im_layout, - args.out_grad_layout); + auto&& config = + sub_opr_config(args.filter_meta, args.im_layout, args.out_grad_layout); bmatmul_opr->param() = config.second; return {config.first, std::move(bmatmul_opr)}; @@ -106,8 +102,7 @@ std::vector Algo::get_subopr_list( std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str, - config.first}}; + return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str, config.first}}; } bool Algo::is_available(const SizeArgs&) const { @@ -139,8 +134,7 @@ void Algo::exec(const ExecArgs& args) const { auto&& opr = args.opr; auto&& handle = concrete_handle(opr->handle()); auto&& param = opr->param(); - auto p = create_param(args, param, handle->cublas_handle(), - handle->stream()); + auto p = create_param(args, param, handle->cublas_handle(), handle->stream()); auto bundle = get_bundle(args); bundle.set(args.workspace.raw_ptr); @@ -161,10 +155,10 @@ void Algo::exec(const ExecArgs& args) const { // clear out grad { size_t im_sz = p.batch_sz * p.IC * p.IH * p.IW * sizeof(float); - size_t offset_sz = p.batch_sz * 2 * p.deformable_group * p.FH * p.FW * - p.OH * p.OW * sizeof(float); - size_t mask_sz = p.batch_sz * p.deformable_group * p.FH * p.FW * p.OH * - p.OW * sizeof(float); + size_t offset_sz = p.batch_sz * 2 * p.deformable_group * p.FH * p.FW * p.OH * + p.OW * sizeof(float); + size_t mask_sz = p.batch_sz * p.deformable_group * p.FH * p.FW * p.OH * p.OW * + sizeof(float); cudaMemsetAsync(dev_im_grad, 0, im_sz, p.stream); cudaMemsetAsync(dev_offset_grad, 0, offset_sz, p.stream); @@ -195,13 +189,12 @@ void Algo::exec(const ExecArgs& args) const { size_t bmm_ws_size = bundle.get_size(0); config.second->exec( - A, B, C, - Workspace(static_cast(bmm_ws), bmm_ws_size)); + A, B, C, Workspace(static_cast(bmm_ws), bmm_ws_size)); } col2im(result_ws, dev_offset, dev_mask, dev_im_grad, p); // col [IC, FH * FW, N, OH * OW] - col2im_coord(dev_im, result_ws, dev_offset, dev_mask, dev_offset_grad, - dev_mask_grad, p); + col2im_coord( + dev_im, result_ws, dev_offset, dev_mask, dev_offset_grad, dev_mask_grad, p); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/deformable_conv/bwd_flt/algo.cpp b/dnn/src/cuda/deformable_conv/bwd_flt/algo.cpp index 0f4b2d3b..865d46b3 100644 --- a/dnn/src/cuda/deformable_conv/bwd_flt/algo.cpp +++ b/dnn/src/cuda/deformable_conv/bwd_flt/algo.cpp @@ -28,15 +28,13 @@ MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvBackwardFilterImpl) OprImpl::AlgoPack OprImpl::sm_algo_pack; -OprImpl::AlgoBase::SizeArgs::SizeArgs(OprImpl* o, const TensorLayout& im, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& out_grad, - const TensorLayout& filter_grad) +OprImpl::AlgoBase::SizeArgs::SizeArgs( + OprImpl* o, const TensorLayout& im, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& filter_grad) : SizeArgs( o, im, offset, mask, out_grad, - o->make_canonized_filter_meta(im.ndim, filter_grad, offset)) { -} + o->make_canonized_filter_meta(im.ndim, filter_grad, offset)) {} OprImpl::AlgoBase::SizeArgs::SizeArgs( OprImpl* o, const TensorLayout& im, const TensorLayout& offset, @@ -50,14 +48,13 @@ OprImpl::AlgoBase::SizeArgs::SizeArgs( out_grad_layout(out_grad), filter_grad_meta(filter_grad_meta) {} -OprImpl::AlgoBase::ExecArgs::ExecArgs(OprImpl* opr, _megdnn_tensor_in im, - _megdnn_tensor_in offset, - _megdnn_tensor_in mask, - _megdnn_tensor_in out_grad, - _megdnn_tensor_out filter_grad, - _megdnn_workspace ws) - : SizeArgs(opr, im.layout, offset.layout, mask.layout, out_grad.layout, - filter_grad.layout), +OprImpl::AlgoBase::ExecArgs::ExecArgs( + OprImpl* opr, _megdnn_tensor_in im, _megdnn_tensor_in offset, + _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, + _megdnn_tensor_out filter_grad, _megdnn_workspace ws) + : SizeArgs( + opr, im.layout, offset.layout, mask.layout, out_grad.layout, + filter_grad.layout), im_tensor(im), offset_tensor(offset), mask_tensor(mask), @@ -68,18 +65,18 @@ OprImpl::AlgoBase::ExecArgs::ExecArgs(OprImpl* opr, _megdnn_tensor_in im, std::string OprImpl::AlgoBase::SizeArgs::to_string() const { auto&& fm = filter_grad_meta; MEGDNN_MARK_USED_VAR(fm); - return ssprintf("im=%s, offset=%s, mask=%s, dst_grad=%s, " - "filter_grad=%u{%u,%u,%u,%u}," - "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, " - "dtype=%s,%s", - megdnn_layout_msg(im_layout).c_str(), - megdnn_layout_msg(offset_layout).c_str(), - megdnn_layout_msg(mask_layout).c_str(), - megdnn_layout_msg(out_grad_layout).c_str(), fm.group, - fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], - fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], - fm.dilation[0], fm.dilation[1], !fm.should_flip, - im_layout.dtype.name(), out_grad_layout.dtype.name()); + return ssprintf( + "im=%s, offset=%s, mask=%s, dst_grad=%s, " + "filter_grad=%u{%u,%u,%u,%u}," + "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, " + "dtype=%s,%s", + megdnn_layout_msg(im_layout).c_str(), + megdnn_layout_msg(offset_layout).c_str(), + megdnn_layout_msg(mask_layout).c_str(), + megdnn_layout_msg(out_grad_layout).c_str(), fm.group, fm.ocpg, fm.icpg, + fm.spatial[0], fm.spatial[1], fm.padding[0], fm.padding[1], fm.stride[0], + fm.stride[1], fm.dilation[0], fm.dilation[1], !fm.should_flip, + im_layout.dtype.name(), out_grad_layout.dtype.name()); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/deformable_conv/bwd_flt/algo.h b/dnn/src/cuda/deformable_conv/bwd_flt/algo.h index 7a4b9125..0e9bacc2 100644 --- a/dnn/src/cuda/deformable_conv/bwd_flt/algo.h +++ b/dnn/src/cuda/deformable_conv/bwd_flt/algo.h @@ -47,24 +47,27 @@ public: std::string to_string() const; - SizeArgs(DeformableConvBackwardFilterImpl* opr, const TensorLayout& im, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& out_grad, const TensorLayout& filter_grad); - - SizeArgs(DeformableConvBackwardFilterImpl* opr, const TensorLayout& im, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& out_grad, - const CanonizedFilterMeta& filter_grad_meta); + SizeArgs( + DeformableConvBackwardFilterImpl* opr, const TensorLayout& im, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& out_grad, const TensorLayout& filter_grad); + + SizeArgs( + DeformableConvBackwardFilterImpl* opr, const TensorLayout& im, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& out_grad, + const CanonizedFilterMeta& filter_grad_meta); }; struct ExecArgs : public SizeArgs { const TensorND im_tensor, offset_tensor, mask_tensor, out_grad_tensor; TensorND filter_grad_tensor; Workspace workspace; - ExecArgs(DeformableConvBackwardFilterImpl* opr, _megdnn_tensor_in im, - _megdnn_tensor_in offset, _megdnn_tensor_in mask, - _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad, - _megdnn_workspace workspace); + ExecArgs( + DeformableConvBackwardFilterImpl* opr, _megdnn_tensor_in im, + _megdnn_tensor_in offset, _megdnn_tensor_in mask, + _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -79,16 +82,15 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "deformable_conv bwd_flt algo %s: required workspace %zu " - "bytes, got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "deformable_conv bwd_flt algo %s: required workspace %zu " + "bytes, got %zu", + name(), req, workspace.size); return *this; } }; @@ -102,13 +104,10 @@ public: size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; const char* name() const override { return "MATMUL"; } MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) @@ -116,6 +115,7 @@ public: class DeformableConvBackwardFilterImpl::AlgoPack : NonCopyableObj { AlgoBase::Mapper m_all_algos_map; + public: AlgoPack(); diff --git a/dnn/src/cuda/deformable_conv/bwd_flt/algo_matmul.cpp b/dnn/src/cuda/deformable_conv/bwd_flt/algo_matmul.cpp index 7d8d3035..795a7088 100644 --- a/dnn/src/cuda/deformable_conv/bwd_flt/algo_matmul.cpp +++ b/dnn/src/cuda/deformable_conv/bwd_flt/algo_matmul.cpp @@ -12,10 +12,10 @@ #include "src/cuda/utils.h" +#include "src/common/algo_base.h" #include "src/cuda/deformable_conv/bwd_flt/algo.h" #include "src/cuda/deformable_conv/kimpl/deformable_conv.cuh" #include "src/cuda/deformable_conv/opr_impl.h" -#include "src/common/algo_base.h" using namespace megdnn; using namespace cuda; @@ -24,10 +24,9 @@ using Algo = DeformableConvBackwardFilterImpl::AlgoMatmul; using OprParam = DeformableConvBase::Param; namespace { -deformable_conv::Param create_param(const Algo::SizeArgs& args, - const OprParam& opr_param, - cublasHandle_t handle, - cudaStream_t stream) { +deformable_conv::Param create_param( + const Algo::SizeArgs& args, const OprParam& opr_param, cublasHandle_t handle, + cudaStream_t stream) { deformable_conv::Param p; auto&& fm = args.filter_grad_meta; @@ -64,11 +63,10 @@ std::pair sub_opr_config( const DeformableConvBackwardFilterImpl::CanonizedFilterMeta& fm, const TensorLayout& im, const TensorLayout& out_grad) { auto&& dt = im.dtype; - size_t batch_sz = im[0], OH = out_grad[2], OW = out_grad[3], - FH = fm.spatial[0], FW = fm.spatial[1]; + size_t batch_sz = im[0], OH = out_grad[2], OW = out_grad[3], FH = fm.spatial[0], + FW = fm.spatial[1]; - size_t M = fm.ocpg, K = OH * OW * batch_sz, N = fm.icpg * FH * FW, - batch = fm.group; + size_t M = fm.ocpg, K = OH * OW * batch_sz, N = fm.icpg * FH * FW, batch = fm.group; TensorLayout al = {{batch, M, K}, dt}; TensorLayout bl = {{batch, N, K}, dt}; TensorLayout cl = {{batch, M, N}, dt}; @@ -80,15 +78,14 @@ std::pair sub_opr_config( return {{al, bl, cl}, param}; } -std::pair> -prepare_sub_opr( +std::pair> prepare_sub_opr( const DeformableConvBackwardFilterImpl::AlgoBase::SizeArgs& args) { auto bmatmul_opr = args.handle->create_operator(); - set_execution_policy(args.opr, bmatmul_opr.get()); + set_execution_policy( + args.opr, bmatmul_opr.get()); - auto&& config = sub_opr_config(args.filter_grad_meta, args.im_layout, - args.out_grad_layout); + auto&& config = + sub_opr_config(args.filter_grad_meta, args.im_layout, args.out_grad_layout); bmatmul_opr->param() = config.second; return {config.first, std::move(bmatmul_opr)}; @@ -106,8 +103,7 @@ std::vector Algo::get_subopr_list( std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str, - config.first}}; + return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str, config.first}}; } bool Algo::is_available(const SizeArgs&) const { @@ -140,8 +136,7 @@ void Algo::exec(const ExecArgs& args) const { auto&& param = opr->param(); auto&& handle = concrete_handle(opr->handle()); - auto p = create_param(args, param, handle->cublas_handle(), - handle->stream()); + auto p = create_param(args, param, handle->cublas_handle(), handle->stream()); auto bundle = get_bundle(args); bundle.set(args.workspace.raw_ptr); @@ -178,7 +173,6 @@ void Algo::exec(const ExecArgs& args) const { size_t bmm_ws_size = bundle.get_size(2); config.second->exec( - A, B, C, - Workspace(static_cast(bmm_ws), bmm_ws_size)); + A, B, C, Workspace(static_cast(bmm_ws), bmm_ws_size)); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/deformable_conv/fwd/algo.cpp b/dnn/src/cuda/deformable_conv/fwd/algo.cpp index f277f572..dcb7b7f8 100644 --- a/dnn/src/cuda/deformable_conv/fwd/algo.cpp +++ b/dnn/src/cuda/deformable_conv/fwd/algo.cpp @@ -32,20 +32,16 @@ MEGDNN_DEF_GET_ALGO_FROM_DESC(DeformableConvForwardImpl) OprImpl::AlgoPack OprImpl::sm_algo_pack; -OprImpl::AlgoBase::SizeArgs::SizeArgs(OprImpl* o, const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst) - : SizeArgs(o, im, - o->make_canonized_filter_meta(im.ndim, filter, offset), - offset, mask, dst) {} +OprImpl::AlgoBase::SizeArgs::SizeArgs( + OprImpl* o, const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& dst) + : SizeArgs( + o, im, o->make_canonized_filter_meta(im.ndim, filter, offset), offset, + mask, dst) {} -OprImpl::AlgoBase::SizeArgs::SizeArgs(OprImpl* o, const TensorLayout& im, - const CanonizedFilterMeta& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst) +OprImpl::AlgoBase::SizeArgs::SizeArgs( + OprImpl* o, const TensorLayout& im, const CanonizedFilterMeta& filter, + const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& dst) : opr(o), handle(concrete_handle(o->handle())), im_layout(im), @@ -54,14 +50,13 @@ OprImpl::AlgoBase::SizeArgs::SizeArgs(OprImpl* o, const TensorLayout& im, mask_layout(mask), dst_layout(dst) {} -OprImpl::AlgoBase::ExecArgs::ExecArgs(OprImpl* opr, _megdnn_tensor_in im, - _megdnn_tensor_in filter, - _megdnn_tensor_in offset, - _megdnn_tensor_in mask, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) - : SizeArgs(opr, im.layout, filter.layout, offset.layout, mask.layout, - dst.layout), +OprImpl::AlgoBase::ExecArgs::ExecArgs( + OprImpl* opr, _megdnn_tensor_in im, _megdnn_tensor_in filter, + _megdnn_tensor_in offset, _megdnn_tensor_in mask, _megdnn_tensor_out dst, + _megdnn_workspace workspace) + : SizeArgs( + opr, im.layout, filter.layout, offset.layout, mask.layout, + dst.layout), im_tensor(im), filter_tensor(filter), offset_tensor(offset), @@ -75,12 +70,12 @@ std::string OprImpl::AlgoBase::SizeArgs::to_string() const { return ssprintf( "im=%s, filter=%u{%u,%u,%u,%u}, offset=%s, mask=%s, dst=%s, " "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s", - im_layout.to_string().c_str(), fm.group, fm.ocpg, fm.icpg, - fm.spatial[0], fm.spatial[1], offset_layout.to_string().c_str(), + im_layout.to_string().c_str(), fm.group, fm.ocpg, fm.icpg, fm.spatial[0], + fm.spatial[1], offset_layout.to_string().c_str(), mask_layout.to_string().c_str(), dst_layout.to_string().c_str(), - fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], - fm.dilation[0], fm.dilation[1], !fm.should_flip, - im_layout.dtype.name(), dst_layout.dtype.name()); + fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], fm.dilation[0], + fm.dilation[1], !fm.should_flip, im_layout.dtype.name(), + dst_layout.dtype.name()); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/deformable_conv/fwd/algo.h b/dnn/src/cuda/deformable_conv/fwd/algo.h index 4390711a..ce5871f6 100644 --- a/dnn/src/cuda/deformable_conv/fwd/algo.h +++ b/dnn/src/cuda/deformable_conv/fwd/algo.h @@ -44,22 +44,25 @@ public: const TensorLayout& dst_layout; std::string to_string() const; - SizeArgs(DeformableConvForwardImpl* opr, const TensorLayout& im, - const TensorLayout& filter, const TensorLayout& offset, - const TensorLayout& mask, const TensorLayout& dst); - SizeArgs(DeformableConvForwardImpl* opr, const TensorLayout& im, - const CanonizedFilterMeta& filter, const TensorLayout& offset, - const TensorLayout& mask, const TensorLayout& dst); + SizeArgs( + DeformableConvForwardImpl* opr, const TensorLayout& im, + const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& dst); + SizeArgs( + DeformableConvForwardImpl* opr, const TensorLayout& im, + const CanonizedFilterMeta& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& dst); }; struct ExecArgs : public SizeArgs { const TensorND &im_tensor, filter_tensor, offset_tensor, mask_tensor, dst_tensor; Workspace workspace; - ExecArgs(DeformableConvForwardImpl* opr, _megdnn_tensor_in im, - _megdnn_tensor_in filter, _megdnn_tensor_in offset, - _megdnn_tensor_in mask, _megdnn_tensor_out dst, - _megdnn_workspace workspace); + ExecArgs( + DeformableConvForwardImpl* opr, _megdnn_tensor_in im, + _megdnn_tensor_in filter, _megdnn_tensor_in offset, + _megdnn_tensor_in mask, _megdnn_tensor_out dst, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -74,16 +77,15 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "deformable_conv fwd algo %s: required workspace %zu " - "bytes, got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "deformable_conv fwd algo %s: required workspace %zu " + "bytes, got %zu", + name(), req, workspace.size); return *this; } }; @@ -97,13 +99,10 @@ public: size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; const char* name() const override { return "MATMUL"; } MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) @@ -111,6 +110,7 @@ public: class DeformableConvForwardImpl::AlgoPack : NonCopyableObj { AlgoBase::Mapper m_all_algos_map; + public: AlgoPack(); AlgoMatmul algo_matmul; diff --git a/dnn/src/cuda/deformable_conv/fwd/algo_matmul.cpp b/dnn/src/cuda/deformable_conv/fwd/algo_matmul.cpp index ce32a28f..f3ff5552 100644 --- a/dnn/src/cuda/deformable_conv/fwd/algo_matmul.cpp +++ b/dnn/src/cuda/deformable_conv/fwd/algo_matmul.cpp @@ -11,10 +11,10 @@ #include "src/cuda/handle.h" +#include "src/common/algo_base.h" #include "src/cuda/batched_matrix_mul/algo.h" #include "src/cuda/deformable_conv/fwd/algo.h" #include "src/cuda/deformable_conv/kimpl/deformable_conv.cuh" -#include "src/common/algo_base.h" using namespace megdnn; using namespace cuda; @@ -23,10 +23,9 @@ using Algo = DeformableConvForwardImpl::AlgoMatmul; using OprParam = DeformableConvBase::Param; namespace { -deformable_conv::Param create_param(const Algo::SizeArgs& args, - const OprParam& opr_param, - cublasHandle_t handle, - cudaStream_t stream) { +deformable_conv::Param create_param( + const Algo::SizeArgs& args, const OprParam& opr_param, cublasHandle_t handle, + cudaStream_t stream) { deformable_conv::Param p; auto&& fm = args.filter_meta; @@ -61,14 +60,12 @@ deformable_conv::Param create_param(const Algo::SizeArgs& args, std::pair sub_opr_config( const DeformableConvForwardImpl::CanonizedFilterMeta& fm, - const TensorLayout& im, - const TensorLayout& dst) { + const TensorLayout& im, const TensorLayout& dst) { auto&& dt = im.dtype; - size_t batch_sz = im[0], OH = dst[2], - OW = dst[3], FH = fm.spatial[0], FW = fm.spatial[1]; + size_t batch_sz = im[0], OH = dst[2], OW = dst[3], FH = fm.spatial[0], + FW = fm.spatial[1]; - size_t M = fm.ocpg, N = OH * OW * batch_sz, K = fm.icpg * FH * FW, - batch = fm.group; + size_t M = fm.ocpg, N = OH * OW * batch_sz, K = fm.icpg * FH * FW, batch = fm.group; TensorLayout al = {{batch, M, K}, dt}; TensorLayout bl = {{batch, K, N}, dt}; TensorLayout cl = {{batch, M, N}, dt}; @@ -79,14 +76,13 @@ std::pair sub_opr_config( return {{al, bl, cl}, param}; } -std::pair> -prepare_sub_opr(const DeformableConvForwardImpl::AlgoBase::SizeArgs& args) { +std::pair> prepare_sub_opr( + const DeformableConvForwardImpl::AlgoBase::SizeArgs& args) { auto bmatmul_opr = args.handle->create_operator(); set_execution_policy( args.opr, bmatmul_opr.get()); - auto&& config = - sub_opr_config(args.filter_meta, args.im_layout, args.dst_layout); + auto&& config = sub_opr_config(args.filter_meta, args.im_layout, args.dst_layout); bmatmul_opr->param() = config.second; return {config.first, std::move(bmatmul_opr)}; @@ -104,8 +100,7 @@ std::vector Algo::get_subopr_list( std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); - return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str, - config.first}}; + return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str, config.first}}; } bool Algo::is_available(const SizeArgs&) const { @@ -115,8 +110,8 @@ bool Algo::is_available(const SizeArgs&) const { WorkspaceBundle Algo::get_bundle(const SizeArgs& args) { auto&& fm = args.filter_meta; size_t batch_sz = args.im_layout[0], IC = fm.group * fm.icpg, - OC = args.dst_layout[1], OH = args.dst_layout[2], - OW = args.dst_layout[3], FH = fm.spatial[0], FW = fm.spatial[1]; + OC = args.dst_layout[1], OH = args.dst_layout[2], OW = args.dst_layout[3], + FH = fm.spatial[0], FW = fm.spatial[1]; auto config = prepare_sub_opr(args); @@ -137,8 +132,7 @@ void Algo::exec(const ExecArgs& args) const { auto&& param = opr->param(); auto&& handle = concrete_handle(opr->handle()); - auto p = create_param(args, param, handle->cublas_handle(), - handle->stream()); + auto p = create_param(args, param, handle->cublas_handle(), handle->stream()); const float* dev_im = args.im_tensor.ptr(); float* dev_filter = args.filter_tensor.ptr(); @@ -153,8 +147,8 @@ void Algo::exec(const ExecArgs& args) const { void* bmm_ws = bundle.get(1); void* result_ws = bundle.get(2); // im2col - deformable_conv::im2col(dev_im, dev_offset, dev_mask, - static_cast(col_ws), p); + deformable_conv::im2col( + dev_im, dev_offset, dev_mask, static_cast(col_ws), p); auto config = prepare_sub_opr(args); @@ -165,8 +159,7 @@ void Algo::exec(const ExecArgs& args) const { size_t bmm_ws_size = bundle.get_size(1); config.second->exec( - A, B, C, - Workspace(static_cast(bmm_ws), bmm_ws_size)); + A, B, C, Workspace(static_cast(bmm_ws), bmm_ws_size)); // relayout auto&& dt = args.im_layout.dtype; size_t dim0 = p.OC, dim1 = p.batch_sz, dim2 = p.OH * p.OW; diff --git a/dnn/src/cuda/deformable_conv/kimpl/deformable_conv.cu b/dnn/src/cuda/deformable_conv/kimpl/deformable_conv.cu index bea70680..81eaa077 100644 --- a/dnn/src/cuda/deformable_conv/kimpl/deformable_conv.cu +++ b/dnn/src/cuda/deformable_conv/kimpl/deformable_conv.cu @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/cuda/query_blocksize.cuh" #include "src/cuda/deformable_conv/kimpl/deformable_conv.cuh" +#include "src/cuda/query_blocksize.cuh" using namespace megdnn; using namespace cuda; @@ -18,9 +18,9 @@ using namespace deformable_conv; namespace { -__device__ float dmcn_im2col_bilinear(const float* bottom_data, - const int data_width, const int height, - const int width, float h, float w) { +__device__ float dmcn_im2col_bilinear( + const float* bottom_data, const int data_width, const int height, + const int width, float h, float w) { int h_low = floor(h); int w_low = floor(w); int h_high = h_low + 1; @@ -49,11 +49,10 @@ __device__ float dmcn_im2col_bilinear(const float* bottom_data, return val; } -__device__ float dmcn_get_gradient_weight(float argmax_h, float argmax_w, - const int h, const int w, - const int height, const int width) { - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || - argmax_w >= width) { +__device__ float dmcn_get_gradient_weight( + float argmax_h, float argmax_w, const int h, const int w, const int height, + const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { return 0; } @@ -74,13 +73,10 @@ __device__ float dmcn_get_gradient_weight(float argmax_h, float argmax_w, return weight; } -__device__ float dmcn_get_coordinate_weight(float argmax_h, float argmax_w, - const int height, const int width, - const float* im_data, - const int data_width, - const int bp_dir) { - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || - argmax_w >= width) { +__device__ float dmcn_get_coordinate_weight( + float argmax_h, float argmax_w, const int height, const int width, + const float* im_data, const int data_width, const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) { return 0; } @@ -122,8 +118,8 @@ __device__ float dmcn_get_coordinate_weight(float argmax_h, float argmax_w, return weight; } -__global__ void deformable_im2col(Param p, const float* im, const float* offset, - const float* mask, float* col) { +__global__ void deformable_im2col( + Param p, const float* im, const float* offset, const float* mask, float* col) { size_t n = blockIdx.y; const size_t N = p.batch_sz; const size_t loops = p.IC * p.OH * p.OW; @@ -146,17 +142,13 @@ __global__ void deformable_im2col(Param p, const float* im, const float* offset, const float* im_ptr = &im[ic * p.IH * p.IW]; const float* offset_ptr = &offset[(dg * 2 * p.FH * p.FW * p.OH + oh) * p.OW + ow]; - const float* mask_ptr = - &mask[(dg * p.FH * p.FW * p.OH + oh) * p.OW + ow]; - float* col_ptr = - &col[((((ic * p.FH * p.FW) * N + n) * p.OH + oh) * p.OW + ow)]; + const float* mask_ptr = &mask[(dg * p.FH * p.FW * p.OH + oh) * p.OW + ow]; + float* col_ptr = &col[((((ic * p.FH * p.FW) * N + n) * p.OH + oh) * p.OW + ow)]; for (int i = 0; i < p.FH; ++i) for (int j = 0; j < p.FW; ++j) { - const float off_h = - offset_ptr[(2 * (i * p.FW + j)) * p.OH * p.OW]; - const float off_w = - offset_ptr[(2 * (i * p.FW + j) + 1) * p.OH * p.OW]; + const float off_h = offset_ptr[(2 * (i * p.FW + j)) * p.OH * p.OW]; + const float off_w = offset_ptr[(2 * (i * p.FW + j) + 1) * p.OH * p.OW]; const float m = mask_ptr[(i * p.FW + j) * p.OH * p.OW]; float val = 0.f; @@ -169,9 +161,8 @@ __global__ void deformable_im2col(Param p, const float* im, const float* offset, } } -__global__ void deformable_col2im(Param p, const float* col, - const float* offset, const float* mask, - float* im) { +__global__ void deformable_col2im( + Param p, const float* col, const float* offset, const float* mask, float* im) { size_t dg = blockIdx.y % p.deformable_group; size_t n = blockIdx.y / p.deformable_group; const size_t loops = p.FH * p.FW * p.OH * p.OW; @@ -194,8 +185,7 @@ __global__ void deformable_col2im(Param p, const float* col, const float* mask_ptr = &mask[dg * p.FH * p.FW * p.OH * p.OW]; const int off_h_idx = ((2 * (fh * p.FW + fw)) * p.OH + oh) * p.OW + ow; - const int off_w_idx = - ((2 * (fh * p.FW + fw) + 1) * p.OH + oh) * p.OW + ow; + const int off_w_idx = ((2 * (fh * p.FW + fw) + 1) * p.OH + oh) * p.OW + ow; const int mask_idx = ((fh * p.FW + fw) * p.OH + oh) * p.OW + ow; const float off_h = offset_ptr[off_h_idx]; @@ -209,8 +199,7 @@ __global__ void deformable_col2im(Param p, const float* col, const int iw = ow * p.SW - p.PW; const int col_idx = - (((((ic * p.FH) + fh) * p.FW + fw) * N + n) * p.OH + oh) * - p.OW + + (((((ic * p.FH) + fh) * p.FW + fw) * N + n) * p.OH + oh) * p.OW + ow; const float top_grad = col[col_idx] * m; @@ -219,16 +208,13 @@ __global__ void deformable_col2im(Param p, const float* col, const int h_hat = (int)h, w_hat = (int)w; #pragma unroll - for (int dy = -2; dy <= 2; - dy++) { // use 0-1 is better, same for dx + for (int dy = -2; dy <= 2; dy++) { // use 0-1 is better, same for dx #pragma unroll for (int dx = -2; dx <= 2; dx++) { - if (h_hat + dy >= 0 && h_hat + dy < p.IH && - w_hat + dx >= 0 && w_hat + dx < p.IW && - abs(h - (h_hat + dy)) < 1 && + if (h_hat + dy >= 0 && h_hat + dy < p.IH && w_hat + dx >= 0 && + w_hat + dx < p.IW && abs(h - (h_hat + dy)) < 1 && abs(w - (w_hat + dx)) < 1) { - int bottom_pos = - (ic * p.IH + h_hat + dy) * p.IW + w_hat + dx; + int bottom_pos = (ic * p.IH + h_hat + dy) * p.IW + w_hat + dx; float weight = dmcn_get_gradient_weight( h, w, h_hat + dy, w_hat + dx, p.IH, p.IW); atomicAdd(&im[bottom_pos], weight * top_grad); @@ -239,9 +225,9 @@ __global__ void deformable_col2im(Param p, const float* col, } } -__global__ void deformable_col2coord(Param p, const float* im, const float* col, - const float* offset, const float* mask, - float* offset_grad, float* mask_grad) { +__global__ void deformable_col2coord( + Param p, const float* im, const float* col, const float* offset, + const float* mask, float* offset_grad, float* mask_grad) { size_t n = blockIdx.y; const size_t N = p.batch_sz; const size_t loops = p.deformable_group * p.FH * p.FW * 2 * p.OH * p.OW; @@ -263,8 +249,7 @@ __global__ void deformable_col2coord(Param p, const float* im, const float* col, const int oh = (idx / 2 / p.OW) % p.OH; const int fw = (idx / 2 / p.OW / p.OH) % p.FW; const int fh = (idx / 2 / p.OW / p.OH / p.FW) % p.FH; - const int dg = - (idx / 2 / p.OW / p.OH / p.FW / p.FH) % p.deformable_group; + const int dg = (idx / 2 / p.OW / p.OH / p.FW / p.FH) % p.deformable_group; const int ih = oh * p.SH - p.PH; const int iw = ow * p.SW - p.PW; @@ -272,14 +257,11 @@ __global__ void deformable_col2coord(Param p, const float* im, const float* col, const float* offset_ptr = &offset[dg * 2 * p.FH * p.FW * p.OH * p.OW]; const float* mask_ptr = &mask[dg * p.FH * p.FW * p.OH * p.OW]; - float* offset_grad_ptr = - &offset_grad[dg * 2 * p.FH * p.FW * p.OH * p.OW]; + float* offset_grad_ptr = &offset_grad[dg * 2 * p.FH * p.FW * p.OH * p.OW]; float* mask_grad_ptr = &mask_grad[dg * p.FH * p.FW * p.OH * p.OW]; - const int offset_h_idx = - ((2 * (fh * p.FW + fw)) * p.OH + oh) * p.OW + ow; - const int offset_w_idx = - ((2 * (fh * p.FW + fw) + 1) * p.OH + oh) * p.OW + ow; + const int offset_h_idx = ((2 * (fh * p.FW + fw)) * p.OH + oh) * p.OW + ow; + const int offset_w_idx = ((2 * (fh * p.FW + fw) + 1) * p.OH + oh) * p.OW + ow; const int mask_idx = ((fh * p.FW + fw) * p.OH + oh) * p.OW + ow; const int offset_grad_idx = (hw == 0) ? offset_h_idx : offset_w_idx; @@ -295,25 +277,23 @@ __global__ void deformable_col2coord(Param p, const float* im, const float* col, for (int ic = ic_l; ic < ic_r; ++ic) { const float* im_ptr = &im[ic * p.IH * p.IW]; const int col_idx = - (((((ic * p.FH + fh) * p.FW + fw) * N + n) * p.OH + oh) * - p.OW + + (((((ic * p.FH + fh) * p.FW + fw) * N + n) * p.OH + oh) * p.OW + ow); const float col_grad = col[col_idx]; if (h <= -1 || w <= -1 || h >= p.IH || w >= p.IW) { h = w = -2; } else if (hw % 2 == 0) { - mval += col_grad * - dmcn_im2col_bilinear(im_ptr, p.IW, p.IH, p.IW, h, w); + mval += col_grad * dmcn_im2col_bilinear(im_ptr, p.IW, p.IH, p.IW, h, w); } const float top_grad = col_grad * m; - const float weight = dmcn_get_coordinate_weight(h, w, p.IH, p.IW, - im_ptr, p.IW, hw); + const float weight = + dmcn_get_coordinate_weight(h, w, p.IH, p.IW, im_ptr, p.IW, hw); val += weight * top_grad; } offset_grad_ptr[offset_grad_idx] = val; - if (hw % 2 ==0) { + if (hw % 2 == 0) { mask_grad_ptr[mask_idx] = mval; } } @@ -325,36 +305,38 @@ namespace megdnn { namespace cuda { namespace deformable_conv { -void im2col(const float* dev_im, const float* dev_offset, const float* dev_mask, - float* dev_col, const Param& p) { +void im2col( + const float* dev_im, const float* dev_offset, const float* dev_mask, + float* dev_col, const Param& p) { dim3 grid; size_t loops = p.IC * p.OH * p.OW; int nr_thds = query_blocksize_for_kernel(deformable_im2col); grid.x = DIVUP(loops, nr_thds), grid.y = p.batch_sz; - deformable_im2col<<>>(p, dev_im, dev_offset, - dev_mask, dev_col); + deformable_im2col<<>>( + p, dev_im, dev_offset, dev_mask, dev_col); after_kernel_launch(); } -void col2im(const float* dev_col, const float* dev_offset, - const float* dev_mask, float* dev_im_grad, const Param& p) { +void col2im( + const float* dev_col, const float* dev_offset, const float* dev_mask, + float* dev_im_grad, const Param& p) { dim3 grid; size_t loops = p.FH * p.FW * p.OH * p.OW; int nr_thds = query_blocksize_for_kernel(deformable_col2im); grid.x = DIVUP(loops, nr_thds), grid.y = p.batch_sz * p.deformable_group; - deformable_col2im<<>>(p, dev_col, dev_offset, - dev_mask, dev_im_grad); + deformable_col2im<<>>( + p, dev_col, dev_offset, dev_mask, dev_im_grad); after_kernel_launch(); } -void col2im_coord(const float* dev_im, const float* dev_col, - const float* dev_offset, const float* dev_mask, - float* dev_offset_grad, float* dev_mask_grad, - const Param& p) { +void col2im_coord( + const float* dev_im, const float* dev_col, const float* dev_offset, + const float* dev_mask, float* dev_offset_grad, float* dev_mask_grad, + const Param& p) { dim3 grid; size_t loops = 2 * p.FH * p.FW * p.OH * p.OW * p.deformable_group; int nr_thds = query_blocksize_for_kernel(deformable_col2coord); @@ -363,8 +345,7 @@ void col2im_coord(const float* dev_im, const float* dev_col, grid.y = p.batch_sz; deformable_col2coord<<>>( - p, dev_im, dev_col, dev_offset, dev_mask, dev_offset_grad, - dev_mask_grad); + p, dev_im, dev_col, dev_offset, dev_mask, dev_offset_grad, dev_mask_grad); after_kernel_launch(); } diff --git a/dnn/src/cuda/deformable_conv/kimpl/deformable_conv.cuh b/dnn/src/cuda/deformable_conv/kimpl/deformable_conv.cuh index faeedf9e..1e2b3d9f 100644 --- a/dnn/src/cuda/deformable_conv/kimpl/deformable_conv.cuh +++ b/dnn/src/cuda/deformable_conv/kimpl/deformable_conv.cuh @@ -36,15 +36,18 @@ struct Param { cublasHandle_t handle; }; -void im2col(const float* dev_im, const float* dev_offset, const float* dev_mask, - float* dev_col, const Param& p); - -void col2im(const float* dev_col, const float* dev_offset, - const float* dev_mask, float* dev_im_grad, const Param& p); - -void col2im_coord(const float* dev_im, const float* dev_col, - const float* dev_offset, const float* dev_mask, - float* dev_offset_grad, float* mask_grad, const Param& p); +void im2col( + const float* dev_im, const float* dev_offset, const float* dev_mask, + float* dev_col, const Param& p); + +void col2im( + const float* dev_col, const float* dev_offset, const float* dev_mask, + float* dev_im_grad, const Param& p); + +void col2im_coord( + const float* dev_im, const float* dev_col, const float* dev_offset, + const float* dev_mask, float* dev_offset_grad, float* mask_grad, + const Param& p); } // namespace deformable_conv } // namespace cuda diff --git a/dnn/src/cuda/deformable_conv/opr_impl.cpp b/dnn/src/cuda/deformable_conv/opr_impl.cpp index 909f1e05..b594d088 100644 --- a/dnn/src/cuda/deformable_conv/opr_impl.cpp +++ b/dnn/src/cuda/deformable_conv/opr_impl.cpp @@ -9,9 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/cuda/deformable_conv/fwd/algo.h" -#include "src/cuda/deformable_conv/bwd_flt/algo.h" #include "src/cuda/deformable_conv/bwd_data/algo.h" +#include "src/cuda/deformable_conv/bwd_flt/algo.h" +#include "src/cuda/deformable_conv/fwd/algo.h" #include "src/common/algo_chooser.h" #include "src/common/utils.h" @@ -31,19 +31,16 @@ using AlgoBwdData = BwdData::Algorithm; /* ============== Fwd Implementation ============== */ -size_t Fwd::get_workspace_in_bytes(const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst) { +size_t Fwd::get_workspace_in_bytes( + const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& dst) { return get_dnn_workspace(this, im, filter, offset, mask, dst); } -std::vector Fwd::get_all_algorithms(const TensorLayout& /* im */, - const TensorLayout& /* filter */, - const TensorLayout& /* offset */, - const TensorLayout& /* mask */, - const TensorLayout& /* dst */) { +std::vector Fwd::get_all_algorithms( + const TensorLayout& /* im */, const TensorLayout& /* filter */, + const TensorLayout& /* offset */, const TensorLayout& /* mask */, + const TensorLayout& /* dst */) { std::vector algos; for (auto i : sm_algo_pack.all_algos) @@ -51,63 +48,56 @@ std::vector Fwd::get_all_algorithms(const TensorLayout& /* im */, return algos; } -std::vector Fwd::get_all_algorithms_safe(const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst) { - auto ret_safe = Fwd::get_all_algorithms(im,filter,offset,mask,dst); +std::vector Fwd::get_all_algorithms_safe( + const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& dst) { + auto ret_safe = Fwd::get_all_algorithms(im, filter, offset, mask, dst); megdnn_assert(!ret_safe.empty(), "no usable deformable_conv fwd algorithm"); return ret_safe; } -AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { +AlgoFwd* Fwd::get_algorithm_heuristic( + const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { auto fm = make_canonized_filter_meta(im.ndim, filter, offset); - return get_algorithm_heuristic(im, fm, offset, mask, dst, - workspace_limit_in_bytes, positive_attr, - negative_attr); + return get_algorithm_heuristic( + im, fm, offset, mask, dst, workspace_limit_in_bytes, positive_attr, + negative_attr); } -AlgoFwd* Fwd::get_algorithm_heuristic(const TensorLayout& im, - const CanonizedFilterMeta& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { +AlgoFwd* Fwd::get_algorithm_heuristic( + const TensorLayout& im, const CanonizedFilterMeta& filter, + const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, im, filter, offset, mask, dst); if (sm_algo_pack.algo_matmul.is_available_attribute( args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return &sm_algo_pack.algo_matmul; } - megdnn_throw( - ssprintf("no deformable conv fwd algorithm without attribute(%s) " - "with attribute(%s) , args(%s) and " - "workspace limit (%zu bytes)", - Algorithm::attribute_str(negative_attr).c_str(), - Algorithm::attribute_str(positive_attr).c_str(), - args.to_string().c_str(), workspace_limit_in_bytes)); + megdnn_throw(ssprintf( + "no deformable conv fwd algorithm without attribute(%s) " + "with attribute(%s) , args(%s) and " + "workspace limit (%zu bytes)", + Algorithm::attribute_str(negative_attr).c_str(), + Algorithm::attribute_str(positive_attr).c_str(), args.to_string().c_str(), + workspace_limit_in_bytes)); } const char* Fwd::get_algorithm_set_name() const { return "DEFORMABLE_CONV_FWD_CUDA"; }; -void Fwd::exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, - _megdnn_tensor_in offset, _megdnn_tensor_in mask, - _megdnn_tensor_out out, _megdnn_workspace workspace) { - check_exec(im.layout, filter.layout, offset.layout, mask.layout, out.layout, - workspace.size); - auto algo = get_algorithm(this, im.layout, filter.layout, offset.layout, - mask.layout, out.layout); +void Fwd::exec( + _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset, + _megdnn_tensor_in mask, _megdnn_tensor_out out, _megdnn_workspace workspace) { + check_exec( + im.layout, filter.layout, offset.layout, mask.layout, out.layout, + workspace.size); + auto algo = get_algorithm( + this, im.layout, filter.layout, offset.layout, mask.layout, out.layout); AlgoBase::ExecArgs args(this, im, filter, offset, mask, out, workspace); algo->exec(args); @@ -115,54 +105,53 @@ void Fwd::exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, /* ============== BwdFlt Implementation ============== */ -std::vector BwdFlt::get_all_algorithms(const TensorLayout& /* im */, - const TensorLayout& /* offset */, const TensorLayout& /* mask */, - const TensorLayout& /* out_grad */, const TensorLayout& /* filter_grad */) { +std::vector BwdFlt::get_all_algorithms( + const TensorLayout& /* im */, const TensorLayout& /* offset */, + const TensorLayout& /* mask */, const TensorLayout& /* out_grad */, + const TensorLayout& /* filter_grad */) { std::vector algos; for (auto i : sm_algo_pack.all_algos) algos.push_back(static_cast(i)); return algos; } -std::vector BwdFlt::get_all_algorithms_safe(const TensorLayout& im, - const TensorLayout& offset, const TensorLayout& mask, +std::vector BwdFlt::get_all_algorithms_safe( + const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& filter_grad) { - auto ret_safe = BwdFlt::get_all_algorithms(im,offset,mask,out_grad,filter_grad); + auto ret_safe = BwdFlt::get_all_algorithms(im, offset, mask, out_grad, filter_grad); megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd filter algorithm"); return ret_safe; } AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( - const TensorLayout& im, const TensorLayout& offset, - const TensorLayout& mask, const TensorLayout& out_grad, - const TensorLayout& filter_grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, + const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& out_grad, const TensorLayout& filter_grad, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { auto fm = make_canonized_filter_meta(im.ndim, filter_grad, offset); - return get_algorithm_heuristic(im, offset, mask, out_grad, fm, - workspace_limit_in_bytes, positive_attr, - negative_attr); + return get_algorithm_heuristic( + im, offset, mask, out_grad, fm, workspace_limit_in_bytes, positive_attr, + negative_attr); } AlgoBwdFlt* BwdFlt::get_algorithm_heuristic( - const TensorLayout& im, const TensorLayout& offset, - const TensorLayout& mask, const TensorLayout& out_grad, - const CanonizedFilterMeta& filter_grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, + const TensorLayout& im, const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& out_grad, const CanonizedFilterMeta& filter_grad, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, im, offset, mask, out_grad, filter_grad); if (sm_algo_pack.algo_matmul.is_available_attribute( args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return &sm_algo_pack.algo_matmul; } - megdnn_throw( - ssprintf("no deformable conv bwd filter algorithm without " - "attribute(%s) with " - "attribute(%s), args(%s) and " - "workspace limit (%zu bytes)", - Algorithm::attribute_str(negative_attr).c_str(), - Algorithm::attribute_str(positive_attr).c_str(), - args.to_string().c_str(), workspace_limit_in_bytes)); + megdnn_throw(ssprintf( + "no deformable conv bwd filter algorithm without " + "attribute(%s) with " + "attribute(%s), args(%s) and " + "workspace limit (%zu bytes)", + Algorithm::attribute_str(negative_attr).c_str(), + Algorithm::attribute_str(positive_attr).c_str(), args.to_string().c_str(), + workspace_limit_in_bytes)); } size_t BwdFlt::get_workspace_in_bytes( @@ -175,15 +164,17 @@ const char* BwdFlt::get_algorithm_set_name() const { return "DEFORMABLE_CONV_BWD_FILTER_CUDA"; }; -void BwdFlt::exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, - _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, - _megdnn_tensor_out filter_grad, _megdnn_workspace workspace) { - check_exec(im.layout, offset.layout, mask.layout, out_grad.layout, - filter_grad.layout, workspace.size); - AlgoBase::ExecArgs args(this, im, offset, mask, out_grad, filter_grad, - workspace); - auto algo = get_algorithm(this, im.layout, offset.layout, mask.layout, - out_grad.layout, filter_grad.layout); +void BwdFlt::exec( + _megdnn_tensor_in im, _megdnn_tensor_in offset, _megdnn_tensor_in mask, + _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad, + _megdnn_workspace workspace) { + check_exec( + im.layout, offset.layout, mask.layout, out_grad.layout, filter_grad.layout, + workspace.size); + AlgoBase::ExecArgs args(this, im, offset, mask, out_grad, filter_grad, workspace); + auto algo = get_algorithm( + this, im.layout, offset.layout, mask.layout, out_grad.layout, + filter_grad.layout); algo->exec(args); } @@ -191,29 +182,31 @@ void BwdFlt::exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, std::vector BwdData::get_all_algorithms( const TensorLayout& /* im */, const TensorLayout& /* filter */, - const TensorLayout& /* offset */, const TensorLayout& /* mask */, const TensorLayout& /* out_grad */, - const TensorLayout& /* im_grad */, const TensorLayout& /* offset_grad */, const TensorLayout& /* mask_grad */) { + const TensorLayout& /* offset */, const TensorLayout& /* mask */, + const TensorLayout& /* out_grad */, const TensorLayout& /* im_grad */, + const TensorLayout& /* offset_grad */, const TensorLayout& /* mask_grad */) { std::vector algos; for (auto i : sm_algo_pack.all_algos) algos.push_back(static_cast(i)); return algos; } std::vector BwdData::get_all_algorithms_safe( - const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, - const TensorLayout& im_grad, const TensorLayout& offset_grad, const TensorLayout& mask_grad ) { - auto ret_safe = BwdData::get_all_algorithms(im,filter,offset,mask,out_grad,im_grad,offset_grad,mask_grad); + const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& im_grad, const TensorLayout& offset_grad, + const TensorLayout& mask_grad) { + auto ret_safe = BwdData::get_all_algorithms( + im, filter, offset, mask, out_grad, im_grad, offset_grad, mask_grad); megdnn_assert(!ret_safe.empty(), "no usable deformable_conv bwd data algorithm"); return ret_safe; } AlgoBwdData* BwdData::get_algorithm_heuristic( - const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& out_grad, const TensorLayout& im_grad, - const TensorLayout& offset_grad, const TensorLayout& mask_grad, - size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { + const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& im_grad, const TensorLayout& offset_grad, + const TensorLayout& mask_grad, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { auto fm = make_canonized_filter_meta(im.ndim, filter, offset); return get_algorithm_heuristic( im, fm, offset, mask, out_grad, im_grad, offset_grad, mask_grad, @@ -227,48 +220,49 @@ AlgoBwdData* BwdData::get_algorithm_heuristic( const TensorLayout& offset_grad, const TensorLayout& mask_grad, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { - AlgoBase::SizeArgs args(this, im, filter, offset, mask, out_grad, im_grad, - offset_grad, mask_grad); + AlgoBase::SizeArgs args( + this, im, filter, offset, mask, out_grad, im_grad, offset_grad, mask_grad); if (sm_algo_pack.algo_matmul.is_available_attribute( args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return &sm_algo_pack.algo_matmul; } - megdnn_throw( - ssprintf("no deformable conv bwd data algorithm without " - "attribute(%s) with attribute(%s), " - "args(%s) and " - "workspace limit (%zu bytes)", - Algorithm::attribute_str(negative_attr).c_str(), - Algorithm::attribute_str(positive_attr).c_str(), - args.to_string().c_str(), workspace_limit_in_bytes)); + megdnn_throw(ssprintf( + "no deformable conv bwd data algorithm without " + "attribute(%s) with attribute(%s), " + "args(%s) and " + "workspace limit (%zu bytes)", + Algorithm::attribute_str(negative_attr).c_str(), + Algorithm::attribute_str(positive_attr).c_str(), args.to_string().c_str(), + workspace_limit_in_bytes)); } size_t BwdData::get_workspace_in_bytes( - const TensorLayout& im, const TensorLayout& filter, - const TensorLayout& offset, const TensorLayout& mask, - const TensorLayout& out_grad, const TensorLayout& im_grad, - const TensorLayout& offset_grad, const TensorLayout& mask_grad) { - return get_dnn_workspace(this, im, filter, offset, mask, out_grad, im_grad, - offset_grad, mask_grad); + const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& im_grad, const TensorLayout& offset_grad, + const TensorLayout& mask_grad) { + return get_dnn_workspace( + this, im, filter, offset, mask, out_grad, im_grad, offset_grad, mask_grad); } const char* BwdData::get_algorithm_set_name() const { return "DEFORMABLE_CONV2_BWD_DATA_CUDA"; }; -void BwdData::exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, - _megdnn_tensor_in offset, _megdnn_tensor_in mask, - _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad, - _megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad, - _megdnn_workspace workspace) { - check_exec(im.layout, filter.layout, offset.layout, mask.layout, - out_grad.layout, im_grad.layout, offset_grad.layout, - mask_grad.layout, workspace.size); - AlgoBase::ExecArgs args(this, im, filter, offset, mask, out_grad, im_grad, - offset_grad, mask_grad, workspace); - auto algo = get_algorithm(this, im.layout, filter.layout, offset.layout, - mask.layout, out_grad.layout, im_grad.layout, - offset_grad.layout, mask_grad.layout); +void BwdData::exec( + _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset, + _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad, + _megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad, + _megdnn_workspace workspace) { + check_exec( + im.layout, filter.layout, offset.layout, mask.layout, out_grad.layout, + im_grad.layout, offset_grad.layout, mask_grad.layout, workspace.size); + AlgoBase::ExecArgs args( + this, im, filter, offset, mask, out_grad, im_grad, offset_grad, mask_grad, + workspace); + auto algo = get_algorithm( + this, im.layout, filter.layout, offset.layout, mask.layout, out_grad.layout, + im_grad.layout, offset_grad.layout, mask_grad.layout); algo->exec(args); } diff --git a/dnn/src/cuda/deformable_conv/opr_impl.h b/dnn/src/cuda/deformable_conv/opr_impl.h index 29b0589b..298be335 100644 --- a/dnn/src/cuda/deformable_conv/opr_impl.h +++ b/dnn/src/cuda/deformable_conv/opr_impl.h @@ -20,24 +20,21 @@ class DeformableConvForwardImpl : public DeformableConvForward { public: using DeformableConvForward::DeformableConvForward; - void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, - _megdnn_tensor_in offset, _megdnn_tensor_in mask, - _megdnn_tensor_out dst, _megdnn_workspace workspace) override; - - size_t get_workspace_in_bytes(const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst) override; - - Algorithm* get_algorithm_heuristic(const TensorLayout& im, - const CanonizedFilterMeta& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr); + void exec( + _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset, + _megdnn_tensor_in mask, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& dst) override; + + Algorithm* get_algorithm_heuristic( + const TensorLayout& im, const CanonizedFilterMeta& filter, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& dst, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr); const char* get_algorithm_set_name() const override; @@ -74,25 +71,21 @@ class DeformableConvBackwardFilterImpl : public DeformableConvBackwardFilter { public: using DeformableConvBackwardFilter::DeformableConvBackwardFilter; - void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, - _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, - _megdnn_tensor_out filter_grad, - _megdnn_workspace workspace) override; - - Algorithm* get_algorithm_heuristic(const TensorLayout& im, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& out_grad, - const CanonizedFilterMeta& filter_grad, - size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr); - - size_t get_workspace_in_bytes(const TensorLayout& im, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& out_grad, - const TensorLayout& filter_grad) override; + void exec( + _megdnn_tensor_in im, _megdnn_tensor_in offset, _megdnn_tensor_in mask, + _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad, + _megdnn_workspace workspace) override; + + Algorithm* get_algorithm_heuristic( + const TensorLayout& im, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const CanonizedFilterMeta& filter_grad, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr); + + size_t get_workspace_in_bytes( + const TensorLayout& im, const TensorLayout& offset, + const TensorLayout& mask, const TensorLayout& out_grad, + const TensorLayout& filter_grad) override; const char* get_algorithm_set_name() const override; @@ -129,11 +122,11 @@ class DeformableConvBackwardDataImpl : public DeformableConvBackwardData { public: using DeformableConvBackwardData::DeformableConvBackwardData; - void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, - _megdnn_tensor_in offset, _megdnn_tensor_in mask, - _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad, - _megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset, + _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, + _megdnn_tensor_out im_grad, _megdnn_tensor_out offset_grad, + _megdnn_tensor_out mask_grad, _megdnn_workspace workspace) override; Algorithm* get_algorithm_heuristic( const TensorLayout& im, const CanonizedFilterMeta& filter, @@ -143,14 +136,11 @@ public: size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr); - size_t get_workspace_in_bytes(const TensorLayout& im, - const TensorLayout& filter, - const TensorLayout& offset, - const TensorLayout& mask, - const TensorLayout& out_grad, - const TensorLayout& im_grad, - const TensorLayout& offset_grad, - const TensorLayout& mask_grad) override; + size_t get_workspace_in_bytes( + const TensorLayout& im, const TensorLayout& filter, + const TensorLayout& offset, const TensorLayout& mask, + const TensorLayout& out_grad, const TensorLayout& im_grad, + const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; const char* get_algorithm_set_name() const override; @@ -167,15 +157,13 @@ protected: const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& im_grad, - const TensorLayout& offset_grad, - const TensorLayout& mask_grad) override; - + const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; + std::vector get_all_algorithms_safe( const TensorLayout& im, const TensorLayout& filter, const TensorLayout& offset, const TensorLayout& mask, const TensorLayout& out_grad, const TensorLayout& im_grad, - const TensorLayout& offset_grad, - const TensorLayout& mask_grad) override; + const TensorLayout& offset_grad, const TensorLayout& mask_grad) override; Algorithm* get_algorithm_heuristic( const TensorLayout& im, const TensorLayout& filter, diff --git a/dnn/src/cuda/deformable_ps_roi_pooling/kimpl/kern.cu b/dnn/src/cuda/deformable_ps_roi_pooling/kimpl/kern.cu index 20e53727..20dba466 100644 --- a/dnn/src/cuda/deformable_ps_roi_pooling/kimpl/kern.cu +++ b/dnn/src/cuda/deformable_ps_roi_pooling/kimpl/kern.cu @@ -16,8 +16,8 @@ namespace { using Param = megdnn::cuda::deformable_ps_roi_pooling::Param; -__device__ float bilinear_interp(const float* data, const int IH, const int IW, - const float h, const float w) { +__device__ float bilinear_interp( + const float* data, const int IH, const int IW, const float h, const float w) { int h1 = floor(h), h2 = ceil(h); int w1 = floor(w), w2 = ceil(w); float dist_h = (float)(h - h1); @@ -27,16 +27,14 @@ __device__ float bilinear_interp(const float* data, const int IH, const int IW, float value21 = data[h1 * IW + w2]; float value22 = data[h2 * IW + w2]; float value = (1 - dist_w) * (1 - dist_h) * value11 + - (1 - dist_w) * dist_h * value12 + - dist_w * (1 - dist_h) * value21 + dist_w * dist_h * value22; + (1 - dist_w) * dist_h * value12 + dist_w * (1 - dist_h) * value21 + + dist_w * dist_h * value22; return value; } -__global__ void DeformablePSROIPoolForwardKern(Param p, const float* data, - const float* rois, - const float* trans, - float* out_data, - float* out_count) { +__global__ void DeformablePSROIPoolForwardKern( + Param p, const float* data, const float* rois, const float* trans, + float* out_data, float* out_count) { const int loops = p.nr_bbox * p.IC * p.pool_h * p.pool_w; const int icpcls = p.IC / p.nr_cls; @@ -51,10 +49,8 @@ __global__ void DeformablePSROIPoolForwardKern(Param p, const float* data, float roi_w_l = static_cast(round(rois_ptr[1])) * p.scale - 0.5; float roi_h_l = static_cast(round(rois_ptr[2])) * p.scale - 0.5; - float roi_w_r = - static_cast(round(rois_ptr[3]) + 1.) * p.scale - 0.5; - float roi_h_r = - static_cast(round(rois_ptr[4]) + 1.) * p.scale - 0.5; + float roi_w_r = static_cast(round(rois_ptr[3]) + 1.) * p.scale - 0.5; + float roi_h_r = static_cast(round(rois_ptr[4]) + 1.) * p.scale - 0.5; // Force too small ROIs to be 1x1 float roi_w = max(roi_w_r - roi_w_l, 0.1); // avoid 0 @@ -76,13 +72,12 @@ __global__ void DeformablePSROIPoolForwardKern(Param p, const float* data, if (!p.no_trans) { int part_h = floor(static_cast(ph) / p.pool_h * p.part_sz); int part_w = floor(static_cast(pw) / p.pool_w * p.part_sz); - int x_idx = (((n * p.nr_cls + cls_id) * 2) * p.part_sz + part_h) * + int x_idx = + (((n * p.nr_cls + cls_id) * 2) * p.part_sz + part_h) * p.part_sz + + part_w; + int y_idx = (((n * p.nr_cls + cls_id) * 2 + 1) * p.part_sz + part_h) * p.part_sz + part_w; - int y_idx = - (((n * p.nr_cls + cls_id) * 2 + 1) * p.part_sz + part_h) * - p.part_sz + - part_w; trans_x = trans[x_idx] * static_cast(p.trans_std); trans_y = trans[y_idx] * static_cast(p.trans_std); } @@ -90,8 +85,7 @@ __global__ void DeformablePSROIPoolForwardKern(Param p, const float* data, wstart += trans_x * roi_w; hstart += trans_y * roi_h; - const float* data_ptr = - data + (roi_batch_idx * p.IC + ic) * p.IH * p.IW; + const float* data_ptr = data + (roi_batch_idx * p.IC + ic) * p.IH * p.IW; for (int ih = 0; ih < p.sample_per_part; ih++) { for (int iw = 0; iw < p.sample_per_part; iw++) { @@ -130,10 +124,8 @@ __global__ void DeformablePSROIPoolBackwardAccKern( float roi_w_l = static_cast(round(rois_ptr[1])) * p.scale - 0.5; float roi_h_l = static_cast(round(rois_ptr[2])) * p.scale - 0.5; - float roi_w_r = - static_cast(round(rois_ptr[3]) + 1.) * p.scale - 0.5; - float roi_h_r = - static_cast(round(rois_ptr[4]) + 1.) * p.scale - 0.5; + float roi_w_r = static_cast(round(rois_ptr[3]) + 1.) * p.scale - 0.5; + float roi_h_r = static_cast(round(rois_ptr[4]) + 1.) * p.scale - 0.5; // Force too small ROIs to be 1x1 float roi_w = max(roi_w_r - roi_w_l, 0.1); // avoid 0 @@ -154,13 +146,12 @@ __global__ void DeformablePSROIPoolBackwardAccKern( if (!p.no_trans) { part_h = floor(static_cast(ph) / p.pool_h * p.part_sz); part_w = floor(static_cast(pw) / p.pool_w * p.part_sz); - int x_idx = (((n * p.nr_cls + cls_id) * 2) * p.part_sz + part_h) * + int x_idx = + (((n * p.nr_cls + cls_id) * 2) * p.part_sz + part_h) * p.part_sz + + part_w; + int y_idx = (((n * p.nr_cls + cls_id) * 2 + 1) * p.part_sz + part_h) * p.part_sz + part_w; - int y_idx = - (((n * p.nr_cls + cls_id) * 2 + 1) * p.part_sz + part_h) * - p.part_sz + - part_w; trans_x = trans[x_idx] * static_cast(p.trans_std); trans_y = trans[y_idx] * static_cast(p.trans_std); } @@ -212,22 +203,20 @@ __global__ void DeformablePSROIPoolBackwardAccKern( float U10 = data_ptr[y0 * p.IW + x1]; float U11 = data_ptr[y1 * p.IW + x1]; - float diff_x = (U11 * dist_y + U10 * (1 - dist_y) - - U01 * dist_y - U00 * (1 - dist_y)) * + float diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - + U00 * (1 - dist_y)) * p.trans_std * diff_val; - float diff_y = (U11 * dist_x + U01 * (1 - dist_x) - - U10 * dist_x - U00 * (1 - dist_x)) * + float diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - + U00 * (1 - dist_x)) * p.trans_std * diff_val; diff_x *= roi_w, diff_y *= roi_h; - int diff_x_idx = - (((n * p.nr_cls + cls_id) * 2) * p.part_sz + part_h) * - p.part_sz + - part_w; + int diff_x_idx = (((n * p.nr_cls + cls_id) * 2) * p.part_sz + part_h) * + p.part_sz + + part_w; int diff_y_idx = - (((n * p.nr_cls + cls_id) * 2 + 1) * p.part_sz + - part_h) * + (((n * p.nr_cls + cls_id) * 2 + 1) * p.part_sz + part_h) * p.part_sz + part_w; @@ -243,9 +232,9 @@ namespace megdnn { namespace cuda { namespace deformable_ps_roi_pooling { -void DeformablePSROIPoolForward(const TensorND& data, const TensorND& rois, - const TensorND& trans, const TensorND& out_data, - const TensorND& out_count, Param& p) { +void DeformablePSROIPoolForward( + const TensorND& data, const TensorND& rois, const TensorND& trans, + const TensorND& out_data, const TensorND& out_count, Param& p) { const int loops = p.nr_bbox * p.IC * p.pool_h * p.pool_w; int nr_thds = query_blocksize_for_kernel(DeformablePSROIPoolForwardKern); const int blks = DIVUP(loops, nr_thds); @@ -270,15 +259,12 @@ void DeformablePSROIPoolForward(const TensorND& data, const TensorND& rois, after_kernel_launch(); } -void DeformablePSROIPoolBackwardAcc(const TensorND& data, const TensorND& rois, - const TensorND& trans, - const TensorND& out_diff, - const TensorND& out_count, - const TensorND& data_diff, - const TensorND& trans_diff, Param& p) { +void DeformablePSROIPoolBackwardAcc( + const TensorND& data, const TensorND& rois, const TensorND& trans, + const TensorND& out_diff, const TensorND& out_count, const TensorND& data_diff, + const TensorND& trans_diff, Param& p) { const int loops = p.nr_bbox * p.IC * p.pool_h * p.pool_w; - int nr_thds = - query_blocksize_for_kernel(DeformablePSROIPoolBackwardAccKern); + int nr_thds = query_blocksize_for_kernel(DeformablePSROIPoolBackwardAccKern); const int blks = DIVUP(loops, nr_thds); const float* data_ptr = data.ptr(); diff --git a/dnn/src/cuda/deformable_ps_roi_pooling/kimpl/kern.cuh b/dnn/src/cuda/deformable_ps_roi_pooling/kimpl/kern.cuh index 248c49e9..1810cb55 100644 --- a/dnn/src/cuda/deformable_ps_roi_pooling/kimpl/kern.cuh +++ b/dnn/src/cuda/deformable_ps_roi_pooling/kimpl/kern.cuh @@ -31,16 +31,14 @@ struct Param { cudaStream_t stream; }; -void DeformablePSROIPoolForward(const TensorND& data, const TensorND& rois, - const TensorND& trans, const TensorND& out_data, - const TensorND& out_count, Param& p); - -void DeformablePSROIPoolBackwardAcc(const TensorND& data, const TensorND& rois, - const TensorND& trans, - const TensorND& out_diff, - const TensorND& out_count, - const TensorND& data_diff, - const TensorND& trans_diff, Param& p); +void DeformablePSROIPoolForward( + const TensorND& data, const TensorND& rois, const TensorND& trans, + const TensorND& out_data, const TensorND& out_count, Param& p); + +void DeformablePSROIPoolBackwardAcc( + const TensorND& data, const TensorND& rois, const TensorND& trans, + const TensorND& out_diff, const TensorND& out_count, const TensorND& data_diff, + const TensorND& trans_diff, Param& p); } // namespace deformable_ps_roi_pooling } // namespace cuda diff --git a/dnn/src/cuda/deformable_ps_roi_pooling/opr_impl.cpp b/dnn/src/cuda/deformable_ps_roi_pooling/opr_impl.cpp index c11d1c8d..bfda497a 100644 --- a/dnn/src/cuda/deformable_ps_roi_pooling/opr_impl.cpp +++ b/dnn/src/cuda/deformable_ps_roi_pooling/opr_impl.cpp @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/cuda/deformable_ps_roi_pooling/kimpl/kern.cuh" #include "src/cuda/deformable_ps_roi_pooling/opr_impl.h" +#include "src/cuda/deformable_ps_roi_pooling/kimpl/kern.cuh" #include "src/cuda/utils.h" using namespace megdnn; @@ -18,9 +18,9 @@ using KernParam = deformable_ps_roi_pooling::Param; namespace { -void create_param(const DeformablePSROIPoolingBase* opr, - const TensorLayout& data, const TensorLayout& rois, - const TensorLayout& trans, KernParam& p) { +void create_param( + const DeformablePSROIPoolingBase* opr, const TensorLayout& data, + const TensorLayout& rois, const TensorLayout& trans, KernParam& p) { auto&& param = opr->param(); auto&& handle = concrete_handle(opr->handle()); @@ -44,16 +44,15 @@ void create_param(const DeformablePSROIPoolingBase* opr, namespace megdnn { namespace cuda { -void DeformablePSROIPoolingForwardImpl::exec(_megdnn_tensor_in data, - _megdnn_tensor_in rois, - _megdnn_tensor_in trans, - _megdnn_tensor_out out_data, - _megdnn_tensor_out out_count, - _megdnn_workspace workspace) { +void DeformablePSROIPoolingForwardImpl::exec( + _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans, + _megdnn_tensor_out out_data, _megdnn_tensor_out out_count, + _megdnn_workspace workspace) { KernParam p; - check_exec(data.layout, rois.layout, trans.layout, out_data.layout, - out_count.layout, workspace.size); + check_exec( + data.layout, rois.layout, trans.layout, out_data.layout, out_count.layout, + workspace.size); create_param(this, data.layout, rois.layout, trans.layout, p); deformable_ps_roi_pooling::DeformablePSROIPoolForward( @@ -67,9 +66,9 @@ void DeformablePSROIPoolingBackwardImpl::exec( _megdnn_workspace workspace) { KernParam p; - check_exec(data.layout, rois.layout, trans.layout, out_diff.layout, - out_count.layout, data_diff.layout, trans_diff.layout, - workspace.size); + check_exec( + data.layout, rois.layout, trans.layout, out_diff.layout, out_count.layout, + data_diff.layout, trans_diff.layout, workspace.size); create_param(this, data.layout, rois.layout, trans.layout, p); deformable_ps_roi_pooling::DeformablePSROIPoolBackwardAcc( data, rois, trans, out_diff, out_count, data_diff, trans_diff, p); diff --git a/dnn/src/cuda/deformable_ps_roi_pooling/opr_impl.h b/dnn/src/cuda/deformable_ps_roi_pooling/opr_impl.h index 0a1858c2..da47aca8 100644 --- a/dnn/src/cuda/deformable_ps_roi_pooling/opr_impl.h +++ b/dnn/src/cuda/deformable_ps_roi_pooling/opr_impl.h @@ -14,8 +14,7 @@ namespace megdnn { namespace cuda { -class DeformablePSROIPoolingForwardImpl final - : public DeformablePSROIPoolingForward { +class DeformablePSROIPoolingForwardImpl final : public DeformablePSROIPoolingForward { public: using DeformablePSROIPoolingForward::DeformablePSROIPoolingForward; @@ -26,32 +25,29 @@ public: return 0ULL; }; - void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois, - _megdnn_tensor_in trans, _megdnn_tensor_out out_data, - _megdnn_tensor_out out_count, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans, + _megdnn_tensor_out out_data, _megdnn_tensor_out out_count, + _megdnn_workspace workspace) override; }; -class DeformablePSROIPoolingBackwardImpl final - : public DeformablePSROIPoolingBackward { +class DeformablePSROIPoolingBackwardImpl final : public DeformablePSROIPoolingBackward { public: using DeformablePSROIPoolingBackward::DeformablePSROIPoolingBackward; - size_t get_workspace_in_bytes(const TensorLayout& /* data */, - const TensorLayout& /* rois */, - const TensorLayout& /* trans */, - const TensorLayout& /* out_diff */, - const TensorLayout& /* out_count */, - const TensorLayout& /* data_diff */, - const TensorLayout& /* trans_diff */) override { + size_t get_workspace_in_bytes( + const TensorLayout& /* data */, const TensorLayout& /* rois */, + const TensorLayout& /* trans */, const TensorLayout& /* out_diff */, + const TensorLayout& /* out_count */, const TensorLayout& /* data_diff */, + const TensorLayout& /* trans_diff */) override { return 0ULL; }; - void exec(_megdnn_tensor_in data, _megdnn_tensor_in rois, - _megdnn_tensor_in trans, _megdnn_tensor_in out_diff, - _megdnn_tensor_in out_count, _megdnn_tensor_out data_diff, - _megdnn_tensor_out trans_diff, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in data, _megdnn_tensor_in rois, _megdnn_tensor_in trans, + _megdnn_tensor_in out_diff, _megdnn_tensor_in out_count, + _megdnn_tensor_out data_diff, _megdnn_tensor_out trans_diff, + _megdnn_workspace workspace) override; }; } // namespace cuda diff --git a/dnn/src/cuda/dot/dot.cu b/dnn/src/cuda/dot/dot.cu index a4913f23..27b21bce 100644 --- a/dnn/src/cuda/dot/dot.cu +++ b/dnn/src/cuda/dot/dot.cu @@ -10,27 +10,32 @@ */ #include "src/cuda/dot/dot.cuh" -#include "src/cuda/utils.cuh" #include "src/cuda/cub/util_ptx.cuh" +#include "src/cuda/utils.cuh" namespace { using namespace megdnn; using namespace cuda; -template __global__ void kernel(const T *a, const T *b, - dt_float32 *c, - uint32_t n, int32_t strideA, int32_t strideB) -{ +template +__global__ void kernel( + const T* a, const T* b, dt_float32* c, uint32_t n, int32_t strideA, + int32_t strideB) { uint32_t tid = threadIdx.x; uint32_t gid = threadIdx.x + blockIdx.x * blockDim.x; volatile __shared__ dt_float32 sdata[256]; - sdata[tid] = (gid < n ? - dt_float32(a[gid*strideA]) * dt_float32(b[gid*strideB]) - : 0); + sdata[tid] = + (gid < n ? dt_float32(a[gid * strideA]) * dt_float32(b[gid * strideB]) : 0); + __syncthreads(); + if (tid < 128) { + sdata[tid] += sdata[tid + 128]; + } + __syncthreads(); + if (tid < 64) { + sdata[tid] += sdata[tid + 64]; + } __syncthreads(); - if (tid < 128) { sdata[tid] += sdata[tid + 128]; } __syncthreads(); - if (tid < 64) { sdata[tid] += sdata[tid + 64]; } __syncthreads(); if (tid < 32) { sdata[tid] += sdata[tid + 32]; cub::WARP_SYNC(0xffffffff); @@ -53,39 +58,36 @@ template __global__ void kernel(const T *a, const T *b, atomicAdd(c, sdata[0]); } -template __global__ void cvt_kernel(const dt_float32 *src, T *dst) -{ +template +__global__ void cvt_kernel(const dt_float32* src, T* dst) { dst[0] = T(src[0]); } -} // anonymous namespace +} // anonymous namespace namespace megdnn { namespace cuda { namespace dot { -template void run(const T *a, const T *b, T *c, float *workspace, - uint32_t n, int32_t strideA, int32_t strideB, - cudaStream_t stream) -{ +template +void run( + const T* a, const T* b, T* c, float* workspace, uint32_t n, int32_t strideA, + int32_t strideB, cudaStream_t stream) { cuda_check(cudaMemsetAsync(workspace, 0, sizeof(dt_float32), stream)); // each block add 256 entries uint32_t blocks = DIVUP(n, 256); uint32_t threads = 256; - kernel<<>>(a, b, - workspace, - n, strideA, strideB); + kernel<<>>(a, b, workspace, n, strideA, strideB); cvt_kernel<<<1, 1, 0, stream>>>(workspace, c); after_kernel_launch(); } -template void run(const dt_float16 *a, const dt_float16 *b, - dt_float16 *c, dt_float32 *workspace, - uint32_t n, int32_t strideA, int32_t strideB, - cudaStream_t stream); +template void run( + const dt_float16* a, const dt_float16* b, dt_float16* c, dt_float32* workspace, + uint32_t n, int32_t strideA, int32_t strideB, cudaStream_t stream); -} // namespace dot -} // namespace cuda -} // namespace megdnn +} // namespace dot +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/dot/dot.cuh b/dnn/src/cuda/dot/dot.cuh index 579d1eb5..745793a3 100644 --- a/dnn/src/cuda/dot/dot.cuh +++ b/dnn/src/cuda/dot/dot.cuh @@ -16,14 +16,13 @@ namespace megdnn { namespace cuda { namespace dot { -template void run(const T *a, const T *b, T *c, - float *workspace, - uint32_t n, - int32_t strideA, int32_t strideB, - cudaStream_t stream); +template +void run( + const T* a, const T* b, T* c, float* workspace, uint32_t n, int32_t strideA, + int32_t strideB, cudaStream_t stream); -} // namespace dot -} // namespace cuda -} // namespace megdnn +} // namespace dot +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/dot/opr_impl.cpp b/dnn/src/cuda/dot/opr_impl.cpp index 6eddb1a9..755a2e42 100644 --- a/dnn/src/cuda/dot/opr_impl.cpp +++ b/dnn/src/cuda/dot/opr_impl.cpp @@ -10,37 +10,32 @@ */ #include "src/cuda/dot/opr_impl.h" -#include "src/cuda/utils.h" #include "src/cuda/dot/dot.cuh" +#include "src/cuda/utils.h" namespace megdnn { namespace cuda { -void DotForwardImpl::exec(_megdnn_tensor_in A, - _megdnn_tensor_in B, - _megdnn_tensor_out C, - _megdnn_workspace workspace) -{ +void DotForwardImpl::exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) { check_exec(A.layout, B.layout, C.layout, workspace.size); megdnn_assert(A.layout.dtype.category() == DTypeCategory::FLOAT); auto handle = cublas_handle(this->handle()); if (A.layout.dtype == dtype::Float32()) { - cublas_check(cublasSdot(handle, A.layout.total_nr_elems(), - A.ptr(), A.layout.stride[0], - B.ptr(), B.layout.stride[0], - C.ptr())); + cublas_check(cublasSdot( + handle, A.layout.total_nr_elems(), A.ptr(), + A.layout.stride[0], B.ptr(), B.layout.stride[0], + C.ptr())); } else { megdnn_assert_internal(A.layout.dtype == dtype::Float16()); - dot::run(A.ptr(), - B.ptr(), - C.ptr(), - workspace.ptr(), - A.layout.total_nr_elems(), - A.layout.stride[0], B.layout.stride[0], - cuda_stream(this->handle())); + dot::run( + A.ptr(), B.ptr(), C.ptr(), + workspace.ptr(), A.layout.total_nr_elems(), + A.layout.stride[0], B.layout.stride[0], cuda_stream(this->handle())); } } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/dot/opr_impl.h b/dnn/src/cuda/dot/opr_impl.h index 11bf2d55..6b0645df 100644 --- a/dnn/src/cuda/dot/opr_impl.h +++ b/dnn/src/cuda/dot/opr_impl.h @@ -16,21 +16,19 @@ namespace megdnn { namespace cuda { -class DotForwardImpl final: public DotForward { - public: - using DotForward::DotForward; - void exec(_megdnn_tensor_in A, - _megdnn_tensor_in B, - _megdnn_tensor_out C, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &, - const TensorLayout &) override { - return sizeof(float); - } +class DotForwardImpl final : public DotForward { +public: + using DotForward::DotForward; + void exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return sizeof(float); + } }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/elemwise/kern_impl.inl b/dnn/src/cuda/elemwise/kern_impl.inl index b3057b8a..aad47a85 100644 --- a/dnn/src/cuda/elemwise/kern_impl.inl +++ b/dnn/src/cuda/elemwise/kern_impl.inl @@ -20,18 +20,16 @@ namespace megdnn { namespace cuda { -#define cb(_mode) \ - typedef ElemwiseKern< \ - megcorePlatformCUDA, \ - param_enumv::Elemwise::Mode::_mode, KERN_IMPL_CTYPE> \ - KernImpl##_mode; \ - typedef ElemArithKernWrapper \ - Wrapper##_mode; \ - INST_RUN_ELEMWISE(Wrapper##_mode, KERN_IMPL_CTYPE, KERN_IMPL_ARITY); \ +#define cb(_mode) \ + typedef ElemwiseKern< \ + megcorePlatformCUDA, param_enumv::Elemwise::Mode::_mode, KERN_IMPL_CTYPE> \ + KernImpl##_mode; \ + typedef ElemArithKernWrapper Wrapper##_mode; \ + INST_RUN_ELEMWISE(Wrapper##_mode, KERN_IMPL_CTYPE, KERN_IMPL_ARITY); KERN_IMPL_MODE(cb) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/elemwise/kern_wrapper.cuh b/dnn/src/cuda/elemwise/kern_wrapper.cuh index a1f97572..b79dce70 100644 --- a/dnn/src/cuda/elemwise/kern_wrapper.cuh +++ b/dnn/src/cuda/elemwise/kern_wrapper.cuh @@ -17,145 +17,134 @@ namespace megdnn { namespace cuda { - template - struct ElemArithKernWrapper; +template +struct ElemArithKernWrapper; - template - struct ElemArithKernWrapper< - 1, KernImpl, - typename std::enable_if< - !std::is_same::value && - !std::is_same::value && - !std::is_same::value>::type> { - typedef typename KernImpl::ctype ctype; - ctype* dst; +template +struct ElemArithKernWrapper< + 1, KernImpl, + typename std::enable_if< + !std::is_same::value && + !std::is_same::value && + !std::is_same::value>::type> { + typedef typename KernImpl::ctype ctype; + ctype* dst; #if MEGDNN_CC_CUDA - __device__ void operator()(uint32_t idx, ctype x) { - dst[idx] = KernImpl::apply(x); - } + __device__ void operator()(uint32_t idx, ctype x) { dst[idx] = KernImpl::apply(x); } #endif - }; - template - struct ElemArithKernWrapper< - 2, KernImpl, - typename std::enable_if< - !std::is_same::value && - !std::is_same::value && - !std::is_same::value>::type> { - typedef typename KernImpl::ctype ctype; - ctype* dst; +}; +template +struct ElemArithKernWrapper< + 2, KernImpl, + typename std::enable_if< + !std::is_same::value && + !std::is_same::value && + !std::is_same::value>::type> { + typedef typename KernImpl::ctype ctype; + ctype* dst; #if MEGDNN_CC_CUDA - __device__ void operator()(uint32_t idx, ctype x, ctype y) { - dst[idx] = KernImpl::apply(x, y); - } + __device__ void operator()(uint32_t idx, ctype x, ctype y) { + dst[idx] = KernImpl::apply(x, y); + } #endif - }; - template - struct ElemArithKernWrapper< - 3, KernImpl, - typename std::enable_if< - !std::is_same::value && - !std::is_same::value && - !std::is_same::value>::type> { - typedef typename KernImpl::ctype ctype; - ctype* dst; +}; +template +struct ElemArithKernWrapper< + 3, KernImpl, + typename std::enable_if< + !std::is_same::value && + !std::is_same::value && + !std::is_same::value>::type> { + typedef typename KernImpl::ctype ctype; + ctype* dst; #if MEGDNN_CC_CUDA - __device__ void operator()(uint32_t idx, ctype x, ctype y, ctype z) { - dst[idx] = KernImpl::apply(x, y, z); - } + __device__ void operator()(uint32_t idx, ctype x, ctype y, ctype z) { + dst[idx] = KernImpl::apply(x, y, z); + } #endif - }; +}; - template - struct ElemArithKernWrapper< - 1, KernImpl, - typename std::enable_if< - std::is_same::value || - std::is_same::value || - std::is_same::value>::type> { - typedef typename KernImpl::ctype ctype; - using VectTypeTrait = elemwise_intl::VectTypeTrait; - typedef typename VectTypeTrait::vect_type vect_type; - ctype* dst; +template +struct ElemArithKernWrapper< + 1, KernImpl, + typename std::enable_if< + std::is_same::value || + std::is_same::value || + std::is_same::value>::type> { + typedef typename KernImpl::ctype ctype; + using VectTypeTrait = elemwise_intl::VectTypeTrait; + typedef typename VectTypeTrait::vect_type vect_type; + ctype* dst; #if MEGDNN_CC_CUDA - __device__ __forceinline__ void operator()(uint32_t idx, ctype x) { - dst[idx] = KernImpl::apply(x); - } - __device__ __forceinline__ void operator()(uint32_t idx, vect_type x) { - ctype a = KernImpl::apply(x.x); - ctype b = KernImpl::apply(x.y); - ctype g = KernImpl::apply(x.z); - ctype r = KernImpl::apply(x.w); - *(vect_type*)(&dst[idx]) = VectTypeTrait::make_vector(a, b, g, r); - } + __device__ __forceinline__ void operator()(uint32_t idx, ctype x) { + dst[idx] = KernImpl::apply(x); + } + __device__ __forceinline__ void operator()(uint32_t idx, vect_type x) { + ctype a = KernImpl::apply(x.x); + ctype b = KernImpl::apply(x.y); + ctype g = KernImpl::apply(x.z); + ctype r = KernImpl::apply(x.w); + *(vect_type*)(&dst[idx]) = VectTypeTrait::make_vector(a, b, g, r); + } #endif - }; +}; - template - struct ElemArithKernWrapper< - 2, KernImpl, - typename std::enable_if< - std::is_same::value || - std::is_same::value || - std::is_same::value>::type> { - typedef typename KernImpl::ctype ctype; - using VectTypeTrait = elemwise_intl::VectTypeTrait; - typedef typename VectTypeTrait::vect_type vect_type; - ctype* dst; +template +struct ElemArithKernWrapper< + 2, KernImpl, + typename std::enable_if< + std::is_same::value || + std::is_same::value || + std::is_same::value>::type> { + typedef typename KernImpl::ctype ctype; + using VectTypeTrait = elemwise_intl::VectTypeTrait; + typedef typename VectTypeTrait::vect_type vect_type; + ctype* dst; #if MEGDNN_CC_CUDA - __device__ __forceinline__ void operator()(uint32_t idx, ctype x, - ctype y) { - dst[idx] = KernImpl::apply(x, y); - } - __device__ __forceinline__ void operator()(uint32_t idx, vect_type x, - vect_type y) { - ctype a = KernImpl::apply(x.x, y.x); - ctype b = KernImpl::apply(x.y, y.y); - ctype g = KernImpl::apply(x.z, y.z); - ctype r = KernImpl::apply(x.w, y.w); - *(vect_type*)(&dst[idx]) = VectTypeTrait::make_vector(a, b, g, r); - } + __device__ __forceinline__ void operator()(uint32_t idx, ctype x, ctype y) { + dst[idx] = KernImpl::apply(x, y); + } + __device__ __forceinline__ void operator()(uint32_t idx, vect_type x, vect_type y) { + ctype a = KernImpl::apply(x.x, y.x); + ctype b = KernImpl::apply(x.y, y.y); + ctype g = KernImpl::apply(x.z, y.z); + ctype r = KernImpl::apply(x.w, y.w); + *(vect_type*)(&dst[idx]) = VectTypeTrait::make_vector(a, b, g, r); + } #endif - }; +}; - template - struct ElemArithKernWrapper< - 3, KernImpl, - typename std::enable_if< - std::is_same::value || - std::is_same::value || - std::is_same::value>::type> { - typedef typename KernImpl::ctype ctype; - using VectTypeTrait = elemwise_intl::VectTypeTrait; - typedef typename VectTypeTrait::vect_type vect_type; - ctype* dst; +template +struct ElemArithKernWrapper< + 3, KernImpl, + typename std::enable_if< + std::is_same::value || + std::is_same::value || + std::is_same::value>::type> { + typedef typename KernImpl::ctype ctype; + using VectTypeTrait = elemwise_intl::VectTypeTrait; + typedef typename VectTypeTrait::vect_type vect_type; + ctype* dst; #if MEGDNN_CC_CUDA - __device__ __forceinline__ void operator()(uint32_t idx, ctype x, - ctype y, ctype z) { - dst[idx] = KernImpl::apply(x, y, z); - } - __device__ __forceinline__ void operator()(uint32_t idx, vect_type x, - vect_type y, vect_type z) { - ctype a = KernImpl::apply(x.x, y.x, z.x); - ctype b = KernImpl::apply(x.y, y.y, z.y); - ctype g = KernImpl::apply(x.z, y.z, z.z); - ctype r = KernImpl::apply(x.w, y.w, z.w); - *(vect_type*)(&dst[idx]) = VectTypeTrait::make_vector(a, b, g, r); - } + __device__ __forceinline__ void operator()( + uint32_t idx, ctype x, ctype y, ctype z) { + dst[idx] = KernImpl::apply(x, y, z); + } + __device__ __forceinline__ void operator()( + uint32_t idx, vect_type x, vect_type y, vect_type z) { + ctype a = KernImpl::apply(x.x, y.x, z.x); + ctype b = KernImpl::apply(x.y, y.y, z.y); + ctype g = KernImpl::apply(x.z, y.z, z.z); + ctype r = KernImpl::apply(x.w, y.w, z.w); + *(vect_type*)(&dst[idx]) = VectTypeTrait::make_vector(a, b, g, r); + } #endif - }; +}; } // namespace cuda } // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen - diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu index c6a0644e..e567e703 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_float16.cu index d4f8ac33..0ed0893c 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_float32.cu index 4b8c7696..d7050449 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int16.cu index fe2bb209..3ed0e0ac 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int32.cu index 062685a7..2edaa77f 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int8.cu index bd883a99..2e179e01 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_uint8.cu index 185c733b..94ea4515 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu index 0cfb7e0e..307d8728 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ABS_dt_float16.cu index 665208ea..657943df 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ABS_dt_float32.cu index 6bd3fd01..42811046 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/ABS_dt_int16.cu index 6d0a1d42..1af52940 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/ABS_dt_int32.cu index b7468e36..f8c40d38 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/ABS_dt_int8.cu index 9af9fc33..a0e78590 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ABS_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/ABS_dt_uint8.cu index c197ee12..52683d2f 100644 --- a/dnn/src/cuda/elemwise/kimpl/ABS_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/ABS_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu index 912d2a62..0f2310b5 100644 --- a/dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ACOS_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ACOS_dt_float16.cu index 9a072a73..5a24fd83 100644 --- a/dnn/src/cuda/elemwise/kimpl/ACOS_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ACOS_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ACOS_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ACOS_dt_float32.cu index c8382465..e9b40455 100644 --- a/dnn/src/cuda/elemwise/kimpl/ACOS_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ACOS_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu index 11892b40..5a6d2945 100644 --- a/dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ADD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ADD_dt_float16.cu index d1097cee..701b4761 100644 --- a/dnn/src/cuda/elemwise/kimpl/ADD_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ADD_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ADD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ADD_dt_float32.cu index 04e414d8..83f23173 100644 --- a/dnn/src/cuda/elemwise/kimpl/ADD_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ADD_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ADD_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/ADD_dt_int16.cu index 2692639b..9b1492a9 100644 --- a/dnn/src/cuda/elemwise/kimpl/ADD_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ADD_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ADD_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/ADD_dt_int32.cu index 2a8b63ab..1a17b22e 100644 --- a/dnn/src/cuda/elemwise/kimpl/ADD_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ADD_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ADD_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/ADD_dt_int8.cu index a9ff809f..30d321cc 100644 --- a/dnn/src/cuda/elemwise/kimpl/ADD_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/ADD_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ADD_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/ADD_dt_uint8.cu index fd4c23d0..b39b02c5 100644 --- a/dnn/src/cuda/elemwise/kimpl/ADD_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/ADD_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu index 7ca91d7a..4e1f3792 100644 --- a/dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu +++ b/dnn/src/cuda/elemwise/kimpl/AND_dt_bool.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(AND, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bool +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu index 011994da..4c34a0f6 100644 --- a/dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ASIN_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ASIN_dt_float16.cu index 20b2a7c8..14c6656c 100644 --- a/dnn/src/cuda/elemwise/kimpl/ASIN_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ASIN_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ASIN_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ASIN_dt_float32.cu index a7852fa9..339cc2b3 100644 --- a/dnn/src/cuda/elemwise/kimpl/ASIN_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ASIN_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu index 84ece010..d14ed04e 100644 --- a/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_float16.cu index e30a5931..84c5ecc8 100644 --- a/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_float32.cu index 7024dbaa..12bc6c6a 100644 --- a/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ATAN2_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu index 9b4995bc..803a5b84 100644 --- a/dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/CEIL_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/CEIL_dt_float16.cu index e5051bb2..d98a8fa0 100644 --- a/dnn/src/cuda/elemwise/kimpl/CEIL_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/CEIL_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/CEIL_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/CEIL_dt_float32.cu index c3f91b79..f8097d03 100644 --- a/dnn/src/cuda/elemwise/kimpl/CEIL_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/CEIL_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu index e2aafdd6..e67deade 100644 --- a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_float16.cu index 6025e7e0..1d2558f3 100644 --- a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_float32.cu index 90d61a5f..04631683 100644 --- a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int16.cu index 81bd6fe1..d54bd542 100644 --- a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int32.cu index 63d9211a..8e8865b9 100644 --- a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int8.cu index cb8b92d3..048ee372 100644 --- a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_uint8.cu index fd1b9437..952361bf 100644 --- a/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu index 467b85a0..4455a84a 100644 --- a/dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/COS_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/COS_dt_float16.cu index c3b061ed..81649217 100644 --- a/dnn/src/cuda/elemwise/kimpl/COS_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/COS_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/COS_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/COS_dt_float32.cu index 89b9f12c..d5b275d1 100644 --- a/dnn/src/cuda/elemwise/kimpl/COS_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/COS_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu index 66caac75..53cb005e 100644 --- a/dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu index 437b4d84..1259d242 100644 --- a/dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu +++ b/dnn/src/cuda/elemwise/kimpl/EQ_dt_bool.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bool +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/EQ_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/EQ_dt_float16.cu index 2492fcb8..b8eaf23d 100644 --- a/dnn/src/cuda/elemwise/kimpl/EQ_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/EQ_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/EQ_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/EQ_dt_float32.cu index 3dbdaf9d..d5801782 100644 --- a/dnn/src/cuda/elemwise/kimpl/EQ_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/EQ_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/EQ_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/EQ_dt_int16.cu index 1887146f..ada45fff 100644 --- a/dnn/src/cuda/elemwise/kimpl/EQ_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/EQ_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/EQ_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/EQ_dt_int32.cu index 2518d6ff..fcb4a91f 100644 --- a/dnn/src/cuda/elemwise/kimpl/EQ_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/EQ_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/EQ_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/EQ_dt_int8.cu index d0ca968f..b062c771 100644 --- a/dnn/src/cuda/elemwise/kimpl/EQ_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/EQ_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/EQ_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/EQ_dt_uint8.cu index 6c62949c..5993ba47 100644 --- a/dnn/src/cuda/elemwise/kimpl/EQ_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/EQ_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu index 66899960..5e6eea87 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_float16.cu index 98315e0e..93be9a58 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_float32.cu index e337f0c6..b733b75a 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu index 05218cf3..5cbd20f6 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERFC_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ERFC_dt_float16.cu index 2f0894cc..f591dde7 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERFC_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERFC_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERFC_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ERFC_dt_float32.cu index 9dd164d5..45fb61f9 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERFC_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERFC_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu index 153f54ab..a711b9b4 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_float16.cu index 37b4a3f4..2f378896 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_float32.cu index a022e82c..7d615c60 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERFINV_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu index 29c76242..e0821627 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERF_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ERF_dt_float16.cu index 2156e847..5c9c1167 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERF_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERF_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ERF_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ERF_dt_float32.cu index 3b86ad21..bc529c29 100644 --- a/dnn/src/cuda/elemwise/kimpl/ERF_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ERF_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu index 36eb1dbb..46b7380e 100644 --- a/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_float16.cu index daaed095..8b066b88 100644 --- a/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_float32.cu index 8acc8cd2..1b90614c 100644 --- a/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/EXPM1_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu index 576831bb..29450d25 100644 --- a/dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/EXP_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/EXP_dt_float16.cu index 57e07652..245ebe67 100644 --- a/dnn/src/cuda/elemwise/kimpl/EXP_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/EXP_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/EXP_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/EXP_dt_float32.cu index cbf23a51..4d83a8a6 100644 --- a/dnn/src/cuda/elemwise/kimpl/EXP_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/EXP_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu index 3f4de682..2f2f684f 100644 --- a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_float16.cu index 68034e3f..232123ec 100644 --- a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_float32.cu index 16614d4d..35863974 100644 --- a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu index 1e41b596..8e839f5e 100644 --- a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_float16.cu index 128142cf..a3a38da2 100644 --- a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_float32.cu index 7c67ca34..5662127f 100644 --- a/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu index df89d8cf..7d67e397 100644 --- a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_float16.cu index 102a4455..de2b7e73 100644 --- a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_float32.cu index c22574b6..5e7ae3e6 100644 --- a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int16.cu index 0c5eadea..c4d59b48 100644 --- a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int32.cu index 23408ae3..7c286e4c 100644 --- a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int8.cu index aa6005ea..6aec2475 100644 --- a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_uint8.cu index 5aa2fa74..e354384b 100644 --- a/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu index cad5adc1..fe2304f2 100644 --- a/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_float16.cu index aa434531..adfe8c74 100644 --- a/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_float32.cu index b64b99c7..16c92309 100644 --- a/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/FLOOR_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu index c4a621e5..44606a13 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float16.cu index 255dca30..50da7da3 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float32.cu index c183462b..ad092051 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu index 86c3d9f6..3668668a 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_float16.cu index f1541b7a..8c7d5549 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_float32.cu index a9aa59ae..fa251100 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int16.cu index 86038f27..37d7fa79 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int32.cu index 6f1a21b7..fb139e25 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int8.cu index dd2771dd..291376c8 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_uint8.cu index 229d7b69..c7a0b8ae 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu index 1ff3d0cb..4681c64c 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float16.cu index 7bd8b0f5..e7c0a3ae 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float32.cu index 48656fc4..20d07290 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu index 5928c23c..1bfef7cb 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_float16.cu index 86ea8f2a..bc5ec2df 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_float32.cu index 349b33ea..34561f8c 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu index 9f213537..3571124b 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_float16.cu index 5716afe2..81596768 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_float32.cu index 7e4134cb..2c2e57d4 100644 --- a/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cu index 32a7d825..657e4d73 100644 --- a/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float16.cu index f3481781..7ef5ef4c 100644 --- a/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float32.cu index 6bff61fe..13f908ac 100644 --- a/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/GELU_GRAD_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/GELU_dt_bfloat16.cu index 862473da..c47979c7 100644 --- a/dnn/src/cuda/elemwise/kimpl/GELU_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/GELU_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/GELU_dt_float16.cu index 2df85208..9e7b02e1 100644 --- a/dnn/src/cuda/elemwise/kimpl/GELU_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/GELU_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/GELU_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/GELU_dt_float32.cu index 5896ef8b..14cc711c 100644 --- a/dnn/src/cuda/elemwise/kimpl/GELU_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/GELU_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu index c9d0642a..aac278f9 100644 --- a/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_float16.cu index 4e03c1e1..3b6c2f9e 100644 --- a/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_float32.cu index 8fbfc156..f7a3dad6 100644 --- a/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu index 42962ab1..780dfd94 100644 --- a/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_float16.cu index a97d4aaf..bc10d85d 100644 --- a/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_float32.cu index 6f42839c..cc401177 100644 --- a/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu index 99002d87..c67b986d 100644 --- a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu index 2e66f6a0..8f44e7f4 100644 --- a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu +++ b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_bool.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bool +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_float16.cu index 786c2feb..c89bd5bd 100644 --- a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_float32.cu index 3d1f4970..9c13493d 100644 --- a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int16.cu index 33f503a9..36add32e 100644 --- a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int32.cu index c7e04327..33375ab4 100644 --- a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int8.cu index 7c7bebcd..4980c1f9 100644 --- a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_uint8.cu index ef977f91..a925b56b 100644 --- a/dnn/src/cuda/elemwise/kimpl/LEQ_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/LEQ_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu index 61049398..a1d1c0c0 100644 --- a/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_float16.cu index 2f95257b..2d93d196 100644 --- a/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_float32.cu index 7fe27d28..bcfdb514 100644 --- a/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/LOG1P_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu index 118b75b5..68ebc213 100644 --- a/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_float16.cu index b9eb2b37..e11ab5ef 100644 --- a/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_float32.cu index c5ea7054..9b6e10a3 100644 --- a/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu index ba72b22e..3da8616b 100644 --- a/dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOG_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/LOG_dt_float16.cu index cda065e6..7697f22a 100644 --- a/dnn/src/cuda/elemwise/kimpl/LOG_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LOG_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/LOG_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/LOG_dt_float32.cu index 56b1cfd6..e85f7bd0 100644 --- a/dnn/src/cuda/elemwise/kimpl/LOG_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/LOG_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu index b1751e5b..89e29e42 100644 --- a/dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu index d66053e6..d25c6d7a 100644 --- a/dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu +++ b/dnn/src/cuda/elemwise/kimpl/LT_dt_bool.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bool +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LT_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/LT_dt_float16.cu index 2bd4bb7f..c90327c3 100644 --- a/dnn/src/cuda/elemwise/kimpl/LT_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LT_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/LT_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/LT_dt_float32.cu index bfd1c942..9f0a55ab 100644 --- a/dnn/src/cuda/elemwise/kimpl/LT_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/LT_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LT_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/LT_dt_int16.cu index 484f8cfe..525a688c 100644 --- a/dnn/src/cuda/elemwise/kimpl/LT_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/LT_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LT_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/LT_dt_int32.cu index d44e5041..0ed47b7e 100644 --- a/dnn/src/cuda/elemwise/kimpl/LT_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/LT_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LT_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/LT_dt_int8.cu index 1ae62018..77ef3e62 100644 --- a/dnn/src/cuda/elemwise/kimpl/LT_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/LT_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/LT_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/LT_dt_uint8.cu index a18d0913..debac858 100644 --- a/dnn/src/cuda/elemwise/kimpl/LT_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/LT_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu index 0ded7168..d271099c 100644 --- a/dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/MAX_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/MAX_dt_float16.cu index 580efc07..54fe068e 100644 --- a/dnn/src/cuda/elemwise/kimpl/MAX_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MAX_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/MAX_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/MAX_dt_float32.cu index fc13cb74..3bc662fb 100644 --- a/dnn/src/cuda/elemwise/kimpl/MAX_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/MAX_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MAX_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/MAX_dt_int16.cu index b49743e1..f9be2946 100644 --- a/dnn/src/cuda/elemwise/kimpl/MAX_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MAX_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MAX_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/MAX_dt_int32.cu index c9649f9c..e6dc68b9 100644 --- a/dnn/src/cuda/elemwise/kimpl/MAX_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/MAX_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MAX_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/MAX_dt_int8.cu index e0e24df0..5310e909 100644 --- a/dnn/src/cuda/elemwise/kimpl/MAX_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/MAX_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MAX_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/MAX_dt_uint8.cu index bf1a78a3..ef4aff04 100644 --- a/dnn/src/cuda/elemwise/kimpl/MAX_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/MAX_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu index 9a487057..42f6a008 100644 --- a/dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/MIN_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/MIN_dt_float16.cu index 26c8df53..f6ec5efe 100644 --- a/dnn/src/cuda/elemwise/kimpl/MIN_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MIN_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/MIN_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/MIN_dt_float32.cu index d3a40eff..2a1b9ba7 100644 --- a/dnn/src/cuda/elemwise/kimpl/MIN_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/MIN_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MIN_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/MIN_dt_int16.cu index 787b8d21..6d1372ec 100644 --- a/dnn/src/cuda/elemwise/kimpl/MIN_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MIN_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MIN_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/MIN_dt_int32.cu index a7621fdb..c11808e1 100644 --- a/dnn/src/cuda/elemwise/kimpl/MIN_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/MIN_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MIN_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/MIN_dt_int8.cu index 598a3f06..4f3e6fa9 100644 --- a/dnn/src/cuda/elemwise/kimpl/MIN_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/MIN_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MIN_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/MIN_dt_uint8.cu index 393347fb..9c13be8e 100644 --- a/dnn/src/cuda/elemwise/kimpl/MIN_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/MIN_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu index 465c5019..d859c5ef 100644 --- a/dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/MOD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/MOD_dt_float16.cu index 0f5d6e14..0824033a 100644 --- a/dnn/src/cuda/elemwise/kimpl/MOD_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MOD_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/MOD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/MOD_dt_float32.cu index 38a18d02..db43fa67 100644 --- a/dnn/src/cuda/elemwise/kimpl/MOD_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/MOD_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MOD_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/MOD_dt_int16.cu index 736a4c1a..27bfffd7 100644 --- a/dnn/src/cuda/elemwise/kimpl/MOD_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MOD_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MOD_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/MOD_dt_int32.cu index f4999db2..13237aa5 100644 --- a/dnn/src/cuda/elemwise/kimpl/MOD_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/MOD_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MOD_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/MOD_dt_int8.cu index af16999c..d7ca71bf 100644 --- a/dnn/src/cuda/elemwise/kimpl/MOD_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/MOD_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MOD_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/MOD_dt_uint8.cu index 65841790..1b053c41 100644 --- a/dnn/src/cuda/elemwise/kimpl/MOD_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/MOD_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu index 434b5797..9669033b 100644 --- a/dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/MUL_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/MUL_dt_float16.cu index 8100f209..84febe84 100644 --- a/dnn/src/cuda/elemwise/kimpl/MUL_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MUL_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/MUL_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/MUL_dt_float32.cu index 73293900..6923e702 100644 --- a/dnn/src/cuda/elemwise/kimpl/MUL_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/MUL_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MUL_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/MUL_dt_int16.cu index 8df90a7e..e1ee2a54 100644 --- a/dnn/src/cuda/elemwise/kimpl/MUL_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/MUL_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MUL_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/MUL_dt_int32.cu index 96f7da3d..4e42803d 100644 --- a/dnn/src/cuda/elemwise/kimpl/MUL_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/MUL_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MUL_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/MUL_dt_int8.cu index 5a90184e..13477ce7 100644 --- a/dnn/src/cuda/elemwise/kimpl/MUL_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/MUL_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/MUL_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/MUL_dt_uint8.cu index 334814b5..e9e75e0e 100644 --- a/dnn/src/cuda/elemwise/kimpl/MUL_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/MUL_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu index ef8c4e4d..ec82c156 100644 --- a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_float16.cu index 1ef8ed1d..eb85fefd 100644 --- a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_float32.cu index 290a1a03..d91eae61 100644 --- a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int16.cu index ea506d31..2bf48384 100644 --- a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int32.cu index 6d21f1e5..8b8917ec 100644 --- a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int8.cu index 74dba711..ae880fd3 100644 --- a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_uint8.cu index 927f0fa1..4b99a25e 100644 --- a/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/NEGATE_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu index bcc61f4a..d5a9578a 100644 --- a/dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu +++ b/dnn/src/cuda/elemwise/kimpl/NOT_dt_bool.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NOT, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bool +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bool #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu index 4f1364b4..2b853e83 100644 --- a/dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu +++ b/dnn/src/cuda/elemwise/kimpl/OR_dt_bool.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(OR, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bool +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu index c97c703a..f4d54d99 100644 --- a/dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/POW_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/POW_dt_float16.cu index d4ba6730..7596f061 100644 --- a/dnn/src/cuda/elemwise/kimpl/POW_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/POW_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/POW_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/POW_dt_float32.cu index e9fb788e..6bc2f587 100644 --- a/dnn/src/cuda/elemwise/kimpl/POW_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/POW_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu index 0e6f2080..3eee45d4 100644 --- a/dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/RELU_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/RELU_dt_float16.cu index e5393775..67796a3b 100644 --- a/dnn/src/cuda/elemwise/kimpl/RELU_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/RELU_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/RELU_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/RELU_dt_float32.cu index d18e37c8..fb35a28a 100644 --- a/dnn/src/cuda/elemwise/kimpl/RELU_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/RELU_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/RELU_dt_int16.cu index 3eb24ed4..86be8f18 100644 --- a/dnn/src/cuda/elemwise/kimpl/RELU_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/RELU_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/RELU_dt_int32.cu index 8c11a2e3..ccddb8ec 100644 --- a/dnn/src/cuda/elemwise/kimpl/RELU_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/RELU_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/RELU_dt_int8.cu index 9330078e..f0d6b6bd 100644 --- a/dnn/src/cuda/elemwise/kimpl/RELU_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/RELU_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RELU_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/RELU_dt_uint8.cu index 470bd051..5b51f82f 100644 --- a/dnn/src/cuda/elemwise/kimpl/RELU_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/RELU_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int16.cu index 0f21d7cb..50c6b788 100644 --- a/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int32.cu index 2f125239..4e61d2a5 100644 --- a/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int8.cu index e2229ac1..7544cb18 100644 --- a/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/RMULH_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/RMULH_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/RMULH_dt_uint8.cu index 89e247eb..68d321e6 100644 --- a/dnn/src/cuda/elemwise/kimpl/RMULH_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/RMULH_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu index 786c6b80..eb33fd05 100644 --- a/dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ROUND_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/ROUND_dt_float16.cu index 0e24f548..2b7fa401 100644 --- a/dnn/src/cuda/elemwise/kimpl/ROUND_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/ROUND_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/ROUND_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/ROUND_dt_float32.cu index 9660812d..b1ec559c 100644 --- a/dnn/src/cuda/elemwise/kimpl/ROUND_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/ROUND_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SHL_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/SHL_dt_int16.cu index 1ec354f7..f5fb588e 100644 --- a/dnn/src/cuda/elemwise/kimpl/SHL_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SHL_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SHL_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/SHL_dt_int32.cu index c62bcc4f..4cb47d0c 100644 --- a/dnn/src/cuda/elemwise/kimpl/SHL_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SHL_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SHL_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/SHL_dt_int8.cu index 906d29f3..2089d12b 100644 --- a/dnn/src/cuda/elemwise/kimpl/SHL_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/SHL_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SHL_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/SHL_dt_uint8.cu index 50dae36e..a36efa6d 100644 --- a/dnn/src/cuda/elemwise/kimpl/SHL_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/SHL_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SHR_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/SHR_dt_int16.cu index d9ecc70c..8cc887fe 100644 --- a/dnn/src/cuda/elemwise/kimpl/SHR_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SHR_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SHR_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/SHR_dt_int32.cu index 583a1554..9733945a 100644 --- a/dnn/src/cuda/elemwise/kimpl/SHR_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SHR_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SHR_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/SHR_dt_int8.cu index 6a9bfba6..7b13adbc 100644 --- a/dnn/src/cuda/elemwise/kimpl/SHR_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/SHR_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SHR_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/SHR_dt_uint8.cu index cff0b17b..00330adf 100644 --- a/dnn/src/cuda/elemwise/kimpl/SHR_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/SHR_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu index 9ea6f76d..aabcdbca 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_float16.cu index 4b89026a..4c71d31c 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_float32.cu index cd70a27d..9f67693c 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int16.cu index 65b55d7b..141a15c7 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int32.cu index 21bde467..025d181c 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int8.cu index 3584305f..251b0c22 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_uint8.cu index d339eea4..f809a1f5 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu index 9b1dab64..8a0fdd06 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_float16.cu index baae5803..638e7b22 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_float32.cu index 4b4b1d8f..395b7ad9 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cu index eb3018bc..27c504e4 100644 --- a/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float16.cu index e9960d0d..b50af0c2 100644 --- a/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float32.cu index 004a7e5c..b53debbc 100644 --- a/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SILU_GRAD_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SILU_dt_bfloat16.cu index 27009b26..a315d0e7 100644 --- a/dnn/src/cuda/elemwise/kimpl/SILU_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SILU_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SILU_dt_float16.cu index 1fd1dd0c..5826f5c8 100644 --- a/dnn/src/cuda/elemwise/kimpl/SILU_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SILU_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SILU_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SILU_dt_float32.cu index c66df4bb..ec38040a 100644 --- a/dnn/src/cuda/elemwise/kimpl/SILU_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SILU_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu index 685effd7..8c46337a 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIN_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SIN_dt_float16.cu index fdabffd0..eb14fd25 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIN_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIN_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SIN_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SIN_dt_float32.cu index 2f1ea67c..8a426a5a 100644 --- a/dnn/src/cuda/elemwise/kimpl/SIN_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SIN_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu index a54b0ca7..d86a8861 100644 --- a/dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SUB_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SUB_dt_float16.cu index 129bd04f..fb95252e 100644 --- a/dnn/src/cuda/elemwise/kimpl/SUB_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SUB_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SUB_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SUB_dt_float32.cu index 1b0aec6a..f40558af 100644 --- a/dnn/src/cuda/elemwise/kimpl/SUB_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SUB_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SUB_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/SUB_dt_int16.cu index 957627f1..892eeae9 100644 --- a/dnn/src/cuda/elemwise/kimpl/SUB_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SUB_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SUB_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/SUB_dt_int32.cu index e41c6bcf..1a1b512c 100644 --- a/dnn/src/cuda/elemwise/kimpl/SUB_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SUB_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SUB_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/SUB_dt_int8.cu index 4a0890e4..367a7362 100644 --- a/dnn/src/cuda/elemwise/kimpl/SUB_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/SUB_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SUB_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/SUB_dt_uint8.cu index 33a54a6a..6a6f483e 100644 --- a/dnn/src/cuda/elemwise/kimpl/SUB_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/SUB_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu index c09c4097..ad56a1f5 100644 --- a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_float16.cu index 7fe80c4c..7fd53793 100644 --- a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_float32.cu index 9a759078..275f92c9 100644 --- a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int16.cu index 0d2892f4..44ab9857 100644 --- a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int32.cu index c7f4b26c..07643d09 100644 --- a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int8.cu index 1d4df389..0b86129d 100644 --- a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_uint8.cu index 7c83a5c2..5a9ac11f 100644 --- a/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu index 59b75e81..45cb05c3 100644 --- a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_float16.cu index 5be50c8c..9b778019 100644 --- a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_float32.cu index 0e259719..11d8d5d4 100644 --- a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int16.cu b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int16.cu index 4efd5978..50bdfdc1 100644 --- a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int16.cu +++ b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int16.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int32.cu b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int32.cu index 69202693..830a9757 100644 --- a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int32.cu +++ b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int8.cu b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int8.cu index 448aaf29..b4e1e18d 100644 --- a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int8.cu +++ b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_int8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_int8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_uint8.cu b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_uint8.cu index e1fc7756..09e105ec 100644 --- a/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_uint8.cu +++ b/dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_uint8.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_uint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu index 6cb6ccca..5382d9d1 100644 --- a/dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/TANH_dt_float16.cu index 3c807b09..072dca9c 100644 --- a/dnn/src/cuda/elemwise/kimpl/TANH_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/TANH_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/TANH_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/TANH_dt_float32.cu index 89184efd..4baf7d58 100644 --- a/dnn/src/cuda/elemwise/kimpl/TANH_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/TANH_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu b/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu index 62f886e7..b881c230 100644 --- a/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu +++ b/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bfloat16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bfloat16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_float16.cu b/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_float16.cu index 7e4779c4..d311d7e0 100644 --- a/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_float16.cu +++ b/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_float16.cu @@ -1,7 +1,7 @@ // generated by gen_elemwise_kern_impls.py #if !MEGDNN_DISABLE_FLOAT16 #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float16 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 #include "../kern_impl.inl" #endif diff --git a/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_float32.cu b/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_float32.cu index 6792bbe3..492ce830 100644 --- a/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_float32.cu +++ b/dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_float32.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_float32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu b/dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu index 548a3134..26dab7a8 100644 --- a/dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu +++ b/dnn/src/cuda/elemwise/kimpl/XOR_dt_bool.cu @@ -1,5 +1,5 @@ // generated by gen_elemwise_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(XOR, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_CTYPE dt_bool +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_bool #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise/opr_impl.cpp b/dnn/src/cuda/elemwise/opr_impl.cpp index dfc6eaf2..6d660239 100644 --- a/dnn/src/cuda/elemwise/opr_impl.cpp +++ b/dnn/src/cuda/elemwise/opr_impl.cpp @@ -18,55 +18,55 @@ namespace megdnn { namespace cuda { -#define on_arity_dispatched_cb_dtype(_dt) \ - if (m_dst->layout.dtype == _dt()) { \ - using dtrait = DTypeTrait<_dt>; \ - using ctype = dtrait::ctype; \ - auto stream = cuda_stream(handle()); \ +#define on_arity_dispatched_cb_dtype(_dt) \ + if (m_dst->layout.dtype == _dt()) { \ + using dtrait = DTypeTrait<_dt>; \ + using ctype = dtrait::ctype; \ + auto stream = cuda_stream(handle()); \ return ModeDispatcher::run( \ - src, stream, m_param.mode, m_dst->ptr()); \ + src, stream, m_param.mode, m_dst->ptr()); \ } -#define _cb_dispatch_mode(_m) case Mode::_m: do { \ - using KernImpl = ElemwiseKern< \ - megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, ctype>; \ - using Wrapper = ElemArithKernWrapper; \ - Wrapper wrapper; \ - wrapper.dst = static_cast(dst); \ - return run_elemwise(src, stream, wrapper); \ -} while(0); +#define _cb_dispatch_mode(_m) \ + case Mode::_m: \ + do { \ + using KernImpl = ElemwiseKern< \ + megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, ctype>; \ + using Wrapper = ElemArithKernWrapper; \ + Wrapper wrapper; \ + wrapper.dst = static_cast(dst); \ + return run_elemwise(src, stream, wrapper); \ + } while (0); -#define IMPL_MODE_DISPATCHER(_arity, _dtype_cat) \ -template \ -struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> { \ - static constexpr int arity = _arity; \ - static void run(const ElemwiseOpParamN &src, \ - cudaStream_t stream, Mode mode, void *dst) { \ - switch (mode) { \ - FOREACH(_cb_dispatch_mode) \ - default: \ - megdnn_throw("bad mode"); \ - } \ - } \ -} +#define IMPL_MODE_DISPATCHER(_arity, _dtype_cat) \ + template \ + struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> { \ + static constexpr int arity = _arity; \ + static void run( \ + const ElemwiseOpParamN& src, cudaStream_t stream, Mode mode, \ + void* dst) { \ + switch (mode) { \ + FOREACH(_cb_dispatch_mode) \ + default: \ + megdnn_throw("bad mode"); \ + } \ + } \ + } #include "src/common/elemwise/opr_impl_body.inl" -template -void ElemwiseForwardImpl::impl_fuse_mul_add3( - const ElemwiseOpParamN<3> ¶m) { +template +void ElemwiseForwardImpl::impl_fuse_mul_add3(const ElemwiseOpParamN<3>& param) { kern_fuse_mul_add3( m_dst->ptr(), param, cuda_stream(handle())); } -template -void ElemwiseForwardImpl::impl_fuse_mul_add4( - const ElemwiseOpParamN<4> ¶m) { +template +void ElemwiseForwardImpl::impl_fuse_mul_add4(const ElemwiseOpParamN<4>& param) { kern_fuse_mul_add4(m_dst->ptr(), param, cuda_stream(handle())); } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/elemwise/opr_impl.h b/dnn/src/cuda/elemwise/opr_impl.h index e45e232a..f9b79260 100644 --- a/dnn/src/cuda/elemwise/opr_impl.h +++ b/dnn/src/cuda/elemwise/opr_impl.h @@ -16,12 +16,11 @@ namespace megdnn { namespace cuda { - class ElemwiseForwardImpl final: public ElemwiseForwardImplHelper { +class ElemwiseForwardImpl final : public ElemwiseForwardImplHelper { #include "src/common/elemwise/opr_impl_class_def.inl" - }; +}; -} -} +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/elemwise/special_kerns.cuh b/dnn/src/cuda/elemwise/special_kerns.cuh index 2059c27d..1e012de5 100644 --- a/dnn/src/cuda/elemwise/special_kerns.cuh +++ b/dnn/src/cuda/elemwise/special_kerns.cuh @@ -16,16 +16,15 @@ namespace megdnn { namespace cuda { - template - void kern_fuse_mul_add3(ctype *dest, - const ElemwiseOpParamN<3> ¶m, cudaStream_t stream); +template +void kern_fuse_mul_add3( + ctype* dest, const ElemwiseOpParamN<3>& param, cudaStream_t stream); - template - void kern_fuse_mul_add4(ctype *dest, - const ElemwiseOpParamN<4> ¶m, cudaStream_t stream); +template +void kern_fuse_mul_add4( + ctype* dest, const ElemwiseOpParamN<4>& param, cudaStream_t stream); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen - diff --git a/dnn/src/cuda/elemwise/special_kerns.inl b/dnn/src/cuda/elemwise/special_kerns.inl index da776d7d..c5cffcdc 100644 --- a/dnn/src/cuda/elemwise/special_kerns.inl +++ b/dnn/src/cuda/elemwise/special_kerns.inl @@ -15,238 +15,227 @@ namespace megdnn { namespace cuda { namespace elemwise_intl { - template - struct FuseMulAdd3Op { - typedef ctype* __restrict bufptr_t; - bufptr_t m_dst, m_src2; - - __device__ __forceinline__ void operator()(uint32_t idx, int off0, - int /* off1 */, ctype x, - ctype y) { - m_dst[idx] = x * y + m_src2[c_is_scalar ? 0 : off0]; - } - }; - - template - struct FuseMulAdd3Op::value || - std::is_same::value>::type> { - typedef ctype* __restrict bufptr_t; - typedef typename VectTypeTrait::vect_type vect_type; - bufptr_t m_dst, m_src2; - __device__ __forceinline__ void operator()(uint32_t idx, int off0, int, - ctype x, ctype y) { - m_dst[idx] = x * y + m_src2[0]; - } - __device__ __forceinline__ void operator()(int32_t idx, int off0, int, - vect_type x, vect_type y) { - ctype a = x.x * y.x + m_src2[0]; - ctype b = x.y * y.y + m_src2[0]; - ctype g = x.z * y.z + m_src2[0]; - ctype r = x.w * y.w + m_src2[0]; - *(vect_type*)(&m_dst[idx]) = - VectTypeTrait::make_vector(a, b, g, r); - } - }; - - template - struct FuseMulAdd3Op::value || - std::is_same::value>::type> { - typedef ctype* __restrict bufptr_t; - typedef typename VectTypeTrait::vect_type vect_type; - bufptr_t m_dst, m_src2; - __device__ __forceinline__ void operator()(uint32_t idx, int off0, int, - ctype x, ctype y) { - m_dst[idx] = x * y + m_src2[off0]; - } - __device__ __forceinline__ void operator()(int32_t idx, int off0, int, - vect_type x, vect_type y) { - vect_type z = *(vect_type*)(&m_src2[off0]); - ctype a = x.x * y.x + z.x; - ctype b = x.y * y.y + z.y; - ctype g = x.z * y.z + z.z; - ctype r = x.w * y.w + z.w; - *(vect_type*)(&m_dst[idx]) = - VectTypeTrait::make_vector(a, b, g, r); - } - }; +template +struct FuseMulAdd3Op { + typedef ctype* __restrict bufptr_t; + bufptr_t m_dst, m_src2; + + __device__ __forceinline__ void operator()( + uint32_t idx, int off0, int /* off1 */, ctype x, ctype y) { + m_dst[idx] = x * y + m_src2[c_is_scalar ? 0 : off0]; + } +}; + +template +struct FuseMulAdd3Op< + ctype, true, + typename std::enable_if< + std::is_same::value || + std::is_same::value>::type> { + typedef ctype* __restrict bufptr_t; + typedef typename VectTypeTrait::vect_type vect_type; + bufptr_t m_dst, m_src2; + __device__ __forceinline__ void operator()( + uint32_t idx, int off0, int, ctype x, ctype y) { + m_dst[idx] = x * y + m_src2[0]; + } + __device__ __forceinline__ void operator()( + int32_t idx, int off0, int, vect_type x, vect_type y) { + ctype a = x.x * y.x + m_src2[0]; + ctype b = x.y * y.y + m_src2[0]; + ctype g = x.z * y.z + m_src2[0]; + ctype r = x.w * y.w + m_src2[0]; + *(vect_type*)(&m_dst[idx]) = VectTypeTrait::make_vector(a, b, g, r); + } +}; + +template +struct FuseMulAdd3Op< + ctype, false, + typename std::enable_if< + std::is_same::value || + std::is_same::value>::type> { + typedef ctype* __restrict bufptr_t; + typedef typename VectTypeTrait::vect_type vect_type; + bufptr_t m_dst, m_src2; + __device__ __forceinline__ void operator()( + uint32_t idx, int off0, int, ctype x, ctype y) { + m_dst[idx] = x * y + m_src2[off0]; + } + __device__ __forceinline__ void operator()( + int32_t idx, int off0, int, vect_type x, vect_type y) { + vect_type z = *(vect_type*)(&m_src2[off0]); + ctype a = x.x * y.x + z.x; + ctype b = x.y * y.y + z.y; + ctype g = x.z * y.z + z.z; + ctype r = x.w * y.w + z.w; + *(vect_type*)(&m_dst[idx]) = VectTypeTrait::make_vector(a, b, g, r); + } +}; - template - struct FuseMulAdd4Op { - typedef ctype* __restrict bufptr_t; - bufptr_t m_dst, m_src2, m_src3; +template +struct FuseMulAdd4Op { + typedef ctype* __restrict bufptr_t; + bufptr_t m_dst, m_src2, m_src3; - __device__ __forceinline__ void operator()(uint32_t idx, int off0, int off1, - ctype src0, ctype src1) { - m_dst[idx] = src0 * src1 + m_src2[off0] * m_src3[off1]; - } - }; - - template - struct FuseMulAdd4Op::value || - std::is_same::value>::type> { - typedef ctype* __restrict bufptr_t; - typedef typename VectTypeTrait::vect_type vect_type; - bufptr_t m_dst, m_src2, m_src3; - __device__ __forceinline__ void operator()(uint32_t idx, int off0, - int off1, ctype x, ctype y) { - m_dst[idx] = x * y + m_src2[off0] * m_src3[off1]; - } - __device__ __forceinline__ void operator()(uint32_t idx, int off0, - int off1, vect_type x, - vect_type y) { - vect_type z = *(vect_type*)(&m_src2[off0]); - vect_type w = *(vect_type*)(&m_src3[off1]); - ctype a = x.x * y.x + z.x * w.x; - ctype b = x.y * y.y + z.y * w.y; - ctype g = x.z * y.z + z.z * w.z; - ctype r = x.w * y.w + z.w * w.w; - *(vect_type*)(&m_dst[idx]) = - VectTypeTrait::make_vector(a, b, g, r); - } - }; - - //! wrap an op so the special OpCaller can be selected by template matching - template - class FuseOpWrapper { - const Op& m_op; - - public: - FuseOpWrapper(const Op& op) : m_op(op) {} - - operator const Op&() const { return m_op; } - }; - - template - struct OpCallerBinary, PVis0, PVis1> { - Op op; - PVis0 par0; - PVis1 par1; - MEGDNN_STATIC_ASSERT(PVis0::packed_size == PVis1::packed_size, - "vector size mismatch"); - static const uint32_t packed_size = PVis0::packed_size; - - __device__ __forceinline__ void thread_init(uint32_t idx) { - idx = idx * packed_size; - par0.thread_init(idx); - par1.thread_init(idx); - } + __device__ __forceinline__ void operator()( + uint32_t idx, int off0, int off1, ctype src0, ctype src1) { + m_dst[idx] = src0 * src1 + m_src2[off0] * m_src3[off1]; + } +}; + +template +struct FuseMulAdd4Op< + ctype, typename std::enable_if< + std::is_same::value || + std::is_same::value>::type> { + typedef ctype* __restrict bufptr_t; + typedef typename VectTypeTrait::vect_type vect_type; + bufptr_t m_dst, m_src2, m_src3; + __device__ __forceinline__ void operator()( + uint32_t idx, int off0, int off1, ctype x, ctype y) { + m_dst[idx] = x * y + m_src2[off0] * m_src3[off1]; + } + __device__ __forceinline__ void operator()( + uint32_t idx, int off0, int off1, vect_type x, vect_type y) { + vect_type z = *(vect_type*)(&m_src2[off0]); + vect_type w = *(vect_type*)(&m_src3[off1]); + ctype a = x.x * y.x + z.x * w.x; + ctype b = x.y * y.y + z.y * w.y; + ctype g = x.z * y.z + z.z * w.z; + ctype r = x.w * y.w + z.w * w.w; + *(vect_type*)(&m_dst[idx]) = VectTypeTrait::make_vector(a, b, g, r); + } +}; + +//! wrap an op so the special OpCaller can be selected by template matching +template +class FuseOpWrapper { + const Op& m_op; + +public: + FuseOpWrapper(const Op& op) : m_op(op) {} + + operator const Op&() const { return m_op; } +}; + +template +struct OpCallerBinary, PVis0, PVis1> { + Op op; + PVis0 par0; + PVis1 par1; + MEGDNN_STATIC_ASSERT( + PVis0::packed_size == PVis1::packed_size, "vector size mismatch"); + static const uint32_t packed_size = PVis0::packed_size; + + __device__ __forceinline__ void thread_init(uint32_t idx) { + idx = idx * packed_size; + par0.thread_init(idx); + par1.thread_init(idx); + } - __device__ __forceinline__ void on(uint32_t idx) { - idx = idx * packed_size; - op(idx, par0.offset(idx), par1.offset(idx), par0.at(idx), - par1.at(idx)); - } + __device__ __forceinline__ void on(uint32_t idx) { + idx = idx * packed_size; + op(idx, par0.offset(idx), par1.offset(idx), par0.at(idx), par1.at(idx)); + } - __device__ __forceinline__ void on(uint32_t idx, uint32_t remain) { - idx = idx * packed_size; - if (remain >= packed_size) { - op(idx, par0.offset(idx), par1.offset(idx), par0.at(idx), - par1.at(idx)); - } else { - auto ptr0 = par0.ptr(); - auto ptr1 = par1.ptr(); - for (int i = 0; i < remain; i++) { - op(idx + i, par0.offset(idx + i), par1.offset(idx + i), - ptr0[par0.offset(idx + i)], ptr1[par1.offset(idx + i)]); - } + __device__ __forceinline__ void on(uint32_t idx, uint32_t remain) { + idx = idx * packed_size; + if (remain >= packed_size) { + op(idx, par0.offset(idx), par1.offset(idx), par0.at(idx), par1.at(idx)); + } else { + auto ptr0 = par0.ptr(); + auto ptr1 = par1.ptr(); + for (int i = 0; i < remain; i++) { + op(idx + i, par0.offset(idx + i), par1.offset(idx + i), + ptr0[par0.offset(idx + i)], ptr1[par1.offset(idx + i)]); } } + } - __device__ __forceinline__ void next() { - par0.next(); - par1.next(); - } - }; - - template - struct OpCallerUniform, 2, PVis> { - Op op; - PVis par[2]; - static const uint32_t packed_size = PVis::packed_size; - - __device__ __forceinline__ void thread_init(uint32_t idx) { - idx = idx * packed_size; - par[0].thread_init(idx); - par[1].thread_init(idx); - } + __device__ __forceinline__ void next() { + par0.next(); + par1.next(); + } +}; + +template +struct OpCallerUniform, 2, PVis> { + Op op; + PVis par[2]; + static const uint32_t packed_size = PVis::packed_size; + + __device__ __forceinline__ void thread_init(uint32_t idx) { + idx = idx * packed_size; + par[0].thread_init(idx); + par[1].thread_init(idx); + } - __device__ __forceinline__ void on(uint32_t idx) { - idx = idx * packed_size; + __device__ __forceinline__ void on(uint32_t idx) { + idx = idx * packed_size; + op(idx, par[0].offset(idx), par[1].offset(idx), par[0].at(idx), par[1].at(idx)); + } + + __device__ __forceinline__ void on(uint32_t idx, uint32_t remain) { + idx = idx * packed_size; + if (remain >= packed_size) { op(idx, par[0].offset(idx), par[1].offset(idx), par[0].at(idx), par[1].at(idx)); - } - - __device__ __forceinline__ void on(uint32_t idx, uint32_t remain) { - idx = idx * packed_size; - if (remain >= packed_size) { - op(idx, par[0].offset(idx), par[1].offset(idx), par[0].at(idx), - par[1].at(idx)); - } else { - auto ptr0 = par[0].ptr(); - auto ptr1 = par[1].ptr(); - for (int i = 0; i < remain; i++) { - op(idx + i, par[0].offset(idx + i), par[1].offset(idx + i), - ptr0[par[0].offset(idx + i)], - ptr1[par[1].offset(idx + i)]); - } + } else { + auto ptr0 = par[0].ptr(); + auto ptr1 = par[1].ptr(); + for (int i = 0; i < remain; i++) { + op(idx + i, par[0].offset(idx + i), par[1].offset(idx + i), + ptr0[par[0].offset(idx + i)], ptr1[par[1].offset(idx + i)]); } } + } - __device__ __forceinline__ void next() { - par[0].next(); - par[1].next(); - } - }; + __device__ __forceinline__ void next() { + par[0].next(); + par[1].next(); + } +}; } // namespace elemwise_intl namespace { - template - void run_fuse_elemwise(Op& op, const ElemwiseOpParamN& param, - cudaStream_t stream) { - param.assert_initialized(); - ElemwiseOpParamN<2> p2 = *static_cast*>( - static_cast(¶m)); - elemwise_intl::UserOpInvoker, ctype, 2>( - p2, stream, op); - } +template +void run_fuse_elemwise( + Op& op, const ElemwiseOpParamN& param, cudaStream_t stream) { + param.assert_initialized(); + ElemwiseOpParamN<2> p2 = + *static_cast*>(static_cast(¶m)); + elemwise_intl::UserOpInvoker, ctype, 2>( + p2, stream, op); +} } // anonymous namespace - template - void kern_fuse_mul_add3(ctype* dest, const ElemwiseOpParamN<3>& param, - cudaStream_t stream) { - elemwise_intl::FuseMulAdd3Op op; - op.m_dst = dest; - op.m_src2 = param[2].ptr(); - run_fuse_elemwise(op, param, stream); - } - - template - void kern_fuse_mul_add4(ctype* dest, const ElemwiseOpParamN<4>& param, - cudaStream_t stream) { - elemwise_intl::FuseMulAdd4Op op; - op.m_dst = dest; - op.m_src2 = param[2].ptr(); - op.m_src3 = param[3].ptr(); - run_fuse_elemwise(op, param, stream); - } - -#define INST(_dt) \ - template void kern_fuse_mul_add3(DTypeTrait<_dt>::ctype*, \ - const ElemwiseOpParamN<3>&, \ - cudaStream_t); \ - template void kern_fuse_mul_add3(DTypeTrait<_dt>::ctype*, \ - const ElemwiseOpParamN<3>&, \ - cudaStream_t); \ - template void kern_fuse_mul_add4(DTypeTrait<_dt>::ctype*, \ - const ElemwiseOpParamN<4>&, \ - cudaStream_t); +template +void kern_fuse_mul_add3( + ctype* dest, const ElemwiseOpParamN<3>& param, cudaStream_t stream) { + elemwise_intl::FuseMulAdd3Op op; + op.m_dst = dest; + op.m_src2 = param[2].ptr(); + run_fuse_elemwise(op, param, stream); +} + +template +void kern_fuse_mul_add4( + ctype* dest, const ElemwiseOpParamN<4>& param, cudaStream_t stream) { + elemwise_intl::FuseMulAdd4Op op; + op.m_dst = dest; + op.m_src2 = param[2].ptr(); + op.m_src3 = param[3].ptr(); + run_fuse_elemwise(op, param, stream); +} + +#define INST(_dt) \ + template void kern_fuse_mul_add3( \ + DTypeTrait<_dt>::ctype*, const ElemwiseOpParamN<3>&, cudaStream_t); \ + template void kern_fuse_mul_add3( \ + DTypeTrait<_dt>::ctype*, const ElemwiseOpParamN<3>&, cudaStream_t); \ + template void kern_fuse_mul_add4( \ + DTypeTrait<_dt>::ctype*, const ElemwiseOpParamN<4>&, cudaStream_t); // vim: ft=cuda syntax=cpp.doxygen - diff --git a/dnn/src/cuda/elemwise_helper.cpp b/dnn/src/cuda/elemwise_helper.cpp index 6c23323e..a8c35fe9 100644 --- a/dnn/src/cuda/elemwise_helper.cpp +++ b/dnn/src/cuda/elemwise_helper.cpp @@ -21,8 +21,7 @@ #include #define _cb_check_ndim(n) megdnn::TensorShape::MAX_NDIM == n || -static_assert(MEGDNN_FOREACH_TENSOR_NDIM(_cb_check_ndim) false, - "bad foreach ndim"); +static_assert(MEGDNN_FOREACH_TENSOR_NDIM(_cb_check_ndim) false, "bad foreach ndim"); #undef _cb_check_ndim namespace megdnn { @@ -54,10 +53,8 @@ void ParamVisitorBase::host_init( #pragma GCC diagnostic pop template -void ParamVisitorBase<3, ctype, BCAST_101>::host_init(const TensorND& rv, - int grid_size, - int block_size, - int packed_size) { +void ParamVisitorBase<3, ctype, BCAST_101>::host_init( + const TensorND& rv, int grid_size, int block_size, int packed_size) { uint32_t shape2, shape1; int stride1; if (rv.layout.ndim == 3) { @@ -77,52 +74,43 @@ void ParamVisitorBase<3, ctype, BCAST_101>::host_init(const TensorND& rv, } template -void ParamVisitorBase<2, ctype, BCAST_10>::host_init(const TensorND& rv, - int grid_size, - int block_size, - int packed_size) { +void ParamVisitorBase<2, ctype, BCAST_10>::host_init( + const TensorND& rv, int grid_size, int block_size, int packed_size) { megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[0]); m_ptr = rv.ptr(); m_stride1 = rv.layout.stride[1]; - m_shape1.host_init(packed_size * grid_size * block_size, - rv.layout.shape[1]); + m_shape1.host_init(packed_size * grid_size * block_size, rv.layout.shape[1]); } template -void ParamVisitorBase<2, ctype, BCAST_01>::host_init(const TensorND& rv, - int grid_size, - int block_size, - int packed_size) { +void ParamVisitorBase<2, ctype, BCAST_01>::host_init( + const TensorND& rv, int grid_size, int block_size, int packed_size) { megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[1]); m_ptr = rv.ptr(); m_stride0 = rv.layout.stride[0]; - m_shape1.host_init(packed_size * grid_size * block_size, - rv.layout.shape[1]); + m_shape1.host_init(packed_size * grid_size * block_size, rv.layout.shape[1]); } template -void ParamVisitorBase<1, ctype, BCAST_FULL>::host_init(const TensorND& rv, - int /*grid_size*/, - int /*block_size*/, - int /*packed_size*/) { +void ParamVisitorBase<1, ctype, BCAST_FULL>::host_init( + const TensorND& rv, int /*grid_size*/, int /*block_size*/, + int /*packed_size*/) { megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[0]); m_ptr = rv.ptr(); } template -void ParamVectVisitor<4, ctype, BCAST_1010>::host_init(const TensorND& rv, - int grid_size, - int block_size) { - megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[0] && - !rv.layout.stride[2]); +void ParamVectVisitor<4, ctype, BCAST_1010>::host_init( + const TensorND& rv, int grid_size, int block_size) { + megdnn_assert( + rv.layout.ndim == NDIM && !rv.layout.stride[0] && !rv.layout.stride[2]); m_ptr = rv.ptr(); m_stride1 = rv.layout.stride[1]; m_stride3 = rv.layout.stride[3]; uint32_t shape1 = rv.layout.shape[1]; uint32_t shape2 = rv.layout.shape[2]; uint32_t shape3 = rv.layout.shape[3]; - m_shape123.host_init(packed_size * grid_size * block_size, shape2 * shape3, - shape1); + m_shape123.host_init(packed_size * grid_size * block_size, shape2 * shape3, shape1); m_shape3.host_init(packed_size * grid_size * block_size, shape3); } @@ -271,39 +259,37 @@ void ParamElemVisitor4bitBase::host_init( m_is_physical_contiguous = rv.layout.is_physical_contiguous(); } -#define ndim_cb(_ndim) \ - template class ParamElemVisitor4bitBase<_ndim, BCAST_OTHER>; +#define ndim_cb(_ndim) template class ParamElemVisitor4bitBase<_ndim, BCAST_OTHER>; MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) #undef ndim_cb } // namespace elemwise_intl -void elemwise_intl::get_launch_spec(const void* kern, size_t size, - int* grid_size, int* block_size) { - safe_size_in_kern(size); - auto config = query_launch_config_for_kernel(kern); - *block_size = config.block_size; - int a = size / (config.block_size * 2), - b = (size - 1) / (config.block_size * 3) + 1; - if (current_device_prop().major <= 3) { - // for Kepler, less blocks (more work per thread) is faster - *grid_size = b; - } else { - *grid_size = std::max(a, b); - } - if (!*grid_size) { - *block_size = std::min(std::max(size / 64, 1) * 32, 1024); - *grid_size = std::max(size / *block_size, 1); - } - // because we unroll 3 times in the kernel - megdnn_assert(static_cast(*block_size) * *grid_size * 3 >= - size); +void elemwise_intl::get_launch_spec( + const void* kern, size_t size, int* grid_size, int* block_size) { + safe_size_in_kern(size); + auto config = query_launch_config_for_kernel(kern); + *block_size = config.block_size; + int a = size / (config.block_size * 2), + b = (size - 1) / (config.block_size * 3) + 1; + if (current_device_prop().major <= 3) { + // for Kepler, less blocks (more work per thread) is faster + *grid_size = b; + } else { + *grid_size = std::max(a, b); } - - void elemwise_intl::on_bad_ndim(int ndim) { - megdnn_throw(ssprintf("invalid ndim: %d", ndim)); - MEGDNN_MARK_USED_VAR(ndim); + if (!*grid_size) { + *block_size = std::min(std::max(size / 64, 1) * 32, 1024); + *grid_size = std::max(size / *block_size, 1); } + // because we unroll 3 times in the kernel + megdnn_assert(static_cast(*block_size) * *grid_size * 3 >= size); +} + +void elemwise_intl::on_bad_ndim(int ndim) { + megdnn_throw(ssprintf("invalid ndim: %d", ndim)); + MEGDNN_MARK_USED_VAR(ndim); +} } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/elemwise_helper.cuh b/dnn/src/cuda/elemwise_helper.cuh index 77c55536..c71b5f2e 100644 --- a/dnn/src/cuda/elemwise_helper.cuh +++ b/dnn/src/cuda/elemwise_helper.cuh @@ -14,9 +14,9 @@ #include "src/common/elemwise_helper.cuh" #include "src/cuda/int_fastdiv.cuh" +#include "src/cuda/integer_subbyte_utils.cuh" #include "src/cuda/query_blocksize.cuh" #include "src/cuda/utils.cuh" -#include "src/cuda/integer_subbyte_utils.cuh" /* * please note that all arithmetics on GPU are 32-bit for best performance; this @@ -35,8 +35,7 @@ namespace elemwise_intl { * \param kern kernel function address * \param size total size of elements */ -void get_launch_spec(const void* kern, size_t size, int* grid_size, - int* block_size); +void get_launch_spec(const void* kern, size_t size, int* grid_size, int* block_size); MEGDNN_NORETURN void on_bad_ndim(int ndim); @@ -44,14 +43,7 @@ MEGDNN_NORETURN void on_bad_ndim(int ndim); * \brief broadcast type * BCAST_x[0]x[1]...: x[i] == !stride[i] */ -enum BcastType { - BCAST_OTHER, - BCAST_1010, - BCAST_101, - BCAST_10, - BCAST_01, - BCAST_FULL -}; +enum BcastType { BCAST_OTHER, BCAST_1010, BCAST_101, BCAST_10, BCAST_01, BCAST_FULL }; /*! * \brief read and write type trait for byte width integer type @@ -63,8 +55,8 @@ struct ATTR_ALIGNED(8) half4 { dt_float16 x, y, z, w; }; -__device__ __forceinline__ half4 make_half4(dt_float16 x, dt_float16 y, - dt_float16 z, dt_float16 w) { +__device__ __forceinline__ half4 +make_half4(dt_float16 x, dt_float16 y, dt_float16 z, dt_float16 w) { half4 t; t.x = x, t.y = y, t.z = z, t.w = w; return t; @@ -74,26 +66,23 @@ struct ATTR_ALIGNED(8) bhalf4 { dt_bfloat16 x, y, z, w; }; -__device__ __forceinline__ bhalf4 make_bhalf4(dt_bfloat16 x, dt_bfloat16 y, - dt_bfloat16 z, dt_bfloat16 w) { +__device__ __forceinline__ bhalf4 +make_bhalf4(dt_bfloat16 x, dt_bfloat16 y, dt_bfloat16 z, dt_bfloat16 w) { bhalf4 t; t.x = x, t.y = y, t.z = z, t.w = w; return t; } -#define INST(_ctype, _vect_type) \ - template <> \ - class VectTypeTrait<_ctype> { \ - public: \ - using vect_type = _vect_type; \ - static const size_t packed_size = sizeof(_vect_type) / sizeof(_ctype); \ - static __device__ __forceinline__ vect_type make_vector(_ctype x, \ - _ctype y, \ - _ctype z, \ - _ctype w) { \ - return make_##_vect_type(as_raw(x), as_raw(y), as_raw(z), \ - as_raw(w)); \ - } \ +#define INST(_ctype, _vect_type) \ + template <> \ + class VectTypeTrait<_ctype> { \ + public: \ + using vect_type = _vect_type; \ + static const size_t packed_size = sizeof(_vect_type) / sizeof(_ctype); \ + static __device__ __forceinline__ vect_type \ + make_vector(_ctype x, _ctype y, _ctype z, _ctype w) { \ + return make_##_vect_type(as_raw(x), as_raw(y), as_raw(z), as_raw(w)); \ + } \ } #define as_raw(x) x INST(dt_int8, char4); @@ -124,21 +113,21 @@ struct uint4bx2 { uint8_t x; }; -#define INST(_ctype, _Storage, _vect_type) \ - template <> \ - class VectTypeTrait<_ctype> { \ - public: \ - using Storage = _Storage; \ - static const Storage kMask = 0xf; \ - static const Storage kBits = 4; \ - using vect_type = _vect_type; \ - static const size_t packed_size = 2; \ - static __device__ __forceinline__ vect_type make_vector(Storage x, \ - Storage y) { \ - vect_type t; \ - t.x = (x & kMask) | (y << kBits); \ - return t; \ - } \ +#define INST(_ctype, _Storage, _vect_type) \ + template <> \ + class VectTypeTrait<_ctype> { \ + public: \ + using Storage = _Storage; \ + static const Storage kMask = 0xf; \ + static const Storage kBits = 4; \ + using vect_type = _vect_type; \ + static const size_t packed_size = 2; \ + static __device__ __forceinline__ vect_type \ + make_vector(Storage x, Storage y) { \ + vect_type t; \ + t.x = (x & kMask) | (y << kBits); \ + return t; \ + } \ } INST(dt_qint4, int8_t, int4bx2); INST(dt_quint4, uint8_t, uint4bx2); @@ -218,8 +207,7 @@ protected: public: static const int NDIM = ndim; - void host_init(const TensorND& rv, int grid_size, int block_size, - int packed_size); + void host_init(const TensorND& rv, int grid_size, int block_size, int packed_size); #if MEGDNN_CC_CUDA devfunc void thread_init(uint32_t) {} @@ -272,8 +260,7 @@ protected: public: static const int NDIM = 3; - void host_init(const TensorND& rv, int grid_size, int block_size, - int packed_size); + void host_init(const TensorND& rv, int grid_size, int block_size, int packed_size); #if MEGDNN_CC_CUDA devfunc void thread_init(uint32_t idx) { m_shape12.device_init(idx); } @@ -314,8 +301,7 @@ protected: public: static const int NDIM = 2; - void host_init(const TensorND& rv, int grid_size, int block_size, - int packed_size); + void host_init(const TensorND& rv, int grid_size, int block_size, int packed_size); #if MEGDNN_CC_CUDA devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); } @@ -356,8 +342,7 @@ protected: public: static const int NDIM = 2; - void host_init(const TensorND& rv, int grid_size, int block_size, - int packed_size); + void host_init(const TensorND& rv, int grid_size, int block_size, int packed_size); #if MEGDNN_CC_CUDA devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); } @@ -392,8 +377,7 @@ public: static const int NDIM = 1; PARAM_ELEM_VISITOR_COMMON_HOST - void host_init(const TensorND& rv, int grid_size, int block_size, - int packed_size); + void host_init(const TensorND& rv, int grid_size, int block_size, int packed_size); #if MEGDNN_CC_CUDA devfunc void thread_init(uint32_t) {} @@ -461,25 +445,23 @@ INST_PARAM_VECT_VISITOR; #define _brdcast_mask BCAST_101 INST_PARAM_VECT_VISITOR; #undef _brdcast_mask -#define INST_DT_IBYTE(ctype) \ - template \ - class ParamVectVisitor \ - : public ParamVisitorBase { \ - public: \ - using Super = ParamVisitorBase; \ - using rwtype = typename VectTypeTrait::vect_type; \ - static const int packed_size = sizeof(rwtype) / sizeof(ctype); \ - void host_init(const TensorND& rv, int grid_size, int block_size) { \ - ParamVisitorBase::host_init( \ - rv, grid_size, block_size, packed_size); \ - } \ - DEVICE_WRAPPER(rwtype vect_scalar; \ - devfunc rwtype & at(uint32_t /* idx */) { \ - ctype v = Super::m_ptr[0]; \ - vect_scalar = VectTypeTrait::make_vector( \ - v, v, v, v); \ - return vect_scalar; \ - }) \ +#define INST_DT_IBYTE(ctype) \ + template \ + class ParamVectVisitor \ + : public ParamVisitorBase { \ + public: \ + using Super = ParamVisitorBase; \ + using rwtype = typename VectTypeTrait::vect_type; \ + static const int packed_size = sizeof(rwtype) / sizeof(ctype); \ + void host_init(const TensorND& rv, int grid_size, int block_size) { \ + ParamVisitorBase::host_init( \ + rv, grid_size, block_size, packed_size); \ + } \ + DEVICE_WRAPPER(rwtype vect_scalar; devfunc rwtype & at(uint32_t /* idx */) { \ + ctype v = Super::m_ptr[0]; \ + vect_scalar = VectTypeTrait::make_vector(v, v, v, v); \ + return vect_scalar; \ + }) \ } INST_DT_IBYTE(dt_int8); INST_DT_IBYTE(dt_uint8); @@ -564,8 +546,7 @@ public: devfunc void next() {} - devfunc void get_shape_from_access(uint32_t access_idx, - int (&shape_idx)[ndim]) { + devfunc void get_shape_from_access(uint32_t access_idx, int (&shape_idx)[ndim]) { #pragma unroll for (int i = ndim - 1; i >= 1; --i) { Uint32Fastdiv& align_shp = m_align_shape_highdim[i - 1]; @@ -802,8 +783,7 @@ struct OpCallerUniform { auto ptr0 = par[0].ptr(); auto ptr1 = par[1].ptr(); for (int i = 0; i < remain; i++) { - op(idx + i, ptr0[par[0].offset(idx + i)], - ptr1[par[1].offset(idx + i)]); + op(idx + i, ptr0[par[0].offset(idx + i)], ptr1[par[1].offset(idx + i)]); } } } @@ -841,8 +821,8 @@ struct OpCallerUniform { auto ptr1 = par[1].ptr(); auto ptr2 = par[2].ptr(); for (int i = 0; i < remain; i++) { - op(idx + i, ptr0[par[0].offset(idx + i)], - ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)]); + op(idx + i, ptr0[par[0].offset(idx + i)], ptr1[par[1].offset(idx + i)], + ptr2[par[2].offset(idx + i)]); } } } @@ -877,17 +857,15 @@ struct OpCallerUniform { devfunc void on(uint32_t idx, uint32_t remain) { idx = idx * packed_size; if (remain >= packed_size) { - op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), - par[3].at(idx)); + op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx)); } else { auto ptr0 = par[0].ptr(); auto ptr1 = par[1].ptr(); auto ptr2 = par[2].ptr(); auto ptr3 = par[3].ptr(); for (int i = 0; i < remain; i++) { - op(idx + i, ptr0[par[0].offset(idx + i)], - ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], - ptr3[par[3].offset(idx + i)]); + op(idx + i, ptr0[par[0].offset(idx + i)], ptr1[par[1].offset(idx + i)], + ptr2[par[2].offset(idx + i)], ptr3[par[3].offset(idx + i)]); } } } @@ -925,8 +903,8 @@ struct OpCallerUniform { devfunc void on(uint32_t idx, uint32_t remain) { idx = idx * packed_size; if (remain >= packed_size) { - op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), - par[3].at(idx), par[4].at(idx)); + op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx), + par[4].at(idx)); } else { auto ptr0 = par[0].ptr(); auto ptr1 = par[1].ptr(); @@ -934,9 +912,9 @@ struct OpCallerUniform { auto ptr3 = par[3].ptr(); auto ptr4 = par[4].ptr(); for (int i = 0; i < remain; i++) { - op(idx + i, ptr0[par[0].offset(idx + i)], - ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], - ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)]); + op(idx + i, ptr0[par[0].offset(idx + i)], ptr1[par[1].offset(idx + i)], + ptr2[par[2].offset(idx + i)], ptr3[par[3].offset(idx + i)], + ptr4[par[4].offset(idx + i)]); } } } @@ -950,7 +928,6 @@ struct OpCallerUniform { } }; - //! specialization for arity == 6 template struct OpCallerUniform { @@ -977,8 +954,8 @@ struct OpCallerUniform { devfunc void on(uint32_t idx, uint32_t remain) { idx = idx * packed_size; if (remain >= packed_size) { - op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), - par[3].at(idx), par[4].at(idx), par[5].at(idx)); + op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx), + par[4].at(idx), par[5].at(idx)); } else { auto ptr0 = par[0].ptr(); auto ptr1 = par[1].ptr(); @@ -987,10 +964,9 @@ struct OpCallerUniform { auto ptr4 = par[4].ptr(); auto ptr5 = par[5].ptr(); for (int i = 0; i < remain; i++) { - op(idx + i, ptr0[par[0].offset(idx + i)], - ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], - ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)], - ptr5[par[5].offset(idx + i)]); + op(idx + i, ptr0[par[0].offset(idx + i)], ptr1[par[1].offset(idx + i)], + ptr2[par[2].offset(idx + i)], ptr3[par[3].offset(idx + i)], + ptr4[par[4].offset(idx + i)], ptr5[par[5].offset(idx + i)]); } } } @@ -1005,7 +981,6 @@ struct OpCallerUniform { } }; - //! specialization for arity == 7 template struct OpCallerUniform { @@ -1033,8 +1008,8 @@ struct OpCallerUniform { devfunc void on(uint32_t idx, uint32_t remain) { idx = idx * packed_size; if (remain >= packed_size) { - op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), - par[3].at(idx), par[4].at(idx), par[5].at(idx), par[6].at(idx)); + op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx), + par[4].at(idx), par[5].at(idx), par[6].at(idx)); } else { auto ptr0 = par[0].ptr(); auto ptr1 = par[1].ptr(); @@ -1044,10 +1019,10 @@ struct OpCallerUniform { auto ptr5 = par[5].ptr(); auto ptr6 = par[6].ptr(); for (int i = 0; i < remain; i++) { - op(idx + i, ptr0[par[0].offset(idx + i)], - ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], - ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)], - ptr5[par[5].offset(idx + i)], ptr6[par[6].offset(idx + i)]); + op(idx + i, ptr0[par[0].offset(idx + i)], ptr1[par[1].offset(idx + i)], + ptr2[par[2].offset(idx + i)], ptr3[par[3].offset(idx + i)], + ptr4[par[4].offset(idx + i)], ptr5[par[5].offset(idx + i)], + ptr6[par[6].offset(idx + i)]); } } } @@ -1072,8 +1047,8 @@ struct OpCallerBinary { Op op; PVis0 par0; PVis1 par1; - MEGDNN_STATIC_ASSERT(PVis0::packed_size == PVis1::packed_size, - "vector size mismatch") + MEGDNN_STATIC_ASSERT( + PVis0::packed_size == PVis1::packed_size, "vector size mismatch") static const uint32_t packed_size = PVis0::packed_size; @@ -1118,8 +1093,7 @@ __global__ void cuda_kern(OpCaller op_caller, uint32_t size) { } template -__global__ void cuda_kern(OpCallerUniform op_caller, - uint32_t size) { +__global__ void cuda_kern(OpCallerUniform op_caller, uint32_t size) { constexpr uint32_t packed_size = PVis::packed_size; const uint32_t size_packed = DIVUP(size, packed_size); uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x, @@ -1163,14 +1137,13 @@ class UserOpInvokerToSameNdim { template void dispatch1() { - typedef OpCallerUniform> + typedef OpCallerUniform> Caller; size_t size = m_param.size; int grid_size, block_size; void (*fptr)(Caller, uint32_t) = cuda_kern; - get_launch_spec(reinterpret_cast(fptr), size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), size, &grid_size, &block_size); Caller caller; caller.op = m_op; @@ -1181,8 +1154,8 @@ class UserOpInvokerToSameNdim { } public: - UserOpInvokerToSameNdim(const ElemwiseOpParamN& param, - cudaStream_t stream, const Op& op) + UserOpInvokerToSameNdim( + const ElemwiseOpParamN& param, cudaStream_t stream, const Op& op) : m_param(param), m_stream(stream), m_op(op) { dispatch0(); } @@ -1191,12 +1164,9 @@ public: template class UserOpInvokerToSameNdimIByteHelper { public: - UserOpInvokerToSameNdimIByteHelper(const ElemwiseOpParamN& param, - cudaStream_t stream, const Op& op) - : m_rw_size(param.size), - m_param(param), - m_stream(stream), - m_op(op) { + UserOpInvokerToSameNdimIByteHelper( + const ElemwiseOpParamN& param, cudaStream_t stream, const Op& op) + : m_rw_size(param.size), m_param(param), m_stream(stream), m_op(op) { if (!try_vect_load_store_contiguous() && !try_vect_load_store()) { dispatch0(); } @@ -1238,8 +1208,8 @@ private: size_t size = m_rw_size; int grid_size, block_size; void (*fptr)(Caller, uint32_t) = cuda_kern; - get_launch_spec(reinterpret_cast(fptr), size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), size, &grid_size, &block_size); Caller caller; caller.op = m_op; @@ -1256,8 +1226,8 @@ private: size_t size = m_rw_size; int grid_size, block_size; void (*fptr)(Caller, uint32_t) = cuda_kern; - get_launch_spec(reinterpret_cast(fptr), size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), size, &grid_size, &block_size); Caller caller; caller.op = m_op; for (int i = 0; i < arity; ++i) @@ -1273,8 +1243,8 @@ private: size_t size = m_rw_size; int grid_size, block_size; void (*fptr)(Caller, uint32_t) = cuda_kern; - get_launch_spec(reinterpret_cast(fptr), size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), size, &grid_size, &block_size); Caller caller; caller.op = m_op; for (int i = 0; i < arity; ++i) @@ -1321,8 +1291,9 @@ private: using Super = UserOpInvokerToSameNdimIByteHelper; \ \ public: \ - UserOpInvokerToSameNdim(const ElemwiseOpParamN& param, \ - cudaStream_t stream, const Op& op) \ + UserOpInvokerToSameNdim( \ + const ElemwiseOpParamN& param, cudaStream_t stream, \ + const Op& op) \ : Super{param, stream, op} {} \ } INST_DT_IBYTE(dt_int8); @@ -1336,8 +1307,8 @@ INST_DT_IBYTE(dt_bool); template class UserOpInvoker : public UserOpInvokerToSameNdim { public: - UserOpInvoker(const ElemwiseOpParamN& param, cudaStream_t stream, - const Op& op) + UserOpInvoker( + const ElemwiseOpParamN& param, cudaStream_t stream, const Op& op) : UserOpInvokerToSameNdim(param, stream, op) {} }; @@ -1345,16 +1316,15 @@ public: template class UserOpInvoker { public: - UserOpInvoker(const ElemwiseOpParamN<0>& param, cudaStream_t stream, - const Op& op) { + UserOpInvoker(const ElemwiseOpParamN<0>& param, cudaStream_t stream, const Op& op) { size_t size = param.size; typedef OpCallerNull Caller; Caller caller; caller.op = op; int grid_size, block_size; void (*fptr)(Caller, uint32_t) = cuda_kern; - get_launch_spec(reinterpret_cast(fptr), size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), size, &grid_size, &block_size); (*fptr)<<>>(caller, size); after_kernel_launch(); } @@ -1414,8 +1384,7 @@ class UserOpInvoker { #define cb_header(ndim) void dispatch1_##ndim() #define cb_dispatch(ndim, brdcast_mask) \ dispatch2>() - DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, - m_param[0].layout.stride) + DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, m_param[0].layout.stride) #undef cb_header #undef cb_dispatch @@ -1436,8 +1405,7 @@ class UserOpInvoker { void dispatch3_##ndim() #define cb_dispatch(ndim, brdcast_mask) \ do_run>() - DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, - m_param[1].layout.stride) + DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, m_param[1].layout.stride) #undef cb_header #undef cb_dispatch @@ -1449,8 +1417,8 @@ class UserOpInvoker { int grid_size, block_size; void (*fptr)(Caller, uint32_t) = cuda_kern; size_t size = m_param.size; - get_launch_spec(reinterpret_cast(fptr), size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), size, &grid_size, &block_size); Caller caller; caller.op = m_op; caller.par0.host_init(m_param[0], grid_size, block_size); @@ -1460,8 +1428,7 @@ class UserOpInvoker { } public: - UserOpInvoker(const ElemwiseOpParamN<2>& param, cudaStream_t stream, - const Op& op) + UserOpInvoker(const ElemwiseOpParamN<2>& param, cudaStream_t stream, const Op& op) : m_param(param), m_stream(stream), m_op(op) { m_invoked = false; dispatch0(); @@ -1469,29 +1436,27 @@ public: } }; -#define INST_DT_TYPE(ctype) \ - template \ - class UserOpInvoker \ - : public UserOpInvokerToSameNdim { \ - public: \ - UserOpInvoker(const ElemwiseOpParamN<2>& param, cudaStream_t stream, \ - const Op& op) \ - : UserOpInvokerToSameNdim(param, stream, op) {} \ +#define INST_DT_TYPE(ctype) \ + template \ + class UserOpInvoker : public UserOpInvokerToSameNdim { \ + public: \ + UserOpInvoker( \ + const ElemwiseOpParamN<2>& param, cudaStream_t stream, const Op& op) \ + : UserOpInvokerToSameNdim(param, stream, op) {} \ } INST_DT_TYPE(dt_qint4); INST_DT_TYPE(dt_quint4); #undef INST_DT_TYPE -#define DEFINE_VECT_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, \ - _stride) \ - DEFINE_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, _stride) \ - _cb_header(4) { \ - const ptrdiff_t* stride = _stride; \ - if (!stride[0] && stride[1] && !stride[2] && stride[3]) { \ - return _cb_dispatch(4, BCAST_1010); \ - } \ - _cb_dispatch(4, BCAST_OTHER); \ +#define DEFINE_VECT_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, _stride) \ + DEFINE_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, _stride) \ + _cb_header(4) { \ + const ptrdiff_t* stride = _stride; \ + if (!stride[0] && stride[1] && !stride[2] && stride[3]) { \ + return _cb_dispatch(4, BCAST_1010); \ + } \ + _cb_dispatch(4, BCAST_OTHER); \ } template @@ -1563,8 +1528,8 @@ private: size_t size = m_rw_size; int grid_size, block_size; void (*fptr)(Caller, uint32_t) = cuda_kern; - get_launch_spec(reinterpret_cast(fptr), size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), size, &grid_size, &block_size); Caller caller; caller.op = m_op; @@ -1583,16 +1548,15 @@ private: #define cb_header(ndim) void dispatch1_##ndim() #define cb_dispatch(ndim, brdcast_mask) \ dispatch2>() - DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, - m_param[0].layout.stride) + DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, m_param[0].layout.stride) #undef cb_header #undef cb_dispatch #define cb_header(ndim) void dispatch1_vect_##ndim() #define cb_dispatch(ndim, brdcast_mask) \ dispatch2_vect>() - DEFINE_VECT_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, - m_param[0].layout.stride) + DEFINE_VECT_BRDCAST_DISPATCH_RECEIVERS( + cb_header, cb_dispatch, m_param[0].layout.stride) #undef cb_header #undef cb_dispatch @@ -1627,8 +1591,7 @@ private: void dispatch3_##ndim() #define cb_dispatch(ndim, brdcast_mask) \ do_run>() - DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, - m_param[1].layout.stride) + DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, m_param[1].layout.stride) #undef cb_header #undef cb_dispatch @@ -1637,8 +1600,8 @@ private: void dispatch3_vect_##ndim() #define cb_dispatch(ndim, brdcast_mask) \ do_run>() - DEFINE_VECT_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, - m_param[1].layout.stride) + DEFINE_VECT_BRDCAST_DISPATCH_RECEIVERS( + cb_header, cb_dispatch, m_param[1].layout.stride) #undef cb_header #undef cb_dispatch @@ -1650,8 +1613,8 @@ private: int grid_size, block_size; void (*fptr)(Caller, uint32_t) = cuda_kern; size_t size = m_rw_size; - get_launch_spec(reinterpret_cast(fptr), size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), size, &grid_size, &block_size); Caller caller; caller.op = m_op; caller.par0.host_init(m_param[0], grid_size, block_size); @@ -1661,12 +1624,9 @@ private: } public: - UserOpInvokerBinaryIByteHelper(const ElemwiseOpParamN<2>& param, - cudaStream_t stream, const Op& op) - : m_rw_size(param.size), - m_param(param), - m_stream(stream), - m_op(op) { + UserOpInvokerBinaryIByteHelper( + const ElemwiseOpParamN<2>& param, cudaStream_t stream, const Op& op) + : m_rw_size(param.size), m_param(param), m_stream(stream), m_op(op) { m_invoked = false; if (!try_vect_load_store_contiguous() && !try_vect_load_store()) { dispatch0(); @@ -1675,16 +1635,16 @@ public: } }; -#define INST_DT_IBYTE(ctype) \ - template \ - class UserOpInvoker \ - : public UserOpInvokerBinaryIByteHelper { \ - using Super = UserOpInvokerBinaryIByteHelper; \ - \ - public: \ - UserOpInvoker(const ElemwiseOpParamN<2>& param, cudaStream_t stream, \ - const Op& op) \ - : Super{param, stream, op} {} \ +#define INST_DT_IBYTE(ctype) \ + template \ + class UserOpInvoker \ + : public UserOpInvokerBinaryIByteHelper { \ + using Super = UserOpInvokerBinaryIByteHelper; \ + \ + public: \ + UserOpInvoker( \ + const ElemwiseOpParamN<2>& param, cudaStream_t stream, const Op& op) \ + : Super{param, stream, op} {} \ } INST_DT_IBYTE(dt_int8); INST_DT_IBYTE(dt_uint8); @@ -1718,13 +1678,13 @@ INST_DT_IBYTE(dt_bool); * should be implemented */ template -void run_elemwise(const ElemwiseOpParamN& param, cudaStream_t stream, - const Op& op = Op()); +void run_elemwise( + const ElemwiseOpParamN& param, cudaStream_t stream, const Op& op = Op()); #if MEGDNN_CC_CUDA template -void run_elemwise(const ElemwiseOpParamN& param, cudaStream_t stream, - const Op& op) { +void run_elemwise( + const ElemwiseOpParamN& param, cudaStream_t stream, const Op& op) { param.assert_initialized(); elemwise_intl::UserOpInvoker(param, stream, op); } diff --git a/dnn/src/cuda/elemwise_helper_q4.cuh b/dnn/src/cuda/elemwise_helper_q4.cuh index af21f725..1bf3f1df 100644 --- a/dnn/src/cuda/elemwise_helper_q4.cuh +++ b/dnn/src/cuda/elemwise_helper_q4.cuh @@ -24,14 +24,16 @@ namespace cuda { template struct IsNotTypeQ4 { - static constexpr bool value = !(std::is_same::value || - std::is_same::value); + static constexpr bool value = + !(std::is_same::value || + std::is_same::value); }; template struct IsTypeQ4 { - static constexpr bool value = (std::is_same::value || - std::is_same::value); + static constexpr bool value = + (std::is_same::value || + std::is_same::value); }; //! internals for element-wise @@ -111,9 +113,7 @@ struct OpCallerToQ4 { PVisSrc par_src[1]; PVisDst par_dst[1]; - devfunc void on(uint32_t access_idx) { - op(access_idx, par_src[0].at(access_idx)); - } + devfunc void on(uint32_t access_idx) { op(access_idx, par_src[0].at(access_idx)); } }; //! specialization for arity == 2 template @@ -161,8 +161,7 @@ __global__ void cuda_kern_q4(OpCaller op_caller, uint32_t size) { /* f{{{ UserOpInvoker specializations */ //! run op by promoting all params to same ndim -template +template class UserOpInvokerQ4 { const ElemwiseOpParamN& m_src_param; const ElemwiseOpParamN<1>& m_dst_param; @@ -186,16 +185,16 @@ class UserOpInvokerQ4 { BetweenQ4, ParamVectVisitor, ParamElemVisitor>::type; - typedef OpCallerToQ4, - BetweenQ4> + typedef OpCallerToQ4< + Op, arity, PVisSrc, ParamVectVisitor, + BetweenQ4> Caller; size_t size = m_dst_param[0].layout.access_bytes(); int grid_size, block_size; void (*fptr)(Caller, uint32_t) = cuda_kern_q4; - get_launch_spec(reinterpret_cast(fptr), size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), size, &grid_size, &block_size); Caller caller; caller.op = m_op; @@ -207,9 +206,9 @@ class UserOpInvokerQ4 { } public: - UserOpInvokerQ4(const ElemwiseOpParamN& src_param, - const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream, - const Op& op) + UserOpInvokerQ4( + const ElemwiseOpParamN& src_param, + const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream, const Op& op) : m_src_param(src_param), m_dst_param(dst_param), m_stream(stream), @@ -224,30 +223,30 @@ public: } // namespace elemwise_intl template -void run_elemwise(const ElemwiseOpParamN& src_param, - const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream, - const Op& op = Op()); +void run_elemwise( + const ElemwiseOpParamN& src_param, const ElemwiseOpParamN<1>& dst_param, + cudaStream_t stream, const Op& op = Op()); #if MEGDNN_CC_CUDA template -void run_elemwise(const ElemwiseOpParamN& src_param, - const ElemwiseOpParamN<1>& dst_param, cudaStream_t stream, - const Op& op) { +void run_elemwise( + const ElemwiseOpParamN& src_param, const ElemwiseOpParamN<1>& dst_param, + cudaStream_t stream, const Op& op) { src_param.assert_initialized(); dst_param.assert_initialized(); // TODO: Maybe 2bit? megdnn_assert(dst_param[0].layout.dtype.is_low_bit()); megdnn_assert(dst_param[0].layout.is_contiguous()); - elemwise_intl::UserOpInvokerQ4::value>( + elemwise_intl::UserOpInvokerQ4< + Op, src_ctype, dst_ctype, arity, IsTypeQ4::value>( src_param, dst_param, stream, op); } -#define INST_RUN_ELEMWISE_LOWBIT(Op, src_ctype, dst_ctype, arity) \ - template void run_elemwise( \ - const ElemwiseOpParamN&, const ElemwiseOpParamN<1>&, \ - cudaStream_t, const Op&) +#define INST_RUN_ELEMWISE_LOWBIT(Op, src_ctype, dst_ctype, arity) \ + template void run_elemwise( \ + const ElemwiseOpParamN&, const ElemwiseOpParamN<1>&, cudaStream_t, \ + const Op&) #endif } // namespace cuda diff --git a/dnn/src/cuda/elemwise_multi_type/kern.cu b/dnn/src/cuda/elemwise_multi_type/kern.cu index d64ba074..0bbfba78 100644 --- a/dnn/src/cuda/elemwise_multi_type/kern.cu +++ b/dnn/src/cuda/elemwise_multi_type/kern.cu @@ -24,8 +24,8 @@ void elemwise_multi_type::fma3_int16x32x32x32_1c1( typedef Fma3Int16x32x32x32Bcast101Op Caller; void (*fptr)(Caller, uint32_t) = cuda_kern; int grid_size, block_size; - get_launch_spec(reinterpret_cast(fptr), param.size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), param.size, &grid_size, &block_size); Caller caller; caller.a.host_init(param[0], grid_size, block_size); @@ -43,8 +43,8 @@ void elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar( typedef RoundShrSaturateIXxBcastScalarOp Caller; void (*fptr)(Caller, uint32_t) = cuda_kern; int grid_size, block_size; - get_launch_spec(reinterpret_cast(fptr), param.size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), param.size, &grid_size, &block_size); Caller caller; caller.a.host_init(param[0], grid_size, block_size); @@ -55,18 +55,16 @@ void elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar( after_kernel_launch(); } -#define INST(stype) \ - template void \ - elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar( \ +#define INST(stype) \ + template void elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar( \ const ElemwiseOpParamN<2>& param, dt_int8*, cudaStream_t) INST(int32_t); INST(int16_t); INST(int8_t); #undef INST -#define INST(stype) \ - template void \ - elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar( \ +#define INST(stype) \ + template void elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar( \ const ElemwiseOpParamN<2>& param, dt_int16*, cudaStream_t) INST(int32_t); INST(int16_t); @@ -78,8 +76,8 @@ void elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11( typedef FuseAddRmulhRoundingShrBcastScalarOp Caller; void (*fptr)(Caller, uint32_t) = cuda_kern; int grid_size, block_size; - get_launch_spec(reinterpret_cast(fptr), param.size, &grid_size, - &block_size); + get_launch_spec( + reinterpret_cast(fptr), param.size, &grid_size, &block_size); Caller caller; caller.x.host_init(param[0], grid_size, block_size); diff --git a/dnn/src/cuda/elemwise_multi_type/kern.cuh b/dnn/src/cuda/elemwise_multi_type/kern.cuh index 89170ee4..a1657ccc 100644 --- a/dnn/src/cuda/elemwise_multi_type/kern.cuh +++ b/dnn/src/cuda/elemwise_multi_type/kern.cuh @@ -11,26 +11,26 @@ #pragma once #include "include/megdnn/thin/small_vector.h" +#include "src/common/elemwise/kern_defs.cuh" #include "src/common/elemwise_helper.cuh" #include "src/cuda/utils.cuh" -#include "src/common/elemwise/kern_defs.cuh" namespace megdnn { namespace cuda { namespace elemwise_multi_type { //! a * b + c, where a is [s0, s1, s2] and b, c both [1, s1, 1] -void fma3_int16x32x32x32_1c1(const ElemwiseOpParamN<3>& param, dt_int32* dst, - cudaStream_t stream); +void fma3_int16x32x32x32_1c1( + const ElemwiseOpParamN<3>& param, dt_int32* dst, cudaStream_t stream); //! a * b + c, where a is [m, n] and b, c both [1, n]; m can be 1 template -void fma3_iXxf32xf32xi8_bcast_1x(const stype* a, const float* b, const float* c, - dt_int8* dst, uint32_t m, uint32_t n, - cudaStream_t stream); +void fma3_iXxf32xf32xi8_bcast_1x( + const stype* a, const float* b, const float* c, dt_int8* dst, uint32_t m, + uint32_t n, cudaStream_t stream); template -void round_shr_saturate_iXxi8xiX_scalar(const ElemwiseOpParamN<2>& param, - dst_ctype* dst, cudaStream_t stream); +void round_shr_saturate_iXxi8xiX_scalar( + const ElemwiseOpParamN<2>& param, dst_ctype* dst, cudaStream_t stream); template void fuse_add_rmulh_round_shr_saturate_bcast_1c11( diff --git a/dnn/src/cuda/elemwise_multi_type/kern_iXxf32xf32xi8.cu b/dnn/src/cuda/elemwise_multi_type/kern_iXxf32xf32xi8.cu index 03138ce9..146992f0 100644 --- a/dnn/src/cuda/elemwise_multi_type/kern_iXxf32xf32xi8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kern_iXxf32xf32xi8.cu @@ -24,8 +24,8 @@ struct __builtin_align__(sizeof(T) * 4) Packed4 { }; template -__global__ void kern_1d(const stype* x, const float* k, const float* b, - dtype* y, uint32_t n) { +__global__ void kern_1d( + const stype* x, const float* k, const float* b, dtype* y, uint32_t n) { elemwise_multi_type::Fma3iXxf32xf32xiYOp op; uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; if (i < n) { @@ -34,8 +34,9 @@ __global__ void kern_1d(const stype* x, const float* k, const float* b, } template -void invoke_kern_1d(const stype* x, const float* k, const float* b, dtype* y, - uint32_t n, cudaStream_t stream) { +void invoke_kern_1d( + const stype* x, const float* k, const float* b, dtype* y, uint32_t n, + cudaStream_t stream) { dim3 threads = NR_THREADS; dim3 blocks = DIVUP(n, NR_THREADS); kern_1d<<>>(x, k, b, y, n); @@ -43,8 +44,9 @@ void invoke_kern_1d(const stype* x, const float* k, const float* b, dtype* y, } template -__global__ void kern_2d_fallback(const stype* x, const float* k, const float* b, - dtype* y, uint32_t m, uint32_t n) { +__global__ void kern_2d_fallback( + const stype* x, const float* k, const float* b, dtype* y, uint32_t m, + uint32_t n) { uint32_t i = threadIdx.y + blockIdx.y * blockDim.y; uint32_t j = threadIdx.x + blockIdx.x * blockDim.x; elemwise_multi_type::Fma3iXxf32xf32xiYOp op; @@ -54,10 +56,9 @@ __global__ void kern_2d_fallback(const stype* x, const float* k, const float* b, } template -__global__ void kern_2d_mul4(const stype* __restrict x, - const float* __restrict k, - const float* __restrict b, dtype* y_, uint32_t m, - uint32_t n) { +__global__ void kern_2d_mul4( + const stype* __restrict x, const float* __restrict k, const float* __restrict b, + dtype* y_, uint32_t m, uint32_t n) { uint32_t i = threadIdx.y + blockIdx.y * blockDim.y; uint32_t j = threadIdx.x + blockIdx.x * blockDim.x; elemwise_multi_type::Fma3iXxf32xf32xiYOp op; @@ -85,8 +86,9 @@ __global__ void kern_2d_mul4(const stype* __restrict x, } template -void invoke_kern_2d(const stype* x, const float* k, const float* b, dtype* y, - uint32_t m, uint32_t n, cudaStream_t stream) { +void invoke_kern_2d( + const stype* x, const float* k, const float* b, dtype* y, uint32_t m, + uint32_t n, cudaStream_t stream) { if (n % 4 == 0 && is_same::value) { dim3 threads(NR_THREADS_X, NR_THREADS_Y); dim3 blocks(DIVUP(n / 4, NR_THREADS_X), DIVUP(m, NR_THREADS_Y)); @@ -118,11 +120,10 @@ void cuda::elemwise_multi_type::fma3_iXxf32xf32xi8_bcast_1x( } } -#define INST(stype) \ - template void \ - cuda::elemwise_multi_type::fma3_iXxf32xf32xi8_bcast_1x( \ - const stype*, const float*, const float*, dt_int8*, uint32_t, \ - uint32_t, cudaStream_t) +#define INST(stype) \ + template void cuda::elemwise_multi_type::fma3_iXxf32xf32xi8_bcast_1x( \ + const stype*, const float*, const float*, dt_int8*, uint32_t, uint32_t, \ + cudaStream_t) #define cb(t) INST(DTypeTrait::ctype); MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) #undef cb diff --git a/dnn/src/cuda/elemwise_multi_type/kern_impl.inl b/dnn/src/cuda/elemwise_multi_type/kern_impl.inl index ccec972a..2d5c410e 100644 --- a/dnn/src/cuda/elemwise_multi_type/kern_impl.inl +++ b/dnn/src/cuda/elemwise_multi_type/kern_impl.inl @@ -20,13 +20,12 @@ namespace megdnn { namespace cuda { -#define cb(_m) \ - typedef ElemwiseKern \ - KernImpl; \ - typedef kern_ops_quantized::QuantizedMultiTypeOp< \ - KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, KernImpl> \ - Op; \ +#define cb(_m) \ + typedef ElemwiseKern \ + KernImpl; \ + typedef kern_ops_quantized::QuantizedMultiTypeOp< \ + KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, KernImpl> \ + Op; \ INST_RUN_ELEMWISE(Op, KERN_IMPL_STYPE, KERN_IMPL_ARITY); KERN_IMPL_MODE(cb) diff --git a/dnn/src/cuda/elemwise_multi_type/kern_impl_q4.inl b/dnn/src/cuda/elemwise_multi_type/kern_impl_q4.inl index f73f119b..01d7cf90 100644 --- a/dnn/src/cuda/elemwise_multi_type/kern_impl_q4.inl +++ b/dnn/src/cuda/elemwise_multi_type/kern_impl_q4.inl @@ -21,15 +21,13 @@ namespace megdnn { namespace cuda { -#define cb(_m) \ - typedef ElemwiseKern \ - KernImpl; \ - typedef kern_ops_quantized::QuantizedMultiTypeOp< \ - KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, KernImpl> \ - Op; \ - INST_RUN_ELEMWISE_LOWBIT(Op, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, \ - KERN_IMPL_ARITY); +#define cb(_m) \ + typedef ElemwiseKern \ + KernImpl; \ + typedef kern_ops_quantized::QuantizedMultiTypeOp< \ + KERN_IMPL_ARITY, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, KernImpl> \ + Op; \ + INST_RUN_ELEMWISE_LOWBIT(Op, KERN_IMPL_STYPE, KERN_IMPL_DTYPE, KERN_IMPL_ARITY); KERN_IMPL_MODE(cb) diff --git a/dnn/src/cuda/elemwise_multi_type/kern_ops.cuh b/dnn/src/cuda/elemwise_multi_type/kern_ops.cuh index 9386f4c7..8df56ccb 100644 --- a/dnn/src/cuda/elemwise_multi_type/kern_ops.cuh +++ b/dnn/src/cuda/elemwise_multi_type/kern_ops.cuh @@ -14,8 +14,8 @@ #include "src/cuda/elemwise_helper.cuh" #include "src/cuda/elemwise_helper_q4.cuh" #include "src/cuda/elemwise_multi_type/kern.cuh" -#include "src/cuda/utils.cuh" #include "src/cuda/integer_subbyte_utils.cuh" +#include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { @@ -63,8 +63,7 @@ struct RoundShrSaturateIXxBcastScalarOp { } __device__ __forceinline__ void on(uint32_t idx) { - stype result = - rounding_shift_right_away_from_zero(a.at(idx), b.at(idx)); + stype result = rounding_shift_right_away_from_zero(a.at(idx), b.at(idx)); result = result < INT8_MAX ? result : INT8_MAX; result = result > INT8_MIN ? result : INT8_MIN; dst[idx] = static_cast(result); @@ -123,29 +122,29 @@ struct FuseAddRmulhRoundingShrBcastScalarOp { namespace kern_ops_quantized { -template +template < + int arity, typename ctype_src, typename ctype_dst, typename KernImpl, + typename enable = void> struct QuantizedMultiTypeOp; template struct QuantizedMultiTypeOp< 1, ctype_src, ctype_dst, KernImpl, - typename std::enable_if<(std::is_same::value || - std::is_same::value || - std::is_same::value) && - IsNotTypeQ4::value>::type> { + typename std::enable_if< + (std::is_same::value || + std::is_same::value || + std::is_same::value) && + IsNotTypeQ4::value>::type> { ctype_dst* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a; - typedef typename elemwise_intl::VectTypeTrait::vect_type - src_vect_type; - typedef typename elemwise_intl::VectTypeTrait::vect_type - dst_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type src_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type dst_vect_type; #if !MEGDNN_CC_CUDA QuantizedMultiTypeOp( - const SmallVector>& src_params, - ctype_dst* dst, const CudaDTypeParam& dst_param) + const SmallVector>& src_params, ctype_dst* dst, + const CudaDTypeParam& dst_param) : dst{dst}, dst_param{dst_param} { param_a = src_params[0]; } @@ -164,11 +163,9 @@ struct QuantizedMultiTypeOp< __device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a) { ctype_src a_x(a.x), a_y(a.y), a_z(a.z), a_w(a.w); - ctype_dst x = apply(a_x), y = apply(a_y), z = apply(a_z), - w = apply(a_w); + ctype_dst x = apply(a_x), y = apply(a_y), z = apply(a_z), w = apply(a_w); *(dst_vect_type*)(&dst[idx]) = - elemwise_intl::VectTypeTrait::make_vector(x, y, z, - w); + elemwise_intl::VectTypeTrait::make_vector(x, y, z, w); } #endif }; @@ -176,22 +173,21 @@ struct QuantizedMultiTypeOp< template struct QuantizedMultiTypeOp< 2, ctype_src, ctype_dst, KernImpl, - typename std::enable_if<(std::is_same::value || - std::is_same::value || - std::is_same::value) && - IsNotTypeQ4::value>::type> { + typename std::enable_if< + (std::is_same::value || + std::is_same::value || + std::is_same::value) && + IsNotTypeQ4::value>::type> { ctype_dst* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a, param_b; - typedef typename elemwise_intl::VectTypeTrait::vect_type - src_vect_type; - typedef typename elemwise_intl::VectTypeTrait::vect_type - dst_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type src_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type dst_vect_type; #if !MEGDNN_CC_CUDA QuantizedMultiTypeOp( - const SmallVector>& src_params, - ctype_dst* dst, const CudaDTypeParam& dst_param) + const SmallVector>& src_params, ctype_dst* dst, + const CudaDTypeParam& dst_param) : dst{dst}, dst_param{dst_param} { param_a = src_params[0]; param_b = src_params[1]; @@ -205,21 +201,19 @@ struct QuantizedMultiTypeOp< return dst_param.quantize(rv); } - __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a, - ctype_src b) { + __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a, ctype_src b) { dst[idx] = dst_param.quantize( KernImpl::apply(param_a.dequantize(a), param_b.dequantize(b))); } - __device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a, - src_vect_type b) { - ctype_src a_x(a.x), a_y(a.y), a_z(a.z), a_w(a.w), b_x(b.x), b_y(b.y), - b_z(b.z), b_w(b.w); + __device__ __forceinline__ void operator()( + uint32_t idx, src_vect_type a, src_vect_type b) { + ctype_src a_x(a.x), a_y(a.y), a_z(a.z), a_w(a.w), b_x(b.x), b_y(b.y), b_z(b.z), + b_w(b.w); ctype_dst x = apply(a_x, b_x), y = apply(a_y, b_y), z = apply(a_z, b_z), w = apply(a_w, b_w); *(dst_vect_type*)(&dst[idx]) = - elemwise_intl::VectTypeTrait::make_vector(x, y, z, - w); + elemwise_intl::VectTypeTrait::make_vector(x, y, z, w); } #endif }; @@ -227,22 +221,21 @@ struct QuantizedMultiTypeOp< template struct QuantizedMultiTypeOp< 3, ctype_src, ctype_dst, KernImpl, - typename std::enable_if<(std::is_same::value || - std::is_same::value || - std::is_same::value) && - IsNotTypeQ4::value>::type> { + typename std::enable_if< + (std::is_same::value || + std::is_same::value || + std::is_same::value) && + IsNotTypeQ4::value>::type> { ctype_dst* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a, param_b, param_c; - typedef typename elemwise_intl::VectTypeTrait::vect_type - src_vect_type; - typedef typename elemwise_intl::VectTypeTrait::vect_type - dst_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type src_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type dst_vect_type; #if !MEGDNN_CC_CUDA QuantizedMultiTypeOp( - const SmallVector>& src_params, - ctype_dst* dst, const CudaDTypeParam& dst_param) + const SmallVector>& src_params, ctype_dst* dst, + const CudaDTypeParam& dst_param) : dst{dst}, dst_param{dst_param} { param_a = src_params[0]; param_b = src_params[1]; @@ -251,31 +244,28 @@ struct QuantizedMultiTypeOp< #endif #if MEGDNN_CC_CUDA - __device__ __forceinline__ ctype_dst apply(ctype_src v1, ctype_src v2, - ctype_src v3) { + __device__ __forceinline__ ctype_dst + apply(ctype_src v1, ctype_src v2, ctype_src v3) { float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2), fv3 = param_c.dequantize(v3); float rv = KernImpl::apply(fv1, fv2, fv3); return dst_param.quantize(rv); } - __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a, - ctype_src b, ctype_src c) { - dst[idx] = dst_param.quantize(KernImpl::apply(param_a.dequantize(a), - param_b.dequantize(b), - param_c.dequantize(c))); + __device__ __forceinline__ void operator()( + uint32_t idx, ctype_src a, ctype_src b, ctype_src c) { + dst[idx] = dst_param.quantize(KernImpl::apply( + param_a.dequantize(a), param_b.dequantize(b), param_c.dequantize(c))); } - __device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a, - src_vect_type b, - src_vect_type c) { - ctype_src a_x(a.x), a_y(a.y), a_z(a.z), a_w(a.w), b_x(b.x), b_y(b.y), - b_z(b.z), b_w(b.w), c_x(c.x), c_y(c.y), c_z(c.z), c_w(c.w); + __device__ __forceinline__ void operator()( + uint32_t idx, src_vect_type a, src_vect_type b, src_vect_type c) { + ctype_src a_x(a.x), a_y(a.y), a_z(a.z), a_w(a.w), b_x(b.x), b_y(b.y), b_z(b.z), + b_w(b.w), c_x(c.x), c_y(c.y), c_z(c.z), c_w(c.w); ctype_dst x = apply(a_x, b_x, c_x), y = apply(a_y, b_y, c_y), z = apply(a_z, b_z, c_z), w = apply(a_w, b_w, c_w); *(dst_vect_type*)(&dst[idx]) = - elemwise_intl::VectTypeTrait::make_vector(x, y, z, - w); + elemwise_intl::VectTypeTrait::make_vector(x, y, z, w); } #endif }; @@ -283,16 +273,16 @@ struct QuantizedMultiTypeOp< template struct QuantizedMultiTypeOp< 1, ctype_src, ctype_dst, KernImpl, - typename std::enable_if::value && - IsNotTypeQ4::value>::type> { + typename std::enable_if< + IsTypeQ4::value && IsNotTypeQ4::value>::type> { ctype_dst* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a; #if !MEGDNN_CC_CUDA QuantizedMultiTypeOp( - const SmallVector>& src_params, - ctype_dst* dst, const CudaDTypeParam& dst_param) + const SmallVector>& src_params, ctype_dst* dst, + const CudaDTypeParam& dst_param) : dst{dst}, dst_param{dst_param} { param_a = src_params[0]; } @@ -314,16 +304,16 @@ struct QuantizedMultiTypeOp< template struct QuantizedMultiTypeOp< 2, ctype_src, ctype_dst, KernImpl, - typename std::enable_if::value && - IsNotTypeQ4::value>::type> { + typename std::enable_if< + IsTypeQ4::value && IsNotTypeQ4::value>::type> { ctype_dst* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a, param_b; #if !MEGDNN_CC_CUDA QuantizedMultiTypeOp( - const SmallVector>& src_params, - ctype_dst* dst, const CudaDTypeParam& dst_param) + const SmallVector>& src_params, ctype_dst* dst, + const CudaDTypeParam& dst_param) : dst{dst}, dst_param{dst_param} { param_a = src_params[0]; param_b = src_params[1]; @@ -337,8 +327,7 @@ struct QuantizedMultiTypeOp< return dst_param.quantize(rv); } - __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a, - ctype_src b) { + __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a, ctype_src b) { dst[idx] = dst_param.quantize( KernImpl::apply(param_a.dequantize(a), param_b.dequantize(b))); } @@ -348,26 +337,21 @@ struct QuantizedMultiTypeOp< template struct QuantizedMultiTypeOp< 1, ctype_src, ctype_dst, KernImpl, - typename std::enable_if::value && - IsTypeQ4::value>::type> { - using src_storage = - typename elemwise_intl::VectTypeTrait::Storage; - using dst_storage = - typename elemwise_intl::VectTypeTrait::Storage; + typename std::enable_if< + IsTypeQ4::value && IsTypeQ4::value>::type> { + using src_storage = typename elemwise_intl::VectTypeTrait::Storage; + using dst_storage = typename elemwise_intl::VectTypeTrait::Storage; dst_storage* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a; - static constexpr bool src_signedness = - std::is_same::value; - typedef typename elemwise_intl::VectTypeTrait::vect_type - src_vect_type; - typedef typename elemwise_intl::VectTypeTrait::vect_type - dst_vect_type; + static constexpr bool src_signedness = std::is_same::value; + typedef typename elemwise_intl::VectTypeTrait::vect_type src_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type dst_vect_type; #if !MEGDNN_CC_CUDA QuantizedMultiTypeOp( - const SmallVector>& src_params, - dst_storage* dst, const CudaDTypeParam& dst_param) + const SmallVector>& src_params, dst_storage* dst, + const CudaDTypeParam& dst_param) : dst{dst}, dst_param{dst_param} { param_a = src_params[0]; } @@ -395,22 +379,21 @@ struct QuantizedMultiTypeOp< template struct QuantizedMultiTypeOp< 1, ctype_src, ctype_dst, KernImpl, - typename std::enable_if<(std::is_same::value || - std::is_same::value || - std::is_same::value) && - IsTypeQ4::value>::type> { - using dst_storage = - typename elemwise_intl::VectTypeTrait::Storage; + typename std::enable_if< + (std::is_same::value || + std::is_same::value || + std::is_same::value) && + IsTypeQ4::value>::type> { + using dst_storage = typename elemwise_intl::VectTypeTrait::Storage; dst_storage* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a; - typedef typename elemwise_intl::VectTypeTrait::vect_type - dst_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type dst_vect_type; #if !MEGDNN_CC_CUDA QuantizedMultiTypeOp( - const SmallVector>& src_params, - dst_storage* dst, const CudaDTypeParam& dst_param) + const SmallVector>& src_params, dst_storage* dst, + const CudaDTypeParam& dst_param) : dst{dst}, dst_param{dst_param} { param_a = src_params[0]; } @@ -423,8 +406,8 @@ struct QuantizedMultiTypeOp< return dst_param.quantize(rv).as_storage(); } - __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x, - ctype_src a_y) { + __device__ __forceinline__ void operator()( + uint32_t idx, ctype_src a_x, ctype_src a_y) { dst_storage x = apply(a_x), y = apply(a_y); *(dst_vect_type*)(&dst[idx]) = elemwise_intl::VectTypeTrait::make_vector(x, y); @@ -435,26 +418,21 @@ struct QuantizedMultiTypeOp< template struct QuantizedMultiTypeOp< 2, ctype_src, ctype_dst, KernImpl, - typename std::enable_if::value && - IsTypeQ4::value>::type> { - using src_storage = - typename elemwise_intl::VectTypeTrait::Storage; - using dst_storage = - typename elemwise_intl::VectTypeTrait::Storage; + typename std::enable_if< + IsTypeQ4::value && IsTypeQ4::value>::type> { + using src_storage = typename elemwise_intl::VectTypeTrait::Storage; + using dst_storage = typename elemwise_intl::VectTypeTrait::Storage; dst_storage* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a, param_b; - static constexpr bool src_signedness = - std::is_same::value; - typedef typename elemwise_intl::VectTypeTrait::vect_type - src_vect_type; - typedef typename elemwise_intl::VectTypeTrait::vect_type - dst_vect_type; + static constexpr bool src_signedness = std::is_same::value; + typedef typename elemwise_intl::VectTypeTrait::vect_type src_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type dst_vect_type; #if !MEGDNN_CC_CUDA QuantizedMultiTypeOp( - const SmallVector>& src_params, - dst_storage* dst, const CudaDTypeParam& dst_param) + const SmallVector>& src_params, dst_storage* dst, + const CudaDTypeParam& dst_param) : dst{dst}, dst_param{dst_param} { param_a = src_params[0]; param_b = src_params[1]; @@ -462,15 +440,14 @@ struct QuantizedMultiTypeOp< #endif #if MEGDNN_CC_CUDA - __device__ __forceinline__ dst_storage apply(src_storage v1, - src_storage v2) { + __device__ __forceinline__ dst_storage apply(src_storage v1, src_storage v2) { float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2); float rv = KernImpl::apply(fv1, fv2); return dst_param.quantize(rv).as_storage(); } - __device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a, - src_vect_type b) { + __device__ __forceinline__ void operator()( + uint32_t idx, src_vect_type a, src_vect_type b) { src_storage a_x = src_storage( integer_subbyte::unpack_integer_4bits(a.x, 0)); src_storage a_y = src_storage( @@ -491,22 +468,21 @@ struct QuantizedMultiTypeOp< template struct QuantizedMultiTypeOp< 2, ctype_src, ctype_dst, KernImpl, - typename std::enable_if<(std::is_same::value || - std::is_same::value || - std::is_same::value) && - IsTypeQ4::value>::type> { - using dst_storage = - typename elemwise_intl::VectTypeTrait::Storage; + typename std::enable_if< + (std::is_same::value || + std::is_same::value || + std::is_same::value) && + IsTypeQ4::value>::type> { + using dst_storage = typename elemwise_intl::VectTypeTrait::Storage; dst_storage* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a, param_b; - typedef typename elemwise_intl::VectTypeTrait::vect_type - dst_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type dst_vect_type; #if !MEGDNN_CC_CUDA QuantizedMultiTypeOp( - const SmallVector>& src_params, - dst_storage* dst, const CudaDTypeParam& dst_param) + const SmallVector>& src_params, dst_storage* dst, + const CudaDTypeParam& dst_param) : dst{dst}, dst_param{dst_param} { param_a = src_params[0]; param_b = src_params[1]; @@ -520,9 +496,8 @@ struct QuantizedMultiTypeOp< return dst_param.quantize(rv).as_storage(); } - __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x, - ctype_src b_x, ctype_src a_y, - ctype_src b_y) { + __device__ __forceinline__ void operator()( + uint32_t idx, ctype_src a_x, ctype_src b_x, ctype_src a_y, ctype_src b_y) { dst_storage x = apply(a_x, b_x), y = apply(a_y, b_y); *(dst_vect_type*)(&dst[idx]) = @@ -534,26 +509,21 @@ struct QuantizedMultiTypeOp< template struct QuantizedMultiTypeOp< 3, ctype_src, ctype_dst, KernImpl, - typename std::enable_if::value && - IsTypeQ4::value>::type> { - using src_storage = - typename elemwise_intl::VectTypeTrait::Storage; - using dst_storage = - typename elemwise_intl::VectTypeTrait::Storage; + typename std::enable_if< + IsTypeQ4::value && IsTypeQ4::value>::type> { + using src_storage = typename elemwise_intl::VectTypeTrait::Storage; + using dst_storage = typename elemwise_intl::VectTypeTrait::Storage; dst_storage* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a, param_b, param_c; - static constexpr bool src_signedness = - std::is_same::value; - typedef typename elemwise_intl::VectTypeTrait::vect_type - src_vect_type; - typedef typename elemwise_intl::VectTypeTrait::vect_type - dst_vect_type; + static constexpr bool src_signedness = std::is_same::value; + typedef typename elemwise_intl::VectTypeTrait::vect_type src_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type dst_vect_type; #if !MEGDNN_CC_CUDA QuantizedMultiTypeOp( - const SmallVector>& src_params, - dst_storage* dst, const CudaDTypeParam& dst_param) + const SmallVector>& src_params, dst_storage* dst, + const CudaDTypeParam& dst_param) : dst{dst}, dst_param{dst_param} { param_a = src_params[0]; param_b = src_params[1]; @@ -562,17 +532,16 @@ struct QuantizedMultiTypeOp< #endif #if MEGDNN_CC_CUDA - __device__ __forceinline__ dst_storage apply(src_storage v1, src_storage v2, - src_storage v3) { + __device__ __forceinline__ dst_storage + apply(src_storage v1, src_storage v2, src_storage v3) { float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2), fv3 = param_c.dequantize(v3); float rv = KernImpl::apply(fv1, fv2, fv3); return dst_param.quantize(rv).as_storage(); } - __device__ __forceinline__ void operator()(uint32_t idx, src_vect_type a, - src_vect_type b, - src_vect_type c) { + __device__ __forceinline__ void operator()( + uint32_t idx, src_vect_type a, src_vect_type b, src_vect_type c) { src_storage a_x = src_storage( integer_subbyte::unpack_integer_4bits(a.x, 0)); src_storage a_y = src_storage( @@ -597,22 +566,21 @@ struct QuantizedMultiTypeOp< template struct QuantizedMultiTypeOp< 3, ctype_src, ctype_dst, KernImpl, - typename std::enable_if<(std::is_same::value || - std::is_same::value || - std::is_same::value) && - IsTypeQ4::value>::type> { - using dst_storage = - typename elemwise_intl::VectTypeTrait::Storage; + typename std::enable_if< + (std::is_same::value || + std::is_same::value || + std::is_same::value) && + IsTypeQ4::value>::type> { + using dst_storage = typename elemwise_intl::VectTypeTrait::Storage; dst_storage* dst; CudaDTypeParam dst_param; CudaDTypeParam param_a, param_b, param_c; - typedef typename elemwise_intl::VectTypeTrait::vect_type - dst_vect_type; + typedef typename elemwise_intl::VectTypeTrait::vect_type dst_vect_type; #if !MEGDNN_CC_CUDA QuantizedMultiTypeOp( - const SmallVector>& src_params, - dst_storage* dst, const CudaDTypeParam& dst_param) + const SmallVector>& src_params, dst_storage* dst, + const CudaDTypeParam& dst_param) : dst{dst}, dst_param{dst_param} { param_a = src_params[0]; param_b = src_params[1]; @@ -621,18 +589,17 @@ struct QuantizedMultiTypeOp< #endif #if MEGDNN_CC_CUDA - __device__ __forceinline__ dst_storage apply(ctype_src v1, ctype_src v2, - ctype_src v3) { + __device__ __forceinline__ dst_storage + apply(ctype_src v1, ctype_src v2, ctype_src v3) { float fv1 = param_a.dequantize(v1), fv2 = param_b.dequantize(v2), fv3 = param_c.dequantize(v3); float rv = KernImpl::apply(fv1, fv2, fv3); return dst_param.quantize(rv).as_storage(); } - __device__ __forceinline__ void operator()(uint32_t idx, ctype_src a_x, - ctype_src b_x, ctype_src c_x, - ctype_src a_y, ctype_src b_y, - ctype_src c_y) { + __device__ __forceinline__ void operator()( + uint32_t idx, ctype_src a_x, ctype_src b_x, ctype_src c_x, ctype_src a_y, + ctype_src b_y, ctype_src c_y) { dst_storage x = apply(a_x, b_x, c_x), y = apply(a_y, b_y, c_y); *(dst_vect_type*)(&dst[idx]) = diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_GRAD_dt_qint8_dt_qint8.cu index c5a54505..04e93ee8 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_GRAD_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_GRAD_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_qint4_dt_qint4.cu index d4211387..17edb6fb 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_qint8_dt_qint8.cu index eaf9025e..a6b29b9f 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_quint4_dt_quint4.cu index 2bc037a9..9255e661 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ABS_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ACOS_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ACOS_dt_qint8_dt_qint8.cu index 4405606d..a4a42013 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ACOS_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ACOS_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_qint4.cu index a979c47f..5975c4b3 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_qint8.cu index b50632a9..baaf92ea 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_quint4.cu index cf3b9103..155650fd 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint32_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint4_dt_qint32.cu index 7a2bb2f4..e5b69138 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint4_dt_qint4.cu index 8f3f5dd4..ffe83a84 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint8_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint8_dt_qint32.cu index b9ccf32d..baaa6173 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint8_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint8_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint8_dt_qint8.cu index dc0be493..ec064824 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_quint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_quint4_dt_qint32.cu index 323dd929..9a3bacaa 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_quint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_quint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_quint4_dt_quint4.cu index d063e249..51008353 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ADD_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ASIN_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ASIN_dt_qint8_dt_qint8.cu index 2e849193..33717911 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ASIN_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ASIN_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ATAN2_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ATAN2_dt_qint8_dt_qint8.cu index 1ed9241e..02809ed0 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ATAN2_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ATAN2_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_qint4_dt_qint4.cu index d844ef6e..45e83443 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_qint8_dt_qint8.cu index 6e875dfd..7da55d1d 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_quint4_dt_quint4.cu index 18d6c377..b6c9f76c 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/CEIL_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_qint4_dt_qint4.cu index 625774e0..50a8b2e3 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_qint8_dt_qint8.cu index 8b863b4c..23bb2651 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_quint4_dt_quint4.cu index 9ecba9c7..27b34450 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/COND_LEQ_MOV_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/COS_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/COS_dt_qint8_dt_qint8.cu index 43beeed2..05a4cc98 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/COS_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/COS_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_qint4_dt_qint4.cu index 4186b04b..c0856884 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_qint8_dt_qint8.cu index 0fb375a3..253f8ead 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_quint4_dt_quint4.cu index 64a314ad..fbea263c 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EQ_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ERFCINV_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ERFCINV_dt_qint8_dt_qint8.cu index be0a1b4e..262f53b6 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ERFCINV_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ERFCINV_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ERFC_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ERFC_dt_qint8_dt_qint8.cu index 090093e5..1264f32f 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ERFC_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ERFC_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ERFINV_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ERFINV_dt_qint8_dt_qint8.cu index 0857b6ca..25b77e58 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ERFINV_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ERFINV_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ERF_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ERF_dt_qint8_dt_qint8.cu index aba969ae..18dd3340 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ERF_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ERF_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EXPM1_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EXPM1_dt_qint8_dt_qint8.cu index 661079a7..e3c30bc0 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/EXPM1_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EXPM1_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/EXP_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/EXP_dt_qint8_dt_qint8.cu index 7238fbdf..8b951844 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/EXP_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/EXP_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_GRAD_dt_qint8_dt_qint8.cu index 4a2f5e4d..5f5a4527 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_GRAD_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_GRAD_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_qint4.cu index be7a40fd..69494283 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_qint8.cu index 59152e79..f2ca7fa6 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_quint4.cu index df42e885..76069013 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint32_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint4_dt_qint32.cu index a463825e..f322eee2 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint4_dt_qint4.cu index 1e558b17..1e22b0d1 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint8_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint8_dt_qint32.cu index 18d9feaa..3b1d1562 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint8_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint8_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint8_dt_qint8.cu index caf308d9..dd32a5ca 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_quint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_quint4_dt_qint32.cu index 8503017f..0fbae6ce 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_quint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_quint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_quint4_dt_quint4.cu index cc94da23..c45499ff 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FAST_TANH_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_DIV_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_DIV_dt_qint8_dt_qint8.cu index f733311d..da17b9a0 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_DIV_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_DIV_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_qint4_dt_qint4.cu index 2fc92120..c3ef3d10 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_qint8_dt_qint8.cu index 237a65bb..525a1593 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_quint4_dt_quint4.cu index f16f555e..670abc1e 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FLOOR_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_qint4.cu index 8a68e058..615d6c05 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_qint8.cu index 50e17237..9012db49 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_quint4.cu index 5f118c81..42faac70 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint32_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint4_dt_qint32.cu index 1cdd04fc..0fafaa84 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint4_dt_qint4.cu index 43afad0a..83ae6fdb 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint8_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint8_dt_qint32.cu index dc875db2..d48ef392 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint8_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint8_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint8_dt_qint8.cu index 1a1c84bb..3c281d9c 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_quint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_quint4_dt_qint32.cu index 3df745a6..ad8e0793 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_quint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_quint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_quint4_dt_quint4.cu index b32ae7dd..974882f1 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_H_SWISH_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_qint4.cu index 6a72e8fe..acc71ad5 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_qint8.cu index 6673c738..51e7a1f8 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_quint4.cu index 19f4984e..91eb38fa 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint32_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint4_dt_qint32.cu index bcc6c93f..18b18cf0 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint4_dt_qint4.cu index 1bbe4b65..6981ff20 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint8_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint8_dt_qint32.cu index 36c33235..fc16aa31 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint8_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint8_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint8_dt_qint8.cu index 0ee3bb5a..af6e0882 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_quint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_quint4_dt_qint32.cu index 0ed56d05..2c1ae8db 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_quint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_quint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_quint4_dt_quint4.cu index 5c626d4c..f0618dde 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_RELU_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_qint4.cu index af01bb60..11ccaf54 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_qint8.cu index 211fa3ed..565f4bc7 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_quint4.cu index 1a989ece..94b42db5 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint32_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint4_dt_qint32.cu index 47192b8e..b88720f4 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint4_dt_qint4.cu index 37f77a1a..05598c77 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint8_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint8_dt_qint32.cu index 7c167af4..e2e7284b 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint8_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint8_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint8_dt_qint8.cu index d8ff17ae..a53fd990 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_quint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_quint4_dt_qint32.cu index 1a975d24..9578c1cd 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_quint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_quint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_quint4_dt_quint4.cu index b3baefab..3b9ec4b8 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_SIGMOID_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_qint4.cu index 6dd395e0..04c8ccb4 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_qint8.cu index addb6872..66cd04d5 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_quint4.cu index 98142053..e7269a41 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint32_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint4_dt_qint32.cu index 6c86b0a9..fff30206 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint4_dt_qint4.cu index 4340a282..9167c03f 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint8_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint8_dt_qint32.cu index 2b369e12..ce4a6404 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint8_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint8_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint8_dt_qint8.cu index 593dfa54..6c8f61c0 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_quint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_quint4_dt_qint32.cu index 484641d2..65e9e467 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_quint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_quint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_quint4_dt_quint4.cu index 66066de0..f45b7e51 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_ADD_TANH_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_qint4_dt_qint4.cu index 0314a2ff..5bb3f69e 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_qint8_dt_qint8.cu index ea5c80b5..192176b9 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_quint4_dt_quint4.cu index 2b16977d..7412be7c 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/FUSE_MUL_ADD3_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) -#define KERN_IMPL_ARITY 3 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_GRAD_dt_qint8_dt_qint8.cu index a200df94..d1fb8693 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_GRAD_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_GRAD_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_dt_qint8_dt_qint8.cu index 8cc57ad7..115975e7 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/GELU_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(GELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_GRAD_dt_qint8_dt_qint8.cu index f7c397c0..dd68332d 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_GRAD_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_GRAD_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_qint4.cu index 13c9dc2c..f5329cb8 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_qint8.cu index fddd0fa1..13939292 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_quint4.cu index 889fac16..b8f11208 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint32_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint4_dt_qint32.cu index d8d61570..6da6e8bd 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint4_dt_qint4.cu index 9d1a414f..9a035c13 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint8_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint8_dt_qint32.cu index 6b5de5de..6f71c0b9 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint8_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint8_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint8_dt_qint8.cu index 4df2fe61..1cbab2e1 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_quint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_quint4_dt_qint32.cu index e0641f28..8f972020 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_quint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_quint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_quint4_dt_quint4.cu index 0052c5a8..fe649f4a 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/H_SWISH_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_qint4_dt_qint4.cu index 71d2f8d7..b525f8e7 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_qint8_dt_qint8.cu index 9abb22c8..72f21db7 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_quint4_dt_quint4.cu index 0e255830..861e0c93 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LEQ_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LOG1P_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LOG1P_dt_qint8_dt_qint8.cu index 66ecb426..6a98561f 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/LOG1P_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LOG1P_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LOG_SUM_EXP_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LOG_SUM_EXP_dt_qint8_dt_qint8.cu index 56325848..e6761aca 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/LOG_SUM_EXP_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LOG_SUM_EXP_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LOG_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LOG_dt_qint8_dt_qint8.cu index 195cc341..a1d9cfdf 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/LOG_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LOG_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_qint4_dt_qint4.cu index 196cce34..503d868b 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_qint8_dt_qint8.cu index be46f227..01f745a9 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_quint4_dt_quint4.cu index 6832349d..048c5599 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/LT_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_qint4_dt_qint4.cu index 68cdbd8a..baf5dd22 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_qint8_dt_qint8.cu index dd9e7000..30a861f4 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_quint4_dt_quint4.cu index 6fc236b1..f661f72c 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/MAX_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_qint4_dt_qint4.cu index 1fd5c646..385d07de 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_qint8_dt_qint8.cu index 31cf5052..93c97706 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_quint4_dt_quint4.cu index a82e56d7..43294de8 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/MIN_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/MOD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/MOD_dt_qint8_dt_qint8.cu index 47e2912c..666eeae7 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/MOD_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/MOD_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_qint4_dt_qint4.cu index 29c30984..3e5c270d 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_qint8_dt_qint8.cu index 92d7fa64..67acc5dd 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_quint4_dt_quint4.cu index 9ecf4ad1..5029cff8 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/MUL_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_qint4_dt_qint4.cu index 353ad75b..41e1802b 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_qint8_dt_qint8.cu index d3a0474b..4d9651ff 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_quint4_dt_quint4.cu index cbdb016e..1552fa5b 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/NEGATE_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/POW_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/POW_dt_qint8_dt_qint8.cu index 8fb12675..3929944d 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/POW_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/POW_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_qint4.cu index d3e611fa..df9a2256 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_qint8.cu index 9497ac7e..a152aca4 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_quint4.cu index c8a08155..22b634d3 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint32_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint4_dt_qint32.cu index 67199f82..43bb8d7f 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint4_dt_qint4.cu index a8e0ad31..24a01dbb 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint8_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint8_dt_qint32.cu index 1336e904..b9b4f9f2 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint8_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint8_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint8_dt_qint8.cu index 50a94f1a..264d78a3 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_quint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_quint4_dt_qint32.cu index 10a6bf76..a09dc7d4 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_quint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_quint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_quint4_dt_quint4.cu index 106d0042..18d702e0 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/RELU_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_qint4_dt_qint4.cu index fd2b7584..5440e607 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_qint8_dt_qint8.cu index 94bfc75a..fce07d48 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_quint4_dt_quint4.cu index d2096352..ccde3cd6 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/ROUND_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_GRAD_dt_qint8_dt_qint8.cu index 18fed731..11da90e0 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_GRAD_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_GRAD_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_qint4.cu index f809012a..0f3303a9 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_qint8.cu index 8c61c9ee..a53e3058 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_quint4.cu index 8540d498..59dd30b5 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint32_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint4_dt_qint32.cu index a4b8c4e5..c487bfc1 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint4_dt_qint4.cu index 08f23b90..b0da2767 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint8_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint8_dt_qint32.cu index e2f341d5..356c3207 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint8_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint8_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint8_dt_qint8.cu index f10a6b3a..15c54773 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_quint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_quint4_dt_qint32.cu index 6788a3ed..3d35ae92 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_quint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_quint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_quint4_dt_quint4.cu index 4a043116..91ee7c20 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIGMOID_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_GRAD_dt_qint8_dt_qint8.cu index b6b198b8..b610ad66 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_GRAD_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_GRAD_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_dt_qint8_dt_qint8.cu index f2127860..617fdea5 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SILU_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SILU, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SIN_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SIN_dt_qint8_dt_qint8.cu index 9bbc4212..81aebfcd 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SIN_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SIN_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_qint4_dt_qint4.cu index c9013eda..95ef3642 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_qint8_dt_qint8.cu index bb8142fb..4ec384d3 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_quint4_dt_quint4.cu index c7633d51..69a1673e 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SUB_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_qint4_dt_qint4.cu index 528a26da..6bb548df 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_qint8_dt_qint8.cu index 04f2ee6a..3f6c1997 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_quint4_dt_quint4.cu index b6ac9ffd..ee94ffe3 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/SWITCH_GT0_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_GRAD_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_GRAD_dt_qint8_dt_qint8.cu index 21902963..0da6b34e 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_GRAD_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_GRAD_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_qint4.cu index b09d8b67..3325f058 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_qint8.cu index 076e4cbf..35b93374 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_quint4.cu index b1663ff2..1ab5ec20 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint32_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint32 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint32 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint4_dt_qint32.cu index 506b9ace..63148d64 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint4_dt_qint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint4_dt_qint4.cu index 54c62553..ec6f9ff2 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint4_dt_qint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint4_dt_qint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint4 -#define KERN_IMPL_DTYPE dt_qint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint4 +#define KERN_IMPL_DTYPE dt_qint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint8_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint8_dt_qint32.cu index 34ed375c..613e6110 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint8_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint8_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint8_dt_qint8.cu index 12ce0edd..e2f2275f 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_quint4_dt_qint32.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_quint4_dt_qint32.cu index 96c6c801..313580c8 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_quint4_dt_qint32.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_quint4_dt_qint32.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_qint32 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_qint32 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_quint4_dt_quint4.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_quint4_dt_quint4.cu index 208e4564..a1b27c72 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_quint4_dt_quint4.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TANH_dt_quint4_dt_quint4.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) -#define KERN_IMPL_ARITY 1 -#define KERN_IMPL_STYPE dt_quint4 -#define KERN_IMPL_DTYPE dt_quint4 +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_STYPE dt_quint4 +#define KERN_IMPL_DTYPE dt_quint4 #include "../kern_impl_q4.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/kimpl/TRUE_DIV_dt_qint8_dt_qint8.cu b/dnn/src/cuda/elemwise_multi_type/kimpl/TRUE_DIV_dt_qint8_dt_qint8.cu index 823c69b2..f1104277 100644 --- a/dnn/src/cuda/elemwise_multi_type/kimpl/TRUE_DIV_dt_qint8_dt_qint8.cu +++ b/dnn/src/cuda/elemwise_multi_type/kimpl/TRUE_DIV_dt_qint8_dt_qint8.cu @@ -1,6 +1,6 @@ // generated by gen_elemwise_multi_type_kern_impls.py #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) -#define KERN_IMPL_ARITY 2 -#define KERN_IMPL_STYPE dt_qint8 -#define KERN_IMPL_DTYPE dt_qint8 +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_STYPE dt_qint8 +#define KERN_IMPL_DTYPE dt_qint8 #include "../kern_impl.inl" diff --git a/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp b/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp index 5346d27b..d9d7f8d8 100644 --- a/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp +++ b/dnn/src/cuda/elemwise_multi_type/opr_impl.cpp @@ -25,8 +25,7 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( BroadcastChannelInfo binfo0, binfo1; if (is_vector(param[0].layout) && is_broadcasted_channel_like(param[1].layout, binfo0) && - is_broadcasted_channel_like(param[2].layout, binfo1) && - binfo0 == binfo1) { + is_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { elemwise_multi_type::fma3_int16x32x32x32_1c1( param, dst, cuda_stream(this->handle())); return; @@ -39,15 +38,14 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( Broadcast1xInfo binfo0, binfo1; auto p1 = param[1].ptr(), p2 = param[2].ptr(); auto stream = cuda_stream(this->handle()); - if (is_vector(param[0].layout) && - is_broadcasted_1x(param[1].layout, binfo0) && + if (is_vector(param[0].layout) && is_broadcasted_1x(param[1].layout, binfo0) && is_broadcasted_1x(param[2].layout, binfo1) && binfo0 == binfo1) { switch (param[0].layout.dtype.enumv()) { -#define cb(t) \ - case DTypeTrait::enumv: \ - elemwise_multi_type::fma3_iXxf32xf32xi8_bcast_1x( \ - param[0].ptr::ctype>(), p1, p2, dst, binfo0.x, \ - binfo0.y, stream); \ +#define cb(t) \ + case DTypeTrait::enumv: \ + elemwise_multi_type::fma3_iXxf32xf32xi8_bcast_1x( \ + param[0].ptr::ctype>(), p1, p2, dst, binfo0.x, binfo0.y, \ + stream); \ return; MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) #undef cb @@ -96,8 +94,8 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( is_broadcasted_scalar(param[3].layout) && is_broadcasted_scalar(param[4].layout) && is_broadcasted_scalar(param[5].layout)) { - elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11< - dt_int16>(param, dst, stream); + elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11( + param, dst, stream); return; } megdnn_throw( @@ -117,8 +115,8 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( is_broadcasted_scalar(param[3].layout) && is_broadcasted_scalar(param[4].layout) && is_broadcasted_scalar(param[5].layout)) { - elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11< - dt_int32>(param, dst, stream); + elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11( + param, dst, stream); return; } megdnn_throw( @@ -159,38 +157,36 @@ namespace { template struct ModeDispatcher; -#define _cb_dispatch_mode(_m) \ - case param::Elemwise::Mode::_m: \ - do { \ - using KernImpl = \ - ElemwiseKern; \ - using Op = kern_ops_quantized::QuantizedMultiTypeOp< \ - arity, src_ctype, dst_ctype, KernImpl>; \ - dst_ctype* dst = dst_tensor.ptr(); \ - Op op(src_params, dst, dst_param); \ - return run_elemwise(param, stream, op); \ +#define _cb_dispatch_mode(_m) \ + case param::Elemwise::Mode::_m: \ + do { \ + using KernImpl = ElemwiseKern< \ + megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, float>; \ + using Op = kern_ops_quantized::QuantizedMultiTypeOp< \ + arity, src_ctype, dst_ctype, KernImpl>; \ + dst_ctype* dst = dst_tensor.ptr(); \ + Op op(src_params, dst, dst_param); \ + return run_elemwise(param, stream, op); \ } while (0); -#define IMPL_MODE_DISPATCHER(_arity, _src_ctype, _dst_ctype) \ - template <> \ - struct ModeDispatcher<_arity, _src_ctype, _dst_ctype> { \ - static constexpr int arity = _arity; \ - using src_ctype = _src_ctype; \ - using dst_ctype = _dst_ctype; \ - static void run( \ - const ElemwiseOpParamN<_arity>& param, \ - const TensorND& dst_tensor, \ - const SmallVector>& src_params, \ - const CudaDTypeParam<_dst_ctype>& dst_param, \ - param::Elemwise::Mode mode, cudaStream_t stream) { \ - megdnn_assert(src_params.size() == _arity); \ - switch (mode) { \ - FOREACH(_cb_dispatch_mode) \ - default: \ - megdnn_throw("bad mode"); \ - } \ - } \ +#define IMPL_MODE_DISPATCHER(_arity, _src_ctype, _dst_ctype) \ + template <> \ + struct ModeDispatcher<_arity, _src_ctype, _dst_ctype> { \ + static constexpr int arity = _arity; \ + using src_ctype = _src_ctype; \ + using dst_ctype = _dst_ctype; \ + static void run( \ + const ElemwiseOpParamN<_arity>& param, const TensorND& dst_tensor, \ + const SmallVector>& src_params, \ + const CudaDTypeParam<_dst_ctype>& dst_param, \ + param::Elemwise::Mode mode, cudaStream_t stream) { \ + megdnn_assert(src_params.size() == _arity); \ + switch (mode) { \ + FOREACH(_cb_dispatch_mode) \ + default: \ + megdnn_throw("bad mode"); \ + } \ + } \ } #define FOREACH MEGDNN_FOREACH_ELEMWISE_MODE_UNARY_FLOAT @@ -231,24 +227,22 @@ IMPL_MODE_DISPATCHER(2, dt_quint4, dt_qint32); #undef _cb_dispatch_mode -#define _cb_dispatch_mode(_m) \ - case param::Elemwise::Mode::_m: \ - do { \ - using KernImpl = \ - ElemwiseKern; \ - using Op = kern_ops_quantized::QuantizedMultiTypeOp< \ - arity, src_ctype, dst_ctype, KernImpl>; \ - using dst_storage = typename VectTypeTrait::Storage; \ - dst_storage* dst = \ - reinterpret_cast(dst_tensor.raw_ptr); \ - Op op(src_params, dst, dst_param); \ - ElemwiseOpParamN<1> param_dst; \ - param_dst[0] = dst_tensor; \ - param_dst.init_from_given_tensor(); \ - run_elemwise(param, param_dst, \ - stream, op); \ - return; \ +#define _cb_dispatch_mode(_m) \ + case param::Elemwise::Mode::_m: \ + do { \ + using KernImpl = ElemwiseKern< \ + megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, float>; \ + using Op = kern_ops_quantized::QuantizedMultiTypeOp< \ + arity, src_ctype, dst_ctype, KernImpl>; \ + using dst_storage = typename VectTypeTrait::Storage; \ + dst_storage* dst = reinterpret_cast(dst_tensor.raw_ptr); \ + Op op(src_params, dst, dst_param); \ + ElemwiseOpParamN<1> param_dst; \ + param_dst[0] = dst_tensor; \ + param_dst.init_from_given_tensor(); \ + run_elemwise( \ + param, param_dst, stream, op); \ + return; \ } while (0); #define FOREACH(cb) \ @@ -313,8 +307,9 @@ IMPL_MODE_DISPATCHER(2, dt_qint32, dt_quint4); #undef IMPL_MODE_DISPATCHER template -void dispatch_src_ctype(const ElemwiseOpParamN<1>&, const TensorND& dst_tensor, - Elemwise::Mode, cudaStream_t); +void dispatch_src_ctype( + const ElemwiseOpParamN<1>&, const TensorND& dst_tensor, Elemwise::Mode, + cudaStream_t); #define DISPATCH(_dt) \ case DTypeTrait<_dt>::enumv: { \ @@ -326,9 +321,9 @@ void dispatch_src_ctype(const ElemwiseOpParamN<1>&, const TensorND& dst_tensor, } template <> -void dispatch_src_ctype(const ElemwiseOpParamN<1>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode, cudaStream_t stream) { +void dispatch_src_ctype( + const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream) { typedef dt_qint8 ctype_src; switch (dst_tensor.layout.dtype.enumv()) { DISPATCH(dtype::QuantizedS8); @@ -341,9 +336,9 @@ void dispatch_src_ctype(const ElemwiseOpParamN<1>& param, } template <> -void dispatch_src_ctype(const ElemwiseOpParamN<1>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode, cudaStream_t stream) { +void dispatch_src_ctype( + const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream) { typedef dt_qint32 ctype_src; switch (dst_tensor.layout.dtype.enumv()) { DISPATCH(dtype::QuantizedS8); @@ -357,9 +352,9 @@ void dispatch_src_ctype(const ElemwiseOpParamN<1>& param, } template <> -void dispatch_src_ctype(const ElemwiseOpParamN<1>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode, cudaStream_t stream) { +void dispatch_src_ctype( + const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream) { typedef dt_qint4 ctype_src; switch (dst_tensor.layout.dtype.enumv()) { DISPATCH(dtype::QuantizedS4); @@ -372,9 +367,9 @@ void dispatch_src_ctype(const ElemwiseOpParamN<1>& param, } template <> -void dispatch_src_ctype(const ElemwiseOpParamN<1>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode, cudaStream_t stream) { +void dispatch_src_ctype( + const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream) { typedef dt_quint4 ctype_src; switch (dst_tensor.layout.dtype.enumv()) { DISPATCH(dtype::Quantized4Asymm); @@ -388,25 +383,24 @@ void dispatch_src_ctype(const ElemwiseOpParamN<1>& param, #undef DISPATCH -#define DISPATCH(_dt) \ - case DTypeTrait<_dt>::enumv: { \ - auto param_a = param[0].layout.dtype.param(); \ - auto param_b = param[1].layout.dtype.param(); \ - auto dst_param = dst_tensor.layout.dtype.param<_dt>(); \ - ModeDispatcher<2, ctype_src, typename DTypeTrait<_dt>::ctype>::run( \ - param, dst_tensor, {param_a, param_b}, dst_param, mode, \ - stream); \ - break; \ +#define DISPATCH(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + auto param_a = param[0].layout.dtype.param(); \ + auto param_b = param[1].layout.dtype.param(); \ + auto dst_param = dst_tensor.layout.dtype.param<_dt>(); \ + ModeDispatcher<2, ctype_src, typename DTypeTrait<_dt>::ctype>::run( \ + param, dst_tensor, {param_a, param_b}, dst_param, mode, stream); \ + break; \ } template -void dispatch_src_ctype(const ElemwiseOpParamN<2>& param, - const TensorND& dst_tensor, Elemwise::Mode mode, - cudaStream_t stream); +void dispatch_src_ctype( + const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream); template <> -void dispatch_src_ctype(const ElemwiseOpParamN<2>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode, cudaStream_t stream) { +void dispatch_src_ctype( + const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream) { typedef dt_qint8 ctype_src; switch (dst_tensor.layout.dtype.enumv()) { DISPATCH(dtype::QuantizedS8); @@ -419,9 +413,9 @@ void dispatch_src_ctype(const ElemwiseOpParamN<2>& param, } template <> -void dispatch_src_ctype(const ElemwiseOpParamN<2>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode, cudaStream_t stream) { +void dispatch_src_ctype( + const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream) { typedef dt_qint32 ctype_src; switch (dst_tensor.layout.dtype.enumv()) { DISPATCH(dtype::QuantizedS8); @@ -435,9 +429,9 @@ void dispatch_src_ctype(const ElemwiseOpParamN<2>& param, } template <> -void dispatch_src_ctype(const ElemwiseOpParamN<2>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode, cudaStream_t stream) { +void dispatch_src_ctype( + const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream) { typedef dt_qint4 ctype_src; switch (dst_tensor.layout.dtype.enumv()) { DISPATCH(dtype::QuantizedS4); @@ -450,9 +444,9 @@ void dispatch_src_ctype(const ElemwiseOpParamN<2>& param, } template <> -void dispatch_src_ctype(const ElemwiseOpParamN<2>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode, cudaStream_t stream) { +void dispatch_src_ctype( + const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream) { typedef dt_quint4 ctype_src; switch (dst_tensor.layout.dtype.enumv()) { DISPATCH(dtype::Quantized4Asymm); @@ -466,26 +460,26 @@ void dispatch_src_ctype(const ElemwiseOpParamN<2>& param, #undef DISPATCH -#define DISPATCH(_dt) \ - case DTypeTrait<_dt>::enumv: { \ - auto param_a = param[0].layout.dtype.param(); \ - auto param_b = param[1].layout.dtype.param(); \ - auto param_c = param[2].layout.dtype.param(); \ - auto dst_param = dst_tensor.layout.dtype.param<_dt>(); \ - ModeDispatcher<3, ctype_src, typename DTypeTrait<_dt>::ctype>::run( \ - param, dst_tensor, {param_a, param_b, param_c}, dst_param, \ - mode, stream); \ - break; \ +#define DISPATCH(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + auto param_a = param[0].layout.dtype.param(); \ + auto param_b = param[1].layout.dtype.param(); \ + auto param_c = param[2].layout.dtype.param(); \ + auto dst_param = dst_tensor.layout.dtype.param<_dt>(); \ + ModeDispatcher<3, ctype_src, typename DTypeTrait<_dt>::ctype>::run( \ + param, dst_tensor, {param_a, param_b, param_c}, dst_param, mode, \ + stream); \ + break; \ } template -void dispatch_src_ctype(const ElemwiseOpParamN<3>& param, - const TensorND& dst_tensor, Elemwise::Mode mode, - cudaStream_t stream); +void dispatch_src_ctype( + const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream); template <> -void dispatch_src_ctype(const ElemwiseOpParamN<3>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode, cudaStream_t stream) { +void dispatch_src_ctype( + const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream) { typedef dt_qint8 ctype_src; switch (dst_tensor.layout.dtype.enumv()) { DISPATCH(dtype::QuantizedS8); @@ -497,9 +491,9 @@ void dispatch_src_ctype(const ElemwiseOpParamN<3>& param, } template <> -void dispatch_src_ctype(const ElemwiseOpParamN<3>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode, cudaStream_t stream) { +void dispatch_src_ctype( + const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream) { typedef dt_qint4 ctype_src; switch (dst_tensor.layout.dtype.enumv()) { DISPATCH(dtype::QuantizedS4); @@ -511,9 +505,9 @@ void dispatch_src_ctype(const ElemwiseOpParamN<3>& param, } template <> -void dispatch_src_ctype(const ElemwiseOpParamN<3>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode, cudaStream_t stream) { +void dispatch_src_ctype( + const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor, + Elemwise::Mode mode, cudaStream_t stream) { typedef dt_quint4 ctype_src; switch (dst_tensor.layout.dtype.enumv()) { DISPATCH(dtype::Quantized4Asymm); @@ -528,9 +522,9 @@ void dispatch_src_ctype(const ElemwiseOpParamN<3>& param, } // namespace -void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<1>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode) { +void ElemwiseMultiTypeImpl::on_quantized_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor, + Elemwise::Mode mode) { megdnn_assert( param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 || param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 || @@ -540,11 +534,11 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<1>& param, param[0].layout.dtype.name()); auto stream = cuda_stream(this->handle()); switch (param[0].layout.dtype.enumv()) { -#define DISPATCH(_dt) \ - case DTypeTrait<_dt>::enumv: { \ - dispatch_src_ctype::ctype>(param, dst_tensor, \ - mode, stream); \ - break; \ +#define DISPATCH(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + dispatch_src_ctype::ctype>( \ + param, dst_tensor, mode, stream); \ + break; \ } DISPATCH(dtype::QuantizedS8); @@ -553,19 +547,18 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<1>& param, DISPATCH(dtype::Quantized4Asymm); default: - megdnn_throw( - ssprintf("Unsupported input dtype %s for ElemwiseMultiType", - param[0].layout.dtype.name())); + megdnn_throw(ssprintf( + "Unsupported input dtype %s for ElemwiseMultiType", + param[0].layout.dtype.name())); } #undef DISPATCH } -void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode) { - megdnn_assert(param[0].layout.dtype.enumv() == - param[1].layout.dtype.enumv()); +void ElemwiseMultiTypeImpl::on_quantized_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor, + Elemwise::Mode mode) { + megdnn_assert(param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv()); megdnn_assert( param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 || param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS32 || @@ -575,11 +568,11 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, param[0].layout.dtype.name()); auto stream = cuda_stream(this->handle()); switch (param[0].layout.dtype.enumv()) { -#define DISPATCH(_dt) \ - case DTypeTrait<_dt>::enumv: { \ - dispatch_src_ctype::ctype>(param, dst_tensor, \ - mode, stream); \ - break; \ +#define DISPATCH(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + dispatch_src_ctype::ctype>( \ + param, dst_tensor, mode, stream); \ + break; \ } DISPATCH(dtype::QuantizedS8); @@ -588,21 +581,19 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, DISPATCH(dtype::Quantized4Asymm); default: - megdnn_throw( - ssprintf("Unsupported input dtype %s for ElemwiseMultiType", - param[0].layout.dtype.name())); + megdnn_throw(ssprintf( + "Unsupported input dtype %s for ElemwiseMultiType", + param[0].layout.dtype.name())); } #undef DISPATCH } -void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, - const TensorND& dst_tensor, - Elemwise::Mode mode) { - megdnn_assert(param[0].layout.dtype.enumv() == - param[1].layout.dtype.enumv()); - megdnn_assert(param[0].layout.dtype.enumv() == - param[2].layout.dtype.enumv()); +void ElemwiseMultiTypeImpl::on_quantized_mode( + const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor, + Elemwise::Mode mode) { + megdnn_assert(param[0].layout.dtype.enumv() == param[1].layout.dtype.enumv()); + megdnn_assert(param[0].layout.dtype.enumv() == param[2].layout.dtype.enumv()); megdnn_assert( param[0].layout.dtype.enumv() == DTypeEnum::QuantizedS8 || @@ -612,11 +603,11 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, param[0].layout.dtype.name()); auto stream = cuda_stream(this->handle()); switch (param[0].layout.dtype.enumv()) { -#define DISPATCH(_dt) \ - case DTypeTrait<_dt>::enumv: { \ - dispatch_src_ctype::ctype>(param, dst_tensor, \ - mode, stream); \ - break; \ +#define DISPATCH(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + dispatch_src_ctype::ctype>( \ + param, dst_tensor, mode, stream); \ + break; \ } DISPATCH(dtype::QuantizedS8); @@ -624,9 +615,9 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, DISPATCH(dtype::Quantized4Asymm); default: - megdnn_throw( - ssprintf("Unsupported input dtype %s for ElemwiseMultiType", - param[0].layout.dtype.name())); + megdnn_throw(ssprintf( + "Unsupported input dtype %s for ElemwiseMultiType", + param[0].layout.dtype.name())); } #undef DISPATCH diff --git a/dnn/src/cuda/elemwise_multi_type/opr_impl.h b/dnn/src/cuda/elemwise_multi_type/opr_impl.h index 249e14be..2d711c77 100644 --- a/dnn/src/cuda/elemwise_multi_type/opr_impl.h +++ b/dnn/src/cuda/elemwise_multi_type/opr_impl.h @@ -17,14 +17,14 @@ namespace megdnn { namespace cuda { class ElemwiseMultiTypeImpl final : public ElemwiseMultiTypeImplHelper { - void on_fuse_mul_add3_int16x32x32x32(const ElemwiseOpParamN<3>& param, - dt_int32* dst) override; + void on_fuse_mul_add3_int16x32x32x32( + const ElemwiseOpParamN<3>& param, dt_int32* dst) override; - void on_fuse_mul_add3_iXxf32xf32xi8(const ElemwiseOpParamN<3>& param, - dt_int8* dst) override; + void on_fuse_mul_add3_iXxf32xf32xi8( + const ElemwiseOpParamN<3>& param, dt_int8* dst) override; - void on_round_shr_saturate_iXxi8xi8(const ElemwiseOpParamN<2>& param, - dt_int8* dst) override; + void on_round_shr_saturate_iXxi8xi8( + const ElemwiseOpParamN<2>& param, dt_int8* dst) override; void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( const ElemwiseOpParamN<6>& param, dt_int8* dst) override; @@ -32,17 +32,20 @@ class ElemwiseMultiTypeImpl final : public ElemwiseMultiTypeImplHelper { void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( const ElemwiseOpParamN<6>& param, dt_int8* dst) override; - void on_round_shr_saturate_iXxi8xi16(const ElemwiseOpParamN<2>& param, - dt_int16* dst) override; - - void on_quantized_mode(const ElemwiseOpParamN<1>& param, - const TensorND& dst, Elemwise::Mode mode) override; - - void on_quantized_mode(const ElemwiseOpParamN<2>& param, - const TensorND& dst, Elemwise::Mode mode) override; - - void on_quantized_mode(const ElemwiseOpParamN<3>& param, - const TensorND& dst, Elemwise::Mode mode) override; + void on_round_shr_saturate_iXxi8xi16( + const ElemwiseOpParamN<2>& param, dt_int16* dst) override; + + void on_quantized_mode( + const ElemwiseOpParamN<1>& param, const TensorND& dst, + Elemwise::Mode mode) override; + + void on_quantized_mode( + const ElemwiseOpParamN<2>& param, const TensorND& dst, + Elemwise::Mode mode) override; + + void on_quantized_mode( + const ElemwiseOpParamN<3>& param, const TensorND& dst, + Elemwise::Mode mode) override; public: using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper; diff --git a/dnn/src/cuda/error_info.cuh b/dnn/src/cuda/error_info.cuh index d95b0a25..3513548a 100644 --- a/dnn/src/cuda/error_info.cuh +++ b/dnn/src/cuda/error_info.cuh @@ -15,7 +15,6 @@ #include "megcore_cdefs.h" #include "megdnn/arch.h" - typedef megcore::AsyncErrorInfo AsyncErrorInfo; #if MEGDNN_CC_CUDA // we can not put this function into anonymous namespace, since it would cause @@ -24,9 +23,9 @@ typedef megcore::AsyncErrorInfo AsyncErrorInfo; namespace { #endif -__device__ void set_async_error_info(AsyncErrorInfo* info, void* tracker, - const char* msg, int arg0 = 0, - int arg1 = 0, int arg2 = 0, int arg3 = 0) +__device__ void set_async_error_info( + AsyncErrorInfo* info, void* tracker, const char* msg, int arg0 = 0, + int arg1 = 0, int arg2 = 0, int arg3 = 0) #if MEGDNN_CC_CUDA { if (info && !atomicAdd(&info->nr_error, 1)) { @@ -45,7 +44,7 @@ __device__ void set_async_error_info(AsyncErrorInfo* info, void* tracker, } } #else -; + ; #endif #if MEGDNN_CC_CUDA diff --git a/dnn/src/cuda/eye/eye.cu b/dnn/src/cuda/eye/eye.cu index e3c36ee4..70667143 100644 --- a/dnn/src/cuda/eye/eye.cu +++ b/dnn/src/cuda/eye/eye.cu @@ -8,43 +8,39 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/cuda/eye/eye.cuh" #include "megdnn/dtype.h" +#include "src/cuda/eye/eye.cuh" #include "src/cuda/utils.cuh" namespace { template -__global__ void kernel(T *dst, uint32_t m, uint32_t n, int k) -{ +__global__ void kernel(T* dst, uint32_t m, uint32_t n, int k) { int32_t i = threadIdx.x + blockIdx.x * blockDim.x; int32_t x = i % n; int32_t y = i / n; - if (i < m*n) { - dst[i] = (y+k == x); + if (i < m * n) { + dst[i] = (y + k == x); } } -} // anonymous namespace +} // anonymous namespace namespace megdnn { namespace cuda { namespace eye { template -void exec_internal(T *dst, size_t m, size_t n, int k, cudaStream_t stream) -{ - kernel<<>>( - dst, m, n, k); +void exec_internal(T* dst, size_t m, size_t n, int k, cudaStream_t stream) { + kernel<<>>(dst, m, n, k); after_kernel_launch(); } -#define INST(T) template void exec_internal(T *, \ - size_t, size_t, int, cudaStream_t); +#define INST(T) template void exec_internal(T*, size_t, size_t, int, cudaStream_t); #define cb(DType) INST(typename DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb) -} // namespace eye -} // namespace cuda -} // namespace megdnn +} // namespace eye +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/eye/eye.cuh b/dnn/src/cuda/eye/eye.cuh index b6efa8d8..93b79dad 100644 --- a/dnn/src/cuda/eye/eye.cuh +++ b/dnn/src/cuda/eye/eye.cuh @@ -9,17 +9,17 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once -#include #include +#include namespace megdnn { namespace cuda { namespace eye { template -void exec_internal(T *dst, size_t m, size_t n, int k, cudaStream_t stream); +void exec_internal(T* dst, size_t m, size_t n, int k, cudaStream_t stream); -} // namespace eye -} // namespace cuda -} // namespace megdnn +} // namespace eye +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/eye/opr_impl.cpp b/dnn/src/cuda/eye/opr_impl.cpp index 83b016ca..fd23b0ee 100644 --- a/dnn/src/cuda/eye/opr_impl.cpp +++ b/dnn/src/cuda/eye/opr_impl.cpp @@ -16,21 +16,19 @@ namespace megdnn { namespace cuda { -void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) -{ +void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(dst.layout, workspace.size); -#define cb(DType) \ - if (dst.layout.dtype.enumv() == DTypeTrait::enumv) { \ - using ctype = typename DTypeTrait::ctype; \ - eye::exec_internal(dst.ptr(), \ - dst.layout.shape[0], dst.layout.shape[1], \ - param().k, \ - cuda_stream(handle())); \ +#define cb(DType) \ + if (dst.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + eye::exec_internal( \ + dst.ptr(), dst.layout.shape[0], dst.layout.shape[1], param().k, \ + cuda_stream(handle())); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/eye/opr_impl.h b/dnn/src/cuda/eye/opr_impl.h index 163edaac..18da12cd 100644 --- a/dnn/src/cuda/eye/opr_impl.h +++ b/dnn/src/cuda/eye/opr_impl.h @@ -14,16 +14,13 @@ namespace megdnn { namespace cuda { -class EyeImpl final: public Eye { - public: - using Eye::Eye; - void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &) override { - return 0; - } +class EyeImpl final : public Eye { +public: + using Eye::Eye; + void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} - diff --git a/dnn/src/cuda/fake_quant/kern.cu b/dnn/src/cuda/fake_quant/kern.cu index 3321a4e9..7dcea797 100644 --- a/dnn/src/cuda/fake_quant/kern.cu +++ b/dnn/src/cuda/fake_quant/kern.cu @@ -15,15 +15,18 @@ namespace megdnn { namespace cuda { -#define cb(_dtype) \ - INST_RUN_ELEMWISE(FakeQuantKernOp::ctype>, \ - DTypeTrait<_dtype>::ctype, 2); \ - INST_RUN_ELEMWISE(FakeQuantBwdKernOp::ctype>, \ - DTypeTrait<_dtype>::ctype, 2); \ - INST_RUN_ELEMWISE(FakeQuantKernOpNonContig::ctype>, \ - DTypeTrait<_dtype>::ctype, 4); \ - INST_RUN_ELEMWISE(FakeQuantBwdKernOpNonContig::ctype>, \ - DTypeTrait<_dtype>::ctype, 5); +#define cb(_dtype) \ + INST_RUN_ELEMWISE( \ + FakeQuantKernOp::ctype>, DTypeTrait<_dtype>::ctype, 2); \ + INST_RUN_ELEMWISE( \ + FakeQuantBwdKernOp::ctype>, DTypeTrait<_dtype>::ctype, \ + 2); \ + INST_RUN_ELEMWISE( \ + FakeQuantKernOpNonContig::ctype>, \ + DTypeTrait<_dtype>::ctype, 4); \ + INST_RUN_ELEMWISE( \ + FakeQuantBwdKernOpNonContig::ctype>, \ + DTypeTrait<_dtype>::ctype, 5); cb(megdnn::dtype::Float32) } // namespace cuda diff --git a/dnn/src/cuda/fake_quant/kern.cuh b/dnn/src/cuda/fake_quant/kern.cuh index 5d7fcb6d..efc95258 100644 --- a/dnn/src/cuda/fake_quant/kern.cuh +++ b/dnn/src/cuda/fake_quant/kern.cuh @@ -35,8 +35,9 @@ struct FakeQuantKernOp { } #if MEGDNN_CC_HOST - FakeQuantKernOp(const TensorND& input, const TensorND& output, - const FakeQuant::Param& param) + FakeQuantKernOp( + const TensorND& input, const TensorND& output, + const FakeQuant::Param& param) : input{input.ptr()}, output{output.ptr()}, qmin(param.qmin), @@ -57,8 +58,9 @@ struct FakeQuantBwdKernOp { } #if MEGDNN_CC_HOST - FakeQuantBwdKernOp(const TensorND& diff, const TensorND& input, - const TensorND& grad, const FakeQuant::Param& param) + FakeQuantBwdKernOp( + const TensorND& diff, const TensorND& input, const TensorND& grad, + const FakeQuant::Param& param) : diff{diff.ptr()}, input{input.ptr()}, grad{grad.ptr()}, @@ -72,8 +74,8 @@ struct FakeQuantKernOpNonContig { ctype qmin; ctype qmax; - __device__ void operator()(uint32_t, ctype& output, ctype input, - ctype scale, ctype zero_point) { + __device__ void operator()( + uint32_t, ctype& output, ctype input, ctype scale, ctype zero_point) { ctype x = round(input / scale) + zero_point; x = fmaxf(fminf(x, qmax), qmin); output = (x - zero_point) * scale; @@ -90,8 +92,9 @@ struct FakeQuantBwdKernOpNonContig { ctype qmin; ctype qmax; - __device__ void operator()(uint32_t, ctype& grad, ctype diff, ctype input, - ctype scale, ctype zero_point) { + __device__ void operator()( + uint32_t, ctype& grad, ctype diff, ctype input, ctype scale, + ctype zero_point) { ctype x = round(input / scale) + zero_point; grad = x <= qmax && x >= qmin ? diff : 0.0; } diff --git a/dnn/src/cuda/fake_quant/opr_impl.cpp b/dnn/src/cuda/fake_quant/opr_impl.cpp index 39d37abf..adc2b6e5 100644 --- a/dnn/src/cuda/fake_quant/opr_impl.cpp +++ b/dnn/src/cuda/fake_quant/opr_impl.cpp @@ -16,13 +16,12 @@ namespace megdnn { namespace cuda { -void FakeQuantForwardImpl::exec(_megdnn_tensor_in input, - _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, - _megdnn_tensor_out output, - _megdnn_workspace workspace) { - check_exec(input.layout, scale.layout, zero_point.layout, output.layout, - workspace.size); +void FakeQuantForwardImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, + _megdnn_tensor_out output, _megdnn_workspace workspace) { + check_exec( + input.layout, scale.layout, zero_point.layout, output.layout, + workspace.size); if (!input.layout.is_contiguous() || !output.layout.is_contiguous()) { return exec_noncontig(input, scale, zero_point, output); @@ -35,21 +34,20 @@ void FakeQuantForwardImpl::exec(_megdnn_tensor_in input, ele_param.init_from_given_tensor(); auto stream = cuda_stream(handle()); -#define cb(DType) \ - if (input.layout.dtype == DType()) { \ - using T = typename DTypeTrait::ctype; \ - run_elemwise, T, 2>(ele_param, stream, \ - {input, output, m_param}); \ - return; \ +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 2>( \ + ele_param, stream, {input, output, m_param}); \ + return; \ } cb(megdnn::dtype::Float32) #undef cb } -void FakeQuantForwardImpl::exec_noncontig(_megdnn_tensor_in input, - _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, - _megdnn_tensor_out output) { +void FakeQuantForwardImpl::exec_noncontig( + _megdnn_tensor_in input, _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, + _megdnn_tensor_out output) { ElemwiseOpParamN<4> ele_param; ele_param[0] = output; ele_param[1] = input; @@ -60,25 +58,23 @@ void FakeQuantForwardImpl::exec_noncontig(_megdnn_tensor_in input, ele_param.init_from_given_tensor(); auto stream = cuda_stream(handle()); -#define cb(DType) \ - if (input.layout.dtype == DType()) { \ - using T = typename DTypeTrait::ctype; \ - run_elemwise, T, 4>(ele_param, stream, \ - {m_param}); \ - return; \ +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 4>(ele_param, stream, {m_param}); \ + return; \ } cb(megdnn::dtype::Float32) #undef cb } -void FakeQuantBackwardImpl::exec(_megdnn_tensor_in diff, - _megdnn_tensor_in input, - _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) { - check_exec(diff.layout, input.layout, scale.layout, zero_point.layout, - grad.layout, workspace.size); +void FakeQuantBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + check_exec( + diff.layout, input.layout, scale.layout, zero_point.layout, grad.layout, + workspace.size); if (!input.layout.is_contiguous() || !diff.layout.is_contiguous() || !grad.layout.is_contiguous()) { @@ -103,11 +99,9 @@ void FakeQuantBackwardImpl::exec(_megdnn_tensor_in diff, #undef cb } -void FakeQuantBackwardImpl::exec_noncontig(_megdnn_tensor_in diff, - _megdnn_tensor_in input, - _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, - _megdnn_tensor_out grad) { +void FakeQuantBackwardImpl::exec_noncontig( + _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_out grad) { ElemwiseOpParamN<5> ele_param; ele_param[0] = grad; ele_param[1] = diff; @@ -119,12 +113,12 @@ void FakeQuantBackwardImpl::exec_noncontig(_megdnn_tensor_in diff, ele_param.init_from_given_tensor(); auto m_param = param(); auto stream = cuda_stream(handle()); -#define cb(DType) \ - if (grad.layout.dtype == DType()) { \ - using T = typename DTypeTrait::ctype; \ - run_elemwise, T, 5>(ele_param, stream, \ - {m_param}); \ - return; \ +#define cb(DType) \ + if (grad.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 5>( \ + ele_param, stream, {m_param}); \ + return; \ } cb(megdnn::dtype::Float32) #undef cb diff --git a/dnn/src/cuda/fake_quant/opr_impl.h b/dnn/src/cuda/fake_quant/opr_impl.h index a8e63bfa..fe8dd371 100644 --- a/dnn/src/cuda/fake_quant/opr_impl.h +++ b/dnn/src/cuda/fake_quant/opr_impl.h @@ -18,37 +18,39 @@ namespace cuda { class FakeQuantForwardImpl : public FakeQuantForward { public: using FakeQuantForward::FakeQuantForward; - void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, _megdnn_tensor_out output, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&, - const TensorLayout&) override { + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_out output, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { return 0; } private: - void exec_noncontig(_megdnn_tensor_in input, _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, - _megdnn_tensor_out output); + void exec_noncontig( + _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_out output); }; class FakeQuantBackwardImpl : public FakeQuantBackward { public: using FakeQuantBackward::FakeQuantBackward; - void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, - _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, - _megdnn_tensor_out grad, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&, const TensorLayout&, - const TensorLayout&) override { + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&) override { return 0; } private: - void exec_noncontig(_megdnn_tensor_in diff, _megdnn_tensor_in input, - _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, - _megdnn_tensor_out grad); + void exec_noncontig( + _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_out grad); }; } // namespace cuda diff --git a/dnn/src/cuda/fill/kern.cu b/dnn/src/cuda/fill/kern.cu index 1daf50a9..bd0bf41e 100644 --- a/dnn/src/cuda/fill/kern.cu +++ b/dnn/src/cuda/fill/kern.cu @@ -8,38 +8,37 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/cuda/fill/kern.cuh" #include "megdnn/dtype.h" +#include "src/cuda/fill/kern.cuh" #include "src/cuda/utils.cuh" namespace { template -__global__ void kernel(T *dst, T value, uint32_t size) { +__global__ void kernel(T* dst, T value, uint32_t size) { int32_t i = threadIdx.x + blockIdx.x * blockDim.x; if (i < size) { dst[i] = value; } } -} // anonymous namespace +} // anonymous namespace namespace megdnn { namespace cuda { namespace fill { template -void exec_internal(T *dst, T value, size_t size, cudaStream_t stream) { +void exec_internal(T* dst, T value, size_t size, cudaStream_t stream) { kernel<<>>(dst, value, size); after_kernel_launch(); } -#define INST(T) template void exec_internal(T *, \ - T, size_t, cudaStream_t); +#define INST(T) template void exec_internal(T*, T, size_t, cudaStream_t); #define cb(DType) INST(typename DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb) -} // namespace fill -} // namespace cuda -} // namespace megdnn +} // namespace fill +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/fill/kern.cuh b/dnn/src/cuda/fill/kern.cuh index a79f9356..6ce3f3fb 100644 --- a/dnn/src/cuda/fill/kern.cuh +++ b/dnn/src/cuda/fill/kern.cuh @@ -9,17 +9,17 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once -#include #include +#include namespace megdnn { namespace cuda { namespace fill { template -void exec_internal(T *dst, T value, size_t size, cudaStream_t stream); +void exec_internal(T* dst, T value, size_t size, cudaStream_t stream); -} // namespace fill -} // namespace cuda -} // namespace megdnn +} // namespace fill +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/fill/opr_impl.cpp b/dnn/src/cuda/fill/opr_impl.cpp index f9966a2c..9d6a8219 100644 --- a/dnn/src/cuda/fill/opr_impl.cpp +++ b/dnn/src/cuda/fill/opr_impl.cpp @@ -9,8 +9,8 @@ * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/cuda/fill/kern.cuh" #include "src/cuda/fill/opr_impl.h" +#include "src/cuda/fill/kern.cuh" #include "src/cuda/utils.h" @@ -21,11 +21,11 @@ void FillImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(dst.layout, workspace.size); auto stream = cuda_stream(handle()); auto size = dst.layout.total_nr_elems(); -#define cb(DType) \ - if (dst.layout.dtype == DType()) { \ - using ctype = typename DTypeTrait::ctype; \ - fill::exec_internal(dst.ptr(), \ - static_cast(param().value), size, stream); \ +#define cb(DType) \ + if (dst.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + fill::exec_internal( \ + dst.ptr(), static_cast(param().value), size, stream); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb diff --git a/dnn/src/cuda/fill/opr_impl.h b/dnn/src/cuda/fill/opr_impl.h index d29eae24..01366e60 100644 --- a/dnn/src/cuda/fill/opr_impl.h +++ b/dnn/src/cuda/fill/opr_impl.h @@ -19,13 +19,10 @@ class FillImpl final : public Fill { public: using Fill::Fill; void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &) override { - return 0; - } + size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } }; } // namespace cuda } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/flip/flip.cu b/dnn/src/cuda/flip/flip.cu index 9ca2819d..2a316509 100644 --- a/dnn/src/cuda/flip/flip.cu +++ b/dnn/src/cuda/flip/flip.cu @@ -24,53 +24,54 @@ namespace { #define rep(i, n) for (size_t i = 0; i < (n); ++i) template -__global__ void flip_kern(const T *src, T *dst, size_t N, size_t H, size_t W, - size_t stride1, size_t stride2, size_t stride3) { +__global__ void flip_kern( + const T* src, T* dst, size_t N, size_t H, size_t W, size_t stride1, + size_t stride2, size_t stride3) { __shared__ T cache[BX][BY][IC]; int ow = blockIdx.x * blockDim.x + threadIdx.x; int oh = blockIdx.y * blockDim.y + threadIdx.y; if (ow < W && oh < H) { - int iw = horizontal ? W - ow - 1 : ow; int ih = vertical ? H - oh - 1 : oh; #pragma unroll rep(c, IC) { cache[threadIdx.y][threadIdx.x][c] = - src[blockIdx.z * stride1 + ih * stride2 + iw * stride3 + c]; + src[blockIdx.z * stride1 + ih * stride2 + iw * stride3 + c]; } __syncthreads(); #pragma unroll rep(c, IC) { dst[blockIdx.z * stride1 + oh * stride2 + ow * stride3 + c] = - cache[threadIdx.y][threadIdx.x][c]; + cache[threadIdx.y][threadIdx.x][c]; } } } #undef rep -} // anonymous namespace +} // anonymous namespace namespace flip { template -void flip(const T *src, T *dst, size_t N, size_t H, size_t W, size_t IC, - size_t stride1, size_t stride2, size_t stride3, cudaStream_t stream) { +void flip( + const T* src, T* dst, size_t N, size_t H, size_t W, size_t IC, size_t stride1, + size_t stride2, size_t stride3, cudaStream_t stream) { dim3 threads(BX, BY); dim3 blocks(DIVUP(W, BX), DIVUP(H, BY), N); megdnn_assert(IC == 1 || IC == 3); if (IC == 1) flip_kern<<>>( - src, dst, N, H, W, stride1, stride2, stride3); + src, dst, N, H, W, stride1, stride2, stride3); else flip_kern<<>>( - src, dst, N, H, W, stride1, stride2, stride3); + src, dst, N, H, W, stride1, stride2, stride3); after_kernel_launch(); } -#define INST(T, vertical, horizontal) \ - template void flip( \ - const T *src, T *dst, size_t N, size_t H, size_t W, size_t IC, \ - size_t stride1, size_t stride2, size_t stride3, cudaStream_t); +#define INST(T, vertical, horizontal) \ + template void flip( \ + const T* src, T* dst, size_t N, size_t H, size_t W, size_t IC, \ + size_t stride1, size_t stride2, size_t stride3, cudaStream_t); #define cb(DType) \ INST(typename DTypeTrait::ctype, true, true) \ diff --git a/dnn/src/cuda/flip/flip.cuh b/dnn/src/cuda/flip/flip.cuh index f19fab7b..b13feda2 100644 --- a/dnn/src/cuda/flip/flip.cuh +++ b/dnn/src/cuda/flip/flip.cuh @@ -9,16 +9,17 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once -#include #include +#include namespace megdnn { namespace cuda { namespace flip { template -void flip(const T *src, T *dst, size_t N, size_t H, size_t W, size_t IC, - size_t stride1, size_t stride2, size_t stride3, cudaStream_t stream); +void flip( + const T* src, T* dst, size_t N, size_t H, size_t W, size_t IC, size_t stride1, + size_t stride2, size_t stride3, cudaStream_t stream); } // namespace flip } // namespace cuda diff --git a/dnn/src/cuda/flip/opr_impl.cpp b/dnn/src/cuda/flip/opr_impl.cpp index 4de71de6..ea5a1367 100644 --- a/dnn/src/cuda/flip/opr_impl.cpp +++ b/dnn/src/cuda/flip/opr_impl.cpp @@ -9,13 +9,13 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./flip.cuh" #include "./opr_impl.h" +#include "./flip.cuh" +#include +#include "src/common/utils.h" #include "src/cuda/handle.h" #include "src/cuda/utils.h" -#include "src/common/utils.h" -#include namespace megdnn { namespace cuda { @@ -23,51 +23,50 @@ namespace cuda { namespace flip_intl { template -void flip_exec(const ctype *src, ctype *dst, size_t N, size_t IH, size_t IW, - size_t IC, size_t stride1, size_t stride2, size_t stride3, - bool vertical, bool horizontal, - cudaStream_t stream) { +void flip_exec( + const ctype* src, ctype* dst, size_t N, size_t IH, size_t IW, size_t IC, + size_t stride1, size_t stride2, size_t stride3, bool vertical, bool horizontal, + cudaStream_t stream) { if (vertical) { if (horizontal) { - flip::flip(src, dst, N, IH, IW, IC, stride1, - stride2, stride3, stream); + flip::flip( + src, dst, N, IH, IW, IC, stride1, stride2, stride3, stream); } else { - flip::flip(src, dst, N, IH, IW, IC, stride1, - stride2, stride3, stream); + flip::flip( + src, dst, N, IH, IW, IC, stride1, stride2, stride3, stream); } } else { if (horizontal) { - flip::flip(src, dst, N, IH, IW, IC, stride1, - stride2, stride3, stream); + flip::flip( + src, dst, N, IH, IW, IC, stride1, stride2, stride3, stream); } else { - flip::flip(src, dst, N, IH, IW, IC, stride1, - stride2, stride3, stream); + flip::flip( + src, dst, N, IH, IW, IC, stride1, stride2, stride3, stream); } } } } // namespace flip_intl -void FlipImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_workspace workspace) { +void FlipImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); auto stream = cuda_stream(handle()); //! src layout is the same as dst layout size_t N = src.layout.shape[0]; size_t batch_size = 0; -#define cb(DType) \ - if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ - using ctype = typename DTypeTrait::ctype; \ - ctype* src_ptr = src.ptr() + curr_batch * src.layout.stride[0]; \ - ctype* dst_ptr = dst.ptr() + curr_batch * src.layout.stride[0]; \ - batch_size = std::min(N - curr_batch, max_batch); \ - flip_intl::flip_exec(src_ptr, dst_ptr, batch_size, \ - src.layout.shape[1], src.layout.shape[2], \ - src.layout.shape[3], src.layout.stride[0], \ - src.layout.stride[1], \ - src.layout.stride[2], param().vertical, \ - param().horizontal, stream); \ +#define cb(DType) \ + if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + ctype* src_ptr = src.ptr() + curr_batch * src.layout.stride[0]; \ + ctype* dst_ptr = dst.ptr() + curr_batch * src.layout.stride[0]; \ + batch_size = std::min(N - curr_batch, max_batch); \ + flip_intl::flip_exec( \ + src_ptr, dst_ptr, batch_size, src.layout.shape[1], \ + src.layout.shape[2], src.layout.shape[3], src.layout.stride[0], \ + src.layout.stride[1], src.layout.stride[2], param().vertical, \ + param().horizontal, stream); \ } size_t curr_batch = 0; diff --git a/dnn/src/cuda/flip/opr_impl.h b/dnn/src/cuda/flip/opr_impl.h index 39e5fb24..d6730ca9 100644 --- a/dnn/src/cuda/flip/opr_impl.h +++ b/dnn/src/cuda/flip/opr_impl.h @@ -15,14 +15,14 @@ namespace megdnn { namespace cuda { class FlipImpl : public Flip { - public: +public: using Flip::Flip; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &) override { + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { return 0; } }; diff --git a/dnn/src/cuda/fp16_help.cuh b/dnn/src/cuda/fp16_help.cuh index 9d39c01c..f85a82df 100644 --- a/dnn/src/cuda/fp16_help.cuh +++ b/dnn/src/cuda/fp16_help.cuh @@ -17,20 +17,17 @@ namespace megdnn { namespace cuda { -__device__ __forceinline__ float fma(const float a, const float b, - const float c) { +__device__ __forceinline__ float fma(const float a, const float b, const float c) { return a * b + c; } -__device__ __forceinline__ float2 fma2(const float2 a, const float2 b, - const float2 c) { +__device__ __forceinline__ float2 fma2(const float2 a, const float2 b, const float2 c) { return {a.x * b.x + c.x, a.y * b.y + c.y}; } #if CUDA_VERSION >= 9000 -__device__ __forceinline__ __half fma(const __half a, const __half b, - const __half c) { +__device__ __forceinline__ __half fma(const __half a, const __half b, const __half c) { #if __CUDA_ARCH__ >= 530 return __hfma(a, b, c); #else @@ -38,21 +35,19 @@ __device__ __forceinline__ __half fma(const __half a, const __half b, #endif } -__device__ __forceinline__ __half2 fma2(const __half2 a, const __half2 b, - const __half2 c) { +__device__ __forceinline__ __half2 +fma2(const __half2 a, const __half2 b, const __half2 c) { #if __CUDA_ARCH__ >= 530 return __hfma2(a, b, c); #else - return {__float2half(__half2float(a.x) * __half2float(b.x) + - __half2float(c.x)), - __float2half(__half2float(a.y) * __half2float(b.y) + - __half2float(c.y))}; + return {__float2half(__half2float(a.x) * __half2float(b.x) + __half2float(c.x)), + __float2half(__half2float(a.y) * __half2float(b.y) + __half2float(c.y))}; #endif } #endif // CUDA_VERSION >= 9000 -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/gaussian_blur/gaussian_blur.cu b/dnn/src/cuda/gaussian_blur/gaussian_blur.cu index 5f934105..ed202922 100644 --- a/dnn/src/cuda/gaussian_blur/gaussian_blur.cu +++ b/dnn/src/cuda/gaussian_blur/gaussian_blur.cu @@ -74,38 +74,33 @@ static const uint8_t BITS = 8; #define rep(i, n) for (size_t i = 0; i < (n); ++i) template -__global__ void prepare_kernel(uint8_t* kernel_ptr, size_t kernel_height, - size_t kernel_width, double sigma_x, - double sigma_y); +__global__ void prepare_kernel( + uint8_t* kernel_ptr, size_t kernel_height, size_t kernel_width, double sigma_x, + double sigma_y); template <> -__global__ void prepare_kernel(uint8_t* _kernel_ptr, - size_t kernel_height, size_t kernel_width, - double sigma_x, double sigma_y) { +__global__ void prepare_kernel( + uint8_t* _kernel_ptr, size_t kernel_height, size_t kernel_width, double sigma_x, + double sigma_y) { float* kernel_ptr = reinterpret_cast(_kernel_ptr); const int kSmallGaussianSize = 7; const float small_gaussian_table[4][kSmallGaussianSize] = { {1.f}, {0.25f, 0.5f, 0.25f}, {0.0625f, 0.25f, 0.375f, 0.25f, 0.0625f}, - {0.03125f, 0.109375f, 0.21875f, 0.28125f, 0.21875f, 0.109375f, - 0.03125f}}; + {0.03125f, 0.109375f, 0.21875f, 0.28125f, 0.21875f, 0.109375f, 0.03125f}}; - const float* fixed_kernel_x = - (kernel_width % 2 == 1 && kernel_width <= kSmallGaussianSize && - sigma_x <= 0) - ? small_gaussian_table[kernel_width >> 1] - : NULL; - const float* fixed_kernel_y = - (kernel_height % 2 == 1 && kernel_height <= kSmallGaussianSize && - sigma_y <= 0) - ? small_gaussian_table[kernel_height >> 1] - : NULL; - sigma_x = - sigma_x > 0 ? sigma_x : ((kernel_width - 1) * 0.5 - 1) * 0.3 + 0.8; + const float* fixed_kernel_x = (kernel_width % 2 == 1 && + kernel_width <= kSmallGaussianSize && sigma_x <= 0) + ? small_gaussian_table[kernel_width >> 1] + : NULL; + const float* fixed_kernel_y = (kernel_height % 2 == 1 && + kernel_height <= kSmallGaussianSize && sigma_y <= 0) + ? small_gaussian_table[kernel_height >> 1] + : NULL; + sigma_x = sigma_x > 0 ? sigma_x : ((kernel_width - 1) * 0.5 - 1) * 0.3 + 0.8; double scale_2x = -0.5 / (sigma_x * sigma_x); - sigma_y = - sigma_y > 0 ? sigma_y : ((kernel_height - 1) * 0.5 - 1) * 0.3 + 0.8; + sigma_y = sigma_y > 0 ? sigma_y : ((kernel_height - 1) * 0.5 - 1) * 0.3 + 0.8; double scale_2y = -0.5 / (sigma_y * sigma_y); //! calc gaussian kernel @@ -133,34 +128,28 @@ __global__ void prepare_kernel(uint8_t* _kernel_ptr, } template <> -__global__ void prepare_kernel(uint8_t* _kernel_ptr, - size_t kernel_height, - size_t kernel_width, double sigma_x, - double sigma_y) { +__global__ void prepare_kernel( + uint8_t* _kernel_ptr, size_t kernel_height, size_t kernel_width, double sigma_x, + double sigma_y) { int32_t* kernel_ptr = reinterpret_cast(_kernel_ptr); const int kSmallGaussianSize = 7; const float small_gaussian_table[4][kSmallGaussianSize] = { {1.f}, {0.25f, 0.5f, 0.25f}, {0.0625f, 0.25f, 0.375f, 0.25f, 0.0625f}, - {0.03125f, 0.109375f, 0.21875f, 0.28125f, 0.21875f, 0.109375f, - 0.03125f}}; + {0.03125f, 0.109375f, 0.21875f, 0.28125f, 0.21875f, 0.109375f, 0.03125f}}; - const float* fixed_kernel_x = - (kernel_width % 2 == 1 && kernel_width <= kSmallGaussianSize && - sigma_x <= 0) - ? small_gaussian_table[kernel_width >> 1] - : NULL; - const float* fixed_kernel_y = - (kernel_height % 2 == 1 && kernel_height <= kSmallGaussianSize && - sigma_y <= 0) - ? small_gaussian_table[kernel_height >> 1] - : NULL; - sigma_x = - sigma_x > 0 ? sigma_x : ((kernel_width - 1) * 0.5 - 1) * 0.3 + 0.8; + const float* fixed_kernel_x = (kernel_width % 2 == 1 && + kernel_width <= kSmallGaussianSize && sigma_x <= 0) + ? small_gaussian_table[kernel_width >> 1] + : NULL; + const float* fixed_kernel_y = (kernel_height % 2 == 1 && + kernel_height <= kSmallGaussianSize && sigma_y <= 0) + ? small_gaussian_table[kernel_height >> 1] + : NULL; + sigma_x = sigma_x > 0 ? sigma_x : ((kernel_width - 1) * 0.5 - 1) * 0.3 + 0.8; double scale_2x = -0.5 / (sigma_x * sigma_x); - sigma_y = - sigma_y > 0 ? sigma_y : ((kernel_height - 1) * 0.5 - 1) * 0.3 + 0.8; + sigma_y = sigma_y > 0 ? sigma_y : ((kernel_height - 1) * 0.5 - 1) * 0.3 + 0.8; double scale_2y = -0.5 / (sigma_y * sigma_y); size_t kernel_size = kernel_width * kernel_height; @@ -179,8 +168,7 @@ __global__ void prepare_kernel(uint8_t* _kernel_ptr, //! calc the sum of vertical kernel filter double sum_x = 0; - float* kx_ptr = - reinterpret_cast(kernel_ptr + kernel_size) + kernel_height; + float* kx_ptr = reinterpret_cast(kernel_ptr + kernel_size) + kernel_height; rep(ix, kernel_width) { double x = ix - (kernel_width - 1) * 0.5; double kx = fixed_kernel_x ? static_cast(fixed_kernel_x[ix]) @@ -203,11 +191,10 @@ __global__ void prepare_kernel(uint8_t* _kernel_ptr, } template -__global__ void gaussian_blur_kern(const T* src, T* dst, size_t N, size_t H, - size_t W, size_t stride0, size_t stride1, - size_t stride2, size_t stride3, - uint8_t* kernel_ptr, size_t kernel_height, - size_t kernel_width) { +__global__ void gaussian_blur_kern( + const T* src, T* dst, size_t N, size_t H, size_t W, size_t stride0, + size_t stride1, size_t stride2, size_t stride3, uint8_t* kernel_ptr, + size_t kernel_height, size_t kernel_width) { int iw = blockIdx.x * blockDim.x + threadIdx.x; int ih = blockIdx.y * blockDim.y + threadIdx.y; if (iw < W && ih < H) { @@ -225,14 +212,12 @@ __global__ void gaussian_blur_kern(const T* src, T* dst, size_t N, size_t H, if (x != -1 && y != -1) { if (is_same::value) { val += static_cast(reinterpret_cast( - kernel_ptr)[iy * kernel_width + - ix]) * + kernel_ptr)[iy * kernel_width + ix]) * src[blockIdx.z * stride0 + y * stride1 + x * stride2 + c * stride3]; } else { val += static_cast(reinterpret_cast( - kernel_ptr)[iy * kernel_width + - ix]) * + kernel_ptr)[iy * kernel_width + ix]) * src[blockIdx.z * stride0 + y * stride1 + x * stride2 + c * stride3]; } @@ -241,13 +226,12 @@ __global__ void gaussian_blur_kern(const T* src, T* dst, size_t N, size_t H, } if (is_same::value) { - dst[blockIdx.z * stride0 + ih * stride1 + iw * stride2 + - c * stride3] = + dst[blockIdx.z * stride0 + ih * stride1 + iw * stride2 + c * stride3] = static_cast(static_cast(val) >> (2 * BITS)); } else { //! float32 - dst[blockIdx.z * stride0 + ih * stride1 + iw * stride2 + - c * stride3] = static_cast(val); + dst[blockIdx.z * stride0 + ih * stride1 + iw * stride2 + c * stride3] = + static_cast(val); } } } @@ -259,14 +243,14 @@ __global__ void gaussian_blur_kern(const T* src, T* dst, size_t N, size_t H, namespace gaussian_blur { template -void gaussian_blur(const T* src, T* dst, size_t N, size_t H, size_t W, - size_t stride0, size_t stride1, size_t stride2, - size_t stride3, uint8_t* kernel_ptr, size_t kernel_height, - size_t kernel_width, double sigma_x, double sigma_y, - cudaStream_t stream) { +void gaussian_blur( + const T* src, T* dst, size_t N, size_t H, size_t W, size_t stride0, + size_t stride1, size_t stride2, size_t stride3, uint8_t* kernel_ptr, + size_t kernel_height, size_t kernel_width, double sigma_x, double sigma_y, + cudaStream_t stream) { //! calc gaussian kernel - prepare_kernel<<<1, 1, 0, stream>>>(kernel_ptr, kernel_height, - kernel_width, sigma_x, sigma_y); + prepare_kernel<<<1, 1, 0, stream>>>( + kernel_ptr, kernel_height, kernel_width, sigma_x, sigma_y); cuda_check(cudaStreamSynchronize(stream)); static const int BX = 16; @@ -279,11 +263,11 @@ void gaussian_blur(const T* src, T* dst, size_t N, size_t H, size_t W, after_kernel_launch(); } -#define INST(T, CH, bmode) \ - template void gaussian_blur( \ - const T* src, T* dst, size_t N, size_t H, size_t W, \ - size_t stride0, size_t stride1, size_t stride2, size_t stride3, \ - uint8_t*, size_t, size_t, double, double, cudaStream_t); +#define INST(T, CH, bmode) \ + template void gaussian_blur( \ + const T* src, T* dst, size_t N, size_t H, size_t W, size_t stride0, \ + size_t stride1, size_t stride2, size_t stride3, uint8_t*, size_t, size_t, \ + double, double, cudaStream_t); #define cb(DType) \ INST(typename DTypeTrait::ctype, 1, BORDER_REPLICATE) \ diff --git a/dnn/src/cuda/gaussian_blur/gaussian_blur.cuh b/dnn/src/cuda/gaussian_blur/gaussian_blur.cuh index 4d878072..4aea4b82 100644 --- a/dnn/src/cuda/gaussian_blur/gaussian_blur.cuh +++ b/dnn/src/cuda/gaussian_blur/gaussian_blur.cuh @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once -#include #include +#include #include "src/common/cv/enums.h" #include @@ -20,11 +20,11 @@ namespace cuda { namespace gaussian_blur { template -void gaussian_blur(const T* src, T* dst, size_t N, size_t H, size_t W, - size_t stride0, size_t stride1, size_t stride2, - size_t stride3, uint8_t* kernel_ptr, size_t kernel_height, - size_t kernel_width, double sigma_x, double sigma_y, - cudaStream_t stream); +void gaussian_blur( + const T* src, T* dst, size_t N, size_t H, size_t W, size_t stride0, + size_t stride1, size_t stride2, size_t stride3, uint8_t* kernel_ptr, + size_t kernel_height, size_t kernel_width, double sigma_x, double sigma_y, + cudaStream_t stream); } // namespace gaussian_blur } // namespace cuda diff --git a/dnn/src/cuda/gaussian_blur/opr_impl.cpp b/dnn/src/cuda/gaussian_blur/opr_impl.cpp index b9cf616c..1fb9c04d 100644 --- a/dnn/src/cuda/gaussian_blur/opr_impl.cpp +++ b/dnn/src/cuda/gaussian_blur/opr_impl.cpp @@ -12,12 +12,12 @@ #include "./opr_impl.h" #include "./gaussian_blur.cuh" -#include "src/cuda/handle.h" -#include "src/cuda/utils.h" #include "src/common/cv/common.h" #include "src/common/cv/enums.h" #include "src/common/cv/filter.h" #include "src/common/utils.h" +#include "src/cuda/handle.h" +#include "src/cuda/utils.h" #include @@ -27,25 +27,21 @@ namespace cuda { namespace intl { template -void gaussian_blur_exec(const ctype* src, ctype* dst, size_t N, size_t IH, - size_t IW, size_t IC, size_t stride0, size_t stride1, - size_t stride2, size_t stride3, - uint8_t* kernel_ptr, size_t kernel_height, - size_t kernel_width, double sigma_x, double sigma_y, - param::GaussianBlur::BorderMode bmode, - cudaStream_t stream) { +void gaussian_blur_exec( + const ctype* src, ctype* dst, size_t N, size_t IH, size_t IW, size_t IC, + size_t stride0, size_t stride1, size_t stride2, size_t stride3, + uint8_t* kernel_ptr, size_t kernel_height, size_t kernel_width, double sigma_x, + double sigma_y, param::GaussianBlur::BorderMode bmode, cudaStream_t stream) { megdnn_assert(IC == 1_z || IC == 3_z); -#define INIT_KERN(bmode) \ - if (IC == 1) { \ - gaussian_blur::gaussian_blur( \ - src, dst, N, IH, IW, stride0, stride1, stride2, stride3, \ - kernel_ptr, kernel_height, kernel_width, sigma_x, sigma_y, \ - stream); \ - } else { \ - gaussian_blur::gaussian_blur( \ - src, dst, N, IH, IW, stride0, stride1, stride2, stride3, \ - kernel_ptr, kernel_height, kernel_width, sigma_x, sigma_y, \ - stream); \ +#define INIT_KERN(bmode) \ + if (IC == 1) { \ + gaussian_blur::gaussian_blur( \ + src, dst, N, IH, IW, stride0, stride1, stride2, stride3, kernel_ptr, \ + kernel_height, kernel_width, sigma_x, sigma_y, stream); \ + } else { \ + gaussian_blur::gaussian_blur( \ + src, dst, N, IH, IW, stride0, stride1, stride2, stride3, kernel_ptr, \ + kernel_height, kernel_width, sigma_x, sigma_y, stream); \ } switch (bmode) { @@ -68,29 +64,28 @@ void gaussian_blur_exec(const ctype* src, ctype* dst, size_t N, size_t IH, } // namespace intl -void GaussianBlurImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_workspace workspace) { - megdnn_assert(src.layout.dtype == dtype::Uint8() || - src.layout.dtype == dtype::Float32()); +void GaussianBlurImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_workspace workspace) { + megdnn_assert( + src.layout.dtype == dtype::Uint8() || src.layout.dtype == dtype::Float32()); check_exec(src.layout, dst.layout, workspace.size); auto stream = cuda_stream(handle()); //! src layout is the same as dst layout size_t N = src.layout.shape[0]; size_t batch_size = 0; -#define cb(DType) \ - if (src.layout.dtype == DType()) { \ - using ctype = typename DTypeTrait::ctype; \ - ctype* src_ptr = src.ptr() + curr_batch * src.layout.stride[0]; \ - ctype* dst_ptr = dst.ptr() + curr_batch * src.layout.stride[0]; \ - batch_size = std::min(N - curr_batch, max_batch_x_channel); \ - intl::gaussian_blur_exec( \ - src_ptr, dst_ptr, batch_size, src.layout.shape[1], \ - src.layout.shape[2], src.layout.shape[3], \ - src.layout.stride[0], src.layout.stride[1], \ - src.layout.stride[2], src.layout.stride[3], \ - workspace.ptr(), m_kernel_height, m_kernel_width, \ - m_sigma_x, m_sigma_y, param().border_mode, stream); \ +#define cb(DType) \ + if (src.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + ctype* src_ptr = src.ptr() + curr_batch * src.layout.stride[0]; \ + ctype* dst_ptr = dst.ptr() + curr_batch * src.layout.stride[0]; \ + batch_size = std::min(N - curr_batch, max_batch_x_channel); \ + intl::gaussian_blur_exec( \ + src_ptr, dst_ptr, batch_size, src.layout.shape[1], \ + src.layout.shape[2], src.layout.shape[3], src.layout.stride[0], \ + src.layout.stride[1], src.layout.stride[2], src.layout.stride[3], \ + workspace.ptr(), m_kernel_height, m_kernel_width, m_sigma_x, \ + m_sigma_y, param().border_mode, stream); \ } size_t max_batch_x_channel = max_batch_x_channel_size(); diff --git a/dnn/src/cuda/gaussian_blur/opr_impl.h b/dnn/src/cuda/gaussian_blur/opr_impl.h index 0486e2e5..8113ef34 100644 --- a/dnn/src/cuda/gaussian_blur/opr_impl.h +++ b/dnn/src/cuda/gaussian_blur/opr_impl.h @@ -11,78 +11,79 @@ #pragma once #include "megdnn/oprs.h" -#include "src/common/utils.h" -#include "src/common/cv/common.h" #include +#include "src/common/cv/common.h" +#include "src/common/utils.h" namespace megdnn { namespace cuda { class GaussianBlurImpl : public GaussianBlur { - public: - using GaussianBlur::GaussianBlur; +public: + using GaussianBlur::GaussianBlur; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout&) override { - //! current only support float and uint8 - megdnn_assert(src.dtype == dtype::Float32() || - src.dtype == dtype::Uint8()); + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout&) override { + //! current only support float and uint8 + megdnn_assert(src.dtype == dtype::Float32() || src.dtype == dtype::Uint8()); - //! Calc gaussian kernel real size - double sigma_x = param().sigma_x; - double sigma_y = param().sigma_y; - uint32_t kernel_height = param().kernel_height; - uint32_t kernel_width = param().kernel_width; + //! Calc gaussian kernel real size + double sigma_x = param().sigma_x; + double sigma_y = param().sigma_y; + uint32_t kernel_height = param().kernel_height; + uint32_t kernel_width = param().kernel_width; - if (sigma_y <= 0) - sigma_y = sigma_x; + if (sigma_y <= 0) + sigma_y = sigma_x; - auto get_size = [&src](double sigma) { - double num = 0; - if (src.dtype == dtype::Uint8()) { - num = sigma * 3 * 2 + 1; - } else { - num = sigma * 4 * 2 + 1; - } - return static_cast(num + (num >= 0 ? 0.5 : -0.5)) | 1; - }; - - if (kernel_width <= 0 && sigma_x > 0) { - m_kernel_width = get_size(sigma_x); - } else { - m_kernel_width = kernel_width; - } - if (kernel_height <= 0 && sigma_y > 0) { - m_kernel_height = get_size(sigma_y); + auto get_size = [&src](double sigma) { + double num = 0; + if (src.dtype == dtype::Uint8()) { + num = sigma * 3 * 2 + 1; } else { - m_kernel_height = kernel_height; + num = sigma * 4 * 2 + 1; } - megdnn_assert(m_kernel_width > 0 && m_kernel_width % 2 == 1 && - m_kernel_height > 0 && m_kernel_height % 2 == 1); + return static_cast(num + (num >= 0 ? 0.5 : -0.5)) | 1; + }; - m_sigma_x = std::max(sigma_x, 0.); - m_sigma_y = std::max(sigma_y, 0.); + if (kernel_width <= 0 && sigma_x > 0) { + m_kernel_width = get_size(sigma_x); + } else { + m_kernel_width = kernel_width; + } + if (kernel_height <= 0 && sigma_y > 0) { + m_kernel_height = get_size(sigma_y); + } else { + m_kernel_height = kernel_height; + } + megdnn_assert( + m_kernel_width > 0 && m_kernel_width % 2 == 1 && m_kernel_height > 0 && + m_kernel_height % 2 == 1); - if (src.dtype == dtype::Uint8()) { - //! element [0, m_kernel_width * m_kernel_height - 1] store the - //! filter matrix of type int32_t, others store float value - //! kernel_x and kernel_y. - return m_kernel_width * m_kernel_height * sizeof(int32_t) + - (m_kernel_width + m_kernel_height) * sizeof(float); - } else { - //! float32 - return m_kernel_width * m_kernel_height * sizeof(float); - } + m_sigma_x = std::max(sigma_x, 0.); + m_sigma_y = std::max(sigma_y, 0.); + + if (src.dtype == dtype::Uint8()) { + //! element [0, m_kernel_width * m_kernel_height - 1] store the + //! filter matrix of type int32_t, others store float value + //! kernel_x and kernel_y. + return m_kernel_width * m_kernel_height * sizeof(int32_t) + + (m_kernel_width + m_kernel_height) * sizeof(float); + } else { + //! float32 + return m_kernel_width * m_kernel_height * sizeof(float); } + } - private: - uint32_t m_kernel_height; - uint32_t m_kernel_width; - double m_sigma_x; - double m_sigma_y; +private: + uint32_t m_kernel_height; + uint32_t m_kernel_width; + double m_sigma_x; + double m_sigma_y; }; // class GaussianBlurImpl diff --git a/dnn/src/cuda/group_local/bwd_data.cpp b/dnn/src/cuda/group_local/bwd_data.cpp index 3b4b523b..fc98b48a 100644 --- a/dnn/src/cuda/group_local/bwd_data.cpp +++ b/dnn/src/cuda/group_local/bwd_data.cpp @@ -17,80 +17,62 @@ namespace megdnn { namespace cuda { -void GroupLocalBackwardDataImpl::exec(_megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) -{ +void GroupLocalBackwardDataImpl::exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { check_exec(filter.layout, diff.layout, grad.layout, workspace.size); auto G = filter.layout[0]; - auto N = grad.layout.shape[0], IC = grad.layout.shape[1]/G, + auto N = grad.layout.shape[0], IC = grad.layout.shape[1] / G, IH = grad.layout.shape[2], IW = grad.layout.shape[3], - OC = diff.layout.shape[1]/G, - OH = diff.layout.shape[2], OW = diff.layout.shape[3]; + OC = diff.layout.shape[1] / G, OH = diff.layout.shape[2], + OW = diff.layout.shape[3]; auto FH = filter.layout.shape[4], FW = filter.layout.shape[5]; auto PH = param().pad_h, PW = param().pad_w; auto SH = param().stride_h, SW = param().stride_w; - float *sptr = grad.ptr(); - const float *fptr = filter.ptr(); - const float *dptr = diff.ptr(); - float *wptr = workspace.ptr(); + float* sptr = grad.ptr(); + const float* fptr = filter.ptr(); + const float* dptr = diff.ptr(); + float* wptr = workspace.ptr(); auto handle = concrete_handle(this->handle()); auto stream = cuda_stream(this->handle()); auto cublas = cublas_handle(this->handle()); auto one = handle->one_device(); auto zero = handle->zero_device(); - megdnn_assert(local::can_backward_data_proxy_convnet(N, IC, IH, IW, - OC, OH, OW, - FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW), + megdnn_assert( + local::can_backward_data_proxy_convnet( + N, IC, IH, IW, OC, OH, OW, FH, FW, G * IC * IH * IW, + G * OC * OH * OW, PH, PW, SH, SW), "Cannot do Group Local bwd data."); for (size_t g = 0; g < G; ++g) { - local::backward_data_proxy_convnet(fptr + g*OH*OW*IC*FH*FW*OC, - dptr + g*OC*OH*OW, - sptr + g*IC*IH*IW, - wptr, - N, IC, IH, IW, - OC, OH, OW, - FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW, - cublas, stream, one, zero); + local::backward_data_proxy_convnet( + fptr + g * OH * OW * IC * FH * FW * OC, dptr + g * OC * OH * OW, + sptr + g * IC * IH * IW, wptr, N, IC, IH, IW, OC, OH, OW, FH, FW, + G * IC * IH * IW, G * OC * OH * OW, PH, PW, SH, SW, cublas, stream, one, + zero); } } -GroupLocalBackwardDataImpl::GroupLocalBackwardDataImpl(Handle *handle): - GroupLocalBackwardData(handle) -{ -} +GroupLocalBackwardDataImpl::GroupLocalBackwardDataImpl(Handle* handle) + : GroupLocalBackwardData(handle) {} size_t GroupLocalBackwardDataImpl::get_workspace_in_bytes( - const TensorLayout &filter, - const TensorLayout &diff, - const TensorLayout &grad) -{ + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) { auto G = filter[0]; - auto N = grad.shape[0], IC = grad.shape[1]/G, - IH = grad.shape[2], IW = grad.shape[3], - OC = diff.shape[1]/G, - OH = diff.shape[2], OW = diff.shape[3]; + auto N = grad.shape[0], IC = grad.shape[1] / G, IH = grad.shape[2], + IW = grad.shape[3], OC = diff.shape[1] / G, OH = diff.shape[2], + OW = diff.shape[3]; auto FH = filter.shape[4], FW = filter.shape[5]; auto PH = param().pad_h, PW = param().pad_w; auto SH = param().stride_h, SW = param().stride_w; - auto res = local::get_workspace_in_floats_backward_data_proxy_convnet(N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW) * sizeof(float); + auto res = local::get_workspace_in_floats_backward_data_proxy_convnet( + N, IC, IH, IW, OC, OH, OW, FH, FW, G * IC * IH * IW, + G * OC * OH * OW, PH, PW, SH, SW) * + sizeof(float); return res; } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/group_local/bwd_filter.cpp b/dnn/src/cuda/group_local/bwd_filter.cpp index 7c7764d5..990b450b 100644 --- a/dnn/src/cuda/group_local/bwd_filter.cpp +++ b/dnn/src/cuda/group_local/bwd_filter.cpp @@ -19,81 +19,60 @@ namespace megdnn { namespace cuda { -void GroupLocalBackwardFilterImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) -{ +void GroupLocalBackwardFilterImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { check_exec(src.layout, diff.layout, grad.layout, workspace.size); auto G = grad.layout[0]; - auto N = src.layout.shape[0], IC = src.layout.shape[1]/G, + auto N = src.layout.shape[0], IC = src.layout.shape[1] / G, IH = src.layout.shape[2], IW = src.layout.shape[3], - OC = diff.layout.shape[1]/G, - OH = diff.layout.shape[2], OW = diff.layout.shape[3]; + OC = diff.layout.shape[1] / G, OH = diff.layout.shape[2], + OW = diff.layout.shape[3]; auto FH = grad.layout.shape[4], FW = grad.layout.shape[5]; auto PH = param().pad_h, PW = param().pad_w; auto SH = param().stride_h, SW = param().stride_w; - const float *sptr = src.ptr(); - float *fptr = grad.ptr(); - const float *dptr = diff.ptr(); - float *wptr = workspace.ptr(); + const float* sptr = src.ptr(); + float* fptr = grad.ptr(); + const float* dptr = diff.ptr(); + float* wptr = workspace.ptr(); auto handle = concrete_handle(this->handle()); auto stream = cuda_stream(this->handle()); auto cublas = cublas_handle(this->handle()); auto one = handle->one_device(); auto zero = handle->zero_device(); - megdnn_assert(local::can_backward_filter_proxy_convnet(N, IC, IH, IW, - OC, OH, OW, - FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW), + megdnn_assert( + local::can_backward_filter_proxy_convnet( + N, IC, IH, IW, OC, OH, OW, FH, FW, G * IC * IH * IW, + G * OC * OH * OW, PH, PW, SH, SW), "Cannot do Group Local bwd filter."); for (size_t g = 0; g < G; ++g) { - local::backward_filter_proxy_convnet(sptr + g*IC*IH*IW, - dptr + g*OC*OH*OW, - fptr + g*OH*OW*IC*FH*FW*OC, - wptr, - N, IC, IH, IW, - OC, OH, OW, - FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW, - cublas, stream, one, zero); + local::backward_filter_proxy_convnet( + sptr + g * IC * IH * IW, dptr + g * OC * OH * OW, + fptr + g * OH * OW * IC * FH * FW * OC, wptr, N, IC, IH, IW, OC, OH, OW, + FH, FW, G * IC * IH * IW, G * OC * OH * OW, PH, PW, SH, SW, cublas, + stream, one, zero); } } -GroupLocalBackwardFilterImpl::GroupLocalBackwardFilterImpl(Handle *handle): - GroupLocalBackwardFilter(handle) -{ -} +GroupLocalBackwardFilterImpl::GroupLocalBackwardFilterImpl(Handle* handle) + : GroupLocalBackwardFilter(handle) {} size_t GroupLocalBackwardFilterImpl::get_workspace_in_bytes( - const TensorLayout &src, - const TensorLayout &diff, - const TensorLayout &grad) -{ + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { auto G = grad[0]; - auto N = src.shape[0], IC = src.shape[1]/G, - IH = src.shape[2], IW = src.shape[3], - OC = diff.shape[1]/G, - OH = diff.shape[2], OW = diff.shape[3]; + auto N = src.shape[0], IC = src.shape[1] / G, IH = src.shape[2], IW = src.shape[3], + OC = diff.shape[1] / G, OH = diff.shape[2], OW = diff.shape[3]; auto FH = grad.shape[4], FW = grad.shape[5]; auto PH = param().pad_h, PW = param().pad_w; auto SH = param().stride_h, SW = param().stride_w; - auto res = local::get_workspace_in_floats_backward_filter_proxy_convnet(N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - G*IC*IH*IW, G*OC*OH*OW, - PH, PW, - SH, SW) * sizeof(float); + auto res = local::get_workspace_in_floats_backward_filter_proxy_convnet( + N, IC, IH, IW, OC, OH, OW, FH, FW, G * IC * IH * IW, + G * OC * OH * OW, PH, PW, SH, SW) * + sizeof(float); return res; } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/group_local/forward/kern.cu b/dnn/src/cuda/group_local/forward/kern.cu index 991cbdb3..04beeb7f 100644 --- a/dnn/src/cuda/group_local/forward/kern.cu +++ b/dnn/src/cuda/group_local/forward/kern.cu @@ -41,14 +41,11 @@ constexpr size_t NB = 4, ICB = 4; // src planes are loaded into shared memory (presumably src spatial size is // small). template -__global__ void forward_kernel(const float* __restrict__ src, - const float* __restrict__ filter, - float* __restrict__ dst, uint32_t N, uint32_t IC, - uint32_t IH, uint32_t IW, uint32_t OC, - uint32_t OH, uint32_t OW, uint32_t FH, - uint32_t FW, uint32_t INs, uint32_t ONs, - uint32_t PH, uint32_t PW, uint32_t SH, - uint32_t SW) { +__global__ void forward_kernel( + const float* __restrict__ src, const float* __restrict__ filter, + float* __restrict__ dst, uint32_t N, uint32_t IC, uint32_t IH, uint32_t IW, + uint32_t OC, uint32_t OH, uint32_t OW, uint32_t FH, uint32_t FW, uint32_t INs, + uint32_t ONs, uint32_t PH, uint32_t PW, uint32_t SH, uint32_t SW) { // NB * ICB * sizeof(float) * IH * IW extern __shared__ float shared_mem[]; float* src_cache = shared_mem; @@ -74,9 +71,8 @@ __global__ void forward_kernel(const float* __restrict__ src, uint32_t ip = i % (IH * IW); uint32_t icb = i / (IH * IW) % ICB; uint32_t nb = i / (IH * IW) / ICB; - src_cache[i] = - (icb < ICB_cur) * - src[nb * INs + min(IC - 1, (ic + icb)) * IH * IW + ip]; + src_cache[i] = (icb < ICB_cur) * + src[nb * INs + min(IC - 1, (ic + icb)) * IH * IW + ip]; } __syncthreads(); if (oid < OC * OH * OW) @@ -102,9 +98,9 @@ __global__ void forward_kernel(const float* __restrict__ src, float src_reg[NB]; #pragma unroll for (uint32_t nb = 0; nb < NB; ++nb) { - src_reg[nb] = src_cache[nb * ICB * IH * IW + - icb * IH * IW + - ih * IW + iw]; + src_reg[nb] = src_cache + [nb * ICB * IH * IW + icb * IH * IW + + ih * IW + iw]; } #pragma unroll for (uint32_t nb = 0; nb < NB; ++nb) { @@ -122,22 +118,21 @@ __global__ void forward_kernel(const float* __restrict__ src, } } -} +} // namespace -void group_local::exec(const float* src, const float* filter, float* dst, - float* wptr, uint32_t N, uint32_t IC, uint32_t IH, - uint32_t IW, uint32_t OC, uint32_t OH, uint32_t OW, - uint32_t FH, uint32_t FW, uint32_t G, uint32_t PH, - uint32_t PW, uint32_t SH, uint32_t SW, - cudaStream_t stream) { +void group_local::exec( + const float* src, const float* filter, float* dst, float* wptr, uint32_t N, + uint32_t IC, uint32_t IH, uint32_t IW, uint32_t OC, uint32_t OH, uint32_t OW, + uint32_t FH, uint32_t FW, uint32_t G, uint32_t PH, uint32_t PW, uint32_t SH, + uint32_t SW, cudaStream_t stream) { MEGDNN_MARK_USED_VAR(wptr); size_t threads = 256; dim3 blocks = dim3(DIVUP(N, NB), DIVUP(OC * OH * OW, threads), G); uint32_t INs = G * IC * IH * IW, ONs = G * OC * OH * OW; forward_kernel <<>>( - src, filter, dst, N, IC, IH, IW, OC, OH, OW, FH, FW, INs, - ONs, PH, PW, SH, SW); + src, filter, dst, N, IC, IH, IW, OC, OH, OW, FH, FW, INs, ONs, PH, + PW, SH, SW); after_kernel_launch(); } diff --git a/dnn/src/cuda/group_local/forward/kern.cuh b/dnn/src/cuda/group_local/forward/kern.cuh index e656c2b3..900effd6 100644 --- a/dnn/src/cuda/group_local/forward/kern.cuh +++ b/dnn/src/cuda/group_local/forward/kern.cuh @@ -16,21 +16,17 @@ namespace megdnn { namespace cuda { namespace group_local { -void exec(const float *src, const float *filter, float *dst, - float *wptr, - uint32_t N, uint32_t IC, uint32_t IH, uint32_t IW, - uint32_t OC, uint32_t OH, uint32_t OW, - uint32_t FH, uint32_t FW, - uint32_t G, - uint32_t PH, uint32_t PW, - uint32_t SH, uint32_t SW, - cudaStream_t stream); +void exec( + const float* src, const float* filter, float* dst, float* wptr, uint32_t N, + uint32_t IC, uint32_t IH, uint32_t IW, uint32_t OC, uint32_t OH, uint32_t OW, + uint32_t FH, uint32_t FW, uint32_t G, uint32_t PH, uint32_t PW, uint32_t SH, + uint32_t SW, cudaStream_t stream); size_t get_share_mem_in_bytes(uint32_t IH, uint32_t IW); -} // namespace group_local +} // namespace group_local -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/group_local/forward/opr_impl.cpp b/dnn/src/cuda/group_local/forward/opr_impl.cpp index e51c498a..5776e82f 100644 --- a/dnn/src/cuda/group_local/forward/opr_impl.cpp +++ b/dnn/src/cuda/group_local/forward/opr_impl.cpp @@ -26,8 +26,7 @@ using namespace cuda; namespace { -std::unique_ptr get_opr(Handle* handle, - param::Convolution param) { +std::unique_ptr get_opr(Handle* handle, param::Convolution param) { auto&& opr = handle->create_operator(); opr->param() = param; return std::move(opr); @@ -56,18 +55,18 @@ TensorLayout prepare_filter(const TensorLayout& filter) { } // namespace -void GroupLocalForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) { - megdnn_assert(src.layout.dtype == dtype::Float32(), - "cuda do not support fp16 group local operator"); +void GroupLocalForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + megdnn_assert( + src.layout.dtype == dtype::Float32(), + "cuda do not support fp16 group local operator"); check_exec(src.layout, filter.layout, dst.layout, workspace.size); auto handle = concrete_handle(this->handle()); auto G = filter.layout[0]; - auto IH = src.layout.shape[2], IW = src.layout.shape[3], - OH = dst.layout.shape[2], OW = dst.layout.shape[3]; + auto IH = src.layout.shape[2], IW = src.layout.shape[3], OH = dst.layout.shape[2], + OW = dst.layout.shape[3]; if (prefer_inference_kernel(src.layout, filter.layout, dst.layout)) { auto N = src.layout.shape[0], ICg = src.layout.shape[1] / G, OCg = dst.layout.shape[1] / G; @@ -80,8 +79,9 @@ void GroupLocalForwardImpl::exec(_megdnn_tensor_in src, float* wptr = workspace.ptr(); auto stream = cuda_stream(this->handle()); - group_local::exec(sptr, fptr, dptr, wptr, N, ICg, IH, IW, OCg, OH, OW, - FH, FW, G, PH, PW, SH, SW, stream); + group_local::exec( + sptr, fptr, dptr, wptr, N, ICg, IH, IW, OCg, OH, OW, FH, FW, G, PH, PW, + SH, SW, stream); } else { auto&& opr = get_opr(handle, param()); TensorND src_g = {src.raw_ptr, prepare_src_dst(src.layout, G)}; @@ -89,12 +89,12 @@ void GroupLocalForwardImpl::exec(_megdnn_tensor_in src, TensorND filter_g = {filter.raw_ptr, prepare_filter(filter.layout)}; for (size_t g = 0; g < G; ++g) { opr->exec(src_g, filter_g, dst_g, workspace); - incr_ptr(src_g.raw_ptr, src_g.layout.stride[1] * - src_g.layout.shape[1] * - src_g.layout.dtype.size()); - incr_ptr(dst_g.raw_ptr, dst_g.layout.stride[1] * - dst_g.layout.shape[1] * - dst_g.layout.dtype.size()); + incr_ptr( + src_g.raw_ptr, src_g.layout.stride[1] * src_g.layout.shape[1] * + src_g.layout.dtype.size()); + incr_ptr( + dst_g.raw_ptr, dst_g.layout.stride[1] * dst_g.layout.shape[1] * + dst_g.layout.dtype.size()); incr_ptr(filter_g.raw_ptr, filter_g.layout.span().dist_byte()); } } @@ -103,9 +103,8 @@ void GroupLocalForwardImpl::exec(_megdnn_tensor_in src, GroupLocalForwardImpl::GroupLocalForwardImpl(Handle* handle) : GroupLocalForward(handle) {} -size_t GroupLocalForwardImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) { +size_t GroupLocalForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { if (prefer_inference_kernel(src, filter, dst)) { return 0; } else { @@ -118,9 +117,8 @@ size_t GroupLocalForwardImpl::get_workspace_in_bytes(const TensorLayout& src, } } -bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) { +bool GroupLocalForwardImpl::prefer_inference_kernel( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { MEGDNN_MARK_USED_VAR(filter); MEGDNN_MARK_USED_VAR(dst); auto handle = concrete_handle(this->handle()); diff --git a/dnn/src/cuda/group_local/opr_impl.h b/dnn/src/cuda/group_local/opr_impl.h index 1cbfb50b..fc81f9c0 100644 --- a/dnn/src/cuda/group_local/opr_impl.h +++ b/dnn/src/cuda/group_local/opr_impl.h @@ -14,46 +14,44 @@ namespace megdnn { namespace cuda { -class GroupLocalForwardImpl: public GroupLocalForward { - public: - GroupLocalForwardImpl(Handle *handle); - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) override; - private: - bool prefer_inference_kernel(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst); +class GroupLocalForwardImpl : public GroupLocalForward { +public: + GroupLocalForwardImpl(Handle* handle); + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; + +private: + bool prefer_inference_kernel( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst); }; -class GroupLocalBackwardDataImpl: public GroupLocalBackwardData { - public: - GroupLocalBackwardDataImpl(Handle *handle); - void exec(_megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &filter, - const TensorLayout &diff, - const TensorLayout &grad) override; +class GroupLocalBackwardDataImpl : public GroupLocalBackwardData { +public: + GroupLocalBackwardDataImpl(Handle* handle); + void exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; }; -class GroupLocalBackwardFilterImpl: public GroupLocalBackwardFilter { - public: - GroupLocalBackwardFilterImpl(Handle *handle); - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &diff, - const TensorLayout &grad) override; +class GroupLocalBackwardFilterImpl : public GroupLocalBackwardFilter { +public: + GroupLocalBackwardFilterImpl(Handle* handle); + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/handle.cpp b/dnn/src/cuda/handle.cpp index 7b85bcee..39bd56b3 100644 --- a/dnn/src/cuda/handle.cpp +++ b/dnn/src/cuda/handle.cpp @@ -12,21 +12,23 @@ #include "src/common/handle_impl.h" #include "src/common/version_symbol.h" +#include "megdnn/common.h" #include "src/cuda/handle.h" #include "src/cuda/utils.h" -#include "megdnn/common.h" #include #include #define STR_HELPER(x) #x -#define STR(x) STR_HELPER(x) +#define STR(x) STR_HELPER(x) -#define CUDNN_VERSION_STR STR(CUDNN_MAJOR) "." STR(CUDNN_MINOR) "." STR(CUDNN_PATCHLEVEL) +#define CUDNN_VERSION_STR \ + STR(CUDNN_MAJOR) "." STR(CUDNN_MINOR) "." STR(CUDNN_PATCHLEVEL) #pragma message "compile with cuDNN " CUDNN_VERSION_STR " " -static_assert(!(CUDNN_MAJOR == 5 && CUDNN_MINOR == 1), +static_assert( + !(CUDNN_MAJOR == 5 && CUDNN_MINOR == 1), "cuDNN 5.1.x series has bugs. Use 5.0.x instead."); #undef STR @@ -35,9 +37,8 @@ static_assert(!(CUDNN_MAJOR == 5 && CUDNN_MINOR == 1), namespace megdnn { namespace cuda { -HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): - HandleImplHelper(comp_handle, HandleType::CUDA) -{ +HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle) + : HandleImplHelper(comp_handle, HandleType::CUDA) { // Get megcore device handle megcoreDeviceHandle_t dev_handle; megcoreGetDeviceHandle(comp_handle, &dev_handle); @@ -49,12 +50,14 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): m_device_id = dev_id; m_device_prop = get_device_prop(dev_id); // Get stream from MegCore computing handle. - megdnn_assert(CUDNN_VERSION == cudnnGetVersion(), - "cudnn version mismatch: compiled with %d; detected %zu at runtime", - CUDNN_VERSION, cudnnGetVersion()); + megdnn_assert( + CUDNN_VERSION == cudnnGetVersion(), + "cudnn version mismatch: compiled with %d; detected %zu at runtime", + CUDNN_VERSION, cudnnGetVersion()); #if CUDA_VERSION >= 10010 - megdnn_assert(cublasLtGetVersion() >= 10010, - "cuda library version is too low to run cublasLt"); + megdnn_assert( + cublasLtGetVersion() >= 10010, + "cuda library version is too low to run cublasLt"); #endif #if CUDNN_VERSION >= 8000 if (!MGB_GETENV("CUDA_CACHE_PATH")) { @@ -77,15 +80,15 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): // Note that all cublas scalars (alpha, beta) and scalar results such as dot // output resides at device side. - cublas_check(cublasSetPointerMode(m_cublas_handle, - CUBLAS_POINTER_MODE_DEVICE)); + cublas_check(cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE)); // init const scalars cuda_check(cudaMalloc(&m_const_scalars, sizeof(ConstScalars))); ConstScalars const_scalars_val; const_scalars_val.init(); - cuda_check(cudaMemcpyAsync(m_const_scalars, &const_scalars_val, - sizeof(ConstScalars), cudaMemcpyHostToDevice, stream())); + cuda_check(cudaMemcpyAsync( + m_const_scalars, &const_scalars_val, sizeof(ConstScalars), + cudaMemcpyHostToDevice, stream())); cuda_check(cudaStreamSynchronize(stream())); // check tk1 @@ -106,13 +109,16 @@ HandleImpl::~HandleImpl() noexcept { } void HandleImpl::ConstScalars::init() { - f16[0].megdnn_x = 0; f16[1].megdnn_x = 1; - f32[0] = 0; f32[1] = 1; - i32[0] = 0; i32[1] = 1; + f16[0].megdnn_x = 0; + f16[1].megdnn_x = 1; + f32[0] = 0; + f32[1] = 1; + i32[0] = 0; + i32[1] = 1; } size_t HandleImpl::alignment_requirement() const { - auto &&prop = m_device_prop; + auto&& prop = m_device_prop; return std::max(prop->textureAlignment, prop->texturePitchAlignment); } @@ -136,8 +142,8 @@ HandleImpl::HandleVendorType HandleImpl::vendor_type() const { return HandleVendorType::CUDA; } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION); MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); diff --git a/dnn/src/cuda/handle.h b/dnn/src/cuda/handle.h index a886c511..e4506942 100644 --- a/dnn/src/cuda/handle.h +++ b/dnn/src/cuda/handle.h @@ -14,15 +14,15 @@ #include "megdnn/handle.h" #include "megdnn/oprs/general.h" -#include "src/common/utils.h" #include "src/common/handle_impl.h" +#include "src/common/utils.h" #include "src/cuda/cudnn_with_check.h" -#include -#include -#include #include +#include #include +#include +#include #include #if CUDA_VERSION >= 10010 @@ -32,134 +32,108 @@ namespace megdnn { namespace cuda { -class HandleImpl: public HandleImplHelper { - public: - HandleImpl(megcoreComputingHandle_t computing_handle); - ~HandleImpl() noexcept; +class HandleImpl : public HandleImplHelper { +public: + HandleImpl(megcoreComputingHandle_t computing_handle); + ~HandleImpl() noexcept; - size_t alignment_requirement() const override; + size_t alignment_requirement() const override; - bool check_cross_dev_copy_constraint(const TensorLayout &src) override; + bool check_cross_dev_copy_constraint(const TensorLayout& src) override; - const cudaDeviceProp& device_prop() const { - return *m_device_prop; - } + const cudaDeviceProp& device_prop() const { return *m_device_prop; } - template - std::unique_ptr create_operator(); + template + std::unique_ptr create_operator(); - const megcore::CudaContext& megcore_context() const { - return m_megcore_context; - } + const megcore::CudaContext& megcore_context() const { return m_megcore_context; } - int device_id() const { return m_device_id; } + int device_id() const { return m_device_id; } - cudaStream_t stream() const { - return megcore_context().stream; - } - cudnnHandle_t cudnn_handle() { - return m_cudnn_handle; - } - cublasHandle_t cublas_handle() { - return m_cublas_handle; - } + cudaStream_t stream() const { return megcore_context().stream; } + cudnnHandle_t cudnn_handle() { return m_cudnn_handle; } + cublasHandle_t cublas_handle() { return m_cublas_handle; } #if CUDA_VERSION >= 10010 - cublasLtHandle_t cublasLt_handle() { - return m_cublasLt_handle; - } + cublasLtHandle_t cublasLt_handle() { return m_cublasLt_handle; } #endif - cusolverDnHandle_t cusolver_handle() { - std::call_once(m_cusolver_initialized, - [this] { initialize_cusolver(); }); - return m_cusolver_handle; - } - dt_float32 *zero_device() { - return &m_const_scalars->f32[0]; - } - dt_float32 *one_device() { - return &m_const_scalars->f32[1]; - } - __half* zero_device_h() { - return &m_const_scalars->f16[0].cuda_x; - } - __half* one_device_h() { - return &m_const_scalars->f16[1].cuda_x; - } - dt_int32 *zero_device_i32() { - return &m_const_scalars->i32[0]; - } - dt_int32 *one_device_i32() { - return &m_const_scalars->i32[1]; - } - - bool is_tegra_k1() const { - return m_is_tegra_k1; - } - - //! global matmul opr - MatrixMul* matmul_opr() override final { - return get_helper_opr(this); - } - - //! global matmul opr with first operand transposed - MatrixMul* matmul_aT_opr() override final { - return get_helper_opr(this, {true, false}); - } - - //! global matmul opr with second operand transposed - MatrixMul* matmul_bT_opr() override final { - return get_helper_opr(this, {false, true}); - } - - //! global relayout opr - Relayout* relayout_opr() override final { - return get_helper_opr(this); - } - - BatchedMatrixMulForward* batched_matrix_mul() { - return get_helper_opr(this); - } - - TypeCvt* typecvt_opr() { return get_helper_opr(this); } - - size_t image2d_pitch_alignment() const override; - HandleVendorType vendor_type() const override; - private: - bool m_is_tegra_k1; - int m_device_id; - //! MegDNN handle does not manage the lifetime of CUDA stream. - megcore::CudaContext m_megcore_context; - - cudnnHandle_t m_cudnn_handle; - cublasHandle_t m_cublas_handle; + cusolverDnHandle_t cusolver_handle() { + std::call_once(m_cusolver_initialized, [this] { initialize_cusolver(); }); + return m_cusolver_handle; + } + dt_float32* zero_device() { return &m_const_scalars->f32[0]; } + dt_float32* one_device() { return &m_const_scalars->f32[1]; } + __half* zero_device_h() { return &m_const_scalars->f16[0].cuda_x; } + __half* one_device_h() { return &m_const_scalars->f16[1].cuda_x; } + dt_int32* zero_device_i32() { return &m_const_scalars->i32[0]; } + dt_int32* one_device_i32() { return &m_const_scalars->i32[1]; } + + bool is_tegra_k1() const { return m_is_tegra_k1; } + + //! global matmul opr + MatrixMul* matmul_opr() override final { + return get_helper_opr(this); + } + + //! global matmul opr with first operand transposed + MatrixMul* matmul_aT_opr() override final { + return get_helper_opr(this, {true, false}); + } + + //! global matmul opr with second operand transposed + MatrixMul* matmul_bT_opr() override final { + return get_helper_opr(this, {false, true}); + } + + //! global relayout opr + Relayout* relayout_opr() override final { + return get_helper_opr(this); + } + + BatchedMatrixMulForward* batched_matrix_mul() { + return get_helper_opr(this); + } + + TypeCvt* typecvt_opr() { return get_helper_opr(this); } + + size_t image2d_pitch_alignment() const override; + HandleVendorType vendor_type() const override; + +private: + bool m_is_tegra_k1; + int m_device_id; + //! MegDNN handle does not manage the lifetime of CUDA stream. + megcore::CudaContext m_megcore_context; + + cudnnHandle_t m_cudnn_handle; + cublasHandle_t m_cublas_handle; #if CUDA_VERSION >= 10010 - cublasLtHandle_t m_cublasLt_handle; + cublasLtHandle_t m_cublasLt_handle; #endif - cusolverDnHandle_t m_cusolver_handle; - std::once_flag m_cusolver_initialized; - - const cudaDeviceProp* m_device_prop; - - struct ConstScalars { - union FP16 { - __half cuda_x; - dt_float16 megdnn_x; - FP16() {} - }; - static_assert(sizeof(FP16) == 2, "bad FP16 size"); - FP16 f16[2]; - dt_float32 f32[2]; - dt_int32 i32[2]; - void init(); + cusolverDnHandle_t m_cusolver_handle; + std::once_flag m_cusolver_initialized; + + const cudaDeviceProp* m_device_prop; + + struct ConstScalars { + union FP16 { + __half cuda_x; + dt_float16 megdnn_x; + FP16() {} }; + static_assert(sizeof(FP16) == 2, "bad FP16 size"); + FP16 f16[2]; + dt_float32 f32[2]; + dt_int32 i32[2]; + void init(); + }; - //! device ptr to const scalars - ConstScalars* m_const_scalars; + //! device ptr to const scalars + ConstScalars* m_const_scalars; - void initialize_cusolver(); + void initialize_cusolver(); }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 2b96c782..03858f5f 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -12,7 +12,6 @@ #include "src/common/handle_impl.h" -#include "src/cuda/padding/opr_impl.h" #include "src/cuda/adaptive_pooling/opr_impl.h" #include "src/cuda/add_update/opr_impl.h" #include "src/cuda/argmxx/opr_impl.h" @@ -56,6 +55,7 @@ #include "src/cuda/matrix_mul/opr_impl.h" #include "src/cuda/max_tensor_diff/opr_impl.h" #include "src/cuda/mesh_indexing/opr_impl.h" +#include "src/cuda/padding/opr_impl.h" #include "src/cuda/param_pack/opr_impl.h" #include "src/cuda/pooling/opr_impl.h" #include "src/cuda/powc/opr_impl.h" diff --git a/dnn/src/cuda/images2neibs/kernel.cu b/dnn/src/cuda/images2neibs/kernel.cu index e8d150d2..c75e9cc8 100644 --- a/dnn/src/cuda/images2neibs/kernel.cu +++ b/dnn/src/cuda/images2neibs/kernel.cu @@ -10,24 +10,22 @@ */ #include "src/cuda/images2neibs/kernel.cuh" +#include #include "megdnn/dtype.h" #include "src/cuda/utils.cuh" -#include namespace megdnn { namespace cuda { namespace images2neibs { - #define grid_y_max 512 template -__global__ void forward_kernel(const T *src, T *dst, - int N, int C, int IH, int IW, int OH, int OW, - int ph, int pw, int sh, int sw, int dh, int dw, int WH, int WW) -{ +__global__ void forward_kernel( + const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW, int ph, + int pw, int sh, int sw, int dh, int dw, int WH, int WW) { int NC = N * C; - int WP = WH*WW; + int WP = WH * WW; for (int wp = threadIdx.x; wp < WP; wp += blockDim.x) { int nc = blockIdx.y; while (nc < NC) { @@ -37,13 +35,12 @@ __global__ void forward_kernel(const T *src, T *dst, if (op < OH * OW) { int oh = op / OW; int ow = op % OW; - int ih = -ph + sh * oh + wh* dh; - int iw = -pw + sw * ow + ww* dw; + int ih = -ph + sh * oh + wh * dh; + int iw = -pw + sw * ow + ww * dw; int dst_pos = nc * OH * OW * WH * WW + op * WH * WW + wp; int src_pos = nc * IH * IW + ih * IW + iw; - dst[dst_pos] = (ih >= 0 && ih < IH && iw >= 0 && iw < IW) - ? src[src_pos] - : 0.0f; + dst[dst_pos] = (ih >= 0 && ih < IH && iw >= 0 && iw < IW) ? src[src_pos] + : 0.0f; } nc += grid_y_max; } @@ -51,9 +48,9 @@ __global__ void forward_kernel(const T *src, T *dst, } template -void forward(const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW, - int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, - cudaStream_t stream) { +void forward( + const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW, int ph, + int pw, int sh, int sw, int dh, int dw, int wh, int ww, cudaStream_t stream) { int spatial_size = OH * OW; int kernel_size = wh * ww; int tx = min(NR_THREADS, kernel_size); @@ -62,72 +59,64 @@ void forward(const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW, int bx = DIVUP(spatial_size, ty); int by = N * C; - forward_kernel<<>>(src, dst, N, C, IH, IW, OH, OW, ph, pw, sh, sw, dh, dw, - wh, ww); + forward_kernel<<>>( + src, dst, N, C, IH, IW, OH, OW, ph, pw, sh, sw, dh, dw, wh, ww); after_kernel_launch(); } #undef grid_y_max template -__global__ void backward_kernel(const T *diff, T *grad, - int N, int C, int IH, int IW, int OH, int OW, - int ph, int pw, int sh, int sw, int dh, int dw, int WH, int WW) -{ +__global__ void backward_kernel( + const T* diff, T* grad, int N, int C, int IH, int IW, int OH, int OW, int ph, + int pw, int sh, int sw, int dh, int dw, int WH, int WW) { int id = threadIdx.x + blockIdx.x * blockDim.x; - if (id < N*C*IH*IW) { - int nc = id / (IH*IW); - int ih = id % (IH*IW) / IW; - int iw = id % (IH*IW) % IW; - grad[nc*IH*IW + ih*IW + iw] = 0.0f; - int oh_max = min((ih+ph) / sh, OH-1); - int oh_min = max((ih+ph-(WH-1)*dh+sh-1) / sh, 0); - int ow_max = min((iw+pw) / sw, OW-1); - int ow_min = max((iw+pw-(WW-1)*dw+sw-1) / sw, 0); + if (id < N * C * IH * IW) { + int nc = id / (IH * IW); + int ih = id % (IH * IW) / IW; + int iw = id % (IH * IW) % IW; + grad[nc * IH * IW + ih * IW + iw] = 0.0f; + int oh_max = min((ih + ph) / sh, OH - 1); + int oh_min = max((ih + ph - (WH - 1) * dh + sh - 1) / sh, 0); + int ow_max = min((iw + pw) / sw, OW - 1); + int ow_min = max((iw + pw - (WW - 1) * dw + sw - 1) / sw, 0); for (int oh = oh_min; oh <= oh_max; ++oh) - for (int ow = ow_min; ow <= ow_max; ++ow) - { - if ((ih+ph - sh*oh)%dh==0 && (iw+pw - sw*ow)%dw==0){ - int wh = ih+ph - sh*oh - (ih+ph - sh*oh)/dh * (dh-1); - int ww = iw+pw - sw*ow - (iw+pw - sw*ow)/dw * (dw-1); - grad[nc*IH*IW + ih*IW + iw] += - diff[nc*OH*OW*WH*WW + oh*OW*WH*WW + ow*WH*WW + - wh*WW + ww]; - + for (int ow = ow_min; ow <= ow_max; ++ow) { + if ((ih + ph - sh * oh) % dh == 0 && (iw + pw - sw * ow) % dw == 0) { + int wh = ih + ph - sh * oh - (ih + ph - sh * oh) / dh * (dh - 1); + int ww = iw + pw - sw * ow - (iw + pw - sw * ow) / dw * (dw - 1); + grad[nc * IH * IW + ih * IW + iw] += + diff[nc * OH * OW * WH * WW + oh * OW * WH * WW + + ow * WH * WW + wh * WW + ww]; + } } - } } } template -void backward(const T *diff, T *grad, - int N, int C, int IH, int IW, int OH, int OW, - int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, - cudaStream_t stream) -{ +void backward( + const T* diff, T* grad, int N, int C, int IH, int IW, int OH, int OW, int ph, + int pw, int sh, int sw, int dh, int dw, int wh, int ww, cudaStream_t stream) { int threads = NR_THREADS; - int blocks = DIVUP(N*C*IH*IW, threads); - backward_kernel<<>>(diff, grad, - N, C, IH, IW, OH, OW, - ph, pw, sh, sw, dh, dw, wh, ww); + int blocks = DIVUP(N * C * IH * IW, threads); + backward_kernel<<>>( + diff, grad, N, C, IH, IW, OH, OW, ph, pw, sh, sw, dh, dw, wh, ww); after_kernel_launch(); } -#define INST(T) \ - template void forward(const T *, T *, int, int, int, int, int, int, \ - int, int, int, int, int, int, int, int, \ - cudaStream_t); \ - template void backward(const T *, T *, int, int, int, int, int, int, \ - int, int, int, int, int, int, int, int, \ - cudaStream_t); -#define cb(DType) \ - INST(DTypeTrait::ctype) +#define INST(T) \ + template void forward( \ + const T*, T*, int, int, int, int, int, int, int, int, int, int, int, int, \ + int, int, cudaStream_t); \ + template void backward( \ + const T*, T*, int, int, int, int, int, int, int, int, int, int, int, int, \ + int, int, cudaStream_t); +#define cb(DType) INST(DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb) -} // namespace images2neibs -} // namespace cuda -} // namespace megdnn +} // namespace images2neibs +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/images2neibs/kernel.cuh b/dnn/src/cuda/images2neibs/kernel.cuh index acef5416..f5c39389 100644 --- a/dnn/src/cuda/images2neibs/kernel.cuh +++ b/dnn/src/cuda/images2neibs/kernel.cuh @@ -16,19 +16,16 @@ namespace cuda { namespace images2neibs { template -void forward(const T *src, T *dst, - int N, int C, int IH, int IW, int OH, int OW, - int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, - cudaStream_t stream); +void forward( + const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW, int ph, + int pw, int sh, int sw, int dh, int dw, int wh, int ww, cudaStream_t stream); template -void backward(const T *diff, T *grad, - int N, int C, int IH, int IW, int OH, int OW, - int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, - cudaStream_t stream); - -} // namespace images2neibs -} // namespace cuda -} // namespace megdnn -// vim: syntax=cpp.doxygen +void backward( + const T* diff, T* grad, int N, int C, int IH, int IW, int OH, int OW, int ph, + int pw, int sh, int sw, int dh, int dw, int wh, int ww, cudaStream_t stream); +} // namespace images2neibs +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/images2neibs/opr_impl.cpp b/dnn/src/cuda/images2neibs/opr_impl.cpp index 4761fd4e..5d984ca4 100644 --- a/dnn/src/cuda/images2neibs/opr_impl.cpp +++ b/dnn/src/cuda/images2neibs/opr_impl.cpp @@ -10,67 +10,60 @@ */ #include "src/cuda/images2neibs/opr_impl.h" -#include "src/cuda/utils.h" #include "src/cuda/images2neibs/kernel.cuh" +#include "src/cuda/utils.h" namespace megdnn { namespace cuda { -void Images2NeibsForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) -{ +void Images2NeibsForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); auto stream = cuda_stream(handle()); - int N = src.layout[0], C = src.layout[1], - IH = src.layout[2], IW = src.layout[3]; + int N = src.layout[0], C = src.layout[1], IH = src.layout[2], IW = src.layout[3]; int OH = dst.layout[2], OW = dst.layout[3]; int ph = param().pad_h, pw = param().pad_w; int sh = param().stride_h, sw = param().stride_w; int dh = param().dilate_h, dw = param().dilate_w; int wh = param().window_h, ww = param().window_w; -#define cb(DType) \ - if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ - using T = DTypeTrait::ctype; \ - images2neibs::forward(src.ptr(), dst.ptr(), \ - N, C, IH, IW, OH, OW, \ - ph, pw, sh, sw, dh, dw, wh, ww, \ - stream); \ - return; \ +#define cb(DType) \ + if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using T = DTypeTrait::ctype; \ + images2neibs::forward( \ + src.ptr(), dst.ptr(), N, C, IH, IW, OH, OW, ph, pw, sh, sw, dh, \ + dw, wh, ww, stream); \ + return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb); #undef cb megdnn_assert_internal(0); } -void Images2NeibsBackwardImpl::exec(_megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) -{ +void Images2NeibsBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) { check_exec(diff.layout, grad.layout, workspace.size); auto stream = cuda_stream(handle()); - int N = grad.layout[0], C = grad.layout[1], - IH = grad.layout[2], IW = grad.layout[3]; + int N = grad.layout[0], C = grad.layout[1], IH = grad.layout[2], + IW = grad.layout[3]; int OH = diff.layout[2], OW = diff.layout[3]; int ph = param().pad_h, pw = param().pad_w; int sh = param().stride_h, sw = param().stride_w; int dh = param().dilate_h, dw = param().dilate_w; int wh = param().window_h, ww = param().window_w; -#define cb(DType) \ - if (diff.layout.dtype == DType()) { \ - using T = DTypeTrait::ctype; \ - images2neibs::backward(diff.ptr(), grad.ptr(), \ - N, C, IH, IW, OH, OW, \ - ph, pw, sh, sw, dh, dw, wh, ww, \ - stream); \ - return; \ +#define cb(DType) \ + if (diff.layout.dtype == DType()) { \ + using T = DTypeTrait::ctype; \ + images2neibs::backward( \ + diff.ptr(), grad.ptr(), N, C, IH, IW, OH, OW, ph, pw, sh, sw, \ + dh, dw, wh, ww, stream); \ + return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb); #undef cb megdnn_assert_internal(0); } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/images2neibs/opr_impl.h b/dnn/src/cuda/images2neibs/opr_impl.h index 01e198a5..62958aee 100644 --- a/dnn/src/cuda/images2neibs/opr_impl.h +++ b/dnn/src/cuda/images2neibs/opr_impl.h @@ -16,30 +16,28 @@ namespace megdnn { namespace cuda { -class Images2NeibsForwardImpl: public Images2NeibsForward { - public: - using Images2NeibsForward::Images2NeibsForward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &) override { - return 0; - } +class Images2NeibsForwardImpl : public Images2NeibsForward { +public: + using Images2NeibsForward::Images2NeibsForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { + return 0; + } }; -class Images2NeibsBackwardImpl: public Images2NeibsBackward { - public: - using Images2NeibsBackward::Images2NeibsBackward; - void exec(_megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &) override { - return 0; - } +class Images2NeibsBackwardImpl : public Images2NeibsBackward { +public: + using Images2NeibsBackward::Images2NeibsBackward; + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { + return 0; + } }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern.cuh b/dnn/src/cuda/indexing_multi_axis_vec/kern.cuh index ca9e56ce..be26b47e 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern.cuh +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern.cuh @@ -12,87 +12,84 @@ #pragma once #include "megdnn/arch.h" -#include "src/cuda/utils.cuh" -#include "src/cuda/int_fastdiv.cuh" #include "src/cuda/error_info.cuh" +#include "src/cuda/int_fastdiv.cuh" +#include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { namespace indexing_multi_axis_vec { - //! AxisIndexer equiv in kernel - struct KAxisIndexer { - int stride; - const int *ptr; - }; - - //! param for gen_offset_base - template - struct GenOffsetBaseParam { - uint32_t size; //!< number of outputs; also size of each index - int *output; //!< output ptr - KAxisIndexer indexer[nidx]; - uint32_t data_shape[nidx]; - int data_stride[nidx]; - - void* error_tracker; - megcore::AsyncErrorInfo* error_info; - }; - - //! tensor layout for fast offset computing - template - struct FastLayout { - int stride[ndim]; +//! AxisIndexer equiv in kernel +struct KAxisIndexer { + int stride; + const int* ptr; +}; + +//! param for gen_offset_base +template +struct GenOffsetBaseParam { + uint32_t size; //!< number of outputs; also size of each index + int* output; //!< output ptr + KAxisIndexer indexer[nidx]; + uint32_t data_shape[nidx]; + int data_stride[nidx]; + + void* error_tracker; + megcore::AsyncErrorInfo* error_info; +}; + +//! tensor layout for fast offset computing +template +struct FastLayout { + int stride[ndim]; #ifdef WIN32 - Uint32Fastdiv shape[ndim]; + Uint32Fastdiv shape[ndim]; #else - Uint32Fastdiv shape[ndim - 1]; + Uint32Fastdiv shape[ndim - 1]; #endif - }; +}; - //! param for apply_opr - template - struct ApplyOprParam { - uint32_t tot_size; //!< total output size +//! param for apply_opr +template +struct ApplyOprParam { + uint32_t tot_size; //!< total output size - //! offset array generated by gen_offset_base for first output axis - const int *offset_base; - ctype *data, *value; + //! offset array generated by gen_offset_base for first output axis + const int* offset_base; + ctype *data, *value; - int idx_axis; + int idx_axis; - int value_stride; + int value_stride; - //! iterate on value, with strides from corresponding axes on data - FastLayout value_ly_on_data; - }; + //! iterate on value, with strides from corresponding axes on data + FastLayout value_ly_on_data; +}; - //! generate offset bases for first axis in the output - template - void gen_offset_base(const GenOffsetBaseParam ¶m, - cudaStream_t stream); +//! generate offset bases for first axis in the output +template +void gen_offset_base(const GenOffsetBaseParam& param, cudaStream_t stream); - struct OprAtomicIncr { +struct OprAtomicIncr { #if MEGDNN_CC_CUDA - template - __device__ static void apply(ctype &data, ctype value) { - atomicAdd(&data, value); - } + template + __device__ static void apply(ctype& data, ctype value) { + atomicAdd(&data, value); + } #endif - }; +}; - /*! - * \brief forward kernel: copy data to value - * \tparam ndim numer of axes except axis_0 in data, - * range from 0 to max_ndim - 1 - */ - template - void apply_opr(const ApplyOprParam ¶m, - cudaStream_t stream); +/*! + * \brief forward kernel: copy data to value + * \tparam ndim numer of axes except axis_0 in data, + * range from 0 to max_ndim - 1 + */ +template +void apply_opr(const ApplyOprParam& param, cudaStream_t stream); -} // namespace indexing_multi_axis_vec -} // namespace cuda -} // namespace megdnn +} // namespace indexing_multi_axis_vec +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen - diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_fwd.cu b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_fwd.cu index 1a25f30e..f656f635 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_fwd.cu +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_fwd.cu @@ -9,10 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ - #include "src/common/indexing_multi_axis_vec_kdef.h" -#define KERN_APPLY_OPR_OPR ::megdnn::indexing_multi_axis_vec_kdef::OprFwd +#define KERN_APPLY_OPR_OPR ::megdnn::indexing_multi_axis_vec_kdef::OprFwd #include "./kern_apply_opr_impl.cuinl" // vim: ft=cuda syntax=cpp.doxygen - diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu index 9879eb02..dd65ec87 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_incr.cu @@ -9,44 +9,41 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ - #include "megdnn/dtype.h" #include "src/cuda/utils.cuh" #if !MEGDNN_DISABLE_FLOAT16 -__device__ void atomicAdd(megdnn::dt_float16 * address, megdnn::dt_float16 val) { +__device__ void atomicAdd(megdnn::dt_float16* address, megdnn::dt_float16 val) { ::megdnn::cuda::atomic_add(address, val); } -__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) { +__device__ void atomicAdd(megdnn::dt_bfloat16*, megdnn::dt_bfloat16) { __trap(); ((int*)0)[0] = 1; } #endif -__device__ void atomicAdd(megdnn::dt_int8 *, megdnn::dt_int8) { +__device__ void atomicAdd(megdnn::dt_int8*, megdnn::dt_int8) { __trap(); ((int*)0)[0] = 1; } -__device__ void atomicAdd(megdnn::dt_uint8 *, megdnn::dt_uint8) { +__device__ void atomicAdd(megdnn::dt_uint8*, megdnn::dt_uint8) { __trap(); ((int*)0)[0] = 1; } -__device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) { +__device__ void atomicAdd(megdnn::dt_int16*, megdnn::dt_int16) { __trap(); ((int*)0)[0] = 1; } -__device__ void atomicAdd(megdnn::dt_bool *, megdnn::dt_bool) { +__device__ void atomicAdd(megdnn::dt_bool*, megdnn::dt_bool) { __trap(); ((int*)0)[0] = 1; } -#define KERN_APPLY_OPR_OPR \ - ::megdnn::cuda::indexing_multi_axis_vec::OprAtomicIncr +#define KERN_APPLY_OPR_OPR ::megdnn::cuda::indexing_multi_axis_vec::OprAtomicIncr #include "./kern_apply_opr_impl.cuinl" // vim: ft=cuda syntax=cpp.doxygen - diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_set.cu b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_set.cu index 6199fdb0..01cda70f 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_set.cu +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_apply_opr_set.cu @@ -9,10 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ - #include "src/common/indexing_multi_axis_vec_kdef.h" -#define KERN_APPLY_OPR_OPR ::megdnn::indexing_multi_axis_vec_kdef::OprSet +#define KERN_APPLY_OPR_OPR ::megdnn::indexing_multi_axis_vec_kdef::OprSet #include "./kern_apply_opr_impl.cuinl" // vim: ft=cuda syntax=cpp.doxygen - diff --git a/dnn/src/cuda/indexing_multi_axis_vec/kern_gen_offset_base.cu b/dnn/src/cuda/indexing_multi_axis_vec/kern_gen_offset_base.cu index 27b857ff..1dc1bfc8 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/kern_gen_offset_base.cu +++ b/dnn/src/cuda/indexing_multi_axis_vec/kern_gen_offset_base.cu @@ -18,37 +18,37 @@ using namespace cuda; using namespace indexing_multi_axis_vec; namespace { - template - __global__ void kgen_offset_base(GenOffsetBaseParam param) { - int oidx = threadIdx.x + blockDim.x * blockIdx.x; - if (oidx < param.size) { - int offset = 0; +template +__global__ void kgen_offset_base(GenOffsetBaseParam param) { + int oidx = threadIdx.x + blockDim.x * blockIdx.x; + if (oidx < param.size) { + int offset = 0; #pragma unroll - for (int i = 0; i < nidx; ++ i) { - int data_idx = param.indexer[i].ptr[ - param.indexer[i].stride * oidx]; - data_idx += (data_idx < 0 ? param.data_shape[i] : 0); - if (static_cast(data_idx) >= param.data_shape[i]) { - // cast to uint32 to handle both negative and overflow - set_async_error_info(param.error_info, param.error_tracker, - "invalid advanced indexing: " - "indexer=%d idx=%d shape=%d", - i, data_idx, param.data_shape[i]); - data_idx = 0; - } - offset += data_idx * param.data_stride[i]; + for (int i = 0; i < nidx; ++i) { + int data_idx = param.indexer[i].ptr[param.indexer[i].stride * oidx]; + data_idx += (data_idx < 0 ? param.data_shape[i] : 0); + if (static_cast(data_idx) >= param.data_shape[i]) { + // cast to uint32 to handle both negative and overflow + set_async_error_info( + param.error_info, param.error_tracker, + "invalid advanced indexing: " + "indexer=%d idx=%d shape=%d", + i, data_idx, param.data_shape[i]); + data_idx = 0; } - param.output[oidx] = offset; + offset += data_idx * param.data_stride[i]; } + param.output[oidx] = offset; } } +} // namespace -template +template void indexing_multi_axis_vec::gen_offset_base( - const GenOffsetBaseParam ¶m, cudaStream_t stream) { + const GenOffsetBaseParam& param, cudaStream_t stream) { void (*kptr)(GenOffsetBaseParam) = kgen_offset_base; int bsize = query_blocksize_for_kernel(kptr); - (*kptr) <<>> (param); + (*kptr)<<>>(param); } namespace megdnn { @@ -56,14 +56,12 @@ namespace cuda { namespace indexing_multi_axis_vec { #define INST(_n) \ - template void gen_offset_base( \ - const GenOffsetBaseParam<_n> &, cudaStream_t); - MEGDNN_FOREACH_TENSOR_NDIM(INST) + template void gen_offset_base(const GenOffsetBaseParam<_n>&, cudaStream_t); +MEGDNN_FOREACH_TENSOR_NDIM(INST) #undef INST -} // namespace indexing_multi_axis_vec -} // namespace cuda -} // namespace megdnn +} // namespace indexing_multi_axis_vec +} // namespace cuda +} // namespace megdnn // vim: ft=cuda syntax=cpp.doxygen - diff --git a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp index 86c28fd4..88833ef9 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp +++ b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.cpp @@ -12,86 +12,88 @@ #include "./opr_impl.h" #include "./kern.cuh" -#include "src/cuda/utils.h" #include "src/common/indexing_multi_axis_vec_kdef.h" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; using namespace indexing_multi_axis_vec; namespace { - class ExecImplHelper { - template - void dispatch_gen_offset_base_nidx(); - - void dispatch_gen_offset_base(); - protected: - using IndexDesc = IndexingMultiAxisVec::IndexDesc; - using ExecInfo = IndexingMultiAxisVec::ExecInfo; - - cudaStream_t m_stream; - const TensorND * const m_data; - const TensorND * const m_value; - const IndexDesc * const m_index; - const ExecInfo* const m_exec_info; - int * const m_offset_base; - TensorLayout m_value_layout_on_data; - size_t m_idx_axis; - int m_value_stride; - - public: - ExecImplHelper(const TensorND &data, const TensorND &value, - const IndexDesc &index, const Workspace &workspace, - const ExecInfo &exec_info, cudaStream_t stream); - }; - - template - class ExecImpl : public ExecImplHelper { - - void dispatch_exec(); - - template - void dispatch_exec_ctype(); - - template - void dispatch_exec_ctype_ndim(); - - public: - using ExecImplHelper::ExecImplHelper; - - void operator() () { - dispatch_exec(); - after_kernel_launch(); - } - }; -} // anonymous namespace - -ExecImplHelper::ExecImplHelper(const TensorND &data, const TensorND &value, - const IndexDesc &index, const Workspace &workspace, - const ExecInfo &exec_info, cudaStream_t stream): - m_stream{stream}, m_data{&data}, m_value{&value}, m_index{&index}, - m_exec_info{&exec_info}, m_offset_base{workspace.ptr()} -{ +class ExecImplHelper { + template + void dispatch_gen_offset_base_nidx(); + + void dispatch_gen_offset_base(); + +protected: + using IndexDesc = IndexingMultiAxisVec::IndexDesc; + using ExecInfo = IndexingMultiAxisVec::ExecInfo; + + cudaStream_t m_stream; + const TensorND* const m_data; + const TensorND* const m_value; + const IndexDesc* const m_index; + const ExecInfo* const m_exec_info; + int* const m_offset_base; + TensorLayout m_value_layout_on_data; + size_t m_idx_axis; + int m_value_stride; + +public: + ExecImplHelper( + const TensorND& data, const TensorND& value, const IndexDesc& index, + const Workspace& workspace, const ExecInfo& exec_info, cudaStream_t stream); +}; + +template +class ExecImpl : public ExecImplHelper { + void dispatch_exec(); + + template + void dispatch_exec_ctype(); + + template + void dispatch_exec_ctype_ndim(); + +public: + using ExecImplHelper::ExecImplHelper; + + void operator()() { + dispatch_exec(); + after_kernel_launch(); + } +}; +} // anonymous namespace + +ExecImplHelper::ExecImplHelper( + const TensorND& data, const TensorND& value, const IndexDesc& index, + const Workspace& workspace, const ExecInfo& exec_info, cudaStream_t stream) + : m_stream{stream}, + m_data{&data}, + m_value{&value}, + m_index{&index}, + m_exec_info{&exec_info}, + m_offset_base{workspace.ptr()} { safe_size_in_kern(data.layout.total_nr_elems()); dispatch_gen_offset_base(); std::tie(m_value_layout_on_data, m_idx_axis) = - IndexingMultiAxisVec::get_value_iter_optimized_layout( - data.layout, value.layout, index, exec_info.idx_axis); + IndexingMultiAxisVec::get_value_iter_optimized_layout( + data.layout, value.layout, index, exec_info.idx_axis); m_value_stride = exec_info.value_stride; } -template +template void ExecImplHelper::dispatch_gen_offset_base_nidx() { - GenOffsetBaseParam param; param.size = m_value->layout.shape[m_exec_info->idx_axis]; param.output = m_offset_base; param.error_tracker = m_exec_info->error_tracker; param.error_info = m_exec_info->error_info; - for (int i = 0; i < nidx; ++ i) { - auto &&dst = param.indexer[i]; - auto &&src = m_index->operator[](i); + for (int i = 0; i < nidx; ++i) { + auto&& dst = param.indexer[i]; + auto&& src = m_index->operator[](i); megdnn_assert(src.vec.layout.ndim == 1); dst.stride = src.vec.layout.stride[0]; if (src.vec.layout.shape[0] == 1) { @@ -105,34 +107,36 @@ void ExecImplHelper::dispatch_gen_offset_base_nidx() { } void ExecImplHelper::dispatch_gen_offset_base() { - switch(m_index->size()) { -#define cb(_n) case _n: return dispatch_gen_offset_base_nidx<_n>(); + switch (m_index->size()) { +#define cb(_n) \ + case _n: \ + return dispatch_gen_offset_base_nidx<_n>(); MEGDNN_FOREACH_TENSOR_NDIM(cb) #undef cb } megdnn_throw("bad index size"); } -template +template void ExecImpl::dispatch_exec() { switch (m_data->layout.dtype.enumv()) { -#define cb(_dtype) \ - case DTypeTrait<_dtype>::enumv: \ - return dispatch_exec_ctype::ctype>(); +#define cb(_dtype) \ + case DTypeTrait<_dtype>::enumv: \ + return dispatch_exec_ctype::ctype>(); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) cb(::megdnn::dtype::Bool) #undef cb - default: - megdnn_throw("bad dtype"); + default : megdnn_throw("bad dtype"); } } -template -template +template +template void ExecImpl::dispatch_exec_ctype() { switch (m_value_layout_on_data.ndim) { #define cb(_n) \ - case _n: return dispatch_exec_ctype_ndim(); + case _n: \ + return dispatch_exec_ctype_ndim(); MEGDNN_FOREACH_TENSOR_NDIM(cb) #undef cb default: @@ -140,8 +144,8 @@ void ExecImpl::dispatch_exec_ctype() { } } -template -template +template +template void ExecImpl::dispatch_exec_ctype_ndim() { ApplyOprParam param; param.tot_size = safe_size_in_kern(m_value->layout.total_nr_elems()); @@ -150,24 +154,21 @@ void ExecImpl::dispatch_exec_ctype_ndim() { param.value = m_value->ptr(); param.idx_axis = m_idx_axis; param.value_stride = m_value_stride; - for (int i = 0; i < ndim; ++ i) { + for (int i = 0; i < ndim; ++i) { param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i]; if (i) { - param.value_ly_on_data.shape[i - 1] = - m_value_layout_on_data.shape[i]; + param.value_ly_on_data.shape[i - 1] = m_value_layout_on_data.shape[i]; } } apply_opr(param, m_stream); } - size_t IndexingMultiAxisVecImpl::get_workspace_in_bytes(size_t dst_idx_size) { return dst_idx_size * sizeof(int); } void IndexingMultiAxisVecImpl::exec( - _megdnn_tensor_in src, const IndexDesc &index, - _megdnn_tensor_out dst, + _megdnn_tensor_in src, const IndexDesc& index, _megdnn_tensor_out dst, _megdnn_workspace workspace) { auto info = check_exec(src.layout, index, dst.layout, workspace.size); info.error_tracker = m_error_tracker; @@ -176,14 +177,13 @@ void IndexingMultiAxisVecImpl::exec( src, dst, index, workspace, info, cuda_stream(handle())}(); } -size_t IndexingSetMultiAxisVecImpl::get_workspace_in_bytes( - size_t value_idx_size) { +size_t IndexingSetMultiAxisVecImpl::get_workspace_in_bytes(size_t value_idx_size) { return value_idx_size * sizeof(int); } void IndexingSetMultiAxisVecImpl::exec( - _megdnn_tensor_inout data, _megdnn_tensor_in value, - const IndexDesc &index, _megdnn_workspace workspace) { + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& index, + _megdnn_workspace workspace) { auto info = check_exec(data.layout, value.layout, index, workspace.size); info.error_tracker = m_error_tracker; info.error_info = async_error_info(handle()); @@ -191,20 +191,18 @@ void IndexingSetMultiAxisVecImpl::exec( data, value, index, workspace, info, cuda_stream(handle())}(); } -size_t IndexingIncrMultiAxisVecImpl::get_workspace_in_bytes( - size_t value_idx_size) { +size_t IndexingIncrMultiAxisVecImpl::get_workspace_in_bytes(size_t value_idx_size) { return value_idx_size * sizeof(int); } void IndexingIncrMultiAxisVecImpl::exec( - _megdnn_tensor_inout data, _megdnn_tensor_in value, - const IndexDesc &index, _megdnn_workspace workspace) { + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& index, + _megdnn_workspace workspace) { auto info = check_exec(data.layout, value.layout, index, workspace.size); info.error_tracker = m_error_tracker; info.error_info = async_error_info(handle()); - ExecImpl{data, value, index, workspace, info, - cuda_stream(handle())}(); + ExecImpl{data, value, index, + workspace, info, cuda_stream(handle())}(); } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.h b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.h index 6ebd88f6..60c24799 100644 --- a/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.h +++ b/dnn/src/cuda/indexing_multi_axis_vec/opr_impl.h @@ -16,58 +16,51 @@ namespace megdnn { namespace cuda { - class IndexingMultiAxisVecImpl final: public IndexingMultiAxisVec { - void* m_error_tracker = nullptr; +class IndexingMultiAxisVecImpl final : public IndexingMultiAxisVec { + void* m_error_tracker = nullptr; - public: - using IndexingMultiAxisVec::IndexingMultiAxisVec; +public: + using IndexingMultiAxisVec::IndexingMultiAxisVec; - size_t get_workspace_in_bytes(size_t dst_idx_size) override; + size_t get_workspace_in_bytes(size_t dst_idx_size) override; - void exec(_megdnn_tensor_in src, const IndexDesc &index, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, const IndexDesc& index, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } - }; + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } +}; - class IndexingSetMultiAxisVecImpl final: public IndexingSetMultiAxisVec { - void* m_error_tracker = nullptr; +class IndexingSetMultiAxisVecImpl final : public IndexingSetMultiAxisVec { + void* m_error_tracker = nullptr; - public: - using IndexingSetMultiAxisVec::IndexingSetMultiAxisVec; +public: + using IndexingSetMultiAxisVec::IndexingSetMultiAxisVec; - size_t get_workspace_in_bytes(size_t dst_idx_size) override; + size_t get_workspace_in_bytes(size_t dst_idx_size) override; - void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value, - const IndexDesc &index, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& index, + _megdnn_workspace workspace) override; - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } - }; + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } +}; - class IndexingIncrMultiAxisVecImpl final: public IndexingIncrMultiAxisVec { - void* m_error_tracker = nullptr; +class IndexingIncrMultiAxisVecImpl final : public IndexingIncrMultiAxisVec { + void* m_error_tracker = nullptr; - public: - using IndexingIncrMultiAxisVec::IndexingIncrMultiAxisVec; +public: + using IndexingIncrMultiAxisVec::IndexingIncrMultiAxisVec; - size_t get_workspace_in_bytes(size_t dst_idx_size) override; + size_t get_workspace_in_bytes(size_t dst_idx_size) override; - void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value, - const IndexDesc &index, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& index, + _megdnn_workspace workspace) override; - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } - }; -} -} + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } +}; +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/indexing_one_hot/kern.cu b/dnn/src/cuda/indexing_one_hot/kern.cu index 52d60497..437acba7 100644 --- a/dnn/src/cuda/indexing_one_hot/kern.cu +++ b/dnn/src/cuda/indexing_one_hot/kern.cu @@ -10,27 +10,26 @@ */ #include "./kern.cuh" -#include "src/cuda/utils.cuh" #include "src/cuda/elemwise_helper.cuh" +#include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { -#define cb(_dt) \ +#define cb(_dt) \ typedef indexing_one_hot::OpGet::ctype, dt_int32> \ - OpGet##_dt; \ + OpGet##_dt; \ typedef indexing_one_hot::OpSet::ctype, dt_int32> \ - OpSet##_dt; \ - INST_RUN_ELEMWISE(OpGet##_dt, void, 0); \ + OpSet##_dt; \ + INST_RUN_ELEMWISE(OpGet##_dt, void, 0); \ INST_RUN_ELEMWISE(OpSet##_dt, void, 0); - MEGDNN_FOREACH_DTYPE_NAME(cb) - MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) +MEGDNN_FOREACH_DTYPE_NAME(cb) +MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) #undef cb -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen - diff --git a/dnn/src/cuda/indexing_one_hot/kern.cuh b/dnn/src/cuda/indexing_one_hot/kern.cuh index 08c7b5fe..79f4a6b5 100644 --- a/dnn/src/cuda/indexing_one_hot/kern.cuh +++ b/dnn/src/cuda/indexing_one_hot/kern.cuh @@ -37,10 +37,11 @@ struct KernParam { idx2 = offset - idx0 * shape_lo.divisor(); idx1 = idx[offset]; if (idx1 >= max_mid_index) { - set_async_error_info(error_info, error_tracker, - "invalid IndexingOneHot: " - "offset=%d idx0=%d indexer=%d idx2=%d", - offset, idx0, idx1, idx2); + set_async_error_info( + error_info, error_tracker, + "invalid IndexingOneHot: " + "offset=%d idx0=%d indexer=%d idx2=%d", + offset, idx0, idx1, idx2); idx1 = 0; } return idx0 * stride_hi + idx1 * shape_lo.divisor() + idx2; diff --git a/dnn/src/cuda/indexing_one_hot/opr_impl.cpp b/dnn/src/cuda/indexing_one_hot/opr_impl.cpp index f5a6fb47..2ed6990b 100644 --- a/dnn/src/cuda/indexing_one_hot/opr_impl.cpp +++ b/dnn/src/cuda/indexing_one_hot/opr_impl.cpp @@ -12,8 +12,8 @@ #include "./opr_impl.h" #include "./kern.cuh" -#include "src/cuda/utils.h" #include "src/cuda/elemwise_helper.cuh" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; @@ -21,20 +21,20 @@ using namespace indexing_one_hot; namespace { - KernParam make_kern_param(const TensorLayout &layout, size_t axis) { - KernParam ret; - memset(&ret, 0, sizeof(ret)); - ret.shape_lo = layout.stride[axis]; - ret.stride_hi = axis > 0 ? layout.stride[axis - 1] : 1; - ret.max_mid_index = layout[axis]; - return ret; - } +KernParam make_kern_param(const TensorLayout& layout, size_t axis) { + KernParam ret; + memset(&ret, 0, sizeof(ret)); + ret.shape_lo = layout.stride[axis]; + ret.stride_hi = axis > 0 ? layout.stride[axis - 1] : 1; + ret.max_mid_index = layout[axis]; + return ret; +} -} // anonymous namespace +} // anonymous namespace void IndexingOneHotForwardImpl::exec( - _megdnn_tensor_in src, _megdnn_tensor_in index, - _megdnn_tensor_out dst, _megdnn_workspace workspace) { + _megdnn_tensor_in src, _megdnn_tensor_in index, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { check_exec(src.layout, index.layout, dst.layout, workspace.size); ElemwiseOpParamN<0> ele_param{dst.layout.total_nr_elems()}; auto kern_param = make_kern_param(src.layout, m_param.axis); @@ -42,13 +42,12 @@ void IndexingOneHotForwardImpl::exec( kern_param.error_tracker = m_error_tracker; kern_param.error_info = async_error_info(handle()); -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: { \ - using ctype = DTypeTrait<_dt>::ctype; \ - using Op = OpGet::ctype, dt_int32>; \ - Op op{src.ptr(), index.ptr(), dst.ptr(), \ - kern_param}; \ - return run_elemwise(ele_param, stream, op); \ +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + using Op = OpGet::ctype, dt_int32>; \ + Op op{src.ptr(), index.ptr(), dst.ptr(), kern_param}; \ + return run_elemwise(ele_param, stream, op); \ } switch (src.layout.dtype.enumv()) { MEGDNN_FOREACH_COMPUTING_DTYPE(cb) @@ -59,8 +58,8 @@ void IndexingOneHotForwardImpl::exec( } void IndexingSetOneHotForwardImpl::exec( - _megdnn_tensor_inout data, _megdnn_tensor_in index, - _megdnn_tensor_in sub, _megdnn_workspace workspace) { + _megdnn_tensor_inout data, _megdnn_tensor_in index, _megdnn_tensor_in sub, + _megdnn_workspace workspace) { check_exec(data.layout, index.layout, sub.layout, workspace.size); ElemwiseOpParamN<0> ele_param{sub.layout.total_nr_elems()}; @@ -69,13 +68,12 @@ void IndexingSetOneHotForwardImpl::exec( kern_param.error_tracker = m_error_tracker; kern_param.error_info = async_error_info(handle()); -#define cb(_dt) \ - case DTypeTrait<_dt>::enumv: { \ - using ctype = DTypeTrait<_dt>::ctype; \ - using Op = OpSet::ctype, dt_int32>; \ - Op op{data.ptr(), index.ptr(), sub.ptr(), \ - kern_param}; \ - return run_elemwise(ele_param, stream, op); \ +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + using Op = OpSet::ctype, dt_int32>; \ + Op op{data.ptr(), index.ptr(), sub.ptr(), kern_param}; \ + return run_elemwise(ele_param, stream, op); \ } switch (data.layout.dtype.enumv()) { MEGDNN_FOREACH_COMPUTING_DTYPE(cb) @@ -86,5 +84,3 @@ void IndexingSetOneHotForwardImpl::exec( } // vim: syntax=cpp.doxygen - - diff --git a/dnn/src/cuda/indexing_one_hot/opr_impl.h b/dnn/src/cuda/indexing_one_hot/opr_impl.h index dac8b38a..153dfd31 100644 --- a/dnn/src/cuda/indexing_one_hot/opr_impl.h +++ b/dnn/src/cuda/indexing_one_hot/opr_impl.h @@ -16,42 +16,39 @@ namespace megdnn { namespace cuda { -class IndexingOneHotForwardImpl final: public IndexingOneHotForward { +class IndexingOneHotForwardImpl final : public IndexingOneHotForward { void* m_error_tracker = nullptr; - public: - using IndexingOneHotForward::IndexingOneHotForward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in index, - _megdnn_tensor_out dst, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &, - const TensorLayout &) override { - return 0; - } - - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } + +public: + using IndexingOneHotForward::IndexingOneHotForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in index, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } + + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } }; -class IndexingSetOneHotForwardImpl final: public IndexingSetOneHotForward { +class IndexingSetOneHotForwardImpl final : public IndexingSetOneHotForward { void* m_error_tracker = nullptr; - public: - using IndexingSetOneHotForward::IndexingSetOneHotForward; - void exec(_megdnn_tensor_inout data, _megdnn_tensor_in index, - _megdnn_tensor_in sub, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &, - const TensorLayout &) override { - return 0; - } - - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } + +public: + using IndexingSetOneHotForward::IndexingSetOneHotForward; + void exec( + _megdnn_tensor_inout data, _megdnn_tensor_in index, _megdnn_tensor_in sub, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } + + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } }; -} -} +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/int_fastdiv.cpp b/dnn/src/cuda/int_fastdiv.cpp index f56e1beb..0317f5e1 100644 --- a/dnn/src/cuda/int_fastdiv.cpp +++ b/dnn/src/cuda/int_fastdiv.cpp @@ -9,7 +9,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ - #include "src/cuda/int_fastdiv.cuh" #include @@ -20,7 +19,7 @@ Uint32Fastdiv::Uint32Fastdiv() { memset(this, 0, sizeof(Uint32Fastdiv)); } -Uint32Fastdiv& Uint32Fastdiv::operator = (uint32_t d) { +Uint32Fastdiv& Uint32Fastdiv::operator=(uint32_t d) { megdnn_assert(d); m_divisor = d; MEGDNN_CONSTEXPR uint32_t MAX_U32 = ~0u; @@ -31,7 +30,7 @@ Uint32Fastdiv& Uint32Fastdiv::operator = (uint32_t d) { m_mul = 1u << 31; int p = 0; while ((1u << p) < d) - ++ p; + ++p; megdnn_assert((1u << p) == d); m_shift = p ? p - 1 : 0; if (d == 1) @@ -41,13 +40,13 @@ Uint32Fastdiv& Uint32Fastdiv::operator = (uint32_t d) { auto n_bound = uint64_t(d / 2 + 1) * MAX_U32; uint32_t shift = 32; while ((1ull << shift) < n_bound) - ++ shift; + ++shift; uint64_t mdst = 1ull << shift; int64_t delta = d - mdst % d; m_mul = mdst / d + 1; if ((uint64_t)delta > d / 2) { delta -= d; - -- m_mul; + --m_mul; m_inc_dividend = 1; } megdnn_assert((uint64_t)m_mul * d == mdst + delta); diff --git a/dnn/src/cuda/int_fastdiv.cuh b/dnn/src/cuda/int_fastdiv.cuh index 02d74628..7297cbc5 100644 --- a/dnn/src/cuda/int_fastdiv.cuh +++ b/dnn/src/cuda/int_fastdiv.cuh @@ -25,45 +25,40 @@ namespace cuda { class Uint32Fastdiv { uint32_t m_mul, m_divisor, m_divisor_is_not_1, m_inc_dividend, m_shift; - public: - Uint32Fastdiv(); +public: + Uint32Fastdiv(); - Uint32Fastdiv(uint32_t d) { - operator=(d); - } + Uint32Fastdiv(uint32_t d) { operator=(d); } - //! set the divisor to be d - Uint32Fastdiv& operator = (uint32_t d); + //! set the divisor to be d + Uint32Fastdiv& operator=(uint32_t d); - //! caller must ensure that dividend would not exceed this number - static MEGDNN_CONSTEXPR uint32_t MAX_DIVIDEND = ~0u - 1; + //! caller must ensure that dividend would not exceed this number + static MEGDNN_CONSTEXPR uint32_t MAX_DIVIDEND = ~0u - 1; - __device__ __forceinline__ uint32_t divisor() const { - return m_divisor; - } + __device__ __forceinline__ uint32_t divisor() const { return m_divisor; } - __device__ __forceinline__ uint32_t divide(uint32_t dividend) const { - uint32_t - ans_for_one = dividend & ~m_divisor_is_not_1, - dfix = dividend + m_inc_dividend, + __device__ __forceinline__ uint32_t divide(uint32_t dividend) const { + uint32_t ans_for_one = dividend & ~m_divisor_is_not_1, + dfix = dividend + m_inc_dividend, #if MEGDNN_CC_CUDA - hi32 = __umulhi(dfix, m_mul), + hi32 = __umulhi(dfix, m_mul), #else - hi32 = ((uint64_t)dfix * m_mul) >> 32, + hi32 = ((uint64_t)dfix * m_mul) >> 32, #endif - ans = hi32 >> m_shift; + ans = hi32 >> m_shift; - return (ans & m_divisor_is_not_1) | ans_for_one; - } + return (ans & m_divisor_is_not_1) | ans_for_one; + } }; static __forceinline__ __device__ uint32_t -operator / (uint32_t a, const Uint32Fastdiv &d) { +operator/(uint32_t a, const Uint32Fastdiv& d) { return d.divide(a); } static __forceinline__ __device__ uint32_t -operator % (uint32_t a, const Uint32Fastdiv &d) { +operator%(uint32_t a, const Uint32Fastdiv& d) { return a - d.divisor() * d.divide(a); } @@ -71,10 +66,10 @@ operator % (uint32_t a, const Uint32Fastdiv &d) { * \brief maintain (a + k * x) / b and (a + k * x) % b for x >= 0 * \tparam need_quotient whether quotient need to be maintained */ -template +template class StridedDivSeq; -template<> +template <> class StridedDivSeq { Uint32Fastdiv m_b; @@ -84,32 +79,26 @@ class StridedDivSeq { //! current (a + k * x) % b uint32_t m_r; - public: - void host_init(uint32_t k, uint32_t b) { - m_b = b; - m_kr = k % b; - } - - //! init to k == 0 - __device__ __forceinline__ void device_init(uint32_t a) { - m_r = a % m_b; - } - - //! perform x += 1 - __device__ __forceinline__ void next() { - uint32_t b = m_b.divisor(), - r1 = m_r + m_kr, - carry_mask = (r1 < b) - 1; - m_r = r1 - (b & carry_mask); - } - - //! current remainder - __device__ __forceinline__ uint32_t r() const { - return m_r; - } +public: + void host_init(uint32_t k, uint32_t b) { + m_b = b; + m_kr = k % b; + } + + //! init to k == 0 + __device__ __forceinline__ void device_init(uint32_t a) { m_r = a % m_b; } + + //! perform x += 1 + __device__ __forceinline__ void next() { + uint32_t b = m_b.divisor(), r1 = m_r + m_kr, carry_mask = (r1 < b) - 1; + m_r = r1 - (b & carry_mask); + } + + //! current remainder + __device__ __forceinline__ uint32_t r() const { return m_r; } }; -template<> +template <> class StridedDivSeq { Uint32Fastdiv m_b; @@ -119,37 +108,31 @@ class StridedDivSeq { //! current (a + k * x) / b and (a + k * x) % b uint32_t m_q, m_r; - public: - void host_init(uint32_t k, uint32_t b) { - m_b = b; - m_kq = k / b; - m_kr = k % b; - } - - //! init to k == 0 - __device__ __forceinline__ void device_init(uint32_t a) { - m_q = m_b.divide(a); - m_r = a - m_b.divisor() * m_q; - } - - //! perform x += 1 - __device__ __forceinline__ void next() { - uint32_t b = m_b.divisor(), - r1 = m_r + m_kr, - carry_mask = (r1 < b) - 1; - m_q += m_kq + (r1 >= b); - m_r = r1 - (b & carry_mask); - } - - //! current quotient - __device__ __forceinline__ uint32_t q() const { - return m_q; - } - - //! current remainder - __device__ __forceinline__ uint32_t r() const { - return m_r; - } +public: + void host_init(uint32_t k, uint32_t b) { + m_b = b; + m_kq = k / b; + m_kr = k % b; + } + + //! init to k == 0 + __device__ __forceinline__ void device_init(uint32_t a) { + m_q = m_b.divide(a); + m_r = a - m_b.divisor() * m_q; + } + + //! perform x += 1 + __device__ __forceinline__ void next() { + uint32_t b = m_b.divisor(), r1 = m_r + m_kr, carry_mask = (r1 < b) - 1; + m_q += m_kq + (r1 >= b); + m_r = r1 - (b & carry_mask); + } + + //! current quotient + __device__ __forceinline__ uint32_t q() const { return m_q; } + + //! current remainder + __device__ __forceinline__ uint32_t r() const { return m_r; } }; /*! @@ -164,41 +147,35 @@ class StridedDivSeq2 { //! current (a + k * x) % b and (a + k * x) / b % c uint32_t m_cur_rkb, m_cur_ans; - public: - - void host_init(uint32_t k, uint32_t b, uint32_t c) { - m_b = b; - m_c = c; - m_qkb = k / b; - m_rkb = k % b; - m_rkbc = m_qkb % c; - } - - //! init to k == 0 - __device__ __forceinline__ void device_init(uint32_t a) { - uint32_t q = m_b.divide(a); - m_cur_rkb = a - m_b.divisor() * q; - m_cur_ans = q % m_c; - } - - //! perform x += 1 - __device__ __forceinline__ void next() { - uint32_t b = m_b.divisor(), - c = m_c.divisor(), - rkb = m_cur_rkb + m_rkb, - carry0 = (rkb < b) - 1, - next_ans = m_cur_ans + m_rkbc + (rkb >= b), - carry1 = (next_ans < c) - 1; - m_cur_rkb = rkb - (b & carry0); - m_cur_ans = next_ans - (c & carry1); - } - - __device__ __forceinline__ uint32_t get() const { - return m_cur_ans; - } +public: + void host_init(uint32_t k, uint32_t b, uint32_t c) { + m_b = b; + m_c = c; + m_qkb = k / b; + m_rkb = k % b; + m_rkbc = m_qkb % c; + } + + //! init to k == 0 + __device__ __forceinline__ void device_init(uint32_t a) { + uint32_t q = m_b.divide(a); + m_cur_rkb = a - m_b.divisor() * q; + m_cur_ans = q % m_c; + } + + //! perform x += 1 + __device__ __forceinline__ void next() { + uint32_t b = m_b.divisor(), c = m_c.divisor(), rkb = m_cur_rkb + m_rkb, + carry0 = (rkb < b) - 1, next_ans = m_cur_ans + m_rkbc + (rkb >= b), + carry1 = (next_ans < c) - 1; + m_cur_rkb = rkb - (b & carry0); + m_cur_ans = next_ans - (c & carry1); + } + + __device__ __forceinline__ uint32_t get() const { return m_cur_ans; } }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/integer_subbyte_utils.cuh b/dnn/src/cuda/integer_subbyte_utils.cuh index 0371933a..b6dfa1a0 100644 --- a/dnn/src/cuda/integer_subbyte_utils.cuh +++ b/dnn/src/cuda/integer_subbyte_utils.cuh @@ -43,8 +43,7 @@ MEGDNN_DEVICE __forceinline__ static int transform_int8_to_int4x8( "cvt.pack.sat.s4.s32.b32 %0, %2, %1, r4;" "}" : "=r"(out) - : "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6), - "r"(s7)); + : "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6), "r"(s7)); #else #define CVT_SAT_S4_S32(r, bits) \ r = r <= -8 ? -8 : r; \ @@ -78,8 +77,7 @@ MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8( "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" "}" : "=r"(out) - : "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6), - "r"(s7)); + : "r"(s0), "r"(s1), "r"(s2), "r"(s3), "r"(s4), "r"(s5), "r"(s6), "r"(s7)); #else #define CVT_SAT_U4_S32(r, bits) \ r = r <= 0 ? 0 : r; \ @@ -100,8 +98,7 @@ MEGDNN_DEVICE __forceinline__ static int transform_int8_to_uint4x8( } template -MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage, - int bits) { +MEGDNN_DEVICE __forceinline__ static int unpack_integer_4bits(T storage, int bits) { //! size in bits of 32 bit integer - 4 bits static constexpr int shift = 28; using type = typename integer_trait::type; diff --git a/dnn/src/cuda/linspace/linspace.cu b/dnn/src/cuda/linspace/linspace.cu index 4aa01df5..9526d631 100644 --- a/dnn/src/cuda/linspace/linspace.cu +++ b/dnn/src/cuda/linspace/linspace.cu @@ -8,43 +8,41 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/cuda/linspace/linspace.cuh" #include "megdnn/dtype.h" +#include "src/cuda/linspace/linspace.cuh" #include "src/cuda/utils.cuh" namespace { template -__global__ void kernel(T *dst, double start, double step, uint32_t n) -{ +__global__ void kernel(T* dst, double start, double step, uint32_t n) { uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; if (i < n) { - dst[i] = T(start + step*i); + dst[i] = T(start + step * i); } } -} // anonymous namespace +} // anonymous namespace namespace megdnn { namespace cuda { namespace linspace { template -void exec_internal(T *dst, double start, double step, size_t n, - cudaStream_t stream) -{ +void exec_internal(T* dst, double start, double step, size_t n, cudaStream_t stream) { uint32_t threads = NR_THREADS; uint32_t blocks = DIVUP(n, threads); kernel<<>>(dst, start, step, n); after_kernel_launch(); } -#define INST(T) template void exec_internal(T *dst, \ - double start, double step, size_t n, cudaStream_t stream); +#define INST(T) \ + template void exec_internal( \ + T * dst, double start, double step, size_t n, cudaStream_t stream); #define cb(DType) INST(typename DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb) -} // namespace linspace -} // namespace cuda -} // namespace megdnn +} // namespace linspace +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/linspace/linspace.cuh b/dnn/src/cuda/linspace/linspace.cuh index 75a4d5a8..315e816b 100644 --- a/dnn/src/cuda/linspace/linspace.cuh +++ b/dnn/src/cuda/linspace/linspace.cuh @@ -16,10 +16,9 @@ namespace cuda { namespace linspace { template -void exec_internal(T *dst, double start, double step, size_t n, - cudaStream_t stream); +void exec_internal(T* dst, double start, double step, size_t n, cudaStream_t stream); -} // namespace linspace -} // namespace cuda -} // namespace megdnn +} // namespace linspace +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/linspace/opr_impl.cpp b/dnn/src/cuda/linspace/opr_impl.cpp index 8dde1b99..40755743 100644 --- a/dnn/src/cuda/linspace/opr_impl.cpp +++ b/dnn/src/cuda/linspace/opr_impl.cpp @@ -16,23 +16,21 @@ namespace megdnn { namespace cuda { -void LinspaceImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) -{ +void LinspaceImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(dst.layout, workspace.size); auto stream = cuda_stream(handle()); auto n = dst.layout.total_nr_elems(); auto step = (param().stop - param().start) / - std::max(static_cast(param().endpoint ? n-1 : n), 1.0); -#define cb(DType) \ - if (dst.layout.dtype == DType()) { \ - using ctype = typename DTypeTrait::ctype; \ - linspace::exec_internal(dst.ptr(), \ - param().start, step, n, \ - stream); \ + std::max(static_cast(param().endpoint ? n - 1 : n), 1.0); +#define cb(DType) \ + if (dst.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + linspace::exec_internal( \ + dst.ptr(), param().start, step, n, stream); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/linspace/opr_impl.h b/dnn/src/cuda/linspace/opr_impl.h index 7cac7da5..4c63efa1 100644 --- a/dnn/src/cuda/linspace/opr_impl.h +++ b/dnn/src/cuda/linspace/opr_impl.h @@ -14,15 +14,13 @@ namespace megdnn { namespace cuda { -class LinspaceImpl final: public Linspace { - public: - using Linspace::Linspace; - void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &) override { - return 0; - } +class LinspaceImpl final : public Linspace { +public: + using Linspace::Linspace; + void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/cuda/local/backward_data.cpp b/dnn/src/cuda/local/backward_data.cpp index 52881d44..25e0d45b 100644 --- a/dnn/src/cuda/local/backward_data.cpp +++ b/dnn/src/cuda/local/backward_data.cpp @@ -10,111 +10,86 @@ */ #include "src/cuda/local/opr_impl.h" -#include "src/cuda/local/local.cuh" #include "src/cuda/handle.h" +#include "src/cuda/local/local.cuh" #include "src/cuda/utils.h" namespace megdnn { namespace cuda { namespace local { -void boom_backward_data() -{ +void boom_backward_data() { megdnn_throw("Local bad param: cannot do backward_data by cuda_convnet"); } -} // namespace local -} // namespace cuda -} // namespace megdnn +} // namespace local +} // namespace cuda +} // namespace megdnn namespace megdnn { namespace cuda { -void LocalBackwardDataImpl::exec(_megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) -{ +void LocalBackwardDataImpl::exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { check_exec(filter.layout, diff.layout, grad.layout, workspace.size); megdnn_assert(param().mode == Mode::CROSS_CORRELATION); - auto N = grad.layout.shape[0], - IC = grad.layout.shape[1], - IH = grad.layout.shape[2], + auto N = grad.layout.shape[0], IC = grad.layout.shape[1], IH = grad.layout.shape[2], IW = grad.layout.shape[3]; - auto OC = diff.layout.shape[1], - OH = diff.layout.shape[2], + auto OC = diff.layout.shape[1], OH = diff.layout.shape[2], OW = diff.layout.shape[3]; - auto FH = filter.layout.shape[3], - FW = filter.layout.shape[4]; + auto FH = filter.layout.shape[3], FW = filter.layout.shape[4]; auto handle = concrete_handle(this->handle()); auto stream = cuda_stream(this->handle()); auto cublas = cublas_handle(this->handle()); auto one = handle->one_device(); auto zero = handle->zero_device(); if (use_cuda_convnet(filter.layout, diff.layout, grad.layout)) { - local::backward_data_proxy_convnet(filter.ptr(), - diff.ptr(), - grad.ptr(), - reinterpret_cast(workspace.raw_ptr), - N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - IC*IH*IW, OC*OH*OW, - param().pad_h, param().pad_w, - param().stride_h, param().stride_w, - cublas, stream, - one, zero); + local::backward_data_proxy_convnet( + filter.ptr(), diff.ptr(), + grad.ptr(), reinterpret_cast(workspace.raw_ptr), N, + IC, IH, IW, OC, OH, OW, FH, FW, IC * IH * IW, OC * OH * OW, + param().pad_h, param().pad_w, param().stride_h, param().stride_w, + cublas, stream, one, zero); } else { local::boom_backward_data(); } } -size_t LocalBackwardDataImpl::get_workspace_in_bytes(const TensorLayout &filter, - const TensorLayout &diff, - const TensorLayout &grad) -{ - auto N = grad.shape[0], - IC = grad.shape[1], IH = grad.shape[2], IW = grad.shape[3], +size_t LocalBackwardDataImpl::get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) { + auto N = grad.shape[0], IC = grad.shape[1], IH = grad.shape[2], IW = grad.shape[3], OC = diff.shape[1], OH = diff.shape[2], OW = diff.shape[3], FH = filter.shape[3], FW = filter.shape[4]; - auto PH = param().pad_h, PW = param().pad_w, - SH = param().stride_h, SW = param().stride_w; + auto PH = param().pad_h, PW = param().pad_w, SH = param().stride_h, + SW = param().stride_w; size_t res = 0u; if (use_cuda_convnet(filter, diff, grad)) { - res = local::get_workspace_in_floats_backward_data_proxy_convnet(N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - IC*IH*IW, OC*OH*OW, - PH, PW, - SH, SW) * sizeof(dt_float32); + res = local::get_workspace_in_floats_backward_data_proxy_convnet( + N, IC, IH, IW, OC, OH, OW, FH, FW, IC * IH * IW, OC * OH * OW, PH, + PW, SH, SW) * + sizeof(dt_float32); } else { local::boom_backward_data(); } return res; } -bool LocalBackwardDataImpl::use_cuda_convnet(const TensorLayout &filter, - const TensorLayout &diff, - const TensorLayout &grad) -{ - auto N = grad.shape[0], - IC = grad.shape[1], IH = grad.shape[2], IW = grad.shape[3], +bool LocalBackwardDataImpl::use_cuda_convnet( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) { + auto N = grad.shape[0], IC = grad.shape[1], IH = grad.shape[2], IW = grad.shape[3], OC = diff.shape[1], OH = diff.shape[2], OW = diff.shape[3], FH = filter.shape[3], FW = filter.shape[4]; - auto PH = param().pad_h, PW = param().pad_w, - SH = param().stride_h, SW = param().stride_w; - return local::can_backward_data_proxy_convnet(N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - IC*IH*IW, OC*OH*OW, - PH, PW, - SH, SW); + auto PH = param().pad_h, PW = param().pad_w, SH = param().stride_h, + SW = param().stride_w; + return local::can_backward_data_proxy_convnet( + N, IC, IH, IW, OC, OH, OW, FH, FW, IC * IH * IW, OC * OH * OW, PH, PW, SH, + SW); } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/local/backward_data.cu b/dnn/src/cuda/local/backward_data.cu index 9115d4b3..fc9bb829 100644 --- a/dnn/src/cuda/local/backward_data.cu +++ b/dnn/src/cuda/local/backward_data.cu @@ -10,31 +10,28 @@ */ #include "src/cuda/local/local.cuh" -#include "src/cuda/utils.cuh" -#include "src/cuda/local/cuda-convnet2/nvmatrix.cuh" #include "src/cuda/local/cuda-convnet2/cudaconv2.cuh" +#include "src/cuda/local/cuda-convnet2/nvmatrix.cuh" +#include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { namespace local { -bool can_backward_data_proxy_convnet(size_t N, - size_t IC, size_t /* IH */, size_t /* IW */, - size_t /*OC*/, size_t /* OH */, size_t /* OW */, - size_t FH, size_t FW, - size_t /* INs */, size_t /* ONs */, - size_t PH, size_t PW, - size_t SH, size_t SW) -{ +bool can_backward_data_proxy_convnet( + size_t N, size_t IC, size_t /* IH */, size_t /* IW */, size_t /*OC*/, + size_t /* OH */, size_t /* OW */, size_t FH, size_t FW, size_t /* INs */, + size_t /* ONs */, size_t PH, size_t PW, size_t SH, size_t SW) { bool flag = true; // check pad flag &= (PH == PW); // check stride flag &= (SH == SW); - // megdnn_assert(numGroups > 1 || (numImgColors > 0 && (numImgColors <= 3 || numImgColors % 2 == 0))); + // megdnn_assert(numGroups > 1 || (numImgColors > 0 && (numImgColors <= 3 || + // numImgColors % 2 == 0))); flag &= (IC <= 3 || IC % 8 == 0); // megdnn_assert(numFilters % (16 * numGroups) == 0); - //flag &= (OC % 16 == 0); + // flag &= (OC % 16 == 0); // megdnn_assert(filterSize * filterSize == filterPixels); flag &= (FH == FW); flag &= (SH <= FH); @@ -42,53 +39,37 @@ bool can_backward_data_proxy_convnet(size_t N, return flag; } -size_t get_workspace_in_floats_backward_data_proxy_convnet(size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t /* FH */, size_t /* FW */, - size_t /* INs */, size_t /* ONs */, - size_t /* PH */, size_t /* PW */, - size_t /* SH */, size_t /* SW */) -{ - return N*IC*IH*IW + N*OC*OH*OW; +size_t get_workspace_in_floats_backward_data_proxy_convnet( + size_t N, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, + size_t /* FH */, size_t /* FW */, size_t /* INs */, size_t /* ONs */, + size_t /* PH */, size_t /* PW */, size_t /* SH */, size_t /* SW */) { + return N * IC * IH * IW + N * OC * OH * OW; } -void backward_data_proxy_convnet(const float *filter, - const float *diff, - float *grad, - float *workspace, - size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, - size_t PH, size_t /* PW */, - size_t SH, size_t /* SW */, - cublasHandle_t cublas_handle, - cudaStream_t stream, - float *one, float *zero) -{ - MemorySegment mhid_n(const_cast(diff)), - mfilter(const_cast(filter)), - mtarget_n(grad), - mtarget_t(workspace), - mhid_t(workspace+N*IC*IH*IW); - NVMatrix nvhid_n(&mhid_n, N, OC*OH*OW, ONs), - nvfilter(&mfilter, OH*OW*IC*FH*FW, OC), - nvtarget_n(&mtarget_n, N, IC*IH*IW, INs), - nvhid_t(&mhid_t, OC*OH*OW, N), - nvtarget_t(&mtarget_t, IC*IH*IW, N); +void backward_data_proxy_convnet( + const float* filter, const float* diff, float* grad, float* workspace, size_t N, + size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, + size_t FW, size_t INs, size_t ONs, size_t PH, size_t /* PW */, size_t SH, + size_t /* SW */, cublasHandle_t cublas_handle, cudaStream_t stream, float* one, + float* zero) { + MemorySegment mhid_n(const_cast(diff)), mfilter(const_cast(filter)), + mtarget_n(grad), mtarget_t(workspace), mhid_t(workspace + N * IC * IH * IW); + NVMatrix nvhid_n(&mhid_n, N, OC * OH * OW, ONs), + nvfilter(&mfilter, OH * OW * IC * FH * FW, OC), + nvtarget_n(&mtarget_n, N, IC * IH * IW, INs), + nvhid_t(&mhid_t, OC * OH * OW, N), nvtarget_t(&mtarget_t, IC * IH * IW, N); nvhid_n.transpose(nvhid_t, cublas_handle, one, zero); - localImgActs(stream, nvhid_t, nvfilter, nvtarget_t, - IH, IW, OH, -static_cast(PH), SH, IC, 1); + localImgActs( + stream, nvhid_t, nvfilter, nvtarget_t, IH, IW, OH, -static_cast(PH), + SH, IC, 1); after_kernel_launch(); nvtarget_t.transpose(nvtarget_n, cublas_handle, one, zero); } -} // namespace local -} // namespace cuda -} // namespace megdnn +} // namespace local +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/local/backward_filter.cpp b/dnn/src/cuda/local/backward_filter.cpp index 0bf31076..37ba2776 100644 --- a/dnn/src/cuda/local/backward_filter.cpp +++ b/dnn/src/cuda/local/backward_filter.cpp @@ -10,110 +10,83 @@ */ #include "src/cuda/local/opr_impl.h" -#include "src/cuda/local/local.cuh" #include "src/cuda/handle.h" +#include "src/cuda/local/local.cuh" #include "src/cuda/utils.h" namespace megdnn { namespace cuda { namespace local { -void boom_backward_filter() -{ +void boom_backward_filter() { megdnn_throw("Local bad param: cannot do backward_filter by cuda_convnet"); } -} // namespace local -} // namespace cuda -} // namespace megdnn +} // namespace local +} // namespace cuda +} // namespace megdnn namespace megdnn { namespace cuda { -void LocalBackwardFilterImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) -{ +void LocalBackwardFilterImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { check_exec(src.layout, diff.layout, grad.layout, workspace.size); megdnn_assert(param().mode == Mode::CROSS_CORRELATION); - auto N = src.layout.shape[0], - IC = src.layout.shape[1], - IH = src.layout.shape[2], + auto N = src.layout.shape[0], IC = src.layout.shape[1], IH = src.layout.shape[2], IW = src.layout.shape[3]; - auto OC = diff.layout.shape[1], - OH = diff.layout.shape[2], + auto OC = diff.layout.shape[1], OH = diff.layout.shape[2], OW = diff.layout.shape[3]; - auto FH = grad.layout.shape[3], - FW = grad.layout.shape[4]; + auto FH = grad.layout.shape[3], FW = grad.layout.shape[4]; auto handle = concrete_handle(this->handle()); auto stream = cuda_stream(this->handle()); auto cublas = cublas_handle(this->handle()); auto one = handle->one_device(); auto zero = handle->zero_device(); if (use_cuda_convnet(src.layout, diff.layout, grad.layout)) { - local::backward_filter_proxy_convnet(src.ptr(), - diff.ptr(), - grad.ptr(), - reinterpret_cast(workspace.raw_ptr), - N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - IC*IH*IW, OC*OH*OW, - param().pad_h, param().pad_w, - param().stride_h, param().stride_w, - cublas, stream, - one, zero); + local::backward_filter_proxy_convnet( + src.ptr(), diff.ptr(), grad.ptr(), + reinterpret_cast(workspace.raw_ptr), N, IC, IH, IW, OC, OH, OW, + FH, FW, IC * IH * IW, OC * OH * OW, param().pad_h, param().pad_w, + param().stride_h, param().stride_w, cublas, stream, one, zero); } else { local::boom_backward_filter(); } } -size_t LocalBackwardFilterImpl::get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &diff, - const TensorLayout &grad) -{ - auto N = src.shape[0], - IC = src.shape[1], IH = src.shape[2], IW = src.shape[3], - OC = diff.shape[1], OH = diff.shape[2], OW = diff.shape[3], - FH = grad.shape[3], FW = grad.shape[4]; - auto SH = param().stride_h, SW = param().stride_w, - PH = param().pad_h, PW = param().pad_w; +size_t LocalBackwardFilterImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { + auto N = src.shape[0], IC = src.shape[1], IH = src.shape[2], IW = src.shape[3], + OC = diff.shape[1], OH = diff.shape[2], OW = diff.shape[3], FH = grad.shape[3], + FW = grad.shape[4]; + auto SH = param().stride_h, SW = param().stride_w, PH = param().pad_h, + PW = param().pad_w; size_t res = 0u; if (use_cuda_convnet(src, diff, grad)) { - res = local::get_workspace_in_floats_backward_filter_proxy_convnet(N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - IC*IH*IW, OC*OH*OW, - SH, SW, - PH, PW) * sizeof(dt_float32); + res = local::get_workspace_in_floats_backward_filter_proxy_convnet( + N, IC, IH, IW, OC, OH, OW, FH, FW, IC * IH * IW, OC * OH * OW, SH, + SW, PH, PW) * + sizeof(dt_float32); } else { local::boom_backward_filter(); } return res; } -bool LocalBackwardFilterImpl::use_cuda_convnet(const TensorLayout &src, - const TensorLayout &diff, - const TensorLayout &grad) -{ - auto N = src.shape[0], - IC = src.shape[1], IH = src.shape[2], IW = src.shape[3], - OC = diff.shape[1], OH = diff.shape[2], OW = diff.shape[3], - FH = grad.shape[3], FW = grad.shape[4]; - auto SH = param().stride_h, SW = param().stride_w, - PH = param().pad_h, PW = param().pad_w; - return local::can_backward_filter_proxy_convnet(N, IC, IH, IW, - OC, OH, OW, - FH, FW, - IC*IH*IW, OC*OH*OW, - PH, PW, - SH, SW); +bool LocalBackwardFilterImpl::use_cuda_convnet( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { + auto N = src.shape[0], IC = src.shape[1], IH = src.shape[2], IW = src.shape[3], + OC = diff.shape[1], OH = diff.shape[2], OW = diff.shape[3], FH = grad.shape[3], + FW = grad.shape[4]; + auto SH = param().stride_h, SW = param().stride_w, PH = param().pad_h, + PW = param().pad_w; + return local::can_backward_filter_proxy_convnet( + N, IC, IH, IW, OC, OH, OW, FH, FW, IC * IH * IW, OC * OH * OW, PH, PW, SH, + SW); } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/local/backward_filter.cu b/dnn/src/cuda/local/backward_filter.cu index 28fd6274..c2c7c8a2 100644 --- a/dnn/src/cuda/local/backward_filter.cu +++ b/dnn/src/cuda/local/backward_filter.cu @@ -10,31 +10,28 @@ */ #include "src/cuda/local/local.cuh" -#include "src/cuda/utils.cuh" -#include "src/cuda/local/cuda-convnet2/nvmatrix.cuh" #include "src/cuda/local/cuda-convnet2/cudaconv2.cuh" +#include "src/cuda/local/cuda-convnet2/nvmatrix.cuh" +#include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { namespace local { -bool can_backward_filter_proxy_convnet(size_t N, - size_t IC, size_t /* IH */, size_t /* IW */, - size_t /*OC*/, size_t /* OH */, size_t /* OW */, - size_t FH, size_t FW, - size_t /* INs */, size_t /* ONs */, - size_t PH, size_t PW, - size_t SH, size_t SW) -{ +bool can_backward_filter_proxy_convnet( + size_t N, size_t IC, size_t /* IH */, size_t /* IW */, size_t /*OC*/, + size_t /* OH */, size_t /* OW */, size_t FH, size_t FW, size_t /* INs */, + size_t /* ONs */, size_t PH, size_t PW, size_t SH, size_t SW) { bool flag = true; // check pad flag &= (PH == PW); // check stride flag &= (SH == SW); - // megdnn_assert(numGroups > 1 || (numImgColors > 0 && (numImgColors <= 3 || numImgColors % 16 == 0))); + // megdnn_assert(numGroups > 1 || (numImgColors > 0 && (numImgColors <= 3 || + // numImgColors % 16 == 0))); flag &= (IC <= 3 || IC % 8 == 0); // megdnn_assert(numFilters % (16 * numGroups) == 0); - //flag &= (OC % 16 == 0); + // flag &= (OC % 16 == 0); // megdnn_assert(filterSize * filterSize == filterPixels); flag &= (FH == FW); flag &= (SH <= FH); @@ -42,53 +39,37 @@ bool can_backward_filter_proxy_convnet(size_t N, return flag; } -size_t get_workspace_in_floats_backward_filter_proxy_convnet(size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t /* FH */, size_t /* FW */, - size_t /* INs */, size_t /* ONs */, - size_t /* PH */, size_t /* PW */, - size_t /* SH */, size_t /* SW */) -{ - return N*IC*IH*IW + N*OC*OH*OW; +size_t get_workspace_in_floats_backward_filter_proxy_convnet( + size_t N, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, + size_t /* FH */, size_t /* FW */, size_t /* INs */, size_t /* ONs */, + size_t /* PH */, size_t /* PW */, size_t /* SH */, size_t /* SW */) { + return N * IC * IH * IW + N * OC * OH * OW; } -void backward_filter_proxy_convnet(const float *src, - const float *diff, - float *grad, - float *workspace, - size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, - size_t PH, size_t /* PW */, - size_t SH, size_t /* SW */, - cublasHandle_t cublas_handle, - cudaStream_t stream, - float *one, float *zero) -{ - MemorySegment mimage_n(const_cast(src)), - mhid_n(const_cast(diff)), - mimage_t(workspace), - mhid_t(workspace+N*IC*IH*IW), - mtarget(grad); - NVMatrix nvimage_n(&mimage_n, N, IC*IH*IW, INs), - nvhid_n(&mhid_n, N, OC*OH*OW, ONs), - nvimage_t(&mimage_t, IC*IH*IW, N), - nvhid_t(&mhid_t, OC*OH*OW, N), - nvtarget(&mtarget, OH*OW*IC*FH*FW, OC); +void backward_filter_proxy_convnet( + const float* src, const float* diff, float* grad, float* workspace, size_t N, + size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, + size_t FW, size_t INs, size_t ONs, size_t PH, size_t /* PW */, size_t SH, + size_t /* SW */, cublasHandle_t cublas_handle, cudaStream_t stream, float* one, + float* zero) { + MemorySegment mimage_n(const_cast(src)), mhid_n(const_cast(diff)), + mimage_t(workspace), mhid_t(workspace + N * IC * IH * IW), mtarget(grad); + NVMatrix nvimage_n(&mimage_n, N, IC * IH * IW, INs), + nvhid_n(&mhid_n, N, OC * OH * OW, ONs), + nvimage_t(&mimage_t, IC * IH * IW, N), nvhid_t(&mhid_t, OC * OH * OW, N), + nvtarget(&mtarget, OH * OW * IC * FH * FW, OC); nvhid_n.transpose(nvhid_t, cublas_handle, one, zero); nvimage_n.transpose(nvimage_t, cublas_handle, one, zero); - localWeightActs(stream, nvimage_t, nvhid_t, nvtarget, - IH, OH, OW, FH, -static_cast(PH), SH, IC, 1); + localWeightActs( + stream, nvimage_t, nvhid_t, nvtarget, IH, OH, OW, FH, -static_cast(PH), + SH, IC, 1); after_kernel_launch(); } -} // namespace local -} // namespace cuda -} // namespace megdnn +} // namespace local +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/local/cuda-convnet2/cudaconv2.cuh b/dnn/src/cuda/local/cuda-convnet2/cudaconv2.cuh index 3019506d..12e6d119 100644 --- a/dnn/src/cuda/local/cuda-convnet2/cudaconv2.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/cudaconv2.cuh @@ -25,69 +25,84 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ - #ifndef COMMON_CUH -#define COMMON_CUH +#define COMMON_CUH #define MIN(x, y) ((x) < (y) ? (x) : (y)) #define MAX(x, y) ((x) > (y) ? (x) : (y)) -#include "helper_cuda.h" // helper functions CUDA error checking and initialization +#include "helper_cuda.h" // helper functions CUDA error checking and initialization #include "nvmatrix.cuh" namespace megdnn { namespace cuda { -enum FILTER_OUTPUT_ORDER {MODULE_FILTER_IMAGE, FILTER_MODULE_IMAGE}; - -void convFilterActs(cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int paddingStart, int moduleStride, - int numImgColors, int numGroups); -void convFilterActs(cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int paddingStart, int moduleStride, - int numImgColors, int numGroups, - float scaleTargets, float scaleOutput); +enum FILTER_OUTPUT_ORDER { MODULE_FILTER_IMAGE, FILTER_MODULE_IMAGE }; -void localFilterActs(cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int paddingStart, int moduleStride, - int numImgColors, int numGroups); -void localFilterActs(cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int paddingStart, int moduleStride, - int numImgColors, int numGroups, - float scaleTargets, float scaleOutput); +void convFilterActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int paddingStart, + int moduleStride, int numImgColors, int numGroups); +void convFilterActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int paddingStart, + int moduleStride, int numImgColors, int numGroups, float scaleTargets, + float scaleOutput); -void convImgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, int numImgColors, int numGroups); -void convImgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, int numImgColors, int numGroups, - float scaleTargets, float scaleOutput); +void localFilterActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int paddingStart, + int moduleStride, int numImgColors, int numGroups); +void localFilterActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int paddingStart, + int moduleStride, int numImgColors, int numGroups, float scaleTargets, + float scaleOutput); -void localImgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, int numImgColors, int numGroups); -void localImgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, int numImgColors, int numGroups, - float scaleTargets, float scaleOutput); +void convImgActs( + cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, + int numImgColors, int numGroups); +void convImgActs( + cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, + int numImgColors, int numGroups, float scaleTargets, float scaleOutput); -void convWeightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int filterSize, int paddingStart, - int moduleStride, int numImgColors, int numGroups, int sumWidth); -void convWeightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int filterSize, int paddingStart, int moduleStride, - int numImgColors, int numGroups, int sumWidth, - float scaleTargets, float scaleOutput); +void localImgActs( + cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, + int numImgColors, int numGroups); +void localImgActs( + cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, + int numImgColors, int numGroups, float scaleTargets, float scaleOutput); -void localWeightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int filterSize, int paddingStart, - int moduleStride, int numImgColors, int numGroups); +void convWeightActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int filterSize, + int paddingStart, int moduleStride, int numImgColors, int numGroups, + int sumWidth); +void convWeightActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int filterSize, + int paddingStart, int moduleStride, int numImgColors, int numGroups, + int sumWidth, float scaleTargets, float scaleOutput); -void localWeightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int filterSize, int paddingStart, int moduleStride, - int numImgColors, int numGroups, float scaleTargets, float scaleOutput); -} -} +void localWeightActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int filterSize, + int paddingStart, int moduleStride, int numImgColors, int numGroups); -#endif /* COMMON_CUH */ +void localWeightActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int filterSize, + int paddingStart, int moduleStride, int numImgColors, int numGroups, + float scaleTargets, float scaleOutput); +} // namespace cuda +} // namespace megdnn +#endif /* COMMON_CUH */ diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts.cu b/dnn/src/cuda/local/cuda-convnet2/filter_acts.cu index d650cd03..d8ef9b46 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts.cu +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts.cu @@ -25,28 +25,40 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ -#include "nvmatrix.cuh" #include "cudaconv2.cuh" -#include "src/cuda/utils.cuh" #include "filter_acts/filter_act_templates.cuh" +#include "nvmatrix.cuh" +#include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { -__device__ __forceinline__ void filterActs_YxX_color_preload_ty_4_tx_32_f_16_cc_3_setImgCoords(int fPidx, int imgLoadModPosY, int imgLoadModPosX, - int imgSizeX, int filterSize, int& iPidx) { +__device__ __forceinline__ void +filterActs_YxX_color_preload_ty_4_tx_32_f_16_cc_3_setImgCoords( + int fPidx, int imgLoadModPosY, int imgLoadModPosX, int imgSizeX, int filterSize, + int& iPidx) { int x = imgLoadModPosX + (fPidx) % filterSize; int y = imgLoadModPosY + (fPidx) / filterSize; iPidx = y >= 0 && y < imgSizeX && x >= 0 && x < imgSizeX ? y * imgSizeX + x : -1; } -#define FA_COLOR3_IMPRELOAD(c,i) imPreload[c][i] = iPidxNext < 0 || (checkImgBounds && myImgIdx + i * B_X >= numImages) ? 0 : mm[c * imgPixels * imgStride + i * B_X]; -#define FA_COLOR3_IMPRELOAD_TX(c,i) imPreload[c][i] = iPidxNext < 0 || (checkImgBounds && myImgIdx + i * B_X >= numImages) ? 0 : tex1Dfetch(images, imagesOffset2 + c * imgPixels * imgStride + i * B_X); - +#define FA_COLOR3_IMPRELOAD(c, i) \ + imPreload[c][i] = \ + iPidxNext < 0 || (checkImgBounds && myImgIdx + i * B_X >= numImages) \ + ? 0 \ + : mm[c * imgPixels * imgStride + i * B_X]; +#define FA_COLOR3_IMPRELOAD_TX(c, i) \ + imPreload[c][i] = \ + iPidxNext < 0 || (checkImgBounds && myImgIdx + i * B_X >= numImages) \ + ? 0 \ + : tex1Dfetch( \ + images, \ + imagesOffset2 + c * imgPixels * imgStride + i * B_X); /* * images: (numImgColors, imgSizeY, imgSizeX, numImages) with stride given @@ -56,24 +68,31 @@ __device__ __forceinline__ void filterActs_YxX_color_preload_ty_4_tx_32_f_16_cc_ * targets: (numFilters, numModulesY, numModulesX, numImages) * */ -template +template < + int B_Y, int B_X, int imgsPerThread, int filtersPerThread, int numColors, + int pixelCache, bool scale, bool checkImgBounds> //__launch_bounds__(128,3) -__global__ void filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex(cudaTextureObject_t images, cudaTextureObject_t filters, float* targets, - const int numImages, const int numFilters, - const int imgSizeY, const int imgSizeX, const int filterSize, const int paddingStart, - const int moduleStride, - const int numModulesY, const int numModulesX, const int imgStride, - const float scaleTargets, const float scaleOutputs, - const bool conv/*, const bool noloads*/) { - __shared__ float shFilters[numColors][pixelCache][B_Y * filtersPerThread]; // pre-load 1 pixel from B_Y*filtersPerThread filters - __shared__ float shImages[numColors][pixelCache][B_X * imgsPerThread]; // pre-load 1 pixel from B_X*imgsPerThread images - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); +__global__ void filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex( + cudaTextureObject_t images, cudaTextureObject_t filters, float* targets, + const int numImages, const int numFilters, const int imgSizeY, + const int imgSizeX, const int filterSize, const int paddingStart, + const int moduleStride, const int numModulesY, const int numModulesX, + const int imgStride, const float scaleTargets, const float scaleOutputs, + const bool conv /*, const bool noloads*/) { + __shared__ float shFilters + [numColors][pixelCache] + [B_Y * + filtersPerThread]; // pre-load 1 pixel from B_Y*filtersPerThread filters + __shared__ float + shImages[numColors][pixelCache] + [B_X * + imgsPerThread]; // pre-load 1 pixel from B_X*imgsPerThread images + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); __syncthreads(); const int imgPixels = imgSizeY * imgSizeX; const int filterPixels = filterSize * filterSize; - const int blocksPerModule = numFilters / (B_Y*filtersPerThread); + const int blocksPerModule = numFilters / (B_Y * filtersPerThread); const int moduleIdx = blockIdx.y / blocksPerModule; const int blockFilterIdx = filtersPerThread * B_Y * (blockIdx.y % blocksPerModule); @@ -91,64 +110,71 @@ __global__ void filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex(c const int shFilterLoadX = tidx % (B_Y * filtersPerThread); const int myImgIdx = blockIdx.x * B_X * imgsPerThread + threadIdx.x; -// images += myImgIdx; -// filters += blockFilterIdx -// + shFilterLoadY * numFilters + shFilterLoadX; -// if (!conv) { // NOTE: UNTESTED! -// filters += moduleIdx * numColors * filterPixels * numFilters; -// } + // images += myImgIdx; + // filters += blockFilterIdx + // + shFilterLoadY * numFilters + shFilterLoadX; + // if (!conv) { // NOTE: UNTESTED! + // filters += moduleIdx * numColors * filterPixels * numFilters; + // } const int imagesOffset = myImgIdx; - const int filtersOffset = blockFilterIdx + shFilterLoadY * numFilters + shFilterLoadX - + (conv ? 0 : moduleIdx * numColors * filterPixels * numFilters); + const int filtersOffset = + blockFilterIdx + shFilterLoadY * numFilters + shFilterLoadX + + (conv ? 0 : moduleIdx * numColors * filterPixels * numFilters); - targets += moduleIdx * numImages - + (blockFilterIdx + threadIdx.y * filtersPerThread) * numImages * numModules - + myImgIdx; + targets += + moduleIdx * numImages + + (blockFilterIdx + threadIdx.y * filtersPerThread) * numImages * numModules + + myImgIdx; float prod[imgsPerThread][filtersPerThread]; - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { prod[i][f] = 0; } } int iPidxNext; float imPreload[numColors][imgsPerThread]; - float fPreload[numColors][pixelCache*filtersPerThread/B_X]; + float fPreload[numColors][pixelCache * filtersPerThread / B_X]; - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; ++c) { - #pragma unroll - for (int p = 0; p < pixelCache; p += B_X/filtersPerThread) { +#pragma unroll + for (int p = 0; p < pixelCache; p += B_X / filtersPerThread) { if (p + shFilterLoadY < filterPixels) { - fPreload[c][p*filtersPerThread/B_X] = tex1Dfetch(filters, filtersOffset + p * numFilters + c * numFilters * filterPixels); - } else{ - fPreload[c][p*filtersPerThread/B_X] = 0; + fPreload[c][p * filtersPerThread / B_X] = tex1Dfetch( + filters, + filtersOffset + p * numFilters + c * numFilters * filterPixels); + } else { + fPreload[c][p * filtersPerThread / B_X] = 0; } } } - filterActs_YxX_color_preload_ty_4_tx_32_f_16_cc_3_setImgCoords(ty, imgLoadModPosY, imgLoadModPosX, imgSizeX, filterSize, iPidxNext); + filterActs_YxX_color_preload_ty_4_tx_32_f_16_cc_3_setImgCoords( + ty, imgLoadModPosY, imgLoadModPosX, imgSizeX, filterSize, iPidxNext); - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; ++c) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (iPidxNext >= 0 && (!checkImgBounds || myImgIdx + i * B_X < numImages)) { - imPreload[c][i] = tex1Dfetch(images, imagesOffset + (c * imgPixels + iPidxNext) * imgStride + i * B_X); + imPreload[c][i] = tex1Dfetch( + images, imagesOffset + (c * imgPixels + iPidxNext) * imgStride + + i * B_X); } else { - imPreload[c][i] = 0; + imPreload[c][i] = 0; } } } for (int p = 0; p < filterPixels; p += pixelCache) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; ++c) { // NOTE: bank conflicts here! shImages[c][ty][tx * imgsPerThread + i] = imPreload[c][i]; @@ -156,51 +182,64 @@ __global__ void filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex(c } const int fPidxNext = p + pixelCache >= filterPixels ? 0 : p + pixelCache; - filterActs_YxX_color_preload_ty_4_tx_32_f_16_cc_3_setImgCoords(fPidxNext + ty, imgLoadModPosY, imgLoadModPosX, imgSizeX, filterSize, iPidxNext); + filterActs_YxX_color_preload_ty_4_tx_32_f_16_cc_3_setImgCoords( + fPidxNext + ty, imgLoadModPosY, imgLoadModPosX, imgSizeX, filterSize, + iPidxNext); -// const float* ff = &filters[numFilters * fPidxNext]; -// const float* mm = &images[imgStride * iPidxNext]; + // const float* ff = &filters[numFilters * fPidxNext]; + // const float* mm = &images[imgStride * iPidxNext]; const int filtersOffset2 = filtersOffset + numFilters * fPidxNext; const int imagesOffset2 = imagesOffset + imgStride * iPidxNext; - FA_COLOR3_IMPRELOAD_TX(0,0); - FA_COLOR3_IMPRELOAD_TX(0,1); - FA_COLOR3_IMPRELOAD_TX(0,2); - FA_COLOR3_IMPRELOAD_TX(0,3); + FA_COLOR3_IMPRELOAD_TX(0, 0); + FA_COLOR3_IMPRELOAD_TX(0, 1); + FA_COLOR3_IMPRELOAD_TX(0, 2); + FA_COLOR3_IMPRELOAD_TX(0, 3); - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; ++c) { - #pragma unroll - for (int pp = 0; pp < pixelCache; pp += B_X/filtersPerThread) { - shFilters[c][pp + shFilterLoadY][shFilterLoadX] = fPreload[c][pp*filtersPerThread/B_X]; +#pragma unroll + for (int pp = 0; pp < pixelCache; pp += B_X / filtersPerThread) { + shFilters[c][pp + shFilterLoadY][shFilterLoadX] = + fPreload[c][pp * filtersPerThread / B_X]; } } __syncthreads(); - FA_COLOR3_IMPRELOAD_TX(1,0); - FA_COLOR3_IMPRELOAD_TX(1,1); - FA_COLOR3_IMPRELOAD_TX(1,2); - FA_COLOR3_IMPRELOAD_TX(1,3); - FA_COLOR3_IMPRELOAD_TX(2,0); - FA_COLOR3_IMPRELOAD_TX(2,1); - FA_COLOR3_IMPRELOAD_TX(2,2); - FA_COLOR3_IMPRELOAD_TX(2,3); - #pragma unroll + FA_COLOR3_IMPRELOAD_TX(1, 0); + FA_COLOR3_IMPRELOAD_TX(1, 1); + FA_COLOR3_IMPRELOAD_TX(1, 2); + FA_COLOR3_IMPRELOAD_TX(1, 3); + FA_COLOR3_IMPRELOAD_TX(2, 0); + FA_COLOR3_IMPRELOAD_TX(2, 1); + FA_COLOR3_IMPRELOAD_TX(2, 2); + FA_COLOR3_IMPRELOAD_TX(2, 3); +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll - for (int pp = 0; pp < pixelCache*filtersPerThread/B_X; pp++) { - fPreload[c][pp] = fPidxNext + pp*(B_X/filtersPerThread) + shFilterLoadY >= filterPixels ? 0 : tex1Dfetch(filters, filtersOffset2 + c * numFilters* filterPixels + pp*(B_X/filtersPerThread) * numFilters); +#pragma unroll + for (int pp = 0; pp < pixelCache * filtersPerThread / B_X; pp++) { + fPreload[c][pp] = + fPidxNext + pp * (B_X / filtersPerThread) + shFilterLoadY >= + filterPixels + ? 0 + : tex1Dfetch( + filters, + filtersOffset2 + + c * numFilters * filterPixels + + pp * (B_X / filtersPerThread) * + numFilters); } } - #pragma unroll +#pragma unroll for (int pp = 0; pp < pixelCache; pp++) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - prod[i][f] += shImages[c][pp][tx * imgsPerThread + i] * shFilters[c][pp][ty * filtersPerThread + f]; +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { + prod[i][f] += shImages[c][pp][tx * imgsPerThread + i] * + shFilters[c][pp][ty * filtersPerThread + f]; } } } @@ -210,23 +249,27 @@ __global__ void filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex(c } if (scale) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { - targets[i * B_X + f * numImages * numModules] = scaleTargets * targets[i * B_X + f * numImages * numModules] + scaleOutputs * prod[i][f]; + targets[i * B_X + f * numImages * numModules] = + scaleTargets * + targets[i * B_X + f * numImages * numModules] + + scaleOutputs * prod[i][f]; } } } } else { - // Note: reversing order of these loops saves 2 registers, but costs time - #pragma unroll +// Note: reversing order of these loops saves 2 registers, but costs time +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { - targets[i * B_X + f * numImages * numModules] = scaleOutputs * prod[i][f]; + targets[i * B_X + f * numImages * numModules] = + scaleOutputs * prod[i][f]; } } } @@ -242,23 +285,30 @@ __global__ void filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex(c * * This won't be pretty. */ -template -__global__ void filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex(cudaTextureObject_t images, cudaTextureObject_t filters, float* targets, - const int numImages, const int numFilters, - const int imgSizeY, const int imgSizeX, const int filterSize, const int paddingStart, - const int moduleStride, - const int numModulesY, const int numModulesX, const int imgStride, - const float scaleTargets, const float scaleOutputs, - const bool conv/*, const bool noloads*/) { - __shared__ float shFilters[numColors][pixelCache][B_Y * filtersPerThread]; // pre-load 1 pixel from B_Y*filtersPerThread filters - __shared__ float shImages[numColors][pixelCache][B_X * imgsPerThread]; // pre-load 1 pixel from B_X*imgsPerThread images - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); +template < + int B_Y, int B_X, int imgsPerThread, int filtersPerThread, int numColors, + int pixelCache, bool scale, bool checkImgBounds> +__global__ void filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex( + cudaTextureObject_t images, cudaTextureObject_t filters, float* targets, + const int numImages, const int numFilters, const int imgSizeY, + const int imgSizeX, const int filterSize, const int paddingStart, + const int moduleStride, const int numModulesY, const int numModulesX, + const int imgStride, const float scaleTargets, const float scaleOutputs, + const bool conv /*, const bool noloads*/) { + __shared__ float shFilters + [numColors][pixelCache] + [B_Y * + filtersPerThread]; // pre-load 1 pixel from B_Y*filtersPerThread filters + __shared__ float + shImages[numColors][pixelCache] + [B_X * + imgsPerThread]; // pre-load 1 pixel from B_X*imgsPerThread images + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); __syncthreads(); const int imgPixels = imgSizeY * imgSizeX; const int filterPixels = filterSize * filterSize; - const int blocksPerModule = numFilters / (B_Y*filtersPerThread); + const int blocksPerModule = numFilters / (B_Y * filtersPerThread); const int moduleIdx = blockIdx.y / blocksPerModule; const int blockFilterIdx = filtersPerThread * B_Y * (blockIdx.y % blocksPerModule); @@ -277,69 +327,78 @@ __global__ void filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex(c const int shFilterLoadX = tidx % (B_Y * filtersPerThread); const int myImgIdx = blockIdx.x * B_X * imgsPerThread + threadIdx.x; -// images += myImgIdx; -// filters += blockFilterIdx -// + shFilterLoadY * numFilters + shFilterLoadX; -// if (!conv) { // NOTE: UNTESTED! -// filters += moduleIdx * numColors * filterPixels * numFilters; -// } + // images += myImgIdx; + // filters += blockFilterIdx + // + shFilterLoadY * numFilters + shFilterLoadX; + // if (!conv) { // NOTE: UNTESTED! + // filters += moduleIdx * numColors * filterPixels * numFilters; + // } const int imagesOffset = myImgIdx; - const int filtersOffset = blockFilterIdx + shFilterLoadY * numFilters + shFilterLoadX - + (conv ? 0 : moduleIdx * numColors * filterPixels * numFilters); + const int filtersOffset = + blockFilterIdx + shFilterLoadY * numFilters + shFilterLoadX + + (conv ? 0 : moduleIdx * numColors * filterPixels * numFilters); - targets += moduleIdx * numImages - + (blockFilterIdx + threadIdx.y * filtersPerThread) * numImages * numModules - + myImgIdx; + targets += + moduleIdx * numImages + + (blockFilterIdx + threadIdx.y * filtersPerThread) * numImages * numModules + + myImgIdx; float prod[imgsPerThread][filtersPerThread]; - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { prod[i][f] = 0; } } int iPidxNext; float imPreload[numColors][imgsPerThread]; - float fPreload[numColors][DIVUP(pixelCache*filtersPerThread,B_X)]; + float fPreload[numColors][DIVUP(pixelCache * filtersPerThread, B_X)]; if (warp < 3) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; ++c) { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelCache; p += 2) { if (p + shFilterLoadY < filterPixels) { - fPreload[c][p/2] = tex1Dfetch(filters, filtersOffset + p * numFilters + c * numFilters * filterPixels); + fPreload[c][p / 2] = tex1Dfetch( + filters, filtersOffset + p * numFilters + + c * numFilters * filterPixels); } else { - fPreload[c][p/2] = 0; + fPreload[c][p / 2] = 0; } } } } - filterActs_YxX_color_preload_ty_4_tx_32_f_16_cc_3_setImgCoords(ty, imgLoadModPosY, imgLoadModPosX, imgSizeX, filterSize, iPidxNext); + filterActs_YxX_color_preload_ty_4_tx_32_f_16_cc_3_setImgCoords( + ty, imgLoadModPosY, imgLoadModPosX, imgSizeX, filterSize, iPidxNext); - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; ++c) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (iPidxNext >= 0 && (!checkImgBounds || myImgIdx + i * B_X < numImages)) { - imPreload[c][i] = tex1Dfetch(images, imagesOffset + (c * imgPixels + iPidxNext) * imgStride + i * B_X); + imPreload[c][i] = tex1Dfetch( + images, imagesOffset + (c * imgPixels + iPidxNext) * imgStride + + i * B_X); } else { - imPreload[c][i] = 0; + imPreload[c][i] = 0; } } } for (int p = 0; p < filterPixels; p += pixelCache) { const int fPidxNext = p + pixelCache >= filterPixels ? 0 : p + pixelCache; - filterActs_YxX_color_preload_ty_4_tx_32_f_16_cc_3_setImgCoords(fPidxNext + ty, imgLoadModPosY, imgLoadModPosX, imgSizeX, filterSize, iPidxNext); + filterActs_YxX_color_preload_ty_4_tx_32_f_16_cc_3_setImgCoords( + fPidxNext + ty, imgLoadModPosY, imgLoadModPosX, imgSizeX, filterSize, + iPidxNext); - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; ++c) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { // NOTE: bank conflicts here! shImages[c][ty][tx * imgsPerThread + i] = imPreload[c][i]; @@ -347,68 +406,80 @@ __global__ void filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex(c } if (warp < 3) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; ++c) { - #pragma unroll +#pragma unroll for (int pp = 0; pp < pixelCache; pp += 2) { - shFilters[c][pp + shFilterLoadY][shFilterLoadX] = fPreload[c][pp/2]; + shFilters[c][pp + shFilterLoadY][shFilterLoadX] = + fPreload[c][pp / 2]; } } } __syncthreads(); -// const float* ff = &filters[numFilters * fPidxNext]; -// const float* mm = &images[imgStride * iPidxNext]; + // const float* ff = &filters[numFilters * fPidxNext]; + // const float* mm = &images[imgStride * iPidxNext]; const int filtersOffset2 = filtersOffset + numFilters * fPidxNext; const int imagesOffset2 = imagesOffset + imgStride * iPidxNext; - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; ++i) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - FA_COLOR3_IMPRELOAD_TX(c,i); + FA_COLOR3_IMPRELOAD_TX(c, i); } } - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int pp = 0; pp < 2; pp++) { - fPreload[c][pp] = warp >= 3 || fPidxNext + pp*2 + shFilterLoadY >= filterPixels ? 0 : tex1Dfetch(filters, filtersOffset2 + c * numFilters* filterPixels + pp*2 * numFilters); + fPreload[c][pp] = + warp >= 3 || fPidxNext + pp * 2 + shFilterLoadY >= filterPixels + ? 0 + : tex1Dfetch( + filters, + filtersOffset2 + + c * numFilters * filterPixels + + pp * 2 * numFilters); } - #pragma unroll +#pragma unroll for (int pp = 0; pp < pixelCache; pp++) { - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - prod[i][f] += shImages[c][pp][tx * imgsPerThread + i] * shFilters[c][pp][ty * filtersPerThread + f]; +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { + prod[i][f] += shImages[c][pp][tx * imgsPerThread + i] * + shFilters[c][pp][ty * filtersPerThread + f]; } } } - } __syncthreads(); } if (scale) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { - targets[i * B_X + f * numImages * numModules] = scaleTargets * targets[i * B_X + f * numImages * numModules] + scaleOutputs * prod[i][f]; + targets[i * B_X + f * numImages * numModules] = + scaleTargets * + targets[i * B_X + f * numImages * numModules] + + scaleOutputs * prod[i][f]; } } } } else { - // Note: reversing order of these loops costs 2 registers, but saves time - #pragma unroll +// Note: reversing order of these loops costs 2 registers, but saves time +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { - targets[i * B_X + f * numImages * numModules] = scaleOutputs * prod[i][f]; + targets[i * B_X + f * numImages * numModules] = + scaleOutputs * prod[i][f]; } } } @@ -422,29 +493,35 @@ __global__ void filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex(c * * targets: (numFilters, numModulesY, numModulesX, numImages) * - * Note: in git there's a 1.5% faster version of this which sues 167 registers instead of 154... - * it's basically the same thing, but it doesn't do the next-pixel computation. It just avoids - * pre-loading when it rolls over to the next pixel. + * Note: in git there's a 1.5% faster version of this which sues 167 registers instead + * of 154... it's basically the same thing, but it doesn't do the next-pixel + * computation. It just avoids pre-loading when it rolls over to the next pixel. */ -template -__global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* images, float* filters, float* targets, - const int numImages, const int numFilters, - const int imgSizeY, const int imgSizeX, const int filterSize, const int paddingStart, - const int moduleStride, - const int numModulesY, const int numModulesX, const int imgStride, const int numImgColors, - const int numGroups, - const float scaleTargets, const float scaleOutputs, - const bool conv/*, const bool noloads*/) { - __shared__ float shFilters[colorCache][B_Y * filtersPerThread]; // pre-load 1 pixel from B_Y*filtersPerThread filters - __shared__ float shImages[colorCache][B_X * imgsPerThread]; // pre-load 1 pixel from B_X*imgsPerThread images - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); +template < + int B_Y, int B_X, int imgsPerThread, int filtersPerThread, int colorCache, + bool scale, bool checkImgBounds> +__global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4( + float* images, float* filters, float* targets, const int numImages, + const int numFilters, const int imgSizeY, const int imgSizeX, + const int filterSize, const int paddingStart, const int moduleStride, + const int numModulesY, const int numModulesX, const int imgStride, + const int numImgColors, const int numGroups, const float scaleTargets, + const float scaleOutputs, const bool conv /*, const bool noloads*/) { + __shared__ float shFilters + [colorCache] + [B_Y * + filtersPerThread]; // pre-load 1 pixel from B_Y*filtersPerThread filters + __shared__ float + shImages[colorCache] + [B_X * + imgsPerThread]; // pre-load 1 pixel from B_X*imgsPerThread images + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); __syncthreads(); const int imgPixels = imgSizeY * imgSizeX; const int filterPixels = filterSize * filterSize; const int numFilterColors = numImgColors / numGroups; - const int blocksPerModule = numFilters / (B_Y*filtersPerThread); + const int blocksPerModule = numFilters / (B_Y * filtersPerThread); const int moduleIdx = blockIdx.y / blocksPerModule; const int blockFilterIdx = filtersPerThread * B_Y * (blockIdx.y % blocksPerModule); const int numFiltersPerGroup = numFilters / numGroups; @@ -466,22 +543,23 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* im const int myImgIdx = blockIdx.x * B_X * imgsPerThread + threadIdx.x; images += (blockColorIdx + threadIdx.y) * imgPixels * imgStride + myImgIdx; - filters +=blockFilterIdx - + shFilterLoadY * numFilters * filterPixels + shFilterLoadX; + filters += + blockFilterIdx + shFilterLoadY * numFilters * filterPixels + shFilterLoadX; if (!conv) { filters += moduleIdx * numFilterColors * filterPixels * numFilters; } - targets += moduleIdx * numImages - + (blockFilterIdx + threadIdx.y * filtersPerThread) * numImages * numModules - + myImgIdx; + targets += + moduleIdx * numImages + + (blockFilterIdx + threadIdx.y * filtersPerThread) * numImages * numModules + + myImgIdx; float prod[imgsPerThread][filtersPerThread]; -// float fCache[filtersPerThread]; - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { + // float fCache[filtersPerThread]; +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { prod[i][f] = 0; } } @@ -490,16 +568,18 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* im const int imgStartY = max(0, imgLoadModPosY); const int imgEndX = min(imgLoadModPosX + filterSize, imgSizeX); const int imgEndY = min(imgLoadModPosY + filterSize, imgSizeY); -// __shared__ int imgPos[] + // __shared__ int imgPos[] int fPidx, iPidx; float imPreload[imgsPerThread]; - float fPreload[colorCache*filtersPerThread/B_X]; -// float fCache[filtersPerThread]; + float fPreload[colorCache * filtersPerThread / B_X]; + // float fCache[filtersPerThread]; - filterActs_YxX_sparse2_preload_ty_4_tx_32_f_16_c_4_setPixelCoords(filterSize, imgSizeX, imgLoadModPosY, imgLoadModPosX, imgStartY, imgStartX, fPidx, iPidx); + filterActs_YxX_sparse2_preload_ty_4_tx_32_f_16_c_4_setPixelCoords( + filterSize, imgSizeX, imgLoadModPosY, imgLoadModPosX, imgStartY, imgStartX, + fPidx, iPidx); - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { imPreload[i] = images[imgStride * iPidx + i * B_X]; @@ -507,20 +587,23 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* im imPreload[i] = 0; } } - if (/*B_X % filtersPerThread == 0 ||*/ shFilterLoadY < B_X/filtersPerThread) { // This if statement reduces reg usage.. - #pragma unroll - for (int c = 0; c < colorCache; c += B_X/filtersPerThread) { - fPreload[c*filtersPerThread/B_X] = filters[(c * filterPixels + fPidx) * numFilters]; + if (/*B_X % filtersPerThread == 0 ||*/ shFilterLoadY < + B_X / filtersPerThread) { // This if statement reduces reg usage.. +#pragma unroll + for (int c = 0; c < colorCache; c += B_X / filtersPerThread) { + fPreload[c * filtersPerThread / B_X] = + filters[(c * filterPixels + fPidx) * numFilters]; } } for (int imgY = imgStartY; imgY < imgEndY; ++imgY) { -// const int filterPxY = imgY - imgLoadModPosY; + // const int filterPxY = imgY - imgLoadModPosY; for (int imgX = imgStartX; imgX < imgEndX; ++imgX) { -// const int filterPxX = imgX - imgLoadModPosX; -// const int p = filterPxY * filterSize + filterPxX; -// const int pixIdx = imgY * imgSizeX + imgX;// Pixel index in img -// setPixelCoords(filterSize, imgSizeX, imgLoadModPosY, imgLoadModPosX, imgY, imgX, &p, &pixIdx); -// float* m = &images[imgStride * pixIdx]; + // const int filterPxX = imgX - imgLoadModPosX; + // const int p = filterPxY * filterSize + filterPxX; + // const int pixIdx = imgY * imgSizeX + imgX;// Pixel index in + // img setPixelCoords(filterSize, imgSizeX, imgLoadModPosY, + // imgLoadModPosX, imgY, imgX, &p, &pixIdx); float* m = + // &images[imgStride * pixIdx]; const bool lastPixel = imgY == imgEndY - 1 && imgX == imgEndX - 1; int imgYNext = imgY; int imgXNext = imgX; @@ -529,10 +612,16 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* im imgYNext = imgY + (imgX + 1 == imgEndX); imgXNext = imgX + 1 == imgEndX ? imgStartX : imgX + 1; } - filterActs_YxX_sparse2_preload_ty_4_tx_32_f_16_c_4_setPixelCoords(filterSize, imgSizeX, imgLoadModPosY, imgLoadModPosX, imgYNext, imgXNext, fPidxNext, iPidxNext); - for (int oc = 0; oc < numFilterColors; oc += colorCache) { // oc stands for outer color (loop) - const float* ff = &filters[numFilters * ((oc + colorCache) * filterPixels + fPidx)]; - const float* mm = &images[imgStride * ((oc + colorCache) * imgPixels + iPidx)]; + filterActs_YxX_sparse2_preload_ty_4_tx_32_f_16_c_4_setPixelCoords( + filterSize, imgSizeX, imgLoadModPosY, imgLoadModPosX, imgYNext, + imgXNext, fPidxNext, iPidxNext); + for (int oc = 0; oc < numFilterColors; + oc += colorCache) { // oc stands for outer color (loop) + const float* ff = &filters + [numFilters * + ((oc + colorCache) * filterPixels + fPidx)]; + const float* mm = + &images[imgStride * ((oc + colorCache) * imgPixels + iPidx)]; if (oc == numFilterColors - colorCache) { ff = &filters[fPidxNext * numFilters]; mm = &images[iPidxNext * imgStride]; @@ -540,57 +629,70 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* im iPidx = iPidxNext; } - #pragma unroll - for (int c = 0; c < colorCache; c += B_X/filtersPerThread) { - shFilters[c + shFilterLoadY][shFilterLoadX] = fPreload[c*filtersPerThread/B_X]; +#pragma unroll + for (int c = 0; c < colorCache; c += B_X / filtersPerThread) { + shFilters[c + shFilterLoadY][shFilterLoadX] = + fPreload[c * filtersPerThread / B_X]; } - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { // NOTE: bank conflicts here! shImages[ty][tx * imgsPerThread + i] = imPreload[i]; } - imPreload[0] = (checkImgBounds && myImgIdx + 0 * B_X >= numImages) ? 0 : mm[0 * B_X]; - imPreload[1] = (checkImgBounds && myImgIdx + 1 * B_X >= numImages) ? 0 : mm[1 * B_X]; - imPreload[2] = (checkImgBounds && myImgIdx + 2 * B_X >= numImages) ? 0 : mm[2 * B_X]; + imPreload[0] = (checkImgBounds && myImgIdx + 0 * B_X >= numImages) + ? 0 + : mm[0 * B_X]; + imPreload[1] = (checkImgBounds && myImgIdx + 1 * B_X >= numImages) + ? 0 + : mm[1 * B_X]; + imPreload[2] = (checkImgBounds && myImgIdx + 2 * B_X >= numImages) + ? 0 + : mm[2 * B_X]; __syncthreads(); - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - prod[i][f] += shImages[0][threadIdx.x * imgsPerThread + i] * shFilters[0][threadIdx.y * filtersPerThread + f]; +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { + prod[i][f] += shImages[0][threadIdx.x * imgsPerThread + i] * + shFilters[0][threadIdx.y * filtersPerThread + f]; } } fPreload[0] = ff[0]; - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - prod[i][f] += shImages[1][threadIdx.x * imgsPerThread + i] * shFilters[1][threadIdx.y * filtersPerThread + f]; +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { + prod[i][f] += shImages[1][threadIdx.x * imgsPerThread + i] * + shFilters[1][threadIdx.y * filtersPerThread + f]; } } - fPreload[1] = ff[(B_X/filtersPerThread * filterPixels) * numFilters]; + fPreload[1] = ff[(B_X / filtersPerThread * filterPixels) * numFilters]; - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - prod[i][f] += shImages[2][threadIdx.x * imgsPerThread + i] * shFilters[2][threadIdx.y * filtersPerThread + f]; +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { + prod[i][f] += shImages[2][threadIdx.x * imgsPerThread + i] * + shFilters[2][threadIdx.y * filtersPerThread + f]; } } - imPreload[3] = (checkImgBounds && myImgIdx + 3 * B_X >= numImages) ? 0 : mm[3 * B_X]; + imPreload[3] = (checkImgBounds && myImgIdx + 3 * B_X >= numImages) + ? 0 + : mm[3 * B_X]; - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - prod[i][f] += shImages[3][threadIdx.x * imgsPerThread + i] * shFilters[3][threadIdx.y * filtersPerThread + f]; +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { + prod[i][f] += shImages[3][threadIdx.x * imgsPerThread + i] * + shFilters[3][threadIdx.y * filtersPerThread + f]; } } __syncthreads(); @@ -599,23 +701,27 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* im } if (scale) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { - targets[i * B_X + f * numImages * numModules] = scaleTargets * targets[i * B_X + f * numImages * numModules] + scaleOutputs * prod[i][f]; + targets[i * B_X + f * numImages * numModules] = + scaleTargets * + targets[i * B_X + f * numImages * numModules] + + scaleOutputs * prod[i][f]; } } } } else { - // Note: reversing order of these loops saves 2 registers, but costs time - #pragma unroll +// Note: reversing order of these loops saves 2 registers, but costs time +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { - targets[i * B_X + f * numImages * numModules] = scaleOutputs * prod[i][f]; + targets[i * B_X + f * numImages * numModules] = + scaleOutputs * prod[i][f]; } } } @@ -634,42 +740,50 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* im * Other batch sizes will work, but but I made no attempt whatsoever * to make them work fast. */ - void _filterActs(cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int paddingStart, int moduleStride, - int numImgColors, int numGroups, - float scaleTargets, float scaleOutput, bool conv) { +void _filterActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int paddingStart, + int moduleStride, int numImgColors, int numGroups, float scaleTargets, + float scaleOutput, bool conv) { int numFilterColors = numImgColors / numGroups; int numFilters = filters.getNumCols(); int numModules = numModulesY * numModulesX; int numImages = images.getNumCols(); - int imgPixels = images.getNumRows()/numImgColors; + int imgPixels = images.getNumRows() / numImgColors; int imgSizeX = imgPixels / imgSizeY; int filterModuleMult = conv ? 1 : numModules; - megdnn_assert_internal(numGroups > 1 || (numImgColors > 0 && (numImgColors <= 3 || numImgColors % 4 == 0))); + megdnn_assert_internal( + numGroups > 1 || + (numImgColors > 0 && (numImgColors <= 3 || numImgColors % 4 == 0))); megdnn_assert_internal(numGroups == 1 || numFilterColors % 4 == 0); - //megdnn_assert_internal(numFilters % (16 * numGroups) == 0); + // megdnn_assert_internal(numFilters % (16 * numGroups) == 0); megdnn_assert_internal(numImgColors % numGroups == 0); bool previous_limit = (numFilters % (16 * numGroups)) == 0; - //images.printShape("images"); - //printf("rows: %d, pixels: %d, colors: %d\n", images.getNumRows(), imgPixels, numImgColors); - //images.printShape("images"); + // images.printShape("images"); + // printf("rows: %d, pixels: %d, colors: %d\n", images.getNumRows(), imgPixels, + // numImgColors); images.printShape("images"); megdnn_assert_internal(images.getNumRows() == imgPixels * numImgColors); megdnn_assert_internal(imgSizeY * imgSizeX == imgPixels); int numFiltersPerGroup = numFilters / numGroups; - int imgStride = images.getStride(); // images does not need to be a contiguous matrix + int imgStride = + images.getStride(); // images does not need to be a contiguous matrix int filterPixels = filters.getNumRows() / (filterModuleMult * numFilterColors); int filterSize = int(sqrt(filterPixels)); megdnn_assert_internal(filterSize * filterSize == filterPixels); - megdnn_assert_internal(filters.getNumRows() == filterModuleMult * numFilterColors * filterPixels); + megdnn_assert_internal( + filters.getNumRows() == filterModuleMult * numFilterColors * filterPixels); - // These routines don't handle the case when only part of the image is visited in the convolution + // These routines don't handle the case when only part of the image is visited in + // the convolution megdnn_assert_internal(paddingStart <= 0); - megdnn_assert_internal(paddingStart + (numModulesX-1)*moduleStride + filterSize >= imgSizeX); - megdnn_assert_internal(paddingStart + (numModulesY-1)*moduleStride + filterSize >= imgSizeY); + megdnn_assert_internal( + paddingStart + (numModulesX - 1) * moduleStride + filterSize >= imgSizeX); + megdnn_assert_internal( + paddingStart + (numModulesY - 1) * moduleStride + filterSize >= imgSizeY); megdnn_assert_internal(moduleStride <= filterSize); megdnn_assert_internal(!images.isTrans()); @@ -681,19 +795,29 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* im int imgsPerThread = numImages % 128 == 0 ? 4 : numImages % 64 == 0 ? 2 : 1; int filtersPerThread, threadsY = 4; if (numImgColors <= 3) { - // Special kernels written for colors = 3, filters = 64 and colors = 3, filters = 48 cases. - // The remaining cases use the old routines. + // Special kernels written for colors = 3, filters = 64 and colors = 3, filters + // = 48 cases. The remaining cases use the old routines. // TODO: Modernize the remaining cases if you care about them. - filtersPerThread = numFiltersPerGroup % 64 == 0 ? 16 : numFiltersPerGroup % 48 == 0 ? 12 : numFiltersPerGroup % 32 == 0 ? 8 : 4; + filtersPerThread = numFiltersPerGroup % 64 == 0 ? 16 + : numFiltersPerGroup % 48 == 0 ? 12 + : numFiltersPerGroup % 32 == 0 ? 8 + : 4; } else { - filtersPerThread = numFiltersPerGroup % 64 == 0 ? 16 : numFiltersPerGroup % 32 == 0 ? 8 : 4; - threadsY = numFiltersPerGroup % 128 == 0 && numFilterColors % 8 == 0 && imgsPerThread != 4 ? 8 : 4; + filtersPerThread = numFiltersPerGroup % 64 == 0 ? 16 + : numFiltersPerGroup % 32 == 0 ? 8 + : 4; + threadsY = numFiltersPerGroup % 128 == 0 && numFilterColors % 8 == 0 && + imgsPerThread != 4 + ? 8 + : 4; } int threadsX = 32; dim3 threads(threadsX, threadsY); - dim3 blocks = dim3(DIVUP(numImages, threads.x * imgsPerThread), numModules * DIVUP(numFilters, (threads.y * filtersPerThread))); + dim3 blocks = + dim3(DIVUP(numImages, threads.x * imgsPerThread), + numModules * DIVUP(numFilters, (threads.y * filtersPerThread))); - bool checkImgBounds = numImages % (threads.x*imgsPerThread) != 0; + bool checkImgBounds = numImages % (threads.x * imgsPerThread) != 0; bool scale = scaleTargets != 0; if (scaleTargets == 0) { targets.resize(numFilters * numModules, numImages); @@ -703,8 +827,9 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* im } // Auto-generated calling code... - // NOTE: The calling code is set up such that if checkImgBounds is true, then imgsPerThread = 1. - // In principle it doesn't have to be this way, and you may want to optimize for that case. + // NOTE: The calling code is set up such that if checkImgBounds is true, then + // imgsPerThread = 1. In principle it doesn't have to be this way, and you may want + // to optimize for that case. if (scale == false) { if (checkImgBounds == false) { @@ -713,826 +838,2175 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* im if (numFiltersPerGroup % 128 == 0) { if (previous_limit) { if (images.getNumDataBytes() < TEXTURE_SIZE_MAX) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex < 4, 32, 4, 16, 4, false, false >, cudaFuncCachePreferL1); - filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex < 4, 32, 4, 16, 4, false, false > <<>>(images.getTextureObject(), filters.getTextureObject(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex< + 4, 32, 4, 16, 4, false, false>, + cudaFuncCachePreferL1); + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex< + 4, 32, 4, 16, 4, false, false> + <<>>( + images.getTextureObject(), + filters.getTextureObject(), + targets.getDevData(), numImages, + numFilters, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + numModulesY, numModulesX, imgStride, + numImgColors, numGroups, scaleTargets, + scaleOutput, conv); } else { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4 < 4, 32, 4, 16, 4, false, false >, cudaFuncCachePreferL1); - filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4 < 4, 32, 4, 16, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4< + 4, 32, 4, 16, 4, false, false>, + cudaFuncCachePreferL1); + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4< + 4, 32, 4, 16, 4, false, false> + <<>>( + images.getDevData(), + filters.getDevData(), + targets.getDevData(), numImages, + numFilters, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + numModulesY, numModulesX, imgStride, + numImgColors, numGroups, scaleTargets, + scaleOutput, conv); } } else { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2< + 4, 32, 4, 16, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 16, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, numImgColors, + numGroups, scaleTargets, scaleOutput, conv); } - } - else if (numFiltersPerGroup % 64 == 0) { + } else if (numFiltersPerGroup % 64 == 0) { if (previous_limit) { if (images.getNumDataBytes() < TEXTURE_SIZE_MAX) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex < 4, 32, 4, 16, 4, false, false >, cudaFuncCachePreferL1); - filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex < 4, 32, 4, 16, 4, false, false > <<>>(images.getTextureObject(), filters.getTextureObject(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex< + 4, 32, 4, 16, 4, false, false>, + cudaFuncCachePreferL1); + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex< + 4, 32, 4, 16, 4, false, false> + <<>>( + images.getTextureObject(), + filters.getTextureObject(), + targets.getDevData(), numImages, + numFilters, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + numModulesY, numModulesX, imgStride, + numImgColors, numGroups, scaleTargets, + scaleOutput, conv); } else { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4 < 4, 32, 4, 16, 4, false, false >, cudaFuncCachePreferL1); - filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4 < 4, 32, 4, 16, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4< + 4, 32, 4, 16, 4, false, false>, + cudaFuncCachePreferL1); + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4< + 4, 32, 4, 16, 4, false, false> + <<>>( + images.getDevData(), + filters.getDevData(), + targets.getDevData(), numImages, + numFilters, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + numModulesY, numModulesX, imgStride, + numImgColors, numGroups, scaleTargets, + scaleOutput, conv); } } else { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2< + 4, 32, 4, 16, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 16, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, numImgColors, + numGroups, scaleTargets, scaleOutput, conv); } - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 8, 8, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 8, 8, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 4, 8, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 4, 8, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 64 == 0) { + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 8, 8, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 8, 8, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 4, 8, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 4, 8, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } + } else if (numImages % 64 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 8, 32, 2, 16, 8, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 8, 32, 2, 16, 8, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 16, 8, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 16, 8, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 8, 8, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 8, 8, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 4, 8, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 4, 8, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<8, 32, 2, 16, 8, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<8, 32, 2, 16, 8, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 16, 8, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 16, 8, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 8, 8, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 8, 8, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 4, 8, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 4, 8, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } + } else if (numImages % 32 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 8, 32, 1, 16, 8, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 8, 32, 1, 16, 8, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 8, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 8, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 8, 8, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 8, 8, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 4, 8, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 4, 8, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<8, 32, 1, 16, 8, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<8, 32, 1, 16, 8, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 8, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 8, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 8, 8, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 8, 8, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 4, 8, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 4, 8, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors % 4 == 0) { + } else if (numFilterColors % 4 == 0) { if (numImages % 128 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 8, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 8, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 4, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 4, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 16, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 16, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 16, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 16, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 8, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 8, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 4, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 4, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } + } else if (numImages % 64 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 16, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 16, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 16, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 16, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 8, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 8, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 4, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 4, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 16, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 16, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 16, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 16, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 8, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 8, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 4, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 4, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } + } else if (numImages % 32 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 8, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 8, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 4, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 4, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 8, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 8, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 4, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 4, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 3) { + } else if (numFilterColors == 3) { if (numImages % 128 == 0) { if (numFiltersPerGroup % 64 == 0) { if (previous_limit) { - cudaFuncSetCacheConfig(filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex < 4, 32, 4, 16, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex < 4, 32, 4, 16, 3, 4, false, false > <<>>(images.getTextureObject(), filters.getTextureObject(), targets.getDevData(),numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex< + 4, 32, 4, 16, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex< + 4, 32, 4, 16, 3, 4, false, false> + <<>>( + images.getTextureObject(), + filters.getTextureObject(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, scaleTargets, + scaleOutput, conv); } else { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 16, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 16, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color< + 4, 32, 4, 16, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 16, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, scaleTargets, + scaleOutput, conv); } - } - else if (numFiltersPerGroup % 48 == 0) { + } else if (numFiltersPerGroup % 48 == 0) { if (previous_limit) { - cudaFuncSetCacheConfig(filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex < 4, 32, 4, 12, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex < 4, 32, 4, 12, 3, 4, false, false > <<>>(images.getTextureObject(), filters.getTextureObject(), targets.getDevData(),numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex< + 4, 32, 4, 12, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex< + 4, 32, 4, 12, 3, 4, false, false> + <<>>( + images.getTextureObject(), + filters.getTextureObject(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, scaleTargets, + scaleOutput, conv); } else { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 12, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 12, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color< + 4, 32, 4, 12, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 12, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, scaleTargets, + scaleOutput, conv); } - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 8, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 8, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 4, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 4, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 64 == 0) { + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 8, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 8, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 4, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 4, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 64 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 16, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 16, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 12, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 12, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 8, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 8, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 4, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 4, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 16, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 16, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 12, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 12, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 8, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 8, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 4, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 4, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 32 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 3, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 3, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 3, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 3, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 2) { + } else if (numFilterColors == 2) { if (numImages % 128 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 16, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 16, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 12, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 12, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 8, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 8, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 4, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 4, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 16, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 16, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 12, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 12, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 8, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 8, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 4, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 4, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 64 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 16, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 16, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 12, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 12, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 8, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 8, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 4, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 4, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 16, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 16, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 12, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 12, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 8, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 8, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 4, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 4, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 32 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 2, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 2, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 2, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 2, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 1) { + } else if (numFilterColors == 1) { if (numImages % 128 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 16, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 16, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 12, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 12, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 8, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 8, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 4, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 4, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 16, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 16, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 12, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 12, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 8, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 8, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 4, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 4, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 64 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 16, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 16, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 12, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 12, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 8, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 8, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 4, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 4, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 16, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 16, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 12, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 12, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 8, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 8, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 4, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 4, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 32 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 1, 4, false, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 1, 4, false, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 1, 4, false, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 1, 4, false, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } } - } - else if (checkImgBounds == true) { + } else if (checkImgBounds == true) { if (numFilterColors % 8 == 0) { if (numImages % 1 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 8, 32, 1, 16, 8, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 8, 32, 1, 16, 8, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 8, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 8, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 8, 8, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 8, 8, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 4, 8, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 4, 8, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<8, 32, 1, 16, 8, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<8, 32, 1, 16, 8, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 8, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 8, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 8, 8, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 8, 8, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 4, 8, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 4, 8, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors % 4 == 0) { + } else if (numFilterColors % 4 == 0) { if (numImages % 1 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 8, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 8, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 4, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 4, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 8, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 8, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 4, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 4, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 3) { + } else if (numFilterColors == 3) { if (numImages % 1 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 3, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 3, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 3, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 3, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 3, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 3, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 3, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 3, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 3, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 3, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 3, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 3, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 3, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 3, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 3, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 3, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 2) { + } else if (numFilterColors == 2) { if (numImages % 1 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 2, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 2, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 2, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 2, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 2, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 2, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 2, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 2, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 2, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 2, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 2, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 2, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 2, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 2, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 2, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 2, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 1) { + } else if (numFilterColors == 1) { if (numImages % 1 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 1, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 1, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 1, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 1, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 1, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 1, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 1, 4, false, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 1, 4, false, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 1, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 1, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 1, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 1, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 1, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 1, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 1, 4, false, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 1, 4, false, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } } } - } - else if (scale == true) { + } else if (scale == true) { if (checkImgBounds == false) { if (numFilterColors % 8 == 0) { if (numImages % 128 == 0) { if (numFiltersPerGroup % 128 == 0) { if (previous_limit) { if (images.getNumDataBytes() < TEXTURE_SIZE_MAX) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex < 4, 32, 4, 16, 4, true, false >, cudaFuncCachePreferL1); - filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex < 4, 32, 4, 16, 4, true, false > <<>>(images.getTextureObject(), filters.getTextureObject(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex< + 4, 32, 4, 16, 4, true, false>, + cudaFuncCachePreferL1); + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex< + 4, 32, 4, 16, 4, true, false> + <<>>( + images.getTextureObject(), + filters.getTextureObject(), + targets.getDevData(), numImages, + numFilters, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + numModulesY, numModulesX, imgStride, + numImgColors, numGroups, scaleTargets, + scaleOutput, conv); } else { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4 < 4, 32, 4, 16, 4, true, false >, cudaFuncCachePreferL1); - filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4 < 4, 32, 4, 16, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4< + 4, 32, 4, 16, 4, true, false>, + cudaFuncCachePreferL1); + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4< + 4, 32, 4, 16, 4, true, false> + <<>>( + images.getDevData(), + filters.getDevData(), + targets.getDevData(), numImages, + numFilters, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + numModulesY, numModulesX, imgStride, + numImgColors, numGroups, scaleTargets, + scaleOutput, conv); } } else { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2< + 4, 32, 4, 16, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 16, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, numImgColors, + numGroups, scaleTargets, scaleOutput, conv); } - } - else if (numFiltersPerGroup % 64 == 0) { + } else if (numFiltersPerGroup % 64 == 0) { if (previous_limit) { if (images.getNumDataBytes() < TEXTURE_SIZE_MAX) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex < 4, 32, 4, 16, 4, true, false >, cudaFuncCachePreferL1); - filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex < 4, 32, 4, 16, 4, true, false > <<>>(images.getTextureObject(), filters.getTextureObject(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex< + 4, 32, 4, 16, 4, true, false>, + cudaFuncCachePreferL1); + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex< + 4, 32, 4, 16, 4, true, false> + <<>>( + images.getTextureObject(), + filters.getTextureObject(), + targets.getDevData(), numImages, + numFilters, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + numModulesY, numModulesX, imgStride, + numImgColors, numGroups, scaleTargets, + scaleOutput, conv); } else { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4 < 4, 32, 4, 16, 4, true, false >, cudaFuncCachePreferL1); - filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4 < 4, 32, 4, 16, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4< + 4, 32, 4, 16, 4, true, false>, + cudaFuncCachePreferL1); + filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4< + 4, 32, 4, 16, 4, true, false> + <<>>( + images.getDevData(), + filters.getDevData(), + targets.getDevData(), numImages, + numFilters, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + numModulesY, numModulesX, imgStride, + numImgColors, numGroups, scaleTargets, + scaleOutput, conv); } } else { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2< + 4, 32, 4, 16, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 16, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, numImgColors, + numGroups, scaleTargets, scaleOutput, conv); } - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 8, 8, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 8, 8, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 4, 8, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 4, 8, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 64 == 0) { + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 8, 8, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 8, 8, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 4, 8, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 4, 8, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } + } else if (numImages % 64 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 8, 32, 2, 16, 8, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 8, 32, 2, 16, 8, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 16, 8, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 16, 8, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 8, 8, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 8, 8, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 4, 8, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 4, 8, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<8, 32, 2, 16, 8, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<8, 32, 2, 16, 8, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 16, 8, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 16, 8, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 8, 8, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 8, 8, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 4, 8, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 4, 8, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } + } else if (numImages % 32 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 8, 32, 1, 16, 8, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 8, 32, 1, 16, 8, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 8, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 8, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 8, 8, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 8, 8, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 4, 8, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 4, 8, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<8, 32, 1, 16, 8, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<8, 32, 1, 16, 8, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 8, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 8, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 8, 8, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 8, 8, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 4, 8, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 4, 8, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors % 4 == 0) { + } else if (numFilterColors % 4 == 0) { if (numImages % 128 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 16, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 8, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 8, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 4, 4, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 4, 4, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 16, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 16, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 16, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 16, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 8, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 8, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 4, 4, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 4, 4, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } + } else if (numImages % 64 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 16, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 16, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 16, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 16, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 8, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 8, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 2, 4, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 2, 4, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 16, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 16, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 16, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 16, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 8, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 8, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 2, 4, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 2, 4, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } + } else if (numImages % 32 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 8, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 8, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 4, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 4, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 8, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 8, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 4, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 4, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 3) { + } else if (numFilterColors == 3) { if (numImages % 128 == 0) { if (numFiltersPerGroup % 64 == 0) { if (previous_limit) { - cudaFuncSetCacheConfig(filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex < 4, 32, 4, 16, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex < 4, 32, 4, 16, 3, 4, true, false > <<>>(images.getTextureObject(), filters.getTextureObject(), targets.getDevData(),numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex< + 4, 32, 4, 16, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_16_px_4_cc_3_tex< + 4, 32, 4, 16, 3, 4, true, false> + <<>>( + images.getTextureObject(), + filters.getTextureObject(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, scaleTargets, + scaleOutput, conv); } else { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 16, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 16, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color< + 4, 32, 4, 16, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 16, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, scaleTargets, + scaleOutput, conv); } - } - else if (numFiltersPerGroup % 48 == 0) { + } else if (numFiltersPerGroup % 48 == 0) { if (previous_limit) { - cudaFuncSetCacheConfig(filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex < 4, 32, 4, 12, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex < 4, 32, 4, 12, 3, 4, true, false > <<>>(images.getTextureObject(), filters.getTextureObject(), targets.getDevData(),numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex< + 4, 32, 4, 12, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color_preload_ty_4_tx_32_i_4_f_12_px_4_cc_3_tex< + 4, 32, 4, 12, 3, 4, true, false> + <<>>( + images.getTextureObject(), + filters.getTextureObject(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, scaleTargets, + scaleOutput, conv); } else { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 12, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 12, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color< + 4, 32, 4, 12, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 12, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, + paddingStart, moduleStride, numModulesY, + numModulesX, imgStride, scaleTargets, + scaleOutput, conv); } - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 8, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 8, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 4, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 4, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 64 == 0) { + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 8, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 8, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 4, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 4, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 64 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 16, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 16, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 12, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 12, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 8, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 8, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 4, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 4, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 16, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 16, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 12, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 12, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 8, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 8, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 4, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 4, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 32 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 3, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 3, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 3, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 3, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 2) { + } else if (numFilterColors == 2) { if (numImages % 128 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 16, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 16, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 12, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 12, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 8, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 8, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 4, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 4, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 16, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 16, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 12, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 12, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 8, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 8, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 4, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 4, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 64 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 16, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 16, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 12, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 12, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 8, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 8, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 4, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 4, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 16, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 16, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 12, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 12, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 8, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 8, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 4, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 4, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 32 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 2, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 2, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 2, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 2, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 1) { + } else if (numFilterColors == 1) { if (numImages % 128 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 16, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 16, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 12, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 12, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 8, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 8, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 4, 4, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 4, 4, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 16, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 16, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 12, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 12, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 8, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 8, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 4, 4, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 4, 4, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 64 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 16, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 16, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 12, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 12, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 8, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 8, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 2, 4, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 2, 4, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - } - else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 16, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 16, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 12, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 12, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 8, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 8, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 2, 4, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 2, 4, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } + } else if (numImages % 32 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 1, 4, true, false >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 1, 4, true, false > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 1, 4, true, false>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 1, 4, true, false> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } } - } - else if (checkImgBounds == true) { + } else if (checkImgBounds == true) { if (numFilterColors % 8 == 0) { if (numImages % 1 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 8, 32, 1, 16, 8, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 8, 32, 1, 16, 8, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 8, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 8, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 8, 8, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 8, 8, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 4, 8, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 4, 8, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<8, 32, 1, 16, 8, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<8, 32, 1, 16, 8, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 8, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 8, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 8, 8, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 8, 8, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 4, 8, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 4, 8, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors % 4 == 0) { + } else if (numFilterColors % 4 == 0) { if (numImages % 1 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 16, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 8, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 8, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_sparse2 < 4, 32, 1, 4, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_sparse2 < 4, 32, 1, 4, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, numImgColors, numGroups, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 16, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 16, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 8, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 8, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_sparse2<4, 32, 1, 4, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_sparse2<4, 32, 1, 4, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, numImgColors, numGroups, + scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 3) { + } else if (numFilterColors == 3) { if (numImages % 1 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 3, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 3, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 3, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 3, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 3, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 3, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 3, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 3, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 3, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 3, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 3, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 3, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 3, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 3, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 3, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 3, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 2) { + } else if (numFilterColors == 2) { if (numImages % 1 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 2, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 2, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 2, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 2, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 2, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 2, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 2, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 2, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 2, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 2, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 2, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 2, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 2, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 2, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 2, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 2, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } - } - else if (numFilterColors == 1) { + } else if (numFilterColors == 1) { if (numImages % 1 == 0) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 16, 1, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 16, 1, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 12, 1, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 12, 1, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 8, 1, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 8, 1, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(filterActs_YxX_color < 4, 32, 1, 4, 1, 4, true, true >, cudaFuncCachePreferShared); - filterActs_YxX_color < 4, 32, 1, 4, 1, 4, true, true > <<>>(images.getDevData(), filters.getDevData(), targets.getDevData(), numImages, numFilters, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, numModulesY, numModulesX, imgStride, scaleTargets, scaleOutput, conv); + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 16, 1, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 16, 1, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 12, 1, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 12, 1, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 8, 1, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 8, 1, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + filterActs_YxX_color<4, 32, 1, 4, 1, 4, true, true>, + cudaFuncCachePreferShared); + filterActs_YxX_color<4, 32, 1, 4, 1, 4, true, true> + <<>>( + images.getDevData(), filters.getDevData(), + targets.getDevData(), numImages, numFilters, + imgSizeY, imgSizeX, filterSize, paddingStart, + moduleStride, numModulesY, numModulesX, + imgStride, scaleTargets, scaleOutput, conv); } } } @@ -1542,31 +3016,45 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4(float* im getLastCudaError("filterActs: kernel execution failed"); } -void convFilterActs(cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int paddingStart, int moduleStride, - int numImgColors, int numGroups) { - convFilterActs(stream, images, filters, targets, imgSizeY, numModulesY, numModulesX, paddingStart, moduleStride, numImgColors, numGroups, 0, 1); +void convFilterActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int paddingStart, + int moduleStride, int numImgColors, int numGroups) { + convFilterActs( + stream, images, filters, targets, imgSizeY, numModulesY, numModulesX, + paddingStart, moduleStride, numImgColors, numGroups, 0, 1); } -void convFilterActs(cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int paddingStart, int moduleStride, - int numImgColors, int numGroups, - float scaleTargets, float scaleOutput) { - _filterActs(stream, images, filters, targets, imgSizeY, numModulesY, numModulesX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput, true); +void convFilterActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int paddingStart, + int moduleStride, int numImgColors, int numGroups, float scaleTargets, + float scaleOutput) { + _filterActs( + stream, images, filters, targets, imgSizeY, numModulesY, numModulesX, + paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, + scaleOutput, true); } -void localFilterActs(cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int paddingStart, int moduleStride, - int numImgColors, int numGroups) { - localFilterActs(stream, images, filters, targets, imgSizeY, numModulesY, numModulesX, paddingStart, moduleStride, numImgColors, numGroups, 0, 1); +void localFilterActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int paddingStart, + int moduleStride, int numImgColors, int numGroups) { + localFilterActs( + stream, images, filters, targets, imgSizeY, numModulesY, numModulesX, + paddingStart, moduleStride, numImgColors, numGroups, 0, 1); } -void localFilterActs(cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int paddingStart, int moduleStride, - int numImgColors, int numGroups, - float scaleTargets, float scaleOutput) { - _filterActs(stream, images, filters, targets, imgSizeY, numModulesY, numModulesX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput, false); +void localFilterActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int paddingStart, + int moduleStride, int numImgColors, int numGroups, float scaleTargets, + float scaleOutput) { + _filterActs( + stream, images, filters, targets, imgSizeY, numModulesY, numModulesX, + paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, + scaleOutput, false); } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color.cuh b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color.cuh index 69ad2ee7..cfd5bae3 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color.cuh @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "filter_act_templates.cuh" @@ -34,9 +35,8 @@ namespace megdnn { namespace cuda { /* - * Block size B_YxB_X. Each block applies B_Y * filtersPerThread filters to B_X * imgsPerThread images. - * threadIdx.x determines image - * threadIdx.y determines filter + * Block size B_YxB_X. Each block applies B_Y * filtersPerThread filters to B_X * + * imgsPerThread images. threadIdx.x determines image threadIdx.y determines filter * * blockIdx.x determines image batch of B_X * imgsPerThread * blockIdx.y determines filter batch of module and B_Y * filtersPerThread @@ -54,17 +54,25 @@ namespace cuda { * The imgSize here is the size of the actual image without the padding. * */ - template +template < + int B_Y, int B_X, int imgsPerThread, int filtersPerThread, int numColors, + int pixelCache, bool scale, bool checkImgBounds> __global__ void filterActs_YxX_color(FILTER_COLOR_PARAMS) { - __shared__ float shFilters[pixelCache*numColors][B_Y * filtersPerThread]; // pre-load pixelCache pixels from B_Y*filtersPerThread filters - __shared__ float shImages[pixelCache*numColors][B_X * imgsPerThread]; // pre-load pixelCache pixels from B_X*imgsPerThread images - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); + __shared__ float + shFilters[pixelCache * numColors] + [B_Y * filtersPerThread]; // pre-load pixelCache pixels from + // B_Y*filtersPerThread filters + __shared__ float shImages + [pixelCache * numColors] + [B_X * + imgsPerThread]; // pre-load pixelCache pixels from B_X*imgsPerThread images + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); __syncthreads(); const int imgPixels = imgSizeY * imgSizeX; const int filterPixels = filterSize * filterSize; - const int blocksPerModule = DIVUP(numFilters, (B_Y*filtersPerThread)); + const int blocksPerModule = DIVUP(numFilters, (B_Y * filtersPerThread)); const int moduleIdx = blockIdx.y / blocksPerModule; const int blockFilterIdx = filtersPerThread * B_Y * (blockIdx.y % blocksPerModule); @@ -77,66 +85,69 @@ __global__ void filterActs_YxX_color(FILTER_COLOR_PARAMS) { const int shFilterLoadX = tidx % (B_Y * filtersPerThread); const int myImgIdx = blockIdx.x * B_X * imgsPerThread + threadIdx.x; images += myImgIdx; - filters += blockFilterIdx - + shFilterLoadY * numFilters + shFilterLoadX; + filters += blockFilterIdx + shFilterLoadY * numFilters + shFilterLoadX; if (!conv) { filters += moduleIdx * numColors * filterPixels * numFilters; } bool active_thread_y = (blockFilterIdx + shFilterLoadX) < numFilters; - targets += moduleIdx * numImages - + myImgIdx - + (blockFilterIdx + threadIdx.y*filtersPerThread) * numImages * numModulesY * numModulesX; - + targets += moduleIdx * numImages + myImgIdx + + (blockFilterIdx + threadIdx.y * filtersPerThread) * numImages * + numModulesY * numModulesX; float prod[filtersPerThread][imgsPerThread]; - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - #pragma unroll - for(int g = 0; g < imgsPerThread; g++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { +#pragma unroll + for (int g = 0; g < imgsPerThread; g++) { prod[f][g] = 0; } } - //float* shImgLoad = &shImages[0][threadIdx.x]; + // float* shImgLoad = &shImages[0][threadIdx.x]; for (int p = 0; p < filterPixels; p += pixelCache) { /* * Load pixelCache pixels from B_Y*filtersPerThread filters * This condition covers the case when B_X is not divisible by filtersPerThread. - * In this case, not all of the threads will participate in the loading operation. - * This ensures that in each loop iteration, an integer number of rows of shFilters - * are filled, which makes indexing simple. + * In this case, not all of the threads will participate in the loading + * operation. This ensures that in each loop iteration, an integer number of + * rows of shFilters are filled, which makes indexing simple. */ - if (B_X % filtersPerThread == 0 || shFilterLoadY < B_X/filtersPerThread) { - #pragma unroll - for (int p2 = 0; p2 < pixelCache; p2 += B_X/filtersPerThread) { + if (B_X % filtersPerThread == 0 || shFilterLoadY < B_X / filtersPerThread) { +#pragma unroll + for (int p2 = 0; p2 < pixelCache; p2 += B_X / filtersPerThread) { const bool omit = pixelCache % (B_X / filtersPerThread) == 0; const int preloadPx = shFilterLoadY + p2; if (omit || preloadPx < pixelCache) { if (p + preloadPx < filterPixels && active_thread_y) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - shFilters[shFilterLoadY + p2 + c * pixelCache][shFilterLoadX] = filters[(c * filterPixels + p + p2) * numFilters]; + shFilters[shFilterLoadY + p2 + c * pixelCache] + [shFilterLoadX] = + filters[(c * filterPixels + p + p2) * + numFilters]; } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - shFilters[shFilterLoadY + p2 + c * pixelCache][shFilterLoadX] = 0; + shFilters[shFilterLoadY + p2 + c * pixelCache] + [shFilterLoadX] = 0; } } } } } - /* - * Load pixelCache pixels from B_X*imgsPerThread images. - */ - #pragma unroll +/* + * Load pixelCache pixels from B_X*imgsPerThread images. + */ +#pragma unroll for (int ly = 0; ly < pixelCache; ly += B_Y) { const int preloadPx = ly + threadIdx.y; const int pixIdx = p + preloadPx; - const bool omit = pixelCache % B_Y == 0; // Compile-time condition + const bool omit = pixelCache % B_Y == 0; // Compile-time condition /* - * Don't load any image pixels corresponding to filter pixels that don't exist. + * Don't load any image pixels corresponding to filter pixels that don't + * exist. */ if (pixIdx < filterPixels && (omit || preloadPx < pixelCache)) { const int x = imgLoadModPosX + pixIdx % filterSize; @@ -145,23 +156,27 @@ __global__ void filterActs_YxX_color(FILTER_COLOR_PARAMS) { if (y >= 0 && y < imgSizeY && x >= 0 && x < imgSizeX) { float* m = &images[imgStride * (y * imgSizeX + x)]; - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { - shImages[preloadPx + c * pixelCache][threadIdx.x * imgsPerThread + i] = m[c * imgStride * imgPixels + i * B_X]; + shImages[preloadPx + c * pixelCache] + [threadIdx.x * imgsPerThread + i] = + m[c * imgStride * imgPixels + i * B_X]; } else { - shImages[preloadPx + c * pixelCache][threadIdx.x * imgsPerThread + i] = 0; + shImages[preloadPx + c * pixelCache] + [threadIdx.x * imgsPerThread + i] = 0; } } } - } else { // Padding - #pragma unroll + } else { // Padding +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[preloadPx + c * pixelCache][threadIdx.x * imgsPerThread + i] = 0; + shImages[preloadPx + c * pixelCache] + [threadIdx.x * imgsPerThread + i] = 0; } } } @@ -170,45 +185,49 @@ __global__ void filterActs_YxX_color(FILTER_COLOR_PARAMS) { __syncthreads(); - #pragma unroll - for (int i = 0; i < pixelCache*numColors; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - #pragma unroll - for(int g = 0; g < imgsPerThread; g++) { - prod[f][g] += shImages[i][g + threadIdx.x * imgsPerThread] - * shFilters[i][threadIdx.y * filtersPerThread + f]; +#pragma unroll + for (int i = 0; i < pixelCache * numColors; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { +#pragma unroll + for (int g = 0; g < imgsPerThread; g++) { + prod[f][g] += shImages[i][g + threadIdx.x * imgsPerThread] * + shFilters[i][threadIdx.y * filtersPerThread + f]; } } } __syncthreads(); } - int filtersThisThread = numFilters - blockFilterIdx - threadIdx.y * filtersPerThread; + int filtersThisThread = + numFilters - blockFilterIdx - threadIdx.y * filtersPerThread; if (filtersThisThread > filtersPerThread) { filtersThisThread = filtersPerThread; } - //active_thread_y = (blockFilterIdx + threadIdx.y * filtersPerThread) < numFilters; + // active_thread_y = (blockFilterIdx + threadIdx.y * filtersPerThread) < numFilters; if (scale) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersThisThread; f++) { - #pragma unroll +#pragma unroll for (int g = 0; g < imgsPerThread; g++) { if (!checkImgBounds || myImgIdx + g * B_X < numImages) { targets[g * B_X + f * numImages * numModules] = - scaleTargets * targets[g * B_X + f * numImages * numModules] + scaleOutputs * prod[f][g]; + scaleTargets * + targets[g * B_X + f * numImages * numModules] + + scaleOutputs * prod[f][g]; } } } } else { - #pragma unroll +#pragma unroll for (int g = 0; g < imgsPerThread; g++) { if (!checkImgBounds || myImgIdx + g * B_X < numImages) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersThisThread; f++) { - //if (active_thread_y) { - targets[g * B_X + f * numImages * numModules] = scaleOutputs * prod[f][g]; + // if (active_thread_y) { + targets[g * B_X + f * numImages * numModules] = + scaleOutputs * prod[f][g]; //} } } @@ -216,55 +235,54 @@ __global__ void filterActs_YxX_color(FILTER_COLOR_PARAMS) { } } - #define FILTER_COLOR_HEAD template __global__ void filterActs_YxX_color -#define FILTER_COLOR(scale, ckImg) \ -FILTER_COLOR_HEAD < 4, 32, 4, 8, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 4, 4, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ - \ -FILTER_COLOR_HEAD < 4, 32, 2, 16, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 2, 12, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 2, 8, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 2, 4, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ - \ -FILTER_COLOR_HEAD < 4, 32, 1, 16, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 1, 12, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 1, 8, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 1, 4, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ - \ -FILTER_COLOR_HEAD < 4, 32, 4, 16, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 4, 12, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 4, 8, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 4, 4, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ - \ -FILTER_COLOR_HEAD < 4, 32, 2, 16, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 2, 12, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 2, 8, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 2, 4, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ - \ -FILTER_COLOR_HEAD < 4, 32, 1, 16, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 1, 12, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 1, 8, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 1, 4, 2, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ - \ -FILTER_COLOR_HEAD < 4, 32, 4, 16, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 4, 12, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 4, 8, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 4, 4, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ - \ -FILTER_COLOR_HEAD < 4, 32, 2, 16, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 2, 12, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 2, 8, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 2, 4, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ - \ -FILTER_COLOR_HEAD < 4, 32, 1, 16, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 1, 12, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 1, 8, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 1, 4, 1, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -\ -FILTER_COLOR_HEAD < 4, 32, 4, 16, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ -FILTER_COLOR_HEAD < 4, 32, 4, 12, 3, 4, scale, ckImg > (FILTER_COLOR_PARAMS); \ +#define FILTER_COLOR(scale, ckImg) \ + FILTER_COLOR_HEAD<4, 32, 4, 8, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 4, 4, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + \ + FILTER_COLOR_HEAD<4, 32, 2, 16, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 2, 12, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 2, 8, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 2, 4, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + \ + FILTER_COLOR_HEAD<4, 32, 1, 16, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 1, 12, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 1, 8, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 1, 4, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + \ + FILTER_COLOR_HEAD<4, 32, 4, 16, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 4, 12, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 4, 8, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 4, 4, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + \ + FILTER_COLOR_HEAD<4, 32, 2, 16, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 2, 12, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 2, 8, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 2, 4, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + \ + FILTER_COLOR_HEAD<4, 32, 1, 16, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 1, 12, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 1, 8, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 1, 4, 2, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + \ + FILTER_COLOR_HEAD<4, 32, 4, 16, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 4, 12, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 4, 8, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 4, 4, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + \ + FILTER_COLOR_HEAD<4, 32, 2, 16, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 2, 12, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 2, 8, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 2, 4, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + \ + FILTER_COLOR_HEAD<4, 32, 1, 16, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 1, 12, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 1, 8, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 1, 4, 1, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + \ + FILTER_COLOR_HEAD<4, 32, 4, 16, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); \ + FILTER_COLOR_HEAD<4, 32, 4, 12, 3, 4, scale, ckImg>(FILTER_COLOR_PARAMS); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale0_ckimg0.cu b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale0_ckimg0.cu index 540e3626..8cbab6d3 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale0_ckimg0.cu +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale0_ckimg0.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "filter_act_color.cuh" @@ -33,5 +34,5 @@ namespace megdnn { namespace cuda { FILTER_COLOR(false, false) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale0_ckimg1.cu b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale0_ckimg1.cu index fb64b2e1..d595fb5a 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale0_ckimg1.cu +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale0_ckimg1.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "filter_act_color.cuh" @@ -33,5 +34,5 @@ namespace megdnn { namespace cuda { FILTER_COLOR(false, true) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale1_ckimg0.cu b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale1_ckimg0.cu index 402ccd1e..f6c04587 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale1_ckimg0.cu +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale1_ckimg0.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "filter_act_color.cuh" @@ -33,5 +34,5 @@ namespace megdnn { namespace cuda { FILTER_COLOR(true, false) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale1_ckimg1.cu b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale1_ckimg1.cu index 4c3edc9c..9f5ce491 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale1_ckimg1.cu +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_color_scale1_ckimg1.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "filter_act_color.cuh" @@ -33,5 +34,5 @@ namespace megdnn { namespace cuda { FILTER_COLOR(true, true) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2.cuh b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2.cuh index 5d708568..5d5727d6 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2.cuh @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "filter_act_templates.cuh" @@ -34,9 +35,8 @@ namespace megdnn { namespace cuda { /* - * Block size B_YxB_X. Each block applies B_Y * filtersPerThread filters to B_X * imgsPerThread images. - * threadIdx.x determines image - * threadIdx.y determines filter + * Block size B_YxB_X. Each block applies B_Y * filtersPerThread filters to B_X * + * imgsPerThread images. threadIdx.x determines image threadIdx.y determines filter * * blockIdx.x determines image batch of B_X * imgsPerThread * blockIdx.y determines filter batch of B_Y * filtersPerThread @@ -60,30 +60,35 @@ namespace cuda { * numFilters must be divisible by numGroups. * no restrictions on pixelCache * The imgSize here is the size of the actual image without the padding. - * As always, try to make B_X * imgsPerThread == B_Y * filtersPerThread for maximum efficiency. + * As always, try to make B_X * imgsPerThread == B_Y * filtersPerThread for maximum + * efficiency. * */ -template -__global__ void filterActs_YxX_sparse2(float* images, float* filters, float* targets, - const int numImages, const int numFilters, - const int imgSizeY, const int imgSizeX, - const int filterSize, const int paddingStart, - const int moduleStride, - const int numModulesY, const int numModulesX, - const int imgStride, const int numImgColors, - const int numGroups, - const float scaleTargets, const float scaleOutputs, - const bool conv) { - __shared__ float shFilters[colorCache][B_Y * filtersPerThread]; // pre-load 1 pixel from B_Y*filtersPerThread filters - __shared__ float shImages[colorCache][B_X * imgsPerThread]; // pre-load 1 pixel from B_X*imgsPerThread images - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); +template < + int B_Y, int B_X, int imgsPerThread, int filtersPerThread, int colorCache, + bool scale, bool checkImgBounds> +__global__ void filterActs_YxX_sparse2( + float* images, float* filters, float* targets, const int numImages, + const int numFilters, const int imgSizeY, const int imgSizeX, + const int filterSize, const int paddingStart, const int moduleStride, + const int numModulesY, const int numModulesX, const int imgStride, + const int numImgColors, const int numGroups, const float scaleTargets, + const float scaleOutputs, const bool conv) { + __shared__ float shFilters + [colorCache] + [B_Y * + filtersPerThread]; // pre-load 1 pixel from B_Y*filtersPerThread filters + __shared__ float + shImages[colorCache] + [B_X * + imgsPerThread]; // pre-load 1 pixel from B_X*imgsPerThread images + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); __syncthreads(); const int imgPixels = imgSizeY * imgSizeX; const int filterPixels = filterSize * filterSize; const int numFilterColors = numImgColors / numGroups; - const int blocksPerModule = DIVUP(numFilters, (B_Y*filtersPerThread)); + const int blocksPerModule = DIVUP(numFilters, (B_Y * filtersPerThread)); const int moduleIdx = blockIdx.y / blocksPerModule; const int blockFilterIdx = filtersPerThread * B_Y * (blockIdx.y % blocksPerModule); const int numFiltersPerGroup = numFilters / numGroups; @@ -102,22 +107,21 @@ __global__ void filterActs_YxX_sparse2(float* images, float* filters, float* tar const int myImgIdx = blockIdx.x * B_X * imgsPerThread + threadIdx.x; images += (blockColorIdx + threadIdx.y) * imgPixels * imgStride + myImgIdx; - filters +=blockFilterIdx + shFilterLoadX - + shFilterLoadY * numFilters * filterPixels; + filters += + blockFilterIdx + shFilterLoadX + shFilterLoadY * numFilters * filterPixels; if (!conv) { filters += moduleIdx * numFilterColors * filterPixels * numFilters; } bool active_thread_y = (blockFilterIdx + shFilterLoadX) < numFilters; - targets += moduleIdx * numImages - + (blockFilterIdx + threadIdx.y) * numImages * numModules - + myImgIdx; + targets += moduleIdx * numImages + + (blockFilterIdx + threadIdx.y) * numImages * numModules + myImgIdx; float prod[filtersPerThread][imgsPerThread]; - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - #pragma unroll - for(int g = 0; g < imgsPerThread; g++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { +#pragma unroll + for (int g = 0; g < imgsPerThread; g++) { prod[f][g] = 0; } } @@ -125,33 +129,42 @@ __global__ void filterActs_YxX_sparse2(float* images, float* filters, float* tar const int imgStartY = MAX(0, imgLoadModPosY); const int imgEndX = MIN(imgLoadModPosX + filterSize, imgSizeX); const int imgEndY = MIN(imgLoadModPosY + filterSize, imgSizeY); -// __shared__ int imgPos[] + // __shared__ int imgPos[] for (int imgY = imgStartY; imgY < imgEndY; ++imgY) { const int filterPxY = imgY - imgLoadModPosY; for (int imgX = imgStartX; imgX < imgEndX; ++imgX) { const int filterPxX = imgX - imgLoadModPosX; const int p = filterPxY * filterSize + filterPxX; - for (int oc = 0; oc < numFilterColors; oc += colorCache) { // oc stands for outer color (loop) + for (int oc = 0; oc < numFilterColors; + oc += colorCache) { // oc stands for outer color (loop) /* * Load a pixel from B_Y*filtersPerThread filters - * This condition covers the case when B_X is not divisible by filtersPerThread. - * In this case, not all of the threads will participate in the loading operation. - * This ensures that in each loop iteration, an integer number of rows of shFilters + * This condition covers the case when B_X is not divisible by + filtersPerThread. + * In this case, not all of the threads will participate in the loading + operation. + * This ensures that in each loop iteration, an integer number of rows + of shFilters * are filled, which makes indexing simple. - * nvcc is behaving in a completely insane way: removing this condition under + * nvcc is behaving in a completely insane way: removing this condition + under * template parameters that guarantee it to be true actually slows down * the computation. * */ - if (/*B_X % filtersPerThread == 0 ||*/ shFilterLoadY < B_X/filtersPerThread) { - #pragma unroll - for (int c = 0; c < colorCache; c += B_X/filtersPerThread) { - if (colorCache % (B_X/filtersPerThread) == 0 || c + shFilterLoadY < colorCache) { + if (/*B_X % filtersPerThread == 0 ||*/ shFilterLoadY < + B_X / filtersPerThread) { +#pragma unroll + for (int c = 0; c < colorCache; c += B_X / filtersPerThread) { + if (colorCache % (B_X / filtersPerThread) == 0 || + c + shFilterLoadY < colorCache) { if (active_thread_y) { - shFilters[c + shFilterLoadY][shFilterLoadX] = filters[((oc+c) * filterPixels + p) * numFilters]; + shFilters[c + shFilterLoadY][shFilterLoadX] = + filters[((oc + c) * filterPixels + p) * + numFilters]; } else { shFilters[c + shFilterLoadY][shFilterLoadX] = 0; } @@ -162,16 +175,17 @@ __global__ void filterActs_YxX_sparse2(float* images, float* filters, float* tar /* * Load a pixel from B_X*imgsPerThread images. */ - const int pixIdx = imgY * imgSizeX + imgX;// Pixel index in img + const int pixIdx = imgY * imgSizeX + imgX; // Pixel index in img float* m = &images[imgStride * (oc * imgPixels + pixIdx)]; - #pragma unroll +#pragma unroll for (int c = 0; c < colorCache; c += B_Y) { if (colorCache % B_Y == 0 || threadIdx.y + c < colorCache) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { - shImages[c + threadIdx.y][threadIdx.x + i * B_X] = m[c * imgStride * imgPixels + i * B_X]; + shImages[c + threadIdx.y][threadIdx.x + i * B_X] = + m[c * imgStride * imgPixels + i * B_X]; } else { shImages[c + threadIdx.y][threadIdx.x + i * B_X] = 0; } @@ -182,11 +196,12 @@ __global__ void filterActs_YxX_sparse2(float* images, float* filters, float* tar __syncthreads(); for (int c = 0; c < colorCache; c++) { - #pragma unroll - for(int g = 0; g < imgsPerThread; g++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - prod[f][g] += shImages[c][g * B_X + threadIdx.x] * shFilters[c][threadIdx.y + f * B_Y]; +#pragma unroll + for (int g = 0; g < imgsPerThread; g++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { + prod[f][g] += shImages[c][g * B_X + threadIdx.x] * + shFilters[c][threadIdx.y + f * B_Y]; } } } @@ -196,32 +211,37 @@ __global__ void filterActs_YxX_sparse2(float* images, float* filters, float* tar } int filtersThisThread = filtersPerThread; - //if(checkFilterBounds) { - int filtersThisBlock = numFilters - (blockIdx.y % blocksPerModule) - * (B_Y*filtersPerThread); - if (filtersThisBlock < (B_Y * filtersPerThread)) { - filtersThisThread = (filtersThisBlock - threadIdx.y + filtersPerThread - 1) / filtersPerThread; - } + // if(checkFilterBounds) { + int filtersThisBlock = + numFilters - (blockIdx.y % blocksPerModule) * (B_Y * filtersPerThread); + if (filtersThisBlock < (B_Y * filtersPerThread)) { + filtersThisThread = (filtersThisBlock - threadIdx.y + filtersPerThread - 1) / + filtersPerThread; + } //} if (scale) { - #pragma unroll +#pragma unroll for (int g = 0; g < imgsPerThread; g++) { if (!checkImgBounds || myImgIdx + g * B_X < numImages) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersThisThread; f++) { - targets[g * B_X + f * B_Y * numImages * numModules] = scaleTargets * targets[g * B_X + f * B_Y * numImages * numModules] + scaleOutputs * prod[f][g]; + targets[g * B_X + f * B_Y * numImages * numModules] = + scaleTargets * targets[g * B_X + + f * B_Y * numImages * numModules] + + scaleOutputs * prod[f][g]; } } } } else { - // Note: reversing order of these loops saves 2 registers, but costs time - #pragma unroll +// Note: reversing order of these loops saves 2 registers, but costs time +#pragma unroll for (int f = 0; f < filtersThisThread; f++) { - #pragma unroll +#pragma unroll for (int g = 0; g < imgsPerThread; g++) { if (!checkImgBounds || myImgIdx + g * B_X < numImages) { - targets[g * B_X + f * B_Y * numImages * numModules] = scaleOutputs * prod[f][g]; + targets[g * B_X + f * B_Y * numImages * numModules] = + scaleOutputs * prod[f][g]; } } } @@ -231,31 +251,31 @@ __global__ void filterActs_YxX_sparse2(float* images, float* filters, float* tar #define FILTER_SPARSE2_HEAD template __global__ void filterActs_YxX_sparse2 // -#define FILTER_SPARSE2(scale, ckImg) \ -FILTER_SPARSE2_HEAD < 4, 32, 4, 8, 8, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 4, 4, 8, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -\ -FILTER_SPARSE2_HEAD < 8, 32, 2, 16, 8, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 2, 16, 8, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 2, 8, 8, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 2, 4, 8, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -\ -FILTER_SPARSE2_HEAD < 8, 32, 1, 16, 8, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 1, 16, 8, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 1, 8, 8, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 1, 4, 8, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -\ -FILTER_SPARSE2_HEAD < 4, 32, 4, 16, 4, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 4, 8, 4, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 4, 4, 4, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -\ -FILTER_SPARSE2_HEAD < 4, 32, 2, 16, 4, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 2, 8, 4, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 2, 4, 4, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -\ -FILTER_SPARSE2_HEAD < 4, 32, 1, 16, 4, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 1, 8, 4, scale, ckImg > (FILTER_SPARSE2_PARAMS); \ -FILTER_SPARSE2_HEAD < 4, 32, 1, 4, 4, scale, ckImg > (FILTER_SPARSE2_PARAMS); +#define FILTER_SPARSE2(scale, ckImg) \ + FILTER_SPARSE2_HEAD<4, 32, 4, 8, 8, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 4, 4, 8, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + \ + FILTER_SPARSE2_HEAD<8, 32, 2, 16, 8, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 2, 16, 8, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 2, 8, 8, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 2, 4, 8, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + \ + FILTER_SPARSE2_HEAD<8, 32, 1, 16, 8, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 1, 16, 8, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 1, 8, 8, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 1, 4, 8, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + \ + FILTER_SPARSE2_HEAD<4, 32, 4, 16, 4, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 4, 8, 4, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 4, 4, 4, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + \ + FILTER_SPARSE2_HEAD<4, 32, 2, 16, 4, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 2, 8, 4, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 2, 4, 4, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + \ + FILTER_SPARSE2_HEAD<4, 32, 1, 16, 4, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 1, 8, 4, scale, ckImg>(FILTER_SPARSE2_PARAMS); \ + FILTER_SPARSE2_HEAD<4, 32, 1, 4, 4, scale, ckImg>(FILTER_SPARSE2_PARAMS); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale0_ckimg0.cu b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale0_ckimg0.cu index 7776823f..4870604b 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale0_ckimg0.cu +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale0_ckimg0.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "filter_act_sparse2.cuh" @@ -34,6 +35,5 @@ namespace cuda { FILTER_SPARSE2(false, false) - -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale0_ckimg1.cu b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale0_ckimg1.cu index 8aed4de9..eff2a9de 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale0_ckimg1.cu +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale0_ckimg1.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "filter_act_sparse2.cuh" @@ -34,6 +35,5 @@ namespace cuda { FILTER_SPARSE2(false, true) - -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale1_ckimg0.cu b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale1_ckimg0.cu index 5f979d15..7e42cd2c 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale1_ckimg0.cu +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale1_ckimg0.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "filter_act_sparse2.cuh" @@ -34,6 +35,5 @@ namespace cuda { FILTER_SPARSE2(true, false) - -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale1_ckimg1.cu b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale1_ckimg1.cu index a69bd797..1eb64d1a 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale1_ckimg1.cu +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_scale1_ckimg1.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "filter_act_sparse2.cuh" @@ -34,6 +35,5 @@ namespace cuda { FILTER_SPARSE2(true, true) - -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_y4x32i4f16c4_tex.cu b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_y4x32i4f16c4_tex.cu index 3e49d026..7c16af09 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_y4x32i4f16c4_tex.cu +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_y4x32i4f16c4_tex.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_y4x32i4f16c4_tex.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_sparse2_y4x32i4f16c4_tex.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "filter_act_templates.cuh" @@ -33,18 +35,26 @@ namespace megdnn { namespace cuda { -template -__global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex (FILTER_ACTS_PARAMS) { - __shared__ float shFilters[colorCache][B_Y * filtersPerThread]; // pre-load 1 pixel from B_Y*filtersPerThread filters - __shared__ float shImages[colorCache][B_X * imgsPerThread]; // pre-load 1 pixel from B_X*imgsPerThread images - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); +template < + int B_Y, int B_X, int imgsPerThread, int filtersPerThread, int colorCache, + bool scale, bool checkImgBounds> +__global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex( + FILTER_ACTS_PARAMS) { + __shared__ float shFilters + [colorCache] + [B_Y * + filtersPerThread]; // pre-load 1 pixel from B_Y*filtersPerThread filters + __shared__ float + shImages[colorCache] + [B_X * + imgsPerThread]; // pre-load 1 pixel from B_X*imgsPerThread images + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); __syncthreads(); const int imgPixels = imgSizeY * imgSizeX; const int filterPixels = filterSize * filterSize; const int numFilterColors = numImgColors / numGroups; - const int blocksPerModule = numFilters / (B_Y*filtersPerThread); + const int blocksPerModule = numFilters / (B_Y * filtersPerThread); const int moduleIdx = blockIdx.y / blocksPerModule; const int blockFilterIdx = filtersPerThread * B_Y * (blockIdx.y % blocksPerModule); const int numFiltersPerGroup = numFilters / numGroups; @@ -64,27 +74,30 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex (FILT const int shFilterLoadY = tidx / (B_Y * filtersPerThread); const int shFilterLoadX = tidx % (B_Y * filtersPerThread); const int myImgIdx = blockIdx.x * B_X * imgsPerThread + threadIdx.x; - const int imgOffset = (blockColorIdx + threadIdx.y) * imgPixels * imgStride + myImgIdx; + const int imgOffset = + (blockColorIdx + threadIdx.y) * imgPixels * imgStride + myImgIdx; -// images += (blockColorIdx + threadIdx.y) * imgPixels * imgStride + myImgIdx; - const int filterOffset = blockFilterIdx - + shFilterLoadY * numFilters * filterPixels + shFilterLoadX + (conv ? 0 : moduleIdx * numFilterColors * filterPixels * numFilters); -// filters +=blockFilterIdx -// + shFilterLoadY * numFilters * filterPixels + shFilterLoadX; -// if (!conv) { -// filters += moduleIdx * numFilterColors * filterPixels * numFilters; -// } + // images += (blockColorIdx + threadIdx.y) * imgPixels * imgStride + myImgIdx; + const int filterOffset = + blockFilterIdx + shFilterLoadY * numFilters * filterPixels + shFilterLoadX + + (conv ? 0 : moduleIdx * numFilterColors * filterPixels * numFilters); + // filters +=blockFilterIdx + // + shFilterLoadY * numFilters * filterPixels + shFilterLoadX; + // if (!conv) { + // filters += moduleIdx * numFilterColors * filterPixels * numFilters; + // } - targets += moduleIdx * numImages - + (blockFilterIdx + threadIdx.y * filtersPerThread) * numImages * numModules - + myImgIdx; + targets += + moduleIdx * numImages + + (blockFilterIdx + threadIdx.y * filtersPerThread) * numImages * numModules + + myImgIdx; float prod[imgsPerThread][filtersPerThread]; -// float fCache[filtersPerThread]; - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { + // float fCache[filtersPerThread]; +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { prod[i][f] = 0; } } @@ -93,37 +106,43 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex (FILT const int imgStartY = max(0, imgLoadModPosY); const int imgEndX = min(imgLoadModPosX + filterSize, imgSizeX); const int imgEndY = min(imgLoadModPosY + filterSize, imgSizeY); -// __shared__ int imgPos[] + // __shared__ int imgPos[] int fPidx, iPidx; - float imPreload[imgsPerThread]; // [4] - float fPreload[colorCache*filtersPerThread/B_X]; // [2] -// float fCache[filtersPerThread]; + float imPreload[imgsPerThread]; // [4] + float fPreload[colorCache * filtersPerThread / B_X]; // [2] + // float fCache[filtersPerThread]; - filterActs_YxX_sparse2_preload_ty_4_tx_32_f_16_c_4_setPixelCoords(filterSize, imgSizeX, imgLoadModPosY, imgLoadModPosX, imgStartY, imgStartX, fPidx, iPidx); + filterActs_YxX_sparse2_preload_ty_4_tx_32_f_16_c_4_setPixelCoords( + filterSize, imgSizeX, imgLoadModPosY, imgLoadModPosX, imgStartY, imgStartX, + fPidx, iPidx); - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { - imPreload[i] = tex1Dfetch(images, imgOffset + imgStride * iPidx + i * B_X); + imPreload[i] = + tex1Dfetch(images, imgOffset + imgStride * iPidx + i * B_X); } else { imPreload[i] = 0; } } - if (/*B_X % filtersPerThread == 0 ||*/ shFilterLoadY < B_X/filtersPerThread) { // This if statement reduces reg usage.. - #pragma unroll - for (int c = 0; c < colorCache; c += B_X/filtersPerThread) { - fPreload[c*filtersPerThread/B_X] = tex1Dfetch(filters, filterOffset + (c * filterPixels + fPidx) * numFilters); + if (/*B_X % filtersPerThread == 0 ||*/ shFilterLoadY < + B_X / filtersPerThread) { // This if statement reduces reg usage.. +#pragma unroll + for (int c = 0; c < colorCache; c += B_X / filtersPerThread) { + fPreload[c * filtersPerThread / B_X] = tex1Dfetch( + filters, filterOffset + (c * filterPixels + fPidx) * numFilters); } } for (int imgY = imgStartY; imgY < imgEndY; ++imgY) { -// const int filterPxY = imgY - imgLoadModPosY; + // const int filterPxY = imgY - imgLoadModPosY; for (int imgX = imgStartX; imgX < imgEndX; ++imgX) { -// const int filterPxX = imgX - imgLoadModPosX; -// const int p = filterPxY * filterSize + filterPxX; -// const int pixIdx = imgY * imgSizeX + imgX;// Pixel index in img -// setPixelCoords(filterSize, imgSizeX, imgLoadModPosY, imgLoadModPosX, imgY, imgX, &p, &pixIdx); -// float* m = &images[imgStride * pixIdx]; + // const int filterPxX = imgX - imgLoadModPosX; + // const int p = filterPxY * filterSize + filterPxX; + // const int pixIdx = imgY * imgSizeX + imgX;// Pixel index in + // img setPixelCoords(filterSize, imgSizeX, imgLoadModPosY, + // imgLoadModPosX, imgY, imgX, &p, &pixIdx); float* m = + // &images[imgStride * pixIdx]; const bool lastPixel = imgY == imgEndY - 1 && imgX == imgEndX - 1; int imgYNext = imgY; int imgXNext = imgX; @@ -132,12 +151,20 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex (FILT imgYNext = imgY + (imgX + 1 == imgEndX); imgXNext = imgX + 1 == imgEndX ? imgStartX : imgX + 1; } - filterActs_YxX_sparse2_preload_ty_4_tx_32_f_16_c_4_setPixelCoords(filterSize, imgSizeX, imgLoadModPosY, imgLoadModPosX, imgYNext, imgXNext, fPidxNext, iPidxNext); - for (int oc = 0; oc < numFilterColors; oc += colorCache) { // oc stands for outer color (loop) -// const float* ff = &filters[numFilters * ((oc + colorCache) * filterPixels + fPidx)]; -// const float* mm = &images[imgStride * ((oc + colorCache) * imgPixels + iPidx)]; - int imgOffset2 = imgOffset + imgStride * ((oc + colorCache) * imgPixels + iPidx); - int filterOffset2 = filterOffset + numFilters * ((oc + colorCache) * filterPixels + fPidx); + filterActs_YxX_sparse2_preload_ty_4_tx_32_f_16_c_4_setPixelCoords( + filterSize, imgSizeX, imgLoadModPosY, imgLoadModPosX, imgYNext, + imgXNext, fPidxNext, iPidxNext); + for (int oc = 0; oc < numFilterColors; + oc += colorCache) { // oc stands for outer color (loop) + // const float* ff = &filters[numFilters * ((oc + + // colorCache) * filterPixels + fPidx)]; const float* mm + // = &images[imgStride * ((oc + colorCache) * imgPixels + + // iPidx)]; + int imgOffset2 = + imgOffset + imgStride * ((oc + colorCache) * imgPixels + iPidx); + int filterOffset2 = + filterOffset + + numFilters * ((oc + colorCache) * filterPixels + fPidx); if (oc == numFilterColors - colorCache) { filterOffset2 = filterOffset + fPidxNext * numFilters; imgOffset2 = imgOffset + iPidxNext * imgStride; @@ -145,57 +172,73 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex (FILT iPidx = iPidxNext; } - #pragma unroll - for (int c = 0; c < colorCache; c += B_X/filtersPerThread) { - shFilters[c + shFilterLoadY][shFilterLoadX] = fPreload[c*filtersPerThread/B_X]; +#pragma unroll + for (int c = 0; c < colorCache; c += B_X / filtersPerThread) { + shFilters[c + shFilterLoadY][shFilterLoadX] = + fPreload[c * filtersPerThread / B_X]; } - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { // NOTE: bank conflicts here! shImages[ty][tx * imgsPerThread + i] = imPreload[i]; } - imPreload[0] = (checkImgBounds && myImgIdx + 0 * B_X >= numImages) ? 0 : tex1Dfetch(images, imgOffset2 + 0 * B_X); - imPreload[1] = (checkImgBounds && myImgIdx + 1 * B_X >= numImages) ? 0 : tex1Dfetch(images, imgOffset2 + 1 * B_X); - imPreload[2] = (checkImgBounds && myImgIdx + 2 * B_X >= numImages) ? 0 : tex1Dfetch(images, imgOffset2 + 2 * B_X); + imPreload[0] = (checkImgBounds && myImgIdx + 0 * B_X >= numImages) + ? 0 + : tex1Dfetch(images, imgOffset2 + 0 * B_X); + imPreload[1] = (checkImgBounds && myImgIdx + 1 * B_X >= numImages) + ? 0 + : tex1Dfetch(images, imgOffset2 + 1 * B_X); + imPreload[2] = (checkImgBounds && myImgIdx + 2 * B_X >= numImages) + ? 0 + : tex1Dfetch(images, imgOffset2 + 2 * B_X); __syncthreads(); - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - prod[i][f] += shImages[0][threadIdx.x * imgsPerThread + i] * shFilters[0][threadIdx.y * filtersPerThread + f]; +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { + prod[i][f] += shImages[0][threadIdx.x * imgsPerThread + i] * + shFilters[0][threadIdx.y * filtersPerThread + f]; } } fPreload[0] = tex1Dfetch(filters, filterOffset2 + 0); - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - prod[i][f] += shImages[1][threadIdx.x * imgsPerThread + i] * shFilters[1][threadIdx.y * filtersPerThread + f]; +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { + prod[i][f] += shImages[1][threadIdx.x * imgsPerThread + i] * + shFilters[1][threadIdx.y * filtersPerThread + f]; } } - fPreload[1] = tex1Dfetch(filters, filterOffset2 + (B_X/filtersPerThread * filterPixels) * numFilters); + fPreload[1] = tex1Dfetch( + filters, + filterOffset2 + + (B_X / filtersPerThread * filterPixels) * numFilters); - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - prod[i][f] += shImages[2][threadIdx.x * imgsPerThread + i] * shFilters[2][threadIdx.y * filtersPerThread + f]; +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { + prod[i][f] += shImages[2][threadIdx.x * imgsPerThread + i] * + shFilters[2][threadIdx.y * filtersPerThread + f]; } } - imPreload[3] = (checkImgBounds && myImgIdx + 3 * B_X >= numImages) ? 0 : tex1Dfetch(images, imgOffset2 + 3 * B_X); + imPreload[3] = (checkImgBounds && myImgIdx + 3 * B_X >= numImages) + ? 0 + : tex1Dfetch(images, imgOffset2 + 3 * B_X); - #pragma unroll - for(int i = 0; i < imgsPerThread; i++) { - #pragma unroll - for(int f = 0; f < filtersPerThread; f++) { - prod[i][f] += shImages[3][threadIdx.x * imgsPerThread + i] * shFilters[3][threadIdx.y * filtersPerThread + f]; +#pragma unroll + for (int i = 0; i < imgsPerThread; i++) { +#pragma unroll + for (int f = 0; f < filtersPerThread; f++) { + prod[i][f] += shImages[3][threadIdx.x * imgsPerThread + i] * + shFilters[3][threadIdx.y * filtersPerThread + f]; } } __syncthreads(); @@ -204,36 +247,38 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex (FILT } if (scale) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { - targets[i * B_X + f * numImages * numModules] = scaleTargets * targets[i * B_X + f * numImages * numModules] + scaleOutputs * prod[i][f]; + targets[i * B_X + f * numImages * numModules] = + scaleTargets * + targets[i * B_X + f * numImages * numModules] + + scaleOutputs * prod[i][f]; } } } } else { - // Note: reversing order of these loops saves 2 registers, but costs time - #pragma unroll +// Note: reversing order of these loops saves 2 registers, but costs time +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { if (!checkImgBounds || myImgIdx + i * B_X < numImages) { - targets[i * B_X + f * numImages * numModules] = scaleOutputs * prod[i][f]; + targets[i * B_X + f * numImages * numModules] = + scaleOutputs * prod[i][f]; } } } } } -template __global__ void -filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex -< 4, 32, 4, 16, 4, false, false >(FILTER_ACTS_PARAMS); +template __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex< + 4, 32, 4, 16, 4, false, false>(FILTER_ACTS_PARAMS); -template __global__ void -filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex -< 4, 32, 4, 16, 4, true, false >(FILTER_ACTS_PARAMS); +template __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex< + 4, 32, 4, 16, 4, true, false>(FILTER_ACTS_PARAMS); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_templates.cuh b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_templates.cuh index 88f76bc7..1735334f 100644 --- a/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_templates.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/filter_acts/filter_act_templates.cuh @@ -25,38 +25,34 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ -#include "../nvmatrix.cuh" #include "../cudaconv2.cuh" +#include "../nvmatrix.cuh" #include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { -__device__ inline void - filterActs_YxX_sparse2_preload_ty_4_tx_32_f_16_c_4_setPixelCoords - (int filterSize, int imgSizeX, - int imgLoadModPosY, int imgLoadModPosX, - int imgY, int imgX, int& fPidx, int& iPidx) { +__device__ inline void filterActs_YxX_sparse2_preload_ty_4_tx_32_f_16_c_4_setPixelCoords( + int filterSize, int imgSizeX, int imgLoadModPosY, int imgLoadModPosX, int imgY, + int imgX, int& fPidx, int& iPidx) { int filterPxY = imgY - imgLoadModPosY; int filterPxX = imgX - imgLoadModPosX; fPidx = filterPxY * filterSize + filterPxX; - iPidx = imgY * imgSizeX + imgX; // Pixel index in img + iPidx = imgY * imgSizeX + imgX; // Pixel index in img } -#define FILTER_ACTS_PARAMS cudaTextureObject_t images, \ - cudaTextureObject_t filters, float* targets, \ - const int numImages, const int numFilters, \ - const int imgSizeY, const int imgSizeX, \ - const int filterSize, const int paddingStart, \ - const int moduleStride, \ - const int numModulesY, const int numModulesX, \ - const int imgStride, const int numImgColors, \ - const int numGroups, \ - const float scaleTargets, const float scaleOutputs, \ - const bool conv/*, const bool noloads*/ +#define FILTER_ACTS_PARAMS \ + cudaTextureObject_t images, cudaTextureObject_t filters, float *targets, \ + const int numImages, const int numFilters, const int imgSizeY, \ + const int imgSizeX, const int filterSize, const int paddingStart, \ + const int moduleStride, const int numModulesY, const int numModulesX, \ + const int imgStride, const int numImgColors, const int numGroups, \ + const float scaleTargets, const float scaleOutputs, \ + const bool conv /*, const bool noloads*/ /* * images: (numImgColors, imgSizeY, imgSizeX, numImages) with stride given * filters: (numFilterColors, filterPixels, numFilters) if conv @@ -65,25 +61,21 @@ __device__ inline void * targets: (numFilters, numModulesY, numModulesX, numImages) * */ -template -__global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex (FILTER_ACTS_PARAMS); - - +template < + int B_Y, int B_X, int imgsPerThread, int filtersPerThread, int colorCache, + bool scale, bool checkImgBounds> +__global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex( + FILTER_ACTS_PARAMS); -#define FILTER_COLOR_PARAMS float* images, float* filters, float* targets, \ - const int numImages, const int numFilters, \ - const int imgSizeY, const int imgSizeX, \ - const int filterSize, const int paddingStart, \ - const int moduleStride, \ - const int numModulesY, const int numModulesX, \ - const int imgStride, \ - const float scaleTargets, const float scaleOutputs, \ - const bool conv +#define FILTER_COLOR_PARAMS \ + float *images, float *filters, float *targets, const int numImages, \ + const int numFilters, const int imgSizeY, const int imgSizeX, \ + const int filterSize, const int paddingStart, const int moduleStride, \ + const int numModulesY, const int numModulesX, const int imgStride, \ + const float scaleTargets, const float scaleOutputs, const bool conv /* - * Block size B_YxB_X. Each block applies B_Y * filtersPerThread filters to B_X * imgsPerThread images. - * threadIdx.x determines image - * threadIdx.y determines filter + * Block size B_YxB_X. Each block applies B_Y * filtersPerThread filters to B_X * + * imgsPerThread images. threadIdx.x determines image threadIdx.y determines filter * * blockIdx.x determines image batch of B_X * imgsPerThread * blockIdx.y determines filter batch of module and B_Y * filtersPerThread @@ -101,26 +93,21 @@ __global__ void filterActs_YxX_sparse2_preload_ty_4_tx_32_i_4_f_16_c_4_tex (FILT * The imgSize here is the size of the actual image without the padding. * */ - template +template < + int B_Y, int B_X, int imgsPerThread, int filtersPerThread, int numColors, + int pixelCache, bool scale, bool checkImgBounds> __global__ void filterActs_YxX_color(FILTER_COLOR_PARAMS); - - - -#define FILTER_SPARSE2_PARAMS float* images, float* filters, float* targets, \ - const int numImages, const int numFilters, \ - const int imgSizeY, const int imgSizeX, \ - const int filterSize, const int paddingStart, \ - const int moduleStride, \ - const int numModulesY, const int numModulesX, \ - const int imgStride, const int numImgColors, \ - const int numGroups, \ - const float scaleTargets, const float scaleOutputs, \ - const bool conv +#define FILTER_SPARSE2_PARAMS \ + float *images, float *filters, float *targets, const int numImages, \ + const int numFilters, const int imgSizeY, const int imgSizeX, \ + const int filterSize, const int paddingStart, const int moduleStride, \ + const int numModulesY, const int numModulesX, const int imgStride, \ + const int numImgColors, const int numGroups, const float scaleTargets, \ + const float scaleOutputs, const bool conv /* - * Block size B_YxB_X. Each block applies B_Y * filtersPerThread filters to B_X * imgsPerThread images. - * threadIdx.x determines image - * threadIdx.y determines filter + * Block size B_YxB_X. Each block applies B_Y * filtersPerThread filters to B_X * + * imgsPerThread images. threadIdx.x determines image threadIdx.y determines filter * * blockIdx.x determines image batch of B_X * imgsPerThread * blockIdx.y determines filter batch of B_Y * filtersPerThread @@ -144,12 +131,14 @@ __global__ void filterActs_YxX_color(FILTER_COLOR_PARAMS); * numFilters must be divisible by numGroups. * no restrictions on pixelCache * The imgSize here is the size of the actual image without the padding. - * As always, try to make B_X * imgsPerThread == B_Y * filtersPerThread for maximum efficiency. + * As always, try to make B_X * imgsPerThread == B_Y * filtersPerThread for maximum + * efficiency. * */ -template +template < + int B_Y, int B_X, int imgsPerThread, int filtersPerThread, int colorCache, + bool scale, bool checkImgBounds> __global__ void filterActs_YxX_sparse2(FILTER_SPARSE2_PARAMS); -} // namespace megdnn -} // namespace cuda +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/helper_cuda.h b/dnn/src/cuda/local/cuda-convnet2/helper_cuda.h index 25425895..ffccf459 100644 --- a/dnn/src/cuda/local/cuda-convnet2/helper_cuda.h +++ b/dnn/src/cuda/local/cuda-convnet2/helper_cuda.h @@ -11,15 +11,16 @@ /** * \file src/cuda/local/cuda-convnet2/helper_cuda.h * - * This file is part of MegDNN, a deep neural network run-time library * developed by Megvii. + * This file is part of MegDNN, a deep neural network run-time library * developed by + * Megvii. * * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved. */ #pragma once -#include "src/cuda/utils.cuh" #include -#define checkCudaErrors(x) cuda_check(x) +#include "src/cuda/utils.cuh" +#define checkCudaErrors(x) cuda_check(x) #define getLastCudaError(x) cuda_check(cudaGetLastError()) // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/local/cuda-convnet2/img_acts.cu b/dnn/src/cuda/local/cuda-convnet2/img_acts.cu index f6ca4257..aae659f9 100644 --- a/dnn/src/cuda/local/cuda-convnet2/img_acts.cu +++ b/dnn/src/cuda/local/cuda-convnet2/img_acts.cu @@ -25,14 +25,15 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "cudaconv2.cuh" -#include "nvmatrix.cuh" #include "img_acts/img_act_templates.cuh" +#include "nvmatrix.cuh" #ifdef _WIN32 #define _Pragma(x) @@ -44,78 +45,94 @@ namespace cuda { * New Titan-optimized stuff. */ -__device__ __forceinline__ void conv_img_acts_manycolor_preload_ty_8_tx_32_c_8_ff_32_fh_16_setCoords(const int my, const int mx, const int numModulesX, - const int paddingStart, const int moduleStride, const int blockPixelIdxY, const int blockPixelIdxX, const int filterSize, int &moduleIdx, int &pxIdxInFilter) { +__device__ __forceinline__ void +conv_img_acts_manycolor_preload_ty_8_tx_32_c_8_ff_32_fh_16_setCoords( + const int my, const int mx, const int numModulesX, const int paddingStart, + const int moduleStride, const int blockPixelIdxY, const int blockPixelIdxX, + const int filterSize, int& moduleIdx, int& pxIdxInFilter) { const int moduleTop = paddingStart + my * moduleStride; const int pxInFilterY = blockPixelIdxY - moduleTop; - moduleIdx = my * numModulesX + mx; // out + moduleIdx = my * numModulesX + mx; // out const int moduleLeft = paddingStart + mx * moduleStride; const int pxInFilterX = blockPixelIdxX - moduleLeft; - pxIdxInFilter = pxInFilterY * filterSize + pxInFilterX; // out + pxIdxInFilter = pxInFilterY * filterSize + pxInFilterX; // out } -#define IA_PRELOAD_LOOP(w,offset) _Pragma("unroll") \ -for (int i = 0; i < imgsPerThread; i++) { \ - _Pragma("unroll") \ - for (int c = 0; c < colorsPerThread; c++) { \ - prod[c][i] += shFilters[c * B_Y + threadIdx.y][(w)+(offset)] * shHidActs[w][threadIdx.x * imgsPerThread + i]; \ - } \ -} \ +#define IA_PRELOAD_LOOP(w, offset) \ + _Pragma("unroll") for (int i = 0; i < imgsPerThread; i++) { \ + _Pragma("unroll") for (int c = 0; c < colorsPerThread; c++) { \ + prod[c][i] += shFilters[c * B_Y + threadIdx.y][(w) + (offset)] * \ + shHidActs[w][threadIdx.x * imgsPerThread + i]; \ + } \ + } /* * Same loop as above but inverted. */ -#define IA_PRELOAD_LOOP2(w,offset) _Pragma("unroll") \ -for (int c = 0; c < colorsPerThread; c++) { \ - _Pragma("unroll") \ - for (int i = 0; i < imgsPerThread; i++) { \ - prod[c][i] += shFilters[c * B_Y + threadIdx.y][(w)+(offset)] * shHidActs[w][threadIdx.x * imgsPerThread + i]; \ - } \ -} \ - -#define IA_PRELOAD_LOOP3(i,offset) _Pragma("unroll") \ -for (int w = 0; w < filterCacheH; w++) { \ - _Pragma("unroll") \ - for (int c = 0; c < colorsPerThread; c++) { \ - prod[c][i] += shFilters[c * B_Y + threadIdx.y][(w)+(offset)] * shHidActs[w][threadIdx.x * imgsPerThread + i]; \ - } \ -} \ - -#define IA_PRELOAD_W(z) wPreload[z] = fLoad[(z) * B_X*B_Y/filterCacheF * filterPixels * numFilters]; -#define IA_PRELOAD_W_TX(z) wPreload[z] = tex1Dfetch(filters, filtersLoadOffset + (z) * B_X*B_Y/filterCacheF * filterPixels * numFilters); -#define IA_PRELOAD_H(y,x) if (!checkCaseBounds || myCaseIdx + (x) * B_X < numImages) { \ - hPreload[y][x] = hLoad[(y) * B_Y * numModules * numImages + (x) * B_X]; \ -} -#define IA_PRELOAD_H_TX(y,x) if (!checkCaseBounds || myCaseIdx + (x) * B_X < numImages) { \ - hPreload[y][x] = tex1Dfetch(hidActs, hidActsLoadOffset + (y) * B_Y * numModules * numImages + (x) * B_X); \ -} +#define IA_PRELOAD_LOOP2(w, offset) \ + _Pragma("unroll") for (int c = 0; c < colorsPerThread; c++) { \ + _Pragma("unroll") for (int i = 0; i < imgsPerThread; i++) { \ + prod[c][i] += shFilters[c * B_Y + threadIdx.y][(w) + (offset)] * \ + shHidActs[w][threadIdx.x * imgsPerThread + i]; \ + } \ + } -template -__global__ void -__launch_bounds__(256, 2) // 256 threads per block, 2 blocks per multiprocessor - // These launch bounds ensure 25% occupancy (128 registers used) - // as oppposed to 13% (130 registers) achieved by defaults. -conv_img_acts_manycolor_preloadfh_ty_8_tx_32_c_8_ff_32_fh_16_tex(cudaTextureObject_t hidActs, cudaTextureObject_t filters, float* targets, - const int numModulesY, const int numModulesX, const int numImages, const int numFilters, - const int filterSize, const int imgSizeY, const int imgSizeX, const int paddingStart, const int moduleStride, - const int numImgColors, const int numGroups, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shFilters[colorsPerThread*B_Y][filterCacheF]; - __shared__ float shHidActs[filterCacheH][B_X*imgsPerThread]; - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +#define IA_PRELOAD_LOOP3(i, offset) \ + _Pragma("unroll") for (int w = 0; w < filterCacheH; w++) { \ + _Pragma("unroll") for (int c = 0; c < colorsPerThread; c++) { \ + prod[c][i] += shFilters[c * B_Y + threadIdx.y][(w) + (offset)] * \ + shHidActs[w][threadIdx.x * imgsPerThread + i]; \ + } \ + } + +#define IA_PRELOAD_W(z) \ + wPreload[z] = fLoad[(z)*B_X * B_Y / filterCacheF * filterPixels * numFilters]; +#define IA_PRELOAD_W_TX(z) \ + wPreload[z] = tex1Dfetch( \ + filters, filtersLoadOffset + (z)*B_X * B_Y / filterCacheF * filterPixels * \ + numFilters); +#define IA_PRELOAD_H(y, x) \ + if (!checkCaseBounds || myCaseIdx + (x)*B_X < numImages) { \ + hPreload[y][x] = hLoad[(y)*B_Y * numModules * numImages + (x)*B_X]; \ + } +#define IA_PRELOAD_H_TX(y, x) \ + if (!checkCaseBounds || myCaseIdx + (x)*B_X < numImages) { \ + hPreload[y][x] = tex1Dfetch( \ + hidActs, \ + hidActsLoadOffset + (y)*B_Y * numModules * numImages + (x)*B_X); \ + } + +template < + int B_Y, int B_X, int imgsPerThread, int colorsPerThread, int filterCacheF, + int filterCacheH, bool scale, bool checkCaseBounds, bool conv> +__global__ void __launch_bounds__( + 256, 2) // 256 threads per block, 2 blocks per multiprocessor + // These launch bounds ensure 25% occupancy (128 registers used) + // as oppposed to 13% (130 registers) achieved by defaults. + conv_img_acts_manycolor_preloadfh_ty_8_tx_32_c_8_ff_32_fh_16_tex( + cudaTextureObject_t hidActs, cudaTextureObject_t filters, + float* targets, const int numModulesY, const int numModulesX, + const int numImages, const int numFilters, const int filterSize, + const int imgSizeY, const int imgSizeX, const int paddingStart, + const int moduleStride, const int numImgColors, const int numGroups, + const float scaleTargets, const float scaleOutputs) { + __shared__ float shFilters[colorsPerThread * B_Y][filterCacheF]; + __shared__ float shHidActs[filterCacheH][B_X * imgsPerThread]; + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); - const int numImgBlocks = DIVUP(numImages,B_X*imgsPerThread); - const int blockCaseIdx = (blockIdx.x % numImgBlocks) * B_X*imgsPerThread; + const int numImgBlocks = DIVUP(numImages, B_X * imgsPerThread); + const int blockCaseIdx = (blockIdx.x % numImgBlocks) * B_X * imgsPerThread; const int myCaseIdx = blockCaseIdx + threadIdx.x; - const int imgColorIdx = (blockIdx.x / numImgBlocks) * B_Y*colorsPerThread; // color idx globally + const int imgColorIdx = + (blockIdx.x / numImgBlocks) * B_Y * colorsPerThread; // color idx globally const int numFilterColors = numImgColors / numGroups; const int blockGroupIdx = imgColorIdx / numFilterColors; - const int filterColorIdx = imgColorIdx % numFilterColors; // color idx within group + const int filterColorIdx = imgColorIdx % numFilterColors; // color idx within group const int numFiltersPerGroup = numFilters / numGroups; const int blockFilterIdx = blockGroupIdx * numFiltersPerGroup; @@ -126,68 +143,89 @@ conv_img_acts_manycolor_preloadfh_ty_8_tx_32_c_8_ff_32_fh_16_tex(cudaTextureObje const int filterPixels = filterSize * filterSize; const int imgPixels = imgSizeY * imgSizeX; const int tidx = threadIdx.y * B_X + threadIdx.x; -// const int hidActLoadY = threadIdx.y % B_Y, hidActLoadX = threadIdx.x % B_X; - //const int hidActLoadY = tidx / (B_X*imgsPerThread), hidActLoadX = tidx % (B_X*imgsPerThread); + // const int hidActLoadY = threadIdx.y % B_Y, hidActLoadX = threadIdx.x % B_X; + // const int hidActLoadY = tidx / (B_X*imgsPerThread), hidActLoadX = tidx % + // (B_X*imgsPerThread); const int filtersLoadY = tidx / filterCacheF, filtersLoadX = tidx % filterCacheF; // nvcc is behaving idiotically again, these useless declarations save registers - //const int outputY = threadIdx.y, outputX = threadIdx.x; - //const int ty = threadIdx.y, tx = threadIdx.x; + // const int outputY = threadIdx.y, outputX = threadIdx.x; + // const int ty = threadIdx.y, tx = threadIdx.x; const int numModules = numModulesY * numModulesX; - const int hidActsOffset = (blockFilterIdx + threadIdx.y) * numImages * numModules + myCaseIdx; - const int filtersOffset = blockFilterIdx + (filterColorIdx + filtersLoadY) * filterPixels * numFilters + filtersLoadX; -// hidActs += (blockFilterIdx + threadIdx.y) * numImages * numModules + myCaseIdx; -// filters += blockFilterIdx + (filterColorIdx + filtersLoadY) * filterPixels * numFilters + filtersLoadX; - targets += (imgColorIdx + threadIdx.y) * imgPixels * numImages + blockPixelIdx * numImages + myCaseIdx; + const int hidActsOffset = + (blockFilterIdx + threadIdx.y) * numImages * numModules + myCaseIdx; + const int filtersOffset = + blockFilterIdx + + (filterColorIdx + filtersLoadY) * filterPixels * numFilters + filtersLoadX; + // hidActs += (blockFilterIdx + threadIdx.y) * numImages * numModules + + // myCaseIdx; filters += blockFilterIdx + (filterColorIdx + filtersLoadY) * + // filterPixels * numFilters + filtersLoadX; + targets += (imgColorIdx + threadIdx.y) * imgPixels * numImages + + blockPixelIdx * numImages + myCaseIdx; float prod[colorsPerThread][imgsPerThread]; - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { prod[c][i] = 0; } } - const int startY = blockPixelIdxY - paddingStart < filterSize ? 0 - : 1 + (blockPixelIdxY - paddingStart - filterSize) / moduleStride; - const int endY = min(numModulesY, 1 + (blockPixelIdxY - paddingStart) / moduleStride); - const int startX = blockPixelIdxX - paddingStart < filterSize ? 0 - : 1 + (blockPixelIdxX - paddingStart - filterSize) / moduleStride; - const int endX = min(numModulesX, 1 + (blockPixelIdxX - paddingStart) / moduleStride); + const int startY = + blockPixelIdxY - paddingStart < filterSize + ? 0 + : 1 + (blockPixelIdxY - paddingStart - filterSize) / moduleStride; + const int endY = + min(numModulesY, 1 + (blockPixelIdxY - paddingStart) / moduleStride); + const int startX = + blockPixelIdxX - paddingStart < filterSize + ? 0 + : 1 + (blockPixelIdxX - paddingStart - filterSize) / moduleStride; + const int endX = + min(numModulesX, 1 + (blockPixelIdxX - paddingStart) / moduleStride); float* shFilterLoad = &shFilters[filtersLoadY][filtersLoadX]; float* shHidActLoad = &shHidActs[threadIdx.y][threadIdx.x * imgsPerThread]; - //const bool noFLoop = filterCacheF == filterCacheH; + // const bool noFLoop = filterCacheF == filterCacheH; /* * Initial preload */ - float hPreload[filterCacheH/B_Y][imgsPerThread]; // [2][4] - float wPreload[filterCacheF*colorsPerThread/B_X]; // [8] + float hPreload[filterCacheH / B_Y][imgsPerThread]; // [2][4] + float wPreload[filterCacheF * colorsPerThread / B_X]; // [8] int moduleIdx, pxIdxInFilter; - conv_img_acts_manycolor_preload_ty_8_tx_32_c_8_ff_32_fh_16_setCoords(startY, startX, numModulesX, paddingStart, moduleStride, blockPixelIdxY, - blockPixelIdxX, filterSize, moduleIdx, pxIdxInFilter); -// const float* fLoad = conv ? &filters[pxIdxInFilter * numFilters + 0] -// : &filters[moduleIdx * numFilterColors * filterPixels * numFilters + pxIdxInFilter * numFilters + 0]; - int filtersLoadOffset = filtersOffset + (conv ? pxIdxInFilter * numFilters + 0 - : moduleIdx * numFilterColors * filterPixels * numFilters + pxIdxInFilter * numFilters); - #pragma unroll - for (int i = 0; i < colorsPerThread*B_Y; i+= B_X*B_Y/filterCacheF) { - if ((colorsPerThread*B_Y) % (B_X*B_Y/filterCacheF) == 0 || i + filtersLoadY < colorsPerThread*B_Y) { - wPreload[i * filterCacheF/(B_X*B_Y)] = tex1Dfetch(filters, filtersLoadOffset + i * filterPixels * numFilters); + conv_img_acts_manycolor_preload_ty_8_tx_32_c_8_ff_32_fh_16_setCoords( + startY, startX, numModulesX, paddingStart, moduleStride, blockPixelIdxY, + blockPixelIdxX, filterSize, moduleIdx, pxIdxInFilter); + // const float* fLoad = conv ? &filters[pxIdxInFilter * numFilters + 0] + // : &filters[moduleIdx * numFilterColors * filterPixels + // * numFilters + pxIdxInFilter * numFilters + 0]; + int filtersLoadOffset = + filtersOffset + + (conv ? pxIdxInFilter * numFilters + 0 + : moduleIdx * numFilterColors * filterPixels * numFilters + + pxIdxInFilter * numFilters); +#pragma unroll + for (int i = 0; i < colorsPerThread * B_Y; i += B_X * B_Y / filterCacheF) { + if ((colorsPerThread * B_Y) % (B_X * B_Y / filterCacheF) == 0 || + i + filtersLoadY < colorsPerThread * B_Y) { + wPreload[i * filterCacheF / (B_X * B_Y)] = tex1Dfetch( + filters, filtersLoadOffset + i * filterPixels * numFilters); } } -// const float* hLoad = &hidActs[(moduleIdx + 0 * numModules) * numImages]; + // const float* hLoad = &hidActs[(moduleIdx + 0 * numModules) * numImages]; int hidActsLoadOffset = hidActsOffset + (moduleIdx + 0 * numModules) * numImages; - #pragma unroll +#pragma unroll for (int j = 0; j < filterCacheH; j += B_Y) { if (filterCacheH % B_Y == 0 || threadIdx.y + j < filterCacheH) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkCaseBounds || myCaseIdx + i * B_X < numImages) { - hPreload[j/B_Y][i] = tex1Dfetch(hidActs, hidActsLoadOffset + j * numModules * numImages + i * B_X); + hPreload[j / B_Y][i] = tex1Dfetch( + hidActs, + hidActsLoadOffset + j * numModules * numImages + i * B_X); } } } @@ -209,31 +247,46 @@ conv_img_acts_manycolor_preloadfh_ty_8_tx_32_c_8_ff_32_fh_16_tex(cudaTextureObje mxNext = mx + 1 == endX ? startX : mx + 1; myNext = my + (mx + 1 == endX); } - conv_img_acts_manycolor_preload_ty_8_tx_32_c_8_ff_32_fh_16_setCoords(myNext, mxNext, numModulesX, paddingStart, moduleStride, blockPixelIdxY, - blockPixelIdxX, filterSize, moduleIdxNext, pxIdxInFilterNext); - for (int f = 0; f < numFiltersPerGroup; f += filterCacheF) { // multiply with filterCacheF filters at a time - #pragma unroll - for (int i = 0; i < colorsPerThread*B_Y; i+= B_X*B_Y/filterCacheF) { - if ((colorsPerThread*B_Y) % (B_X*B_Y/filterCacheF) == 0 || i + filtersLoadY < colorsPerThread*B_Y) { - shFilterLoad[i * filterCacheF] = wPreload[i * filterCacheF/(B_X*B_Y)]; + conv_img_acts_manycolor_preload_ty_8_tx_32_c_8_ff_32_fh_16_setCoords( + myNext, mxNext, numModulesX, paddingStart, moduleStride, + blockPixelIdxY, blockPixelIdxX, filterSize, moduleIdxNext, + pxIdxInFilterNext); + for (int f = 0; f < numFiltersPerGroup; + f += filterCacheF) { // multiply with filterCacheF filters at a time +#pragma unroll + for (int i = 0; i < colorsPerThread * B_Y; + i += B_X * B_Y / filterCacheF) { + if ((colorsPerThread * B_Y) % (B_X * B_Y / filterCacheF) == 0 || + i + filtersLoadY < colorsPerThread * B_Y) { + shFilterLoad[i * filterCacheF] = + wPreload[i * filterCacheF / (B_X * B_Y)]; } } - filtersLoadOffset = filtersOffset + (conv ? pxIdxInFilter * numFilters + f + filterCacheF - : moduleIdx * numFilterColors * filterPixels * numFilters + pxIdxInFilter * numFilters + f + filterCacheF); + filtersLoadOffset = + filtersOffset + + (conv ? pxIdxInFilter * numFilters + f + filterCacheF + : moduleIdx * numFilterColors * filterPixels * + numFilters + + pxIdxInFilter * numFilters + f + filterCacheF); if (f == numFiltersPerGroup - filterCacheF) { - filtersLoadOffset = filtersOffset + (conv ? pxIdxInFilterNext * numFilters - : moduleIdxNext * numFilterColors * filterPixels * numFilters + pxIdxInFilterNext * numFilters); + filtersLoadOffset = + filtersOffset + + (conv ? pxIdxInFilterNext * numFilters + : moduleIdxNext * numFilterColors * filterPixels * + numFilters + + pxIdxInFilterNext * numFilters); } - #pragma unroll +#pragma unroll for (int j = 0; j < filterCacheH; j += B_Y) { if (filterCacheH % B_Y == 0 || threadIdx.y + j < filterCacheH) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { // NOTE: bank conflicts here! if (!checkCaseBounds || myCaseIdx + i * B_X < numImages) { - shHidActLoad[j * B_X * imgsPerThread + i] = hPreload[j/B_Y][i]; + shHidActLoad[j * B_X * imgsPerThread + i] = + hPreload[j / B_Y][i]; } } } @@ -241,34 +294,37 @@ conv_img_acts_manycolor_preloadfh_ty_8_tx_32_c_8_ff_32_fh_16_tex(cudaTextureObje __syncthreads(); - hidActsLoadOffset = hidActsOffset + (moduleIdx + (f + filterCacheH) * numModules) * numImages; + hidActsLoadOffset = + hidActsOffset + + (moduleIdx + (f + filterCacheH) * numModules) * numImages; - #pragma unroll +#pragma unroll for (int z = 0; z < 4; ++z) { - IA_PRELOAD_LOOP(z,0); + IA_PRELOAD_LOOP(z, 0); IA_PRELOAD_W_TX(z); } - #pragma unroll +#pragma unroll for (int z = 4; z < 12; ++z) { - IA_PRELOAD_LOOP(z,0); - IA_PRELOAD_H_TX((z-4)/4,z%4); + IA_PRELOAD_LOOP(z, 0); + IA_PRELOAD_H_TX((z - 4) / 4, z % 4); } - #pragma unroll +#pragma unroll for (int z = 12; z < 16; ++z) { - IA_PRELOAD_LOOP(z,0); + IA_PRELOAD_LOOP(z, 0); } __syncthreads(); - #pragma unroll +#pragma unroll for (int j = 0; j < filterCacheH; j += B_Y) { if (filterCacheH % B_Y == 0 || threadIdx.y + j < filterCacheH) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkCaseBounds || myCaseIdx + i * B_X < numImages) { - shHidActLoad[j * B_X * imgsPerThread + i] = hPreload[j/B_Y][i]; + shHidActLoad[j * B_X * imgsPerThread + i] = + hPreload[j / B_Y][i]; } } } @@ -276,26 +332,28 @@ conv_img_acts_manycolor_preloadfh_ty_8_tx_32_c_8_ff_32_fh_16_tex(cudaTextureObje __syncthreads(); - hidActsLoadOffset = hidActsOffset + (moduleIdx + (f + filterCacheF) * numModules) * numImages; + hidActsLoadOffset = + hidActsOffset + + (moduleIdx + (f + filterCacheF) * numModules) * numImages; if (f == numFiltersPerGroup - filterCacheF) { hidActsLoadOffset = hidActsOffset + moduleIdxNext * numImages; } - #pragma unroll +#pragma unroll for (int z = 0; z < 4; ++z) { - IA_PRELOAD_LOOP(z,filterCacheH); - IA_PRELOAD_W_TX(z+4); + IA_PRELOAD_LOOP(z, filterCacheH); + IA_PRELOAD_W_TX(z + 4); } - #pragma unroll +#pragma unroll for (int z = 4; z < 12; ++z) { - IA_PRELOAD_LOOP(z,filterCacheH); - IA_PRELOAD_H_TX((z-4)/4, z%4); + IA_PRELOAD_LOOP(z, filterCacheH); + IA_PRELOAD_H_TX((z - 4) / 4, z % 4); } - #pragma unroll +#pragma unroll for (int z = 12; z < 16; ++z) { - IA_PRELOAD_LOOP(z,filterCacheH); + IA_PRELOAD_LOOP(z, filterCacheH); } __syncthreads(); @@ -303,51 +361,59 @@ conv_img_acts_manycolor_preloadfh_ty_8_tx_32_c_8_ff_32_fh_16_tex(cudaTextureObje } } if (scale) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkCaseBounds || myCaseIdx + i * B_X < numImages) { - targets[c * B_Y * imgPixels * numImages + i * B_X] = scaleTargets * targets[c * B_Y * imgPixels * numImages + i * B_X] + scaleOutputs * prod[c][i]; + targets[c * B_Y * imgPixels * numImages + i * B_X] = + scaleTargets * + targets[c * B_Y * imgPixels * numImages + i * B_X] + + scaleOutputs * prod[c][i]; } } } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkCaseBounds || myCaseIdx + i * B_X < numImages) { - targets[c * B_Y * imgPixels * numImages + i * B_X] = scaleOutputs * prod[c][i]; + targets[c * B_Y * imgPixels * numImages + i * B_X] = + scaleOutputs * prod[c][i]; } } } } } - -template +template < + int B_Y, int B_X, int imgsPerThread, int colorsPerThread, int filterCacheF, + int filterCacheH, bool scale, bool checkCaseBounds, bool conv> __global__ void //__launch_bounds__(128, 3) // 128 threads per block, 3 blocks per multiprocessor -conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16(cudaTextureObject_t hidActs, cudaTextureObject_t filters, float* targets, - const int numModulesY, const int numModulesX, const int numImages, const int numFilters, - const int filterSize, const int imgSizeY, const int imgSizeX, const int paddingStart, const int moduleStride, - const int numImgColors, const int numGroups, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shFilters[colorsPerThread*B_Y][filterCacheF]; - __shared__ float shHidActs[filterCacheH][B_X*imgsPerThread]; - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16( + cudaTextureObject_t hidActs, cudaTextureObject_t filters, float* targets, + const int numModulesY, const int numModulesX, const int numImages, + const int numFilters, const int filterSize, const int imgSizeY, + const int imgSizeX, const int paddingStart, const int moduleStride, + const int numImgColors, const int numGroups, const float scaleTargets, + const float scaleOutputs) { + __shared__ float shFilters[colorsPerThread * B_Y][filterCacheF]; + __shared__ float shHidActs[filterCacheH][B_X * imgsPerThread]; + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); - const int numImgBlocks = DIVUP(numImages,B_X*imgsPerThread); - const int blockCaseIdx = (blockIdx.x % numImgBlocks) * B_X*imgsPerThread; + const int numImgBlocks = DIVUP(numImages, B_X * imgsPerThread); + const int blockCaseIdx = (blockIdx.x % numImgBlocks) * B_X * imgsPerThread; const int myCaseIdx = blockCaseIdx + threadIdx.x; - const int imgColorIdx = (blockIdx.x / numImgBlocks) * B_Y*colorsPerThread; // color idx globally + const int imgColorIdx = + (blockIdx.x / numImgBlocks) * B_Y * colorsPerThread; // color idx globally const int numFilterColors = numImgColors / numGroups; const int blockGroupIdx = imgColorIdx / numFilterColors; - const int filterColorIdx = imgColorIdx % numFilterColors; // color idx within group + const int filterColorIdx = imgColorIdx % numFilterColors; // color idx within group const int numFiltersPerGroup = numFilters / numGroups; const int blockFilterIdx = blockGroupIdx * numFiltersPerGroup; @@ -358,70 +424,91 @@ conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16(cudaTextureObject_ const int filterPixels = filterSize * filterSize; const int imgPixels = imgSizeY * imgSizeX; const int tidx = threadIdx.y * B_X + threadIdx.x; -// const int hidActLoadY = threadIdx.y % B_Y, hidActLoadX = threadIdx.x % B_X; - //const int hidActLoadY = tidx / (B_X*imgsPerThread), hidActLoadX = tidx % (B_X*imgsPerThread); + // const int hidActLoadY = threadIdx.y % B_Y, hidActLoadX = threadIdx.x % B_X; + // const int hidActLoadY = tidx / (B_X*imgsPerThread), hidActLoadX = tidx % + // (B_X*imgsPerThread); const int filtersLoadY = tidx / filterCacheF, filtersLoadX = tidx % filterCacheF; // nvcc is behaving idiotically again, these useless declarations save registers - //const int outputY = threadIdx.y, outputX = threadIdx.x; - //const int ty = threadIdx.y, tx = threadIdx.x; + // const int outputY = threadIdx.y, outputX = threadIdx.x; + // const int ty = threadIdx.y, tx = threadIdx.x; const int numModules = numModulesY * numModulesX; - const int hidActsOffset = (blockFilterIdx + threadIdx.y) * numImages * numModules + myCaseIdx; - const int filtersOffset = blockFilterIdx + (filterColorIdx + filtersLoadY) * filterPixels * numFilters + filtersLoadX; + const int hidActsOffset = + (blockFilterIdx + threadIdx.y) * numImages * numModules + myCaseIdx; + const int filtersOffset = + blockFilterIdx + + (filterColorIdx + filtersLoadY) * filterPixels * numFilters + filtersLoadX; -// hidActs += (blockFilterIdx + threadIdx.y) * numImages * numModules + myCaseIdx; -// filters += blockFilterIdx + (filterColorIdx + filtersLoadY) * filterPixels * numFilters + filtersLoadX; - targets += (imgColorIdx + threadIdx.y) * imgPixels * numImages + blockPixelIdx * numImages + myCaseIdx; + // hidActs += (blockFilterIdx + threadIdx.y) * numImages * numModules + + // myCaseIdx; filters += blockFilterIdx + (filterColorIdx + filtersLoadY) * + // filterPixels * numFilters + filtersLoadX; + targets += (imgColorIdx + threadIdx.y) * imgPixels * numImages + + blockPixelIdx * numImages + myCaseIdx; float prod[colorsPerThread][imgsPerThread]; - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { prod[c][i] = 0; } } - const int startY = blockPixelIdxY - paddingStart < filterSize ? 0 - : 1 + (blockPixelIdxY - paddingStart - filterSize) / moduleStride; - const int endY = min(numModulesY, 1 + (blockPixelIdxY - paddingStart) / moduleStride); - const int startX = blockPixelIdxX - paddingStart < filterSize ? 0 - : 1 + (blockPixelIdxX - paddingStart - filterSize) / moduleStride; - const int endX = min(numModulesX, 1 + (blockPixelIdxX - paddingStart) / moduleStride); + const int startY = + blockPixelIdxY - paddingStart < filterSize + ? 0 + : 1 + (blockPixelIdxY - paddingStart - filterSize) / moduleStride; + const int endY = + min(numModulesY, 1 + (blockPixelIdxY - paddingStart) / moduleStride); + const int startX = + blockPixelIdxX - paddingStart < filterSize + ? 0 + : 1 + (blockPixelIdxX - paddingStart - filterSize) / moduleStride; + const int endX = + min(numModulesX, 1 + (blockPixelIdxX - paddingStart) / moduleStride); float* shFilterLoad = &shFilters[filtersLoadY][filtersLoadX]; float* shHidActLoad = &shHidActs[threadIdx.y][threadIdx.x * imgsPerThread]; - //const bool noFLoop = filterCacheF == filterCacheH; + // const bool noFLoop = filterCacheF == filterCacheH; /* * Initial preload */ - float hPreload[filterCacheH/B_Y][imgsPerThread]; // [4][4] - float wPreload[filterCacheF*colorsPerThread/B_X]; // [6] + float hPreload[filterCacheH / B_Y][imgsPerThread]; // [4][4] + float wPreload[filterCacheF * colorsPerThread / B_X]; // [6] int moduleIdx, pxIdxInFilter; - conv_img_acts_manycolor_preload_ty_8_tx_32_c_8_ff_32_fh_16_setCoords(startY, startX, numModulesX, paddingStart, moduleStride, blockPixelIdxY, - blockPixelIdxX, filterSize, moduleIdx, pxIdxInFilter); -// const float* fLoad = conv ? &filters[pxIdxInFilter * numFilters + 0] -// : &filters[moduleIdx * numFilterColors * filterPixels * numFilters + pxIdxInFilter * numFilters + 0]; - int filtersLoadOffset = filtersOffset + (conv ? pxIdxInFilter * numFilters - : moduleIdx * numFilterColors * filterPixels * numFilters + pxIdxInFilter * numFilters); - #pragma unroll - for (int i = 0; i < colorsPerThread*B_Y; i+= B_X*B_Y/filterCacheF) { - if ((colorsPerThread*B_Y) % (B_X*B_Y/filterCacheF) == 0 || i + filtersLoadY < colorsPerThread*B_Y) { - wPreload[i * filterCacheF/(B_X*B_Y)] = tex1Dfetch(filters, filtersLoadOffset + i * filterPixels * numFilters); + conv_img_acts_manycolor_preload_ty_8_tx_32_c_8_ff_32_fh_16_setCoords( + startY, startX, numModulesX, paddingStart, moduleStride, blockPixelIdxY, + blockPixelIdxX, filterSize, moduleIdx, pxIdxInFilter); + // const float* fLoad = conv ? &filters[pxIdxInFilter * numFilters + 0] + // : &filters[moduleIdx * numFilterColors * filterPixels + // * numFilters + pxIdxInFilter * numFilters + 0]; + int filtersLoadOffset = + filtersOffset + + (conv ? pxIdxInFilter * numFilters + : moduleIdx * numFilterColors * filterPixels * numFilters + + pxIdxInFilter * numFilters); +#pragma unroll + for (int i = 0; i < colorsPerThread * B_Y; i += B_X * B_Y / filterCacheF) { + if ((colorsPerThread * B_Y) % (B_X * B_Y / filterCacheF) == 0 || + i + filtersLoadY < colorsPerThread * B_Y) { + wPreload[i * filterCacheF / (B_X * B_Y)] = tex1Dfetch( + filters, filtersLoadOffset + i * filterPixels * numFilters); } } -// const float* hLoad = &hidActs[moduleIdx * numImages]; + // const float* hLoad = &hidActs[moduleIdx * numImages]; int hidActsLoadOffset = hidActsOffset + moduleIdx * numImages; - #pragma unroll +#pragma unroll for (int j = 0; j < filterCacheH; j += B_Y) { if (filterCacheH % B_Y == 0 || threadIdx.y + j < filterCacheH) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkCaseBounds || myCaseIdx + i * B_X < numImages) { - hPreload[j/B_Y][i] = tex1Dfetch(hidActs, hidActsLoadOffset + j * numModules * numImages + i * B_X); + hPreload[j / B_Y][i] = tex1Dfetch( + hidActs, + hidActsLoadOffset + j * numModules * numImages + i * B_X); } } } @@ -443,36 +530,53 @@ conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16(cudaTextureObject_ mxNext = mx + 1 == endX ? startX : mx + 1; myNext = my + (mx + 1 == endX); } - conv_img_acts_manycolor_preload_ty_8_tx_32_c_8_ff_32_fh_16_setCoords(myNext, mxNext, numModulesX, paddingStart, moduleStride, blockPixelIdxY, - blockPixelIdxX, filterSize, moduleIdxNext, pxIdxInFilterNext); - for (int f = 0; f < numFiltersPerGroup; f += filterCacheF) { // multiply with filterCacheF filters at a time - #pragma unroll - for (int i = 0; i < colorsPerThread*B_Y; i+= B_X*B_Y/filterCacheF) { - if ((colorsPerThread*B_Y) % (B_X*B_Y/filterCacheF) == 0 || i + filtersLoadY < colorsPerThread*B_Y) { - shFilterLoad[i * filterCacheF] = wPreload[i * filterCacheF/(B_X*B_Y)]; + conv_img_acts_manycolor_preload_ty_8_tx_32_c_8_ff_32_fh_16_setCoords( + myNext, mxNext, numModulesX, paddingStart, moduleStride, + blockPixelIdxY, blockPixelIdxX, filterSize, moduleIdxNext, + pxIdxInFilterNext); + for (int f = 0; f < numFiltersPerGroup; + f += filterCacheF) { // multiply with filterCacheF filters at a time +#pragma unroll + for (int i = 0; i < colorsPerThread * B_Y; + i += B_X * B_Y / filterCacheF) { + if ((colorsPerThread * B_Y) % (B_X * B_Y / filterCacheF) == 0 || + i + filtersLoadY < colorsPerThread * B_Y) { + shFilterLoad[i * filterCacheF] = + wPreload[i * filterCacheF / (B_X * B_Y)]; } } - filtersLoadOffset = filtersOffset + (conv ? pxIdxInFilter * numFilters + f + filterCacheF - : moduleIdx * numFilterColors * filterPixels * numFilters + pxIdxInFilter * numFilters + f + filterCacheF); + filtersLoadOffset = + filtersOffset + + (conv ? pxIdxInFilter * numFilters + f + filterCacheF + : moduleIdx * numFilterColors * filterPixels * + numFilters + + pxIdxInFilter * numFilters + f + filterCacheF); if (f == numFiltersPerGroup - filterCacheF) { - filtersLoadOffset = filtersOffset + (conv ? pxIdxInFilterNext * numFilters - : moduleIdxNext * numFilterColors * filterPixels * numFilters + pxIdxInFilterNext * numFilters); + filtersLoadOffset = + filtersOffset + + (conv ? pxIdxInFilterNext * numFilters + : moduleIdxNext * numFilterColors * filterPixels * + numFilters + + pxIdxInFilterNext * numFilters); } - #pragma unroll +#pragma unroll for (int j = 0; j < filterCacheH; j += B_Y) { if (filterCacheH % B_Y == 0 || threadIdx.y + j < filterCacheH) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { // NOTE: bank conflicts here! if (!checkCaseBounds || myCaseIdx + i * B_X < numImages) { - shHidActLoad[j * B_X * imgsPerThread + i] = hPreload[j/B_Y][i]; + shHidActLoad[j * B_X * imgsPerThread + i] = + hPreload[j / B_Y][i]; } } } } - hidActsLoadOffset = hidActsOffset + (moduleIdx + (f + filterCacheF) * numModules) * numImages; + hidActsLoadOffset = + hidActsOffset + + (moduleIdx + (f + filterCacheF) * numModules) * numImages; if (f == numFiltersPerGroup - filterCacheF) { hidActsLoadOffset = hidActsOffset + moduleIdxNext * numImages; } @@ -482,22 +586,22 @@ conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16(cudaTextureObject_ // It seems that there is no point explicitly interleaving loads // and computations because the scheduler does that anyway. - IA_PRELOAD_LOOP2(0,0); - IA_PRELOAD_LOOP2(1,0); - IA_PRELOAD_LOOP2(2,0); - IA_PRELOAD_LOOP2(3,0); - IA_PRELOAD_LOOP2(4,0); - IA_PRELOAD_LOOP2(5,0); - IA_PRELOAD_LOOP2(6,0); - IA_PRELOAD_LOOP2(7,0); - IA_PRELOAD_LOOP2(8,0); - IA_PRELOAD_LOOP2(9,0); - IA_PRELOAD_LOOP2(10,0); - IA_PRELOAD_LOOP2(11,0); - IA_PRELOAD_LOOP2(12,0); - IA_PRELOAD_LOOP2(13,0); - IA_PRELOAD_LOOP2(14,0); - IA_PRELOAD_LOOP2(15,0); + IA_PRELOAD_LOOP2(0, 0); + IA_PRELOAD_LOOP2(1, 0); + IA_PRELOAD_LOOP2(2, 0); + IA_PRELOAD_LOOP2(3, 0); + IA_PRELOAD_LOOP2(4, 0); + IA_PRELOAD_LOOP2(5, 0); + IA_PRELOAD_LOOP2(6, 0); + IA_PRELOAD_LOOP2(7, 0); + IA_PRELOAD_LOOP2(8, 0); + IA_PRELOAD_LOOP2(9, 0); + IA_PRELOAD_LOOP2(10, 0); + IA_PRELOAD_LOOP2(11, 0); + IA_PRELOAD_LOOP2(12, 0); + IA_PRELOAD_LOOP2(13, 0); + IA_PRELOAD_LOOP2(14, 0); + IA_PRELOAD_LOOP2(15, 0); IA_PRELOAD_W_TX(0); IA_PRELOAD_W_TX(1); @@ -506,44 +610,48 @@ conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16(cudaTextureObject_ IA_PRELOAD_W_TX(4); IA_PRELOAD_W_TX(5); - IA_PRELOAD_H_TX(0,0); - IA_PRELOAD_H_TX(0,1); - IA_PRELOAD_H_TX(0,2); - IA_PRELOAD_H_TX(0,3); - IA_PRELOAD_H_TX(1,0); - IA_PRELOAD_H_TX(1,1); - IA_PRELOAD_H_TX(1,2); - IA_PRELOAD_H_TX(1,3); - IA_PRELOAD_H_TX(2,0); - IA_PRELOAD_H_TX(2,1); - IA_PRELOAD_H_TX(2,2); - IA_PRELOAD_H_TX(2,3); - IA_PRELOAD_H_TX(3,0); - IA_PRELOAD_H_TX(3,1); - IA_PRELOAD_H_TX(3,2); - IA_PRELOAD_H_TX(3,3); + IA_PRELOAD_H_TX(0, 0); + IA_PRELOAD_H_TX(0, 1); + IA_PRELOAD_H_TX(0, 2); + IA_PRELOAD_H_TX(0, 3); + IA_PRELOAD_H_TX(1, 0); + IA_PRELOAD_H_TX(1, 1); + IA_PRELOAD_H_TX(1, 2); + IA_PRELOAD_H_TX(1, 3); + IA_PRELOAD_H_TX(2, 0); + IA_PRELOAD_H_TX(2, 1); + IA_PRELOAD_H_TX(2, 2); + IA_PRELOAD_H_TX(2, 3); + IA_PRELOAD_H_TX(3, 0); + IA_PRELOAD_H_TX(3, 1); + IA_PRELOAD_H_TX(3, 2); + IA_PRELOAD_H_TX(3, 3); __syncthreads(); } } } if (scale) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkCaseBounds || myCaseIdx + i * B_X < numImages) { - targets[c * B_Y * imgPixels * numImages + i * B_X] = scaleTargets * targets[c * B_Y * imgPixels * numImages + i * B_X] + scaleOutputs * prod[c][i]; + targets[c * B_Y * imgPixels * numImages + i * B_X] = + scaleTargets * + targets[c * B_Y * imgPixels * numImages + i * B_X] + + scaleOutputs * prod[c][i]; } } } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkCaseBounds || myCaseIdx + i * B_X < numImages) { - targets[c * B_Y * imgPixels * numImages + i * B_X] = scaleOutputs * prod[c][i]; + targets[c * B_Y * imgPixels * numImages + i * B_X] = + scaleOutputs * prod[c][i]; } } } @@ -561,9 +669,11 @@ conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16(cudaTextureObject_ * Other batch sizes will work, but but I made no attempt whatsoever * to make them work fast. */ -void _imgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, int numImgColors, int numGroups, - float scaleTargets, float scaleOutput, bool conv) { +void _imgActs( + cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, + int numImgColors, int numGroups, float scaleTargets, float scaleOutput, + bool conv) { int numFilterColors = numImgColors / numGroups; int numImages = hidActs.getNumCols(); int numFilters = filters.getNumCols(); @@ -575,15 +685,19 @@ void _imgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatri int numModulesX = numModules / numModulesY; megdnn_assert_internal(numImgColors % numGroups == 0); - //megdnn_assert_internal(numFilters % (16*numGroups) == 0); // TODO: insisting on 32 filters due to bug in calling code below. fix that. + // megdnn_assert_internal(numFilters % (16*numGroups) == 0); // TODO: insisting on + // 32 filters due to bug in calling code below. fix that. bool previous_limit = (numFilters % (16 * numGroups)) == 0; - megdnn_assert_internal(numGroups > 1 || (numImgColors > 0 && (numImgColors <= 3 || numImgColors % 2 == 0))); + megdnn_assert_internal( + numGroups > 1 || + (numImgColors > 0 && (numImgColors <= 3 || numImgColors % 2 == 0))); megdnn_assert_internal(numGroups == 1 || numFilterColors % 4 == 0); megdnn_assert_internal(filterPixels == filterSize * filterSize); megdnn_assert_internal(hidActs.getNumRows() == numModules * numFilters); - megdnn_assert_internal(filters.getNumRows() == filterModuleMult * numFilterColors * filterPixels); + megdnn_assert_internal( + filters.getNumRows() == filterModuleMult * numFilterColors * filterPixels); megdnn_assert_internal(numModules == numModulesY * numModulesX); megdnn_assert_internal(hidActs.isContiguous()); @@ -592,13 +706,16 @@ void _imgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatri megdnn_assert_internal(!hidActs.isTrans()); megdnn_assert_internal(!filters.isTrans()); megdnn_assert_internal(!targets.isTrans()); - // These routines don't handle the case when only part of the image is visited in the convolution + // These routines don't handle the case when only part of the image is visited in + // the convolution megdnn_assert_internal(paddingStart <= 0); - megdnn_assert_internal(paddingStart + (numModulesX-1)*moduleStride + filterSize >= imgSizeX); - megdnn_assert_internal(paddingStart + (numModulesY-1)*moduleStride + filterSize >= imgSizeY); + megdnn_assert_internal( + paddingStart + (numModulesX - 1) * moduleStride + filterSize >= imgSizeX); + megdnn_assert_internal( + paddingStart + (numModulesY - 1) * moduleStride + filterSize >= imgSizeY); megdnn_assert_internal(moduleStride <= filterSize); - megdnn_assert_internal(targets.isContiguous()); // no stride support here! + megdnn_assert_internal(targets.isContiguous()); // no stride support here! dim3 blocks; dim3 threads; @@ -609,46 +726,59 @@ void _imgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatri : numFilterColors % 48 == 0 ? 12 : numFilterColors % 32 == 0 ? 8 : numFilterColors % 16 == 0 ? 4 - : 2; + : 2; imgsPerThread = numImages % 128 == 0 ? 4 : numImages % 64 == 0 ? 2 : 1; megdnn_assert_internal(numFilterColors % (threads.y * colorsPerThread) == 0); - //previous_limit = numFilterColors % (threads.y * colorsPerThread) == 0; - - blocks = dim3(DIVUP(numImages, threads.x*imgsPerThread) * (numImgColors/(threads.y*colorsPerThread)), imgPixels); - // NOTE: the case when channels % 32 == 0 but channels % 48 != 0 and channels % 64 != 0 has not been optimized!! + // previous_limit = numFilterColors % (threads.y * colorsPerThread) == 0; + + blocks = + dim3(DIVUP(numImages, threads.x * imgsPerThread) * + (numImgColors / (threads.y * colorsPerThread)), + imgPixels); + // NOTE: the case when channels % 32 == 0 but channels % 48 != 0 and channels % + // 64 != 0 has not been optimized!! } else if (numFilterColors > 3) { // NOTE: THIS CASE HAS NOT BEEN OPTIMIZED FOR KEPLER!! imgsPerThread = numImages % 128 == 0 ? 8 : numImages % 64 == 0 ? 4 : 2; threads = dim3(16, 16); colorsPerThread = numFilterColors % 4 == 0 ? 4 : 2; - blocks = dim3(DIVUP(numImages,threads.x*imgsPerThread) * (numImgColors / colorsPerThread), DIVUP(imgSizeY,4) * DIVUP(imgSizeX,4)); + blocks = + dim3(DIVUP(numImages, threads.x * imgsPerThread) * + (numImgColors / colorsPerThread), + DIVUP(imgSizeY, 4) * DIVUP(imgSizeX, 4)); } else { // NOTE: THIS CASE HAS NOT BEEN OPTIMIZED FOR KEPLER!! imgsPerThread = numImages % 128 == 0 ? 8 : numImages % 64 == 0 ? 4 : 2; threads = dim3(16, 16); - blocks = dim3(DIVUP(numImages,threads.x*imgsPerThread), DIVUP(imgSizeY,4) * DIVUP(imgSizeX,4)); + blocks = + dim3(DIVUP(numImages, threads.x * imgsPerThread), + DIVUP(imgSizeY, 4) * DIVUP(imgSizeX, 4)); } bool checkCaseBounds = numImages % (threads.x * imgsPerThread) != 0; - if (scaleTargets == 0) { // do not scale or use targets matrix - targets.resize(numImgColors*imgPixels, numImages); + if (scaleTargets == 0) { // do not scale or use targets matrix + targets.resize(numImgColors * imgPixels, numImages); } else { megdnn_assert_internal(targets.getNumRows() == numImgColors * imgPixels); megdnn_assert_internal(targets.getNumCols() == numImages); } const bool scale = scaleTargets != 0; -// cudaFuncSetCacheConfig(conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16< 4, 32, 4, 12, 16, 16, false, false, true >, cudaFuncCachePreferShared); -// conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16< 4, 32, 4, 12, 16, 16, false, false, true ><<>>( -// hidActs.getTextureObject(), filters.getTextureObject(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, -// imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - - //return; -// printf("conv: %d\n", conv); -// printf("scale: %d\n", scale); -// printf("checkCaseBounds: %d\n", checkCaseBounds); -// printf("numFilterColors: %d\n", numFilterColors); -// printf("numImages: %d\n", numImages); -// cudaStream_t stream = NVMatrix::getDefaultStream(); + // cudaFuncSetCacheConfig(conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16< + // 4, 32, 4, 12, 16, 16, false, false, true >, cudaFuncCachePreferShared); + // conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16< 4, 32, 4, 12, + // 16, 16, false, false, true ><<>>( + // hidActs.getTextureObject(), filters.getTextureObject(), + // targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, + // filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, + // numImgColors, numGroups, scaleTargets, scaleOutput); + + // return; + // printf("conv: %d\n", conv); + // printf("scale: %d\n", scale); + // printf("checkCaseBounds: %d\n", checkCaseBounds); + // printf("numFilterColors: %d\n", numFilterColors); + // printf("numImages: %d\n", numImages); + // cudaStream_t stream = NVMatrix::getDefaultStream(); if (conv == false) { if (scale == false) { @@ -658,315 +788,993 @@ void _imgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatri if (numFilters % 32 == 0) { if (numImages % 128 == 0) { if (previous_limit) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_preloadfh_ty_8_tx_32_c_8_ff_32_fh_16_tex< 8, 32, 4, 8, 32, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_preloadfh_ty_8_tx_32_c_8_ff_32_fh_16_tex< 8, 32, 4, 8, 32, 16, false, false, false ><<>>(hidActs.getTextureObject(), filters.getTextureObject(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_preloadfh_ty_8_tx_32_c_8_ff_32_fh_16_tex< + 8, 32, 4, 8, 32, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_preloadfh_ty_8_tx_32_c_8_ff_32_fh_16_tex< + 8, 32, 4, 8, 32, 16, false, false, false> + <<>>( + hidActs.getTextureObject(), + filters.getTextureObject(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, + scaleTargets, scaleOutput); } else { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 8, 32, 4, 8, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 8, 32, 4, 8, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 8, 32, 4, 8, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 8, 32, 4, 8, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, + scaleTargets, scaleOutput); } + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 8, 32, 2, 8, 32, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 8, 32, 2, 8, 32, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 32, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 32, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 32, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 32, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 8, 32, 2, 8, 32, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 8, 32, 2, 8, 32, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 32, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 32, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 32, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 32, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - } - else if ((numFilters % 1 == 0)) { + } else if ((numFilters % 1 == 0)) { if (numImages % 128 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 8, 32, 4, 8, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 8, 32, 4, 8, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 8, 32, 2, 8, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 8, 32, 2, 8, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 8, 32, 4, 8, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 8, 32, 4, 8, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 8, 32, 2, 8, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 8, 32, 2, 8, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } - } - else if (numFilterColors % 48 == 0) { + } else if (numFilterColors % 48 == 0) { if ((numFilters % 1 == 0)) { if (numImages % 128 == 0) { if (previous_limit) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16< 4, 32, 4, 12, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16< 4, 32, 4, 12, 16, 16, false, false, false ><<>>(hidActs.getTextureObject(), filters.getTextureObject(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16< + 4, 32, 4, 12, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_preloadfh_ty_4_tx_32_c_12_ff_16_fh_16< + 4, 32, 4, 12, 16, 16, false, false, false> + <<>>( + hidActs.getTextureObject(), + filters.getTextureObject(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, + scaleTargets, scaleOutput); } else { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 4, 12, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 4, 12, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 4, 12, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 4, 12, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, + scaleTargets, scaleOutput); } - } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 2, 12, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 2, 12, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 12, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 12, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 12, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 12, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 2, 12, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 2, 12, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 12, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 12, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 12, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 12, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } - } - else if (numFilterColors % 32 == 0) { + } else if (numFilterColors % 32 == 0) { if (numFilters % 32 == 0) { if (numImages % 128 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 4, 8, 32, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 4, 8, 32, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 2, 8, 32, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 2, 8, 32, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 32, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 32, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 4, 8, 32, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 4, 8, 32, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 2, 8, 32, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 2, 8, 32, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 32, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 32, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 32, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 32, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 32, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 32, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - } - else if ((numFilters % 1 == 0)) { + } else if ((numFilters % 1 == 0)) { if (numImages % 128 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 4, 8, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 4, 8, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 2, 8, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 2, 8, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 4, 8, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 4, 8, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 2, 8, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 2, 8, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } - } - else if (numFilterColors % 16 == 0) { + } else if (numFilterColors % 16 == 0) { if ((numFilters % 1 == 0)) { if (numImages % 128 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 4, 4, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 4, 4, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 2, 4, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 2, 4, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 4, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 4, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 4, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 4, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 4, 4, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 4, 4, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 2, 4, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 2, 4, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 4, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 4, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 4, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 4, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } - } - else if (numFilterColors % 8 == 0) { + } else if (numFilterColors % 8 == 0) { if ((numFilters % 1 == 0)) { if (numImages % 128 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 4, 2, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 4, 2, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 2, 2, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 2, 2, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 2, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 2, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 2, 16, 16, false, false, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 2, 16, 16, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 4, 2, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 4, 2, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 2, 2, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 2, 2, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 2, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 2, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 2, 16, 16, false, false, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 2, 16, 16, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } } - } - else if (numFilterColors > 3) { + } else if (numFilterColors > 3) { if (numFilterColors == 4) { if ((numFilters % 1 == 0)) { if (numImages % 128 == 0) { - cudaFuncSetCacheConfig(img_acts_mediumcolor < 8, 4, false, false, false >, cudaFuncCachePreferShared); - img_acts_mediumcolor < 8, 4, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(img_acts_mediumcolor < 4, 4, false, false, false >, cudaFuncCachePreferShared); - img_acts_mediumcolor < 4, 4, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(img_acts_mediumcolor < 2, 4, false, false, false >, cudaFuncCachePreferShared); - img_acts_mediumcolor < 2, 4, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); - } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(img_acts_mediumcolor < 2, 4, false, false, false >, cudaFuncCachePreferShared); - img_acts_mediumcolor < 2, 4, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + img_acts_mediumcolor<8, 4, false, false, false>, + cudaFuncCachePreferShared); + img_acts_mediumcolor<8, 4, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + img_acts_mediumcolor<4, 4, false, false, false>, + cudaFuncCachePreferShared); + img_acts_mediumcolor<4, 4, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + img_acts_mediumcolor<2, 4, false, false, false>, + cudaFuncCachePreferShared); + img_acts_mediumcolor<2, 4, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + img_acts_mediumcolor<2, 4, false, false, false>, + cudaFuncCachePreferShared); + img_acts_mediumcolor<2, 4, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } - } - else if (numFilterColors == 2) { + } else if (numFilterColors == 2) { if ((numFilters % 1 == 0)) { if (numImages % 128 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 8, 2, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 8, 2, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 4, 2, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 4, 2, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 2, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 2, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 2, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 2, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + img_acts_color<8, 2, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<8, 2, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<4, 2, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<4, 2, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<2, 2, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<2, 2, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<2, 2, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<2, 2, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); } } } - } - else if (numFilterColors <= 3) { + } else if (numFilterColors <= 3) { if (numFilterColors == 3) { if ((numFilters % 1 == 0)) { if (numImages % 128 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 8, 3, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 8, 3, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 4, 3, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 4, 3, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 3, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 3, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 3, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 3, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + img_acts_color<8, 3, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<8, 3, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<4, 3, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<4, 3, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<2, 3, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<2, 3, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<2, 3, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<2, 3, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); } } - } - else if (numFilterColors == 2) { + } else if (numFilterColors == 2) { if ((numFilters % 1 == 0)) { if (numImages % 128 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 8, 2, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 8, 2, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 4, 2, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 4, 2, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 2, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 2, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 2, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 2, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + img_acts_color<8, 2, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<8, 2, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<4, 2, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<4, 2, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<2, 2, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<2, 2, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<2, 2, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<2, 2, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); } } - } - else if (numFilterColors == 1) { + } else if (numFilterColors == 1) { if ((numFilters % 1 == 0)) { if (numImages % 128 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 8, 1, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 8, 1, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 64 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 4, 1, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 4, 1, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 32 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 1, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 1, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); - } - else if (numImages % 16 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 1, false, false, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 1, false, false, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + img_acts_color<8, 1, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<8, 1, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 64 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<4, 1, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<4, 1, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 32 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<2, 1, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<2, 1, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); + } else if (numImages % 16 == 0) { + cudaFuncSetCacheConfig( + img_acts_color<2, 1, false, false, false>, + cudaFuncCachePreferShared); + img_acts_color<2, 1, false, false, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); } } } } - } - else if (checkCaseBounds == true) { + } else if (checkCaseBounds == true) { if (numFilterColors % 8 == 0) { if (numFilterColors % 64 == 0) { if (numFilters % 32 == 0) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 32, 16, false, true, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 32, 16, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 32, 16, false, true, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 32, 16, false, true, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } - } - else if ((numFilters % 1 == 0)) { + } else if ((numFilters % 1 == 0)) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 16, 16, false, true, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 8, 32, 1, 8, 16, 16, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 16, 16, false, true, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 8, 32, 1, 8, 16, 16, false, true, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } - } - else if (numFilterColors % 48 == 0) { + } else if (numFilterColors % 48 == 0) { if ((numFilters % 1 == 0)) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 12, 16, 16, false, true, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 12, 16, 16, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 12, 16, 16, false, true, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 12, 16, 16, false, true, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } - } - else if (numFilterColors % 32 == 0) { + } else if (numFilterColors % 32 == 0) { if (numFilters % 32 == 0) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 32, 16, false, true, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 32, 16, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 32, 16, false, true, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 32, 16, false, true, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } - } - else if ((numFilters % 1 == 0)) { + } else if ((numFilters % 1 == 0)) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 16, 16, false, true, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 8, 16, 16, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 16, 16, false, true, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 8, 16, 16, false, true, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } - } - else if (numFilterColors % 16 == 0) { + } else if (numFilterColors % 16 == 0) { if ((numFilters % 1 == 0)) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 4, 16, 16, false, true, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 4, 16, 16, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 4, 16, 16, false, true, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 4, 16, 16, false, true, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } - } - else if (numFilterColors % 8 == 0) { + } else if (numFilterColors % 8 == 0) { if ((numFilters % 1 == 0)) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(conv_img_acts_manycolor_kepler < 4, 32, 1, 2, 16, 16, false, true, false >, cudaFuncCachePreferShared); - conv_img_acts_manycolor_kepler < 4, 32, 1, 2, 16, 16, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_img_acts_manycolor_kepler< + 4, 32, 1, 2, 16, 16, false, true, + false>, + cudaFuncCachePreferShared); + conv_img_acts_manycolor_kepler< + 4, 32, 1, 2, 16, 16, false, true, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } } - } - else if (numFilterColors > 3) { + } else if (numFilterColors > 3) { if (numFilterColors == 4) { if ((numFilters % 1 == 0)) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(img_acts_mediumcolor < 2, 4, false, true, false >, cudaFuncCachePreferShared); - img_acts_mediumcolor < 2, 4, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + img_acts_mediumcolor<2, 4, false, true, false>, + cudaFuncCachePreferShared); + img_acts_mediumcolor<2, 4, false, true, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + numImgColors, numGroups, scaleTargets, + scaleOutput); } } } @@ -974,35 +1782,67 @@ void _imgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatri else if (numFilterColors == 2) { if ((numFilters % 1 == 0)) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 2, false, true, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 2, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig(img_acts_color < 2, 2, false, + true, false >, cudaFuncCachePreferShared); img_acts_color < 2, 2, + false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), + targets.getDevData(), numModulesY, numModulesX, numImages, + numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, + moduleStride, scaleTargets, scaleOutput); } } } */ - } - else if (numFilterColors <= 3) { + } else if (numFilterColors <= 3) { if (numFilterColors == 3) { if ((numFilters % 1 == 0)) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 3, false, true, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 3, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + img_acts_color<2, 3, false, true, false>, + cudaFuncCachePreferShared); + img_acts_color<2, 3, false, true, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); } } - } - else if (numFilterColors == 2) { + } else if (numFilterColors == 2) { if ((numFilters % 1 == 0)) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 2, false, true, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 2, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + img_acts_color<2, 2, false, true, false>, + cudaFuncCachePreferShared); + img_acts_color<2, 2, false, true, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); } } - } - else if (numFilterColors == 1) { + } else if (numFilterColors == 1) { if ((numFilters % 1 == 0)) { if (numImages % 1 == 0) { - cudaFuncSetCacheConfig(img_acts_color < 2, 1, false, true, false >, cudaFuncCachePreferShared); - img_acts_color < 2, 1, false, true, false ><<>>(hidActs.getDevData(), filters.getDevData(), targets.getDevData(), numModulesY, numModulesX, numImages, numFilters, filterSize, imgSizeY, imgSizeX, paddingStart, moduleStride, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + img_acts_color<2, 1, false, true, false>, + cudaFuncCachePreferShared); + img_acts_color<2, 1, false, true, false> + <<>>( + hidActs.getDevData(), + filters.getDevData(), + targets.getDevData(), numModulesY, + numModulesX, numImages, numFilters, + filterSize, imgSizeY, imgSizeX, + paddingStart, moduleStride, + scaleTargets, scaleOutput); } } } @@ -1014,29 +1854,43 @@ void _imgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatri getLastCudaError("imgActs: kernel execution failed"); } - -void convImgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, int numImgColors, int numGroups) { - _imgActs(stream, hidActs, filters, targets, imgSizeY, imgSizeX, numModulesY, paddingStart, moduleStride, numImgColors, numGroups, 0, 1, true); +void convImgActs( + cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, + int numImgColors, int numGroups) { + _imgActs( + stream, hidActs, filters, targets, imgSizeY, imgSizeX, numModulesY, + paddingStart, moduleStride, numImgColors, numGroups, 0, 1, true); } -void convImgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, int numImgColors, int numGroups, - float scaleTargets, float scaleOutput) { - _imgActs(stream, hidActs, filters, targets, imgSizeY, imgSizeX, numModulesY, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput, true); +void convImgActs( + cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, + int numImgColors, int numGroups, float scaleTargets, float scaleOutput) { + _imgActs( + stream, hidActs, filters, targets, imgSizeY, imgSizeX, numModulesY, + paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, + scaleOutput, true); } -void localImgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, int numImgColors, int numGroups) { - _imgActs(stream, hidActs, filters, targets, imgSizeY, imgSizeX, numModulesY, paddingStart, moduleStride, numImgColors, numGroups, 0, 1, false); +void localImgActs( + cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, + int numImgColors, int numGroups) { + _imgActs( + stream, hidActs, filters, targets, imgSizeY, imgSizeX, numModulesY, + paddingStart, moduleStride, numImgColors, numGroups, 0, 1, false); } -void localImgActs(cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, - int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, int numImgColors, int numGroups, - float scaleTargets, float scaleOutput) { - _imgActs(stream, hidActs, filters, targets, imgSizeY, imgSizeX, numModulesY, paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, scaleOutput, false); +void localImgActs( + cudaStream_t stream, NVMatrix& hidActs, NVMatrix& filters, NVMatrix& targets, + int imgSizeY, int imgSizeX, int numModulesY, int paddingStart, int moduleStride, + int numImgColors, int numGroups, float scaleTargets, float scaleOutput) { + _imgActs( + stream, hidActs, filters, targets, imgSizeY, imgSizeX, numModulesY, + paddingStart, moduleStride, numImgColors, numGroups, scaleTargets, + scaleOutput, false); } -} // namespace cuda -} // namespace megdnn - +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color.cuh b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color.cuh index f6c00cc3..52c746eb 100644 --- a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color.cuh @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "img_act_templates.cuh" @@ -42,8 +43,8 @@ namespace cuda { * threadIdx.y determines pixel. * * hidActs: (numFilters, numModulesY, numModulesX, numImages) - * filters: (numColors, filterPixels, numFilters) if conv - * (numModulesY, numModulesX, numColors, filterPixels, numFilters) otherwise + * filters: (numColors, filterPixels, numFilters) if + * conv (numModulesY, numModulesX, numColors, filterPixels, numFilters) otherwise * targets: (numColors, imgSizeY, imgSizeX, numImages) * * Each block reconstructs one 4x4 pixels from 16*imgsPerThread cases. @@ -57,18 +58,19 @@ namespace cuda { * This version conserves shared memory by loading 16 filters at a time rather than 32. */ template -__global__ void img_acts_color(const float* hidActs, const float* filters, float* targets, - const int numModulesY, const int numModulesX, const int numImages, const int numFilters, - const int filterSize, const int imgSizeY, const int imgSizeX, - const int paddingStart, const int moduleStride, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shFilters[numColors*16][16 + 1]; - __shared__ float shHidActs[16][16*imgsPerThread]; - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +__global__ void img_acts_color( + const float* hidActs, const float* filters, float* targets, + const int numModulesY, const int numModulesX, const int numImages, + const int numFilters, const int filterSize, const int imgSizeY, + const int imgSizeX, const int paddingStart, const int moduleStride, + const float scaleTargets, const float scaleOutputs) { + __shared__ float shFilters[numColors * 16][16 + 1]; + __shared__ float shHidActs[16][16 * imgsPerThread]; + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); - const int blockCaseIdx = blockIdx.x * 16*imgsPerThread; + const int blockCaseIdx = blockIdx.x * 16 * imgsPerThread; const int numRegionsX = DIVUP(imgSizeX, 4); const int blockRegionIdx = blockIdx.y; const int blockRegionIdxX = blockRegionIdx % numRegionsX; @@ -90,21 +92,26 @@ __global__ void img_acts_color(const float* hidActs, const float* filters, float filters += threadIdx.x; targets += pxIdx * numImages + blockCaseIdx + threadIdx.x; - float prod[numColors][imgsPerThread]; - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { prod[c][i] = 0; } } - const int startY = blockRegionTop - paddingStart < filterSize ? 0 - : 1 + (blockRegionTop - paddingStart - filterSize) / moduleStride; - const int endY = MIN(numModulesY, 1 + (blockRegionTop + 3 - paddingStart) / moduleStride); - const int startX = blockRegionLeft - paddingStart < filterSize ? 0 - : 1 + (blockRegionLeft - paddingStart - filterSize) / moduleStride; - const int endX = MIN(numModulesX, 1 + (blockRegionLeft + 3 - paddingStart) / moduleStride); + const int startY = + blockRegionTop - paddingStart < filterSize + ? 0 + : 1 + (blockRegionTop - paddingStart - filterSize) / moduleStride; + const int endY = + MIN(numModulesY, 1 + (blockRegionTop + 3 - paddingStart) / moduleStride); + const int startX = + blockRegionLeft - paddingStart < filterSize + ? 0 + : 1 + (blockRegionLeft - paddingStart - filterSize) / moduleStride; + const int endX = + MIN(numModulesX, 1 + (blockRegionLeft + 3 - paddingStart) / moduleStride); float* shilterLoad = &shFilters[threadIdx.y][threadIdx.x]; float* shHidActLoad = &shHidActs[loadY][loadX]; @@ -118,59 +125,73 @@ __global__ void img_acts_color(const float* hidActs, const float* filters, float const int moduleLeft = paddingStart + mx * moduleStride; const int pxInModuleX = pxX - moduleLeft; - const bool isPxInModule = pxInModuleY >= 0 && pxInModuleY < filterSize && pxInModuleX >= 0 && pxInModuleX < filterSize; + const bool isPxInModule = pxInModuleY >= 0 && pxInModuleY < filterSize && + pxInModuleX >= 0 && pxInModuleX < filterSize; const int pxIdxInModule = pxInModuleY * filterSize + pxInModuleX; - for (int f = 0; f < numFilters; f += 16) { // multiply with 16 filters at a time - // Now the threads split up into half-warps, and each half-warp decides if it's interested. + for (int f = 0; f < numFilters; + f += 16) { // multiply with 16 filters at a time + // Now the threads split up into half-warps, and each half-warp decides + // if it's interested. const float* hLoad = &hidActs[(moduleIdx + f * numModules) * numImages]; - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread * 16; i += 32) { if (!checkCaseBounds || blockCaseIdx + i + loadX < numImages) { - #pragma unroll - for (int j = 0; j < 16; j += 8) { // load 16 rows of imgsPerThread*16 cols, 8 * 32 elements at a time. +#pragma unroll + for (int j = 0; j < 16; + j += 8) { // load 16 rows of imgsPerThread*16 cols, 8 * 32 + // elements at a time. if (f + loadY + j < numFilters) { - shHidActLoad[j * 16 * imgsPerThread + i] = hLoad[j * numModules * numImages + i]; + shHidActLoad[j * 16 * imgsPerThread + i] = + hLoad[j * numModules * numImages + i]; } else { shHidActLoad[j * 16 * imgsPerThread + i] = 0; } } } else { - #pragma unroll - for (int j = 0; j < 16; j += 8) { // load 16 rows of imgsPerThread*16 cols, 8 * 32 elements at a time. +#pragma unroll + for (int j = 0; j < 16; + j += 8) { // load 16 rows of imgsPerThread*16 cols, 8 * 32 + // elements at a time. shHidActLoad[j * 16 * imgsPerThread + i] = 0; } } } if (isPxInImg && isPxInModule) { - // This half-warp is interested, so it's going to load the weights from this module to its pixel. - // Not fully coalesced read :( - // But taking out this read entirely only reduces the runtime by ~2.8%, so it isn't costing me much. - const float* fLoad = conv ? &filters[pxIdxInModule * numFilters + f] - : &filters[(moduleIdx * numColors * filterPixels + pxIdxInModule) * numFilters + f]; - #pragma unroll + // This half-warp is interested, so it's going to load the weights + // from this module to its pixel. Not fully coalesced read :( But + // taking out this read entirely only reduces the runtime by ~2.8%, + // so it isn't costing me much. + const float* fLoad = + conv ? &filters[pxIdxInModule * numFilters + f] + : &filters + [(moduleIdx * numColors * filterPixels + + pxIdxInModule) * + numFilters + + f]; +#pragma unroll for (int c = 0; c < numColors; c++) { if (f + threadIdx.x < numFilters) { - shilterLoad[c * 16 * (16 + 1)] = fLoad[c * filterPixels * numFilters]; + shilterLoad[c * 16 * (16 + 1)] = + fLoad[c * filterPixels * numFilters]; } else { shilterLoad[c * 16 * (16 + 1)] = 0; } } - - } __syncthreads(); // Do some actual computation if (isPxInImg && isPxInModule) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int w = 0; w < 16; w++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - prod[c][i] += shFilters[threadIdx.y + c * 16][w] * shHidActs[w][threadIdx.x + i * 16]; + prod[c][i] += shFilters[threadIdx.y + c * 16][w] * + shHidActs[w][threadIdx.x + i * 16]; } } } @@ -179,25 +200,32 @@ __global__ void img_acts_color(const float* hidActs, const float* filters, float } } } - // Not fully coalesced write :(... shmem (and fully coalesced) version is actually slightly slower, though + // Not fully coalesced write :(... shmem (and fully coalesced) version is actually + // slightly slower, though if (isPxInImg) { if (scale) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - if (!checkCaseBounds || blockCaseIdx + threadIdx.x + i * 16 < numImages) { - #pragma unroll + if (!checkCaseBounds || + blockCaseIdx + threadIdx.x + i * 16 < numImages) { +#pragma unroll for (int c = 0; c < numColors; c++) { - targets[c * imgPixels * numImages + i * 16] = scaleTargets * targets[c * imgPixels * numImages + i * 16] + scaleOutputs * prod[c][i]; + targets[c * imgPixels * numImages + i * 16] = + scaleTargets * + targets[c * imgPixels * numImages + i * 16] + + scaleOutputs * prod[c][i]; } } } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - if (!checkCaseBounds || blockCaseIdx + threadIdx.x + i * 16 < numImages) { - #pragma unroll + if (!checkCaseBounds || + blockCaseIdx + threadIdx.x + i * 16 < numImages) { +#pragma unroll for (int c = 0; c < numColors; c++) { - targets[c * imgPixels * numImages + i * 16] = scaleOutputs * prod[c][i]; + targets[c * imgPixels * numImages + i * 16] = + scaleOutputs * prod[c][i]; } } } @@ -206,16 +234,16 @@ __global__ void img_acts_color(const float* hidActs, const float* filters, float } #define IMG_COLOR_K_HEAD template __global__ void img_acts_color -#define IMG_COLOR_K(scale, ckCase, conv) \ - IMG_COLOR_K_HEAD < 8, 2, scale, ckCase, conv >(COLOR_KEP_PARAM); \ - IMG_COLOR_K_HEAD < 4, 2, scale, ckCase, conv >(COLOR_KEP_PARAM); \ - IMG_COLOR_K_HEAD < 2, 2, scale, ckCase, conv >(COLOR_KEP_PARAM); \ - IMG_COLOR_K_HEAD < 8, 3, scale, ckCase, conv >(COLOR_KEP_PARAM); \ - IMG_COLOR_K_HEAD < 4, 3, scale, ckCase, conv >(COLOR_KEP_PARAM); \ - IMG_COLOR_K_HEAD < 2, 3, scale, ckCase, conv >(COLOR_KEP_PARAM); \ - IMG_COLOR_K_HEAD < 8, 1, scale, ckCase, conv >(COLOR_KEP_PARAM); \ - IMG_COLOR_K_HEAD < 4, 1, scale, ckCase, conv >(COLOR_KEP_PARAM); \ - IMG_COLOR_K_HEAD < 2, 1, scale, ckCase, conv >(COLOR_KEP_PARAM); - -} // namespace cuda -} // namespace megdnn +#define IMG_COLOR_K(scale, ckCase, conv) \ + IMG_COLOR_K_HEAD<8, 2, scale, ckCase, conv>(COLOR_KEP_PARAM); \ + IMG_COLOR_K_HEAD<4, 2, scale, ckCase, conv>(COLOR_KEP_PARAM); \ + IMG_COLOR_K_HEAD<2, 2, scale, ckCase, conv>(COLOR_KEP_PARAM); \ + IMG_COLOR_K_HEAD<8, 3, scale, ckCase, conv>(COLOR_KEP_PARAM); \ + IMG_COLOR_K_HEAD<4, 3, scale, ckCase, conv>(COLOR_KEP_PARAM); \ + IMG_COLOR_K_HEAD<2, 3, scale, ckCase, conv>(COLOR_KEP_PARAM); \ + IMG_COLOR_K_HEAD<8, 1, scale, ckCase, conv>(COLOR_KEP_PARAM); \ + IMG_COLOR_K_HEAD<4, 1, scale, ckCase, conv>(COLOR_KEP_PARAM); \ + IMG_COLOR_K_HEAD<2, 1, scale, ckCase, conv>(COLOR_KEP_PARAM); + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color_ff.cu b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color_ff.cu index 514026a5..3509d36b 100644 --- a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color_ff.cu +++ b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color_ff.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "img_act_color.cuh" @@ -34,7 +35,7 @@ namespace megdnn { namespace cuda { IMG_COLOR_K(false, false, false) -//IMG_COLOR_K(false, false, true) +// IMG_COLOR_K(false, false, true) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color_ft.cu b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color_ft.cu index 20846ad5..fd7ea21d 100644 --- a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color_ft.cu +++ b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_color_ft.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "img_act_color.cuh" @@ -34,7 +35,7 @@ namespace megdnn { namespace cuda { IMG_COLOR_K(false, true, false) -//IMG_COLOR_K(false, true, true) +// IMG_COLOR_K(false, true, true) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor.cuh b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor.cuh index ee2b0af8..a2bf228d 100644 --- a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor.cuh @@ -25,13 +25,15 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ /* * Block size: B_YxB_X. - * blockIdx.x determines case in batches of B_X*imgsPerThread, also color in batches of B_Y*colorsPerThread. + * blockIdx.x determines case in batches of B_X*imgsPerThread, also color in batches of + B_Y*colorsPerThread. * In essence, blockIdx.x.x = 1..numImages/(B_X*imgsPerThread) * blockIdx.x.y = 1..numImgColors/(B_Y*colorsPerThread) * blockIdx.y determines image pixel in target image. @@ -40,11 +42,13 @@ * threadIdx.y determines color. * * hidActs: (numFilters, numModulesY, numModulesX, numImages) - * filters: (numFilterColors, filterPixels, numFilters) if conv - * (numModulesY, numModulesX, numFilterColors, filterPixels, numFilters) otherwise + * filters: (numFilterColors, filterPixels, numFilters) if conv + * (numModulesY, numModulesX, numFilterColors, filterPixels, numFilters) + otherwise * targets: (numImageColors, imgSizeY, imgSizeX, numImages) * - * Each block reconstructs one B_Y*colorsPerThread colors from 1 pixel from B_X*imgsPerThread cases. + * Each block reconstructs one B_Y*colorsPerThread colors from 1 pixel from + B_X*imgsPerThread cases. * * numImages must be divisible by B_X*imgsPerThread if checkCaseBounds is false. * numFiltersPerGroup must be divisible by filterCache. @@ -56,29 +60,35 @@ * B_X*B_Y must be divisible by filterCache * This version loads 32 cases at a time, so it gets full coalescing on that load. - * It only loads filterCache weights at a time, so those aren't fully coalesced (depending on size of filterCache). + * It only loads filterCache weights at a time, so those aren't fully coalesced + (depending on size of filterCache). * * To be used when there are >= 16 color channels. */ -template -__global__ void conv_img_acts_manycolor(const float* hidActs, const float* filters, float* targets, - const int numModulesY, const int numModulesX, const int numImages, const int numFilters, - const int filterSize, const int imgSizeY, const int imgSizeX, const int paddingStart, const int moduleStride, - const int numImgColors, const int numGroups, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shFilters[colorsPerThread*B_Y][filterCache + 1]; - __shared__ float shHidActs[filterCache][B_X*imgsPerThread]; - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +template < + int B_Y, int B_X, int imgsPerThread, int colorsPerThread, int filterCache, + bool scale, bool checkCaseBounds, bool conv> +__global__ void conv_img_acts_manycolor( + const float* hidActs, const float* filters, float* targets, + const int numModulesY, const int numModulesX, const int numImages, + const int numFilters, const int filterSize, const int imgSizeY, + const int imgSizeX, const int paddingStart, const int moduleStride, + const int numImgColors, const int numGroups, const float scaleTargets, + const float scaleOutputs) { + __shared__ float shFilters[colorsPerThread * B_Y][filterCache + 1]; + __shared__ float shHidActs[filterCache][B_X * imgsPerThread]; + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); - const int numImgBlocks = DIVUP(numImages,B_X*imgsPerThread); - const int blockCaseIdx = (blockIdx.x % numImgBlocks) * B_X*imgsPerThread; + const int numImgBlocks = DIVUP(numImages, B_X * imgsPerThread); + const int blockCaseIdx = (blockIdx.x % numImgBlocks) * B_X * imgsPerThread; - const int imgColorIdx = (blockIdx.x / numImgBlocks) * B_Y*colorsPerThread; // color idx globally + const int imgColorIdx = + (blockIdx.x / numImgBlocks) * B_Y * colorsPerThread; // color idx globally const int numFilterColors = numImgColors / numGroups; const int blockGroupIdx = imgColorIdx / numFilterColors; - const int filterColorIdx = imgColorIdx % numFilterColors; // color idx within group + const int filterColorIdx = imgColorIdx % numFilterColors; // color idx within group const int numFiltersPerGroup = numFilters / numGroups; const int blockFilterIdx = blockGroupIdx * numFiltersPerGroup; @@ -93,25 +103,35 @@ __global__ void conv_img_acts_manycolor(const float* hidActs, const float* filte const int filtersLoadY = tidx / filterCache, filtersLoadX = tidx % filterCache; const int numModules = numModulesY * numModulesX; - hidActs += blockCaseIdx + (blockFilterIdx + hidActLoadY) * numImages * numModules + hidActLoadX; - filters += blockFilterIdx + (filterColorIdx + filtersLoadY) * filterPixels * numFilters + filtersLoadX; - targets += (imgColorIdx + threadIdx.y) * imgPixels * numImages + blockPixelIdx * numImages + blockCaseIdx + threadIdx.x; + hidActs += blockCaseIdx + (blockFilterIdx + hidActLoadY) * numImages * numModules + + hidActLoadX; + filters += blockFilterIdx + + (filterColorIdx + filtersLoadY) * filterPixels * numFilters + + filtersLoadX; + targets += (imgColorIdx + threadIdx.y) * imgPixels * numImages + + blockPixelIdx * numImages + blockCaseIdx + threadIdx.x; float prod[colorsPerThread][imgsPerThread]; - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { prod[c][i] = 0; } } - const int startY = blockPixelIdxY - paddingStart < filterSize ? 0 - : 1 + (blockPixelIdxY - paddingStart - filterSize) / moduleStride; - const int endY = MIN(numModulesY, 1 + (blockPixelIdxY - paddingStart) / moduleStride); - const int startX = blockPixelIdxX - paddingStart < filterSize ? 0 - : 1 + (blockPixelIdxX - paddingStart - filterSize) / moduleStride; - const int endX = MIN(numModulesX, 1 + (blockPixelIdxX - paddingStart) / moduleStride); + const int startY = + blockPixelIdxY - paddingStart < filterSize + ? 0 + : 1 + (blockPixelIdxY - paddingStart - filterSize) / moduleStride; + const int endY = + MIN(numModulesY, 1 + (blockPixelIdxY - paddingStart) / moduleStride); + const int startX = + blockPixelIdxX - paddingStart < filterSize + ? 0 + : 1 + (blockPixelIdxX - paddingStart - filterSize) / moduleStride; + const int endX = + MIN(numModulesX, 1 + (blockPixelIdxX - paddingStart) / moduleStride); float* shFilterLoad = &shFilters[filtersLoadY][filtersLoadX]; float* shHidActLoad = &shHidActs[hidActLoadY][hidActLoadX]; @@ -127,40 +147,57 @@ __global__ void conv_img_acts_manycolor(const float* hidActs, const float* filte const int pxIdxInFilter = pxInFilterY * filterSize + pxInFilterX; - for (int f = 0; f < numFiltersPerGroup; f += filterCache) { // multiply with filterCache filters at a time + for (int f = 0; f < numFiltersPerGroup; + f += filterCache) { // multiply with filterCache filters at a time const float* hLoad = &hidActs[(moduleIdx + f * numModules) * numImages]; - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread * B_X; i += 32) { - if (!checkCaseBounds || blockCaseIdx + hidActLoadX + i < numImages) { - #pragma unroll - for (int j = 0; j < filterCache; j += B_X*B_Y/32) { // load filterCache rows of imgsPerThread*B_X cols, 8 * 32 elements at a time. - shHidActLoad[j * B_X * imgsPerThread + i] = hLoad[j * numModules * numImages + i]; + if (!checkCaseBounds || + blockCaseIdx + hidActLoadX + i < numImages) { +#pragma unroll + for (int j = 0; j < filterCache; + j += B_X * B_Y / + 32) { // load filterCache rows of imgsPerThread*B_X + // cols, 8 * 32 elements at a time. + shHidActLoad[j * B_X * imgsPerThread + i] = + hLoad[j * numModules * numImages + i]; } } else { - #pragma unroll - for (int j = 0; j < filterCache; j += B_X*B_Y/32) { // load filterCache rows of imgsPerThread*B_X cols, 8 * 32 elements at a time. +#pragma unroll + for (int j = 0; j < filterCache; + j += B_X * B_Y / + 32) { // load filterCache rows of imgsPerThread*B_X + // cols, 8 * 32 elements at a time. shHidActLoad[j * B_X * imgsPerThread + i] = 0; } } } - const float* fLoad = conv ? &filters[pxIdxInFilter * numFilters + f] - : &filters[moduleIdx * numFilterColors * filterPixels * numFilters + pxIdxInFilter * numFilters + f]; - #pragma unroll - for (int i = 0; i < colorsPerThread*B_Y; i+= B_X*B_Y/filterCache) { - if ((colorsPerThread*B_Y) % (B_X*B_Y/filterCache) == 0 || i + filtersLoadY < colorsPerThread*B_Y) { - shFilterLoad[i * (filterCache + 1)] = fLoad[i * filterPixels * numFilters]; + const float* fLoad = + conv ? &filters[pxIdxInFilter * numFilters + f] + : &filters + [moduleIdx * numFilterColors * filterPixels * + numFilters + + pxIdxInFilter * numFilters + f]; +#pragma unroll + for (int i = 0; i < colorsPerThread * B_Y; + i += B_X * B_Y / filterCache) { + if ((colorsPerThread * B_Y) % (B_X * B_Y / filterCache) == 0 || + i + filtersLoadY < colorsPerThread * B_Y) { + shFilterLoad[i * (filterCache + 1)] = + fLoad[i * filterPixels * numFilters]; } } __syncthreads(); - // Do some actual computation - #pragma unroll +// Do some actual computation +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - #pragma unroll +#pragma unroll for (int w = 0; w < filterCache; w++) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - prod[c][i] += shFilters[c * B_Y + threadIdx.y][w] * shHidActs[w][threadIdx.x + i * B_X]; + prod[c][i] += shFilters[c * B_Y + threadIdx.y][w] * + shHidActs[w][threadIdx.x + i * B_X]; } } } @@ -169,22 +206,26 @@ __global__ void conv_img_acts_manycolor(const float* hidActs, const float* filte } } if (scale) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkCaseBounds || blockCaseIdx + threadIdx.x + i * B_X < numImages) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - targets[c * B_Y * imgPixels * numImages + i * B_X] = scaleTargets * targets[c * B_Y * imgPixels * numImages + i * B_X] + scaleOutputs * prod[c][i]; + targets[c * B_Y * imgPixels * numImages + i * B_X] = + scaleTargets * + targets[c * B_Y * imgPixels * numImages + i * B_X] + + scaleOutputs * prod[c][i]; } } } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkCaseBounds || blockCaseIdx + threadIdx.x + i * B_X < numImages) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - targets[c * B_Y * imgPixels * numImages + i * B_X] = scaleOutputs * prod[c][i]; + targets[c * B_Y * imgPixels * numImages + i * B_X] = + scaleOutputs * prod[c][i]; } } } diff --git a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler.cuh b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler.cuh index 48c3c716..1977c118 100644 --- a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler.cuh @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "img_act_templates.cuh" @@ -35,8 +36,8 @@ namespace cuda { /* * Block size: B_YxB_X. - * blockIdx.x determines case in batches of B_X*imgsPerThread, also color in batches of B_Y*colorsPerThread. - * In essence, blockIdx.x.x = 1..numImages/(B_X*imgsPerThread) + * blockIdx.x determines case in batches of B_X*imgsPerThread, also color in batches of + * B_Y*colorsPerThread. In essence, blockIdx.x.x = 1..numImages/(B_X*imgsPerThread) * blockIdx.x.y = 1..numImgColors/(B_Y*colorsPerThread) * blockIdx.y determines image pixel in target image. * @@ -44,11 +45,12 @@ namespace cuda { * threadIdx.y determines color. * * hidActs: (numFilters, numModulesY, numModulesX, numImages) - * filters: (numFilterColors, filterPixels, numFilters) if conv - * (numModulesY, numModulesX, numFilterColors, filterPixels, numFilters) otherwise - * targets: (numImageColors, imgSizeY, imgSizeX, numImages) + * filters: (numFilterColors, filterPixels, numFilters) if conv (numModulesY, + * numModulesX, numFilterColors, filterPixels, numFilters) otherwise targets: + * (numImageColors, imgSizeY, imgSizeX, numImages) * - * Each block reconstructs one B_Y*colorsPerThread colors from 1 pixel from B_X*imgsPerThread cases. + * Each block reconstructs one B_Y*colorsPerThread colors from 1 pixel from + * B_X*imgsPerThread cases. * * numImages must be divisible by B_X*imgsPerThread if checkCaseBounds is false. * numFiltersPerGroup must be divisible by filterCacheF. @@ -58,29 +60,35 @@ namespace cuda { * filterCacheF must be divisible by filterCacheH * * This version loads 32 cases at a time, so it gets full coalescing on that load. - * It only loads filterCacheF weights at a time, so those aren't fully coalesced (depending on size of filterCacheF). + * It only loads filterCacheF weights at a time, so those aren't fully coalesced + * (depending on size of filterCacheF). * * To be used when there are >= 16 color channels. */ -template -__global__ void conv_img_acts_manycolor_kepler(const float* hidActs, const float* filters, float* targets, - const int numModulesY, const int numModulesX, const int numImages, const int numFilters, - const int filterSize, const int imgSizeY, const int imgSizeX, const int paddingStart, const int moduleStride, - const int numImgColors, const int numGroups, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shFilters[colorsPerThread*B_Y][filterCacheF]; - __shared__ float shHidActs[filterCacheH][B_X*imgsPerThread]; - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +template < + int B_Y, int B_X, int imgsPerThread, int colorsPerThread, int filterCacheF, + int filterCacheH, bool scale, bool checkCaseBounds, bool conv> +__global__ void conv_img_acts_manycolor_kepler( + const float* hidActs, const float* filters, float* targets, + const int numModulesY, const int numModulesX, const int numImages, + const int numFilters, const int filterSize, const int imgSizeY, + const int imgSizeX, const int paddingStart, const int moduleStride, + const int numImgColors, const int numGroups, const float scaleTargets, + const float scaleOutputs) { + __shared__ float shFilters[colorsPerThread * B_Y][filterCacheF]; + __shared__ float shHidActs[filterCacheH][B_X * imgsPerThread]; + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); - const int numImgBlocks = DIVUP(numImages,B_X*imgsPerThread); - const int blockCaseIdx = (blockIdx.x % numImgBlocks) * B_X*imgsPerThread; + const int numImgBlocks = DIVUP(numImages, B_X * imgsPerThread); + const int blockCaseIdx = (blockIdx.x % numImgBlocks) * B_X * imgsPerThread; - const int imgColorIdx = (blockIdx.x / numImgBlocks) * B_Y*colorsPerThread; // color idx globally + const int imgColorIdx = + (blockIdx.x / numImgBlocks) * B_Y * colorsPerThread; // color idx globally const int numFilterColors = numImgColors / numGroups; const int blockGroupIdx = imgColorIdx / numFilterColors; - const int filterColorIdx = imgColorIdx % numFilterColors; // color idx within group + const int filterColorIdx = imgColorIdx % numFilterColors; // color idx within group const int numFiltersPerGroup = numFilters / numGroups; const int blockFilterIdx = blockGroupIdx * numFiltersPerGroup; @@ -92,37 +100,48 @@ __global__ void conv_img_acts_manycolor_kepler(const float* hidActs, const float const int imgPixels = imgSizeY * imgSizeX; const int tidx = threadIdx.y * B_X + threadIdx.x; const int hidActLoadY = threadIdx.y, hidActLoadX = threadIdx.x; - //const int hidActLoadY = tidx / (B_X*imgsPerThread), hidActLoadX = tidx % (B_X*imgsPerThread); + // const int hidActLoadY = tidx / (B_X*imgsPerThread), hidActLoadX = tidx % + // (B_X*imgsPerThread); const int filtersLoadY = tidx / filterCacheF, filtersLoadX = tidx % filterCacheF; // nvcc is behaving idiotically again, these useless declarations save registers - //const int outputY = threadIdx.y, outputX = threadIdx.x; - //const int ty = threadIdx.y, tx = threadIdx.x; + // const int outputY = threadIdx.y, outputX = threadIdx.x; + // const int ty = threadIdx.y, tx = threadIdx.x; const int numModules = numModulesY * numModulesX; - hidActs += blockCaseIdx + (blockFilterIdx + hidActLoadY) * numImages * numModules + hidActLoadX; - filters += blockFilterIdx + (filterColorIdx + filtersLoadY) * filterPixels * numFilters + filtersLoadX; - targets += (imgColorIdx + threadIdx.y) * imgPixels * numImages + blockPixelIdx * numImages + blockCaseIdx + threadIdx.x; - //bool active_t = filtersLoadX < numFilters; + hidActs += blockCaseIdx + (blockFilterIdx + hidActLoadY) * numImages * numModules + + hidActLoadX; + filters += blockFilterIdx + + (filterColorIdx + filtersLoadY) * filterPixels * numFilters + + filtersLoadX; + targets += (imgColorIdx + threadIdx.y) * imgPixels * numImages + + blockPixelIdx * numImages + blockCaseIdx + threadIdx.x; + // bool active_t = filtersLoadX < numFilters; float prod[colorsPerThread][imgsPerThread]; - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { prod[c][i] = 0; } } - const int startY = blockPixelIdxY - paddingStart < filterSize ? 0 - : 1 + (blockPixelIdxY - paddingStart - filterSize) / moduleStride; - const int endY = min(numModulesY, 1 + (blockPixelIdxY - paddingStart) / moduleStride); - const int startX = blockPixelIdxX - paddingStart < filterSize ? 0 - : 1 + (blockPixelIdxX - paddingStart - filterSize) / moduleStride; - const int endX = min(numModulesX, 1 + (blockPixelIdxX - paddingStart) / moduleStride); + const int startY = + blockPixelIdxY - paddingStart < filterSize + ? 0 + : 1 + (blockPixelIdxY - paddingStart - filterSize) / moduleStride; + const int endY = + min(numModulesY, 1 + (blockPixelIdxY - paddingStart) / moduleStride); + const int startX = + blockPixelIdxX - paddingStart < filterSize + ? 0 + : 1 + (blockPixelIdxX - paddingStart - filterSize) / moduleStride; + const int endX = + min(numModulesX, 1 + (blockPixelIdxX - paddingStart) / moduleStride); float* shFilterLoad = &shFilters[filtersLoadY][filtersLoadX]; float* shHidActLoad = &shHidActs[hidActLoadY][hidActLoadX]; - //const bool noFLoop = filterCacheF == filterCacheH; + // const bool noFLoop = filterCacheF == filterCacheH; for (int my = startY; my < endY; my++) { const int moduleTop = paddingStart + my * moduleStride; const int pxInFilterY = blockPixelIdxY - moduleTop; @@ -134,34 +153,48 @@ __global__ void conv_img_acts_manycolor_kepler(const float* hidActs, const float const int pxIdxInFilter = pxInFilterY * filterSize + pxInFilterX; - for (int f = 0; f < numFiltersPerGroup; f += filterCacheF) { // multiply with filterCacheF filters at a time - const float* fLoad = conv ? &filters[pxIdxInFilter * numFilters + f] - : &filters[(moduleIdx * numFilterColors * filterPixels + pxIdxInFilter) * numFilters + f]; - #pragma unroll - for (int i = 0; i < colorsPerThread*B_Y; i+= B_X*B_Y/filterCacheF) { - if (((colorsPerThread*B_Y) % (B_X*B_Y/filterCacheF) == 0 || - i + filtersLoadY < colorsPerThread*B_Y) && - f + filtersLoadX < numFiltersPerGroup) { - shFilterLoad[i * filterCacheF] = fLoad[i * filterPixels * numFilters]; + for (int f = 0; f < numFiltersPerGroup; + f += filterCacheF) { // multiply with filterCacheF filters at a time + const float* fLoad = + conv ? &filters[pxIdxInFilter * numFilters + f] + : &filters + [(moduleIdx * numFilterColors * filterPixels + + pxIdxInFilter) * + numFilters + + f]; +#pragma unroll + for (int i = 0; i < colorsPerThread * B_Y; + i += B_X * B_Y / filterCacheF) { + if (((colorsPerThread * B_Y) % (B_X * B_Y / filterCacheF) == 0 || + i + filtersLoadY < colorsPerThread * B_Y) && + f + filtersLoadX < numFiltersPerGroup) { + shFilterLoad[i * filterCacheF] = + fLoad[i * filterPixels * numFilters]; } else { shFilterLoad[i * filterCacheF] = 0; - } } //#pragma unroll for (int fh = f; fh < f + filterCacheF; fh += filterCacheH) { - //conv_img_acts_manycolor_dummy_fhLoop(hidActs, shHidActLoad, shHidActs, shFilters, moduleIdx, numImages, hidActLoadY, hidActLoadX, blockCaseIdx, numModules, f, fh, prod); + // conv_img_acts_manycolor_dummy_fhLoop(hidActs, shHidActLoad, shHidActs, shFilters, + // moduleIdx, numImages, hidActLoadY, hidActLoadX, blockCaseIdx, + // numModules, f, fh, prod); - const float* hLoad = &hidActs[(moduleIdx + fh * numModules) * numImages]; + const float* hLoad = + &hidActs[(moduleIdx + fh * numModules) * numImages]; int hload_offset = blockFilterIdx + hidActLoadY + fh; - #pragma unroll +#pragma unroll for (int j = 0; j < filterCacheH; j += B_Y) { if (filterCacheH % B_Y == 0 || hidActLoadY + j < filterCacheH) { - #pragma unroll - for (int i = 0; i < imgsPerThread*B_X; i += B_X) { - if ((!checkCaseBounds || blockCaseIdx + hidActLoadX + i < numImages) - && hload_offset + j < numFilters) { - shHidActLoad[j * B_X * imgsPerThread + i] = hLoad[j * numModules * numImages + i]; +#pragma unroll + for (int i = 0; i < imgsPerThread * B_X; i += B_X) { + if ((!checkCaseBounds || + blockCaseIdx + hidActLoadX + i < numImages) && + hload_offset + j < numFilters) { + shHidActLoad[j * B_X * imgsPerThread + i] = + hLoad[j * numModules * numImages + i]; } else { shHidActLoad[j * B_X * imgsPerThread + i] = 0; } @@ -170,51 +203,55 @@ __global__ void conv_img_acts_manycolor_kepler(const float* hidActs, const float } __syncthreads(); - // Do some actual computation - // Using these variables causes register usage to go from 161 --> 123. - // But nonetheless, the high-register version is faster. - //const float* shF = &shFilters[threadIdx.y][fh-f]; - //const float* const shF2 = &shFilters[threadIdx.y][fh]; - //const float* shH = &shHidActs[0][threadIdx.x]; - #pragma unroll +// Do some actual computation +// Using these variables causes register usage to go from 161 --> 123. +// But nonetheless, the high-register version is faster. +// const float* shF = &shFilters[threadIdx.y][fh-f]; +// const float* const shF2 = &shFilters[threadIdx.y][fh]; +// const float* shH = &shHidActs[0][threadIdx.x]; +#pragma unroll for (int w = 0; w < filterCacheH; w++) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { // for test (checking result) - //float hid_val = shHidActs[w][threadIdx.x + i * B_X]; - //if (isnan(hid_val)) { + // float hid_val = shHidActs[w][threadIdx.x + i * B_X]; + // if (isnan(hid_val)) { // hid_val = 0; //} - prod[c][i] += shFilters[c * B_Y + threadIdx.y][fh-f + w] * shHidActs[w][threadIdx.x + i * B_X]; - + prod[c][i] += + shFilters[c * B_Y + threadIdx.y][fh - f + w] * + shHidActs[w][threadIdx.x + i * B_X]; } } } __syncthreads(); - } } } } if (scale) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkCaseBounds || blockCaseIdx + threadIdx.x + i * B_X < numImages) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - targets[c * B_Y * imgPixels * numImages + i * B_X] = scaleTargets * targets[c * B_Y * imgPixels * numImages + i * B_X] + scaleOutputs * prod[c][i]; + targets[c * B_Y * imgPixels * numImages + i * B_X] = + scaleTargets * + targets[c * B_Y * imgPixels * numImages + i * B_X] + + scaleOutputs * prod[c][i]; } } } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { if (!checkCaseBounds || blockCaseIdx + threadIdx.x + i * B_X < numImages) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - targets[c * B_Y * imgPixels * numImages + i * B_X] = scaleOutputs * prod[c][i]; + targets[c * B_Y * imgPixels * numImages + i * B_X] = + scaleOutputs * prod[c][i]; } } } @@ -222,34 +259,55 @@ __global__ void conv_img_acts_manycolor_kepler(const float* hidActs, const float } #define IMG_MANY_COLOR_K_HEAD template __global__ void conv_img_acts_manycolor_kepler -#define IMG_MANY_COLOR_K(scale, ckCase, conv) \ - IMG_MANY_COLOR_K_HEAD< 8, 32, 4, 8, 32, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 8, 32, 2, 8, 32, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 8, 32, 1, 8, 32, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - \ - IMG_MANY_COLOR_K_HEAD< 8, 32, 4, 8, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 8, 32, 2, 8, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 8, 32, 1, 8, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 4, 12, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 2, 12, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 1, 12, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 4, 8, 32, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 2, 8, 32, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 1, 8, 32, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 4, 8, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 2, 8, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 1, 8, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 4, 4, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 2, 4, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 1, 4, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 4, 2, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 2, 2, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ - IMG_MANY_COLOR_K_HEAD< 4, 32, 1, 2, 16, 16, scale, ckCase, conv > (MANYCOLOR_KEP_PARAM); \ +#define IMG_MANY_COLOR_K(scale, ckCase, conv) \ + IMG_MANY_COLOR_K_HEAD<8, 32, 4, 8, 32, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<8, 32, 2, 8, 32, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<8, 32, 1, 8, 32, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + \ + IMG_MANY_COLOR_K_HEAD<8, 32, 4, 8, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<8, 32, 2, 8, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<8, 32, 1, 8, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + \ + IMG_MANY_COLOR_K_HEAD<4, 32, 4, 12, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<4, 32, 2, 12, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<4, 32, 1, 12, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + \ + IMG_MANY_COLOR_K_HEAD<4, 32, 4, 8, 32, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<4, 32, 2, 8, 32, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<4, 32, 1, 8, 32, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + \ + IMG_MANY_COLOR_K_HEAD<4, 32, 4, 8, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<4, 32, 2, 8, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<4, 32, 1, 8, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + \ + IMG_MANY_COLOR_K_HEAD<4, 32, 4, 4, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<4, 32, 2, 4, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<4, 32, 1, 4, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + \ + IMG_MANY_COLOR_K_HEAD<4, 32, 4, 2, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<4, 32, 2, 2, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); \ + IMG_MANY_COLOR_K_HEAD<4, 32, 1, 2, 16, 16, scale, ckCase, conv>( \ + MANYCOLOR_KEP_PARAM); // ftt //< 8, 32, 1, 8, 32, 16, scale, conv, conv > @@ -260,5 +318,5 @@ __global__ void conv_img_acts_manycolor_kepler(const float* hidActs, const float //< 4, 32, 1, 4, 16, 16, scale, conv, conv > //< 4, 32, 1, 2, 16, 16, scale, conv, conv > -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler_fff.cu b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler_fff.cu index e9d330dc..822ff923 100644 --- a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler_fff.cu +++ b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler_fff.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "img_act_manycolor_kepler.cuh" @@ -34,6 +35,5 @@ namespace cuda { IMG_MANY_COLOR_K(false, false, false) - -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler_ftf.cu b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler_ftf.cu index 365947ba..7a5e9195 100644 --- a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler_ftf.cu +++ b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_manycolor_kepler_ftf.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "img_act_manycolor_kepler.cuh" @@ -34,6 +35,5 @@ namespace cuda { IMG_MANY_COLOR_K(false, true, false) - -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_medium_color.cu b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_medium_color.cu index c6ab1e03..d2d579a6 100644 --- a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_medium_color.cu +++ b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_medium_color.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "img_act_medium_color.cuh" @@ -34,14 +35,14 @@ namespace megdnn { namespace cuda { IMG_MED_COLOR_K(false, false, false) -//IMG_MED_COLOR_K(false, false, true) +// IMG_MED_COLOR_K(false, false, true) IMG_MED_COLOR_K(false, true, false) -//IMG_MED_COLOR_K(false, true, true) +// IMG_MED_COLOR_K(false, true, true) -//IMG_MED_COLOR_K(true, false, false) -//IMG_MED_COLOR_K(true, false, true) -//IMG_MED_COLOR_K(true, true, false) -//IMG_MED_COLOR_K(true, true, true) +// IMG_MED_COLOR_K(true, false, false) +// IMG_MED_COLOR_K(true, false, true) +// IMG_MED_COLOR_K(true, true, false) +// IMG_MED_COLOR_K(true, true, true) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_medium_color.cuh b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_medium_color.cuh index 2cce459f..0cdbea44 100644 --- a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_medium_color.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_medium_color.cuh @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "img_act_templates.cuh" @@ -34,18 +35,18 @@ namespace megdnn { namespace cuda { /* * Block size: 16x16. - * blockIdx.x determines case in batches of 16*imgsPerThread, also color in batches of colorsPerThread. - * In essence, blockIdx.x.x = 1..numImages/(16*imgsPerThread) - * blockIdx.x.y = 1..numImgColors/colorsPerThread - * blockIdx.y determines 4x4 image region in target image. + * blockIdx.x determines case in batches of 16*imgsPerThread, also color in batches of + * colorsPerThread. In essence, blockIdx.x.x = 1..numImages/(16*imgsPerThread) + * blockIdx.x.y = 1..numImgColors/colorsPerThread blockIdx.y determines 4x4 image region + * in target image. * * threadIdx.x determines case. * threadIdx.y determines pixel. * * hidActs: (numFilters, numModulesY, numModulesX, numImages) - * filters: (numFilterColors, filterPixels, numFilters) if conv - * (numModulesY, numModulesX, numFilterColors, filterPixels, numFilters) otherwise - * targets: (numImageColors, imgSizeY, imgSizeX, numImages) + * filters: (numFilterColors, filterPixels, numFilters) if conv (numModulesY, + * numModulesX, numFilterColors, filterPixels, numFilters) otherwise targets: + * (numImageColors, imgSizeY, imgSizeX, numImages) * * Each block reconstructs one 4x4 pixels from 16*imgsPerThread cases. * @@ -59,25 +60,30 @@ namespace cuda { * * To be used when there are 4-16 color channels. */ -template -__global__ void img_acts_mediumcolor(const float* hidActs, const float* filters, float* targets, - const int numModulesY, const int numModulesX, const int numImages, const int numFilters, - const int filterSize, const int imgSizeY, const int imgSizeX, const int paddingStart, - const int moduleStride, const int numImgColors, const int numGroups, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shFilters[colorsPerThread*16][16 + 1]; - __shared__ float shHidActs[16][16*imgsPerThread]; - fill_shared_mem((float *)shFilters, sizeof(shFilters)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +template < + int imgsPerThread, int colorsPerThread, bool scale, bool checkCaseBounds, + bool conv> +__global__ void img_acts_mediumcolor( + const float* hidActs, const float* filters, float* targets, + const int numModulesY, const int numModulesX, const int numImages, + const int numFilters, const int filterSize, const int imgSizeY, + const int imgSizeX, const int paddingStart, const int moduleStride, + const int numImgColors, const int numGroups, const float scaleTargets, + const float scaleOutputs) { + __shared__ float shFilters[colorsPerThread * 16][16 + 1]; + __shared__ float shHidActs[16][16 * imgsPerThread]; + fill_shared_mem((float*)shFilters, sizeof(shFilters) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); - const int numImgBlocks = DIVUP(numImages,16*imgsPerThread); - const int blockCaseIdx = (blockIdx.x % numImgBlocks) * 16*imgsPerThread; + const int numImgBlocks = DIVUP(numImages, 16 * imgsPerThread); + const int blockCaseIdx = (blockIdx.x % numImgBlocks) * 16 * imgsPerThread; - const int imgColorIdx = (blockIdx.x / numImgBlocks) * colorsPerThread; // color idx globally + const int imgColorIdx = + (blockIdx.x / numImgBlocks) * colorsPerThread; // color idx globally const int numFilterColors = numImgColors / numGroups; const int blockGroupIdx = imgColorIdx / numFilterColors; - const int filterColorIdx = imgColorIdx % numFilterColors; // color idx within group + const int filterColorIdx = imgColorIdx % numFilterColors; // color idx within group const int numFiltersPerGroup = numFilters / numGroups; const int blockFilterIdx = blockGroupIdx * numFiltersPerGroup; @@ -99,28 +105,35 @@ __global__ void img_acts_mediumcolor(const float* hidActs, const float* filters, const int loadY = tidx / 32, loadX = tidx % 32; hidActs += blockCaseIdx + (blockFilterIdx + loadY) * numImages * numModules + loadX; - filters += blockFilterIdx + filterColorIdx * filterPixels * numFilters + threadIdx.x; - targets += imgColorIdx * imgPixels * numImages + pxIdx * numImages + blockCaseIdx + threadIdx.x; + filters += + blockFilterIdx + filterColorIdx * filterPixels * numFilters + threadIdx.x; + targets += imgColorIdx * imgPixels * numImages + pxIdx * numImages + blockCaseIdx + + threadIdx.x; float prod[colorsPerThread][imgsPerThread]; - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { prod[c][i] = 0; } } - const int startY = blockRegionTop - paddingStart < filterSize ? 0 - : 1 + (blockRegionTop - paddingStart - filterSize) / moduleStride; - const int endY = MIN(numModulesY, 1 + (blockRegionTop + 3 - paddingStart) / moduleStride); - const int startX = blockRegionLeft - paddingStart < filterSize ? 0 - : 1 + (blockRegionLeft - paddingStart - filterSize) / moduleStride; - const int endX = MIN(numModulesX, 1 + (blockRegionLeft + 3 - paddingStart) / moduleStride); + const int startY = + blockRegionTop - paddingStart < filterSize + ? 0 + : 1 + (blockRegionTop - paddingStart - filterSize) / moduleStride; + const int endY = + MIN(numModulesY, 1 + (blockRegionTop + 3 - paddingStart) / moduleStride); + const int startX = + blockRegionLeft - paddingStart < filterSize + ? 0 + : 1 + (blockRegionLeft - paddingStart - filterSize) / moduleStride; + const int endX = + MIN(numModulesX, 1 + (blockRegionLeft + 3 - paddingStart) / moduleStride); float* shFilterLoad = &shFilters[threadIdx.y][threadIdx.x]; float* shHidActLoad = &shHidActs[loadY][loadX]; - for (int my = startY; my < endY; my++) { const int moduleTop = paddingStart + my * moduleStride; const int pxInModuleY = pxY - moduleTop; @@ -130,43 +143,59 @@ __global__ void img_acts_mediumcolor(const float* hidActs, const float* filters, const int moduleLeft = paddingStart + mx * moduleStride; const int pxInModuleX = pxX - moduleLeft; - const bool isPxInModule = pxInModuleY >= 0 && pxInModuleY < filterSize && pxInModuleX >= 0 && pxInModuleX < filterSize; + const bool isPxInModule = pxInModuleY >= 0 && pxInModuleY < filterSize && + pxInModuleX >= 0 && pxInModuleX < filterSize; const int pxIdxInModule = pxInModuleY * filterSize + pxInModuleX; - for (int f = 0; f < numFiltersPerGroup; f += 16) { // multipply with 16 filters at a time - // Now the threads split up into half-warps, and each half-warp decides if it's interested. + for (int f = 0; f < numFiltersPerGroup; + f += 16) { // multipply with 16 filters at a time + // Now the threads split up into half-warps, and each half-warp decides + // if it's interested. const float* hLoad = &hidActs[(moduleIdx + f * numModules) * numImages]; int hload_offset = blockFilterIdx + loadY + f; - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread * 16; i += 32) { if (!checkCaseBounds || blockCaseIdx + loadX + i < numImages) { - #pragma unroll - for (int j = 0; j < 16; j += 8) { // load 16 rows of imgsPerThread*16 cols, 8 * 32 elements at a time. +#pragma unroll + for (int j = 0; j < 16; + j += 8) { // load 16 rows of imgsPerThread*16 cols, 8 * 32 + // elements at a time. if (hload_offset + j < numFilters) { - shHidActLoad[j * 16 * imgsPerThread + i] = hLoad[j * numModules * numImages + i]; + shHidActLoad[j * 16 * imgsPerThread + i] = + hLoad[j * numModules * numImages + i]; } else { shHidActLoad[j * 16 * imgsPerThread + i] = 0; } } } else { - #pragma unroll - for (int j = 0; j < 16; j += 8) { // load 16 rows of imgsPerThread*16 cols, 8 * 32 elements at a time. +#pragma unroll + for (int j = 0; j < 16; + j += 8) { // load 16 rows of imgsPerThread*16 cols, 8 * 32 + // elements at a time. shHidActLoad[j * 16 * imgsPerThread + i] = 0; } } } if (isPxInImg && isPxInModule) { - // This half-warp is interested, so it's going to load the weights from this module to its pixel. + // This half-warp is interested, so it's going to load the weights + // from this module to its pixel. // Not fully coalesced read :( - // But taking out this read entirely only reduces the runtime by ~2.8%, so it isn't costing me much. + // But taking out this read entirely only reduces the runtime by + // ~2.8%, so it isn't costing me much. const float* fLoad = conv ? &filters[pxIdxInModule * numFilters + f] - : &filters[(moduleIdx * numFilterColors * filterPixels + pxIdxInModule) * numFilters + f]; - #pragma unroll + : &filters + [(moduleIdx * numFilterColors * + filterPixels + + pxIdxInModule) * + numFilters + + f]; +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { if (blockFilterIdx + threadIdx.x + f < numFilters) { - shFilterLoad[c * 16 * (16 + 1)] = fLoad[c * filterPixels * numFilters]; + shFilterLoad[c * 16 * (16 + 1)] = + fLoad[c * filterPixels * numFilters]; } else { shFilterLoad[c * 16 * (16 + 1)] = 0; } @@ -176,13 +205,14 @@ __global__ void img_acts_mediumcolor(const float* hidActs, const float* filters, __syncthreads(); // Do some actual computation if (isPxInImg && isPxInModule) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int w = 0; w < 16; w++) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - prod[c][i] += shFilters[threadIdx.y + c * 16][w] * shHidActs[w][threadIdx.x + i * 16]; + prod[c][i] += shFilters[threadIdx.y + c * 16][w] * + shHidActs[w][threadIdx.x + i * 16]; } } } @@ -191,25 +221,32 @@ __global__ void img_acts_mediumcolor(const float* hidActs, const float* filters, } } } - // Not fully coalesced write :(... shmem (and fully coalesced) version is actually slightly slower, though + // Not fully coalesced write :(... shmem (and fully coalesced) version is actually + // slightly slower, though if (isPxInImg) { if (scale) { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - if (!checkCaseBounds || blockCaseIdx + threadIdx.x + i * 16 < numImages) { - #pragma unroll + if (!checkCaseBounds || + blockCaseIdx + threadIdx.x + i * 16 < numImages) { +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - targets[c * imgPixels * numImages + i * 16] = scaleTargets * targets[c * imgPixels * numImages + i * 16] + scaleOutputs * prod[c][i]; + targets[c * imgPixels * numImages + i * 16] = + scaleTargets * + targets[c * imgPixels * numImages + i * 16] + + scaleOutputs * prod[c][i]; } } } } else { - #pragma unroll +#pragma unroll for (int i = 0; i < imgsPerThread; i++) { - if (!checkCaseBounds || blockCaseIdx + threadIdx.x + i * 16 < numImages) { - #pragma unroll + if (!checkCaseBounds || + blockCaseIdx + threadIdx.x + i * 16 < numImages) { +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - targets[c * imgPixels * numImages + i * 16] = scaleOutputs * prod[c][i]; + targets[c * imgPixels * numImages + i * 16] = + scaleOutputs * prod[c][i]; } } } @@ -218,10 +255,10 @@ __global__ void img_acts_mediumcolor(const float* hidActs, const float* filters, } #define IMG_MED_COLOR_K_HEAD template __global__ void img_acts_mediumcolor -#define IMG_MED_COLOR_K(scale, ckCase, conv) \ - IMG_MED_COLOR_K_HEAD< 8, 4, scale, ckCase, conv >(MED_COLOR_KEP_PARAM); \ - IMG_MED_COLOR_K_HEAD< 4, 4, scale, ckCase, conv >(MED_COLOR_KEP_PARAM); \ - IMG_MED_COLOR_K_HEAD< 2, 4, scale, ckCase, conv >(MED_COLOR_KEP_PARAM); +#define IMG_MED_COLOR_K(scale, ckCase, conv) \ + IMG_MED_COLOR_K_HEAD<8, 4, scale, ckCase, conv>(MED_COLOR_KEP_PARAM); \ + IMG_MED_COLOR_K_HEAD<4, 4, scale, ckCase, conv>(MED_COLOR_KEP_PARAM); \ + IMG_MED_COLOR_K_HEAD<2, 4, scale, ckCase, conv>(MED_COLOR_KEP_PARAM); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_templates.cuh b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_templates.cuh index 88ee9410..bb7747e0 100644 --- a/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_templates.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/img_acts/img_act_templates.cuh @@ -25,30 +25,28 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ -#include "../nvmatrix.cuh" #include "../cudaconv2.cuh" +#include "../nvmatrix.cuh" #include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { -#define MANYCOLOR_KEP_PARAM const float* hidActs, \ - const float* filters, float* targets, \ - const int numModulesY, const int numModulesX, \ - const int numImages, const int numFilters, \ - const int filterSize, const int imgSizeY, \ - const int imgSizeX, const int paddingStart, \ - const int moduleStride, \ - const int numImgColors, const int numGroups, \ - const float scaleTargets, const float scaleOutputs +#define MANYCOLOR_KEP_PARAM \ + const float *hidActs, const float *filters, float *targets, const int numModulesY, \ + const int numModulesX, const int numImages, const int numFilters, \ + const int filterSize, const int imgSizeY, const int imgSizeX, \ + const int paddingStart, const int moduleStride, const int numImgColors, \ + const int numGroups, const float scaleTargets, const float scaleOutputs /* * Block size: B_YxB_X. - * blockIdx.x determines case in batches of B_X*imgsPerThread, also color in batches of B_Y*colorsPerThread. - * In essence, blockIdx.x.x = 1..numImages/(B_X*imgsPerThread) + * blockIdx.x determines case in batches of B_X*imgsPerThread, also color in batches of + * B_Y*colorsPerThread. In essence, blockIdx.x.x = 1..numImages/(B_X*imgsPerThread) * blockIdx.x.y = 1..numImgColors/(B_Y*colorsPerThread) * blockIdx.y determines image pixel in target image. * @@ -56,11 +54,12 @@ namespace cuda { * threadIdx.y determines color. * * hidActs: (numFilters, numModulesY, numModulesX, numImages) - * filters: (numFilterColors, filterPixels, numFilters) if conv - * (numModulesY, numModulesX, numFilterColors, filterPixels, numFilters) otherwise - * targets: (numImageColors, imgSizeY, imgSizeX, numImages) + * filters: (numFilterColors, filterPixels, numFilters) if conv (numModulesY, + * numModulesX, numFilterColors, filterPixels, numFilters) otherwise targets: + * (numImageColors, imgSizeY, imgSizeX, numImages) * - * Each block reconstructs one B_Y*colorsPerThread colors from 1 pixel from B_X*imgsPerThread cases. + * Each block reconstructs one B_Y*colorsPerThread colors from 1 pixel from + * B_X*imgsPerThread cases. * * numImages must be divisible by B_X*imgsPerThread if checkCaseBounds is false. * numFiltersPerGroup must be divisible by filterCacheF. @@ -70,41 +69,36 @@ namespace cuda { * filterCacheF must be divisible by filterCacheH * * This version loads 32 cases at a time, so it gets full coalescing on that load. - * It only loads filterCacheF weights at a time, so those aren't fully coalesced (depending on size of filterCacheF). + * It only loads filterCacheF weights at a time, so those aren't fully coalesced + * (depending on size of filterCacheF). * * To be used when there are >= 16 color channels. */ -template +template < + int B_Y, int B_X, int imgsPerThread, int colorsPerThread, int filterCacheF, + int filterCacheH, bool scale, bool checkCaseBounds, bool conv> __global__ void conv_img_acts_manycolor_kepler(MANYCOLOR_KEP_PARAM); - - -#define MED_COLOR_KEP_PARAM const float* hidActs, \ - const float* filters, float* targets, \ - const int numModulesY, const int numModulesX, \ - const int numImages, const int numFilters, \ - const int filterSize, \ - const int imgSizeY, const int imgSizeX, \ - const int paddingStart, const int moduleStride, \ - const int numImgColors, const int numGroups, \ - const float scaleTargets, const float scaleOutputs +#define MED_COLOR_KEP_PARAM \ + const float *hidActs, const float *filters, float *targets, const int numModulesY, \ + const int numModulesX, const int numImages, const int numFilters, \ + const int filterSize, const int imgSizeY, const int imgSizeX, \ + const int paddingStart, const int moduleStride, const int numImgColors, \ + const int numGroups, const float scaleTargets, const float scaleOutputs /* * Block size: 16x16. - * blockIdx.x determines case in batches of 16*imgsPerThread, also color in batches of colorsPerThread. - * In essence, blockIdx.x.x = 1..numImages/(16*imgsPerThread) - * blockIdx.x.y = 1..numImgColors/colorsPerThread - * blockIdx.y determines 4x4 image region in target image. + * blockIdx.x determines case in batches of 16*imgsPerThread, also color in batches of + * colorsPerThread. In essence, blockIdx.x.x = 1..numImages/(16*imgsPerThread) + * blockIdx.x.y = 1..numImgColors/colorsPerThread blockIdx.y determines 4x4 image region + * in target image. * * threadIdx.x determines case. * threadIdx.y determines pixel. * * hidActs: (numFilters, numModulesY, numModulesX, numImages) - * filters: (numFilterColors, filterPixels, numFilters) if conv - * (numModulesY, numModulesX, numFilterColors, filterPixels, numFilters) otherwise - * targets: (numImageColors, imgSizeY, imgSizeX, numImages) + * filters: (numFilterColors, filterPixels, numFilters) if conv (numModulesY, + * numModulesX, numFilterColors, filterPixels, numFilters) otherwise targets: + * (numImageColors, imgSizeY, imgSizeX, numImages) * * Each block reconstructs one 4x4 pixels from 16*imgsPerThread cases. * @@ -118,18 +112,17 @@ __global__ void conv_img_acts_manycolor_kepler(MANYCOLOR_KEP_PARAM); * * To be used when there are 4-16 color channels. */ -template +template < + int imgsPerThread, int colorsPerThread, bool scale, bool checkCaseBounds, + bool conv> __global__ void img_acts_mediumcolor(MED_COLOR_KEP_PARAM); - -#define COLOR_KEP_PARAM const float* hidActs, \ - const float* filters, float* targets, \ - const int numModulesY, const int numModulesX, \ - const int numImages, const int numFilters, \ - const int filterSize, \ - const int imgSizeY, const int imgSizeX, \ - const int paddingStart, const int moduleStride, \ - const float scaleTargets, const float scaleOutputs +#define COLOR_KEP_PARAM \ + const float *hidActs, const float *filters, float *targets, const int numModulesY, \ + const int numModulesX, const int numImages, const int numFilters, \ + const int filterSize, const int imgSizeY, const int imgSizeX, \ + const int paddingStart, const int moduleStride, const float scaleTargets, \ + const float scaleOutputs /* * Block size: 16x16. @@ -140,8 +133,8 @@ __global__ void img_acts_mediumcolor(MED_COLOR_KEP_PARAM); * threadIdx.y determines pixel. * * hidActs: (numFilters, numModulesY, numModulesX, numImages) - * filters: (numColors, filterPixels, numFilters) if conv - * (numModulesY, numModulesX, numColors, filterPixels, numFilters) otherwise + * filters: (numColors, filterPixels, numFilters) if + * conv (numModulesY, numModulesX, numColors, filterPixels, numFilters) otherwise * targets: (numColors, imgSizeY, imgSizeX, numImages) * * Each block reconstructs one 4x4 pixels from 16*imgsPerThread cases. @@ -157,5 +150,5 @@ __global__ void img_acts_mediumcolor(MED_COLOR_KEP_PARAM); template __global__ void img_acts_color(COLOR_KEP_PARAM); -} // namespace megdnn -} // namespace cuda +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/nvmatrix.cuh b/dnn/src/cuda/local/cuda-convnet2/nvmatrix.cuh index 9846ca95..f8675f8e 100644 --- a/dnn/src/cuda/local/cuda-convnet2/nvmatrix.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/nvmatrix.cuh @@ -25,53 +25,41 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #pragma once -#include "src/cuda/utils.cuh" #include +#include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { -const int TEXTURE_SIZE_MAX = 1<<29; +const int TEXTURE_SIZE_MAX = 1 << 29; struct MemorySegment { - float *data; - MemorySegment(float *data): data(data) - {} + float* data; + MemorySegment(float* data) : data(data) {} }; struct NVMatrix { - NVMatrix(MemorySegment *seg, int row, int col): - seg(seg), row(row), col(col), stride(col), _texObj(0) - { - } - NVMatrix(MemorySegment *seg, int row, int col, int stride): - seg(seg), row(row), col(col), stride(stride), _texObj(0) - { - } - float *getDevData() - { - return seg->data; - } - MemorySegment *seg; + NVMatrix(MemorySegment* seg, int row, int col) + : seg(seg), row(row), col(col), stride(col), _texObj(0) {} + NVMatrix(MemorySegment* seg, int row, int col, int stride) + : seg(seg), row(row), col(col), stride(stride), _texObj(0) {} + float* getDevData() { return seg->data; } + MemorySegment* seg; int row, col, stride; cudaTextureObject_t _texObj; // target must be initialized before transpose. - void transpose(const NVMatrix &target, cublasHandle_t handle, - float *one, float *zero) - { - cublas_check(cublasSgeam(handle, - CUBLAS_OP_T, CUBLAS_OP_T, - row, col, - one, - seg->data, this->stride, - zero, - seg->data, this->stride, - target.seg->data, target.stride)); + void transpose( + const NVMatrix& target, cublasHandle_t handle, float* one, float* zero) { + cublas_check(cublasSgeam( + handle, CUBLAS_OP_T, CUBLAS_OP_T, row, col, one, seg->data, + this->stride, zero, seg->data, this->stride, target.seg->data, + target.stride)); } cudaTextureObject_t getTextureObject() { if (_texObj == 0) { @@ -80,8 +68,8 @@ struct NVMatrix { resDesc.resType = cudaResourceTypeLinear; resDesc.res.linear.devPtr = getDevData(); resDesc.res.linear.sizeInBytes = getNumDataBytes(); - resDesc.res.linear.desc = cudaCreateChannelDesc(32, 0, 0, 0, - cudaChannelFormatKindFloat); + resDesc.res.linear.desc = + cudaCreateChannelDesc(32, 0, 0, 0, cudaChannelFormatKindFloat); struct cudaTextureDesc texDesc; memset(&texDesc, 0, sizeof(texDesc)); cuda_check(cudaCreateTextureObject(&_texObj, &resDesc, &texDesc, NULL)); @@ -89,43 +77,23 @@ struct NVMatrix { megdnn_assert_internal(_texObj != 0); return _texObj; } - ~NVMatrix() - { + ~NVMatrix() { if (_texObj) { cuda_check(cudaDestroyTextureObject(_texObj)); } } - int getNumDataBytes() - { - return row * col * sizeof(float); - } - int getNumRows() - { - return row; - } - int getNumCols() - { - return col; - } - int getStride() - { - return stride; - } - bool isTrans() - { - return false; - } - bool isContiguous() - { - return true; - } - void resize(int row, int col) - { + int getNumDataBytes() { return row * col * sizeof(float); } + int getNumRows() { return row; } + int getNumCols() { return col; } + int getStride() { return stride; } + bool isTrans() { return false; } + bool isContiguous() { return true; } + void resize(int row, int col) { megdnn_assert_internal(row * col == this->row * this->col); this->row = row; this->col = col; } }; -} -} +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts.cu index ecd67d7d..2d6086de 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts.cu @@ -25,15 +25,16 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "cudaconv2.cuh" +#include #include "nvmatrix.cuh" #include "weight_acts/wet_act_templates.cuh" -#include #ifdef _WIN32 #define _Pragma(x) @@ -42,57 +43,59 @@ namespace megdnn { namespace cuda { -__device__ __forceinline__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16_setCoords( - const int my, const int mx, const int paddingStart, const int numModulesX, const int moduleStride, - const int blockPixelY, const int blockPixelX, const int imgSizeX, - const int imgStride, int& pixIdx, int& m) { +__device__ __forceinline__ void +conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16_setCoords( + const int my, const int mx, const int paddingStart, const int numModulesX, + const int moduleStride, const int blockPixelY, const int blockPixelX, + const int imgSizeX, const int imgStride, int& pixIdx, int& m) { const int imgLoadModPosY = paddingStart + my * moduleStride; const int imgLoadModPosX = paddingStart + mx * moduleStride; - const int pxY = imgLoadModPosY + blockPixelY; // pixel x,y coords in image + const int pxY = imgLoadModPosY + blockPixelY; // pixel x,y coords in image const int pxX = imgLoadModPosX + blockPixelX; - pixIdx = (pxY * imgSizeX + pxX) * imgStride; // pixel idx in image + pixIdx = (pxY * imgSizeX + pxX) * imgStride; // pixel idx in image m = my * numModulesX + mx; } +#define WA_C3_LOOP(pp, c) \ + _Pragma("unroll") for (int i = 0; i < preloadCases; i++) { \ + _Pragma("unroll") for (int p = 0; p < pixelCache; p++) { \ + _Pragma("unroll") for (int f = 0; f < filtersPerThread; f++) { \ + prod[c][(pp) + p][f] += \ + shImages[threadIdx.y + p * B_Y + (c)*pixelCache * B_Y][i] * \ + shHidActs[threadIdx.x * filtersPerThread + f][i]; \ + } \ + } \ + } -#define WA_C3_LOOP(pp, c) _Pragma("unroll") \ -for (int i = 0; i < preloadCases; i++) { \ - _Pragma("unroll") \ - for (int p = 0; p < pixelCache; p++) { \ - _Pragma("unroll") \ - for (int f = 0; f < filtersPerThread; f++) { \ - prod[c][(pp) + p][f] += shImages[threadIdx.y + p * B_Y + (c) * pixelCache * B_Y][i] * shHidActs[threadIdx.x * filtersPerThread + f][i]; \ - } \ - } \ -} - -#define WA_C3_LOOP2(pp) _Pragma("unroll") \ -for (int p = 0; p < pixelCache; p++) { \ - _Pragma("unroll") \ - for (int i = 0; i < preloadCases; i++) { \ - _Pragma("unroll") \ - for (int f = 0; f < filtersPerThread; f++) { \ - _Pragma("unroll") \ - for (int c = 0; c < 3; ++c) { \ - prod[c][(pp) + p][f] += shImages[threadIdx.y + p * B_Y + (c) * pixelCache * B_Y][i] * shHidActs[threadIdx.x * filtersPerThread + f][i]; \ - } \ - } \ - } \ -} - -#define WA_3_FIDX(y) (((loadY + (y)*B_X*B_Y/preloadCases) % filtersPerThread) * B_X + (loadY + (y)*B_X*B_Y/preloadCases) / filtersPerThread) +#define WA_C3_LOOP2(pp) \ + _Pragma("unroll") for (int p = 0; p < pixelCache; p++) { \ + _Pragma("unroll") for (int i = 0; i < preloadCases; i++) { \ + _Pragma("unroll") for (int f = 0; f < filtersPerThread; f++) { \ + _Pragma("unroll") for (int c = 0; c < 3; ++c) { \ + prod[c][(pp) + p][f] += \ + shImages[threadIdx.y + p * B_Y + (c)*pixelCache * B_Y] \ + [i] * \ + shHidActs[threadIdx.x * filtersPerThread + f][i]; \ + } \ + } \ + } \ + } +#define WA_3_FIDX(y) \ + (((loadY + (y)*B_X * B_Y / preloadCases) % filtersPerThread) * B_X + \ + (loadY + (y)*B_X * B_Y / preloadCases) / filtersPerThread) /* * Each block computes weight gradients for B_Y * pixelsPerThread pixels and B_X filters * threadIdx.x determines filter * threadIdx.y determines pixel in filter * - * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of partialSum - * blockIdx.y determines pixel batch of B_Y * pixelsPerThread + * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of + * partialSum blockIdx.y determines pixel batch of B_Y * pixelsPerThread * * Number of filters must be divisible by B_X * filtersPerThread - * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is false. + * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is + * false. * * images: (numColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) @@ -106,22 +109,29 @@ for (int p = 0; p < pixelCache; p++) { \ * numModules must be divisible by partialSum * pixelsPerThread must be divisible by pixelCache * - * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread = 1)... - * so the compiler is messing up here somehow. It's unable to optimize that case away. + * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread + * = 1)... so the compiler is messing up here somehow. It's unable to optimize that case + * away. */ -template +template < + int B_Y, int B_X, int pixelCache, int pixelsPerThread, int filtersPerThread, + int preloadCases, int numColors, bool scale, bool checkCaseBounds> //__launch_bounds__(256,2) -__global__ void conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3(cudaTextureObject_t images, cudaTextureObject_t hidActs, float* targets, - const int numImages, const int numFilters, - const int numModulesY, const int numModulesX, - const int imgSizeY, const int imgSizeX, const int filterSize, - const int paddingStart, const int moduleStride, const int imgStride, - const int sumWidth, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shImages[pixelCache * B_Y * numColors][preloadCases]; // preload preloadCases cases of B_Y * pixelsPerThread pixels - __shared__ float shHidActs[B_X * filtersPerThread][preloadCases + 1]; // preload preloadCases cases of B_X hidActs - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +__global__ void conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3( + cudaTextureObject_t images, cudaTextureObject_t hidActs, float* targets, + const int numImages, const int numFilters, const int numModulesY, + const int numModulesX, const int imgSizeY, const int imgSizeX, + const int filterSize, const int paddingStart, const int moduleStride, + const int imgStride, const int sumWidth, const float scaleTargets, + const float scaleOutputs) { + __shared__ float shImages[pixelCache * B_Y * numColors] + [preloadCases]; // preload preloadCases cases of B_Y * + // pixelsPerThread pixels + __shared__ float + shHidActs[B_X * filtersPerThread] + [preloadCases + 1]; // preload preloadCases cases of B_X hidActs + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); const int tidx = B_X * threadIdx.y + threadIdx.x; @@ -130,12 +140,12 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3(cudaTextureObj const int filterPixels = filterSize * filterSize; const int imgPixels = imgSizeY * imgSizeX; - const int numFilterBlocks = numFilters / (B_X*filtersPerThread); + const int numFilterBlocks = numFilters / (B_X * filtersPerThread); const int blockModuleChunkIdx = blockIdx.x / numFilterBlocks; const int numModuleChunksX = DIVUP(numModulesX, sumWidth); -// const int numModuleChunksY = DIVUP(numModulesY, sumWidth); + // const int numModuleChunksY = DIVUP(numModulesY, sumWidth); const int blockModuleChunkX = blockModuleChunkIdx % numModuleChunksX; const int blockModuleChunkY = blockModuleChunkIdx / numModuleChunksX; @@ -143,32 +153,31 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3(cudaTextureObj const int blockModuleStartX = blockModuleChunkX * sumWidth; const int blockModuleStartY = blockModuleChunkY * sumWidth; - const int blockFilterIdx = B_X * filtersPerThread* (blockIdx.x % numFilterBlocks); + const int blockFilterIdx = B_X * filtersPerThread * (blockIdx.x % numFilterBlocks); -// const int moduleStride = (imgSize - filterSize + 1) / numModulesX; + // const int moduleStride = (imgSize - filterSize + 1) / numModulesX; const int numModules = numModulesY * numModulesX; const int blockPixelOffset = blockIdx.y * B_Y * pixelsPerThread; const int imgOffset = loadX; const int hidActsOffset = blockFilterIdx * numImages * numModules + loadX; -// images += loadX; -// hidActs += blockFilterIdx * numImages * numModules -// + loadX; + // images += loadX; + // hidActs += blockFilterIdx * numImages * numModules + // + loadX; - targets += (blockModuleChunkIdx * numFilters) * filterPixels * numColors - + blockPixelOffset * numFilters - + blockFilterIdx - + threadIdx.y * numFilters + threadIdx.x; + targets += (blockModuleChunkIdx * numFilters) * filterPixels * numColors + + blockPixelOffset * numFilters + blockFilterIdx + + threadIdx.y * numFilters + threadIdx.x; - //float* shImgLoad = &shImages[loadY][loadX]; - //float* shHidActLoad = &shHidActs[loadY][loadX]; + // float* shImgLoad = &shImages[loadY][loadX]; + // float* shHidActLoad = &shHidActs[loadY][loadX]; float prod[numColors][pixelsPerThread][filtersPerThread]; - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { prod[c][p][f] = 0; } @@ -180,36 +189,39 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3(cudaTextureObj const int mEndY = min(numModulesY, blockModuleStartY + sumWidth); const bool doWork = mStartY < mEndY && mStartX < mEndX; -// if (!doWork) { -// hidActs -= -// } -// if (mStartY == mEndY || mStartX == mEndX) { -// return; -// } - -// float imPreload[pixelCache * numColors * preloadCases / B_X]; // [12] - float haPreload[filtersPerThread * preloadCases / B_Y]; // [8] -// if (blockIdx.x != 0 || blockIdx.y !=0) { -// return; -// } -// printf("mStartX: %d, mStartX: %d, mStartX: %d, mStartX: %d\n", mStartX, mStartY, mEndX, mEndY); + // if (!doWork) { + // hidActs -= + // } + // if (mStartY == mEndY || mStartX == mEndX) { + // return; + // } + + // float imPreload[pixelCache * numColors * preloadCases / B_X]; // [12] + float haPreload[filtersPerThread * preloadCases / B_Y]; // [8] + // if (blockIdx.x != 0 || blockIdx.y !=0) { + // return; + // } + // printf("mStartX: %d, mStartX: %d, mStartX: %d, mStartX: %d\n", mStartX, + // mStartY, mEndX, mEndY); const int fYOff = (blockPixelOffset + tidx) / filterSize; const int fXOff = (blockPixelOffset + tidx) % filterSize; - __shared__ int pxIdxes[B_Y*pixelsPerThread]; - fill_shared_mem((int *)pxIdxes, sizeof(pxIdxes)/sizeof(int), 0); + __shared__ int pxIdxes[B_Y * pixelsPerThread]; + fill_shared_mem((int*)pxIdxes, sizeof(pxIdxes) / sizeof(int), 0); __syncthreads(); -// __shared__ int fidx[filtersPerThread * preloadCases / B_Y]; // [8] + // __shared__ int fidx[filtersPerThread * preloadCases / B_Y]; // [8] int m = mStartY * numModulesX + mStartX; int fidx[filtersPerThread * preloadCases / B_Y]; if (doWork) { - #pragma unroll +#pragma unroll for (int y = 0; y < filtersPerThread * preloadCases / B_Y; ++y) { const int fIdx = WA_3_FIDX(y); -// if (doWork) { - haPreload[y] = tex1Dfetch(hidActs, hidActsOffset + fIdx * numImages * numModules + m * numImages); -// } + // if (doWork) { + haPreload[y] = tex1Dfetch( + hidActs, + hidActsOffset + fIdx * numImages * numModules + m * numImages); + // } fidx[y] = fIdx * numImages * numModules; } } @@ -219,15 +231,18 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3(cudaTextureObj for (int mx = mStartX; mx < mEndX; mx++) { m = my * numModulesX + mx; -// __syncthreads(); + // __syncthreads(); const int imgLoadModPosX = paddingStart + mx * moduleStride; if (tidx < B_Y * pixelsPerThread) { -// const int imgLoadModPosY = paddingStart + my * moduleStride; -// const int imgLoadModPosX = paddingStart + mx * moduleStride; + // const int imgLoadModPosY = paddingStart + my * + // moduleStride; const int imgLoadModPosX = paddingStart + // + mx * moduleStride; const int pxY = (imgLoadModPosY + fYOff); const int pxX = (imgLoadModPosX + fXOff); const int pixIdx = (pxY * imgSizeX + pxX) * imgStride; - pxIdxes[tidx] = pxY >= 0 && pxY < imgSizeY && pxX >= 0 && pxX < imgSizeX ? pixIdx : -1; + pxIdxes[tidx] = pxY >= 0 && pxY < imgSizeY && pxX >= 0 && pxX < imgSizeX + ? pixIdx + : -1; } __syncthreads(); @@ -242,40 +257,50 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3(cudaTextureObj for (int caseIdx = 0; caseIdx < numImages; caseIdx += preloadCases) { const bool lastBatch = caseIdx + preloadCases == numImages; -// const float* im = &images[caseIdx + preloadCases + pixIdx]; -// const float* ha = &hidActs[caseIdx + preloadCases + m * numImages]; - int hidActsOffset2 = hidActsOffset + caseIdx + preloadCases + m * numImages; + // const float* im = &images[caseIdx + preloadCases + + // pixIdx]; const float* ha = &hidActs[caseIdx + + // preloadCases + m * numImages]; + int hidActsOffset2 = + hidActsOffset + caseIdx + preloadCases + m * numImages; if (lastBatch) { -// ha = &hidActs[mNext * numImages]; + // ha = &hidActs[mNext * numImages]; hidActsOffset2 = hidActsOffset + mNext * numImages; } - #pragma unroll - for (int y = 0; y < B_X*filtersPerThread; y += (B_X * B_Y) / preloadCases) { - shHidActs[loadY+y][loadX] = haPreload[y*preloadCases/(B_X*B_Y)]; +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + shHidActs[loadY + y][loadX] = + haPreload[y * preloadCases / (B_X * B_Y)]; } - /* ================================================================================== - * Iteration 0 - * ================================================================================== - */ - #pragma unroll +/* ================================================================================== + * Iteration 0 + * ================================================================================== + */ +#pragma unroll for (int y = 0; y < B_Y * pixelCache; y += (B_X * B_Y) / preloadCases) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX] = 0; + shImages[loadY + y + c * pixelCache * B_Y][loadX] = 0; } } - #pragma unroll +#pragma unroll for (int y = 0; y < B_Y * pixelCache; y += (B_X * B_Y) / preloadCases) { - const int pxIdx = 0 * B_Y + loadY + y; // pixel idx in filter + const int pxIdx = 0 * B_Y + loadY + y; // pixel idx in filter if (pxIdx + blockPixelOffset < filterPixels) { - const int pixIdx = pxIdxes[pxIdx];//(pxY * imgSizeX + pxX) * imgStride; + const int pixIdx = + pxIdxes[pxIdx]; //(pxY * imgSizeX + pxX) * imgStride; if (pixIdx >= 0) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX] = tex1Dfetch(images, imgOffset + caseIdx + c * imgPixels * imgStride + pixIdx); + shImages[loadY + y + c * pixelCache * B_Y][loadX] = + tex1Dfetch( + images, + imgOffset + caseIdx + + c * imgPixels * imgStride + + pixIdx); } } } @@ -285,13 +310,13 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3(cudaTextureObj haPreload[0] = tex1Dfetch(hidActs, hidActsOffset2 + fidx[0]); haPreload[1] = tex1Dfetch(hidActs, hidActsOffset2 + fidx[1]); - WA_C3_LOOP(0,0); + WA_C3_LOOP(0, 0); haPreload[2] = tex1Dfetch(hidActs, hidActsOffset2 + fidx[2]); haPreload[3] = tex1Dfetch(hidActs, hidActsOffset2 + fidx[3]); - WA_C3_LOOP(0,1); + WA_C3_LOOP(0, 1); haPreload[4] = tex1Dfetch(hidActs, hidActsOffset2 + fidx[4]); haPreload[5] = tex1Dfetch(hidActs, hidActsOffset2 + fidx[5]); - WA_C3_LOOP(0,2); + WA_C3_LOOP(0, 2); haPreload[6] = tex1Dfetch(hidActs, hidActsOffset2 + fidx[6]); haPreload[7] = tex1Dfetch(hidActs, hidActsOffset2 + fidx[7]); @@ -301,28 +326,34 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3(cudaTextureObj } if (scale) { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { if (blockPixelOffset + p * B_Y + threadIdx.y < filterPixels) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[p * B_Y * numFilters + c * filterPixels * numFilters + f * B_X] = scaleTargets * targets[p * B_Y * numFilters + c * filterPixels * numFilters + f * B_X] + scaleOutputs * prod[c][p][f]; + targets[p * B_Y * numFilters + c * filterPixels * numFilters + + f * B_X] = + scaleTargets * targets[p * B_Y * numFilters + + c * filterPixels * numFilters + + f * B_X] + + scaleOutputs * prod[c][p][f]; } } } } } else { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { if (blockPixelOffset + p * B_Y + threadIdx.y < filterPixels) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { -// if (threadIdx.x == 3) - targets[p * B_Y * numFilters + c * filterPixels * numFilters + f * B_X] = scaleOutputs * prod[c][p][f]; + // if (threadIdx.x == 3) + targets[p * B_Y * numFilters + c * filterPixels * numFilters + + f * B_X] = scaleOutputs * prod[c][p][f]; } } } @@ -330,17 +361,17 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3(cudaTextureObj } } - /* * Each block computes weight gradients for B_Y * pixelsPerThread pixels and B_X filters * threadIdx.x determines filter * threadIdx.y determines pixel in filter * - * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of partialSum - * blockIdx.y determines pixel batch of B_Y * pixelsPerThread + * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of + * partialSum blockIdx.y determines pixel batch of B_Y * pixelsPerThread * * Number of filters must be divisible by B_X * filtersPerThread - * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is false. + * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is + * false. * * images: (numColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) @@ -354,22 +385,29 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3(cudaTextureObj * numModules must be divisible by partialSum * pixelsPerThread must be divisible by pixelCache * - * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread = 1)... - * so the compiler is messing up here somehow. It's unable to optimize that case away. + * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread + * = 1)... so the compiler is messing up here somehow. It's unable to optimize that case + * away. */ -template -__launch_bounds__(256,2) -__global__ void conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3(cudaTextureObject_t images, cudaTextureObject_t hidActs, float* targets, - const int numImages, const int numFilters, - const int numModulesY, const int numModulesX, - const int imgSizeY, const int imgSizeX, const int filterSize, - const int paddingStart, const int moduleStride, const int imgStride, - const int sumWidth, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shImages[pixelCache * B_Y * numColors][preloadCases]; // preload preloadCases cases of B_Y * pixelsPerThread pixels - __shared__ float shHidActs[B_X * filtersPerThread][preloadCases + 1]; // preload preloadCases cases of B_X hidActs - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +template < + int B_Y, int B_X, int pixelCache, int pixelsPerThread, int filtersPerThread, + int preloadCases, int numColors, bool scale, bool checkCaseBounds> +__launch_bounds__(256, 2) __global__ + void conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3( + cudaTextureObject_t images, cudaTextureObject_t hidActs, float* targets, + const int numImages, const int numFilters, const int numModulesY, + const int numModulesX, const int imgSizeY, const int imgSizeX, + const int filterSize, const int paddingStart, const int moduleStride, + const int imgStride, const int sumWidth, const float scaleTargets, + const float scaleOutputs) { + __shared__ float shImages[pixelCache * B_Y * numColors] + [preloadCases]; // preload preloadCases cases of B_Y * + // pixelsPerThread pixels + __shared__ float + shHidActs[B_X * filtersPerThread] + [preloadCases + 1]; // preload preloadCases cases of B_X hidActs + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); const int tidx = B_X * threadIdx.y + threadIdx.x; @@ -378,12 +416,12 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3(cudaTextureObj const int filterPixels = filterSize * filterSize; const int imgPixels = imgSizeY * imgSizeX; - const int numFilterBlocks = numFilters / (B_X*filtersPerThread); + const int numFilterBlocks = numFilters / (B_X * filtersPerThread); const int blockModuleChunkIdx = blockIdx.x / numFilterBlocks; const int numModuleChunksX = DIVUP(numModulesX, sumWidth); -// const int numModuleChunksY = DIVUP(numModulesY, sumWidth); + // const int numModuleChunksY = DIVUP(numModulesY, sumWidth); const int blockModuleChunkX = blockModuleChunkIdx % numModuleChunksX; const int blockModuleChunkY = blockModuleChunkIdx / numModuleChunksX; @@ -391,33 +429,31 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3(cudaTextureObj const int blockModuleStartX = blockModuleChunkX * sumWidth; const int blockModuleStartY = blockModuleChunkY * sumWidth; - const int blockFilterIdx = B_X * filtersPerThread* (blockIdx.x % numFilterBlocks); + const int blockFilterIdx = B_X * filtersPerThread * (blockIdx.x % numFilterBlocks); -// const int moduleStride = (imgSize - filterSize + 1) / numModulesX; + // const int moduleStride = (imgSize - filterSize + 1) / numModulesX; const int numModules = numModulesY * numModulesX; const int blockPixelOffset = blockIdx.y * B_Y * pixelsPerThread; const int imgOffset = loadX; - const int hidActsOffset = blockFilterIdx * numImages * numModules - + loadX; -// images += loadX; -// hidActs += blockFilterIdx * numImages * numModules -// + loadX; + const int hidActsOffset = blockFilterIdx * numImages * numModules + loadX; + // images += loadX; + // hidActs += blockFilterIdx * numImages * numModules + // + loadX; - targets += (blockModuleChunkIdx * numFilters) * filterPixels * numColors - + blockPixelOffset * numFilters - + blockFilterIdx - + threadIdx.y * numFilters + threadIdx.x; + targets += (blockModuleChunkIdx * numFilters) * filterPixels * numColors + + blockPixelOffset * numFilters + blockFilterIdx + + threadIdx.y * numFilters + threadIdx.x; - //float* shImgLoad = &shImages[loadY][loadX]; - //float* shHidActLoad = &shHidActs[loadY][loadX]; + // float* shImgLoad = &shImages[loadY][loadX]; + // float* shHidActLoad = &shHidActs[loadY][loadX]; float prod[numColors][pixelsPerThread][filtersPerThread]; - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { prod[c][p][f] = 0; } @@ -429,99 +465,122 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3(cudaTextureObj const int mEndY = min(numModulesY, blockModuleStartY + sumWidth); const bool doWork = mStartY < mEndY && mStartX < mEndX; -// if (mStartY == mEndY || mStartX == mEndX) { -// return; -// } - -// float imPreload[pixelCache * numColors * preloadCases / B_X]; // [12] - float haPreload[filtersPerThread * preloadCases / B_Y]; // [6] -// if (blockIdx.x != 0 || blockIdx.y !=0) { -// return; -// } -// printf("mStartX: %d, mStartX: %d, mStartX: %d, mStartX: %d\n", mStartX, mStartY, mEndX, mEndY); + // if (mStartY == mEndY || mStartX == mEndX) { + // return; + // } + + // float imPreload[pixelCache * numColors * preloadCases / B_X]; // [12] + float haPreload[filtersPerThread * preloadCases / B_Y]; // [6] + // if (blockIdx.x != 0 || blockIdx.y !=0) { + // return; + // } + // printf("mStartX: %d, mStartX: %d, mStartX: %d, mStartX: %d\n", mStartX, + // mStartY, mEndX, mEndY); const int fYOff = (blockPixelOffset + tidx) / filterSize; const int fXOff = (blockPixelOffset + tidx) % filterSize; - __shared__ int pxIdxes[B_Y*pixelsPerThread]; - fill_shared_mem((int *)pxIdxes, sizeof(pxIdxes)/sizeof(int), 0); + __shared__ int pxIdxes[B_Y * pixelsPerThread]; + fill_shared_mem((int*)pxIdxes, sizeof(pxIdxes) / sizeof(int), 0); __syncthreads(); -// __shared__ int fidx[filtersPerThread * preloadCases / B_Y]; // [6] + // __shared__ int fidx[filtersPerThread * preloadCases / B_Y]; // [6] int m = mStartY * numModulesX + mStartX; int fidx[filtersPerThread * preloadCases / B_Y]; -// if (doWork) { - #pragma unroll + // if (doWork) { +#pragma unroll for (int y = 0; y < filtersPerThread * preloadCases / B_Y; ++y) { fidx[y] = WA_3_FIDX(y) * numImages * numModules; - if (doWork) { // Not actually necessary, I think - haPreload[y] = tex1Dfetch(hidActs, hidActsOffset + fidx[y] + m * numImages); + if (doWork) { // Not actually necessary, I think + haPreload[y] = + tex1Dfetch(hidActs, hidActsOffset + fidx[y] + m * numImages); } } -// } + // } int mNext = mStartY * numModulesX + mStartX; for (int my = mStartY; my < mEndY; my++) { -// const int imgLoadModPosY = paddingStart + my * moduleStride; + // const int imgLoadModPosY = paddingStart + my * moduleStride; for (int mx = mStartX; mx < mEndX; mx++) { - m = mNext;//my * numModulesX + mx; + m = mNext; // my * numModulesX + mx; -// __syncthreads(); -// const int imgLoadModPosX = paddingStart + mx * moduleStride; + // __syncthreads(); + // const int imgLoadModPosX = paddingStart + mx * moduleStride; if (tidx < B_Y * pixelsPerThread) { const int imgLoadModPosY = paddingStart + my * moduleStride; const int imgLoadModPosX = paddingStart + mx * moduleStride; const int pxY = (imgLoadModPosY + fYOff); const int pxX = (imgLoadModPosX + fXOff); const int pixIdx = (pxY * imgSizeX + pxX) * imgStride; - pxIdxes[tidx] = pxY >= 0 && pxY < imgSizeY && pxX >= 0 && pxX < imgSizeX ? pixIdx : -1; + pxIdxes[tidx] = pxY >= 0 && pxY < imgSizeY && pxX >= 0 && pxX < imgSizeX + ? pixIdx + : -1; } __syncthreads(); - const bool lastModule = my == mEndY - 1 && mx == mEndX - 1; - mNext = lastModule * m + !lastModule * ((my + (mx + 1 == mEndX)) * numModulesX + (mx + 1 == mEndX ? mStartX : mx + 1)); -// if (!lastModule) { -// const int mxNext = mx + 1 == mEndX ? mStartX : mx + 1; -// const int myNext = my + (mx + 1 == mEndX); -// mNext = myNext * numModulesX + mxNext; -// } + mNext = lastModule * m + + !lastModule * ((my + (mx + 1 == mEndX)) * numModulesX + + (mx + 1 == mEndX ? mStartX : mx + 1)); + // if (!lastModule) { + // const int mxNext = mx + 1 == mEndX ? mStartX : mx + 1; + // const int myNext = my + (mx + 1 == mEndX); + // mNext = myNext * numModulesX + mxNext; + // } for (int caseIdx = 0; caseIdx < numImages; caseIdx += preloadCases) { const bool lastBatch = caseIdx + preloadCases == numImages; -// const float* im = &images[caseIdx + preloadCases + pixIdx]; -// const float* ha = hidActs + !lastBatch * (caseIdx + preloadCases + m * numImages) + lastBatch * mNext * numImages; - const int hidActsOffset2 = hidActsOffset + !lastBatch * (caseIdx + preloadCases + m * numImages) + lastBatch * mNext * numImages; -// if (lastBatch) { -// ha = &hidActs[mNext * numImages]; -// } - - #pragma unroll - for (int y = 0; y < B_X*filtersPerThread; y += (B_X * B_Y) / preloadCases) { - shHidActs[loadY+y][loadX] = haPreload[y*preloadCases/(B_X*B_Y)]; + // const float* im = &images[caseIdx + preloadCases + + // pixIdx]; const float* ha = hidActs + !lastBatch * + // (caseIdx + preloadCases + m * numImages) + lastBatch * + // mNext * numImages; + const int hidActsOffset2 = + hidActsOffset + + !lastBatch * (caseIdx + preloadCases + m * numImages) + + lastBatch * mNext * numImages; + // if (lastBatch) { + // ha = &hidActs[mNext * numImages]; + // } + +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + shHidActs[loadY + y][loadX] = + haPreload[y * preloadCases / (B_X * B_Y)]; } - /* ================================================================================== - * Iteration 0 - * ================================================================================== - */ - #pragma unroll +/* ================================================================================== + * Iteration 0 + * ================================================================================== + */ +#pragma unroll for (int y = 0; y < B_Y * pixelCache; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_Y * pixelCache) { - #pragma unroll + // Make sure number of rows in the array is divisible by number of + // rows filled per iteration + if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || + y + loadY < B_Y * pixelCache) { +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX] = 0; + shImages[loadY + y + c * pixelCache * B_Y][loadX] = 0; } } } - #pragma unroll +#pragma unroll for (int y = 0; y < B_Y * pixelCache; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_Y * pixelCache) { - const int pxIdx = 0 * B_Y + loadY + y; // pixel idx in filter - const int pixIdx = pxIdxes[pxIdx];//(pxY * imgSizeX + pxX) * imgStride; - if (pixIdx >= 0 && pxIdx + blockPixelOffset < filterPixels && (!checkCaseBounds || caseIdx + loadX < numImages)) { - #pragma unroll + // Make sure number of rows in the array is divisible by number of + // rows filled per iteration + if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || + y + loadY < B_Y * pixelCache) { + const int pxIdx = 0 * B_Y + loadY + y; // pixel idx in filter + const int pixIdx = + pxIdxes[pxIdx]; //(pxY * imgSizeX + pxX) * imgStride; + if (pixIdx >= 0 && pxIdx + blockPixelOffset < filterPixels && + (!checkCaseBounds || caseIdx + loadX < numImages)) { +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX] = tex1Dfetch(images, imgOffset + caseIdx + c * imgPixels * imgStride + pixIdx); + shImages[loadY + y + c * pixelCache * B_Y][loadX] = + tex1Dfetch( + images, + imgOffset + caseIdx + + c * imgPixels * imgStride + + pixIdx); } } } @@ -540,32 +599,44 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3(cudaTextureObj __syncthreads(); - /* ================================================================================== - * Iteration 1 - * ================================================================================== - */ - #pragma unroll +/* ================================================================================== + * Iteration 1 + * ================================================================================== + */ +#pragma unroll for (int y = 0; y < B_Y * pixelCache; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_Y * pixelCache) { -// const int pxIdx = 2 * B_Y + loadY + y; // pixel idx in filter - #pragma unroll + // Make sure number of rows in the array is divisible by number of + // rows filled per iteration + if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || + y + loadY < B_Y * pixelCache) { + // const int pxIdx = 2 * B_Y + loadY + y; + // // pixel idx in filter +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX] = 0; + shImages[loadY + y + c * pixelCache * B_Y][loadX] = 0; } } } - #pragma unroll +#pragma unroll for (int y = 0; y < B_Y * pixelCache; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_Y * pixelCache) { - const int pxIdx = 2 * B_Y + loadY + y; // pixel idx in filter - const int pixIdx = pxIdxes[pxIdx];//(pxY * imgSizeX + pxX) * imgStride; - if (pixIdx >= 0 && pxIdx + blockPixelOffset < filterPixels && (!checkCaseBounds || caseIdx + loadX < numImages)) { - #pragma unroll + // Make sure number of rows in the array is divisible by number of + // rows filled per iteration + if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || + y + loadY < B_Y * pixelCache) { + const int pxIdx = 2 * B_Y + loadY + y; // pixel idx in filter + const int pixIdx = + pxIdxes[pxIdx]; //(pxY * imgSizeX + pxX) * imgStride; + if (pixIdx >= 0 && pxIdx + blockPixelOffset < filterPixels && + (!checkCaseBounds || caseIdx + loadX < numImages)) { +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX] = tex1Dfetch(images, imgOffset + caseIdx + c * imgPixels * imgStride + pixIdx); + shImages[loadY + y + c * pixelCache * B_Y][loadX] = + tex1Dfetch( + images, + imgOffset + caseIdx + + c * imgPixels * imgStride + + pixIdx); } } } @@ -576,33 +647,38 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3(cudaTextureObj WA_C3_LOOP2(2); __syncthreads(); - } } } if (scale) { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { if (blockPixelOffset + p * B_Y + threadIdx.y < filterPixels) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[p * B_Y * numFilters + c * filterPixels * numFilters + f * B_X] = scaleTargets * targets[p * B_Y * numFilters + c * filterPixels * numFilters + f * B_X] + scaleOutputs * prod[c][p][f]; + targets[p * B_Y * numFilters + c * filterPixels * numFilters + + f * B_X] = + scaleTargets * targets[p * B_Y * numFilters + + c * filterPixels * numFilters + + f * B_X] + + scaleOutputs * prod[c][p][f]; } } } } } else { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { if (blockPixelOffset + p * B_Y + threadIdx.y < filterPixels) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[p * B_Y * numFilters + c * filterPixels * numFilters + f * B_X] = scaleOutputs * prod[c][p][f]; + targets[p * B_Y * numFilters + c * filterPixels * numFilters + + f * B_X] = scaleOutputs * prod[c][p][f]; } } } @@ -610,26 +686,32 @@ __global__ void conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3(cudaTextureObj } } - /* * images: (numImgColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) * - * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, numFilters) + * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, + * numFilters) */ -template -__launch_bounds__(128, 4) -__global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16(cudaTextureObject_t images, cudaTextureObject_t hidActs, float* targets, - const int numImages, const int numFilters, - const int numModulesY, const int numModulesX, - const int imgSizeY, const int imgSizeX, const int filterSize, - const int paddingStart, const int moduleStride, const int imgStride, - const int numImgColors, const int numGroups, const int sumWidth, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shImages[colorsPerThread * B_Y][preloadCases]; // preload preloadCases cases - __shared__ float shHidActs[filtersPerThread * B_X][preloadCases + 1]; // preload preloadCases cases of B_X hidacts - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +template < + int B_Y, int B_X, int filtersPerThread, int colorsPerThread, int preloadCases, + bool scale> +__launch_bounds__(128, 4) __global__ + void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16( + cudaTextureObject_t images, cudaTextureObject_t hidActs, float* targets, + const int numImages, const int numFilters, const int numModulesY, + const int numModulesX, const int imgSizeY, const int imgSizeX, + const int filterSize, const int paddingStart, const int moduleStride, + const int imgStride, const int numImgColors, const int numGroups, + const int sumWidth, const float scaleTargets, + const float scaleOutputs) { + __shared__ float shImages[colorsPerThread * B_Y] + [preloadCases]; // preload preloadCases cases + __shared__ float + shHidActs[filtersPerThread * B_X] + [preloadCases + 1]; // preload preloadCases cases of B_X hidacts + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); const int tidx = B_X * threadIdx.y + threadIdx.x; @@ -642,7 +724,7 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16(cu const int blockModuleChunkIdx = blockIdx.x / numFilterBlocks; const int numModuleChunksX = DIVUP(numModulesX, sumWidth); -// const int numModuleChunksY = DIVUP(numModulesY, sumWidth); + // const int numModuleChunksY = DIVUP(numModulesY, sumWidth); const int blockModuleChunkX = blockModuleChunkIdx % numModuleChunksX; const int blockModuleChunkY = blockModuleChunkIdx / numModuleChunksX; @@ -650,7 +732,7 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16(cu const int blockModuleStartX = blockModuleChunkX * sumWidth; const int blockModuleStartY = blockModuleChunkY * sumWidth; -// const int moduleIdx = partialSum * outputModuleIdx; + // const int moduleIdx = partialSum * outputModuleIdx; const int blockFilterIdx = filtersPerThread * B_X * (blockIdx.x % numFilterBlocks); const int numModules = numModulesY * numModulesX; @@ -658,49 +740,55 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16(cu const int blockGroupIdx = blockFilterIdx / numFiltersPerGroup; const int numFilterColors = numImgColors / numGroups; - const int blockPixelOffset = blockIdx.z; // pixel idx in filter - const int blockPixelY = blockPixelOffset / filterSize, blockPixelX = blockPixelOffset % filterSize; - const int blockFilterColorIdx = blockIdx.y * B_Y * colorsPerThread; + const int blockPixelOffset = blockIdx.z; // pixel idx in filter + const int blockPixelY = blockPixelOffset / filterSize, + blockPixelX = blockPixelOffset % filterSize; + const int blockFilterColorIdx = blockIdx.y * B_Y * colorsPerThread; const int imgColorIdx = blockFilterColorIdx + blockGroupIdx * numFilterColors; const int imgOffset = (imgColorIdx + loadY) * imgPixels * imgStride + loadX; -// images += (imgColorIdx + loadY) * imgPixels * imgStride + loadX; - const int hidActsOffset = blockFilterIdx * numImages * numModules - + loadY * numImages * numModules - + loadX; -// -// hidActs += -// blockFilterIdx * numImages * numModules -// + loadY * numImages * numModules -// + loadX; - - targets += blockModuleChunkIdx * numFilters * filterPixels * numFilterColors - + (blockFilterColorIdx + threadIdx.y) * filterPixels * numFilters - + blockPixelOffset * numFilters - + blockFilterIdx - + threadIdx.x; -// if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return; - - const int mStartX = max(blockModuleStartX, DIVUP(-blockPixelX - paddingStart, moduleStride)); - const int mStartY = max(blockModuleStartY, DIVUP(-blockPixelY - paddingStart, moduleStride)); - const int mEndX = min(numModulesX, min(blockModuleStartX + sumWidth, DIVUP(imgSizeX - blockPixelX - paddingStart, moduleStride))); - const int mEndY = min(numModulesY, min(blockModuleStartY + sumWidth, DIVUP(imgSizeY - blockPixelY - paddingStart, moduleStride))); - -// if (mStartY == mEndY || mStartX == mEndX) { -// return; -// } -// const bool doWork = mStartY < mEndY && mStartX < mEndX; + // images += (imgColorIdx + loadY) * imgPixels * imgStride + loadX; + const int hidActsOffset = blockFilterIdx * numImages * numModules + + loadY * numImages * numModules + loadX; + // + // hidActs += + // blockFilterIdx * numImages * numModules + // + loadY * numImages * numModules + // + loadX; + + targets += blockModuleChunkIdx * numFilters * filterPixels * numFilterColors + + (blockFilterColorIdx + threadIdx.y) * filterPixels * numFilters + + blockPixelOffset * numFilters + blockFilterIdx + threadIdx.x; + // if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return; + + const int mStartX = + max(blockModuleStartX, DIVUP(-blockPixelX - paddingStart, moduleStride)); + const int mStartY = + max(blockModuleStartY, DIVUP(-blockPixelY - paddingStart, moduleStride)); + const int mEndX = + min(numModulesX, + min(blockModuleStartX + sumWidth, + DIVUP(imgSizeX - blockPixelX - paddingStart, moduleStride))); + const int mEndY = + min(numModulesY, + min(blockModuleStartY + sumWidth, + DIVUP(imgSizeY - blockPixelY - paddingStart, moduleStride))); + + // if (mStartY == mEndY || mStartX == mEndX) { + // return; + // } + // const bool doWork = mStartY < mEndY && mStartX < mEndX; float* shHidActLoad = &shHidActs[loadY][loadX]; float* shImgLoad = &shImages[loadY][loadX]; - float imPreload[preloadCases*colorsPerThread/B_X]; // [8] - float haPreload[preloadCases*filtersPerThread/B_Y]; // [8] + float imPreload[preloadCases * colorsPerThread / B_X]; // [8] + float haPreload[preloadCases * filtersPerThread / B_Y]; // [8] float prod[filtersPerThread][colorsPerThread]; - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { prod[f][c] = 0; } @@ -708,25 +796,27 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16(cu int pixIdx, pixIdxNext, m, mNext; conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16_setCoords( - mStartY, mStartX, paddingStart, numModulesX, moduleStride, - blockPixelY, blockPixelX, imgSizeX, imgStride, - pixIdx, m); + mStartY, mStartX, paddingStart, numModulesX, moduleStride, blockPixelY, + blockPixelX, imgSizeX, imgStride, pixIdx, m); - #pragma unroll +#pragma unroll for (int y = 0; y < B_Y * colorsPerThread; y += (B_X * B_Y) / preloadCases) { - // It's bizarre, but this is the fastest way I've found to get it not to load nonexistent pixels. - // All other ways cause crazy excessive register usage. - const int idx = (mStartY < mEndY && mStartX < mEndX) * (0 + y * imgPixels * imgStride + pixIdx); - imPreload[y * preloadCases/(B_X * B_Y)] = tex1Dfetch(images, imgOffset + idx); + // It's bizarre, but this is the fastest way I've found to get it not to load + // nonexistent pixels. All other ways cause crazy excessive register usage. + const int idx = (mStartY < mEndY && mStartX < mEndX) * + (0 + y * imgPixels * imgStride + pixIdx); + imPreload[y * preloadCases / (B_X * B_Y)] = + tex1Dfetch(images, imgOffset + idx); } - #pragma unroll +#pragma unroll for (int y = 0; y < B_X * filtersPerThread; y += (B_X * B_Y) / preloadCases) { // Almost certainly not necessary here. - const int idx = (mStartY < mEndY && mStartX < mEndX) * (0 + y * numImages * numModules + m * numImages); - haPreload[y * preloadCases / (B_X * B_Y)] = tex1Dfetch(hidActs, hidActsOffset + idx); + const int idx = (mStartY < mEndY && mStartX < mEndX) * + (0 + y * numImages * numModules + m * numImages); + haPreload[y * preloadCases / (B_X * B_Y)] = + tex1Dfetch(hidActs, hidActsOffset + idx); } - for (int my = mStartY; my < mEndY; my++) { for (int mx = mStartX; mx < mEndX; mx++) { int myNext = my, mxNext = mx; @@ -739,42 +829,46 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16(cu conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16_setCoords( myNext, mxNext, paddingStart, numModulesX, moduleStride, - blockPixelY, blockPixelX, imgSizeX, imgStride, - pixIdxNext, mNext); + blockPixelY, blockPixelX, imgSizeX, imgStride, pixIdxNext, mNext); for (int caseIdx = 0; caseIdx < numImages; caseIdx += preloadCases) { - - #pragma unroll - for (int y = 0; y < B_Y * colorsPerThread; y += (B_X * B_Y) / preloadCases) { - shImgLoad[(y) * preloadCases] = imPreload[y * preloadCases / (B_X * B_Y)]; +#pragma unroll + for (int y = 0; y < B_Y * colorsPerThread; + y += (B_X * B_Y) / preloadCases) { + shImgLoad[(y)*preloadCases] = + imPreload[y * preloadCases / (B_X * B_Y)]; } -// const float* im = &images[caseIdx + preloadCases + pixIdx]; -// const float* ha = &hidActs[caseIdx + preloadCases + m * numImages]; + // const float* im = &images[caseIdx + preloadCases + + // pixIdx]; const float* ha = &hidActs[caseIdx + + // preloadCases + m * numImages]; int imgOffset2 = imgOffset + caseIdx + preloadCases + pixIdx; - int hidActsOffset2 = hidActsOffset + caseIdx + preloadCases + m * numImages; + int hidActsOffset2 = + hidActsOffset + caseIdx + preloadCases + m * numImages; if (caseIdx + preloadCases == numImages) { pixIdx = pixIdxNext; m = mNext; imgOffset2 = imgOffset + pixIdxNext; hidActsOffset2 = hidActsOffset + mNext * numImages; } - #pragma unroll - for (int y = 0; y < B_X * filtersPerThread; y += (B_X * B_Y) / preloadCases) { - shHidActLoad[y * (preloadCases + 1)] = haPreload[y * preloadCases / (B_X * B_Y)]; +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + shHidActLoad[y * (preloadCases + 1)] = + haPreload[y * preloadCases / (B_X * B_Y)]; } __syncthreads(); - #pragma unroll +#pragma unroll for (int z = 0; z < 8; ++z) { WA_IMLOAD_TX(z); WA_LOOP2(z); } - #pragma unroll +#pragma unroll for (int z = 0; z < 8; ++z) { WA_HALOAD_TX(z); - WA_LOOP2(z+8); + WA_LOOP2(z + 8); } __syncthreads(); } @@ -782,19 +876,23 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16(cu } if (scale) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[c * B_Y * filterPixels * numFilters + f * B_X] = scaleTargets * targets[c * B_Y * filterPixels * numFilters + f * B_X] + scaleOutputs * prod[f][c]; + targets[c * B_Y * filterPixels * numFilters + f * B_X] = + scaleTargets * + targets[c * B_Y * filterPixels * numFilters + f * B_X] + + scaleOutputs * prod[f][c]; } } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[c * B_Y * filterPixels * numFilters + f * B_X] = scaleOutputs * prod[f][c]; + targets[c * B_Y * filterPixels * numFilters + f * B_X] = + scaleOutputs * prod[f][c]; } } } @@ -804,21 +902,28 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16(cu * images: (numImgColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) * - * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, numFilters) + * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, + * numFilters) */ -template -__launch_bounds__(256, 2) -__global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32(cudaTextureObject_t images, cudaTextureObject_t hidActs, float* targets, - const int numImages, const int numFilters, - const int numModulesY, const int numModulesX, - const int imgSizeY, const int imgSizeX, const int filterSize, - const int paddingStart, const int moduleStride, const int imgStride, - const int numImgColors, const int numGroups, const int sumWidth, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shImages[colorsPerThread * B_Y][preloadCases]; // preload preloadCases cases - __shared__ float shHidActs[filtersPerThread * B_X][preloadCases + 1]; // preload preloadCases cases of B_X hidacts - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +template < + int B_Y, int B_X, int filtersPerThread, int colorsPerThread, int preloadCases, + bool scale> +__launch_bounds__(256, 2) __global__ + void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32( + cudaTextureObject_t images, cudaTextureObject_t hidActs, float* targets, + const int numImages, const int numFilters, const int numModulesY, + const int numModulesX, const int imgSizeY, const int imgSizeX, + const int filterSize, const int paddingStart, const int moduleStride, + const int imgStride, const int numImgColors, const int numGroups, + const int sumWidth, const float scaleTargets, + const float scaleOutputs) { + __shared__ float shImages[colorsPerThread * B_Y] + [preloadCases]; // preload preloadCases cases + __shared__ float + shHidActs[filtersPerThread * B_X] + [preloadCases + 1]; // preload preloadCases cases of B_X hidacts + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); const int tidx = B_X * threadIdx.y + threadIdx.x; @@ -831,7 +936,7 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32(cu const int blockModuleChunkIdx = blockIdx.x / numFilterBlocks; const int numModuleChunksX = DIVUP(numModulesX, sumWidth); -// const int numModuleChunksY = DIVUP(numModulesY, sumWidth); + // const int numModuleChunksY = DIVUP(numModulesY, sumWidth); const int blockModuleChunkX = blockModuleChunkIdx % numModuleChunksX; const int blockModuleChunkY = blockModuleChunkIdx / numModuleChunksX; @@ -839,7 +944,7 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32(cu const int blockModuleStartX = blockModuleChunkX * sumWidth; const int blockModuleStartY = blockModuleChunkY * sumWidth; -// const int moduleIdx = partialSum * outputModuleIdx; + // const int moduleIdx = partialSum * outputModuleIdx; const int blockFilterIdx = filtersPerThread * B_X * (blockIdx.x % numFilterBlocks); const int numModules = numModulesY * numModulesX; @@ -847,50 +952,56 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32(cu const int blockGroupIdx = blockFilterIdx / numFiltersPerGroup; const int numFilterColors = numImgColors / numGroups; - const int blockPixelOffset = blockIdx.z; // pixel idx in filter - const int blockPixelY = blockPixelOffset / filterSize, blockPixelX = blockPixelOffset % filterSize; - const int blockFilterColorIdx = blockIdx.y * B_Y * colorsPerThread; + const int blockPixelOffset = blockIdx.z; // pixel idx in filter + const int blockPixelY = blockPixelOffset / filterSize, + blockPixelX = blockPixelOffset % filterSize; + const int blockFilterColorIdx = blockIdx.y * B_Y * colorsPerThread; const int imgColorIdx = blockFilterColorIdx + blockGroupIdx * numFilterColors; const int imgOffset = (imgColorIdx + loadY) * imgPixels * imgStride + loadX; - const int hidActsOffset = blockFilterIdx * numImages * numModules - + loadY * numImages * numModules - + loadX; -// images += (imgColorIdx + loadY) * imgPixels * imgStride + loadX; -// -// hidActs += -// blockFilterIdx * numImages * numModules -// + loadY * numImages * numModules -// + loadX; - - targets += blockModuleChunkIdx * numFilters * filterPixels * numFilterColors - + (blockFilterColorIdx + threadIdx.y) * filterPixels * numFilters - + blockPixelOffset * numFilters - + blockFilterIdx - + threadIdx.x; -// if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return; - - const int mStartX = max(blockModuleStartX, DIVUP(-blockPixelX - paddingStart, moduleStride)); - const int mStartY = max(blockModuleStartY, DIVUP(-blockPixelY - paddingStart, moduleStride)); - const int mEndX = min(numModulesX, min(blockModuleStartX + sumWidth, DIVUP(imgSizeX - blockPixelX - paddingStart, moduleStride))); - const int mEndY = min(numModulesY, min(blockModuleStartY + sumWidth, DIVUP(imgSizeY - blockPixelY - paddingStart, moduleStride))); - -// if (mStartY == mEndY || mStartX == mEndX) { -// return; -// } + const int hidActsOffset = blockFilterIdx * numImages * numModules + + loadY * numImages * numModules + loadX; + // images += (imgColorIdx + loadY) * imgPixels * imgStride + loadX; + // + // hidActs += + // blockFilterIdx * numImages * numModules + // + loadY * numImages * numModules + // + loadX; + + targets += blockModuleChunkIdx * numFilters * filterPixels * numFilterColors + + (blockFilterColorIdx + threadIdx.y) * filterPixels * numFilters + + blockPixelOffset * numFilters + blockFilterIdx + threadIdx.x; + // if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return; + + const int mStartX = + max(blockModuleStartX, DIVUP(-blockPixelX - paddingStart, moduleStride)); + const int mStartY = + max(blockModuleStartY, DIVUP(-blockPixelY - paddingStart, moduleStride)); + const int mEndX = + min(numModulesX, + min(blockModuleStartX + sumWidth, + DIVUP(imgSizeX - blockPixelX - paddingStart, moduleStride))); + const int mEndY = + min(numModulesY, + min(blockModuleStartY + sumWidth, + DIVUP(imgSizeY - blockPixelY - paddingStart, moduleStride))); + + // if (mStartY == mEndY || mStartX == mEndX) { + // return; + // } const bool doWork = mStartY < mEndY && mStartX < mEndX; float* shHidActLoad = &shHidActs[loadY][loadX]; float* shImgLoad = &shImages[loadY][loadX]; - float imPreload[preloadCases*colorsPerThread/B_X]; // [6] - float haPreload[preloadCases*filtersPerThread/B_Y]; // [16] + float imPreload[preloadCases * colorsPerThread / B_X]; // [6] + float haPreload[preloadCases * filtersPerThread / B_Y]; // [16] float prod[filtersPerThread][colorsPerThread]; - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { prod[f][c] = 0; } @@ -898,24 +1009,26 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32(cu int pixIdx, pixIdxNext, m, mNext; conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16_setCoords( - mStartY, mStartX, paddingStart, numModulesX, moduleStride, - blockPixelY, blockPixelX, imgSizeX, imgStride, - pixIdx, m); + mStartY, mStartX, paddingStart, numModulesX, moduleStride, blockPixelY, + blockPixelX, imgSizeX, imgStride, pixIdx, m); if (doWork) { - #pragma unroll +#pragma unroll for (int y = 0; y < B_Y * colorsPerThread; y += (B_X * B_Y) / preloadCases) { - imPreload[y * preloadCases/(B_X * B_Y)] = tex1Dfetch(images, imgOffset + y * imgPixels * imgStride + pixIdx); + imPreload[y * preloadCases / (B_X * B_Y)] = tex1Dfetch( + images, imgOffset + y * imgPixels * imgStride + pixIdx); } - #pragma unroll +#pragma unroll for (int y = 0; y < B_X * filtersPerThread; y += (B_X * B_Y) / preloadCases) { - haPreload[y * preloadCases / (B_X * B_Y)] = tex1Dfetch(hidActs, hidActsOffset + y * numImages * numModules + m * numImages); + haPreload[y * preloadCases / (B_X * B_Y)] = tex1Dfetch( + hidActs, + hidActsOffset + y * numImages * numModules + m * numImages); } } -// if (mStartY > mEndY || mStartX > mEndX) { -// printf("crzy!!\n"); -// } + // if (mStartY > mEndY || mStartX > mEndX) { + // printf("crzy!!\n"); + // } for (int my = mStartY; my < mEndY; my++) { for (int mx = mStartX; mx < mEndX; mx++) { @@ -929,26 +1042,31 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32(cu conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16_setCoords( myNext, mxNext, paddingStart, numModulesX, moduleStride, - blockPixelY, blockPixelX, imgSizeX, imgStride, - pixIdxNext, mNext); + blockPixelY, blockPixelX, imgSizeX, imgStride, pixIdxNext, mNext); for (int caseIdx = 0; caseIdx < numImages; caseIdx += preloadCases) { - #pragma unroll - for (int y = 0; y < B_Y * colorsPerThread; y += (B_X * B_Y) / preloadCases) { - shImgLoad[(y) * preloadCases] = imPreload[y * preloadCases / (B_X * B_Y)]; +#pragma unroll + for (int y = 0; y < B_Y * colorsPerThread; + y += (B_X * B_Y) / preloadCases) { + shImgLoad[(y)*preloadCases] = + imPreload[y * preloadCases / (B_X * B_Y)]; } - #pragma unroll - for (int y = 0; y < B_X * filtersPerThread; y += (B_X * B_Y) / preloadCases) { - shHidActLoad[y * (preloadCases + 1)] = haPreload[y * preloadCases / (B_X * B_Y)]; +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + shHidActLoad[y * (preloadCases + 1)] = + haPreload[y * preloadCases / (B_X * B_Y)]; } __syncthreads(); -// const float* im = &images[caseIdx + preloadCases + pixIdx]; -// const float* ha = &hidActs[caseIdx + preloadCases + m * numImages]; + // const float* im = &images[caseIdx + preloadCases + + // pixIdx]; const float* ha = &hidActs[caseIdx + + // preloadCases + m * numImages]; int imgOffset2 = imgOffset + caseIdx + preloadCases + pixIdx; - int hidActsOffset2 = hidActsOffset + caseIdx + preloadCases + m * numImages; + int hidActsOffset2 = + hidActsOffset + caseIdx + preloadCases + m * numImages; if (caseIdx + preloadCases == numImages) { pixIdx = pixIdxNext; m = mNext; @@ -1020,19 +1138,23 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32(cu } if (scale) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[c * B_Y * filterPixels * numFilters + f * B_X] = scaleTargets * targets[c * B_Y * filterPixels * numFilters + f * B_X] + scaleOutputs * prod[f][c]; + targets[c * B_Y * filterPixels * numFilters + f * B_X] = + scaleTargets * + targets[c * B_Y * filterPixels * numFilters + f * B_X] + + scaleOutputs * prod[f][c]; } } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[c * B_Y * filterPixels * numFilters + f * B_X] = scaleOutputs * prod[f][c]; + targets[c * B_Y * filterPixels * numFilters + f * B_X] = + scaleOutputs * prod[f][c]; } } } @@ -1042,21 +1164,28 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32(cu * images: (numImgColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) * - * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, numFilters) + * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, + * numFilters) */ -template -__launch_bounds__(256, 2) -__global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16(cudaTextureObject_t images, cudaTextureObject_t hidActs, float* targets, - const int numImages, const int numFilters, - const int numModulesY, const int numModulesX, - const int imgSizeY, const int imgSizeX, const int filterSize, - const int paddingStart, const int moduleStride, const int imgStride, - const int numImgColors, const int numGroups, const int sumWidth, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shImages[colorsPerThread * B_Y][preloadCases]; // preload preloadCases cases - __shared__ float shHidActs[filtersPerThread * B_X][preloadCases + 1]; // preload preloadCases cases of B_X hidacts - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +template < + int B_Y, int B_X, int filtersPerThread, int colorsPerThread, int preloadCases, + bool scale> +__launch_bounds__(256, 2) __global__ + void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16( + cudaTextureObject_t images, cudaTextureObject_t hidActs, float* targets, + const int numImages, const int numFilters, const int numModulesY, + const int numModulesX, const int imgSizeY, const int imgSizeX, + const int filterSize, const int paddingStart, const int moduleStride, + const int imgStride, const int numImgColors, const int numGroups, + const int sumWidth, const float scaleTargets, + const float scaleOutputs) { + __shared__ float shImages[colorsPerThread * B_Y] + [preloadCases]; // preload preloadCases cases + __shared__ float + shHidActs[filtersPerThread * B_X] + [preloadCases + 1]; // preload preloadCases cases of B_X hidacts + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); const int tidx = B_X * threadIdx.y + threadIdx.x; @@ -1069,7 +1198,7 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16(cu const int blockModuleChunkIdx = blockIdx.x / numFilterBlocks; const int numModuleChunksX = DIVUP(numModulesX, sumWidth); -// const int numModuleChunksY = DIVUP(numModulesY, sumWidth); + // const int numModuleChunksY = DIVUP(numModulesY, sumWidth); const int blockModuleChunkX = blockModuleChunkIdx % numModuleChunksX; const int blockModuleChunkY = blockModuleChunkIdx / numModuleChunksX; @@ -1077,7 +1206,7 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16(cu const int blockModuleStartX = blockModuleChunkX * sumWidth; const int blockModuleStartY = blockModuleChunkY * sumWidth; -// const int moduleIdx = partialSum * outputModuleIdx; + // const int moduleIdx = partialSum * outputModuleIdx; const int blockFilterIdx = filtersPerThread * B_X * (blockIdx.x % numFilterBlocks); const int numModules = numModulesY * numModulesX; @@ -1085,46 +1214,52 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16(cu const int blockGroupIdx = blockFilterIdx / numFiltersPerGroup; const int numFilterColors = numImgColors / numGroups; - const int blockPixelOffset = blockIdx.z; // pixel idx in filter - const int blockPixelY = blockPixelOffset / filterSize, blockPixelX = blockPixelOffset % filterSize; - const int blockFilterColorIdx = blockIdx.y * B_Y * colorsPerThread; + const int blockPixelOffset = blockIdx.z; // pixel idx in filter + const int blockPixelY = blockPixelOffset / filterSize, + blockPixelX = blockPixelOffset % filterSize; + const int blockFilterColorIdx = blockIdx.y * B_Y * colorsPerThread; const int imgColorIdx = blockFilterColorIdx + blockGroupIdx * numFilterColors; const int imgOffset = (imgColorIdx + loadY) * imgPixels * imgStride + loadX; -// images += (imgColorIdx + loadY) * imgPixels * imgStride + loadX; - const int hidActsOffset = blockFilterIdx * numImages * numModules - + loadY * numImages * numModules - + loadX; -// -// hidActs += -// blockFilterIdx * numImages * numModules -// + loadY * numImages * numModules -// + loadX; - - targets += blockModuleChunkIdx * numFilters * filterPixels * numFilterColors - + (blockFilterColorIdx + threadIdx.y) * filterPixels * numFilters - + blockPixelOffset * numFilters - + blockFilterIdx - + threadIdx.x; -// if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return; - - const int mStartX = max(blockModuleStartX, DIVUP(-blockPixelX - paddingStart, moduleStride)); - const int mStartY = max(blockModuleStartY, DIVUP(-blockPixelY - paddingStart, moduleStride)); - const int mEndX = min(numModulesX, min(blockModuleStartX + sumWidth, DIVUP(imgSizeX - blockPixelX - paddingStart, moduleStride))); - const int mEndY = min(numModulesY, min(blockModuleStartY + sumWidth, DIVUP(imgSizeY - blockPixelY - paddingStart, moduleStride))); + // images += (imgColorIdx + loadY) * imgPixels * imgStride + loadX; + const int hidActsOffset = blockFilterIdx * numImages * numModules + + loadY * numImages * numModules + loadX; + // + // hidActs += + // blockFilterIdx * numImages * numModules + // + loadY * numImages * numModules + // + loadX; + + targets += blockModuleChunkIdx * numFilters * filterPixels * numFilterColors + + (blockFilterColorIdx + threadIdx.y) * filterPixels * numFilters + + blockPixelOffset * numFilters + blockFilterIdx + threadIdx.x; + // if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return; + + const int mStartX = + max(blockModuleStartX, DIVUP(-blockPixelX - paddingStart, moduleStride)); + const int mStartY = + max(blockModuleStartY, DIVUP(-blockPixelY - paddingStart, moduleStride)); + const int mEndX = + min(numModulesX, + min(blockModuleStartX + sumWidth, + DIVUP(imgSizeX - blockPixelX - paddingStart, moduleStride))); + const int mEndY = + min(numModulesY, + min(blockModuleStartY + sumWidth, + DIVUP(imgSizeY - blockPixelY - paddingStart, moduleStride))); const bool doWork = mStartY < mEndY && mStartX < mEndX; float* shHidActLoad = &shHidActs[loadY][loadX]; float* shImgLoad = &shImages[loadY][loadX]; - float imPreload[preloadCases*colorsPerThread/B_X]; // [4] - float haPreload[preloadCases*filtersPerThread/B_Y]; // [8] + float imPreload[preloadCases * colorsPerThread / B_X]; // [4] + float haPreload[preloadCases * filtersPerThread / B_Y]; // [8] float prod[filtersPerThread][colorsPerThread]; - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { prod[f][c] = 0; } @@ -1132,21 +1267,23 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16(cu int pixIdx, pixIdxNext, m, mNext; conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16_setCoords( - mStartY, mStartX, paddingStart, numModulesX, moduleStride, - blockPixelY, blockPixelX, imgSizeX, imgStride, - pixIdx, m); + mStartY, mStartX, paddingStart, numModulesX, moduleStride, blockPixelY, + blockPixelX, imgSizeX, imgStride, pixIdx, m); if (doWork && loadY < B_Y * colorsPerThread) { - #pragma unroll +#pragma unroll for (int y = 0; y < B_Y * colorsPerThread; y += (B_X * B_Y) / preloadCases) { - imPreload[y * preloadCases/(B_X * B_Y)] = tex1Dfetch(images, imgOffset + y * imgPixels * imgStride + pixIdx); + imPreload[y * preloadCases / (B_X * B_Y)] = tex1Dfetch( + images, imgOffset + y * imgPixels * imgStride + pixIdx); } } if (doWork && loadY < B_X * filtersPerThread) { - #pragma unroll +#pragma unroll for (int y = 0; y < B_X * filtersPerThread; y += (B_X * B_Y) / preloadCases) { - haPreload[y * preloadCases / (B_X * B_Y)] = tex1Dfetch(hidActs, hidActsOffset + y * numImages * numModules + m * numImages); + haPreload[y * preloadCases / (B_X * B_Y)] = tex1Dfetch( + hidActs, + hidActsOffset + y * numImages * numModules + m * numImages); } } @@ -1162,37 +1299,42 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16(cu conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16_setCoords( myNext, mxNext, paddingStart, numModulesX, moduleStride, - blockPixelY, blockPixelX, imgSizeX, imgStride, - pixIdxNext, mNext); + blockPixelY, blockPixelX, imgSizeX, imgStride, pixIdxNext, mNext); for (int caseIdx = 0; caseIdx < numImages; caseIdx += preloadCases) { - -// const float* im = &images[caseIdx + preloadCases + pixIdx]; + // const float* im = &images[caseIdx + preloadCases + + // pixIdx]; int imgOffset2 = imgOffset + caseIdx + preloadCases + pixIdx; - int hidActsOffset2 = hidActsOffset + caseIdx + preloadCases + m * numImages; -// const float* ha = &hidActs[caseIdx + preloadCases + m * numImages]; + int hidActsOffset2 = + hidActsOffset + caseIdx + preloadCases + m * numImages; + // const float* ha = &hidActs[caseIdx + preloadCases + m + // * numImages]; if (caseIdx + preloadCases == numImages) { pixIdx = pixIdxNext; m = mNext; -// im = &images[pixIdxNext]; + // im = &images[pixIdxNext]; imgOffset2 = imgOffset + pixIdxNext; hidActsOffset2 = hidActsOffset + mNext * numImages; -// ha = &hidActs[mNext * numImages]; + // ha = &hidActs[mNext * numImages]; } if (loadY < B_Y * colorsPerThread) { - #pragma unroll - for (int y = 0; y < B_Y * colorsPerThread; y += (B_X * B_Y) / preloadCases) { - shImgLoad[(y) * preloadCases] = imPreload[y * preloadCases / (B_X * B_Y)]; +#pragma unroll + for (int y = 0; y < B_Y * colorsPerThread; + y += (B_X * B_Y) / preloadCases) { + shImgLoad[(y)*preloadCases] = + imPreload[y * preloadCases / (B_X * B_Y)]; } } if (loadY < B_X * filtersPerThread) { - #pragma unroll - for (int y = 0; y < B_X * filtersPerThread; y += (B_X * B_Y) / preloadCases) { - shHidActLoad[y * (preloadCases + 1)] = haPreload[y * preloadCases / (B_X * B_Y)]; +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + shHidActLoad[y * (preloadCases + 1)] = + haPreload[y * preloadCases / (B_X * B_Y)]; } } @@ -1233,49 +1375,59 @@ __global__ void conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16(cu } if (scale) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[c * B_Y * filterPixels * numFilters + f * B_X] = scaleTargets * targets[c * B_Y * filterPixels * numFilters + f * B_X] + scaleOutputs * prod[f][c]; + targets[c * B_Y * filterPixels * numFilters + f * B_X] = + scaleTargets * + targets[c * B_Y * filterPixels * numFilters + f * B_X] + + scaleOutputs * prod[f][c]; } } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[c * B_Y * filterPixels * numFilters + f * B_X] = scaleOutputs * prod[f][c]; + targets[c * B_Y * filterPixels * numFilters + f * B_X] = + scaleOutputs * prod[f][c]; } } } } -std::pair getWeightActsOutputSize(int numModulesY, int numModulesX, int numFilterColors, - int filterSize, int numFilters, int sumWidth) { +std::pair getWeightActsOutputSize( + int numModulesY, int numModulesX, int numFilterColors, int filterSize, + int numFilters, int sumWidth) { const int outputModuleChunksX = DIVUP(numModulesX, sumWidth); const int outputModuleChunksY = DIVUP(numModulesY, sumWidth); const int outputModuleChunks = outputModuleChunksX * outputModuleChunksY; - return std::pair(outputModuleChunks * numFilterColors * filterSize * filterSize, numFilters); + return std::pair( + outputModuleChunks * numFilterColors * filterSize * filterSize, numFilters); } /* * images: (numImgColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModules, numImages) * - * targets: (numModuleY*numModulesX/partialSum, numFilterColors, filterPixels, numFilters) + * targets: (numModuleY*numModulesX/partialSum, numFilterColors, filterPixels, + * numFilters) * - * TODO: you can get a slight speed boost for local non-convolutional units by writing special - * routines for partialSum = 1. But I dunno if the code duplication is worth it... + * TODO: you can get a slight speed boost for local non-convolutional units by writing + * special routines for partialSum = 1. But I dunno if the code duplication is worth + * it... * * Note: all of these convolution routines are optimized for the case when * the number of images (i.e. the minibatch size) is a multiple of 128. * Other batch sizes will work, but but I made no attempt whatsoever * to make them work fast. */ -void _weightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int filterSize, int paddingStart, int moduleStride, int numImgColors, - int numGroups, int sumWidth, float scaleTargets, float scaleOutput) { +void _weightActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int filterSize, + int paddingStart, int moduleStride, int numImgColors, int numGroups, + int sumWidth, float scaleTargets, float scaleOutput) { int numFilterColors = numImgColors / numGroups; int imgStride = images.getStride(); int numImages = images.getNumCols(); @@ -1286,10 +1438,12 @@ void _weightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMat int numFiltersPerGroup = numFilters / numGroups; megdnn_assert_internal(numImgColors % numGroups == 0); - //megdnn_assert_internal(numFilters % (16*numGroups) == 0); - bool previous_limit = numFilters % (16*numGroups) == 0; + // megdnn_assert_internal(numFilters % (16*numGroups) == 0); + bool previous_limit = numFilters % (16 * numGroups) == 0; - megdnn_assert_internal(numGroups > 1 || (numImgColors > 0 /*&& (numImgColors <= 3 || numImgColors % 16 == 0)*/)); + megdnn_assert_internal( + numGroups > 1 || + (numImgColors > 0 /*&& (numImgColors <= 3 || numImgColors % 16 == 0)*/)); previous_limit &= numImgColors % 16 == 0; megdnn_assert_internal(numGroups == 1 || numFilterColors % 16 == 0); @@ -1300,15 +1454,18 @@ void _weightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMat int outputModuleChunksX = DIVUP(numModulesX, sumWidth); int outputModuleChunksY = DIVUP(numModulesY, sumWidth); int outputModuleChunks = outputModuleChunksX * outputModuleChunksY; -// partialSum = partialSum == 0 ? numModules : partialSum; + // partialSum = partialSum == 0 ? numModules : partialSum; -// megdnn_assert_internal(numModules % partialSum == 0); + // megdnn_assert_internal(numModules % partialSum == 0); megdnn_assert_internal(hidActs.getNumCols() == numImages); - // These routines don't handle the case when only part of the image is visited in the convolution + // These routines don't handle the case when only part of the image is visited in + // the convolution megdnn_assert_internal(paddingStart <= 0); - megdnn_assert_internal(paddingStart + (numModulesX-1)*moduleStride + filterSize >= imgSizeX); - megdnn_assert_internal(paddingStart + (numModulesY-1)*moduleStride + filterSize >= imgSizeY); + megdnn_assert_internal( + paddingStart + (numModulesX - 1) * moduleStride + filterSize >= imgSizeX); + megdnn_assert_internal( + paddingStart + (numModulesY - 1) * moduleStride + filterSize >= imgSizeY); megdnn_assert_internal(moduleStride <= filterSize); megdnn_assert_internal(numModules * numFilters == hidActs.getNumRows()); @@ -1329,24 +1486,29 @@ void _weightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMat // These values work relatively well, but not optimal for all problems. if (numFilterColors > 3) { filtersPerThread = numFiltersPerGroup % 64 == 0 ? 4 - : numFiltersPerGroup % 32 == 0 ? 2 - : 1; + : numFiltersPerGroup % 32 == 0 ? 2 + : 1; colorsPerThread = numFilterColors % 64 == 0 ? 8 : numFilterColors % 48 == 0 ? 6 : numFilterColors % 32 == 0 ? 8 - : 4; + : 4; by = (numFilterColors / colorsPerThread) % 8 == 0 ? 8 : 4; bx = numFiltersPerGroup % 128 == 0 ? 32 : 16; preloadCases = filtersPerThread * colorsPerThread < 32 ? 32 : 16; - blocks = dim3(outputModuleChunks * DIVUP(numFilters,bx*filtersPerThread), DIVUP(numFilterColors, (by*colorsPerThread)), filterPixels); - - //megdnn_assert_internal(numFilterColors % (by*colorsPerThread) == 0); - previous_limit &= numFilterColors % (by*colorsPerThread) == 0; - - } else { // This is ugly but it's nice to spell it out clearly - megdnn_assert_internal(numGroups == 1); // Just for sanity - // NOTE: these things are only optimized for colors = 3. I didn't really test other cases. - if (numFilters % 64 == 0) { // TODO: having a separate case for 128 would make things faster, but I probably don't care about 128 + blocks = + dim3(outputModuleChunks * DIVUP(numFilters, bx * filtersPerThread), + DIVUP(numFilterColors, (by * colorsPerThread)), filterPixels); + + // megdnn_assert_internal(numFilterColors % (by*colorsPerThread) == 0); + previous_limit &= numFilterColors % (by * colorsPerThread) == 0; + + } else { // This is ugly but it's nice to spell it out clearly + megdnn_assert_internal(numGroups == 1); // Just for sanity + // NOTE: these things are only optimized for colors = 3. I didn't really test + // other cases. + if (numFilters % 64 == + 0) { // TODO: having a separate case for 128 would make things faster, but + // I probably don't care about 128 filtersPerThread = 4; pixelsPerThread = 2; by = 16; @@ -1364,23 +1526,28 @@ void _weightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMat by = 8; bx = 16; preloadCases = 16; - } else { // This case is completely untested. It might be really slow. But no time now. + } else { // This case is completely untested. It might be really slow. But no + // time now. filtersPerThread = 1; pixelsPerThread = 16; by = 16; bx = 16; preloadCases = 32; } - blocks = dim3(outputModuleChunks * DIVUP(numFilters,bx*filtersPerThread), DIVUP(filterPixels, by*pixelsPerThread)); + blocks = + dim3(outputModuleChunks * DIVUP(numFilters, bx * filtersPerThread), + DIVUP(filterPixels, by * pixelsPerThread)); } megdnn_assert_internal((by * bx) % preloadCases == 0); - //megdnn_assert_internal(numFilters % (bx * filtersPerThread) == 0); + // megdnn_assert_internal(numFilters % (bx * filtersPerThread) == 0); previous_limit &= numFilters % (bx * filtersPerThread) == 0; threads = dim3(bx, by); bool checkCaseBounds = numImages % preloadCases != 0; bool scale = scaleTargets != 0; - std::pair targetSize = getWeightActsOutputSize(numModulesY, numModulesX, numFilterColors, filterSize, numFilters, sumWidth); + std::pair targetSize = getWeightActsOutputSize( + numModulesY, numModulesX, numFilterColors, filterSize, numFilters, + sumWidth); if (!scale) { targets.resize(targetSize.first, targetSize.second); } else { @@ -1390,288 +1557,860 @@ void _weightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMat if (scale == false) { if (checkCaseBounds == false) { - if (numFilterColors > 3) { + if (numFilterColors > 3) { if (numFilterColors % 64 == 0) { if (numFiltersPerGroup % 128 == 0) { if (previous_limit) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16< 8, 32, 4, 8, 16, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16< 8, 32, 4, 8, 16, false ><<>>(images.getTextureObject(), hidActs.getTextureObject(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16< + 8, 32, 4, 8, 16, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_8_r_16< + 8, 32, 4, 8, 16, false> + <<>>( + images.getTextureObject(), + hidActs.getTextureObject(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, + imgSizeX, filterSize, paddingStart, + moduleStride, imgStride, numImgColors, + numGroups, sumWidth, scaleTargets, + scaleOutput); } else { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16< 8, 32, 4, 8, 16, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 2, 8, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16< + 8, 32, 4, 8, 16, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 2, 8, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, + imgSizeX, filterSize, paddingStart, + moduleStride, imgStride, numImgColors, + numGroups, sumWidth, scaleTargets, + scaleOutput); } - } - else if (numFiltersPerGroup % 64 == 0) { + } else if (numFiltersPerGroup % 64 == 0) { if (previous_limit) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16< 8, 16, 4, 8, 16, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16< 8, 16, 4, 8, 16, false ><<>>(images.getTextureObject(), hidActs.getTextureObject(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16< + 8, 16, 4, 8, 16, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_16_f_4_c_8_r_16< + 8, 16, 4, 8, 16, false> + <<>>( + images.getTextureObject(), + hidActs.getTextureObject(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, + imgSizeX, filterSize, paddingStart, + moduleStride, imgStride, numImgColors, + numGroups, sumWidth, scaleTargets, + scaleOutput); } else { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 4, 8, 16, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 4, 8, 16, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 4, 8, 16, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 4, 8, 16, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, + imgSizeX, filterSize, paddingStart, + moduleStride, imgStride, numImgColors, + numGroups, sumWidth, scaleTargets, + scaleOutput); } - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 2, 8, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 2, 8, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 1, 8, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 1, 8, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - } - else if (numFilterColors % 48 == 0) { + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 2, 8, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 2, 8, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 1, 8, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 1, 8, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } + } else if (numFilterColors % 48 == 0) { if (numFiltersPerGroup % 128 == 0) { if (previous_limit) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32< 8, 32, 4, 6, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32< 8, 32, 4, 6, 32, false ><<>>(images.getTextureObject(), hidActs.getTextureObject(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32< + 8, 32, 4, 6, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_preload_ty_8_tx_32_f_4_c_6_r_32< + 8, 32, 4, 6, 32, false> + <<>>( + images.getTextureObject(), + hidActs.getTextureObject(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, + imgSizeX, filterSize, paddingStart, + moduleStride, imgStride, numImgColors, + numGroups, sumWidth, scaleTargets, + scaleOutput); } else { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 32, 4, 6, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 32, 4, 6, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 32, 4, 6, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 32, 4, 6, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, + imgSizeX, filterSize, paddingStart, + moduleStride, imgStride, numImgColors, + numGroups, sumWidth, scaleTargets, + scaleOutput); } - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 4, 6, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 4, 6, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 2, 6, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 2, 6, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 1, 6, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 1, 6, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - } - else if (numFilterColors % 32 == 0) { + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 4, 6, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 4, 6, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 2, 6, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 2, 6, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 1, 6, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 1, 6, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } + } else if (numFilterColors % 32 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 32, 4, 8, 16, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 32, 4, 8, 16, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 4, 8, 16, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 4, 8, 16, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 2, 8, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 2, 8, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 1, 8, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 1, 8, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - } - else if (numFilterColors % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 32, 4, 8, 16, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 32, 4, 8, 16, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 4, 8, 16, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 4, 8, 16, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 2, 8, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 2, 8, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 1, 8, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 1, 8, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } + } else if (numFilterColors % 1 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 32, 4, 4, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 32, 4, 4, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 4, 4, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 4, 4, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 2, 4, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 2, 4, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 1, 4, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 1, 4, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 32, 4, 4, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 32, 4, 4, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 4, 4, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 4, 4, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 2, 4, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 2, 4, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 1, 4, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 1, 4, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); } } - } - else if (numFilterColors <= 3) { + } else if (numFilterColors <= 3) { if (numFilterColors == 3) { if (numFiltersPerGroup % 64 == 0) { if (previous_limit) { - cudaFuncSetCacheConfig(conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3 < 16, 16, 2, 2, 4, 32, 3, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3 < 16, 16, 2, 2, 4, 32, 3, false, false ><<>>(images.getTextureObject(), hidActs.getTextureObject(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3< + 16, 16, 2, 2, 4, 32, 3, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_preload_pc_2_pt_2_f_4_r_32_c_3< + 16, 16, 2, 2, 4, 32, 3, false, false> + <<>>( + images.getTextureObject(), + hidActs.getTextureObject(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, + imgSizeX, filterSize, paddingStart, + moduleStride, imgStride, sumWidth, + scaleTargets, scaleOutput); } else { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 3, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 3, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 3, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 3, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, + imgSizeX, filterSize, paddingStart, + moduleStride, imgStride, sumWidth, + scaleTargets, scaleOutput); } - } - else if (numFiltersPerGroup % 48 == 0) { + } else if (numFiltersPerGroup % 48 == 0) { if (previous_limit) { - cudaFuncSetCacheConfig(conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3 < 16, 16, 2, 4, 3, 32, 3, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3 < 16, 16, 2, 4, 3, 32, 3, false, false ><<>>(images.getTextureObject(), hidActs.getTextureObject(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3< + 16, 16, 2, 4, 3, 32, 3, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_preload_pc_2_pt_4_f_3_r_32_c_3< + 16, 16, 2, 4, 3, 32, 3, false, false> + <<>>( + images.getTextureObject(), + hidActs.getTextureObject(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, + imgSizeX, filterSize, paddingStart, + moduleStride, imgStride, sumWidth, + scaleTargets, scaleOutput); } else { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 4, 3, 32, 3, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw< 16, 16, 2, 4, 3, 32, 3, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 3, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 3, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, + imgSizeX, filterSize, paddingStart, + moduleStride, imgStride, sumWidth, + scaleTargets, scaleOutput); } - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 3, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 3, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 3, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 3, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - } - else if (numFilterColors == 2) { + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 8, 16, 2, 2, 2, 16, 3, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 8, 16, 2, 2, 2, 16, 3, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 3, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 3, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } + } else if (numFilterColors == 2) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 2, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 2, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 4, 3, 32, 2, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 4, 3, 32, 2, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 2, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 2, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 2, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 2, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - } - else if (numFilterColors == 1) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 2, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 2, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 2, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 2, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 8, 16, 2, 2, 2, 16, 2, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 8, 16, 2, 2, 2, 16, 2, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 2, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 2, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } + } else if (numFilterColors == 1) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 1, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 1, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 4, 3, 32, 1, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 4, 3, 32, 1, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 1, false, false >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 1, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 1, false, false >,cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 1, false, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 1, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 1, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 1, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 1, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 8, 16, 2, 2, 2, 16, 1, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 8, 16, 2, 2, 2, 16, 1, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 1, false, false>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 1, false, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); } } } - } - else if (checkCaseBounds == true) { + } else if (checkCaseBounds == true) { if (numFilterColors > 3) { if (numFilterColors % 64 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 32, 4, 8, 16, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 32, 4, 8, 16, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 4, 8, 16, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 4, 8, 16, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 2, 8, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 2, 8, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 1, 8, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 1, 8, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - } - else if (numFilterColors % 48 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 32, 4, 8, 16, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 32, 4, 8, 16, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 4, 8, 16, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 4, 8, 16, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 2, 8, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 2, 8, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 1, 8, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 1, 8, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } + } else if (numFilterColors % 48 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 32, 4, 6, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 32, 4, 6, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 4, 6, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 4, 6, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 2, 6, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 2, 6, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 8, 16, 1, 6, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 8, 16, 1, 6, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - } - else if (numFilterColors % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 32, 4, 6, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 32, 4, 6, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 4, 6, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 4, 6, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 2, 6, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 2, 6, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 8, 16, 1, 6, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<8, 16, 1, 6, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } + } else if (numFilterColors % 32 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 32, 4, 8, 16, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 32, 4, 8, 16, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 4, 8, 16, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 4, 8, 16, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 2, 8, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 2, 8, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 1, 8, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 1, 8, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - } - else if (numFilterColors % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 32, 4, 8, 16, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 32, 4, 8, 16, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 4, 8, 16, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 4, 8, 16, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 2, 8, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 2, 8, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 1, 8, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 1, 8, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } + } else if (numFilterColors % 1 == 0) { if (numFiltersPerGroup % 128 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 32, 4, 4, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 32, 4, 4, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 4, 4, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 4, 4, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 2, 4, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 2, 4, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_mc_mf_kepler_sw < 4, 16, 1, 4, 32, false >, cudaFuncCachePreferShared); - conv_weight_acts_mc_mf_kepler_sw < 4, 16, 1, 4, 32, false ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, numImgColors, numGroups, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 32, 4, 4, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 32, 4, 4, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 64 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 4, 4, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 4, 4, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 2, 4, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 2, 4, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_mc_mf_kepler_sw< + 4, 16, 1, 4, 32, false>, + cudaFuncCachePreferShared); + conv_weight_acts_mc_mf_kepler_sw<4, 16, 1, 4, 32, false> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, numImgColors, numGroups, sumWidth, + scaleTargets, scaleOutput); } } - } - else if (numFilterColors <= 3) { + } else if (numFilterColors <= 3) { if (numFilterColors == 3) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 3, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 3, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 4, 3, 32, 3, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 4, 3, 32, 3, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 3, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 3, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 3, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 3, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - } - else if (numFilterColors == 2) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 3, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 3, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 3, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 3, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 8, 16, 2, 2, 2, 16, 3, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw<8, 16, 2, 2, 2, 16, 3, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 3, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 3, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } + } else if (numFilterColors == 2) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 2, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 2, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 4, 3, 32, 2, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 4, 3, 32, 2, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 2, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 2, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 2, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 2, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - } - else if (numFilterColors == 1) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 2, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 2, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 2, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 2, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 8, 16, 2, 2, 2, 16, 2, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw<8, 16, 2, 2, 2, 16, 2, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 2, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 2, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } + } else if (numFilterColors == 1) { if (numFiltersPerGroup % 64 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 1, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 2, 4, 32, 1, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 48 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 4, 3, 32, 1, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 4, 3, 32, 1, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 32 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 1, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 8, 16, 2, 2, 2, 16, 1, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); - } - else if (numFiltersPerGroup % 1 == 0) { - cudaFuncSetCacheConfig(conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 1, false, true >, cudaFuncCachePreferShared); - conv_weight_acts_c_kepler_sw < 16, 16, 2, 16, 1, 32, 1, false, true ><<>>(images.getDevData(), hidActs.getDevData(), targets.getDevData(), numImages, numFilters, numModulesY, numModulesX, imgSizeY, imgSizeX, filterSize, paddingStart, moduleStride, imgStride, sumWidth, scaleTargets, scaleOutput); + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 1, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 2, 4, 32, 1, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 48 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 1, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 4, 3, 32, 1, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 32 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 8, 16, 2, 2, 2, 16, 1, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw<8, 16, 2, 2, 2, 16, 1, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); + } else if (numFiltersPerGroup % 1 == 0) { + cudaFuncSetCacheConfig( + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 1, false, true>, + cudaFuncCachePreferShared); + conv_weight_acts_c_kepler_sw< + 16, 16, 2, 16, 1, 32, 1, false, true> + <<>>( + images.getDevData(), hidActs.getDevData(), + targets.getDevData(), numImages, numFilters, + numModulesY, numModulesX, imgSizeY, imgSizeX, + filterSize, paddingStart, moduleStride, + imgStride, sumWidth, scaleTargets, scaleOutput); } } } @@ -1681,28 +2420,47 @@ void _weightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMat getLastCudaError("weightActs: kernel execution failed"); } -void convWeightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int filterSize, int paddingStart, int moduleStride, int numImgColors, int numGroups, int partialSum) { - _weightActs(stream, images, hidActs, targets, imgSizeY, numModulesY, numModulesX, filterSize, paddingStart, moduleStride, numImgColors, numGroups, partialSum, 0, 1); +void convWeightActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int filterSize, + int paddingStart, int moduleStride, int numImgColors, int numGroups, + int partialSum) { + _weightActs( + stream, images, hidActs, targets, imgSizeY, numModulesY, numModulesX, + filterSize, paddingStart, moduleStride, numImgColors, numGroups, partialSum, + 0, 1); } -void convWeightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int filterSize, int paddingStart, int moduleStride, int numImgColors, int numGroups, int partialSum, - float scaleTargets, float scaleOutput) { - _weightActs(stream, images, hidActs, targets, imgSizeY, numModulesY, numModulesX, filterSize, paddingStart, moduleStride, numImgColors, numGroups, partialSum, scaleTargets, scaleOutput); +void convWeightActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int filterSize, + int paddingStart, int moduleStride, int numImgColors, int numGroups, + int partialSum, float scaleTargets, float scaleOutput) { + _weightActs( + stream, images, hidActs, targets, imgSizeY, numModulesY, numModulesX, + filterSize, paddingStart, moduleStride, numImgColors, numGroups, partialSum, + scaleTargets, scaleOutput); } -void localWeightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int filterSize, int paddingStart, int moduleStride, int numImgColors, int numGroups) { - _weightActs(stream, images, hidActs, targets, imgSizeY, numModulesY, numModulesX, filterSize, paddingStart, moduleStride, numImgColors, numGroups, 1, 0, 1); +void localWeightActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int filterSize, + int paddingStart, int moduleStride, int numImgColors, int numGroups) { + _weightActs( + stream, images, hidActs, targets, imgSizeY, numModulesY, numModulesX, + filterSize, paddingStart, moduleStride, numImgColors, numGroups, 1, 0, 1); } -void localWeightActs(cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, - int imgSizeY, int numModulesY, int numModulesX, int filterSize, int paddingStart, int moduleStride, - int numImgColors, int numGroups, float scaleTargets, float scaleOutput) { - _weightActs(stream, images, hidActs, targets, imgSizeY, numModulesY, numModulesX, filterSize, paddingStart, moduleStride, numImgColors, numGroups, 1, - scaleTargets, scaleOutput); +void localWeightActs( + cudaStream_t stream, NVMatrix& images, NVMatrix& hidActs, NVMatrix& targets, + int imgSizeY, int numModulesY, int numModulesX, int filterSize, + int paddingStart, int moduleStride, int numImgColors, int numGroups, + float scaleTargets, float scaleOutput) { + _weightActs( + stream, images, hidActs, targets, imgSizeY, numModulesY, numModulesX, + filterSize, paddingStart, moduleStride, numImgColors, numGroups, 1, + scaleTargets, scaleOutput); } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ff.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ff.cu index 0b1f80de..ea2b2325 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ff.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ff.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ff.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ff.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_c_kepler_sw.cuh" @@ -33,10 +35,10 @@ namespace megdnn { namespace cuda { - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, false, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 16, 1, 32, 1, false, false>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, false, true > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, true, true > (C_KEP_SW_PARAM); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ft.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ft.cu index 4764a67a..967aca89 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ft.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ft.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ft.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_1_ft.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_c_kepler_sw.cuh" @@ -33,10 +35,10 @@ namespace megdnn { namespace cuda { - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, false, false > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, true, true > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, false, false > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 16, 1, 32, 1, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 1, true, true > (C_KEP_SW_PARAM); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ff.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ff.cu index 0dfb712e..4d7820e7 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ff.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ff.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ff.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ff.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_c_kepler_sw.cuh" @@ -33,15 +35,14 @@ namespace megdnn { namespace cuda { - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, false, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 16, 1, 32, 2, false, false>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, false, true > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, true, true > (C_KEP_SW_PARAM); +// instead of preload +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 2, 4, 32, 3, false, false>(C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 4, 3, 32, 3, false, false>(C_KEP_SW_PARAM); - // instead of preload - WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 2, 4, 32, 3, false, false> (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 4, 3, 32, 3, false, false> (C_KEP_SW_PARAM); - -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ft.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ft.cu index a12d2275..fbeffa2e 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ft.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ft.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ft.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_2_ft.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_c_kepler_sw.cuh" @@ -33,10 +35,10 @@ namespace megdnn { namespace cuda { - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, false, false > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, true, true > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, false, false > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 16, 1, 32, 2, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 2, true, true > (C_KEP_SW_PARAM); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ff.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ff.cu index eb841c36..df86b139 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ff.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ff.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ff.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ff.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_c_kepler_sw.cuh" @@ -33,10 +35,10 @@ namespace megdnn { namespace cuda { - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, false, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 16, 1, 32, 3, false, false>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, false, true > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, true, true > (C_KEP_SW_PARAM); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ft.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ft.cu index 6addd933..63154c71 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ft.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ft.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ft.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_c_3_ft.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -33,10 +34,10 @@ namespace megdnn { namespace cuda { - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, false, false > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, true, true > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, false, false > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 16, 1, 32, 3, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 16, 1, 32, 3, true, true > (C_KEP_SW_PARAM); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_f4.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_f4.cu index a260aa87..7d3d573d 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_f4.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_f4.cu @@ -25,23 +25,24 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_c_kepler_sw.cuh" namespace megdnn { namespace cuda { - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 1, false, false > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 1, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 1, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 1, true, true > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 2, false, false > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 2, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 2, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 2, true, true > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 3, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 3, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 2, 4, 32, 1, false, false>(C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 2, 4, 32, 1, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 1, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 1, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 2, 4, 32, 2, false, false>(C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 2, 4, 32, 2, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 2, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 2, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 2, 4, 32, 3, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 2, 4, 32, 3, true, true > (C_KEP_SW_PARAM); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_pt_4.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_pt_4.cu index 0125b3a1..b746d404 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_pt_4.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_pt_4.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_pt_4.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_16_pt_4.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,23 +26,24 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_c_kepler_sw.cuh" namespace megdnn { namespace cuda { - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 1, false, false > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 1, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 1, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 1, true, true > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 2, false, false > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 2, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 2, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 2, true, true > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 3, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 3, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 4, 3, 32, 1, false, false>(C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 4, 3, 32, 1, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 1, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 1, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 4, 3, 32, 2, false, false>(C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 4, 3, 32, 2, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 2, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 2, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<16, 16, 2, 4, 3, 32, 3, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 16, 16, 2, 4, 3, 32, 3, true, true > (C_KEP_SW_PARAM); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_8.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_8.cu index 9a7f9b67..94e00dbc 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_8.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/weight_acts_c_kepler_sw_by_8.cu @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_c_kepler_sw.cuh" @@ -33,18 +34,18 @@ namespace megdnn { namespace cuda { - WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 1, false, false > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 1, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 1, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 1, true, true > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 2, false, false > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 2, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 2, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 2, true, true > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 3, false, false > (C_KEP_SW_PARAM); - WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 3, false, true > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 3, true, false > (C_KEP_SW_PARAM); - //WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 3, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<8, 16, 2, 2, 2, 16, 1, false, false>(C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<8, 16, 2, 2, 2, 16, 1, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 1, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 1, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<8, 16, 2, 2, 2, 16, 2, false, false>(C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<8, 16, 2, 2, 2, 16, 2, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 2, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 2, true, true > (C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<8, 16, 2, 2, 2, 16, 3, false, false>(C_KEP_SW_PARAM); +WET_ACT_C_KEPLER_SW_HEAD<8, 16, 2, 2, 2, 16, 3, false, true>(C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 3, true, false > (C_KEP_SW_PARAM); +// WET_ACT_C_KEPLER_SW_HEAD< 8, 16, 2, 2, 2, 16, 3, true, true > (C_KEP_SW_PARAM); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_c_kepler.cuh b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_c_kepler.cuh index 28987e91..1fd98fb4 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_c_kepler.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_c_kepler.cuh @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_templates.cuh" @@ -38,11 +39,12 @@ namespace cuda { * threadIdx.x determines filter * threadIdx.y determines pixel in filter * - * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of partialSum - * blockIdx.y determines pixel batch of B_Y * pixelsPerThread + * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of + * partialSum blockIdx.y determines pixel batch of B_Y * pixelsPerThread * * Number of filters must be divisible by B_X * filtersPerThread - * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is false. + * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is + * false. * * images: (numColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) @@ -56,21 +58,27 @@ namespace cuda { * numModules must be divisible by partialSum * pixelsPerThread must be divisible by pixelCache * - * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread = 1)... - * so the compiler is messing up here somehow. It's unable to optimize that case away. + * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread + * = 1)... so the compiler is messing up here somehow. It's unable to optimize that case + * away. */ -template -__global__ void conv_weight_acts_c_kepler(float* images, float* hidActs, float* targets, - const int numImages, const int numFilters, - const int numModulesY, const int numModulesX, - const int imgSizeY, const int imgSizeX, const int filterSize, - const int paddingStart, const int moduleStride, const int imgStride, - const int partialSum, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shImages[pixelCache * B_Y * numColors][preloadCases]; // preload preloadCases cases of B_Y * pixelsPerThread pixels - __shared__ float shHidActs[B_X * filtersPerThread][preloadCases + 1]; // preload preloadCases cases of B_X hidActs - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +template < + int B_Y, int B_X, int pixelCache, int pixelsPerThread, int filtersPerThread, + int preloadCases, int numColors, bool scale, bool checkCaseBounds> +__global__ void conv_weight_acts_c_kepler( + float* images, float* hidActs, float* targets, const int numImages, + const int numFilters, const int numModulesY, const int numModulesX, + const int imgSizeY, const int imgSizeX, const int filterSize, + const int paddingStart, const int moduleStride, const int imgStride, + const int partialSum, const float scaleTargets, const float scaleOutputs) { + __shared__ float shImages[pixelCache * B_Y * numColors] + [preloadCases]; // preload preloadCases cases of B_Y * + // pixelsPerThread pixels + __shared__ float + shHidActs[B_X * filtersPerThread] + [preloadCases + 1]; // preload preloadCases cases of B_X hidActs + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); const int tidx = B_X * threadIdx.y + threadIdx.x; @@ -79,44 +87,42 @@ __global__ void conv_weight_acts_c_kepler(float* images, float* hidActs, float* const int filterPixels = filterSize * filterSize; const int imgPixels = imgSizeY * imgSizeX; - const int filterBlocksPerModule = numFilters / (B_X*filtersPerThread); + const int filterBlocksPerModule = numFilters / (B_X * filtersPerThread); const int outputModuleIdx = blockIdx.x / filterBlocksPerModule; const int moduleIdx = partialSum * outputModuleIdx; - const int blockFilterIdx = B_X * filtersPerThread* (blockIdx.x % filterBlocksPerModule); + const int blockFilterIdx = + B_X * filtersPerThread * (blockIdx.x % filterBlocksPerModule); -// const int moduleStride = (imgSize - filterSize + 1) / numModulesX; + // const int moduleStride = (imgSize - filterSize + 1) / numModulesX; const int numModules = numModulesY * numModulesX; const int blockPixelOffset = blockIdx.y * B_Y * pixelsPerThread; images += loadX; - hidActs += blockFilterIdx * numImages * numModules - + loadY * numImages * numModules - + loadX; + hidActs += blockFilterIdx * numImages * numModules + + loadY * numImages * numModules + loadX; - targets += (outputModuleIdx * numFilters) * filterPixels * numColors - + blockPixelOffset * numFilters - + blockFilterIdx - + threadIdx.y * numFilters + threadIdx.x; + targets += (outputModuleIdx * numFilters) * filterPixels * numColors + + blockPixelOffset * numFilters + blockFilterIdx + + threadIdx.y * numFilters + threadIdx.x; float prod[numColors][pixelsPerThread][filtersPerThread]; - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { prod[c][p][f] = 0; } } } - __shared__ int pxIdxes[B_Y*pixelsPerThread]; - fill_shared_mem((int *)pxIdxes, sizeof(pxIdxes)/sizeof(int), 0); + __shared__ int pxIdxes[B_Y * pixelsPerThread]; + fill_shared_mem((int*)pxIdxes, sizeof(pxIdxes) / sizeof(int), 0); __syncthreads(); //__shared__ bool isPxInImage[B_Y*pixelsPerThread]; for (int m = moduleIdx; m < moduleIdx + partialSum; m++) { - __syncthreads(); if (tidx < B_Y * pixelsPerThread) { const int imgLoadModPosY = paddingStart + (m / numModulesX) * moduleStride; @@ -124,72 +130,91 @@ __global__ void conv_weight_acts_c_kepler(float* images, float* hidActs, float* int pxY = (imgLoadModPosY + (blockPixelOffset + tidx) / filterSize); int pxX = (imgLoadModPosX + (blockPixelOffset + tidx) % filterSize); int pixIdx = (pxY * imgSizeX + pxX) * imgStride; - pxIdxes[tidx] = pxY >= 0 && pxY < imgSizeY && pxX >= 0 && pxX < imgSizeX ? pixIdx : -1; - //isPxInImage[tidx] = ; + pxIdxes[tidx] = pxY >= 0 && pxY < imgSizeY && pxX >= 0 && pxX < imgSizeX + ? pixIdx + : -1; + // isPxInImage[tidx] = ; } __syncthreads(); for (int caseIdx = 0; caseIdx < numImages; caseIdx += preloadCases) { - if (/*loadY < B_X*filtersPerThread &&*/ (!checkCaseBounds || caseIdx + loadX < numImages)) { - #pragma unroll - for (int y = 0; y < B_X*filtersPerThread; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_X*filtersPerThread) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_X*filtersPerThread) { - shHidActs[loadY+y][loadX]= hidActs[caseIdx + y * numImages * numModules + m * numImages]; + if (/*loadY < B_X*filtersPerThread &&*/ ( + !checkCaseBounds || caseIdx + loadX < numImages)) { +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + // Make sure number of rows in the array is divisible by number of + // rows filled per iteration + if ((B_X * filtersPerThread) % (B_X * B_Y / preloadCases) == 0 || + y + loadY < B_X * filtersPerThread) { + shHidActs[loadY + y][loadX] = + hidActs[caseIdx + y * numImages * numModules + + m * numImages]; } } } - #pragma unroll +#pragma unroll for (int pp = 0; pp < pixelsPerThread; pp += pixelCache) { - //if (loadY < B_Y * pixelCache) { // This condition is not necessary for correctness, but it speeds things a bit - /* - * As long as B_Y * B_X is divisible by preloadCases this will loop the right - * number of times. - * - * This will load some imgGrads from filter pixels that don't exit (it'll set those to 0), - * but the code does not produce any output for those pixels (see last lines). - */ - #pragma unroll +// if (loadY < B_Y * pixelCache) { // This condition is not necessary for correctness, +// but it speeds things a bit +/* + * As long as B_Y * B_X is divisible by preloadCases this will loop the right + * number of times. + * + * This will load some imgGrads from filter pixels that don't exit (it'll set those to + * 0), but the code does not produce any output for those pixels (see last lines). + */ +#pragma unroll for (int y = 0; y < B_Y * pixelCache; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_Y * pixelCache) { - const int pxIdx = pp * B_Y + loadY + y; // pixel idx in filter + // Make sure number of rows in the array is divisible by number of + // rows filled per iteration + if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || + y + loadY < B_Y * pixelCache) { + const int pxIdx = pp * B_Y + loadY + y; // pixel idx in filter - if (pxIdx + blockPixelOffset < filterPixels && (!checkCaseBounds || caseIdx + loadX < numImages)) { - const int pixIdx = pxIdxes[pxIdx];//(pxY * imgSizeX + pxX) * imgStride; + if (pxIdx + blockPixelOffset < filterPixels && + (!checkCaseBounds || caseIdx + loadX < numImages)) { + const int pixIdx = pxIdxes + [pxIdx]; //(pxY * imgSizeX + pxX) * imgStride; if (pixIdx >= 0) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX] = images[caseIdx + c * imgPixels * imgStride + pixIdx]; + shImages[loadY + y + c * pixelCache * B_Y][loadX] = + images[caseIdx + c * imgPixels * imgStride + + pixIdx]; } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX] = 0; + shImages[loadY + y + c * pixelCache * B_Y][loadX] = + 0; } } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX]= 0; + shImages[loadY + y + c * pixelCache * B_Y][loadX] = 0; } } } } //} - __syncthreads(); - #pragma unroll +#pragma unroll for (int i = 0; i < preloadCases; i++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelCache; p++) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - prod[c][pp + p][f] += shImages[threadIdx.y + p * B_Y + c * pixelCache * B_Y][i] * shHidActs[threadIdx.x + f * B_X][i]; + prod[c][pp + p][f] += + shImages + [threadIdx.y + p * B_Y + + c * pixelCache * B_Y][i] * + shHidActs[threadIdx.x + f * B_X][i]; } } } @@ -201,27 +226,33 @@ __global__ void conv_weight_acts_c_kepler(float* images, float* hidActs, float* } if (scale) { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { if (blockPixelOffset + p * B_Y + threadIdx.y < filterPixels) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[p * B_Y * numFilters + c * filterPixels * numFilters + f * B_X] = scaleTargets * targets[p * B_Y * numFilters + c * filterPixels * numFilters + f * B_X] + scaleOutputs * prod[c][p][f]; + targets[p * B_Y * numFilters + c * filterPixels * numFilters + + f * B_X] = + scaleTargets * targets[p * B_Y * numFilters + + c * filterPixels * numFilters + + f * B_X] + + scaleOutputs * prod[c][p][f]; } } } } } else { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { if (blockPixelOffset + p * B_Y + threadIdx.y < filterPixels) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[p * B_Y * numFilters + c * filterPixels * numFilters + f * B_X] = scaleOutputs * prod[c][p][f]; + targets[p * B_Y * numFilters + c * filterPixels * numFilters + + f * B_X] = scaleOutputs * prod[c][p][f]; } } } @@ -229,5 +260,5 @@ __global__ void conv_weight_acts_c_kepler(float* images, float* hidActs, float* } } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_c_kepler_sw.cuh b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_c_kepler_sw.cuh index 493d9b27..de1773b2 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_c_kepler_sw.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_c_kepler_sw.cuh @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_templates.cuh" @@ -38,11 +39,12 @@ namespace cuda { * threadIdx.x determines filter * threadIdx.y determines pixel in filter * - * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of partialSum - * blockIdx.y determines pixel batch of B_Y * pixelsPerThread + * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of + * partialSum blockIdx.y determines pixel batch of B_Y * pixelsPerThread * * Number of filters must be divisible by B_X * filtersPerThread - * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is false. + * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is + * false. * * images: (numColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) @@ -56,22 +58,27 @@ namespace cuda { * numModules must be divisible by partialSum * pixelsPerThread must be divisible by pixelCache * - * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread = 1)... - * so the compiler is messing up here somehow. It's unable to optimize that case away. - * To be used when numFilterColors <= 3 + * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread + * = 1)... so the compiler is messing up here somehow. It's unable to optimize that case + * away. To be used when numFilterColors <= 3 */ -template -__global__ void conv_weight_acts_c_kepler_sw(float* images, float* hidActs, float* targets, - const int numImages, const int numFilters, - const int numModulesY, const int numModulesX, - const int imgSizeY, const int imgSizeX, const int filterSize, - const int paddingStart, const int moduleStride, const int imgStride, - const int sumWidth, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shImages[pixelCache * B_Y * numColors][preloadCases]; // preload preloadCases cases of B_Y * pixelsPerThread pixels - __shared__ float shHidActs[B_X * filtersPerThread][preloadCases + 1]; // preload preloadCases cases of B_X hidActs - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +template < + int B_Y, int B_X, int pixelCache, int pixelsPerThread, int filtersPerThread, + int preloadCases, int numColors, bool scale, bool checkCaseBounds> +__global__ void conv_weight_acts_c_kepler_sw( + float* images, float* hidActs, float* targets, const int numImages, + const int numFilters, const int numModulesY, const int numModulesX, + const int imgSizeY, const int imgSizeX, const int filterSize, + const int paddingStart, const int moduleStride, const int imgStride, + const int sumWidth, const float scaleTargets, const float scaleOutputs) { + __shared__ float shImages[pixelCache * B_Y * numColors] + [preloadCases]; // preload preloadCases cases of B_Y * + // pixelsPerThread pixels + __shared__ float + shHidActs[B_X * filtersPerThread] + [preloadCases + 1]; // preload preloadCases cases of B_X hidActs + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); const int tidx = B_X * threadIdx.y + threadIdx.x; const int loadY = tidx / preloadCases, loadX = tidx % preloadCases; @@ -79,12 +86,12 @@ __global__ void conv_weight_acts_c_kepler_sw(float* images, float* hidActs, floa const int filterPixels = filterSize * filterSize; const int imgPixels = imgSizeY * imgSizeX; - const int numFilterBlocks = DIVUP(numFilters, B_X*filtersPerThread); + const int numFilterBlocks = DIVUP(numFilters, B_X * filtersPerThread); const int blockModuleChunkIdx = blockIdx.x / numFilterBlocks; const int numModuleChunksX = DIVUP(numModulesX, sumWidth); -// const int numModuleChunksY = DIVUP(numModulesY, sumWidth); + // const int numModuleChunksY = DIVUP(numModulesY, sumWidth); const int blockModuleChunkX = blockModuleChunkIdx % numModuleChunksX; const int blockModuleChunkY = blockModuleChunkIdx / numModuleChunksX; @@ -92,32 +99,31 @@ __global__ void conv_weight_acts_c_kepler_sw(float* images, float* hidActs, floa const int blockModuleStartX = blockModuleChunkX * sumWidth; const int blockModuleStartY = blockModuleChunkY * sumWidth; - const int blockFilterIdx = B_X * filtersPerThread* (blockIdx.x % numFilterBlocks); + const int blockFilterIdx = B_X * filtersPerThread * (blockIdx.x % numFilterBlocks); -// const int moduleStride = (imgSize - filterSize + 1) / numModulesX; + // const int moduleStride = (imgSize - filterSize + 1) / numModulesX; const int numModules = numModulesY * numModulesX; const int blockPixelOffset = blockIdx.y * B_Y * pixelsPerThread; images += loadX; hidActs += blockFilterIdx * numImages * numModules -// + loadY * numImages * numModules - + loadX; + // + loadY * numImages * numModules + + loadX; - targets += (blockModuleChunkIdx * numFilters) * filterPixels * numColors - + blockPixelOffset * numFilters - + blockFilterIdx - + threadIdx.y * numFilters + threadIdx.x; + targets += (blockModuleChunkIdx * numFilters) * filterPixels * numColors + + blockPixelOffset * numFilters + blockFilterIdx + + threadIdx.y * numFilters + threadIdx.x; - //float* shImgLoad = &shImages[loadY][loadX]; - //float* shHidActLoad = &shHidActs[loadY][loadX]; + // float* shImgLoad = &shImages[loadY][loadX]; + // float* shHidActLoad = &shHidActs[loadY][loadX]; float prod[numColors][pixelsPerThread][filtersPerThread]; - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { prod[c][p][f] = 0; } @@ -128,14 +134,14 @@ __global__ void conv_weight_acts_c_kepler_sw(float* images, float* hidActs, floa const int mEndX = min(numModulesX, blockModuleStartX + sumWidth); const int mEndY = min(numModulesY, blockModuleStartY + sumWidth); -// if (mStartY == mEndY || mStartX == mEndX) { -// return; -// } + // if (mStartY == mEndY || mStartX == mEndX) { + // return; + // } const int fYOff = (blockPixelOffset + tidx) / filterSize; const int fXOff = (blockPixelOffset + tidx) % filterSize; - __shared__ int pxIdxes[B_Y*pixelsPerThread]; - fill_shared_mem((int *)pxIdxes, sizeof(pxIdxes)/sizeof(int), 0); + __shared__ int pxIdxes[B_Y * pixelsPerThread]; + fill_shared_mem((int*)pxIdxes, sizeof(pxIdxes) / sizeof(int), 0); __syncthreads(); for (int my = mStartY; my < mEndY; my++) { const int imgLoadModPosY = paddingStart + my * moduleStride; @@ -145,73 +151,103 @@ __global__ void conv_weight_acts_c_kepler_sw(float* images, float* hidActs, floa __syncthreads(); const int imgLoadModPosX = paddingStart + mx * moduleStride; if (tidx < B_Y * pixelsPerThread) { -// const int imgLoadModPosY = paddingStart + my * moduleStride; -// const int imgLoadModPosX = paddingStart + mx * moduleStride; + // const int imgLoadModPosY = paddingStart + my * + // moduleStride; const int imgLoadModPosX = paddingStart + // + mx * moduleStride; int pxY = (imgLoadModPosY + fYOff); int pxX = (imgLoadModPosX + fXOff); int pixIdx = (pxY * imgSizeX + pxX) * imgStride; - pxIdxes[tidx] = pxY >= 0 && pxY < imgSizeY && pxX >= 0 && pxX < imgSizeX ? pixIdx : -1; + pxIdxes[tidx] = pxY >= 0 && pxY < imgSizeY && pxX >= 0 && pxX < imgSizeX + ? pixIdx + : -1; } __syncthreads(); for (int caseIdx = 0; caseIdx < numImages; caseIdx += preloadCases) { - if (//loadY < B_X*filtersPerThread && - (!checkCaseBounds || caseIdx + loadX < numImages)) { - #pragma unroll - for (int y = 0; y < B_X*filtersPerThread; y += (B_X * B_Y) / preloadCases) { - const int fIdx = ((loadY + y) % filtersPerThread) * B_X + (loadY + y) / filtersPerThread; - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_X*filtersPerThread) % (B_X * B_Y / preloadCases) == 0 || loadY+y < B_X*filtersPerThread) { + if ( // loadY < B_X*filtersPerThread && + (!checkCaseBounds || caseIdx + loadX < numImages)) { +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + const int fIdx = ((loadY + y) % filtersPerThread) * B_X + + (loadY + y) / filtersPerThread; + // Make sure number of rows in the array is divisible by number + // of rows filled per iteration + if ((B_X * filtersPerThread) % (B_X * B_Y / preloadCases) == + 0 || + loadY + y < B_X * filtersPerThread) { if (blockFilterIdx + fIdx < numFilters) { - shHidActs[loadY+y][loadX]= hidActs[caseIdx + (fIdx * numModules + m) * numImages]; + shHidActs[loadY + y][loadX] = + hidActs[caseIdx + + (fIdx * numModules + m) * numImages]; } else { - shHidActs[loadY+y][loadX] = 0; + shHidActs[loadY + y][loadX] = 0; } } } } else { - #pragma unroll - for (int y = 0; y < B_X*filtersPerThread; y += (B_X * B_Y) / preloadCases) { - // const int fIdx = ((loadY + y) % filtersPerThread) * B_X + (loadY + y) / filtersPerThread; - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_X*filtersPerThread) % (B_X * B_Y / preloadCases) == 0 || loadY+y < B_X*filtersPerThread) { - shHidActs[loadY+y][loadX] = 0; +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + // const int fIdx = ((loadY + y) % + // filtersPerThread) * B_X + (loadY + y) + // / filtersPerThread; + // Make sure number of rows in the array is divisible by number + // of rows filled per iteration + if ((B_X * filtersPerThread) % (B_X * B_Y / preloadCases) == + 0 || + loadY + y < B_X * filtersPerThread) { + shHidActs[loadY + y][loadX] = 0; } } } - #pragma unroll +#pragma unroll for (int pp = 0; pp < pixelsPerThread; pp += pixelCache) { - //if (loadY < B_Y * pixelCache) { // This condition is not necessary for correctness, but it speeds things a bit - // - //As long as B_Y * B_X is divisible by preloadCases this will loop the right - //number of times. - // - //This will load some imgGrads from filter pixels that don't exit (it'll set those to 0), - //but the code does not produce any output for those pixels (see last lines). - // - #pragma unroll - for (int y = 0; y < B_Y * pixelCache; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_Y * pixelCache) { - const int pxIdx = pp * B_Y + loadY + y; // pixel idx in filter +// if (loadY < B_Y * pixelCache) { // This condition is not necessary for correctness, +// but it speeds things a bit +// +// As long as B_Y * B_X is divisible by preloadCases this will loop the right +// number of times. +// +// This will load some imgGrads from filter pixels that don't exit (it'll set those to +// 0), but the code does not produce any output for those pixels (see last lines). +// +#pragma unroll + for (int y = 0; y < B_Y * pixelCache; + y += (B_X * B_Y) / preloadCases) { + // Make sure number of rows in the array is divisible by number + // of rows filled per iteration + if ((B_Y * pixelCache) % (B_X * B_Y / preloadCases) == 0 || + y + loadY < B_Y * pixelCache) { + const int pxIdx = + pp * B_Y + loadY + y; // pixel idx in filter - if (pxIdx + blockPixelOffset < filterPixels && (!checkCaseBounds || caseIdx + loadX < numImages)) { - const int pixIdx = pxIdxes[pxIdx];//(pxY * imgSizeX + pxX) * imgStride; + if (pxIdx + blockPixelOffset < filterPixels && + (!checkCaseBounds || caseIdx + loadX < numImages)) { + const int pixIdx = pxIdxes + [pxIdx]; //(pxY * imgSizeX + pxX) * imgStride; if (pixIdx >= 0) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX] = images[caseIdx + c * imgPixels * imgStride + pixIdx]; + shImages[loadY + y + c * pixelCache * B_Y] + [loadX] = + images[caseIdx + + c * imgPixels * + imgStride + + pixIdx]; } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX] = 0; + shImages[loadY + y + c * pixelCache * B_Y] + [loadX] = 0; } } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - shImages[loadY+y + c * pixelCache * B_Y][loadX]= 0; + shImages[loadY + y + c * pixelCache * B_Y][loadX] = + 0; } } } @@ -220,16 +256,25 @@ __global__ void conv_weight_acts_c_kepler_sw(float* images, float* hidActs, floa __syncthreads(); - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int i = 0; i < preloadCases; i++) { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelCache; p++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - if (blockFilterIdx + threadIdx.x + f * B_X < numFilters) { - prod[c][pp + p][f] += shImages[threadIdx.y + (p + c * pixelCache) * B_Y][i] * shHidActs[threadIdx.x * filtersPerThread + f][i]; + if (blockFilterIdx + threadIdx.x + f * B_X < + numFilters) { + prod[c][pp + p][f] += + shImages + [threadIdx.y + + (p + c * pixelCache) * B_Y] + [i] * + shHidActs + [threadIdx.x * + filtersPerThread + + f][i]; } } } @@ -242,28 +287,33 @@ __global__ void conv_weight_acts_c_kepler_sw(float* images, float* hidActs, floa } } if (scale) { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { if (blockPixelOffset + p * B_Y + threadIdx.y < filterPixels) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[(p * B_Y + c * filterPixels) * numFilters + f * B_X] = scaleTargets * targets[(p * B_Y + c * filterPixels) * numFilters + f * B_X] + scaleOutputs * prod[c][p][f]; + targets[(p * B_Y + c * filterPixels) * numFilters + f * B_X] = + scaleTargets * targets[(p * B_Y + c * filterPixels) * + numFilters + + f * B_X] + + scaleOutputs * prod[c][p][f]; } } } } } else { - #pragma unroll +#pragma unroll for (int p = 0; p < pixelsPerThread; p++) { if (blockPixelOffset + p * B_Y + threadIdx.y < filterPixels) { - #pragma unroll +#pragma unroll for (int c = 0; c < numColors; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { if (blockFilterIdx + threadIdx.x + f * B_X < numFilters) { - targets[(p * B_Y + c * filterPixels) * numFilters + f * B_X] = scaleOutputs * prod[c][p][f]; + targets[(p * B_Y + c * filterPixels) * numFilters + + f * B_X] = scaleOutputs * prod[c][p][f]; } } } @@ -272,8 +322,7 @@ __global__ void conv_weight_acts_c_kepler_sw(float* images, float* hidActs, floa } } - #define WET_ACT_C_KEPLER_SW_HEAD template __global__ void conv_weight_acts_c_kepler_sw -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler.cuh b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler.cuh index b06094e1..8ca173df 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler.cuh @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_templates.cuh" @@ -33,35 +34,45 @@ namespace megdnn { namespace cuda { /* - * Each block computes weight gradients for 1 pixel, B_Y * colorsPerThread colors and B_X * filtersPerThread filters + * Each block computes weight gradients for 1 pixel, B_Y * colorsPerThread colors and + B_X * filtersPerThread filters * threadIdx.x determines filter * threadIdx.y determines color * - * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of partialSum + * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of + partialSum * blockIdx.y determines color batch of B_Y * colorsPerThread * blockIdx.z determines pixel in filter - * NOTE: blockIdx.z is limited to values < 2^16. This means that this routine will - * fail for filters >= 256*256. I'm assuming I won't ever use such large filters. + * NOTE: blockIdx.z is limited to values < 2^16. This means that this routine + will + * fail for filters >= 256*256. I'm assuming I won't ever use such + large filters. * images: (numImgColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) * - * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, numFilters) + * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, + numFilters) * B_X * B_Y must be divisible by preloadCases */ -template -__global__ void conv_weight_acts_mc_mf_kepler(float* images, float* hidActs, float* targets, - const int numImages, const int numFilters, - const int numModulesY, const int numModulesX, - const int imgSizeY, const int imgSizeX, const int filterSize, - const int paddingStart, const int moduleStride, const int imgStride, - const int numImgColors, const int numGroups, const int partialSum, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shImages[colorsPerThread * B_Y][preloadCases]; // preload preloadCases cases - __shared__ float shHidActs[filtersPerThread * B_X][preloadCases + 1]; // preload preloadCases cases of B_X hidacts - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +template < + int B_Y, int B_X, int filtersPerThread, int colorsPerThread, int preloadCases, + bool scale> +__global__ void conv_weight_acts_mc_mf_kepler( + float* images, float* hidActs, float* targets, const int numImages, + const int numFilters, const int numModulesY, const int numModulesX, + const int imgSizeY, const int imgSizeX, const int filterSize, + const int paddingStart, const int moduleStride, const int imgStride, + const int numImgColors, const int numGroups, const int partialSum, + const float scaleTargets, const float scaleOutputs) { + __shared__ float shImages[colorsPerThread * B_Y] + [preloadCases]; // preload preloadCases cases + __shared__ float + shHidActs[filtersPerThread * B_X] + [preloadCases + 1]; // preload preloadCases cases of B_X hidacts + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); const int tidx = B_X * threadIdx.y + threadIdx.x; @@ -80,30 +91,27 @@ __global__ void conv_weight_acts_mc_mf_kepler(float* images, float* hidActs, flo const int blockGroupIdx = blockFilterIdx / numFiltersPerGroup; const int numFilterColors = numImgColors / numGroups; - const int blockPixelOffset = blockIdx.z; // pixel idx in filter - const int blockPixelY = blockPixelOffset / filterSize, blockPixelX = blockPixelOffset % filterSize; - const int blockFilterColorIdx = blockIdx.y * B_Y * colorsPerThread; + const int blockPixelOffset = blockIdx.z; // pixel idx in filter + const int blockPixelY = blockPixelOffset / filterSize, + blockPixelX = blockPixelOffset % filterSize; + const int blockFilterColorIdx = blockIdx.y * B_Y * colorsPerThread; const int imgColorIdx = blockFilterColorIdx + blockGroupIdx * numFilterColors; images += (imgColorIdx + loadY) * imgPixels * imgStride + loadX; - hidActs += - blockFilterIdx * numImages * numModules - + loadY * numImages * numModules - + loadX; + hidActs += blockFilterIdx * numImages * numModules + + loadY * numImages * numModules + loadX; - targets += outputModuleIdx * numFilters * filterPixels * numFilterColors - + (blockFilterColorIdx + threadIdx.y) * filterPixels * numFilters - + blockPixelOffset * numFilters - + blockFilterIdx - + threadIdx.x; - //if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return; + targets += outputModuleIdx * numFilters * filterPixels * numFilterColors + + (blockFilterColorIdx + threadIdx.y) * filterPixels * numFilters + + blockPixelOffset * numFilters + blockFilterIdx + threadIdx.x; + // if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return; float* shHidActLoad = &shHidActs[loadY][loadX]; float* shImgLoad = &shImages[loadY][loadX]; float prod[colorsPerThread][filtersPerThread]; - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { prod[c][f] = 0; } @@ -112,65 +120,87 @@ __global__ void conv_weight_acts_mc_mf_kepler(float* images, float* hidActs, flo for (int m = moduleIdx; m < moduleIdx + partialSum; m++) { const int imgLoadModPosY = paddingStart + (m / numModulesX) * moduleStride; const int imgLoadModPosX = paddingStart + (m % numModulesX) * moduleStride; - const int pxY = imgLoadModPosY + blockPixelY; // pixel x,y coords in image + const int pxY = imgLoadModPosY + blockPixelY; // pixel x,y coords in image const int pxX = imgLoadModPosX + blockPixelX; - const int pixIdx = (pxY * imgSizeX + pxX) * imgStride; // pixel idx in image + const int pixIdx = (pxY * imgSizeX + pxX) * imgStride; // pixel idx in image if (pxY >= 0 && pxY < imgSizeY && pxX >= 0 && pxX < imgSizeX) { for (int caseIdx = 0; caseIdx < numImages; caseIdx += preloadCases) { // Checking this condition actually makes things faster ... :/ - // So I've removed the !checkCaseBounds flag and just check it all the time. + // So I've removed the !checkCaseBounds flag and just check it all the + // time. if (caseIdx + loadX < numImages) { /* - * As long as B_Y * B_X is divisible by preloadCases this will loop the right - * number of times. + * As long as B_Y * B_X is divisible by preloadCases this will loop + * the right number of times. * - * This will load some images from filter pixels that don't exist (it'll set those to 0), - * but the code does not produce any output for those pixels (see last lines). + * This will load some images from filter pixels that don't exist + * (it'll set those to 0), but the code does not produce any output + * for those pixels (see last lines). */ if (loadY < B_Y * colorsPerThread) { - #pragma unroll - for (int y = 0; y < B_Y * colorsPerThread; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_Y*colorsPerThread) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_Y*colorsPerThread) { - shImgLoad[(y) * preloadCases] = images[caseIdx + y * imgPixels * imgStride + pixIdx]; +#pragma unroll + for (int y = 0; y < B_Y * colorsPerThread; + y += (B_X * B_Y) / preloadCases) { + // Make sure number of rows in the array is divisible by + // number of rows filled per iteration + if ((B_Y * colorsPerThread) % (B_X * B_Y / preloadCases) == + 0 || + y + loadY < B_Y * colorsPerThread) { + shImgLoad[(y)*preloadCases] = + images[caseIdx + y * imgPixels * imgStride + + pixIdx]; } } } if (loadY < B_X * filtersPerThread) { - #pragma unroll - for (int y = 0; y < B_X * filtersPerThread; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_X * filtersPerThread) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_X * filtersPerThread) { - shHidActLoad[y * (preloadCases + 1)] = hidActs[caseIdx + (y * numModules + m) * numImages]; +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + // Make sure number of rows in the array is divisible by + // number of rows filled per iteration + if ((B_X * filtersPerThread) % (B_X * B_Y / preloadCases) == + 0 || + y + loadY < B_X * filtersPerThread) { + shHidActLoad[y * (preloadCases + 1)] = + hidActs[caseIdx + + (y * numModules + m) * numImages]; } } } } else { - #pragma unroll - for (int y = 0; y < B_Y * colorsPerThread; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_Y*colorsPerThread) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_Y*colorsPerThread) { - shImgLoad[(y) * preloadCases] = 0; +#pragma unroll + for (int y = 0; y < B_Y * colorsPerThread; + y += (B_X * B_Y) / preloadCases) { + // Make sure number of rows in the array is divisible by number + // of rows filled per iteration + if ((B_Y * colorsPerThread) % (B_X * B_Y / preloadCases) == 0 || + y + loadY < B_Y * colorsPerThread) { + shImgLoad[(y)*preloadCases] = 0; } } - #pragma unroll - for (int y = 0; y < B_X * filtersPerThread; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_X * filtersPerThread) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_X * filtersPerThread) { +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + // Make sure number of rows in the array is divisible by number + // of rows filled per iteration + if ((B_X * filtersPerThread) % (B_X * B_Y / preloadCases) == + 0 || + y + loadY < B_X * filtersPerThread) { shHidActLoad[y * (preloadCases + 1)] = 0; } } } __syncthreads(); - #pragma unroll +#pragma unroll for (int i = 0; i < preloadCases; i++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - prod[c][f] += shImages[threadIdx.y + c * B_Y][i] * shHidActs[threadIdx.x + f * B_X][i]; + prod[c][f] += shImages[threadIdx.y + c * B_Y][i] * + shHidActs[threadIdx.x + f * B_X][i]; } } } @@ -179,23 +209,27 @@ __global__ void conv_weight_acts_mc_mf_kepler(float* images, float* hidActs, flo } } if (scale) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[c * B_Y * filterPixels * numFilters + f * B_X] = scaleTargets * targets[c * B_Y * filterPixels * numFilters + f * B_X] + scaleOutputs * prod[c][f]; + targets[c * B_Y * filterPixels * numFilters + f * B_X] = + scaleTargets * + targets[c * B_Y * filterPixels * numFilters + f * B_X] + + scaleOutputs * prod[c][f]; } } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - targets[c * B_Y * filterPixels * numFilters + f * B_X] = scaleOutputs * prod[c][f]; + targets[c * B_Y * filterPixels * numFilters + f * B_X] = + scaleOutputs * prod[c][f]; } } } } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw.cuh b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw.cuh index f46f1bd8..688acde1 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw.cuh @@ -25,7 +25,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_templates.cuh" @@ -34,15 +35,19 @@ namespace megdnn { namespace cuda { /* - * Each block computes weight gradients for 1 pixel, B_Y * colorsPerThread colors and B_X * filtersPerThread filters + * Each block computes weight gradients for 1 pixel, B_Y * colorsPerThread colors and + B_X * filtersPerThread filters * threadIdx.x determines filter * threadIdx.y determines color * - * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of partialSum + * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of + partialSum * blockIdx.y determines color batch of B_Y * colorsPerThread * blockIdx.z determines pixel in filter - * NOTE: blockIdx.z is limited to values < 2^16. This means that this routine will - * fail for filters >= 256*256. I'm assuming I won't ever use such large filters. + * NOTE: blockIdx.z is limited to values < 2^16. This means that this routine + will + * fail for filters >= 256*256. I'm assuming I won't ever use such + large filters. * images: (numImgColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) @@ -52,18 +57,23 @@ namespace cuda { * B_X * B_Y must be divisible by preloadCases * To be used when numFilterColors > 3 && numFilterColors % 16 == 0 */ -template -__global__ void conv_weight_acts_mc_mf_kepler_sw(float* images, float* hidActs, float* targets, - const int numImages, const int numFilters, - const int numModulesY, const int numModulesX, - const int imgSizeY, const int imgSizeX, const int filterSize, - const int paddingStart, const int moduleStride, const int imgStride, - const int numImgColors, const int numGroups, const int sumWidth, - const float scaleTargets, const float scaleOutputs) { - __shared__ float shImages[colorsPerThread * B_Y][preloadCases]; // preload preloadCases cases - __shared__ float shHidActs[filtersPerThread * B_X][preloadCases + 1]; // preload preloadCases cases of B_X hidacts - fill_shared_mem((float *)shImages, sizeof(shImages)/sizeof(float), 0); - fill_shared_mem((float *)shHidActs, sizeof(shHidActs)/sizeof(float), 0); +template < + int B_Y, int B_X, int filtersPerThread, int colorsPerThread, int preloadCases, + bool scale> +__global__ void conv_weight_acts_mc_mf_kepler_sw( + float* images, float* hidActs, float* targets, const int numImages, + const int numFilters, const int numModulesY, const int numModulesX, + const int imgSizeY, const int imgSizeX, const int filterSize, + const int paddingStart, const int moduleStride, const int imgStride, + const int numImgColors, const int numGroups, const int sumWidth, + const float scaleTargets, const float scaleOutputs) { + __shared__ float shImages[colorsPerThread * B_Y] + [preloadCases]; // preload preloadCases cases + __shared__ float + shHidActs[filtersPerThread * B_X] + [preloadCases + 1]; // preload preloadCases cases of B_X hidacts + fill_shared_mem((float*)shImages, sizeof(shImages) / sizeof(float), 0); + fill_shared_mem((float*)shHidActs, sizeof(shHidActs) / sizeof(float), 0); __syncthreads(); const int tidx = B_X * threadIdx.y + threadIdx.x; @@ -72,12 +82,12 @@ __global__ void conv_weight_acts_mc_mf_kepler_sw(float* images, float* hidActs, const int filterPixels = filterSize * filterSize; const int imgPixels = imgSizeY * imgSizeX; - //const int numFilterBlocks = numFilters / (B_X * filtersPerThread); + // const int numFilterBlocks = numFilters / (B_X * filtersPerThread); const int numFilterBlocks = DIVUP(numFilters, (B_X * filtersPerThread)); const int blockModuleChunkIdx = blockIdx.x / numFilterBlocks; const int numModuleChunksX = DIVUP(numModulesX, sumWidth); -// const int numModuleChunksY = DIVUP(numModulesY, sumWidth); + // const int numModuleChunksY = DIVUP(numModulesY, sumWidth); const int blockModuleChunkX = blockModuleChunkIdx % numModuleChunksX; const int blockModuleChunkY = blockModuleChunkIdx / numModuleChunksX; @@ -92,189 +102,219 @@ __global__ void conv_weight_acts_mc_mf_kepler_sw(float* images, float* hidActs, const int blockGroupIdx = blockFilterIdx / numFiltersPerGroup; const int numFilterColors = numImgColors / numGroups; - const int blockPixelOffset = blockIdx.z; // pixel idx in filter - const int blockPixelY = blockPixelOffset / filterSize, blockPixelX = blockPixelOffset % filterSize; - const int blockFilterColorIdx = blockIdx.y * B_Y * colorsPerThread; + const int blockPixelOffset = blockIdx.z; // pixel idx in filter + const int blockPixelY = blockPixelOffset / filterSize, + blockPixelX = blockPixelOffset % filterSize; + const int blockFilterColorIdx = blockIdx.y * B_Y * colorsPerThread; const int imgColorIdx = blockFilterColorIdx + blockGroupIdx * numFilterColors; images += (imgColorIdx + loadY) * imgPixels * imgStride + loadX; - hidActs += - blockFilterIdx * numImages * numModules - + loadY * numImages * numModules - + loadX; + hidActs += blockFilterIdx * numImages * numModules + + loadY * numImages * numModules + loadX; - targets += blockModuleChunkIdx * numFilters * filterPixels * numFilterColors - + (blockFilterColorIdx + threadIdx.y) * filterPixels * numFilters - + blockPixelOffset * numFilters - + blockFilterIdx - + threadIdx.x; + targets += blockModuleChunkIdx * numFilters * filterPixels * numFilterColors + + (blockFilterColorIdx + threadIdx.y) * filterPixels * numFilters + + blockPixelOffset * numFilters + blockFilterIdx + threadIdx.x; - //if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return; + // if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return; - const int mStartX = max(blockModuleStartX, DIVUP(-blockPixelX - paddingStart, moduleStride)); - const int mStartY = max(blockModuleStartY, DIVUP(-blockPixelY - paddingStart, moduleStride)); - const int mEndX = min(numModulesX, min(blockModuleStartX + sumWidth, DIVUP(imgSizeX - blockPixelX - paddingStart, moduleStride))); - const int mEndY = min(numModulesY, min(blockModuleStartY + sumWidth, DIVUP(imgSizeY - blockPixelY - paddingStart, moduleStride))); + const int mStartX = + max(blockModuleStartX, DIVUP(-blockPixelX - paddingStart, moduleStride)); + const int mStartY = + max(blockModuleStartY, DIVUP(-blockPixelY - paddingStart, moduleStride)); + const int mEndX = + min(numModulesX, + min(blockModuleStartX + sumWidth, + DIVUP(imgSizeX - blockPixelX - paddingStart, moduleStride))); + const int mEndY = + min(numModulesY, + min(blockModuleStartY + sumWidth, + DIVUP(imgSizeY - blockPixelY - paddingStart, moduleStride))); -// if (mStartY == mEndY || mStartX == mEndX) { -// return; -// } + // if (mStartY == mEndY || mStartX == mEndX) { + // return; + // } float* shHidActLoad = &shHidActs[loadY][loadX]; float* shImgLoad = &shImages[loadY][loadX]; float prod[colorsPerThread][filtersPerThread]; - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { prod[c][f] = 0; } } /* - * Note; iterating this way is about 1% slower and uses a few more registers than iterating - * over the modules linearly. But it's consistent with the preload routines, - * so I'm using it. + * Note; iterating this way is about 1% slower and uses a few more registers than + * iterating over the modules linearly. But it's consistent with the preload + * routines, so I'm using it. */ for (int my = mStartY; my < mEndY; my++) { const int imgLoadModPosY = paddingStart + my * moduleStride; - const int pxY = imgLoadModPosY + blockPixelY; // pixel x,y coords in image + const int pxY = imgLoadModPosY + blockPixelY; // pixel x,y coords in image for (int mx = mStartX; mx < mEndX; mx++) { const int m = my * numModulesX + mx; const int imgLoadModPosX = paddingStart + mx * moduleStride; const int pxX = imgLoadModPosX + blockPixelX; - const int pixIdx = (pxY * imgSizeX + pxX) * imgStride; // pixel idx in image + const int pixIdx = + (pxY * imgSizeX + pxX) * imgStride; // pixel idx in image for (int caseIdx = 0; caseIdx < numImages; caseIdx += preloadCases) { // Checking this condition actually makes things faster ... :/ - // So I've removed the !checkCaseBounds flag and just check it all the time. + // So I've removed the !checkCaseBounds flag and just check it all the + // time. if (caseIdx + loadX < numImages) { /* - * As long as B_Y * B_X is divisible by preloadCases this will loop the right - * number of times. + * As long as B_Y * B_X is divisible by preloadCases this will loop + * the right number of times. * - * This will load some images from filter pixels that don't exist (it'll set those to 0), - * but the code does not produce any output for those pixels (see last lines). + * This will load some images from filter pixels that don't exist + * (it'll set those to 0), but the code does not produce any output + * for those pixels (see last lines). */ if (loadY < B_Y * colorsPerThread) { - #pragma unroll - for (int y = 0; y < B_Y * colorsPerThread; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_Y*colorsPerThread) % (B_X * B_Y / preloadCases) == 0 || - y + loadY < B_Y*colorsPerThread) { - if(y + loadY + imgColorIdx < numImgColors) { - shImgLoad[(y) * preloadCases] = images[caseIdx + y * imgPixels * imgStride + pixIdx]; +#pragma unroll + for (int y = 0; y < B_Y * colorsPerThread; + y += (B_X * B_Y) / preloadCases) { + // Make sure number of rows in the array is divisible by + // number of rows filled per iteration + if ((B_Y * colorsPerThread) % (B_X * B_Y / preloadCases) == + 0 || + y + loadY < B_Y * colorsPerThread) { + if (y + loadY + imgColorIdx < numImgColors) { + shImgLoad[(y)*preloadCases] = + images[caseIdx + y * imgPixels * imgStride + + pixIdx]; } else { - shImgLoad[(y) * preloadCases] = 0; + shImgLoad[(y)*preloadCases] = 0; } } } } if (loadY < B_X * filtersPerThread) { - #pragma unroll - for (int y = 0; y < B_X * filtersPerThread; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_X * filtersPerThread) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_X * filtersPerThread) { +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + // Make sure number of rows in the array is divisible by + // number of rows filled per iteration + if ((B_X * filtersPerThread) % (B_X * B_Y / preloadCases) == + 0 || + y + loadY < B_X * filtersPerThread) { if (blockFilterIdx + loadY + y < numFilters) { - shHidActLoad[y * (preloadCases + 1)] = hidActs[caseIdx + (y * numModules + m) * numImages]; + shHidActLoad[y * (preloadCases + 1)] = + hidActs[caseIdx + + (y * numModules + m) * numImages]; } else if (loadY + y < filtersPerThread * B_X) { - shHidActLoad[y * (preloadCases + 1)] = 0; + shHidActLoad[y * (preloadCases + 1)] = 0; } } } } } else { - #pragma unroll - for (int y = 0; y < B_Y * colorsPerThread; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_Y*colorsPerThread) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_Y*colorsPerThread) { - shImgLoad[(y) * preloadCases] = 0; +#pragma unroll + for (int y = 0; y < B_Y * colorsPerThread; + y += (B_X * B_Y) / preloadCases) { + // Make sure number of rows in the array is divisible by number + // of rows filled per iteration + if ((B_Y * colorsPerThread) % (B_X * B_Y / preloadCases) == 0 || + y + loadY < B_Y * colorsPerThread) { + shImgLoad[(y)*preloadCases] = 0; } } - #pragma unroll - for (int y = 0; y < B_X * filtersPerThread; y += (B_X * B_Y) / preloadCases) { - // Make sure number of rows in the array is divisible by number of rows filled per iteration - if ((B_X * filtersPerThread) % (B_X * B_Y / preloadCases) == 0 || y + loadY < B_X * filtersPerThread) { +#pragma unroll + for (int y = 0; y < B_X * filtersPerThread; + y += (B_X * B_Y) / preloadCases) { + // Make sure number of rows in the array is divisible by number + // of rows filled per iteration + if ((B_X * filtersPerThread) % (B_X * B_Y / preloadCases) == + 0 || + y + loadY < B_X * filtersPerThread) { shHidActLoad[y * (preloadCases + 1)] = 0; } } } __syncthreads(); - #pragma unroll +#pragma unroll for (int i = 0; i < preloadCases; i++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { if (blockFilterIdx + threadIdx.x + f * B_X < numFilters) { - prod[c][f] += shImages[threadIdx.y + c * B_Y][i] * shHidActs[threadIdx.x + f * B_X][i]; + prod[c][f] += shImages[threadIdx.y + c * B_Y][i] * + shHidActs[threadIdx.x + f * B_X][i]; } } } } __syncthreads(); } - } } if (scale) { //#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { if (blockFilterIdx + threadIdx.x + f * B_X < numFilters) { targets[c * B_Y * filterPixels * numFilters + f * B_X] = - scaleTargets * targets[c * B_Y * filterPixels * numFilters + f * B_X] + scaleOutputs * prod[c][f]; + scaleTargets * targets[c * B_Y * filterPixels * numFilters + + f * B_X] + + scaleOutputs * prod[c][f]; } } } } else { - #pragma unroll +#pragma unroll for (int c = 0; c < colorsPerThread; c++) { - #pragma unroll +#pragma unroll for (int f = 0; f < filtersPerThread; f++) { if ((blockFilterIdx + threadIdx.x + f * B_X < numFilters) && (c * B_Y + blockFilterColorIdx + threadIdx.y < numImgColors)) { - targets[c * B_Y * filterPixels * numFilters + f * B_X] = scaleOutputs * prod[c][f]; + targets[c * B_Y * filterPixels * numFilters + f * B_X] = + scaleOutputs * prod[c][f]; } } } } } -#define WET_ACT_MC_MF_KEPLER_SW_HEAD template __global__ void conv_weight_acts_mc_mf_kepler_sw -#define WET_ACT_MC_MF_KEPLER_SW_4_A(scale) \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<4,16,1,4,32,scale> (MC_MF_KEP_SW_PARAM); \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<4,16,1,8,32,scale> (MC_MF_KEP_SW_PARAM); +#define WET_ACT_MC_MF_KEPLER_SW_HEAD \ + template __global__ void conv_weight_acts_mc_mf_kepler_sw +#define WET_ACT_MC_MF_KEPLER_SW_4_A(scale) \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<4, 16, 1, 4, 32, scale>(MC_MF_KEP_SW_PARAM); \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<4, 16, 1, 8, 32, scale>(MC_MF_KEP_SW_PARAM); -#define WET_ACT_MC_MF_KEPLER_SW_4_B(scale) \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<4,16,2,4,32,scale> (MC_MF_KEP_SW_PARAM); \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<4,16,2,8,32,scale> (MC_MF_KEP_SW_PARAM); +#define WET_ACT_MC_MF_KEPLER_SW_4_B(scale) \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<4, 16, 2, 4, 32, scale>(MC_MF_KEP_SW_PARAM); \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<4, 16, 2, 8, 32, scale>(MC_MF_KEP_SW_PARAM); -#define WET_ACT_MC_MF_KEPLER_SW_4_C(scale) \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<4,16,4,4,32,scale> (MC_MF_KEP_SW_PARAM); \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<4,16,4,8,16,scale> (MC_MF_KEP_SW_PARAM); +#define WET_ACT_MC_MF_KEPLER_SW_4_C(scale) \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<4, 16, 4, 4, 32, scale>(MC_MF_KEP_SW_PARAM); \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<4, 16, 4, 8, 16, scale>(MC_MF_KEP_SW_PARAM); -#define WET_ACT_MC_MF_KEPLER_SW_4_D(scale) \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<4,32,4,4,32,scale> (MC_MF_KEP_SW_PARAM); \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<4,32,4,8,16,scale> (MC_MF_KEP_SW_PARAM); +#define WET_ACT_MC_MF_KEPLER_SW_4_D(scale) \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<4, 32, 4, 4, 32, scale>(MC_MF_KEP_SW_PARAM); \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<4, 32, 4, 8, 16, scale>(MC_MF_KEP_SW_PARAM); -#define WET_ACT_MC_MF_KEPLER_SW_8_A(scale) \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<8,16,1,6,32,scale> (MC_MF_KEP_SW_PARAM); \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<8,16,1,8,32,scale> (MC_MF_KEP_SW_PARAM); +#define WET_ACT_MC_MF_KEPLER_SW_8_A(scale) \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<8, 16, 1, 6, 32, scale>(MC_MF_KEP_SW_PARAM); \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<8, 16, 1, 8, 32, scale>(MC_MF_KEP_SW_PARAM); -#define WET_ACT_MC_MF_KEPLER_SW_8_B(scale) \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<8,16,2,6,32,scale> (MC_MF_KEP_SW_PARAM); \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<8,16,2,8,32,scale> (MC_MF_KEP_SW_PARAM); +#define WET_ACT_MC_MF_KEPLER_SW_8_B(scale) \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<8, 16, 2, 6, 32, scale>(MC_MF_KEP_SW_PARAM); \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<8, 16, 2, 8, 32, scale>(MC_MF_KEP_SW_PARAM); -#define WET_ACT_MC_MF_KEPLER_SW_8_C(scale) \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<8,16,4,6,32,scale> (MC_MF_KEP_SW_PARAM); \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<8,16,4,8,16,scale> (MC_MF_KEP_SW_PARAM); +#define WET_ACT_MC_MF_KEPLER_SW_8_C(scale) \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<8, 16, 4, 6, 32, scale>(MC_MF_KEP_SW_PARAM); \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<8, 16, 4, 8, 16, scale>(MC_MF_KEP_SW_PARAM); -#define WET_ACT_MC_MF_KEPLER_SW_8_D(scale) \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<8,32,4,6,32,scale> (MC_MF_KEP_SW_PARAM); \ - WET_ACT_MC_MF_KEPLER_SW_HEAD<8,32,4,8,16,scale> (MC_MF_KEP_SW_PARAM); +#define WET_ACT_MC_MF_KEPLER_SW_8_D(scale) \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<8, 32, 4, 6, 32, scale>(MC_MF_KEP_SW_PARAM); \ + WET_ACT_MC_MF_KEPLER_SW_HEAD<8, 32, 4, 8, 16, scale>(MC_MF_KEP_SW_PARAM); -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_A_scale_f.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_A_scale_f.cu index 1cc4e1d3..95dc8e70 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_A_scale_f.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_A_scale_f.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_A_scale_f.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_A_scale_f.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_mc_mf_kepler_sw.cuh" @@ -34,5 +36,5 @@ namespace cuda { WET_ACT_MC_MF_KEPLER_SW_4_A(false) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_B_scale_f.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_B_scale_f.cu index e0b11913..f1f72d2b 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_B_scale_f.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_B_scale_f.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_B_scale_f.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_B_scale_f.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_mc_mf_kepler_sw.cuh" @@ -34,5 +36,5 @@ namespace cuda { WET_ACT_MC_MF_KEPLER_SW_4_B(false) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_C_scale_f.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_C_scale_f.cu index e412bf35..08ef997c 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_C_scale_f.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_C_scale_f.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_C_scale_f.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_C_scale_f.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_mc_mf_kepler_sw.cuh" @@ -34,5 +36,5 @@ namespace cuda { WET_ACT_MC_MF_KEPLER_SW_4_C(false) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_D_scale_f.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_D_scale_f.cu index f0b9f65f..cad1eb38 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_D_scale_f.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_D_scale_f.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_D_scale_f.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_4_D_scale_f.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_mc_mf_kepler_sw.cuh" @@ -34,5 +36,5 @@ namespace cuda { WET_ACT_MC_MF_KEPLER_SW_4_D(false) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_A_scale_f.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_A_scale_f.cu index 544a43e9..47f9f4b7 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_A_scale_f.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_A_scale_f.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_A_scale_f.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_A_scale_f.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_mc_mf_kepler_sw.cuh" @@ -34,5 +36,5 @@ namespace cuda { WET_ACT_MC_MF_KEPLER_SW_8_A(false) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_B_scale_f.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_B_scale_f.cu index 520e38e2..06bb4259 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_B_scale_f.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_B_scale_f.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_B_scale_f.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_B_scale_f.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_mc_mf_kepler_sw.cuh" @@ -34,5 +36,5 @@ namespace cuda { WET_ACT_MC_MF_KEPLER_SW_8_B(false) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_C_scale_f.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_C_scale_f.cu index 19add712..7df4705a 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_C_scale_f.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_C_scale_f.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_C_scale_f.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_C_scale_f.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_mc_mf_kepler_sw.cuh" @@ -34,5 +36,5 @@ namespace cuda { WET_ACT_MC_MF_KEPLER_SW_8_C(false) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_D_scale_f.cu b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_D_scale_f.cu index 879e9801..b7e5d1ad 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_D_scale_f.cu +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_D_scale_f.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_D_scale_f.cu + * \file + * dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_mc_mf_kepler_sw_by_8_D_scale_f.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -25,7 +26,8 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ #include "wet_act_mc_mf_kepler_sw.cuh" @@ -34,5 +36,5 @@ namespace cuda { WET_ACT_MC_MF_KEPLER_SW_8_D(false) -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_templates.cuh b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_templates.cuh index 6b80bb78..1f562ace 100644 --- a/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_templates.cuh +++ b/dnn/src/cuda/local/cuda-convnet2/weight_acts/wet_act_templates.cuh @@ -25,58 +25,66 @@ * * -------------------------------------------------------------------------- * * This file has been modified by Megvii ("Megvii Modifications"). - * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights reserved. + * * All Megvii Modifications are Copyright (C) 2014-2021 Megvii Inc. All rights + * reserved. * -------------------------------------------------------------------------- */ -#include "../nvmatrix.cuh" #include "../cudaconv2.cuh" +#include "../nvmatrix.cuh" #include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { -#define LO16(x) ((x) & 0x0000FFFF) -#define HI16(x) ((x) >> 16) - -#define WA_LOOP(r) _Pragma("unroll") \ -for (int c = 0; c < colorsPerThread; c++) { \ - _Pragma("unroll") \ - for (int f = 0; f < filtersPerThread; f++) { \ - prod[f][c] += shImages[threadIdx.y + c * B_Y][(r)] * shHidActs[threadIdx.x + f * B_X][(r)]; \ - } \ -} - -#define WA_LOOP2(r) _Pragma("unroll") \ -for (int f = 0; f < filtersPerThread; f++) { \ - _Pragma("unroll") \ - for (int c = 0; c < colorsPerThread; c++) { \ - prod[f][c] += shImages[threadIdx.y + c * B_Y][(r)] * shHidActs[threadIdx.x + f * B_X][(r)]; \ - } \ -} - -#define WA_IMLOAD(r) imPreload[r] = im[(r) * B_X * B_Y / preloadCases * imgPixels * imgStride]; -#define WA_IMLOAD_TX(r) imPreload[r] = tex1Dfetch(images, imgOffset2 + (r) * B_X * B_Y / preloadCases * imgPixels * imgStride); -#define WA_HALOAD(r) haPreload[r] = ha[(r) * B_X * B_Y / preloadCases * numImages * numModules]; -#define WA_HALOAD_TX(r) haPreload[r] = tex1Dfetch(hidActs, hidActsOffset2 + (r) * B_X * B_Y / preloadCases * numImages * numModules); - -#define C_KEP_PARAM float* images, float* hidActs, float* targets, \ - const int numImages, const int numFilters, \ - const int numModulesY, const int numModulesX, \ - const int imgSizeY, const int imgSizeX, \ - const int filterSize, const int paddingStart, \ - const int moduleStride, const int imgStride, \ - const int partialSum, \ - const float scaleTargets, const float scaleOutputs +#define LO16(x) ((x)&0x0000FFFF) +#define HI16(x) ((x) >> 16) + +#define WA_LOOP(r) \ + _Pragma("unroll") for (int c = 0; c < colorsPerThread; c++) { \ + _Pragma("unroll") for (int f = 0; f < filtersPerThread; f++) { \ + prod[f][c] += shImages[threadIdx.y + c * B_Y][(r)] * \ + shHidActs[threadIdx.x + f * B_X][(r)]; \ + } \ + } + +#define WA_LOOP2(r) \ + _Pragma("unroll") for (int f = 0; f < filtersPerThread; f++) { \ + _Pragma("unroll") for (int c = 0; c < colorsPerThread; c++) { \ + prod[f][c] += shImages[threadIdx.y + c * B_Y][(r)] * \ + shHidActs[threadIdx.x + f * B_X][(r)]; \ + } \ + } + +#define WA_IMLOAD(r) \ + imPreload[r] = im[(r)*B_X * B_Y / preloadCases * imgPixels * imgStride]; +#define WA_IMLOAD_TX(r) \ + imPreload[r] = tex1Dfetch( \ + images, \ + imgOffset2 + (r)*B_X * B_Y / preloadCases * imgPixels * imgStride); +#define WA_HALOAD(r) \ + haPreload[r] = ha[(r)*B_X * B_Y / preloadCases * numImages * numModules]; +#define WA_HALOAD_TX(r) \ + haPreload[r] = tex1Dfetch( \ + hidActs, \ + hidActsOffset2 + (r)*B_X * B_Y / preloadCases * numImages * numModules); + +#define C_KEP_PARAM \ + float *images, float *hidActs, float *targets, const int numImages, \ + const int numFilters, const int numModulesY, const int numModulesX, \ + const int imgSizeY, const int imgSizeX, const int filterSize, \ + const int paddingStart, const int moduleStride, const int imgStride, \ + const int partialSum, const float scaleTargets, const float scaleOutputs /* * Each block computes weight gradients for B_Y * pixelsPerThread pixels and B_X filters * threadIdx.x determines filter * threadIdx.y determines pixel in filter * - * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of partialSum - * blockIdx.y determines pixel batch of B_Y * pixelsPerThread + * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of + * partialSum blockIdx.y determines pixel batch of B_Y * pixelsPerThread * * Number of filters must be divisible by B_X * filtersPerThread - * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is false. + * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is + * false. * * images: (numColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) @@ -90,98 +98,102 @@ for (int f = 0; f < filtersPerThread; f++) { \ * numModules must be divisible by partialSum * pixelsPerThread must be divisible by pixelCache * - * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread = 1)... - * so the compiler is messing up here somehow. It's unable to optimize that case away. + * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread + * = 1)... so the compiler is messing up here somehow. It's unable to optimize that case + * away. */ -template +template < + int B_Y, int B_X, int pixelCache, int pixelsPerThread, int filtersPerThread, + int preloadCases, int numColors, bool scale, bool checkCaseBounds> __global__ void conv_weight_acts_c_kepler(C_KEP_PARAM); - - -#define MC_MF_KEP_PARAM float* images, \ - float* hidActs, float* targets, \ - const int numImages, const int numFilters, \ - const int numModulesY, const int numModulesX, \ - const int imgSizeY, const int imgSizeX, \ - const int filterSize, const int paddingStart, \ - const int moduleStride, const int imgStride, \ - const int numImgColors, const int numGroups, \ - const int partialSum, \ - const float scaleTargets, const float scaleOutputs +#define MC_MF_KEP_PARAM \ + float *images, float *hidActs, float *targets, const int numImages, \ + const int numFilters, const int numModulesY, const int numModulesX, \ + const int imgSizeY, const int imgSizeX, const int filterSize, \ + const int paddingStart, const int moduleStride, const int imgStride, \ + const int numImgColors, const int numGroups, const int partialSum, \ + const float scaleTargets, const float scaleOutputs /* - * Each block computes weight gradients for 1 pixel, B_Y * colorsPerThread colors and B_X * filtersPerThread filters + * Each block computes weight gradients for 1 pixel, B_Y * colorsPerThread colors and + B_X * filtersPerThread filters * threadIdx.x determines filter * threadIdx.y determines color * - * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of partialSum + * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of + partialSum * blockIdx.y determines color batch of B_Y * colorsPerThread * blockIdx.z determines pixel in filter - * NOTE: blockIdx.z is limited to values < 2^16. This means that this routine will - * fail for filters >= 256*256. I'm assuming I won't ever use such large filters. + * NOTE: blockIdx.z is limited to values < 2^16. This means that this routine + will + * fail for filters >= 256*256. I'm assuming I won't ever use such + large filters. * images: (numImgColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) * - * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, numFilters) + * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, + numFilters) * B_X * B_Y must be divisible by preloadCases */ -template +template < + int B_Y, int B_X, int filtersPerThread, int colorsPerThread, int preloadCases, + bool scale> __global__ void conv_weight_acts_mc_mf_kepler(MC_MF_KEP_PARAM); -#define MC_MF_KEP_SW_PARAM float* images, \ - float* hidActs, float* targets, \ - const int numImages, const int numFilters, \ - const int numModulesY, const int numModulesX, \ - const int imgSizeY, const int imgSizeX, const \ - int filterSize, const int paddingStart, \ - const int moduleStride, const int imgStride, \ - const int numImgColors, const int numGroups, \ - const int sumWidth, \ - const float scaleTargets, const float scaleOutputs +#define MC_MF_KEP_SW_PARAM \ + float *images, float *hidActs, float *targets, const int numImages, \ + const int numFilters, const int numModulesY, const int numModulesX, \ + const int imgSizeY, const int imgSizeX, const int filterSize, \ + const int paddingStart, const int moduleStride, const int imgStride, \ + const int numImgColors, const int numGroups, const int sumWidth, \ + const float scaleTargets, const float scaleOutputs /* - * Each block computes weight gradients for 1 pixel, B_Y * colorsPerThread colors and B_X * filtersPerThread filters + * Each block computes weight gradients for 1 pixel, B_Y * colorsPerThread colors and + B_X * filtersPerThread filters * threadIdx.x determines filter * threadIdx.y determines color * - * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of partialSum + * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of + partialSum * blockIdx.y determines color batch of B_Y * colorsPerThread * blockIdx.z determines pixel in filter - * NOTE: blockIdx.z is limited to values < 2^16. This means that this routine will - * fail for filters >= 256*256. I'm assuming I won't ever use such large filters. + * NOTE: blockIdx.z is limited to values < 2^16. This means that this routine + will + * fail for filters >= 256*256. I'm assuming I won't ever use such + large filters. * images: (numImgColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) * - * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, numFilters) + * targets: (numModulesY*numModulesX/partialSum, numFilterColors, filterPixels, + numFilters) * B_X * B_Y must be divisible by preloadCases */ -template +template < + int B_Y, int B_X, int filtersPerThread, int colorsPerThread, int preloadCases, + bool scale> __global__ void conv_weight_acts_mc_mf_kepler_sw(MC_MF_KEP_SW_PARAM); - - - -#define C_KEP_SW_PARAM float* images, \ - float* hidActs, float* targets, \ - const int numImages, const int numFilters, \ - const int numModulesY, const int numModulesX, \ - const int imgSizeY, const int imgSizeX, \ - const int filterSize, const int paddingStart, \ - const int moduleStride, const int imgStride, \ - const int sumWidth, \ - const float scaleTargets, const float scaleOutputs +#define C_KEP_SW_PARAM \ + float *images, float *hidActs, float *targets, const int numImages, \ + const int numFilters, const int numModulesY, const int numModulesX, \ + const int imgSizeY, const int imgSizeX, const int filterSize, \ + const int paddingStart, const int moduleStride, const int imgStride, \ + const int sumWidth, const float scaleTargets, const float scaleOutputs /* * Each block computes weight gradients for B_Y * pixelsPerThread pixels and B_X filters * threadIdx.x determines filter * threadIdx.y determines pixel in filter * - * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of partialSum - * blockIdx.y determines pixel batch of B_Y * pixelsPerThread + * blockIdx.x determines filter batch of B_X * filtersPerThread, module batch of + * partialSum blockIdx.y determines pixel batch of B_Y * pixelsPerThread * * Number of filters must be divisible by B_X * filtersPerThread - * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is false. + * Number of images (cases) should be divisible by preloadCases if checkCaseBounds is + * false. * * images: (numColors, imgSizeY, imgSizeX, numImages), with stride given * hidActs: (numFilters, numModulesY, numModulesX, numImages) @@ -195,17 +207,19 @@ __global__ void conv_weight_acts_mc_mf_kepler_sw(MC_MF_KEP_SW_PARAM); * numModules must be divisible by partialSum * pixelsPerThread must be divisible by pixelCache * - * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread = 1)... - * so the compiler is messing up here somehow. It's unable to optimize that case away. + * After adding pixelsPerThread, register usage went from 20 to 23 (when pixelsPerThread + * = 1)... so the compiler is messing up here somehow. It's unable to optimize that case + * away. */ -template -__global__ void conv_weight_acts_c_kepler_sw(float* images, float* hidActs, float* targets, - const int numImages, const int numFilters, - const int numModulesY, const int numModulesX, - const int imgSizeY, const int imgSizeX, const int filterSize, - const int paddingStart, const int moduleStride, const int imgStride, - const int sumWidth, - const float scaleTargets, const float scaleOutputs); - -} // namespace cuda -} // namespace megdnn +template < + int B_Y, int B_X, int pixelCache, int pixelsPerThread, int filtersPerThread, + int preloadCases, int numColors, bool scale, bool checkCaseBounds> +__global__ void conv_weight_acts_c_kepler_sw( + float* images, float* hidActs, float* targets, const int numImages, + const int numFilters, const int numModulesY, const int numModulesX, + const int imgSizeY, const int imgSizeX, const int filterSize, + const int paddingStart, const int moduleStride, const int imgStride, + const int sumWidth, const float scaleTargets, const float scaleOutputs); + +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/local/forward.cpp b/dnn/src/cuda/local/forward.cpp index 26494604..22b4e458 100644 --- a/dnn/src/cuda/local/forward.cpp +++ b/dnn/src/cuda/local/forward.cpp @@ -10,33 +10,27 @@ */ #include "src/cuda/local/opr_impl.h" +#include "src/cuda/handle.h" #include "src/cuda/local/local.cuh" #include "src/cuda/utils.h" -#include "src/cuda/handle.h" #include "src/common/utils.cuh" namespace megdnn { namespace cuda { -void LocalForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) -{ - megdnn_assert(src.layout.dtype == dtype::Float32(), - "cuda do not support fp16 local operator"); +void LocalForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + megdnn_assert( + src.layout.dtype == dtype::Float32(), + "cuda do not support fp16 local operator"); check_exec(src.layout, filter.layout, dst.layout, workspace.size); bool is_xcorr = param().mode == Mode::CROSS_CORRELATION; - auto N = src.layout.shape[0], - IC = src.layout.shape[1], - IH = src.layout.shape[2], + auto N = src.layout.shape[0], IC = src.layout.shape[1], IH = src.layout.shape[2], IW = src.layout.shape[3]; - auto OC = dst.layout.shape[1], - OH = dst.layout.shape[2], - OW = dst.layout.shape[3]; - auto FH = filter.layout.shape[3], - FW = filter.layout.shape[4]; + auto OC = dst.layout.shape[1], OH = dst.layout.shape[2], OW = dst.layout.shape[3]; + auto FH = filter.layout.shape[3], FW = filter.layout.shape[4]; auto handle = concrete_handle(this->handle()); auto stream = cuda_stream(this->handle()); auto cublas = cublas_handle(this->handle()); @@ -45,90 +39,61 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, size_t src_batch_strd = src.layout.stride[0]; size_t dst_batch_strd = dst.layout.stride[0]; if (use_cuda_convnet(src.layout, filter.layout, dst.layout)) { - local::forward_proxy_convnet(src.ptr(), - filter.ptr(), - dst.ptr(), - reinterpret_cast(workspace.raw_ptr), - N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - src_batch_strd, dst_batch_strd, - param().pad_h, param().pad_w, - param().stride_h, param().stride_w, - cublas, stream, - one, zero); - } else if (local::forward_proxy_default_share_mem_in_bytes(IH, IW) <= - handle->device_prop().sharedMemPerBlock) { - local::forward_proxy_default(src.ptr(), - filter.ptr(), - dst.ptr(), - N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - src_batch_strd, dst_batch_strd, - param().pad_h, param().pad_w, - param().stride_h, param().stride_w, - is_xcorr, - stream); + local::forward_proxy_convnet( + src.ptr(), filter.ptr(), dst.ptr(), + reinterpret_cast(workspace.raw_ptr), N, IC, IH, IW, OC, OH, OW, + FH, FW, src_batch_strd, dst_batch_strd, param().pad_h, param().pad_w, + param().stride_h, param().stride_w, cublas, stream, one, zero); + } else if ( + local::forward_proxy_default_share_mem_in_bytes(IH, IW) <= + handle->device_prop().sharedMemPerBlock) { + local::forward_proxy_default( + src.ptr(), filter.ptr(), dst.ptr(), + N, IC, IH, IW, OC, OH, OW, FH, FW, src_batch_strd, dst_batch_strd, + param().pad_h, param().pad_w, param().stride_h, param().stride_w, + is_xcorr, stream); } else { megdnn_throw(ssprintf( "No usable kernel for local conv, src: %s filter: %s \n", - src.layout.to_string().c_str(), - filter.layout.to_string().c_str())); + src.layout.to_string().c_str(), filter.layout.to_string().c_str())); } } -size_t LocalForwardImpl::get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) -{ +size_t LocalForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { size_t res = 0u; - auto N = src.shape[0], - IC = src.shape[1], IH = src.shape[2], IW = src.shape[3], - OC = dst.shape[1], OH = dst.shape[2], OW = dst.shape[3], - FH = filter.shape[3], FW = filter.shape[4]; - auto PH = param().pad_h, PW = param().pad_w, - SH = param().stride_h, SW = param().stride_w; + auto N = src.shape[0], IC = src.shape[1], IH = src.shape[2], IW = src.shape[3], + OC = dst.shape[1], OH = dst.shape[2], OW = dst.shape[3], FH = filter.shape[3], + FW = filter.shape[4]; + auto PH = param().pad_h, PW = param().pad_w, SH = param().stride_h, + SW = param().stride_w; size_t src_batch_strd = src.stride[0]; size_t dst_batch_strd = dst.stride[0]; if (use_cuda_convnet(src, filter, dst)) { - res = local::get_workspace_in_floats_forward_proxy_convnet(N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - src_batch_strd, dst_batch_strd, - PH, PW, - SH, SW) * sizeof(dt_float32); + res = local::get_workspace_in_floats_forward_proxy_convnet( + N, IC, IH, IW, OC, OH, OW, FH, FW, src_batch_strd, dst_batch_strd, + PH, PW, SH, SW) * + sizeof(dt_float32); } else { res = 0u; } return res; } -bool LocalForwardImpl::use_cuda_convnet(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) -{ - auto N = src.shape[0], - IC = src.shape[1], IH = src.shape[2], IW = src.shape[3], - OC = dst.shape[1], OH = dst.shape[2], OW = dst.shape[3], - FH = filter.shape[3], FW = filter.shape[4]; - auto PH = param().pad_h, PW = param().pad_w, - SH = param().stride_h, SW = param().stride_w; +bool LocalForwardImpl::use_cuda_convnet( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { + auto N = src.shape[0], IC = src.shape[1], IH = src.shape[2], IW = src.shape[3], + OC = dst.shape[1], OH = dst.shape[2], OW = dst.shape[3], FH = filter.shape[3], + FW = filter.shape[4]; + auto PH = param().pad_h, PW = param().pad_w, SH = param().stride_h, + SW = param().stride_w; return param().mode == Mode::CROSS_CORRELATION && - local::can_forward_proxy_convnet(N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - IC*IH*IW, OC*OH*OW, - PH, PW, - SH, SW); + local::can_forward_proxy_convnet( + N, IC, IH, IW, OC, OH, OW, FH, FW, IC * IH * IW, OC * OH * OW, PH, + PW, SH, SW); } -} // namespace cuda -} // namespace megdnn - +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/local/forward.cu b/dnn/src/cuda/local/forward.cu index 0a1bf6e7..49eea89a 100644 --- a/dnn/src/cuda/local/forward.cu +++ b/dnn/src/cuda/local/forward.cu @@ -10,9 +10,9 @@ */ #include "src/cuda/local/local.cuh" -#include "src/cuda/utils.cuh" -#include "src/cuda/local/cuda-convnet2/nvmatrix.cuh" #include "src/cuda/local/cuda-convnet2/cudaconv2.cuh" +#include "src/cuda/local/cuda-convnet2/nvmatrix.cuh" +#include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { @@ -28,134 +28,117 @@ size_t forward_proxy_default_share_mem_in_bytes(size_t IH, size_t IW) { // blockIdx.x is N/4 // threadIdx.x is [0, 1024) template -__global__ void forward_kernel(const float * __restrict__ src, - const float * __restrict__ filter, - float * __restrict__ dst, - uint32_t N, - uint32_t IC, uint32_t IH, uint32_t IW, - uint32_t OC, uint32_t OH, uint32_t OW, - uint32_t FH, uint32_t FW, - uint32_t INs, size_t ONs, - uint32_t PH, uint32_t PW, - uint32_t SH, uint32_t SW) -{ +__global__ void forward_kernel( + const float* __restrict__ src, const float* __restrict__ filter, + float* __restrict__ dst, uint32_t N, uint32_t IC, uint32_t IH, uint32_t IW, + uint32_t OC, uint32_t OH, uint32_t OW, uint32_t FH, uint32_t FW, uint32_t INs, + size_t ONs, uint32_t PH, uint32_t PW, uint32_t SH, uint32_t SW) { // Ns*ICs*sizeof(float)*IH*IW extern __shared__ float shared_mem[]; - float *src_cache = shared_mem; + float* src_cache = shared_mem; uint32_t tid = threadIdx.x; uint32_t tstride = blockDim.x; uint32_t oid = tid + blockIdx.y * tstride; - src += blockIdx.x*Ns * INs; - dst += blockIdx.x*Ns * ONs; + src += blockIdx.x * Ns * INs; + dst += blockIdx.x * Ns * ONs; uint32_t op = oid / OC; uint32_t oc = oid % OC; uint32_t oh = op / OW; uint32_t ow = op % OW; float dst_reg[Ns]; - for (uint32_t no = 0; no < Ns; ++no) dst_reg[no] = 0.0f; - uint32_t Nb = min(N-blockIdx.x*Ns, Ns); + for (uint32_t no = 0; no < Ns; ++no) + dst_reg[no] = 0.0f; + uint32_t Nb = min(N - blockIdx.x * Ns, Ns); for (uint32_t ic = 0; ic < IC; ic += ICs) { // read ICs-channel src // (Ns, ICs, IHs, IWs) - uint32_t ICb = min(ICs, IC-ic); - for (uint32_t i = tid; i < Nb*ICs*IH*IW; i += tstride) { - uint32_t ip = i % (IH*IW); - uint32_t ico = i / (IH*IW) % ICs; - uint32_t no = i / (IH*IW) / ICs; - src_cache[i] = - (ico < ICb) * src[no*INs + min(IC-1, (ic+ico))*IH*IW + ip]; + uint32_t ICb = min(ICs, IC - ic); + for (uint32_t i = tid; i < Nb * ICs * IH * IW; i += tstride) { + uint32_t ip = i % (IH * IW); + uint32_t ico = i / (IH * IW) % ICs; + uint32_t no = i / (IH * IW) / ICs; + src_cache[i] = (ico < ICb) * + src[no * INs + min(IC - 1, (ic + ico)) * IH * IW + ip]; } __syncthreads(); - if (oid < OC*OH*OW) - for (uint32_t fh = 0; fh < FH; ++fh) - { - uint32_t ih; - if (is_xcorr) ih = oh*SH + fh - PH; else ih = oh*SH + (FH-fh-1) - PH; - if (ih < IH) - for (uint32_t fw = 0; fw < FW; ++fw) - { - uint32_t iw; - if (is_xcorr) iw = ow*SW + fw - PW; else iw = ow*SW + (FW-fw-1) - PW; - if (iw < IW) - for (uint32_t ico = 0; ico < ICb; ++ico) { - uint32_t fid = op*IC*FH*FW*OC + (ic+ico)*FH*FW*OC + - fh*FW*OC + fw*OC + oc; - float fval = filter[fid]; - float src_reg[Ns]; + if (oid < OC * OH * OW) + for (uint32_t fh = 0; fh < FH; ++fh) { + uint32_t ih; + if (is_xcorr) + ih = oh * SH + fh - PH; + else + ih = oh * SH + (FH - fh - 1) - PH; + if (ih < IH) + for (uint32_t fw = 0; fw < FW; ++fw) { + uint32_t iw; + if (is_xcorr) + iw = ow * SW + fw - PW; + else + iw = ow * SW + (FW - fw - 1) - PW; + if (iw < IW) + for (uint32_t ico = 0; ico < ICb; ++ico) { + uint32_t fid = op * IC * FH * FW * OC + + (ic + ico) * FH * FW * OC + + fh * FW * OC + fw * OC + oc; + float fval = filter[fid]; + float src_reg[Ns]; #pragma unroll - for (uint32_t no = 0; no < Ns; ++no) { - src_reg[no] = src_cache[no*ICs*IH*IW + ico*IH*IW + ih*IW + iw]; - } + for (uint32_t no = 0; no < Ns; ++no) { + src_reg[no] = src_cache + [no * ICs * IH * IW + ico * IH * IW + + ih * IW + iw]; + } #pragma unroll - for (uint32_t no = 0; no < Ns; ++no) { - dst_reg[no] += src_reg[no]*fval; - } + for (uint32_t no = 0; no < Ns; ++no) { + dst_reg[no] += src_reg[no] * fval; + } + } + } } - } - } __syncthreads(); } - if (oid < OC*OH*OW) { + if (oid < OC * OH * OW) { for (uint32_t no = 0; no < Nb; ++no) { - dst[no*ONs + oc*OH*OW + op] = dst_reg[no]; + dst[no * ONs + oc * OH * OW + op] = dst_reg[no]; } } } -void forward_proxy_default(const float *src, const float *filter, float *dst, - size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, - size_t PH, size_t PW, - size_t SH, size_t SW, - bool is_xcorr, - cudaStream_t stream) -{ +void forward_proxy_default( + const float* src, const float* filter, float* dst, size_t N, size_t IC, + size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, size_t FW, + size_t INs, size_t ONs, size_t PH, size_t PW, size_t SH, size_t SW, + bool is_xcorr, cudaStream_t stream) { size_t threads = 256; - dim3 blocks = dim3(DIVUP(N, Ns), DIVUP(OC*OH*OW, threads)); + dim3 blocks = dim3(DIVUP(N, Ns), DIVUP(OC * OH * OW, threads)); if (is_xcorr) { - forward_kernel<<>>(src, filter, dst, - N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - INs, ONs, - PH, PW, - SH, SW); + forward_kernel + <<>>( + src, filter, dst, N, IC, IH, IW, OC, OH, OW, FH, FW, INs, ONs, + PH, PW, SH, SW); } else { - forward_kernel<<>>(src, filter, dst, - N, - IC, IH, IW, - OC, OH, OW, - FH, FW, - INs, ONs, - PH, PW, - SH, SW); + forward_kernel + <<>>( + src, filter, dst, N, IC, IH, IW, OC, OH, OW, FH, FW, INs, ONs, + PH, PW, SH, SW); } after_kernel_launch(); } -bool can_forward_proxy_convnet(size_t N, - size_t IC, size_t /* IH */, size_t /* IW */, - size_t /*OC*/, size_t /* OH */, size_t /* OW */, - size_t FH, size_t FW, - size_t /* INs */, size_t /* ONs */, - size_t PH, size_t PW, - size_t SH, size_t SW) -{ +bool can_forward_proxy_convnet( + size_t N, size_t IC, size_t /* IH */, size_t /* IW */, size_t /*OC*/, + size_t /* OH */, size_t /* OW */, size_t FH, size_t FW, size_t /* INs */, + size_t /* ONs */, size_t PH, size_t PW, size_t SH, size_t SW) { bool flag = true; // check pad flag &= (PH == PW); // check stride flag &= (SH == SW); - // megdnn_assert(numGroups > 1 || (numImgColors > 0 && (numImgColors <= 3 || numImgColors % 4 == 0))); + // megdnn_assert(numGroups > 1 || (numImgColors > 0 && (numImgColors <= 3 || + // numImgColors % 4 == 0))); flag &= (IC <= 3 || IC % 4 == 0); // megdnn_assert(numFilters % (16 * numGroups) == 0); - //flag &= (OC % 16 == 0); + // flag &= (OC % 16 == 0); // megdnn_assert(filterSize * filterSize == filterPixels); flag &= (FH == FW); flag &= (SH <= FH); @@ -163,53 +146,42 @@ bool can_forward_proxy_convnet(size_t N, return flag; } -size_t get_workspace_in_floats_forward_proxy_convnet(size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t /* FH */, size_t /* FW */, - size_t /* INs */, size_t /* ONs */, - size_t /* PH */, size_t /* PW */, - size_t /* SH */, size_t /* SW */) -{ - return N*IC*IH*IW + N*OC*OH*OW; +size_t get_workspace_in_floats_forward_proxy_convnet( + size_t N, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, + size_t /* FH */, size_t /* FW */, size_t /* INs */, size_t /* ONs */, + size_t /* PH */, size_t /* PW */, size_t /* SH */, size_t /* SW */) { + return N * IC * IH * IW + N * OC * OH * OW; } -void forward_proxy_convnet(const float *src, const float *filter, float *dst, - float *workspace, - size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, // IN stride and ON stride - size_t PH, size_t /* PW */, - size_t SH, size_t /* SW */, - cublasHandle_t cublas_handle, - cudaStream_t stream, - float *one, float *zero) +void forward_proxy_convnet( + const float* src, const float* filter, float* dst, float* workspace, size_t N, + size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, + size_t FW, size_t INs, size_t ONs, // IN stride and ON stride + size_t PH, size_t /* PW */, size_t SH, size_t /* SW */, + cublasHandle_t cublas_handle, cudaStream_t stream, float* one, float* zero) { - MemorySegment msrc_n(const_cast(src)), - mdst_n(dst), - mfilter(const_cast(filter)), - msrc_t(workspace+0), - mdst_t(workspace+N*IC*IH*IW); - NVMatrix nvimage_n(&msrc_n, N, IC*IH*IW, INs); - NVMatrix nvtarget_n(&mdst_n, N, OC*OH*OW, ONs); - NVMatrix nvimage_t(&msrc_t, IC*IH*IW, N); - NVMatrix nvfilter(&mfilter, OH*OW*IC*FH*FW, OC); - NVMatrix nvtarget_t(&mdst_t, OC*OH*OW, N); + MemorySegment msrc_n(const_cast(src)), mdst_n(dst), + mfilter(const_cast(filter)), msrc_t(workspace + 0), + mdst_t(workspace + N * IC * IH * IW); + NVMatrix nvimage_n(&msrc_n, N, IC * IH * IW, INs); + NVMatrix nvtarget_n(&mdst_n, N, OC * OH * OW, ONs); + NVMatrix nvimage_t(&msrc_t, IC * IH * IW, N); + NVMatrix nvfilter(&mfilter, OH * OW * IC * FH * FW, OC); + NVMatrix nvtarget_t(&mdst_t, OC * OH * OW, N); nvimage_n.transpose(nvimage_t, cublas_handle, one, zero); - localFilterActs(stream, nvimage_t, nvfilter, nvtarget_t, - IH, OH, OW, -static_cast(PH), SH, IC, 1); + localFilterActs( + stream, nvimage_t, nvfilter, nvtarget_t, IH, OH, OW, -static_cast(PH), + SH, IC, 1); after_kernel_launch(); nvtarget_t.transpose(nvtarget_n, cublas_handle, one, zero); } -} // namespace local -} // namespace cuda -} // namespace megdnn +} // namespace local +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/local/local.cuh b/dnn/src/cuda/local/local.cuh index 142a2b3e..524b3f95 100644 --- a/dnn/src/cuda/local/local.cuh +++ b/dnn/src/cuda/local/local.cuh @@ -10,8 +10,8 @@ */ #pragma once -#include #include +#include namespace megdnn { namespace cuda { @@ -19,116 +19,71 @@ namespace local { size_t forward_proxy_default_share_mem_in_bytes(size_t IH, size_t IW); -void forward_proxy_default(const float *src, const float *filter, float *dst, - size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, - size_t PH, size_t PW, - size_t SH, size_t SW, - bool is_xcorr, - cudaStream_t stream); +void forward_proxy_default( + const float* src, const float* filter, float* dst, size_t N, size_t IC, + size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, size_t FW, + size_t INs, size_t ONs, size_t PH, size_t PW, size_t SH, size_t SW, + bool is_xcorr, cudaStream_t stream); /// forward -bool can_forward_proxy_convnet(size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, - size_t PH, size_t PW, - size_t SH, size_t SW); - -void forward_proxy_convnet(const float *src, const float *filter, float *dst, - float *workspace, - size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, // IN stride and ON stride - size_t PH, size_t PW, - size_t SH, size_t SW, - cublasHandle_t cublas_handle, - cudaStream_t stream, - float *one, float *zero); - -size_t get_workspace_in_floats_forward_proxy_convnet(size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, - size_t PH, size_t PW, - size_t SH, size_t SW); +bool can_forward_proxy_convnet( + size_t N, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, + size_t FH, size_t FW, size_t INs, size_t ONs, size_t PH, size_t PW, size_t SH, + size_t SW); + +void forward_proxy_convnet( + const float* src, const float* filter, float* dst, float* workspace, size_t N, + size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, + size_t FW, size_t INs, size_t ONs, // IN stride and ON stride + size_t PH, size_t PW, size_t SH, size_t SW, cublasHandle_t cublas_handle, + cudaStream_t stream, float* one, float* zero); + +size_t get_workspace_in_floats_forward_proxy_convnet( + size_t N, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, + size_t FH, size_t FW, size_t INs, size_t ONs, size_t PH, size_t PW, size_t SH, + size_t SW); /// bwd data -bool can_backward_data_proxy_convnet(size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, - size_t PH, size_t PW, - size_t SH, size_t SW); - -void backward_data_proxy_convnet(const float *filter, - const float *diff, - float *grad, - float *workspace, - size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, // IN stride and ON stride - size_t PH, size_t PW, - size_t SH, size_t SW, - cublasHandle_t cublas_handle, - cudaStream_t stream, - float *one, float *zero); - -size_t get_workspace_in_floats_backward_data_proxy_convnet(size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, - size_t PH, size_t PW, - size_t SH, size_t SW); +bool can_backward_data_proxy_convnet( + size_t N, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, + size_t FH, size_t FW, size_t INs, size_t ONs, size_t PH, size_t PW, size_t SH, + size_t SW); + +void backward_data_proxy_convnet( + const float* filter, const float* diff, float* grad, float* workspace, size_t N, + size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, + size_t FW, size_t INs, size_t ONs, // IN stride and ON stride + size_t PH, size_t PW, size_t SH, size_t SW, cublasHandle_t cublas_handle, + cudaStream_t stream, float* one, float* zero); + +size_t get_workspace_in_floats_backward_data_proxy_convnet( + size_t N, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, + size_t FH, size_t FW, size_t INs, size_t ONs, size_t PH, size_t PW, size_t SH, + size_t SW); /// bwd filter -bool can_backward_filter_proxy_convnet(size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, - size_t PH, size_t PW, - size_t SH, size_t SW); - -void backward_filter_proxy_convnet(const float *src, - const float *diff, - float *grad, - float *workspace, - size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, // IN stride and ON stride - size_t PH, size_t PW, - size_t SH, size_t SW, - cublasHandle_t cublas_handle, - cudaStream_t stream, - float *one, float *zero); - -size_t get_workspace_in_floats_backward_filter_proxy_convnet(size_t N, - size_t IC, size_t IH, size_t IW, - size_t OC, size_t OH, size_t OW, - size_t FH, size_t FW, - size_t INs, size_t ONs, - size_t PH, size_t PW, - size_t SH, size_t SW); - -} // namespace local -} // namespace cuda -} // namespace megdnn +bool can_backward_filter_proxy_convnet( + size_t N, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, + size_t FH, size_t FW, size_t INs, size_t ONs, size_t PH, size_t PW, size_t SH, + size_t SW); + +void backward_filter_proxy_convnet( + const float* src, const float* diff, float* grad, float* workspace, size_t N, + size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, size_t FH, + size_t FW, size_t INs, size_t ONs, // IN stride and ON stride + size_t PH, size_t PW, size_t SH, size_t SW, cublasHandle_t cublas_handle, + cudaStream_t stream, float* one, float* zero); + +size_t get_workspace_in_floats_backward_filter_proxy_convnet( + size_t N, size_t IC, size_t IH, size_t IW, size_t OC, size_t OH, size_t OW, + size_t FH, size_t FW, size_t INs, size_t ONs, size_t PH, size_t PW, size_t SH, + size_t SW); + +} // namespace local +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/local/opr_impl.h b/dnn/src/cuda/local/opr_impl.h index 5ecbbd0d..e1011129 100644 --- a/dnn/src/cuda/local/opr_impl.h +++ b/dnn/src/cuda/local/opr_impl.h @@ -16,55 +16,55 @@ namespace megdnn { namespace cuda { -class LocalForwardImpl final: public LocalForward { - public: - using LocalForward::LocalForward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) override; - private: - bool use_cuda_convnet(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst); +class LocalForwardImpl final : public LocalForward { +public: + using LocalForward::LocalForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; + +private: + bool use_cuda_convnet( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst); }; -class LocalBackwardDataImpl final: public LocalBackwardData { - public: - using LocalBackwardData::LocalBackwardData; - void exec(_megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &filter, - const TensorLayout &diff, - const TensorLayout &grad) override; - private: - bool use_cuda_convnet(const TensorLayout &filter, - const TensorLayout &diff, - const TensorLayout &grad); +class LocalBackwardDataImpl final : public LocalBackwardData { +public: + using LocalBackwardData::LocalBackwardData; + void exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; + +private: + bool use_cuda_convnet( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad); }; -class LocalBackwardFilterImpl final: public LocalBackwardFilter { - public: - using LocalBackwardFilter::LocalBackwardFilter; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_in grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &src, - const TensorLayout &diff, - const TensorLayout &grad) override; - private: - bool use_cuda_convnet(const TensorLayout &src, - const TensorLayout &diff, - const TensorLayout &grad); +class LocalBackwardFilterImpl final : public LocalBackwardFilter { +public: + using LocalBackwardFilter::LocalBackwardFilter; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_in grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; + +private: + bool use_cuda_convnet( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad); }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/local_share/backward_data/algo.cpp b/dnn/src/cuda/local_share/backward_data/algo.cpp index 2a5d6d4c..d7f9a0ea 100644 --- a/dnn/src/cuda/local_share/backward_data/algo.cpp +++ b/dnn/src/cuda/local_share/backward_data/algo.cpp @@ -33,11 +33,9 @@ LocalShareBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( const TensorLayout& diff, const TensorLayout& grad) : opr{o}, filter_layout{filter}, diff_layout{diff}, grad_layout{grad} {} -LocalShareBackwardDataImpl::AlgoBase::ExecArgs::ExecArgs(LocalShareBackwardDataImpl* opr, - _megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) +LocalShareBackwardDataImpl::AlgoBase::ExecArgs::ExecArgs( + LocalShareBackwardDataImpl* opr, _megdnn_tensor_in filter, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) : SizeArgs(opr, filter.layout, diff.layout, grad.layout), filter_tensor{&filter}, diff_tensor{&diff}, @@ -51,8 +49,8 @@ std::string LocalShareBackwardDataImpl::AlgoBase::SizeArgs::to_string() const { "filter=%s, diff=%s, grad=%s, " "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s->%s", filter_layout.to_string().c_str(), diff_layout.to_string().c_str(), - grad_layout.to_string().c_str(), param.pad_h, param.pad_w, - param.stride_h, param.stride_w, param.dilate_h, param.dilate_w, + grad_layout.to_string().c_str(), param.pad_h, param.pad_w, param.stride_h, + param.stride_w, param.dilate_h, param.dilate_w, static_cast(param.mode), filter_layout.dtype.name(), diff_layout.dtype.name(), grad_layout.dtype.name()); } diff --git a/dnn/src/cuda/local_share/backward_data/algo.h b/dnn/src/cuda/local_share/backward_data/algo.h index 7926b44c..23645ff4 100644 --- a/dnn/src/cuda/local_share/backward_data/algo.h +++ b/dnn/src/cuda/local_share/backward_data/algo.h @@ -41,16 +41,18 @@ public: TensorLayout filter_layout, diff_layout, grad_layout; std::string to_string() const; - SizeArgs(LocalShareBackwardDataImpl* opr, const TensorLayout& filter, - const TensorLayout& diff, const TensorLayout& grad); + SizeArgs( + LocalShareBackwardDataImpl* opr, const TensorLayout& filter, + const TensorLayout& diff, const TensorLayout& grad); }; struct ExecArgs : public SizeArgs { const TensorND *filter_tensor, *diff_tensor, *grad_tensor; Workspace workspace; - ExecArgs(LocalShareBackwardDataImpl* opr, _megdnn_tensor_in filter, - _megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace); + ExecArgs( + LocalShareBackwardDataImpl* opr, _megdnn_tensor_in filter, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -65,16 +67,15 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "local share conv fwd algo %s: required workspace %zu " - "bytes, got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "local share conv fwd algo %s: required workspace %zu " + "bytes, got %zu", + name(), req, workspace.size); return *this; } }; @@ -85,32 +86,22 @@ public: size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } - const char* name() const override { - return "LOCAL_SHARE_IMPLICIT_GEMM"; - } + const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; } MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) }; -class LocalShareBackwardDataImpl::AlgoBatchedMatMul final - : public AlgoBase { +class LocalShareBackwardDataImpl::AlgoBatchedMatMul final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; void exec(const ExecArgs& args) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } - const char* name() const override { - return "LOCAL_SHARE_BATCHED_MATMUL"; - } + const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) }; diff --git a/dnn/src/cuda/local_share/backward_data/batched_matmul.cpp b/dnn/src/cuda/local_share/backward_data/batched_matmul.cpp index 2c82e27c..155a9aa2 100644 --- a/dnn/src/cuda/local_share/backward_data/batched_matmul.cpp +++ b/dnn/src/cuda/local_share/backward_data/batched_matmul.cpp @@ -31,24 +31,23 @@ bool LocalShareBackwardDataImpl::AlgoBatchedMatMul::is_available( available &= (format == Format::NCHW); // mode must be cross correlation available &= (mode == Mode::CROSS_CORRELATION); - auto filter_dtype = args.filter_layout.dtype, - diff_dtype = args.diff_layout.dtype, + auto filter_dtype = args.filter_layout.dtype, diff_dtype = args.diff_layout.dtype, grad_dtype = args.grad_layout.dtype; // only support float32 - available &= (filter_dtype == diff_dtype && filter_dtype == grad_dtype && - filter_dtype == dtype::Float32()); + available &= + (filter_dtype == diff_dtype && filter_dtype == grad_dtype && + filter_dtype == dtype::Float32()); // do not support dilate conv size_t dh = param.dilate_h, dw = param.dilate_w; available &= (dh == 1 && dw == 1); return available; } -WorkspaceBundle -LocalShareBackwardDataImpl::AlgoBatchedMatMul::get_workspace_bundle( +WorkspaceBundle LocalShareBackwardDataImpl::AlgoBatchedMatMul::get_workspace_bundle( dt_byte* raw_ptr, const SizeArgs& args) const { auto&& param = args.opr->param(); - unpack_local_share_params(args.grad_layout, args.filter_layout, - args.diff_layout, param); + unpack_local_share_params( + args.grad_layout, args.filter_layout, args.diff_layout, param); using Param = LocalShare::Param; using Sparse = Param::Sparse; size_t groups = 1; @@ -57,13 +56,11 @@ LocalShareBackwardDataImpl::AlgoBatchedMatMul::get_workspace_bundle( } size_t icpg = ci / groups, ocpg = co / groups; size_t ws_pretranspose = n * co * ho * wo * args.diff_layout.dtype.size(); - size_t ws_col2im = - n * ci * ho * wo * fh * fw * args.grad_layout.dtype.size(); + size_t ws_col2im = n * ci * ho * wo * fh * fw * args.grad_layout.dtype.size(); auto&& matmul_opr = args.opr->handle()->create_operator(); - TensorLayout A{{groups * sgh * sgw, icpg * fh * fw, ocpg}, - dtype::Float32()}; - TensorLayout B{{groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, - dtype::Float32()}; + TensorLayout A{{groups * sgh * sgw, icpg * fh * fw, ocpg}, dtype::Float32()}; + TensorLayout B{ + {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; TensorLayout C{ {groups * sgh * sgw, icpg * fh * fw, ho / sgh * wo / sgw * n}, dtype::Float32()}; @@ -77,11 +74,10 @@ size_t LocalShareBackwardDataImpl::AlgoBatchedMatMul::get_workspace_in_bytes( return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } -void LocalShareBackwardDataImpl::AlgoBatchedMatMul::exec( - const ExecArgs& args) const { +void LocalShareBackwardDataImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) const { auto&& param = args.opr->param(); - unpack_local_share_params(args.grad_layout, args.filter_layout, - args.diff_layout, param); + unpack_local_share_params( + args.grad_layout, args.filter_layout, args.diff_layout, param); using Param = LocalShare::Param; using Sparse = Param::Sparse; size_t groups = 1; @@ -90,10 +86,10 @@ void LocalShareBackwardDataImpl::AlgoBatchedMatMul::exec( } size_t icpg = ci / groups, ocpg = co / groups; local_share::Param kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ph = ph, - kern_param.pw = pw, kern_param.grp_ho = ho / sgh, - kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, kern_param.sgw = sgw; + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ph = ph, kern_param.pw = pw, + kern_param.grp_ho = ho / sgh, kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, + kern_param.sgw = sgw; auto ws = get_workspace_bundle(args.workspace.raw_ptr, args); auto ws_pretranspose = ws.get(0); @@ -101,8 +97,8 @@ void LocalShareBackwardDataImpl::AlgoBatchedMatMul::exec( auto ws_matmul = ws.get(2); { - TensorLayout B1{{groups, sgh, sgw, ocpg, ho / sgh, wo / sgw, n}, - dtype::Float32()}; + TensorLayout B1{ + {groups, sgh, sgw, ocpg, ho / sgh, wo / sgw, n}, dtype::Float32()}; B1.stride[0] = wo * ho * ocpg; B1.stride[1] = wo * ho / sgh; B1.stride[2] = wo / sgw; @@ -111,8 +107,8 @@ void LocalShareBackwardDataImpl::AlgoBatchedMatMul::exec( B1.stride[5] = 1; B1.stride[6] = co * ho * wo; TensorND ts_B1{args.diff_tensor->raw_ptr, B1}; - TensorLayout B2{{groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, - dtype::Float32()}; + TensorLayout B2{ + {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; B2.init_contiguous_stride(); TensorND ts_B2{ws_pretranspose, B2}; auto&& relayout_opr = args.opr->handle()->create_operator(); @@ -120,10 +116,9 @@ void LocalShareBackwardDataImpl::AlgoBatchedMatMul::exec( } auto&& matmul_opr = args.opr->handle()->create_operator(); - TensorLayout A{{groups * sgh * sgw, icpg * fh * fw, ocpg}, - dtype::Float32()}; - TensorLayout B{{groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, - dtype::Float32()}; + TensorLayout A{{groups * sgh * sgw, icpg * fh * fw, ocpg}, dtype::Float32()}; + TensorLayout B{ + {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; TensorLayout C{ {groups * sgh * sgw, icpg * fh * fw, ho / sgh * wo / sgw * n}, dtype::Float32()}; @@ -138,8 +133,8 @@ void LocalShareBackwardDataImpl::AlgoBatchedMatMul::exec( auto&& stream = cuda_stream(args.opr->handle()); local_share::_do_local_share_col2im( reinterpret_cast(ws_col2im), - args.grad_tensor->ptr(), fh, fw, sh, sw, groups, - kern_param, stream); + args.grad_tensor->ptr(), fh, fw, sh, sw, groups, kern_param, + stream); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/local_share/backward_data/implicit_gemm.cpp b/dnn/src/cuda/local_share/backward_data/implicit_gemm.cpp index 21ac3d2b..41d4dd78 100644 --- a/dnn/src/cuda/local_share/backward_data/implicit_gemm.cpp +++ b/dnn/src/cuda/local_share/backward_data/implicit_gemm.cpp @@ -35,45 +35,43 @@ bool LocalShareBackwardDataImpl::AlgoImplicitGemm::is_available( available &= (sparse == Sparse::DENSE); // mode must be cross correlation available &= (mode == Mode::CROSS_CORRELATION); - unpack_local_share_params(args.grad_layout, args.filter_layout, - args.diff_layout, param); + unpack_local_share_params( + args.grad_layout, args.filter_layout, args.diff_layout, param); available &= (ho % sgh == 0 && wo % sgw == 0); // not support dilated convolution available &= (dh == 1 && dw == 1); available &= (co % 4 == 0); - auto filter_dtype = args.filter_layout.dtype, - diff_dtype = args.diff_layout.dtype, + auto filter_dtype = args.filter_layout.dtype, diff_dtype = args.diff_layout.dtype, grad_dtype = args.grad_layout.dtype; // only support float32 - available &= (filter_dtype == diff_dtype && filter_dtype == grad_dtype && - filter_dtype == dtype::Float32()); + available &= + (filter_dtype == diff_dtype && filter_dtype == grad_dtype && + filter_dtype == dtype::Float32()); // only support sm_60 or later available &= is_compute_capability_required(6, 0); return available; } -size_t -LocalShareBackwardDataImpl::AlgoImplicitGemm::get_workspace_in_bytes( +size_t LocalShareBackwardDataImpl::AlgoImplicitGemm::get_workspace_in_bytes( const SizeArgs& args) const { auto&& param = args.opr->param(); - unpack_local_share_params(args.grad_layout, args.filter_layout, - args.diff_layout, param); + unpack_local_share_params( + args.grad_layout, args.filter_layout, args.diff_layout, param); size_t ws_size_grad = n * ci * hi * wi * args.grad_layout.dtype.size(); size_t ws_size_diff = n * co * ho * wo * args.diff_layout.dtype.size(); return ws_size_grad + ws_size_diff; } -void LocalShareBackwardDataImpl::AlgoImplicitGemm::exec( - const ExecArgs& args) const { +void LocalShareBackwardDataImpl::AlgoImplicitGemm::exec(const ExecArgs& args) const { local_share::Param kern_param; auto&& param = args.opr->param(); - unpack_local_share_params(args.grad_layout, args.filter_layout, - args.diff_layout, param); - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ph = ph, - kern_param.pw = pw, kern_param.grp_ho = ho / sgh, - kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, kern_param.sgw = sgw; + unpack_local_share_params( + args.grad_layout, args.filter_layout, args.diff_layout, param); + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ph = ph, kern_param.pw = pw, + kern_param.grp_ho = ho / sgh, kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, + kern_param.sgw = sgw; auto&& handle = concrete_handle(args.opr->handle()); auto&& cublas_hdl = cublas_handle(args.opr->handle()); auto&& stream = cuda_stream(args.opr->handle()); @@ -82,8 +80,7 @@ void LocalShareBackwardDataImpl::AlgoImplicitGemm::exec( auto zero = handle->zero_device(); local_share_bwd_data::_do_local_share_bwd_data_implicit_gemm( - args.filter_tensor->ptr(), - args.diff_tensor->ptr(), + args.filter_tensor->ptr(), args.diff_tensor->ptr(), args.grad_tensor->ptr(), reinterpret_cast(args.workspace.raw_ptr), fh, fw, sh, sw, kern_param, cublas_hdl, stream, one, zero); diff --git a/dnn/src/cuda/local_share/backward_data/local_share_bwd_data.cuh b/dnn/src/cuda/local_share/backward_data/local_share_bwd_data.cuh index feadb88c..295ba2db 100644 --- a/dnn/src/cuda/local_share/backward_data/local_share_bwd_data.cuh +++ b/dnn/src/cuda/local_share/backward_data/local_share_bwd_data.cuh @@ -15,10 +15,9 @@ namespace cuda { namespace local_share_bwd_data { void _do_local_share_bwd_data_implicit_gemm( - const float* d_filter, const float* d_diff, float* d_grad, - float* workspace, int fh, int fw, int sh, int sw, - const local_share::Param& param, cublasHandle_t cublas_handle, - cudaStream_t stream, float* one, float* zero); + const float* d_filter, const float* d_diff, float* d_grad, float* workspace, + int fh, int fw, int sh, int sw, const local_share::Param& param, + cublasHandle_t cublas_handle, cudaStream_t stream, float* one, float* zero); } // namespace local_share_bwd_data } // namespace cuda diff --git a/dnn/src/cuda/local_share/backward_data/local_share_bwd_data_f32_implicit_gemm.cu b/dnn/src/cuda/local_share/backward_data/local_share_bwd_data_f32_implicit_gemm.cu index f377a5ac..e58bc7be 100644 --- a/dnn/src/cuda/local_share/backward_data/local_share_bwd_data_f32_implicit_gemm.cu +++ b/dnn/src/cuda/local_share/backward_data/local_share_bwd_data_f32_implicit_gemm.cu @@ -31,8 +31,7 @@ struct ThreadConfig { template struct DiffTileCount { - static int const tile_batch = - UnrollConfig::unroll_n * ThreadConfig::nr_thread_x; + static int const tile_batch = UnrollConfig::unroll_n * ThreadConfig::nr_thread_x; static int const load_x = tile_batch > 32 ? 32 : tile_batch; static int const load_y = ThreadConfig::nr_threads / load_x; @@ -49,8 +48,7 @@ struct DiffTileCount { template struct FilterTileCount { - static int const tile_ci = - ThreadConfig::nr_thread_y * UnrollConfig::unroll_ci; + static int const tile_ci = ThreadConfig::nr_thread_y * UnrollConfig::unroll_ci; static int const smem_h = tile_ci; static int const smem_w = UnrollConfig::unroll_co; static int const smem_stride = smem_w % 2 == 0 ? smem_w + 1 : smem_w; @@ -176,8 +174,7 @@ struct FilterGlobal2ShareMemVisitor { copy_t reg[TileCount::reg_row][TileCount::reg_col]; - __device__ FilterGlobal2ShareMemVisitor(copy_t* smem, int stride, - int remain) + __device__ FilterGlobal2ShareMemVisitor(copy_t* smem, int stride, int remain) : smem{smem}, stride{stride}, remain{remain} {} __device__ __forceinline__ void first_copy() { @@ -248,9 +245,7 @@ struct FilterGlobal2ShareMemVisitor { return &smem[y * TileCount::smem_stride + x]; } - __device__ __forceinline__ void move_forward() { - g_ptr += UnrollConfig::unroll_co; - } + __device__ __forceinline__ void move_forward() { g_ptr += UnrollConfig::unroll_co; } }; template @@ -259,8 +254,7 @@ __device__ __forceinline__ void consume_block( diff_gl2sh_visitor, FilterGlobal2ShareMemVisitor& filter_gl2sh_visitor, - float r_diff[UnrollConfig::unroll_n], - float r_filter[UnrollConfig::unroll_ci], + float r_diff[UnrollConfig::unroll_n], float r_filter[UnrollConfig::unroll_ci], float r_grad[UnrollConfig::unroll_ci][UnrollConfig::unroll_n]) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -327,8 +321,7 @@ __global__ void local_share_bwd_data_device_template_f32( DiffGlobal2ShareMemVisitor diff_gl2sh_visitor{sh_diff, ho * wo * param.n, param.n - b_batch}; FilterGlobal2ShareMemVisitor - filter_gl2sh_visitor{sh_filter, param.co * fh * fw, - param.ci - b_ci}; + filter_gl2sh_visitor{sh_filter, param.co * fh * fw, param.ci - b_ci}; float r_diff[UnrollConfig::unroll_n]; float r_filter[UnrollConfig::unroll_ci]; @@ -358,11 +351,9 @@ __global__ void local_share_bwd_data_device_template_f32( int kw = b_wi + param.pw - width_start * sw; int sgh_idx = height_start / param.grp_ho; int sgw_idx = width_start / param.grp_wo; - diff_gl2sh_visitor.g_ptr = - g_ptr_diff + (height_start * wo + width_start) * param.n; + diff_gl2sh_visitor.g_ptr = g_ptr_diff + (height_start * wo + width_start) * param.n; filter_gl2sh_visitor.g_ptr = - g_ptr_filter + - (sgh_idx * param.sgw + sgw_idx) * nr_elems_per_filter_grp + + g_ptr_filter + (sgh_idx * param.sgw + sgw_idx) * nr_elems_per_filter_grp + (kh * fw + kw) * param.co; if (height_start <= height_end && width_start <= width_end) { @@ -386,11 +377,10 @@ __global__ void local_share_bwd_data_device_template_f32( int sgw_idx = w_next / param.grp_wo; diff_gl2sh_visitor.g_ptr = g_ptr_diff + (h_next * wo + w_next) * param.n; - filter_gl2sh_visitor.g_ptr = - g_ptr_filter + - (sgh_idx * param.sgw + sgw_idx) * - nr_elems_per_filter_grp + - (kh * fw + kw) * param.co; + filter_gl2sh_visitor.g_ptr = g_ptr_filter + + (sgh_idx * param.sgw + sgw_idx) * + nr_elems_per_filter_grp + + (kh * fw + kw) * param.co; diff_gl2sh_visitor.copy(); filter_gl2sh_visitor.copy(); } @@ -402,12 +392,11 @@ __global__ void local_share_bwd_data_device_template_f32( } consume_block( - diff_gl2sh_visitor, filter_gl2sh_visitor, r_diff, - r_filter, r_grad); + diff_gl2sh_visitor, filter_gl2sh_visitor, r_diff, r_filter, + r_grad); // last tile - if (!(h == height_end && w == width_end && - co_outer == co_blks - 1)) { + if (!(h == height_end && w == width_end && co_outer == co_blks - 1)) { __syncthreads(); diff_gl2sh_visitor.commit(); filter_gl2sh_visitor.commit(); @@ -423,55 +412,50 @@ __global__ void local_share_bwd_data_device_template_f32( for (int i = 0; i < UnrollConfig::unroll_ci; ++i) { #pragma unroll for (int j = 0; j < UnrollConfig::unroll_n; ++j) { - if (check_bounds && - (t_batch + j * ThreadConfig::nr_thread_x >= param.n || - t_ci + i * ThreadConfig::nr_thread_y >= param.ci)) { + if (check_bounds && (t_batch + j * ThreadConfig::nr_thread_x >= param.n || + t_ci + i * ThreadConfig::nr_thread_y >= param.ci)) { } else { - g_ptr_grad[j * ThreadConfig::nr_thread_x + - i * ThreadConfig::nr_thread_y * ci_stride] = - r_grad[i][j]; + g_ptr_grad + [j * ThreadConfig::nr_thread_x + + i * ThreadConfig::nr_thread_y * ci_stride] = r_grad[i][j]; } } } } void (*get_kern(const Param& param, LaunchConfig& launch_config))( - const float* __restrict__, const float* __restrict__, - float* __restrict__, Param, int, int, int, int) { - void (*kern)(const float* __restrict__, const float* __restrict__, - float* __restrict__, Param, int, int, int, int); + const float* __restrict__, const float* __restrict__, float* __restrict__, + Param, int, int, int, int) { + void (*kern)( + const float* __restrict__, const float* __restrict__, float* __restrict__, + Param, int, int, int, int); kern = nullptr; -#define CHK3(n_, ci_, co_, tx_, ty_) \ - if (param.n >= n_) { \ - if (param.ci >= ci_) { \ - if (param.co % co_ == 0) { \ - static constexpr int unroll_ci = (ci_ + ty_ - 1) / ty_; \ - static constexpr int unroll_co = co_; \ - static constexpr int unroll_n = (n_ + tx_ - 1) / tx_; \ - static constexpr int thread_x = tx_; \ - static constexpr int thread_y = ty_; \ - typedef UnrollConfig \ - UnrollConfig; \ - typedef ThreadConfig ThreadConfig; \ - typedef DiffTileCount \ - DiffTileCount; \ - typedef FilterTileCount \ - FilterTileCount; \ - kern = local_share_bwd_data_device_template_f32< \ - true, UnrollConfig, ThreadConfig>; \ - launch_config.nr_threads_x = thread_x; \ - launch_config.nr_threads_y = thread_y; \ - launch_config.nr_threads_z = 1; \ - launch_config.nr_blocks_x = param.hi * param.wi; \ - launch_config.nr_blocks_y = \ - DIVUP(param.n, DiffTileCount::tile_batch); \ - launch_config.nr_blocks_z = \ - DIVUP(param.ci, FilterTileCount::tile_ci); \ - launch_config.smem_size_in_bytes = \ - sizeof(float) * \ - (DiffTileCount::smem_tot + FilterTileCount::smem_tot); \ - } \ - } \ +#define CHK3(n_, ci_, co_, tx_, ty_) \ + if (param.n >= n_) { \ + if (param.ci >= ci_) { \ + if (param.co % co_ == 0) { \ + static constexpr int unroll_ci = (ci_ + ty_ - 1) / ty_; \ + static constexpr int unroll_co = co_; \ + static constexpr int unroll_n = (n_ + tx_ - 1) / tx_; \ + static constexpr int thread_x = tx_; \ + static constexpr int thread_y = ty_; \ + typedef UnrollConfig UnrollConfig; \ + typedef ThreadConfig ThreadConfig; \ + typedef DiffTileCount DiffTileCount; \ + typedef FilterTileCount FilterTileCount; \ + kern = local_share_bwd_data_device_template_f32< \ + true, UnrollConfig, ThreadConfig>; \ + launch_config.nr_threads_x = thread_x; \ + launch_config.nr_threads_y = thread_y; \ + launch_config.nr_threads_z = 1; \ + launch_config.nr_blocks_x = param.hi * param.wi; \ + launch_config.nr_blocks_y = DIVUP(param.n, DiffTileCount::tile_batch); \ + launch_config.nr_blocks_z = DIVUP(param.ci, FilterTileCount::tile_ci); \ + launch_config.smem_size_in_bytes = \ + sizeof(float) * \ + (DiffTileCount::smem_tot + FilterTileCount::smem_tot); \ + } \ + } \ } #define CHK2(n_, ci_) \ CHK3(n_, ci_, 4, 8, 16) \ @@ -491,43 +475,41 @@ void (*get_kern(const Param& param, LaunchConfig& launch_config))( #undef CHK2 #undef CHK2_ #undef CHK3 -#define CHK3(n_, ci_, co_, tx_, ty_) \ - if (param.n % n_ == 0) { \ - if (param.ci % ci_ == 0) { \ - if (param.co % co_ == 0) { \ - static constexpr int unroll_ci = (ci_) / (ty_); \ - static constexpr int unroll_co = co_; \ - static constexpr int unroll_n = (n_) / (tx_); \ - static constexpr int thread_x = tx_; \ - static constexpr int thread_y = ty_; \ - typedef UnrollConfig \ - UnrollConfig; \ - typedef ThreadConfig ThreadConfig; \ - typedef DiffTileCount \ - DiffTileCount; \ - typedef FilterTileCount \ - FilterTileCount; \ - kern = local_share_bwd_data_device_template_f32< \ - false, UnrollConfig, ThreadConfig>; \ - launch_config.nr_threads_x = thread_x; \ - launch_config.nr_threads_y = thread_y; \ - launch_config.nr_threads_z = 1; \ - launch_config.nr_blocks_x = param.hi * param.wi; \ - launch_config.nr_blocks_y = \ - DIVUP(param.n, DiffTileCount::tile_batch); \ - launch_config.nr_blocks_z = \ - DIVUP(param.ci, FilterTileCount::tile_ci); \ - launch_config.smem_size_in_bytes = \ - sizeof(float) * \ - (DiffTileCount::smem_tot + FilterTileCount::smem_tot); \ - } \ - } \ +#define CHK3(n_, ci_, co_, tx_, ty_) \ + if (param.n % n_ == 0) { \ + if (param.ci % ci_ == 0) { \ + if (param.co % co_ == 0) { \ + static constexpr int unroll_ci = (ci_) / (ty_); \ + static constexpr int unroll_co = co_; \ + static constexpr int unroll_n = (n_) / (tx_); \ + static constexpr int thread_x = tx_; \ + static constexpr int thread_y = ty_; \ + typedef UnrollConfig UnrollConfig; \ + typedef ThreadConfig ThreadConfig; \ + typedef DiffTileCount DiffTileCount; \ + typedef FilterTileCount FilterTileCount; \ + kern = local_share_bwd_data_device_template_f32< \ + false, UnrollConfig, ThreadConfig>; \ + launch_config.nr_threads_x = thread_x; \ + launch_config.nr_threads_y = thread_y; \ + launch_config.nr_threads_z = 1; \ + launch_config.nr_blocks_x = param.hi * param.wi; \ + launch_config.nr_blocks_y = DIVUP(param.n, DiffTileCount::tile_batch); \ + launch_config.nr_blocks_z = DIVUP(param.ci, FilterTileCount::tile_ci); \ + launch_config.smem_size_in_bytes = \ + sizeof(float) * \ + (DiffTileCount::smem_tot + FilterTileCount::smem_tot); \ + } \ + } \ } -#define CHK2(n_, ci_) CHK3(n_, ci_, 4, 8, 8) CHK3(n_, ci_, 8, 8, 8) CHK3(n_, ci_, 16, 8, 8) +#define CHK2(n_, ci_) \ + CHK3(n_, ci_, 4, 8, 8) CHK3(n_, ci_, 8, 8, 8) CHK3(n_, ci_, 16, 8, 8) #define CHK(n_) \ CHK2(n_, 8) \ CHK2(n_, 16) \ - CHK2(n_, 32) CHK2(n_, 64) CHK3(n_, 128, 4, 8, 16) CHK3(n_, 128, 8, 8, 16) CHK3(n_, 128, 16, 8, 16) + CHK2(n_, 32) \ + CHK2(n_, 64) \ + CHK3(n_, 128, 4, 8, 16) CHK3(n_, 128, 8, 8, 16) CHK3(n_, 128, 16, 8, 16) CHK(8); CHK(16); CHK(32); @@ -535,19 +517,19 @@ void (*get_kern(const Param& param, LaunchConfig& launch_config))( #undef CHK #undef CHK2 #undef CHK3 - megdnn_assert(kern != nullptr, - "no usable kernel implementation for local share " - "backward data (batch,co,ci)=(%d,%d,%d)", - param.n, param.co, param.ci); + megdnn_assert( + kern != nullptr, + "no usable kernel implementation for local share " + "backward data (batch,co,ci)=(%d,%d,%d)", + param.n, param.co, param.ci); return kern; } } // namespace void megdnn::cuda::local_share_bwd_data::_do_local_share_bwd_data_implicit_gemm( - const float* d_filter, const float* d_diff, float* d_grad, - float* workspace, int fh, int fw, int sh, int sw, const Param& param, - cublasHandle_t cublas_handle, cudaStream_t stream, float* one, - float* zero) { + const float* d_filter, const float* d_diff, float* d_grad, float* workspace, + int fh, int fw, int sh, int sw, const Param& param, + cublasHandle_t cublas_handle, cudaStream_t stream, float* one, float* zero) { int ho = param.grp_ho * param.sgh, wo = param.grp_wo * param.sgw; size_t nr_grad_total = param.n * param.ci * param.hi * param.wi; float* ws_grad = workspace; @@ -558,14 +540,15 @@ void megdnn::cuda::local_share_bwd_data::_do_local_share_bwd_data_implicit_gemm( int lda, ldb; lda = ldb = param.co * ho * wo; int ldc = param.n; - cublas_check(cublasSgeam(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, - one, d_diff, lda, zero, d_diff, ldb, ws_diff, - ldc)); + cublas_check(cublasSgeam( + cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, one, d_diff, lda, zero, + d_diff, ldb, ws_diff, ldc)); } { - void (*kern)(const float* __restrict__, const float* __restrict__, - float* __restrict__, Param, int, int, int, int); + void (*kern)( + const float* __restrict__, const float* __restrict__, + float* __restrict__, Param, int, int, int, int); LaunchConfig launch_config; kern = get_kern(param, launch_config); @@ -591,9 +574,9 @@ void megdnn::cuda::local_share_bwd_data::_do_local_share_bwd_data_implicit_gemm( int lda, ldb; lda = ldb = param.n; int ldc = param.ci * param.hi * param.wi; - cublas_check(cublasSgeam(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, - one, ws_grad, lda, zero, ws_grad, ldb, d_grad, - ldc)); + cublas_check(cublasSgeam( + cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, one, ws_grad, lda, zero, + ws_grad, ldb, d_grad, ldc)); } } diff --git a/dnn/src/cuda/local_share/backward_filter/algo.cpp b/dnn/src/cuda/local_share/backward_filter/algo.cpp index 0e7a2a46..22237024 100644 --- a/dnn/src/cuda/local_share/backward_filter/algo.cpp +++ b/dnn/src/cuda/local_share/backward_filter/algo.cpp @@ -33,27 +33,24 @@ LocalShareBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( const TensorLayout& diff, const TensorLayout& grad) : opr{o}, src_layout{src}, diff_layout{diff}, grad_layout{grad} {} -LocalShareBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs(LocalShareBackwardFilterImpl* opr, - _megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) +LocalShareBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs( + LocalShareBackwardFilterImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) : SizeArgs(opr, src.layout, diff.layout, grad.layout), src_tensor{&src}, diff_tensor{&diff}, grad_tensor{&grad}, workspace{workspace} {} -std::string LocalShareBackwardFilterImpl::AlgoBase::SizeArgs::to_string() - const { +std::string LocalShareBackwardFilterImpl::AlgoBase::SizeArgs::to_string() const { auto&& param = opr->param(); MEGDNN_MARK_USED_VAR(param); return ssprintf( "src=%s, diff=%s, grad=%s, " "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s->%s", src_layout.to_string().c_str(), diff_layout.to_string().c_str(), - grad_layout.to_string().c_str(), param.pad_h, param.pad_w, - param.stride_h, param.stride_w, param.dilate_h, param.dilate_w, + grad_layout.to_string().c_str(), param.pad_h, param.pad_w, param.stride_h, + param.stride_w, param.dilate_h, param.dilate_w, static_cast(param.mode), src_layout.dtype.name(), diff_layout.dtype.name(), grad_layout.dtype.name()); } diff --git a/dnn/src/cuda/local_share/backward_filter/algo.h b/dnn/src/cuda/local_share/backward_filter/algo.h index 3aa463ca..24deb869 100644 --- a/dnn/src/cuda/local_share/backward_filter/algo.h +++ b/dnn/src/cuda/local_share/backward_filter/algo.h @@ -41,16 +41,18 @@ public: TensorLayout src_layout, diff_layout, grad_layout; std::string to_string() const; - SizeArgs(LocalShareBackwardFilterImpl* opr, const TensorLayout& src, - const TensorLayout& diff, const TensorLayout& grad); + SizeArgs( + LocalShareBackwardFilterImpl* opr, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad); }; struct ExecArgs : public SizeArgs { const TensorND *src_tensor, *diff_tensor, *grad_tensor; Workspace workspace; - ExecArgs(LocalShareBackwardFilterImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace); + ExecArgs( + LocalShareBackwardFilterImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -65,16 +67,15 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "local share conv fwd algo %s: required workspace %zu " - "bytes, got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "local share conv fwd algo %s: required workspace %zu " + "bytes, got %zu", + name(), req, workspace.size); return *this; } }; @@ -85,9 +86,7 @@ public: size_t get_workspace_in_bytes(const SizeArgs& args) const override; void exec(const ExecArgs& args) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; } MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) @@ -97,13 +96,10 @@ class LocalShareBackwardFilterImpl::AlgoBatchedMatMul final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; void exec(const ExecArgs& args) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) diff --git a/dnn/src/cuda/local_share/backward_filter/batched_matmul.cpp b/dnn/src/cuda/local_share/backward_filter/batched_matmul.cpp index 2c27f1af..5ad6153a 100644 --- a/dnn/src/cuda/local_share/backward_filter/batched_matmul.cpp +++ b/dnn/src/cuda/local_share/backward_filter/batched_matmul.cpp @@ -34,20 +34,20 @@ bool LocalShareBackwardFilterImpl::AlgoBatchedMatMul::is_available( auto src_dtype = args.src_layout.dtype, diff_dtype = args.diff_layout.dtype, grad_dtype = args.grad_layout.dtype; // only support float32 - available &= (src_dtype == diff_dtype && src_dtype == grad_dtype && - src_dtype == dtype::Float32()); + available &= + (src_dtype == diff_dtype && src_dtype == grad_dtype && + src_dtype == dtype::Float32()); // do not support dilate conv size_t dh = param.dilate_h, dw = param.dilate_w; available &= (dh == 1 && dw == 1); return available; } -WorkspaceBundle -LocalShareBackwardFilterImpl::AlgoBatchedMatMul::get_workspace_bundle( +WorkspaceBundle LocalShareBackwardFilterImpl::AlgoBatchedMatMul::get_workspace_bundle( dt_byte* raw_ptr, const SizeArgs& args) const { auto&& param = args.opr->param(); - unpack_local_share_params(args.src_layout, args.grad_layout, - args.diff_layout, param); + unpack_local_share_params( + args.src_layout, args.grad_layout, args.diff_layout, param); using Param = LocalShare::Param; using Sparse = Param::Sparse; size_t groups = 1; @@ -55,8 +55,7 @@ LocalShareBackwardFilterImpl::AlgoBatchedMatMul::get_workspace_bundle( groups = args.grad_layout.shape[0]; } size_t icpg = ci / groups, ocpg = co / groups; - size_t ws_im2col = - n * ci * ho * wo * fh * fw * args.src_layout.dtype.size(); + size_t ws_im2col = n * ci * ho * wo * fh * fw * args.src_layout.dtype.size(); size_t ws_pretranspose = n * co * ho * wo * args.diff_layout.dtype.size(); auto&& matmul_opr = args.opr->handle()->create_operator(); matmul_opr->param().transposeA = true; @@ -64,10 +63,9 @@ LocalShareBackwardFilterImpl::AlgoBatchedMatMul::get_workspace_bundle( TensorLayout A{ {groups * sgh * sgw, ho / sgh * wo / sgw * n, icpg * fh * fw}, dtype::Float32()}; - TensorLayout B{{groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, - dtype::Float32()}; - TensorLayout C{{groups * sgh * sgw, icpg * fh * fw, ocpg}, - dtype::Float32()}; + TensorLayout B{ + {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; + TensorLayout C{{groups * sgh * sgw, icpg * fh * fw, ocpg}, dtype::Float32()}; size_t ws_matmul = matmul_opr->get_workspace_in_bytes(A, B, C); WorkspaceBundle ws{raw_ptr, {ws_im2col, ws_pretranspose, ws_matmul}}; return ws; @@ -78,11 +76,10 @@ size_t LocalShareBackwardFilterImpl::AlgoBatchedMatMul::get_workspace_in_bytes( return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } -void LocalShareBackwardFilterImpl::AlgoBatchedMatMul::exec( - const ExecArgs& args) const { +void LocalShareBackwardFilterImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) const { auto&& param = args.opr->param(); - unpack_local_share_params(args.src_layout, args.grad_layout, - args.diff_layout, param); + unpack_local_share_params( + args.src_layout, args.grad_layout, args.diff_layout, param); using Param = LocalShare::Param; using Sparse = Param::Sparse; size_t groups = 1; @@ -91,10 +88,10 @@ void LocalShareBackwardFilterImpl::AlgoBatchedMatMul::exec( } size_t icpg = ci / groups, ocpg = co / groups; local_share::Param kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ph = ph, - kern_param.pw = pw, kern_param.grp_ho = ho / sgh, - kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, kern_param.sgw = sgw; + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ph = ph, kern_param.pw = pw, + kern_param.grp_ho = ho / sgh, kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, + kern_param.sgw = sgw; auto ws = get_workspace_bundle(args.workspace.raw_ptr, args); auto ws_im2col = ws.get(0); @@ -107,8 +104,8 @@ void LocalShareBackwardFilterImpl::AlgoBatchedMatMul::exec( kern_param, stream); { - TensorLayout B1{{groups, sgh, sgw, ocpg, n, ho / sgh, wo / sgw}, - dtype::Float32()}; + TensorLayout B1{ + {groups, sgh, sgw, ocpg, n, ho / sgh, wo / sgw}, dtype::Float32()}; B1.stride[0] = wo * ho * ocpg; B1.stride[1] = wo * ho / sgh; B1.stride[2] = wo / sgw; @@ -117,8 +114,8 @@ void LocalShareBackwardFilterImpl::AlgoBatchedMatMul::exec( B1.stride[5] = wo; B1.stride[6] = 1; TensorND ts_B1{args.diff_tensor->raw_ptr, B1}; - TensorLayout B2{{groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, - dtype::Float32()}; + TensorLayout B2{ + {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; B2.init_contiguous_stride(); TensorND ts_B2{ws_pretranspose, B2}; auto&& relayout_opr = args.opr->handle()->create_operator(); @@ -131,10 +128,9 @@ void LocalShareBackwardFilterImpl::AlgoBatchedMatMul::exec( TensorLayout A{ {groups * sgh * sgw, ho / sgh * wo / sgw * n, icpg * fh * fw}, dtype::Float32()}; - TensorLayout B{{groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, - dtype::Float32()}; - TensorLayout C{{groups * sgh * sgw, icpg * fh * fw, ocpg}, - dtype::Float32()}; + TensorLayout B{ + {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; + TensorLayout C{{groups * sgh * sgw, icpg * fh * fw, ocpg}, dtype::Float32()}; TensorND ts_A{ws_im2col, A}; TensorND ts_B{ws_pretranspose, B}; TensorND ts_C{args.grad_tensor->raw_ptr, C}; diff --git a/dnn/src/cuda/local_share/backward_filter/implicit_gemm.cpp b/dnn/src/cuda/local_share/backward_filter/implicit_gemm.cpp index e5c7ad65..262bfb87 100644 --- a/dnn/src/cuda/local_share/backward_filter/implicit_gemm.cpp +++ b/dnn/src/cuda/local_share/backward_filter/implicit_gemm.cpp @@ -35,8 +35,8 @@ bool LocalShareBackwardFilterImpl::AlgoImplicitGemm::is_available( available &= (sparse == Sparse::DENSE); // mode must be cross correlation available &= (mode == Mode::CROSS_CORRELATION); - unpack_local_share_params(args.src_layout, args.grad_layout, - args.diff_layout, param); + unpack_local_share_params( + args.src_layout, args.grad_layout, args.diff_layout, param); available &= (ho % sgh == 0 && wo % sgw == 0); // not support dilated convolution available &= (dh == 1 && dw == 1); @@ -44,8 +44,9 @@ bool LocalShareBackwardFilterImpl::AlgoImplicitGemm::is_available( auto src_dtype = args.src_layout.dtype, diff_dtype = args.diff_layout.dtype, grad_dtype = args.grad_layout.dtype; // only support float32 - available &= (src_dtype == diff_dtype && src_dtype == grad_dtype && - src_dtype == dtype::Float32()); + available &= + (src_dtype == diff_dtype && src_dtype == grad_dtype && + src_dtype == dtype::Float32()); // only support sm_60 or later available &= is_compute_capability_required(6, 0); @@ -55,23 +56,22 @@ bool LocalShareBackwardFilterImpl::AlgoImplicitGemm::is_available( size_t LocalShareBackwardFilterImpl::AlgoImplicitGemm::get_workspace_in_bytes( const SizeArgs& args) const { auto&& param = args.opr->param(); - unpack_local_share_params(args.src_layout, args.grad_layout, - args.diff_layout, param); + unpack_local_share_params( + args.src_layout, args.grad_layout, args.diff_layout, param); size_t ws_size_src = n * ci * hi * wi * args.grad_layout.dtype.size(); size_t ws_size_diff = n * co * ho * wo * args.diff_layout.dtype.size(); return ws_size_src + ws_size_diff; } -void LocalShareBackwardFilterImpl::AlgoImplicitGemm::exec( - const ExecArgs& args) const { +void LocalShareBackwardFilterImpl::AlgoImplicitGemm::exec(const ExecArgs& args) const { local_share::Param kern_param; auto&& param = args.opr->param(); - unpack_local_share_params(args.src_layout, args.grad_layout, - args.diff_layout, param); - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ph = ph, - kern_param.pw = pw, kern_param.grp_ho = ho / sgh, - kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, kern_param.sgw = sgw; + unpack_local_share_params( + args.src_layout, args.grad_layout, args.diff_layout, param); + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ph = ph, kern_param.pw = pw, + kern_param.grp_ho = ho / sgh, kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, + kern_param.sgw = sgw; auto&& handle = concrete_handle(args.opr->handle()); auto&& cublas_hdl = cublas_handle(args.opr->handle()); auto&& stream = cuda_stream(args.opr->handle()); @@ -80,8 +80,7 @@ void LocalShareBackwardFilterImpl::AlgoImplicitGemm::exec( auto zero = handle->zero_device(); local_share_bwd_filter::_do_local_share_bwd_filter_implicit_gemm( - args.src_tensor->ptr(), - args.diff_tensor->ptr(), + args.src_tensor->ptr(), args.diff_tensor->ptr(), args.grad_tensor->ptr(), reinterpret_cast(args.workspace.raw_ptr), fh, fw, sh, sw, kern_param, cublas_hdl, stream, one, zero); diff --git a/dnn/src/cuda/local_share/backward_filter/local_share_bwd_filter.cuh b/dnn/src/cuda/local_share/backward_filter/local_share_bwd_filter.cuh index 5671cdc0..3f2ef4b1 100644 --- a/dnn/src/cuda/local_share/backward_filter/local_share_bwd_filter.cuh +++ b/dnn/src/cuda/local_share/backward_filter/local_share_bwd_filter.cuh @@ -15,10 +15,9 @@ namespace cuda { namespace local_share_bwd_filter { void _do_local_share_bwd_filter_implicit_gemm( - const float* d_src, const float* d_diff, float* d_grad, - float* workspace, int fh, int fw, int sh, int sw, - const local_share::Param& param, cublasHandle_t cublas_handle, - cudaStream_t stream, float* one, float* zero); + const float* d_src, const float* d_diff, float* d_grad, float* workspace, + int fh, int fw, int sh, int sw, const local_share::Param& param, + cublasHandle_t cublas_handle, cudaStream_t stream, float* one, float* zero); } // namespace local_share_bwd_filter } // namespace cuda diff --git a/dnn/src/cuda/local_share/backward_filter/local_share_bwd_filter_f32_implicit_gemm.cu b/dnn/src/cuda/local_share/backward_filter/local_share_bwd_filter_f32_implicit_gemm.cu index a5ae40a2..36bfab57 100644 --- a/dnn/src/cuda/local_share/backward_filter/local_share_bwd_filter_f32_implicit_gemm.cu +++ b/dnn/src/cuda/local_share/backward_filter/local_share_bwd_filter_f32_implicit_gemm.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local_share/backward_filter/local_share_bwd_filter_f32_implicit_gemm.cu + * \file + * dnn/src/cuda/local_share/backward_filter/local_share_bwd_filter_f32_implicit_gemm.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -32,8 +33,7 @@ struct ThreadConfig { template struct DiffTileCount { static int const tile_batch = UnrollConfig::unroll_n; - static int const tile_co = - UnrollConfig::unroll_co * ThreadConfig::nr_thread_x; + static int const tile_co = UnrollConfig::unroll_co * ThreadConfig::nr_thread_x; static int const load_x = tile_batch > 32 ? 32 : tile_batch; static int const load_y = ThreadConfig::nr_threads / load_x; @@ -52,8 +52,7 @@ struct DiffTileCount { template struct DataTileCount { static int const tile_batch = UnrollConfig::unroll_n; - static int const tile_ci = - ThreadConfig::nr_thread_y * UnrollConfig::unroll_ci; + static int const tile_ci = ThreadConfig::nr_thread_y * UnrollConfig::unroll_ci; static int const load_x = tile_batch > 32 ? 32 : tile_batch; static int const load_y = ThreadConfig::nr_threads / load_x; @@ -156,21 +155,16 @@ struct Global2ShareMemVisitor { return &smem[y * TileCount::smem_stride + x]; } - __device__ __forceinline__ void move_forward() { - g_ptr += TileCount::tile_batch; - } + __device__ __forceinline__ void move_forward() { g_ptr += TileCount::tile_batch; } }; template __device__ __forceinline__ void consume_block( - Global2ShareMemVisitor>& + Global2ShareMemVisitor>& src_gl2sh_visitor, - Global2ShareMemVisitor>& + Global2ShareMemVisitor>& diff_gl2sh_visitor, - float r_src[UnrollConfig::unroll_ci], - float r_diff[UnrollConfig::unroll_co], + float r_src[UnrollConfig::unroll_ci], float r_diff[UnrollConfig::unroll_co], float r_grad[UnrollConfig::unroll_ci][UnrollConfig::unroll_co]) { const int tidx = threadIdx.x; const int tidy = threadIdx.y; @@ -237,11 +231,10 @@ __global__ void local_share_bwd_filter_device_template_f32( const float* __restrict__ g_ptr_diff = diff + b_co * ho * wo * param.n; float* __restrict__ g_ptr_grad = grad + - sp_grp_idx * filter_sizes * param.co * - param.ci // spatial group stride - + t_ci * filter_sizes * param.co // input channel stride - + kern_spatial * param.co // kernel spatial stride - + t_co; // output channel stride + sp_grp_idx * filter_sizes * param.co * param.ci // spatial group stride + + t_ci * filter_sizes * param.co // input channel stride + + kern_spatial * param.co // kernel spatial stride + + t_co; // output channel stride Global2ShareMemVisitor src_gl2sh_visitor{ sh_src, param.hi * param.wi * param.n, param.ci - b_ci}; @@ -265,8 +258,7 @@ __global__ void local_share_bwd_filter_device_template_f32( int sp_grp_w_start = sgw_idx * param.grp_wo; int sp_grp_w_end = sgw_idx * param.grp_wo + param.grp_wo - 1; int height_start = (param.ph - kh + sh - 1) / sh; - height_start = - sp_grp_h_start >= height_start ? sp_grp_h_start : height_start; + height_start = sp_grp_h_start >= height_start ? sp_grp_h_start : height_start; int width_start = (param.pw - kw + sw - 1) / sw; width_start = sp_grp_w_start >= width_start ? sp_grp_w_start : width_start; int height_end = (param.hi - 1 + param.ph - kh) / sh; @@ -274,15 +266,12 @@ __global__ void local_share_bwd_filter_device_template_f32( int width_end = (param.wi - 1 + param.pw - kw) / sw; width_end = sp_grp_w_end <= width_end ? sp_grp_w_end : width_end; - const int b_blks = - (param.n + UnrollConfig::unroll_n - 1) / UnrollConfig::unroll_n; + const int b_blks = (param.n + UnrollConfig::unroll_n - 1) / UnrollConfig::unroll_n; int ih_idx = height_start * sh - param.ph + kh; int iw_idx = width_start * sw - param.pw + kw; - src_gl2sh_visitor.g_ptr = - g_ptr_src + (ih_idx * param.wi + iw_idx) * param.n; - diff_gl2sh_visitor.g_ptr = - g_ptr_diff + (height_start * wo + width_start) * param.n; + src_gl2sh_visitor.g_ptr = g_ptr_src + (ih_idx * param.wi + iw_idx) * param.n; + diff_gl2sh_visitor.g_ptr = g_ptr_diff + (height_start * wo + width_start) * param.n; if (height_start <= height_end && width_start <= width_end) { src_gl2sh_visitor.first_copy(); @@ -303,8 +292,7 @@ __global__ void local_share_bwd_filter_device_template_f32( int iw_idx = w_next * sw - param.pw + kw; src_gl2sh_visitor.g_ptr = - g_ptr_src + - (ih_idx * param.wi + iw_idx) * param.n; + g_ptr_src + (ih_idx * param.wi + iw_idx) * param.n; diff_gl2sh_visitor.g_ptr = g_ptr_diff + (h_next * wo + w_next) * param.n; src_gl2sh_visitor.copy(); @@ -318,12 +306,10 @@ __global__ void local_share_bwd_filter_device_template_f32( } consume_block( - src_gl2sh_visitor, diff_gl2sh_visitor, r_src, r_diff, - r_grad); + src_gl2sh_visitor, diff_gl2sh_visitor, r_src, r_diff, r_grad); // last tile - if (!(h == height_end && w == width_end && - b_outer == b_blks - 1)) { + if (!(h == height_end && w == width_end && b_outer == b_blks - 1)) { __syncthreads(); src_gl2sh_visitor.commit(); diff_gl2sh_visitor.commit(); @@ -339,58 +325,51 @@ __global__ void local_share_bwd_filter_device_template_f32( for (int i = 0; i < UnrollConfig::unroll_ci; ++i) { #pragma unroll for (int j = 0; j < UnrollConfig::unroll_co; ++j) { - if (check_bounds && - (t_co + j * ThreadConfig::nr_thread_x >= param.co || - t_ci + i * ThreadConfig::nr_thread_y >= param.ci)) { + if (check_bounds && (t_co + j * ThreadConfig::nr_thread_x >= param.co || + t_ci + i * ThreadConfig::nr_thread_y >= param.ci)) { } else { - g_ptr_grad[j * ThreadConfig::nr_thread_x + - i * ThreadConfig::nr_thread_y * ci_stride] = - r_grad[i][j]; + g_ptr_grad + [j * ThreadConfig::nr_thread_x + + i * ThreadConfig::nr_thread_y * ci_stride] = r_grad[i][j]; } } } } -void (*get_kern(const Param& param, const int filter_sizes, - LaunchConfig& launch_config))(const float* __restrict__, - const float* __restrict__, - float* __restrict__, Param, int, - int, int, int) { - void (*kern)(const float* __restrict__, const float* __restrict__, - float* __restrict__, Param, int, int, int, int); +void (*get_kern( + const Param& param, const int filter_sizes, LaunchConfig& launch_config))( + const float* __restrict__, const float* __restrict__, float* __restrict__, + Param, int, int, int, int) { + void (*kern)( + const float* __restrict__, const float* __restrict__, float* __restrict__, + Param, int, int, int, int); kern = nullptr; -#define CHK3(ci_, co_, n_, tx_, ty_) \ - if (param.ci >= ci_) { \ - if (param.co >= co_) { \ - if (param.n % n_ == 0) { \ - static constexpr int unroll_ci = (ci_ + ty_ - 1) / ty_; \ - static constexpr int unroll_co = (co_ + tx_ - 1) / tx_; \ - static constexpr int unroll_n = n_; \ - static constexpr int thread_x = tx_; \ - static constexpr int thread_y = ty_; \ - typedef UnrollConfig \ - UnrollConfig; \ - typedef ThreadConfig ThreadConfig; \ - typedef DataTileCount \ - DataTileCount; \ - typedef DiffTileCount \ - DiffTileCount; \ - kern = local_share_bwd_filter_device_template_f32< \ - true, UnrollConfig, ThreadConfig>; \ - launch_config.nr_threads_x = thread_x; \ - launch_config.nr_threads_y = thread_y; \ - launch_config.nr_threads_z = 1; \ - launch_config.nr_blocks_x = \ - param.sgh * param.sgw * filter_sizes; \ - launch_config.nr_blocks_y = \ - DIVUP(param.co, DiffTileCount::tile_co); \ - launch_config.nr_blocks_z = \ - DIVUP(param.ci, DataTileCount::tile_ci); \ - launch_config.smem_size_in_bytes = \ - sizeof(float) * \ - (DataTileCount::smem_tot + DiffTileCount::smem_tot); \ - } \ - } \ +#define CHK3(ci_, co_, n_, tx_, ty_) \ + if (param.ci >= ci_) { \ + if (param.co >= co_) { \ + if (param.n % n_ == 0) { \ + static constexpr int unroll_ci = (ci_ + ty_ - 1) / ty_; \ + static constexpr int unroll_co = (co_ + tx_ - 1) / tx_; \ + static constexpr int unroll_n = n_; \ + static constexpr int thread_x = tx_; \ + static constexpr int thread_y = ty_; \ + typedef UnrollConfig UnrollConfig; \ + typedef ThreadConfig ThreadConfig; \ + typedef DataTileCount DataTileCount; \ + typedef DiffTileCount DiffTileCount; \ + kern = local_share_bwd_filter_device_template_f32< \ + true, UnrollConfig, ThreadConfig>; \ + launch_config.nr_threads_x = thread_x; \ + launch_config.nr_threads_y = thread_y; \ + launch_config.nr_threads_z = 1; \ + launch_config.nr_blocks_x = param.sgh * param.sgw * filter_sizes; \ + launch_config.nr_blocks_y = DIVUP(param.co, DiffTileCount::tile_co); \ + launch_config.nr_blocks_z = DIVUP(param.ci, DataTileCount::tile_ci); \ + launch_config.smem_size_in_bytes = \ + sizeof(float) * \ + (DataTileCount::smem_tot + DiffTileCount::smem_tot); \ + } \ + } \ } #define CHK2(ci_, co_) \ CHK3(ci_, co_, 4, 16, 8) \ @@ -411,46 +390,39 @@ void (*get_kern(const Param& param, const int filter_sizes, #undef CHK2 #undef CHK2_ #undef CHK3 -#define CHK3(ci_, co_, n_, tx_, ty_) \ - if (param.ci % ci_ == 0) { \ - if (param.co % co_ == 0) { \ - if (param.n % n_ == 0) { \ - static constexpr int unroll_ci = (ci_) / (ty_); \ - static constexpr int unroll_co = (co_) / (tx_); \ - static constexpr int unroll_n = n_; \ - static constexpr int thread_x = tx_; \ - static constexpr int thread_y = ty_; \ - typedef UnrollConfig \ - UnrollConfig; \ - typedef ThreadConfig ThreadConfig; \ - typedef DataTileCount \ - DataTileCount; \ - typedef DiffTileCount \ - DiffTileCount; \ - kern = local_share_bwd_filter_device_template_f32< \ - false, UnrollConfig, ThreadConfig>; \ - launch_config.nr_threads_x = thread_x; \ - launch_config.nr_threads_y = thread_y; \ - launch_config.nr_threads_z = 1; \ - launch_config.nr_blocks_x = \ - param.sgh * param.sgw * filter_sizes; \ - launch_config.nr_blocks_y = \ - DIVUP(param.co, DiffTileCount::tile_co); \ - launch_config.nr_blocks_z = \ - DIVUP(param.ci, DataTileCount::tile_ci); \ - launch_config.smem_size_in_bytes = \ - sizeof(float) * \ - (DataTileCount::smem_tot + DiffTileCount::smem_tot); \ - } \ - } \ +#define CHK3(ci_, co_, n_, tx_, ty_) \ + if (param.ci % ci_ == 0) { \ + if (param.co % co_ == 0) { \ + if (param.n % n_ == 0) { \ + static constexpr int unroll_ci = (ci_) / (ty_); \ + static constexpr int unroll_co = (co_) / (tx_); \ + static constexpr int unroll_n = n_; \ + static constexpr int thread_x = tx_; \ + static constexpr int thread_y = ty_; \ + typedef UnrollConfig UnrollConfig; \ + typedef ThreadConfig ThreadConfig; \ + typedef DataTileCount DataTileCount; \ + typedef DiffTileCount DiffTileCount; \ + kern = local_share_bwd_filter_device_template_f32< \ + false, UnrollConfig, ThreadConfig>; \ + launch_config.nr_threads_x = thread_x; \ + launch_config.nr_threads_y = thread_y; \ + launch_config.nr_threads_z = 1; \ + launch_config.nr_blocks_x = param.sgh * param.sgw * filter_sizes; \ + launch_config.nr_blocks_y = DIVUP(param.co, DiffTileCount::tile_co); \ + launch_config.nr_blocks_z = DIVUP(param.ci, DataTileCount::tile_ci); \ + launch_config.smem_size_in_bytes = \ + sizeof(float) * \ + (DataTileCount::smem_tot + DiffTileCount::smem_tot); \ + } \ + } \ } -#define CHK2(ci_, co_) \ - CHK3(ci_, co_, 4, 8, 8) CHK3(ci_, co_, 8, 8, 8) -#define CHK(ci_) \ - CHK2(ci_, 8) \ - CHK2(ci_, 16) \ - CHK2(ci_, 32) \ - CHK2(ci_, 64) \ +#define CHK2(ci_, co_) CHK3(ci_, co_, 4, 8, 8) CHK3(ci_, co_, 8, 8, 8) +#define CHK(ci_) \ + CHK2(ci_, 8) \ + CHK2(ci_, 16) \ + CHK2(ci_, 32) \ + CHK2(ci_, 64) \ CHK3(ci_, 128, 4, 16, 8) CHK3(ci_, 128, 8, 16, 8) CHK(8); CHK(16); @@ -460,20 +432,19 @@ void (*get_kern(const Param& param, const int filter_sizes, #undef CHK #undef CHK2 #undef CHK3 - megdnn_assert(kern != nullptr, - "no usable kernel implementation for local share " - "backward data (batch,co,ci)=(%d,%d,%d)", - param.n, param.co, param.ci); + megdnn_assert( + kern != nullptr, + "no usable kernel implementation for local share " + "backward data (batch,co,ci)=(%d,%d,%d)", + param.n, param.co, param.ci); return kern; } } // namespace -void megdnn::cuda::local_share_bwd_filter:: - _do_local_share_bwd_filter_implicit_gemm( - const float* d_src, const float* d_diff, float* d_grad, - float* workspace, int fh, int fw, int sh, int sw, - const Param& param, cublasHandle_t cublas_handle, - cudaStream_t stream, float* one, float* zero) { +void megdnn::cuda::local_share_bwd_filter::_do_local_share_bwd_filter_implicit_gemm( + const float* d_src, const float* d_diff, float* d_grad, float* workspace, + int fh, int fw, int sh, int sw, const Param& param, + cublasHandle_t cublas_handle, cudaStream_t stream, float* one, float* zero) { int ho = param.grp_ho * param.sgh, wo = param.grp_wo * param.sgw; size_t nr_src_total = param.n * param.ci * param.hi * param.wi; float* ws_src = workspace; @@ -484,9 +455,9 @@ void megdnn::cuda::local_share_bwd_filter:: int lda, ldb; lda = ldb = param.ci * param.hi * param.wi; int ldc = param.n; - cublas_check(cublasSgeam(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, - one, d_src, lda, zero, d_src, ldb, ws_src, - ldc)); + cublas_check(cublasSgeam( + cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, one, d_src, lda, zero, + d_src, ldb, ws_src, ldc)); } { @@ -494,15 +465,16 @@ void megdnn::cuda::local_share_bwd_filter:: int lda, ldb; lda = ldb = param.co * ho * wo; int ldc = param.n; - cublas_check(cublasSgeam(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, - one, d_diff, lda, zero, d_diff, ldb, ws_diff, - ldc)); + cublas_check(cublasSgeam( + cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, one, d_diff, lda, zero, + d_diff, ldb, ws_diff, ldc)); } { int filter_sizes = fh * fw; - void (*kern)(const float* __restrict__, const float* __restrict__, - float* __restrict__, Param, int, int, int, int); + void (*kern)( + const float* __restrict__, const float* __restrict__, + float* __restrict__, Param, int, int, int, int); LaunchConfig launch_config; kern = get_kern(param, filter_sizes, launch_config); diff --git a/dnn/src/cuda/local_share/forward/algo.cpp b/dnn/src/cuda/local_share/forward/algo.cpp index f915d70d..db33d98c 100644 --- a/dnn/src/cuda/local_share/forward/algo.cpp +++ b/dnn/src/cuda/local_share/forward/algo.cpp @@ -29,17 +29,14 @@ MEGDNN_DEF_GET_ALGO_FROM_DESC(LocalShareForwardImpl) LocalShareForwardImpl::AlgoPack LocalShareForwardImpl::sm_algo_pack; -LocalShareForwardImpl::AlgoBase::SizeArgs::SizeArgs(LocalShareForwardImpl* o, - const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) +LocalShareForwardImpl::AlgoBase::SizeArgs::SizeArgs( + LocalShareForwardImpl* o, const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) : opr{o}, src_layout{src}, filter_layout{filter}, dst_layout{dst} {} -LocalShareForwardImpl::AlgoBase::ExecArgs::ExecArgs(LocalShareForwardImpl* opr, - _megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) +LocalShareForwardImpl::AlgoBase::ExecArgs::ExecArgs( + LocalShareForwardImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_in filter, + _megdnn_tensor_out dst, _megdnn_workspace workspace) : SizeArgs(opr, src.layout, filter.layout, dst.layout), src_tensor{&src}, filter_tensor{&filter}, @@ -53,8 +50,8 @@ std::string LocalShareForwardImpl::AlgoBase::SizeArgs::to_string() const { "src=%s, filter=%s, dst=%s, " "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s", src_layout.to_string().c_str(), filter_layout.to_string().c_str(), - dst_layout.to_string().c_str(), param.pad_h, param.pad_w, - param.stride_h, param.stride_w, param.dilate_h, param.dilate_w, + dst_layout.to_string().c_str(), param.pad_h, param.pad_w, param.stride_h, + param.stride_w, param.dilate_h, param.dilate_w, static_cast(param.mode), src_layout.dtype.name(), dst_layout.dtype.name()); } diff --git a/dnn/src/cuda/local_share/forward/algo.h b/dnn/src/cuda/local_share/forward/algo.h index 44498099..6939e534 100644 --- a/dnn/src/cuda/local_share/forward/algo.h +++ b/dnn/src/cuda/local_share/forward/algo.h @@ -13,9 +13,9 @@ #include "megdnn/oprs.h" -#include "src/common/utils.h" #include "src/common/algo_base.h" #include "src/common/metahelper.h" +#include "src/common/utils.h" #include "src/cuda/handle.h" #include "src/cuda/local_share/opr_impl.h" @@ -42,16 +42,18 @@ public: TensorLayout src_layout, filter_layout, dst_layout; std::string to_string() const; - SizeArgs(LocalShareForwardImpl* opr, const TensorLayout& src, - const TensorLayout& filter, const TensorLayout& dst); + SizeArgs( + LocalShareForwardImpl* opr, const TensorLayout& src, + const TensorLayout& filter, const TensorLayout& dst); }; struct ExecArgs : public SizeArgs { const TensorND *src_tensor, *filter_tensor, *dst_tensor; Workspace workspace; - ExecArgs(LocalShareForwardImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_in filter, _megdnn_tensor_out dst, - _megdnn_workspace workspace); + ExecArgs( + LocalShareForwardImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -66,16 +68,15 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); - megdnn_assert(req <= workspace.size, - "local share conv fwd algo %s: required workspace %zu " - "bytes, got %zu", - name(), req, workspace.size); + megdnn_assert( + req <= workspace.size, + "local share conv fwd algo %s: required workspace %zu " + "bytes, got %zu", + name(), req, workspace.size); return *this; } }; @@ -84,33 +85,26 @@ class LocalShareForwardImpl::AlgoCHWNBatchSizeAware final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; void exec(const ExecArgs& args) const override; AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } - const char* name() const override { - return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE"; - } + const char* name() const override { return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE"; } MEGDNN_DECL_ALGO_TYPE(CUDA_CHWN_BATCH_SIZE_AWARE) }; -class LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage final - : public AlgoBase { +class LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; void exec(const ExecArgs& args) const override; AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { @@ -123,13 +117,10 @@ class LocalShareForwardImpl::AlgoBatchedMatMul final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; - WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, - const SizeArgs& args) const; + WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const; void exec(const ExecArgs& args) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) diff --git a/dnn/src/cuda/local_share/forward/batch_size_aware_chwn.cpp b/dnn/src/cuda/local_share/forward/batch_size_aware_chwn.cpp index 3db59138..c42cb877 100644 --- a/dnn/src/cuda/local_share/forward/batch_size_aware_chwn.cpp +++ b/dnn/src/cuda/local_share/forward/batch_size_aware_chwn.cpp @@ -35,36 +35,35 @@ bool LocalShareForwardImpl::AlgoCHWNBatchSizeAware::is_available( available &= (sparse == Sparse::DENSE); // mode must be cross correlation available &= (mode == Mode::CROSS_CORRELATION); - unpack_local_share_params(args.src_layout, args.filter_layout, - args.dst_layout, param); + unpack_local_share_params( + args.src_layout, args.filter_layout, args.dst_layout, param); available &= (ho % sgh == 0 && wo % sgw == 0); // not support dilated convolution available &= (dh == 1 && dw == 1); available &= (n % 32 == 0); // kernel size should be 3, 5, 7 - available &= (fh == 1 && fw == 1) || (fh == 3 && fw == 3) || - (fh == 5 && fw == 5) || (fh == 7 || fw == 7); + available &= (fh == 1 && fw == 1) || (fh == 3 && fw == 3) || (fh == 5 && fw == 5) || + (fh == 7 || fw == 7); // stride should be 1 or 2 available &= (sh == sw && (sh == 1 || sh == 2)); available &= (ci % 4 == 0) || (fh == 3 && ci % 2 == 0); - auto src_dtype = args.src_layout.dtype, - filter_dtype = args.filter_layout.dtype, + auto src_dtype = args.src_layout.dtype, filter_dtype = args.filter_layout.dtype, dst_dtype = args.dst_layout.dtype; // only support float32 - available &= (src_dtype == filter_dtype && src_dtype == dst_dtype && - src_dtype == dtype::Float32()); + available &= + (src_dtype == filter_dtype && src_dtype == dst_dtype && + src_dtype == dtype::Float32()); // only support sm_60 or later available &= is_compute_capability_required(6, 0); return available; } -WorkspaceBundle -LocalShareForwardImpl::AlgoCHWNBatchSizeAware::get_workspace_bundle( +WorkspaceBundle LocalShareForwardImpl::AlgoCHWNBatchSizeAware::get_workspace_bundle( dt_byte* raw_ptr, const SizeArgs& args) const { auto&& param = args.opr->param(); - unpack_local_share_params(args.src_layout, args.filter_layout, - args.dst_layout, param); + unpack_local_share_params( + args.src_layout, args.filter_layout, args.dst_layout, param); size_t ws_size_src = n * ci * hi * wi * args.src_layout.dtype.size(); size_t ws_size_dst = n * co * ho * wo * args.dst_layout.dtype.size(); WorkspaceBundle ws{raw_ptr, {ws_size_src, ws_size_dst}}; @@ -76,16 +75,15 @@ size_t LocalShareForwardImpl::AlgoCHWNBatchSizeAware::get_workspace_in_bytes( return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } -void LocalShareForwardImpl::AlgoCHWNBatchSizeAware::exec( - const ExecArgs& args) const { +void LocalShareForwardImpl::AlgoCHWNBatchSizeAware::exec(const ExecArgs& args) const { local_share::Param kern_param; auto&& param = args.opr->param(); - unpack_local_share_params(args.src_layout, args.filter_layout, - args.dst_layout, param); - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ph = ph, - kern_param.pw = pw, kern_param.grp_ho = ho / sgh, - kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, kern_param.sgw = sgw; + unpack_local_share_params( + args.src_layout, args.filter_layout, args.dst_layout, param); + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ph = ph, kern_param.pw = pw, + kern_param.grp_ho = ho / sgh, kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, + kern_param.sgw = sgw; auto&& handle = concrete_handle(args.opr->handle()); auto&& cublas_hdl = cublas_handle(args.opr->handle()); auto&& stream = cuda_stream(args.opr->handle()); @@ -94,8 +92,7 @@ void LocalShareForwardImpl::AlgoCHWNBatchSizeAware::exec( auto zero = handle->zero_device(); local_share::_do_local_share_convolution_large_batch_size( - args.src_tensor->ptr(), - args.filter_tensor->ptr(), + args.src_tensor->ptr(), args.filter_tensor->ptr(), args.dst_tensor->ptr(), reinterpret_cast(args.workspace.raw_ptr), fh, fw, sh, sw, kern_param, cublas_hdl, stream, one, zero); diff --git a/dnn/src/cuda/local_share/forward/batch_size_aware_chwn_small_image.cpp b/dnn/src/cuda/local_share/forward/batch_size_aware_chwn_small_image.cpp index 64961562..3bd90a25 100644 --- a/dnn/src/cuda/local_share/forward/batch_size_aware_chwn_small_image.cpp +++ b/dnn/src/cuda/local_share/forward/batch_size_aware_chwn_small_image.cpp @@ -35,38 +35,36 @@ bool LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage::is_available( available &= (sparse == Sparse::DENSE); // mode must be cross correlation available &= (mode == Mode::CROSS_CORRELATION); - unpack_local_share_params(args.src_layout, args.filter_layout, - args.dst_layout, param); + unpack_local_share_params( + args.src_layout, args.filter_layout, args.dst_layout, param); available &= (ho % sgh == 0 && wo % sgw == 0); // not support dilated convolution available &= (dh == 1 && dw == 1); available &= (ci % 4 == 0); - auto src_dtype = args.src_layout.dtype, - filter_dtype = args.filter_layout.dtype, + auto src_dtype = args.src_layout.dtype, filter_dtype = args.filter_layout.dtype, dst_dtype = args.dst_layout.dtype; // only support float32 - available &= (src_dtype == filter_dtype && src_dtype == dst_dtype && - src_dtype == dtype::Float32()); + available &= + (src_dtype == filter_dtype && src_dtype == dst_dtype && + src_dtype == dtype::Float32()); // only support sm_60 or later available &= is_compute_capability_required(6, 0); return available; } -WorkspaceBundle -LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage::get_workspace_bundle( - dt_byte* raw_ptr, const SizeArgs& args) const { +WorkspaceBundle LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage:: + get_workspace_bundle(dt_byte* raw_ptr, const SizeArgs& args) const { auto&& param = args.opr->param(); - unpack_local_share_params(args.src_layout, args.filter_layout, - args.dst_layout, param); + unpack_local_share_params( + args.src_layout, args.filter_layout, args.dst_layout, param); size_t ws_size_src = n * ci * hi * wi * args.src_layout.dtype.size(); size_t ws_size_dst = n * co * ho * wo * args.dst_layout.dtype.size(); WorkspaceBundle ws{raw_ptr, {ws_size_src, ws_size_dst}}; return ws; } -size_t -LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage::get_workspace_in_bytes( +size_t LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage::get_workspace_in_bytes( const SizeArgs& args) const { return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } @@ -75,12 +73,12 @@ void LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage::exec( const ExecArgs& args) const { local_share::Param kern_param; auto&& param = args.opr->param(); - unpack_local_share_params(args.src_layout, args.filter_layout, - args.dst_layout, param); - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ph = ph, - kern_param.pw = pw, kern_param.grp_ho = ho / sgh, - kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, kern_param.sgw = sgw; + unpack_local_share_params( + args.src_layout, args.filter_layout, args.dst_layout, param); + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ph = ph, kern_param.pw = pw, + kern_param.grp_ho = ho / sgh, kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, + kern_param.sgw = sgw; auto&& handle = concrete_handle(args.opr->handle()); auto&& cublas_hdl = cublas_handle(args.opr->handle()); auto&& stream = cuda_stream(args.opr->handle()); @@ -89,8 +87,7 @@ void LocalShareForwardImpl::AlgoCHWNBatchSizeAwareSmallImage::exec( auto zero = handle->zero_device(); local_share::_do_local_share_convolution_large_batch_size_small_image( - args.src_tensor->ptr(), - args.filter_tensor->ptr(), + args.src_tensor->ptr(), args.filter_tensor->ptr(), args.dst_tensor->ptr(), reinterpret_cast(args.workspace.raw_ptr), fh, fw, sh, sw, kern_param, cublas_hdl, stream, one, zero); diff --git a/dnn/src/cuda/local_share/forward/batched_matmul.cpp b/dnn/src/cuda/local_share/forward/batched_matmul.cpp index 64264567..d0eaf061 100644 --- a/dnn/src/cuda/local_share/forward/batched_matmul.cpp +++ b/dnn/src/cuda/local_share/forward/batched_matmul.cpp @@ -24,8 +24,7 @@ bool LocalShareForwardImpl::AlgoBatchedMatMul::is_available( // NCHW format available &= param.format == Format::NCHW; // only support float - auto src_dtype = args.src_layout.dtype, - filter_dtype = args.filter_layout.dtype, + auto src_dtype = args.src_layout.dtype, filter_dtype = args.filter_layout.dtype, dst_dtype = args.dst_layout.dtype; available &= (src_dtype == filter_dtype) && (src_dtype == dst_dtype) && (src_dtype == dtype::Float32()); @@ -38,8 +37,8 @@ bool LocalShareForwardImpl::AlgoBatchedMatMul::is_available( WorkspaceBundle LocalShareForwardImpl::AlgoBatchedMatMul::get_workspace_bundle( dt_byte* raw_ptr, const SizeArgs& args) const { auto&& param = args.opr->param(); - unpack_local_share_params(args.src_layout, args.filter_layout, - args.dst_layout, param); + unpack_local_share_params( + args.src_layout, args.filter_layout, args.dst_layout, param); using Param = LocalShare::Param; using Sparse = Param::Sparse; size_t groups = 1; @@ -47,17 +46,15 @@ WorkspaceBundle LocalShareForwardImpl::AlgoBatchedMatMul::get_workspace_bundle( groups = args.filter_layout.shape[0]; } size_t icpg = ci / groups, ocpg = co / groups; - size_t ws_im2col = - n * ci * ho * wo * fh * fw * args.src_layout.dtype.size(); + size_t ws_im2col = n * ci * ho * wo * fh * fw * args.src_layout.dtype.size(); size_t ws_posttranspose = n * co * ho * wo * args.dst_layout.dtype.size(); auto&& matmul_opr = args.opr->handle()->create_operator(); TensorLayout A{ {groups * sgh * sgw, ho / sgh * wo / sgw * n, icpg * fh * fw}, dtype::Float32()}; - TensorLayout B{{groups * sgh * sgw, icpg * fh * fw, ocpg}, - dtype::Float32()}; - TensorLayout C{{groups * sgh * sgw, ho / sgh * wo / sgw * n, ocpg}, - dtype::Float32()}; + TensorLayout B{{groups * sgh * sgw, icpg * fh * fw, ocpg}, dtype::Float32()}; + TensorLayout C{ + {groups * sgh * sgw, ho / sgh * wo / sgw * n, ocpg}, dtype::Float32()}; size_t ws_matmul = matmul_opr->get_workspace_in_bytes(A, B, C); WorkspaceBundle ws{raw_ptr, {ws_im2col, ws_matmul, ws_posttranspose}}; return ws; @@ -68,11 +65,10 @@ size_t LocalShareForwardImpl::AlgoBatchedMatMul::get_workspace_in_bytes( return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } -void LocalShareForwardImpl::AlgoBatchedMatMul::exec( - const ExecArgs& args) const { +void LocalShareForwardImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) const { auto&& param = args.opr->param(); - unpack_local_share_params(args.src_layout, args.filter_layout, - args.dst_layout, param); + unpack_local_share_params( + args.src_layout, args.filter_layout, args.dst_layout, param); using Param = LocalShare::Param; using Sparse = Param::Sparse; size_t groups = 1; @@ -81,10 +77,10 @@ void LocalShareForwardImpl::AlgoBatchedMatMul::exec( } size_t icpg = ci / groups, ocpg = co / groups; local_share::Param kern_param; - kern_param.n = n, kern_param.co = co, kern_param.ci = ci, - kern_param.hi = hi, kern_param.wi = wi, kern_param.ph = ph, - kern_param.pw = pw, kern_param.grp_ho = ho / sgh, - kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, kern_param.sgw = sgw; + kern_param.n = n, kern_param.co = co, kern_param.ci = ci, kern_param.hi = hi, + kern_param.wi = wi, kern_param.ph = ph, kern_param.pw = pw, + kern_param.grp_ho = ho / sgh, kern_param.grp_wo = wo / sgw, kern_param.sgh = sgh, + kern_param.sgw = sgw; auto ws = get_workspace_bundle(args.workspace.raw_ptr, args); auto ws_im2col = ws.get(0); @@ -100,10 +96,9 @@ void LocalShareForwardImpl::AlgoBatchedMatMul::exec( TensorLayout A{ {groups * sgh * sgw, ho / sgh * wo / sgw * n, icpg * fh * fw}, dtype::Float32()}; - TensorLayout B{{groups * sgh * sgw, icpg * fh * fw, ocpg}, - dtype::Float32()}; - TensorLayout C{{groups * sgh * sgw, ho / sgh * wo / sgw * n, ocpg}, - dtype::Float32()}; + TensorLayout B{{groups * sgh * sgw, icpg * fh * fw, ocpg}, dtype::Float32()}; + TensorLayout C{ + {groups * sgh * sgw, ho / sgh * wo / sgw * n, ocpg}, dtype::Float32()}; TensorND ts_A{ws_im2col, A}; TensorND ts_B{args.filter_tensor->raw_ptr, B}; TensorND ts_C{ws_posttranspose, C}; @@ -113,8 +108,8 @@ void LocalShareForwardImpl::AlgoBatchedMatMul::exec( matmul_opr->exec(ts_A, ts_B, ts_C, ws_wrapper); { - TensorLayout C1{{n, groups, ocpg, sgh, ho / sgh, sgw, wo / sgw}, - dtype::Float32()}; + TensorLayout C1{ + {n, groups, ocpg, sgh, ho / sgh, sgw, wo / sgw}, dtype::Float32()}; C1.stride[0] = ho / sgh * wo / sgw * ocpg; C1.stride[1] = n * ho * wo * ocpg; C1.stride[2] = 1; diff --git a/dnn/src/cuda/local_share/forward/local_share_forward.cuh b/dnn/src/cuda/local_share/forward/local_share_forward.cuh index 867ec820..371354fc 100644 --- a/dnn/src/cuda/local_share/forward/local_share_forward.cuh +++ b/dnn/src/cuda/local_share/forward/local_share_forward.cuh @@ -15,16 +15,14 @@ namespace cuda { namespace local_share { void _do_local_share_convolution_large_batch_size( - const float* d_src, const float* d_filter, float* d_dst, - float* workspace, int fh, int fw, int sh, int sw, const Param& param, - cublasHandle_t cublas_handle, cudaStream_t stream, float* one, - float* zero); + const float* d_src, const float* d_filter, float* d_dst, float* workspace, + int fh, int fw, int sh, int sw, const Param& param, + cublasHandle_t cublas_handle, cudaStream_t stream, float* one, float* zero); void _do_local_share_convolution_large_batch_size_small_image( - const float* d_src, const float* d_filter, float* d_dst, - float* workspace, int fh, int fw, int sh, int sw, const Param& param, - cublasHandle_t cublas_handle, cudaStream_t stream, float* one, - float* zero); + const float* d_src, const float* d_filter, float* d_dst, float* workspace, + int fh, int fw, int sh, int sw, const Param& param, + cublasHandle_t cublas_handle, cudaStream_t stream, float* one, float* zero); } // namespace local_share } // namespace cuda diff --git a/dnn/src/cuda/local_share/forward/local_share_fwd_chwn_f32_batch_size_aware.cu b/dnn/src/cuda/local_share/forward/local_share_fwd_chwn_f32_batch_size_aware.cu index 4def8df1..0dd1a889 100644 --- a/dnn/src/cuda/local_share/forward/local_share_fwd_chwn_f32_batch_size_aware.cu +++ b/dnn/src/cuda/local_share/forward/local_share_fwd_chwn_f32_batch_size_aware.cu @@ -28,47 +28,42 @@ struct ThreadConfig { static int const nr_thread_y = thread_y; }; -template +template struct DataTileCount { static int const tile_hi = LocalShareConfig::fh; static int const tile_wi = UnrollConfig::unroll_wo * LocalShareConfig::sw + LocalShareConfig::fw - LocalShareConfig::sw; static int const tile_hw = tile_hi * tile_wi; static int const tile_chw = UnrollConfig::unroll_ci * tile_hi * tile_wi; - static int const reg_gl2sh = (tile_chw + ThreadConfig::nr_thread_y - 1) / - ThreadConfig::nr_thread_y; + static int const reg_gl2sh = + (tile_chw + ThreadConfig::nr_thread_y - 1) / ThreadConfig::nr_thread_y; static int const smem_h = tile_chw; static int const smem_w = ThreadConfig::nr_thread_x; static int const smem_stride = smem_w; static int const smem_tot = smem_h * smem_stride; }; -template +template struct FilterTileCount { - static int const tile_co = - ThreadConfig::nr_thread_y * UnrollConfig::unroll_co; + static int const tile_co = ThreadConfig::nr_thread_y * UnrollConfig::unroll_co; static int const tile_ci = UnrollConfig::unroll_ci; - static int const smem_h = - tile_ci * LocalShareConfig::fh * LocalShareConfig::fw; + static int const smem_h = tile_ci * LocalShareConfig::fh * LocalShareConfig::fw; static int const smem_w = tile_co; static int const smem_stride = smem_w + 1; static int const smem_tot = smem_h * smem_stride; - MEGDNN_STATIC_ASSERT(smem_w % ThreadConfig::nr_thread_x == 0, - "col of share memory must be divided by nr_thread_x"); - static int const reg_h = (smem_h + ThreadConfig::nr_thread_y - 1) / - ThreadConfig::nr_thread_y; + MEGDNN_STATIC_ASSERT( + smem_w % ThreadConfig::nr_thread_x == 0, + "col of share memory must be divided by nr_thread_x"); + static int const reg_h = + (smem_h + ThreadConfig::nr_thread_y - 1) / ThreadConfig::nr_thread_y; static int const reg_w = smem_w / ThreadConfig::nr_thread_x; }; -template +template struct DataGlobal2ShareMemVisitor { typedef float copy_t; - typedef DataTileCount - DataTileCount; + typedef DataTileCount DataTileCount; float* smem; const float* g_ptr; int c_stride; @@ -81,10 +76,9 @@ struct DataGlobal2ShareMemVisitor { copy_t reg[DataTileCount::reg_gl2sh]; - __device__ DataGlobal2ShareMemVisitor(float* smem, const float* g_ptr, - int c_stride, int h_stride, - int w_stride, int h1, int h2, int w1, - int w2) + __device__ DataGlobal2ShareMemVisitor( + float* smem, const float* g_ptr, int c_stride, int h_stride, int w_stride, + int h1, int h2, int w1, int w2) : smem{smem}, g_ptr{g_ptr}, c_stride{c_stride}, @@ -152,8 +146,7 @@ struct DataGlobal2ShareMemVisitor { }; }; -template +template struct FilterGlobal2ShareMemVisitor { typedef float copy_t; typedef FilterTileCount @@ -167,8 +160,8 @@ struct FilterGlobal2ShareMemVisitor { copy_t reg[FilterTileCount::reg_h][FilterTileCount::reg_w]; - __device__ FilterGlobal2ShareMemVisitor(float* smem, const float* g_ptr, - int remain, int stride) + __device__ FilterGlobal2ShareMemVisitor( + float* smem, const float* g_ptr, int remain, int stride) : smem{smem}, g_ptr{g_ptr}, remain{remain}, stride{stride} {}; __device__ __forceinline__ void first_copy() { @@ -189,7 +182,7 @@ struct FilterGlobal2ShareMemVisitor { } __device__ __forceinline__ void copy() { - // TODO: co bound check + // TODO: co bound check #pragma unroll for (int i = 0; i < FilterTileCount::reg_h; ++i) { int h_idx = tid_y + i * ThreadConfig::nr_thread_y; @@ -225,29 +218,26 @@ struct FilterGlobal2ShareMemVisitor { } __device__ __forceinline__ void move_forward() { - g_ptr += UnrollConfig::unroll_ci * LocalShareConfig::fh * - LocalShareConfig::fw * stride; + g_ptr += UnrollConfig::unroll_ci * LocalShareConfig::fh * LocalShareConfig::fw * + stride; } }; -template +template __device__ __forceinline__ void consume_block( - DataGlobal2ShareMemVisitor& src_gl2sh_visitor, - FilterGlobal2ShareMemVisitor& filter_gl2sh_visitor, - float r_src[DataTileCount::tile_wi], + DataGlobal2ShareMemVisitor& + src_gl2sh_visitor, + FilterGlobal2ShareMemVisitor& + filter_gl2sh_visitor, + float r_src + [DataTileCount::tile_wi], float r_filter[UnrollConfig::unroll_co][LocalShareConfig::fw], float r_acc[UnrollConfig::unroll_co][UnrollConfig::unroll_wo]) { - typedef DataTileCount - DataTileCount; + typedef DataTileCount DataTileCount; const int tidx = threadIdx.x; const int tidy = threadIdx.y; for (int ci_inner = 0; ci_inner < UnrollConfig::unroll_ci; ++ci_inner) { - int sh_flt_row_base = - ci_inner * LocalShareConfig::fh * LocalShareConfig::fw; + int sh_flt_row_base = ci_inner * LocalShareConfig::fh * LocalShareConfig::fw; int sh_flt_col_base = tidy * UnrollConfig::unroll_co; int sh_src_row_base = ci_inner * DataTileCount::tile_hw; #pragma unroll @@ -264,8 +254,8 @@ __device__ __forceinline__ void consume_block( #pragma unroll for (int i = 0; i < DataTileCount::tile_wi; ++i) { int sh_src_row = kh * DataTileCount::tile_wi + i; - r_src[i] = *(src_gl2sh_visitor.sh_ptr( - sh_src_row_base + sh_src_row, tidx)); + r_src[i] = + *(src_gl2sh_visitor.sh_ptr(sh_src_row_base + sh_src_row, tidx)); } #pragma unroll for (int kw = 0; kw < LocalShareConfig::fw; ++kw) { @@ -273,8 +263,8 @@ __device__ __forceinline__ void consume_block( for (int i = 0; i < UnrollConfig::unroll_co; ++i) { #pragma unroll for (int j = 0; j < UnrollConfig::unroll_wo; ++j) { - r_acc[i][j] += r_src[j * LocalShareConfig::sw + kw] * - r_filter[i][kw]; + r_acc[i][j] += + r_src[j * LocalShareConfig::sw + kw] * r_filter[i][kw]; } } } @@ -289,13 +279,11 @@ __device__ __forceinline__ void consume_block( * of one slice with height ho and width wo of the output tensor. Each block * compute 32 batches and BY x UnrollConfig::unroll_co output channels. */ -template +template __global__ void local_share_device_template_f32( const float* __restrict__ src, const float* __restrict__ filter, float* __restrict__ dst, Param param) { - typedef DataTileCount - DataTileCount; + typedef DataTileCount DataTileCount; typedef FilterTileCount FilterTileCount; @@ -306,8 +294,8 @@ __global__ void local_share_device_template_f32( const int bidy = blockIdx.y; const int bidz = blockIdx.z; - const int blks_per_grp_wo = (param.grp_wo + UnrollConfig::unroll_wo - 1) / - UnrollConfig::unroll_wo; + const int blks_per_grp_wo = + (param.grp_wo + UnrollConfig::unroll_wo - 1) / UnrollConfig::unroll_wo; const int b_co = bidy / param.grp_ho; const int b_grp_ho = bidy - b_co * param.grp_ho; const int b_n = bidx / blks_per_grp_wo; @@ -324,16 +312,13 @@ __global__ void local_share_device_template_f32( const int ho = param.sgh * param.grp_ho; const int wo = param.sgw * param.grp_wo; - const int t_co = - b_co * FilterTileCount::tile_co + tidy * UnrollConfig::unroll_co; + const int t_co = b_co * FilterTileCount::tile_co + tidy * UnrollConfig::unroll_co; - const float* __restrict__ g_ptr_src = - src + (b_hi * param.wi + b_wi) * param.n + - b_n * ThreadConfig::nr_thread_x + tidx; + const float* __restrict__ g_ptr_src = src + (b_hi * param.wi + b_wi) * param.n + + b_n * ThreadConfig::nr_thread_x + tidx; const float* __restrict__ g_ptr_filter = filter + - (b_sgh * param.sgw + b_sgw) * param.co * param.ci * - LocalShareConfig::fh * + (b_sgh * param.sgw + b_sgw) * param.co * param.ci * LocalShareConfig::fh * LocalShareConfig::fw // spatial group + b_co; // output channel float* __restrict__ g_ptr_dst = dst + t_co * ho * wo * param.n + @@ -347,18 +332,18 @@ __global__ void local_share_device_template_f32( // TODO check register DataGlobal2ShareMemVisitor - src_gl2sh_visitor{sh_src, - g_ptr_src, - param.hi * param.wi * param.n, - param.wi * param.n, - param.n, - -b_hi, - param.hi - b_hi, - -b_wi, - param.wi - b_wi}; + src_gl2sh_visitor{ + sh_src, + g_ptr_src, + param.hi * param.wi * param.n, + param.wi * param.n, + param.n, + -b_hi, + param.hi - b_hi, + -b_wi, + param.wi - b_wi}; FilterGlobal2ShareMemVisitor - filter_gl2sh_visitor{sh_filter, g_ptr_filter, param.co - b_co, - param.co}; + filter_gl2sh_visitor{sh_filter, g_ptr_filter, param.co - b_co, param.co}; float r_src[DataTileCount::tile_wi]; float r_filter[UnrollConfig::unroll_co][LocalShareConfig::fw]; @@ -377,8 +362,7 @@ __global__ void local_share_device_template_f32( __syncthreads(); - int ci_blks = - (param.ci + UnrollConfig::unroll_ci - 1) / UnrollConfig::unroll_ci; + int ci_blks = (param.ci + UnrollConfig::unroll_ci - 1) / UnrollConfig::unroll_ci; for (int ci_outer = 0; ci_outer < ci_blks - 1; ci_outer++) { src_gl2sh_visitor.move_forward(); @@ -387,8 +371,7 @@ __global__ void local_share_device_template_f32( filter_gl2sh_visitor.copy(); consume_block( - src_gl2sh_visitor, filter_gl2sh_visitor, r_src, r_filter, - r_acc); + src_gl2sh_visitor, filter_gl2sh_visitor, r_src, r_filter, r_acc); __syncthreads(); src_gl2sh_visitor.commit(); @@ -414,12 +397,14 @@ __global__ void local_share_device_template_f32( } } -void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, - LaunchConfig& launch_config))(const float* __restrict__, - const float* __restrict__, - float* __restrict__, Param) { - void (*kern)(const float* __restrict__, const float* __restrict__, - float* __restrict__, Param); +void (*get_kern( + int fh, int fw, int sh, int sw, const Param& param, + LaunchConfig& launch_config))( + const float* __restrict__, const float* __restrict__, float* __restrict__, + Param) { + void (*kern)( + const float* __restrict__, const float* __restrict__, float* __restrict__, + Param); kern = nullptr; if (fh == 1 && fw == 1 && sh == 1 && sw == 1) { static constexpr int fh_ = 1; @@ -436,8 +421,8 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, typedef LocalShareConfig LocalShareConfig_; \ typedef UnrollConfig UnrollConfig_; \ typedef ThreadConfig ThreadConfig_; \ - kern = local_share_device_template_f32; \ + kern = local_share_device_template_f32< \ + LocalShareConfig_, UnrollConfig_, ThreadConfig_>; \ launch_config.nr_threads_x = nr_thread_x; \ launch_config.nr_threads_y = nr_thread_y; \ launch_config.nr_threads_z = 1; \ @@ -447,11 +432,11 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; \ launch_config.nr_blocks_z = param.sgh * param.sgw; \ launch_config.smem_size_in_bytes = \ - sizeof(float) * \ - DataTileCount::smem_tot + \ - sizeof(float) * \ - FilterTileCount::smem_tot + \ + sizeof(float) * FilterTileCount< \ + LocalShareConfig_, UnrollConfig_, \ ThreadConfig_>::smem_tot; \ } CK_GRP_WO(1); @@ -474,8 +459,8 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, typedef LocalShareConfig LocalShareConfig_; \ typedef UnrollConfig UnrollConfig_; \ typedef ThreadConfig ThreadConfig_; \ - kern = local_share_device_template_f32; \ + kern = local_share_device_template_f32< \ + LocalShareConfig_, UnrollConfig_, ThreadConfig_>; \ launch_config.nr_threads_x = nr_thread_x; \ launch_config.nr_threads_y = nr_thread_y; \ launch_config.nr_threads_z = 1; \ @@ -485,11 +470,11 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; \ launch_config.nr_blocks_z = param.sgh * param.sgw; \ launch_config.smem_size_in_bytes = \ - sizeof(float) * \ - DataTileCount::smem_tot + \ - sizeof(float) * \ - FilterTileCount::smem_tot + \ + sizeof(float) * FilterTileCount< \ + LocalShareConfig_, UnrollConfig_, \ ThreadConfig_>::smem_tot; \ } CK_GRP_WO(1); @@ -516,8 +501,8 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, typedef LocalShareConfig LocalShareConfig_; \ typedef UnrollConfig UnrollConfig_; \ typedef ThreadConfig ThreadConfig_; \ - kern = local_share_device_template_f32; \ + kern = local_share_device_template_f32< \ + LocalShareConfig_, UnrollConfig_, ThreadConfig_>; \ launch_config.nr_threads_x = nr_thread_x; \ launch_config.nr_threads_y = nr_thread_y; \ launch_config.nr_threads_z = 1; \ @@ -527,11 +512,11 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; \ launch_config.nr_blocks_z = param.sgh * param.sgw; \ launch_config.smem_size_in_bytes = \ - sizeof(float) * \ - DataTileCount::smem_tot + \ - sizeof(float) * \ - FilterTileCount::smem_tot + \ + sizeof(float) * FilterTileCount< \ + LocalShareConfig_, UnrollConfig_, \ ThreadConfig_>::smem_tot; \ } CK_GRP_WO(1); @@ -558,8 +543,8 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, typedef LocalShareConfig LocalShareConfig_; \ typedef UnrollConfig UnrollConfig_; \ typedef ThreadConfig ThreadConfig_; \ - kern = local_share_device_template_f32; \ + kern = local_share_device_template_f32< \ + LocalShareConfig_, UnrollConfig_, ThreadConfig_>; \ launch_config.nr_threads_x = nr_thread_x; \ launch_config.nr_threads_y = nr_thread_y; \ launch_config.nr_threads_z = 1; \ @@ -569,11 +554,11 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; \ launch_config.nr_blocks_z = param.sgh * param.sgw; \ launch_config.smem_size_in_bytes = \ - sizeof(float) * \ - DataTileCount::smem_tot + \ - sizeof(float) * \ - FilterTileCount::smem_tot + \ + sizeof(float) * FilterTileCount< \ + LocalShareConfig_, UnrollConfig_, \ ThreadConfig_>::smem_tot; \ } CK_GRP_WO(1); @@ -605,17 +590,17 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, launch_config.nr_threads_x = nr_thread_x; launch_config.nr_threads_y = nr_thread_y; launch_config.nr_threads_z = 1; - launch_config.nr_blocks_x = DIVUP(param.n, nr_thread_x) * - DIVUP(param.grp_wo, unroll_wo); + launch_config.nr_blocks_x = + DIVUP(param.n, nr_thread_x) * DIVUP(param.grp_wo, unroll_wo); launch_config.nr_blocks_y = DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; launch_config.nr_blocks_z = param.sgh * param.sgw; launch_config.smem_size_in_bytes = - sizeof(float) * - DataTileCount::smem_tot + - sizeof(float) * - FilterTileCount::smem_tot + + sizeof(float) * FilterTileCount< + LocalShareConfig_, UnrollConfig_, ThreadConfig_>::smem_tot; } else if (param.grp_wo >= 4) { @@ -632,17 +617,17 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, launch_config.nr_threads_x = nr_thread_x; launch_config.nr_threads_y = nr_thread_y; launch_config.nr_threads_z = 1; - launch_config.nr_blocks_x = DIVUP(param.n, nr_thread_x) * - DIVUP(param.grp_wo, unroll_wo); + launch_config.nr_blocks_x = + DIVUP(param.n, nr_thread_x) * DIVUP(param.grp_wo, unroll_wo); launch_config.nr_blocks_y = DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; launch_config.nr_blocks_z = param.sgh * param.sgw; launch_config.smem_size_in_bytes = - sizeof(float) * - DataTileCount::smem_tot + - sizeof(float) * - FilterTileCount::smem_tot + + sizeof(float) * FilterTileCount< + LocalShareConfig_, UnrollConfig_, ThreadConfig_>::smem_tot; } else { @@ -659,17 +644,17 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, launch_config.nr_threads_x = nr_thread_x; launch_config.nr_threads_y = nr_thread_y; launch_config.nr_threads_z = 1; - launch_config.nr_blocks_x = DIVUP(param.n, nr_thread_x) * - DIVUP(param.grp_wo, unroll_wo); + launch_config.nr_blocks_x = + DIVUP(param.n, nr_thread_x) * DIVUP(param.grp_wo, unroll_wo); launch_config.nr_blocks_y = DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; launch_config.nr_blocks_z = param.sgh * param.sgw; launch_config.smem_size_in_bytes = - sizeof(float) * - DataTileCount::smem_tot + - sizeof(float) * - FilterTileCount::smem_tot + + sizeof(float) * FilterTileCount< + LocalShareConfig_, UnrollConfig_, ThreadConfig_>::smem_tot; } } else if (fh == 5 && fw == 5 && sh == 2 && sw == 2) { @@ -691,17 +676,17 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, launch_config.nr_threads_x = nr_thread_x; launch_config.nr_threads_y = nr_thread_y; launch_config.nr_threads_z = 1; - launch_config.nr_blocks_x = DIVUP(param.n, nr_thread_x) * - DIVUP(param.grp_wo, unroll_wo); + launch_config.nr_blocks_x = + DIVUP(param.n, nr_thread_x) * DIVUP(param.grp_wo, unroll_wo); launch_config.nr_blocks_y = DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; launch_config.nr_blocks_z = param.sgh * param.sgw; launch_config.smem_size_in_bytes = - sizeof(float) * - DataTileCount::smem_tot + - sizeof(float) * - FilterTileCount::smem_tot + + sizeof(float) * FilterTileCount< + LocalShareConfig_, UnrollConfig_, ThreadConfig_>::smem_tot; } else { static constexpr int unroll_co = 16; @@ -717,17 +702,17 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, launch_config.nr_threads_x = nr_thread_x; launch_config.nr_threads_y = nr_thread_y; launch_config.nr_threads_z = 1; - launch_config.nr_blocks_x = DIVUP(param.n, nr_thread_x) * - DIVUP(param.grp_wo, unroll_wo); + launch_config.nr_blocks_x = + DIVUP(param.n, nr_thread_x) * DIVUP(param.grp_wo, unroll_wo); launch_config.nr_blocks_y = DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; launch_config.nr_blocks_z = param.sgh * param.sgw; launch_config.smem_size_in_bytes = - sizeof(float) * - DataTileCount::smem_tot + - sizeof(float) * - FilterTileCount::smem_tot + + sizeof(float) * FilterTileCount< + LocalShareConfig_, UnrollConfig_, ThreadConfig_>::smem_tot; } } else if (fh == 7 && fw == 7 && sh == 1 && sw == 1) { @@ -749,17 +734,17 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, launch_config.nr_threads_x = nr_thread_x; launch_config.nr_threads_y = nr_thread_y; launch_config.nr_threads_z = 1; - launch_config.nr_blocks_x = DIVUP(param.n, nr_thread_x) * - DIVUP(param.grp_wo, unroll_wo); + launch_config.nr_blocks_x = + DIVUP(param.n, nr_thread_x) * DIVUP(param.grp_wo, unroll_wo); launch_config.nr_blocks_y = DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; launch_config.nr_blocks_z = param.sgh * param.sgw; launch_config.smem_size_in_bytes = - sizeof(float) * - DataTileCount::smem_tot + - sizeof(float) * - FilterTileCount::smem_tot + + sizeof(float) * FilterTileCount< + LocalShareConfig_, UnrollConfig_, ThreadConfig_>::smem_tot; } else if (param.grp_wo >= 4) { @@ -776,17 +761,17 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, launch_config.nr_threads_x = nr_thread_x; launch_config.nr_threads_y = nr_thread_y; launch_config.nr_threads_z = 1; - launch_config.nr_blocks_x = DIVUP(param.n, nr_thread_x) * - DIVUP(param.grp_wo, unroll_wo); + launch_config.nr_blocks_x = + DIVUP(param.n, nr_thread_x) * DIVUP(param.grp_wo, unroll_wo); launch_config.nr_blocks_y = DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; launch_config.nr_blocks_z = param.sgh * param.sgw; launch_config.smem_size_in_bytes = - sizeof(float) * - DataTileCount::smem_tot + - sizeof(float) * - FilterTileCount::smem_tot + + sizeof(float) * FilterTileCount< + LocalShareConfig_, UnrollConfig_, ThreadConfig_>::smem_tot; } else { @@ -803,17 +788,17 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, launch_config.nr_threads_x = nr_thread_x; launch_config.nr_threads_y = nr_thread_y; launch_config.nr_threads_z = 1; - launch_config.nr_blocks_x = DIVUP(param.n, nr_thread_x) * - DIVUP(param.grp_wo, unroll_wo); + launch_config.nr_blocks_x = + DIVUP(param.n, nr_thread_x) * DIVUP(param.grp_wo, unroll_wo); launch_config.nr_blocks_y = DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; launch_config.nr_blocks_z = param.sgh * param.sgw; launch_config.smem_size_in_bytes = - sizeof(float) * - DataTileCount::smem_tot + - sizeof(float) * - FilterTileCount::smem_tot + + sizeof(float) * FilterTileCount< + LocalShareConfig_, UnrollConfig_, ThreadConfig_>::smem_tot; } } else if (fh == 7 && fw == 7 && sh == 2 && sw == 2) { @@ -829,8 +814,8 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, typedef LocalShareConfig LocalShareConfig_; typedef UnrollConfig UnrollConfig_; typedef ThreadConfig ThreadConfig_; - kern = local_share_device_template_f32; + kern = local_share_device_template_f32< + LocalShareConfig_, UnrollConfig_, ThreadConfig_>; launch_config.nr_threads_x = nr_thread_x; launch_config.nr_threads_y = nr_thread_y; launch_config.nr_threads_z = 1; @@ -840,16 +825,18 @@ void (*get_kern(int fh, int fw, int sh, int sw, const Param& param, DIVUP(param.co, nr_thread_y * unroll_co) * param.grp_ho; launch_config.nr_blocks_z = param.sgh * param.sgw; launch_config.smem_size_in_bytes = - sizeof(float) * DataTileCount::smem_tot + sizeof(float) * - FilterTileCount:: + smem_tot + + sizeof(float) * FilterTileCount< + LocalShareConfig_, UnrollConfig_, ThreadConfig_>::smem_tot; } else { - megdnn_assert(false, - "no usable kernel implementation for local share " - "convolution (fh,fw)=(%d,%d), (sh,sw)=(%d,%d)", - fh, fw, sh, sw); + megdnn_assert( + false, + "no usable kernel implementation for local share " + "convolution (fh,fw)=(%d,%d), (sh,sw)=(%d,%d)", + fh, fw, sh, sw); } return kern; } @@ -1252,10 +1239,9 @@ __global__ void local_share_device_template_f32( #endif void megdnn::cuda::local_share::_do_local_share_convolution_large_batch_size( - const float* d_src, const float* d_filter, float* d_dst, - float* workspace, int fh, int fw, int sh, int sw, const Param& param, - cublasHandle_t cublas_handle, cudaStream_t stream, float* one, - float* zero) { + const float* d_src, const float* d_filter, float* d_dst, float* workspace, + int fh, int fw, int sh, int sw, const Param& param, + cublasHandle_t cublas_handle, cudaStream_t stream, float* one, float* zero) { float* ws_src = workspace; int nr_elem_total = param.n * param.ci * param.hi * param.wi; float* ws_dst = workspace + nr_elem_total; @@ -1265,14 +1251,15 @@ void megdnn::cuda::local_share::_do_local_share_convolution_large_batch_size( int lda, ldb; lda = ldb = param.ci * param.hi * param.wi; int ldc = param.n; - cublas_check(cublasSgeam(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, - one, d_src, lda, zero, d_src, ldb, ws_src, - ldc)); + cublas_check(cublasSgeam( + cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, one, d_src, lda, zero, + d_src, ldb, ws_src, ldc)); } { - void (*kern)(const float* __restrict__, const float* __restrict__, - float* __restrict__, Param); + void (*kern)( + const float* __restrict__, const float* __restrict__, + float* __restrict__, Param); LaunchConfig launch_config; kern = get_kern(fh, fw, sh, sw, param, launch_config); @@ -1299,9 +1286,9 @@ void megdnn::cuda::local_share::_do_local_share_convolution_large_batch_size( int lda, ldb; lda = ldb = param.n; int ldc = param.co * ho * wo; - cublas_check(cublasSgeam(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, - one, ws_dst, lda, zero, ws_dst, ldb, d_dst, - ldc)); + cublas_check(cublasSgeam( + cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, one, ws_dst, lda, zero, + ws_dst, ldb, d_dst, ldc)); } } diff --git a/dnn/src/cuda/local_share/forward/local_share_fwd_chwn_f32_batch_size_aware_small_image.cu b/dnn/src/cuda/local_share/forward/local_share_fwd_chwn_f32_batch_size_aware_small_image.cu index 5279b295..5fd46d85 100644 --- a/dnn/src/cuda/local_share/forward/local_share_fwd_chwn_f32_batch_size_aware_small_image.cu +++ b/dnn/src/cuda/local_share/forward/local_share_fwd_chwn_f32_batch_size_aware_small_image.cu @@ -1,5 +1,6 @@ /** - * \file dnn/src/cuda/local_share/forward/local_share_fwd_chwn_f32_batch_size_aware_small_image.cu + * \file + * dnn/src/cuda/local_share/forward/local_share_fwd_chwn_f32_batch_size_aware_small_image.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -33,8 +34,7 @@ template struct DataTileCount { typedef UnrollConfig_ UnrollConfig; typedef ThreadConfig_ ThreadConfig; - static int const tile_batch = - UnrollConfig::unroll_n * ThreadConfig::nr_thread_x; + static int const tile_batch = UnrollConfig::unroll_n * ThreadConfig::nr_thread_x; static int const load_x = tile_batch > 32 ? 32 : tile_batch; static int const load_y = ThreadConfig::nr_threads / load_x; @@ -53,8 +53,7 @@ template struct FilterTileCount { typedef UnrollConfig_ UnrollConfig; typedef ThreadConfig_ ThreadConfig; - static int const tile_co = - ThreadConfig::nr_thread_y * UnrollConfig::unroll_co; + static int const tile_co = ThreadConfig::nr_thread_y * UnrollConfig::unroll_co; static int const smem_h = UnrollConfig::unroll_ci; static int const smem_w = tile_co; static int const smem_stride = smem_w + 1; @@ -178,8 +177,7 @@ struct FilterGlobal2ShareMemVisitor { copy_t reg[TileCount::reg_row][TileCount::reg_col]; - __device__ FilterGlobal2ShareMemVisitor(copy_t* smem, int stride, - int remain) + __device__ FilterGlobal2ShareMemVisitor(copy_t* smem, int stride, int remain) : smem{smem}, stride{stride}, remain{remain} {} __device__ __forceinline__ void first_copy() { @@ -261,8 +259,7 @@ __device__ __forceinline__ void consume_block( data_gl2sh_visitor, FilterGlobal2ShareMemVisitor& filter_gl2sh_visitor, - float r_src[UnrollConfig::unroll_n], - float r_filter[UnrollConfig::unroll_co], + float r_src[UnrollConfig::unroll_n], float r_filter[UnrollConfig::unroll_co], float r_acc[UnrollConfig::unroll_co][UnrollConfig::unroll_n]) { typedef DataTileCount DataTileCount; const int tidx = threadIdx.x; @@ -329,19 +326,17 @@ __global__ void local_share_device_template_f32( param.co * param.ci * fh * fw; // spatial group - float* __restrict__ g_ptr_dst = - dst + t_co * ho * wo * param.n // output channel stride+ - + (b_ho * wo + b_wo) * param.n // spatial stride - + t_batch; + float* __restrict__ g_ptr_dst = dst + + t_co * ho * wo * param.n // output channel stride+ + + (b_ho * wo + b_wo) * param.n // spatial stride + + t_batch; // TODO check register DataGlobal2ShareMemVisitor - src_gl2sh_visitor{sh_src, param.hi * param.wi * param.n, - param.n - b_batch}; + src_gl2sh_visitor{sh_src, param.hi * param.wi * param.n, param.n - b_batch}; FilterGlobal2ShareMemVisitor - filter_gl2sh_visitor{sh_filter, param.co * fh * fw, - param.co - b_co}; + filter_gl2sh_visitor{sh_filter, param.co * fh * fw, param.co - b_co}; float r_src[UnrollConfig::unroll_n]; float r_filter[UnrollConfig::unroll_co]; @@ -368,8 +363,7 @@ __global__ void local_share_device_template_f32( int kh = h_start - h_base; int kw = w_start - w_base; - src_gl2sh_visitor.g_ptr = - g_ptr_src + (h_start * param.wi + w_start) * param.n; + src_gl2sh_visitor.g_ptr = g_ptr_src + (h_start * param.wi + w_start) * param.n; filter_gl2sh_visitor.g_ptr = g_ptr_filter + (kh * fw + kw) * param.co; src_gl2sh_visitor.first_copy(); filter_gl2sh_visitor.first_copy(); @@ -386,8 +380,7 @@ __global__ void local_share_device_template_f32( int kh = h_next - h_base; int kw = w_next - w_base; src_gl2sh_visitor.g_ptr = - g_ptr_src + - (h_next * param.wi + w_next) * param.n; + g_ptr_src + (h_next * param.wi + w_next) * param.n; filter_gl2sh_visitor.g_ptr = g_ptr_filter + (kh * fw + kw) * param.co; src_gl2sh_visitor.copy(); @@ -401,8 +394,8 @@ __global__ void local_share_device_template_f32( } consume_block( - src_gl2sh_visitor, filter_gl2sh_visitor, r_src, - r_filter, r_acc); + src_gl2sh_visitor, filter_gl2sh_visitor, r_src, r_filter, + r_acc); if (!(ci_outer == ci_blks - 1 && h == h_end && w == w_end)) { __syncthreads(); @@ -419,55 +412,51 @@ __global__ void local_share_device_template_f32( for (int i = 0; i < UnrollConfig::unroll_co; ++i) { #pragma unroll for (int j = 0; j < UnrollConfig::unroll_n; ++j) { - if (check_bounds && - (t_co + i * ThreadConfig::nr_thread_y >= param.co || - t_batch + j * ThreadConfig::nr_thread_x >= param.n)) { + if (check_bounds && (t_co + i * ThreadConfig::nr_thread_y >= param.co || + t_batch + j * ThreadConfig::nr_thread_x >= param.n)) { } else { - g_ptr_dst[i * ThreadConfig::nr_thread_y * co_stride + - j * ThreadConfig::nr_thread_x] = r_acc[i][j]; + g_ptr_dst + [i * ThreadConfig::nr_thread_y * co_stride + + j * ThreadConfig::nr_thread_x] = r_acc[i][j]; } } } } void (*get_kern(const Param& param, LaunchConfig& launch_config))( - const float* __restrict__, const float* __restrict__, - float* __restrict__, Param, int, int, int, int) { - void (*kern)(const float* __restrict__, const float* __restrict__, - float* __restrict__, Param, int, int, int, int); + const float* __restrict__, const float* __restrict__, float* __restrict__, + Param, int, int, int, int) { + void (*kern)( + const float* __restrict__, const float* __restrict__, float* __restrict__, + Param, int, int, int, int); kern = nullptr; -#define CHK3(n_, co_, ci_, tx_, ty_) \ - if (param.n >= n_) { \ - if (param.co >= co_) { \ - if (param.ci % ci_ == 0) { \ - static constexpr int unroll_ci = (ci_); \ - static constexpr int unroll_co = (co_ + ty_ - 1) / ty_; \ - static constexpr int unroll_n = (n_ + tx_ - 1) / tx_; \ - static constexpr int thread_x = tx_; \ - static constexpr int thread_y = ty_; \ - typedef UnrollConfig \ - UnrollConfig; \ - typedef ThreadConfig ThreadConfig; \ - typedef DataTileCount \ - DataTileCount; \ - typedef FilterTileCount \ - FilterTileCount; \ - kern = local_share_device_template_f32; \ - launch_config.nr_threads_x = thread_x; \ - launch_config.nr_threads_y = thread_y; \ - launch_config.nr_threads_z = 1; \ - launch_config.nr_blocks_x = \ - param.grp_ho * param.grp_wo * param.sgh * param.sgw; \ - launch_config.nr_blocks_y = \ - DIVUP(param.n, DataTileCount::tile_batch); \ - launch_config.nr_blocks_z = \ - DIVUP(param.co, FilterTileCount::tile_co); \ - launch_config.smem_size_in_bytes = \ - sizeof(float) * \ - (DataTileCount::smem_tot + FilterTileCount::smem_tot); \ - } \ - } \ +#define CHK3(n_, co_, ci_, tx_, ty_) \ + if (param.n >= n_) { \ + if (param.co >= co_) { \ + if (param.ci % ci_ == 0) { \ + static constexpr int unroll_ci = (ci_); \ + static constexpr int unroll_co = (co_ + ty_ - 1) / ty_; \ + static constexpr int unroll_n = (n_ + tx_ - 1) / tx_; \ + static constexpr int thread_x = tx_; \ + static constexpr int thread_y = ty_; \ + typedef UnrollConfig UnrollConfig; \ + typedef ThreadConfig ThreadConfig; \ + typedef DataTileCount DataTileCount; \ + typedef FilterTileCount FilterTileCount; \ + kern = local_share_device_template_f32< \ + true, UnrollConfig, ThreadConfig>; \ + launch_config.nr_threads_x = thread_x; \ + launch_config.nr_threads_y = thread_y; \ + launch_config.nr_threads_z = 1; \ + launch_config.nr_blocks_x = \ + param.grp_ho * param.grp_wo * param.sgh * param.sgw; \ + launch_config.nr_blocks_y = DIVUP(param.n, DataTileCount::tile_batch); \ + launch_config.nr_blocks_z = DIVUP(param.co, FilterTileCount::tile_co); \ + launch_config.smem_size_in_bytes = \ + sizeof(float) * \ + (DataTileCount::smem_tot + FilterTileCount::smem_tot); \ + } \ + } \ } #define CHK2(n_, co_) \ CHK3(n_, co_, 4, 8, 16) \ @@ -487,38 +476,33 @@ void (*get_kern(const Param& param, LaunchConfig& launch_config))( #undef CHK2 #undef CHK2_ #undef CHK3 -#define CHK3(n_, co_, ci_, tx_, ty_) \ - if (param.n % n_ == 0) { \ - if (param.co % co_ == 0) { \ - if (param.ci % ci_ == 0) { \ - static constexpr int unroll_ci = (ci_); \ - static constexpr int unroll_co = (co_) / (ty_); \ - static constexpr int unroll_n = (n_) / (tx_); \ - static constexpr int thread_x = tx_; \ - static constexpr int thread_y = ty_; \ - typedef UnrollConfig \ - UnrollConfig; \ - typedef ThreadConfig ThreadConfig; \ - typedef DataTileCount \ - DataTileCount; \ - typedef FilterTileCount \ - FilterTileCount; \ - kern = local_share_device_template_f32; \ - launch_config.nr_threads_x = thread_x; \ - launch_config.nr_threads_y = thread_y; \ - launch_config.nr_threads_z = 1; \ - launch_config.nr_blocks_x = \ - param.grp_ho * param.grp_wo * param.sgh * param.sgw; \ - launch_config.nr_blocks_y = \ - DIVUP(param.n, DataTileCount::tile_batch); \ - launch_config.nr_blocks_z = \ - DIVUP(param.co, FilterTileCount::tile_co); \ - launch_config.smem_size_in_bytes = \ - sizeof(float) * \ - (DataTileCount::smem_tot + FilterTileCount::smem_tot); \ - } \ - } \ +#define CHK3(n_, co_, ci_, tx_, ty_) \ + if (param.n % n_ == 0) { \ + if (param.co % co_ == 0) { \ + if (param.ci % ci_ == 0) { \ + static constexpr int unroll_ci = (ci_); \ + static constexpr int unroll_co = (co_) / (ty_); \ + static constexpr int unroll_n = (n_) / (tx_); \ + static constexpr int thread_x = tx_; \ + static constexpr int thread_y = ty_; \ + typedef UnrollConfig UnrollConfig; \ + typedef ThreadConfig ThreadConfig; \ + typedef DataTileCount DataTileCount; \ + typedef FilterTileCount FilterTileCount; \ + kern = local_share_device_template_f32< \ + false, UnrollConfig, ThreadConfig>; \ + launch_config.nr_threads_x = thread_x; \ + launch_config.nr_threads_y = thread_y; \ + launch_config.nr_threads_z = 1; \ + launch_config.nr_blocks_x = \ + param.grp_ho * param.grp_wo * param.sgh * param.sgw; \ + launch_config.nr_blocks_y = DIVUP(param.n, DataTileCount::tile_batch); \ + launch_config.nr_blocks_z = DIVUP(param.co, FilterTileCount::tile_co); \ + launch_config.smem_size_in_bytes = \ + sizeof(float) * \ + (DataTileCount::smem_tot + FilterTileCount::smem_tot); \ + } \ + } \ } #define CHK2(n_, co_) CHK3(n_, co_, 4, 8, 8) CHK3(n_, co_, 8, 8, 8) #define CHK(n_) \ @@ -532,10 +516,11 @@ void (*get_kern(const Param& param, LaunchConfig& launch_config))( #undef CHK #undef CHK2 #undef CHK3 - megdnn_assert(kern != nullptr, - "no usable kernel implementation for local share " - "convolution (batch,co,ci)=(%d,%d,%d)", - param.n, param.co, param.ci); + megdnn_assert( + kern != nullptr, + "no usable kernel implementation for local share " + "convolution (batch,co,ci)=(%d,%d,%d)", + param.n, param.co, param.ci); return kern; } @@ -544,9 +529,9 @@ void (*get_kern(const Param& param, LaunchConfig& launch_config))( void megdnn::cuda::local_share:: _do_local_share_convolution_large_batch_size_small_image( const float* d_src, const float* d_filter, float* d_dst, - float* workspace, int fh, int fw, int sh, int sw, - const Param& param, cublasHandle_t cublas_handle, - cudaStream_t stream, float* one, float* zero) { + float* workspace, int fh, int fw, int sh, int sw, const Param& param, + cublasHandle_t cublas_handle, cudaStream_t stream, float* one, + float* zero) { float* ws_src = workspace; int nr_src_total = param.n * param.ci * param.hi * param.wi; float* ws_dst = ws_src + nr_src_total; @@ -556,14 +541,15 @@ void megdnn::cuda::local_share:: int lda, ldb; lda = ldb = param.ci * param.hi * param.wi; int ldc = param.n; - cublas_check(cublasSgeam(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, - one, d_src, lda, zero, d_src, ldb, ws_src, - ldc)); + cublas_check(cublasSgeam( + cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, one, d_src, lda, zero, + d_src, ldb, ws_src, ldc)); } - + { - void (*kern)(const float* __restrict__, const float* __restrict__, - float* __restrict__, Param, int, int, int, int); + void (*kern)( + const float* __restrict__, const float* __restrict__, + float* __restrict__, Param, int, int, int, int); LaunchConfig launch_config; kern = get_kern(param, launch_config); @@ -590,9 +576,9 @@ void megdnn::cuda::local_share:: int lda, ldb; lda = ldb = param.n; int ldc = param.co * ho * wo; - cublas_check(cublasSgeam(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, - one, ws_dst, lda, zero, ws_dst, ldb, d_dst, - ldc)); + cublas_check(cublasSgeam( + cublas_handle, CUBLAS_OP_T, CUBLAS_OP_T, m, n, one, ws_dst, lda, zero, + ws_dst, ldb, d_dst, ldc)); } } diff --git a/dnn/src/cuda/local_share/helper.cuh b/dnn/src/cuda/local_share/helper.cuh index 67ae1dee..af3cba33 100644 --- a/dnn/src/cuda/local_share/helper.cuh +++ b/dnn/src/cuda/local_share/helper.cuh @@ -52,38 +52,37 @@ uint32_t _get_kern_block_size(const void* kern); } // namespace cuda } // namespace megdnn -#define unpack_local_share_params(_src, _filter, _dst, _param) \ - size_t n = _src[0], ci = _src[1], hi = _src[2], wi = _src[3]; \ - size_t weight_spatial_pos; \ - if (_param.sparse == LocalShare::Param::Sparse::DENSE) { \ - weight_spatial_pos = 3; \ - } else { \ - megdnn_assert(_param.sparse == LocalShare::Param::Sparse::GROUP); \ - weight_spatial_pos = 4; \ - } \ - size_t fh = _filter[weight_spatial_pos], \ - fw = _filter[weight_spatial_pos + 1]; \ - size_t co = _dst[1], ho = _dst[2], wo = _dst[3]; \ - size_t ph = _param.pad_h, pw = _param.pad_w; \ - size_t sh = _param.stride_h, sw = _param.stride_w; \ - size_t dh = _param.dilate_h, dw = _param.dilate_w; \ - size_t sgh = _param.spatial_groups_h, sgw = _param.spatial_groups_w; \ - MEGDNN_MARK_USED_VAR(n); \ - MEGDNN_MARK_USED_VAR(ci); \ - MEGDNN_MARK_USED_VAR(hi); \ - MEGDNN_MARK_USED_VAR(wi); \ - MEGDNN_MARK_USED_VAR(co); \ - MEGDNN_MARK_USED_VAR(fh); \ - MEGDNN_MARK_USED_VAR(fw); \ - MEGDNN_MARK_USED_VAR(ho); \ - MEGDNN_MARK_USED_VAR(wo); \ - MEGDNN_MARK_USED_VAR(ph); \ - MEGDNN_MARK_USED_VAR(pw); \ - MEGDNN_MARK_USED_VAR(sh); \ - MEGDNN_MARK_USED_VAR(sw); \ - MEGDNN_MARK_USED_VAR(dh); \ - MEGDNN_MARK_USED_VAR(dw); \ - MEGDNN_MARK_USED_VAR(sgh); \ +#define unpack_local_share_params(_src, _filter, _dst, _param) \ + size_t n = _src[0], ci = _src[1], hi = _src[2], wi = _src[3]; \ + size_t weight_spatial_pos; \ + if (_param.sparse == LocalShare::Param::Sparse::DENSE) { \ + weight_spatial_pos = 3; \ + } else { \ + megdnn_assert(_param.sparse == LocalShare::Param::Sparse::GROUP); \ + weight_spatial_pos = 4; \ + } \ + size_t fh = _filter[weight_spatial_pos], fw = _filter[weight_spatial_pos + 1]; \ + size_t co = _dst[1], ho = _dst[2], wo = _dst[3]; \ + size_t ph = _param.pad_h, pw = _param.pad_w; \ + size_t sh = _param.stride_h, sw = _param.stride_w; \ + size_t dh = _param.dilate_h, dw = _param.dilate_w; \ + size_t sgh = _param.spatial_groups_h, sgw = _param.spatial_groups_w; \ + MEGDNN_MARK_USED_VAR(n); \ + MEGDNN_MARK_USED_VAR(ci); \ + MEGDNN_MARK_USED_VAR(hi); \ + MEGDNN_MARK_USED_VAR(wi); \ + MEGDNN_MARK_USED_VAR(co); \ + MEGDNN_MARK_USED_VAR(fh); \ + MEGDNN_MARK_USED_VAR(fw); \ + MEGDNN_MARK_USED_VAR(ho); \ + MEGDNN_MARK_USED_VAR(wo); \ + MEGDNN_MARK_USED_VAR(ph); \ + MEGDNN_MARK_USED_VAR(pw); \ + MEGDNN_MARK_USED_VAR(sh); \ + MEGDNN_MARK_USED_VAR(sw); \ + MEGDNN_MARK_USED_VAR(dh); \ + MEGDNN_MARK_USED_VAR(dw); \ + MEGDNN_MARK_USED_VAR(sgh); \ MEGDNN_MARK_USED_VAR(sgw); // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/local_share/im2col.cu b/dnn/src/cuda/local_share/im2col.cu index 8856fab3..ad2c8e8c 100644 --- a/dnn/src/cuda/local_share/im2col.cu +++ b/dnn/src/cuda/local_share/im2col.cu @@ -16,9 +16,9 @@ using namespace local_share; namespace { template -__global__ void local_share_im2col(const T* __restrict__ img, - T* __restrict__ col, int fh, int fw, int sh, - int sw, int nr_groups, Param param) { +__global__ void local_share_im2col( + const T* __restrict__ img, T* __restrict__ col, int fh, int fw, int sh, int sw, + int nr_groups, Param param) { const int in_ch_idx = threadIdx.x + blockIdx.y * blockDim.x; const int batch = threadIdx.y + blockIdx.z * blockDim.y; if (in_ch_idx >= param.ci || batch >= param.n) @@ -36,16 +36,15 @@ __global__ void local_share_im2col(const T* __restrict__ img, const int ch_grp_idx = in_ch_idx / icpg; const int grp_ch_idx = in_ch_idx - icpg * ch_grp_idx; - const T* __restrict__ img_ptr = img + - batch * param.ci * param.hi * param.wi + + const T* __restrict__ img_ptr = img + batch * param.ci * param.hi * param.wi + in_ch_idx * param.hi * param.wi; const int ld = icpg * fh * fw; T* __restrict__ col_ptr = col + ch_grp_idx * (param.sgh * param.sgw) * param.n * grp_sizes * ld // channel group stride - + (sgh_idx * param.sgw + sgw_idx) * param.n * grp_sizes * - ld // batch stride + + + (sgh_idx * param.sgw + sgw_idx) * param.n * grp_sizes * ld // batch stride + grp_ch_idx * fh * fw // input channel stride + (batch * grp_sizes + (grp_oh_idx * param.grp_wo + grp_ow_idx)) * ld; // row stride @@ -55,8 +54,7 @@ __global__ void local_share_im2col(const T* __restrict__ img, int ih_idx = oh_idx * sh - param.ph + kh; int iw_idx = ow_idx * sw - param.pw + kw; float val = 0.f; - if (ih_idx < param.hi && ih_idx >= 0 && iw_idx < param.wi && - iw_idx >= 0) { + if (ih_idx < param.hi && ih_idx >= 0 && iw_idx < param.wi && iw_idx >= 0) { val = img_ptr[ih_idx * param.wi + iw_idx]; } *(col_ptr++) = val; @@ -65,9 +63,9 @@ __global__ void local_share_im2col(const T* __restrict__ img, } template -__global__ void local_share_col2im(const T* __restrict__ col, - T* __restrict__ img, int fh, int fw, int sh, - int sw, int nr_groups, Param param) { +__global__ void local_share_col2im( + const T* __restrict__ col, T* __restrict__ img, int fh, int fw, int sh, int sw, + int nr_groups, Param param) { const int batch = threadIdx.x + blockIdx.y * blockDim.x; const int in_ch_idx = threadIdx.y + blockIdx.z * blockDim.y; if (in_ch_idx >= param.ci || batch >= param.n) @@ -87,10 +85,9 @@ __global__ void local_share_col2im(const T* __restrict__ col, const T* __restrict__ col_ptr = col + ch_grp_idx * param.sgh * param.sgw * ch_filter_sizes * grp_sizes * - param.n // channel group stride - + batch // batch stride - + - grp_ch_idx * filter_sizes * grp_sizes * param.n; // channel stride + param.n // channel group stride + + batch // batch stride + + grp_ch_idx * filter_sizes * grp_sizes * param.n; // channel stride T res(0); for (int kh = 0; kh < fh; ++kh) { @@ -117,22 +114,22 @@ __global__ void local_share_col2im(const T* __restrict__ col, } } } - img[batch * param.ci * param.hi * param.wi + - in_ch_idx * param.hi * param.wi + ih_idx * param.wi + iw_idx] = res; + img[batch * param.ci * param.hi * param.wi + in_ch_idx * param.hi * param.wi + + ih_idx * param.wi + iw_idx] = res; } } // namespace void megdnn::cuda::local_share::_do_local_share_im2col( - const float* d_im, float* d_col, int fh, int fw, int sh, int sw, - int nr_groups, const Param& param, cudaStream_t stream) { - void (*kern)(const float* __restrict__, float* __restrict__, int, int, int, - int, int, Param); + const float* d_im, float* d_col, int fh, int fw, int sh, int sw, int nr_groups, + const Param& param, cudaStream_t stream) { + void (*kern)( + const float* __restrict__, float* __restrict__, int, int, int, int, int, + Param); kern = local_share_im2col; constexpr int threads_x = 256; - uint32_t nr_threads = - _get_kern_block_size(reinterpret_cast(kern)); + uint32_t nr_threads = _get_kern_block_size(reinterpret_cast(kern)); uint32_t nr_threads_x = std::min(threads_x, param.ci); uint32_t nr_threads_y = std::min(static_cast(nr_threads / nr_threads_x), param.n); @@ -141,21 +138,20 @@ void megdnn::cuda::local_share::_do_local_share_im2col( nr_blocks_z = DIVUP(param.n, nr_threads_y); dim3 threads{nr_threads_x, nr_threads_y, 1}; dim3 blocks{nr_blocks_x, nr_blocks_y, nr_blocks_z}; - kern<<>>(d_im, d_col, fh, fw, sh, sw, nr_groups, - param); + kern<<>>(d_im, d_col, fh, fw, sh, sw, nr_groups, param); after_kernel_launch(); } void megdnn::cuda::local_share::_do_local_share_col2im( - const float* d_col, float* d_im, int fh, int fw, int sh, int sw, - int nr_groups, const Param& param, cudaStream_t stream) { - void (*kern)(const float* __restrict__, float* __restrict__, int, int, int, - int, int, Param); + const float* d_col, float* d_im, int fh, int fw, int sh, int sw, int nr_groups, + const Param& param, cudaStream_t stream) { + void (*kern)( + const float* __restrict__, float* __restrict__, int, int, int, int, int, + Param); kern = local_share_col2im; constexpr int threads_x = 256; - uint32_t nr_threads = - _get_kern_block_size(reinterpret_cast(kern)); + uint32_t nr_threads = _get_kern_block_size(reinterpret_cast(kern)); uint32_t nr_threads_x = std::min(threads_x, param.n); uint32_t nr_threads_y = std::min(static_cast(nr_threads / nr_threads_x), param.ci); @@ -164,8 +160,7 @@ void megdnn::cuda::local_share::_do_local_share_col2im( nr_blocks_z = DIVUP(param.ci, nr_threads_y); dim3 threads{nr_threads_x, nr_threads_y, 1}; dim3 blocks{nr_blocks_x, nr_blocks_y, nr_blocks_z}; - kern<<>>(d_col, d_im, fh, fw, sh, sw, nr_groups, - param); + kern<<>>(d_col, d_im, fh, fw, sh, sw, nr_groups, param); after_kernel_launch(); } diff --git a/dnn/src/cuda/local_share/im2col.cuh b/dnn/src/cuda/local_share/im2col.cuh index c57fdae2..9abb61d0 100644 --- a/dnn/src/cuda/local_share/im2col.cuh +++ b/dnn/src/cuda/local_share/im2col.cuh @@ -8,20 +8,20 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/cuda/utils.cuh" #include "./helper.cuh" +#include "src/cuda/utils.cuh" namespace megdnn { namespace cuda { namespace local_share { -void _do_local_share_im2col(const float* d_im, float* d_col, int fh, int fw, - int sh, int sw, int nr_groups, const Param& param, - cudaStream_t stream); +void _do_local_share_im2col( + const float* d_im, float* d_col, int fh, int fw, int sh, int sw, int nr_groups, + const Param& param, cudaStream_t stream); -void _do_local_share_col2im(const float* d_col, float* d_im, int fh, int fw, - int sh, int sw, int nr_groups, const Param& param, - cudaStream_t stream); +void _do_local_share_col2im( + const float* d_col, float* d_im, int fh, int fw, int sh, int sw, int nr_groups, + const Param& param, cudaStream_t stream); } // namespace local_share } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/local_share/opr_impl.cpp b/dnn/src/cuda/local_share/opr_impl.cpp index c43b6c1b..474714d2 100644 --- a/dnn/src/cuda/local_share/opr_impl.cpp +++ b/dnn/src/cuda/local_share/opr_impl.cpp @@ -9,9 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "src/cuda/local_share/opr_impl.h" -#include "./forward/algo.h" #include "./backward_data/algo.h" #include "./backward_filter/algo.h" +#include "./forward/algo.h" #include "src/common/algo_chooser.h" #include "src/cuda/utils.h" @@ -19,16 +19,13 @@ using namespace megdnn; using namespace cuda; /* ============== LocalShareForwardImpl ============== */ -LocalShareForwardImpl::Algorithm* -LocalShareForwardImpl::get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& filter, - const TensorLayout& dst, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, +LocalShareForwardImpl::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, src, filter, dst); - if (sm_algo_pack.batch_size_aware_chwn_small_image - .is_available_attribute(args, positive_attr, negative_attr, - workspace_limit_in_bytes)) { + if (sm_algo_pack.batch_size_aware_chwn_small_image.is_available_attribute( + args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return &sm_algo_pack.batch_size_aware_chwn_small_image; } if (sm_algo_pack.batch_size_aware_chwn.is_available_attribute( @@ -39,40 +36,38 @@ LocalShareForwardImpl::get_algorithm_heuristic( args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return &sm_algo_pack.batched_matmul; } - megdnn_throw( - ssprintf("no local share conv algorithm without attribute(%s) with " - "attribute(%s), args(%s) and " - "workspace limit (%zu bytes)", - Algorithm::attribute_str(negative_attr).c_str(), - Algorithm::attribute_str(positive_attr).c_str(), - args.to_string().c_str(), workspace_limit_in_bytes)); -} -std::vector -LocalShareForwardImpl::get_all_algorithms(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) { + megdnn_throw(ssprintf( + "no local share conv algorithm without attribute(%s) with " + "attribute(%s), args(%s) and " + "workspace limit (%zu bytes)", + Algorithm::attribute_str(negative_attr).c_str(), + Algorithm::attribute_str(positive_attr).c_str(), args.to_string().c_str(), + workspace_limit_in_bytes)); +} +std::vector LocalShareForwardImpl:: + get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) { AlgoBase::SizeArgs args{this, src, filter, dst}; return megdnn::get_all_algorithms(args); } -std::vector -LocalShareForwardImpl::get_all_algorithms_safe(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) { +std::vector LocalShareForwardImpl:: + get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) { AlgoBase::SizeArgs args{this, src, filter, dst}; return megdnn::get_all_algorithms_safe(args); } -size_t LocalShareForwardImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) { +size_t LocalShareForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, const TensorLayout& dst) { return get_dnn_workspace(this, src, filter, dst); } -void LocalShareForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) { +void LocalShareForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { check_exec(src.layout, filter.layout, dst.layout, workspace.size); AlgoBase::ExecArgs args(this, src, filter, dst, workspace); auto algo = get_algorithm(this, src.layout, filter.layout, dst.layout); @@ -84,12 +79,12 @@ const char* LocalShareForwardImpl::get_algorithm_set_name() const { } /* ============== LocalShareBackwardDataImpl ============== */ -LocalShareBackwardDataImpl::Algorithm* -LocalShareBackwardDataImpl::get_algorithm_heuristic( - const TensorLayout& filter, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { +LocalShareBackwardDataImpl::Algorithm* LocalShareBackwardDataImpl:: + get_algorithm_heuristic( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, filter, diff, grad); if (sm_algo_pack.implicit_gemm.is_available_attribute( args, positive_attr, negative_attr, workspace_limit_in_bytes)) { @@ -99,41 +94,40 @@ LocalShareBackwardDataImpl::get_algorithm_heuristic( args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return &sm_algo_pack.batched_matmul; } - megdnn_throw( - ssprintf("no local share bwd data algorithm without attribute(%s) " - "with attribute(%s) args(%s) and " - "workspace limit (%zu bytes)", - Algorithm::attribute_str(negative_attr).c_str(), - Algorithm::attribute_str(positive_attr).c_str(), - args.to_string().c_str(), workspace_limit_in_bytes)); -} - -std::vector -LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) { + megdnn_throw(ssprintf( + "no local share bwd data algorithm without attribute(%s) " + "with attribute(%s) args(%s) and " + "workspace limit (%zu bytes)", + Algorithm::attribute_str(negative_attr).c_str(), + Algorithm::attribute_str(positive_attr).c_str(), args.to_string().c_str(), + workspace_limit_in_bytes)); +} + +std::vector LocalShareBackwardDataImpl:: + get_all_algorithms( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) { AlgoBase::SizeArgs args{this, filter, diff, grad}; return megdnn::get_all_algorithms(args); } -std::vector -LocalShareBackwardDataImpl::get_all_algorithms_safe(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) { +std::vector LocalShareBackwardDataImpl:: + get_all_algorithms_safe( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) { AlgoBase::SizeArgs args{this, filter, diff, grad}; return megdnn::get_all_algorithms_safe(args); } -size_t LocalShareBackwardDataImpl::get_workspace_in_bytes(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) { +size_t LocalShareBackwardDataImpl::get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) { return get_dnn_workspace(this, filter, diff, grad); } -void LocalShareBackwardDataImpl::exec(_megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) { +void LocalShareBackwardDataImpl::exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { AlgoBase::ExecArgs args(this, filter, diff, grad, workspace); auto algo = get_algorithm(this, filter.layout, diff.layout, grad.layout); algo->check_workspace(args, workspace).exec(args); @@ -144,12 +138,12 @@ const char* LocalShareBackwardDataImpl::get_algorithm_set_name() const { } /* ============== LocalShareBackwardFilterImpl ============== */ -LocalShareBackwardFilterImpl::Algorithm* -LocalShareBackwardFilterImpl::get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { +LocalShareBackwardFilterImpl::Algorithm* LocalShareBackwardFilterImpl:: + get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { AlgoBase::SizeArgs args(this, src, diff, grad); if (sm_algo_pack.implicit_gemm.is_available_attribute( args, positive_attr, negative_attr, workspace_limit_in_bytes)) { @@ -159,42 +153,40 @@ LocalShareBackwardFilterImpl::get_algorithm_heuristic( args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return &sm_algo_pack.batched_matmul; } - megdnn_throw( - ssprintf("no local share bwd filter algorithm without " - "attribute(%s) with attribute(%s), " - "args(%s) and " - "workspace limit (%zu bytes)", - Algorithm::attribute_str(negative_attr).c_str(), - Algorithm::attribute_str(positive_attr).c_str(), - args.to_string().c_str(), workspace_limit_in_bytes)); -} - -std::vector -LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) { + megdnn_throw(ssprintf( + "no local share bwd filter algorithm without " + "attribute(%s) with attribute(%s), " + "args(%s) and " + "workspace limit (%zu bytes)", + Algorithm::attribute_str(negative_attr).c_str(), + Algorithm::attribute_str(positive_attr).c_str(), args.to_string().c_str(), + workspace_limit_in_bytes)); +} + +std::vector LocalShareBackwardFilterImpl:: + get_all_algorithms( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) { AlgoBase::SizeArgs args{this, src, diff, grad}; return megdnn::get_all_algorithms(args); } -std::vector -LocalShareBackwardFilterImpl::get_all_algorithms_safe(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) { +std::vector LocalShareBackwardFilterImpl:: + get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) { AlgoBase::SizeArgs args{this, src, diff, grad}; return megdnn::get_all_algorithms_safe(args); } -size_t LocalShareBackwardFilterImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) { +size_t LocalShareBackwardFilterImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) { return get_dnn_workspace(this, src, diff, grad); } -void LocalShareBackwardFilterImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) { +void LocalShareBackwardFilterImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) { AlgoBase::ExecArgs args(this, src, diff, grad, workspace); auto algo = get_algorithm(this, src.layout, diff.layout, grad.layout); algo->check_workspace(args, workspace).exec(args); diff --git a/dnn/src/cuda/local_share/opr_impl.h b/dnn/src/cuda/local_share/opr_impl.h index a261bed4..d532f29a 100644 --- a/dnn/src/cuda/local_share/opr_impl.h +++ b/dnn/src/cuda/local_share/opr_impl.h @@ -18,11 +18,12 @@ namespace cuda { class LocalShareForwardImpl : public LocalShareForward { public: using LocalShareForward::LocalShareForward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_out dst, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; const char* get_algorithm_set_name() const override; class AlgoBase; @@ -55,11 +56,12 @@ private: class LocalShareBackwardDataImpl : public LocalShareBackwardData { public: using LocalShareBackwardData::LocalShareBackwardData; - void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad) override; + void exec( + _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; const char* get_algorithm_set_name() const override; class AlgoBase; @@ -91,11 +93,12 @@ private: class LocalShareBackwardFilterImpl : public LocalShareBackwardFilter { public: using LocalShareBackwardFilter::LocalShareBackwardFilter; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; const char* get_algorithm_set_name() const override; class AlgoBase; @@ -115,9 +118,8 @@ protected: const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& diff, - const TensorLayout& grad, size_t workspace_limit_in_bytes, - const AlgoAttribute& positive_attr, + const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) override; private: diff --git a/dnn/src/cuda/lrn/opr_impl.cpp b/dnn/src/cuda/lrn/opr_impl.cpp index 50afa1d1..ffc37c23 100644 --- a/dnn/src/cuda/lrn/opr_impl.cpp +++ b/dnn/src/cuda/lrn/opr_impl.cpp @@ -15,34 +15,26 @@ namespace megdnn { namespace cuda { -void LRNForwardImpl::setup_descs(const TensorLayout &src, - const TensorLayout &dst) -{ +void LRNForwardImpl::setup_descs(const TensorLayout& src, const TensorLayout& dst) { src_desc.set(src); dst_desc.set(dst); lrn_desc.set(this->param()); } -void LRNForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) -{ +void LRNForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); auto handle = cudnn_handle(this->handle()); setup_descs(src.layout, dst.layout); float alpha = 1.0f, beta = 0.0f; - cudnn_check(cudnnLRNCrossChannelForward(handle, - lrn_desc.desc, - CUDNN_LRN_CROSS_CHANNEL_DIM1, - &alpha, src_desc.desc, src.raw_ptr, - &beta, dst_desc.desc, dst.raw_ptr)); + cudnn_check(cudnnLRNCrossChannelForward( + handle, lrn_desc.desc, CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, src_desc.desc, + src.raw_ptr, &beta, dst_desc.desc, dst.raw_ptr)); } -void LRNBackwardImpl::setup_descs(const TensorLayout &src, - const TensorLayout &dst, - const TensorLayout &diff, - const TensorLayout &grad) -{ +void LRNBackwardImpl::setup_descs( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad) { src_desc.set(src); dst_desc.set(dst); diff_desc.set(diff); @@ -50,30 +42,20 @@ void LRNBackwardImpl::setup_descs(const TensorLayout &src, lrn_desc.set(this->param()); } -void LRNBackwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in dst, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) -{ - check_exec(src.layout, dst.layout, diff.layout, grad.layout, - workspace.size); +void LRNBackwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) { + check_exec(src.layout, dst.layout, diff.layout, grad.layout, workspace.size); auto handle = cudnn_handle(this->handle()); setup_descs(src.layout, dst.layout, diff.layout, grad.layout); float alpha = 1.0f, beta = 0.0f; - cudnn_check(cudnnLRNCrossChannelBackward(handle, - lrn_desc.desc, - CUDNN_LRN_CROSS_CHANNEL_DIM1, - &alpha, - dst_desc.desc, dst.raw_ptr, - diff_desc.desc, diff.raw_ptr, - src_desc.desc, src.raw_ptr, - &beta, - grad_desc.desc, grad.raw_ptr)); + cudnn_check(cudnnLRNCrossChannelBackward( + handle, lrn_desc.desc, CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dst_desc.desc, + dst.raw_ptr, diff_desc.desc, diff.raw_ptr, src_desc.desc, src.raw_ptr, + &beta, grad_desc.desc, grad.raw_ptr)); } -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/lrn/opr_impl.h b/dnn/src/cuda/lrn/opr_impl.h index 1a91f2ef..aea33121 100644 --- a/dnn/src/cuda/lrn/opr_impl.h +++ b/dnn/src/cuda/lrn/opr_impl.h @@ -16,46 +16,43 @@ namespace megdnn { namespace cuda { -class LRNForwardImpl final: public LRNForward { - public: - using LRNForward::LRNForward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &) override { - return 0; - } - private: - TensorDesc src_desc, dst_desc; - LRNDesc lrn_desc; - void setup_descs(const TensorLayout &src, const TensorLayout &dst); +class LRNForwardImpl final : public LRNForward { +public: + using LRNForward::LRNForward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { + return 0; + } + +private: + TensorDesc src_desc, dst_desc; + LRNDesc lrn_desc; + void setup_descs(const TensorLayout& src, const TensorLayout& dst); }; -class LRNBackwardImpl final: public LRNBackward { - public: - using LRNBackward::LRNBackward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in dst, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout &, - const TensorLayout &, - const TensorLayout &, - const TensorLayout &) override { - return 0; - } - private: - TensorDesc src_desc, dst_desc, diff_desc, grad_desc; - LRNDesc lrn_desc; - void setup_descs(const TensorLayout &src, - const TensorLayout &dst, - const TensorLayout &diff, - const TensorLayout &grad); +class LRNBackwardImpl final : public LRNBackward { +public: + using LRNBackward::LRNBackward; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } + +private: + TensorDesc src_desc, dst_desc, diff_desc, grad_desc; + LRNDesc lrn_desc; + void setup_descs( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad); }; -} // namespace cuda -} // namespace megdnn +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/lsq/kern.cu b/dnn/src/cuda/lsq/kern.cu index 74950fc7..6bdcdc9f 100644 --- a/dnn/src/cuda/lsq/kern.cu +++ b/dnn/src/cuda/lsq/kern.cu @@ -15,15 +15,17 @@ namespace megdnn { namespace cuda { -#define cb(_dtype) \ - INST_RUN_ELEMWISE(LSQKernOp::ctype>, \ - DTypeTrait<_dtype>::ctype, 3); \ - INST_RUN_ELEMWISE(LSQBwdKernOp::ctype>, \ - DTypeTrait<_dtype>::ctype, 3); \ - INST_RUN_ELEMWISE(LSQKernOpNonContig::ctype>, \ - DTypeTrait<_dtype>::ctype, 5); \ - INST_RUN_ELEMWISE(LSQBwdKernOpNonContig::ctype>, \ - DTypeTrait<_dtype>::ctype, 7); +#define cb(_dtype) \ + INST_RUN_ELEMWISE( \ + LSQKernOp::ctype>, DTypeTrait<_dtype>::ctype, 3); \ + INST_RUN_ELEMWISE( \ + LSQBwdKernOp::ctype>, DTypeTrait<_dtype>::ctype, 3); \ + INST_RUN_ELEMWISE( \ + LSQKernOpNonContig::ctype>, DTypeTrait<_dtype>::ctype, \ + 5); \ + INST_RUN_ELEMWISE( \ + LSQBwdKernOpNonContig::ctype>, \ + DTypeTrait<_dtype>::ctype, 7); cb(megdnn::dtype::Float32) } // namespace cuda diff --git a/dnn/src/cuda/lsq/kern.cuh b/dnn/src/cuda/lsq/kern.cuh index 6bed31be..70db5ccd 100644 --- a/dnn/src/cuda/lsq/kern.cuh +++ b/dnn/src/cuda/lsq/kern.cuh @@ -27,8 +27,8 @@ struct LSQKernOp { ctype* output; ctype qmin, qmax; - __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point, - ctype grad_scale) { + __device__ void operator()( + uint32_t idx, ctype scale, ctype zero_point, ctype grad_scale) { ctype x = input[idx] / scale + zero_point; x = fmaxf(fminf(x, qmax), qmin); x = round(x); @@ -36,8 +36,7 @@ struct LSQKernOp { } #if MEGDNN_CC_HOST - LSQKernOp(const TensorND& input, const TensorND& output, - const LSQ::Param& param) + LSQKernOp(const TensorND& input, const TensorND& output, const LSQ::Param& param) : input{input.ptr()}, output{output.ptr()}, qmin(param.qmin), @@ -53,23 +52,22 @@ struct LSQBwdKernOp { ctype* grad_s; ctype qmin, qmax; - __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point, - ctype grad_scale) { + __device__ void operator()( + uint32_t idx, ctype scale, ctype zero_point, ctype grad_scale) { ctype x = input[idx] / scale + zero_point; bool ind_small = x < qmin; bool ind_big = x > qmax; bool ind_middle = ind_small ^ ind_big; ind_middle = !ind_middle; - grad_s[idx] = ind_small * qmin + ind_big * qmax + - ind_middle * (-x + round(x)); + grad_s[idx] = ind_small * qmin + ind_big * qmax + ind_middle * (-x + round(x)); grad_s[idx] = grad_s[idx] * grad_scale * diff[idx]; grad_x[idx] = ind_middle * diff[idx]; } #if MEGDNN_CC_HOST - LSQBwdKernOp(const TensorND& diff, const TensorND& input, - const TensorND& grad_x, const TensorND& grad_s, - const LSQ::Param& param) + LSQBwdKernOp( + const TensorND& diff, const TensorND& input, const TensorND& grad_x, + const TensorND& grad_s, const LSQ::Param& param) : diff{diff.ptr()}, input{input.ptr()}, grad_x{grad_x.ptr()}, @@ -84,17 +82,16 @@ struct LSQKernOpNonContig { ctype qmin; ctype qmax; - __device__ void operator()(uint32_t, ctype& output, ctype& input, - ctype& scale, ctype& zero_point, - ctype grad_scale) { + __device__ void operator()( + uint32_t, ctype& output, ctype& input, ctype& scale, ctype& zero_point, + ctype grad_scale) { ctype x = input / scale + zero_point; x = fmaxf(fminf(x, qmax), qmin); x = round(x); output = (x - zero_point) * scale; } #if MEGDNN_CC_HOST - LSQKernOpNonContig(const LSQ::Param& param) - : qmin(param.qmin), qmax(param.qmax) {} + LSQKernOpNonContig(const LSQ::Param& param) : qmin(param.qmin), qmax(param.qmax) {} #endif }; @@ -103,16 +100,15 @@ struct LSQBwdKernOpNonContig { ctype qmin; ctype qmax; - __device__ void operator()(uint32_t, ctype& grad_x, ctype& grad_s, - ctype& diff, ctype& input, ctype& scale, - ctype& zero_point, ctype grad_scale) { + __device__ void operator()( + uint32_t, ctype& grad_x, ctype& grad_s, ctype& diff, ctype& input, + ctype& scale, ctype& zero_point, ctype grad_scale) { ctype x = input / scale + zero_point; bool ind_small = x < qmin; bool ind_big = x > qmax; bool ind_middle = ind_small ^ ind_big; ind_middle = !ind_middle; - grad_s = ind_small * qmin + ind_big * qmax + - ind_middle * (-x + round(x)); + grad_s = ind_small * qmin + ind_big * qmax + ind_middle * (-x + round(x)); grad_s = grad_s * grad_scale * diff; grad_x = ind_middle * diff; } diff --git a/dnn/src/cuda/lsq/opr_impl.cpp b/dnn/src/cuda/lsq/opr_impl.cpp index d4338c09..d0a87205 100644 --- a/dnn/src/cuda/lsq/opr_impl.cpp +++ b/dnn/src/cuda/lsq/opr_impl.cpp @@ -16,13 +16,13 @@ namespace megdnn { namespace cuda { -void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, - _megdnn_tensor_in grad_scale, - _megdnn_tensor_out output, - _megdnn_workspace workspace) { - check_exec(input.layout, scale.layout, zero_point.layout, grad_scale.layout, - output.layout, workspace.size); +void LSQForwardImpl::exec( + _megdnn_tensor_in input, _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, _megdnn_tensor_out output, + _megdnn_workspace workspace) { + check_exec( + input.layout, scale.layout, zero_point.layout, grad_scale.layout, + output.layout, workspace.size); if (!input.layout.is_contiguous() || !output.layout.is_contiguous()) return exec_noncontig(input, scale, zero_point, grad_scale, output); @@ -38,22 +38,19 @@ void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, auto m_param = param(); auto stream = cuda_stream(handle()); -#define cb(DType) \ - if (input.layout.dtype == DType()) { \ - using T = typename DTypeTrait::ctype; \ - run_elemwise, T, 3>(ele_param, stream, \ - {input, output, m_param}); \ - return; \ +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 3>(ele_param, stream, {input, output, m_param}); \ + return; \ } cb(megdnn::dtype::Float32) #undef cb } -void LSQForwardImpl::exec_noncontig(_megdnn_tensor_in input, - _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, - _megdnn_tensor_in grad_scale, - _megdnn_tensor_out output) { +void LSQForwardImpl::exec_noncontig( + _megdnn_tensor_in input, _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, + _megdnn_tensor_in grad_scale, _megdnn_tensor_out output) { ElemwiseOpParamN<5> ele_param; ele_param[0] = output; ele_param[1] = input; @@ -67,30 +64,29 @@ void LSQForwardImpl::exec_noncontig(_megdnn_tensor_in input, auto m_param = param(); auto stream = cuda_stream(handle()); -#define cb(DType) \ - if (input.layout.dtype == DType()) { \ - using T = typename DTypeTrait::ctype; \ - run_elemwise, T, 5>(ele_param, stream, \ - {m_param}); \ - return; \ +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 5>(ele_param, stream, {m_param}); \ + return; \ } cb(megdnn::dtype::Float32) #undef cb } -void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, - _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, - _megdnn_tensor_in grad_scale, - _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, - _megdnn_workspace workspace) { - check_exec(diff.layout, input.layout, scale.layout, zero_point.layout, - grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size); +void LSQBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, + _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, + _megdnn_workspace workspace) { + check_exec( + diff.layout, input.layout, scale.layout, zero_point.layout, + grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size); if (!input.layout.is_contiguous() || !diff.layout.is_contiguous() || !grad_x.layout.is_contiguous() || !grad_s.layout.is_contiguous()) - return exec_noncontig(diff, input, scale, zero_point, grad_scale, - grad_x, grad_s); + return exec_noncontig( + diff, input, scale, zero_point, grad_scale, grad_x, grad_s); ElemwiseOpParamN<3> ele_param; ele_param[0] = scale; @@ -114,13 +110,10 @@ void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, #undef cb } -void LSQBackwardImpl::exec_noncontig(_megdnn_tensor_in diff, - _megdnn_tensor_in input, - _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, - _megdnn_tensor_in grad_scale, - _megdnn_tensor_out grad_x, - _megdnn_tensor_out grad_s) { +void LSQBackwardImpl::exec_noncontig( + _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, + _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s) { ElemwiseOpParamN<7> ele_param; ele_param[0] = grad_x; ele_param[1] = grad_s; @@ -136,12 +129,11 @@ void LSQBackwardImpl::exec_noncontig(_megdnn_tensor_in diff, auto m_param = param(); auto stream = cuda_stream(handle()); -#define cb(DType) \ - if (input.layout.dtype == DType()) { \ - using T = typename DTypeTrait::ctype; \ - run_elemwise, T, 7>(ele_param, stream, \ - {m_param}); \ - return; \ +#define cb(DType) \ + if (input.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + run_elemwise, T, 7>(ele_param, stream, {m_param}); \ + return; \ } cb(megdnn::dtype::Float32) #undef cb diff --git a/dnn/src/cuda/lsq/opr_impl.h b/dnn/src/cuda/lsq/opr_impl.h index aba0caf4..e066a566 100644 --- a/dnn/src/cuda/lsq/opr_impl.h +++ b/dnn/src/cuda/lsq/opr_impl.h @@ -19,46 +19,47 @@ namespace cuda { class LSQForwardImpl final : public LSQForward { public: using LSQForward::LSQForward; - void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, - _megdnn_tensor_out output, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, /* input */ - const TensorLayout&, /* scale */ - const TensorLayout&, /* zero_point */ - const TensorLayout&, /* grad_scale */ - const TensorLayout& /* output */) override { + void exec( + _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, + _megdnn_tensor_out output, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, /* input */ + const TensorLayout&, /* scale */ + const TensorLayout&, /* zero_point */ + const TensorLayout&, /* grad_scale */ + const TensorLayout& /* output */) override { return 0; } private: - void exec_noncontig(_megdnn_tensor_in input, _megdnn_tensor_in scale, - _megdnn_tensor_in zero_point, - _megdnn_tensor_in grad_scale, - _megdnn_tensor_out output); + void exec_noncontig( + _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, + _megdnn_tensor_out output); }; class LSQBackwardImpl final : public LSQBackward { public: using LSQBackward::LSQBackward; - void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, - _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, - _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, - _megdnn_tensor_out grad_s, _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& /* diff */, - const TensorLayout& /* input */, - const TensorLayout& /* scale */, - const TensorLayout& /* zero_point */, - const TensorLayout& /* grad_scale */, - const TensorLayout& /* grad_x */, - const TensorLayout& /* grad_s */) override { + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, + _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& /* diff */, const TensorLayout& /* input */, + const TensorLayout& /* scale */, const TensorLayout& /* zero_point */, + const TensorLayout& /* grad_scale */, const TensorLayout& /* grad_x */, + const TensorLayout& /* grad_s */) override { return 0; } private: - void exec_noncontig(_megdnn_tensor_in diff, _megdnn_tensor_in input, - _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, - _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, - _megdnn_tensor_out grad_s); + void exec_noncontig( + _megdnn_tensor_in diff, _megdnn_tensor_in input, _megdnn_tensor_in scale, + _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, + _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s); }; } // namespace cuda diff --git a/dnn/src/cuda/mask_conv/mask_conv.cu b/dnn/src/cuda/mask_conv/mask_conv.cu index c89638d8..12cb05f6 100644 --- a/dnn/src/cuda/mask_conv/mask_conv.cu +++ b/dnn/src/cuda/mask_conv/mask_conv.cu @@ -16,8 +16,8 @@ namespace { template -__global__ void set_zero_by_mask_kernel(float* dst, const ctype* mask, size_t N, - size_t mask_size) { +__global__ void set_zero_by_mask_kernel( + float* dst, const ctype* mask, size_t N, size_t mask_size) { int dst_offset = blockIdx.x * blockDim.x + threadIdx.x; int mask_idx = blockIdx.y * blockDim.y + threadIdx.y; if (dst_offset >= N || mask_idx >= mask_size) { @@ -29,11 +29,10 @@ __global__ void set_zero_by_mask_kernel(float* dst, const ctype* mask, size_t N, } template -__global__ void mask_propagate_kernel(const ctype* src, ctype* dst, size_t IH, - size_t IW, size_t OH, size_t OW, - size_t FH, size_t FW, size_t SH, - size_t SW, size_t PH, size_t PW, - size_t DH, size_t DW) { +__global__ void mask_propagate_kernel( + const ctype* src, ctype* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t FH, size_t FW, size_t SH, size_t SW, size_t PH, size_t PW, size_t DH, + size_t DW) { int dst_idx = blockIdx.x * blockDim.x + threadIdx.x; if (dst_idx >= OH * OW) { return; @@ -45,8 +44,7 @@ __global__ void mask_propagate_kernel(const ctype* src, ctype* dst, size_t IH, for (int fw = 0; fw < FW; ++fw) { int ih = oh * SH + fh * DH - PH; int iw = ow * SW + fw * DW - PW; - if (ih < 0 || ih >= IH || iw < 0 || iw >= IW || - src[ih * IW + iw] == 0) { + if (ih < 0 || ih >= IH || iw < 0 || iw >= IW || src[ih * IW + iw] == 0) { continue; } dst[dst_idx] = 1; @@ -62,8 +60,9 @@ namespace cuda { namespace mask_conv { template -void set_zero_by_mask_proxy(float* dst, const ctype* mask, size_t N, size_t OC, - size_t OH, size_t OW, cudaStream_t stream) { +void set_zero_by_mask_proxy( + float* dst, const ctype* mask, size_t N, size_t OC, size_t OH, size_t OW, + cudaStream_t stream) { dim3 threads(NR_THREADS_X, NR_THREADS_Y); dim3 blocks(DIVUP(N * OC, threads.x), DIVUP(OH * OW, threads.y)); set_zero_by_mask_kernel @@ -71,25 +70,23 @@ void set_zero_by_mask_proxy(float* dst, const ctype* mask, size_t N, size_t OC, } template -void mask_propagate_exec_proxy(const ctype* src, ctype* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t FH, - size_t FW, size_t SH, size_t SW, size_t PH, - size_t PW, size_t DH, size_t DW, - cudaStream_t stream) { - mask_propagate_kernel - <<>>( - src, dst, IH, IW, OH, OW, FH, FW, SH, SW, PH, PW, DH, DW); +void mask_propagate_exec_proxy( + const ctype* src, ctype* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t FH, size_t FW, size_t SH, size_t SW, size_t PH, size_t PW, size_t DH, + size_t DW, cudaStream_t stream) { + mask_propagate_kernel<<>>( + src, dst, IH, IW, OH, OW, FH, FW, SH, SW, PH, PW, DH, DW); } -#define INST(ctype) \ - template void mask_propagate_exec_proxy( \ - const ctype* src, ctype* dst, size_t IH, size_t IW, size_t OH, \ - size_t OW, size_t FH, size_t FW, size_t SH, size_t SW, size_t PH, \ - size_t PW, size_t DH, size_t DW, cudaStream_t stream); \ - \ - template void set_zero_by_mask_proxy( \ - float* dst, const ctype* mask, size_t N, size_t OC, size_t OH, \ - size_t OW, cudaStream_t stream); +#define INST(ctype) \ + template void mask_propagate_exec_proxy( \ + const ctype* src, ctype* dst, size_t IH, size_t IW, size_t OH, size_t OW, \ + size_t FH, size_t FW, size_t SH, size_t SW, size_t PH, size_t PW, \ + size_t DH, size_t DW, cudaStream_t stream); \ + \ + template void set_zero_by_mask_proxy( \ + float* dst, const ctype* mask, size_t N, size_t OC, size_t OH, size_t OW, \ + cudaStream_t stream); #define cb(DType) INST(DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) diff --git a/dnn/src/cuda/mask_conv/mask_conv.cuh b/dnn/src/cuda/mask_conv/mask_conv.cuh index 31ffd6a7..25e7d6a3 100644 --- a/dnn/src/cuda/mask_conv/mask_conv.cuh +++ b/dnn/src/cuda/mask_conv/mask_conv.cuh @@ -14,15 +14,15 @@ namespace cuda { namespace mask_conv { template -void set_zero_by_mask_proxy(float* dst, const ctype* mask, size_t N, size_t OC, - size_t OH, size_t OW, cudaStream_t stream); +void set_zero_by_mask_proxy( + float* dst, const ctype* mask, size_t N, size_t OC, size_t OH, size_t OW, + cudaStream_t stream); template -void mask_propagate_exec_proxy(const ctype* src, ctype* dst, size_t IH, - size_t IW, size_t OH, size_t OW, size_t FH, - size_t FW, size_t SH, size_t SW, size_t PH, - size_t PW, size_t DH, size_t DW, - cudaStream_t stream); +void mask_propagate_exec_proxy( + const ctype* src, ctype* dst, size_t IH, size_t IW, size_t OH, size_t OW, + size_t FH, size_t FW, size_t SH, size_t SW, size_t PH, size_t PW, size_t DH, + size_t DW, cudaStream_t stream); } // namespace mask_conv diff --git a/dnn/src/cuda/mask_conv/opr_impl.cpp b/dnn/src/cuda/mask_conv/opr_impl.cpp index e3b29597..24a88464 100644 --- a/dnn/src/cuda/mask_conv/opr_impl.cpp +++ b/dnn/src/cuda/mask_conv/opr_impl.cpp @@ -16,26 +16,26 @@ namespace megdnn { namespace cuda { -MaskConvForwardImpl::MaskConvForwardImpl(Handle* handle) - : MaskConvForward(handle) { - m_conv_opr = static_cast(handle) - ->create_operator(); +MaskConvForwardImpl::MaskConvForwardImpl(Handle* handle) : MaskConvForward(handle) { + m_conv_opr = + static_cast(handle)->create_operator(); } -void MaskConvForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_in mask, _megdnn_tensor_out dst, - _megdnn_workspace workspace) { - megdnn_assert(dst.layout.dtype.enumv() == DTypeTrait::enumv, - "Mask conv only support Float32 dtype."); +void MaskConvForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in mask, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + megdnn_assert( + dst.layout.dtype.enumv() == DTypeTrait::enumv, + "Mask conv only support Float32 dtype."); m_conv_opr->exec(src, filter, dst, nullptr, workspace); auto stream = cuda_stream(handle()); -#define cb(DType) \ - if (mask.layout.dtype == DType()) { \ - using ctype = typename DTypeTrait::ctype; \ - mask_conv::set_zero_by_mask_proxy( \ - dst.ptr(), mask.ptr(), dst.layout[0], \ - dst.layout[1], dst.layout[2], dst.layout[3], stream); \ - return; \ +#define cb(DType) \ + if (mask.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + mask_conv::set_zero_by_mask_proxy( \ + dst.ptr(), mask.ptr(), dst.layout[0], dst.layout[1], \ + dst.layout[2], dst.layout[3], stream); \ + return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) @@ -43,20 +43,19 @@ void MaskConvForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, megdnn_assert_internal(0); } -void MaskPropagateImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace) { +void MaskPropagateImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace) { auto stream = cuda_stream(handle()); -#define cb(DType) \ - if (src.layout.dtype == DType()) { \ - using ctype = typename DTypeTrait::ctype; \ - mask_conv::mask_propagate_exec_proxy( \ - src.ptr(), dst.ptr(), src.layout[0], \ - src.layout[1], dst.layout[0], dst.layout[1], param().kernel_h, \ - param().kernel_w, param().stride_h, param().stride_w, \ - param().pad_h, param().pad_w, param().dilate_h, \ - param().dilate_w, stream); \ - return; \ +#define cb(DType) \ + if (src.layout.dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + mask_conv::mask_propagate_exec_proxy( \ + src.ptr(), dst.ptr(), src.layout[0], src.layout[1], \ + dst.layout[0], dst.layout[1], param().kernel_h, param().kernel_w, \ + param().stride_h, param().stride_w, param().pad_h, param().pad_w, \ + param().dilate_h, param().dilate_w, stream); \ + return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb); diff --git a/dnn/src/cuda/mask_conv/opr_impl.h b/dnn/src/cuda/mask_conv/opr_impl.h index 03e18059..7fb6ad99 100644 --- a/dnn/src/cuda/mask_conv/opr_impl.h +++ b/dnn/src/cuda/mask_conv/opr_impl.h @@ -20,14 +20,13 @@ class MaskConvForwardImpl : public MaskConvForward { public: MaskConvForwardImpl(Handle* handle); - void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, - _megdnn_tensor_in mask, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& mask, - const TensorLayout& dst) override { + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in mask, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& mask, const TensorLayout& dst) override { MEGDNN_MARK_USED_VAR(mask); m_conv_opr->param() = param(); return m_conv_opr->get_workspace_in_bytes(src, filter, dst, nullptr); @@ -41,10 +40,11 @@ class MaskPropagateImpl : public MaskPropagate { public: MaskPropagateImpl(Handle* handle) : MaskPropagate(handle) {} - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace worksapce) override final; - size_t get_workspace_in_bytes(const TensorLayout&, - const TensorLayout&) override final { + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace worksapce) override final; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&) override final { return 0; } }; diff --git a/dnn/src/cuda/matrix_inverse/helper.cu b/dnn/src/cuda/matrix_inverse/helper.cu index ea184778..270a0f68 100644 --- a/dnn/src/cuda/matrix_inverse/helper.cu +++ b/dnn/src/cuda/matrix_inverse/helper.cu @@ -18,30 +18,30 @@ using namespace matrix_inverse; namespace { -__global__ void kern_check_error(const int* src_info, uint32_t n, - megcore::AsyncErrorInfo* dst_info, - void* tracker) { +__global__ void kern_check_error( + const int* src_info, uint32_t n, megcore::AsyncErrorInfo* dst_info, + void* tracker) { uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n && src_info[i]) { - set_async_error_info(dst_info, tracker, - "The U is exactly singular and the inversion " - "failed on %d-th input matrix (U(%d, %d) = 0)", i, - src_info[i], src_info[i]); + set_async_error_info( + dst_info, tracker, + "The U is exactly singular and the inversion " + "failed on %d-th input matrix (U(%d, %d) = 0)", + i, src_info[i], src_info[i]); } } } // anonymous namespace -void matrix_inverse::check_error(const int* src_info, uint32_t n, - megcore::AsyncErrorInfo* dst_info, - void* tracker, cudaStream_t stream) { +void matrix_inverse::check_error( + const int* src_info, uint32_t n, megcore::AsyncErrorInfo* dst_info, + void* tracker, cudaStream_t stream) { if (!dst_info) { return; } uint32_t threads = NR_THREADS; uint32_t blocks = DIVUP(n, threads); - kern_check_error<<>>(src_info, n, dst_info, - tracker); + kern_check_error<<>>(src_info, n, dst_info, tracker); after_kernel_launch(); } diff --git a/dnn/src/cuda/matrix_inverse/helper.cuh b/dnn/src/cuda/matrix_inverse/helper.cuh index 61c6ff57..a969b465 100644 --- a/dnn/src/cuda/matrix_inverse/helper.cuh +++ b/dnn/src/cuda/matrix_inverse/helper.cuh @@ -17,9 +17,9 @@ namespace megdnn { namespace cuda { namespace matrix_inverse { -void check_error(const int* src_info, uint32_t n, - megcore::AsyncErrorInfo* dst_info, void* tracker, - cudaStream_t stream); +void check_error( + const int* src_info, uint32_t n, megcore::AsyncErrorInfo* dst_info, + void* tracker, cudaStream_t stream); } // namespace matrix_inverse } // namespace cuda diff --git a/dnn/src/cuda/matrix_inverse/opr_impl.cpp b/dnn/src/cuda/matrix_inverse/opr_impl.cpp index ef96a45c..fe775708 100644 --- a/dnn/src/cuda/matrix_inverse/opr_impl.cpp +++ b/dnn/src/cuda/matrix_inverse/opr_impl.cpp @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "./helper.cuh" #include "./opr_impl.h" +#include "./helper.cuh" #include "src/cuda/batched_matrix_mul/helper.cuh" #include "src/cuda/handle.h" #include "src/cuda/utils.h" @@ -21,11 +21,12 @@ size_t MatrixInverseImpl::get_workspace_in_bytes(size_t batch, size_t, size_t) { return batch * (sizeof(int) + sizeof(void*) + sizeof(void*)); } -void MatrixInverseImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) { - megdnn_assert(src.layout.dtype == dtype::Float32(), - "Matrix Inverse only support Float32 dtype, got: %s", - src.layout.dtype.name()); +void MatrixInverseImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { + megdnn_assert( + src.layout.dtype == dtype::Float32(), + "Matrix Inverse only support Float32 dtype, got: %s", + src.layout.dtype.name()); size_t batch, n; check_exec(src.layout, dst.layout, workspace, &batch, &n); auto handle = concrete_handle(this->handle()); @@ -36,17 +37,16 @@ void MatrixInverseImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, auto stream = handle->stream(); batched_matrix_mul::arange( reinterpret_cast(psrc_batch), - reinterpret_cast(src.raw_ptr), n * n * sizeof(float), - batch, stream); + reinterpret_cast(src.raw_ptr), n * n * sizeof(float), batch, + stream); batched_matrix_mul::arange( reinterpret_cast(pdst_batch), - reinterpret_cast(dst.raw_ptr), n * n * sizeof(float), - batch, stream); - cublas_check(cublasSmatinvBatched(handle->cublas_handle(), n, psrc_batch, n, - pdst_batch, n, info, batch)); - matrix_inverse::check_error(info, batch, - handle->megcore_context().error_info, - m_error_tracker, stream); + reinterpret_cast(dst.raw_ptr), n * n * sizeof(float), batch, + stream); + cublas_check(cublasSmatinvBatched( + handle->cublas_handle(), n, psrc_batch, n, pdst_batch, n, info, batch)); + matrix_inverse::check_error( + info, batch, handle->megcore_context().error_info, m_error_tracker, stream); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_inverse/opr_impl.h b/dnn/src/cuda/matrix_inverse/opr_impl.h index 40bb96f4..c6fac603 100644 --- a/dnn/src/cuda/matrix_inverse/opr_impl.h +++ b/dnn/src/cuda/matrix_inverse/opr_impl.h @@ -17,17 +17,15 @@ namespace cuda { class MatrixInverseImpl : public MatrixInverse { public: using MatrixInverse::MatrixInverse; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } protected: void* m_error_tracker = nullptr; - size_t get_workspace_in_bytes(size_t batch, size_t n, - size_t dtype_size) override; + size_t get_workspace_in_bytes(size_t batch, size_t n, size_t dtype_size) override; }; } // namespace cuda diff --git a/dnn/src/cuda/matrix_mul/algos.cpp b/dnn/src/cuda/matrix_mul/algos.cpp index 93d2450d..1aaf8609 100644 --- a/dnn/src/cuda/matrix_mul/algos.cpp +++ b/dnn/src/cuda/matrix_mul/algos.cpp @@ -135,17 +135,14 @@ MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; MEGDNN_DEF_GET_ALGO_FROM_DESC(MatrixMulForwardImpl) -MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs(MatrixMulForwardImpl* o, - const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) +MatrixMulForwardImpl::AlgoBase::SizeArgs::SizeArgs( + MatrixMulForwardImpl* o, const TensorLayout& A, const TensorLayout& B, + const TensorLayout& C) : opr{o}, layout_a{A}, layout_b{B}, layout_c{C} {} -MatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs(MatrixMulForwardImpl* opr, - _megdnn_tensor_in A, - _megdnn_tensor_in B, - _megdnn_tensor_out C, - _megdnn_workspace workspace) +MatrixMulForwardImpl::AlgoBase::ExecArgs::ExecArgs( + MatrixMulForwardImpl* opr, _megdnn_tensor_in A, _megdnn_tensor_in B, + _megdnn_tensor_out C, _megdnn_workspace workspace) : SizeArgs(opr, A.layout, B.layout, C.layout), tensor_a{A}, tensor_b{B}, @@ -162,8 +159,8 @@ std::string MatrixMulForwardImpl::AlgoBase::SizeArgs::to_string() const { return ssprintf( "A={%zux%zu},B={%zux%zu},C={%zux%zu},Transpose A=%d,Transpose " "B=%d,ldA=%zu,ldB=%zu,ldC=%zu", - m, k, k, n, m, n, param.transposeA, param.transposeB, - layout_a.stride[0], layout_b.stride[0], layout_c.stride[0]); + m, k, k, n, m, n, param.transposeA, param.transposeB, layout_a.stride[0], + layout_b.stride[0], layout_c.stride[0]); } // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_mul/algos.h b/dnn/src/cuda/matrix_mul/algos.h index 193d08f5..7b76b3a9 100644 --- a/dnn/src/cuda/matrix_mul/algos.h +++ b/dnn/src/cuda/matrix_mul/algos.h @@ -59,8 +59,9 @@ public: TensorLayout layout_a, layout_b, layout_c; std::string to_string() const; - SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A, - const TensorLayout& B, const TensorLayout& C); + SizeArgs( + MatrixMulForwardImpl* opr, const TensorLayout& A, const TensorLayout& B, + const TensorLayout& C); bool can_be_treated_as_int8x8x32() const { return layout_a.dtype.enumv() == layout_b.dtype.enumv() && @@ -75,9 +76,9 @@ public: TensorND tensor_a, tensor_b, tensor_c; Workspace workspace; - ExecArgs(MatrixMulForwardImpl* opr, _megdnn_tensor_in A, - _megdnn_tensor_in B, _megdnn_tensor_out C, - _megdnn_workspace workspace); + ExecArgs( + MatrixMulForwardImpl* opr, _megdnn_tensor_in A, _megdnn_tensor_in B, + _megdnn_tensor_out C, _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; @@ -92,16 +93,14 @@ public: const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT, size_t limit = std::numeric_limits::max()) const { return contain_attribute_all(positive_attr) && - !contain_attribute_any(negative_attr) && - is_available_wk(args, limit); + !contain_attribute_any(negative_attr) && is_available_wk(args, limit); } - AlgoBase& check_workspace(const SizeArgs& args, - const Workspace& workspace) { + AlgoBase& check_workspace(const SizeArgs& args, const Workspace& workspace) { auto req = get_workspace_in_bytes(args); megdnn_assert( req <= workspace.size, - "matrix mul fwd algo %s: required workspace %zu bytes, got %zu", - name(), req, workspace.size); + "matrix mul fwd algo %s: required workspace %zu bytes, got %zu", name(), + req, workspace.size); return *this; } }; @@ -117,8 +116,7 @@ public: void exec(const ExecArgs& args) const override; MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE | + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; } }; @@ -132,9 +130,7 @@ public: const char* name() const override { return "UINT4x4x32_WMMA"; } void exec(const ExecArgs& args) const override; MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } }; #endif #if CUDA_VERSION >= 10010 @@ -146,8 +142,7 @@ public: void exec(const ExecArgs& args) const override; MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; } }; #endif @@ -176,14 +171,11 @@ public: MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) std::vector get_subopr_list( - const TensorLayoutArray& layouts, - const OperatorBase* opr) const override; + const TensorLayoutArray& layouts, const OperatorBase* opr) const override; const char* name() const override { return "MATMUL_BFLOAT16"; } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } private: WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; @@ -229,9 +221,10 @@ public: int threadblock_m, threadblock_n, threadblock_k; int warp_m, warp_n, warp_k; int instruction_m, instruction_n, instruction_k; - AlgoParam(int threadblock_m_, int threadblock_n_, int threadblock_k_, - int warp_m_, int warp_n_, int warp_k_, int instruction_m_ = 1, - int instruction_n_ = 1, int instruction_k_ = 1) + AlgoParam( + int threadblock_m_, int threadblock_n_, int threadblock_k_, int warp_m_, + int warp_n_, int warp_k_, int instruction_m_ = 1, + int instruction_n_ = 1, int instruction_k_ = 1) : threadblock_m{threadblock_m_}, threadblock_n{threadblock_n_}, threadblock_k{threadblock_k_}, @@ -260,20 +253,17 @@ protected: AlgoParam m_algo_param; }; -class MatrixMulForwardImpl::AlgoFloat32SIMT final - : public AlgoCutlassMatrixMulBase { +class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoCutlassMatrixMulBase { public: AlgoFloat32SIMT(AlgoParam algo_param) : AlgoCutlassMatrixMulBase{algo_param}, - m_name{ssprintf("CUTLASS_FLOAT32_SIMT_%s", - m_algo_param.to_string().c_str())} {} + m_name{ssprintf( + "CUTLASS_FLOAT32_SIMT_%s", m_algo_param.to_string().c_str())} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; const char* name() const override { return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) std::string param() const override { std::string ret; @@ -283,10 +273,9 @@ public: int threadblock_m, threadblock_n, threadblock_k; int warp_m, warp_n, warp_k; }; - AlgoParam_ algo_param{ - m_algo_param.threadblock_m, m_algo_param.threadblock_n, - m_algo_param.threadblock_k, m_algo_param.warp_m, - m_algo_param.warp_n, m_algo_param.warp_k}; + AlgoParam_ algo_param{m_algo_param.threadblock_m, m_algo_param.threadblock_n, + m_algo_param.threadblock_k, m_algo_param.warp_m, + m_algo_param.warp_n, m_algo_param.warp_k}; serialize_write_pod(algo_param, ret); return ret; } @@ -303,15 +292,15 @@ class MatrixMulForwardImpl::AlgoFloat32SIMTSplitK final public: AlgoFloat32SIMTSplitK(AlgoParam algo_param) : AlgoCutlassMatrixMulBase{algo_param}, - m_name{ssprintf("CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", - m_algo_param.to_string().c_str())} {} + m_name{ssprintf( + "CUTLASS_FLOAT32_SIMT_SPLIT_K_%s", + m_algo_param.to_string().c_str())} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; const char* name() const override { return m_name.c_str(); } AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) std::string param() const override { @@ -322,10 +311,9 @@ public: int threadblock_m, threadblock_n, threadblock_k; int warp_m, warp_n, warp_k; }; - AlgoParam_ algo_param{ - m_algo_param.threadblock_m, m_algo_param.threadblock_n, - m_algo_param.threadblock_k, m_algo_param.warp_m, - m_algo_param.warp_n, m_algo_param.warp_k}; + AlgoParam_ algo_param{m_algo_param.threadblock_m, m_algo_param.threadblock_n, + m_algo_param.threadblock_k, m_algo_param.warp_m, + m_algo_param.warp_n, m_algo_param.warp_k}; serialize_write_pod(algo_param, ret); return ret; } @@ -337,20 +325,18 @@ private: const void* get_available_op(const SizeArgs& args) const; }; -class MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided final - : public AlgoBase { +class MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided final : public AlgoBase { public: AlgoFloat32SIMTGemvBatchedStrided(int threadblock_n) : m_threadblock_n{threadblock_n}, - m_name{ssprintf("CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_%d", - m_threadblock_n)} {} + m_name{ssprintf( + "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_%d", + m_threadblock_n)} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; const char* name() const override { return m_name.c_str(); } void exec(const ExecArgs& args) const override; - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED) std::string param() const override { @@ -370,17 +356,14 @@ class MatrixMulForwardImpl::AlgoFloat16TensorOp final public: AlgoFloat16TensorOp(AlgoParam algo_param) : AlgoCutlassMatrixMulBase{algo_param}, - m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_h%d%d%d_%s", - m_algo_param.instruction_m, - m_algo_param.instruction_n, - m_algo_param.instruction_k, - m_algo_param.to_string().c_str())} {} + m_name{ssprintf( + "CUTLASS_FLOAT16_TENSOR_OP_h%d%d%d_%s", + m_algo_param.instruction_m, m_algo_param.instruction_n, + m_algo_param.instruction_k, m_algo_param.to_string().c_str())} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; const char* name() const override { return m_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP) private: @@ -394,17 +377,15 @@ class MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK final public: AlgoFloat16TensorOpSplitK(AlgoParam algo_param) : AlgoCutlassMatrixMulBase{algo_param}, - m_name{ssprintf("CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h%d%d%d_%s", - m_algo_param.instruction_m, - m_algo_param.instruction_n, - m_algo_param.instruction_k, - m_algo_param.to_string().c_str())} {} + m_name{ssprintf( + "CUTLASS_FLOAT16_TENSOR_OP_SPLIT_K_h%d%d%d_%s", + m_algo_param.instruction_m, m_algo_param.instruction_n, + m_algo_param.instruction_k, m_algo_param.to_string().c_str())} {} bool is_available(const SizeArgs& args) const override; size_t get_workspace_in_bytes(const SizeArgs& args) const override; const char* name() const override { return m_name.c_str(); } AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE | - AlgoAttribute::USABLE_DEPEND_ON_SHAPE; + return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT16_TENSOR_OP_SPLIT_K) @@ -436,8 +417,7 @@ public: #if CUDA_VERSION >= 9020 std::vector simt_float32; std::vector simt_float32_split_k; - std::vector - simt_float32_gemv_batched_strided; + std::vector simt_float32_gemv_batched_strided; #if CUDA_VERSION >= 10020 std::vector tensorop_float16; std::vector tensorop_float16_split_k; diff --git a/dnn/src/cuda/matrix_mul/bfloat16.cpp b/dnn/src/cuda/matrix_mul/bfloat16.cpp index 329112ea..f6a9f189 100644 --- a/dnn/src/cuda/matrix_mul/bfloat16.cpp +++ b/dnn/src/cuda/matrix_mul/bfloat16.cpp @@ -10,11 +10,11 @@ * implied. */ +#include "src/common/algo_base.h" +#include "src/common/algo_chooser.h" #include "src/cuda/handle.h" #include "src/cuda/matrix_mul/algos.h" #include "src/cuda/utils.h" -#include "src/common/algo_chooser.h" -#include "src/common/algo_base.h" using namespace megdnn; using namespace cuda; @@ -41,27 +41,25 @@ std::pair sub_opr_config( std::pair> prepare_sub_opr( const MatrixMulForwardImpl::AlgoBase::SizeArgs& args) { - auto&& config = sub_opr_config( - {args.layout_a, args.layout_b, args.layout_c}, args.opr); + auto&& config = + sub_opr_config({args.layout_a, args.layout_b, args.layout_c}, args.opr); auto matmul_opr = args.opr->handle()->create_operator(); matmul_opr->param() = config.second; return {config.first, std::move(matmul_opr)}; } } // namespace -std::vector -MatrixMulForwardImpl::AlgoBFloat16::get_subopr_list( +std::vector MatrixMulForwardImpl::AlgoBFloat16::get_subopr_list( const TensorLayoutArray& layouts, const OperatorBase* opr) const { - auto&& config = sub_opr_config( - layouts, static_cast(opr)); + auto&& config = + sub_opr_config(layouts, static_cast(opr)); std::string param_str; Algorithm::serialize_write_pod(config.second, param_str); return {{Algorithm::OprType::MATRIX_MUL_FORWARD, param_str, config.first}}; } -bool MatrixMulForwardImpl::AlgoBFloat16::is_available( - const SizeArgs& args) const { +bool MatrixMulForwardImpl::AlgoBFloat16::is_available(const SizeArgs& args) const { auto config = prepare_sub_opr(args); return args.layout_a.dtype == dtype::BFloat16() && get_algorithm( @@ -74,8 +72,7 @@ WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle( auto config = prepare_sub_opr(args); SmallVector sizes; - auto get_workspace = [&sizes](const TensorLayout& src, - const TensorLayout& dst) { + auto get_workspace = [&sizes](const TensorLayout& src, const TensorLayout& dst) { if (src.dtype != dst.dtype) { sizes.push_back(dst.span().dist_byte()); } @@ -99,8 +96,8 @@ void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { TensorND b = args.tensor_b; TensorND c = args.tensor_c; auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); - auto ctypecvt = CompTypeCvter( - args.opr->handle(), &bundle); + auto ctypecvt = + CompTypeCvter(args.opr->handle(), &bundle); ctypecvt.src_to_comp_type(args.tensor_a, a) .src_to_comp_type(args.tensor_b, b) .src_to_comp_type(args.tensor_c, c); diff --git a/dnn/src/cuda/matrix_mul/conv1x1.cpp b/dnn/src/cuda/matrix_mul/conv1x1.cpp index 31fdc7fa..3ef07a73 100644 --- a/dnn/src/cuda/matrix_mul/conv1x1.cpp +++ b/dnn/src/cuda/matrix_mul/conv1x1.cpp @@ -13,25 +13,24 @@ namespace { std::unique_ptr prepare_conv_opr( const MatrixMulForwardImpl::AlgoBase::SizeArgs& args) { - auto conv_bias_opr_ptr = - args.opr->handle()->create_operator(); + auto conv_bias_opr_ptr = args.opr->handle()->create_operator(); auto conv_param_computemode = - (args.opr->param().compute_mode == - param::MatrixMul::ComputeMode::DEFAULT) + (args.opr->param().compute_mode == param::MatrixMul::ComputeMode::DEFAULT) ? param::Convolution::ComputeMode::DEFAULT : param::Convolution::ComputeMode::FLOAT32; - conv_bias_opr_ptr->param() = {param::ConvBias::NonlineMode::IDENTITY, - param::Convolution::Mode::CROSS_CORRELATION, - param::Convolution::Sparse::DENSE, - param::Convolution::Format::NCHW, - 0, // pad_h - 0, // pad_w - 1, // stride_h - 1, // stride_w - 1, // dilate_h - 1, // dilate_w - conv_param_computemode}; + conv_bias_opr_ptr->param() = { + param::ConvBias::NonlineMode::IDENTITY, + param::Convolution::Mode::CROSS_CORRELATION, + param::Convolution::Sparse::DENSE, + param::Convolution::Format::NCHW, + 0, // pad_h + 0, // pad_w + 1, // stride_h + 1, // stride_w + 1, // dilate_h + 1, // dilate_w + conv_param_computemode}; return conv_bias_opr_ptr; } @@ -52,13 +51,12 @@ std::tuple gen_matrixmul_shape( megdnn_assert(k == args.layout_b.shape[1]); n = args.layout_b.shape[0]; } - return std::tuple {m, k, n}; + return std::tuple{m, k, n}; } } // namespace -bool MatrixMulForwardImpl::AlgoConv1X1CUDNN::is_available( - const SizeArgs& args) const { +bool MatrixMulForwardImpl::AlgoConv1X1CUDNN::is_available(const SizeArgs& args) const { if (!(args.layout_a.ndim == 2 && args.layout_b.ndim == 2 && args.layout_c.ndim == 2)) return false; @@ -133,8 +131,7 @@ void MatrixMulForwardImpl::AlgoConv1X1CUDNN::exec(const ExecArgs& args) const { if (args.opr->param().transposeA || args.opr->param().transposeB) { auto trans = args.opr->handle()->create_operator(); - auto trans_tensor = [&](size_t workspace_pos, - const TensorND& ori_tensor, + auto trans_tensor = [&](size_t workspace_pos, const TensorND& ori_tensor, TensorND& dst_tensor) { TensorLayout dst_layout( {ori_tensor.layout.shape[1], ori_tensor.layout.shape[0]}, @@ -151,8 +148,7 @@ void MatrixMulForwardImpl::AlgoConv1X1CUDNN::exec(const ExecArgs& args) const { trans_tensor(1, args.tensor_a, A_dst_tensor); } if (args.opr->param().transposeB) { - trans_tensor(bundle.nr_workspace() - 1, args.tensor_b, - B_dst_tensor); + trans_tensor(bundle.nr_workspace() - 1, args.tensor_b, B_dst_tensor); } } @@ -167,7 +163,7 @@ void MatrixMulForwardImpl::AlgoConv1X1CUDNN::exec(const ExecArgs& args) const { TensorND dst(args.tensor_c.raw_ptr, dst_layout); ConvBiasForwardImpl::AlgoBase::ExecArgs conv_exec_args( - static_cast(conv_opr_ptr.get()), src, filter, - bias, z, dst, bundle.get_workspace(0)); + static_cast(conv_opr_ptr.get()), src, filter, bias, z, + dst, bundle.get_workspace(0)); m_impl->exec(conv_exec_args); } diff --git a/dnn/src/cuda/matrix_mul/cublas.cpp b/dnn/src/cuda/matrix_mul/cublas.cpp index 1ece0a10..bf689ea0 100644 --- a/dnn/src/cuda/matrix_mul/cublas.cpp +++ b/dnn/src/cuda/matrix_mul/cublas.cpp @@ -29,23 +29,22 @@ using namespace cuda; #define CUBLAS_COMPUTE_32I CUDA_R_32I #endif -bool MatrixMulForwardImpl::AlgoCuBlas::is_available( - const SizeArgs& args) const { +bool MatrixMulForwardImpl::AlgoCuBlas::is_available(const SizeArgs& args) const { if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) return false; if (args.layout_a.dtype == dtype::Float32() || args.layout_a.dtype == dtype::Float16()) { return true; - } else if (args.layout_a.dtype.enumv() == DTypeEnum::Int8 || - args.layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) { + } else if ( + args.layout_a.dtype.enumv() == DTypeEnum::Int8 || + args.layout_a.dtype.enumv() == DTypeEnum::QuantizedS8) { /** * \note When passing in the strides which can not be divided by 4, the * cublas rontine cublasGemmEx will raise a Error * CUBLAS_STATUS_INVALID_VALUE. The error occured because the leading * dimension of matrix A or B is illegal. */ - return args.layout_a.stride[0] % 4 == 0 && - args.layout_b.stride[0] % 4 == 0 && + return args.layout_a.stride[0] % 4 == 0 && args.layout_b.stride[0] % 4 == 0 && is_compute_capability_required(6, 1); } return false; @@ -65,9 +64,8 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, args.tensor_b.ptr(), args.tensor_b.layout.stride[0], - args.tensor_a.ptr(), args.tensor_a.layout.stride[0], - zero, args.tensor_c.ptr(), - args.tensor_c.layout.stride[0])); + args.tensor_a.ptr(), args.tensor_a.layout.stride[0], zero, + args.tensor_c.ptr(), args.tensor_c.layout.stride[0])); }; auto sgemm_ex = [&]() { @@ -117,11 +115,10 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { cublas_check(cublasGemmEx( cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, - args.tensor_b.raw_ptr, CUDA_R_8I, - args.tensor_b.layout.stride[0], args.tensor_a.raw_ptr, - CUDA_R_8I, args.tensor_a.layout.stride[0], zero, - args.tensor_c.raw_ptr, CUDA_R_32I, - args.tensor_c.layout.stride[0], CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DFALT)); + args.tensor_b.raw_ptr, CUDA_R_8I, args.tensor_b.layout.stride[0], + args.tensor_a.raw_ptr, CUDA_R_8I, args.tensor_a.layout.stride[0], zero, + args.tensor_c.raw_ptr, CUDA_R_32I, args.tensor_c.layout.stride[0], + CUBLAS_COMPUTE_32I, CUBLAS_GEMM_DFALT)); }; // Note that cublas takes column-major matrices as inputs, diff --git a/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp b/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp index 1226fb6c..d5cad0d9 100644 --- a/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp +++ b/dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/common/utils.h" #include "src/cuda/matrix_mul/cublasLt_wrapper.h" +#include "src/common/utils.h" #include "src/cuda/utils.h" #if CUDA_VERSION >= 10010 @@ -140,9 +140,8 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) { stride_a_trans = round_up(k, 32) / 32 * ldatransform; stride_c_trans = round_up(m, 32) / 32 * ldctransform; trans_b = CUBLAS_OP_T; - cublas_check(cublasLtMatmulDescSetAttribute(matmul_desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &trans_b, sizeof(trans_b))); + cublas_check(cublasLtMatmulDescSetAttribute( + matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b))); // origin layout cublas_check(cublasLtMatrixLayoutCreate( &layout_b, dt_b, n, k, args.layout_b.stride[batched ? 1 : 0])); @@ -151,18 +150,18 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) { cublas_check(cublasLtMatrixLayoutCreate( &layout_c, dt_c, n, m, args.layout_c.stride[batched ? 1 : 0])); // transformed layout - cublas_check(cublasLtMatrixLayoutCreate(&layout_trans_b, dt_b, n, k, - ldbtransform)); - cublas_check(cublasLtMatrixLayoutCreate(&layout_trans_a, dt_a, m, k, - ldatransform)); - cublas_check(cublasLtMatrixLayoutCreate(&layout_trans_c, dt_c, n, m, - ldctransform)); + cublas_check( + cublasLtMatrixLayoutCreate(&layout_trans_b, dt_b, n, k, ldbtransform)); + cublas_check( + cublasLtMatrixLayoutCreate(&layout_trans_a, dt_a, m, k, ldatransform)); + cublas_check( + cublasLtMatrixLayoutCreate(&layout_trans_c, dt_c, n, m, ldctransform)); cublas_check(cublasLtMatrixLayoutSetAttribute( layout_trans_b, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); cublas_check(cublasLtMatrixLayoutSetAttribute( - layout_trans_a, CUBLASLT_MATRIX_LAYOUT_ORDER, - &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C))); + layout_trans_a, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL4_4R2_8C, + sizeof(order_COL4_4R2_8C))); cublas_check(cublasLtMatrixLayoutSetAttribute( layout_trans_c, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32, sizeof(order_COL32))); @@ -192,20 +191,16 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) { } else { trans_b = args.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N; trans_a = args.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N; - cublas_check(cublasLtMatmulDescSetAttribute(matmul_desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &trans_b, sizeof(trans_b))); - cublas_check(cublasLtMatmulDescSetAttribute(matmul_desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &trans_a, sizeof(trans_a))); + cublas_check(cublasLtMatmulDescSetAttribute( + matmul_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_b, sizeof(trans_b))); + cublas_check(cublasLtMatmulDescSetAttribute( + matmul_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_a, sizeof(trans_a))); cublas_check(cublasLtMatrixLayoutCreate( &layout_b, dt_b, trans_b == CUBLAS_OP_N ? n : k, - trans_b == CUBLAS_OP_N ? k : n, - args.layout_b.stride[batched ? 1 : 0])); + trans_b == CUBLAS_OP_N ? k : n, args.layout_b.stride[batched ? 1 : 0])); cublas_check(cublasLtMatrixLayoutCreate( &layout_a, dt_a, trans_a == CUBLAS_OP_N ? k : m, - trans_a == CUBLAS_OP_N ? m : k, - args.layout_a.stride[batched ? 1 : 0])); + trans_a == CUBLAS_OP_N ? m : k, args.layout_a.stride[batched ? 1 : 0])); cublas_check(cublasLtMatrixLayoutCreate( &layout_c, dt_c, n, m, args.layout_c.stride[batched ? 1 : 0])); } @@ -213,14 +208,11 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) { size_t stride_a = args.layout_a.stride[0]; size_t stride_c = args.layout_c.stride[0]; cublas_check(cublasLtMatrixLayoutSetAttribute( - layout_b, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, - sizeof(batch))); + layout_b, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); cublas_check(cublasLtMatrixLayoutSetAttribute( - layout_a, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, - sizeof(batch))); + layout_a, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); cublas_check(cublasLtMatrixLayoutSetAttribute( - layout_c, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, - sizeof(batch))); + layout_c, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); cublas_check(cublasLtMatrixLayoutSetAttribute( layout_b, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b))); @@ -239,8 +231,7 @@ bool CUBLASLTMatmulDesc::is_available(const SizeArgs& args, size_t ws_limit) { support = (dt_a == CUDA_R_16F); break; case CUDA_R_32I: { - support = (dt_a == CUDA_R_8I) && - (!args.transposeA && !args.transposeB); + support = (dt_a == CUDA_R_8I) && (!args.transposeA && !args.transposeB); break; } case CUDA_R_32F: @@ -274,13 +265,12 @@ WorkspaceBundle CUBLASLTMatmulDesc::get_workspace_bundle( algo_workspace_size = result.workspaceSize; return {nullptr, (dt_c == CUDA_R_32I) - ? SmallVector{algo_workspace_size, workspace_b, - workspace_a, workspace_c} + ? SmallVector< + size_t>{algo_workspace_size, workspace_b, workspace_a, workspace_c} : SmallVector{algo_workspace_size}}; } -bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args, - size_t ws_limit, - cublasLtMatmulAlgo_t& algo) { +bool CUBLASLTMatmulDesc::get_algorithm_heuristic( + const SizeArgs& args, size_t ws_limit, cublasLtMatmulAlgo_t& algo) { bool result; int return_algo_count; size_t algo_ws_limit; @@ -331,8 +321,8 @@ bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args, dt_c == CUDA_R_32I ? layout_trans_b : layout_b, dt_c == CUDA_R_32I ? layout_trans_a : layout_a, dt_c == CUDA_R_32I ? layout_trans_c : layout_c, - dt_c == CUDA_R_32I ? layout_trans_c : layout_c, algo_pref, 1, - &algo_result, &return_algo_count); + dt_c == CUDA_R_32I ? layout_trans_c : layout_c, algo_pref, 1, &algo_result, + &return_algo_count); if (status == CUBLAS_STATUS_SUCCESS && return_algo_count > 0 && // perform cublasLtAlgoCheck() to make sure the algo is correct get_workspace_bundle(args, algo_result.algo).nr_workspace() > 0) { diff --git a/dnn/src/cuda/matrix_mul/cublasLt_wrapper.h b/dnn/src/cuda/matrix_mul/cublasLt_wrapper.h index f9615252..c8804c7c 100644 --- a/dnn/src/cuda/matrix_mul/cublasLt_wrapper.h +++ b/dnn/src/cuda/matrix_mul/cublasLt_wrapper.h @@ -26,9 +26,9 @@ struct CUBLASLTMatmulDesc { bool transposeA, transposeB; TensorLayout layout_a, layout_b, layout_c; std::string to_string() const; - SizeArgs(HandleImpl* handle, bool transposeA, bool transposeB, - const TensorLayout& A, const TensorLayout& B, - const TensorLayout& C) + SizeArgs( + HandleImpl* handle, bool transposeA, bool transposeB, + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) : handle(handle), transposeA(transposeA), transposeB(transposeB), @@ -73,10 +73,10 @@ struct CUBLASLTMatmulDesc { ~CUBLASLTMatmulDesc(); void set(const SizeArgs& args, bool batched = false); void reset(); - bool get_algorithm_heuristic(const SizeArgs& args, size_t ws_limit, - cublasLtMatmulAlgo_t& algo); - WorkspaceBundle get_workspace_bundle(const SizeArgs& args, - const cublasLtMatmulAlgo_t& algo); + bool get_algorithm_heuristic( + const SizeArgs& args, size_t ws_limit, cublasLtMatmulAlgo_t& algo); + WorkspaceBundle get_workspace_bundle( + const SizeArgs& args, const cublasLtMatmulAlgo_t& algo); bool is_available(const SizeArgs& args, size_t ws_limit); }; } // namespace cuda diff --git a/dnn/src/cuda/matrix_mul/cublas_lt.cpp b/dnn/src/cuda/matrix_mul/cublas_lt.cpp index 111777d3..e90e071b 100644 --- a/dnn/src/cuda/matrix_mul/cublas_lt.cpp +++ b/dnn/src/cuda/matrix_mul/cublas_lt.cpp @@ -11,14 +11,13 @@ #include "./algos.h" #include "src/cuda/handle.h" -#include "src/cuda/utils.h" #include "src/cuda/matrix_mul/cublasLt_wrapper.h" +#include "src/cuda/utils.h" #if CUDA_VERSION >= 10010 using namespace megdnn; using namespace cuda; -bool MatrixMulForwardImpl::AlgoCuBlasLt::is_available( - const SizeArgs& args) const { +bool MatrixMulForwardImpl::AlgoCuBlasLt::is_available(const SizeArgs& args) const { if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) return false; if (args.layout_a.dtype.enumv() == DTypeEnum::Quantized4Asymm || @@ -49,83 +48,69 @@ void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const { auto sgemm = [&]() { auto zero = handle->zero_device(); auto one = handle->one_device(); - megdnn_assert(ws_bundle.nr_workspace() == 1, - "workspace bundle size should be 1(ws_algo)"); - cublas_check(cublasLtMatmul(cublasLt_handle, - desc.matmul_desc, - one, - static_cast(args.tensor_b.ptr()), desc.layout_b, - static_cast(args.tensor_a.ptr()), desc.layout_a, - zero, - static_cast(args.tensor_c.ptr()), desc.layout_c, - static_cast(args.tensor_c.ptr()), desc.layout_c, - &algo, - ws_bundle.get(0), ws_bundle.get_size(0), - stream - )); + megdnn_assert( + ws_bundle.nr_workspace() == 1, + "workspace bundle size should be 1(ws_algo)"); + cublas_check(cublasLtMatmul( + cublasLt_handle, desc.matmul_desc, one, + static_cast(args.tensor_b.ptr()), desc.layout_b, + static_cast(args.tensor_a.ptr()), desc.layout_a, + zero, static_cast(args.tensor_c.ptr()), + desc.layout_c, static_cast(args.tensor_c.ptr()), + desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), stream)); }; auto hgemm = [&]() { auto zero_half = handle->zero_device_h(); auto one_half = handle->one_device_h(); - megdnn_assert(ws_bundle.nr_workspace() == 1, - "workspace bundle size should be 1(ws_algo)"); - cublas_check(cublasLtMatmul(cublasLt_handle, - desc.matmul_desc, - one_half, - static_cast(args.tensor_b.raw_ptr), desc.layout_b, - static_cast(args.tensor_a.raw_ptr), desc.layout_a, - zero_half, - static_cast(args.tensor_c.raw_ptr), desc.layout_c, - static_cast<__half *>(args.tensor_c.raw_ptr), desc.layout_c, - &algo, - ws_bundle.get(0), ws_bundle.get_size(0), - stream - )); + megdnn_assert( + ws_bundle.nr_workspace() == 1, + "workspace bundle size should be 1(ws_algo)"); + cublas_check(cublasLtMatmul( + cublasLt_handle, desc.matmul_desc, one_half, + static_cast(args.tensor_b.raw_ptr), desc.layout_b, + static_cast(args.tensor_a.raw_ptr), desc.layout_a, + zero_half, static_cast(args.tensor_c.raw_ptr), + desc.layout_c, static_cast<__half*>(args.tensor_c.raw_ptr), + desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), stream)); }; auto igemm = [&]() { auto zero = handle->zero_device(); auto one = handle->one_device(); - megdnn_assert(ws_bundle.nr_workspace() == 4, - "workspace bundle size should be 4(ws_algo, ws_a, ws_b, ws_c)"); - void *ws_b = ws_bundle.get(1); - void *ws_a = ws_bundle.get(2); - void *ws_c = ws_bundle.get(3); - int32_t pm=CUBLAS_POINTER_MODE_DEVICE; - cublasOperation_t trans_a=CUBLAS_OP_T, trans_c=CUBLAS_OP_N; + megdnn_assert( + ws_bundle.nr_workspace() == 4, + "workspace bundle size should be 4(ws_algo, ws_a, ws_b, ws_c)"); + void* ws_b = ws_bundle.get(1); + void* ws_a = ws_bundle.get(2); + void* ws_c = ws_bundle.get(3); + int32_t pm = CUBLAS_POINTER_MODE_DEVICE; + cublasOperation_t trans_a = CUBLAS_OP_T, trans_c = CUBLAS_OP_N; cublasLtMatrixTransformDesc_t transform_desc = nullptr; cublas_check(cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F)); - cublas_check(cublasLtMatrixTransformDescSetAttribute(transform_desc, - CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, &pm, sizeof(pm))); - cublas_check(cublasLtMatrixTransform(cublasLt_handle, transform_desc, - one, args.tensor_b.raw_ptr, desc.layout_b, - zero, nullptr, nullptr, - ws_b, desc.layout_trans_b, - stream)); - cublas_check(cublasLtMatrixTransformDescSetAttribute(transform_desc, - CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_a, sizeof(trans_a))); - cublas_check(cublasLtMatrixTransform(cublasLt_handle, transform_desc, - one, args.tensor_a.raw_ptr, desc.layout_a, - zero, nullptr, nullptr, - ws_a, desc.layout_trans_a, - stream)); - cublas_check(cublasLtMatmul(cublasLt_handle, desc.matmul_desc, - one, - ws_b, desc.layout_trans_b, - ws_a, desc.layout_trans_a, - zero, - ws_c, desc.layout_trans_c, - ws_c, desc.layout_trans_c, - &algo, - ws_bundle.get(0), - ws_bundle.get_size(0), - stream)); - cublas_check(cublasLtMatrixTransformDescSetAttribute(transform_desc, - CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_c, sizeof(trans_c))); - cublas_check(cublasLtMatrixTransform(cublasLt_handle, transform_desc, - one, ws_c, desc.layout_trans_c, - zero, nullptr, nullptr, - args.tensor_c.raw_ptr, desc.layout_c, - stream)); + cublas_check(cublasLtMatrixTransformDescSetAttribute( + transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, &pm, + sizeof(pm))); + cublas_check(cublasLtMatrixTransform( + cublasLt_handle, transform_desc, one, args.tensor_b.raw_ptr, + desc.layout_b, zero, nullptr, nullptr, ws_b, desc.layout_trans_b, + stream)); + cublas_check(cublasLtMatrixTransformDescSetAttribute( + transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_a, + sizeof(trans_a))); + cublas_check(cublasLtMatrixTransform( + cublasLt_handle, transform_desc, one, args.tensor_a.raw_ptr, + desc.layout_a, zero, nullptr, nullptr, ws_a, desc.layout_trans_a, + stream)); + cublas_check(cublasLtMatmul( + cublasLt_handle, desc.matmul_desc, one, ws_b, desc.layout_trans_b, ws_a, + desc.layout_trans_a, zero, ws_c, desc.layout_trans_c, ws_c, + desc.layout_trans_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), + stream)); + cublas_check(cublasLtMatrixTransformDescSetAttribute( + transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_c, + sizeof(trans_c))); + cublas_check(cublasLtMatrixTransform( + cublasLt_handle, transform_desc, one, ws_c, desc.layout_trans_c, zero, + nullptr, nullptr, args.tensor_c.raw_ptr, desc.layout_c, stream)); cublas_check(cublasLtMatrixTransformDescDestroy(transform_desc)); }; #if CUDA_VERSION >= 11000 diff --git a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp index 1dcf6d84..ca054be8 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop.cpp @@ -21,24 +21,23 @@ using namespace cuda; bool MatrixMulForwardImpl::AlgoFloat16TensorOp::is_available( const SizeArgs& args) const { - bool available = - args.opr->param().format == param::MatrixMul::Format::DEFAULT && - args.layout_b.dtype == dtype::Float16() && - args.layout_c.dtype == dtype::Float16(); + bool available = args.opr->param().format == param::MatrixMul::Format::DEFAULT && + args.layout_b.dtype == dtype::Float16() && + args.layout_c.dtype == dtype::Float16(); int n = args.layout_c.shape[1]; auto&& device_prop = cuda::current_device_prop(); int y_grid_limit = device_prop.maxGridSize[1]; // limit y grid - available &= ((n + m_algo_param.threadblock_n - 1) / - m_algo_param.threadblock_n <= - y_grid_limit); + available &= + ((n + m_algo_param.threadblock_n - 1) / m_algo_param.threadblock_n <= + y_grid_limit); if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && m_algo_param.instruction_k == 4) { available &= is_compute_capability_required(7, 0); } else { - megdnn_assert(m_algo_param.instruction_m == 16 && - m_algo_param.instruction_n == 8 && - m_algo_param.instruction_k == 8); + megdnn_assert( + m_algo_param.instruction_m == 16 && m_algo_param.instruction_n == 8 && + m_algo_param.instruction_k == 8); available &= is_compute_capability_required(7, 5); } @@ -58,20 +57,18 @@ size_t MatrixMulForwardImpl::AlgoFloat16TensorOp::get_workspace_in_bytes( return ws_size; } -void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec( - const ExecArgs& args) const { - int64_t lda = args.tensor_a.layout.stride[0], - ldb = args.tensor_b.layout.stride[0], +void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec(const ExecArgs& args) const { + int64_t lda = args.tensor_a.layout.stride[0], ldb = args.tensor_b.layout.stride[0], ldc = args.tensor_c.layout.stride[0]; int alignment = max_alignment(args); int min_alignment = min_alignment_requirement(); auto&& param = args.opr->param(); int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; - megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && - ldc % alignment == 0 && m % alignment == 0 && - n % alignment == 0 && k % alignment == 0 && - alignment >= min_alignment); + megdnn_assert( + lda % alignment == 0 && ldb % alignment == 0 && ldc % alignment == 0 && + m % alignment == 0 && n % alignment == 0 && k % alignment == 0 && + alignment >= min_alignment); cutlass::gemm::GemmCoord problem_size{m, n, k}; auto&& stream = cuda_stream(args.opr->handle()); int* workspace = reinterpret_cast(args.workspace.raw_ptr); @@ -87,10 +84,10 @@ void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec( using namespace cutlass::library; - auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor - : LayoutTypeID::kRowMajor; - auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor - : LayoutTypeID::kRowMajor; + auto layoutA = + param.transposeA ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + auto layoutB = + param.transposeB ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; void *host_one, *host_zero; NumericTypeID element_accumulator; @@ -99,53 +96,54 @@ void MatrixMulForwardImpl::AlgoFloat16TensorOp::do_exec( host_one = &one_f16; host_zero = &zero_f16; } else { - megdnn_assert(param.compute_mode == - param::MatrixMul::ComputeMode::FLOAT32); + megdnn_assert(param.compute_mode == param::MatrixMul::ComputeMode::FLOAT32); element_accumulator = NumericTypeID::kF32; host_one = &one; host_zero = &zero; } - GemmKey key{NumericTypeID::kF16, - layoutA, - NumericTypeID::kF16, - layoutB, - NumericTypeID::kF16, - LayoutTypeID::kRowMajor, - element_accumulator, - m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k, - m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k, - m_algo_param.instruction_m, - m_algo_param.instruction_n, - m_algo_param.instruction_k, - 2, - alignment, - alignment, - SplitKMode::kNone}; + GemmKey key{ + NumericTypeID::kF16, + layoutA, + NumericTypeID::kF16, + layoutB, + NumericTypeID::kF16, + LayoutTypeID::kRowMajor, + element_accumulator, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + m_algo_param.instruction_m, + m_algo_param.instruction_n, + m_algo_param.instruction_k, + 2, + alignment, + alignment, + SplitKMode::kNone}; const auto& table = Singleton::get().operation_table; - megdnn_assert(table.gemm_operations.count(key) > 0, - "key not found in cutlass operation table"); + megdnn_assert( + table.gemm_operations.count(key) > 0, + "key not found in cutlass operation table"); const auto& ops = table.gemm_operations.at(key); - megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", - ops.size()); + megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", ops.size()); - GemmArguments gemm_args{problem_size, - args.tensor_a.raw_ptr, - args.tensor_b.raw_ptr, - args.tensor_c.raw_ptr, - args.tensor_c.raw_ptr, - lda, - ldb, - ldc, - ldc, - 1, - host_one, - host_zero}; + GemmArguments gemm_args{ + problem_size, + args.tensor_a.raw_ptr, + args.tensor_b.raw_ptr, + args.tensor_c.raw_ptr, + args.tensor_c.raw_ptr, + lda, + ldb, + ldc, + ldc, + 1, + host_one, + host_zero}; cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); } diff --git a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp index c9b9adf5..cd44216b 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float16_tensorop_split_k.cpp @@ -24,24 +24,23 @@ bool MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::is_available( auto&& param = args.opr->param(); int m = args.layout_c.shape[0], n = args.layout_c.shape[1], k = args.layout_a.shape[param.transposeA ? 0 : 1]; - bool available = - args.opr->param().format == param::MatrixMul::Format::DEFAULT && - args.layout_a.dtype == dtype::Float16() && - args.layout_b.dtype == dtype::Float16() && - args.layout_c.dtype == dtype::Float16() && k > n; + bool available = args.opr->param().format == param::MatrixMul::Format::DEFAULT && + args.layout_a.dtype == dtype::Float16() && + args.layout_b.dtype == dtype::Float16() && + args.layout_c.dtype == dtype::Float16() && k > n; auto&& device_prop = cuda::current_device_prop(); int y_grid_limit = device_prop.maxGridSize[1]; // limit y grid - available &= ((m + m_algo_param.threadblock_m - 1) / - m_algo_param.threadblock_m <= - y_grid_limit); + available &= + ((m + m_algo_param.threadblock_m - 1) / m_algo_param.threadblock_m <= + y_grid_limit); if (m_algo_param.instruction_m == 8 && m_algo_param.instruction_n == 8 && m_algo_param.instruction_k == 4) { available &= is_compute_capability_required(7, 0); } else { - megdnn_assert(m_algo_param.instruction_m == 16 && - m_algo_param.instruction_n == 8 && - m_algo_param.instruction_k == 8); + megdnn_assert( + m_algo_param.instruction_m == 16 && m_algo_param.instruction_n == 8 && + m_algo_param.instruction_k == 8); available &= is_compute_capability_required(7, 5); } @@ -61,8 +60,7 @@ size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes( int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1], align_k = layouts[0].shape[1]; split_k_slices = std::max(1, align_k / align_n); - size_t ws_size = - args.layout_c.dtype.size(align_m * align_n * split_k_slices); + size_t ws_size = args.layout_c.dtype.size(align_m * align_n * split_k_slices); for (auto&& ly : layouts) ws_size += ly.span().dist_byte(); return ws_size; @@ -70,18 +68,17 @@ size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes( void MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::do_exec( const ExecArgs& args) const { - int64_t lda = args.tensor_a.layout.stride[0], - ldb = args.tensor_b.layout.stride[0], + int64_t lda = args.tensor_a.layout.stride[0], ldb = args.tensor_b.layout.stride[0], ldc = args.tensor_c.layout.stride[0]; int alignment = max_alignment(args); int min_alignment = min_alignment_requirement(); auto&& param = args.opr->param(); int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; - megdnn_assert(lda % alignment == 0 && ldb % alignment == 0 && - ldc % alignment == 0 && m % alignment == 0 && - n % alignment == 0 && k % alignment == 0 && - alignment >= min_alignment); + megdnn_assert( + lda % alignment == 0 && ldb % alignment == 0 && ldc % alignment == 0 && + m % alignment == 0 && n % alignment == 0 && k % alignment == 0 && + alignment >= min_alignment); cutlass::gemm::GemmCoord problem_size{m, n, k}; int split_k_slices = std::max(1, k / n); auto&& stream = cuda_stream(args.opr->handle()); @@ -98,10 +95,10 @@ void MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::do_exec( using namespace cutlass::library; - auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor - : LayoutTypeID::kRowMajor; - auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor - : LayoutTypeID::kRowMajor; + auto layoutA = + param.transposeA ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + auto layoutB = + param.transposeB ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; void *host_one, *host_zero; NumericTypeID element_accumulator; @@ -110,53 +107,54 @@ void MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::do_exec( host_one = &one_f16; host_zero = &zero_f16; } else { - megdnn_assert(param.compute_mode == - param::MatrixMul::ComputeMode::FLOAT32); + megdnn_assert(param.compute_mode == param::MatrixMul::ComputeMode::FLOAT32); element_accumulator = NumericTypeID::kF32; host_one = &one; host_zero = &zero; } - GemmKey key{NumericTypeID::kF16, - layoutA, - NumericTypeID::kF16, - layoutB, - NumericTypeID::kF16, - LayoutTypeID::kRowMajor, - element_accumulator, - m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k, - m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k, - m_algo_param.instruction_m, - m_algo_param.instruction_n, - m_algo_param.instruction_k, - 2, - alignment, - alignment, - SplitKMode::kParallel}; + GemmKey key{ + NumericTypeID::kF16, + layoutA, + NumericTypeID::kF16, + layoutB, + NumericTypeID::kF16, + LayoutTypeID::kRowMajor, + element_accumulator, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + m_algo_param.instruction_m, + m_algo_param.instruction_n, + m_algo_param.instruction_k, + 2, + alignment, + alignment, + SplitKMode::kParallel}; const auto& table = Singleton::get().operation_table; - megdnn_assert(table.gemm_operations.count(key) > 0, - "key not found in cutlass operation table"); + megdnn_assert( + table.gemm_operations.count(key) > 0, + "key not found in cutlass operation table"); const auto& ops = table.gemm_operations.at(key); - megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", - ops.size()); + megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", ops.size()); - GemmArguments gemm_args{problem_size, - args.tensor_a.raw_ptr, - args.tensor_b.raw_ptr, - args.tensor_c.raw_ptr, - args.tensor_c.raw_ptr, - lda, - ldb, - ldc, - ldc, - split_k_slices, - host_one, - host_zero}; + GemmArguments gemm_args{ + problem_size, + args.tensor_a.raw_ptr, + args.tensor_b.raw_ptr, + args.tensor_c.raw_ptr, + args.tensor_c.raw_ptr, + lda, + ldb, + ldc, + ldc, + split_k_slices, + host_one, + host_zero}; cutlass_check(ops[0]->run(&gemm_args, workspace, stream)); } diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp index 96d35303..f9ce2d0d 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp @@ -23,49 +23,48 @@ const void* MatrixMulForwardImpl::AlgoFloat32SIMT::get_available_op( const SizeArgs& args) const { using namespace cutlass::library; auto&& param = args.opr->param(); - auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor - : LayoutTypeID::kRowMajor; - auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor - : LayoutTypeID::kRowMajor; + auto layoutA = + param.transposeA ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + auto layoutB = + param.transposeB ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; int alignment = min_alignment_requirement(); - GemmKey key{NumericTypeID::kF32, - layoutA, - NumericTypeID::kF32, - layoutB, - NumericTypeID::kF32, - LayoutTypeID::kRowMajor, - NumericTypeID::kF32, - m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k, - m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k, - 1, - 1, - 1, - 2, - alignment, - alignment, - SplitKMode::kNone}; + GemmKey key{ + NumericTypeID::kF32, + layoutA, + NumericTypeID::kF32, + layoutB, + NumericTypeID::kF32, + LayoutTypeID::kRowMajor, + NumericTypeID::kF32, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + 1, + 1, + 1, + 2, + alignment, + alignment, + SplitKMode::kNone}; return (void*)Singleton::get().operation_table.find_op(key); } -bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( - const SizeArgs& args) const { - bool available = - args.opr->param().format == param::MatrixMul::Format::DEFAULT && - args.layout_a.dtype == dtype::Float32() && - args.layout_b.dtype == dtype::Float32() && - args.layout_c.dtype == dtype::Float32(); +bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available(const SizeArgs& args) const { + bool available = args.opr->param().format == param::MatrixMul::Format::DEFAULT && + args.layout_a.dtype == dtype::Float32() && + args.layout_b.dtype == dtype::Float32() && + args.layout_c.dtype == dtype::Float32(); int n = args.layout_c.shape[1]; auto&& device_prop = cuda::current_device_prop(); int y_grid_limit = device_prop.maxGridSize[1]; // limit y grid - available &= ((n + m_algo_param.threadblock_n - 1) / - m_algo_param.threadblock_n <= - y_grid_limit); + available &= + ((n + m_algo_param.threadblock_n - 1) / m_algo_param.threadblock_n <= + y_grid_limit); available &= (get_available_op(args) != nullptr); @@ -77,10 +76,8 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( return 0_z; } -void MatrixMulForwardImpl::AlgoFloat32SIMT::do_exec( - const ExecArgs& args) const { - int64_t lda = args.tensor_a.layout.stride[0], - ldb = args.tensor_b.layout.stride[0], +void MatrixMulForwardImpl::AlgoFloat32SIMT::do_exec(const ExecArgs& args) const { + int64_t lda = args.tensor_a.layout.stride[0], ldb = args.tensor_b.layout.stride[0], ldc = args.tensor_c.layout.stride[0]; auto&& param = args.opr->param(); int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], @@ -98,18 +95,19 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::do_exec( const Operation* op = (const Operation*)get_available_op(args); - GemmArguments gemm_args{problem_size, - args.tensor_a.raw_ptr, - args.tensor_b.raw_ptr, - args.tensor_c.raw_ptr, - args.tensor_c.raw_ptr, - lda, - ldb, - ldc, - ldc, - 1, - &alpha, - &beta}; + GemmArguments gemm_args{ + problem_size, + args.tensor_a.raw_ptr, + args.tensor_b.raw_ptr, + args.tensor_c.raw_ptr, + args.tensor_c.raw_ptr, + lda, + ldb, + ldc, + ldc, + 1, + &alpha, + &beta}; cutlass_check(op->run(&gemm_args, workspace, stream)); } diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp index d0b94e35..daa6067b 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp @@ -30,16 +30,14 @@ bool MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::is_available( args.layout_c.dtype == dtype::Float32() && ((!ta) && (!tb)); } -size_t -MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::get_workspace_in_bytes( +size_t MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::get_workspace_in_bytes( const SizeArgs& /* args */) const { return 0; } void MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::exec( const ExecArgs& args) const { - size_t lda = args.tensor_a.layout.stride[0], - ldb = args.tensor_b.layout.stride[0], + size_t lda = args.tensor_a.layout.stride[0], ldb = args.tensor_b.layout.stride[0], ldc = args.tensor_c.layout.stride[0]; auto&& param = args.opr->param(); int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], @@ -48,9 +46,8 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::exec( BatchedGemmCoord problem_size{1, n, k, m}; auto&& stream = cuda_stream(args.opr->handle()); return cutlass_matrix_mul_float32_simt_gemv_batched_strided( - args.tensor_a.ptr(), lda, lda, - args.tensor_b.ptr(), ldb, 0, - args.tensor_c.ptr(), ldc, ldc, problem_size, + args.tensor_a.ptr(), lda, lda, args.tensor_b.ptr(), + ldb, 0, args.tensor_c.ptr(), ldc, ldc, problem_size, m_threadblock_n, stream); } #endif diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp index d2128a1a..74cfc465 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp @@ -23,32 +23,33 @@ const void* MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_available_op( const SizeArgs& args) const { using namespace cutlass::library; auto&& param = args.opr->param(); - auto layoutA = param.transposeA ? LayoutTypeID::kColumnMajor - : LayoutTypeID::kRowMajor; - auto layoutB = param.transposeB ? LayoutTypeID::kColumnMajor - : LayoutTypeID::kRowMajor; + auto layoutA = + param.transposeA ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + auto layoutB = + param.transposeB ? LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; int alignment = min_alignment_requirement(); - GemmKey key{NumericTypeID::kF32, - layoutA, - NumericTypeID::kF32, - layoutB, - NumericTypeID::kF32, - LayoutTypeID::kRowMajor, - NumericTypeID::kF32, - m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k, - m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k, - 1, - 1, - 1, - 2, - alignment, - alignment, - SplitKMode::kParallel}; + GemmKey key{ + NumericTypeID::kF32, + layoutA, + NumericTypeID::kF32, + layoutB, + NumericTypeID::kF32, + LayoutTypeID::kRowMajor, + NumericTypeID::kF32, + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k, + m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k, + 1, + 1, + 1, + 2, + alignment, + alignment, + SplitKMode::kParallel}; return (void*)Singleton::get().operation_table.find_op(key); } @@ -57,17 +58,16 @@ bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( auto&& param = args.opr->param(); int m = args.layout_c.shape[0], n = args.layout_c.shape[1], k = args.layout_a.shape[param.transposeA ? 0 : 1]; - bool available = - args.opr->param().format == param::MatrixMul::Format::DEFAULT && - args.layout_a.dtype == dtype::Float32() && - args.layout_b.dtype == dtype::Float32() && - args.layout_c.dtype == dtype::Float32() && k > n; + bool available = args.opr->param().format == param::MatrixMul::Format::DEFAULT && + args.layout_a.dtype == dtype::Float32() && + args.layout_b.dtype == dtype::Float32() && + args.layout_c.dtype == dtype::Float32() && k > n; auto&& device_prop = cuda::current_device_prop(); int y_grid_limit = device_prop.maxGridSize[1]; // limit y grid - available &= ((m + m_algo_param.threadblock_m - 1) / - m_algo_param.threadblock_m <= - y_grid_limit); + available &= + ((m + m_algo_param.threadblock_m - 1) / m_algo_param.threadblock_m <= + y_grid_limit); available &= (get_available_op(args) != nullptr); return available; @@ -82,10 +82,8 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( return args.layout_c.dtype.size(m * n * split_k_slices); } -void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::do_exec( - const ExecArgs& args) const { - int64_t lda = args.tensor_a.layout.stride[0], - ldb = args.tensor_b.layout.stride[0], +void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::do_exec(const ExecArgs& args) const { + int64_t lda = args.tensor_a.layout.stride[0], ldb = args.tensor_b.layout.stride[0], ldc = args.tensor_c.layout.stride[0]; auto&& param = args.opr->param(); int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], @@ -103,18 +101,19 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::do_exec( using namespace cutlass::library; const Operation* op = (const Operation*)get_available_op(args); - GemmArguments gemm_args{problem_size, - args.tensor_a.raw_ptr, - args.tensor_b.raw_ptr, - args.tensor_c.raw_ptr, - args.tensor_c.raw_ptr, - lda, - ldb, - ldc, - ldc, - split_k_slices, - &alpha, - &beta}; + GemmArguments gemm_args{ + problem_size, + args.tensor_a.raw_ptr, + args.tensor_b.raw_ptr, + args.tensor_c.raw_ptr, + args.tensor_c.raw_ptr, + lda, + ldb, + ldc, + ldc, + split_k_slices, + &alpha, + &beta}; cutlass_check(op->run(&gemm_args, workspace, stream)); } diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp index 5169e498..30d11960 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_base.cpp @@ -19,15 +19,15 @@ using namespace megdnn; using namespace cuda; -std::string -MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::AlgoParam::to_string() const { - return ssprintf("%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, - threadblock_k, warp_m, warp_n, warp_k); +std::string MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::AlgoParam::to_string() + const { + return ssprintf( + "%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, threadblock_k, warp_m, + warp_n, warp_k); } -std::pair -MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::construct_aligned_layouts( - const SizeArgs& args) const { +std::pair MatrixMulForwardImpl::AlgoCutlassMatrixMulBase:: + construct_aligned_layouts(const SizeArgs& args) const { int alignment = max_alignment(args); int min_alignment = min_alignment_requirement(); bool aligned = alignment >= min_alignment; @@ -46,8 +46,7 @@ MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::construct_aligned_layouts( return std::make_pair(!aligned, std::move(layouts)); } -void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec( - const ExecArgs& args) const { +void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec(const ExecArgs& args) const { auto aligned = construct_aligned_layouts(args); if (!aligned.first) return do_exec(args); @@ -65,8 +64,7 @@ void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec( auto&& relayout = args.opr->handle()->create_operator(); - auto copy_stride = [](const TensorLayout& src, TensorLayout& dst, - bool trans) { + auto copy_stride = [](const TensorLayout& src, TensorLayout& dst, bool trans) { dst.stride[0] = src.stride[0], dst.stride[1] = src.stride[1]; if (trans) std::swap(dst.stride[0], dst.stride[1]); @@ -94,8 +92,9 @@ void MatrixMulForwardImpl::AlgoCutlassMatrixMulBase::exec( tensor_a.layout = layouts[0]; tensor_b.layout = layouts[1]; - ExecArgs args_{static_cast(matmul.get()), tensor_a, - tensor_b, tensor_c, workspace}; + ExecArgs args_{ + static_cast(matmul.get()), tensor_a, tensor_b, + tensor_c, workspace}; do_exec(args_); tensor_c.layout.TensorShape::operator=(args.layout_c); diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh index 371c04a9..9e23b45a 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh +++ b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh @@ -23,17 +23,16 @@ using BatchedGemmCoord = cutlass::gemm::BatchedGemmCoord; template void cutlass_vector_matrix_mul_batched_strided_wrapper( - BatchedGemmCoord const& problem_size, - const typename GemvKernel::ElementA* d_A, size_t lda, - size_t batch_stride_a, const typename GemvKernel::ElementB* d_B, + BatchedGemmCoord const& problem_size, const typename GemvKernel::ElementA* d_A, + size_t lda, size_t batch_stride_a, const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, cudaStream_t stream); void cutlass_matrix_mul_float32_simt_gemv_batched_strided( const float* d_A, size_t lda, size_t batch_stride_a, const float* d_B, size_t ldb, size_t batch_stride_b, float* d_C, size_t ldc, - size_t batch_stride_c, BatchedGemmCoord const& problem_size, - int threadblock_n, cudaStream_t stream); + size_t batch_stride_c, BatchedGemmCoord const& problem_size, int threadblock_n, + cudaStream_t stream); } // namespace cutlass_wrapper } // namespace cuda diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cu b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cu index 9366b25d..bd8c464d 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cu +++ b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cu @@ -11,8 +11,7 @@ */ // ignore warning of cutlass #include "cuda.h" -#if __CUDACC_VER_MAJOR__ > 9 || \ - (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Wstrict-aliasing" @@ -31,46 +30,45 @@ using namespace cutlass_wrapper; /* ============ cutlass kernel wrapper for f32 vector-matrix mul batched strided * =========== */ -#define DISPATCH(cb) \ - cb(128, 4, 4); \ - cb(128, 4, 2); \ - cb(128, 4, 1); \ - cb(128, 2, 4); \ - cb(128, 1, 4); \ - cb(128, 2, 2); \ - cb(128, 1, 2); \ - cb(128, 2, 1); \ - cb(128, 1, 1); \ - cb(64, 4, 4); \ - cb(64, 4, 2); \ - cb(64, 4, 1); \ - cb(64, 2, 4); \ - cb(64, 1, 4); \ - cb(64, 2, 2); \ - cb(64, 1, 2); \ - cb(64, 2, 1); \ - cb(64, 1, 1); \ - cb(32, 4, 4); \ - cb(32, 4, 2); \ - cb(32, 4, 1); \ - cb(32, 2, 4); \ - cb(32, 1, 4); \ - cb(32, 2, 2); \ - cb(32, 1, 2); \ - cb(32, 2, 1); \ - cb(32, 1, 1); \ - megdnn_assert(false, \ - "unsupported gemv batched strided A=%dX%dX%d, B=%dX%dX%d", \ - problem_size.batch(), problem_size.m(), problem_size.k(), \ - problem_size.batch(), problem_size.k(), problem_size.n()); +#define DISPATCH(cb) \ + cb(128, 4, 4); \ + cb(128, 4, 2); \ + cb(128, 4, 1); \ + cb(128, 2, 4); \ + cb(128, 1, 4); \ + cb(128, 2, 2); \ + cb(128, 1, 2); \ + cb(128, 2, 1); \ + cb(128, 1, 1); \ + cb(64, 4, 4); \ + cb(64, 4, 2); \ + cb(64, 4, 1); \ + cb(64, 2, 4); \ + cb(64, 1, 4); \ + cb(64, 2, 2); \ + cb(64, 1, 2); \ + cb(64, 2, 1); \ + cb(64, 1, 1); \ + cb(32, 4, 4); \ + cb(32, 4, 2); \ + cb(32, 4, 1); \ + cb(32, 2, 4); \ + cb(32, 1, 4); \ + cb(32, 2, 2); \ + cb(32, 1, 2); \ + cb(32, 2, 1); \ + cb(32, 1, 1); \ + megdnn_assert( \ + false, "unsupported gemv batched strided A=%dX%dX%d, B=%dX%dX%d", \ + problem_size.batch(), problem_size.m(), problem_size.k(), \ + problem_size.batch(), problem_size.k(), problem_size.n()); void megdnn::cuda::cutlass_wrapper:: cutlass_matrix_mul_float32_simt_gemv_batched_strided( - const float* d_A, size_t lda, size_t batch_stride_a, - const float* d_B, size_t ldb, size_t batch_stride_b, float* d_C, - size_t ldc, size_t batch_stride_c, - BatchedGemmCoord const& problem_size, int threadblock_n, - cudaStream_t stream) { + const float* d_A, size_t lda, size_t batch_stride_a, const float* d_B, + size_t ldb, size_t batch_stride_b, float* d_C, size_t ldc, + size_t batch_stride_c, BatchedGemmCoord const& problem_size, + int threadblock_n, cudaStream_t stream) { int LDG_K, LDG_N; if (lda % 4 == 0) LDG_K = 4; @@ -85,21 +83,17 @@ void megdnn::cuda::cutlass_wrapper:: LDG_N = 2; else LDG_N = 1; -#define cb(threadblock_n_, LDG_K_, LDG_N_) \ - if (threadblock_n == threadblock_n_ && LDG_K == LDG_K_ && \ - LDG_N == LDG_N_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape<1, threadblock_n_, \ - (256 * LDG_K_) / \ - (threadblock_n_ / LDG_N_)>; \ - using ThreadShape = cutlass::gemm::GemmShape<1, LDG_N_, LDG_K_>; \ - using GemvKernel = cutlass::gemm::kernel::DefaultGemv< \ - ThreadBlockShape, ThreadShape, float, \ - cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, \ - float, cutlass::layout::RowMajor>; \ - return cutlass_vector_matrix_mul_batched_strided_wrapper( \ - problem_size, d_A, lda, batch_stride_a, d_B, ldb, \ - batch_stride_b, d_C, ldc, batch_stride_c, stream); \ +#define cb(threadblock_n_, LDG_K_, LDG_N_) \ + if (threadblock_n == threadblock_n_ && LDG_K == LDG_K_ && LDG_N == LDG_N_) { \ + using ThreadBlockShape = cutlass::gemm::GemmShape< \ + 1, threadblock_n_, (256 * LDG_K_) / (threadblock_n_ / LDG_N_)>; \ + using ThreadShape = cutlass::gemm::GemmShape<1, LDG_N_, LDG_K_>; \ + using GemvKernel = cutlass::gemm::kernel::DefaultGemv< \ + ThreadBlockShape, ThreadShape, float, cutlass::layout::RowMajor, \ + float, cutlass::layout::RowMajor, float, cutlass::layout::RowMajor>; \ + return cutlass_vector_matrix_mul_batched_strided_wrapper( \ + problem_size, d_A, lda, batch_stride_a, d_B, ldb, batch_stride_b, d_C, \ + ldc, batch_stride_c, stream); \ } DISPATCH(cb) #undef cb diff --git a/dnn/src/cuda/matrix_mul/naive.cpp b/dnn/src/cuda/matrix_mul/naive.cpp index a9145e59..2ab85d17 100644 --- a/dnn/src/cuda/matrix_mul/naive.cpp +++ b/dnn/src/cuda/matrix_mul/naive.cpp @@ -37,33 +37,31 @@ void MatrixMulForwardImpl::AlgoNaive::exec(const ExecArgs& args) const { auto&& param = args.opr->param(); auto m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; - auto LDA = args.tensor_a.layout.stride[0], - LDB = args.tensor_b.layout.stride[0], + auto LDA = args.tensor_a.layout.stride[0], LDB = args.tensor_b.layout.stride[0], LDC = args.tensor_c.layout.stride[0]; auto&& handle = concrete_handle(args.opr->handle()); using ComputeMode = Param::ComputeMode; -#define DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, cmode) \ - MIDOUT_BEGIN(megdnn_naive_matmul, midout_iv(#in_dt #out_dt #in_ct, \ - #out_ct, #comp_ct, #cmode)) { \ - do { \ - using namespace dtype; \ - if (args.tensor_a.layout.dtype.enumv() == \ - DTypeTrait::enumv && \ - args.tensor_c.layout.dtype.enumv() == \ - DTypeTrait::enumv && \ - param.compute_mode == cmode) { \ - in_ct* A = args.tensor_a.compatible_ptr(); \ - in_ct* B = args.tensor_b.compatible_ptr(); \ - out_ct* C = args.tensor_c.compatible_ptr(); \ - exec_gemm_naive( \ - A, B, C, m, n, k, LDA, LDB, LDC, param.transposeA, \ - param.transposeB, cuda_stream(handle)); \ - return; \ - } \ - } while (0); \ - } \ +#define DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, cmode) \ + MIDOUT_BEGIN( \ + megdnn_naive_matmul, \ + midout_iv(#in_dt #out_dt #in_ct, #out_ct, #comp_ct, #cmode)) { \ + do { \ + using namespace dtype; \ + if (args.tensor_a.layout.dtype.enumv() == DTypeTrait::enumv && \ + args.tensor_c.layout.dtype.enumv() == DTypeTrait::enumv && \ + param.compute_mode == cmode) { \ + in_ct* A = args.tensor_a.compatible_ptr(); \ + in_ct* B = args.tensor_b.compatible_ptr(); \ + out_ct* C = args.tensor_c.compatible_ptr(); \ + exec_gemm_naive( \ + A, B, C, m, n, k, LDA, LDB, LDC, param.transposeA, \ + param.transposeB, cuda_stream(handle)); \ + return; \ + } \ + } while (0); \ + } \ MIDOUT_END(); #define DISPATCH(in_dt, out_dt, in_ct, out_ct, comp_ct) \ DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, ComputeMode::DEFAULT) @@ -72,8 +70,9 @@ void MatrixMulForwardImpl::AlgoNaive::exec(const ExecArgs& args) const { DISPATCH(Float16, Float16, dt_float16, dt_float16, dt_float16); DISPATCH(Int8, Int32, dt_int8, dt_int32, dt_int32); DISPATCH(QuantizedS8, QuantizedS32, dt_int8, dt_int32, dt_int32); - DNN_INC_FLOAT16(DISPATCH_CMODE(Float16, Float16, dt_float16, dt_float16, - dt_float32, ComputeMode::FLOAT32)); + DNN_INC_FLOAT16(DISPATCH_CMODE( + Float16, Float16, dt_float16, dt_float16, dt_float32, + ComputeMode::FLOAT32)); #undef DISPATCH_CMODE #undef DISPATCH megdnn_throw(ssprintf( diff --git a/dnn/src/cuda/matrix_mul/naive.cu b/dnn/src/cuda/matrix_mul/naive.cu index 6a892a6a..c7a81a27 100644 --- a/dnn/src/cuda/matrix_mul/naive.cu +++ b/dnn/src/cuda/matrix_mul/naive.cu @@ -16,9 +16,9 @@ namespace { template -__global__ void do_exec(const AType* A, const BType* B, CType* C, size_t M, - size_t N, size_t K, size_t LDA, size_t LDB, size_t LDC, - bool transA, bool transB) { +__global__ void do_exec( + const AType* A, const BType* B, CType* C, size_t M, size_t N, size_t K, + size_t LDA, size_t LDB, size_t LDC, bool transA, bool transB) { size_t m = blockIdx.x; for (; m < M; m += gridDim.x) { size_t n = threadIdx.x; @@ -26,7 +26,7 @@ __global__ void do_exec(const AType* A, const BType* B, CType* C, size_t M, CompType res = static_cast(0); for (size_t k = 0; k < K; ++k) { AType av = transA ? A[k * LDA + m] : A[m * LDA + k], - bv = transB ? B[n * LDB + k] : B[k * LDB + n]; + bv = transB ? B[n * LDB + k] : B[k * LDB + n]; res += av * bv; } C[m * LDC + n] = res; @@ -39,19 +39,20 @@ namespace megdnn { namespace cuda { template -void exec_gemm_naive(const AType* A, const BType* B, CType* C, size_t M, - size_t N, size_t K, size_t LDA, size_t LDB, size_t LDC, - bool transA, bool transB, cudaStream_t stream) { - do_exec<<<128, 128, 0, stream>>>( - A, B, C, M, N, K, LDA, LDB, LDC, transA, transB); +void exec_gemm_naive( + const AType* A, const BType* B, CType* C, size_t M, size_t N, size_t K, + size_t LDA, size_t LDB, size_t LDC, bool transA, bool transB, + cudaStream_t stream) { + do_exec + <<<128, 128, 0, stream>>>(A, B, C, M, N, K, LDA, LDB, LDC, transA, transB); } -#define INST(in_ct, out_ct, comp_ct) \ - template void exec_gemm_naive( \ - const in_ct* A, const in_ct* B, out_ct* C, size_t M, size_t N, \ - size_t K, size_t LDA, size_t LDB, size_t LDC, bool transA, \ - bool transB, cudaStream_t stream); +#define INST(in_ct, out_ct, comp_ct) \ + template void exec_gemm_naive< \ + typename in_ct, typename in_ct, typename out_ct, typename comp_ct>( \ + const in_ct* A, const in_ct* B, out_ct* C, size_t M, size_t N, size_t K, \ + size_t LDA, size_t LDB, size_t LDC, bool transA, bool transB, \ + cudaStream_t stream); INST(megdnn::dt_float32, megdnn::dt_float32, megdnn::dt_float32) INST(megdnn::dt_float16, megdnn::dt_float16, megdnn::dt_float16) diff --git a/dnn/src/cuda/matrix_mul/naive.cuh b/dnn/src/cuda/matrix_mul/naive.cuh index 615d6bef..397ad366 100644 --- a/dnn/src/cuda/matrix_mul/naive.cuh +++ b/dnn/src/cuda/matrix_mul/naive.cuh @@ -16,10 +16,10 @@ namespace megdnn { namespace cuda { template -void exec_gemm_naive(const AType* A, const BType* B, CType* C, size_t m, - size_t n, size_t k, size_t ldA, size_t ldB, - size_t ldC, bool transA, bool transB, - cudaStream_t stream); +void exec_gemm_naive( + const AType* A, const BType* B, CType* C, size_t m, size_t n, size_t k, + size_t ldA, size_t ldB, size_t ldC, bool transA, bool transB, + cudaStream_t stream); } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/matrix_mul/opr_impl.cpp b/dnn/src/cuda/matrix_mul/opr_impl.cpp index b47b31b5..be912202 100644 --- a/dnn/src/cuda/matrix_mul/opr_impl.cpp +++ b/dnn/src/cuda/matrix_mul/opr_impl.cpp @@ -14,24 +14,21 @@ #include #include "src/cuda/handle.h" -#include "src/cuda/utils.h" #include "src/cuda/matrix_mul/cublasLt_wrapper.h" +#include "src/cuda/utils.h" namespace megdnn { namespace cuda { -std::vector -MatrixMulForwardImpl::get_all_algorithms(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) { +std::vector MatrixMulForwardImpl::get_all_algorithms( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { AlgoBase::SizeArgs args{this, A, B, C}; return megdnn::get_all_algorithms(args); } -std::vector -MatrixMulForwardImpl::get_all_algorithms_safe(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) { +std::vector MatrixMulForwardImpl:: + get_all_algorithms_safe( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { AlgoBase::SizeArgs args{this, A, B, C}; return megdnn::get_all_algorithms_safe(args); } @@ -64,15 +61,14 @@ MatrixMulForwardImpl::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( "matrix mul forward", positive_attr, negative_attr); } -size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) { +size_t MatrixMulForwardImpl::get_workspace_in_bytes( + const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { return get_dnn_workspace(this, A, B, C); } -void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, - _megdnn_tensor_out C, - _megdnn_workspace workspace) { +void MatrixMulForwardImpl::exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) { check_exec(A.layout, B.layout, C.layout, workspace.size); AlgoBase::ExecArgs args(this, A, B, C, workspace); auto&& algo = get_algorithm(this, A.layout, B.layout, C.layout); diff --git a/dnn/src/cuda/matrix_mul/opr_impl.h b/dnn/src/cuda/matrix_mul/opr_impl.h index 09ccf35c..69999fe5 100644 --- a/dnn/src/cuda/matrix_mul/opr_impl.h +++ b/dnn/src/cuda/matrix_mul/opr_impl.h @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once -#include "megdnn/oprs.h" #include +#include "megdnn/oprs.h" namespace megdnn { namespace cuda { @@ -18,16 +18,15 @@ namespace cuda { class MatrixMulForwardImpl : public MatrixMulForward { public: using MatrixMulForward::MatrixMulForward; - void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&) override; + void exec( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&) override; bool is_thread_safe() const override { return true; } - const char* get_algorithm_set_name() const override { - return "CUDA MATMUL"; - } + const char* get_algorithm_set_name() const override { return "CUDA MATMUL"; } class AlgoBase; class AlgoCuBlas; @@ -52,19 +51,17 @@ public: #endif class AlgoPack; - static const AlgoPack& algo_pack() { - return sm_algo_pack; - } + static const AlgoPack& algo_pack() { return sm_algo_pack; } Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: - std::vector get_all_algorithms(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) override; + std::vector get_all_algorithms( + const TensorLayout& A, const TensorLayout& B, + const TensorLayout& C) override; - std::vector get_all_algorithms_safe(const TensorLayout& A, - const TensorLayout& B, - const TensorLayout& C) override; + std::vector get_all_algorithms_safe( + const TensorLayout& A, const TensorLayout& B, + const TensorLayout& C) override; Algorithm* get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, diff --git a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma.cpp b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma.cpp index 380940d2..8f4325ff 100644 --- a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma.cpp +++ b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma.cpp @@ -11,9 +11,9 @@ #include "./algos.h" -#include "src/cuda/utils.h" #include "src/cuda/handle.h" #include "src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul.h" +#include "src/cuda/utils.h" using namespace megdnn; using namespace cuda; @@ -25,15 +25,13 @@ bool MatrixMulForwardImpl::AlgoUInt4x4x32WMMA::is_available( if (args.opr->param().format != param::MatrixMul::Format::DEFAULT) return false; auto&& device_prop = current_device_prop(); - if (device_prop.major < 7 || - (device_prop.major == 7 && device_prop.minor < 5)) { + if (device_prop.major < 7 || (device_prop.major == 7 && device_prop.minor < 5)) { return false; } auto&& param = args.opr->param(); if (!param.transposeA && param.transposeB) { - bool available = - args.layout_a.dtype.enumv() == DTypeEnum::Quantized4Asymm && - args.layout_c.dtype.enumv() == DTypeEnum::QuantizedS32; + bool available = args.layout_a.dtype.enumv() == DTypeEnum::Quantized4Asymm && + args.layout_c.dtype.enumv() == DTypeEnum::QuantizedS32; size_t m = args.layout_c.shape[0], n = args.layout_c.shape[1]; available &= (m % 8 == 0) && (n % 8 == 0); available &= (args.layout_a.stride[0] % 2 == 0) && @@ -53,9 +51,9 @@ void MatrixMulForwardImpl::AlgoUInt4x4x32WMMA::exec(const ExecArgs& args) const auto&& handle = concrete_handle(args.opr->handle()); auto&& param = args.opr->param(); if (!param.transposeA && param.transposeB) { - exec_wmma_matrix_mul_quint4_nt(args.tensor_a, args.tensor_b, - args.tensor_c, args.workspace, - handle->stream()); + exec_wmma_matrix_mul_quint4_nt( + args.tensor_a, args.tensor_b, args.tensor_c, args.workspace, + handle->stream()); } } #endif diff --git a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/preprocess_quantize_sum.cu b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/preprocess_quantize_sum.cu index af7d38a6..84915c5e 100644 --- a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/preprocess_quantize_sum.cu +++ b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/preprocess_quantize_sum.cu @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -46,12 +47,9 @@ namespace { using namespace megdnn::cuda; template -__global__ void reduce_column_with_scale_u4(const uint8_t* src, int32_t scale, - int rows, int cols_int32, - int ld_in_bytes, - int nr_thread_per_row_log2, - int sm_width_in_bytes, - int32_t* dst) { +__global__ void reduce_column_with_scale_u4( + const uint8_t* src, int32_t scale, int rows, int cols_int32, int ld_in_bytes, + int nr_thread_per_row_log2, int sm_width_in_bytes, int32_t* dst) { constexpr int warp_size = 32; extern __shared__ uint8_t sub_block_raw[]; @@ -63,8 +61,7 @@ __global__ void reduce_column_with_scale_u4(const uint8_t* src, int32_t scale, if (row_idx >= rows) return; - volatile int32_t* row = - (int32_t*)(sub_block_raw + row_num * sm_width_in_bytes); + volatile int32_t* row = (int32_t*)(sub_block_raw + row_num * sm_width_in_bytes); const int32_t* sptr = (const int32_t*)(src + row_idx * ld_in_bytes); sptr += tid; int32_t local = 0; @@ -98,9 +95,9 @@ __global__ void reduce_column_with_scale_u4(const uint8_t* src, int32_t scale, } template -__global__ void span_qsum(const int32_t* qSumA, const uint32_t M, - const int32_t* qSumB, const uint32_t N, int32_t* dst, - const uint32_t strd, const int32_t scaler_bias) { +__global__ void span_qsum( + const int32_t* qSumA, const uint32_t M, const int32_t* qSumB, const uint32_t N, + int32_t* dst, const uint32_t strd, const int32_t scaler_bias) { constexpr size_t mm = (BY + TY - 1) / TY; constexpr size_t nn = (BX + TX - 1) / TX; @@ -111,18 +108,16 @@ __global__ void span_qsum(const int32_t* qSumA, const uint32_t M, int gtidx = threadIdx.x + TX * j + blockIdx.x * BX; int gtidy = threadIdx.y + TY * i + blockIdx.y * BY; if (gtidx < N && gtidy < M) { - dst[gtidy * strd + gtidx] += - qSumA[gtidy] + qSumB[gtidx] + scaler_bias; + dst[gtidy * strd + gtidx] += qSumA[gtidy] + qSumB[gtidx] + scaler_bias; } } } } template -void _do_dispatch_reduce_column_with_scale_u4(const uint8_t* src, int32_t scale, - int rows, int cols_int32, - int ld_in_bytes, int32_t* dst, - cudaStream_t stream) { +void _do_dispatch_reduce_column_with_scale_u4( + const uint8_t* src, int32_t scale, int rows, int cols_int32, int ld_in_bytes, + int32_t* dst, cudaStream_t stream) { constexpr int warp_size = 32; int block_size = 1 << block_size_log2; int nr_thread_per_row = 1, nr_thread_per_row_log2 = 0; @@ -152,8 +147,7 @@ void _do_dispatch_reduce_column_with_scale_u4(const uint8_t* src, int32_t scale, } int nr_row_per_block = block_size / nr_thread_per_row, - nr_blk = DIVUP(rows, nr_row_per_block), - sm_width_word32 = nr_thread_per_row; + nr_blk = DIVUP(rows, nr_row_per_block), sm_width_word32 = nr_thread_per_row; // gcd(sm_width_word32, BANKS) should be 1 to avoid bank confliction // iff sm_width_word32 is odd @@ -161,18 +155,16 @@ void _do_dispatch_reduce_column_with_scale_u4(const uint8_t* src, int32_t scale, int sm_width_in_bytes = sm_width_word32 * 4, sm_size = nr_row_per_block * sm_width_in_bytes; - void (*kptr)(const uint8_t* src, int32_t scale, int rows, int cols_int32, - int ld_in_bytes, int nr_thread_per_row_log2, - int sm_width_in_bytes, int32_t* dst); + void (*kptr)( + const uint8_t* src, int32_t scale, int rows, int cols_int32, + int ld_in_bytes, int nr_thread_per_row_log2, int sm_width_in_bytes, + int32_t* dst); if (nr_thread_per_row <= max_nr_threads_per_row / 4) { - kptr = reduce_column_with_scale_u4; + kptr = reduce_column_with_scale_u4; } else if (nr_thread_per_row <= max_nr_threads_per_row / 2) { - kptr = reduce_column_with_scale_u4; + kptr = reduce_column_with_scale_u4; } else { - kptr = reduce_column_with_scale_u4; + kptr = reduce_column_with_scale_u4; } kptr<<>>( src, scale, rows, cols_int32, ld_in_bytes, nr_thread_per_row_log2, @@ -183,17 +175,16 @@ void _do_dispatch_reduce_column_with_scale_u4(const uint8_t* src, int32_t scale, } // namespace void megdnn::cuda::exec_reduce_sum_with_scale_uint4( - const uint8_t* A, int32_t scale, uint32_t M, uint32_t K, - uint32_t ldA_in_byte, int32_t* dst, cudaStream_t stream) { - _do_dispatch_reduce_column_with_scale_u4<7, 64>(A, scale, M, K / 8, - ldA_in_byte, dst, stream); + const uint8_t* A, int32_t scale, uint32_t M, uint32_t K, uint32_t ldA_in_byte, + int32_t* dst, cudaStream_t stream) { + _do_dispatch_reduce_column_with_scale_u4<7, 64>( + A, scale, M, K / 8, ldA_in_byte, dst, stream); } -void megdnn::cuda::exec_span_qsum(const int32_t* qSumA, const uint32_t M, - const int32_t* qSumB, const uint32_t N, - int32_t* dst, const uint32_t strd, - const int32_t scaler_bias, - cudaStream_t stream) { +void megdnn::cuda::exec_span_qsum( + const int32_t* qSumA, const uint32_t M, const int32_t* qSumB, const uint32_t N, + int32_t* dst, const uint32_t strd, const int32_t scaler_bias, + cudaStream_t stream) { constexpr uint32_t TX = 32, TY = 32, BX = 32, BY = 32; dim3 nthreads{TX, TY}; dim3 nblocks{DIVUP(N, BX), DIVUP(M, BY)}; diff --git a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/preprocess_quantize_sum.cuh b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/preprocess_quantize_sum.cuh index df98b444..05d88dd3 100644 --- a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/preprocess_quantize_sum.cuh +++ b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/preprocess_quantize_sum.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -38,16 +39,15 @@ namespace megdnn { namespace cuda { -void exec_reduce_sum_with_scale_uint4(const uint8_t* A, int32_t scale, - uint32_t M, uint32_t K, - uint32_t ldA_in_byte, int32_t* dst, - cudaStream_t stream); +void exec_reduce_sum_with_scale_uint4( + const uint8_t* A, int32_t scale, uint32_t M, uint32_t K, uint32_t ldA_in_byte, + int32_t* dst, cudaStream_t stream); -void exec_span_qsum(const int32_t* qSumA, const uint32_t M, - const int32_t* qSumB, const uint32_t N, int32_t* dst, - const uint32_t strd, const int32_t scaler_bias, - cudaStream_t stream); -} // namespace cuda -} // namespace megdnn +void exec_span_qsum( + const int32_t* qSumA, const uint32_t M, const int32_t* qSumB, const uint32_t N, + int32_t* dst, const uint32_t strd, const int32_t scaler_bias, + cudaStream_t stream); +} // namespace cuda +} // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul.cpp b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul.cpp index 7a4eb5d3..5ed4ed1b 100644 --- a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul.cpp +++ b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul.cpp @@ -28,17 +28,18 @@ void megdnn::cuda::matrix_mul::exec_wmma_matrix_mul_quint4_nt( ldC = C.layout.stride[0]; int32_t zA = A.layout.dtype.param().zero_point, zB = B.layout.dtype.param().zero_point; - exec_reduce_sum_with_scale_uint4(static_cast(A.raw_ptr), -zB, M, - K, ldA / 2, workspace.ptr(), - stream); - exec_reduce_sum_with_scale_uint4(static_cast(B.raw_ptr), -zA, N, - K, ldB / 2, workspace.ptr() + M, - stream); + exec_reduce_sum_with_scale_uint4( + static_cast(A.raw_ptr), -zB, M, K, ldA / 2, + workspace.ptr(), stream); + exec_reduce_sum_with_scale_uint4( + static_cast(B.raw_ptr), -zA, N, K, ldB / 2, + workspace.ptr() + M, stream); exec_wmma_gemm_u4( static_cast(A.raw_ptr), static_cast(B.raw_ptr), C.compatible_ptr(), M, N, K, ldA, ldB, ldC, stream); - exec_span_qsum(workspace.ptr(), M, workspace.ptr() + M, N, - C.compatible_ptr(), ldC, K * zA * zB, stream); + exec_span_qsum( + workspace.ptr(), M, workspace.ptr() + M, N, + C.compatible_ptr(), ldC, K * zA * zB, stream); } #endif // CUDA_VERSION diff --git a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul.h b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul.h index 573f4aac..8e1ec520 100644 --- a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul.h +++ b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul.h @@ -15,10 +15,9 @@ namespace megdnn { namespace cuda { namespace matrix_mul { -void exec_wmma_matrix_mul_quint4_nt(_megdnn_tensor_in A, _megdnn_tensor_in B, - _megdnn_tensor_out C, - _megdnn_workspace workspace, - cudaStream_t stream); +void exec_wmma_matrix_mul_quint4_nt( + _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace, cudaStream_t stream); } // namespace matrix_mul } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul_u4.cu b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul_u4.cu index fc3bb2a2..8ecd6b42 100644 --- a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul_u4.cu +++ b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul_u4.cu @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -48,8 +49,7 @@ namespace wmma_matrix_mul_u4 { constexpr uint32_t WMMA_M = 8, WMMA_N = 8, WMMA_K = 32, WARP_SIZE = 32; -template +template struct BlockConfig { static const size_t WARP_X = WARP_X_; static const size_t WARP_Y = WARP_Y_; @@ -64,22 +64,22 @@ struct BlockConfig { template struct GlobalToShareMemStreamConfig { static const size_t BlockSize = BlockSize_; - static const size_t CACHE_SIZE = - (BlockSize + BlockConfig_::WARPS_PER_BLOCK - 1) / - BlockConfig_::WARPS_PER_BLOCK; + static const size_t CACHE_SIZE = (BlockSize + BlockConfig_::WARPS_PER_BLOCK - 1) / + BlockConfig_::WARPS_PER_BLOCK; static const size_t SMEM_ROW = BlockSize; static const size_t SMEM_COL = BlockConfig_::BK; - static const size_t SMEM_SKEW = - WMMA_K * ((BlockConfig_::BK / WMMA_K) % 2 == 0); + static const size_t SMEM_SKEW = WMMA_K * ((BlockConfig_::BK / WMMA_K) % 2 == 0); static const size_t SMEM_STRIDE = SMEM_COL + SMEM_SKEW; }; -#if __CUDA_ARCH__ >= 730 +#if __CUDA_ARCH__ >= 730 template struct GlobalToShareMemStream { - MEGDNN_STATIC_ASSERT(GlobalToShareMemStreamConfig_::BlockSize == - GlobalToShareMemStreamConfig_::CACHE_SIZE * BlockConfig_::WARPS_PER_BLOCK, - "Block size mismatch"); + MEGDNN_STATIC_ASSERT( + GlobalToShareMemStreamConfig_::BlockSize == + GlobalToShareMemStreamConfig_::CACHE_SIZE * + BlockConfig_::WARPS_PER_BLOCK, + "Block size mismatch"); uint8_t* smem; const uint8_t* g_ptr; @@ -96,10 +96,10 @@ struct GlobalToShareMemStream { typedef int32_t copy_t; copy_t reg_cache[GlobalToShareMemStreamConfig_::CACHE_SIZE]; - __device__ GlobalToShareMemStream(uint8_t* smem, const uint8_t* g_ptr, - int ld, int row_remain, int K) + __device__ GlobalToShareMemStream( + uint8_t* smem, const uint8_t* g_ptr, int ld, int row_remain, int K) : smem{smem}, g_ptr{g_ptr}, ld{ld}, row_remain{row_remain}, K{K} { - k_base = 0; + k_base = 0; } __device__ __forceinline__ void copy() { @@ -137,11 +137,9 @@ struct GlobalToShareMemStream { template __device__ inline void load_share_mem( - wmma::fragment + wmma::fragment a_frag[BlockConfig_::ROW_PER_WARP], - wmma::fragment + wmma::fragment b_frag[BlockConfig_::COL_PER_WARP], GlobalToShareMemStream< BlockConfig_, @@ -152,10 +150,8 @@ __device__ inline void load_share_mem( GlobalToShareMemStreamConfig>& gbl2smem_b, int warp_k) { - typedef GlobalToShareMemStreamConfig - Config_A; - typedef GlobalToShareMemStreamConfig - Config_B; + typedef GlobalToShareMemStreamConfig Config_A; + typedef GlobalToShareMemStreamConfig Config_B; const int warp_x = threadIdx.x / WARP_SIZE; const int warp_y = threadIdx.y; uint8_t* __restrict__ s_ptr_a = @@ -180,19 +176,18 @@ __device__ inline void load_share_mem( } template -__device__ inline void -calc(wmma::fragment - a_frag[ROW_PER_WARP], - wmma::fragment - b_frag[COL_PER_WARP], - wmma::fragment - acc_frag[ROW_PER_WARP][COL_PER_WARP]) { +__device__ inline void calc( + wmma::fragment + a_frag[ROW_PER_WARP], + wmma::fragment + b_frag[COL_PER_WARP], + wmma::fragment + acc_frag[ROW_PER_WARP][COL_PER_WARP]) { #pragma unroll for (int i = 0; i < ROW_PER_WARP; ++i) { #pragma unroll for (int j = 0; j < COL_PER_WARP; ++j) { - wmma::mma_sync(acc_frag[i][j], a_frag[i], b_frag[j], - acc_frag[i][j]); + wmma::mma_sync(acc_frag[i][j], a_frag[i], b_frag[j], acc_frag[i][j]); } } } @@ -207,15 +202,12 @@ __device__ void inline consume_tile( BlockConfig_, GlobalToShareMemStreamConfig>& gbl2smem_b, - wmma::fragment + wmma::fragment a_frag[2][BlockConfig_::ROW_PER_WARP], - wmma::fragment + wmma::fragment b_frag[2][BlockConfig_::COL_PER_WARP], wmma::fragment - acc_frag[BlockConfig_::ROW_PER_WARP] - [BlockConfig_::COL_PER_WARP]) { + acc_frag[BlockConfig_::ROW_PER_WARP][BlockConfig_::COL_PER_WARP]) { if (!last_block) { gbl2smem_a.inc_stage(); gbl2smem_b.inc_stage(); @@ -225,9 +217,9 @@ __device__ void inline consume_tile( int warp_k = 0; #pragma unroll for (warp_k = 0; warp_k < BlockConfig_::BK / WMMA_K - 1; ++warp_k) { - load_share_mem(a_frag[(warp_k + 1) % 2], - b_frag[(warp_k + 1) % 2], gbl2smem_a, - gbl2smem_b, warp_k + 1); + load_share_mem( + a_frag[(warp_k + 1) % 2], b_frag[(warp_k + 1) % 2], gbl2smem_a, + gbl2smem_b, warp_k + 1); calc( a_frag[warp_k % 2], b_frag[warp_k % 2], acc_frag); } @@ -238,19 +230,16 @@ __device__ void inline consume_tile( gbl2smem_a.commit(); gbl2smem_b.commit(); __syncthreads(); - load_share_mem(a_frag[0], b_frag[0], gbl2smem_a, - gbl2smem_b, 0); + load_share_mem(a_frag[0], b_frag[0], gbl2smem_a, gbl2smem_b, 0); } } template -__global__ void u4_gemm_template_device_nt(const uint8_t* A, const uint8_t* B, - int32_t* C, int M, int N, int K, - int lda, int ldb, int ldc) { - typedef GlobalToShareMemStreamConfig - Config_A; - typedef GlobalToShareMemStreamConfig - Config_B; +__global__ void u4_gemm_template_device_nt( + const uint8_t* A, const uint8_t* B, int32_t* C, int M, int N, int K, int lda, + int ldb, int ldc) { + typedef GlobalToShareMemStreamConfig Config_A; + typedef GlobalToShareMemStreamConfig Config_B; __shared__ uint8_t smem_a[BlockConfig_::BM][Config_A::SMEM_STRIDE / 2]; __shared__ uint8_t smem_b[BlockConfig_::BN][Config_B::SMEM_STRIDE / 2]; @@ -297,22 +286,20 @@ __global__ void u4_gemm_template_device_nt(const uint8_t* A, const uint8_t* B, const int BLK_K = (K + BlockConfig_::BK - 1) / BlockConfig_::BK; #pragma unroll 1 for (int blk_k = 0; blk_k < BLK_K - 1; ++blk_k) { - consume_tile(gbl2smem_a, gbl2smem_b, a_frag, - b_frag, acc_frag); + consume_tile( + gbl2smem_a, gbl2smem_b, a_frag, b_frag, acc_frag); } - consume_tile(gbl2smem_a, gbl2smem_b, a_frag, b_frag, - acc_frag); + consume_tile(gbl2smem_a, gbl2smem_b, a_frag, b_frag, acc_frag); #pragma unroll for (int i = 0; i < BlockConfig_::ROW_PER_WARP; ++i) { #pragma unroll for (int j = 0; j < BlockConfig_::COL_PER_WARP; ++j) { - if (warp_row_start + i * BlockConfig_::WARP_Y * WMMA_M <= - M - WMMA_M && - warp_col_start + j * BlockConfig_::WARP_X * WMMA_N <= - N - WMMA_N) { + if (warp_row_start + i * BlockConfig_::WARP_Y * WMMA_M <= M - WMMA_M && + warp_col_start + j * BlockConfig_::WARP_X * WMMA_N <= N - WMMA_N) { wmma::store_matrix_sync( - &g_ptr_c[(i * BlockConfig_::WARP_Y * WMMA_M) * ldc + + &g_ptr_c + [(i * BlockConfig_::WARP_Y * WMMA_M) * ldc + (j * BlockConfig_::WARP_X * WMMA_N)], acc_frag[i][j], ldc, wmma::mem_row_major); } @@ -321,23 +308,20 @@ __global__ void u4_gemm_template_device_nt(const uint8_t* A, const uint8_t* B, } #else template -__global__ void u4_gemm_template_device_nt(const uint8_t* /*A*/, - const uint8_t* /*B*/, int32_t* /*C*/, - int /*M*/, int /*N*/, int /*K*/, - int /*lda*/, int /*ldb*/, - int /*ldc*/) {} +__global__ void u4_gemm_template_device_nt( + const uint8_t* /*A*/, const uint8_t* /*B*/, int32_t* /*C*/, int /*M*/, + int /*N*/, int /*K*/, int /*lda*/, int /*ldb*/, int /*ldc*/) {} #endif -void _do_dispatch_wmma_matrix_mul_u4(const uint8_t* A, const uint8_t* B, - int32_t* C, int M, int N, int K, int lda, - int ldb, int ldc, cudaStream_t stream) { - constexpr uint32_t warp_x = 4, warp_y = 4, row_per_warp = 4, - col_per_warp = 4; - typedef BlockConfig - BlockConfig_; +void _do_dispatch_wmma_matrix_mul_u4( + const uint8_t* A, const uint8_t* B, int32_t* C, int M, int N, int K, int lda, + int ldb, int ldc, cudaStream_t stream) { + constexpr uint32_t warp_x = 4, warp_y = 4, row_per_warp = 4, col_per_warp = 4; + typedef BlockConfig BlockConfig_; dim3 block{warp_x * WARP_SIZE, warp_y}; - dim3 grid{static_cast(DIVUP(N, BlockConfig_::BN)), - static_cast(DIVUP(M, BlockConfig_::BM))}; + dim3 grid{ + static_cast(DIVUP(N, BlockConfig_::BN)), + static_cast(DIVUP(M, BlockConfig_::BM))}; u4_gemm_template_device_nt <<>>(A, B, C, M, N, K, lda, ldb, ldc); after_kernel_launch(); @@ -346,11 +330,11 @@ void _do_dispatch_wmma_matrix_mul_u4(const uint8_t* A, const uint8_t* B, namespace megdnn { namespace cuda { -void exec_wmma_gemm_u4(const uint8_t* A, const uint8_t* B, int32_t* C, int M, - int N, int K, int lda, int ldb, int ldc, - cudaStream_t stream) { - wmma_matrix_mul_u4::_do_dispatch_wmma_matrix_mul_u4(A, B, C, M, N, K, lda, - ldb, ldc, stream); +void exec_wmma_gemm_u4( + const uint8_t* A, const uint8_t* B, int32_t* C, int M, int N, int K, int lda, + int ldb, int ldc, cudaStream_t stream) { + wmma_matrix_mul_u4::_do_dispatch_wmma_matrix_mul_u4( + A, B, C, M, N, K, lda, ldb, ldc, stream); } } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul_u4.cuh b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul_u4.cuh index 6ae7c778..931390c4 100644 --- a/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul_u4.cuh +++ b/dnn/src/cuda/matrix_mul/uint4x4x32_wmma/wmma_matrix_mul_u4.cuh @@ -1,25 +1,26 @@ /*************************************************************************************************** * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved. * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without modification, are + *permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this + *list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this + *list of conditions and the following disclaimer in the documentation and/or other + *materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors + *may be used to endorse or promote products derived from this software without specific + *prior written permission. * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY + *EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + *OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT + *SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + *EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + *HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + *SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** @@ -38,8 +39,9 @@ namespace megdnn { namespace cuda { -void exec_wmma_gemm_u4(const uint8_t* A, const uint8_t* B, int32_t* C, int M, - int N, int K, int ldA, int ldB, int ldC, cudaStream_t stream); +void exec_wmma_gemm_u4( + const uint8_t* A, const uint8_t* B, int32_t* C, int M, int N, int K, int ldA, + int ldB, int ldC, cudaStream_t stream); } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/max_tensor_diff/opr_impl.cpp b/dnn/src/cuda/max_tensor_diff/opr_impl.cpp index b90b5c6f..75bdcc14 100644 --- a/dnn/src/cuda/max_tensor_diff/opr_impl.cpp +++ b/dnn/src/cuda/max_tensor_diff/opr_impl.cpp @@ -15,8 +15,7 @@ using namespace megdnn; using namespace cuda; -float MaxTensorDiffImpl::exec(_megdnn_tensor_in, _megdnn_tensor_in, - _megdnn_workspace) { +float MaxTensorDiffImpl::exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_workspace) { megdnn_throw("MaxTensorDiff not support in cuda"); } diff --git a/dnn/src/cuda/max_tensor_diff/opr_impl.h b/dnn/src/cuda/max_tensor_diff/opr_impl.h index 71027f55..6aac645a 100644 --- a/dnn/src/cuda/max_tensor_diff/opr_impl.h +++ b/dnn/src/cuda/max_tensor_diff/opr_impl.h @@ -20,13 +20,13 @@ public: bool is_thread_safe() const override { return true; } - size_t get_workspace_in_bytes(const TensorLayout&, - const TensorLayout&) override { + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { return 0; }; - float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2, - _megdnn_workspace workspace) override; + float exec( + _megdnn_tensor_in src1, _megdnn_tensor_in src2, + _megdnn_workspace workspace) override; }; } // namespace cuda diff --git a/dnn/src/cuda/megcore/cuda_computing_context.cpp b/dnn/src/cuda/megcore/cuda_computing_context.cpp index 434273a3..7095b805 100644 --- a/dnn/src/cuda/megcore/cuda_computing_context.cpp +++ b/dnn/src/cuda/megcore/cuda_computing_context.cpp @@ -13,38 +13,34 @@ #include "src/common/utils.h" #include "src/cuda/utils.h" - #include "./cuda_computing_context.hpp" using namespace megcore; using namespace megcore::cuda; -CUDAComputingContext::CUDAComputingContext(megcoreDeviceHandle_t dev_handle, - unsigned int flags, const CudaContext& ctx): - ComputingContext(dev_handle, flags), - own_stream_{ctx.stream == nullptr}, - context_{ctx} -{ +CUDAComputingContext::CUDAComputingContext( + megcoreDeviceHandle_t dev_handle, unsigned int flags, const CudaContext& ctx) + : ComputingContext(dev_handle, flags), + own_stream_{ctx.stream == nullptr}, + context_{ctx} { megcorePlatform_t platform; megcoreGetPlatform(dev_handle, &platform); - megdnn_throw_if(platform != megcorePlatformCUDA, megdnn_error, - "platform should be CUDA Platform"); + megdnn_throw_if( + platform != megcorePlatformCUDA, megdnn_error, + "platform should be CUDA Platform"); if (own_stream_) { - cuda_check(cudaStreamCreateWithFlags(&context_.stream, - cudaStreamNonBlocking)); + cuda_check(cudaStreamCreateWithFlags(&context_.stream, cudaStreamNonBlocking)); } } -CUDAComputingContext::~CUDAComputingContext() -{ +CUDAComputingContext::~CUDAComputingContext() { if (own_stream_) { cuda_check(cudaStreamDestroy(context_.stream)); } } -void CUDAComputingContext::memcpy(void *dst, const void *src, - size_t size_in_bytes, megcoreMemcpyKind_t kind) -{ +void CUDAComputingContext::memcpy( + void* dst, const void* src, size_t size_in_bytes, megcoreMemcpyKind_t kind) { cudaMemcpyKind cuda_kind; switch (kind) { case megcoreMemcpyDeviceToHost: @@ -59,19 +55,15 @@ void CUDAComputingContext::memcpy(void *dst, const void *src, default: megdnn_throw("bad cuda memcpy kind"); } - cuda_check(cudaMemcpyAsync(dst, src, size_in_bytes, cuda_kind, - context_.stream)); + cuda_check(cudaMemcpyAsync(dst, src, size_in_bytes, cuda_kind, context_.stream)); } -void CUDAComputingContext::memset(void *dst, int value, size_t size_in_bytes) -{ +void CUDAComputingContext::memset(void* dst, int value, size_t size_in_bytes) { cuda_check(cudaMemsetAsync(dst, value, size_in_bytes, context_.stream)); } -void CUDAComputingContext::synchronize() -{ +void CUDAComputingContext::synchronize() { cuda_check(cudaStreamSynchronize(context_.stream)); } - // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/megcore/cuda_device_context.cpp b/dnn/src/cuda/megcore/cuda_device_context.cpp index 914f1998..d34fe764 100644 --- a/dnn/src/cuda/megcore/cuda_device_context.cpp +++ b/dnn/src/cuda/megcore/cuda_device_context.cpp @@ -16,21 +16,21 @@ #include "./cuda_device_context.hpp" #define STR_HELPER(x) #x -#define STR(x) STR_HELPER(x) +#define STR(x) STR_HELPER(x) #pragma message "compile with cuda " STR(CUDART_VERSION) " " using namespace megcore; using namespace cuda; -CUDADeviceContext::CUDADeviceContext(int device_id, unsigned int flags): - DeviceContext(megcorePlatformCUDA, device_id, flags) -{ +CUDADeviceContext::CUDADeviceContext(int device_id, unsigned int flags) + : DeviceContext(megcorePlatformCUDA, device_id, flags) { int version; cuda_check(cudaRuntimeGetVersion(&version)); - megdnn_assert(version == CUDART_VERSION, - "megcore compiled with cuda %d, get %d at runtime", - CUDART_VERSION, version); + megdnn_assert( + version == CUDART_VERSION, + "megcore compiled with cuda %d, get %d at runtime", CUDART_VERSION, + version); int id = device_id; if (id < 0) { cuda_check(cudaGetDevice(&id)); @@ -44,23 +44,20 @@ size_t CUDADeviceContext::mem_alignment_in_bytes() const noexcept { return std::max(prop_.textureAlignment, prop_.texturePitchAlignment); } -void CUDADeviceContext::activate() -{ +void CUDADeviceContext::activate() { int id = device_id(); if (id >= 0) { cuda_check(cudaSetDevice(id)); } } -void *CUDADeviceContext::malloc(size_t size_in_bytes) -{ - void *ptr; +void* CUDADeviceContext::malloc(size_t size_in_bytes) { + void* ptr; cuda_check(cudaMalloc(&ptr, size_in_bytes)); return ptr; } -void CUDADeviceContext::free(void *ptr) -{ +void CUDADeviceContext::free(void* ptr) { cuda_check(cudaFree(ptr)); } diff --git a/dnn/src/cuda/megcore/public_api/computing.cpp b/dnn/src/cuda/megcore/public_api/computing.cpp index 9b45d81c..8270c46a 100644 --- a/dnn/src/cuda/megcore/public_api/computing.cpp +++ b/dnn/src/cuda/megcore/public_api/computing.cpp @@ -10,41 +10,36 @@ */ #include "megcore_cuda.h" -#include "src/common/utils.h" -#include "src/common/megcore/public_api/computing.hpp" #include "../cuda_computing_context.hpp" +#include "src/common/megcore/public_api/computing.hpp" +#include "src/common/utils.h" using namespace megcore; megcoreStatus_t megcore::createComputingHandleWithCUDAContext( - megcoreComputingHandle_t *compHandle, - megcoreDeviceHandle_t devHandle, - unsigned int flags, - const CudaContext& ctx) -{ - auto content = megdnn::make_unique( - devHandle, flags, ctx); - auto &H = *compHandle; + megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, + unsigned int flags, const CudaContext& ctx) { + auto content = + megdnn::make_unique(devHandle, flags, ctx); + auto& H = *compHandle; H = new megcoreComputingContext; H->content = std::move(content); return megcoreSuccess; } -megcoreStatus_t megcore::getCUDAContext(megcoreComputingHandle_t handle, - CudaContext* ctx) -{ - auto &&H = handle; +megcoreStatus_t megcore::getCUDAContext( + megcoreComputingHandle_t handle, CudaContext* ctx) { + auto&& H = handle; megdnn_assert(H); megcoreDeviceHandle_t dev_handle = H->content->dev_handle(); megcorePlatform_t platform; megcoreGetPlatform(dev_handle, &platform); - megdnn_throw_if(platform != megcorePlatformCUDA, megdnn_error, - "platform should be CUDA Platform"); - auto context = - static_cast(H->content.get()); + megdnn_throw_if( + platform != megcorePlatformCUDA, megdnn_error, + "platform should be CUDA Platform"); + auto context = static_cast(H->content.get()); *ctx = context->context(); return megcoreSuccess; } // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/memory_utils.cuh b/dnn/src/cuda/memory_utils.cuh index 82c718ed..1469d87d 100644 --- a/dnn/src/cuda/memory_utils.cuh +++ b/dnn/src/cuda/memory_utils.cuh @@ -33,8 +33,8 @@ struct global_load; // initialize data to zero before ld.global template struct global_load { - MEGDNN_DEVICE __forceinline__ global_load(AccessType& D, void const* ptr, - bool pred_guard, int val = 0) { + MEGDNN_DEVICE __forceinline__ + global_load(AccessType& D, void const* ptr, bool pred_guard, int val = 0) { uint4* data = reinterpret_cast(&D); asm volatile( @@ -52,19 +52,17 @@ struct global_load { " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%8];\n" " @p ld.global.v4.u32 {%4, %5, %6, %7}, [%11];\n" "}\n" - : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), - "=r"(data[0].w), "=r"(data[1].x), "=r"(data[1].y), - "=r"(data[1].z), "=r"(data[1].w) - : "l"(ptr), "r"((int)pred_guard), - "r"(reinterpret_cast(val)), + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w), + "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w) + : "l"(ptr), "r"((int)pred_guard), "r"(reinterpret_cast(val)), "l"(((uint8_t*)ptr) + 16)); } }; template struct global_load { - MEGDNN_DEVICE __forceinline__ global_load(AccessType& D, void const* ptr, - bool pred_guard, int val) { + MEGDNN_DEVICE __forceinline__ + global_load(AccessType& D, void const* ptr, bool pred_guard, int val) { uint4& data = reinterpret_cast(D); asm volatile( @@ -85,8 +83,8 @@ struct global_load { template struct global_load { - MEGDNN_DEVICE __forceinline__ global_load(AccessType& D, void const* ptr, - bool pred_guard, int val) { + MEGDNN_DEVICE __forceinline__ + global_load(AccessType& D, void const* ptr, bool pred_guard, int val) { uint2& data = reinterpret_cast(D); asm volatile( @@ -105,8 +103,8 @@ struct global_load { template struct global_load { - MEGDNN_DEVICE __forceinline__ global_load(AccessType& D, void const* ptr, - bool pred_guard, int val) { + MEGDNN_DEVICE __forceinline__ + global_load(AccessType& D, void const* ptr, bool pred_guard, int val) { unsigned& data = reinterpret_cast(D); asm volatile( @@ -124,8 +122,8 @@ struct global_load { template struct global_load { - MEGDNN_DEVICE __forceinline__ global_load(AccessType& D, void const* ptr, - bool pred_guard, int val) { + MEGDNN_DEVICE __forceinline__ + global_load(AccessType& D, void const* ptr, bool pred_guard, int val) { if (pred_guard) D = *(reinterpret_cast(ptr)); else { @@ -152,8 +150,8 @@ struct global_store; template struct global_store { - MEGDNN_DEVICE __forceinline__ global_store(AccessType const& D, void* ptr, - bool pred_guard) { + MEGDNN_DEVICE __forceinline__ + global_store(AccessType const& D, void* ptr, bool pred_guard) { uint4 const* data = reinterpret_cast(&D); asm volatile( @@ -165,16 +163,15 @@ struct global_store { "}\n" : : "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), - "r"(data[0].w), "r"((int)pred_guard), - "l"(((uint8_t*)ptr) + 16), "r"(data[1].x), "r"(data[1].y), - "r"(data[1].z), "r"(data[1].w)); + "r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t*)ptr) + 16), + "r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w)); } }; template struct global_store { - MEGDNN_DEVICE __forceinline__ global_store(AccessType const& D, void* ptr, - bool pred_guard) { + MEGDNN_DEVICE __forceinline__ + global_store(AccessType const& D, void* ptr, bool pred_guard) { uint4 const& data = reinterpret_cast(D); asm volatile( "{\n" @@ -190,8 +187,8 @@ struct global_store { template struct global_store { - MEGDNN_DEVICE __forceinline__ global_store(AccessType const& D, void* ptr, - bool pred_guard) { + MEGDNN_DEVICE __forceinline__ + global_store(AccessType const& D, void* ptr, bool pred_guard) { uint2 const& data = reinterpret_cast(D); asm volatile( "{\n" @@ -206,8 +203,8 @@ struct global_store { template struct global_store { - MEGDNN_DEVICE __forceinline__ global_store(AccessType const& D, void* ptr, - bool pred_guard) { + MEGDNN_DEVICE __forceinline__ + global_store(AccessType const& D, void* ptr, bool pred_guard) { uint32_t const& data = reinterpret_cast(D); asm volatile( "{\n" @@ -222,8 +219,8 @@ struct global_store { template struct global_store { - MEGDNN_DEVICE __forceinline__ global_store(AccessType const& D, void* ptr, - bool pred_guard) { + MEGDNN_DEVICE __forceinline__ + global_store(AccessType const& D, void* ptr, bool pred_guard) { uint16_t const& data = reinterpret_cast(D); asm volatile( "{\n" @@ -238,8 +235,8 @@ struct global_store { template struct global_store { - MEGDNN_DEVICE __forceinline__ global_store(AccessType const& D, void* ptr, - bool pred_guard) { + MEGDNN_DEVICE __forceinline__ + global_store(AccessType const& D, void* ptr, bool pred_guard) { if (pred_guard) *(reinterpret_cast(ptr)) = D; } diff --git a/dnn/src/cuda/mesh_indexing/mesh_indexing.cu b/dnn/src/cuda/mesh_indexing/mesh_indexing.cu index faa70fe1..875eab1b 100644 --- a/dnn/src/cuda/mesh_indexing/mesh_indexing.cu +++ b/dnn/src/cuda/mesh_indexing/mesh_indexing.cu @@ -18,8 +18,7 @@ #define KERN_APPLY_OPR_INDEXING ::megdnn::indexing_multi_axis_vec_kdef::OprFwd -#define KERN_APPLY_OPR_INCR \ - ::megdnn::cuda::indexing_multi_axis_vec::OprAtomicIncr +#define KERN_APPLY_OPR_INCR ::megdnn::cuda::indexing_multi_axis_vec::OprAtomicIncr #define KERN_APPLY_OPR_SET ::megdnn::indexing_multi_axis_vec_kdef::OprSet @@ -30,8 +29,8 @@ using namespace cuda; using namespace mesh_indexing; template -__global__ void mesh_indexing_general_kernel(T* src, T* dst, - const KernIndexer indexer) { +__global__ void mesh_indexing_general_kernel( + T* src, T* dst, const KernIndexer indexer) { uint32_t dst_idx = blockIdx.x * blockDim.x + threadIdx.x; if (dst_idx < indexer.size) { int src_idx = indexer.convert_indxer(dst_idx); @@ -45,26 +44,22 @@ namespace cuda { namespace mesh_indexing { template -void mesh_indexing_proxy(T* src, T* dst, KernIndexer* indexer, - cudaStream_t stream) { +void mesh_indexing_proxy(T* src, T* dst, KernIndexer* indexer, cudaStream_t stream) { mesh_indexing_general_kernel <<size, NR_THREADS), NR_THREADS, 0, stream>>>( src, dst, *indexer); } -#define INST(_ctype) \ - template void mesh_indexing_proxy<_ctype, KERN_APPLY_OPR_INDEXING>( \ - _ctype * src, _ctype * dst, KernIndexer * indexer, \ - cudaStream_t stream); \ - \ - template void mesh_indexing_proxy<_ctype, KERN_APPLY_OPR_SET>( \ - _ctype * src, _ctype * dst, KernIndexer * indexer, \ - cudaStream_t stream); +#define INST(_ctype) \ + template void mesh_indexing_proxy<_ctype, KERN_APPLY_OPR_INDEXING>( \ + _ctype * src, _ctype * dst, KernIndexer * indexer, cudaStream_t stream); \ + \ + template void mesh_indexing_proxy<_ctype, KERN_APPLY_OPR_SET>( \ + _ctype * src, _ctype * dst, KernIndexer * indexer, cudaStream_t stream); #define INST_ATOMIC_ADD(_ctype) \ template void mesh_indexing_proxy<_ctype, KERN_APPLY_OPR_INCR>( \ - _ctype * src, _ctype * dst, KernIndexer * indexer, \ - cudaStream_t stream); + _ctype * src, _ctype * dst, KernIndexer * indexer, cudaStream_t stream); #define cb(_dtype) INST(DTypeTrait<_dtype>::ctype) diff --git a/dnn/src/cuda/mesh_indexing/mesh_indexing.cuh b/dnn/src/cuda/mesh_indexing/mesh_indexing.cuh index e88baa97..22b3274c 100644 --- a/dnn/src/cuda/mesh_indexing/mesh_indexing.cuh +++ b/dnn/src/cuda/mesh_indexing/mesh_indexing.cuh @@ -35,12 +35,10 @@ struct KernIndexer { uint32_t batch_stride; uint32_t size; - KernIndexer(const TensorLayout& origin_layout, - const TensorLayout& indexed_layout, int** _ptrs, - const TensorLayout* desc_layouts, - void* _err_tracker = nullptr, - megcore::AsyncErrorInfo* _err_info = nullptr, - bool _batch_mode = false) + KernIndexer( + const TensorLayout& origin_layout, const TensorLayout& indexed_layout, + int** _ptrs, const TensorLayout* desc_layouts, void* _err_tracker = nullptr, + megcore::AsyncErrorInfo* _err_info = nullptr, bool _batch_mode = false) : error_tracker(_err_tracker), error_info(_err_info), batch_mode(_batch_mode), @@ -76,10 +74,11 @@ struct KernIndexer { pos += (pos < 0 ? origin_shape[i] : 0); } if (static_cast(pos) >= origin_shape[i]) { - set_async_error_info(error_info, error_tracker, - "invalid mesh indexing: " - "indexer=%d idx=%d shape=%d", - i, pos, origin_shape[i]); + set_async_error_info( + error_info, error_tracker, + "invalid mesh indexing: " + "indexer=%d idx=%d shape=%d", + i, pos, origin_shape[i]); } data_offset += pos * origin_stride[i]; index /= indexed_shape[i]; @@ -91,8 +90,8 @@ struct KernIndexer { }; template -void mesh_indexing_proxy(T* origin, T* indexed, KernIndexer* indexer, - cudaStream_t stream); +void mesh_indexing_proxy( + T* origin, T* indexed, KernIndexer* indexer, cudaStream_t stream); } // namespace mesh_indexing } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/mesh_indexing/opr_impl.cpp b/dnn/src/cuda/mesh_indexing/opr_impl.cpp index b0a4b586..5f11f1d8 100644 --- a/dnn/src/cuda/mesh_indexing/opr_impl.cpp +++ b/dnn/src/cuda/mesh_indexing/opr_impl.cpp @@ -18,9 +18,10 @@ namespace { using namespace megdnn; using namespace cuda; using namespace mesh_indexing; -KernIndexer get_indexer(const TensorND& origin, const TensorND& indexed, - const MeshBase::IndexDesc& desc, void* error_tracker, - megcore::AsyncErrorInfo* error_info, bool batched) { +KernIndexer get_indexer( + const TensorND& origin, const TensorND& indexed, + const MeshBase::IndexDesc& desc, void* error_tracker, + megcore::AsyncErrorInfo* error_info, bool batched) { int* tmp_ptrs[TensorShape::MAX_NDIM] = {nullptr}; TensorLayout desc_layouts[TensorShape::MAX_NDIM]; for (size_t i = 0; i < desc.size(); ++i) { @@ -34,12 +35,11 @@ KernIndexer get_indexer(const TensorND& origin, const TensorND& indexed, } template -void do_exec(const TensorND& data, const TensorND& value, - const MeshBase::IndexDesc& desc, Handle* handle, - void* error_tracker) { +void do_exec( + const TensorND& data, const TensorND& value, const MeshBase::IndexDesc& desc, + Handle* handle, void* error_tracker) { auto error_info = async_error_info(handle); - auto indexer = - get_indexer(data, value, desc, error_tracker, error_info, batched); + auto indexer = get_indexer(data, value, desc, error_tracker, error_info, batched); auto stream = cuda_stream(handle); mesh_indexing::mesh_indexing_proxy( @@ -53,8 +53,9 @@ namespace cuda { /* =========================== MeshIndexing ============================ */ -void MeshIndexingImpl::exec(_megdnn_tensor_in src, const IndexDesc& desc, - _megdnn_tensor_out dst, _megdnn_workspace) { +void MeshIndexingImpl::exec( + _megdnn_tensor_in src, const IndexDesc& desc, _megdnn_tensor_out dst, + _megdnn_workspace) { check_exec(src.layout, dst.layout, desc); #define cb(DType) \ if (dst.layout.dtype.enumv() == DTypeTrait::enumv) { \ @@ -70,8 +71,9 @@ void MeshIndexingImpl::exec(_megdnn_tensor_in src, const IndexDesc& desc, /* ========================= BatchedMeshIndexing ========================== */ -void BatchedMeshIndexingImpl::exec(_megdnn_tensor_in src, const IndexDesc& desc, - _megdnn_tensor_out dst, _megdnn_workspace) { +void BatchedMeshIndexingImpl::exec( + _megdnn_tensor_in src, const IndexDesc& desc, _megdnn_tensor_out dst, + _megdnn_workspace) { check_exec(src.layout, dst.layout, desc); #define cb(DType) \ @@ -88,9 +90,9 @@ void BatchedMeshIndexingImpl::exec(_megdnn_tensor_in src, const IndexDesc& desc, /* ============================ Mesh ============================= */ -void IncrMeshIndexingImpl::exec(_megdnn_tensor_inout data, - _megdnn_tensor_in value, const IndexDesc& desc, - _megdnn_workspace) { +void IncrMeshIndexingImpl::exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc, + _megdnn_workspace) { check_exec(data.layout, value.layout, desc); #define cb(DType) \ @@ -107,9 +109,9 @@ void IncrMeshIndexingImpl::exec(_megdnn_tensor_inout data, megdnn_assert_internal(0); } -void SetMeshIndexingImpl::exec(_megdnn_tensor_inout data, - _megdnn_tensor_in value, const IndexDesc& desc, - _megdnn_workspace) { +void SetMeshIndexingImpl::exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc, + _megdnn_workspace) { check_exec(data.layout, value.layout, desc); #define cb(DType) \ @@ -126,10 +128,9 @@ void SetMeshIndexingImpl::exec(_megdnn_tensor_inout data, } /* ========================== BatchedMesh ============================= */ -void BatchedIncrMeshIndexingImpl::exec(_megdnn_tensor_inout data, - _megdnn_tensor_in value, - const IndexDesc& desc, - _megdnn_workspace) { +void BatchedIncrMeshIndexingImpl::exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc, + _megdnn_workspace) { check_exec(data.layout, value.layout, desc); #define cb(DType) \ @@ -145,10 +146,9 @@ void BatchedIncrMeshIndexingImpl::exec(_megdnn_tensor_inout data, megdnn_assert_internal(0); } -void BatchedSetMeshIndexingImpl::exec(_megdnn_tensor_inout data, - _megdnn_tensor_in value, - const IndexDesc& desc, - _megdnn_workspace) { +void BatchedSetMeshIndexingImpl::exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc, + _megdnn_workspace) { check_exec(data.layout, value.layout, desc); #define cb(DType) \ diff --git a/dnn/src/cuda/mesh_indexing/opr_impl.h b/dnn/src/cuda/mesh_indexing/opr_impl.h index 22c0dfc1..1184c798 100644 --- a/dnn/src/cuda/mesh_indexing/opr_impl.h +++ b/dnn/src/cuda/mesh_indexing/opr_impl.h @@ -22,12 +22,11 @@ class MeshIndexingImpl : public MeshIndexing { public: using MeshIndexing::MeshIndexing; - void exec(_megdnn_tensor_in src, const IndexDesc& desc, - _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, const IndexDesc& desc, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } }; class IncrMeshIndexingImpl : public IncrMeshIndexing { @@ -36,12 +35,11 @@ class IncrMeshIndexingImpl : public IncrMeshIndexing { public: using IncrMeshIndexing::IncrMeshIndexing; - void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value, - const IndexDesc& desc, _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc, + _megdnn_workspace workspace) override; - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } }; class SetMeshIndexingImpl : public SetMeshIndexing { @@ -50,12 +48,11 @@ class SetMeshIndexingImpl : public SetMeshIndexing { public: using SetMeshIndexing::SetMeshIndexing; - void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value, - const IndexDesc& desc, _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc, + _megdnn_workspace workspace) override; - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } }; class BatchedMeshIndexingImpl : public BatchedMeshIndexing { @@ -64,12 +61,11 @@ class BatchedMeshIndexingImpl : public BatchedMeshIndexing { public: using BatchedMeshIndexing::BatchedMeshIndexing; - void exec(_megdnn_tensor_in src, const IndexDesc& desc, - _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in src, const IndexDesc& desc, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } }; class BatchedIncrMeshIndexingImpl : public BatchedIncrMeshIndexing { @@ -78,12 +74,11 @@ class BatchedIncrMeshIndexingImpl : public BatchedIncrMeshIndexing { public: using BatchedIncrMeshIndexing::BatchedIncrMeshIndexing; - void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value, - const IndexDesc& desc, _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc, + _megdnn_workspace workspace) override; - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } }; class BatchedSetMeshIndexingImpl : public BatchedSetMeshIndexing { @@ -92,12 +87,11 @@ class BatchedSetMeshIndexingImpl : public BatchedSetMeshIndexing { public: using BatchedSetMeshIndexing::BatchedSetMeshIndexing; - void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value, - const IndexDesc& desc, _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, const IndexDesc& desc, + _megdnn_workspace workspace) override; - void set_error_tracker(void* tracker) override { - m_error_tracker = tracker; - } + void set_error_tracker(void* tracker) override { m_error_tracker = tracker; } }; } // namespace cuda diff --git a/dnn/src/cuda/padding/opr_impl.cpp b/dnn/src/cuda/padding/opr_impl.cpp index 9c2c9528..5da1c744 100644 --- a/dnn/src/cuda/padding/opr_impl.cpp +++ b/dnn/src/cuda/padding/opr_impl.cpp @@ -27,12 +27,12 @@ void PaddingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { offsets[5], offsets[6], offsets[7], offsets[8], offsets[9], offsets[10], offsets[11], offsets[12], offsets[13]}; auto stream = cuda_stream(this->handle()); -#define cb(DType) \ - if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ - using ctype = typename DTypeTrait::ctype; \ - padding::padding_forward_proxy(src, dst, param_offsets, \ - uint32_t(param().padding_mode), \ - param().padding_val, stream); \ +#define cb(DType) \ + if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + padding::padding_forward_proxy( \ + src, dst, param_offsets, uint32_t(param().padding_mode), \ + param().padding_val, stream); \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb @@ -47,24 +47,23 @@ void PaddingBackwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { offsets[5], offsets[6], offsets[7], offsets[8], offsets[9], offsets[10], offsets[11], offsets[12], offsets[13]}; auto stream = cuda_stream(this->handle()); -#define cb(DType) \ - if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ - using ctype = typename DTypeTrait::ctype; \ - padding::padding_backward_proxy(src, dst, param_offsets, \ - uint32_t(param().padding_mode), \ - stream); \ +#define cb(DType) \ + if (src.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + padding::padding_backward_proxy( \ + src, dst, param_offsets, uint32_t(param().padding_mode), stream); \ } MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) #undef cb } -size_t PaddingForwardImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) { +size_t PaddingForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { return 0; } -size_t PaddingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) { +size_t PaddingBackwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { return 0; } } // namespace cuda diff --git a/dnn/src/cuda/padding/opr_impl.h b/dnn/src/cuda/padding/opr_impl.h index 9cd495ac..946f2489 100644 --- a/dnn/src/cuda/padding/opr_impl.h +++ b/dnn/src/cuda/padding/opr_impl.h @@ -19,8 +19,8 @@ class PaddingForwardImpl : public PaddingForward { public: void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) override; }; class PaddingBackwardImpl : public PaddingBackward { @@ -28,8 +28,8 @@ class PaddingBackwardImpl : public PaddingBackward { public: void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) override; }; } // namespace cuda } // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/padding/padding.cu b/dnn/src/cuda/padding/padding.cu index 9c7b3074..d1a761b4 100644 --- a/dnn/src/cuda/padding/padding.cu +++ b/dnn/src/cuda/padding/padding.cu @@ -30,25 +30,26 @@ struct ShapeParams { }; template -__global__ void paddingConst_kernel(const size_t ndim, - const size_t total_out_nr, - const T* const src, T* const dst, - ShapeParams params, - const float_t padding_val) { +__global__ void paddingConst_kernel( + const size_t ndim, const size_t total_out_nr, const T* const src, T* const dst, + ShapeParams params, const float_t padding_val) { KERN_FOR(out_index, total_out_nr) { bool in_src_valid_area = true; size_t in_index = 0; size_t out_index_tmp = out_index; for (size_t dim = 0; dim <= ndim - 1; ++dim) { - Uint32Fastdiv dst_stride = params.dst_stride[dim], src_stride = params.src_stride[dim]; + Uint32Fastdiv dst_stride = params.dst_stride[dim], + src_stride = params.src_stride[dim]; size_t src_shape = params.src_shape[dim]; - size_t offset = params.offsets[dim*2]; + size_t offset = params.offsets[dim * 2]; size_t dim_index = out_index_tmp / dst_stride; - in_src_valid_area &= (dim_index >= offset && dim_index < offset+src_shape); - if(!in_src_valid_area) break; + in_src_valid_area &= + (dim_index >= offset && dim_index < offset + src_shape); + if (!in_src_valid_area) + break; out_index_tmp -= dim_index * dst_stride.divisor(); - in_index += (dim_index - offset)*src_stride.divisor(); + in_index += (dim_index - offset) * src_stride.divisor(); /* size_t dim_index = out_index_tmp / params.dst_stride[dim]; out_index_tmp -= dim_index * params.dst_stride[dim].divisor(); @@ -64,10 +65,9 @@ __global__ void paddingConst_kernel(const size_t ndim, } template -__global__ void paddingReplicate_kernel(const size_t ndim, - const size_t total_out_nr, - const T* const src, T* const dst, - ShapeParams params, const float_t) { +__global__ void paddingReplicate_kernel( + const size_t ndim, const size_t total_out_nr, const T* const src, T* const dst, + ShapeParams params, const float_t) { KERN_FOR(out_index, total_out_nr) { size_t in_index = 0; size_t out_index_tmp = out_index; @@ -76,8 +76,7 @@ __global__ void paddingReplicate_kernel(const size_t ndim, out_index_tmp -= dim_index * params.dst_stride[dim].divisor(); dim_index = (size_t)llmin( (long long)params.src_shape[dim] - 1, - llmax((long long)dim_index - - (long long)params.offsets[dim * 2], + llmax((long long)dim_index - (long long)params.offsets[dim * 2], (long long)0)); in_index += dim_index * params.src_stride[dim].divisor(); } @@ -86,10 +85,9 @@ __global__ void paddingReplicate_kernel(const size_t ndim, } template -__global__ void paddingReflect_kernel(const size_t ndim, - const size_t total_out_nr, - const T* const src, T* const dst, - ShapeParams params, const float_t) { +__global__ void paddingReflect_kernel( + const size_t ndim, const size_t total_out_nr, const T* const src, T* const dst, + ShapeParams params, const float_t) { KERN_FOR(out_index, total_out_nr) { size_t in_index = 0; size_t out_index_tmp = out_index; @@ -98,20 +96,18 @@ __global__ void paddingReflect_kernel(const size_t ndim, out_index_tmp -= dim_index * params.dst_stride[dim].divisor(); dim_index -= (long long)params.offsets[dim * 2]; dim_index = llmax(dim_index, -dim_index); - dim_index = llmin(dim_index, 2 * (long long)params.src_shape[dim] - - dim_index - 2); - in_index += size_t(dim_index) * - (size_t)params.src_stride[dim].divisor(); + dim_index = llmin( + dim_index, 2 * (long long)params.src_shape[dim] - dim_index - 2); + in_index += size_t(dim_index) * (size_t)params.src_stride[dim].divisor(); } dst[out_index] = src[in_index]; } } template -__global__ void paddingConstBackward_kernel(const size_t ndim, - const size_t total_in_nr, - const T* const src, T* const dst, - ShapeParams params) { +__global__ void paddingConstBackward_kernel( + const size_t ndim, const size_t total_in_nr, const T* const src, T* const dst, + ShapeParams params) { KERN_FOR(in_index, total_in_nr) { bool in_dst_valid_area = true; size_t out_index = 0; @@ -119,9 +115,9 @@ __global__ void paddingConstBackward_kernel(const size_t ndim, for (size_t dim = 0; dim <= ndim - 1; ++dim) { size_t dim_index = in_index_tmp / params.src_stride[dim]; in_index_tmp -= dim_index * params.src_stride[dim].divisor(); - in_dst_valid_area &= (dim_index >= params.offsets[dim * 2] && - dim_index < params.offsets[dim * 2] + - params.dst_shape[dim]); + in_dst_valid_area &= + (dim_index >= params.offsets[dim * 2] && + dim_index < params.offsets[dim * 2] + params.dst_shape[dim]); out_index += (dim_index - params.offsets[dim * 2]) * params.dst_stride[dim].divisor(); } @@ -132,11 +128,9 @@ __global__ void paddingConstBackward_kernel(const size_t ndim, } template -__global__ void paddingReplicateBackward_kernel(const size_t ndim, - const size_t total_in_nr, - const T* const src, - T* const dst, - ShapeParams params) { +__global__ void paddingReplicateBackward_kernel( + const size_t ndim, const size_t total_in_nr, const T* const src, T* const dst, + ShapeParams params) { KERN_FOR(in_index, total_in_nr) { size_t out_index = 0; size_t in_index_tmp = in_index; @@ -145,8 +139,7 @@ __global__ void paddingReplicateBackward_kernel(const size_t ndim, in_index_tmp -= dim_index * params.src_stride[dim].divisor(); dim_index = (size_t)llmin( (long long)params.dst_shape[dim] - 1, - llmax((long long)dim_index - - (long long)params.offsets[dim * 2], + llmax((long long)dim_index - (long long)params.offsets[dim * 2], (long long)0)); out_index += dim_index * params.dst_stride[dim].divisor(); } @@ -155,10 +148,9 @@ __global__ void paddingReplicateBackward_kernel(const size_t ndim, } template -__global__ void paddingReflectBackward_kernel(const size_t ndim, - const size_t total_in_nr, - const T* const src, T* const dst, - ShapeParams params) { +__global__ void paddingReflectBackward_kernel( + const size_t ndim, const size_t total_in_nr, const T* const src, T* const dst, + ShapeParams params) { KERN_FOR(in_index, total_in_nr) { size_t out_index = 0; size_t in_index_tmp = in_index; @@ -167,19 +159,18 @@ __global__ void paddingReflectBackward_kernel(const size_t ndim, in_index_tmp -= dim_index * params.src_stride[dim].divisor(); dim_index -= (long long)params.offsets[dim * 2]; dim_index = llmax(dim_index, -dim_index); - dim_index = llmin(dim_index, 2 * (long long)params.dst_shape[dim] - - dim_index - 2); - out_index += size_t(dim_index) * - (size_t)params.dst_stride[dim].divisor(); + dim_index = llmin( + dim_index, 2 * (long long)params.dst_shape[dim] - dim_index - 2); + out_index += size_t(dim_index) * (size_t)params.dst_stride[dim].divisor(); } atomic_add(&dst[out_index], src[in_index]); } } template -void padding_forward_proxy(const TensorND& src, const TensorND& dst, - size_t offsets[MEGDNN_MAX_NDIM * 2], uint32_t mode, - const float_t padding_val, cudaStream_t stream) { +void padding_forward_proxy( + const TensorND& src, const TensorND& dst, size_t offsets[MEGDNN_MAX_NDIM * 2], + uint32_t mode, const float_t padding_val, cudaStream_t stream) { ShapeParams params; for (size_t i = 0; i < src.layout.ndim; ++i) { params.src_shape[i] = src.layout.shape[i]; @@ -190,8 +181,9 @@ void padding_forward_proxy(const TensorND& src, const TensorND& dst, params.offsets[i * 2 + 1] = offsets[i * 2 + 1]; } - void (*fwd_kern)(const size_t, const size_t, const T* const, T* const, - ShapeParams, const float_t); + void (*fwd_kern)( + const size_t, const size_t, const T* const, T* const, ShapeParams, + const float_t); switch (mode) { case param_enumv::Padding::PaddingMode::CONSTANT: fwd_kern = paddingConst_kernel; @@ -211,16 +203,15 @@ void padding_forward_proxy(const TensorND& src, const TensorND& dst, uint32_t nr_threads = query_blocksize_for_kernel(fwd_kern); dim3 threads(nr_threads); dim3 blocks(DIVUP(total_nr, nr_threads)); - fwd_kern<<>>(src.layout.ndim, total_nr, - src.ptr(), dst.ptr(), params, - padding_val); + fwd_kern<<>>( + src.layout.ndim, total_nr, src.ptr(), dst.ptr(), params, padding_val); after_kernel_launch(); } template -void padding_backward_proxy(const TensorND& src, const TensorND& dst, - size_t offsets[MEGDNN_MAX_NDIM * 2], uint32_t mode, - cudaStream_t stream) { +void padding_backward_proxy( + const TensorND& src, const TensorND& dst, size_t offsets[MEGDNN_MAX_NDIM * 2], + uint32_t mode, cudaStream_t stream) { ShapeParams params; for (size_t i = 0; i < src.layout.ndim; ++i) { @@ -234,8 +225,7 @@ void padding_backward_proxy(const TensorND& src, const TensorND& dst, cudaMemset(dst.raw_ptr, 0, dst.layout.access_bytes()); - void (*bwd_kern)(const size_t, const size_t, const T* const, T* const, - ShapeParams); + void (*bwd_kern)(const size_t, const size_t, const T* const, T* const, ShapeParams); switch (mode) { case param_enumv::Padding::PaddingMode::CONSTANT: @@ -269,11 +259,10 @@ MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb #undef INST -#define INST(T) \ - template void padding_backward_proxy( \ - const TensorND& src, const TensorND& dst, \ - size_t offsets[MEGDNN_MAX_NDIM * 2], uint32_t mode, \ - cudaStream_t stream); +#define INST(T) \ + template void padding_backward_proxy( \ + const TensorND& src, const TensorND& dst, \ + size_t offsets[MEGDNN_MAX_NDIM * 2], uint32_t mode, cudaStream_t stream); #define cb(DType) INST(typename DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) #undef cb diff --git a/dnn/src/cuda/padding/padding.cuh b/dnn/src/cuda/padding/padding.cuh index 17629d49..3f496630 100644 --- a/dnn/src/cuda/padding/padding.cuh +++ b/dnn/src/cuda/padding/padding.cuh @@ -20,14 +20,14 @@ namespace cuda { namespace padding { template -void padding_forward_proxy(const TensorND& src, const TensorND& dst, - size_t offsets[MEGDNN_MAX_NDIM * 2], uint32_t mode, - const float_t padding_val, cudaStream_t stream); +void padding_forward_proxy( + const TensorND& src, const TensorND& dst, size_t offsets[MEGDNN_MAX_NDIM * 2], + uint32_t mode, const float_t padding_val, cudaStream_t stream); template -void padding_backward_proxy(const TensorND& src, const TensorND& dst, - size_t offsets[MEGDNN_MAX_NDIM * 2], uint32_t mode, - cudaStream_t stream); +void padding_backward_proxy( + const TensorND& src, const TensorND& dst, size_t offsets[MEGDNN_MAX_NDIM * 2], + uint32_t mode, cudaStream_t stream); } // namespace padding } // namespace cuda diff --git a/dnn/src/cuda/param_pack/opr_impl.cpp b/dnn/src/cuda/param_pack/opr_impl.cpp index 354dd2e7..8bb72851 100644 --- a/dnn/src/cuda/param_pack/opr_impl.cpp +++ b/dnn/src/cuda/param_pack/opr_impl.cpp @@ -16,19 +16,16 @@ namespace megdnn { namespace cuda { -size_t ParamPackConcatImpl::get_workspace_in_bytes(const TensorShapeArray& srcs, - const TensorShape&, - const TensorShape&) { +size_t ParamPackConcatImpl::get_workspace_in_bytes( + const TensorShapeArray& srcs, const TensorShape&, const TensorShape&) { return sizeof(size_t) * srcs.size(); } template -void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, - _megdnn_tensor_in offsets, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) { - size_t inp_size = srcs.layout.shape[0], - out_size = dst.layout.total_nr_elems(); +void ParamPackConcatImpl::exec_internal( + _megdnn_tensor_in srcs, _megdnn_tensor_in offsets, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + size_t inp_size = srcs.layout.shape[0], out_size = dst.layout.total_nr_elems(); auto stream = cuda_stream(this->handle()); auto src_cpu = static_cast(srcs.raw_ptr); @@ -37,17 +34,17 @@ void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, auto offsets_gpu = offsets.ptr(); - cuda_check(cudaMemcpyAsync(src_gpu, src_cpu, sizeof(const T*) * inp_size, - cudaMemcpyHostToDevice, stream)); + cuda_check(cudaMemcpyAsync( + src_gpu, src_cpu, sizeof(const T*) * inp_size, cudaMemcpyHostToDevice, + stream)); - param_pack::concat_proxy(src_gpu, dst.ptr(), inp_size, out_size, - offsets_gpu, stream); + param_pack::concat_proxy( + src_gpu, dst.ptr(), inp_size, out_size, offsets_gpu, stream); } -void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, - _megdnn_tensor_in offsets, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) { +void ParamPackConcatImpl::exec( + _megdnn_tensor_in srcs, _megdnn_tensor_in offsets, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { check_exec(dst.layout, offsets.layout, srcs.layout); #define cb(DType) \ if (dst.layout.dtype == DType()) { \ diff --git a/dnn/src/cuda/param_pack/opr_impl.h b/dnn/src/cuda/param_pack/opr_impl.h index 87e521ea..569ecc34 100644 --- a/dnn/src/cuda/param_pack/opr_impl.h +++ b/dnn/src/cuda/param_pack/opr_impl.h @@ -18,17 +18,19 @@ namespace cuda { class ParamPackConcatImpl final : public ParamPackConcat { public: using ParamPackConcat::ParamPackConcat; - void exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, - _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + void exec( + _megdnn_tensor_in srcs, _megdnn_tensor_in table, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorShapeArray& srcs, - const TensorShape& table, - const TensorShape& dst) override; + size_t get_workspace_in_bytes( + const TensorShapeArray& srcs, const TensorShape& table, + const TensorShape& dst) override; private: template - void exec_internal(_megdnn_tensor_in srcs, _megdnn_tensor_in table, - _megdnn_tensor_out dst, _megdnn_workspace workspace); + void exec_internal( + _megdnn_tensor_in srcs, _megdnn_tensor_in table, _megdnn_tensor_out dst, + _megdnn_workspace workspace); }; } // namespace cuda diff --git a/dnn/src/cuda/param_pack/param_pack.cu b/dnn/src/cuda/param_pack/param_pack.cu index c907db34..23c40555 100644 --- a/dnn/src/cuda/param_pack/param_pack.cu +++ b/dnn/src/cuda/param_pack/param_pack.cu @@ -18,10 +18,9 @@ namespace cuda { namespace param_pack { template -__global__ void concat_kernel(const T** srcs, T* dst, - const int32_t* offsets, - size_t srcs_size, - size_t total_size) { +__global__ void concat_kernel( + const T** srcs, T* dst, const int32_t* offsets, size_t srcs_size, + size_t total_size) { size_t addr = threadIdx.x + blockIdx.x * blockDim.x; if (addr < total_size) { size_t l = 0, r = srcs_size - 1, mid; @@ -41,19 +40,18 @@ __global__ void concat_kernel(const T** srcs, T* dst, } template -void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, - const int32_t* offsets, - cudaStream_t stream) { +void concat_proxy( + const T** srcs, T* dst, size_t srcs_size, size_t total_size, + const int32_t* offsets, cudaStream_t stream) { size_t NR_BLOCKS = DIVUP(total_size, NR_THREADS); concat_kernel<<>>( srcs, dst, offsets, srcs_size, total_size); after_kernel_launch(); } -#define INST(T) \ - template void concat_proxy(const T**, T*, size_t, size_t, \ - const int32_t*, \ - cudaStream_t); +#define INST(T) \ + template void concat_proxy( \ + const T**, T*, size_t, size_t, const int32_t*, cudaStream_t); #define cb(DType) INST(typename DTypeTrait::ctype) MEGDNN_FOREACH_COMPUTING_DTYPE(cb) #undef cb diff --git a/dnn/src/cuda/param_pack/param_pack.cuh b/dnn/src/cuda/param_pack/param_pack.cuh index 35c9fe52..481d03c7 100644 --- a/dnn/src/cuda/param_pack/param_pack.cuh +++ b/dnn/src/cuda/param_pack/param_pack.cuh @@ -20,8 +20,9 @@ namespace cuda { namespace param_pack { template -void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, - const int32_t* offsets, cudaStream_t stream); +void concat_proxy( + const T** srcs, T* dst, size_t srcs_size, size_t total_size, + const int32_t* offsets, cudaStream_t stream); } // namespace param_pack } // namespace cuda diff --git a/dnn/src/cuda/pooling/algo.cpp b/dnn/src/cuda/pooling/algo.cpp index 9d6ecce2..a689f12e 100644 --- a/dnn/src/cuda/pooling/algo.cpp +++ b/dnn/src/cuda/pooling/algo.cpp @@ -19,7 +19,7 @@ using namespace cuda; namespace { #define V1(v) #v -#define V(v) V1(v) +#define V(v) V1(v) #define DEF_NAME(NAME) \ #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) } // namespace @@ -43,26 +43,25 @@ PoolingForwardImpl::AlgoPack::AlgoPack() { PoolingForwardImpl::AlgoPack PoolingForwardImpl::sm_algo_pack; MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingForwardImpl) -PoolingForwardImpl::AlgoBase::SizeArgs::SizeArgs(PoolingForwardImpl* o, - const TensorLayout& src, - const TensorLayout& dst) +PoolingForwardImpl::AlgoBase::SizeArgs::SizeArgs( + PoolingForwardImpl* o, const TensorLayout& src, const TensorLayout& dst) : handle{concrete_handle(o->handle())}, opr{o}, layout_src{&src}, layout_dst{&dst} {} -PoolingForwardImpl::AlgoBase::ExecArgs::ExecArgs(PoolingForwardImpl* opr, - _megdnn_tensor_in src, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) +PoolingForwardImpl::AlgoBase::ExecArgs::ExecArgs( + PoolingForwardImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) : SizeArgs(opr, src.layout, dst.layout), src_tensor{&src}, dst_tensor{&dst}, workspace{workspace} {} std::string PoolingForwardImpl::AlgoBase::SizeArgs::to_string() const { - return ssprintf("src=%s, dst=%s", layout_src->to_string().c_str(), - layout_dst->to_string().c_str()); + return ssprintf( + "src=%s, dst=%s", layout_src->to_string().c_str(), + layout_dst->to_string().c_str()); } WorkspaceBundle PoolingForwardImpl::AlgoBase::get_workspace_bundle( @@ -103,8 +102,8 @@ bool PoolingForwardImpl::AlgoCUDNN::is_available(const SizeArgs& args) const { args.layout_src->dtype.enumv() == DTypeEnum::Quantized8Asymm))); } -void PoolingForwardImpl::AlgoCUDNN::init_mode(const ExecArgs& args, - cudnnPoolingMode_t& mode) const { +void PoolingForwardImpl::AlgoCUDNN::init_mode( + const ExecArgs& args, cudnnPoolingMode_t& mode) const { switch (args.opr->param().mode) { case param::Pooling::Mode::MAX: mode = CUDNN_POOLING_MAX; @@ -116,8 +115,9 @@ void PoolingForwardImpl::AlgoCUDNN::init_mode(const ExecArgs& args, mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; break; default: - megdnn_throw(ssprintf("Unspport pooling mode : {%d}", - static_cast(args.opr->param().mode))); + megdnn_throw(ssprintf( + "Unspport pooling mode : {%d}", + static_cast(args.opr->param().mode))); } } @@ -143,13 +143,13 @@ void PoolingForwardImpl::AlgoCUDNN::exec(const ExecArgs& args) const { cudnnPoolingDescriptor_t cudnn_desc; cudnn_check(cudnnCreatePoolingDescriptor(&cudnn_desc)); cudnn_check(cudnnSetPooling2dDescriptor( - cudnn_desc, mode, CUDNN_NOT_PROPAGATE_NAN, - args.opr->param().window_h, args.opr->param().window_w, - args.opr->param().pad_h, args.opr->param().pad_w, - args.opr->param().stride_h, args.opr->param().stride_w)); - cudnn_check(cudnnPoolingForward(args.handle->cudnn_handle(), cudnn_desc, - &alpha, src_desc.desc, src.raw_ptr, - &beta, dst_desc.desc, dst.raw_ptr)); + cudnn_desc, mode, CUDNN_NOT_PROPAGATE_NAN, args.opr->param().window_h, + args.opr->param().window_w, args.opr->param().pad_h, + args.opr->param().pad_w, args.opr->param().stride_h, + args.opr->param().stride_w)); + cudnn_check(cudnnPoolingForward( + args.handle->cudnn_handle(), cudnn_desc, &alpha, src_desc.desc, + src.raw_ptr, &beta, dst_desc.desc, dst.raw_ptr)); cudnn_check(cudnnDestroyPoolingDescriptor(cudnn_desc)); } if (args.layout_src->dtype.enumv() == DTypeTrait::enumv) { @@ -184,13 +184,13 @@ void PoolingForwardImpl::AlgoCUDNNMAXDETERMINISTIC::init_mode( mode = CUDNN_POOLING_MAX_DETERMINISTIC; break; default: - megdnn_throw(ssprintf("Unspport pooling mode : {%d}", - static_cast(args.opr->param().mode))); + megdnn_throw(ssprintf( + "Unspport pooling mode : {%d}", + static_cast(args.opr->param().mode))); } } -void PoolingForwardImpl::AlgoCUDNNMAXDETERMINISTIC::exec( - const ExecArgs& args) const { +void PoolingForwardImpl::AlgoCUDNNMAXDETERMINISTIC::exec(const ExecArgs& args) const { TensorND src = *args.src_tensor; TensorND dst = *args.dst_tensor; auto wsb = get_workspace_bundle(args.workspace.raw_ptr, args); @@ -212,13 +212,13 @@ void PoolingForwardImpl::AlgoCUDNNMAXDETERMINISTIC::exec( cudnnPoolingDescriptor_t cudnn_desc; cudnn_check(cudnnCreatePoolingDescriptor(&cudnn_desc)); cudnn_check(cudnnSetPooling2dDescriptor( - cudnn_desc, mode, CUDNN_NOT_PROPAGATE_NAN, - args.opr->param().window_h, args.opr->param().window_w, - args.opr->param().pad_h, args.opr->param().pad_w, - args.opr->param().stride_h, args.opr->param().stride_w)); - cudnn_check(cudnnPoolingForward(args.handle->cudnn_handle(), cudnn_desc, - &alpha, src_desc.desc, src.raw_ptr, - &beta, dst_desc.desc, dst.raw_ptr)); + cudnn_desc, mode, CUDNN_NOT_PROPAGATE_NAN, args.opr->param().window_h, + args.opr->param().window_w, args.opr->param().pad_h, + args.opr->param().pad_w, args.opr->param().stride_h, + args.opr->param().stride_w)); + cudnn_check(cudnnPoolingForward( + args.handle->cudnn_handle(), cudnn_desc, &alpha, src_desc.desc, + src.raw_ptr, &beta, dst_desc.desc, dst.raw_ptr)); cudnn_check(cudnnDestroyPoolingDescriptor(cudnn_desc)); } if (args.layout_src->dtype.enumv() == DTypeTrait::enumv) { @@ -241,13 +241,12 @@ void PoolingForwardImpl::AlgoCHWN4::exec(const ExecArgs& args) const { ho = (*args.layout_dst)[1], wo = (*args.layout_dst)[2]; c = c * 4; size_t ph = args.opr->param().pad_h, pw = args.opr->param().pad_w; - size_t window_h = args.opr->param().window_h, - window_w = args.opr->param().window_w; + size_t window_h = args.opr->param().window_h, window_w = args.opr->param().window_w; size_t sh = args.opr->param().stride_h, sw = args.opr->param().stride_w; kern_param.n = n, kern_param.c = c, kern_param.hi = hi, kern_param.wi = wi, - kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, - kern_param.pw = pw, kern_param.window_h = window_h, - kern_param.window_w = window_w, kern_param.sh = sh, kern_param.sw = sw; + kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, + kern_param.window_h = window_h, kern_param.window_w = window_w, kern_param.sh = sh, + kern_param.sw = sw; auto&& stream = cuda_stream(args.handle); pooling2d::do_pooling2d_int8_cdiv4hwn4( args.src_tensor->compatible_ptr(), @@ -269,13 +268,12 @@ void PoolingForwardImpl::AlgoNCHW4::exec(const ExecArgs& args) const { ho = (*args.layout_dst)[2], wo = (*args.layout_dst)[3]; c = c * 4; size_t ph = args.opr->param().pad_h, pw = args.opr->param().pad_w; - size_t window_h = args.opr->param().window_h, - window_w = args.opr->param().window_w; + size_t window_h = args.opr->param().window_h, window_w = args.opr->param().window_w; size_t sh = args.opr->param().stride_h, sw = args.opr->param().stride_w; kern_param.n = n, kern_param.c = c, kern_param.hi = hi, kern_param.wi = wi, - kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, - kern_param.pw = pw, kern_param.window_h = window_h, - kern_param.window_w = window_w, kern_param.sh = sh, kern_param.sw = sw; + kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, + kern_param.window_h = window_h, kern_param.window_w = window_w, kern_param.sh = sh, + kern_param.sw = sw; auto&& stream = cuda_stream(args.handle); pooling2d::do_pooling2d_int8_ncdiv4hw4( args.src_tensor->compatible_ptr(), @@ -297,13 +295,12 @@ void PoolingForwardImpl::AlgoNCHW32::exec(const ExecArgs& args) const { ho = (*args.layout_dst)[2], wo = (*args.layout_dst)[3]; c = c * 32; size_t ph = args.opr->param().pad_h, pw = args.opr->param().pad_w; - size_t window_h = args.opr->param().window_h, - window_w = args.opr->param().window_w; + size_t window_h = args.opr->param().window_h, window_w = args.opr->param().window_w; size_t sh = args.opr->param().stride_h, sw = args.opr->param().stride_w; kern_param.n = n, kern_param.c = c, kern_param.hi = hi, kern_param.wi = wi, - kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, - kern_param.pw = pw, kern_param.window_h = window_h, - kern_param.window_w = window_w, kern_param.sh = sh, kern_param.sw = sw; + kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, + kern_param.window_h = window_h, kern_param.window_w = window_w, kern_param.sh = sh, + kern_param.sw = sw; auto&& stream = cuda_stream(args.handle); pooling2d::do_pooling2d_int8_ncdiv32hw32( args.src_tensor->compatible_ptr(), @@ -322,8 +319,9 @@ void PoolingForwardImpl::AlgoNHWC::exec(const ExecArgs& args) const { TensorND src = *args.src_tensor; TensorND dst = *args.dst_tensor; { - megdnn_assert(src.layout.dtype.enumv() == dst.layout.dtype.enumv(), - "src and dst dtype must equal"); + megdnn_assert( + src.layout.dtype.enumv() == dst.layout.dtype.enumv(), + "src and dst dtype must equal"); pooling2d::Param kern_param; size_t n = src.layout[0], hi = src.layout[1], wi = src.layout[2], c = src.layout[3], ho = dst.layout[1], wo = dst.layout[2]; @@ -331,29 +329,26 @@ void PoolingForwardImpl::AlgoNHWC::exec(const ExecArgs& args) const { size_t window_h = args.opr->param().window_h, window_w = args.opr->param().window_w; size_t sh = args.opr->param().stride_h, sw = args.opr->param().stride_w; - kern_param.n = n, kern_param.c = c, kern_param.hi = hi, - kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, - kern_param.ph = ph, kern_param.pw = pw, kern_param.window_h = window_h, - kern_param.window_w = window_w, kern_param.sh = sh, kern_param.sw = sw; + kern_param.n = n, kern_param.c = c, kern_param.hi = hi, kern_param.wi = wi, + kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, + kern_param.window_h = window_h, kern_param.window_w = window_w, + kern_param.sh = sh, kern_param.sw = sw; bool uint_case = false; int zero_point = 0; if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { uint_case = true; - zero_point = - src.layout.dtype.param().zero_point; + zero_point = src.layout.dtype.param().zero_point; } auto&& stream = cuda_stream(args.handle); pooling2d::do_pooling2d_int4_nhwc( (int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, kern_param, stream, - static_cast(args.opr->param().mode), uint_case, - zero_point); + static_cast(args.opr->param().mode), uint_case, zero_point); } } inline void PoolingForwardImpl::AlgoNCHW64::deduce_reformat_layout( - std::unique_ptr& relayout, - const TensorLayout& src_layout, TensorLayout& dst_layout, - RelayoutFormat::Param::Mode mode, const int oc = 0, + std::unique_ptr& relayout, const TensorLayout& src_layout, + TensorLayout& dst_layout, RelayoutFormat::Param::Mode mode, const int oc = 0, const int group = 1) const { if (src_layout.ndim > 0) { RelayoutFormat::Param trans_param; @@ -368,14 +363,16 @@ inline void PoolingForwardImpl::AlgoNCHW64::deduce_reformat_layout( } void PoolingForwardImpl::AlgoNCHW64::get_inner_layout( - const TensorLayout& src, const TensorLayout& dst, - TensorLayout& inner_src, TensorLayout& inner_dst, Handle* handle, + const TensorLayout& src, const TensorLayout& dst, TensorLayout& inner_src, + TensorLayout& inner_dst, Handle* handle, PoolingForwardImpl::Param::Format format) const { auto relayout_opr = handle->create_operator(); - deduce_reformat_layout(relayout_opr, src, inner_src, - RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1); - deduce_reformat_layout(relayout_opr, dst, inner_dst, - RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1); + deduce_reformat_layout( + relayout_opr, src, inner_src, RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, + 1); + deduce_reformat_layout( + relayout_opr, dst, inner_dst, RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, + 1); } WorkspaceBundle PoolingForwardImpl::AlgoNCHW64::get_workspace_bundle( @@ -385,8 +382,9 @@ WorkspaceBundle PoolingForwardImpl::AlgoNCHW64::get_workspace_bundle( TensorLayout fsrc = *args.layout_src; TensorLayout fdst = *args.layout_dst; if (args.opr->param().format == Format::NCHW) { - get_inner_layout(*args.layout_src, *args.layout_dst, fsrc, fdst, - args.handle, args.opr->param().format); + get_inner_layout( + *args.layout_src, *args.layout_dst, fsrc, fdst, args.handle, + args.opr->param().format); sizes.push_back(fsrc.span().dist_byte()); sizes.push_back(fdst.span().dist_byte()); } @@ -410,8 +408,9 @@ void PoolingForwardImpl::AlgoNCHW64::exec(const ExecArgs& args) const { if (args.opr->param().format == Format::NCHW) { auto wsb = get_workspace_bundle(args.workspace.raw_ptr, args); auto handle_ptr = args.handle; - get_inner_layout(*args.layout_src, *args.layout_dst, src.layout, - dst.layout, handle_ptr, args.opr->param().format); + get_inner_layout( + *args.layout_src, *args.layout_dst, src.layout, dst.layout, handle_ptr, + args.opr->param().format); src.raw_ptr = wsb.get(0); dst.raw_ptr = wsb.get(1); auto relayout_opr = handle_ptr->create_operator(); @@ -430,22 +429,20 @@ void PoolingForwardImpl::AlgoNCHW64::exec(const ExecArgs& args) const { size_t window_h = args.opr->param().window_h, window_w = args.opr->param().window_w; size_t sh = args.opr->param().stride_h, sw = args.opr->param().stride_w; - kern_param.n = n, kern_param.c = c, kern_param.hi = hi, - kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo, - kern_param.ph = ph, kern_param.pw = pw, kern_param.window_h = window_h, - kern_param.window_w = window_w, kern_param.sh = sh, kern_param.sw = sw; + kern_param.n = n, kern_param.c = c, kern_param.hi = hi, kern_param.wi = wi, + kern_param.ho = ho, kern_param.wo = wo, kern_param.ph = ph, kern_param.pw = pw, + kern_param.window_h = window_h, kern_param.window_w = window_w, + kern_param.sh = sh, kern_param.sw = sw; bool uint_case = false; int zero_point = 0; if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) { uint_case = true; - zero_point = - src.layout.dtype.param().zero_point; + zero_point = src.layout.dtype.param().zero_point; } auto&& stream = cuda_stream(args.handle); pooling2d::do_pooling2d_int4_ncdiv64hw64( (int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, kern_param, stream, - static_cast(args.opr->param().mode), uint_case, - zero_point); + static_cast(args.opr->param().mode), uint_case, zero_point); } if (args.layout_dst->ndim == 4) { auto relayout_opr = args.handle->create_operator(); @@ -472,11 +469,9 @@ PoolingBackwardImpl::AlgoPack::AlgoPack() { PoolingBackwardImpl::AlgoPack PoolingBackwardImpl::sm_algo_pack; MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingBackwardImpl) -PoolingBackwardImpl::AlgoBase::SizeArgs::SizeArgs(PoolingBackwardImpl* o, - const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad) +PoolingBackwardImpl::AlgoBase::SizeArgs::SizeArgs( + PoolingBackwardImpl* o, const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) : handle{concrete_handle(o->handle())}, opr{o}, layout_src{&src}, @@ -484,12 +479,9 @@ PoolingBackwardImpl::AlgoBase::SizeArgs::SizeArgs(PoolingBackwardImpl* o, layout_diff{&diff}, layout_grad{&grad} {} -PoolingBackwardImpl::AlgoBase::ExecArgs::ExecArgs(PoolingBackwardImpl* opr, - _megdnn_tensor_in src, - _megdnn_tensor_in dst, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) +PoolingBackwardImpl::AlgoBase::ExecArgs::ExecArgs( + PoolingBackwardImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_in dst, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, _megdnn_workspace workspace) : SizeArgs(opr, src.layout, dst.layout, diff.layout, grad.layout), src_tensor{&src}, dst_tensor{&dst}, @@ -518,8 +510,7 @@ bool PoolingBackwardImpl::AlgoCUDNN::is_available(const SizeArgs& args) const { args.opr->param().format == Format::NHWC || args.opr->param().format == Format::NCHW4 || args.opr->param().format == Format::NCHW32) && - (m_is_reproducible || - args.opr->param().mode == param::Pooling::Mode::MAX)); + (m_is_reproducible || args.opr->param().mode == param::Pooling::Mode::MAX)); #endif } @@ -548,8 +539,8 @@ size_t PoolingBackwardImpl::AlgoBase::get_workspace_in_bytes( return get_workspace_bundle(nullptr, args).total_size_in_bytes(); } -void PoolingBackwardImpl::AlgoCUDNN::init_mode(const ExecArgs& args, - cudnnPoolingMode_t& mode) const { +void PoolingBackwardImpl::AlgoCUDNN::init_mode( + const ExecArgs& args, cudnnPoolingMode_t& mode) const { if (m_is_reproducible) { switch (args.opr->param().mode) { #if CUDNN_VERSION >= 6000 @@ -564,9 +555,9 @@ void PoolingBackwardImpl::AlgoCUDNN::init_mode(const ExecArgs& args, mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; break; default: - megdnn_throw( - ssprintf("Unspport pooling mode : {%d}", - static_cast(args.opr->param().mode))); + megdnn_throw(ssprintf( + "Unspport pooling mode : {%d}", + static_cast(args.opr->param().mode))); } } else if (args.opr->param().mode == param::Pooling::Mode::MAX) { mode = CUDNN_POOLING_MAX; @@ -603,14 +594,14 @@ void PoolingBackwardImpl::AlgoCUDNN::exec(const ExecArgs& args) const { cudnnPoolingDescriptor_t cudnn_desc; cudnn_check(cudnnCreatePoolingDescriptor(&cudnn_desc)); cudnn_check(cudnnSetPooling2dDescriptor( - cudnn_desc, mode, CUDNN_NOT_PROPAGATE_NAN, - args.opr->param().window_h, args.opr->param().window_w, - args.opr->param().pad_h, args.opr->param().pad_w, - args.opr->param().stride_h, args.opr->param().stride_w)); + cudnn_desc, mode, CUDNN_NOT_PROPAGATE_NAN, args.opr->param().window_h, + args.opr->param().window_w, args.opr->param().pad_h, + args.opr->param().pad_w, args.opr->param().stride_h, + args.opr->param().stride_w)); cudnn_check(cudnnPoolingBackward( args.handle->cudnn_handle(), cudnn_desc, &alpha, dst_desc.desc, - dst.raw_ptr, diff_desc.desc, diff.raw_ptr, src_desc.desc, - src.raw_ptr, &beta, grad_desc.desc, grad.raw_ptr)); + dst.raw_ptr, diff_desc.desc, diff.raw_ptr, src_desc.desc, src.raw_ptr, + &beta, grad_desc.desc, grad.raw_ptr)); cudnn_check(cudnnDestroyPoolingDescriptor(cudnn_desc)); } if (args.layout_src->dtype.enumv() == DTypeTrait::enumv) { diff --git a/dnn/src/cuda/pooling/algo.h b/dnn/src/cuda/pooling/algo.h index 5cf137a6..f499ff15 100644 --- a/dnn/src/cuda/pooling/algo.h +++ b/dnn/src/cuda/pooling/algo.h @@ -23,7 +23,7 @@ namespace cuda { namespace { #define V1(v) #v -#define V(v) V1(v) +#define V(v) V1(v) #define DEF_NAME(NAME) \ #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) } // namespace @@ -50,15 +50,17 @@ public: const TensorLayout *layout_src, *layout_dst; std::string to_string() const; - SizeArgs(PoolingForwardImpl* opr, const TensorLayout& src, - const TensorLayout& dst); + SizeArgs( + PoolingForwardImpl* opr, const TensorLayout& src, + const TensorLayout& dst); }; struct ExecArgs : public SizeArgs { const TensorND *src_tensor, *dst_tensor; Workspace workspace; - ExecArgs(PoolingForwardImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_out dst, _megdnn_workspace workspace); + ExecArgs( + PoolingForwardImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; @@ -75,8 +77,7 @@ public: protected: ~AlgoBase() = default; - virtual WorkspaceBundle get_workspace_bundle(void* ptr, - const SizeArgs& args) const; + virtual WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; }; class PoolingForwardImpl::AlgoCUDNN final : public AlgoBase { @@ -90,9 +91,7 @@ public: void exec(const ExecArgs& args) const override; const char* name() const override { return m_algo_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN) @@ -111,9 +110,7 @@ public: void exec(const ExecArgs& args) const override; const char* name() const override { return m_algo_name.c_str(); } - AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; - } + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_MAXDETERMINISTIC) @@ -121,42 +118,46 @@ public: }; #endif -#define ALGO_LAYOUT_POOLING_IMPL(_layout) \ - class PoolingForwardImpl::Algo##_layout final : public AlgoBase { \ - std::string m_algo_name; \ - \ - public: \ - Algo##_layout( \ - std::string name = std::string("CUDA_").append(#_layout)) \ - : m_algo_name(name) {} \ - bool is_available(const SizeArgs& args) const override; \ - void exec(const ExecArgs& args) const override; \ - const char* name() const override { return m_algo_name.c_str(); } \ - AlgoAttribute attribute() const override { \ - return AlgoAttribute::REPRODUCIBLE; \ - } \ +#define ALGO_LAYOUT_POOLING_IMPL(_layout) \ + class PoolingForwardImpl::Algo##_layout final : public AlgoBase { \ + std::string m_algo_name; \ + \ + public: \ + Algo##_layout(std::string name = std::string("CUDA_").append(#_layout)) \ + : m_algo_name(name) {} \ + bool is_available(const SizeArgs& args) const override; \ + void exec(const ExecArgs& args) const override; \ + const char* name() const override { return m_algo_name.c_str(); } \ + AlgoAttribute attribute() const override { \ + return AlgoAttribute::REPRODUCIBLE; \ + } \ MEGDNN_DECL_ALGO_TYPE(CUDA_##_layout) -ALGO_LAYOUT_POOLING_IMPL(CHWN4)}; -ALGO_LAYOUT_POOLING_IMPL(NCHW4)}; -ALGO_LAYOUT_POOLING_IMPL(NCHW32)}; -ALGO_LAYOUT_POOLING_IMPL(NHWC)}; -ALGO_LAYOUT_POOLING_IMPL(NCHW64) //{ +ALGO_LAYOUT_POOLING_IMPL(CHWN4) +}; +ALGO_LAYOUT_POOLING_IMPL(NCHW4) +}; +ALGO_LAYOUT_POOLING_IMPL(NCHW32) +} +; +ALGO_LAYOUT_POOLING_IMPL(NHWC) +} +; +ALGO_LAYOUT_POOLING_IMPL(NCHW64) //{ protected: - WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) - const override; +WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const override; private: - inline void deduce_reformat_layout( - std::unique_ptr & relayout, - const TensorLayout& src_layout, TensorLayout& dst_layout, - RelayoutFormat::Param::Mode mode, const int oc, const int group) - const; - void get_inner_layout(const TensorLayout& src, const TensorLayout& dst, - TensorLayout& inner_src, TensorLayout& inner_dst, - Handle* handle, - PoolingForwardImpl::Param::Format format) const; -}; +inline void deduce_reformat_layout( + std::unique_ptr& relayout, const TensorLayout& src_layout, + TensorLayout& dst_layout, RelayoutFormat::Param::Mode mode, const int oc, + const int group) const; +void get_inner_layout( + const TensorLayout& src, const TensorLayout& dst, TensorLayout& inner_src, + TensorLayout& inner_dst, Handle* handle, + PoolingForwardImpl::Param::Format format) const; +} +; #undef ALGO_LAYOUT_POOLING_IMPL @@ -194,17 +195,19 @@ public: const TensorLayout *layout_src, *layout_dst, *layout_diff, *layout_grad; std::string to_string() const; - SizeArgs(PoolingBackwardImpl* opr, const TensorLayout& src, - const TensorLayout& dst, const TensorLayout& diff, - const TensorLayout& grad); + SizeArgs( + PoolingBackwardImpl* opr, const TensorLayout& src, + const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad); }; struct ExecArgs : public SizeArgs { const TensorND *src_tensor, *dst_tensor, *diff_tensor, *grad_tensor; Workspace workspace; - ExecArgs(PoolingBackwardImpl* opr, _megdnn_tensor_in src, - _megdnn_tensor_in dst, _megdnn_tensor_in diff, - _megdnn_tensor_out grad, _megdnn_workspace workspace); + ExecArgs( + PoolingBackwardImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_in dst, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); }; virtual bool is_available(const SizeArgs& args) const = 0; @@ -221,8 +224,7 @@ public: protected: ~AlgoBase() = default; - virtual WorkspaceBundle get_workspace_bundle(void* ptr, - const SizeArgs& args) const; + virtual WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; }; class PoolingBackwardImpl::AlgoCUDNN final : public AlgoBase { diff --git a/dnn/src/cuda/pooling/opr_impl.cpp b/dnn/src/cuda/pooling/opr_impl.cpp index c4f3d39d..728fca31 100644 --- a/dnn/src/cuda/pooling/opr_impl.cpp +++ b/dnn/src/cuda/pooling/opr_impl.cpp @@ -19,8 +19,8 @@ namespace megdnn { namespace cuda { -size_t PoolingForwardImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) { +size_t PoolingForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) { return get_dnn_workspace(this, src, dst); } @@ -28,14 +28,12 @@ const char* PoolingForwardImpl::get_algorithm_set_name() const { return "CUDA_POOLING_FORWARD"; } -std::vector -PoolingForwardImpl::get_all_algorithms(const TensorLayout& src, - const TensorLayout& dst) { +std::vector PoolingForwardImpl::get_all_algorithms( + const TensorLayout& src, const TensorLayout& dst) { return megdnn::get_all_algorithms({this, src, dst}); } -std::vector -PoolingForwardImpl::get_all_algorithms_safe(const TensorLayout& src, - const TensorLayout& dst) { +std::vector PoolingForwardImpl::get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst) { return megdnn::get_all_algorithms_safe({this, src, dst}); } @@ -51,16 +49,16 @@ PoolingForwardImpl::Algorithm* PoolingForwardImpl::get_algorithm_heuristic( return iter; } } - megdnn_throw( - ssprintf("require algorithm with attribute(%s) and without " - "attribute(%s), but can't get suitable algo.\n", - Algorithm::attribute_str(positive_attr).c_str(), - Algorithm::attribute_str(negative_attr).c_str())); + megdnn_throw(ssprintf( + "require algorithm with attribute(%s) and without " + "attribute(%s), but can't get suitable algo.\n", + Algorithm::attribute_str(positive_attr).c_str(), + Algorithm::attribute_str(negative_attr).c_str())); return nullptr; } -void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, - _megdnn_workspace sworkspace) { +void PoolingForwardImpl::exec( + _megdnn_tensor_in ssrc, _megdnn_tensor_out sdst, _megdnn_workspace sworkspace) { check_exec(ssrc.layout, sdst.layout, sworkspace.size); { AlgoBase::ExecArgs args(this, ssrc, sdst, sworkspace); @@ -73,29 +71,25 @@ const char* PoolingBackwardImpl::get_algorithm_set_name() const { return "CUDA_POOLING_BACKWARD"; } -std::vector -PoolingBackwardImpl::get_all_algorithms(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad) { +std::vector PoolingBackwardImpl::get_all_algorithms( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad) { return megdnn::get_all_algorithms( {this, src, dst, diff, grad}); } -std::vector -PoolingBackwardImpl::get_all_algorithms_safe(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad) { +std::vector PoolingBackwardImpl:: + get_all_algorithms_safe( + const TensorLayout& src, const TensorLayout& dst, + const TensorLayout& diff, const TensorLayout& grad) { return megdnn::get_all_algorithms_safe( {this, src, dst, diff, grad}); } PoolingBackwardImpl::Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& dst, - const TensorLayout& diff, const TensorLayout& grad, - size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); AlgoBase::SizeArgs args(this, src, dst, diff, grad); @@ -104,32 +98,29 @@ PoolingBackwardImpl::Algorithm* PoolingBackwardImpl::get_algorithm_heuristic( return iter; } } - megdnn_throw( - ssprintf("require algorithm with attribute(%s) and without " - "attribute(%s), but can't get suitable algo.\n", - Algorithm::attribute_str(positive_attr).c_str(), - Algorithm::attribute_str(negative_attr).c_str())); + megdnn_throw(ssprintf( + "require algorithm with attribute(%s) and without " + "attribute(%s), but can't get suitable algo.\n", + Algorithm::attribute_str(positive_attr).c_str(), + Algorithm::attribute_str(negative_attr).c_str())); return nullptr; } -void PoolingBackwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_in sdst, - _megdnn_tensor_in sdiff, - _megdnn_tensor_out sgrad, - _megdnn_workspace sworkspace) { - check_exec(ssrc.layout, sdst.layout, sdiff.layout, sgrad.layout, - sworkspace.size); +void PoolingBackwardImpl::exec( + _megdnn_tensor_in ssrc, _megdnn_tensor_in sdst, _megdnn_tensor_in sdiff, + _megdnn_tensor_out sgrad, _megdnn_workspace sworkspace) { + check_exec(ssrc.layout, sdst.layout, sdiff.layout, sgrad.layout, sworkspace.size); { AlgoBase::ExecArgs args(this, ssrc, sdst, sdiff, sgrad, sworkspace); - auto algo = get_algorithm(this, ssrc.layout, sdst.layout, sdiff.layout, - sgrad.layout); + auto algo = get_algorithm( + this, ssrc.layout, sdst.layout, sdiff.layout, sgrad.layout); algo->exec(args); } } -size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad) { +size_t PoolingBackwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad) { return get_dnn_workspace(this, src, dst, diff, grad); } diff --git a/dnn/src/cuda/pooling/opr_impl.h b/dnn/src/cuda/pooling/opr_impl.h index 3096290d..4ee923f1 100644 --- a/dnn/src/cuda/pooling/opr_impl.h +++ b/dnn/src/cuda/pooling/opr_impl.h @@ -20,10 +20,11 @@ namespace cuda { class PoolingForwardImpl final : public PoolingForward { public: using PoolingForward::PoolingForward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst) override; const char* get_algorithm_set_name() const override; Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; @@ -32,8 +33,8 @@ public: const TensorLayout& src, const TensorLayout& dst, size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { - return get_algorithm_heuristic(src, dst, workspace_limit_in_bytes, - positive_attr, negative_attr) + return get_algorithm_heuristic( + src, dst, workspace_limit_in_bytes, positive_attr, negative_attr) ->info(); } @@ -69,25 +70,23 @@ private: class PoolingBackwardImpl final : public PoolingBackward { public: using PoolingBackward::PoolingBackward; - void exec(_megdnn_tensor_in src, _megdnn_tensor_in dst, - _megdnn_tensor_in diff, _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - size_t get_workspace_in_bytes(const TensorLayout& src, - const TensorLayout& dst, - const TensorLayout& diff, - const TensorLayout& grad) override; + void exec( + _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad) override; const char* get_algorithm_set_name() const override; Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; AlgorithmInfo get_algorithm_info_heuristic( - const TensorLayout& src, const TensorLayout& dst, - const TensorLayout& diff, const TensorLayout& grad, - size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, - const AlgoAttribute& negative_attr) { - return get_algorithm_heuristic(src, dst, diff, grad, - workspace_limit_in_bytes, positive_attr, - negative_attr) + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { + return get_algorithm_heuristic( + src, dst, diff, grad, workspace_limit_in_bytes, positive_attr, + negative_attr) ->info(); } @@ -99,15 +98,15 @@ public: protected: std::vector get_all_algorithms( - const TensorLayout& src, const TensorLayout& dst, - const TensorLayout& diff, const TensorLayout& grad) override; + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad) override; std::vector get_all_algorithms_safe( - const TensorLayout& src, const TensorLayout& dst, - const TensorLayout& diff, const TensorLayout& grad) override; + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad) override; Algorithm* get_algorithm_heuristic( - const TensorLayout& src, const TensorLayout& dst, - const TensorLayout& diff, const TensorLayout& grad, - size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, + const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) override; private: diff --git a/dnn/src/cuda/pooling/pooling2d_qint.cu b/dnn/src/cuda/pooling/pooling2d_qint.cu index e0a7e6ca..e9da2725 100644 --- a/dnn/src/cuda/pooling/pooling2d_qint.cu +++ b/dnn/src/cuda/pooling/pooling2d_qint.cu @@ -18,10 +18,10 @@ using namespace cuda; using namespace pooling2d; namespace { -__device__ __forceinline__ int pack_int8_to_int8x4(int8_t x, int8_t y, int8_t z, - int8_t w) { - int ix = static_cast(x), iy = static_cast(y), - iz = static_cast(z), iw = static_cast(w); +__device__ __forceinline__ int pack_int8_to_int8x4( + int8_t x, int8_t y, int8_t z, int8_t w) { + int ix = static_cast(x), iy = static_cast(y), iz = static_cast(z), + iw = static_cast(w); asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(ix) : "r"(iy)); asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(iz) : "r"(iw)); @@ -50,8 +50,9 @@ __device__ __forceinline__ int4 pack_int8<16, 8, int4>(int8_t (&x)[16]) { int8_t x1[4]{x[4], x[5], x[6], x[7]}; int8_t x2[4]{x[8], x[9], x[10], x[11]}; int8_t x3[4]{x[12], x[13], x[14], x[15]}; - return ::make_int4(pack_int8<4, 8, int>(x0), pack_int8<4, 8, int>(x1), - pack_int8<4, 8, int>(x2), pack_int8<4, 8, int>(x3)); + return ::make_int4( + pack_int8<4, 8, int>(x0), pack_int8<4, 8, int>(x1), + pack_int8<4, 8, int>(x2), pack_int8<4, 8, int>(x3)); } __device__ __forceinline__ int8_t pack_int8_to_int4x2(int8_t x0, int8_t x1) { @@ -72,8 +73,9 @@ __device__ __forceinline__ int4 pack_int8<32, 4, int4>(int8_t (&x)[32]) { int8_t x1[8]{x[8], x[9], x[10], x[11], x[12], x[13], x[14], x[15]}; int8_t x2[8]{x[16], x[17], x[18], x[19], x[20], x[21], x[22], x[23]}; int8_t x3[8]{x[24], x[25], x[26], x[27], x[28], x[29], x[30], x[31]}; - return ::make_int4(pack_int8<8, 4, int>(x0), pack_int8<8, 4, int>(x1), - pack_int8<8, 4, int>(x2), pack_int8<8, 4, int>(x3)); + return ::make_int4( + pack_int8<8, 4, int>(x0), pack_int8<8, 4, int>(x1), + pack_int8<8, 4, int>(x2), pack_int8<8, 4, int>(x3)); } template @@ -129,7 +131,7 @@ struct MaxPooler { #pragma unroll for (int i = 0; i < unroll_n; i++) { int8_t temp = ((x >> (i * bit_width)) & TypeTrait::mask) - << shift_fix_sign; + << shift_fix_sign; temp = temp >> shift_fix_sign; res[idx + i] = res[idx + i] > temp ? res[idx + i] : temp; } @@ -183,7 +185,7 @@ struct MeanIncludeRoundedPooler { #pragma unroll for (int i = 0; i < unroll_n; i++) { int8_t temp = ((x >> (i * bit_width)) & TypeTrait::mask) - << shift_fix_sign; + << shift_fix_sign; temp = temp >> shift_fix_sign; res[idx + i] += static_cast(temp); } @@ -217,14 +219,13 @@ struct MeanIncludeRoundedPooler { for (int i = 0; i < nr_results; i++) { float f32_res = roundf(static_cast(res[i]) * fi_count); if (need_zero_pad) { - f32_res = roundf((static_cast(res[i]) + - (count - real_fi_count) * zero_pad) * - fi_count); + f32_res = + roundf((static_cast(res[i]) + + (count - real_fi_count) * zero_pad) * + fi_count); } int i8_res; - asm volatile("cvt.rni.s8.f32 %0, %1;" - : "=r"(i8_res) - : "f"(f32_res)); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(i8_res) : "f"(f32_res)); out_res[i] = i8_res; } ans = pack_int8(out_res); @@ -255,7 +256,7 @@ struct MeanExcludeRoundedPooler { #pragma unroll for (int i = 0; i < unroll_n; i++) { int8_t temp = ((x >> (i * bit_width)) & TypeTrait::mask) - << shift_fix_sign; + << shift_fix_sign; temp = temp >> shift_fix_sign; res[idx + i] += static_cast(temp); } @@ -284,9 +285,7 @@ struct MeanExcludeRoundedPooler { for (int i = 0; i < nr_results; i++) { float f32_res = roundf(static_cast(res[i]) / count); int i8_res; - asm volatile("cvt.rni.s8.f32 %0, %1;" - : "=r"(i8_res) - : "f"(f32_res)); + asm volatile("cvt.rni.s8.f32 %0, %1;" : "=r"(i8_res) : "f"(f32_res)); out_res[i] = i8_res; } ans = pack_int8(out_res); @@ -331,8 +330,7 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4( if (ih < param.hi && iw < param.wi) { const int8_t* __restrict__ cur_src_ptr = g_src_ptr + (ih * param.wi + iw) * npack; - ldg_type sval = - __ldg(reinterpret_cast(cur_src_ptr)); + ldg_type sval = __ldg(reinterpret_cast(cur_src_ptr)); pooler.feed(sval); } } @@ -341,11 +339,10 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4( *(reinterpret_cast(g_dst_ptr)) = res; } -template -__global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, - int8_t* __restrict__ dst, - Param param, int zero_point) { +template +__global__ void pooling2d_device_template_nchwc( + const int8_t* __restrict__ src, int8_t* __restrict__ dst, Param param, + int zero_point) { const int tid = blockIdx.x * blockDim.x + threadIdx.x; using ldg_type = typename Pooler::feed_type; static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); @@ -356,8 +353,7 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, "pooling2d (NCHW64) kernel must use 128bit width ldg instruction"); const int c_packed = param.c / pack_size; const int batch = tid / (param.ho * param.wo * c_packed * section); - const int batch_residual = - tid - batch * param.ho * param.wo * c_packed * section; + const int batch_residual = tid - batch * param.ho * param.wo * c_packed * section; const int oc = batch_residual / (param.ho * param.wo * section); const int oc_residual = batch_residual - oc * param.ho * param.wo * section; const int oh = oc_residual / (param.wo * section); @@ -367,15 +363,13 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, if (batch >= param.n || oc >= c_packed || oh >= param.ho || ow >= param.wo) return; - const int in_batch_stride = - param.hi * param.wi * param.c * pack_byte / pack_size; - const int out_batch_stride = - param.ho * param.wo * param.c * pack_byte / pack_size; + const int in_batch_stride = param.hi * param.wi * param.c * pack_byte / pack_size; + const int out_batch_stride = param.ho * param.wo * param.c * pack_byte / pack_size; const int in_channel_stride = param.hi * param.wi * pack_byte; const int out_channel_stride = param.ho * param.wo * pack_byte; const int8_t* __restrict__ g_src_ptr = - src + (batch * in_batch_stride + oc * in_channel_stride + - sec * ldg_width_bytes); + src + + (batch * in_batch_stride + oc * in_channel_stride + sec * ldg_width_bytes); int8_t* __restrict__ g_dst_ptr = dst + (batch * out_batch_stride + oc * out_channel_stride + (oh * param.wo + ow) * pack_byte + sec * ldg_width_bytes); @@ -389,8 +383,7 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, if (ih < param.hi && iw < param.wi) { const int8_t* __restrict__ cur_src_ptr = g_src_ptr + (ih * param.wi + iw) * pack_byte; - ldg_type sval = - __ldg(reinterpret_cast(cur_src_ptr)); + ldg_type sval = __ldg(reinterpret_cast(cur_src_ptr)); pooler.feed(sval); } } @@ -399,11 +392,10 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, *(reinterpret_cast(g_dst_ptr)) = res; } -template -__global__ void pooling2d_device_template_nhwc(const int8_t* __restrict__ src, - int8_t* __restrict__ dst, - Param param, int zero_point) { +template +__global__ void pooling2d_device_template_nhwc( + const int8_t* __restrict__ src, int8_t* __restrict__ dst, Param param, + int zero_point) { const int tid = blockIdx.x * blockDim.x + threadIdx.x; using ldg_type = typename Pooler::feed_type; static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); @@ -422,10 +414,8 @@ __global__ void pooling2d_device_template_nhwc(const int8_t* __restrict__ src, if (batch >= param.n || oh >= param.ho || ow >= param.wo) return; - const int in_batch_stride = - param.hi * param.wi * param.c * pack_byte / pack_size; - const int out_batch_stride = - param.ho * param.wo * param.c * pack_byte / pack_size; + const int in_batch_stride = param.hi * param.wi * param.c * pack_byte / pack_size; + const int out_batch_stride = param.ho * param.wo * param.c * pack_byte / pack_size; const int w_stride = param.c * pack_byte / pack_size; const int8_t* __restrict__ g_src_ptr = src + (batch * in_batch_stride + sec * ldg_width_bytes); @@ -442,8 +432,7 @@ __global__ void pooling2d_device_template_nhwc(const int8_t* __restrict__ src, if (ih < param.hi && iw < param.wi) { const int8_t* __restrict__ cur_src_ptr = g_src_ptr + (ih * param.wi + iw) * w_stride; - ldg_type sval = - __ldg(reinterpret_cast(cur_src_ptr)); + ldg_type sval = __ldg(reinterpret_cast(cur_src_ptr)); pooler.feed(sval); } } @@ -454,11 +443,9 @@ __global__ void pooling2d_device_template_nhwc(const int8_t* __restrict__ src, }; // namespace -void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, - int8_t* d_dst, - const Param& param, - cudaStream_t stream, - uint32_t mode) { +void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4( + const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, + uint32_t mode) { using Mode = megdnn::param_enumv::Pooling::Mode; void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); uint32_t vthreads_x = 0, vthreads_y = param.c / 4; @@ -504,34 +491,33 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, } void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4( - const int8_t* d_src, int8_t* d_dst, const Param& param, - cudaStream_t stream, uint32_t mode, bool /* uint_case */, int zero_point) { + const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, + uint32_t mode, bool /* uint_case */, int zero_point) { using Mode = megdnn::param_enumv::Pooling::Mode; - void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, - int zero_point); + void (*kern)( + const int8_t* __restrict__, int8_t* __restrict__, Param param, + int zero_point); constexpr int ldg_byte = 4; constexpr int elem_per_byte = 1; constexpr int pack_size = 4; constexpr int pack_byte = pack_size / elem_per_byte; constexpr int elem_per_thread = ldg_byte * elem_per_byte; constexpr int ldg_assert_width = ldg_byte / sizeof(int32_t); - uint32_t vthreads = - param.n * param.c * param.ho * param.wo / elem_per_thread; + uint32_t vthreads = param.n * param.c * param.ho * param.wo / elem_per_thread; switch (mode) { case Mode::MAX: - kern = pooling2d_device_template_nchwc, - pack_size, pack_byte, - ldg_assert_width>; + kern = pooling2d_device_template_nchwc< + MaxPooler, pack_size, pack_byte, ldg_assert_width>; break; case Mode::AVERAGE: kern = pooling2d_device_template_nchwc< - MeanIncludeRoundedPooler, - pack_size, pack_byte, ldg_assert_width>; + MeanIncludeRoundedPooler, pack_size, + pack_byte, ldg_assert_width>; break; case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: kern = pooling2d_device_template_nchwc< - MeanExcludeRoundedPooler, - pack_size, pack_byte, ldg_assert_width>; + MeanExcludeRoundedPooler, pack_size, + pack_byte, ldg_assert_width>; break; default: megdnn_assert(false, "invalid pooling mode"); @@ -544,22 +530,22 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4( } void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32( - const int8_t* d_src, int8_t* d_dst, const Param& param, - cudaStream_t stream, uint32_t mode, bool /* uint_case */, int zero_point) { + const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, + uint32_t mode, bool /* uint_case */, int zero_point) { using Mode = megdnn::param_enumv::Pooling::Mode; - void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, - int zero_point); + void (*kern)( + const int8_t* __restrict__, int8_t* __restrict__, Param param, + int zero_point); constexpr int ldg_byte = 16; constexpr int elem_per_byte = 1; constexpr int pack_size = 32; constexpr int pack_byte = pack_size / elem_per_byte; constexpr int elem_per_thread = ldg_byte * elem_per_byte; - uint32_t vthreads = - param.n * param.c * param.ho * param.wo / elem_per_thread; + uint32_t vthreads = param.n * param.c * param.ho * param.wo / elem_per_thread; switch (mode) { case Mode::MAX: - kern = pooling2d_device_template_nchwc, - pack_size, pack_byte>; + kern = pooling2d_device_template_nchwc< + MaxPooler, pack_size, pack_byte>; break; case Mode::AVERAGE: kern = pooling2d_device_template_nchwc< @@ -582,18 +568,18 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32( } void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64( - const int8_t* d_src, int8_t* d_dst, const Param& param, - cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { + const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, + uint32_t mode, bool uint_case, int zero_point) { using Mode = megdnn::param_enumv::Pooling::Mode; - void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, - int zero_point); + void (*kern)( + const int8_t* __restrict__, int8_t* __restrict__, Param param, + int zero_point); constexpr int ldg_byte = 16; constexpr int elem_per_byte = 2; constexpr int pack_size = 64; constexpr int pack_byte = pack_size / elem_per_byte; constexpr int elem_per_thread = ldg_byte * elem_per_byte; - uint32_t vthreads = - param.n * param.c * param.ho * param.wo / elem_per_thread; + uint32_t vthreads = param.n * param.c * param.ho * param.wo / elem_per_thread; if (uint_case) { switch (mode) { case Mode::MAX: @@ -602,13 +588,13 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64( break; case Mode::AVERAGE: kern = pooling2d_device_template_nchwc< - MeanIncludeRoundedPooler, - pack_size, pack_byte>; + MeanIncludeRoundedPooler, pack_size, + pack_byte>; break; case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: kern = pooling2d_device_template_nchwc< - MeanExcludeRoundedPooler, - pack_size, pack_byte>; + MeanExcludeRoundedPooler, pack_size, + pack_byte>; break; default: megdnn_assert(false, "invalid pooling mode"); @@ -617,18 +603,18 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64( } else { switch (mode) { case Mode::MAX: - kern = pooling2d_device_template_nchwc, - pack_size, pack_byte>; + kern = pooling2d_device_template_nchwc< + MaxPooler, pack_size, pack_byte>; break; case Mode::AVERAGE: kern = pooling2d_device_template_nchwc< - MeanIncludeRoundedPooler, - pack_size, pack_byte>; + MeanIncludeRoundedPooler, pack_size, + pack_byte>; break; case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: kern = pooling2d_device_template_nchwc< - MeanExcludeRoundedPooler, - pack_size, pack_byte>; + MeanExcludeRoundedPooler, pack_size, + pack_byte>; break; default: megdnn_assert(false, "invalid pooling mode"); @@ -642,11 +628,12 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64( } void megdnn::cuda::pooling2d::do_pooling2d_int4_nhwc( - const int8_t* d_src, int8_t* d_dst, const Param& param, - cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { + const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, + uint32_t mode, bool uint_case, int zero_point) { using Mode = megdnn::param_enumv::Pooling::Mode; - void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, - int zero_point); + void (*kern)( + const int8_t* __restrict__, int8_t* __restrict__, Param param, + int zero_point); megdnn_assert(param.c % 8 == 0); constexpr int ldg_byte = 4; @@ -655,8 +642,7 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_nhwc( constexpr int pack_size = ldg_byte * elem_per_byte; constexpr int pack_byte = pack_size / elem_per_byte; constexpr int elem_per_thread = ldg_byte * elem_per_byte; - uint32_t vthreads = - param.n * param.c * param.ho * param.wo / elem_per_thread; + uint32_t vthreads = param.n * param.c * param.ho * param.wo / elem_per_thread; if (uint_case) { switch (mode) { case Mode::MAX: @@ -687,13 +673,13 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_nhwc( break; case Mode::AVERAGE: kern = pooling2d_device_template_nhwc< - MeanIncludeRoundedPooler, - pack_size, pack_byte, ldg_width_assert>; + MeanIncludeRoundedPooler, pack_size, + pack_byte, ldg_width_assert>; break; case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: kern = pooling2d_device_template_nhwc< - MeanExcludeRoundedPooler, - pack_size, pack_byte, ldg_width_assert>; + MeanExcludeRoundedPooler, pack_size, + pack_byte, ldg_width_assert>; break; default: megdnn_assert(false, "invalid pooling mode"); diff --git a/dnn/src/cuda/pooling/pooling2d_qint.cuh b/dnn/src/cuda/pooling/pooling2d_qint.cuh index 5ad2ef6e..150b9973 100644 --- a/dnn/src/cuda/pooling/pooling2d_qint.cuh +++ b/dnn/src/cuda/pooling/pooling2d_qint.cuh @@ -21,29 +21,25 @@ struct Param { int n, c, hi, wi, ho, wo, ph, pw, window_h, window_w, sh, sw; }; -void do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, int8_t* d_dst, - const Param& param, cudaStream_t stream, - uint32_t mode); - -void do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, int8_t* d_dst, - const Param& param, cudaStream_t stream, - uint32_t mode, bool uint_case = false, - int zero_point = 0); - -void do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, int8_t* d_dst, - const Param& param, cudaStream_t stream, - uint32_t mode, bool uint_case = false, - int zero_point = 0); - -void do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, int8_t* d_dst, - const Param& param, cudaStream_t stream, - uint32_t mode, bool uint_case = false, - int zero_point = 0); - -void do_pooling2d_int4_nhwc(const int8_t* d_src, int8_t* d_dst, - const Param& param, cudaStream_t stream, - uint32_t mode, bool uint_case = false, - int zero_point = 0); +void do_pooling2d_int8_cdiv4hwn4( + const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, + uint32_t mode); + +void do_pooling2d_int8_ncdiv4hw4( + const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, + uint32_t mode, bool uint_case = false, int zero_point = 0); + +void do_pooling2d_int8_ncdiv32hw32( + const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, + uint32_t mode, bool uint_case = false, int zero_point = 0); + +void do_pooling2d_int4_ncdiv64hw64( + const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, + uint32_t mode, bool uint_case = false, int zero_point = 0); + +void do_pooling2d_int4_nhwc( + const int8_t* d_src, int8_t* d_dst, const Param& param, cudaStream_t stream, + uint32_t mode, bool uint_case = false, int zero_point = 0); } // namespace pooling2d } // namespace cuda diff --git a/dnn/src/cuda/powc/kern.cu b/dnn/src/cuda/powc/kern.cu index 7380ca36..46c42e35 100644 --- a/dnn/src/cuda/powc/kern.cu +++ b/dnn/src/cuda/powc/kern.cu @@ -120,9 +120,7 @@ template struct PowCFloat { T exp; - __device__ __forceinline__ T apply(T x) { - return static_cast(powf(x, exp)); - } + __device__ __forceinline__ T apply(T x) { return static_cast(powf(x, exp)); } }; template @@ -142,8 +140,8 @@ using namespace cuda_kern; namespace { template -void invoke(const TensorND& dest, const TensorND& src, PowOp pow_op, - cudaStream_t stream) { +void invoke( + const TensorND& dest, const TensorND& src, PowOp pow_op, cudaStream_t stream) { ElemwiseOpParamN<1> param; param[0] = src; param.init_from_given_tensor(); @@ -159,8 +157,9 @@ bool feq(float a, float b) { } template -void dispatch_op(const TensorND& dest, const TensorND& src, const float* exp_f, - const int* exp_i, cudaStream_t stream) { +void dispatch_op( + const TensorND& dest, const TensorND& src, const float* exp_f, const int* exp_i, + cudaStream_t stream) { #define CALL(_op) invoke(dest, src, _op, stream) if (exp_f) { float exp = *exp_f; @@ -213,14 +212,13 @@ void dispatch_op(const TensorND& dest, const TensorND& src, const float* exp_f, } } // anonymous namespace -void cuda::powc_kern(const TensorND& dest, const TensorND& src, - const float* exp_f, const int* exp_i, - cudaStream_t stream) { +void cuda::powc_kern( + const TensorND& dest, const TensorND& src, const float* exp_f, const int* exp_i, + cudaStream_t stream) { switch (src.layout.dtype.enumv().ev) { -#define cb(dt) \ - case DTypeTrait
::enumv: \ - return dispatch_op::ctype>(dest, src, exp_f, exp_i, \ - stream); +#define cb(dt) \ + case DTypeTrait
::enumv: \ + return dispatch_op::ctype>(dest, src, exp_f, exp_i, stream); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) #undef cb default: diff --git a/dnn/src/cuda/powc/kern.cuh b/dnn/src/cuda/powc/kern.cuh index 41b03212..106a98e3 100644 --- a/dnn/src/cuda/powc/kern.cuh +++ b/dnn/src/cuda/powc/kern.cuh @@ -15,8 +15,9 @@ namespace megdnn { namespace cuda { -void powc_kern(const TensorND& dest, const TensorND& src, const float* exp_f, - const int* exp_i, cudaStream_t stream); +void powc_kern( + const TensorND& dest, const TensorND& src, const float* exp_f, const int* exp_i, + cudaStream_t stream); } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/powc/opr_impl.cpp b/dnn/src/cuda/powc/opr_impl.cpp index 4b6f6ce5..1155367b 100644 --- a/dnn/src/cuda/powc/opr_impl.cpp +++ b/dnn/src/cuda/powc/opr_impl.cpp @@ -17,8 +17,9 @@ using namespace megdnn; using namespace cuda; -void PowCImpl::do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - const float* exp_f, const int* exp_i) { +void PowCImpl::do_exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, const float* exp_f, + const int* exp_i) { powc_kern(dst, src, exp_f, exp_i, cuda_stream(handle())); } diff --git a/dnn/src/cuda/powc/opr_impl.h b/dnn/src/cuda/powc/opr_impl.h index bc5fe0c6..3a0adbca 100644 --- a/dnn/src/cuda/powc/opr_impl.h +++ b/dnn/src/cuda/powc/opr_impl.h @@ -18,12 +18,12 @@ namespace cuda { class PowCImpl final : public PowC { public: using PowC::PowC; - void do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, - const float* exp_f, const int* exp_i) override; + void do_exec( + _megdnn_tensor_in src, _megdnn_tensor_out dst, const float* exp_f, + const int* exp_i) override; }; } // namespace cuda } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/cuda/query_blocksize.cpp b/dnn/src/cuda/query_blocksize.cpp index 022979d9..9f9f249f 100644 --- a/dnn/src/cuda/query_blocksize.cpp +++ b/dnn/src/cuda/query_blocksize.cpp @@ -29,17 +29,15 @@ struct pairhash { public: template size_t operator()(const std::pair& x) const { - return hash_pair_combine(std::hash{}(x.first), - std::hash{}(x.second)); + return hash_pair_combine(std::hash{}(x.first), std::hash{}(x.second)); } }; } // anonymous namespace -LaunchConfig cuda::query_launch_config_for_kernel(const void* kern, - const SmemGetter& smem) { +LaunchConfig cuda::query_launch_config_for_kernel( + const void* kern, const SmemGetter& smem) { static std::mutex mtx; - static std::unordered_map, LaunchConfig, - pairhash> + static std::unordered_map, LaunchConfig, pairhash> cache; std::lock_guard _lock{mtx}; @@ -47,11 +45,9 @@ LaunchConfig cuda::query_launch_config_for_kernel(const void* kern, cuda_check(cudaGetDevice(&device)); auto ins = cache.insert({{device, kern}, LaunchConfig{}}); if (ins.second) { - ins.first->second = - detail::query_launch_config_for_kernel_uncached(kern, smem); + ins.first->second = detail::query_launch_config_for_kernel_uncached(kern, smem); } return ins.first->second; } // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} - diff --git a/dnn/src/cuda/query_blocksize.cuh b/dnn/src/cuda/query_blocksize.cuh index c3af3df0..d1e8e6a5 100644 --- a/dnn/src/cuda/query_blocksize.cuh +++ b/dnn/src/cuda/query_blocksize.cuh @@ -49,12 +49,11 @@ static inline int query_blocksize_for_kernel(T kern) { } namespace detail { -LaunchConfig query_launch_config_for_kernel_uncached(const void* kern, - const SmemGetter& smem); +LaunchConfig query_launch_config_for_kernel_uncached( + const void* kern, const SmemGetter& smem); } } // namespace cuda } // namespace megdnn // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} - diff --git a/dnn/src/cuda/query_blocksize_impl.cu b/dnn/src/cuda/query_blocksize_impl.cu index 9f43bc3b..86044a96 100644 --- a/dnn/src/cuda/query_blocksize_impl.cu +++ b/dnn/src/cuda/query_blocksize_impl.cu @@ -52,4 +52,3 @@ LaunchConfig cuda::detail::query_launch_config_for_kernel_uncached( } // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} - diff --git a/dnn/src/cuda/reduce/opr_impl.cpp b/dnn/src/cuda/reduce/opr_impl.cpp index 0449e4e6..410b29e7 100644 --- a/dnn/src/cuda/reduce/opr_impl.cpp +++ b/dnn/src/cuda/reduce/opr_impl.cpp @@ -23,18 +23,17 @@ using namespace megdnn; using namespace cuda; template