You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. /**
  2. * \file dnn/src/arm_common/elemwise/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/arm_common/elemwise/opr_impl.h"
  13. #include "src/arm_common/elemwise/binary/algo.h"
  14. #include "src/arm_common/elemwise/ternary/algo.h"
  15. #include "src/arm_common/elemwise/unary/algo.h"
  16. #include "src/arm_common/elemwise_op.h"
  17. #include "src/common/metahelper.h"
  18. #include "src/common/utils.h"
  19. #include "src/naive/handle.h"
  20. using namespace megdnn;
  21. using namespace arm_common;
  22. class ElemwiseImpl::AlgoPack {
  23. AlgoUnary algo_unary;
  24. AlgoBinaryVecVec algo_binary_vec_vec;
  25. AlgoBinaryVecScalar algo_binary_vec_sca;
  26. AlgoBinaryVecBcast101 algo_binary_vec_bcast101;
  27. AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX;
  28. AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec;
  29. AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca;
  30. AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101;
  31. AlgoTernaryFma3Bcast101xXVecBcast101xX
  32. algo_ternaryfma3_bcast101xX_vec_bcast101xX;
  33. AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec;
  34. AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec;
  35. AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec;
  36. AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca;
  37. public:
  38. AlgoPack() {
  39. all_algos.emplace_back(&algo_unary);
  40. all_algos.emplace_back(&algo_binary_vec_vec);
  41. all_algos.emplace_back(&algo_binary_vec_sca);
  42. all_algos.emplace_back(&algo_binary_vec_bcast101);
  43. all_algos.emplace_back(&algo_binary_VEC_BCAST101xX);
  44. all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec);
  45. all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca);
  46. all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101);
  47. all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX);
  48. all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec);
  49. all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec);
  50. all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec);
  51. all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca);
  52. }
  53. SmallVector<AlgoBase*> all_algos;
  54. };
  55. void ElemwiseImpl::exec(const TensorNDArray& srcs, _megdnn_tensor_out dst) {
  56. m_src = &srcs;
  57. m_dst = &dst;
  58. if (!dst.layout.is_contiguous()) {
  59. return fallback::ElemwiseImpl::exec(srcs, dst);
  60. }
  61. if (m_dst->layout.dtype == dtype::Float32() ||
  62. DNN_FLOAT16_SELECT(m_dst->layout.dtype == dtype::Float16(), false) ||
  63. m_dst->layout.dtype == dtype::Int32() ||
  64. m_dst->layout.dtype == dtype::Int16() ||
  65. m_dst->layout.dtype == dtype::Int8()) {
  66. auto kern_param = make_kern_param(this);
  67. kern_param.m_dst = &dst;
  68. static AlgoPack m_algo_pack;
  69. for (auto& m_algo : m_algo_pack.all_algos) {
  70. if (m_algo->is_available(kern_param)) {
  71. m_algo->exec(kern_param);
  72. return;
  73. }
  74. }
  75. }
  76. fallback::ElemwiseImpl::exec(srcs, dst);
  77. }
  78. ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
  79. KernParam kern_param;
  80. kern_param.broad_cast_type = BcastType::UNKNOWN_BCAST_TYPE;
  81. kern_param.mode = opr->param().mode;
  82. kern_param.handle = opr->handle();
  83. if ((opr->m_src->size() == 3) &&
  84. (opr->param().mode == Mode::FUSE_MUL_ADD3)) {
  85. kern_param.ternary_elparam = opr->make_elemwise_op_param<3>();
  86. bool c_is_scalar;
  87. opr->prepare_fma3(kern_param.ternary_elparam, c_is_scalar);
  88. auto &src0 = kern_param.ternary_elparam[0],
  89. &src1 = kern_param.ternary_elparam[1],
  90. &src2 = kern_param.ternary_elparam[2];
  91. BroadcastChannelInfo binfo;
  92. if (is_vector(src0.layout) && is_vector(src1.layout) &&
  93. is_vector(src2.layout)) {
  94. kern_param.broad_cast_type = BcastType::VEC_VEC_VEC;
  95. return kern_param;
  96. }
  97. if (is_vector(src0.layout) && is_vector(src1.layout) && c_is_scalar) {
  98. kern_param.broad_cast_type = BcastType::VEC_VEC_SCALAR;
  99. return kern_param;
  100. }
  101. if (is_vector(src1.layout) &&
  102. is_broadcasted_channel_like(src0.layout, binfo) &&
  103. src0.layout.eq_layout(src2.layout)) {
  104. kern_param.broad_cast_type = BcastType::BCAST101_VEC_BCAST101;
  105. return kern_param;
  106. }
  107. if (is_vector(src1.layout) &&
  108. (is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
  109. is_broadcastedx_channel_like<8>(src0.layout, binfo)) &&
  110. src0.layout.eq_layout(src2.layout)) {
  111. kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX;
  112. return kern_param;
  113. }
  114. if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
  115. is_broadcasted_channel_like(src1.layout, binfo)) {
  116. kern_param.broad_cast_type = BcastType::VEC_BCAST101_VEC;
  117. return kern_param;
  118. }
  119. if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
  120. (is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
  121. is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
  122. kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC;
  123. return kern_param;
  124. }
  125. if (is_vector(src0.layout) && is_vector(src2.layout) &&
  126. is_broadcasted_scalar(src1.layout)) {
  127. kern_param.broad_cast_type = BcastType::VEC_SCALAR_VEC;
  128. return kern_param;
  129. }
  130. if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout) &&
  131. is_broadcasted_scalar(src2.layout)) {
  132. kern_param.broad_cast_type = BcastType::VEC_SCALAR_SCALAR;
  133. return kern_param;
  134. }
  135. } else if (opr->m_src->size() == 2) {
  136. kern_param.binary_elparam = opr->make_elemwise_op_param<2>();
  137. auto &src0 = kern_param.binary_elparam[0],
  138. &src1 = kern_param.binary_elparam[1];
  139. BroadcastChannelInfo binfo;
  140. if (is_vector(src0.layout) && is_vector(src1.layout)) {
  141. kern_param.broad_cast_type = BcastType::VEC_VEC;
  142. return kern_param;
  143. }
  144. if (is_vector(src0.layout) && is_broadcasted_scalar(src1.layout)) {
  145. kern_param.broad_cast_type = BcastType::VEC_SCALAR;
  146. return kern_param;
  147. }
  148. if (is_vector(src1.layout) && is_broadcasted_scalar(src0.layout)) {
  149. kern_param.broad_cast_type = BcastType::SCALAR_VEC;
  150. return kern_param;
  151. }
  152. if (is_vector(src0.layout) &&
  153. is_broadcasted_channel_like(src1.layout, binfo)) {
  154. kern_param.broad_cast_type = BcastType::VEC_BCAST101;
  155. return kern_param;
  156. }
  157. if (is_vector(src1.layout) &&
  158. is_broadcasted_channel_like(src0.layout, binfo)) {
  159. kern_param.broad_cast_type = BcastType::BCAST101_VEC;
  160. return kern_param;
  161. }
  162. if (is_vector(src0.layout) &&
  163. (is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
  164. is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
  165. kern_param.broad_cast_type = BcastType::VEC_BCAST101xX;
  166. return kern_param;
  167. }
  168. if (is_vector(src1.layout) &&
  169. (is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
  170. is_broadcastedx_channel_like<8>(src0.layout, binfo))) {
  171. kern_param.broad_cast_type = BcastType::BCAST101xX_VEC;
  172. return kern_param;
  173. }
  174. } else if (opr->m_src->size() == 1) {
  175. kern_param.broad_cast_type = BcastType::VEC;
  176. kern_param.unary_elparam = opr->make_elemwise_op_param<1>();
  177. return kern_param;
  178. }
  179. return kern_param;
  180. }
  181. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台