You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. /**
  2. * \file dnn/src/arm_common/pooling/opr_impl.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <unordered_map>
  14. #include "megdnn/oprs/base.h"
  15. #include "src/fallback/pooling/opr_impl.h"
  16. namespace megdnn {
  17. namespace arm_common {
  18. class PoolingImpl final : public fallback::PoolingImpl {
  19. private:
  20. class AlgoFilterxModexStride1;
  21. class AlgoFilter2ModexStride2;
  22. class AlgoFilter3MaxStride2;
  23. class AlgoFilter3AverageStride2;
  24. class AlgoFilter4MaxStride2;
  25. class AlgoFilter5MaxStride2;
  26. class AlgoInt8Filter2MaxStride2;
  27. class AlgoInt8Filter3MaxStride2;
  28. class AlgoFilter2ModexStridexNCHW44;
  29. class AlgoFilter3ModexStridexNCHW44;
  30. class AlgoFilter4ModexStridexNCHW44;
  31. class AlgoFilter5ModexStridexNCHW44;
  32. class AlgoFp32ModexStridexNCHW44;
  33. class AlgoFallback;
  34. class AlgoPack;
  35. static AlgoPack sm_algo_pack;
  36. public:
  37. using fallback::PoolingImpl::PoolingImpl;
  38. void exec(
  39. _megdnn_tensor_in src, _megdnn_tensor_out dst,
  40. _megdnn_workspace workspace) override;
  41. size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override;
  42. static size_t constexpr MAX_SPATIAL_DIM = 2;
  43. struct PoolingKernSizeParam {
  44. uint32_t n, ic;
  45. std::array<uint32_t, MAX_SPATIAL_DIM> isz, osz;
  46. std::array<uint32_t, MAX_SPATIAL_DIM> padding, filter, stride;
  47. DType src_type, dst_type;
  48. Handle* handle;
  49. Param::Format format;
  50. Mode mode;
  51. };
  52. struct PoolingKernParam : public PoolingKernSizeParam {
  53. RefPtr src_ptr;
  54. RefPtr dst_ptr;
  55. void* workspace_ptr;
  56. size_t workspace_size;
  57. template <typename T>
  58. const T* src() const {
  59. src_type.assert_is_compatible_ctype<T>();
  60. return static_cast<const T*>(src_ptr.get_ptr());
  61. }
  62. template <typename T>
  63. T* dst() const {
  64. dst_type.assert_is_compatible_ctype<T>();
  65. return static_cast<T*>(dst_ptr.get_ptr());
  66. }
  67. template <typename T>
  68. T* workspace() const {
  69. return static_cast<T*>(workspace_ptr);
  70. }
  71. };
  72. PoolingKernSizeParam make_pooling_kern_szie_param(
  73. fallback::PoolingImpl* opr, const TensorLayout& src,
  74. const TensorLayout& dst);
  75. PoolingKernParam make_pooling_kern_param(
  76. fallback::PoolingImpl* opr, _megdnn_tensor_in src, _megdnn_tensor_out dst,
  77. _megdnn_workspace workspace);
  78. class AlgoBase : public detail::Algorithm {
  79. public:
  80. enum class AlgoType : uint32_t {
  81. ARM_FilterxModexStride1,
  82. ARM_Filter2ModexStride2,
  83. ARM_Filter3MaxStride2,
  84. ARM_Filter3AverageStride2,
  85. ARM_Filter4MaxStride2,
  86. ARM_Filter5MaxStride2,
  87. ARM_Int8Filter2MaxStride2,
  88. ARM_Int8Filter3MaxStride2,
  89. ARM_Filter2ModexStridexNCHW44,
  90. ARM_Filter3ModexStridexNCHW44,
  91. ARM_Filter4ModexStridexNCHW44,
  92. ARM_Filter5ModexStridexNCHW44,
  93. ARM_Fp32ModexStridexNCHW44,
  94. ARM_Fallback
  95. };
  96. using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
  97. AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ARM_COMMON; }
  98. virtual ~AlgoBase() = default;
  99. virtual bool usable(const PoolingKernSizeParam& param) const = 0;
  100. virtual void exec(const PoolingKernParam& param) const = 0;
  101. uint32_t type() const override { return INVALID_ALGO_TYPE; };
  102. bool is_available_attribute(
  103. const PoolingKernSizeParam& param,
  104. const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE,
  105. const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) {
  106. return contain_attribute_all(positive_attr) &&
  107. !contain_attribute_any(negative_attr) && usable(param);
  108. }
  109. };
  110. const char* get_algorithm_set_name() const override {
  111. return "ARM_POOLING_FORWARD";
  112. }
  113. Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override;
  114. std::vector<Algorithm*> get_all_algorithms(
  115. const TensorLayout& src, const TensorLayout& dst) override;
  116. std::vector<Algorithm*> get_all_algorithms_safe(
  117. const TensorLayout& src, const TensorLayout& dst) override;
  118. Algorithm* get_algorithm_heuristic(
  119. const TensorLayout& src, const TensorLayout& dst,
  120. size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
  121. const AlgoAttribute& negative_attr) override;
  122. AlgorithmInfo get_algorithm_info_heuristic(
  123. const TensorLayout& src, const TensorLayout& dst,
  124. size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
  125. const AlgoAttribute& negative_attr) {
  126. return get_algorithm_heuristic(
  127. src, dst, workspace_limit_in_bytes, positive_attr, negative_attr)
  128. ->info();
  129. }
  130. static const AlgoPack& algo_pack() { return sm_algo_pack; }
  131. bool is_fallback_algo(Algorithm* algo) {
  132. return strcmp(algo->name(), "FALLBACK_POOLING") == 0;
  133. }
  134. };
  135. } // namespace arm_common
  136. } // namespace megdnn
  137. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台