You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generator.py 40 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. import argparse
  10. import platform
  11. from library import *
  12. from manifest import *
  13. ###################################################################################################
  14. #
  15. def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
  16. # by default, use the latest CUDA Toolkit version
  17. cuda_version = [11, 0, 132]
  18. # Update cuda_version based on parsed string
  19. if semantic_ver_string != '':
  20. for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')]):
  21. if i < len(cuda_version):
  22. cuda_version[i] = x
  23. else:
  24. cuda_version.append(x)
  25. return cuda_version >= [major, minor, patch]
  26. ###################################################################################################
  27. ###################################################################################################
  28. #
  29. def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
  30. alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
  31. swizzling_functor = SwizzlingFunctor.Identity8):
  32. if complex_transforms is None:
  33. complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
  34. element_a, element_b, element_c, element_epilogue = data_type
  35. operations = []
  36. # by default, only generate the largest tile and largest alignment
  37. if manifest.args.kernels == '':
  38. tile_descriptions = [tile_descriptions[0],]
  39. alignment_constraints = [alignment_constraints[0],]
  40. for layout in layouts:
  41. for tile_description in tile_descriptions:
  42. for alignment in alignment_constraints:
  43. for complex_transform in complex_transforms:
  44. alignment_c = min(8, alignment)
  45. A = TensorDescription(element_a, layout[0], alignment, complex_transform[0])
  46. B = TensorDescription(element_b, layout[1], alignment, complex_transform[1])
  47. C = TensorDescription(element_c, layout[2], alignment_c)
  48. new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \
  49. tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor)
  50. manifest.append(new_operation)
  51. operations.append(new_operation)
  52. return operations
  53. ###########################################################################################################
  54. # ConvolutionOperator support variations
  55. # ____________________________________________________________________
  56. # ConvolutionalOperator | Analytic | Optimized
  57. # ____________________________________________________________________
  58. # | Fprop | (strided) | (strided)
  59. # | Dgrad | (strided, unity*) | (unity)
  60. # | Wgrad | (strided) | (strided)
  61. # ____________________________________________________________________
  62. #
  63. # Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
  64. ###########################################################################################################
  65. # Convolution for 2D operations
  66. def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
  67. conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination):
  68. element_a, element_b, element_c, element_epilogue = data_type
  69. # one exceptional case
  70. alignment_c = min(8, alignment)
  71. # iterator algorithm (analytic and optimized)
  72. iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
  73. # by default, only generate the largest tile size
  74. if manifest.args.kernels == '':
  75. tile_descriptions = [tile_descriptions[0],]
  76. operations = []
  77. for tile in tile_descriptions:
  78. for conv_kind in conv_kinds:
  79. for iterator_algorithm in iterator_algorithms:
  80. A = TensorDescription(element_a, layout[0], alignment)
  81. B = TensorDescription(element_b, layout[1], alignment)
  82. C = TensorDescription(element_c, layout[2], alignment_c)
  83. # unity stride only for Optimized Dgrad
  84. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad):
  85. new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
  86. A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor)
  87. manifest.append(new_operation)
  88. operations.append(new_operation)
  89. # strided dgrad is not supported by Optimized Dgrad
  90. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad):
  91. continue
  92. # strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
  93. new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
  94. A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor)
  95. manifest.append(new_operation)
  96. operations.append(new_operation)
  97. return operations
  98. ###################################################################################################
  99. ###################################################################################################
  100. def GenerateConv2d_Simt(args):
  101. operations = []
  102. layouts = [
  103. (LayoutType.TensorNC4HW4, LayoutType.TensorC4RSK4),
  104. ]
  105. math_instructions = [
  106. MathInstruction( \
  107. [1, 1, 4], \
  108. DataType.s8, DataType.s8, DataType.s32, \
  109. OpcodeClass.Simt, \
  110. MathOperation.multiply_add),
  111. ]
  112. dst_layouts = [
  113. LayoutType.TensorNC4HW4,
  114. LayoutType.TensorNC32HW32,
  115. LayoutType.TensorNHWC,
  116. LayoutType.TensorNHWC,
  117. LayoutType.TensorNCHW
  118. ]
  119. dst_types = [
  120. DataType.s8,
  121. DataType.s8,
  122. DataType.u4,
  123. DataType.s4,
  124. DataType.f32,
  125. ]
  126. max_cc = 1024
  127. for math_inst in math_instructions:
  128. for layout in layouts:
  129. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  130. if dst_type == DataType.s4 or dst_type == DataType.u4:
  131. min_cc = 75
  132. use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt
  133. else:
  134. min_cc = 61
  135. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  136. tile_descriptions = [
  137. TileDescription([128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  138. TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  139. TileDescription([ 64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc),
  140. TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  141. TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  142. TileDescription([ 32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  143. TileDescription([ 64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  144. TileDescription([ 16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc),
  145. TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  146. ]
  147. for tile in tile_descriptions:
  148. if dst_layout == LayoutType.TensorNC32HW32 and tile.threadblock_shape[0] > 32:
  149. continue
  150. if (dst_layout == LayoutType.TensorNCHW or dst_layout == LayoutType.TensorNHWC) \
  151. and tile.threadblock_shape[0] > 16:
  152. continue
  153. operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1],
  154. dst_layout, dst_type, min_cc, 32, 32, 32,
  155. use_special_optimization)
  156. return operations
  157. def GenerateConv2d_TensorOp_8816(args):
  158. operations = []
  159. layouts = [
  160. (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32),
  161. ]
  162. math_instructions = [
  163. MathInstruction( \
  164. [8, 8, 16], \
  165. DataType.s8, DataType.s8, DataType.s32, \
  166. OpcodeClass.TensorOp, \
  167. MathOperation.multiply_add_saturate),
  168. ]
  169. dst_layouts = [
  170. LayoutType.TensorNC32HW32,
  171. LayoutType.TensorNC4HW4,
  172. ]
  173. dst_types = [
  174. DataType.s8,
  175. DataType.s8,
  176. ]
  177. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  178. min_cc = 75
  179. max_cc = 1024
  180. cuda_major = 10
  181. cuda_minor = 2
  182. for math_inst in math_instructions:
  183. for layout in layouts:
  184. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  185. if dst_layout == LayoutType.TensorNC32HW32:
  186. tile_descriptions = [
  187. TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  188. TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  189. TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  190. TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  191. TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  192. TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc),
  193. TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  194. ]
  195. operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
  196. dst_layout, dst_type, min_cc, 128, 128, 64, use_special_optimization,
  197. ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor)
  198. else:
  199. assert dst_layout == LayoutType.TensorNC4HW4
  200. tile_descriptions = [
  201. TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  202. TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc),
  203. ]
  204. operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
  205. dst_layout, dst_type, min_cc, 128, 128, 64, use_special_optimization,
  206. ImplicitGemmMode.GemmNT, False, cuda_major, cuda_minor)
  207. layouts_nhwc = [
  208. (LayoutType.TensorNHWC, LayoutType.TensorNC4HW4, 32),
  209. (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 64),
  210. (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 128),
  211. ]
  212. dst_layouts_nhwc = [
  213. LayoutType.TensorNHWC,
  214. ]
  215. for math_inst in math_instructions:
  216. for layout in layouts_nhwc:
  217. for dst_layout in dst_layouts_nhwc:
  218. dst_type = math_inst.element_b
  219. tile_descriptions = [
  220. TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  221. TileDescription([64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  222. ]
  223. for tile in tile_descriptions:
  224. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  225. operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout,
  226. dst_type, min_cc, layout[2], layout[2], dst_align, use_special_optimization,
  227. ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor)
  228. if tile.threadblock_shape[1] == 16 or tile.threadblock_shape[1] == 32:
  229. operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout,
  230. dst_type, min_cc, layout[2], layout[2], dst_align, use_special_optimization,
  231. ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor)
  232. out_dtypes = [DataType.s4, DataType.u4, DataType.f32]
  233. #INT8x8x4 and INT8x8x32
  234. for math_inst in math_instructions:
  235. for layout in layouts_nhwc:
  236. for dst_layout in dst_layouts_nhwc:
  237. for out_dtype in out_dtypes:
  238. tile_descriptions = [
  239. TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  240. TileDescription([64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  241. ]
  242. for tile in tile_descriptions:
  243. dst_align = 4 * DataTypeSize[out_dtype] if tile.threadblock_shape[1] == 16 or out_dtype == DataType.f32 \
  244. else 8 * DataTypeSize[out_dtype]
  245. operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout,
  246. out_dtype, min_cc, layout[2], layout[2], dst_align, use_special_optimization,
  247. ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor)
  248. if tile.threadblock_shape[1] == 16 or (tile.threadblock_shape[1] == 32 and out_dtype != DataType.f32):
  249. operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout,
  250. out_dtype, min_cc, layout[2], layout[2], dst_align, use_special_optimization,
  251. ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor)
  252. return operations
  253. def GenerateConv2d_TensorOp_8832(args):
  254. operations = []
  255. layouts = [
  256. (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64),
  257. ]
  258. math_instructions = [
  259. MathInstruction( \
  260. [8, 8, 32], \
  261. DataType.s4, DataType.s4, DataType.s32, \
  262. OpcodeClass.TensorOp, \
  263. MathOperation.multiply_add_saturate), \
  264. MathInstruction( \
  265. [8, 8, 32], \
  266. DataType.s4, DataType.u4, DataType.s32, \
  267. OpcodeClass.TensorOp, \
  268. MathOperation.multiply_add_saturate)
  269. ]
  270. dst_layouts = [
  271. LayoutType.TensorNC64HW64,
  272. ]
  273. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  274. min_cc = 75
  275. max_cc = 1024
  276. cuda_major = 10
  277. cuda_minor = 2
  278. for math_inst in math_instructions:
  279. for layout in layouts:
  280. for dst_layout in dst_layouts:
  281. dst_type = math_inst.element_b
  282. tile_descriptions = [
  283. TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  284. TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  285. TileDescription([128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  286. TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  287. ]
  288. operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
  289. dst_layout, dst_type, min_cc, 128, 128, 64, use_special_optimization,
  290. ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor)
  291. layouts_nhwc = [
  292. (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32),
  293. (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 64),
  294. (LayoutType.TensorNHWC, LayoutType.TensorNC32HW32, 128),
  295. ]
  296. dst_layouts_nhwc = [
  297. LayoutType.TensorNHWC,
  298. ]
  299. for math_inst in math_instructions:
  300. for layout in layouts_nhwc:
  301. for dst_layout in dst_layouts_nhwc:
  302. dst_type = math_inst.element_b
  303. tile_descriptions = [
  304. TileDescription([128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  305. TileDescription([128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  306. TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  307. ]
  308. for tile in tile_descriptions:
  309. dst_align = 16 if tile.threadblock_shape[1] == 16 else 32
  310. operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout,
  311. dst_type, min_cc, layout[2], layout[2], dst_align, use_special_optimization,
  312. ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor)
  313. if tile.threadblock_shape[1] == 32 or tile.threadblock_shape[1] == 64:
  314. dst_align = 32 if tile.threadblock_shape[1] == 32 else 64
  315. operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout,
  316. dst_type, min_cc, layout[2], layout[2], dst_align, use_special_optimization,
  317. ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor)
  318. # INT4x4x8
  319. for math_inst in math_instructions:
  320. for layout in layouts_nhwc:
  321. for dst_layout in dst_layouts_nhwc:
  322. tile_descriptions = [
  323. TileDescription([128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  324. TileDescription([128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  325. TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  326. ]
  327. for tile in tile_descriptions:
  328. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  329. operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout,
  330. DataType.s8, min_cc, layout[2], layout[2], dst_align, use_special_optimization,
  331. ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor)
  332. if tile.threadblock_shape[1] == 32 or tile.threadblock_shape[1] == 64:
  333. dst_align = 64 if tile.threadblock_shape[1] == 32 else 128
  334. operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], dst_layout,
  335. DataType.s8, min_cc, layout[2], layout[2], dst_align, use_special_optimization,
  336. ImplicitGemmMode.GemmTN, True, cuda_major, cuda_minor)
  337. return operations
  338. def GenerateDeconv_Simt(args):
  339. operations = []
  340. layouts = [
  341. (LayoutType.TensorNC4HW4, LayoutType.TensorK4RSC4),
  342. ]
  343. math_instructions = [
  344. MathInstruction( \
  345. [1, 1, 4], \
  346. DataType.s8, DataType.s8, DataType.s32, \
  347. OpcodeClass.Simt, \
  348. MathOperation.multiply_add),
  349. ]
  350. dst_layouts = [
  351. LayoutType.TensorNC4HW4,
  352. ]
  353. dst_types = [
  354. DataType.s8,
  355. ]
  356. use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
  357. min_cc = 61
  358. max_cc = 1024
  359. for math_inst in math_instructions:
  360. for layout in layouts:
  361. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  362. tile_descriptions = [
  363. TileDescription([32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  364. TileDescription([16, 128, 16], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  365. TileDescription([16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc),
  366. TileDescription([16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  367. ]
  368. operations += GenerateConv2d(ConvKind.Dgrad, tile_descriptions, layout[0], layout[1],
  369. dst_layout, dst_type, min_cc, 32, 32, 32,
  370. use_special_optimization)
  371. return operations
  372. def GenerateDeconv_TensorOp_8816(args):
  373. operations = []
  374. layouts = [
  375. (LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32),
  376. (LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64),
  377. (LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128),
  378. ]
  379. math_instructions = [
  380. MathInstruction( \
  381. [8, 8, 16], \
  382. DataType.s8, DataType.s8, DataType.s32, \
  383. OpcodeClass.TensorOp, \
  384. MathOperation.multiply_add_saturate),
  385. ]
  386. dst_layouts = [
  387. LayoutType.TensorNHWC,
  388. ]
  389. dst_types = [
  390. DataType.s8,
  391. ]
  392. use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
  393. min_cc = 75
  394. max_cc = 1024
  395. cuda_major = 10
  396. cuda_minor = 2
  397. for math_inst in math_instructions:
  398. for layout in layouts:
  399. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  400. tile_descriptions = [
  401. TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  402. TileDescription([64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  403. ]
  404. for tile in tile_descriptions:
  405. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  406. operations += GenerateConv2d(ConvKind.Dgrad, [tile], layout[0], layout[1], dst_layout, dst_type,
  407. min_cc, layout[2], layout[2], dst_align, use_special_optimization,
  408. ImplicitGemmMode.GemmTN, False, cuda_major, cuda_minor)
  409. return operations
  410. ################################################################################
  411. # parameters
  412. # Edge - for tiles, the edges represent the length of one side
  413. # Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles
  414. # MaxEdge - maximum length of each edge
  415. # Min/Max - minimum/maximum of the product of edge lengths
  416. ################################################################################
  417. warpsPerThreadblockEdge = [1, 2, 4, 8, 16]
  418. warpsPerThreadblockRatio = 2
  419. warpsPerThreadblockMax = 16
  420. # NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases
  421. warpShapeEdges = [8, 16, 32, 64, 128, 256]
  422. warpShapeRatio = 4
  423. warpShapeMax = 64*64
  424. warpShapeMin = 8*8
  425. threadblockEdgeMax = 256
  426. # char, type bits/elem, max tile, L0 threadblock tiles
  427. precisions = {
  428. "c" : [ "cutlass::complex<float>", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ],
  429. "d" : [ "double", 64, 64*64, [ [ 64, 64], [ 32, 32] ] ],
  430. "h" : [ "cutlass::half_t", 16, 128*256, [ [256, 128], [ 64, 128], [ 64, 32] ] ],
  431. "i" : [ "int", 32, 128*128, [ [128, 64], [ 16, 32] ] ],
  432. "s" : [ "float", 32, 128*128, [ [128, 256], [128, 128], [ 64, 64] ] ],
  433. "z" : [ "cutlass::complex<double>", 128, 64*64, [ [ 32, 64], [ 16, 32] ] ],
  434. }
  435. # L1 will have a single kernel for every unique shape
  436. # L2 will have everything else
  437. def GenerateGemm_Simt(args):
  438. ################################################################################
  439. # warps per threadblock
  440. ################################################################################
  441. warpsPerThreadblocks = []
  442. for warpsPerThreadblock0 in warpsPerThreadblockEdge:
  443. for warpsPerThreadblock1 in warpsPerThreadblockEdge:
  444. if warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio \
  445. and warpsPerThreadblock1 / warpsPerThreadblock0 <= warpsPerThreadblockRatio \
  446. and warpsPerThreadblock0 * warpsPerThreadblock1 <= warpsPerThreadblockMax:
  447. warpsPerThreadblocks.append([warpsPerThreadblock0,
  448. warpsPerThreadblock1])
  449. ################################################################################
  450. # warp shapes
  451. ################################################################################
  452. warpNumThreads = 32
  453. warpShapes = []
  454. for warp0 in warpShapeEdges:
  455. for warp1 in warpShapeEdges:
  456. if warp0 / warp1 <= warpShapeRatio \
  457. and warp1 / warp0 <= warpShapeRatio \
  458. and warp0 * warp1 <= warpShapeMax \
  459. and warp0*warp1 > warpShapeMin:
  460. warpShapes.append([warp0, warp1])
  461. # sgemm
  462. precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions["s"]
  463. layouts = [
  464. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  465. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  466. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  467. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  468. ]
  469. math_instructions = [
  470. MathInstruction( \
  471. [1, 1, 1], \
  472. DataType.f32, DataType.f32, DataType.f32, \
  473. OpcodeClass.Simt, \
  474. MathOperation.multiply_add),
  475. ]
  476. min_cc = 50
  477. max_cc = 1024
  478. operations = []
  479. for math_inst in math_instructions:
  480. for layout in layouts:
  481. data_type = [
  482. math_inst.element_a,
  483. math_inst.element_b,
  484. math_inst.element_accumulator,
  485. math_inst.element_accumulator,
  486. ]
  487. tile_descriptions = [
  488. TileDescription([64, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  489. TileDescription([256, 64, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  490. TileDescription([ 32, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  491. TileDescription([256, 32, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  492. TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  493. TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  494. TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  495. TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  496. TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  497. TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  498. TileDescription([ 32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  499. TileDescription([ 64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  500. TileDescription([ 32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  501. TileDescription([ 8, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  502. TileDescription([ 16, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  503. TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  504. TileDescription([ 16, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  505. ]
  506. for warpsPerThreadblock in warpsPerThreadblocks:
  507. for warpShape in warpShapes:
  508. warpThreadsM = 0
  509. if warpShape[0] > warpShape[1]:
  510. warpThreadsM = 8
  511. else:
  512. warpThreadsM = 4
  513. warpThreadsN = warpNumThreads / warpThreadsM
  514. # skip shapes with conflicting rectangularity
  515. # they are unlikely to be fastest
  516. blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
  517. blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
  518. warpG = warpShape[0] > warpShape[1]
  519. warpL = warpShape[0] < warpShape[1]
  520. blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1]*2
  521. blockL2 = warpsPerThreadblock[0]*2 < warpsPerThreadblock[1]
  522. warpG2 = warpShape[0] > warpShape[1]*2
  523. warpL2 = warpShape[0]*2 < warpShape[1]
  524. if blockG2 and warpL: continue
  525. if blockL2 and warpG: continue
  526. if warpG2 and blockL: continue
  527. if warpL2 and blockG: continue
  528. # check threadblock ratios and max
  529. threadblockTile = [warpShape[0]*warpsPerThreadblock[0],
  530. warpShape[1]*warpsPerThreadblock[1]]
  531. if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: continue
  532. if threadblockTile[0] > threadblockEdgeMax: continue
  533. if threadblockTile[1] > threadblockEdgeMax: continue
  534. totalThreads = warpNumThreads*warpsPerThreadblock[0]*warpsPerThreadblock[1]
  535. # calculate unroll
  536. # ensure that every iteration at least a full load of A,B are done
  537. unrollMin = 8
  538. unrollMin0 = totalThreads // threadblockTile[0]
  539. unrollMin1 = totalThreads // threadblockTile[1]
  540. unroll = max(unrollMin, unrollMin0, unrollMin1)
  541. threadTileM = warpShape[0] // warpThreadsM
  542. threadTileN = warpShape[1] // warpThreadsN
  543. if threadTileM < 2 or threadTileN < 2: continue
  544. if threadTileM*threadTileN*precisionBits > 8*8*32: continue
  545. # epilogue currently only supports N < WarpNumThreads
  546. if threadblockTile[1] < warpNumThreads: continue
  547. # limit smem
  548. smemBitsA = threadblockTile[0]*unroll*2*precisionBits
  549. smemBitsB = threadblockTile[1]*unroll*2*precisionBits
  550. smemKBytes = (smemBitsA+smemBitsB)/8/1024
  551. if (smemKBytes > 48): continue
  552. tile = TileDescription([threadblockTile[0], threadblockTile[1], unroll], \
  553. 2, \
  554. [threadblockTile[0]//warpShape[0], threadblockTile[1]//warpShape[1], 1], \
  555. math_inst, min_cc, max_cc)
  556. def filter(t: TileDescription) -> bool:
  557. nonlocal tile
  558. return t.threadblock_shape[0] == tile.threadblock_shape[0] and \
  559. t.threadblock_shape[1] == tile.threadblock_shape[1] and \
  560. t.threadblock_shape[2] == tile.threadblock_shape[2] and \
  561. t.warp_count[0] == tile.warp_count[0] and \
  562. t.warp_count[1] == tile.warp_count[1] and \
  563. t.warp_count[2] == tile.warp_count[2] and \
  564. t.stages == tile.stages
  565. if not any(t for t in tile_descriptions if filter(t)): continue
  566. operations += GeneratesGemm(tile, data_type, layout[0], layout[1], layout[2], min_cc)
  567. return operations
  568. #
  569. def GenerateGemv_Simt(args):
  570. threadBlockShape_N = [128, 64, 32]
  571. ldgBits_A = [128, 64, 32]
  572. ldgBits_B = [128, 64, 32]
  573. layouts = [
  574. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
  575. ]
  576. math_instructions = [
  577. MathInstruction( \
  578. [1, 1, 1], \
  579. DataType.f32, DataType.f32, DataType.f32, \
  580. OpcodeClass.Simt, \
  581. MathOperation.multiply_add),
  582. ]
  583. min_cc = 50
  584. operations = []
  585. for math_inst in math_instructions:
  586. for layout in layouts:
  587. data_type = [
  588. math_inst.element_a,
  589. math_inst.element_b,
  590. math_inst.element_accumulator,
  591. math_inst.element_accumulator,
  592. ]
  593. for threadblock_shape_n in threadBlockShape_N:
  594. for align_a in ldgBits_A:
  595. for align_b in ldgBits_B:
  596. ldg_elements_a = align_a // DataTypeSize[math_inst.element_a]
  597. ldg_elements_b = align_b // DataTypeSize[math_inst.element_b]
  598. threadblock_shape_k = (256 * ldg_elements_a) // (threadblock_shape_n // ldg_elements_b)
  599. threadblock_shape = [1, threadblock_shape_n, threadblock_shape_k]
  600. thread_shape = [1, ldg_elements_b, ldg_elements_a]
  601. operations.append(GeneratesGemv(math_inst, \
  602. threadblock_shape, \
  603. thread_shape, \
  604. data_type, \
  605. layout[0], \
  606. layout[1], \
  607. layout[2], \
  608. min_cc, \
  609. align_a, \
  610. align_b))
  611. return operations
  612. #
  613. def GeneratesGemm_TensorOp_1688(args):
  614. layouts = [
  615. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  616. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  617. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  618. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  619. ]
  620. math_instructions = [
  621. MathInstruction( \
  622. [16, 8, 8], \
  623. DataType.f16, DataType.f16, DataType.f32, \
  624. OpcodeClass.TensorOp, \
  625. MathOperation.multiply_add),
  626. MathInstruction( \
  627. [16, 8, 8], \
  628. DataType.f16, DataType.f16, DataType.f16, \
  629. OpcodeClass.TensorOp, \
  630. MathOperation.multiply_add),
  631. ]
  632. min_cc = 75
  633. max_cc = 1024
  634. alignment_constraints = [8, 4, 2,
  635. #1
  636. ]
  637. cuda_major = 10
  638. cuda_minor = 2
  639. operations = []
  640. for math_inst in math_instructions:
  641. for layout in layouts:
  642. for align in alignment_constraints:
  643. tile_descriptions = [
  644. TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  645. TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  646. TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  647. ## comment some configuration to reduce compilation time and binary size
  648. # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  649. # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  650. # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  651. ]
  652. data_type = [
  653. math_inst.element_a,
  654. math_inst.element_b,
  655. math_inst.element_a,
  656. math_inst.element_accumulator,
  657. ]
  658. for tile in tile_descriptions:
  659. operations += GeneratesGemm(tile, \
  660. data_type, \
  661. layout[0], \
  662. layout[1], \
  663. layout[2], \
  664. min_cc, \
  665. align * 16, \
  666. align * 16, \
  667. align * 16, \
  668. cuda_major, \
  669. cuda_minor)
  670. return operations
  671. #
  672. def GeneratesGemm_TensorOp_884(args):
  673. layouts = [
  674. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  675. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  676. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  677. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  678. ]
  679. math_instructions = [
  680. MathInstruction( \
  681. [8, 8, 4], \
  682. DataType.f16, DataType.f16, DataType.f32, \
  683. OpcodeClass.TensorOp, \
  684. MathOperation.multiply_add),
  685. MathInstruction( \
  686. [8, 8, 4], \
  687. DataType.f16, DataType.f16, DataType.f16, \
  688. OpcodeClass.TensorOp, \
  689. MathOperation.multiply_add),
  690. ]
  691. min_cc = 70
  692. max_cc = 75
  693. alignment_constraints = [8, 4, 2,
  694. # 1
  695. ]
  696. cuda_major = 10
  697. cuda_minor = 2
  698. operations = []
  699. for math_inst in math_instructions:
  700. for layout in layouts:
  701. for align in alignment_constraints:
  702. tile_descriptions = [
  703. TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  704. TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  705. TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  706. ## comment some configuration to reduce compilation time and binary size
  707. # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  708. # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  709. # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  710. ]
  711. data_type = [
  712. math_inst.element_a,
  713. math_inst.element_b,
  714. math_inst.element_a,
  715. math_inst.element_accumulator,
  716. ]
  717. for tile in tile_descriptions:
  718. operations += GeneratesGemm(tile, \
  719. data_type, \
  720. layout[0], \
  721. layout[1], \
  722. layout[2], \
  723. min_cc, \
  724. align * 16, \
  725. align * 16, \
  726. align * 16, \
  727. cuda_major, \
  728. cuda_minor)
  729. return operations
  730. #
  731. def GenerateConv2dOperations(args):
  732. if args.type == "simt":
  733. return GenerateConv2d_Simt(args)
  734. elif args.type == "tensorop8816":
  735. return GenerateConv2d_TensorOp_8816(args)
  736. else:
  737. assert args.type == "tensorop8832", "operation conv2d only support" \
  738. "simt, tensorop8816 and tensorop8832. (got:{})".format(args.type)
  739. return GenerateConv2d_TensorOp_8832(args)
  740. def GenerateDeconvOperations(args):
  741. if args.type == "simt":
  742. return GenerateDeconv_Simt(args)
  743. else:
  744. assert args.type == "tensorop8816", "operation deconv only support" \
  745. "simt and tensorop8816. (got:{})".format(args.type)
  746. return GenerateDeconv_TensorOp_8816(args)
  747. def GenerateGemmOperations(args):
  748. if args.type == "tensorop884":
  749. return GeneratesGemm_TensorOp_884(args)
  750. elif args.type == "tensorop1688":
  751. return GeneratesGemm_TensorOp_1688(args)
  752. else:
  753. assert args.type == "simt", "operation gemm only support" \
  754. "simt. (got:{})".format(args.type)
  755. return GenerateGemm_Simt(args)
  756. def GenerateGemvOperations(args):
  757. assert args.type == "simt", "operation gemv only support" \
  758. "simt. (got:{})".format(args.type)
  759. return GenerateGemv_Simt(args)
  760. ###################################################################################################
  761. ###################################################################################################
  762. if __name__ == "__main__":
  763. parser = argparse.ArgumentParser(description="Generates device kernel registration code for CUTLASS Kernels")
  764. parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'],
  765. required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)")
  766. parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files")
  767. parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832', 'tensorop884', 'tensorop1688'],
  768. default='simt', help="kernel type of CUTLASS kernel generator")
  769. gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
  770. short_path = (platform.system() == "Windows" or platform.system().find('NT') >= 0) and ('true'!= os.getenv("CUTLASS_WITH_LONG_PATH", default='False').lower())
  771. args = parser.parse_args()
  772. if args.operations == "gemm":
  773. operations = GenerateGemmOperations(args)
  774. elif args.operations == "gemv":
  775. operations = GenerateGemvOperations(args)
  776. elif args.operations == "conv2d":
  777. operations = GenerateConv2dOperations(args)
  778. elif args.operations == "deconv":
  779. operations = GenerateDeconvOperations(args)
  780. if args.operations == "conv2d" or args.operations == "deconv":
  781. for operation in operations:
  782. with EmitConvSingleKernelWrapper(args.output, operation, short_path) as emitter:
  783. emitter.emit()
  784. elif args.operations == "gemm":
  785. for operation in operations:
  786. with EmitGemmSingleKernelWrapper(args.output, operation, short_path) as emitter:
  787. emitter.emit()
  788. elif args.operations == "gemv":
  789. for operation in operations:
  790. with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path, short_path) as emitter:
  791. emitter.emit()
  792. if args.operations != "gemv":
  793. GenerateManifest(args, operations, args.output)
  794. #
  795. ###################################################################################################