You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.h 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. /**
  2. * \file dnn/src/fallback/conv_bias/common.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include <stdint.h>
  13. #include "megdnn/oprs.h"
  14. #include "src/common/utils.h"
  15. namespace megdnn {
  16. using NonlineMode = ConvBias::Param::NonlineMode;
  17. using BiasMode = ConvBiasForward::BiasMode;
  18. #define DISPATCH_GEMM_NONLINE(_gemm, _gemm_midout_enum, _bias, \
  19. _bias_midout_enum) \
  20. switch (param.nonlineMode) { \
  21. case param::ConvBias::NonlineMode::IDENTITY: { \
  22. DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
  23. _bias_midout_enum, identity, 0); \
  24. break; \
  25. } \
  26. case param::ConvBias::NonlineMode::RELU: { \
  27. DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
  28. _bias_midout_enum, relu, 1); \
  29. break; \
  30. } \
  31. case param::ConvBias::NonlineMode::H_SWISH: { \
  32. DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
  33. _bias_midout_enum, hswish, 2); \
  34. break; \
  35. } \
  36. default: \
  37. megdnn_assert(0); \
  38. break; \
  39. }
  40. #define DISPATCH_GEMM_BIAS(_gemm, _gemm_midout_enum) \
  41. switch (param.bias_mode) { \
  42. case BiasMode::NO_BIAS: \
  43. DISPATCH_GEMM_NONLINE(_gemm, _gemm_midout_enum, nobias, 0) \
  44. break; \
  45. case BiasMode::BROADCAST_CHANNEL_BIAS: \
  46. DISPATCH_GEMM_NONLINE(_gemm, _gemm_midout_enum, bias_channel, 1) \
  47. break; \
  48. default: \
  49. megdnn_assert(0); \
  50. break; \
  51. }
  52. #define DISPATCH_CONV_NONLINE(i, midout_tag, stride, _conv, BIAS_MODE, \
  53. dst_type) \
  54. switch (param.nonlineMode) { \
  55. case param::ConvBias::NonlineMode::IDENTITY: { \
  56. DISPATCH_CONV_STRATEGY(i, midout_tag, stride, _conv, BIAS_MODE, \
  57. TypeCvtOp<dt_qint32 MEGDNN_COMMA dst_type>, \
  58. 0); \
  59. break; \
  60. } \
  61. case param::ConvBias::NonlineMode::RELU: { \
  62. DISPATCH_CONV_STRATEGY(i, midout_tag, stride, _conv, BIAS_MODE, \
  63. ReluOp<dt_qint32 MEGDNN_COMMA dst_type>, \
  64. 1); \
  65. break; \
  66. } \
  67. case param::ConvBias::NonlineMode::H_SWISH: { \
  68. DISPATCH_CONV_STRATEGY(i, midout_tag, stride, _conv, BIAS_MODE, \
  69. HSwishOp<dt_qint32 MEGDNN_COMMA dst_type>, \
  70. 2); \
  71. break; \
  72. } \
  73. default: \
  74. megdnn_assert(0); \
  75. break; \
  76. }
  77. #define DISPATCH_CONV_BIAS(i, midout_tag, stride, _conv, dst_type) \
  78. switch (param.bias_mode) { \
  79. case BiasMode::NO_BIAS: \
  80. DISPATCH_CONV_NONLINE(i, midout_tag, stride, _conv, \
  81. BiasMode::NO_BIAS, dst_type) \
  82. break; \
  83. case BiasMode::BROADCAST_CHANNEL_BIAS: \
  84. DISPATCH_CONV_NONLINE(i, midout_tag, stride, _conv, \
  85. BiasMode::BROADCAST_CHANNEL_BIAS, dst_type) \
  86. break; \
  87. default: \
  88. megdnn_assert(0); \
  89. break; \
  90. }
  91. #define DISPATCH_CONV_STRATEGY(i, midout_tag, stride, conv, BIAS_MODE, Op, \
  92. _nonline_midout_enum) \
  93. MIDOUT_BEGIN(midout_tag, i, stride, midout_iv(BIAS_MODE), \
  94. _nonline_midout_enum) { \
  95. return {{conv<i, BIAS_MODE, Op>, {1_z, 1_z, 1_z}}}; \
  96. } \
  97. MIDOUT_END()
  98. #define DISPATCH_FILTER(filter, kern, arg...) \
  99. switch (filter) { \
  100. case 2: \
  101. kern(2, ##arg); \
  102. break; \
  103. case 3: \
  104. kern(3, ##arg); \
  105. break; \
  106. case 5: \
  107. kern(5, ##arg); \
  108. break; \
  109. case 7: \
  110. kern(7, ##arg); \
  111. break; \
  112. default: \
  113. megdnn_assert(0); \
  114. break; \
  115. }
  116. #define DISPATCH_FILTER_CHANNEL_WISE(filter, kern, arg...) \
  117. switch (filter) { \
  118. case 2: \
  119. kern(2, ##arg); \
  120. break; \
  121. case 3: \
  122. kern(3, ##arg); \
  123. break; \
  124. case 5: \
  125. kern(5, ##arg); \
  126. break; \
  127. default: \
  128. megdnn_assert(0); \
  129. break; \
  130. }
  131. #define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE() \
  132. bool is_reproducible() const override { return true; } \
  133. bool usable(const NCBKernSizeParam& param, \
  134. AlgoSelectionStrategy algo_selection_strategy) const override; \
  135. size_t get_workspace(const NCBKernSizeParam& param) const override; \
  136. virtual SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) \
  137. const override; \
  138. SmallVector<TensorLayout> deduce_preprocessed_filter_layout( \
  139. const NCBKernSizeParam& param) const override; \
  140. size_t get_preprocess_workspace(const NCBKernSizeParam& param) \
  141. const override; \
  142. virtual SmallVector<NCBKern> dispatch_preprocess_kerns( \
  143. const NCBKernSizeParam& param) const override; \
  144. \
  145. private: \
  146. fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; \
  147. mutable std::string m_name; \
  148. uint32_t m_tile_size;
  149. enum class PostprocessMode : uint8_t {
  150. FLOAT = 0, ///< support all biasmode and no_nonlinemode
  151. NO_PROCESS, ///<support non bias and identity
  152. QUANTIZED,///<support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish identify nonline mode
  153. };
  154. } // namespace megdnn
  155. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台