You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algo_fp32_pooling_nchw44.cpp 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. /**
  2. * \file dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/opr_param_defs.h"
  13. #include "src/arm_common/pooling/algo.h"
  14. #include "src/arm_common/pooling/kern_fp32_pooling_nchw44.h"
  15. #include "midout.h"
  16. MIDOUT_DECL(megdnn_arm_common_fp32_pooling_nchw44)
  17. namespace megdnn {
  18. namespace arm_common {
  19. bool PoolingImpl::AlgoFp32ModexStridexNCHW44::usable(
  20. const PoolingKernSizeParam& param) const {
  21. uint32_t sh = param.stride[0];
  22. uint32_t sw = param.stride[1];
  23. uint32_t fh = param.filter[0];
  24. uint32_t fw = param.filter[1];
  25. bool avaible = param.src_type.enumv() == DTypeEnum::Float32 &&
  26. param.format == Param::Format::NCHW44 &&
  27. (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
  28. fh == fw && sh == sw &&
  29. (fh == 2 || fh == 3 || fh == 4 || fh == 5) &&
  30. (sh == 1 || sh == 2);
  31. return avaible;
  32. }
  33. void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec(
  34. const PoolingKernParam& param) const {
  35. int ih = param.isz[0];
  36. int iw = param.isz[1];
  37. int oh = param.osz[0];
  38. int ow = param.osz[1];
  39. int n = param.n;
  40. int ic = param.ic;
  41. int ph = param.padding[0];
  42. int pw = param.padding[1];
  43. int sh = param.stride[0];
  44. int fh = param.filter[0];
  45. void* src_ptr = param.src_ptr;
  46. void* dst_ptr = param.dst_ptr;
  47. #define DISPATCH_FUNC(filter, stride, mode) \
  48. MIDOUT_BEGIN(megdnn_arm_common_fp32_pooling_nchw44, midout_iv(0), \
  49. midout_iv(#filter #stride #mode##_hash)) { \
  50. auto run = [ih, iw, oh, ow, ph, pw, src_ptr, dst_ptr](size_t index, \
  51. size_t) { \
  52. const int c_idx = index; \
  53. pooling_fp32_nchw44<filter, stride, mode>( \
  54. static_cast<const float*>(src_ptr) + c_idx * ih * iw * 4, \
  55. static_cast<float*>(dst_ptr) + c_idx * oh * ow * 4, ih, \
  56. iw, oh, ow, ph, pw); \
  57. }; \
  58. MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
  59. static_cast<::megdnn::naive::HandleImpl*>(param.handle), \
  60. n* ic, run); \
  61. } \
  62. MIDOUT_END();
  63. #define DISPATCH_MODE(filter, stride) \
  64. switch (param.mode) { \
  65. case PoolingBase::Mode::MAX: \
  66. DISPATCH_FUNC(filter, stride, PoolingBase::Mode::MAX); \
  67. break; \
  68. case PoolingBase::Mode::AVERAGE: \
  69. DISPATCH_FUNC(filter, stride, PoolingBase::Mode::AVERAGE); \
  70. break; \
  71. default: \
  72. megdnn_assert(0, "invalid mode %u", \
  73. static_cast<uint32_t>(param.mode)); \
  74. }
  75. #define DISPATCH_STRIDE(filter) \
  76. switch (sh) { \
  77. case 1: \
  78. DISPATCH_MODE(filter, 1); \
  79. break; \
  80. case 2: \
  81. DISPATCH_MODE(filter, 2); \
  82. break; \
  83. default: \
  84. megdnn_assert(0, "invalid stride %d", sh); \
  85. }
  86. #define DISPATCH_FILTER() \
  87. switch (fh) { \
  88. case 2: \
  89. DISPATCH_STRIDE(2); \
  90. break; \
  91. case 3: \
  92. DISPATCH_STRIDE(3); \
  93. break; \
  94. case 4: \
  95. DISPATCH_STRIDE(4); \
  96. break; \
  97. case 5: \
  98. DISPATCH_STRIDE(5); \
  99. break; \
  100. default: \
  101. megdnn_assert(0, "invalid filter %d", fh); \
  102. }
  103. DISPATCH_FILTER()
  104. #undef DISPATCH_FILTER
  105. #undef DISPATCH_STRIDE
  106. #undef DISPATCH_MODE
  107. #undef DISPATCH_FUNC
  108. }
  109. } // namespace arm_common
  110. } // namespace megdnn
  111. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台