You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo.cpp 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. /**
  2. * \file dnn/src/arm_common/elemwise/binary/algo.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/arm_common/elemwise/binary/algo.h"
  13. #include "src/arm_common/elemwise_op.h"
  14. #include "src/common/utils.h"
  15. #include "src/naive/handle.h"
  16. #include "midout.h"
  17. MIDOUT_DECL(megdnn_arm_common_elemwise_binary)
  18. using namespace megdnn;
  19. using namespace arm_common;
  20. namespace {
  21. static inline bool is_available_common(Elemwise::Mode mode) {
  22. /**
  23. * Fused sigmoid & tanh may be slower than the naive algo, because the
  24. * time used by neon function `exp_ps_f32` is decided by the input.
  25. */
  26. if (mode == Elemwise::Mode::FUSE_ADD_SIGMOID ||
  27. mode == Elemwise::Mode::FUSE_ADD_TANH) {
  28. return false;
  29. }
  30. return true;
  31. }
  32. } // anonymous namespace
  33. #if MEGDNN_AARCH64
  34. #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
  35. auto mode = kern_param.mode; \
  36. if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \
  37. mode == Mode::SUB || mode == Mode::MUL || mode == Mode::POW || \
  38. mode == Mode::TRUE_DIV || mode == Mode::FUSE_ADD_RELU || \
  39. mode == Mode::FUSE_ADD_H_SWISH) \
  40. return true;
  41. #else
  42. #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
  43. auto mode = kern_param.mode; \
  44. if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \
  45. mode == Mode::SUB || mode == Mode::MUL || mode == Mode::POW || \
  46. mode == Mode::FUSE_ADD_RELU || mode == Mode::FUSE_ADD_H_SWISH) \
  47. return true;
  48. #endif
  49. #define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \
  50. auto mode = kern_param.mode; \
  51. if (mode == Mode::MIN || mode == Mode::MAX || mode == Mode::ADD || \
  52. mode == Mode::SUB || mode == Mode::MUL || mode == Mode::RMULH || \
  53. mode == Mode::FUSE_ADD_RELU) \
  54. return true;
  55. bool ElemwiseImpl::AlgoBinaryVecVec::is_available(
  56. const KernParam& kern_param) const {
  57. if (!is_available_common(kern_param.mode) ||
  58. (BcastType::VEC_VEC != kern_param.broad_cast_type))
  59. return false;
  60. auto& elparam = kern_param.binary_elparam;
  61. auto& src0 = elparam[0];
  62. //! exactly match [x, y] + [x, y]
  63. DISPATCH_TYPE("AlgoBinaryVecVec::is_available"_hash);
  64. return false;
  65. }
  66. bool ElemwiseImpl::AlgoBinaryVecScalar::is_available(
  67. const KernParam& kern_param) const {
  68. if (!is_available_common(kern_param.mode) ||
  69. ((BcastType::VEC_SCALAR != kern_param.broad_cast_type) &&
  70. (BcastType::SCALAR_VEC != kern_param.broad_cast_type)))
  71. return false;
  72. auto& elparam = kern_param.binary_elparam;
  73. auto& src0 = elparam[0];
  74. DISPATCH_TYPE("AlgoBinaryVecScalar::is_available"_hash);
  75. return false;
  76. }
  77. bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available(
  78. const KernParam& kern_param) const {
  79. if (!is_available_common(kern_param.mode) ||
  80. ((BcastType::VEC_BCAST101 != kern_param.broad_cast_type) &&
  81. (BcastType::BCAST101_VEC != kern_param.broad_cast_type)))
  82. return false;
  83. auto& elparam = kern_param.binary_elparam;
  84. auto& src0 = elparam[0];
  85. DISPATCH_TYPE("AlgoBinaryVecBcast101::is_available"_hash);
  86. return false;
  87. }
  88. bool ElemwiseImpl::AlgoBinaryVecBcast101x4::is_available(
  89. const KernParam& kern_param) const {
  90. if (!is_available_common(kern_param.mode) ||
  91. ((BcastType::VEC_BCAST101x4 != kern_param.broad_cast_type) &&
  92. (BcastType::BCAST101x4_VEC != kern_param.broad_cast_type)))
  93. return false;
  94. auto& elparam = kern_param.binary_elparam;
  95. auto& src0 = elparam[0];
  96. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  97. if (DNN_FLOAT16_SELECT(src0.layout.dtype == dtype::Float16{}, false)) {
  98. return false;
  99. }
  100. #endif
  101. DISPATCH_TYPE("AlgoBinaryVecBcast101x::is_available"_hash);
  102. return false;
  103. }
  104. #undef DISPATCH_MODE_FLOAT
  105. #undef DISPATCH_MODE_INT
  106. #if MEGDNN_AARCH64
  107. #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
  108. switch (kern_param.mode) { \
  109. DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \
  110. DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \
  111. DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
  112. DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
  113. DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
  114. DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \
  115. DISPATCH_BINARY(TRUE_DIV, _case, _type, _type_midout_id, TrueDivOp); \
  116. DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \
  117. FuseAddReluOp); \
  118. DISPATCH_BINARY(FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, \
  119. FuseAddHSwishOp); \
  120. default: \
  121. megdnn_throw(ssprintf("No avaiable algo find for: %d", \
  122. static_cast<int>(kern_param.mode))); \
  123. }
  124. #else
  125. #define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
  126. switch (kern_param.mode) { \
  127. DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \
  128. DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \
  129. DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
  130. DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
  131. DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
  132. DISPATCH_BINARY(POW, _case, _type, _type_midout_id, PowOp); \
  133. DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \
  134. FuseAddReluOp); \
  135. DISPATCH_BINARY(FUSE_ADD_H_SWISH, _case, _type, _type_midout_id, \
  136. FuseAddHSwishOp); \
  137. default: \
  138. megdnn_throw(ssprintf("No avaiable algo find for: %d", \
  139. static_cast<int>(kern_param.mode))); \
  140. }
  141. #endif
  142. #define DISPATCH_MODE_INT(_case, _type, _type_midout_id) \
  143. switch (kern_param.mode) { \
  144. DISPATCH_BINARY(MIN, _case, _type, _type_midout_id, MinOp); \
  145. DISPATCH_BINARY(MAX, _case, _type, _type_midout_id, MaxOp); \
  146. DISPATCH_BINARY(ADD, _case, _type, _type_midout_id, AddOp); \
  147. DISPATCH_BINARY(SUB, _case, _type, _type_midout_id, SubOp); \
  148. DISPATCH_BINARY(MUL, _case, _type, _type_midout_id, MulOp); \
  149. DISPATCH_BINARY(RMULH, _case, _type, _type_midout_id, RmulhOp); \
  150. DISPATCH_BINARY(FUSE_ADD_RELU, _case, _type, _type_midout_id, \
  151. FuseAddReluOp); \
  152. default: \
  153. megdnn_throw(ssprintf("No avaiable algo find for: %d", \
  154. static_cast<int>(kern_param.mode))); \
  155. }
  156. void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const {
  157. auto& elparam = kern_param.binary_elparam;
  158. auto &src0 = elparam[0], &src1 = elparam[1];
  159. //! exactly match [x, y] + [x, y]
  160. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  161. case Mode::_mode: \
  162. MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  163. midout_iv(Mode::_mode), _type_midout_id) { \
  164. thin_function<void(const _type*, const _type*, _type*, DType, \
  165. DType, DType, size_t)> \
  166. run = OpCallerBinary<_op<_type, _type>, \
  167. BcastType::VEC_VEC>::run; \
  168. MEGDNN_DISPATCH_CPU_KERN( \
  169. static_cast<naive::HandleImpl*>(kern_param.handle), \
  170. run(static_cast<const _type*>(src0.raw_ptr), \
  171. static_cast<const _type*>(src1.raw_ptr), \
  172. static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
  173. src1.layout.dtype, dst.layout.dtype, \
  174. src0.layout.total_nr_elems())); \
  175. } \
  176. MIDOUT_END(); \
  177. return
  178. auto&& dst = *(kern_param.m_dst);
  179. DISPATCH_TYPE("AlgoBinaryVecVec::exec"_hash);
  180. #undef DISPATCH_BINARY
  181. return;
  182. }
  183. void ElemwiseImpl::AlgoBinaryVecScalar::exec(
  184. const KernParam& kern_param) const {
  185. auto& elparam = kern_param.binary_elparam;
  186. auto &src0 = elparam[0], &src1 = elparam[1];
  187. auto&& dst = *(kern_param.m_dst);
  188. // Case 2: vector + scalar
  189. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  190. case Mode::_mode: \
  191. MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  192. midout_iv(Mode::_mode), _type_midout_id) { \
  193. thin_function<void(const _type*, const _type, _type*, DType, \
  194. DType, DType, size_t)> \
  195. run = OpCallerBinary<_op<_type, _type>, \
  196. BcastType::VEC_SCALAR>::run; \
  197. MEGDNN_DISPATCH_CPU_KERN( \
  198. static_cast<naive::HandleImpl*>(kern_param.handle), \
  199. run(static_cast<const _type*>(src0.raw_ptr), \
  200. static_cast<const _type*>(src1.raw_ptr)[0], \
  201. static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
  202. src1.layout.dtype, dst.layout.dtype, \
  203. src0.layout.total_nr_elems())); \
  204. } \
  205. MIDOUT_END(); \
  206. return
  207. if (BcastType::VEC_SCALAR == kern_param.broad_cast_type) {
  208. DISPATCH_TYPE("AlgoBinaryVecScalar::exec_vec_sca"_hash);
  209. }
  210. #undef DISPATCH_BINARY
  211. // scalar + vector
  212. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  213. case Mode::_mode: \
  214. MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  215. midout_iv(Mode::_mode), _type_midout_id) { \
  216. thin_function<void(const _type, const _type*, _type*, DType, \
  217. DType, DType, size_t)> \
  218. run = OpCallerBinary<_op<_type, _type>, \
  219. BcastType::SCALAR_VEC>::run; \
  220. MEGDNN_DISPATCH_CPU_KERN( \
  221. static_cast<naive::HandleImpl*>(kern_param.handle), \
  222. run(static_cast<const _type*>(src0.raw_ptr)[0], \
  223. static_cast<const _type*>(src1.raw_ptr), \
  224. static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
  225. src1.layout.dtype, dst.layout.dtype, \
  226. src1.layout.total_nr_elems())); \
  227. } \
  228. MIDOUT_END(); \
  229. return
  230. if (BcastType::SCALAR_VEC == kern_param.broad_cast_type) {
  231. DISPATCH_TYPE("AlgoBinaryVecScalar::exec_sca_vec"_hash);
  232. }
  233. #undef DISPATCH_BINARY
  234. return;
  235. }
  236. void ElemwiseImpl::AlgoBinaryVecBcast101::exec(
  237. const KernParam& kern_param) const {
  238. auto& elparam = kern_param.binary_elparam;
  239. auto &src0 = elparam[0], &src1 = elparam[1];
  240. auto&& dst = *(kern_param.m_dst);
  241. BroadcastChannelInfo binfo;
  242. // Case 3: BcastType::VEC + BCAST_101
  243. if (BcastType::VEC_BCAST101 == kern_param.broad_cast_type &&
  244. is_broadcasted_channel_like(src1.layout, binfo)) {
  245. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  246. case Mode::_mode: \
  247. MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  248. midout_iv(Mode::_mode), _type_midout_id) { \
  249. thin_function<void(const _type*, const _type*, _type*, DType, \
  250. DType, DType, size_t, size_t, size_t)> \
  251. run = OpCallerBinary<_op<_type, _type>, \
  252. BcastType::VEC_BCAST101>::run; \
  253. MEGDNN_DISPATCH_CPU_KERN( \
  254. static_cast<naive::HandleImpl*>(kern_param.handle), \
  255. run(static_cast<const _type*>(src0.raw_ptr), \
  256. static_cast<const _type*>(src1.raw_ptr), \
  257. static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
  258. src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
  259. binfo.z)); \
  260. } \
  261. MIDOUT_END(); \
  262. return
  263. DISPATCH_TYPE("AlgoBinaryVecBcast101::exec_vec_b"_hash);
  264. #undef DISPATCH_BINARY
  265. }
  266. // BCAST_101 + BcastType::VEC
  267. if (BcastType::BCAST101_VEC == kern_param.broad_cast_type &&
  268. is_broadcasted_channel_like(src0.layout, binfo)) {
  269. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  270. case Mode::_mode: \
  271. MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  272. midout_iv(Mode::_mode), _type_midout_id) { \
  273. thin_function<void(const _type*, const _type*, _type*, DType, \
  274. DType, DType, size_t, size_t, size_t)> \
  275. run = OpCallerBinary<_op<_type, _type>, \
  276. BcastType::BCAST101_VEC>::run; \
  277. MEGDNN_DISPATCH_CPU_KERN( \
  278. static_cast<naive::HandleImpl*>(kern_param.handle), \
  279. run(static_cast<const _type*>(src0.raw_ptr), \
  280. static_cast<const _type*>(src1.raw_ptr), \
  281. static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
  282. src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
  283. binfo.z)); \
  284. } \
  285. MIDOUT_END(); \
  286. return
  287. DISPATCH_TYPE("AlgoBinaryVecBcast101::exec_b_vec"_hash);
  288. #undef DISPATCH_BINARY
  289. }
  290. return;
  291. }
  292. void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
  293. const KernParam& kern_param) const {
  294. auto& elparam = kern_param.binary_elparam;
  295. auto &src0 = elparam[0], &src1 = elparam[1];
  296. auto&& dst = *(kern_param.m_dst);
  297. BroadcastChannelInfo binfo;
  298. // BcastType::VEC + BCAST_101x
  299. if (BcastType::VEC_BCAST101x4 == kern_param.broad_cast_type &&
  300. is_broadcastedx_channel_like<4>(src1.layout, binfo)) {
  301. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  302. case Mode::_mode: \
  303. MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  304. midout_iv(Mode::_mode), _type_midout_id) { \
  305. thin_function<void(const _type*, const _type*, _type*, DType, \
  306. DType, DType, size_t, size_t, size_t, size_t)> \
  307. run = OpCallerBinary<_op<_type, _type>, \
  308. BcastType::VEC_BCAST101x4>::run; \
  309. MEGDNN_DISPATCH_CPU_KERN( \
  310. static_cast<naive::HandleImpl*>(kern_param.handle), \
  311. run(static_cast<const _type*>(src0.raw_ptr), \
  312. static_cast<const _type*>(src1.raw_ptr), \
  313. static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
  314. src1.layout.dtype, dst.layout.dtype, batch_size, \
  315. binfo.x, binfo.y, binfo.z)); \
  316. } \
  317. MIDOUT_END(); \
  318. return
  319. size_t batch_size =
  320. src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
  321. DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_vec_b"_hash);
  322. #undef DISPATCH_BINARY
  323. }
  324. // BCAST_101x + BcastType::VEC
  325. if (BcastType::BCAST101x4_VEC == kern_param.broad_cast_type &&
  326. is_broadcastedx_channel_like<4>(src0.layout, binfo)) {
  327. #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
  328. case Mode::_mode: \
  329. MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
  330. midout_iv(Mode::_mode), _type_midout_id) { \
  331. thin_function<void(const _type*, const _type*, _type*, DType, \
  332. DType, DType, size_t, size_t, size_t, size_t)> \
  333. run = OpCallerBinary<_op<_type, _type>, \
  334. BcastType::BCAST101x4_VEC>::run; \
  335. MEGDNN_DISPATCH_CPU_KERN( \
  336. static_cast<naive::HandleImpl*>(kern_param.handle), \
  337. run(static_cast<const _type*>(src0.raw_ptr), \
  338. static_cast<const _type*>(src1.raw_ptr), \
  339. static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
  340. src1.layout.dtype, dst.layout.dtype, batch_size, \
  341. binfo.x, binfo.y, binfo.z)); \
  342. } \
  343. MIDOUT_END(); \
  344. return
  345. size_t batch_size =
  346. src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
  347. DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_b_vec"_hash);
  348. #undef DISPATCH_BINARY
  349. }
  350. return;
  351. }
  352. #undef DISPATCH_MODE_FLOAT
  353. #undef DISPATCH_MODE_INT
  354. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台