GitOrigin-RevId: 86dead0a11
tags/v1.3.0
@@ -99,6 +99,27 @@ enum class AlgoDataType : uint32_t { | |||||
class Algorithm { | class Algorithm { | ||||
public: | public: | ||||
static constexpr uint32_t INVALID_ALGO_TYPE = static_cast<uint32_t>(-1); | static constexpr uint32_t INVALID_ALGO_TYPE = static_cast<uint32_t>(-1); | ||||
/** | |||||
* \brief the attribe of the algo, such as REPRODUCIBLE, NAIVE | |||||
* | |||||
*/ | |||||
enum class Attribute : uint32_t { | |||||
/** | |||||
* \brief whether the execution result is | |||||
* reproducible across multiple runs. | |||||
*/ | |||||
REPRODUCIBLE = 1 << 0, | |||||
/** | |||||
* \brief whether the algo is naive | |||||
* Mark algorithms with simple implementation as NAIVE, so we can filter | |||||
* these algorithms to speed up fastrun. | |||||
* */ | |||||
NAIVE = 1 << 1, | |||||
}; | |||||
/** | /** | ||||
* \brief Algorithm information, we can get real algo from | * \brief Algorithm information, we can get real algo from | ||||
* AlgorithmInfo::Info::Desc | * AlgorithmInfo::Info::Desc | ||||
@@ -121,7 +142,7 @@ public: | |||||
} desc; | } desc; | ||||
//! algorithm name | //! algorithm name | ||||
std::string name; | std::string name; | ||||
bool is_reproducible; | |||||
Attribute attribute; | |||||
bool valid() const { return desc.valid(); } | bool valid() const { return desc.valid(); } | ||||
void reset() { desc.reset(); } | void reset() { desc.reset(); } | ||||
//! desc donate the algo | //! desc donate the algo | ||||
@@ -131,18 +152,20 @@ public: | |||||
virtual ~Algorithm() = default; | virtual ~Algorithm() = default; | ||||
/** | /** | ||||
* \brief whether the execution result is | |||||
* reproducible across multiple runs. | |||||
* \brief get the attribute of the algo | |||||
*/ | */ | ||||
virtual bool is_reproducible() const = 0; | |||||
virtual Attribute attribute() const = 0; | |||||
virtual const char* name() const = 0; | virtual const char* name() const = 0; | ||||
//! serialized param | //! serialized param | ||||
virtual std::string param() const { return {}; } | virtual std::string param() const { return {}; } | ||||
virtual uint32_t type() const = 0; | virtual uint32_t type() const = 0; | ||||
bool contain_attribute(const Attribute& attr) const; | |||||
Handle::HandleType handle_type() const { return m_handle_type; } | Handle::HandleType handle_type() const { return m_handle_type; } | ||||
Info info() const { | Info info() const { | ||||
return {{handle_type(), type(), param()}, name(), is_reproducible()}; | |||||
return {{handle_type(), type(), param()}, name(), attribute()}; | |||||
} | } | ||||
Info::Desc desc() const { return {handle_type(), type(), param()}; } | Info::Desc desc() const { return {handle_type(), type(), param()}; } | ||||
@@ -524,6 +547,7 @@ protected: | |||||
} // namespace detail | } // namespace detail | ||||
using Algorithm = detail::Algorithm; | using Algorithm = detail::Algorithm; | ||||
using AlgoAttribute = Algorithm::Attribute; | |||||
using ExecutionPolicy = detail::ExecutionPolicy; | using ExecutionPolicy = detail::ExecutionPolicy; | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -19,7 +19,9 @@ namespace aarch64 { | |||||
class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV8F16STRD2"; } | const char* name() const override { return "ARMV8F16STRD2"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -23,7 +23,9 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | ||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV8F32STRD2"; } | const char* name() const override { return "ARMV8F32STRD2"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
@@ -25,7 +25,9 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { | |||||
static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); | static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "S8MATMUL"; } | const char* name() const override { return "S8MATMUL"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
@@ -25,7 +25,9 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { | |||||
static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "QU8MATMUL"; } | const char* name() const override { return "QU8MATMUL"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
@@ -21,7 +21,9 @@ namespace aarch64 { | |||||
class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_F32K8X12X1"; } | const char* name() const override { return "AARCH64_F32K8X12X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -32,7 +34,9 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -43,7 +47,9 @@ public: | |||||
class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_F32K4X16X1"; } | const char* name() const override { return "AARCH64_F32K4X16X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -54,7 +60,9 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_F32_MK4_4x16"; } | const char* name() const override { return "AARCH64_F32_MK4_4x16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -76,7 +84,9 @@ public: | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_F16_K8X24X1"; } | const char* name() const override { return "AARCH64_F16_K8X24X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -87,7 +97,9 @@ public: | |||||
class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_F16_MK8_8X8"; } | const char* name() const override { return "AARCH64_F16_MK8_8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -102,7 +114,9 @@ public: | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; | return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; | ||||
} | } | ||||
@@ -115,7 +129,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; | return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; | ||||
} | } | ||||
@@ -129,7 +145,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -143,7 +161,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -156,7 +176,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } | const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -169,7 +191,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } | const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -182,7 +206,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -194,7 +220,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -207,7 +235,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH64_INT8X8X16_MK4_16X12X4"; | return "AARCH64_INT8X8X16_MK4_16X12X4"; | ||||
} | } | ||||
@@ -223,7 +253,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH64_INT8X8X16_MK4_K8X8X8"; | return "AARCH64_INT8X8X16_MK4_K8X8X8"; | ||||
} | } | ||||
@@ -239,7 +271,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -253,7 +287,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } | const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -265,7 +301,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } | const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -278,7 +316,9 @@ public: | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH64_QUINT8_K8X8X4_DOTPROD"; | return "AARCH64_QUINT8_K8X8X4_DOTPROD"; | ||||
} | } | ||||
@@ -291,7 +331,9 @@ public: | |||||
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -306,7 +348,9 @@ public: | |||||
class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } | const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -29,6 +29,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_FP16) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_FP16) | ||||
}; | }; | ||||
@@ -45,6 +48,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP16) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP16) | ||||
}; | }; | ||||
@@ -60,7 +66,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP16) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP16) | ||||
}; | }; | ||||
@@ -76,6 +84,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_FP16) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_FP16) | ||||
}; | }; | ||||
@@ -84,7 +95,9 @@ class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "F16DIRECT"; } | const char* name() const override { return "F16DIRECT"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -104,7 +117,9 @@ class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "F16STRD1"; } | const char* name() const override { return "F16STRD1"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -29,6 +29,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_FP32) | ||||
}; | }; | ||||
@@ -45,6 +48,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_FP32) | ||||
}; | }; | ||||
@@ -61,6 +67,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_FP32) | ||||
}; | }; | ||||
@@ -77,6 +86,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F54_FP32) | ||||
}; | }; | ||||
@@ -93,6 +105,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F45_FP32) | ||||
}; | }; | ||||
@@ -111,6 +126,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_4X4_NCHW44_F32) | ||||
}; | }; | ||||
@@ -128,6 +146,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F63_4X4_NCHW44_F32) | ||||
}; | }; | ||||
@@ -145,6 +166,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F73_4X4_NCHW44_F32) | ||||
}; | }; | ||||
@@ -154,7 +178,9 @@ class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "F32DIRECT"; } | const char* name() const override { return "F32DIRECT"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -172,7 +198,9 @@ class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "F32STRD1"; } | const char* name() const override { return "F32STRD1"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -190,7 +218,9 @@ class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "F32STRD2"; } | const char* name() const override { return "F32STRD2"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -209,7 +239,9 @@ class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { | |||||
public: | public: | ||||
AlgoF32DirectNCHW44() {} | AlgoF32DirectNCHW44() {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "F32_CONV_NCHW44_DIRECT"; } | const char* name() const override { return "F32_CONV_NCHW44_DIRECT"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -228,7 +260,9 @@ class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { | |||||
public: | public: | ||||
AlgoF32DirectNCHWNCHW44() {} | AlgoF32DirectNCHWNCHW44() {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "F32_CONV_NCHW_NCHW44"; } | const char* name() const override { return "F32_CONV_NCHW_NCHW44"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -246,7 +280,9 @@ class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { | |||||
SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "F32_CHANNEL_WISE_NCHW44"; } | const char* name() const override { return "F32_CHANNEL_WISE_NCHW44"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -20,7 +20,9 @@ namespace arm_common { | |||||
class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectStride1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "S8STRD1"; } | const char* name() const override { return "S8STRD1"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -39,7 +41,9 @@ public: | |||||
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "S8STRD2"; } | const char* name() const override { return "S8STRD2"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -56,7 +60,9 @@ public: | |||||
class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoS8DirectNCHW44() {} | AlgoS8DirectNCHW44() {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "S8_NCHW44_DIRECT"; } | const char* name() const override { return "S8_NCHW44_DIRECT"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -73,7 +79,9 @@ public: | |||||
class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoS8DirectNCHWNCHW44() {} | AlgoS8DirectNCHWNCHW44() {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "S8_CONV_NCHW_NCHW44"; } | const char* name() const override { return "S8_CONV_NCHW_NCHW44"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -89,7 +97,9 @@ public: | |||||
class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "S8_CHAN_WISE_STRD1_NCHW44"; } | const char* name() const override { return "S8_CHAN_WISE_STRD1_NCHW44"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -104,7 +114,9 @@ public: | |||||
class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "S8_CHAN_WISE_STRD2_NCHW44"; } | const char* name() const override { return "S8_CHAN_WISE_STRD2_NCHW44"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -121,7 +133,9 @@ public: | |||||
class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMDOTS8_NCHW_NCHW44"; } | const char* name() const override { return "ARMDOTS8_NCHW_NCHW44"; } | ||||
bool usable(const NCBKernSizeParam&, | bool usable(const NCBKernSizeParam&, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -138,7 +152,9 @@ public: | |||||
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMDOTS8STRD1"; } | const char* name() const override { return "ARMDOTS8STRD1"; } | ||||
bool usable(const NCBKernSizeParam&, | bool usable(const NCBKernSizeParam&, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -155,7 +171,9 @@ public: | |||||
class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMDOTS8STRD2"; } | const char* name() const override { return "ARMDOTS8STRD2"; } | ||||
bool usable(const NCBKernSizeParam&, | bool usable(const NCBKernSizeParam&, | ||||
@@ -174,7 +192,9 @@ class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { | |||||
public: | public: | ||||
AlgoDotS8Direct_NCHW44() {} | AlgoDotS8Direct_NCHW44() {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMDOTS8DIRECT_NCHW44"; } | const char* name() const override { return "ARMDOTS8DIRECT_NCHW44"; } | ||||
bool usable(const NCBKernSizeParam&, | bool usable(const NCBKernSizeParam&, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -205,6 +225,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_S8) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_S8) | ||||
}; | }; | ||||
@@ -223,6 +246,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8CF32) | ||||
}; | }; | ||||
@@ -241,7 +267,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | ||||
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8) | MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_WINOGRAD_F23_8X8_NCHW44_S8) | ||||
}; | }; | ||||
@@ -29,7 +29,9 @@ class ConvBiasImpl::AlgoI8x8x16Direct final : public AlgoBase { | |||||
const CpuNDRange& workspace_ids); | const CpuNDRange& workspace_ids); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "I8816DIRECT"; } | const char* name() const override { return "I8816DIRECT"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -45,7 +47,9 @@ public: | |||||
class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoS8x8x16DirectNCHW44() {} | AlgoS8x8x16DirectNCHW44() {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "S8x8x16_NCHW44_DIRECT"; } | const char* name() const override { return "S8x8x16_NCHW44_DIRECT"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -71,7 +75,9 @@ class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | |||||
const CpuNDRange& workspace_ids); | const CpuNDRange& workspace_ids); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "I8816STRD2"; } | const char* name() const override { return "I8816STRD2"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -87,7 +93,9 @@ public: | |||||
class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { | class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "I8816STRD2F2"; } | const char* name() const override { return "I8816STRD2F2"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
@@ -105,10 +113,10 @@ public: | |||||
class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final | class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final | ||||
: public AlgoBase { | : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { | |||||
return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | } | ||||
const char* name() const override { return "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"; } | |||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
size_t get_workspace( | size_t get_workspace( | ||||
@@ -126,7 +134,9 @@ class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { | |||||
public: | public: | ||||
AlgoI8x8x16DirectNCHWNCHW44() {} | AlgoI8x8x16DirectNCHWNCHW44() {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "I8816_CONV_NCHW_NCHW44"; } | const char* name() const override { return "I8816_CONV_NCHW_NCHW44"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -20,7 +20,9 @@ namespace arm_common { | |||||
class ConvBiasImpl::AlgoQU8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoQU8DirectStride1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "QU8STRD1"; } | const char* name() const override { return "QU8STRD1"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
@@ -38,7 +40,9 @@ public: | |||||
class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "QU8STRD2"; } | const char* name() const override { return "QU8STRD2"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -55,7 +59,9 @@ public: | |||||
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMDOTU8STRD1"; } | const char* name() const override { return "ARMDOTU8STRD1"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
@@ -73,7 +79,9 @@ public: | |||||
class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMDOTU8STRD2"; } | const char* name() const override { return "ARMDOTU8STRD2"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -23,7 +23,9 @@ namespace arm_common { | |||||
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final | class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final | ||||
: public AlgoBase { | : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH32_I8x8x32_DECONV_STRIDE1"; | return "AARCH32_I8x8x32_DECONV_STRIDE1"; | ||||
} | } | ||||
@@ -42,7 +44,9 @@ public: | |||||
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final | class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final | ||||
: public AlgoBase { | : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH32_I8x8x32_DECONV_STRIDE2"; | return "AARCH32_I8x8x32_DECONV_STRIDE2"; | ||||
} | } | ||||
@@ -22,7 +22,9 @@ namespace arm_common { | |||||
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final | class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final | ||||
: public AlgoBase { | : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1"; | return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1"; | ||||
} | } | ||||
@@ -42,7 +44,9 @@ public: | |||||
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final | class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final | ||||
: public AlgoBase { | : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2"; | return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2"; | ||||
} | } | ||||
@@ -18,7 +18,9 @@ namespace arm_common { | |||||
class ElemwiseImpl::AlgoBinary##case final \ | class ElemwiseImpl::AlgoBinary##case final \ | ||||
: public ElemwiseImpl::AlgoBase { \ | : public ElemwiseImpl::AlgoBase { \ | ||||
mutable std::string m_name; \ | mutable std::string m_name; \ | ||||
bool is_reproducible() const override { return true; } \ | |||||
AlgoAttribute attribute() const override { \ | |||||
return AlgoAttribute::REPRODUCIBLE; \ | |||||
} \ | |||||
const char* name() const override { \ | const char* name() const override { \ | ||||
if (m_name.empty()) { \ | if (m_name.empty()) { \ | ||||
m_name = megdnn_mangle( \ | m_name = megdnn_mangle( \ | ||||
@@ -11,8 +11,8 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "src/fallback/elemwise/opr_impl.h" | #include "src/fallback/elemwise/opr_impl.h" | ||||
#include "src/arm_common/elemwise_op.h" | #include "src/arm_common/elemwise_op.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | namespace arm_common { | ||||
class ElemwiseImpl final : public fallback::ElemwiseImpl { | class ElemwiseImpl final : public fallback::ElemwiseImpl { | ||||
@@ -18,7 +18,9 @@ namespace arm_common { | |||||
class ElemwiseImpl::AlgoTernaryFma3##case final \ | class ElemwiseImpl::AlgoTernaryFma3##case final \ | ||||
: public ElemwiseImpl::AlgoBase { \ | : public ElemwiseImpl::AlgoBase { \ | ||||
mutable std::string m_name; \ | mutable std::string m_name; \ | ||||
bool is_reproducible() const override { return true; } \ | |||||
AlgoAttribute attribute() const override { \ | |||||
return AlgoAttribute::REPRODUCIBLE; \ | |||||
} \ | |||||
const char* name() const override { \ | const char* name() const override { \ | ||||
if (m_name.empty()) { \ | if (m_name.empty()) { \ | ||||
m_name = megdnn_mangle( \ | m_name = megdnn_mangle( \ | ||||
@@ -16,7 +16,9 @@ namespace arm_common { | |||||
class ElemwiseImpl::AlgoUnary final : public ElemwiseImpl::AlgoBase { | class ElemwiseImpl::AlgoUnary final : public ElemwiseImpl::AlgoBase { | ||||
mutable std::string m_name; | mutable std::string m_name; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = megdnn_mangle(ssprintf("Elemwise::AlgoUnary")); | m_name = megdnn_mangle(ssprintf("Elemwise::AlgoUnary")); | ||||
@@ -19,7 +19,9 @@ namespace arm_common { | |||||
class MatrixMulImpl::AlgoInt8x8x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARM_COMMON_INT8X8X16"; } | const char* name() const override { return "ARM_COMMON_INT8X8X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -31,7 +33,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -45,7 +49,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -60,7 +66,9 @@ public: | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -78,7 +86,9 @@ protected: | |||||
~AlgoF32Gemv() = default; | ~AlgoF32Gemv() = default; | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARM_COMMON_F32_GEMV"; } | const char* name() const override { return "ARM_COMMON_F32_GEMV"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -91,7 +101,9 @@ public: | |||||
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -106,7 +118,9 @@ public: | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { | class MatrixMulImpl::AlgoF16Gemv : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARM_COMMON_F16_GEMV"; } | const char* name() const override { return "ARM_COMMON_F16_GEMV"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -121,7 +135,9 @@ public: | |||||
class MatrixMulImpl::AlgoGevm : public AlgoBase { | class MatrixMulImpl::AlgoGevm : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARM_COMMON_GEVM"; } | const char* name() const override { return "ARM_COMMON_GEVM"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -22,7 +22,9 @@ using AlgoBase = PoolingImpl::AlgoBase; | |||||
class PoolingImpl::AlgoFilterxModexStride1 final : public AlgoBase { | class PoolingImpl::AlgoFilterxModexStride1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_STRIDE1"; } | const char* name() const override { return "ARM_POOLING_STRIDE1"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
@@ -30,14 +32,18 @@ public: | |||||
class PoolingImpl::AlgoFilter2ModexStride2 final : public AlgoBase { | class PoolingImpl::AlgoFilter2ModexStride2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_STRIDE2"; } | const char* name() const override { return "ARM_POOLING_STRIDE2"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
}; | }; | ||||
class PoolingImpl::AlgoFilter3MaxStride2 final : public AlgoBase { | class PoolingImpl::AlgoFilter3MaxStride2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_FILTER3_MAX"; } | const char* name() const override { return "ARM_POOLING_FILTER3_MAX"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
@@ -45,7 +51,9 @@ public: | |||||
class PoolingImpl::AlgoFilter3AverageStride2 final : public AlgoBase { | class PoolingImpl::AlgoFilter3AverageStride2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_FILTER3_AVERAGE"; } | const char* name() const override { return "ARM_POOLING_FILTER3_AVERAGE"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
@@ -53,7 +61,9 @@ public: | |||||
class PoolingImpl::AlgoFilter4MaxStride2 final : public AlgoBase { | class PoolingImpl::AlgoFilter4MaxStride2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_FILTER4_MAX"; } | const char* name() const override { return "ARM_POOLING_FILTER4_MAX"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
@@ -61,7 +71,9 @@ public: | |||||
class PoolingImpl::AlgoFilter5MaxStride2 final : public AlgoBase { | class PoolingImpl::AlgoFilter5MaxStride2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_FILTER5_MAX"; } | const char* name() const override { return "ARM_POOLING_FILTER5_MAX"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
@@ -69,7 +81,9 @@ public: | |||||
class PoolingImpl::AlgoInt8Filter2MaxStride2 final : public AlgoBase { | class PoolingImpl::AlgoInt8Filter2MaxStride2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_INT8_FILTER2X2"; } | const char* name() const override { return "ARM_POOLING_INT8_FILTER2X2"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
@@ -77,7 +91,9 @@ public: | |||||
class PoolingImpl::AlgoInt8Filter3MaxStride2 final : public AlgoBase { | class PoolingImpl::AlgoInt8Filter3MaxStride2 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_INT8_FILTER3X3"; } | const char* name() const override { return "ARM_POOLING_INT8_FILTER3X3"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
@@ -85,7 +101,9 @@ public: | |||||
class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase { | class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; } | const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
@@ -93,7 +111,9 @@ public: | |||||
class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase { | class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; } | const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
@@ -101,7 +121,9 @@ public: | |||||
class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase { | class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; } | const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
@@ -109,14 +131,18 @@ public: | |||||
class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase { | class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; } | const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
}; | }; | ||||
class PoolingImpl::AlgoFp32ModexStridexNCHW44 final : public AlgoBase { | class PoolingImpl::AlgoFp32ModexStridexNCHW44 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
}; | |||||
const char* name() const override { return "ARM_POOLING_FP32_MODEX_STRIDEX_NCHW44"; } | const char* name() const override { return "ARM_POOLING_FP32_MODEX_STRIDEX_NCHW44"; } | ||||
bool usable(const PoolingKernSizeParam& param) const override; | bool usable(const PoolingKernSizeParam& param) const override; | ||||
void exec(const PoolingKernParam& param) const override; | void exec(const PoolingKernParam& param) const override; | ||||
@@ -24,7 +24,9 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { | |||||
static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "S8MATMUL"; } | const char* name() const override { return "S8MATMUL"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
@@ -24,7 +24,9 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { | |||||
static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "QU8MATMUL"; } | const char* name() const override { return "QU8MATMUL"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
@@ -21,7 +21,9 @@ namespace armv7 { | |||||
class MatrixMulImpl::AlgoF32 final : public AlgoBase { | class MatrixMulImpl::AlgoF32 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_F32"; } | const char* name() const override { return "ARMV7_F32"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -32,7 +34,9 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } | const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -43,7 +47,9 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_4x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_F32_MK4_4x8"; } | const char* name() const override { return "ARMV7_F32_MK4_4x8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -56,7 +62,9 @@ public: | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
class MatrixMulImpl::AlgoF16K4x16x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF16K4x16x1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH32_F16_K4X16X1"; } | const char* name() const override { return "AARCH32_F16_K4X16X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -66,7 +74,9 @@ public: | |||||
}; | }; | ||||
class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH32_F16_MK8_4X8"; } | const char* name() const override { return "AARCH32_F16_MK8_4X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -79,7 +89,9 @@ public: | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
class MatrixMulImpl::AlgoInt8x8x32K6x8x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K6x8x4 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH32_INT8_K6X8X4"; } | const char* name() const override { return "AARCH32_INT8_K6X8X4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -90,7 +102,9 @@ public: | |||||
class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8DotK4x8x4 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "AARCH32_QUINT8_K4X8X4"; } | const char* name() const override { return "AARCH32_QUINT8_K4X8X4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -101,7 +115,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH32_INT8_MK4_8X4X4_DOTPROD"; | return "AARCH32_INT8_MK4_8X4X4_DOTPROD"; | ||||
} | } | ||||
@@ -124,7 +140,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K4x2x16 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_INT8X8X32_K4X2X16"; } | const char* name() const override { return "ARMV7_INT8X8X32_K4X2X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -136,7 +154,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K4x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_INT8X8X32_K4X8X8"; } | const char* name() const override { return "ARMV7_INT8X8X32_K4X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -148,7 +168,9 @@ public: | |||||
class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8K4x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_QUINT8_K4X8X8"; } | const char* name() const override { return "ARMV7_QUINT8_K4X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -159,7 +181,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K4x2x16 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_INT8X8X16_K4X2X16"; } | const char* name() const override { return "ARMV7_INT8X8X16_K4X2X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -171,7 +195,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K4x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_INT8X8X16_K4X8X8"; } | const char* name() const override { return "ARMV7_INT8X8X16_K4X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -183,7 +209,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16K8x8x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K8x8x4 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_INT8X8X16_K8X8X4"; } | const char* name() const override { return "ARMV7_INT8X8X16_K8X8X4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -195,7 +223,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -207,7 +237,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32K12x4x1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_INT16X16X32_K12X4X1"; } | const char* name() const override { return "ARMV7_INT16X16X32_K12X4X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -219,7 +251,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32MK8_4x8 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_INT16X16X32_MK8_4X8"; } | const char* name() const override { return "ARMV7_INT16X16X32_MK8_4X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -231,7 +265,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } | const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -0,0 +1,22 @@ | |||||
/** | |||||
* \file dnn/src/common/algo_base.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/common/algo_base.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | |||||
bool Algorithm::contain_attribute(const Attribute& attr) const { | |||||
return bool(attribute() & attr); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -21,6 +21,8 @@ | |||||
namespace megdnn { | namespace megdnn { | ||||
MEGDNN_DEF_ENUM_CLASS_BIT_OPR(AlgoAttribute) | |||||
#define MEGDNN_DECL_ALGO_TYPE(_type) \ | #define MEGDNN_DECL_ALGO_TYPE(_type) \ | ||||
uint32_t type() const override { \ | uint32_t type() const override { \ | ||||
return static_cast<std::underlying_type<AlgoType>::type>( \ | return static_cast<std::underlying_type<AlgoType>::type>( \ | ||||
@@ -82,7 +82,7 @@ template <typename Opr> | |||||
typename Opr::Algorithm* get_reproducible_algo(typename Opr::AlgoBase* algo, | typename Opr::Algorithm* get_reproducible_algo(typename Opr::AlgoBase* algo, | ||||
bool reproducible) { | bool reproducible) { | ||||
if (reproducible) { | if (reproducible) { | ||||
if (algo->is_reproducible()) { | |||||
if (algo->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
return algo; | return algo; | ||||
} | } | ||||
} else { | } else { | ||||
@@ -113,7 +113,7 @@ typename Opr::Algorithm* get_reproducible_algo( | |||||
} | } | ||||
} | } | ||||
if (i->is_available(args)) { | if (i->is_available(args)) { | ||||
if (!i->is_reproducible()) | |||||
if (!i->contain_attribute(AlgoAttribute::REPRODUCIBLE)) | |||||
available_but_not_reproducible = true; | available_but_not_reproducible = true; | ||||
} | } | ||||
} | } | ||||
@@ -54,6 +54,7 @@ | |||||
#include <mutex> | #include <mutex> | ||||
#include <string> | #include <string> | ||||
#include <thread> | #include <thread> | ||||
#include <type_traits> | |||||
#if defined(_WIN32) | #if defined(_WIN32) | ||||
#include <windows.h> | #include <windows.h> | ||||
@@ -683,6 +684,62 @@ inline void* get_origin_ptr(const TensorND* tensor, void* ptr) { | |||||
return static_cast<void*>(static_cast<dt_byte*>(ptr) - | return static_cast<void*>(static_cast<dt_byte*>(ptr) - | ||||
tensor->layout.span().low_byte); | tensor->layout.span().low_byte); | ||||
} | } | ||||
template <typename T> | |||||
class EnumClassBit { | |||||
std::underlying_type_t<T> m_val; | |||||
constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {} | |||||
public: | |||||
constexpr EnumClassBit(T v) | |||||
: m_val(static_cast<std::underlying_type_t<T>>(v)) {} | |||||
constexpr operator T() const { return static_cast<T>(m_val); } | |||||
constexpr explicit operator bool() const { return m_val; } | |||||
#define DEF_OPR(op) \ | |||||
constexpr EnumClassBit operator op(const EnumClassBit& rhs) const { \ | |||||
return m_val op rhs.m_val; \ | |||||
} | |||||
DEF_OPR(&) | |||||
DEF_OPR(|) | |||||
DEF_OPR (^) | |||||
constexpr EnumClassBit operator~() const { return ~m_val; } | |||||
#undef DEF_OPR | |||||
}; | |||||
#define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ | |||||
return ::megdnn::EnumClassBit<cls>(x) \ | |||||
op ::megdnn::EnumClassBit<cls>(y); \ | |||||
} \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ | |||||
::megdnn::EnumClassBit<cls> x, cls y) { \ | |||||
return x op ::megdnn::EnumClassBit<cls>(y); \ | |||||
} | |||||
#define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ | |||||
inline constexpr cls& operator op##=(cls& x, cls y) { \ | |||||
x = x op ::megdnn::EnumClassBit<cls>(y); \ | |||||
return x; \ | |||||
} | |||||
#define MEGDNN_DEF_ENUM_CLASS_BIT_OPR(cls) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR(cls, &) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR(cls, |) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR(cls, ^) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, &) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, |) \ | |||||
_MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, ^) \ | |||||
inline constexpr ::megdnn::EnumClassBit<cls> operator~(cls x) { \ | |||||
return ~::megdnn::EnumClassBit<cls>(x); \ | |||||
} | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -14,6 +14,7 @@ | |||||
#include <unordered_map> | #include <unordered_map> | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "megdnn/oprs/base.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/cuda/batch_conv_bias/opr_impl.h" | #include "src/cuda/batch_conv_bias/opr_impl.h" | ||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
@@ -67,7 +68,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
@@ -89,7 +91,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD"; | return "BATCH_CONV_BIAS_INT8_NCHW4_GEMM_DOTPROD"; | ||||
@@ -104,7 +108,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD"; | return "BATCH_CONV_BIAS_INT8_NCHW4_IMPLICIT_GEMM_PRECOMP_DOTPROD"; | ||||
@@ -71,7 +71,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -94,7 +95,9 @@ public: | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | ||||
void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute()const override{ | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "BRUTE_FORCE"; } | const char* name() const override { return "BRUTE_FORCE"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE) | MEGDNN_DECL_ALGO_TYPE(CUDA_BRUTE_FORCE) | ||||
@@ -109,7 +112,9 @@ public: | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | ||||
void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "CUBLAS"; } | const char* name() const override { return "CUBLAS"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | ||||
}; | }; | ||||
@@ -120,7 +125,9 @@ public: | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | ||||
void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "CUBLAS_LT"; } | const char* name() const override { return "CUBLAS_LT"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | ||||
}; | }; | ||||
@@ -132,7 +139,9 @@ public: | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | ||||
void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "INT8x8x32"; } | const char* name() const override { return "INT8x8x32"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_INT8X8X32) | MEGDNN_DECL_ALGO_TYPE(CUDA_INT8X8X32) | ||||
}; | }; | ||||
@@ -130,7 +130,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
@@ -165,7 +166,13 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_attr.is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; } | cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; } | ||||
@@ -198,7 +205,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | ||||
@@ -219,8 +228,10 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
private: | private: | ||||
mutable std::string m_name; | mutable std::string m_name; | ||||
@@ -238,8 +249,10 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_INT8X8X32) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
private: | private: | ||||
mutable std::string m_name; | mutable std::string m_name; | ||||
@@ -260,7 +273,13 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_attr.is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
@@ -298,8 +317,10 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
private: | private: | ||||
mutable std::string m_name; | mutable std::string m_name; | ||||
@@ -327,8 +348,10 @@ public: | |||||
std::vector<SearchItem> get_subopr_list( | std::vector<SearchItem> get_subopr_list( | ||||
const TensorLayoutArray& layouts, | const TensorLayoutArray& layouts, | ||||
const OperatorBase* opr) const override; | const OperatorBase* opr) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
private: | private: | ||||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
@@ -347,8 +370,10 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL_INT8X8X32) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
private: | private: | ||||
bool need_src_unroll(const SizeArgs& args) const; | bool need_src_unroll(const SizeArgs& args) const; | ||||
@@ -378,7 +403,10 @@ public: | |||||
const TensorLayoutArray& layouts, | const TensorLayoutArray& layouts, | ||||
const OperatorBase* opr) const override; | const OperatorBase* opr) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | ||||
private: | private: | ||||
@@ -397,7 +425,13 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | ||||
TensorLayout& dst_pg, TensorLayout& bias_pg); | TensorLayout& dst_pg, TensorLayout& bias_pg); | ||||
@@ -423,7 +457,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "QUINT4x4x32_WMMA"; } | const char* name() const override { return "QUINT4x4x32_WMMA"; } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
private: | private: | ||||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
@@ -444,7 +480,9 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
return "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM"; | return "INT8_CHWN4_DOTPROD_IMPLICIT_GEMM"; | ||||
} | } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
template <typename BiasVisitor> | template <typename BiasVisitor> | ||||
static void dispatch_nonlinear_mode( | static void dispatch_nonlinear_mode( | ||||
const int8_t* d_src, const int8_t* d_filter, | const int8_t* d_src, const int8_t* d_filter, | ||||
@@ -486,7 +524,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
size_t get_preprocess_workspace_in_bytes( | size_t get_preprocess_workspace_in_bytes( | ||||
const SizeArgs& args) const override; | const SizeArgs& args) const override; | ||||
SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | ||||
@@ -524,7 +564,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
template <typename BiasVisitor> | template <typename BiasVisitor> | ||||
static void dispatch_nonlinear_mode( | static void dispatch_nonlinear_mode( | ||||
const int8_t* d_src, const int8_t* d_filter, | const int8_t* d_src, const int8_t* d_filter, | ||||
@@ -561,7 +603,6 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8) | MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_NCHW4_IMMA_INT8) | ||||
std::string param() const override { | std::string param() const override { | ||||
@@ -569,6 +610,9 @@ public: | |||||
serialize_write_pod(m_mma_tile_size, ret); | serialize_write_pod(m_mma_tile_size, ret); | ||||
return ret; | return ret; | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
private: | private: | ||||
WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, | ||||
@@ -590,7 +634,6 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8) | MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_REORDER_FILTER_CHWN4_IMMA_INT8) | ||||
std::string param() const override { | std::string param() const override { | ||||
@@ -598,6 +641,9 @@ public: | |||||
serialize_write_pod(m_mma_tile_size, ret); | serialize_write_pod(m_mma_tile_size, ret); | ||||
return ret; | return ret; | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
private: | private: | ||||
MMATileSize m_mma_tile_size; | MMATileSize m_mma_tile_size; | ||||
@@ -617,7 +663,6 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8) | MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM_UNROLL_WIDTH_CHWN4_IMMA_INT8) | ||||
std::string param() const override { | std::string param() const override { | ||||
@@ -625,6 +670,9 @@ public: | |||||
serialize_write_pod(m_mma_tile_size, ret); | serialize_write_pod(m_mma_tile_size, ret); | ||||
return ret; | return ret; | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
private: | private: | ||||
MMATileSize m_mma_tile_size; | MMATileSize m_mma_tile_size; | ||||
@@ -655,7 +703,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
static std::string to_string(AlgoParam algo_param); | static std::string to_string(AlgoParam algo_param); | ||||
size_t get_preprocess_workspace_in_bytes( | size_t get_preprocess_workspace_in_bytes( | ||||
const SizeArgs& args) const override; | const SizeArgs& args) const override; | ||||
@@ -690,7 +740,10 @@ public: | |||||
const OperatorBase* opr) const override; | const OperatorBase* opr) const override; | ||||
const char* name() const override { return "CONVBIAS_BFLOAT16"; } | const char* name() const override { return "CONVBIAS_BFLOAT16"; } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | ||||
private: | private: | ||||
@@ -82,7 +82,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
@@ -115,10 +116,14 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
const char* name() const override { return m_attr.name.c_str(); } | const char* name() const override { return m_attr.name.c_str(); } | ||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_attr.is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | ||||
bool is_cudnn() const override { return true; } | bool is_cudnn() const override { return true; } | ||||
@@ -146,8 +151,10 @@ public: | |||||
const OperatorBase* opr) const override; | const OperatorBase* opr) const override; | ||||
const char* name() const override { return "MATMUL"; } | const char* name() const override { return "MATMUL"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase { | ||||
@@ -157,8 +164,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "CHANNEL_WISE"; } | const char* name() const override { return "CHANNEL_WISE"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final : public AlgoBase { | ||||
@@ -168,8 +177,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "CHANNEL_WISE_SMALL"; } | const char* name() const override { return "CHANNEL_WISE_SMALL"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | ||||
@@ -185,7 +196,10 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
return "CONVOLUTION_BACKWARD_DATD_BFLOAT16"; | return "CONVOLUTION_BACKWARD_DATD_BFLOAT16"; | ||||
} | } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
private: | private: | ||||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
@@ -207,11 +221,17 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, | static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, | ||||
TensorLayout& grad_pg); | TensorLayout& grad_pg); | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | ||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
std::string param() const override { | std::string param() const override { | ||||
std::string ret; | std::string ret; | ||||
@@ -81,7 +81,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
@@ -114,9 +115,14 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
const char* name() const override { return m_attr.name.c_str(); } | const char* name() const override { return m_attr.name.c_str(); } | ||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_attr.is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; } | cudnnConvolutionBwdFilterAlgo_t cudnn_enum() const { return m_cudnn_enum; } | ||||
@@ -145,8 +151,10 @@ public: | |||||
const OperatorBase* opr) const override; | const OperatorBase* opr) const override; | ||||
const char* name() const override { return "MATMUL"; } | const char* name() const override { return "MATMUL"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase { | class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase { | ||||
@@ -156,8 +164,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "CHANNEL_WISE"; } | const char* name() const override { return "CHANNEL_WISE"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { | class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { | ||||
@@ -173,7 +183,11 @@ public: | |||||
const char* name() const override { | const char* name() const override { | ||||
return "CONVOLUTION_BACKWARD_FILTER_BFLOAT16"; | return "CONVOLUTION_BACKWARD_FILTER_BFLOAT16"; | ||||
} | } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | MEGDNN_DECL_ALGO_TYPE(CUDA_BFLOAT16) | ||||
private: | private: | ||||
@@ -195,12 +209,17 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | ||||
TensorLayout& diff_pg); | TensorLayout& diff_pg); | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | ||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
std::string param() const override { | std::string param() const override { | ||||
std::string ret; | std::string ret; | ||||
@@ -31,6 +31,7 @@ protected: | |||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
enum class AlgoType : uint32_t { | enum class AlgoType : uint32_t { | ||||
CUDA_DEFAULT, | CUDA_DEFAULT, | ||||
}; | }; | ||||
@@ -65,7 +66,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -86,7 +88,10 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override; | size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override; | ||||
const char* name() const override { return "DEFAULT"; } | const char* name() const override { return "DEFAULT"; } | ||||
void exec(const ExecArgs&) const override; | void exec(const ExecArgs&) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
std::vector<SearchItem> get_subopr_list( | std::vector<SearchItem> get_subopr_list( | ||||
const TensorLayoutArray& layouts, | const TensorLayoutArray& layouts, | ||||
const OperatorBase* opr) const override; | const OperatorBase* opr) const override; | ||||
@@ -38,7 +38,6 @@ public: | |||||
}; | }; | ||||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
struct SizeArgs { | struct SizeArgs { | ||||
HandleImpl* handle; | HandleImpl* handle; | ||||
@@ -79,7 +78,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -111,9 +111,14 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
const char* name() const override { return m_attr.name.c_str(); } | const char* name() const override { return m_attr.name.c_str(); } | ||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_attr.is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | ||||
@@ -135,8 +140,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "CHANNEL_WISE"; } | const char* name() const override { return "CHANNEL_WISE"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
//! implement group conv by another algo | //! implement group conv by another algo | ||||
@@ -154,10 +161,15 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, | static void modify_size_args(SizeArgs& args, TensorLayout& diff_pg, | ||||
TensorLayout& grad_pg); | TensorLayout& grad_pg); | ||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | ||||
std::string param() const override { | std::string param() const override { | ||||
@@ -72,7 +72,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -104,7 +105,13 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_attr.is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
const char* name() const override { return m_attr.name.c_str(); } | const char* name() const override { return m_attr.name.c_str(); } | ||||
@@ -128,7 +135,9 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "INPLACE_MATMUL"; } | const char* name() const override { return "INPLACE_MATMUL"; } | ||||
bool is_reproducible() const override { return false; } | |||||
AlgoAttribute attribute() const override { | |||||
return static_cast<AlgoAttribute>(0); | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | ||||
}; | }; | ||||
@@ -139,7 +148,9 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "CHANNEL_WISE"; } | const char* name() const override { return "CHANNEL_WISE"; } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | ||||
}; | }; | ||||
@@ -158,7 +169,13 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | ||||
TensorLayout& diff_pg); | TensorLayout& diff_pg); | ||||
@@ -201,3 +218,4 @@ public: | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -77,7 +77,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -102,7 +103,9 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "1x1x1"; } | const char* name() const override { return "1x1x1"; } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1) | MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1) | ||||
}; | }; | ||||
@@ -120,8 +123,13 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_impl->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | ||||
TensorLayout& dst_pg); | TensorLayout& dst_pg); | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | ||||
@@ -147,7 +155,13 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return m_attr.is_reproducible; } | |||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_attr.is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
const char* name() const override { return m_attr.name.c_str(); } | const char* name() const override { return m_attr.name.c_str(); } | ||||
@@ -172,7 +186,9 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "INPLACE_MATMUL"; } | const char* name() const override { return "INPLACE_MATMUL"; } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_INPLACE_MATMUL) | ||||
}; | }; | ||||
@@ -183,7 +199,9 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "CHANNEL_WISE"; } | const char* name() const override { return "CHANNEL_WISE"; } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE) | ||||
}; | }; | ||||
@@ -218,3 +236,4 @@ public: | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -83,7 +83,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -107,7 +108,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
std::vector<SearchItem> get_subopr_list( | std::vector<SearchItem> get_subopr_list( | ||||
const TensorLayoutArray& layouts, | const TensorLayoutArray& layouts, | ||||
@@ -76,7 +76,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -99,7 +100,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
std::vector<SearchItem> get_subopr_list( | std::vector<SearchItem> get_subopr_list( | ||||
const TensorLayoutArray& layouts, | const TensorLayoutArray& layouts, | ||||
@@ -71,7 +71,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -94,7 +95,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
std::vector<SearchItem> get_subopr_list( | std::vector<SearchItem> get_subopr_list( | ||||
const TensorLayoutArray& layouts, | const TensorLayoutArray& layouts, | ||||
@@ -35,7 +35,6 @@ public: | |||||
}; | }; | ||||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; } | ||||
struct SizeArgs { | struct SizeArgs { | ||||
LocalShareBackwardDataImpl* opr; | LocalShareBackwardDataImpl* opr; | ||||
@@ -63,7 +62,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -83,7 +83,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "LOCAL_SHARE_IMPLICIT_GEMM"; | return "LOCAL_SHARE_IMPLICIT_GEMM"; | ||||
@@ -100,7 +102,9 @@ public: | |||||
const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "LOCAL_SHARE_BATCHED_MATMUL"; | return "LOCAL_SHARE_BATCHED_MATMUL"; | ||||
@@ -62,7 +62,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -82,7 +83,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; } | const char* name() const override { return "LOCAL_SHARE_IMPLICIT_GEMM"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) | MEGDNN_DECL_ALGO_TYPE(CUDA_IMPLICIT_GEMM) | ||||
@@ -96,7 +99,9 @@ public: | |||||
const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | ||||
@@ -63,7 +63,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -85,7 +86,9 @@ public: | |||||
const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE"; | return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE"; | ||||
@@ -102,7 +105,9 @@ public: | |||||
const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE"; | return "LOCAL_SHARE_CHWN_BATCH_SIZE_AWARE_SMALL_IMAGE"; | ||||
@@ -118,7 +123,9 @@ public: | |||||
const SizeArgs& args) const; | const SizeArgs& args) const; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | const char* name() const override { return "LOCAL_SHARE_BATCHED_MATMUL"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | ||||
@@ -86,7 +86,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -109,8 +110,10 @@ public: | |||||
} | } | ||||
const char* name() const override { return "CUBLAS"; } | const char* name() const override { return "CUBLAS"; } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
#if CUDA_VERSION >= 10000 | #if CUDA_VERSION >= 10000 | ||||
@@ -121,8 +124,10 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
const char* name() const override { return "UINT4x4x32_WMMA"; } | const char* name() const override { return "UINT4x4x32_WMMA"; } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) | MEGDNN_DECL_ALGO_TYPE(CUDA_WMMA_UINT4X4X32) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
#endif | #endif | ||||
#if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
@@ -132,8 +137,10 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
const char* name() const override { return "CUBLAS_LT"; } | const char* name() const override { return "CUBLAS_LT"; } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -146,8 +153,10 @@ public: | |||||
} | } | ||||
const char* name() const override { return "NAIVE"; } | const char* name() const override { return "NAIVE"; } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) | MEGDNN_DECL_ALGO_TYPE(CUDA_NAIVE) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
@@ -163,7 +172,10 @@ public: | |||||
const OperatorBase* opr) const override; | const OperatorBase* opr) const override; | ||||
const char* name() const override { return "MATMUL_BFLOAT16"; } | const char* name() const override { return "MATMUL_BFLOAT16"; } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
private: | private: | ||||
WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | ||||
@@ -189,7 +201,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) | MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT) | ||||
std::string param() const override { | std::string param() const override { | ||||
@@ -214,7 +228,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | ||||
std::string param() const override { | std::string param() const override { | ||||
@@ -239,7 +255,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED) | MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED) | ||||
std::string param() const override { | std::string param() const override { | ||||
@@ -66,7 +66,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -87,7 +88,9 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override; | size_t get_workspace_in_bytes(const SizeArgs& /* args */) const override; | ||||
const char* name() const override { return "DEFAULT"; } | const char* name() const override { return "DEFAULT"; } | ||||
virtual void exec(const ExecArgs&) const override; | virtual void exec(const ExecArgs&) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(fallback_BLAS) | MEGDNN_DECL_ALGO_TYPE(fallback_BLAS) | ||||
}; | }; | ||||
@@ -20,7 +20,9 @@ namespace fallback { | |||||
class ConvBiasImpl::AlgoNaive final : public AlgoBase { | class ConvBiasImpl::AlgoNaive final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override{ | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
const char* name() const override { return "FALLBACK_NAIVE"; } | const char* name() const override { return "FALLBACK_NAIVE"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -43,7 +45,9 @@ class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { | |||||
public: | public: | ||||
AlgoWinogradF32(MatrixMulImpl::AlgoBase* matmul_algo) | AlgoWinogradF32(MatrixMulImpl::AlgoBase* matmul_algo) | ||||
: m_matmul_algo{matmul_algo} {} | : m_matmul_algo{matmul_algo} {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
@@ -77,7 +81,9 @@ class ConvBiasImpl::AlgoWinogradF32_4x4 final : public AlgoBase { | |||||
public: | public: | ||||
AlgoWinogradF32_4x4(MatrixMulImpl::AlgoBase* matmul_algo) | AlgoWinogradF32_4x4(MatrixMulImpl::AlgoBase* matmul_algo) | ||||
: m_matmul_algo{matmul_algo} {} | : m_matmul_algo{matmul_algo} {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
@@ -111,7 +117,9 @@ class ConvBiasImpl::AlgoWinogradQS8 final : public AlgoBase { | |||||
public: | public: | ||||
AlgoWinogradQS8(MatrixMulImpl::AlgoBase* matmul_algo) | AlgoWinogradQS8(MatrixMulImpl::AlgoBase* matmul_algo) | ||||
: m_matmul_algo{matmul_algo} {} | : m_matmul_algo{matmul_algo} {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
@@ -145,7 +153,9 @@ class ConvBiasImpl::AlgoWinogradQS8_8x8 final : public AlgoBase { | |||||
public: | public: | ||||
AlgoWinogradQS8_8x8(MatrixMulImpl::AlgoBase* matmul_algo) | AlgoWinogradQS8_8x8(MatrixMulImpl::AlgoBase* matmul_algo) | ||||
: m_matmul_algo{matmul_algo} {} | : m_matmul_algo{matmul_algo} {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>( | ||||
@@ -141,7 +141,6 @@ using BiasMode = ConvBiasForward::BiasMode; | |||||
} | } | ||||
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(_algo_data_type) \ | #define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(_algo_data_type) \ | ||||
bool is_reproducible() const override { return true; } \ | |||||
bool usable(const NCBKernSizeParam& param, \ | bool usable(const NCBKernSizeParam& param, \ | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; \ | AlgoSelectionStrategy algo_selection_strategy) const override; \ | ||||
size_t get_workspace(const NCBKernSizeParam& param) const override; \ | size_t get_workspace(const NCBKernSizeParam& param) const override; \ | ||||
@@ -29,7 +29,9 @@ public: | |||||
AlgoConv1x1(MatrixMulImpl::AlgoBase* matmul_algo, size_t oc_block_size) | AlgoConv1x1(MatrixMulImpl::AlgoBase* matmul_algo, size_t oc_block_size) | ||||
: m_matmul_algo(matmul_algo), m_oc_block_size(oc_block_size) {} | : m_matmul_algo(matmul_algo), m_oc_block_size(oc_block_size) {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return m_matmul_algo->attribute(); | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
@@ -22,7 +22,9 @@ class ConvBiasImpl::AlgoConv1x1Gemv final : public AlgoBase { | |||||
public: | public: | ||||
AlgoConv1x1Gemv() = default; | AlgoConv1x1Gemv() = default; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "CONV1x1_GEMV"; } | const char* name() const override { return "CONV1x1_GEMV"; } | ||||
@@ -27,7 +27,9 @@ public: | |||||
: m_matmul_algo(matmul_algo), | : m_matmul_algo(matmul_algo), | ||||
m_ohw_tile_size(ohw_tile_size) {} | m_ohw_tile_size(ohw_tile_size) {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return m_matmul_algo->attribute(); | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
if (m_name.empty()) { | if (m_name.empty()) { | ||||
m_name = ssprintf("IM2COLMATMUL:%s:%zu", m_matmul_algo->name(), | m_name = ssprintf("IM2COLMATMUL:%s:%zu", m_matmul_algo->name(), | ||||
@@ -320,10 +320,12 @@ public: | |||||
virtual bool is_preferred(const NCBKernSizeParam&) const { | virtual bool is_preferred(const NCBKernSizeParam&) const { | ||||
return false; | return false; | ||||
} | } | ||||
bool usable_reproducible(const NCBKernSizeParam& param, | bool usable_reproducible(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy, | AlgoSelectionStrategy algo_selection_strategy, | ||||
bool reproducible = true) const { | bool reproducible = true) const { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
usable(param, algo_selection_strategy); | usable(param, algo_selection_strategy); | ||||
} | } | ||||
@@ -75,7 +75,6 @@ void kern_naive(const ConvolutionBackwardDataImpl::NCBKernParam& p) { | |||||
class ConvolutionImpl::AlgoFallback final : public AlgoBase { | class ConvolutionImpl::AlgoFallback final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "FALLBACK_ALGO"; } | const char* name() const override { return "FALLBACK_ALGO"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -85,6 +84,10 @@ public: | |||||
SmallVector<NCBKern> dispatch_kern( | SmallVector<NCBKern> dispatch_kern( | ||||
const NCBKernSizeParam& /*param*/) const override; | const NCBKernSizeParam& /*param*/) const override; | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
return {AlgoDataType::FLOAT32, AlgoCategory::NAIVE}; | return {AlgoDataType::FLOAT32, AlgoCategory::NAIVE}; | ||||
} | } | ||||
@@ -93,7 +96,6 @@ public: | |||||
class ConvolutionImpl::AlgoNaive final : public AlgoBase { | class ConvolutionImpl::AlgoNaive final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "NAIVE_ALGO"; } | const char* name() const override { return "NAIVE_ALGO"; } | ||||
bool usable(const NCBKernSizeParam& /*param*/, | bool usable(const NCBKernSizeParam& /*param*/, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -103,6 +105,9 @@ public: | |||||
SmallVector<NCBKern> dispatch_kern( | SmallVector<NCBKern> dispatch_kern( | ||||
const NCBKernSizeParam& /*param*/) const override; | const NCBKernSizeParam& /*param*/) const override; | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
ConvAlgoTypePack get_algo_type() const override { | ConvAlgoTypePack get_algo_type() const override { | ||||
auto support_data_type = static_cast<AlgoDataType>( | auto support_data_type = static_cast<AlgoDataType>( | ||||
static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | ||||
@@ -122,7 +127,6 @@ class ConvolutionImpl::AlgoDefault final : public AlgoBase { | |||||
public: | public: | ||||
AlgoDefault(ConvBiasImpl::AlgoBase*); | AlgoDefault(ConvBiasImpl::AlgoBase*); | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
@@ -144,6 +148,10 @@ public: | |||||
return get_kimpl(m_algorithm, param); | return get_kimpl(m_algorithm, param); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return m_algorithm->attribute(); | |||||
} | |||||
//! select matmul to the highest preference | //! select matmul to the highest preference | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
@@ -169,7 +177,6 @@ private: | |||||
////////////////////////// convolutionbackwarddata //////////////////////// | ////////////////////////// convolutionbackwarddata //////////////////////// | ||||
class ConvolutionBackwardDataImpl::AlgoNaive final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoNaive final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DeconvNaive"; } | const char* name() const override { return "DeconvNaive"; } | ||||
bool usable(ConvolutionBackwardDataImpl* opr, | bool usable(ConvolutionBackwardDataImpl* opr, | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
@@ -178,12 +185,14 @@ public: | |||||
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ||||
const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
bool is_naive() const override { return true; } | bool is_naive() const override { return true; } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) | MEGDNN_DECL_ALGO_TYPE(FB_NAIVE) | ||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoDirect final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoDirect final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DeconvDirect"; } | const char* name() const override { return "DeconvDirect"; } | ||||
bool usable(ConvolutionBackwardDataImpl* opr, | bool usable(ConvolutionBackwardDataImpl* opr, | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
@@ -191,12 +200,14 @@ public: | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ||||
const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(FB_DIRECT) | MEGDNN_DECL_ALGO_TYPE(FB_DIRECT) | ||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoMatrixMul final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoMatrixMul final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DeconvMatmul"; } | const char* name() const override { return "DeconvMatmul"; } | ||||
bool usable(ConvolutionBackwardDataImpl* opr, | bool usable(ConvolutionBackwardDataImpl* opr, | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
@@ -205,6 +216,9 @@ public: | |||||
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ||||
const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(FB_MATMUL) | MEGDNN_DECL_ALGO_TYPE(FB_MATMUL) | ||||
}; | }; | ||||
@@ -736,7 +736,7 @@ ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic( | |||||
for (auto i : ncb_1g_get_all_algorithms(param)) { | for (auto i : ncb_1g_get_all_algorithms(param)) { | ||||
if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) { | if (ncb_1g_get_workspace(i, param) <= workspace_limit_in_bytes) { | ||||
if (reproducible) { | if (reproducible) { | ||||
if (i->is_reproducible()) { | |||||
if (i->contain_attribute(AlgoAttribute::REPRODUCIBLE)) { | |||||
return i; | return i; | ||||
} | } | ||||
} else { | } else { | ||||
@@ -237,10 +237,12 @@ public: | |||||
virtual bool is_preferred(const NCBKernSizeParam&) const { | virtual bool is_preferred(const NCBKernSizeParam&) const { | ||||
return false; | return false; | ||||
} | } | ||||
bool usable_reproducible(const NCBKernSizeParam& param, | bool usable_reproducible(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy algo_selection_strategy, | AlgoSelectionStrategy algo_selection_strategy, | ||||
bool reproducible = true) const { | bool reproducible = true) const { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
usable(param, algo_selection_strategy); | usable(param, algo_selection_strategy); | ||||
} | } | ||||
@@ -422,7 +424,9 @@ protected: | |||||
bool usable_reproducible(ConvolutionBackwardDataImpl* opr, | bool usable_reproducible(ConvolutionBackwardDataImpl* opr, | ||||
const NCBKernSizeParam& param, | const NCBKernSizeParam& param, | ||||
bool reproducible = true) const { | bool reproducible = true) const { | ||||
return (!reproducible || is_reproducible()) && usable(opr, param); | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
usable(opr, param); | |||||
} | } | ||||
virtual bool is_preferred(const NCBKernSizeParam&) const { | virtual bool is_preferred(const NCBKernSizeParam&) const { | ||||
return false; | return false; | ||||
@@ -21,18 +21,19 @@ namespace fallback { | |||||
class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "FB_F32_K8X12X1"; } | const char* name() const override { return "FB_F32_K8X12X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(FB_F32K8x12x1) | MEGDNN_DECL_ALGO_TYPE(FB_F32K8x12x1) | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); | ||||
}; | }; | ||||
class MatrixMulImpl::AlgoGemv final : public AlgoBase { | class MatrixMulImpl::AlgoGemv final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "FB_GEMV"; } | const char* name() const override { return "FB_GEMV"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -40,6 +41,9 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(FB_GEMV) | MEGDNN_DECL_ALGO_TYPE(FB_GEMV) | ||||
MEGDNN_OVERRIDE_MATMUL_DESC( | MEGDNN_OVERRIDE_MATMUL_DESC( | ||||
8, 16, 1, 4, | 8, 16, 1, 4, | ||||
@@ -54,7 +58,9 @@ public: | |||||
class MatrixMulImpl::AlgoNaive final : public AlgoBase { | class MatrixMulImpl::AlgoNaive final : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
const char* name() const override { return "FB_NAIVE"; } | const char* name() const override { return "FB_NAIVE"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
@@ -225,7 +225,9 @@ public: | |||||
}; | }; | ||||
bool preferred_reproducible(const KernSizeParam& param, | bool preferred_reproducible(const KernSizeParam& param, | ||||
bool reproducible = true) { | bool reproducible = true) { | ||||
return (!reproducible || is_reproducible()) && preferred(param); | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
preferred(param); | |||||
}; | }; | ||||
virtual MatmulDescription matmul_description() const = 0; | virtual MatmulDescription matmul_description() const = 0; | ||||
@@ -129,7 +129,7 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( | |||||
auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
->default_batch_conv_bias_fwd_algo(); | ->default_batch_conv_bias_fwd_algo(); | ||||
if (reproducible) { | if (reproducible) { | ||||
megdnn_assert(algo->is_reproducible(), | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | "require reproducible algorithm, but heuristic " | ||||
"algorithm(%s) is not " | "algorithm(%s) is not " | ||||
"reproducible", | "reproducible", | ||||
@@ -250,7 +250,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo(); | static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo(); | ||||
if (reproducible) { | if (reproducible) { | ||||
megdnn_assert(algo->is_reproducible(), | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | "require reproducible algorithm, but heuristic " | ||||
"algorithm(%s) is not " | "algorithm(%s) is not " | ||||
"reproducible", | "reproducible", | ||||
@@ -11,39 +11,50 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/common/algo_base.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace naive { | namespace naive { | ||||
class DefaultConvolutionForwardAlgorithm final | class DefaultConvolutionForwardAlgorithm final | ||||
: public megdnn::ConvolutionForward::Algorithm { | : public megdnn::ConvolutionForward::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DEFAULT"; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
const char* name() const override { return "DEFAULT"; } | |||||
}; | }; | ||||
class DefaultConvolutionBackwardDataAlgorithm final | class DefaultConvolutionBackwardDataAlgorithm final | ||||
: public megdnn::ConvolutionBackwardData::Algorithm { | : public megdnn::ConvolutionBackwardData::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DEFAULT"; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
const char* name() const override { return "DEFAULT"; } | |||||
}; | }; | ||||
class DefaultConvolutionBackwardFilterAlgorithm final | class DefaultConvolutionBackwardFilterAlgorithm final | ||||
: public megdnn::ConvolutionBackwardFilter::Algorithm { | : public megdnn::ConvolutionBackwardFilter::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DEFAULT"; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
const char* name() const override { return "DEFAULT"; } | |||||
}; | }; | ||||
class DefaultConvBiasForwardAlgorithm final | class DefaultConvBiasForwardAlgorithm final | ||||
: public megdnn::ConvBiasForward::Algorithm { | : public megdnn::ConvBiasForward::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DEFAULT"; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
const char* name() const override { return "DEFAULT"; } | |||||
}; | }; | ||||
class DefaultBatchConvBiasForwardAlgorithm final | class DefaultBatchConvBiasForwardAlgorithm final | ||||
: public megdnn::BatchConvBiasForward::Algorithm { | : public megdnn::BatchConvBiasForward::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DEFAULT"; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
const char* name() const override { return "DEFAULT"; } | |||||
}; | }; | ||||
} // namespace naive | } // namespace naive | ||||
@@ -276,7 +276,7 @@ ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_conv_fwd_algo(); | static_cast<HandleImpl*>(handle())->default_conv_fwd_algo(); | ||||
if (reproducible) { | if (reproducible) { | ||||
megdnn_assert(algo->is_reproducible(), | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | "require reproducible algorithm, but heuristic " | ||||
"algorithm(%s) is not " | "algorithm(%s) is not " | ||||
"reproducible", | "reproducible", | ||||
@@ -308,7 +308,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo(); | static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo(); | ||||
if (reproducible) { | if (reproducible) { | ||||
megdnn_assert(algo->is_reproducible(), | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | "require reproducible algorithm, but heuristic " | ||||
"algorithm(%s) is not " | "algorithm(%s) is not " | ||||
"reproducible", | "reproducible", | ||||
@@ -341,7 +341,7 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo(); | static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo(); | ||||
if (reproducible) { | if (reproducible) { | ||||
megdnn_assert(algo->is_reproducible(), | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | "require reproducible algorithm, but heuristic " | ||||
"algorithm(%s) is not " | "algorithm(%s) is not " | ||||
"reproducible", | "reproducible", | ||||
@@ -10,25 +10,32 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/common/algo_base.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace naive { | namespace naive { | ||||
class DefaultConvolution3DForwardAlgorithm final | class DefaultConvolution3DForwardAlgorithm final | ||||
: public megdnn::Convolution3DForward::Algorithm { | : public megdnn::Convolution3DForward::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
const char* name() const override { return "DEFAULT"; } | const char* name() const override { return "DEFAULT"; } | ||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
}; | }; | ||||
class DefaultConvolution3DBackwardDataAlgorithm final | class DefaultConvolution3DBackwardDataAlgorithm final | ||||
: public megdnn::Convolution3DBackwardData::Algorithm { | : public megdnn::Convolution3DBackwardData::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
const char* name() const override { return "DEFAULT"; } | const char* name() const override { return "DEFAULT"; } | ||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
}; | }; | ||||
class DefaultConvolution3DBackwardFilterAlgorithm final | class DefaultConvolution3DBackwardFilterAlgorithm final | ||||
: public megdnn::Convolution3DBackwardFilter::Algorithm { | : public megdnn::Convolution3DBackwardFilter::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
const char* name() const override { return "DEFAULT"; } | const char* name() const override { return "DEFAULT"; } | ||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
}; | }; | ||||
@@ -123,7 +123,7 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||||
bool reproducible) { | bool reproducible) { | ||||
auto algo = static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo(); | auto algo = static_cast<HandleImpl*>(handle())->default_conv3d_fwd_algo(); | ||||
if (reproducible) { | if (reproducible) { | ||||
megdnn_assert(algo->is_reproducible(), | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | "require reproducible algorithm, but heuristic " | ||||
"algorithm(%s) is not " | "algorithm(%s) is not " | ||||
"reproducible", | "reproducible", | ||||
@@ -156,7 +156,7 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo(); | static_cast<HandleImpl*>(handle())->default_conv3d_bwd_data_algo(); | ||||
if (reproducible) { | if (reproducible) { | ||||
megdnn_assert(algo->is_reproducible(), | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | "require reproducible algorithm, but heuristic " | ||||
"algorithm(%s) is not " | "algorithm(%s) is not " | ||||
"reproducible", | "reproducible", | ||||
@@ -191,7 +191,7 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( | |||||
auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
->default_conv3d_bwd_filter_algo(); | ->default_conv3d_bwd_filter_algo(); | ||||
if (reproducible) { | if (reproducible) { | ||||
megdnn_assert(algo->is_reproducible(), | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | "require reproducible algorithm, but heuristic " | ||||
"algorithm(%s) is not " | "algorithm(%s) is not " | ||||
"reproducible", | "reproducible", | ||||
@@ -11,6 +11,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
#include "megdnn/oprs/base.h" | |||||
#include "src/common/handle_impl.h" | #include "src/common/handle_impl.h" | ||||
#include "src/naive/convolution/algorithms.h" | #include "src/naive/convolution/algorithms.h" | ||||
#include "src/naive/matrix_mul/algorithms.h" | #include "src/naive/matrix_mul/algorithms.h" | ||||
@@ -11,27 +11,34 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
#include "src/common/algo_base.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace naive { | namespace naive { | ||||
class DefaultLocalShareForwardAlgorithm final | class DefaultLocalShareForwardAlgorithm final | ||||
: public megdnn::LocalShareForward::Algorithm { | : public megdnn::LocalShareForward::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DEFAULT"; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
const char* name() const override { return "DEFAULT"; } | |||||
}; | }; | ||||
class DefaultLocalShareBackwardDataAlgorithm final | class DefaultLocalShareBackwardDataAlgorithm final | ||||
: public megdnn::LocalShareBackwardData::Algorithm { | : public megdnn::LocalShareBackwardData::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DEFAULT"; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
const char* name() const override { return "DEFAULT"; } | |||||
}; | }; | ||||
class DefaultLocalShareBackwardFilterAlgorithm final | class DefaultLocalShareBackwardFilterAlgorithm final | ||||
: public megdnn::LocalShareBackwardFilter::Algorithm { | : public megdnn::LocalShareBackwardFilter::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "DEFAULT"; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE; | |||||
} | |||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
const char* name() const override { return "DEFAULT"; } | |||||
}; | }; | ||||
} // namespace naive | } // namespace naive | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -166,7 +166,7 @@ LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( | |||||
auto algo = | auto algo = | ||||
static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); | static_cast<HandleImpl*>(handle())->default_local_share_fwd_algo(); | ||||
if (reproducible) { | if (reproducible) { | ||||
megdnn_assert(algo->is_reproducible(), | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | "require reproducible algorithm, but heuristic " | ||||
"algorithm(%s) is not " | "algorithm(%s) is not " | ||||
"reproducible", | "reproducible", | ||||
@@ -200,7 +200,7 @@ LocalShareBackwardDataImpl::get_algorithm_heuristic( | |||||
auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
->default_local_share_bwd_data_algo(); | ->default_local_share_bwd_data_algo(); | ||||
if (reproducible) { | if (reproducible) { | ||||
megdnn_assert(algo->is_reproducible(), | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | "require reproducible algorithm, but heuristic " | ||||
"algorithm(%s) is not " | "algorithm(%s) is not " | ||||
"reproducible", | "reproducible", | ||||
@@ -234,7 +234,7 @@ LocalShareBackwardFilterImpl::get_algorithm_heuristic( | |||||
auto algo = static_cast<HandleImpl*>(handle()) | auto algo = static_cast<HandleImpl*>(handle()) | ||||
->default_local_share_bwd_filter_algo(); | ->default_local_share_bwd_filter_algo(); | ||||
if (reproducible) { | if (reproducible) { | ||||
megdnn_assert(algo->is_reproducible(), | |||||
megdnn_assert(algo->contain_attribute(AlgoAttribute::REPRODUCIBLE), | |||||
"require reproducible algorithm, but heuristic " | "require reproducible algorithm, but heuristic " | ||||
"algorithm(%s) is not " | "algorithm(%s) is not " | ||||
"reproducible", | "reproducible", | ||||
@@ -17,14 +17,18 @@ namespace naive { | |||||
class DefaultMatrixMulAlgorithm final | class DefaultMatrixMulAlgorithm final | ||||
: public megdnn::MatrixMulForward::Algorithm { | : public megdnn::MatrixMulForward::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "DEFAULT"; } | const char* name() const override { return "DEFAULT"; } | ||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
}; | }; | ||||
class DefaultBatchedMatrixMulAlgorithm final | class DefaultBatchedMatrixMulAlgorithm final | ||||
: public megdnn::BatchedMatrixMulForward::Algorithm { | : public megdnn::BatchedMatrixMulForward::Algorithm { | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "DEFAULT"; } | const char* name() const override { return "DEFAULT"; } | ||||
uint32_t type() const override { return 0; } | uint32_t type() const override { return 0; } | ||||
}; | }; | ||||
@@ -73,7 +73,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -96,7 +97,9 @@ public: | |||||
} | } | ||||
const char* name() const override { return "BLAS"; } | const char* name() const override { return "BLAS"; } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | ||||
}; | }; | ||||
@@ -77,7 +77,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
@@ -107,8 +108,13 @@ public: | |||||
bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return m_is_reproducible; } | |||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "MIOpenConvolutionBackwardData"; | return "MIOpenConvolutionBackwardData"; | ||||
@@ -137,8 +143,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "MATMUL"; } | const char* name() const override { return "MATMUL"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL) | MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase { | ||||
@@ -148,8 +156,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "CHANNEL_WISE"; } | const char* name() const override { return "CHANNEL_WISE"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE) | MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | ||||
@@ -34,6 +34,7 @@ public: | |||||
ROCM_CHANWISE | ROCM_CHANWISE | ||||
}; | }; | ||||
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | ||||
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; } | ||||
struct SizeArgs { | struct SizeArgs { | ||||
HandleImpl* handle; | HandleImpl* handle; | ||||
@@ -73,7 +74,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
@@ -104,8 +106,13 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return m_is_reproducible; } | |||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "MIOpenConvolutionBackwardFilter"; | return "MIOpenConvolutionBackwardFilter"; | ||||
} | } | ||||
@@ -133,8 +140,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "MATMUL"; } | const char* name() const override { return "MATMUL"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL) | MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase { | class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase { | ||||
@@ -144,8 +153,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "CHANNEL_WISE"; } | const char* name() const override { return "CHANNEL_WISE"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE) | MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { | class ConvolutionBackwardFilterImpl::AlgoPack : NonCopyableObj { | ||||
@@ -33,7 +33,6 @@ namespace rocm { | |||||
class ConvolutionForwardImpl::AlgoBase : public Algorithm { | class ConvolutionForwardImpl::AlgoBase : public Algorithm { | ||||
protected: | protected: | ||||
~AlgoBase() = default; | ~AlgoBase() = default; | ||||
public: | public: | ||||
enum class AlgoType : uint32_t { | enum class AlgoType : uint32_t { | ||||
ROCM_MIOPEN, | ROCM_MIOPEN, | ||||
@@ -77,7 +76,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) { | size_t limit = std::numeric_limits<size_t>::max()) { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
@@ -107,7 +107,13 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& args) const override; | size_t get_workspace_in_bytes(const SizeArgs& args) const override; | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return m_is_reproducible; } | |||||
AlgoAttribute attribute() const override { | |||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
if (m_is_reproducible) { | |||||
ret |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
return ret; | |||||
} | |||||
const char* name() const override { return "MIOpenConvolutionForward"; } | const char* name() const override { return "MIOpenConvolutionForward"; } | ||||
@@ -134,7 +140,9 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "MATMUL"; } | const char* name() const override { return "MATMUL"; } | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL) | MEGDNN_DECL_ALGO_TYPE(ROCM_MATMUL) | ||||
}; | }; | ||||
@@ -146,8 +154,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "INPLACE_MATMUL"; } | const char* name() const override { return "INPLACE_MATMUL"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_INPLACE_MATMUL) | MEGDNN_DECL_ALGO_TYPE(ROCM_INPLACE_MATMUL) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
//! optimized 1x1 conv | //! optimized 1x1 conv | ||||
@@ -161,8 +171,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "1x1"; } | const char* name() const override { return "1x1"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_1X1) | MEGDNN_DECL_ALGO_TYPE(ROCM_1X1) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
//! optimized 1x1 conv when input data batchsize is larger than 32 | //! optimized 1x1 conv when input data batchsize is larger than 32 | ||||
@@ -176,8 +188,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "LARGE_BATCH_1x1"; } | const char* name() const override { return "LARGE_BATCH_1x1"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_1X1_LARGE_BATCH) | MEGDNN_DECL_ALGO_TYPE(ROCM_1X1_LARGE_BATCH) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
class ConvolutionForwardImpl::AlgoChanwise final : public AlgoBase { | class ConvolutionForwardImpl::AlgoChanwise final : public AlgoBase { | ||||
@@ -187,8 +201,10 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
const char* name() const override { return "CHANNEL_WISE"; } | const char* name() const override { return "CHANNEL_WISE"; } | ||||
bool is_reproducible() const override { return true; } | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE) | MEGDNN_DECL_ALGO_TYPE(ROCM_CHANWISE) | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
}; | }; | ||||
class ConvolutionForwardImpl::AlgoPack : NonCopyableObj { | class ConvolutionForwardImpl::AlgoPack : NonCopyableObj { | ||||
@@ -73,7 +73,8 @@ public: | |||||
bool is_available_reproducible( | bool is_available_reproducible( | ||||
const SizeArgs& args, bool reproducible = true, | const SizeArgs& args, bool reproducible = true, | ||||
size_t limit = std::numeric_limits<size_t>::max()) const { | size_t limit = std::numeric_limits<size_t>::max()) const { | ||||
return (!reproducible || is_reproducible()) && | |||||
return (!reproducible || | |||||
contain_attribute(AlgoAttribute::REPRODUCIBLE)) && | |||||
is_available_wk(args, limit); | is_available_wk(args, limit); | ||||
} | } | ||||
AlgoBase& check_workspace(const SizeArgs& args, | AlgoBase& check_workspace(const SizeArgs& args, | ||||
@@ -96,7 +97,9 @@ public: | |||||
} | } | ||||
const char* name() const override { return "BLAS"; } | const char* name() const override { return "BLAS"; } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | ||||
}; | }; | ||||
@@ -32,7 +32,9 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase { | |||||
const CpuNDRange& workspace_ids); | const CpuNDRange& workspace_ids); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; | return "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"; | ||||
} | } | ||||
@@ -68,7 +70,9 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase { | |||||
const CpuNDRange& workspace_ids); | const CpuNDRange& workspace_ids); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP"; | return "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP"; | ||||
} | } | ||||
@@ -101,6 +105,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(X86_WINOGRAD_F63_8x8_F32) | MEGDNN_DECL_ALGO_TYPE(X86_WINOGRAD_F63_8x8_F32) | ||||
}; | }; | ||||
@@ -117,6 +124,9 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | ||||
MEGDNN_DECL_ALGO_TYPE(X86_WINOGRAD_F23_8x8_F32) | MEGDNN_DECL_ALGO_TYPE(X86_WINOGRAD_F23_8x8_F32) | ||||
}; | }; | ||||
@@ -128,7 +138,9 @@ class ConvBiasImpl::AlgoMkldnnConv final : public AlgoBase { | |||||
public: | public: | ||||
AlgoMkldnnConv() {} | AlgoMkldnnConv() {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "MKLDNN_CONV_FP32"; } | const char* name() const override { return "MKLDNN_CONV_FP32"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy) const override { | AlgoSelectionStrategy) const override { | ||||
@@ -21,7 +21,9 @@ class ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8 final : public AlgoBase { | |||||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"; | return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1"; | ||||
} | } | ||||
@@ -46,7 +48,9 @@ class ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8 final : public AlgoBase { | |||||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"; | return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2"; | ||||
} | } | ||||
@@ -71,7 +75,9 @@ class ConvBiasImpl::AlgoDirectAvx2Stride1Int8 final : public AlgoBase { | |||||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE1"; | return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE1"; | ||||
} | } | ||||
@@ -96,7 +102,9 @@ class ConvBiasImpl::AlgoAVX2DirectConvStride2 final : public AlgoBase { | |||||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { | const char* name() const override { | ||||
return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE2"; | return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE2"; | ||||
} | } | ||||
@@ -124,7 +132,9 @@ class ConvBiasImpl::AlgoMkldnnQint8 final : public AlgoBase { | |||||
public: | public: | ||||
AlgoMkldnnQint8() {} | AlgoMkldnnQint8() {} | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "MKLDNN_INT8"; } | const char* name() const override { return "MKLDNN_INT8"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy) const override; | AlgoSelectionStrategy) const override; | ||||
@@ -163,7 +173,9 @@ class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase { | |||||
static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | static WorkspaceBundle get_bundle(const NCBKernSizeParam& param); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "MKLDNN_MATMUL_INT8"; } | const char* name() const override { return "MKLDNN_MATMUL_INT8"; } | ||||
bool usable(const NCBKernSizeParam& param, | bool usable(const NCBKernSizeParam& param, | ||||
AlgoSelectionStrategy) const override; | AlgoSelectionStrategy) const override; | ||||
@@ -20,11 +20,13 @@ namespace x86 { | |||||
class MatrixMulImpl::AlgoF32Blas : public AlgoBase { | class MatrixMulImpl::AlgoF32Blas : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
const char* name() const override { return "X86_F32_BLAS"; } | const char* name() const override { return "X86_F32_BLAS"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | ||||
MEGDNN_DECL_ALGO_TYPE(X86_F32_BLAS) | MEGDNN_DECL_ALGO_TYPE(X86_F32_BLAS) | ||||
@@ -33,7 +35,9 @@ public: | |||||
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | ||||
class MatrixMulImpl::AlgoF32MKLPackA : public AlgoBase { | class MatrixMulImpl::AlgoF32MKLPackA : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "X86_F32_MKL_PACKA"; } | const char* name() const override { return "X86_F32_MKL_PACKA"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
@@ -55,7 +59,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16 : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16 : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "X86_INT8X8X32_AVX2_2X4X16"; } | const char* name() const override { return "X86_INT8X8X32_AVX2_2X4X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -66,7 +72,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "X86_INT8X8X32_AVX2_4X16X2"; } | const char* name() const override { return "X86_INT8X8X32_AVX2_4X16X2"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -81,7 +89,9 @@ private: | |||||
const MatrixMulImpl::KernParam& kern_param); | const MatrixMulImpl::KernParam& kern_param); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "X86_INT8X8X16_AVX2"; } | const char* name() const override { return "X86_INT8X8X16_AVX2"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -97,7 +107,9 @@ private: | |||||
const MatrixMulImpl::KernParam& kern_param); | const MatrixMulImpl::KernParam& kern_param); | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "X86_INT8X8X16_SSE"; } | const char* name() const override { return "X86_INT8X8X16_SSE"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -109,7 +121,9 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "X86_INT8X8X32_SSE_4X8X2"; } | const char* name() const override { return "X86_INT8X8X32_SSE_4X8X2"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -120,7 +134,9 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase { | class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "X86_F32MK8_8X8"; } | const char* name() const override { return "X86_F32MK8_8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -133,7 +149,9 @@ public: | |||||
#if MEGDNN_X86_WITH_VNNI | #if MEGDNN_X86_WITH_VNNI | ||||
class MatrixMulImpl::AlgoInt8x8x32Vnni : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32Vnni : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "X86_INT8X8X32_VNNI"; } | const char* name() const override { return "X86_INT8X8X32_VNNI"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
@@ -146,7 +164,9 @@ public: | |||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
class MatrixMulImpl::AlgoInt8x8x32Mkldnn : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32Mkldnn : public AlgoBase { | ||||
public: | public: | ||||
bool is_reproducible() const override { return true; } | |||||
AlgoAttribute attribute() const override { | |||||
return AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return "X86_INT8X8X32_MKLDNN"; } | const char* name() const override { return "X86_INT8X8X32_MKLDNN"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
size_t get_workspace(const KernSizeParam&) const override { return 0; } | size_t get_workspace(const KernSizeParam&) const override { return 0; } | ||||
@@ -420,7 +420,9 @@ size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts, | |||||
mgb_assert(palgo, "Unknown algo description"); | mgb_assert(palgo, "Unknown algo description"); | ||||
ret.append("): algo=" + std::string(palgo->name())); | ret.append("): algo=" + std::string(palgo->name())); | ||||
ret.append(ssprintf(" workspace=%.2fMiB reproducible=%d", | ret.append(ssprintf(" workspace=%.2fMiB reproducible=%d", | ||||
workspace / (1024 * 1024.0), palgo->is_reproducible())); | |||||
workspace / (1024 * 1024.0), | |||||
palgo->contain_attribute( | |||||
megdnn::AlgoAttribute::REPRODUCIBLE))); | |||||
mgb_log_debug("%s", ret.c_str()); | mgb_log_debug("%s", ret.c_str()); | ||||
megdnn_opr->execution_policy() = policy; | megdnn_opr->execution_policy() = policy; | ||||
@@ -715,8 +717,10 @@ AlgoChooser<Opr>::ExeContext::profile_single_algo( | |||||
if (!rst.valid()) | if (!rst.valid()) | ||||
return None; | return None; | ||||
return AlgoChooserProfileCache::ResultEntry{ | return AlgoChooserProfileCache::ResultEntry{ | ||||
palgo->name(), palgo->is_reproducible(), rst.val().time, | |||||
param.workspace}; | |||||
palgo->name(), | |||||
palgo->contain_attribute( | |||||
megdnn::AlgoAttribute::REPRODUCIBLE), | |||||
rst.val().time, param.workspace}; | |||||
} | } | ||||
template <typename Opr> | template <typename Opr> | ||||
@@ -2127,7 +2127,8 @@ TEST(TestOprDNN, HeuristicReproducible) { | |||||
megdnn_opr->get_algorithm_from_desc(algo); | megdnn_opr->get_algorithm_from_desc(algo); | ||||
mgb_assert(palgo, "Unknown algo description"); | mgb_assert(palgo, "Unknown algo description"); | ||||
if (strategy == S::HEURISTIC_REPRODUCIBLE) { | if (strategy == S::HEURISTIC_REPRODUCIBLE) { | ||||
EXPECT_TRUE(palgo->is_reproducible()); | |||||
EXPECT_TRUE(palgo->contain_attribute( | |||||
megdnn::AlgoAttribute::REPRODUCIBLE)); | |||||
} | } | ||||
algo_name0 = palgo->name(); | algo_name0 = palgo->name(); | ||||
} | } | ||||
@@ -2338,7 +2339,9 @@ class MockAlgorithm : public megdnn::detail::Algorithm { | |||||
public: | public: | ||||
MockAlgorithm(const char* name = "NotImportant") : m_name(name) {} | MockAlgorithm(const char* name = "NotImportant") : m_name(name) {} | ||||
bool is_reproducible() const override { return true; } | |||||
Attribute attribute() const override { | |||||
return Attribute::REPRODUCIBLE; | |||||
} | |||||
const char* name() const override { return m_name; } | const char* name() const override { return m_name; } | ||||
uint32_t type() const override { | uint32_t type() const override { | ||||
return megdnn::detail::Algorithm::INVALID_ALGO_TYPE; | return megdnn::detail::Algorithm::INVALID_ALGO_TYPE; | ||||