You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. /**
  2. * \file dnn/src/cuda/convolution/opr_impl.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/oprs/nn.h"
  14. #include "src/common/utils.h"
  15. namespace megdnn {
  16. namespace cuda {
  17. class ConvolutionForwardImpl : public ConvolutionForward {
  18. public:
  19. using ConvolutionForward::ConvolutionForward;
  20. void exec(
  21. _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
  22. const PreprocessedFilter* preprocessed_filter,
  23. _megdnn_workspace workspace) override;
  24. size_t get_workspace_in_bytes(
  25. const TensorLayout& src, const TensorLayout& filter,
  26. const TensorLayout& dst,
  27. const PreprocessedFilter* preprocessed_filter) override;
  28. const char* get_algorithm_set_name() const override;
  29. SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
  30. const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
  31. return {};
  32. }
  33. size_t get_preprocess_workspace_in_bytes(
  34. const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
  35. return 0;
  36. }
  37. void exec_preprocess(
  38. const TensorLayout&, _megdnn_tensor_in, const TensorLayout&,
  39. PreprocessedFilter*, _megdnn_workspace) override {
  40. megdnn_throw("cuda exec_preprocess has not implemeted yet");
  41. }
  42. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  43. class AlgoBase;
  44. class AlgoDefault;
  45. class AlgoPack;
  46. static const AlgoPack& algo_pack() { return sm_algo_pack; }
  47. protected:
  48. std::vector<Algorithm*> get_all_algorithms(
  49. const TensorLayout& src, const TensorLayout& filter,
  50. const TensorLayout& dst) override;
  51. std::vector<Algorithm*> get_all_algorithms_safe(
  52. const TensorLayout& src, const TensorLayout& filter,
  53. const TensorLayout& dst) override;
  54. Algorithm* get_algorithm_heuristic(
  55. const TensorLayout& src, const TensorLayout& filter,
  56. const TensorLayout& dst, size_t workspace_limit_in_bytes,
  57. const AlgoAttribute& positive_attr,
  58. const AlgoAttribute& negative_attr) override;
  59. private:
  60. static AlgoPack sm_algo_pack;
  61. };
  62. class ConvolutionBackwardDataImpl : public ConvolutionBackwardData {
  63. public:
  64. using ConvolutionBackwardData::ConvolutionBackwardData;
  65. void exec(
  66. _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  67. _megdnn_workspace workspace) override;
  68. AlgorithmInfo get_algorithm_info_heuristic(
  69. const TensorLayout& filter, const TensorLayout& diff,
  70. const TensorLayout& grad, size_t workspace_limit_in_bytes,
  71. const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
  72. return get_algorithm_heuristic(
  73. filter, diff, grad, workspace_limit_in_bytes, positive_attr,
  74. negative_attr)
  75. ->info();
  76. }
  77. size_t get_workspace_in_bytes(
  78. const TensorLayout& filter, const TensorLayout& diff,
  79. const TensorLayout& grad) override;
  80. const char* get_algorithm_set_name() const override;
  81. class AlgoBase;
  82. class AlgoCUDNN;
  83. class AlgoMatmul;
  84. class AlgoChanwise;
  85. class AlgoChanwiseSmall;
  86. class AlgoDepthwiseLargeFilter;
  87. class AlgoGroupConvGeneral;
  88. class AlgoBFloat16;
  89. class AlgoInt8NCHW4DotProdImplicitGemm;
  90. class AlgoInt8NCHWDotProdImplicitGemm;
  91. class AlgoInt8NHWCIMMAImplicitGemm;
  92. class AlgoFloat32NCHWFMAImplicitBatchedGemm;
  93. class AlgoFloat16NCHWHMMAImplicitBatchedGemm;
  94. class AlgoPack;
  95. static const AlgoPack& algo_pack() { return sm_algo_pack; }
  96. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  97. protected:
  98. std::vector<Algorithm*> get_all_algorithms(
  99. const TensorLayout& filter, const TensorLayout& diff,
  100. const TensorLayout& grad) override;
  101. std::vector<Algorithm*> get_all_algorithms_safe(
  102. const TensorLayout& filter, const TensorLayout& diff,
  103. const TensorLayout& grad) override;
  104. Algorithm* get_algorithm_heuristic(
  105. const TensorLayout& filter, const TensorLayout& diff,
  106. const TensorLayout& grad, size_t workspace_limit_in_bytes,
  107. const AlgoAttribute& positive_attr,
  108. const AlgoAttribute& negative_attr) override;
  109. private:
  110. static AlgoPack sm_algo_pack;
  111. };
  112. class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter {
  113. public:
  114. using ConvolutionBackwardFilter::ConvolutionBackwardFilter;
  115. void exec(
  116. _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
  117. _megdnn_workspace workspace) override;
  118. size_t get_workspace_in_bytes(
  119. const TensorLayout& src, const TensorLayout& diff,
  120. const TensorLayout& grad) override;
  121. AlgorithmInfo get_algorithm_info_heuristic(
  122. const TensorLayout& filter, const TensorLayout& diff,
  123. const TensorLayout& grad, size_t workspace_limit_in_bytes,
  124. const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
  125. return get_algorithm_heuristic(
  126. filter, diff, grad, workspace_limit_in_bytes, positive_attr,
  127. negative_attr)
  128. ->info();
  129. }
  130. const char* get_algorithm_set_name() const override;
  131. class AlgoBase;
  132. class AlgoCUDNN;
  133. class AlgoMatmul;
  134. class AlgoChanwise;
  135. class AlgoGroupConvGeneral;
  136. class AlgoBFloat16;
  137. class AlgoFloat32NCHWFMAImplicitBatchedGemm;
  138. class AlgoFloat16NCHWHMMAImplicitBatchedGemm;
  139. class AlgoPack;
  140. static const AlgoPack& algo_pack() { return sm_algo_pack; }
  141. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  142. protected:
  143. std::vector<Algorithm*> get_all_algorithms(
  144. const TensorLayout& src, const TensorLayout& diff,
  145. const TensorLayout& grad) override;
  146. std::vector<Algorithm*> get_all_algorithms_safe(
  147. const TensorLayout& src, const TensorLayout& diff,
  148. const TensorLayout& grad) override;
  149. Algorithm* get_algorithm_heuristic(
  150. const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad,
  151. size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
  152. const AlgoAttribute& negative_attr) override;
  153. private:
  154. static AlgoPack sm_algo_pack;
  155. };
  156. } // namespace cuda
  157. } // namespace megdnn
  158. // vim: syntax=cpp.doxygen