You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.td 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. /**
  2. * \file src/core/include/megbrain/ir/ops.td
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #ifndef MGB_OPS
  13. #define MGB_OPS
  14. include "base.td"
  15. include "param_defs.td"
  16. include "mlir/Interfaces/SideEffectInterfaces.td"
  17. def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
  18. let inputs = (ins Variadic<AnyType>:$input);
  19. let results = (outs AnyType);
  20. let nameFunction = [{
  21. return to_string($_self.mode);
  22. }];
  23. }
  24. def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>;
  25. def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
  26. let inputs = (ins AnyType:$inputs);
  27. let extraArguments = (ins
  28. TypeAttr:$idtype,
  29. MgbDTypeAttr:$dtype
  30. );
  31. let results = (outs AnyType);
  32. }
  33. def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>;
  34. def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
  35. def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
  36. def Dot: MgbHashableOp<"Dot", [EmptyParam]>;
  37. def SVD: MgbHashableOp<"SVD", [SVDParam]>;
  38. def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;
  39. def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;
  40. def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;
  41. def Convolution3DBackwardData: MgbHashableOp<"Convolution3DBackwardData", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;
  42. def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;
  43. def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>;
  44. def Pooling: MgbHashableOp<"Pooling", [PoolingParam]>;
  45. def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>;
  46. def ROIPooling: MgbHashableOp<"ROIPooling", [ROIPoolingParam]>;
  47. def DeformablePSROIPooling : MgbHashableOp<"DeformablePSROIPooling", [DeformablePSROIPoolingParam]>;
  48. def ConvBias : MgbHashableOp<"ConvBias", [ConvBiasParam, ExecutionPolicyParamBase<"policy">]> {
  49. let extraArguments = (ins
  50. MgbDTypeAttr:$dtype
  51. );
  52. }
  53. def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, ExecutionPolicyParamBase<"policy">]> {
  54. let extraArguments = (ins
  55. MgbDTypeAttr:$dtype
  56. );
  57. }
  58. def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>;
  59. def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>;
  60. def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>;
  61. def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>;
  62. def WarpAffine: MgbHashableOp<"WarpAffine", [WarpAffineParam]>;
  63. def Remap: MgbHashableOp<"Remap", [RemapParam]>;
  64. def Resize: MgbHashableOp<"Resize", [ResizeParam]>;
  65. def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>;
  66. def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>;
  67. def Copy: MgbHashableOp<"Copy"> {
  68. let extraArguments = (ins
  69. MgbCompNodeAttr:$comp_node
  70. );
  71. }
  72. def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>;
  73. def Argmax : MgbHashableOp<"Argmax", [AxisParam]>;
  74. def Argmin : MgbHashableOp<"Argmin", [AxisParam]>;
  75. def CondTake : MgbHashableOp<"CondTake">;
  76. def TopK: MgbHashableOp<"TopK", [TopKParam]>;
  77. def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;
  78. def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> {
  79. let extraArguments = (ins
  80. MgbSizeTAddr:$handle
  81. );
  82. let hashFunction = [{
  83. return mgb::hash_pair_combine(
  84. mgb::hash($_self.dyn_typeinfo()),
  85. mgb::hash($_self.handle));
  86. }];
  87. let cmpFunction = [{return $0.handle == $1.handle;}];
  88. }
  89. def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> {
  90. let extraArguments = (ins
  91. MgbSizeTAddr:$handle
  92. );
  93. let hashFunction = [{
  94. return mgb::hash_pair_combine(
  95. mgb::hash($_self.dyn_typeinfo()),
  96. mgb::hash_pair_combine(
  97. mgb::hash($_self.handle),
  98. mgb::hash_pair_combine(
  99. mgb::hash($_self.mean),
  100. mgb::hash($_self.std))
  101. )
  102. );
  103. }];
  104. let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std;}];
  105. }
  106. def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
  107. let extraArguments = (ins
  108. MgbCompNodeAttr:$comp_node
  109. );
  110. }
  111. def Eye: MgbHashableOp<"Eye", [EyeParam]> {
  112. let extraArguments = (ins
  113. MgbCompNodeAttr:$comp_node
  114. );
  115. }
  116. def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>;
  117. def Concat: MgbHashableOp<"Concat", [AxisParam]> {
  118. let extraArguments = (ins
  119. MgbCompNodeAttr:$comp_node
  120. );
  121. }
  122. def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]>;
  123. def Identity: MgbHashableOp<"Identity">;
  124. def CollectiveComm : MgbHashableOp<"CollectiveComm", [CollectiveCommParam]> {
  125. let extraArguments = (ins
  126. MgbStringAttr:$key,
  127. MgbUI32Attr:$nr_devices,
  128. MgbUI32Attr:$rank,
  129. MgbBoolAttr:$is_root,
  130. MgbBoolAttr:$local_grad,
  131. MgbStringAttr:$addr,
  132. MgbUI32Attr:$port,
  133. MgbDTypeAttr:$dtype,
  134. MgbStringAttr:$backend,
  135. MgbStringAttr:$comp_node
  136. );
  137. }
  138. def RemoteSend : MgbHashableOp<"RemoteSend"> {
  139. let extraArguments = (ins
  140. MgbStringAttr:$key,
  141. MgbStringAttr:$addr,
  142. MgbUI32Attr:$port,
  143. MgbUI32Attr:$rank_to,
  144. MgbStringAttr:$backend
  145. );
  146. }
  147. def RemoteRecv : MgbHashableOp<"RemoteRecv"> {
  148. let extraArguments = (ins
  149. MgbStringAttr:$key,
  150. MgbStringAttr:$addr,
  151. MgbUI32Attr:$port,
  152. MgbUI32Attr:$rank_from,
  153. MgbCompNodeAttr:$cn,
  154. MgbTensorShapeAttr:$shape,
  155. MgbDTypeAttr:$dtype,
  156. MgbStringAttr:$backend
  157. );
  158. }
  159. def NMSKeep : MgbHashableOp<"NMSKeep"> {
  160. let extraArguments = (ins
  161. MgbF32Attr:$iou_thresh,
  162. MgbUI32Attr:$max_output
  163. );
  164. }
  165. def ParamPackSplit : MgbHashableOp<"ParamPackSplit"> {
  166. let extraArguments = (ins
  167. MgbArrayAttr<MgbI32Attr>:$offsets,
  168. MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$shapes
  169. );
  170. }
  171. def ParamPackConcat : MgbHashableOp<"ParamPackConcat"> {
  172. let extraArguments = (ins
  173. MgbArrayAttr<MgbI32Attr>:$offsets
  174. );
  175. }
  176. def Dimshuffle: MgbHashableOp<"Dimshuffle"> {
  177. let inputs = (ins AnyMemRef:$input);
  178. let extraArguments = (ins MgbArrayAttr<MgbI32Attr>:$pattern);
  179. let results = (outs AnyMemRef);
  180. }
  181. def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]>;
  182. // TODO: merge Add/Remove Axis into AxisAddRemove as megbrain?
  183. def AddAxis: MgbHashableOp<"AddAxis"> {
  184. let extraArguments = (ins
  185. MgbArrayAttr<MgbI32Attr>:$axis
  186. );
  187. }
  188. def RemoveAxis: MgbHashableOp<"RemoveAxis"> {
  189. let extraArguments = (ins
  190. MgbArrayAttr<MgbI32Attr>:$axis
  191. );
  192. }
  193. class FancyIndexingBase<string name>: MgbHashableOp<name> {
  194. let extraArguments = (ins
  195. MgbArrayAttr<MgbTupleAttr<
  196. [MgbI8Attr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr]>>:$items
  197. );
  198. }
  199. def Subtensor: FancyIndexingBase<"Subtensor">;
  200. def SetSubtensor: FancyIndexingBase<"SetSubtensor">;
  201. def IncrSubtensor: FancyIndexingBase<"IncrSubtensor">;
  202. def IndexingMultiAxisVec: FancyIndexingBase<"IndexingMultiAxisVec">;
  203. def IndexingSetMultiAxisVec: FancyIndexingBase<"IndexingSetMultiAxisVec">;
  204. def IndexingIncrMultiAxisVec: FancyIndexingBase<"IndexingIncrMultiAxisVec">;
  205. def MeshIndexing: FancyIndexingBase<"MeshIndexing">;
  206. def IncrMeshIndexing: FancyIndexingBase<"IncrMeshIndexing">;
  207. def SetMeshIndexing: FancyIndexingBase<"SetMeshIndexing">;
  208. def BatchedMeshIndexing: FancyIndexingBase<"BatchedMeshIndexing">;
  209. def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">;
  210. def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">;
  211. def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>;
  212. def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>;
  213. def TQT: MgbHashableOp<"TQT", [TQTParam]>;
  214. def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> {
  215. let extraArguments = (ins
  216. MgbDTypeAttr:$dtype
  217. );
  218. let nameFunction = [{
  219. return to_string($_self.mode);
  220. }];
  221. }
  222. def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;
  223. def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> {
  224. let extraArguments = (ins
  225. MgbStringAttr:$buf,
  226. MgbSizeTAddr:$buf_size
  227. );
  228. }
  229. def AtlasRuntime: MgbHashableOp<"AtlasRuntime"> {
  230. let extraArguments = (ins
  231. MgbStringAttr:$buf,
  232. MgbSizeTAddr:$buf_size
  233. );
  234. }
  235. def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> {
  236. let extraArguments = (ins
  237. MgbStringAttr:$buf,
  238. MgbSizeTAddr:$buf_size,
  239. MgbStringAttr:$symbol,
  240. MgbBoolAttr:$tensor_dim_mutable
  241. );
  242. }
  243. def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;
  244. #endif // MGB_OPS

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台