You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.h 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #pragma once
  2. #include "megdnn/thin/small_vector.h"
  3. #include "src/fallback/conv_bias/opr_impl.h"
  4. #include "src/fallback/matrix_mul/opr_impl.h"
  5. namespace megdnn {
  6. namespace fallback {
  7. class ConvBiasImpl::AlgoNaive final : public AlgoBase {
  8. public:
  9. AlgoAttribute attribute() const override {
  10. return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
  11. }
  12. const char* name() const override { return "FALLBACK_NAIVE"; }
  13. bool usable(
  14. const NCBKernSizeParam& param,
  15. AlgoSelectionStrategy algo_selection_strategy) const override;
  16. size_t get_workspace(const NCBKernSizeParam& param) const override;
  17. SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
  18. ConvAlgoTypePack get_algo_type() const override {
  19. auto support_data_type = static_cast<AlgoDataType>(
  20. static_cast<uint32_t>(AlgoDataType::FLOAT16) |
  21. static_cast<uint32_t>(AlgoDataType::FLOAT32) |
  22. static_cast<uint32_t>(AlgoDataType::INT8X8X16) |
  23. static_cast<uint32_t>(AlgoDataType::QINT8X8X32) |
  24. static_cast<uint32_t>(AlgoDataType::QUINT8X8X32));
  25. return {support_data_type, AlgoCategory::NAIVE};
  26. }
  27. MEGDNN_DECL_ALGO_TYPE(FB_NAIVE)
  28. };
  29. class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase {
  30. public:
  31. AlgoWinogradF32(MatrixMulImpl::AlgoBase* matmul_algo)
  32. : m_matmul_algo{matmul_algo} {}
  33. AlgoAttribute attribute() const override {
  34. return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
  35. }
  36. const char* name() const override {
  37. if (m_name.empty()) {
  38. m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
  39. ssprintf("FALLBACK_WINOGRAD_F32-%s", m_matmul_algo->name()),
  40. {1, 2, UNIT_TILE_SIZE, 3});
  41. }
  42. return m_name.c_str();
  43. }
  44. bool usable(
  45. const NCBKernSizeParam& param,
  46. AlgoSelectionStrategy algo_selection_strategy) const override;
  47. size_t get_workspace(const NCBKernSizeParam& param) const override;
  48. SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
  49. ConvAlgoTypePack get_algo_type() const override {
  50. return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD};
  51. }
  52. MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_F32)
  53. private:
  54. MatrixMulImpl::AlgoBase* m_matmul_algo;
  55. mutable std::string m_name;
  56. constexpr size_t static UNIT_TILE_SIZE = 32;
  57. };
  58. class ConvBiasImpl::AlgoWinogradF32_4x4 final : public AlgoBase {
  59. public:
  60. AlgoWinogradF32_4x4(MatrixMulImpl::AlgoBase* matmul_algo)
  61. : m_matmul_algo{matmul_algo} {}
  62. AlgoAttribute attribute() const override {
  63. return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
  64. }
  65. const char* name() const override {
  66. if (m_name.empty()) {
  67. m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
  68. ssprintf("FALLBACK_WINOGRAD_F32-%s", m_matmul_algo->name()),
  69. {4, 2, UNIT_TILE_SIZE, 3});
  70. }
  71. return m_name.c_str();
  72. }
  73. bool usable(
  74. const NCBKernSizeParam& param,
  75. AlgoSelectionStrategy algo_selection_strategy) const override;
  76. size_t get_workspace(const NCBKernSizeParam& param) const override;
  77. SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
  78. ConvAlgoTypePack get_algo_type() const override {
  79. return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD};
  80. }
  81. MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_4X4_F32)
  82. private:
  83. MatrixMulImpl::AlgoBase* m_matmul_algo;
  84. mutable std::string m_name;
  85. constexpr size_t static UNIT_TILE_SIZE = 32;
  86. };
  87. class ConvBiasImpl::AlgoWinogradQS8 final : public AlgoBase {
  88. public:
  89. AlgoWinogradQS8(MatrixMulImpl::AlgoBase* matmul_algo)
  90. : m_matmul_algo{matmul_algo} {}
  91. AlgoAttribute attribute() const override {
  92. return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
  93. }
  94. const char* name() const override {
  95. if (m_name.empty()) {
  96. m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
  97. ssprintf("FALLBACK_WINOGRAD_QS8-%s", m_matmul_algo->name()),
  98. {1, 2, UNIT_TILE_SIZE, 3});
  99. }
  100. return m_name.c_str();
  101. }
  102. bool usable(
  103. const NCBKernSizeParam& param,
  104. AlgoSelectionStrategy algo_selection_strategy) const override;
  105. size_t get_workspace(const NCBKernSizeParam& param) const override;
  106. SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
  107. ConvAlgoTypePack get_algo_type() const override {
  108. return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD};
  109. }
  110. MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_QS8)
  111. private:
  112. MatrixMulImpl::AlgoBase* m_matmul_algo;
  113. mutable std::string m_name;
  114. constexpr size_t static UNIT_TILE_SIZE = 32;
  115. };
  116. class ConvBiasImpl::AlgoWinogradQS8_8x8 final : public AlgoBase {
  117. public:
  118. AlgoWinogradQS8_8x8(MatrixMulImpl::AlgoBase* matmul_algo)
  119. : m_matmul_algo{matmul_algo} {}
  120. AlgoAttribute attribute() const override {
  121. return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::NAIVE;
  122. }
  123. const char* name() const override {
  124. if (m_name.empty()) {
  125. m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
  126. ssprintf("FALLBACK_WINOGRAD_QS8-%s", m_matmul_algo->name()),
  127. {8, 2, UNIT_TILE_SIZE, 3});
  128. }
  129. return m_name.c_str();
  130. }
  131. bool usable(
  132. const NCBKernSizeParam& param,
  133. AlgoSelectionStrategy algo_selection_strategy) const override;
  134. size_t get_workspace(const NCBKernSizeParam& param) const override;
  135. SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;
  136. ConvAlgoTypePack get_algo_type() const override {
  137. return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD};
  138. }
  139. MEGDNN_DECL_ALGO_TYPE(FB_WINOGRAD_8X8_QS8)
  140. private:
  141. MatrixMulImpl::AlgoBase* m_matmul_algo;
  142. mutable std::string m_name;
  143. constexpr size_t static UNIT_TILE_SIZE = 32;
  144. };
  145. } // namespace fallback
  146. } // namespace megdnn
  147. // vim: syntax=cpp.doxygen