Browse Source

refactor(dnn): refactor algorithm type interface

GitOrigin-RevId: 843d885f82
release-1.1
Megvii Engine Team 4 years ago
parent
commit
f7b2bdae1a
53 changed files with 230 additions and 277 deletions
  1. +3
    -2
      dnn/include/megdnn/oprs/base.h
  2. +1
    -1
      dnn/src/aarch64/conv_bias/opr_impl.cpp
  3. +7
    -2
      dnn/src/aarch64/conv_bias/opr_impl.h
  4. +0
    -21
      dnn/src/aarch64/matrix_mul/algos.h
  5. +2
    -2
      dnn/src/aarch64/matrix_mul/opr_impl.cpp
  6. +8
    -2
      dnn/src/aarch64/matrix_mul/opr_impl.h
  7. +16
    -11
      dnn/src/arm_common/conv_bias/opr_impl.cpp
  8. +10
    -8
      dnn/src/arm_common/conv_bias/opr_impl.h
  9. +10
    -6
      dnn/src/arm_common/convolution/int8x8x32/algos.cpp
  10. +18
    -15
      dnn/src/arm_common/convolution/int8x8x32/algos.h
  11. +20
    -25
      dnn/src/arm_common/convolution/opr_impl.cpp
  12. +12
    -9
      dnn/src/arm_common/convolution/opr_impl.h
  13. +10
    -6
      dnn/src/arm_common/convolution/quint8/algos.cpp
  14. +18
    -13
      dnn/src/arm_common/convolution/quint8/algos.h
  15. +0
    -8
      dnn/src/arm_common/matrix_mul/algos.h
  16. +2
    -9
      dnn/src/arm_common/matrix_mul/opr_impl.cpp
  17. +8
    -3
      dnn/src/arm_common/matrix_mul/opr_impl.h
  18. +1
    -1
      dnn/src/armv7/conv_bias/opr_impl.cpp
  19. +7
    -2
      dnn/src/armv7/conv_bias/opr_impl.h
  20. +0
    -14
      dnn/src/armv7/matrix_mul/algos.h
  21. +2
    -2
      dnn/src/armv7/matrix_mul/opr_impl.cpp
  22. +8
    -1
      dnn/src/armv7/matrix_mul/opr_impl.h
  23. +1
    -0
      dnn/src/cuda/batch_conv_bias/algo.h
  24. +1
    -0
      dnn/src/cuda/batched_matrix_mul/algo.h
  25. +1
    -0
      dnn/src/cuda/conv_bias/algo.h
  26. +1
    -0
      dnn/src/cuda/convolution/backward_data/algo.h
  27. +1
    -0
      dnn/src/cuda/convolution/backward_filter/algo.h
  28. +1
    -0
      dnn/src/cuda/convolution3d/backward_data/algo.h
  29. +3
    -2
      dnn/src/cuda/convolution3d/backward_filter/algo.h
  30. +3
    -2
      dnn/src/cuda/convolution3d/forward/algo.h
  31. +1
    -0
      dnn/src/cuda/deformable_conv/bwd_data/algo.h
  32. +1
    -0
      dnn/src/cuda/deformable_conv/bwd_flt/algo.h
  33. +1
    -0
      dnn/src/cuda/deformable_conv/fwd/algo.h
  34. +1
    -0
      dnn/src/cuda/local_share/backward_data/algo.h
  35. +1
    -0
      dnn/src/cuda/local_share/backward_filter/algo.h
  36. +1
    -0
      dnn/src/cuda/local_share/forward/algo.h
  37. +3
    -2
      dnn/src/cuda/matrix_mul/algos.h
  38. +3
    -0
      dnn/src/fallback/conv_bias/opr_impl.h
  39. +0
    -4
      dnn/src/fallback/convolution/algos.h
  40. +2
    -10
      dnn/src/fallback/convolution/opr_impl.cpp
  41. +6
    -4
      dnn/src/fallback/convolution/opr_impl.h
  42. +1
    -0
      dnn/src/fallback/matrix_mul/opr_impl.h
  43. +5
    -4
      dnn/src/rocm/batched_matrix_mul/opr_impl.cpp
  44. +1
    -0
      dnn/src/rocm/convolution/backward_data/algo.h
  45. +1
    -0
      dnn/src/rocm/convolution/backward_filter/algo.h
  46. +1
    -0
      dnn/src/rocm/convolution/forward/algo.h
  47. +0
    -7
      dnn/src/x86/conv_bias/f32/algos.h
  48. +0
    -7
      dnn/src/x86/conv_bias/int8/algos.h
  49. +10
    -50
      dnn/src/x86/conv_bias/opr_impl.cpp
  50. +7
    -2
      dnn/src/x86/conv_bias/opr_impl.h
  51. +0
    -10
      dnn/src/x86/matrix_mul/algos.h
  52. +2
    -8
      dnn/src/x86/matrix_mul/opr_impl.cpp
  53. +7
    -2
      dnn/src/x86/matrix_mul/opr_impl.h

+ 3
- 2
dnn/include/megdnn/oprs/base.h View File

@@ -11,6 +11,7 @@
#pragma once #pragma once


#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "megdnn/handle.h"


#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"
namespace megdnn { namespace megdnn {
@@ -105,11 +106,11 @@ public:
virtual bool is_reproducible() const = 0; virtual bool is_reproducible() const = 0;
virtual const char* name() const = 0; virtual const char* name() const = 0;


//! a pointer to represent class type
virtual void* type() const { return nullptr; }
Handle::HandleType handle_type() const { return m_handle_type; }


protected: protected:
~Algorithm() = default; ~Algorithm() = default;
Handle::HandleType m_handle_type = Handle::HandleType::NAIVE;
}; };


/*! /*!


+ 1
- 1
dnn/src/aarch64/conv_bias/opr_impl.cpp View File

@@ -45,7 +45,7 @@ public:
SmallVector<AlgoBase*> matmul_algos; SmallVector<AlgoBase*> matmul_algos;
}; };


SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack; static AlgoPack sl_algo_pack;
auto&& algos = arm_common::ConvBiasImpl::algo_pack(); auto&& algos = arm_common::ConvBiasImpl::algo_pack();
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),


+ 7
- 2
dnn/src/aarch64/conv_bias/opr_impl.h View File

@@ -18,11 +18,16 @@ namespace aarch64 {
class ConvBiasImpl : public arm_common::ConvBiasImpl { class ConvBiasImpl : public arm_common::ConvBiasImpl {
public: public:
using arm_common::ConvBiasImpl::ConvBiasImpl; using arm_common::ConvBiasImpl::ConvBiasImpl;
class AlgoBase : public arm_common::ConvBiasImpl::AlgoBase {
public:
AlgoBase() : arm_common::ConvBiasImpl::AlgoBase() {
m_handle_type = Handle::HandleType::AARCH64;
}
};


SmallVector<AlgoBase*> algo_pack() override;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override;


protected: protected:

const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;


private: private:


+ 0
- 21
dnn/src/aarch64/matrix_mul/algos.h View File

@@ -26,7 +26,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -37,7 +36,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -48,7 +46,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -59,7 +56,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4)
}; };
@@ -75,7 +71,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -86,7 +81,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8)
}; };
@@ -103,7 +97,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -116,7 +109,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
#else #else
@@ -129,7 +121,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; } PackMode packmode() const override { return PackMode::DEFAULT; }


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
@@ -143,7 +134,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
@@ -156,7 +146,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
#endif #endif
@@ -169,7 +158,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
@@ -182,7 +170,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -196,7 +183,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; } PackMode packmode() const override { return PackMode::DEFAULT; }


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
@@ -212,7 +198,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; } PackMode packmode() const override { return PackMode::DEFAULT; }


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
@@ -226,7 +211,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::DEFAULT; } PackMode packmode() const override { return PackMode::DEFAULT; }


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
@@ -240,7 +224,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -251,7 +234,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8)
}; };
@@ -266,7 +248,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -278,7 +259,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; } size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT)
@@ -292,7 +272,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
#endif #endif


+ 2
- 2
dnn/src/aarch64/matrix_mul/opr_impl.cpp View File

@@ -52,7 +52,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#endif #endif


public: public:
SmallVector<MatrixMulImpl::AlgoBase*> all_algos;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos;


AlgoPack() { AlgoPack() {
all_algos.emplace_back(&f32_gemv); all_algos.emplace_back(&f32_gemv);
@@ -89,7 +89,7 @@ public:
} }
}; };


SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
static AlgoPack s_algo_pack; static AlgoPack s_algo_pack;
auto&& algos = arm_common::MatrixMulImpl::algo_pack(); auto&& algos = arm_common::MatrixMulImpl::algo_pack();
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), algos.insert(algos.begin(), s_algo_pack.all_algos.begin(),


+ 8
- 2
dnn/src/aarch64/matrix_mul/opr_impl.h View File

@@ -18,8 +18,14 @@ namespace aarch64 {
class MatrixMulImpl : public arm_common::MatrixMulImpl { class MatrixMulImpl : public arm_common::MatrixMulImpl {
public: public:
using arm_common::MatrixMulImpl::MatrixMulImpl; using arm_common::MatrixMulImpl::MatrixMulImpl;
class AlgoBase : public arm_common::MatrixMulImpl::AlgoBase {
public:
AlgoBase() : arm_common::MatrixMulImpl::AlgoBase() {
m_handle_type = Handle::HandleType::AARCH64;
}
};


SmallVector<AlgoBase*> algo_pack() override;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override;


private: private:
class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1 class AlgoF32K8x12x1; // Aarch64 F32 Kernel 8X12X1
@@ -57,7 +63,7 @@ private:
#else #else
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8
#endif #endif
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16


class AlgoPack; class AlgoPack;
}; };


+ 16
- 11
dnn/src/arm_common/conv_bias/opr_impl.cpp View File

@@ -11,6 +11,7 @@
*/ */


#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/base.h"
#include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/algos.h" #include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/quint8/algos.h" #include "src/arm_common/conv_bias/quint8/algos.h"
@@ -18,6 +19,7 @@
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/common/metahelper.h" #include "src/common/metahelper.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"


#include "src/arm_common/convolution/opr_impl.h" #include "src/arm_common/convolution/opr_impl.h"
@@ -37,7 +39,12 @@ using namespace megdnn;
using namespace arm_common; using namespace arm_common;


namespace { namespace {
uint8_t arm_common_algo_type_storage;

bool is_fallback_or_naive(const detail::Algorithm* algo) {
return algo->handle_type() == Handle::HandleType::NAIVE ||
algo->handle_type() == Handle::HandleType::FALLBACK;
}

} // anonymous namespace } // anonymous namespace


class ConvBiasImpl::AlgoPack : NonCopyableObj { class ConvBiasImpl::AlgoPack : NonCopyableObj {
@@ -50,7 +57,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8DirectStride1 s8_direct_stride1; AlgoS8DirectStride1 s8_direct_stride1;
AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44; AlgoS8ChanWiseStride1NCHW44 s8_channel_wise_stride1_nchw44;
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44; AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44;
AlgoS8x8x16ChanWiseStride1Stride2NCHW44 s8x8x16_channel_wise_stride1_stride2_nchw44;
AlgoS8x8x16ChanWiseStride1Stride2NCHW44
s8x8x16_channel_wise_stride1_stride2_nchw44;


#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
AlgoDotS8DirectStride1 ds8_direct_stride1; AlgoDotS8DirectStride1 ds8_direct_stride1;
@@ -129,7 +137,7 @@ public:
->select_algo_type( ->select_algo_type(
{AlgoDataType::FLOAT32, MatmulFormat::MK4}); {AlgoDataType::FLOAT32, MatmulFormat::MK4});
for (auto&& algo : matmul_algos) { for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
if (is_fallback_or_naive(algo))
continue; continue;
for (uint32_t tile_size : {16, 8, 24, 32}) { for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP32WinogradF23_4x4( refhold.emplace_back(new AlgoFP32WinogradF23_4x4(
@@ -166,7 +174,7 @@ public:
->select_algo_type({AlgoDataType::FLOAT32, ->select_algo_type({AlgoDataType::FLOAT32,
MatmulFormat::DEFAULT}); MatmulFormat::DEFAULT});
for (auto&& algo : matmul_algos) { for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
if (is_fallback_or_naive(algo))
continue; continue;
for (uint32_t tile_size : {16, 8, 24, 32}) { for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP32WinogradF63( refhold.emplace_back(new AlgoFP32WinogradF63(
@@ -189,7 +197,7 @@ public:
->select_algo_type({AlgoDataType::FLOAT16, ->select_algo_type({AlgoDataType::FLOAT16,
MatmulFormat::DEFAULT}); MatmulFormat::DEFAULT});
for (auto&& algo : matmul_algos) { for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
if (is_fallback_or_naive(algo))
continue; continue;
for (uint32_t tile_size : {16, 8, 24, 32}) { for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP16WinogradF23( refhold.emplace_back(new AlgoFP16WinogradF23(
@@ -210,7 +218,7 @@ public:
->select_algo_type({AlgoDataType::FLOAT16, ->select_algo_type({AlgoDataType::FLOAT16,
MatmulFormat::MK8}); MatmulFormat::MK8});
for (auto&& algo : matmul_algos) { for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
if (is_fallback_or_naive(algo))
continue; continue;
for (uint32_t tile_size : {16, 8, 24, 32}) { for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP16WinogradF23_8x8( refhold.emplace_back(new AlgoFP16WinogradF23_8x8(
@@ -224,7 +232,7 @@ public:
->select_algo_type({AlgoDataType::INT16X16X32, ->select_algo_type({AlgoDataType::INT16X16X32,
MatmulFormat::MK8}); MatmulFormat::MK8});
for (auto&& algo : matmul_algos) { for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
if (is_fallback_or_naive(algo))
continue; continue;
for (uint32_t tile_size : {16, 8, 24, 32}) { for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoS8WinogradF23_8x8( refhold.emplace_back(new AlgoS8WinogradF23_8x8(
@@ -242,7 +250,7 @@ public:
SmallVector<AlgoBase*> winograd_algos; SmallVector<AlgoBase*> winograd_algos;
}; };


SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack; static AlgoPack sl_algo_pack;
auto&& algos = fallback::ConvBiasImpl::algo_pack(); auto&& algos = fallback::ConvBiasImpl::algo_pack();
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
@@ -252,9 +260,6 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
return std::move(algos); return std::move(algos);
} }


void* const ConvBiasImpl::sm_arm_common_algo_type =
&arm_common_algo_type_storage;

bool ConvBiasImpl::is_matmul_quantized_prefer( bool ConvBiasImpl::is_matmul_quantized_prefer(
const ConvBiasImpl::NCBKernSizeParam& param) const { const ConvBiasImpl::NCBKernSizeParam& param) const {
fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param( fallback::ConvBiasImpl::NCBKernSizeParam conv_ncb_param(


+ 10
- 8
dnn/src/arm_common/conv_bias/opr_impl.h View File

@@ -19,23 +19,25 @@ namespace arm_common {
class ConvBiasImpl : public fallback::ConvBiasImpl { class ConvBiasImpl : public fallback::ConvBiasImpl {
public: public:
using fallback::ConvBiasImpl::ConvBiasImpl; using fallback::ConvBiasImpl::ConvBiasImpl;
using FallbackConvBiasImpl = fallback::ConvBiasImpl;
using NCBKernIndex = fallback::ConvBiasImpl::NCBKernIndex;

bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }
class AlgoBase : public fallback::ConvBiasImpl::AlgoBase {
public:
AlgoBase() : fallback::ConvBiasImpl::AlgoBase() {
m_handle_type = Handle::HandleType::ARM_COMMON;
}
};


SmallVector<AlgoBase*> algo_pack() override;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override;


bool is_matmul_quantized_prefer( bool is_matmul_quantized_prefer(
const ConvBiasImpl::NCBKernSizeParam& ncb_param) const override;
const fallback::ConvBiasImpl::NCBKernSizeParam& ncb_param)
const override;


SmallVector<AlgoCategory> suggest_algo_category_order( SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
class AlgoPack; class AlgoPack;


protected: protected:
static void* const sm_arm_common_algo_type;

const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;


private: private:
@@ -93,7 +95,7 @@ private:
class AlgoF16Direct; class AlgoF16Direct;
class AlgoF16DirectStride1; class AlgoF16DirectStride1;
#endif #endif
};
};


} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn


+ 10
- 6
dnn/src/arm_common/convolution/int8x8x32/algos.cpp View File

@@ -26,12 +26,14 @@ using namespace arm_common;
/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */
/* ===================== direct stride 1 algo ===================== */ /* ===================== direct stride 1 algo ===================== */
bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const {
return deconv::can_stride1_int8x8x32_dot(param); return deconv::can_stride1_int8x8x32_dot(param);
} }


size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace( size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride1::get_workspace"_hash)) { midout_iv("AlgoSdot8DirectStride1::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param); return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param);
@@ -42,7 +44,7 @@ size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace(


ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride1::dispatch_kern"_hash)) { midout_iv("AlgoSdot8DirectStride1::dispatch_kern"_hash)) {
return deconv::stride1_int8x8x32_dot; return deconv::stride1_int8x8x32_dot;
@@ -53,12 +55,14 @@ ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern(


/* ===================== direct stride 2 algo ===================== */ /* ===================== direct stride 2 algo ===================== */
bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const {
return deconv::can_stride2_int8x8x32_dot(param); return deconv::can_stride2_int8x8x32_dot(param);
} }


size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace( size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride2::get_workspace"_hash)) { midout_iv("AlgoSdot8DirectStride2::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param); return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param);
@@ -69,7 +73,7 @@ size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace(


ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern( ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl, MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride2::dispatch_kern"_hash)) { midout_iv("AlgoSdot8DirectStride2::dispatch_kern"_hash)) {
return deconv::stride2_int8x8x32_dot; return deconv::stride2_int8x8x32_dot;


+ 18
- 15
dnn/src/arm_common/convolution/int8x8x32/algos.h View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#pragma once #pragma once
@@ -19,38 +20,40 @@ namespace arm_common {
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */


class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final : public AlgoBase {
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final
: public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH32_I8x8x32_DECONV_STRIDE1"; }
const char* name() const override {
return "AARCH32_I8x8x32_DECONV_STRIDE1";
}


bool usable(ConvolutionBackwardDataImpl*,
bool usable(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;


size_t get_workspace(ConvolutionBackwardDataImpl*,
size_t get_workspace(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;


ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*,
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override; const NCBKernSizeParam&) const override;

void* type() const override { return sm_arm_common_algo_type; }
}; };


class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final : public AlgoBase {
class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2 final
: public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "AARCH32_I8x8x32_DECONV_STRIDE2"; }
const char* name() const override {
return "AARCH32_I8x8x32_DECONV_STRIDE2";
}


bool usable(ConvolutionBackwardDataImpl*,
bool usable(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;


size_t get_workspace(ConvolutionBackwardDataImpl*,
size_t get_workspace(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;


ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*,
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override; const NCBKernSizeParam&) const override;

void* type() const override { return sm_arm_common_algo_type; }
}; };


#endif #endif


+ 20
- 25
dnn/src/arm_common/convolution/opr_impl.cpp View File

@@ -21,9 +21,6 @@
using namespace megdnn; using namespace megdnn;
using namespace arm_common; using namespace arm_common;


namespace {
uint8_t arm_common_algo_type_storage;
} // anonymous namespace


/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */
struct ConvolutionBackwardDataImpl::AlgoPack { struct ConvolutionBackwardDataImpl::AlgoPack {
@@ -36,46 +33,44 @@ struct ConvolutionBackwardDataImpl::AlgoPack {
}; };
ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;


void* const ConvolutionBackwardDataImpl::sm_arm_common_algo_type =
&arm_common_algo_type_storage;

ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
Algorithm* algo, const NCBKernSizeParam& param) { Algorithm* algo, const NCBKernSizeParam& param) {
if (algo->type() == sm_arm_common_algo_type) {
if (algo->handle_type() == Handle::HandleType::ARM_COMMON) {
return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param); return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
} }
return fallback::ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(algo, param);
return fallback::ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(algo,
param);
} }


size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(Algorithm* algo,
const NCBKernSizeParam& param) {
if (algo->type() == sm_arm_common_algo_type) {
size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
Algorithm* algo, const NCBKernSizeParam& param) {
if (algo->handle_type() == Handle::HandleType::ARM_COMMON) {
return static_cast<AlgoBase*>(algo)->get_workspace(this, param); return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
} }
return fallback::ConvolutionBackwardDataImpl::ncb_1g_get_workspace(algo, param);
return fallback::ConvolutionBackwardDataImpl::ncb_1g_get_workspace(algo,
param);
} }


std::vector<ConvolutionBackwardDataImpl::Algorithm*> std::vector<ConvolutionBackwardDataImpl::Algorithm*>
ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(const NCBKernSizeParam& param) {

auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(param);
ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
const NCBKernSizeParam& param) {
auto ret = fallback::ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms(
param);


#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
if((param.filter_type.enumv() == DTypeEnum::QuantizedS8 ||
param.filter_type.enumv() == DTypeEnum::Int8) &&
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 ||
param.grad_type.enumv() == DTypeEnum::Int32)) {

if ((param.filter_type.enumv() == DTypeEnum::QuantizedS8 ||
param.filter_type.enumv() == DTypeEnum::Int8) &&
(param.grad_type.enumv() == DTypeEnum::QuantizedS32 ||
param.grad_type.enumv() == DTypeEnum::Int32)) {
if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) { if (sm_algo_pack.i8x8x32_direct_stride1_sdot.usable(this, param)) {
ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot); ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride1_sdot);
} }
if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) { if (sm_algo_pack.i8x8x32_direct_stride2_sdot.usable(this, param)) {
ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot); ret.insert(ret.begin(), &sm_algo_pack.i8x8x32_direct_stride2_sdot);
} }
}
else if(param.filter_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.grad_type.enumv() == DTypeEnum::QuantizedS32) {

} else if (param.filter_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.grad_type.enumv() == DTypeEnum::QuantizedS32) {
if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) { if (sm_algo_pack.quint8_direct_stride1_udot.usable(this, param)) {
ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot); ret.insert(ret.begin(), &sm_algo_pack.quint8_direct_stride1_udot);
} }


+ 12
- 9
dnn/src/arm_common/convolution/opr_impl.h View File

@@ -18,24 +18,27 @@ namespace arm_common {


class ConvBiasImpl; class ConvBiasImpl;


class ConvolutionBackwardDataImpl : public fallback::ConvolutionBackwardDataImpl {
class ConvolutionBackwardDataImpl
: public fallback::ConvolutionBackwardDataImpl {
public: public:
using fallback::ConvolutionBackwardDataImpl::ConvolutionBackwardDataImpl; using fallback::ConvolutionBackwardDataImpl::ConvolutionBackwardDataImpl;


protected: protected:
static void* const sm_arm_common_algo_type;

class AlgoBase : public Algorithm {
class AlgoBase : public fallback::ConvolutionBackwardDataImpl::AlgoBase {
protected: protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
virtual bool usable(ConvolutionBackwardDataImpl* opr,
AlgoBase() : fallback::ConvolutionBackwardDataImpl::AlgoBase() {
m_handle_type = Handle::HandleType::ARM_COMMON;
}
virtual bool usable(fallback::ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const = 0; const NCBKernSizeParam& param) const = 0;
virtual size_t get_workspace(ConvolutionBackwardDataImpl* opr,
virtual size_t get_workspace(fallback::ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const = 0; const NCBKernSizeParam& param) const = 0;
virtual ncb_kern_t dispatch_kern( virtual ncb_kern_t dispatch_kern(
ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param) const = 0;
fallback::ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const = 0;
}; };


ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo, ncb_kern_t ncb_1g_dispatch_kern(Algorithm* algo,
@@ -49,7 +52,7 @@ protected:


const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;


private:
private:
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
class AlgoSdot8DirectStride1; class AlgoSdot8DirectStride1;
class AlgoSdot8DirectStride2; class AlgoSdot8DirectStride2;
@@ -62,4 +65,4 @@ protected:


} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 10
- 6
dnn/src/arm_common/convolution/quint8/algos.cpp View File

@@ -27,12 +27,14 @@ using namespace arm_common;


/* ===================== direct stride 1 algo ===================== */ /* ===================== direct stride 1 algo ===================== */
bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const {
return deconv::can_stride1_quint8_dot(param); return deconv::can_stride1_quint8_dot(param);
} }


size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace( size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride1::get_workspace"_hash)) { midout_iv("AlgoUdot8DirectStride1::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride1_quint8_dot(param); return deconv::get_workspace_in_bytes_stride1_quint8_dot(param);
@@ -43,7 +45,7 @@ size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace(


ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride1::dispatch_kern"_hash)) { midout_iv("AlgoUdot8DirectStride1::dispatch_kern"_hash)) {
return deconv::stride1_quint8_dot; return deconv::stride1_quint8_dot;
@@ -54,12 +56,14 @@ ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern(


/* ===================== direct stride 2 algo ===================== */ /* ===================== direct stride 2 algo ===================== */
bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const {
return deconv::can_stride2_quint8_dot(param); return deconv::can_stride2_quint8_dot(param);
} }


size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace( size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride2::get_workspace"_hash)) { midout_iv("AlgoUdot8DirectStride2::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride2_quint8_dot(param); return deconv::get_workspace_in_bytes_stride2_quint8_dot(param);
@@ -70,7 +74,7 @@ size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace(


ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern( ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl, MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride2::dispatch_kern"_hash)) { midout_iv("AlgoUdot8DirectStride2::dispatch_kern"_hash)) {
return deconv::stride2_quint8_dot; return deconv::stride2_quint8_dot;


+ 18
- 13
dnn/src/arm_common/convolution/quint8/algos.h View File

@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */


#pragma once #pragma once
@@ -18,38 +19,42 @@ namespace arm_common {


#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final : public AlgoBase {
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final
: public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1"; }
const char* name() const override {
return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE1";
}


bool usable(ConvolutionBackwardDataImpl*,
bool usable(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;


size_t get_workspace(ConvolutionBackwardDataImpl*,
size_t get_workspace(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;


ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*,
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override; const NCBKernSizeParam&) const override;


void* type() const override { return sm_arm_common_algo_type; }
}; };


class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final : public AlgoBase {
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2 final
: public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2"; }
const char* name() const override {
return "ARM_COMMON_QUINT8_DIRECT_DECONV_STRIDE2";
}


bool usable(ConvolutionBackwardDataImpl*,
bool usable(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;


size_t get_workspace(ConvolutionBackwardDataImpl*,
size_t get_workspace(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;


ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*,
ncb_kern_t dispatch_kern(fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override; const NCBKernSizeParam&) const override;


void* type() const override { return sm_arm_common_algo_type; }
}; };
#endif #endif
} // namespace arm_common } // namespace arm_common


+ 0
- 8
dnn/src/arm_common/matrix_mul/algos.h View File

@@ -24,7 +24,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT)
}; };
@@ -37,7 +36,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; } size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT)
@@ -51,7 +49,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; } size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4)
@@ -66,7 +63,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; } size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT)
@@ -84,7 +80,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; } size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
@@ -98,7 +93,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; } size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4)
@@ -113,7 +107,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; } size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT)
@@ -128,7 +121,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; } size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC( MEGDNN_OVERRIDE_MATMUL_DESC(


+ 2
- 9
dnn/src/arm_common/matrix_mul/opr_impl.cpp View File

@@ -15,13 +15,6 @@
using namespace megdnn; using namespace megdnn;
using namespace arm_common; using namespace arm_common;


namespace {
uint8_t arm_common_algo_type_storage;
} // anonymous namespace

void* const MatrixMulImpl::sm_arm_common_algo_type =
&arm_common_algo_type_storage;

class MatrixMulImpl::AlgoPack : NonCopyableObj { class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x16 int8x8x16; AlgoInt8x8x16 int8x8x16;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -49,10 +42,10 @@ public:
all_algos.emplace_back(&f32_gemv_mk4); all_algos.emplace_back(&f32_gemv_mk4);
all_algos.emplace_back(&gevm); all_algos.emplace_back(&gevm);
} }
SmallVector<AlgoBase*> all_algos;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos;
}; };


SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
static AlgoPack s_algo_pack; static AlgoPack s_algo_pack;
auto&& algos = fallback::MatrixMulImpl::algo_pack(); auto&& algos = fallback::MatrixMulImpl::algo_pack();
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), algos.insert(algos.begin(), s_algo_pack.all_algos.begin(),


+ 8
- 3
dnn/src/arm_common/matrix_mul/opr_impl.h View File

@@ -18,13 +18,18 @@ namespace arm_common {
class MatrixMulImpl : public fallback::MatrixMulImpl { class MatrixMulImpl : public fallback::MatrixMulImpl {
public: public:
using fallback::MatrixMulImpl::MatrixMulImpl; using fallback::MatrixMulImpl::MatrixMulImpl;

bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }


SmallVector<AlgoBase*> algo_pack() override;
class AlgoBase : public fallback::MatrixMulImpl::AlgoBase {
public:
AlgoBase() : fallback::MatrixMulImpl::AlgoBase() {
m_handle_type = Handle::HandleType::ARM_COMMON;
}
};

SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override;


protected: protected:
static void* const sm_arm_common_algo_type;
class AlgoF32Gemv; // Arm_common F32 Gemv class AlgoF32Gemv; // Arm_common F32 Gemv
class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44 class AlgoF32GemvMK4; // Arm_common F32 Gemv NCHW44
class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv class AlgoInt8x8x32Gemv; // Arm_common Int8x8x32 Gemv


+ 1
- 1
dnn/src/armv7/conv_bias/opr_impl.cpp View File

@@ -32,7 +32,7 @@ public:
SmallVector<AlgoBase*> all_algos; SmallVector<AlgoBase*> all_algos;
}; };


SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack; static AlgoPack sl_algo_pack;
auto&& algos = arm_common::ConvBiasImpl::algo_pack(); auto&& algos = arm_common::ConvBiasImpl::algo_pack();
//! TODO fused matmul bias is slower than matmul + elemwise in armv7 now, //! TODO fused matmul bias is slower than matmul + elemwise in armv7 now,


+ 7
- 2
dnn/src/armv7/conv_bias/opr_impl.h View File

@@ -18,11 +18,16 @@ namespace armv7 {
class ConvBiasImpl : public arm_common::ConvBiasImpl { class ConvBiasImpl : public arm_common::ConvBiasImpl {
public: public:
using arm_common::ConvBiasImpl::ConvBiasImpl; using arm_common::ConvBiasImpl::ConvBiasImpl;
class AlgoBase : public arm_common::ConvBiasImpl::AlgoBase {
public:
AlgoBase() : arm_common::ConvBiasImpl::AlgoBase() {
m_handle_type = Handle::HandleType::ARMV7;
}
};


SmallVector<AlgoBase*> algo_pack() override;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override;


protected: protected:

const char* get_algorithm_set_name() const override; const char* get_algorithm_set_name() const override;


private: private:


+ 0
- 14
dnn/src/armv7/matrix_mul/algos.h View File

@@ -26,7 +26,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -37,7 +36,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -48,7 +46,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4)
}; };
@@ -61,7 +58,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase { class MatrixMulImpl::AlgoF16MK8_4x8 final : public AlgoBase {
@@ -71,7 +67,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8)
}; };
@@ -121,7 +116,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -133,7 +127,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -144,7 +137,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -156,7 +148,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -168,7 +159,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -180,7 +170,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -192,7 +181,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -203,7 +191,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8)
}; };
@@ -216,7 +203,6 @@ public:
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };




+ 2
- 2
dnn/src/armv7/matrix_mul/opr_impl.cpp View File

@@ -44,7 +44,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8; AlgoInt16x16x32MK8_4x8 int16x16x32_mk8_4x8;


public: public:
SmallVector<MatrixMulImpl::AlgoBase*> all_algos;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos;


AlgoPack() { AlgoPack() {
all_algos.emplace_back(&f32_gemv); all_algos.emplace_back(&f32_gemv);
@@ -73,7 +73,7 @@ public:
} }
}; };


SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
static AlgoPack s_algo_pack; static AlgoPack s_algo_pack;
auto algos = arm_common::MatrixMulImpl::algo_pack(); auto algos = arm_common::MatrixMulImpl::algo_pack();
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), algos.insert(algos.begin(), s_algo_pack.all_algos.begin(),


+ 8
- 1
dnn/src/armv7/matrix_mul/opr_impl.h View File

@@ -18,7 +18,14 @@ namespace armv7 {
class MatrixMulImpl : public arm_common::MatrixMulImpl { class MatrixMulImpl : public arm_common::MatrixMulImpl {
public: public:
using arm_common::MatrixMulImpl::MatrixMulImpl; using arm_common::MatrixMulImpl::MatrixMulImpl;
SmallVector<AlgoBase*> algo_pack() override;
class AlgoBase : public arm_common::MatrixMulImpl::AlgoBase {
public:
AlgoBase() : arm_common::MatrixMulImpl::AlgoBase() {
m_handle_type = Handle::HandleType::ARMV7;
}
};

SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override;


private: private:
class AlgoF32; // Armv7 F32 class AlgoF32; // Armv7 F32


+ 1
- 0
dnn/src/cuda/batch_conv_bias/algo.h View File

@@ -26,6 +26,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
BatchConvBiasForwardImpl* opr; BatchConvBiasForwardImpl* opr;
TensorLayout src_layout, filter_layout, bias_layout, z_layout, TensorLayout src_layout, filter_layout, bias_layout, z_layout,


+ 1
- 0
dnn/src/cuda/batched_matrix_mul/algo.h View File

@@ -28,6 +28,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
BatchedMatrixMulForwardImpl* opr; BatchedMatrixMulForwardImpl* opr;
TensorLayout layout_a, layout_b, layout_c; TensorLayout layout_a, layout_b, layout_c;


+ 1
- 0
dnn/src/cuda/conv_bias/algo.h View File

@@ -38,6 +38,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs : public conv_bias::BiasForwardSizeArgs { struct SizeArgs : public conv_bias::BiasForwardSizeArgs {
ConvBiasForwardImpl* opr; ConvBiasForwardImpl* opr;




+ 1
- 0
dnn/src/cuda/convolution/backward_data/algo.h View File

@@ -28,6 +28,7 @@ class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm {
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
HandleImpl *handle; HandleImpl *handle;
CanonizedFilterMeta filter_meta; CanonizedFilterMeta filter_meta;


+ 1
- 0
dnn/src/cuda/convolution/backward_filter/algo.h View File

@@ -28,6 +28,7 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm {
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
HandleImpl *handle; HandleImpl *handle;
const TensorLayout *src_layout, *diff_layout, *grad_layout; const TensorLayout *src_layout, *diff_layout, *grad_layout;


+ 1
- 0
dnn/src/cuda/convolution3d/backward_data/algo.h View File

@@ -28,6 +28,7 @@ class Convolution3DBackwardDataImpl::AlgoBase: public Algorithm {
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
HandleImpl *handle; HandleImpl *handle;
CanonizedFilterMeta filter_meta; CanonizedFilterMeta filter_meta;


+ 3
- 2
dnn/src/cuda/convolution3d/backward_filter/algo.h View File

@@ -22,6 +22,7 @@ class Convolution3DBackwardFilterImpl::AlgoBase: public Algorithm {
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
HandleImpl *handle; HandleImpl *handle;
const TensorLayout *src_layout, *diff_layout; const TensorLayout *src_layout, *diff_layout;
@@ -128,8 +129,8 @@ class Convolution3DBackwardFilterImpl::AlgoInplaceMatmul final: public AlgoBase
const char* name() const override { const char* name() const override {
return "INPLACE_MATMUL"; return "INPLACE_MATMUL";
} }
bool is_reproducible() const override {
return false;
bool is_reproducible() const override {
return false;
} }
}; };




+ 3
- 2
dnn/src/cuda/convolution3d/forward/algo.h View File

@@ -34,6 +34,7 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm {
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs: public convolution3d::ForwardSizeArgs { struct SizeArgs: public convolution3d::ForwardSizeArgs {
Convolution3DForwardImpl *opr; Convolution3DForwardImpl *opr;


@@ -42,11 +43,11 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm {
desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); desc.set(*src_layout, filter_meta, *dst_layout, opr->param());
} }
SizeArgs(Convolution3DForwardImpl *opr, SizeArgs(Convolution3DForwardImpl *opr,
const TensorLayout &src,
const TensorLayout &src,
const TensorLayout &filter, const TensorLayout &filter,
const TensorLayout &dst); const TensorLayout &dst);
SizeArgs(Convolution3DForwardImpl *opr, SizeArgs(Convolution3DForwardImpl *opr,
const TensorLayout &src,
const TensorLayout &src,
const CanonizedFilterMeta &filter, const CanonizedFilterMeta &filter,
const TensorLayout &dst); const TensorLayout &dst);
}; };


+ 1
- 0
dnn/src/cuda/deformable_conv/bwd_data/algo.h View File

@@ -26,6 +26,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
DeformableConvBackwardDataImpl* opr; DeformableConvBackwardDataImpl* opr;
HandleImpl* handle; HandleImpl* handle;


+ 1
- 0
dnn/src/cuda/deformable_conv/bwd_flt/algo.h View File

@@ -26,6 +26,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
DeformableConvBackwardFilterImpl* opr; DeformableConvBackwardFilterImpl* opr;
HandleImpl* handle; HandleImpl* handle;


+ 1
- 0
dnn/src/cuda/deformable_conv/fwd/algo.h View File

@@ -24,6 +24,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
DeformableConvForwardImpl* opr; DeformableConvForwardImpl* opr;
HandleImpl* handle; HandleImpl* handle;


+ 1
- 0
dnn/src/cuda/local_share/backward_data/algo.h View File

@@ -25,6 +25,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
LocalShareBackwardDataImpl* opr; LocalShareBackwardDataImpl* opr;
TensorLayout filter_layout, diff_layout, grad_layout; TensorLayout filter_layout, diff_layout, grad_layout;


+ 1
- 0
dnn/src/cuda/local_share/backward_filter/algo.h View File

@@ -25,6 +25,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
LocalShareBackwardFilterImpl* opr; LocalShareBackwardFilterImpl* opr;
TensorLayout src_layout, diff_layout, grad_layout; TensorLayout src_layout, diff_layout, grad_layout;


+ 1
- 0
dnn/src/cuda/local_share/forward/algo.h View File

@@ -25,6 +25,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
LocalShareForwardImpl* opr; LocalShareForwardImpl* opr;
TensorLayout src_layout, filter_layout, dst_layout; TensorLayout src_layout, filter_layout, dst_layout;


+ 3
- 2
dnn/src/cuda/matrix_mul/algos.h View File

@@ -32,13 +32,14 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::CUDA; }
struct SizeArgs { struct SizeArgs {
MatrixMulForwardImpl* opr; MatrixMulForwardImpl* opr;
TensorLayout layout_a, layout_b, layout_c; TensorLayout layout_a, layout_b, layout_c;


std::string to_string() const; std::string to_string() const;
SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A, const TensorLayout& B,
const TensorLayout& C);
SizeArgs(MatrixMulForwardImpl* opr, const TensorLayout& A,
const TensorLayout& B, const TensorLayout& C);


bool can_be_treated_as_int8x8x32() const { bool can_be_treated_as_int8x8x32() const {
return layout_a.dtype.enumv() == layout_b.dtype.enumv() && return layout_a.dtype.enumv() == layout_b.dtype.enumv() &&


+ 3
- 0
dnn/src/fallback/conv_bias/opr_impl.h View File

@@ -213,6 +213,9 @@ public:


class AlgoBase : public Algorithm { class AlgoBase : public Algorithm {
public: public:
AlgoBase() : Algorithm() {
m_handle_type = Handle::HandleType::FALLBACK;
}
virtual ~AlgoBase() = default; virtual ~AlgoBase() = default;
virtual bool usable( virtual bool usable(
const NCBKernSizeParam& param, const NCBKernSizeParam& param,


+ 0
- 4
dnn/src/fallback/convolution/algos.h View File

@@ -141,8 +141,6 @@ public:
return get_kimpl(m_algorithm, param); return get_kimpl(m_algorithm, param);
} }


void* type() const override { return sm_fallback_conv_algo_type; }

//! select matmul to the highest preference //! select matmul to the highest preference
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;


@@ -168,7 +166,6 @@ public:
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override; const NCBKernSizeParam&) const override;
void* type() const override { return sm_fallback_deconv_algo_type; }
}; };


class ConvolutionBackwardDataImpl::AlgoMatrixMul final : public AlgoBase { class ConvolutionBackwardDataImpl::AlgoMatrixMul final : public AlgoBase {
@@ -181,7 +178,6 @@ public:
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*,
const NCBKernSizeParam&) const override; const NCBKernSizeParam&) const override;
void* type() const override { return sm_fallback_deconv_algo_type; }
}; };


} // namespace fallback } // namespace fallback


+ 2
- 10
dnn/src/fallback/convolution/opr_impl.cpp View File

@@ -37,8 +37,6 @@ class NaiveConvolutionBackwardData final
const char* name() const override { return "NCBD"; } const char* name() const override { return "NCBD"; }
}; };
NaiveConvolutionBackwardData naive_conv_backward_data; NaiveConvolutionBackwardData naive_conv_backward_data;
uint8_t fallback_deconv_algo_type_storage;
uint8_t fallback_conv_algo_type_storage;


template <typename T> template <typename T>
void incr_ptr(T*& dst, ptrdiff_t delta) { void incr_ptr(T*& dst, ptrdiff_t delta) {
@@ -69,9 +67,6 @@ public:
SmallVector<AlgoBase*> all_algos; SmallVector<AlgoBase*> all_algos;
}; };


void* const ConvolutionImpl::sm_fallback_conv_algo_type =
&fallback_conv_algo_type_storage;

SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() { SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() {
static AlgoPack sl_algo_pack; static AlgoPack sl_algo_pack;
return sl_algo_pack.all_algos; return sl_algo_pack.all_algos;
@@ -412,9 +407,6 @@ ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const {


/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */


void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type =
&fallback_deconv_algo_type_storage;

struct ConvolutionBackwardDataImpl::AlgoPack { struct ConvolutionBackwardDataImpl::AlgoPack {
AlgoDirect direct; AlgoDirect direct;
AlgoMatrixMul matmul; AlgoMatrixMul matmul;
@@ -630,7 +622,7 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic_with_ncb(
size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace(
Algorithm* algo, const NCBKernSizeParam& param) { Algorithm* algo, const NCBKernSizeParam& param) {
megdnn_assert(param.filter_meta.group == 1); megdnn_assert(param.filter_meta.group == 1);
if (algo->type() == sm_fallback_deconv_algo_type) {
if (algo->handle_type() == Handle::HandleType::FALLBACK) {
return static_cast<AlgoBase*>(algo)->get_workspace(this, param); return static_cast<AlgoBase*>(algo)->get_workspace(this, param);
} }
megdnn_assert(algo == &naive_conv_backward_data); megdnn_assert(algo == &naive_conv_backward_data);
@@ -642,7 +634,7 @@ ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern(
Algorithm* algo, const NCBKernSizeParam& param) { Algorithm* algo, const NCBKernSizeParam& param) {
megdnn_assert(param.filter_meta.group == 1); megdnn_assert(param.filter_meta.group == 1);


if (algo->type() == sm_fallback_deconv_algo_type) {
if (algo->handle_type() == Handle::HandleType::FALLBACK) {
return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param); return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param);
} }




+ 6
- 4
dnn/src/fallback/convolution/opr_impl.h View File

@@ -177,8 +177,6 @@ public:
} }
}; };


static void* const sm_fallback_conv_algo_type;

/** /**
* \brief Kernel run time id, This information is used for getting the * \brief Kernel run time id, This information is used for getting the
* work data * work data
@@ -197,6 +195,9 @@ public:


class AlgoBase : public Algorithm { class AlgoBase : public Algorithm {
public: public:
AlgoBase() : Algorithm() {
m_handle_type = Handle::HandleType::FALLBACK;
}
virtual ~AlgoBase() = default; virtual ~AlgoBase() = default;
virtual bool usable(const NCBKernSizeParam& param, virtual bool usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const = 0; AlgoSelectionStrategy) const = 0;
@@ -407,13 +408,14 @@ protected:
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible = false); bool reproducible = false);


static void* const sm_fallback_deconv_algo_type;

class AlgoBase : public Algorithm { class AlgoBase : public Algorithm {
protected: protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() {
m_handle_type = Handle::HandleType::FALLBACK;
}
virtual bool usable(ConvolutionBackwardDataImpl* opr, virtual bool usable(ConvolutionBackwardDataImpl* opr,
const NCBKernSizeParam& param) const = 0; const NCBKernSizeParam& param) const = 0;
virtual size_t get_workspace(ConvolutionBackwardDataImpl* opr, virtual size_t get_workspace(ConvolutionBackwardDataImpl* opr,


+ 1
- 0
dnn/src/fallback/matrix_mul/opr_impl.h View File

@@ -103,6 +103,7 @@ public:
} }


public: public:
AlgoBase() { m_handle_type = Handle::HandleType::FALLBACK; }
enum class AlgoSet : uint32_t { enum class AlgoSet : uint32_t {
ALGO_TYPE_GEMM = 0, ALGO_TYPE_GEMM = 0,
ALGO_TYPE_GEMV = 1, ALGO_TYPE_GEMV = 1,


+ 5
- 4
dnn/src/rocm/batched_matrix_mul/opr_impl.cpp View File

@@ -6,10 +6,11 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "hcc_detail/hcc_defs_prologue.h"
#include "./opr_impl.h" #include "./opr_impl.h"
#include "hcc_detail/hcc_defs_prologue.h"


#include "src/common/utils.cuh" #include "src/common/utils.cuh"
#include "src/rocm/handle.h" #include "src/rocm/handle.h"
@@ -92,8 +93,8 @@ void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
static_cast<const rocblas_half*>(A.raw_ptr), static_cast<const rocblas_half*>(A.raw_ptr),
A.layout.stride[1], A.layout.stride[0], A.layout.stride[1], A.layout.stride[0],
reinterpret_cast<const rocblas_half*>(zero_half), reinterpret_cast<const rocblas_half*>(zero_half),
static_cast<rocblas_half*>(C.raw_ptr),
C.layout.stride[1], C.layout.stride[0], batch));
static_cast<rocblas_half*>(C.raw_ptr), C.layout.stride[1],
C.layout.stride[0], batch));


}; };
#endif #endif


+ 1
- 0
dnn/src/rocm/convolution/backward_data/algo.h View File

@@ -25,6 +25,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; }
struct SizeArgs { struct SizeArgs {
HandleImpl* handle; HandleImpl* handle;
CanonizedFilterMeta filter_meta; CanonizedFilterMeta filter_meta;


+ 1
- 0
dnn/src/rocm/convolution/backward_filter/algo.h View File

@@ -26,6 +26,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; }
struct SizeArgs { struct SizeArgs {
HandleImpl* handle; HandleImpl* handle;
const TensorLayout *src_layout, *diff_layout; const TensorLayout *src_layout, *diff_layout;


+ 1
- 0
dnn/src/rocm/convolution/forward/algo.h View File

@@ -32,6 +32,7 @@ protected:
~AlgoBase() = default; ~AlgoBase() = default;


public: public:
AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ROCM; }
struct SizeArgs : public convolution::ForwardSizeArgs { struct SizeArgs : public convolution::ForwardSizeArgs {
ConvolutionForwardImpl* opr; ConvolutionForwardImpl* opr;




+ 0
- 7
dnn/src/x86/conv_bias/f32/algos.h View File

@@ -47,8 +47,6 @@ public:
return get_kimpls(param); return get_kimpls(param);
} }


void* type() const override;

ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
} }
@@ -84,8 +82,6 @@ public:
return get_kimpls(param); return get_kimpls(param);
} }


void* type() const override;

ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
} }
@@ -103,7 +99,6 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
void* type() const override;
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };


@@ -119,7 +114,6 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
void* type() const override;
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };


@@ -161,7 +155,6 @@ public:
}; };
return {{kern, {1_z, 1_z, 1_z}}}; return {{kern, {1_z, 1_z, 1_z}}};
} }
void* type() const override;


ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};


+ 0
- 7
dnn/src/x86/conv_bias/int8/algos.h View File

@@ -32,7 +32,6 @@ public:
const NCBKernSizeParam& param) const override { const NCBKernSizeParam& param) const override {
return get_kimpls(param); return get_kimpls(param);
} }
void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;


ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
@@ -57,7 +56,6 @@ public:
const NCBKernSizeParam& param) const override { const NCBKernSizeParam& param) const override {
return get_kimpls(param); return get_kimpls(param);
} }
void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;


ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
@@ -82,7 +80,6 @@ public:
const NCBKernSizeParam& param) const override { const NCBKernSizeParam& param) const override {
return get_kimpls(param); return get_kimpls(param);
} }
void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;


ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
@@ -107,7 +104,6 @@ public:
const NCBKernSizeParam& param) const override { const NCBKernSizeParam& param) const override {
return get_kimpls(param); return get_kimpls(param);
} }
void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;


ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
@@ -148,7 +144,6 @@ public:
}; };
return {{kern, {group, n, 1_z}}}; return {{kern, {group, n, 1_z}}};
} }
void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;


ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
@@ -179,8 +174,6 @@ public:
//! select matmul to the highest preference //! select matmul to the highest preference
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;


void* type() const override;

ConvAlgoTypePack get_algo_type() const override { ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
} }


+ 10
- 50
dnn/src/x86/conv_bias/opr_impl.cpp View File

@@ -22,54 +22,14 @@


using namespace megdnn; using namespace megdnn;
using namespace x86; using namespace x86;

namespace { namespace {
uint8_t x86_algo_type_storage;
void* x86_algo_type = &x86_algo_type_storage;
} // anonymous namespace
#if MEGDNN_X86_WITH_MKL_DNN
void* ConvBiasImpl::AlgoMkldnnQint8::type() const {
return x86_algo_type;
}
void* ConvBiasImpl::AlgoMkldnnMatmulQint8::type() const {
return x86_algo_type;
}
void* ConvBiasImpl::AlgoMkldnnConv::type() const {
return x86_algo_type;
}
#endif

void* ConvBiasImpl::AlgoDirect::type() const {
return x86_algo_type;
}

void* ConvBiasImpl::AlgoDirectStride2::type() const {
return x86_algo_type;
}


void* ConvBiasImpl::AlgoDirectAvx2Stride1Int8::type() const {
return x86_algo_type;
bool is_fallback_or_naive(const detail::Algorithm* algo) {
return algo->handle_type() == Handle::HandleType::NAIVE ||
algo->handle_type() == Handle::HandleType::FALLBACK;
} }


void* ConvBiasImpl::AlgoFP32WinogradF63_8x8::type() const {
return x86_algo_type;
}

void* ConvBiasImpl::AlgoFP32WinogradF23_8x8::type() const {
return x86_algo_type;
}

void* ConvBiasImpl::AlgoAVX2DirectConvStride2::type() const {
return x86_algo_type;
}

void* ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8::type() const {
return x86_algo_type;
}

void* ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8::type() const {
return x86_algo_type;
}
} // anonymous namespace


class ConvBiasImpl::AlgoPack : NonCopyableObj { class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDirect stride1_direct; AlgoDirect stride1_direct;
@@ -88,8 +48,8 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {


public: public:
AlgoPack() { AlgoPack() {
//! FIXME: preference to use mkldnn algo on VNNI devices
//! But now mkldnn algo preference issue with NCHW->NHWC->NCHW
//! FIXME: preference to use mkldnn algo on VNNI devices
//! But now mkldnn algo preference issue with NCHW->NHWC->NCHW
#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
//! Create the mkldnn algo //! Create the mkldnn algo
all_algos.emplace_back(&mkldnn_conv_fp32); all_algos.emplace_back(&mkldnn_conv_fp32);
@@ -108,7 +68,7 @@ public:
auto&& matmul_algos = auto&& matmul_algos =
static_cast<MatrixMulImpl*>(matmul_opr)->algo_pack(); static_cast<MatrixMulImpl*>(matmul_opr)->algo_pack();
for (auto&& algo : matmul_algos) { for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
if (is_fallback_or_naive(algo))
continue; continue;
for (uint32_t tile_size : {8, 16, 24}) { for (uint32_t tile_size : {8, 16, 24}) {
refhold.emplace_back(new AlgoFP32WinogradF63_8x8( refhold.emplace_back(new AlgoFP32WinogradF63_8x8(
@@ -126,7 +86,7 @@ public:
SmallVector<AlgoBase*> winograd_algos; SmallVector<AlgoBase*> winograd_algos;
}; };


SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack; static AlgoPack sl_algo_pack;
auto&& algos = fallback::ConvBiasImpl::algo_pack(); auto&& algos = fallback::ConvBiasImpl::algo_pack();
algos.insert(algos.begin(), sl_algo_pack.all_algos.begin(), algos.insert(algos.begin(), sl_algo_pack.all_algos.begin(),
@@ -176,8 +136,8 @@ bool ConvBiasImpl::is_matmul_quantized_prefer(
!chanwise_avx2_stride2_qint8_usable_preferred(param)); !chanwise_avx2_stride2_qint8_usable_preferred(param));
} }


SmallVector<AlgoCategory>
ConvBiasImpl::suggest_algo_category_order(const NCBKernSizeParam& param) const {
SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order(
const NCBKernSizeParam& param) const {
auto IC = param.filter_meta.icpg; auto IC = param.filter_meta.icpg;
auto OC = param.filter_meta.ocpg; auto OC = param.filter_meta.ocpg;
auto FH = param.filter_meta.spatial[0]; auto FH = param.filter_meta.spatial[0];


+ 7
- 2
dnn/src/x86/conv_bias/opr_impl.h View File

@@ -20,10 +20,15 @@ namespace x86 {
class ConvBiasImpl : public fallback::ConvBiasImpl { class ConvBiasImpl : public fallback::ConvBiasImpl {
public: public:
using fallback::ConvBiasImpl::ConvBiasImpl; using fallback::ConvBiasImpl::ConvBiasImpl;
using FallbackConvBiasImpl = fallback::ConvBiasImpl;
class AlgoBase : public fallback::ConvBiasImpl::AlgoBase {
public:
AlgoBase() : fallback::ConvBiasImpl::AlgoBase() {
m_handle_type = Handle::HandleType::X86;
}
};


bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }
SmallVector<AlgoBase*> algo_pack() override;
SmallVector<fallback::ConvBiasImpl::AlgoBase*> algo_pack() override;
SmallVector<AlgoCategory> suggest_algo_category_order( SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;




+ 0
- 10
dnn/src/x86/matrix_mul/algos.h View File

@@ -25,7 +25,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; } size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
}; };
@@ -38,7 +37,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; } size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::ONLY_PACKA; } PackMode packmode() const override { return PackMode::ONLY_PACKA; }
kern_naked_t get_kern_naked(const KernSizeParam&) const override; kern_naked_t get_kern_naked(const KernSizeParam&) const override;
void pack_A(const KernParam& kern_param, void* out, size_t index, void pack_A(const KernParam& kern_param, void* out, size_t index,
@@ -60,7 +58,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -71,7 +68,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -86,7 +82,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
@@ -102,7 +97,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
bool preferred(const KernSizeParam&) const override; bool preferred(const KernSizeParam&) const override;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
@@ -114,7 +108,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };


@@ -125,7 +118,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4, AlgoDataType::FLOAT32, MK8) MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4, AlgoDataType::FLOAT32, MK8)
}; };
@@ -138,7 +130,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
}; };
#endif #endif
@@ -151,7 +142,6 @@ public:
bool usable(const KernSizeParam&) const override; bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override { return 0; } size_t get_workspace(const KernSizeParam&) const override { return 0; }
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT)
}; };


+ 2
- 8
dnn/src/x86/matrix_mul/opr_impl.cpp View File

@@ -16,12 +16,6 @@
using namespace megdnn; using namespace megdnn;
using namespace x86; using namespace x86;


namespace {
uint8_t x86_algo_type_storage;
} // anonymous namespace

void* const MatrixMulImpl::sm_x86_algo_type = &x86_algo_type_storage;

class MatrixMulImpl::AlgoPack : NonCopyableObj { class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF32Blas f32blas; AlgoF32Blas f32blas;


@@ -62,10 +56,10 @@ public:
all_algos.emplace_back(&f32mkl_packa); all_algos.emplace_back(&f32mkl_packa);
#endif #endif
} }
SmallVector<AlgoBase*> all_algos;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> all_algos;
}; };


SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
SmallVector<fallback::MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
static AlgoPack s_algo_pack; static AlgoPack s_algo_pack;
auto&& algos = fallback::MatrixMulImpl::algo_pack(); auto&& algos = fallback::MatrixMulImpl::algo_pack();
algos.insert(algos.begin(), s_algo_pack.all_algos.begin(), algos.insert(algos.begin(), s_algo_pack.all_algos.begin(),


+ 7
- 2
dnn/src/x86/matrix_mul/opr_impl.h View File

@@ -33,13 +33,18 @@ namespace x86 {
class MatrixMulImpl : public fallback::MatrixMulImpl { class MatrixMulImpl : public fallback::MatrixMulImpl {
public: public:
using fallback::MatrixMulImpl::MatrixMulImpl; using fallback::MatrixMulImpl::MatrixMulImpl;
class AlgoBase : public fallback::MatrixMulImpl::AlgoBase {
public:
AlgoBase() : fallback::MatrixMulImpl::AlgoBase() {
m_handle_type = Handle::HandleType::X86;
}
};


bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }


SmallVector<AlgoBase*> algo_pack() override;
SmallVector<fallback::MatrixMulImpl::AlgoBase*> algo_pack() override;


protected: protected:
static void* const sm_x86_algo_type;
class AlgoF32Blas; class AlgoF32Blas;
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
class AlgoF32MKLPackA; class AlgoF32MKLPackA;


Loading…
Cancel
Save