You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. /**
  2. * \file dnn/src/arm_common/pooling/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/arm_common/pooling/opr_impl.h"
  13. #include "src/arm_common/pooling/algo.h"
  14. #include "src/common/metahelper.h"
  15. #include "src/common/algo_chooser.h"
  16. using namespace megdnn;
  17. using namespace arm_common;
  18. class PoolingImpl::AlgoPack : NonCopyableObj {
  19. private:
  20. AlgoBase::Mapper m_all_algos_map;
  21. AlgoFilterxModexStride1 algo_filterx_modex_stride1;
  22. AlgoFilter2ModexStride2 algo_filter2_modex_stride2;
  23. AlgoFilter3MaxStride2 algo_filter3_max_stride2;
  24. AlgoFilter3AverageStride2 algo_filter3_average_stride2;
  25. AlgoFilter4MaxStride2 algo_filter4_max_stride2;
  26. AlgoFilter5MaxStride2 algo_filter5_max_stride2;
  27. AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2;
  28. AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2;
  29. AlgoFilter2ModexStridexNCHW44 algo_filter2_modex_stridex_nchw4;
  30. AlgoFilter3ModexStridexNCHW44 algo_filter3_modex_stridex_nchw4;
  31. AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4;
  32. AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4;
  33. AlgoFp32ModexStridexNCHW44 algo_fp32_modex_stridex_nchw44;
  34. AlgoFallback algo_fallback;
  35. public:
  36. AlgoPack() {
  37. all_algos.emplace_back(&algo_filterx_modex_stride1);
  38. all_algos.emplace_back(&algo_filter2_modex_stride2);
  39. all_algos.emplace_back(&algo_filter3_max_stride2);
  40. all_algos.emplace_back(&algo_filter3_average_stride2);
  41. all_algos.emplace_back(&algo_filter4_max_stride2);
  42. all_algos.emplace_back(&algo_filter5_max_stride2);
  43. all_algos.emplace_back(&algo_int8_filter2_max_stride2);
  44. all_algos.emplace_back(&algo_int8_filter3_max_stride2);
  45. all_algos.emplace_back(&algo_filter3_modex_stridex_nchw4);
  46. all_algos.emplace_back(&algo_filter2_modex_stridex_nchw4);
  47. all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4);
  48. all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4);
  49. all_algos.emplace_back(&algo_fp32_modex_stridex_nchw44);
  50. all_algos.emplace_back(&algo_fallback);
  51. for (auto&& algo : all_algos) {
  52. m_all_algos_map.emplace(algo->info().desc, algo);
  53. }
  54. }
  55. SmallVector<AlgoBase*> all_algos;
  56. const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
  57. };
  58. PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack;
  59. PoolingImpl::PoolingKernSizeParam PoolingImpl::make_pooling_kern_szie_param(
  60. fallback::PoolingImpl* opr, const TensorLayout& src,
  61. const TensorLayout& dst) {
  62. auto safe_u32 = [](size_t v) -> uint32_t {
  63. megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
  64. "value too large: %zu", v);
  65. return v;
  66. };
  67. return {safe_u32(src.shape[0]),
  68. safe_u32(src.shape[1]),
  69. {{safe_u32(src.shape[2]), safe_u32(src.shape[3])}},
  70. {{safe_u32(dst.shape[2]), safe_u32(dst.shape[3])}},
  71. {{safe_u32(opr->param().pad_h), safe_u32(opr->param().pad_w)}},
  72. {{safe_u32(opr->param().window_h),
  73. safe_u32(opr->param().window_w)}},
  74. {{safe_u32(opr->param().stride_h),
  75. safe_u32(opr->param().stride_w)}},
  76. src.dtype,
  77. dst.dtype,
  78. opr->handle(),
  79. opr->param().format,
  80. opr->param().mode};
  81. };
  82. PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param(
  83. fallback::PoolingImpl* opr, _megdnn_tensor_in src,
  84. _megdnn_tensor_out dst, _megdnn_workspace workspace) {
  85. PoolingKernParam ret;
  86. static_cast<PoolingKernSizeParam&>(ret) =
  87. make_pooling_kern_szie_param(opr, src.layout, dst.layout);
  88. ret.src_ptr = src.raw_ptr;
  89. ret.dst_ptr = dst.raw_ptr;
  90. ret.workspace_ptr = workspace.raw_ptr;
  91. ret.workspace_size = workspace.size;
  92. return ret;
  93. };
  94. size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
  95. const TensorLayout& dst) {
  96. TensorLayoutArray layouts{src, dst};
  97. HeuristicCache::Key key{this->handle(), this->get_opr_type(),
  98. layouts.data(), layouts.size(), &this->param(),
  99. sizeof(this->param())};
  100. auto rst = HeuristicCache::instance().get(key);
  101. if (rst.policy.algo.valid()) {
  102. return rst.workspace;
  103. }
  104. auto param = make_pooling_kern_szie_param(this, src, dst);
  105. auto algo = get_algorithm(this, src, dst);
  106. if (!is_fallback_algo(algo)) {
  107. size_t arm_common_workspace = 0;
  108. //! When multi-thread, every thread has its own workspace
  109. size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
  110. ->megcore_dispatcher()
  111. ->nr_threads();
  112. if ((param.src_type.category() == DTypeCategory::FLOAT ||
  113. param.src_type == dtype::Int8{} ||
  114. param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
  115. param.src_type.enumv() == DTypeEnum::Quantized8Asymm) &&
  116. param.filter[0] == param.filter[1] &&
  117. (param.filter[0] == 3 || param.filter[0] == 5) &&
  118. param.format == Param::Format::NCHW &&
  119. (param.mode == Mode::MAX ||
  120. (param.mode == Mode::AVERAGE && param.filter[0] == 3)) &&
  121. param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 &&
  122. param.isz[1] >= 2) {
  123. WorkspaceBundle ws = get_bundle(param);
  124. arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
  125. }
  126. if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
  127. param.src_type.enumv() == DTypeEnum::Int8) &&
  128. (param.format == param::Pooling::Format::NCHW44)) {
  129. WorkspaceBundle ws = get_bundle_nchw44(param);
  130. arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
  131. }
  132. return arm_common_workspace;
  133. } else {
  134. auto fallback_worksapce =
  135. fallback::PoolingImpl::get_workspace_in_bytes(src, dst);
  136. return fallback_worksapce;
  137. }
  138. }
  139. void PoolingImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  140. _megdnn_workspace workspace) {
  141. check_exec(src.layout, dst.layout, workspace.size);
  142. auto param = make_pooling_kern_param(this, src, dst, workspace);
  143. auto algo = get_algorithm(this, src.layout, dst.layout);
  144. if (!is_fallback_algo(algo)) {
  145. algo->exec(param);
  146. } else {
  147. fallback::PoolingImpl::exec(src, dst, workspace);
  148. }
  149. }
  150. MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingImpl);
  151. std::vector<Algorithm*> PoolingImpl::get_all_algorithms(
  152. const TensorLayout& src, const TensorLayout& dst) {
  153. auto param = make_pooling_kern_szie_param(this, src, dst);
  154. std::vector<Algorithm*> ret;
  155. ret.reserve(algo_pack().all_algos.size());
  156. for (auto i : algo_pack().all_algos) {
  157. if (i->usable(param)) {
  158. ret.push_back(i);
  159. }
  160. }
  161. megdnn_assert(!ret.empty(), "no usable pooling fwd algorithm");
  162. return ret;
  163. }
  164. Algorithm* PoolingImpl::get_algorithm_heuristic(
  165. const TensorLayout& src, const TensorLayout& dst,
  166. size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
  167. const AlgoAttribute& negative_attr) {
  168. MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes);
  169. auto param = make_pooling_kern_szie_param(this, src, dst);
  170. for (auto&& iter : sm_algo_pack.all_algos) {
  171. if (iter->is_available_attribute(param, positive_attr, negative_attr)) {
  172. return iter;
  173. }
  174. }
  175. megdnn_throw(
  176. ssprintf("require algorithm with attribute(%s) and without "
  177. "attribute(%s), but can't get suitable algo.\n",
  178. Algorithm::attribute_str(positive_attr).c_str(),
  179. Algorithm::attribute_str(negative_attr).c_str()));
  180. return nullptr;
  181. }
  182. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台