GitOrigin-RevId: 843d885f82
release-1.1
@@ -11,6 +11,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
#include "megdnn/handle.h" | |||||
#include "megdnn/internal/visibility_prologue.h" | #include "megdnn/internal/visibility_prologue.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
@@ -105,11 +106,11 @@ public: | |||||
virtual bool is_reproducible() const = 0; | virtual bool is_reproducible() const = 0; | ||||
virtual const char* name() const = 0; | virtual const char* name() const = 0; | ||||
//! a pointer to represent class type | |||||
virtual void* type() const { return nullptr; } | |||||
Handle::HandleType handle_type() const { return m_handle_type; } | |||||
protected: | protected: | ||||
~Algorithm() = default; | ~Algorithm() = default; | ||||
Handle::HandleType m_handle_type = Handle::HandleType::NAIVE; | |||||
}; | }; | ||||
/*! | /*! | ||||
@@ -45,7 +45,7 @@ public: | |||||
SmallVector<AlgoBase*> matmul_algos; | SmallVector<AlgoBase*> matmul_algos; | ||||
}; | }; | ||||
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
static AlgoPack sl_algo_pack; | static AlgoPack sl_algo_pack; | ||||
auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | ||||
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | ||||
@@ -18,11 +18,16 @@ namespace aarch64 { | |||||
class ConvBiasImpl : public arm_common::ConvBiasImpl { | class ConvBiasImpl : public arm_common::ConvBiasImpl { | ||||
public: | public: | ||||
using arm_common::ConvBiasImpl::ConvBiasImpl; | using arm_common::ConvBiasImpl::ConvBiasImpl; | ||||
class AlgoBase : public arm_common::ConvBiasImpl::AlgoBase { | |||||
public: | |||||
AlgoBase() : arm_common::ConvBiasImpl::AlgoBase() { | |||||
m_handle_type = Handle::HandleType::AARCH64; | |||||
} | |||||
}; | |||||
SmallVector<AlgoBase*> algo_pack() override; | |||||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||||
protected: | protected: | ||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
private: | private: | ||||
@@ -26,7 +26,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -37,7 +36,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -48,7 +46,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -59,7 +56,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) | MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) | ||||
}; | }; | ||||
@@ -75,7 +71,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -86,7 +81,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | ||||
}; | }; | ||||
@@ -103,7 +97,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -116,7 +109,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
#else | #else | ||||
@@ -129,7 +121,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::DEFAULT; } | PackMode packmode() const override { return PackMode::DEFAULT; } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
@@ -143,7 +134,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -156,7 +146,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
#endif | #endif | ||||
@@ -169,7 +158,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -182,7 +170,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -196,7 +183,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::DEFAULT; } | PackMode packmode() const override { return PackMode::DEFAULT; } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
@@ -212,7 +198,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::DEFAULT; } | PackMode packmode() const override { return PackMode::DEFAULT; } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
@@ -226,7 +211,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::DEFAULT; } | PackMode packmode() const override { return PackMode::DEFAULT; } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
@@ -240,7 +224,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -251,7 +234,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | ||||
}; | }; | ||||
@@ -266,7 +248,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -278,7 +259,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) | ||||
@@ -292,7 +272,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
#endif | #endif | ||||
@@ -52,7 +52,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
#endif | #endif | ||||
public: | public: | ||||
SmallVector<MatrixMulImpl::AlgoBase*> all_algos; | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||||
AlgoPack() { | AlgoPack() { | ||||
all_algos.emplace_back(&f32_gemv); | all_algos.emplace_back(&f32_gemv); | ||||
@@ -89,7 +89,7 @@ public: | |||||
} | } | ||||
}; | }; | ||||
SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
static AlgoPack s_algo_pack; | static AlgoPack s_algo_pack; | ||||
auto&& algos = arm_common::MatrixMulImpl::algo_pack(); | auto&& algos = arm_common::MatrixMulImpl::algo_pack(); | ||||
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | ||||
@@ -18,8 +18,14 @@ namespace aarch64 { | |||||
class MatrixMulImpl : public arm_common::MatrixMulImpl { | class MatrixMulImpl : public arm_common::MatrixMulImpl { | ||||
public: | public: | ||||
using arm_common::MatrixMulImpl::MatrixMulImpl; | using arm_common::MatrixMulImpl::MatrixMulImpl; | ||||
class AlgoBase : public arm_common::MatrixMulImpl::AlgoBase { | |||||
public: | |||||
AlgoBase() : arm_common::MatrixMulImpl::AlgoBase() { | |||||
m_handle_type = Handle::HandleType::AARCH64; | |||||
} | |||||
}; | |||||
SmallVector<AlgoBase*> algo_pack() override; | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||||
private: | private: | ||||
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 | ||||
@@ -57,7 +63,7 @@ private: | |||||
#else | #else | ||||
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 | ||||
#endif | #endif | ||||
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | |||||
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 | |||||
class AlgoPack; | class AlgoPack; | ||||
}; | }; | ||||
@@ -11,6 +11,7 @@ | |||||
*/ | */ | ||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "megdnn/oprs/base.h" | |||||
#include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
#include "src/arm_common/conv_bias/int8x8x16/algos.h" | #include "src/arm_common/conv_bias/int8x8x16/algos.h" | ||||
#include "src/arm_common/conv_bias/quint8/algos.h" | #include "src/arm_common/conv_bias/quint8/algos.h" | ||||
@@ -18,6 +19,7 @@ | |||||
#include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
#include "src/common/metahelper.h" | #include "src/common/metahelper.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/conv_bias/opr_impl.h" | |||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
#include "src/arm_common/convolution/opr_impl.h" | #include "src/arm_common/convolution/opr_impl.h" | ||||
@@ -37,7 +39,12 @@ using namespace megdnn; | |||||
using namespace arm_common; | using namespace arm_common; | ||||
namespace { | namespace { | ||||
uint8_t arm_common_algo_type_storage; | |||||
bool is_fallback_or_naive(const detail::Algorithm* algo) { | |||||
return algo->handle_type() == Handle::HandleType::NAIVE || | |||||
algo->handle_type() == Handle::HandleType::FALLBACK; | |||||
} | |||||
} // anonymous namespace | } // anonymous namespace | ||||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | class ConvBiasImpl::AlgoPack : NonCopyableObj { | ||||
@@ -50,7 +57,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
AlgoS8DirectStride1 s8_direct_stride1; | AlgoS8DirectStride1 s8_direct_stride1; | ||||
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; | ||||
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; | ||||
AlgoS8x8x16ChanWiseStride1Stride2NCHW44 s8x8x16_channel_wise_stride1_stride2_nchw44; | |||||
AlgoS8x8x16ChanWiseStride1Stride2NCHW44 | |||||
s8x8x16_channel_wise_stride1_stride2_nchw44; | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
AlgoDotS8DirectStride1 ds8_direct_stride1; | AlgoDotS8DirectStride1 ds8_direct_stride1; | ||||
@@ -129,7 +137,7 @@ public: | |||||
->select_algo_type( | ->select_algo_type( | ||||
{AlgoDataType::FLOAT32, MatmulFormat::MK4}); | {AlgoDataType::FLOAT32, MatmulFormat::MK4}); | ||||
for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
if (algo->type() == nullptr) | |||||
if (is_fallback_or_naive(algo)) | |||||
continue; | continue; | ||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | for (uint32_t tile_size : {16, 8, 24, 32}) { | ||||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | ||||
@@ -166,7 +174,7 @@ public: | |||||
->select_algo_type({AlgoDataType::FLOAT32, | ->select_algo_type({AlgoDataType::FLOAT32, | ||||
MatmulFormat::DEFAULT}); | MatmulFormat::DEFAULT}); | ||||
for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
if (algo->type() == nullptr) | |||||
if (is_fallback_or_naive(algo)) | |||||
continue; | continue; | ||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | for (uint32_t tile_size : {16, 8, 24, 32}) { | ||||
refhold.emplace_back(new AlgoFP32WinogradF63( | refhold.emplace_back(new AlgoFP32WinogradF63( | ||||
@@ -189,7 +197,7 @@ public: | |||||
->select_algo_type({AlgoDataType::FLOAT16, | ->select_algo_type({AlgoDataType::FLOAT16, | ||||
MatmulFormat::DEFAULT}); | MatmulFormat::DEFAULT}); | ||||
for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
if (algo->type() == nullptr) | |||||
if (is_fallback_or_naive(algo)) | |||||
continue; | continue; | ||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | for (uint32_t tile_size : {16, 8, 24, 32}) { | ||||
refhold.emplace_back(new AlgoFP16WinogradF23( | refhold.emplace_back(new AlgoFP16WinogradF23( | ||||
@@ -210,7 +218,7 @@ public: | |||||
->select_algo_type({AlgoDataType::FLOAT16, | ->select_algo_type({AlgoDataType::FLOAT16, | ||||
MatmulFormat::MK8}); | MatmulFormat::MK8}); | ||||
for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
if (algo->type() == nullptr) | |||||
if (is_fallback_or_naive(algo)) | |||||
continue; | continue; | ||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | for (uint32_t tile_size : {16, 8, 24, 32}) { | ||||
refhold.emplace_back(new AlgoFP16WinogradF23_8x8( | refhold.emplace_back(new AlgoFP16WinogradF23_8x8( | ||||
@@ -224,7 +232,7 @@ public: | |||||
->select_algo_type({AlgoDataType::INT16X16X32, | ->select_algo_type({AlgoDataType::INT16X16X32, | ||||
MatmulFormat::MK8}); | MatmulFormat::MK8}); | ||||
for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
if (algo->type() == nullptr) | |||||
if (is_fallback_or_naive(algo)) | |||||
continue; | continue; | ||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | for (uint32_t tile_size : {16, 8, 24, 32}) { | ||||
refhold.emplace_back(new AlgoS8WinogradF23_8x8( | refhold.emplace_back(new AlgoS8WinogradF23_8x8( | ||||
@@ -242,7 +250,7 @@ public: | |||||
SmallVector<AlgoBase*> winograd_algos; | SmallVector<AlgoBase*> winograd_algos; | ||||
}; | }; | ||||
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
static AlgoPack sl_algo_pack; | static AlgoPack sl_algo_pack; | ||||
auto&& algos = fallback::ConvBiasImpl::algo_pack(); | auto&& algos = fallback::ConvBiasImpl::algo_pack(); | ||||
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | ||||
@@ -252,9 +260,6 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
return std::move(algos); | return std::move(algos); | ||||
} | } | ||||
void* const ConvBiasImpl::sm_arm_common_algo_type = | |||||
&arm_common_algo_type_storage; | |||||
bool ConvBiasImpl::is_matmul_quantized_prefer( | bool ConvBiasImpl::is_matmul_quantized_prefer( | ||||
const ConvBiasImpl::NCBKernSizeParam& param) const { | const ConvBiasImpl::NCBKernSizeParam& param) const { | ||||
fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param( | fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param( | ||||
@@ -19,23 +19,25 @@ namespace arm_common { | |||||
class ConvBiasImpl : public fallback::ConvBiasImpl { | class ConvBiasImpl : public fallback::ConvBiasImpl { | ||||
public: | public: | ||||
using fallback::ConvBiasImpl::ConvBiasImpl; | using fallback::ConvBiasImpl::ConvBiasImpl; | ||||
using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||||
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex; | |||||
bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
class AlgoBase : public fallback::ConvBiasImpl::AlgoBase { | |||||
public: | |||||
AlgoBase() : fallback::ConvBiasImpl::AlgoBase() { | |||||
m_handle_type = Handle::HandleType::ARM_COMMON; | |||||
} | |||||
}; | |||||
SmallVector<AlgoBase*> algo_pack() override; | |||||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||||
bool is_matmul_quantized_prefer( | bool is_matmul_quantized_prefer( | ||||
const ConvBiasImpl::NCBKernSizeParam& ncb_param) const override; | |||||
const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param) | |||||
const override; | |||||
SmallVector<AlgoCategory> suggest_algo_category_order( | SmallVector<AlgoCategory> suggest_algo_category_order( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
class AlgoPack; | class AlgoPack; | ||||
protected: | protected: | ||||
static void* const sm_arm_common_algo_type; | |||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
private: | private: | ||||
@@ -93,7 +95,7 @@ private: | |||||
class AlgoF16Direct; | class AlgoF16Direct; | ||||
class AlgoF16DirectStride1; | class AlgoF16DirectStride1; | ||||
#endif | #endif | ||||
}; | |||||
}; | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -26,12 +26,14 @@ using namespace arm_common; | |||||
/* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
/* ===================== direct stride 1 algo ===================== */ | /* ===================== direct stride 1 algo ===================== */ | ||||
bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( | bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||||
fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const { | |||||
return deconv::can_stride1_int8x8x32_dot(param); | return deconv::can_stride1_int8x8x32_dot(param); | ||||
} | } | ||||
size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||||
fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | ||||
midout_iv("AlgoSdot8DirectStride1::get_workspace"_hash)) { | midout_iv("AlgoSdot8DirectStride1::get_workspace"_hash)) { | ||||
return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param); | return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param); | ||||
@@ -42,7 +44,7 @@ size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace( | |||||
ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( | ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||||
fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | ||||
midout_iv("AlgoSdot8DirectStride1::dispatch_kern"_hash)) { | midout_iv("AlgoSdot8DirectStride1::dispatch_kern"_hash)) { | ||||
return deconv::stride1_int8x8x32_dot; | return deconv::stride1_int8x8x32_dot; | ||||
@@ -53,12 +55,14 @@ ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( | |||||
/* ===================== direct stride 2 algo ===================== */ | /* ===================== direct stride 2 algo ===================== */ | ||||
bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( | bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||||
fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const { | |||||
return deconv::can_stride2_int8x8x32_dot(param); | return deconv::can_stride2_int8x8x32_dot(param); | ||||
} | } | ||||
size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||||
fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | ||||
midout_iv("AlgoSdot8DirectStride2::get_workspace"_hash)) { | midout_iv("AlgoSdot8DirectStride2::get_workspace"_hash)) { | ||||
return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param); | return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param); | ||||
@@ -69,7 +73,7 @@ size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace( | |||||
ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern( | ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||||
fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||||
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, | ||||
midout_iv("AlgoSdot8DirectStride2::dispatch_kern"_hash)) { | midout_iv("AlgoSdot8DirectStride2::dispatch_kern"_hash)) { | ||||
return deconv::stride2_int8x8x32_dot; | return deconv::stride2_int8x8x32_dot; | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
@@ -19,38 +20,40 @@ namespace arm_common { | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
/* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final : public AlgoBase { | |||||
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final | |||||
: public AlgoBase { | |||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "AARCH32_I8x8x32_DECONV_STRIDE1"; } | |||||
const char* name() const override { | |||||
return "AARCH32_I8x8x32_DECONV_STRIDE1"; | |||||
} | |||||
bool usable(ConvolutionBackwardDataImpl*, | |||||
bool usable(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
size_t get_workspace(ConvolutionBackwardDataImpl*, | |||||
size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | |||||
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final : public AlgoBase { | |||||
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final | |||||
: public AlgoBase { | |||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "AARCH32_I8x8x32_DECONV_STRIDE2"; } | |||||
const char* name() const override { | |||||
return "AARCH32_I8x8x32_DECONV_STRIDE2"; | |||||
} | |||||
bool usable(ConvolutionBackwardDataImpl*, | |||||
bool usable(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
size_t get_workspace(ConvolutionBackwardDataImpl*, | |||||
size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | |||||
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -21,9 +21,6 @@ | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | using namespace arm_common; | ||||
namespace { | |||||
uint8_t arm_common_algo_type_storage; | |||||
} // anonymous namespace | |||||
/* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
struct ConvolutionBackwardDataImpl::AlgoPack { | struct ConvolutionBackwardDataImpl::AlgoPack { | ||||
@@ -36,46 +33,44 @@ struct ConvolutionBackwardDataImpl::AlgoPack { | |||||
}; | }; | ||||
ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; | ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; | ||||
void* const ConvolutionBackwardDataImpl::sm_arm_common_algo_type = | |||||
&arm_common_algo_type_storage; | |||||
ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( | |||||
ConvolutionBackwardDataImpl::ncb_kern_t | |||||
ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( | |||||
Algorithm* algo, const NCBKernSizeParam& param) { | Algorithm* algo, const NCBKernSizeParam& param) { | ||||
if (algo->type() == sm_arm_common_algo_type) { | |||||
if (algo->handle_type() == Handle::HandleType::ARM_COMMON) { | |||||
return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param); | return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param); | ||||
} | } | ||||
return fallback::ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(algo, param); | |||||
return fallback::ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(algo, | |||||
param); | |||||
} | } | ||||
size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(Algorithm* algo, | |||||
const NCBKernSizeParam& param) { | |||||
if (algo->type() == sm_arm_common_algo_type) { | |||||
size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( | |||||
Algorithm* algo, const NCBKernSizeParam& param) { | |||||
if (algo->handle_type() == Handle::HandleType::ARM_COMMON) { | |||||
return static_cast<AlgoBase*>(algo)->get_workspace(this, param); | return static_cast<AlgoBase*>(algo)->get_workspace(this, param); | ||||
} | } | ||||
return fallback::ConvolutionBackwardDataImpl::ncb_1g_get_workspace(algo, param); | |||||
return fallback::ConvolutionBackwardDataImpl::ncb_1g_get_workspace(algo, | |||||
param); | |||||
} | } | ||||
std::vector<ConvolutionBackwardDataImpl::Algorithm*> | std::vector<ConvolutionBackwardDataImpl::Algorithm*> | ||||
ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(const NCBKernSizeParam& param) { | |||||
auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(param); | |||||
ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | |||||
const NCBKernSizeParam& param) { | |||||
auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | |||||
param); | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
if((param.filter_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
param.filter_type.enumv() == DTypeEnum::Int8) && | |||||
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 || | |||||
param.grad_type.enumv() == DTypeEnum::Int32)) { | |||||
if ((param.filter_type.enumv() == DTypeEnum::QuantizedS8 || | |||||
param.filter_type.enumv() == DTypeEnum::Int8) && | |||||
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 || | |||||
param.grad_type.enumv() == DTypeEnum::Int32)) { | |||||
if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) { | if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) { | ||||
ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot); | ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot); | ||||
} | } | ||||
if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) { | if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) { | ||||
ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot); | ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot); | ||||
} | } | ||||
} | |||||
else if(param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && | |||||
param.grad_type.enumv() == DTypeEnum::QuantizedS32) { | |||||
} else if (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm && | |||||
param.grad_type.enumv() == DTypeEnum::QuantizedS32) { | |||||
if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) { | if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) { | ||||
ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot); | ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot); | ||||
} | } | ||||
@@ -18,24 +18,27 @@ namespace arm_common { | |||||
class ConvBiasImpl; | class ConvBiasImpl; | ||||
class ConvolutionBackwardDataImpl : public fallback::ConvolutionBackwardDataImpl { | |||||
class ConvolutionBackwardDataImpl | |||||
: public fallback::ConvolutionBackwardDataImpl { | |||||
public: | public: | ||||
using fallback::ConvolutionBackwardDataImpl::ConvolutionBackwardDataImpl; | using fallback::ConvolutionBackwardDataImpl::ConvolutionBackwardDataImpl; | ||||
protected: | protected: | ||||
static void* const sm_arm_common_algo_type; | |||||
class AlgoBase : public Algorithm { | |||||
class AlgoBase : public fallback::ConvolutionBackwardDataImpl::AlgoBase { | |||||
protected: | protected: | ||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
virtual bool usable(ConvolutionBackwardDataImpl* opr, | |||||
AlgoBase() : fallback::ConvolutionBackwardDataImpl::AlgoBase() { | |||||
m_handle_type = Handle::HandleType::ARM_COMMON; | |||||
} | |||||
virtual bool usable(fallback::ConvolutionBackwardDataImpl* opr, | |||||
const NCBKernSizeParam& param) const = 0; | const NCBKernSizeParam& param) const = 0; | ||||
virtual size_t get_workspace(ConvolutionBackwardDataImpl* opr, | |||||
virtual size_t get_workspace(fallback::ConvolutionBackwardDataImpl* opr, | |||||
const NCBKernSizeParam& param) const = 0; | const NCBKernSizeParam& param) const = 0; | ||||
virtual ncb_kern_t dispatch_kern( | virtual ncb_kern_t dispatch_kern( | ||||
ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param) const = 0; | |||||
fallback::ConvolutionBackwardDataImpl* opr, | |||||
const NCBKernSizeParam& param) const = 0; | |||||
}; | }; | ||||
ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo, | ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo, | ||||
@@ -49,7 +52,7 @@ protected: | |||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
private: | |||||
private: | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
class AlgoSdot8DirectStride1; | class AlgoSdot8DirectStride1; | ||||
class AlgoSdot8DirectStride2; | class AlgoSdot8DirectStride2; | ||||
@@ -62,4 +65,4 @@ protected: | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | |||||
// vim: syntax=cpp.doxygen |
@@ -27,12 +27,14 @@ using namespace arm_common; | |||||
/* ===================== direct stride 1 algo ===================== */ | /* ===================== direct stride 1 algo ===================== */ | ||||
bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( | bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||||
fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const { | |||||
return deconv::can_stride1_quint8_dot(param); | return deconv::can_stride1_quint8_dot(param); | ||||
} | } | ||||
size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||||
fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | ||||
midout_iv("AlgoUdot8DirectStride1::get_workspace"_hash)) { | midout_iv("AlgoUdot8DirectStride1::get_workspace"_hash)) { | ||||
return deconv::get_workspace_in_bytes_stride1_quint8_dot(param); | return deconv::get_workspace_in_bytes_stride1_quint8_dot(param); | ||||
@@ -43,7 +45,7 @@ size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace( | |||||
ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( | ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||||
fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | ||||
midout_iv("AlgoUdot8DirectStride1::dispatch_kern"_hash)) { | midout_iv("AlgoUdot8DirectStride1::dispatch_kern"_hash)) { | ||||
return deconv::stride1_quint8_dot; | return deconv::stride1_quint8_dot; | ||||
@@ -54,12 +56,14 @@ ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( | |||||
/* ===================== direct stride 2 algo ===================== */ | /* ===================== direct stride 2 algo ===================== */ | ||||
bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( | bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||||
fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const { | |||||
return deconv::can_stride2_quint8_dot(param); | return deconv::can_stride2_quint8_dot(param); | ||||
} | } | ||||
size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||||
fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const { | |||||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | ||||
midout_iv("AlgoUdot8DirectStride2::get_workspace"_hash)) { | midout_iv("AlgoUdot8DirectStride2::get_workspace"_hash)) { | ||||
return deconv::get_workspace_in_bytes_stride2_quint8_dot(param); | return deconv::get_workspace_in_bytes_stride2_quint8_dot(param); | ||||
@@ -70,7 +74,7 @@ size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace( | |||||
ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern( | ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern( | ||||
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||||
fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||||
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, | ||||
midout_iv("AlgoUdot8DirectStride2::dispatch_kern"_hash)) { | midout_iv("AlgoUdot8DirectStride2::dispatch_kern"_hash)) { | ||||
return deconv::stride2_quint8_dot; | return deconv::stride2_quint8_dot; | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
@@ -18,38 +19,42 @@ namespace arm_common { | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
/* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final : public AlgoBase { | |||||
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final | |||||
: public AlgoBase { | |||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1"; } | |||||
const char* name() const override { | |||||
return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1"; | |||||
} | |||||
bool usable(ConvolutionBackwardDataImpl*, | |||||
bool usable(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
size_t get_workspace(ConvolutionBackwardDataImpl*, | |||||
size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | |||||
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final : public AlgoBase { | |||||
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final | |||||
: public AlgoBase { | |||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
const char* name() const override { return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2"; } | |||||
const char* name() const override { | |||||
return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2"; | |||||
} | |||||
bool usable(ConvolutionBackwardDataImpl*, | |||||
bool usable(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
size_t get_workspace(ConvolutionBackwardDataImpl*, | |||||
size_t get_workspace(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | |||||
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*, | |||||
const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
}; | }; | ||||
#endif | #endif | ||||
} // namespace arm_common | } // namespace arm_common | ||||
@@ -24,7 +24,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) | ||||
}; | }; | ||||
@@ -37,7 +36,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | ||||
@@ -51,7 +49,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | ||||
@@ -66,7 +63,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) | ||||
@@ -84,7 +80,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | ||||
@@ -98,7 +93,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | ||||
@@ -113,7 +107,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) | ||||
@@ -128,7 +121,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC( | MEGDNN_OVERRIDE_MATMUL_DESC( | ||||
@@ -15,13 +15,6 @@ | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace arm_common; | using namespace arm_common; | ||||
namespace { | |||||
uint8_t arm_common_algo_type_storage; | |||||
} // anonymous namespace | |||||
void* const MatrixMulImpl::sm_arm_common_algo_type = | |||||
&arm_common_algo_type_storage; | |||||
class MatrixMulImpl::AlgoPack : NonCopyableObj { | class MatrixMulImpl::AlgoPack : NonCopyableObj { | ||||
AlgoInt8x8x16 int8x8x16; | AlgoInt8x8x16 int8x8x16; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
@@ -49,10 +42,10 @@ public: | |||||
all_algos.emplace_back(&f32_gemv_mk4); | all_algos.emplace_back(&f32_gemv_mk4); | ||||
all_algos.emplace_back(&gevm); | all_algos.emplace_back(&gevm); | ||||
} | } | ||||
SmallVector<AlgoBase*> all_algos; | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||||
}; | }; | ||||
SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
static AlgoPack s_algo_pack; | static AlgoPack s_algo_pack; | ||||
auto&& algos = fallback::MatrixMulImpl::algo_pack(); | auto&& algos = fallback::MatrixMulImpl::algo_pack(); | ||||
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | ||||
@@ -18,13 +18,18 @@ namespace arm_common { | |||||
class MatrixMulImpl : public fallback::MatrixMulImpl { | class MatrixMulImpl : public fallback::MatrixMulImpl { | ||||
public: | public: | ||||
using fallback::MatrixMulImpl::MatrixMulImpl; | using fallback::MatrixMulImpl::MatrixMulImpl; | ||||
bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
SmallVector<AlgoBase*> algo_pack() override; | |||||
class AlgoBase : public fallback::MatrixMulImpl::AlgoBase { | |||||
public: | |||||
AlgoBase() : fallback::MatrixMulImpl::AlgoBase() { | |||||
m_handle_type = Handle::HandleType::ARM_COMMON; | |||||
} | |||||
}; | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||||
protected: | protected: | ||||
static void* const sm_arm_common_algo_type; | |||||
class AlgoF32Gemv; // Arm_common F32 Gemv | class AlgoF32Gemv; // Arm_common F32 Gemv | ||||
class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 | class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 | ||||
class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv | class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv | ||||
@@ -32,7 +32,7 @@ public: | |||||
SmallVector<AlgoBase*> all_algos; | SmallVector<AlgoBase*> all_algos; | ||||
}; | }; | ||||
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
static AlgoPack sl_algo_pack; | static AlgoPack sl_algo_pack; | ||||
auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | ||||
//! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, | //! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, | ||||
@@ -18,11 +18,16 @@ namespace armv7 { | |||||
class ConvBiasImpl : public arm_common::ConvBiasImpl { | class ConvBiasImpl : public arm_common::ConvBiasImpl { | ||||
public: | public: | ||||
using arm_common::ConvBiasImpl::ConvBiasImpl; | using arm_common::ConvBiasImpl::ConvBiasImpl; | ||||
class AlgoBase : public arm_common::ConvBiasImpl::AlgoBase { | |||||
public: | |||||
AlgoBase() : arm_common::ConvBiasImpl::AlgoBase() { | |||||
m_handle_type = Handle::HandleType::ARMV7; | |||||
} | |||||
}; | |||||
SmallVector<AlgoBase*> algo_pack() override; | |||||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||||
protected: | protected: | ||||
const char* get_algorithm_set_name() const override; | const char* get_algorithm_set_name() const override; | ||||
private: | private: | ||||
@@ -26,7 +26,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -37,7 +36,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -48,7 +46,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) | MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) | ||||
}; | }; | ||||
@@ -61,7 +58,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { | ||||
@@ -71,7 +67,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | ||||
}; | }; | ||||
@@ -121,7 +116,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -133,7 +127,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -144,7 +137,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -156,7 +148,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -168,7 +159,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -180,7 +170,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -192,7 +181,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -203,7 +191,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | ||||
}; | }; | ||||
@@ -216,7 +203,6 @@ public: | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -44,7 +44,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; | AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; | ||||
public: | public: | ||||
SmallVector<MatrixMulImpl::AlgoBase*> all_algos; | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||||
AlgoPack() { | AlgoPack() { | ||||
all_algos.emplace_back(&f32_gemv); | all_algos.emplace_back(&f32_gemv); | ||||
@@ -73,7 +73,7 @@ public: | |||||
} | } | ||||
}; | }; | ||||
SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
static AlgoPack s_algo_pack; | static AlgoPack s_algo_pack; | ||||
auto algos = arm_common::MatrixMulImpl::algo_pack(); | auto algos = arm_common::MatrixMulImpl::algo_pack(); | ||||
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | ||||
@@ -18,7 +18,14 @@ namespace armv7 { | |||||
class MatrixMulImpl : public arm_common::MatrixMulImpl { | class MatrixMulImpl : public arm_common::MatrixMulImpl { | ||||
public: | public: | ||||
using arm_common::MatrixMulImpl::MatrixMulImpl; | using arm_common::MatrixMulImpl::MatrixMulImpl; | ||||
SmallVector<AlgoBase*> algo_pack() override; | |||||
class AlgoBase : public arm_common::MatrixMulImpl::AlgoBase { | |||||
public: | |||||
AlgoBase() : arm_common::MatrixMulImpl::AlgoBase() { | |||||
m_handle_type = Handle::HandleType::ARMV7; | |||||
} | |||||
}; | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||||
private: | private: | ||||
class AlgoF32; // Armv7 F32 | class AlgoF32; // Armv7 F32 | ||||
@@ -26,6 +26,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
BatchConvBiasForwardImpl* opr; | BatchConvBiasForwardImpl* opr; | ||||
TensorLayout src_layout, filter_layout, bias_layout, z_layout, | TensorLayout src_layout, filter_layout, bias_layout, z_layout, | ||||
@@ -28,6 +28,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
BatchedMatrixMulForwardImpl* opr; | BatchedMatrixMulForwardImpl* opr; | ||||
TensorLayout layout_a, layout_b, layout_c; | TensorLayout layout_a, layout_b, layout_c; | ||||
@@ -38,6 +38,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs : public conv_bias::BiasForwardSizeArgs { | struct SizeArgs : public conv_bias::BiasForwardSizeArgs { | ||||
ConvBiasForwardImpl* opr; | ConvBiasForwardImpl* opr; | ||||
@@ -28,6 +28,7 @@ class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm { | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
HandleImpl *handle; | HandleImpl *handle; | ||||
CanonizedFilterMeta filter_meta; | CanonizedFilterMeta filter_meta; | ||||
@@ -28,6 +28,7 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
HandleImpl *handle; | HandleImpl *handle; | ||||
const TensorLayout *src_layout, *diff_layout, *grad_layout; | const TensorLayout *src_layout, *diff_layout, *grad_layout; | ||||
@@ -28,6 +28,7 @@ class Convolution3DBackwardDataImpl::AlgoBase: public Algorithm { | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
HandleImpl *handle; | HandleImpl *handle; | ||||
CanonizedFilterMeta filter_meta; | CanonizedFilterMeta filter_meta; | ||||
@@ -22,6 +22,7 @@ class Convolution3DBackwardFilterImpl::AlgoBase: public Algorithm { | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
HandleImpl *handle; | HandleImpl *handle; | ||||
const TensorLayout *src_layout, *diff_layout; | const TensorLayout *src_layout, *diff_layout; | ||||
@@ -128,8 +129,8 @@ class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final: public AlgoBase | |||||
const char* name() const override { | const char* name() const override { | ||||
return "INPLACE_MATMUL"; | return "INPLACE_MATMUL"; | ||||
} | } | ||||
bool is_reproducible() const override { | |||||
return false; | |||||
bool is_reproducible() const override { | |||||
return false; | |||||
} | } | ||||
}; | }; | ||||
@@ -34,6 +34,7 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm { | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs: public convolution3d::ForwardSizeArgs { | struct SizeArgs: public convolution3d::ForwardSizeArgs { | ||||
Convolution3DForwardImpl *opr; | Convolution3DForwardImpl *opr; | ||||
@@ -42,11 +43,11 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm { | |||||
desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); | desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); | ||||
} | } | ||||
SizeArgs(Convolution3DForwardImpl *opr, | SizeArgs(Convolution3DForwardImpl *opr, | ||||
const TensorLayout &src, | |||||
const TensorLayout &src, | |||||
const TensorLayout &filter, | const TensorLayout &filter, | ||||
const TensorLayout &dst); | const TensorLayout &dst); | ||||
SizeArgs(Convolution3DForwardImpl *opr, | SizeArgs(Convolution3DForwardImpl *opr, | ||||
const TensorLayout &src, | |||||
const TensorLayout &src, | |||||
const CanonizedFilterMeta &filter, | const CanonizedFilterMeta &filter, | ||||
const TensorLayout &dst); | const TensorLayout &dst); | ||||
}; | }; | ||||
@@ -26,6 +26,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
DeformableConvBackwardDataImpl* opr; | DeformableConvBackwardDataImpl* opr; | ||||
HandleImpl* handle; | HandleImpl* handle; | ||||
@@ -26,6 +26,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
DeformableConvBackwardFilterImpl* opr; | DeformableConvBackwardFilterImpl* opr; | ||||
HandleImpl* handle; | HandleImpl* handle; | ||||
@@ -24,6 +24,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
DeformableConvForwardImpl* opr; | DeformableConvForwardImpl* opr; | ||||
HandleImpl* handle; | HandleImpl* handle; | ||||
@@ -25,6 +25,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
LocalShareBackwardDataImpl* opr; | LocalShareBackwardDataImpl* opr; | ||||
TensorLayout filter_layout, diff_layout, grad_layout; | TensorLayout filter_layout, diff_layout, grad_layout; | ||||
@@ -25,6 +25,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
LocalShareBackwardFilterImpl* opr; | LocalShareBackwardFilterImpl* opr; | ||||
TensorLayout src_layout, diff_layout, grad_layout; | TensorLayout src_layout, diff_layout, grad_layout; | ||||
@@ -25,6 +25,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
LocalShareForwardImpl* opr; | LocalShareForwardImpl* opr; | ||||
TensorLayout src_layout, filter_layout, dst_layout; | TensorLayout src_layout, filter_layout, dst_layout; | ||||
@@ -32,13 +32,14 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
MatrixMulForwardImpl* opr; | MatrixMulForwardImpl* opr; | ||||
TensorLayout layout_a, layout_b, layout_c; | TensorLayout layout_a, layout_b, layout_c; | ||||
std::string to_string() const; | std::string to_string() const; | ||||
SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A, const TensorLayout& B, | |||||
const TensorLayout& C); | |||||
SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A, | |||||
const TensorLayout& B, const TensorLayout& C); | |||||
bool can_be_treated_as_int8x8x32() const { | bool can_be_treated_as_int8x8x32() const { | ||||
return layout_a.dtype.enumv() == layout_b.dtype.enumv() && | return layout_a.dtype.enumv() == layout_b.dtype.enumv() && | ||||
@@ -213,6 +213,9 @@ public: | |||||
class AlgoBase : public Algorithm { | class AlgoBase : public Algorithm { | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { | |||||
m_handle_type = Handle::HandleType::FALLBACK; | |||||
} | |||||
virtual ~AlgoBase() = default; | virtual ~AlgoBase() = default; | ||||
virtual bool usable( | virtual bool usable( | ||||
const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
@@ -141,8 +141,6 @@ public: | |||||
return get_kimpl(m_algorithm, param); | return get_kimpl(m_algorithm, param); | ||||
} | } | ||||
void* type() const override { return sm_fallback_conv_algo_type; } | |||||
//! select matmul to the highest preference | //! select matmul to the highest preference | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
@@ -168,7 +166,6 @@ public: | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ||||
const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
void* type() const override { return sm_fallback_deconv_algo_type; } | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoMatrixMul final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoMatrixMul final : public AlgoBase { | ||||
@@ -181,7 +178,6 @@ public: | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ||||
const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
void* type() const override { return sm_fallback_deconv_algo_type; } | |||||
}; | }; | ||||
} // namespace fallback | } // namespace fallback | ||||
@@ -37,8 +37,6 @@ class NaiveConvolutionBackwardData final | |||||
const char* name() const override { return "NCBD"; } | const char* name() const override { return "NCBD"; } | ||||
}; | }; | ||||
NaiveConvolutionBackwardData naive_conv_backward_data; | NaiveConvolutionBackwardData naive_conv_backward_data; | ||||
uint8_t fallback_deconv_algo_type_storage; | |||||
uint8_t fallback_conv_algo_type_storage; | |||||
template <typename T> | template <typename T> | ||||
void incr_ptr(T*& dst, ptrdiff_t delta) { | void incr_ptr(T*& dst, ptrdiff_t delta) { | ||||
@@ -69,9 +67,6 @@ public: | |||||
SmallVector<AlgoBase*> all_algos; | SmallVector<AlgoBase*> all_algos; | ||||
}; | }; | ||||
void* const ConvolutionImpl::sm_fallback_conv_algo_type = | |||||
&fallback_conv_algo_type_storage; | |||||
SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() { | SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() { | ||||
static AlgoPack sl_algo_pack; | static AlgoPack sl_algo_pack; | ||||
return sl_algo_pack.all_algos; | return sl_algo_pack.all_algos; | ||||
@@ -412,9 +407,6 @@ ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const { | |||||
/* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type = | |||||
&fallback_deconv_algo_type_storage; | |||||
struct ConvolutionBackwardDataImpl::AlgoPack { | struct ConvolutionBackwardDataImpl::AlgoPack { | ||||
AlgoDirect direct; | AlgoDirect direct; | ||||
AlgoMatrixMul matmul; | AlgoMatrixMul matmul; | ||||
@@ -630,7 +622,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb( | |||||
size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( | size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( | ||||
Algorithm* algo, const NCBKernSizeParam& param) { | Algorithm* algo, const NCBKernSizeParam& param) { | ||||
megdnn_assert(param.filter_meta.group == 1); | megdnn_assert(param.filter_meta.group == 1); | ||||
if (algo->type() == sm_fallback_deconv_algo_type) { | |||||
if (algo->handle_type() == Handle::HandleType::FALLBACK) { | |||||
return static_cast<AlgoBase*>(algo)->get_workspace(this, param); | return static_cast<AlgoBase*>(algo)->get_workspace(this, param); | ||||
} | } | ||||
megdnn_assert(algo == &naive_conv_backward_data); | megdnn_assert(algo == &naive_conv_backward_data); | ||||
@@ -642,7 +634,7 @@ ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( | |||||
Algorithm* algo, const NCBKernSizeParam& param) { | Algorithm* algo, const NCBKernSizeParam& param) { | ||||
megdnn_assert(param.filter_meta.group == 1); | megdnn_assert(param.filter_meta.group == 1); | ||||
if (algo->type() == sm_fallback_deconv_algo_type) { | |||||
if (algo->handle_type() == Handle::HandleType::FALLBACK) { | |||||
return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param); | return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param); | ||||
} | } | ||||
@@ -177,8 +177,6 @@ public: | |||||
} | } | ||||
}; | }; | ||||
static void* const sm_fallback_conv_algo_type; | |||||
/** | /** | ||||
* \brief Kernel run time id, This information is used for getting the | * \brief Kernel run time id, This information is used for getting the | ||||
* work data | * work data | ||||
@@ -197,6 +195,9 @@ public: | |||||
class AlgoBase : public Algorithm { | class AlgoBase : public Algorithm { | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { | |||||
m_handle_type = Handle::HandleType::FALLBACK; | |||||
} | |||||
virtual ~AlgoBase() = default; | virtual ~AlgoBase() = default; | ||||
virtual bool usable(const NCBKernSizeParam& param, | virtual bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy) const = 0; | AlgoSelectionStrategy) const = 0; | ||||
@@ -407,13 +408,14 @@ protected: | |||||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
bool reproducible = false); | bool reproducible = false); | ||||
static void* const sm_fallback_deconv_algo_type; | |||||
class AlgoBase : public Algorithm { | class AlgoBase : public Algorithm { | ||||
protected: | protected: | ||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { | |||||
m_handle_type = Handle::HandleType::FALLBACK; | |||||
} | |||||
virtual bool usable(ConvolutionBackwardDataImpl* opr, | virtual bool usable(ConvolutionBackwardDataImpl* opr, | ||||
const NCBKernSizeParam& param) const = 0; | const NCBKernSizeParam& param) const = 0; | ||||
virtual size_t get_workspace(ConvolutionBackwardDataImpl* opr, | virtual size_t get_workspace(ConvolutionBackwardDataImpl* opr, | ||||
@@ -103,6 +103,7 @@ public: | |||||
} | } | ||||
public: | public: | ||||
AlgoBase() { m_handle_type = Handle::HandleType::FALLBACK; } | |||||
enum class AlgoSet : uint32_t { | enum class AlgoSet : uint32_t { | ||||
ALGO_TYPE_GEMM = 0, | ALGO_TYPE_GEMM = 0, | ||||
ALGO_TYPE_GEMV = 1, | ALGO_TYPE_GEMV = 1, | ||||
@@ -6,10 +6,11 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "hcc_detail/hcc_defs_prologue.h" | |||||
#include "./opr_impl.h" | #include "./opr_impl.h" | ||||
#include "hcc_detail/hcc_defs_prologue.h" | |||||
#include "src/common/utils.cuh" | #include "src/common/utils.cuh" | ||||
#include "src/rocm/handle.h" | #include "src/rocm/handle.h" | ||||
@@ -92,8 +93,8 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
static_cast<const rocblas_half*>(A.raw_ptr), | static_cast<const rocblas_half*>(A.raw_ptr), | ||||
A.layout.stride[1], A.layout.stride[0], | A.layout.stride[1], A.layout.stride[0], | ||||
reinterpret_cast<const rocblas_half*>(zero_half), | reinterpret_cast<const rocblas_half*>(zero_half), | ||||
static_cast<rocblas_half*>(C.raw_ptr), | |||||
C.layout.stride[1], C.layout.stride[0], batch)); | |||||
static_cast<rocblas_half*>(C.raw_ptr), C.layout.stride[1], | |||||
C.layout.stride[0], batch)); | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -25,6 +25,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
HandleImpl* handle; | HandleImpl* handle; | ||||
CanonizedFilterMeta filter_meta; | CanonizedFilterMeta filter_meta; | ||||
@@ -26,6 +26,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||||
struct SizeArgs { | struct SizeArgs { | ||||
HandleImpl* handle; | HandleImpl* handle; | ||||
const TensorLayout *src_layout, *diff_layout; | const TensorLayout *src_layout, *diff_layout; | ||||
@@ -32,6 +32,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | |||||
struct SizeArgs : public convolution::ForwardSizeArgs { | struct SizeArgs : public convolution::ForwardSizeArgs { | ||||
ConvolutionForwardImpl* opr; | ConvolutionForwardImpl* opr; | ||||
@@ -47,8 +47,6 @@ public: | |||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
void* type() const override; | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
} | } | ||||
@@ -84,8 +82,6 @@ public: | |||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
void* type() const override; | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
} | } | ||||
@@ -103,7 +99,6 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
void* type() const override; | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
}; | }; | ||||
@@ -119,7 +114,6 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
void* type() const override; | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
}; | }; | ||||
@@ -161,7 +155,6 @@ public: | |||||
}; | }; | ||||
return {{kern, {1_z, 1_z, 1_z}}}; | return {{kern, {1_z, 1_z, 1_z}}}; | ||||
} | } | ||||
void* type() const override; | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | ||||
@@ -32,7 +32,6 @@ public: | |||||
const NCBKernSizeParam& param) const override { | const NCBKernSizeParam& param) const override { | ||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
void* type() const override; | |||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
@@ -57,7 +56,6 @@ public: | |||||
const NCBKernSizeParam& param) const override { | const NCBKernSizeParam& param) const override { | ||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
void* type() const override; | |||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
@@ -82,7 +80,6 @@ public: | |||||
const NCBKernSizeParam& param) const override { | const NCBKernSizeParam& param) const override { | ||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
void* type() const override; | |||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
@@ -107,7 +104,6 @@ public: | |||||
const NCBKernSizeParam& param) const override { | const NCBKernSizeParam& param) const override { | ||||
return get_kimpls(param); | return get_kimpls(param); | ||||
} | } | ||||
void* type() const override; | |||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
@@ -148,7 +144,6 @@ public: | |||||
}; | }; | ||||
return {{kern, {group, n, 1_z}}}; | return {{kern, {group, n, 1_z}}}; | ||||
} | } | ||||
void* type() const override; | |||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
@@ -179,8 +174,6 @@ public: | |||||
//! select matmul to the highest preference | //! select matmul to the highest preference | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
void* type() const override; | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | ||||
} | } | ||||
@@ -22,54 +22,14 @@ | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace x86; | using namespace x86; | ||||
namespace { | namespace { | ||||
uint8_t x86_algo_type_storage; | |||||
void* x86_algo_type = &x86_algo_type_storage; | |||||
} // anonymous namespace | |||||
#if MEGDNN_X86_WITH_MKL_DNN | |||||
void* ConvBiasImpl::AlgoMkldnnQint8::type() const { | |||||
return x86_algo_type; | |||||
} | |||||
void* ConvBiasImpl::AlgoMkldnnMatmulQint8::type() const { | |||||
return x86_algo_type; | |||||
} | |||||
void* ConvBiasImpl::AlgoMkldnnConv::type() const { | |||||
return x86_algo_type; | |||||
} | |||||
#endif | |||||
void* ConvBiasImpl::AlgoDirect::type() const { | |||||
return x86_algo_type; | |||||
} | |||||
void* ConvBiasImpl::AlgoDirectStride2::type() const { | |||||
return x86_algo_type; | |||||
} | |||||
void* ConvBiasImpl::AlgoDirectAvx2Stride1Int8::type() const { | |||||
return x86_algo_type; | |||||
bool is_fallback_or_naive(const detail::Algorithm* algo) { | |||||
return algo->handle_type() == Handle::HandleType::NAIVE || | |||||
algo->handle_type() == Handle::HandleType::FALLBACK; | |||||
} | } | ||||
void* ConvBiasImpl::AlgoFP32WinogradF63_8x8::type() const { | |||||
return x86_algo_type; | |||||
} | |||||
void* ConvBiasImpl::AlgoFP32WinogradF23_8x8::type() const { | |||||
return x86_algo_type; | |||||
} | |||||
void* ConvBiasImpl::AlgoAVX2DirectConvStride2::type() const { | |||||
return x86_algo_type; | |||||
} | |||||
void* ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::type() const { | |||||
return x86_algo_type; | |||||
} | |||||
void* ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::type() const { | |||||
return x86_algo_type; | |||||
} | |||||
} // anonymous namespace | |||||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | class ConvBiasImpl::AlgoPack : NonCopyableObj { | ||||
AlgoDirect stride1_direct; | AlgoDirect stride1_direct; | ||||
@@ -88,8 +48,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
//! FIXME: preference to use mkldnn algo on VNNI devices | |||||
//! But now mkldnn algo preference issue with NCHW->NHWC->NCHW | |||||
//! FIXME: preference to use mkldnn algo on VNNI devices | |||||
//! But now mkldnn algo preference issue with NCHW->NHWC->NCHW | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
//! Create the mkldnn algo | //! Create the mkldnn algo | ||||
all_algos.emplace_back(&mkldnn_conv_fp32); | all_algos.emplace_back(&mkldnn_conv_fp32); | ||||
@@ -108,7 +68,7 @@ public: | |||||
auto&& matmul_algos = | auto&& matmul_algos = | ||||
static_cast<MatrixMulImpl*>(matmul_opr)->algo_pack(); | static_cast<MatrixMulImpl*>(matmul_opr)->algo_pack(); | ||||
for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
if (algo->type() == nullptr) | |||||
if (is_fallback_or_naive(algo)) | |||||
continue; | continue; | ||||
for (uint32_t tile_size : {8, 16, 24}) { | for (uint32_t tile_size : {8, 16, 24}) { | ||||
refhold.emplace_back(new AlgoFP32WinogradF63_8x8( | refhold.emplace_back(new AlgoFP32WinogradF63_8x8( | ||||
@@ -126,7 +86,7 @@ public: | |||||
SmallVector<AlgoBase*> winograd_algos; | SmallVector<AlgoBase*> winograd_algos; | ||||
}; | }; | ||||
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
static AlgoPack sl_algo_pack; | static AlgoPack sl_algo_pack; | ||||
auto&& algos = fallback::ConvBiasImpl::algo_pack(); | auto&& algos = fallback::ConvBiasImpl::algo_pack(); | ||||
algos.insert(algos.begin(), sl_algo_pack.all_algos.begin(), | algos.insert(algos.begin(), sl_algo_pack.all_algos.begin(), | ||||
@@ -176,8 +136,8 @@ bool ConvBiasImpl::is_matmul_quantized_prefer( | |||||
!chanwise_avx2_stride2_qint8_usable_preferred(param)); | !chanwise_avx2_stride2_qint8_usable_preferred(param)); | ||||
} | } | ||||
SmallVector<AlgoCategory> | |||||
ConvBiasImpl::suggest_algo_category_order(const NCBKernSizeParam& param) const { | |||||
SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order( | |||||
const NCBKernSizeParam& param) const { | |||||
auto IC = param.filter_meta.icpg; | auto IC = param.filter_meta.icpg; | ||||
auto OC = param.filter_meta.ocpg; | auto OC = param.filter_meta.ocpg; | ||||
auto FH = param.filter_meta.spatial[0]; | auto FH = param.filter_meta.spatial[0]; | ||||
@@ -20,10 +20,15 @@ namespace x86 { | |||||
class ConvBiasImpl : public fallback::ConvBiasImpl { | class ConvBiasImpl : public fallback::ConvBiasImpl { | ||||
public: | public: | ||||
using fallback::ConvBiasImpl::ConvBiasImpl; | using fallback::ConvBiasImpl::ConvBiasImpl; | ||||
using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||||
class AlgoBase : public fallback::ConvBiasImpl::AlgoBase { | |||||
public: | |||||
AlgoBase() : fallback::ConvBiasImpl::AlgoBase() { | |||||
m_handle_type = Handle::HandleType::X86; | |||||
} | |||||
}; | |||||
bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
SmallVector<AlgoBase*> algo_pack() override; | |||||
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override; | |||||
SmallVector<AlgoCategory> suggest_algo_category_order( | SmallVector<AlgoCategory> suggest_algo_category_order( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
@@ -25,7 +25,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | ||||
}; | }; | ||||
@@ -38,7 +37,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | |||||
PackMode packmode() const override { return PackMode::ONLY_PACKA; } | PackMode packmode() const override { return PackMode::ONLY_PACKA; } | ||||
kern_naked_t get_kern_naked(const KernSizeParam&) const override; | kern_naked_t get_kern_naked(const KernSizeParam&) const override; | ||||
void pack_A(const KernParam& kern_param, void* out, size_t index, | void pack_A(const KernParam& kern_param, void* out, size_t index, | ||||
@@ -60,7 +58,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -71,7 +68,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -86,7 +82,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -102,7 +97,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | |||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -114,7 +108,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
@@ -125,7 +118,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4, AlgoDataType::FLOAT32, MK8) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4, AlgoDataType::FLOAT32, MK8) | ||||
}; | }; | ||||
@@ -138,7 +130,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
#endif | #endif | ||||
@@ -151,7 +142,6 @@ public: | |||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | ||||
}; | }; | ||||
@@ -16,12 +16,6 @@ | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace x86; | using namespace x86; | ||||
namespace { | |||||
uint8_t x86_algo_type_storage; | |||||
} // anonymous namespace | |||||
void* const MatrixMulImpl::sm_x86_algo_type = &x86_algo_type_storage; | |||||
class MatrixMulImpl::AlgoPack : NonCopyableObj { | class MatrixMulImpl::AlgoPack : NonCopyableObj { | ||||
AlgoF32Blas f32blas; | AlgoF32Blas f32blas; | ||||
@@ -62,10 +56,10 @@ public: | |||||
all_algos.emplace_back(&f32mkl_packa); | all_algos.emplace_back(&f32mkl_packa); | ||||
#endif | #endif | ||||
} | } | ||||
SmallVector<AlgoBase*> all_algos; | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos; | |||||
}; | }; | ||||
SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
static AlgoPack s_algo_pack; | static AlgoPack s_algo_pack; | ||||
auto&& algos = fallback::MatrixMulImpl::algo_pack(); | auto&& algos = fallback::MatrixMulImpl::algo_pack(); | ||||
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), | ||||
@@ -33,13 +33,18 @@ namespace x86 { | |||||
class MatrixMulImpl : public fallback::MatrixMulImpl { | class MatrixMulImpl : public fallback::MatrixMulImpl { | ||||
public: | public: | ||||
using fallback::MatrixMulImpl::MatrixMulImpl; | using fallback::MatrixMulImpl::MatrixMulImpl; | ||||
class AlgoBase : public fallback::MatrixMulImpl::AlgoBase { | |||||
public: | |||||
AlgoBase() : fallback::MatrixMulImpl::AlgoBase() { | |||||
m_handle_type = Handle::HandleType::X86; | |||||
} | |||||
}; | |||||
bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
SmallVector<AlgoBase*> algo_pack() override; | |||||
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override; | |||||
protected: | protected: | ||||
static void* const sm_x86_algo_type; | |||||
class AlgoF32Blas; | class AlgoF32Blas; | ||||
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | ||||
class AlgoF32MKLPackA; | class AlgoF32MKLPackA; | ||||