You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

handle_impl.h 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. /**
  2. * \file dnn/src/common/handle_impl.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/handle.h"
  13. #include "megdnn/oprs.h"
  14. #include "src/common/utils.h"
  15. #include <mutex>
  16. namespace megdnn {
  17. class HandleImplHelper : public Handle {
  18. public:
  19. using Handle::Handle;
  20. //! global matmul opr
  21. virtual MatrixMul* matmul_opr() {
  22. megdnn_throw("Unimplement matmul opr.\n");
  23. }
  24. //! global matmul opr with first operand transposed
  25. virtual MatrixMul* matmul_aT_opr() {
  26. megdnn_throw("Unimplement matmul_aT opr.\n");
  27. }
  28. //! global matmul opr with second operand transposed
  29. virtual MatrixMul* matmul_bT_opr() {
  30. megdnn_throw("Unimplement matmul_bT opr.\n");
  31. }
  32. //! global matmul opr with both operand transposed
  33. virtual MatrixMul* matmul_aT_bT_opr() {
  34. megdnn_throw("Unimplement matmul_aT_bT opr.\n");
  35. }
  36. //! global relayout opr
  37. virtual Relayout* relayout_opr() {
  38. megdnn_throw("Unimplement Relayout opr.\n");
  39. }
  40. virtual Checksum* checksum_opr() {
  41. megdnn_throw("Unimplement Checksum opr.\n");
  42. }
  43. virtual MaxTensorDiff* max_tensor_diff_opr() {
  44. megdnn_throw("Unimplement MaxTensorDiff opr.\n");
  45. }
  46. protected:
  47. static constexpr size_t NR_HELPER_OPRS = 7;
  48. template <class Opr, size_t idx, class Self>
  49. static Opr* get_helper_opr(Self self,
  50. const typename Opr::Param& param = {}) {
  51. static_assert(idx < NR_HELPER_OPRS, "invalid idx");
  52. if (!self->m_helper_oprs[idx]) {
  53. std::lock_guard<std::mutex> lg{self->m_helper_oprs_mtx};
  54. if (!self->m_helper_oprs[idx]) {
  55. self->m_helper_oprs[idx] =
  56. self->template create_operator<Opr>();
  57. auto ret = static_cast<Opr*>(self->m_helper_oprs[idx].get());
  58. ret->param() = param;
  59. megdnn_assert(ret->is_thread_safe());
  60. return ret;
  61. }
  62. }
  63. return static_cast<Opr*>(self->m_helper_oprs[idx].get());
  64. }
  65. private:
  66. std::array<std::unique_ptr<OperatorBase>, NR_HELPER_OPRS> m_helper_oprs;
  67. std::mutex m_helper_oprs_mtx;
  68. };
  69. } // namespace megdnn
  70. /*!
  71. * \brief iterate though each operator class name; useful for explicit
  72. * instantialization of create_operator<> templates
  73. */
  74. #define MEGDNN_FOREACH_OPR_CLASS(cb) \
  75. cb(ConvolutionForward) \
  76. cb(ConvolutionBackwardData) \
  77. cb(ConvolutionBackwardFilter) \
  78. cb(ConvPoolingForward) \
  79. cb(ConvBiasForward) \
  80. cb(Images2NeibsForward) \
  81. cb(Images2NeibsBackward) \
  82. cb(ElemwiseForward) \
  83. cb(ElemwiseMultiType) \
  84. cb(AddUpdateForward) \
  85. cb(RelayoutForward) \
  86. cb(PoolingForward) \
  87. cb(PoolingBackward) \
  88. cb(LocalForward) \
  89. cb(LocalBackwardData) \
  90. cb(LocalBackwardFilter) \
  91. cb(LRNForward) \
  92. cb(LRNBackward) \
  93. cb(ROIPoolingForward) \
  94. cb(ROIPoolingBackward) \
  95. cb(WarpPerspectiveForward) \
  96. cb(WarpPerspectiveBackwardData) \
  97. cb(WarpPerspectiveBackwardMat) \
  98. cb(DotForward) \
  99. cb(MatrixInverse) \
  100. cb(MatrixMulForward) \
  101. cb(BatchedMatrixMulForward) \
  102. cb(SVDForward) \
  103. cb(ReduceForward) \
  104. cb(CondTake) \
  105. cb(CumsumForward) \
  106. cb(ArgmaxForward) \
  107. cb(ArgminForward) \
  108. cb(TransposeForward) \
  109. cb(ConcatForward) \
  110. cb(SplitForward) \
  111. cb(TileForward) \
  112. cb(TileBackward) \
  113. cb(RepeatForward) \
  114. cb(RepeatBackward) \
  115. cb(ArgsortForward) \
  116. cb(ArgsortBackward) \
  117. cb(TypeCvt) \
  118. cb(IndexingRemapForward) \
  119. cb(IndexingRemapBackward) \
  120. cb(ChecksumForward) \
  121. cb(IndexingOneHotForward) \
  122. cb(IndexingSetOneHotForward) \
  123. cb(IndexingMultiAxisVec) \
  124. cb(IndexingSetMultiAxisVec) \
  125. cb(IndexingIncrMultiAxisVec) \
  126. cb(MeshIndexing) \
  127. cb(IncrMeshIndexing) \
  128. cb(SetMeshIndexing) \
  129. cb(BatchedMeshIndexing) \
  130. cb(BatchedIncrMeshIndexing) \
  131. cb(BatchedSetMeshIndexing) \
  132. cb(Linspace) \
  133. cb(Eye) \
  134. cb(SleepForward) \
  135. cb(UniformRNG) \
  136. cb(GaussianRNG) \
  137. cb(SeparableConvForward) \
  138. cb(SeparableFilterForward) \
  139. cb(BNForward) \
  140. cb(BNBackward) \
  141. cb(GroupLocalForward) \
  142. cb(GroupLocalBackwardData) \
  143. cb(GroupLocalBackwardFilter) \
  144. cb(Flip) \
  145. cb(Rotate) \
  146. cb(ROICopy) \
  147. cb(CvtColor) \
  148. cb(WarpAffine) \
  149. cb(GaussianBlur) \
  150. cb(Resize) \
  151. cb(ResizeBackward) \
  152. cb(ParamPackConcat) \
  153. cb(MaxTensorDiff) \
  154. cb(MaskConvForward) \
  155. cb(MaskPropagate) \
  156. cb(Convolution3DForward) \
  157. cb(Convolution3DBackwardData) \
  158. cb(Convolution3DBackwardFilter) \
  159. cb(DeformableConvForward) \
  160. cb(DeformableConvBackwardFilter) \
  161. cb(DeformableConvBackwardData) \
  162. cb(DeformablePSROIPoolingForward) \
  163. cb(DeformablePSROIPoolingBackward) \
  164. cb(RelayoutFormat) \
  165. cb(TopK) \
  166. cb(PowC) \
  167. cb(WinogradFilterPreprocess) \
  168. cb(LocalShareForward) \
  169. cb(LocalShareBackwardData) \
  170. cb(LocalShareBackwardFilter) \
  171. cb(ROIAlignForward) \
  172. cb(ROIAlignBackward) \
  173. cb(BatchConvBiasForward) \
  174. cb(Remap) \
  175. /*!
  176. * \brief specialize HandleImpl::create_operator for a single opr type;
  177. * implemented by <opr>Impl class
  178. */
  179. #define MEGDNN_SPECIALIZE_CREATE_OPERATOR(opr) \
  180. template <> \
  181. std::unique_ptr<megdnn::opr> HandleImpl::create_operator() { \
  182. return megdnn::make_unique<opr##Impl>(this); \
  183. }
  184. /*!
  185. * \brief for explicit instantiation for HandleImpl::create_operator methods
  186. */
  187. #define MEGDNN_INST_CREATE_OPERATOR(opr) \
  188. template std::unique_ptr<megdnn::opr> HandleImpl::create_operator();
  189. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台