You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops.td 13 kB


  1. /**
  2. * \file src/core/include/megbrain/ir/ops.td
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #ifndef MGB_OPS
  13. #define MGB_OPS
  14. include "base.td"
  15. include "param_defs.td"
  16. include "mlir/Interfaces/SideEffectInterfaces.td"
  17. def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
  18. let inputs = (ins Variadic<AnyType>:$input);
  19. let results = (outs AnyType);
  20. let nameFunction = [{
  21. return to_string($_self.mode);
  22. }];
  23. }
  24. def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>;
  25. def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
  26. let inputs = (ins AnyType:$inputs);
  27. let extraArguments = (ins
  28. TypeAttr:$idtype,
  29. MgbDTypeAttr:$dtype
  30. );
  31. let results = (outs AnyType);
  32. }
  33. def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>;
  34. def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
  35. def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
  36. def Dot: MgbHashableOp<"Dot", [EmptyParam]>;
  37. def SVD: MgbHashableOp<"SVD", [SVDParam]>;
  38. def Convolution : MgbHashableOp<"Convolution", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;
  39. def ConvolutionBackwardData: MgbHashableOp<"ConvolutionBackwardData", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]> {
  40. let extraArguments = (ins
  41. MgbDTypeAttr:$dtype
  42. );
  43. }
  44. def Convolution3D: MgbHashableOp<"Convolution3D", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;
  45. def Convolution3DBackwardData: MgbHashableOp<"Convolution3DBackwardData", [Convolution3DParam, ExecutionPolicyParamBase<"policy">]>;
  46. def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, ExecutionPolicyParamBase<"policy">]>;
  47. def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>;
  48. def Pooling: MgbHashableOp<"Pooling", [PoolingParam, ExecutionPolicyParamBase<"policy">]>;
  49. def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>;
  50. def ROIPooling: MgbHashableOp<"ROIPooling", [ROIPoolingParam]>;
  51. def DeformablePSROIPooling : MgbHashableOp<"DeformablePSROIPooling", [DeformablePSROIPoolingParam]>;
  52. def ConvBias : MgbHashableOp<"ConvBias", [ConvBiasParam, ExecutionPolicyParamBase<"policy">]> {
  53. let extraArguments = (ins
  54. MgbDTypeAttr:$dtype
  55. );
  56. }
  57. def BatchConvBias : MgbHashableOp<"BatchConvBias", [BatchConvBiasParam, ExecutionPolicyParamBase<"policy">]> {
  58. let extraArguments = (ins
  59. MgbDTypeAttr:$dtype
  60. );
  61. }
  62. def Images2Neibs : MgbHashableOp<"Images2Neibs", [Images2NeibsParam]>;
  63. def SlidingWindowTranspose : MgbHashableOp<"SlidingWindowTranspose", [SlidingWindowTransposeParam]>;
  64. def BatchNorm : MgbHashableOp<"BatchNorm", [BNParam]>;
  65. def ROIAlign: MgbHashableOp<"ROIAlign", [ROIAlignParam]>;
  66. def Correlation: MgbHashableOp<"Correlation", [CorrelationParam]>;
  67. def WarpPerspective: MgbHashableOp<"WarpPerspective", [WarpPerspectiveParam]>;
  68. def WarpAffine: MgbHashableOp<"WarpAffine", [WarpAffineParam]>;
  69. def Remap: MgbHashableOp<"Remap", [RemapParam]>;
  70. def Resize: MgbHashableOp<"Resize", [ResizeParam]>;
  71. def IndexingOneHot: MgbHashableOp<"IndexingOneHot", [AxisParam]>;
  72. def IndexingSetOneHot: MgbHashableOp<"IndexingSetOneHot", [AxisParam]>;
  73. def Copy: MgbHashableOp<"Copy"> {
  74. let extraArguments = (ins
  75. MgbCompNodeAttr:$comp_node
  76. );
  77. }
  78. def Argsort: MgbHashableOp<"Argsort", [ArgsortParam]>;
  79. def Argmax : MgbHashableOp<"Argmax", [AxisParam]>;
  80. def Argmin : MgbHashableOp<"Argmin", [AxisParam]>;
  81. def CondTake : MgbHashableOp<"CondTake">;
  82. def TopK: MgbHashableOp<"TopK", [TopKParam]>;
  83. def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;
  84. def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> {
  85. let extraArguments = (ins
  86. MgbSizeTAddr:$handle
  87. );
  88. let hashFunction = [{
  89. return mgb::hash_pair_combine(
  90. mgb::hash($_self.dyn_typeinfo()),
  91. mgb::hash_pair_combine(
  92. mgb::hash($_self.handle),
  93. mgb::hash($_self.dtype.enumv())
  94. )
  95. );
  96. }];
  97. let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}];
  98. }
  99. def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> {
  100. let extraArguments = (ins
  101. MgbSizeTAddr:$handle
  102. );
  103. let hashFunction = [{
  104. return mgb::hash_pair_combine(
  105. mgb::hash($_self.dyn_typeinfo()),
  106. mgb::hash_pair_combine(
  107. mgb::hash($_self.handle),
  108. mgb::hash_pair_combine(
  109. mgb::hash($_self.mean),
  110. mgb::hash_pair_combine(
  111. mgb::hash($_self.std),
  112. mgb::hash($_self.dtype.enumv())
  113. )
  114. )
  115. )
  116. );
  117. }];
  118. let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std && $0.dtype == $1.dtype;}];
  119. }
  120. def GammaRNG: MgbHashableOp<"GammaRNG", [GammaRNGParam]> {
  121. let extraArguments = (ins
  122. MgbSizeTAddr:$handle
  123. );
  124. let hashFunction = [{
  125. return mgb::hash_pair_combine(
  126. mgb::hash($_self.dyn_typeinfo()),
  127. mgb::hash($_self.handle)
  128. );
  129. }];
  130. let cmpFunction = [{return $0.handle == $1.handle;}];
  131. }
  132. def PoissonRNG: MgbHashableOp<"PoissonRNG", [PoissonRNGParam]> {
  133. let extraArguments = (ins
  134. MgbSizeTAddr:$handle
  135. );
  136. let hashFunction = [{
  137. return mgb::hash_pair_combine(
  138. mgb::hash($_self.dyn_typeinfo()),
  139. mgb::hash($_self.handle)
  140. );
  141. }];
  142. let cmpFunction = [{return $0.handle == $1.handle;}];
  143. }
  144. def BetaRNG: MgbHashableOp<"BetaRNG", [BetaRNGParam]> {
  145. let extraArguments = (ins
  146. MgbSizeTAddr:$handle
  147. );
  148. let hashFunction = [{
  149. return mgb::hash_pair_combine(
  150. mgb::hash($_self.dyn_typeinfo()),
  151. mgb::hash($_self.handle)
  152. );
  153. }];
  154. let cmpFunction = [{return $0.handle == $1.handle;}];
  155. }
  156. def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> {
  157. let extraArguments = (ins
  158. MgbSizeTAddr:$handle
  159. );
  160. let hashFunction = [{
  161. return mgb::hash_pair_combine(
  162. mgb::hash($_self.dyn_typeinfo()),
  163. mgb::hash_pair_combine(
  164. mgb::hash($_self.handle),
  165. mgb::hash($_self.dtype.enumv())
  166. )
  167. );
  168. }];
  169. let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}];
  170. }
  171. def ShuffleRNG: MgbHashableOp<"ShuffleRNG", [ShuffleRNGParam]> {
  172. let extraArguments = (ins
  173. MgbSizeTAddr:$handle
  174. );
  175. let hashFunction = [{
  176. return mgb::hash_pair_combine(
  177. mgb::hash($_self.dyn_typeinfo()),
  178. mgb::hash($_self.handle)
  179. );
  180. }];
  181. let cmpFunction = [{return $0.handle == $1.handle;}];
  182. }
  183. def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
  184. let extraArguments = (ins
  185. MgbCompNodeAttr:$comp_node
  186. );
  187. }
  188. def Eye: MgbHashableOp<"Eye", [EyeParam]> {
  189. let extraArguments = (ins
  190. MgbCompNodeAttr:$comp_node
  191. );
  192. }
  193. def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>;
  194. def Concat: MgbHashableOp<"Concat", [AxisParam]> {
  195. let extraArguments = (ins
  196. MgbCompNodeAttr:$comp_node
  197. );
  198. }
  199. def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]>;
  200. def Identity: MgbHashableOp<"Identity">;
  201. def CollectiveComm : MgbHashableOp<"CollectiveComm", [CollectiveCommParam]> {
  202. let extraArguments = (ins
  203. MgbStringAttr:$key,
  204. MgbUI32Attr:$nr_devices,
  205. MgbUI32Attr:$rank,
  206. MgbBoolAttr:$is_root,
  207. MgbBoolAttr:$local_grad,
  208. MgbStringAttr:$addr,
  209. MgbUI32Attr:$port,
  210. MgbDTypeAttr:$dtype,
  211. MgbStringAttr:$backend,
  212. MgbStringAttr:$comp_node
  213. );
  214. }
  215. def RemoteSend : MgbHashableOp<"RemoteSend"> {
  216. let extraArguments = (ins
  217. MgbStringAttr:$key,
  218. MgbStringAttr:$addr,
  219. MgbUI32Attr:$port,
  220. MgbUI32Attr:$rank_to,
  221. MgbStringAttr:$backend
  222. );
  223. }
  224. def RemoteRecv : MgbHashableOp<"RemoteRecv"> {
  225. let extraArguments = (ins
  226. MgbStringAttr:$key,
  227. MgbStringAttr:$addr,
  228. MgbUI32Attr:$port,
  229. MgbUI32Attr:$rank_from,
  230. MgbCompNodeAttr:$cn,
  231. MgbTensorShapeAttr:$shape,
  232. MgbDTypeAttr:$dtype,
  233. MgbStringAttr:$backend
  234. );
  235. }
  236. def NMSKeep : MgbHashableOp<"NMSKeep"> {
  237. let extraArguments = (ins
  238. MgbF32Attr:$iou_thresh,
  239. MgbUI32Attr:$max_output
  240. );
  241. }
  242. def ParamPackSplit : MgbHashableOp<"ParamPackSplit"> {
  243. let extraArguments = (ins
  244. MgbArrayAttr<MgbI32Attr>:$offsets,
  245. MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$shapes
  246. );
  247. }
  248. def ParamPackConcat : MgbHashableOp<"ParamPackConcat"> {
  249. let extraArguments = (ins
  250. MgbArrayAttr<MgbI32Attr>:$offsets
  251. );
  252. }
  253. def Dimshuffle: MgbHashableOp<"Dimshuffle"> {
  254. let inputs = (ins AnyMemRef:$input);
  255. let extraArguments = (ins MgbArrayAttr<MgbI32Attr>:$pattern);
  256. let results = (outs AnyMemRef);
  257. }
  258. def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]>;
  259. // TODO: merge Add/Remove Axis into AxisAddRemove as megbrain?
  260. def AddAxis: MgbHashableOp<"AddAxis"> {
  261. let extraArguments = (ins
  262. MgbArrayAttr<MgbI32Attr>:$axis
  263. );
  264. }
  265. def RemoveAxis: MgbHashableOp<"RemoveAxis"> {
  266. let extraArguments = (ins
  267. MgbArrayAttr<MgbI32Attr>:$axis
  268. );
  269. }
  270. class FancyIndexingBase<string name>: MgbHashableOp<name> {
  271. let extraArguments = (ins
  272. MgbArrayAttr<MgbTupleAttr<
  273. [MgbI8Attr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr, MgbBoolAttr]>>:$items
  274. );
  275. }
  276. def Subtensor: FancyIndexingBase<"Subtensor">;
  277. def SetSubtensor: FancyIndexingBase<"SetSubtensor">;
  278. def IncrSubtensor: FancyIndexingBase<"IncrSubtensor">;
  279. def IndexingMultiAxisVec: FancyIndexingBase<"IndexingMultiAxisVec">;
  280. def IndexingSetMultiAxisVec: FancyIndexingBase<"IndexingSetMultiAxisVec">;
  281. def IndexingIncrMultiAxisVec: FancyIndexingBase<"IndexingIncrMultiAxisVec">;
  282. def MeshIndexing: FancyIndexingBase<"MeshIndexing">;
  283. def IncrMeshIndexing: FancyIndexingBase<"IncrMeshIndexing">;
  284. def SetMeshIndexing: FancyIndexingBase<"SetMeshIndexing">;
  285. def BatchedMeshIndexing: FancyIndexingBase<"BatchedMeshIndexing">;
  286. def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">;
  287. def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">;
  288. def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>;
  289. def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>;
  290. def TQT: MgbHashableOp<"TQT", [TQTParam]>;
  291. def LSQ: MgbHashableOp<"LSQ", [LSQParam]>;
  292. def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> {
  293. let extraArguments = (ins
  294. MgbDTypeAttr:$dtype
  295. );
  296. let nameFunction = [{
  297. return to_string($_self.mode);
  298. }];
  299. }
  300. def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;
  301. def TensorRTRuntime: MgbHashableOp<"TensorRTRuntime"> {
  302. let extraArguments = (ins
  303. MgbStringAttr:$buf,
  304. MgbSizeTAddr:$buf_size
  305. );
  306. }
  307. def AtlasRuntime: MgbHashableOp<"AtlasRuntime"> {
  308. let extraArguments = (ins
  309. MgbStringAttr:$buf,
  310. MgbSizeTAddr:$buf_size
  311. );
  312. }
  313. def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> {
  314. let extraArguments = (ins
  315. MgbStringAttr:$buf,
  316. MgbSizeTAddr:$buf_size,
  317. MgbStringAttr:$symbol,
  318. MgbBoolAttr:$tensor_dim_mutable
  319. );
  320. }
  321. def MagicMindRuntime: MgbHashableOp<"MagicMindRuntime"> {
  322. let extraArguments = (ins
  323. MgbStringAttr:$buf,
  324. MgbSizeTAddr:$buf_size
  325. );
  326. }
  327. def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;
  328. def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [CheckNonFiniteParam]>;
  329. def FastpathCopy: MgbHashableOp<"FastpathCopy">;
  330. def ExternOpr: MgbHashableOp<"ExternOpr"> {
  331. let extraArguments = (ins
  332. MgbArrayAttr<MgbArrayAttr<MgbSizeTAddr>>:$output_shapes,
  333. MgbStringAttr:$name,
  334. MgbStringAttr:$data,
  335. MgbSizeTAddr:$data_len,
  336. MgbArrayAttr<MgbDTypeAttr>:$output_dtypes
  337. );
  338. let hashFunction = [{
  339. return mgb::hash_pair_combine(
  340. mgb::hash($_self.dyn_typeinfo()),
  341. mgb::hash_pair_combine(
  342. mgb::hash($_self.name),
  343. mgb::hash($_self.data))
  344. );
  345. }];
  346. }
  347. def Cumsum: MgbHashableOp<"Cumsum", [CumsumParam]>;
  348. def Split: MgbHashableOp<"Split", [EmptyParam]> {
  349. let extraArguments = (ins
  350. MgbI32Attr:$axis
  351. );
  352. }
  353. def Padding: MgbHashableOp<"Padding", [PaddingParam]>;
  354. def LRN: MgbHashableOp<"LRN", [LRNParam]>;
  355. def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>;
  356. def Dropout: MgbHashableOp<"Dropout", [DropoutParam]> {
  357. let extraArguments = (ins
  358. MgbSizeTAddr:$handle
  359. );
  360. let hashFunction = [{
  361. return mgb::hash_pair_combine(
  362. mgb::hash($_self.dyn_typeinfo()),
  363. mgb::hash_pair_combine(
  364. mgb::hash($_self.drop_prob),
  365. mgb::hash($_self.handle))
  366. );
  367. }];
  368. let cmpFunction = [{return $0.handle == $1.handle && $0.drop_prob == $1.drop_prob;}];
  369. }
  370. #endif // MGB_OPS