You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_trait.h 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #pragma once
  2. #include "megdnn/oprs.h"
  3. #include <cstddef>
  4. namespace megdnn {
  5. template <typename Opr>
  6. struct OprTrait {};
  7. #define DEF(Name, Arity, HasWorkspace, CanDeduceLayout) \
  8. template <> \
  9. struct OprTrait<Name> { \
  10. static const size_t arity = Arity; \
  11. static const bool has_workspace = HasWorkspace; \
  12. static const bool can_deduce_layout = CanDeduceLayout; \
  13. }
  14. DEF(Norm, 2, true, true);
  15. DEF(Padding, 2, false, true);
  16. DEF(PaddingBackward, 2, false, false);
  17. DEF(ConvolutionForward, 3, true, true);
  18. DEF(Convolution3DForward, 3, true, true);
  19. DEF(ConvolutionBackwardData, 3, true, false);
  20. DEF(ConvolutionBackwardFilter, 3, true, false);
  21. DEF(Convolution3DBackwardData, 3, true, false);
  22. DEF(Convolution3DBackwardFilter, 3, true, false);
  23. DEF(ConvPoolingForward, 4, true, true);
  24. DEF(ConvBiasForward, 5, true, true);
  25. DEF(SeparableConvForward, 4, true, true);
  26. DEF(SeparableFilterForward, 4, true, true);
  27. DEF(Images2NeibsForward, 2, true, true);
  28. DEF(Images2NeibsBackward, 2, true, false);
  29. DEF(SlidingWindowTransposeForward, 2, true, true);
  30. DEF(SlidingWindowTransposeBackward, 2, true, false);
  31. DEF(PoolingForward, 2, true, true);
  32. DEF(PoolingBackward, 4, true, false);
  33. DEF(AdaptivePoolingForward, 2, true, false);
  34. DEF(AdaptivePoolingBackward, 4, true, false);
  35. DEF(LocalForward, 3, true, true);
  36. DEF(LocalBackwardData, 3, true, false);
  37. DEF(LocalBackwardFilter, 3, true, false);
  38. DEF(GroupLocalForward, 3, true, true);
  39. DEF(GroupLocalBackwardData, 3, true, false);
  40. DEF(GroupLocalBackwardFilter, 3, true, false);
  41. DEF(LRNForward, 2, true, true);
  42. DEF(LRNBackward, 4, true, false);
  43. DEF(BNForward, 9, true, true);
  44. DEF(BNBackward, 9, true, false);
  45. DEF(ROIPoolingForward, 4, true, false);
  46. DEF(ROIPoolingBackward, 5, true, false);
  47. DEF(CorrelationForward, 3, true, true);
  48. DEF(CorrelationBackwardData1, 4, true, true);
  49. DEF(CorrelationBackwardData2, 4, true, true);
  50. DEF(WarpPerspectiveForward, 3, true, false);
  51. DEF(WarpPerspectiveBackwardData, 3, true, false);
  52. DEF(WarpPerspectiveBackwardMat, 4, true, false);
  53. DEF(AddUpdateForward, 2, false, false);
  54. DEF(DotForward, 3, true, true);
  55. DEF(MatrixMulForward, 3, true, true);
  56. DEF(BatchedMatrixMulForward, 3, true, true);
  57. DEF(MatrixInverse, 2, true, true);
  58. DEF(SVDForward, 4, true, true);
  59. DEF(ReduceForward, 2, true, true);
  60. DEF(CumsumForward, 2, true, true);
  61. DEF(ArgmaxForward, 2, true, true);
  62. DEF(ArgminForward, 2, true, true);
  63. DEF(TransposeForward, 2, true, true);
  64. DEF(RelayoutForward, 2, false, false);
  65. DEF(TileForward, 2, true, true);
  66. DEF(TileBackward, 2, true, false);
  67. DEF(RepeatForward, 2, true, true);
  68. DEF(RepeatBackward, 2, true, false);
  69. DEF(ArgsortForward, 3, true, true);
  70. DEF(ArgsortBackward, 3, true, false);
  71. DEF(TypeCvtForward, 2, false, false);
  72. DEF(IndexingRemapForward, 3, true, true);
  73. DEF(IndexingRemapBackward, 3, true, false);
  74. DEF(Linspace, 1, true, false);
  75. DEF(Eye, 1, true, false);
  76. DEF(Diag, 2, true, true);
  77. DEF(Flip, 2, true, true);
  78. DEF(ROICopy, 2, true, true);
  79. DEF(Rotate, 2, true, true);
  80. DEF(CvtColor, 2, true, true);
  81. DEF(WarpAffine, 3, true, false);
  82. DEF(GaussianBlur, 2, true, true);
  83. DEF(Resize, 2, true, false);
  84. DEF(ResizeBackward, 2, true, false);
  85. DEF(IndexingOneHot, 3, true, true);
  86. DEF(IndexingSetOneHot, 3, true, false);
  87. DEF(MaskConvolution, 4, true, true);
  88. DEF(MaskPropagate, 2, true, true);
  89. DEF(RelayoutFormat, 2, true, true);
  90. DEF(MaxTensorDiff, 2, true, false);
  91. DEF(LocalShareForward, 3, true, true);
  92. DEF(LocalShareBackwardData, 3, true, false);
  93. DEF(LocalShareBackwardFilter, 3, true, false);
  94. DEF(ROIAlignForward, 4, true, true);
  95. DEF(ROIAlignBackward, 4, true, false);
  96. DEF(DeformableConvForward, 5, true, true);
  97. DEF(DeformableConvBackwardFilter, 5, true, false);
  98. DEF(DeformableConvBackwardData, 8, true, false);
  99. DEF(DeformablePSROIPoolingForward, 5, true, true);
  100. DEF(DeformablePSROIPoolingBackward, 7, true, false);
  101. DEF(BatchConvBiasForward, 5, true, true);
  102. DEF(Remap, 3, true, true);
  103. DEF(RemapBackwardData, 3, true, false);
  104. DEF(RemapBackwardMat, 4, true, false);
  105. DEF(DctChannelSelectForward, 4, true, true);
  106. DEF(FakeQuantForward, 4, true, true);
  107. DEF(FakeQuantBackward, 5, true, false);
  108. DEF(TQTForward, 3, true, true);
  109. DEF(TQTBackward, 5, true, false);
  110. DEF(PowC, 2, false, true);
  111. DEF(UniformRNG, 1, true, true);
  112. DEF(GaussianRNG, 1, true, true);
  113. DEF(GammaRNG, 3, true, true);
  114. DEF(BetaRNG, 3, true, true);
  115. DEF(PoissonRNG, 2, true, true);
  116. DEF(PermutationRNG, 1, true, true);
  117. DEF(ShuffleRNGForward, 3, true, true);
  118. DEF(ShuffleRNGBackward, 3, true, false);
  119. DEF(ChecksumForward, 1, true, false);
  120. DEF(CheckNonFinite, 2, true, true);
  121. DEF(LSQForward, 5, true, true);
  122. DEF(LSQBackward, 7, true, false);
  123. DEF(Fill, 1, true, false);
  124. DEF(LayerNormForward, 6, true, true);
  125. DEF(LayerNormBackward, 8, true, true);
  126. DEF(LAMBUpdate, 7, true, true);
  127. DEF(DropoutForward, 3, true, true);
  128. DEF(DropoutBackward, 3, true, true);
  129. DEF(RNNCellForward, 7, true, true);
  130. DEF(RNNForward, 6, true, true);
  131. DEF(RNNBackward, 10, true, true);
  132. DEF(LSTMCellForward, 10, true, true);
  133. DEF(LSTMForward, 8, true, true);
  134. DEF(LSTMBackward, 13, true, true);
  135. DEF(SoftmaxForward, 2, true, true);
  136. DEF(SoftmaxBackward, 3, true, false);
  137. } // namespace megdnn
  138. // vim: syntax=cpp.doxygen