You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.h 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. /**
  2. * \file dnn/src/arm_common/matrix_mul/algos.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "src/arm_common/matrix_mul/opr_impl.h"
  13. #include "src/fallback/matrix_mul/gemm_common.h"
  14. namespace megdnn {
  15. namespace arm_common {
  16. class MatrixMulImpl::AlgoInt8x8x16 final : public AlgoBase {
  17. public:
  18. AlgoAttribute attribute() const override {
  19. return AlgoAttribute::REPRODUCIBLE;
  20. }
  21. const char* name() const override { return "ARM_COMMON_INT8X8X16"; }
  22. bool usable(const KernSizeParam&) const override;
  23. size_t get_workspace(const KernSizeParam&) const override;
  24. kern_t get_kern(const KernSizeParam&) const override;
  25. PackMode packmode() const override { return PackMode::NO_PACK; }
  26. MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT)
  27. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X16)
  28. };
  29. class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase {
  30. public:
  31. AlgoAttribute attribute() const override {
  32. return AlgoAttribute::REPRODUCIBLE;
  33. }
  34. const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; }
  35. bool usable(const KernSizeParam&) const override;
  36. bool preferred(const KernSizeParam&) const override;
  37. size_t get_workspace(const KernSizeParam&) const override { return 0; }
  38. kern_t get_kern(const KernSizeParam&) const override;
  39. AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
  40. PackMode packmode() const override { return PackMode::NO_PACK; }
  41. MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT)
  42. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV)
  43. };
  44. class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase {
  45. public:
  46. AlgoAttribute attribute() const override {
  47. return AlgoAttribute::REPRODUCIBLE;
  48. }
  49. const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; }
  50. bool usable(const KernSizeParam&) const override;
  51. bool preferred(const KernSizeParam&) const override;
  52. size_t get_workspace(const KernSizeParam&) const override { return 0; }
  53. kern_t get_kern(const KernSizeParam&) const override;
  54. AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
  55. PackMode packmode() const override { return PackMode::NO_PACK; }
  56. MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4)
  57. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4)
  58. };
  59. #if __ARM_FEATURE_DOTPROD
  60. class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase {
  61. public:
  62. AlgoAttribute attribute() const override {
  63. return AlgoAttribute::REPRODUCIBLE;
  64. }
  65. const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; }
  66. bool usable(const KernSizeParam&) const override;
  67. bool preferred(const KernSizeParam&) const override;
  68. size_t get_workspace(const KernSizeParam&) const override { return 0; }
  69. kern_t get_kern(const KernSizeParam&) const override;
  70. AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
  71. PackMode packmode() const override { return PackMode::NO_PACK; }
  72. MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT)
  73. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4_DOT)
  74. };
  75. #endif
  76. class MatrixMulImpl::AlgoF32Gemv : public AlgoBase {
  77. protected:
  78. ~AlgoF32Gemv() = default;
  79. public:
  80. AlgoAttribute attribute() const override {
  81. return AlgoAttribute::REPRODUCIBLE;
  82. }
  83. const char* name() const override { return "ARM_COMMON_F32_GEMV"; }
  84. bool usable(const KernSizeParam&) const override;
  85. bool preferred(const KernSizeParam&) const override;
  86. size_t get_workspace(const KernSizeParam&) const override { return 0; }
  87. kern_t get_kern(const KernSizeParam&) const override;
  88. AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
  89. PackMode packmode() const override { return PackMode::NO_PACK; }
  90. MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
  91. };
  92. class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase {
  93. public:
  94. AlgoAttribute attribute() const override {
  95. return AlgoAttribute::REPRODUCIBLE;
  96. }
  97. const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; }
  98. bool usable(const KernSizeParam&) const override;
  99. bool preferred(const KernSizeParam&) const override;
  100. size_t get_workspace(const KernSizeParam&) const override { return 0; }
  101. kern_t get_kern(const KernSizeParam&) const override;
  102. AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
  103. PackMode packmode() const override { return PackMode::NO_PACK; }
  104. MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4)
  105. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F32_GEMV_MK4)
  106. };
  107. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  108. class MatrixMulImpl::AlgoF16Gemv : public AlgoBase {
  109. public:
  110. AlgoAttribute attribute() const override {
  111. return AlgoAttribute::REPRODUCIBLE;
  112. }
  113. const char* name() const override { return "ARM_COMMON_F16_GEMV"; }
  114. bool usable(const KernSizeParam&) const override;
  115. bool preferred(const KernSizeParam&) const override;
  116. size_t get_workspace(const KernSizeParam&) const override { return 0; }
  117. kern_t get_kern(const KernSizeParam&) const override;
  118. AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
  119. PackMode packmode() const override { return PackMode::NO_PACK; }
  120. MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT)
  121. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_F16_GEMV)
  122. };
  123. #endif
  124. class MatrixMulImpl::AlgoGevm : public AlgoBase {
  125. public:
  126. AlgoAttribute attribute() const override {
  127. return AlgoAttribute::REPRODUCIBLE;
  128. }
  129. const char* name() const override { return "ARM_COMMON_GEVM"; }
  130. bool usable(const KernSizeParam&) const override;
  131. bool preferred(const KernSizeParam&) const override;
  132. size_t get_workspace(const KernSizeParam&) const override { return 0; }
  133. kern_t get_kern(const KernSizeParam&) const override;
  134. AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
  135. PackMode packmode() const override { return PackMode::NO_PACK; }
  136. MEGDNN_OVERRIDE_MATMUL_DESC(
  137. 1, 1, 1, 4,
  138. static_cast<AlgoDataType>(
  139. static_cast<uint32_t>(AlgoDataType::FLOAT16) |
  140. static_cast<uint32_t>(AlgoDataType::FLOAT32) |
  141. static_cast<uint32_t>(AlgoDataType::QINT8X8X32)),
  142. DEFAULT)
  143. MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_GEVM)
  144. };
  145. } // namespace arm_common
  146. } // namespace megdnn
  147. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台