You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_param_defs.py 59 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262
  1. pdef('Empty')
  2. pdef('Axis').add_fields('int32', 'axis', 0)
  3. (pdef('Convolution', version=0, is_legacy=True).
  4. add_enum('Mode', 'CROSS_CORRELATION = 0', 'CONVOLUTION = 1').
  5. add_fields(
  6. 'uint32',
  7. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  8. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  9. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  10. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  11. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  12. 'on the second dimension'), 1,
  13. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  14. 'on the second dimension'), 1
  15. ).
  16. add_enum('DataType',
  17. Doc('FLOAT = 0', 'input/output both float32/float16'),
  18. 'INT8x8x16 = 1',
  19. 'INT8x8x32 = 2',
  20. Doc('FLOAT_IO16xC32 = 3', 'input/output both float16, the internal '
  21. 'compute is float32'),
  22. Doc('QUINT8x8x32 = 4', 'input QuantizedAsymm8, output QuantizedS32'),
  23. Doc('INT8x8xX = 5', 'input int8, output specified by tensor DType'),
  24. Doc('QUINT4x4x32 = 6', 'input QuantizedAsymm4, output QuantizedS32'),
  25. name_field='data_type').
  26. add_enum('Sparse',
  27. Doc('DENSE = 0', 'dense convolution: filter shape should be '
  28. '[oc, ic, spatial...] if format is NCHW, '
  29. '[oc, spatial..., ic] if format is NHWC'),
  30. Doc('GROUP = 1', 'group convolution: filter shape should be '
  31. '[group, oc_per_group, ic_per_group, spatial...] if format is NCHW, '
  32. '[group, oc_per_group, spatial..., ic_per_group] if format is NHWC')
  33. ).
  34. add_enum(Doc('Format', 'convolution data/filter/output format; see '
  35. ':class:`RelayoutFormat` for more details'),
  36. 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6',
  37. 'NCHW44 = 7', 'NCHW44_DOT = 8',
  38. Doc('NCHW_WINOGRAD = 9', 'NCHW layout with weights tranformed by winograd'),
  39. Doc('NCHW88_WINOGRAD = 10',
  40. 'NCHW88 layout with weights tranformed by winograd'),
  41. Doc('NCHW44_WINOGRAD = 11',
  42. 'NCHW44 layout with weights tranformed by winograd'),
  43. Doc('NCHW4_NCHW32 = 12',
  44. 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
  45. Doc('NCHW32_NCHW4 = 13',
  46. 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
  47. Doc('NCHW4_NCHW = 14',
  48. 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
  49. Doc('NHWC_NCHW = 15', 'NHWC_NCHW means input tensors are nhwc layout, '
  50. 'output tensor is nchw layout'),
  51. Doc('NHWC_NCHW4_IC_SMALL = 16', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
  52. 'output tensor is nchw4 layout, padding c=4'),
  53. Doc('NCHW_NCHW4_IC_SMALL = 17', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
  54. 'output tensor is nchw4 layout, padding c=4'),
  55. Doc('CHWN4 = 18', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
  56. 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'),
  57. Doc('NCHW4_NHWC = 19', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'))
  58. )
  59. (pdef('Convolution', version=1, is_legacy=True).
  60. add_enum_alias('Mode', 'ConvolutionV0').
  61. add_fields(
  62. 'uint32',
  63. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  64. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  65. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  66. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  67. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  68. 'on the second dimension'), 1,
  69. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  70. 'on the second dimension'), 1
  71. ).
  72. add_enum_alias('Sparse', 'ConvolutionV0').
  73. add_enum_alias('Format', 'ConvolutionV0').
  74. add_enum(Doc('ComputeMode', 'Specifies special computation modes, e.g. '
  75. 'different combinations of intermediate result '
  76. 'data types.'),
  77. Doc('DEFAULT = 0', 'No special requirements on the precision of '
  78. 'intermediate results.'),
  79. Doc('FLOAT32 = 1', 'Use Float32 accumulator and intermediate result. '
  80. 'Only supported when input and output is Float16.'),
  81. name_field='compute_mode')
  82. )
  83. (pdef('Convolution', version=2).
  84. add_enum_alias('Mode', 'ConvolutionV0').
  85. add_fields(
  86. 'uint32',
  87. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  88. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  89. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  90. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  91. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  92. 'on the second dimension'), 1,
  93. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  94. 'on the second dimension'), 1
  95. ).
  96. add_enum_alias('Sparse', 'ConvolutionV0').
  97. add_enum(Doc('Format', 'convolution data/filter/output format; see '
  98. ':class:`RelayoutFormat` for more details'),
  99. 'NCHW = 0', 'NHWC = 1', 'NHWCD4 = 2', 'NCHW4 = 3', 'NCHW8 = 4', 'NCHW32 = 5', 'NCHW88 = 6',
  100. 'NCHW44 = 7', 'NCHW44_DOT = 8',
  101. Doc('NCHW4_NCHW32 = 9',
  102. 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
  103. Doc('NCHW32_NCHW4 = 10',
  104. 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
  105. Doc('NCHW4_NCHW = 11',
  106. 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
  107. Doc('NHWC_NCHW = 12', 'NHWC_NCHW means input tensors are nhwc layout, '
  108. 'output tensor is nchw layout'),
  109. Doc('NHWC_NCHW4_IC_SMALL = 13', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, '
  110. 'output tensor is nchw4 layout, padding c=4'),
  111. Doc('NCHW_NCHW4_IC_SMALL = 14', 'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
  112. 'output tensor is nchw4 layout, padding c=4'),
  113. Doc('CHWN4 = 15', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
  114. 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'),
  115. Doc('NCHW64 = 16', 'NCHW64 is designed for convolution implementation to utilizing TensorCore '
  116. 'instructions for 4-bit integers on Nvidia platforms'),
  117. Doc('NCHW4_NHWC = 17', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout')).
  118. add_enum_alias('ComputeMode', 'ConvolutionV1', name_field='compute_mode')
  119. )
  120. (pdef('MaskPropagate').
  121. add_fields(
  122. 'uint32',
  123. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  124. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  125. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  126. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  127. Doc('kernel_h', 'kernel height'), 1,
  128. Doc('kernel_w', 'kernel width'), 1,
  129. Doc('dilate_h', 'dilate height'), 1,
  130. Doc('dilate_w', 'dilate width'), 1)
  131. )
  132. (pdef('ConvPooling').
  133. add_enum('Method', 'WITH_TEXTURE_OBJ = 0', 'WITH_SHARED_MEM = 1').
  134. add_enum_alias('ConvMode', 'ConvolutionV0', 'Mode').
  135. add_enum('PoolMode', 'AVERAGE = 0', 'MAX = 1').
  136. add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2').
  137. add_fields('uint32', 'pool_shape_h', 1, 'pool_shape_w', 1, 'pool_stride_h', 1, 'pool_stride_w', 1,
  138. 'pool_pad_h', 0, 'pool_pad_w', 0, 'conv_stride_h', 1, 'conv_stride_w', 1, 'conv_pad_h', 0, 'conv_pad_w', 0))
  139. (pdef('ConvBias', 'legacy conv_bias', version=0, is_legacy=True).
  140. add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2', 'H_SWISH = 3').
  141. add_enum_alias('Mode', 'ConvolutionV0').
  142. add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1))
  143. (pdef('ConvBias', 'active(conv(x, w) + bias)', version=1, is_legacy=True).
  144. add_enum_alias('NonlineMode', 'ConvBiasV0').
  145. add_enum_alias('Mode', 'ConvolutionV0').
  146. add_enum_alias('DataType', 'ConvolutionV0', name_field='data_type').
  147. add_enum_alias('Sparse', 'ConvolutionV0').
  148. add_enum_alias('Format', 'ConvolutionV0').
  149. add_fields(
  150. 'uint32',
  151. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  152. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  153. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  154. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  155. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  156. 'on the second dimension'), 1,
  157. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  158. 'on the second dimension'), 1)
  159. )
  160. (pdef('ConvBias', 'active(conv(x, w) + bias)', version=2, is_legacy=True).
  161. add_enum_alias('NonlineMode', 'ConvBiasV0').
  162. add_enum_alias('Mode', 'ConvolutionV0').
  163. add_enum_alias('Sparse', 'ConvolutionV0').
  164. add_enum_alias('Format', 'ConvolutionV0').
  165. add_fields(
  166. 'uint32',
  167. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  168. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  169. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  170. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  171. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  172. 'on the second dimension'), 1,
  173. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  174. 'on the second dimension'), 1).
  175. add_enum_alias('ComputeMode', 'ConvolutionV1', name_field='compute_mode')
  176. )
  177. (pdef('ConvBias', 'active(conv(x, w) + bias)', version=3, is_legacy=True).
  178. add_enum_alias('NonlineMode', 'ConvBiasV0').
  179. add_enum_alias('Mode', 'ConvolutionV0').
  180. add_enum_alias('Sparse', 'ConvolutionV0').
  181. add_enum_alias('Format', 'ConvolutionV0').
  182. add_fields(
  183. 'uint32',
  184. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  185. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  186. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  187. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  188. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  189. 'on the second dimension'), 1,
  190. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  191. 'on the second dimension'), 1,
  192. Doc('output_block_size', 'detail meaning \see winograd in conv bias'), 0).
  193. add_enum_alias('ComputeMode', 'ConvolutionV1', name_field='compute_mode')
  194. )
  195. (pdef('ConvBias', 'active(conv(x, w) + bias)', version=4).
  196. add_enum_alias('NonlineMode', 'ConvBiasV0').
  197. add_enum_alias('Mode', 'ConvolutionV0').
  198. add_enum_alias('Sparse', 'ConvolutionV0').
  199. add_enum_alias('Format', 'Convolution').
  200. add_fields(
  201. 'uint32',
  202. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  203. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  204. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  205. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  206. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  207. 'on the second dimension'), 1,
  208. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  209. 'on the second dimension'), 1).
  210. add_enum_alias('ComputeMode', 'ConvolutionV1', name_field='compute_mode')
  211. )
  212. (pdef('SeparableConv').
  213. add_enum_alias('Mode', 'ConvolutionV0').
  214. add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1',
  215. 'BORDER_REFLECT_101 = 2', 'BORDER_WRAP = 3',
  216. 'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5', 'BORDER_ISOLATED = 6').
  217. add_fields('bool', 'is_symm_kernel', 'true').
  218. add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1,
  219. 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1))
  220. (pdef('Images2Neibs').
  221. add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1,
  222. 'dilate_h', 1, 'dilate_w', 1, 'window_h', 3, 'window_w', 3))
  223. (pdef('SlidingWindowTranspose').
  224. add_fields('uint32', 'out_h', 0, 'out_w', 0, 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1,
  225. 'dilate_h', 1, 'dilate_w', 1, 'window_h', 3, 'window_w', 3))
  226. (pdef('Pooling', version=0, is_legacy=True).
  227. add_enum(
  228. 'Mode',
  229. Doc('MAX = 0', 'maximum value inside pooling window'),
  230. Doc('AVERAGE = 1',
  231. 'arithmetic mean of all values inside pooling window. Padding values '
  232. 'are taken into account and are viewed as zero'),
  233. Doc('AVERAGE_COUNT_EXCLUDE_PADDING = 2',
  234. 'arithmetic mean of all values inside pooling window. No padding is'
  235. 'used.')
  236. ).
  237. add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 2, 'stride_w', 2,
  238. 'window_h', 2, 'window_w', 2).
  239. add_enum_alias('Format', 'ConvolutionV0')
  240. )
  241. (pdef('Pooling', version=1).
  242. add_enum_alias('Mode', 'PoolingV0').
  243. add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 2, 'stride_w', 2,
  244. 'window_h', 2, 'window_w', 2).
  245. add_enum_alias('Format', 'Convolution')
  246. )
  247. (pdef('AdaptivePooling', version=0, is_legacy=True).
  248. add_enum_alias('Mode', 'PoolingV0').
  249. add_enum_alias('Format', 'ConvolutionV0')
  250. )
  251. (pdef('AdaptivePooling', version=1).
  252. add_enum_alias('Mode', 'PoolingV0').
  253. add_enum_alias('Format', 'Convolution')
  254. )
  255. (pdef('LRN',
  256. 'see ImageNet Classification with Deep Convolutional Neural Networks for'
  257. ' meaning of the fields').
  258. add_fields('uint32', Doc('n', 'must be odd'), 5).
  259. add_fields('float32', 'k', '2.f', 'alpha', '1e-4f', 'beta', '0.75f')
  260. )
  261. (pdef('BN').
  262. add_enum(
  263. 'ParamDim',
  264. Doc('DIM_11HW = 0', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'),
  265. Doc('DIM_1CHW = 1', 'Dim of params (Sigma, Mu) is 1 x C x H x W'),
  266. Doc('DIM_1C11 = 2', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'),
  267. Doc('DIM_111C = 3', 'Dim of params (Sigma, Mu) is 1 x 1 x 1 x C'),
  268. name_field='param_dim'
  269. ).
  270. add_enum(
  271. 'FwdMode',
  272. Doc('TRAINING = 0', 'Training phase.'),
  273. Doc('INFERENCE = 1', 'Inference phase.'),
  274. name_field='fwd_mode'
  275. ).
  276. add_fields('float64', 'epsilon', '1e-4f').
  277. add_fields('float64', 'avg_factor', '1.f').
  278. add_fields('float32', 'scale', '1.f').
  279. add_fields('float32', 'bias', '0.f')
  280. )
  281. (pdef('ROIPooling').
  282. add_enum(
  283. 'Mode',
  284. Doc('MAX = 0', 'maximum value inside pooling window; pooling result would '
  285. 'be 0 if pooling window is empty'),
  286. Doc('AVERAGE = 1',
  287. 'arithmetic mean of all values inside pooling window; pooling result '
  288. 'would be 0 if pooling window is empty')
  289. ).
  290. add_fields('float32', 'scale', '1.f'))
  291. INTERP_MODES = ['NEAREST = 0', 'LINEAR = 1',
  292. 'AREA = 2', 'CUBIC = 3', 'LANCZOS4 = 4']
  293. BORDER_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
  294. Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'),
  295. Doc('REFLECT_101 = 2', 'gfedcb|abcdefgh|gfedcba'),
  296. Doc('WRAP = 3', 'cdefgh|abcdefgh|abcdefg'),
  297. Doc('CONSTANT = 4', 'iiiiii|abcdefgh|iiiiiii'),
  298. Doc('TRANSPARENT = 5', ''),
  299. Doc('ISOLATED = 6', '')]
  300. (pdef('WarpPerspective', version=1, is_legacy=True).
  301. add_enum('InterpolationMode', *INTERP_MODES,
  302. name_field='imode', default=1,
  303. member_alias=[(i, 'INTER_{}'.format(i)) for i in INTERP_MODES]
  304. ).
  305. add_enum('BorderMode', *BORDER_MODES,
  306. name_field='bmode',
  307. member_alias=[(i, 'BORDER_{}'.format(i)) for i in BORDER_MODES]
  308. ).
  309. add_enum_alias('Format', 'ConvolutionV0').
  310. add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f'))
  311. (pdef('WarpPerspective', version=2).
  312. add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field="imode").
  313. add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field="bmode").
  314. add_enum_alias('Format', 'Convolution').
  315. add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f'))
  316. pdef('SpatialTfGridGenerator').add_enum('Mode', 'AFFINE = 0')
  317. pdef('SpatialTfSampler').add_enum('Mode', 'BILINEAR = 0')
  318. pdef('AddUpdate').add_fields(
  319. 'float32', 'alpha', '1.f', 'beta', '1.f', 'bias', '0.f')
  320. pdef('Elemwise').add_enum(
  321. 'Mode',
  322. Doc('RELU = 0', 'unary: max(x, 0)'),
  323. Doc('ABS = 1', 'unary: abs(x)'),
  324. Doc('ACOS = 2', 'unary: acos(x)'),
  325. Doc('ASIN = 3', 'unary: asin(x)'),
  326. Doc('CEIL = 4', 'unary: ceil(x)'),
  327. Doc('COS = 5', 'unary: cos(x)'),
  328. Doc('EXP = 6', 'unary: exp(x)'),
  329. Doc('EXPM1 = 7', 'unary: numerically stable exp(x)-1'),
  330. Doc('FLOOR = 8', 'unary: floor(x)'),
  331. Doc('LOG = 9', 'unary: natural logarithm, log(x)'),
  332. Doc('LOG1P = 10', 'unary: numerically stable log(x+1)'),
  333. Doc('NEGATE = 11', 'unary: -x'),
  334. Doc('SIGMOID = 12', 'unary: 1/(1+exp(-x))'),
  335. Doc('SIN = 13', 'unary: sin(x)'),
  336. Doc('TANH = 14', 'unary: tanh(x)'),
  337. Doc('ABS_GRAD = 15', 'binary: x > 0 ? y : -y'),
  338. Doc('ADD = 16', 'binary: x + y'),
  339. Doc('FLOOR_DIV = 17', 'binary: floor(x / y)'),
  340. Doc('MAX = 18', 'binary: max(x, y)'),
  341. Doc('MIN = 19', 'binary: min(x, y)'),
  342. Doc('MOD = 20', 'binary: x % y or fmodf(x, y)'),
  343. Doc('MUL = 21', 'binary: x * y'),
  344. Doc('POW = 22', 'binary: pow(x, y)'),
  345. Doc('SIGMOID_GRAD = 23', 'binary: x * (1 - x) * y'),
  346. Doc('SUB = 24', 'binary: x - y'),
  347. Doc('SWITCH_GT0 = 25', 'binary: (x > 0) * y'),
  348. Doc('TANH_GRAD = 26', 'binary: (1 - x * x) * y'),
  349. Doc('TRUE_DIV = 27', 'binary: x / y'),
  350. Doc('LOG_SUM_EXP = 28', 'binary: numerically stable log(exp(x) + exp(y))'),
  351. Doc('LT = 29', 'binary: x < y'),
  352. Doc('LEQ = 30', 'binary: x <= y'),
  353. Doc('EQ = 31', 'binary: x == y'),
  354. Doc('SHL = 32', 'bitwise binary: x << y. '
  355. 'Note that result is undefined if y < 0 or y >= bitwidth. Logical '
  356. 'shift is performed for unsigned intergers, and arithmetic shift for '
  357. 'signed ones.'),
  358. Doc('SHR = 33', 'bitwise binary: x >> y; see SHL mode for more details'),
  359. Doc('COND_LEQ_MOV = 34', 'ternary: x <= y ? z : 0'),
  360. Doc('FUSE_MUL_ADD3 = 35',
  361. 'compute ``a * b + c`` where c must either have same layout as '
  362. 'a or b, or be a scalar'),
  363. Doc('FUSE_MUL_ADD4 = 36',
  364. 'compute ``a * A + b * B`` where a and b must have equal layout, '
  365. 'and A and B must have equal layout. In the inputs ``b`` and ``B`` '
  366. 'can be swapped'),
  367. Doc('FUSE_ADD_RELU = 37', 'binary: max(x+y, 0)'),
  368. Doc('FUSE_ADD_SIGMOID = 38', 'binary: 1/(1+exp(-(x+y)))'),
  369. Doc('FUSE_ADD_TANH = 39', 'binary: tanh(x+y)'),
  370. Doc('FAST_TANH = 40', 'unary: rational approximation of tanh(x)'),
  371. Doc('FAST_TANH_GRAD = 41', 'binary: grad of the rational approximation of tanh(x)'),
  372. Doc('ROUND = 42', 'unary: round(x), the nearest integer value to x, rounding '
  373. 'halfway cases away from zero. Float only.'),
  374. Doc('RMULH = 43', 'binary: rounded higher l bits of x * y, where l is the bit '
  375. 'length of x.'),
  376. Doc('ATAN2 = 44', 'binary: atan2(y,x)'),
  377. Doc('ERF = 45', 'unary: erf(x)'),
  378. Doc('ERFINV = 46', 'unary: inverse function of erf(x)'),
  379. Doc('ERFC = 47', 'unary: erfc(x)'),
  380. Doc('ERFCINV = 48', 'unary: inverse function of erfc(x)'),
  381. Doc('H_SWISH = 49', 'unary: x * clip(x + 3, 0, 6) / 6'),
  382. Doc('H_SWISH_GRAD = 50', 'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'),
  383. Doc('FUSE_ADD_H_SWISH = 51', 'binary: hswish(x+y)'),
  384. Doc('NOT = 52', 'unary: !x'),
  385. Doc('AND = 53', 'binary: x && y'),
  386. Doc('OR = 54', 'binary: x || y'),
  387. Doc('XOR = 55', 'binary: x ^ y'),
  388. Doc('SILU = 56', 'unary: x / (1 + exp(-x))'),
  389. Doc('SILU_GRAD = 57', 'binary: grad(x / (1 + exp(-x))'),
  390. Doc('GELU = 58', 'unary: x Phi(x)'),
  391. Doc('GELU_GRAD = 59', 'binary: grad(x Phi(x))'),
  392. )
  393. pdef('ElemwiseMultiType').add_enum(
  394. 'Mode',
  395. Doc('FUSE_MUL_ADD3_INT16x32x32x32 = 0',
  396. 'compute ``a * b + c`` requiring that ``a`` be int16 and ``b`` and '
  397. '``c`` int32, and the result is int32. This mode is optimized for '
  398. 'the channel-broadacsted case, i.e. ``a`` has shape (A, B, C) and '
  399. '``b`` and ``c`` have shape (1, C, 1)'),
  400. Doc('FUSE_MUL_ADD3_IXxF32xF32xI8 = 1',
  401. 'compuate ``a * b + c`` where the inputs ``a`` is an integer type '
  402. '``b`` and ``c`` are both ``float32``, the result is '
  403. '``int8``. This is currently only optimized for ``(1, x)`` '
  404. 'broadcast for ``b`` and ``c``. Computation is carried in floating '
  405. 'points and results are rounded towards zero with saturated cast to '
  406. 'int.'),
  407. Doc('ROUND_SHR_SATURATE_IXxI8xI8 = 2',
  408. 'Compute ``a >> b``, round the result according to lower ``b`` bits '
  409. 'of ``a``` and make a saturating conversion to int8. Where ``a`` should'
  410. ' be an integer tensor and ``b`` should be an int8 scalar.'),
  411. Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8 = 3',
  412. 'Fused operation of an int16 elemwise add, an int16 rounding multiply '
  413. 'high and an int16 to int8 rounding right shift with saturation.'),
  414. Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8 = 4',
  415. 'Fused operation of an int32 elemwise add, an int32 rounding multiply '
  416. 'high and an int32 to int8 rounding right shift with saturation.'),
  417. Doc('ROUND_SHR_SATURATE_IXxI8xI16 = 5',
  418. 'Compute ``a >> b``, round the result according to lower ``b`` bits of '
  419. '``a``` and make a saturating conversion to int16. Where ``a`` should'
  420. ' be an integer tensor and ``b`` should be an int8 scalar.'),
  421. Doc('QADD = 6', 'Fused elemwise add two quantized int8 with specified'
  422. 'output quantized dtype'),
  423. Doc('QFUSE_ADD_RELU = 7', 'Fused elemwise add two quantized int8 followed'
  424. ' by ReLU and typecvt to specified dtype'),
  425. Doc('QMUL = 8', 'Fused elemwise multiply two quantized int8 with specified'
  426. 'output quantized dtype'),
  427. Doc('QMIN = 9', 'Fused elemwise min two quantized int8 with specified'
  428. 'output quantized dtype'),
  429. Doc('QMAX = 10', 'quantized: max(x, y), with specified output quantized dtype'),
  430. Doc('QSUB = 11', 'quantized: x - y'),
  431. Doc('QTRUE_DIV = 12', 'quantized: x / y'),
  432. Doc('QFUSE_ADD_SIGMOID = 13', 'quantized: sigmoid(x + y)'),
  433. Doc('QFUSE_ADD_TANH = 14', 'quantized: tanh(x + y)'),
  434. Doc('QRELU = 15', 'quantized: x > 0 ? x : 0'),
  435. Doc('QABS = 16', 'quantized: x > 0 ? x : -x'),
  436. Doc('QSIGMOID = 17', 'quantized: sigmoid(x)'),
  437. Doc('QEXP = 18', 'quantized: exp(x)'),
  438. Doc('QTANH = 19', 'quantized: tanh(x)'),
  439. Doc('QFUSE_MUL_ADD3 = 20', 'quantized: x * y + z'),
  440. Doc('QFAST_TANH = 21', 'quantized: fast_tanh(x)'),
  441. Doc('QNEGATE = 22', 'quantized: -x'),
  442. Doc('QACOS = 23', 'quantized: acos(x)'),
  443. Doc('QASIN = 24', 'quantized: asin(x)'),
  444. Doc('QCEIL = 25', 'quantized: ceil(x)'),
  445. Doc('QCOS = 26', 'quantized: cos(x)'),
  446. Doc('QEXPM1 = 27', 'quantized: expm1(x)'),
  447. Doc('QFLOOR = 28', 'quantized: floor(x)'),
  448. Doc('QLOG = 29', 'quantized: log(x)'),
  449. Doc('QLOG1P = 30', 'quantized: log1p(x)'),
  450. Doc('QSIN = 31', 'quantized: sin(x)'),
  451. Doc('QROUND = 32', 'quantized: round(x)'),
  452. Doc('QERF = 33', 'quantized: erf(x)'),
  453. Doc('QERFINV = 34', 'quantized: erfinv(x)'),
  454. Doc('QERFC = 35', 'quantized: erfc(x)'),
  455. Doc('QERFCINV = 36', 'quantized: erfcinv(x)'),
  456. Doc('QABS_GRAD = 37', 'quantized: abs_grad'),
  457. Doc('QFLOOR_DIV = 38', 'quantized floor_div'),
  458. Doc('QMOD = 39', 'quantized mod'),
  459. Doc('QSIGMOID_GRAD = 40', 'quantized sigmoid_grad'),
  460. Doc('QSWITCH_GT0 = 41', 'quantized switch_gt0'),
  461. Doc('QTANH_GRAD = 42', 'quantized tanh_grad'),
  462. Doc('QLT = 43', 'quantized lt'),
  463. Doc('QLEQ = 44', 'quantized leq'),
  464. Doc('QEQ = 45', 'quantized eq'),
  465. Doc('QPOW = 46', 'quantized pow'),
  466. Doc('QLOG_SUM_EXP = 47', 'quantized log_sum_exp'),
  467. Doc('QFAST_TANH_GRAD = 48', 'quantized fast_tanh_grad'),
  468. Doc('QATAN2 = 49', 'quantized atan2'),
  469. Doc('QCOND_LEQ_MOV = 50', 'quantized cond_leq_mov'),
  470. Doc('QH_SWISH = 51', 'quantized h_swish'),
  471. Doc('QFUSE_ADD_H_SWISH = 52', 'quantized h_swish(x+y)'),
  472. Doc('QH_SWISH_GRAD = 53', 'quantized h_swish_grad'),
  473. Doc('FUSE_MUL_ADD3_INT16xF32xF32xF32 = 54',
  474. 'compute ``a * b + c`` requiring that ``a`` be int16 and ``b`` and '
  475. '``c`` float32, and the result is float32.'),
  476. Doc('MUL_INT16xF32xF32 = 55',
  477. 'compute ``a * b `` requiring that ``a`` be int16 and ``b`` float32, '
  478. 'and the result is float32.'),
  479. Doc('FUSE_MUL_ADD3_UINT8xF32xF32xF32 = 56',
  480. 'compute ``a * b + c`` requiring that ``a`` be uint8 and ``b`` and '
  481. '``c`` float32, and the result is float32.')
  482. )
  483. pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
  484. (pdef('DctChannelSelect', '2d discrete cosine transform', version=0, is_legacy=True).add_enum_alias('Format', 'ConvolutionV0').
  485. add_enum('FastImpl', 'NONE = 0', 'FIX_32_MASK = 1').add_fields('int32', 'dct_block_size', 8))
  486. (pdef('DctChannelSelect', '2d discrete cosine transform', version=1).add_enum_alias('Format', 'Convolution').
  487. add_enum_alias('FastImpl', 'DctChannelSelectV0').add_fields('int32', 'dct_block_size', 8))
  488. (pdef('MatrixMul', version=0, is_legacy=True).
  489. add_fields('bool', 'transposeA', 'false', 'transposeB', 'false').
  490. add_enum('DataType',
  491. Doc('FLOAT = 0', 'input/output both float32/float16'),
  492. 'INT8x8x16 = 1',
  493. 'INT8x8x32 = 2',
  494. Doc('FLOAT_IO16xC32 = 3', 'input/output both float16, the internal compute is '
  495. 'float32'),
  496. Doc('QUINT8x8x32 = 4', 'input QuantizedAsymm8, output QuantizedS32'),
  497. Doc('QUINT4x4x32 = 5', 'input QuantizedAsymm4, output QuantizedS32'),
  498. name_field='data_type'))
  499. (pdef('MatrixMul', version=1, is_legacy=True).
  500. add_fields('bool', 'transposeA', 'false', 'transposeB', 'false').
  501. add_enum(Doc('ComputeMode', 'Specifies special computation modes, e.g. '
  502. 'different combinations of intermediate result '
  503. 'data types.'),
  504. Doc('DEFAULT = 0', 'No special requirements on the precision of '
  505. 'intermediate results.'),
  506. Doc('FLOAT32 = 1', 'Use Float32 accumulator and intermediate result. '
  507. 'Only supported when input and output is Float16.'),
  508. name_field='compute_mode'))
  509. (pdef('MatrixMul', version=2).
  510. add_fields('bool', 'transposeA', 'false', 'transposeB', 'false').
  511. add_enum_alias('ComputeMode', 'MatrixMulV1', name_field='compute_mode').
  512. add_enum('Format',
  513. Doc('DEFAULT = 0', 'Normal matrix mul: (M, K) x (K, N) = (M, N)'),
  514. Doc('MK4 = 1', 'Split 4 from M and K, better for neon compute:'
  515. '(M/4, K/4, 4(k), 4(m)) x (K/4, N, 4(k)). if transposeA the '
  516. 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'),
  517. Doc('MK8 = 2', 'Split 8 from M and K, better for neon compute:'
  518. '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the '
  519. 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'),
  520. Doc('MK4_DOT = 3', 'Split 4 from M and K, better for neon dotprod:'
  521. 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the '
  522. 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'))
  523. )
  524. (pdef('SVD').
  525. add_fields('bool',
  526. Doc('full_matrices',
  527. 'Whether to compute the full-sized u and v or only the leading'
  528. ' min(m, n) singular vectors. Ignored if compute_uv is '
  529. 'false.'),
  530. 'false',
  531. Doc('compute_uv',
  532. 'Whether the left (u) and right (v) singular vectors will be '
  533. 'computed and outputted.'),
  534. 'true'))
  535. (pdef('Reduce', 'legacy reduce', version=0, is_legacy=True).
  536. add_enum('Mode',
  537. 'SUM = 0',
  538. Doc('SUM_SQR = 1', 'sum of x * x for each element x'),
  539. 'PRODUCT = 2', 'MIN = 3', 'MAX = 4').
  540. add_fields('int32',
  541. Doc('axis',
  542. 'axis along which reduction is performed; if -1 is given, '
  543. 'reduce to given target shape (only used in megbrain)'),
  544. -1))
  545. (pdef('Reduce', 'reduce along given axis', version=1, is_legacy=True).
  546. add_enum('Mode',
  547. 'SUM = 0',
  548. Doc('SUM_SQR = 1', 'sum of x * x for each element x'),
  549. 'PRODUCT = 2', 'MIN = 3', 'MAX = 4', 'MEAN = 5').
  550. add_fields('int32',
  551. Doc('axis',
  552. 'axis along which reduction is performed; if -1 is given, '
  553. 'reduce to given target shape (only used in megbrain)'),
  554. -1).
  555. add_enum('DataType',
  556. Doc('DEFAULT = 0',
  557. '''
  558. input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode.
  559. Currently, ```DEFAULT``` mode means:
  560. +--------------------+-----------------------------------+-------------------+
  561. | Input/Output DType | Mode | Computation DType |
  562. +====================+===================================+===================+
  563. | FLOAT32 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | FLOAT32 |
  564. +--------------------+-----------------------------------+-------------------+
  565. | FLOAT16 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | FLOAT16 |
  566. +--------------------+-----------------------------------+-------------------+
  567. | INT32 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | INT32 |
  568. +--------------------+-----------------------------------+-------------------+
  569. | INT8 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | INT8 |
  570. +--------------------+-----------------------------------+-------------------+
  571. | QuantizedS8 | MIN/MAX | QuantizedS8 |
  572. +--------------------+-----------------------------------+-------------------+
  573. | QuantizedS8 | MEAN/SUM | QuantizedS32 |
  574. +--------------------+-----------------------------------+-------------------+
  575. | Quantized8Asymm | MIN/MAX | Quantized8Asymm |
  576. +--------------------+-----------------------------------+-------------------+
  577. | Quantized8Asymm | MEAN/SUM | QuantizedS32 |
  578. +--------------------+-----------------------------------+-------------------+
  579. '''
  580. ),
  581. Doc('FLOAT_IO16xC32 = 1', 'Deprecated. This was replaced by '
  582. 'FLOAT_O16xC32, and input\'s dtype decided by actual input tensor.'),
  583. Doc('FLOAT_O32xC32 = 2', 'compute/output both are float32'),
  584. Doc('FLOAT_O16xC32 = 3', 'compute are float32, output float16'),
  585. Doc('QUINT_I8xO32 = 4', 'input quint8, compute and output are qint32'),
  586. Doc('QINT_I8xO32 = 5', 'input qint8, compute and output are qint32'),
  587. name_field='data_type'))
  588. (pdef('Reduce', 'reduce along given axis', version=2).
  589. add_enum('Mode',
  590. 'SUM = 0',
  591. Doc('SUM_SQR = 1', 'sum of x * x for each element x'),
  592. 'PRODUCT = 2', 'MIN = 3', 'MAX = 4', 'MEAN = 5').
  593. add_fields('int32',
  594. Doc('axis',
  595. 'axis along which reduction is performed; if INT_MAX is given, '
  596. 'reduce to given target shape (only used in megbrain)'),
  597. (1 << 31)-1).
  598. add_enum('DataType',
  599. Doc('DEFAULT = 0',
  600. '''
  601. input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode.
  602. Currently, ```DEFAULT``` mode means:
  603. +--------------------+-----------------------------------+-------------------+
  604. | Input/Output DType | Mode | Computation DType |
  605. +====================+===================================+===================+
  606. | FLOAT32 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | FLOAT32 |
  607. +--------------------+-----------------------------------+-------------------+
  608. | FLOAT16 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | FLOAT16 |
  609. +--------------------+-----------------------------------+-------------------+
  610. | INT32 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | INT32 |
  611. +--------------------+-----------------------------------+-------------------+
  612. | INT8 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | INT8 |
  613. +--------------------+-----------------------------------+-------------------+
  614. | QuantizedS8 | MIN/MAX | QuantizedS8 |
  615. +--------------------+-----------------------------------+-------------------+
  616. | QuantizedS8 | MEAN/SUM | QuantizedS32 |
  617. +--------------------+-----------------------------------+-------------------+
  618. | Quantized8Asymm | MIN/MAX | Quantized8Asymm |
  619. +--------------------+-----------------------------------+-------------------+
  620. | Quantized8Asymm | MEAN/SUM | QuantizedS32 |
  621. +--------------------+-----------------------------------+-------------------+
  622. '''
  623. ),
  624. Doc('FLOAT_IO16xC32 = 1', 'Deprecated. This was replaced by '
  625. 'FLOAT_O16xC32, and input\'s dtype decided by actual input tensor.'),
  626. Doc('FLOAT_O32xC32 = 2', 'compute/output both are float32'),
  627. Doc('FLOAT_O16xC32 = 3', 'compute are float32, output float16'),
  628. Doc('QUINT_I8xO32 = 4', 'input quint8, compute and output are qint32'),
  629. Doc('QINT_I8xO32 = 5', 'input qint8, compute and output are qint32'),
  630. name_field='data_type'))
  631. (pdef('Cumsum', 'calculate accumulated sum along given axis', version=0, is_legacy=True).
  632. add_fields('int32',
  633. Doc('axis',
  634. 'axis along which cumsum is performed'),
  635. -1).
  636. add_fields('bool',
  637. Doc('exclusive',
  638. 'whether the current element is taken into account'),
  639. 'true').
  640. add_fields('bool',
  641. Doc('reverse',
  642. 'whether the cumsum is forward or backward'),
  643. 'false'))
  644. (pdef('Cumsum', 'calculate accumulated sum along given axis', version=1).
  645. add_fields('int32',
  646. Doc('axis',
  647. 'axis along which cumsum is performed, default with INT_MAX'),
  648. (1 << 31)-1).
  649. add_fields('bool',
  650. Doc('exclusive',
  651. 'whether the current element is taken into account'),
  652. 'true').
  653. add_fields('bool',
  654. Doc('reverse',
  655. 'whether the cumsum is forward or backward'),
  656. 'false'))
  657. (pdef('CondTake').
  658. add_enum('Mode',
  659. Doc('EQ = 0', 'take if ``abs(data-val)<eps``'),
  660. Doc('NEQ = 1', 'take if ``abs(data-val)>=eps``'),
  661. Doc('LT = 2', 'take if ``data<val``'),
  662. Doc('LEQ = 3', 'take if ``data<=val``'),
  663. Doc('GT = 4', 'take if ``data>val``'),
  664. Doc('GEQ = 5', 'take if ``data>=val``')).
  665. add_fields('float32',
  666. Doc('val', 'the value to be compared with; note that for integer '
  667. 'data, val is also converted to int'), 0).
  668. add_fields('float32', Doc('eps', 'used for float equality comparison'),
  669. 1e-6))
  670. pdef('Argsort').add_enum('Order', 'ASCENDING = 0', 'DESCENDING = 1')
  671. (pdef('IndexingRemap').
  672. add_fields('bool',
  673. Doc('is_non_overlapping',
  674. 'Whether no two dst element maps to the same src element. '
  675. 'Enabling this option can accelerate gradient operator since'
  676. ' atomic adding operations could be avoided.'),
  677. 'false'))
  678. pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
  679. (pdef('Linspace').
  680. add_fields('bool',
  681. Doc('endpoint',
  682. 'Whether stop is included in the generated tensor'),
  683. 'true'))
  684. (pdef('LinspaceFull').
  685. add_fields('float64',
  686. Doc('start', 'The first val.'),
  687. 0).
  688. add_fields('float64',
  689. Doc('stop', 'The last val.'),
  690. 1).
  691. add_fields('bool',
  692. Doc('endpoint',
  693. 'Whether stop is included in the generated tensor'),
  694. 'true'))
  695. (pdef('Eye').
  696. add_fields(
  697. 'int32',
  698. Doc('k', 'Index of the diagonal: 0 (the default) refers to the main '
  699. 'diagonal, a positive value refers to an upper diagonal, and a '
  700. 'negative value to a lower diagonal.'),
  701. 0).
  702. add_fields(
  703. 'dtype', Doc('dtype', 'data type of output value'),
  704. 'DTypeEnum::Float32'))
  705. (pdef('UniformRNG', version=0, is_legacy=True).
  706. add_fields('uint64', 'seed', 0))
  707. (pdef('UniformRNG', version=1).
  708. add_fields('uint64', 'seed', 0).
  709. add_fields(
  710. 'dtype', Doc(
  711. 'dtype', 'The dtype of output Tensor. Only support Float32.'),
  712. 'DTypeEnum::Float32'))
  713. (pdef('GaussianRNG', version=0, is_legacy=True).
  714. add_fields('uint64', 'seed', 0).
  715. add_fields('float32', 'mean', 0, 'std', 1))
  716. (pdef('GaussianRNG', version=1).
  717. add_fields('uint64', 'seed', 0).
  718. add_fields('float32', 'mean', 0, 'std', 1).
  719. add_fields(
  720. 'dtype', Doc(
  721. 'dtype', 'The dtype of output Tensor. Only support Float32.'),
  722. 'DTypeEnum::Float32'))
  723. (pdef('GammaRNG').
  724. add_fields('uint64', 'seed', 0))
  725. (pdef('BetaRNG').
  726. add_fields('uint64', 'seed', 0))
  727. (pdef('PoissonRNG').
  728. add_fields('uint64', 'seed', 0))
  729. (pdef('PermutationRNG').
  730. add_fields('uint64', 'seed', 0).
  731. add_fields(
  732. 'dtype', Doc('dtype', 'The dtype of output Tensor. Int32, Int16 and '
  733. 'Float32 are supported.'),
  734. 'DTypeEnum::Int32'))
  735. (pdef('ShuffleRNG').
  736. add_fields('uint64', 'seed', 0))
  737. (pdef('Flip').
  738. add_fields('bool', 'vertical', 'false', 'horizontal', 'false'))
  739. (pdef('Rotate')
  740. .add_fields('bool', 'clockwise', 'true'))
  741. (pdef('ROICopy')
  742. .add_fields('uint32', 'row_from', 0, 'row_to', 0, 'col_from', 0, 'col_to', 0))
  743. (pdef('CvtColor')
  744. .add_enum('Mode', 'RGB2GRAY = 0', 'RGB2YUV = 1', 'YUV2RGB = 2', 'GRAY2RGB = 3', 'RGBA2RGB = 4',
  745. 'RGBA2BGR = 5', 'RGBA2GRAY = 6', 'RGB2BGR = 7', 'BGR2GRAY = 8', 'BGR2RGB = 9',
  746. Doc('YUV2GRAY_NV21 = 10', 'For historical reasons, referred to as YCC by opencv'),
  747. 'YUV2RGB_NV21 = 11', 'YUV2BGR_NV21 = 12', 'YUV2GRAY_NV12 = 13', 'YUV2RGB_NV12 = 14',
  748. 'YUV2BGR_NV12 = 15', 'YUV2GRAY_YV12 = 16', 'YUV2RGB_YV12 = 17', 'YUV2BGR_YV12 = 18',
  749. 'YUV2GRAY_YU12 = 19', 'YUV2RGB_YU12 = 20', 'YUV2BGR_YU12 = 21',
  750. 'YCrCb2RGB = 22', 'YCrCb2BGR = 23',
  751. Doc('BT601_YUV2RGB_NV21 = 24', 'BT601 yuv format, referred to as YUV by opencv'),
  752. 'BT601_YUV2BGR_NV21 = 25', 'BT601_YUV2RGB_NV12 = 26', 'BT601_YUV2BGR_NV12 = 27',
  753. 'BT601_YUV2RGB_YV12 = 28', 'BT601_YUV2BGR_YV12 = 29', 'BT601_YUV2RGB_YU12 = 30',
  754. 'BT601_YUV2BGR_YU12 = 31',
  755. member_alias=[('YUV2GRAY_NV21', 'BT601_YUV2GRAY_NV21'),
  756. ('YUV2GRAY_NV12', 'BT601_YUV2GRAY_NV12'),
  757. ('YUV2GRAY_YV12', 'BT601_YUV2GRAY_YV12'),
  758. ('YUV2GRAY_YU12', 'BT601_YUV2GRAY_YU12')],
  759. name_field='mode'))
  760. (pdef('WarpAffine', version=0, is_legacy=True)
  761. .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
  762. .add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_mode')
  763. .add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f'))
  764. (pdef('WarpAffine', version=1, is_legacy=True)
  765. .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
  766. .add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_mode')
  767. .add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')
  768. .add_enum_alias('Format', 'ConvolutionV0', default=1))
  769. (pdef('WarpAffine', version=2)
  770. .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
  771. .add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_mode')
  772. .add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')
  773. .add_enum_alias('Format', 'Convolution', default=1))
  774. (pdef('GaussianBlur')
  775. .add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_mode')
  776. .add_fields('uint32', 'kernel_height', 0, 'kernel_width', 0)
  777. .add_fields('float32', 'sigma_x', '0.f', 'sigma_y', '0.f'))
  778. (pdef('Resize', version=0, is_legacy=True)
  779. .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode'))
  780. (pdef('Resize', version=1, is_legacy=True)
  781. .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
  782. .add_enum_alias('Format', 'ConvolutionV0', default=1))
  783. (pdef('Resize', version=2)
  784. .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
  785. .add_enum_alias('Format', 'Convolution', default=1))
  786. (pdef('Remap', version=0, is_legacy=True)
  787. .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
  788. .add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_type')
  789. .add_enum_alias('Format', 'ConvolutionV0', default=1)
  790. .add_fields('float32', 'scalar', '0.f'))
  791. (pdef('Remap', version=1)
  792. .add_enum_alias('InterpolationMode', 'WarpPerspectiveV1', name_field='imode')
  793. .add_enum_alias('BorderMode', 'WarpPerspectiveV1', name_field='border_type')
  794. .add_enum_alias('Format', 'Convolution', default=1)
  795. .add_fields('float32', 'scalar', '0.f'))
  796. (pdef('Convolution3D').
  797. add_enum('Mode', 'CROSS_CORRELATION = 0', 'CONVOLUTION = 1').
  798. add_fields(
  799. 'uint32',
  800. Doc('pad_d', 'padding on one side on the first dimension'), 0,
  801. Doc('pad_h', 'padding on one side on the second dimension'), 0,
  802. Doc('pad_w', 'padding on one side on the third dimension'), 0,
  803. Doc('stride_d', 'kernel stride on the first dimension'), 1,
  804. Doc('stride_h', 'kernel stride on the second dimension'), 1,
  805. Doc('stride_w', 'kernel stride on the third dimension'), 1,
  806. Doc('dilate_d', 'dilation (i.e. size of each zero-padded kernel block) '
  807. 'on the first dimension'), 1,
  808. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  809. 'on the second dimension'), 1,
  810. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  811. 'on the third dimension'), 1
  812. ).
  813. add_enum('Sparse',
  814. Doc('DENSE = 0', 'dense convolution: filter shape should be '
  815. '[oc, ic, spatial...] if format is NCDHW, '
  816. '[oc, spatial..., ic] if format is NDHWC'),
  817. Doc('GROUP = 1', 'group convolution: filter shape should be '
  818. '[group, oc_per_group, ic_per_group, spatial...] if format is NCDHW, '
  819. '[group, oc_per_group, spatial..., ic_per_group] if format is NDHWC')
  820. ).
  821. add_enum('DataType',
  822. Doc('FLOAT = 0', 'input/output both float32/float16'),
  823. Doc('FLOAT_IO16xC32 = 1', 'input/output both float16, the internal '
  824. 'compute is float32'),
  825. name_field='data_type').
  826. add_enum('Format', 'NCDHW = 0', 'NDHWC = 1')
  827. )
  828. (pdef('Conv3DBias').
  829. add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'SIGMOID = 2').
  830. add_enum_alias('Mode', 'Convolution3D').
  831. add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0,
  832. 'stride_d', 1, 'stride_h', 1, 'stride_w', 0))
  833. (pdef('SeparableConv3D').
  834. add_enum_alias('Mode', 'Convolution3D').
  835. add_enum('BorderMode', 'BORDER_REPLICATE = 0', 'BORDER_REFLECT = 1',
  836. 'BORDER_REFLECT_101 = 2', 'BORDER_WRAP = 3',
  837. 'BORDER_CONSTANT = 4', 'BORDER_TRANSPARENT = 5', 'BORDER_ISOLATED = 6').
  838. add_fields('bool', 'is_symm_kernel', 'true').
  839. add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0,
  840. 'stride_d', 0, 'stride_h', 1, 'stride_w', 1,
  841. 'ksize_d', 0, 'ksize_h', 3, 'ksize_w', 3,
  842. 'anchor_d', 0, 'anchor_h', 1, 'anchor_w', 1))
  843. (pdef('TopK').
  844. add_enum(
  845. 'Mode',
  846. Doc('KTH_ONLY = 0', "only the value of the k'th element would be computed"),
  847. Doc('VALUE_IDX_NOSORT = 1',
  848. 'all the top-k values and corresponding indices would be computed; '
  849. 'no order is guaranteed'),
  850. Doc('VALUE_IDX_SORTED = 2',
  851. 'all the top-k values and corresponding indices sorted'))
  852. )
  853. RELAYOUT_FORMAT_MODE_DOC = """
  854. Relayout mode.
  855. **Naming conventions**
  856. 1. ``A_B`` means change from layout format ``A`` to ``B``.
  857. 2. ``INTER_WEIGHT_xx`` means relayout the weight for faster processing by
  858. :attr:`Convolution.Format.NHWCD4` convolutions.
  859. 3. A suffix of ``I`` means ``Image2DPack4TensorFormat`` tensor format is used
  860. for faster processing on GPUs.
  861. **Layout definitions**
  862. * ``NCHW`` layout: ``{N, C, H, W}``
  863. * ``NHWC`` layout: ``{N, H, W, C}``
  864. * ``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}``
  865. * ``NHWCD4I`` layout: with ``align_axis = 2``
  866. * ``NCHW4`` layout: ``{N, C/4, H, W, 4}``
  867. * ``NCHW88`` layout: ``{N, C/8, H, W, 8}``
  868. * ``CHWN4`` layout: ``{C/4, H, W, N, 4}``
  869. * ``NCHW64`` layout: ``{N, C/64, H, W, 64}``
  870. **Float weight transformation definitions**
  871. +---------------+---------------------------------+--------------------+--------------------------------------+------+
  872. | Sparsity Type | Input Layout | Input Req | Output Layout | Axis |
  873. +===============+=================================+====================+======================================+======+
  874. | DENSE | ``{OC, IC, FH, FW}`` | ``OC % 4 == 0`` | ``{OC/4, FH, FW, IC, 4}`` | 3 |
  875. +---------------+---------------------------------+--------------------+--------------------------------------+------+
  876. | GROUP | ``{GROUP, OCPG, ICPG, FH, FW}`` | ``OCPG % 4 == 0`` | ``{GROUP, OCPG/4, FH, FW, ICPG, 4}`` | 4 |
  877. | | | ``ICPG % 4 == 0`` | | |
  878. +---------------+---------------------------------+--------------------+--------------------------------------+------+
  879. | CHAN | ``{GROUP, 1, 1, FH, FW}`` | ``GROUP % 4 == 0`` | ``{GROUP / 4, 1, FH ,FW, 4}`` | 1 |
  880. +---------------+---------------------------------+--------------------+--------------------------------------+------+
  881. **Float weight transformation nchw88 definitions**
  882. +---------------+---------------------------------+--------------------+--------------------------------------+
  883. | Sparsity Type | Input Layout | Input Req | Output Layout |
  884. +===============+=================================+====================+======================================+
  885. | DENSE | ``{OC, IC, FH, FW}`` | ``OC % 8 == 0`` |``{OC/8, IC/8 ,FH, FW, 8(IC), 8(OC)}``|
  886. | | | ``IC % 8 == 0`` | |
  887. +---------------+---------------------------------+--------------------+--------------------------------------+
  888. | GROUP | ``{GROUP, OCPG, ICPG, FH, FW}`` | ``OCPG % 8 == 0`` | ``{GROUP, OCPG/8, ICPG/8 FH, FW, |
  889. | | | ``ICPG % 8 == 0`` | 8(ICPG), 8(OCPG)} `` |
  890. +---------------+---------------------------------+--------------------+--------------------------------------+
  891. | CHAN | ``{GROUP, 1, 1, FH, FW}`` | ``GROUP % 8 == 0`` | ``{GROUP / 8, 1, FH ,FW, 8}`` |
  892. +---------------+---------------------------------+--------------------+--------------------------------------+
  893. **Int8(DOT) weight transformation definitions**
  894. +---------------+---------------------------------+--------------------+------------------------------------------+------+
  895. | Sparsity Type | Input Layout | Input Req | Output Layout | Axis |
  896. +===============+=================================+====================+==========================================+======+
  897. | DENSE | ``{OC, IC, FH, FW}`` | ``OC % 4 == 0`` | ``{OC/4, FH, FW, IC/4, 4, 4}` | 3 |
  898. +---------------+---------------------------------+--------------------+------------------------------------------+------+
  899. | GROUP | ``{GROUP, OCPG, ICPG, FH, FW}`` | ``OCPG % 4 == 0`` | ``{GROUP, OCPG/4, FH, FW, ICPG/4, 4, 4}``| 4 |
  900. | | | ``ICPG % 4 == 0`` | | |
  901. +---------------+---------------------------------+--------------------+------------------------------------------+------+
  902. Note: the axis column means the corresponding ``align_axis`` for image format
  903. when the ``I`` suffix is present.
  904. Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later opr by seting group and oc param with NCHW4_NCHW
  905. """
  906. (pdef('RelayoutFormat', 'Change the tensor layout format', version=0, is_legacy=True).
  907. add_enum(
  908. Doc('Mode', RELAYOUT_FORMAT_MODE_DOC),
  909. 'NHWC_NHWCD4 = 0',
  910. 'NHWCD4_NHWC = 1',
  911. 'NHWC_NHWCD4I = 2',
  912. 'NCHW_NHWCD4 = 3',
  913. 'NCHW_NHWCD4I = 4',
  914. 'NHWCD4I_NCHW = 5',
  915. 'NHWCD4_NCHW = 6',
  916. 'INTER_WEIGHT_DENSE = 7',
  917. 'INTER_WEIGHT_DENSEI = 8',
  918. 'INTER_WEIGHT_GROUP = 9',
  919. 'INTER_WEIGHT_GROUPI = 10',
  920. 'INTER_WEIGHT_CHAN = 11',
  921. 'INTER_WEIGHT_CHANI = 12',
  922. 'INTER_WEIGHT_DENSEI_DOT = 13',
  923. 'INTER_WEIGHT_GROUPI_DOT = 14',
  924. 'NCHW4_CHWN4 = 15',
  925. 'CHWN4_NCHW4 = 16',
  926. 'NCHW_NCHW88_CONV_DENSE_WEIGHT = 17',
  927. 'NCHW_NCHW88_CONV_CHAN_WEIGHT = 18',
  928. 'NCHW_NCHW88_CONV_GROUP_WEIGHT = 19',
  929. 'NCHW_NCHW88 = 20',
  930. 'NCHW88_NCHW = 21',
  931. 'NCHW_NCHW4_IC_SMALL = 22',
  932. 'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT = 23',
  933. 'NCHW_NCHW4 = 24',
  934. 'NCHW4_NCHW = 25',
  935. 'NCHW_NCHW4_WEIGHT = 26',
  936. 'NCHW_NCHW64 = 27',
  937. 'NCHW64_NCHW = 28',
  938. 'NCHW_NHWC = 29',
  939. 'NHWC_NCHW = 30',
  940. )
  941. )
  942. (pdef('RelayoutFormat', 'Change the tensor layout format', version=1).
  943. add_enum_alias('Mode', 'RelayoutFormatV0').
  944. add_fields('uint32', 'oc', '0').
  945. add_fields('uint32', 'group', '1')
  946. )
  947. (pdef('SeparableFilter', version=0, is_legacy=True).
  948. add_enum_alias('Format', 'ConvolutionV0').
  949. add_enum_alias('BorderMode', 'WarpPerspectiveV1').
  950. add_fields('bool', 'is_symm_kernel', 'true').
  951. add_fields('uint32', 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1))
  952. (pdef('SeparableFilter', version=1).
  953. add_enum_alias('Format', 'Convolution').
  954. add_enum_alias('BorderMode', 'WarpPerspectiveV1').
  955. add_fields('bool', 'is_symm_kernel', 'true').
  956. add_fields('uint32', 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1))
  957. (pdef('LocalShare', 'Local share convolution', version=0, is_legacy=True).
  958. add_enum_alias('Mode', 'ConvolutionV0').
  959. add_fields(
  960. 'uint32',
  961. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  962. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  963. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  964. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  965. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  966. 'on the second dimension'), 1,
  967. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  968. 'on the second dimension'), 1,
  969. Doc('spatial_groups_h', 'spatial groups on the first dimension'), 1,
  970. Doc('spatial_groups_w', 'spatial groups on the second dimension'), 1
  971. ).
  972. add_enum_alias('Sparse', 'ConvolutionV0').
  973. add_enum_alias('Format', 'ConvolutionV0').
  974. add_enum_alias('ComputeMode', 'ConvolutionV1')
  975. )
  976. (pdef('LocalShare', 'Local share convolution', version=1).
  977. add_enum_alias('Mode', 'ConvolutionV0').
  978. add_fields(
  979. 'uint32',
  980. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  981. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  982. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  983. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  984. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  985. 'on the second dimension'), 1,
  986. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  987. 'on the second dimension'), 1,
  988. Doc('spatial_groups_h', 'spatial groups on the first dimension'), 1,
  989. Doc('spatial_groups_w', 'spatial groups on the second dimension'), 1
  990. ).
  991. add_enum_alias('Sparse', 'ConvolutionV0').
  992. add_enum_alias('Format', 'Convolution').
  993. add_enum_alias('ComputeMode', 'ConvolutionV1')
  994. )
  995. (pdef('ROIAlign', version=0, is_legacy=True).
  996. add_enum('Mode', 'MAX = 0', 'AVERAGE = 1', name_field='mode').
  997. add_enum_alias('Format', 'ConvolutionV0').
  998. add_fields('float32', 'spatial_scale', '1.0').
  999. add_fields('float32', 'offset', '0.0').
  1000. add_fields('uint32',
  1001. 'pooled_height', '1',
  1002. 'pooled_width', '1',
  1003. 'sample_height', '2',
  1004. 'sample_width', '2')
  1005. )
  1006. (pdef('ROIAlign', version=1).
  1007. add_enum_alias('Mode', 'ROIAlignV0', name_field='mode').
  1008. add_enum_alias('Format', 'Convolution').
  1009. add_fields('float32', 'spatial_scale', '1.0').
  1010. add_fields('float32', 'offset', '0.0').
  1011. add_fields('uint32',
  1012. 'pooled_height', '1',
  1013. 'pooled_width', '1',
  1014. 'sample_height', '2',
  1015. 'sample_width', '2')
  1016. )
  1017. (pdef('Correlation').
  1018. add_enum_alias('Format', 'ConvolutionV0').
  1019. add_fields('uint32', 'kernel_size', '1').
  1020. add_fields('uint32', 'max_displacement', '1').
  1021. add_fields('uint32', 'stride1', '1').
  1022. add_fields('uint32', 'stride2', '1').
  1023. add_fields('uint32', 'pad_size', '0').
  1024. add_fields('bool', 'is_multiply', 'true')
  1025. )
  1026. (pdef('DeformablePSROIPooling').
  1027. add_fields('bool', 'no_trans', 'true').
  1028. add_fields('float32', 'spatial_scale', 1,
  1029. 'trans_std', 1).
  1030. add_fields('uint32',
  1031. Doc('pooled_h', 'height of pooling output'), 1,
  1032. Doc('pooled_w', 'width of pooling output'), 1,
  1033. Doc('part_size', 'size of each deformable part'), 1,
  1034. Doc('sample_per_part', 'sample count of each bbox'), 1))
  1035. (pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)', version=0, is_legacy=True).
  1036. add_enum_alias('NonlineMode', 'ConvBiasV0').
  1037. add_enum_alias('Mode', 'ConvolutionV0').
  1038. add_fields(
  1039. 'uint32',
  1040. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  1041. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  1042. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  1043. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  1044. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  1045. 'on the second dimension'), 1,
  1046. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  1047. 'on the second dimension'), 1,
  1048. ).
  1049. add_enum_alias('Sparse', 'ConvolutionV0').
  1050. add_enum_alias('Format', 'ConvolutionV0').
  1051. add_enum_alias('ComputeMode', 'ConvolutionV1', name_field="compute_mode")
  1052. )
  1053. (pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)', version=1).
  1054. add_enum_alias('NonlineMode', 'ConvBiasV0').
  1055. add_enum_alias('Mode', 'ConvolutionV0').
  1056. add_fields(
  1057. 'uint32',
  1058. Doc('pad_h', 'padding on one side on the first dimension'), 0,
  1059. Doc('pad_w', 'padding on one side on the second dimension'), 0,
  1060. Doc('stride_h', 'kernel stride on the first dimension'), 1,
  1061. Doc('stride_w', 'kernel stride on the second dimension'), 1,
  1062. Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) '
  1063. 'on the second dimension'), 1,
  1064. Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) '
  1065. 'on the second dimension'), 1,
  1066. ).
  1067. add_enum_alias('Sparse', 'ConvolutionV0').
  1068. add_enum_alias('Format', 'Convolution').
  1069. add_enum_alias('ComputeMode', 'ConvolutionV1', name_field="compute_mode")
  1070. )
  1071. (pdef('FakeQuant').
  1072. add_fields('int32', 'qmin', '-2147483648').
  1073. add_fields('int32', 'qmax', '2147483647')
  1074. )
  1075. (pdef('TQT').
  1076. add_fields('int32', 'qmin', '-2147483648').
  1077. add_fields('int32', 'qmax', '2147483647')
  1078. )
  1079. (pdef('LSQ').
  1080. add_fields('int32', 'qmin', '-2147483648').
  1081. add_fields('int32', 'qmax', '2147483647')
  1082. )
  1083. pdef('Fill').add_fields('float32', 'value', '0')
  1084. pdef('CheckNonFinite').add_fields('float32', 'scale', '1.0')
  1085. PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
  1086. Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'),
  1087. Doc('CONSTANT = 2', 'iiiiii|abcdefgh|iiiiiii')]
  1088. (pdef('Padding').
  1089. add_fields('uint32', Doc('front_offset_dim0', 'offset in dim 0'), 0).
  1090. add_fields('uint32', Doc('front_offset_dim1', 'offset in dim 1'), 0).
  1091. add_fields('uint32', Doc('front_offset_dim2', 'offset in dim 2'), 0).
  1092. add_fields('uint32', Doc('front_offset_dim3', 'offset in dim 3'), 0).
  1093. add_fields('uint32', Doc('front_offset_dim4', 'offset in dim 4'), 0).
  1094. add_fields('uint32', Doc('front_offset_dim5', 'offset in dim 5'), 0).
  1095. add_fields('uint32', Doc('front_offset_dim6', 'offset in dim 6'), 0).
  1096. add_fields('uint32', Doc('back_offset_dim0', 'back offset in dim0'), 0).
  1097. add_fields('uint32', Doc('back_offset_dim1', 'back offset in dim1'), 0).
  1098. add_fields('uint32', Doc('back_offset_dim2', 'back offset in dim2'), 0).
  1099. add_fields('uint32', Doc('back_offset_dim3', 'back offset in dim3'), 0).
  1100. add_fields('uint32', Doc('back_offset_dim4', 'back offset in dim4'), 0).
  1101. add_fields('uint32', Doc('back_offset_dim5', 'back offset in dim5'), 0).
  1102. add_fields('uint32', Doc('back_offset_dim6', 'back offset in dim6'), 0).
  1103. add_fields('float32', Doc('padding_val', 'param of padding opr'), 0).
  1104. add_enum('PaddingMode', *PADDING_MODES,
  1105. name_field='padding_mode', default=2,
  1106. member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES]
  1107. )
  1108. )
  1109. (pdef('LayerNorm')
  1110. .add_fields('bool', 'affine', 'true')
  1111. .add_fields('float32', 'eps', '1e-5f')
  1112. .add_fields('uint64', 'normalized_dim', '1')
  1113. .add_fields('uint64', 'normalized_size', '1')
  1114. )
  1115. (pdef('Dropout')
  1116. .add_fields('float32', 'drop_prob', '0')
  1117. .add_fields('uint64', 'seed', '0')
  1118. )
  1119. (pdef('RNNCell').
  1120. add_enum('NonlineMode', 'IDENTITY = 0', 'RELU = 1', 'TANH = 2')
  1121. )
  1122. (pdef('RNN').
  1123. add_fields('uint32', 'num_layers', '1').
  1124. add_fields('bool', 'bidirectional', 'false').
  1125. add_fields('bool', 'bias', 'true').
  1126. add_fields('uint32', 'hidden_size', '128').
  1127. add_fields('uint32', 'proj_size', '0').
  1128. add_fields('float32', 'dropout', '0.f').
  1129. add_enum_alias('NonlineMode', 'RNNCell').
  1130. add_enum_alias('FwdMode', 'BN', name_field='fwd_mode')
  1131. )
  1132. (pdef('LSTM').
  1133. add_fields('uint32', 'num_layers', '1').
  1134. add_fields('bool', 'bidirectional', 'false').
  1135. add_fields('bool', 'bias', 'true').
  1136. add_fields('uint32', 'hidden_size', '128').
  1137. add_fields('uint32', 'proj_size', '0').
  1138. add_fields('float32', 'dropout', '0.f').
  1139. add_enum_alias('FwdMode', 'BN', name_field='fwd_mode')
  1140. )