You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. /**
  2. * \file dnn/src/cuda/deformable_conv/opr_impl.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/oprs/nn.h"
  14. namespace megdnn {
  15. namespace cuda {
  16. class DeformableConvForwardImpl : public DeformableConvForward {
  17. public:
  18. using DeformableConvForward::DeformableConvForward;
  19. void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  20. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  21. _megdnn_tensor_out dst, _megdnn_workspace workspace) override;
  22. size_t get_workspace_in_bytes(const TensorLayout& im,
  23. const TensorLayout& filter,
  24. const TensorLayout& offset,
  25. const TensorLayout& mask,
  26. const TensorLayout& dst) override;
  27. Algorithm* get_algorithm_heuristic(const TensorLayout& im,
  28. const CanonizedFilterMeta& filter,
  29. const TensorLayout& offset,
  30. const TensorLayout& mask,
  31. const TensorLayout& dst,
  32. size_t workspace_limit_in_bytes,
  33. const AlgoAttribute& positive_attr,
  34. const AlgoAttribute& negative_attr);
  35. const char* get_algorithm_set_name() const override;
  36. class AlgoBase;
  37. class AlgoMatmul;
  38. class AlgoPack;
  39. static const AlgoPack& algo_pack() { return sm_algo_pack; }
  40. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  41. protected:
  42. std::vector<Algorithm*> get_all_algorithms(
  43. const TensorLayout& im, const TensorLayout& filter,
  44. const TensorLayout& offset, const TensorLayout& mask,
  45. const TensorLayout& dst) override;
  46. std::vector<Algorithm*> get_all_algorithms_safe(
  47. const TensorLayout& im, const TensorLayout& filter,
  48. const TensorLayout& offset, const TensorLayout& mask,
  49. const TensorLayout& dst) override;
  50. Algorithm* get_algorithm_heuristic(
  51. const TensorLayout& im, const TensorLayout& filter,
  52. const TensorLayout& offset, const TensorLayout& mask,
  53. const TensorLayout& dst, size_t workspace_limit_in_bytes,
  54. const AlgoAttribute& positive_attr,
  55. const AlgoAttribute& negative_attr) override;
  56. private:
  57. static AlgoPack sm_algo_pack;
  58. };
  59. class DeformableConvBackwardFilterImpl : public DeformableConvBackwardFilter {
  60. public:
  61. using DeformableConvBackwardFilter::DeformableConvBackwardFilter;
  62. void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset,
  63. _megdnn_tensor_in mask, _megdnn_tensor_in out_grad,
  64. _megdnn_tensor_out filter_grad,
  65. _megdnn_workspace workspace) override;
  66. Algorithm* get_algorithm_heuristic(const TensorLayout& im,
  67. const TensorLayout& offset,
  68. const TensorLayout& mask,
  69. const TensorLayout& out_grad,
  70. const CanonizedFilterMeta& filter_grad,
  71. size_t workspace_limit_in_bytes,
  72. const AlgoAttribute& positive_attr,
  73. const AlgoAttribute& negative_attr);
  74. size_t get_workspace_in_bytes(const TensorLayout& im,
  75. const TensorLayout& offset,
  76. const TensorLayout& mask,
  77. const TensorLayout& out_grad,
  78. const TensorLayout& filter_grad) override;
  79. const char* get_algorithm_set_name() const override;
  80. class AlgoBase;
  81. class AlgoMatmul;
  82. class AlgoPack;
  83. static const AlgoPack& algo_pack() { return sm_algo_pack; }
  84. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  85. protected:
  86. std::vector<Algorithm*> get_all_algorithms(
  87. const TensorLayout& im, const TensorLayout& offset,
  88. const TensorLayout& mask, const TensorLayout& out_grad,
  89. const TensorLayout& filter_grad) override;
  90. std::vector<Algorithm*> get_all_algorithms_safe(
  91. const TensorLayout& im, const TensorLayout& offset,
  92. const TensorLayout& mask, const TensorLayout& out_grad,
  93. const TensorLayout& filter_grad) override;
  94. Algorithm* get_algorithm_heuristic(
  95. const TensorLayout& im, const TensorLayout& offset,
  96. const TensorLayout& mask, const TensorLayout& out_grad,
  97. const TensorLayout& filter_grad, size_t workspace_limit_in_bytes,
  98. const AlgoAttribute& positive_attr,
  99. const AlgoAttribute& negative_attr) override;
  100. private:
  101. static AlgoPack sm_algo_pack;
  102. };
  103. class DeformableConvBackwardDataImpl : public DeformableConvBackwardData {
  104. public:
  105. using DeformableConvBackwardData::DeformableConvBackwardData;
  106. void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter,
  107. _megdnn_tensor_in offset, _megdnn_tensor_in mask,
  108. _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad,
  109. _megdnn_tensor_out offset_grad, _megdnn_tensor_out mask_grad,
  110. _megdnn_workspace workspace) override;
  111. Algorithm* get_algorithm_heuristic(
  112. const TensorLayout& im, const CanonizedFilterMeta& filter,
  113. const TensorLayout& offset, const TensorLayout& mask,
  114. const TensorLayout& out_grad, const TensorLayout& im_grad,
  115. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  116. size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
  117. const AlgoAttribute& negative_attr);
  118. size_t get_workspace_in_bytes(const TensorLayout& im,
  119. const TensorLayout& filter,
  120. const TensorLayout& offset,
  121. const TensorLayout& mask,
  122. const TensorLayout& out_grad,
  123. const TensorLayout& im_grad,
  124. const TensorLayout& offset_grad,
  125. const TensorLayout& mask_grad) override;
  126. const char* get_algorithm_set_name() const override;
  127. class AlgoBase;
  128. class AlgoMatmul;
  129. class AlgoPack;
  130. static const AlgoPack& algo_pack() { return sm_algo_pack; }
  131. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  132. protected:
  133. std::vector<Algorithm*> get_all_algorithms(
  134. const TensorLayout& im, const TensorLayout& filter,
  135. const TensorLayout& offset, const TensorLayout& mask,
  136. const TensorLayout& out_grad, const TensorLayout& im_grad,
  137. const TensorLayout& offset_grad,
  138. const TensorLayout& mask_grad) override;
  139. std::vector<Algorithm*> get_all_algorithms_safe(
  140. const TensorLayout& im, const TensorLayout& filter,
  141. const TensorLayout& offset, const TensorLayout& mask,
  142. const TensorLayout& out_grad, const TensorLayout& im_grad,
  143. const TensorLayout& offset_grad,
  144. const TensorLayout& mask_grad) override;
  145. Algorithm* get_algorithm_heuristic(
  146. const TensorLayout& im, const TensorLayout& filter,
  147. const TensorLayout& offset, const TensorLayout& mask,
  148. const TensorLayout& out_grad, const TensorLayout& im_grad,
  149. const TensorLayout& offset_grad, const TensorLayout& mask_grad,
  150. size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
  151. const AlgoAttribute& negative_attr) override;
  152. private:
  153. static AlgoPack sm_algo_pack;
  154. };
  155. } // namespace cuda
  156. } // namespace megdnn
  157. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台