You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_trait.h 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /**
  2. * \file dnn/src/common/opr_trait.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "megdnn/oprs.h"
  14. #include <cstddef>
  15. namespace megdnn {
  16. template <typename Opr>
  17. struct OprTrait {};
  18. #define DEF(Name, Arity, HasWorkspace, CanDeduceLayout) \
  19. template <> \
  20. struct OprTrait<Name> { \
  21. static const size_t arity = Arity; \
  22. static const bool has_workspace = HasWorkspace; \
  23. static const bool can_deduce_layout = CanDeduceLayout; \
  24. }
  25. DEF(Padding, 2, false, true);
  26. DEF(PaddingBackward, 2, false, false);
  27. DEF(ConvolutionForward, 3, true, true);
  28. DEF(Convolution3DForward, 3, true, true);
  29. DEF(ConvolutionBackwardData, 3, true, false);
  30. DEF(ConvolutionBackwardFilter, 3, true, false);
  31. DEF(Convolution3DBackwardData, 3, true, false);
  32. DEF(Convolution3DBackwardFilter, 3, true, false);
  33. DEF(ConvPoolingForward, 4, true, true);
  34. DEF(ConvBiasForward, 5, true, true);
  35. DEF(SeparableConvForward, 4, true, true);
  36. DEF(SeparableFilterForward, 4, true, true);
  37. DEF(Images2NeibsForward, 2, true, true);
  38. DEF(Images2NeibsBackward, 2, true, false);
  39. DEF(SlidingWindowTransposeForward, 2, true, true);
  40. DEF(SlidingWindowTransposeBackward, 2, true, false);
  41. DEF(PoolingForward, 2, true, true);
  42. DEF(PoolingBackward, 4, true, false);
  43. DEF(AdaptivePoolingForward, 2, true, false);
  44. DEF(AdaptivePoolingBackward, 4, true, false);
  45. DEF(LocalForward, 3, true, true);
  46. DEF(LocalBackwardData, 3, true, false);
  47. DEF(LocalBackwardFilter, 3, true, false);
  48. DEF(GroupLocalForward, 3, true, true);
  49. DEF(GroupLocalBackwardData, 3, true, false);
  50. DEF(GroupLocalBackwardFilter, 3, true, false);
  51. DEF(LRNForward, 2, true, true);
  52. DEF(LRNBackward, 4, true, false);
  53. DEF(BNForward, 9, true, true);
  54. DEF(BNBackward, 9, true, false);
  55. DEF(ROIPoolingForward, 4, true, false);
  56. DEF(ROIPoolingBackward, 5, true, false);
  57. DEF(CorrelationForward, 3, true, true);
  58. DEF(CorrelationBackwardData1, 4, true, true);
  59. DEF(CorrelationBackwardData2, 4, true, true);
  60. DEF(WarpPerspectiveForward, 3, true, false);
  61. DEF(WarpPerspectiveBackwardData, 3, true, false);
  62. DEF(WarpPerspectiveBackwardMat, 4, true, false);
  63. DEF(AddUpdateForward, 2, false, false);
  64. DEF(DotForward, 3, true, true);
  65. DEF(MatrixMulForward, 3, true, true);
  66. DEF(BatchedMatrixMulForward, 3, true, true);
  67. DEF(MatrixInverse, 2, true, true);
  68. DEF(SVDForward, 4, true, true);
  69. DEF(ReduceForward, 2, true, true);
  70. DEF(CumsumForward, 2, true, true);
  71. DEF(ArgmaxForward, 2, true, true);
  72. DEF(ArgminForward, 2, true, true);
  73. DEF(TransposeForward, 2, true, true);
  74. DEF(RelayoutForward, 2, false, false);
  75. DEF(TileForward, 2, true, true);
  76. DEF(TileBackward, 2, true, false);
  77. DEF(RepeatForward, 2, true, true);
  78. DEF(RepeatBackward, 2, true, false);
  79. DEF(ArgsortForward, 3, true, true);
  80. DEF(ArgsortBackward, 3, true, false);
  81. DEF(TypeCvtForward, 2, false, false);
  82. DEF(IndexingRemapForward, 3, true, true);
  83. DEF(IndexingRemapBackward, 3, true, false);
  84. DEF(Linspace, 1, true, false);
  85. DEF(Eye, 1, true, false);
  86. DEF(Flip, 2, true, true);
  87. DEF(ROICopy, 2, true, true);
  88. DEF(Rotate, 2, true, true);
  89. DEF(CvtColor, 2, true, true);
  90. DEF(WarpAffine, 3, true, false);
  91. DEF(GaussianBlur, 2, true, true);
  92. DEF(Resize, 2, true, false);
  93. DEF(ResizeBackward, 2, true, false);
  94. DEF(IndexingOneHot, 3, true, true);
  95. DEF(IndexingSetOneHot, 3, true, false);
  96. DEF(MaskConvolution, 4, true, true);
  97. DEF(MaskPropagate, 2, true, true);
  98. DEF(RelayoutFormat, 2, true, true);
  99. DEF(MaxTensorDiff, 2, true, false);
  100. DEF(LocalShareForward, 3, true, true);
  101. DEF(LocalShareBackwardData, 3, true, false);
  102. DEF(LocalShareBackwardFilter, 3, true, false);
  103. DEF(ROIAlignForward, 4, true, true);
  104. DEF(ROIAlignBackward, 4, true, false);
  105. DEF(DeformableConvForward, 5, true, true);
  106. DEF(DeformableConvBackwardFilter, 5, true, false);
  107. DEF(DeformableConvBackwardData, 8, true, false);
  108. DEF(DeformablePSROIPoolingForward, 5, true, true);
  109. DEF(DeformablePSROIPoolingBackward, 7, true, false);
  110. DEF(BatchConvBiasForward, 5, true, true);
  111. DEF(Remap, 3, true, true);
  112. DEF(RemapBackwardData, 3, true, false);
  113. DEF(RemapBackwardMat, 4, true, false);
  114. DEF(DctChannelSelectForward, 4, true, true);
  115. DEF(FakeQuantForward, 4, true, true);
  116. DEF(FakeQuantBackward, 5, true, false);
  117. DEF(TQTForward, 3, true, true);
  118. DEF(TQTBackward, 5, true, false);
  119. DEF(PowC, 2, false, true);
  120. DEF(UniformRNG, 1, true, true);
  121. DEF(GaussianRNG, 1, true, true);
  122. DEF(GammaRNG, 3, true, true);
  123. DEF(BetaRNG, 3, true, true);
  124. DEF(PoissonRNG, 2, true, true);
  125. DEF(PermutationRNG, 1, true, true);
  126. DEF(ShuffleRNGForward, 3, true, true);
  127. DEF(ShuffleRNGBackward, 3, true, false);
  128. DEF(ChecksumForward, 1, true, false);
  129. DEF(CheckNonFinite, 2, true, true);
  130. DEF(LSQForward, 5, true, true);
  131. DEF(LSQBackward, 7, true, false);
  132. DEF(Fill, 1, true, false);
  133. DEF(LayerNormForward, 6, true, true);
  134. DEF(LayerNormBackward, 8, true, true);
  135. DEF(DropoutForward, 3, true, true);
  136. DEF(DropoutBackward, 3, true, true);
  137. DEF(RNNCellForward, 7, true, true);
  138. DEF(RNNForward, 6, true, true);
  139. DEF(RNNBackward, 10, true, true);
  140. DEF(LSTMCellForward, 10, true, true);
  141. DEF(LSTMForward, 8, true, true);
  142. DEF(LSTMBackward, 13, true, true);
  143. DEF(SoftmaxForward, 2, true, true);
  144. DEF(SoftmaxBackward, 3, true, false);
  145. } // namespace megdnn
  146. // vim: syntax=cpp.doxygen