You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl_helper.cpp 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /**
  2. * \file dnn/src/common/elemwise/opr_impl_helper.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "./opr_impl_helper.h"
  13. #include "src/common/utils.h"
  14. using namespace megdnn;
  15. template <int arity>
  16. ElemwiseOpParamN<arity> ElemwiseLayoutHelper::make_elemwise_op_param(
  17. void* opr,
  18. void (*check_layout_and_broadcast)(void*, const TensorLayoutPtrArray&,
  19. const TensorLayout&),
  20. const TensorNDArray& src, const TensorND& dst) {
  21. megdnn_assert(src.size() == static_cast<size_t>(arity));
  22. ElemwiseOpParamN<arity> ret;
  23. TensorLayoutPtrArray src_layouts(arity);
  24. for (int i = 0; i < arity; ++i) {
  25. ret.param[i] = src[i];
  26. src_layouts[i] = &ret.param[i].layout;
  27. }
  28. check_layout_and_broadcast(opr, src_layouts, dst.layout);
  29. ret.init_from_given_tensor();
  30. return ret;
  31. }
  32. // explicit instantiation so subclasses can call this method
  33. #define INST(n) \
  34. template ElemwiseOpParamN<n> \
  35. ElemwiseLayoutHelper::make_elemwise_op_param<n>( \
  36. void*, \
  37. void (*)(void*, const TensorLayoutPtrArray&, const TensorLayout&), \
  38. const TensorNDArray&, const TensorND&)
  39. INST(1);
  40. INST(2);
  41. INST(3);
  42. INST(4);
  43. INST(5);
  44. INST(6);
  45. #undef INST
  46. void ElemwiseForwardImplHelper::prepare_fma3(ElemwiseOpParamN<3>& param,
  47. bool& c_is_scalar) {
  48. c_is_scalar = is_broadcasted_scalar(m_src->at(2).layout);
  49. param = make_elemwise_op_param<3>();
  50. if (!c_is_scalar && !param[2].layout.eq_layout(param[0].layout)) {
  51. megdnn_assert_eq_layout(param[2].layout, param[1].layout);
  52. std::swap(param[0], param[1]);
  53. }
  54. if (c_is_scalar && param[2].layout.eq_layout(param[0].layout)) {
  55. std::swap(param[0], param[1]);
  56. }
  57. }
  58. void ElemwiseForwardImplHelper::prepare_fma4(ElemwiseOpParamN<4>& param) {
  59. param = make_elemwise_op_param<4>();
  60. if (!param[0].layout.eq_layout(param[2].layout))
  61. std::swap(param[0], param[1]);
  62. megdnn_assert_eq_layout(param[0].layout, param[2].layout);
  63. megdnn_assert_eq_layout(param[1].layout, param[3].layout);
  64. }
  65. bool ElemwiseLayoutHelper::is_broadcasted_scalar(const TensorLayout& layout) {
  66. if (layout.format.type() != TensorFormat::Type::DEFAULT)
  67. return false;
  68. for (size_t i = 0; i < layout.ndim; ++i) {
  69. if (layout.shape[i] != 1 && layout.stride[i] != 0)
  70. return false;
  71. }
  72. return true;
  73. }
  74. template <size_t slice_size>
  75. bool ElemwiseLayoutHelper::is_broadcastedx_channel_like(
  76. const TensorLayout& layout, BroadcastChannelInfo& info) {
  77. if (layout.format.type() == TensorFormat::Type::DEFAULT &&
  78. layout.ndim == 3 && layout.stride[0] == slice_size &&
  79. layout.stride[1] == 0 && layout.stride[2] == 1) {
  80. info.x = layout.shape[0];
  81. info.y = layout.shape[1];
  82. info.z = layout.shape[2];
  83. return true;
  84. } else if (layout.format.type() == TensorFormat::Type::DEFAULT &&
  85. layout.ndim == 4 && layout.stride[0] == 0 &&
  86. layout.stride[1] == slice_size && layout.stride[2] == 0 &&
  87. layout.stride[3] == 1) {
  88. info.x = layout.shape[1];
  89. info.y = layout.shape[2];
  90. info.z = layout.shape[3];
  91. return true;
  92. }
  93. return false;
  94. }
  95. #define INST(n) \
  96. template bool ElemwiseLayoutHelper::is_broadcastedx_channel_like<n>( \
  97. const TensorLayout& layout, BroadcastChannelInfo& info)
  98. INST(4);
  99. INST(8);
  100. #undef INST
  101. bool ElemwiseLayoutHelper::is_broadcasted_channel_like(
  102. const TensorLayout& layout, BroadcastChannelInfo& info) {
  103. if (layout.format.type() == TensorFormat::Type::DEFAULT) {
  104. if (layout.ndim == 3 && layout.stride[0] == 0 &&
  105. layout.stride[2] == 0 && layout.stride[1] == 1) {
  106. info.x = layout.shape[0];
  107. info.y = layout.shape[1];
  108. info.z = layout.shape[2];
  109. return true;
  110. } else if (layout.ndim == 2 && layout.stride[1] == 0 &&
  111. layout.stride[0] == 1) {
  112. info.x = 1;
  113. info.y = layout.shape[0];
  114. info.z = layout.shape[1];
  115. return true;
  116. }
  117. } else {
  118. if (Image2DPack4TensorFormat::is_valid_image(layout)) {
  119. auto align_axis = layout.format.as_impl<Image2DPack4TensorFormat>()
  120. .align_axis();
  121. if (layout.ndim == 4 && align_axis == 1 &&
  122. (layout.stride[0] == 0 || layout.shape[0] == 1) &&
  123. layout.stride[1] == 4 && layout.stride[2] == 0 &&
  124. layout.stride[3] == 1) {
  125. info.x = 1;
  126. info.y = 1;
  127. info.z = layout.shape[2];
  128. return true;
  129. } else if (layout.ndim == 3 && align_axis == 1 &&
  130. (layout.stride[0] == 0 || layout.shape[0] == 1) &&
  131. layout.stride[1] == 0 && layout.shape[2] == 4 &&
  132. layout.stride[2] == 1) {
  133. //! [1, 1, 1, 1, 4] + [N, H, 1, W, 4]
  134. info.x = 1;
  135. info.y = 1;
  136. info.z = layout.shape[1];
  137. return true;
  138. }
  139. return false;
  140. }
  141. }
  142. return false;
  143. }
  144. bool ElemwiseLayoutHelper::is_broadcasted_1x(const TensorLayout& layout,
  145. Broadcast1xInfo& binfo) {
  146. if (layout.ndim == 2 && layout.stride[0] == 0 && layout.stride[1] == 1) {
  147. binfo.x = layout[0];
  148. binfo.y = layout[1];
  149. return true;
  150. }
  151. if (layout.ndim == 1 && layout.stride[0] == 1) {
  152. binfo.x = 1;
  153. binfo.y = layout[0];
  154. return true;
  155. }
  156. return false;
  157. }
  158. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台