You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_trait.h 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. /**
  2. * \file dnn/test/common/opr_trait.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "megdnn/oprs.h"
  13. #include <cstddef>
  14. namespace megdnn {
  15. namespace test {
  16. template <typename Opr>
  17. struct OprTrait {};
  18. #define DEF(Name, Arity, HasWorkspace, CanDeduceLayout) \
  19. template <> \
  20. struct OprTrait<Name> { \
  21. static const size_t arity = Arity; \
  22. static const bool has_workspace = HasWorkspace; \
  23. static const bool can_deduce_layout = CanDeduceLayout; \
  24. }
  25. DEF(ConvolutionForward, 3, true, true);
  26. DEF(Convolution3DForward, 3, true, true);
  27. DEF(ConvolutionBackwardData, 3, true, false);
  28. DEF(ConvolutionBackwardFilter, 3, true, false);
  29. DEF(Convolution3DBackwardData, 3, true, false);
  30. DEF(Convolution3DBackwardFilter, 3, true, false);
  31. DEF(ConvPoolingForward, 4, true, true);
  32. DEF(ConvBiasForward, 5, true, true);
  33. DEF(SeparableConvForward, 4, true, true);
  34. DEF(SeparableFilterForward, 4, true, true);
  35. DEF(Images2NeibsForward, 2, true, true);
  36. DEF(Images2NeibsBackward, 2, true, false);
  37. DEF(PoolingForward, 2, true, true);
  38. DEF(PoolingBackward, 4, true, false);
  39. DEF(AdaptivePoolingForward, 2, true, false);
  40. DEF(AdaptivePoolingBackward, 4, true, false);
  41. DEF(LocalForward, 3, true, true);
  42. DEF(LocalBackwardData, 3, true, false);
  43. DEF(LocalBackwardFilter, 3, true, false);
  44. DEF(GroupLocalForward, 3, true, true);
  45. DEF(GroupLocalBackwardData, 3, true, false);
  46. DEF(GroupLocalBackwardFilter, 3, true, false);
  47. DEF(LRNForward, 2, true, true);
  48. DEF(LRNBackward, 4, true, false);
  49. DEF(BNForward, 8, true, true);
  50. DEF(BNBackward, 8, true, false);
  51. DEF(ROIPoolingForward, 4, true, false);
  52. DEF(ROIPoolingBackward, 5, true, false);
  53. DEF(WarpPerspectiveForward, 3, true, false);
  54. DEF(WarpPerspectiveBackwardData, 3, true, false);
  55. DEF(WarpPerspectiveBackwardMat, 4, true, false);
  56. DEF(AddUpdateForward, 2, false, false);
  57. DEF(DotForward, 3, true, true);
  58. DEF(MatrixMulForward, 3, true, true);
  59. DEF(BatchedMatrixMulForward, 3, true, true);
  60. DEF(MatrixInverse, 2, true, true);
  61. DEF(SVDForward, 4, true, true);
  62. DEF(ReduceForward, 2, true, true);
  63. DEF(CumsumForward, 2, true, true);
  64. DEF(ArgmaxForward, 2, true, true);
  65. DEF(ArgminForward, 2, true, true);
  66. DEF(TransposeForward, 2, true, true);
  67. DEF(RelayoutForward, 2, false, false);
  68. DEF(TileForward, 2, true, true);
  69. DEF(TileBackward, 2, true, false);
  70. DEF(RepeatForward, 2, true, true);
  71. DEF(RepeatBackward, 2, true, false);
  72. DEF(ArgsortForward, 3, true, true);
  73. DEF(ArgsortBackward, 3, true, false);
  74. DEF(TypeCvtForward, 2, false, false);
  75. DEF(IndexingRemapForward, 3, true, true);
  76. DEF(IndexingRemapBackward, 3, true, false);
  77. DEF(Linspace, 1, true, false);
  78. DEF(Eye, 1, true, false);
  79. DEF(Flip, 2, true, true);
  80. DEF(ROICopy, 2, true, true);
  81. DEF(Rotate, 2, true, true);
  82. DEF(CvtColor, 2, true, true);
  83. DEF(WarpAffine, 3, true, false);
  84. DEF(GaussianBlur, 2, true, true);
  85. DEF(Resize, 2, true, false);
  86. DEF(ResizeBackward, 2, true, false);
  87. DEF(IndexingOneHot, 3, true, true);
  88. DEF(IndexingSetOneHot, 3, true, false);
  89. DEF(MaskConvolution, 4, true, true);
  90. DEF(MaskPropagate, 2, true, true);
  91. DEF(RelayoutFormat, 2, true, true);
  92. DEF(MaxTensorDiff, 2, true, false);
  93. DEF(WinogradFilterPreprocess, 2, true, true);
  94. DEF(LocalShareForward, 3, true, true);
  95. DEF(LocalShareBackwardData, 3, true, false);
  96. DEF(LocalShareBackwardFilter, 3, true, false);
  97. DEF(ROIAlignForward, 4, true, true);
  98. DEF(ROIAlignBackward, 4, true, false);
  99. DEF(DeformableConvForward, 5, true, true);
  100. DEF(DeformableConvBackwardFilter, 5, true, false);
  101. DEF(DeformableConvBackwardData, 8, true, false);
  102. DEF(DeformablePSROIPoolingForward, 5, true, true);
  103. DEF(DeformablePSROIPoolingBackward, 7, true, false);
  104. DEF(BatchConvBiasForward, 5, true, true);
  105. DEF(Remap, 3, true, true);
  106. DEF(RemapBackwardData, 3, true, false);
  107. DEF(RemapBackwardMat, 4, true, false);
  108. } // namespace test
  109. } // namespace megdnn
  110. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台