You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generator.py 28 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. import argparse
  10. from library import *
  11. from manifest import *
  12. ###################################################################################################
  13. #
  14. def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
  15. # by default, use the latest CUDA Toolkit version
  16. cuda_version = [11, 0, 132]
  17. # Update cuda_version based on parsed string
  18. if semantic_ver_string != '':
  19. for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')]):
  20. if i < len(cuda_version):
  21. cuda_version[i] = x
  22. else:
  23. cuda_version.append(x)
  24. return cuda_version >= [major, minor, patch]
  25. ###################################################################################################
  26. ###################################################################################################
  27. #
  28. def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
  29. alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
  30. swizzling_functor = SwizzlingFunctor.Identity8):
  31. if complex_transforms is None:
  32. complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
  33. element_a, element_b, element_c, element_epilogue = data_type
  34. operations = []
  35. # by default, only generate the largest tile and largest alignment
  36. if manifest.args.kernels == '':
  37. tile_descriptions = [tile_descriptions[0],]
  38. alignment_constraints = [alignment_constraints[0],]
  39. for layout in layouts:
  40. for tile_description in tile_descriptions:
  41. for alignment in alignment_constraints:
  42. for complex_transform in complex_transforms:
  43. alignment_c = min(8, alignment)
  44. A = TensorDescription(element_a, layout[0], alignment, complex_transform[0])
  45. B = TensorDescription(element_b, layout[1], alignment, complex_transform[1])
  46. C = TensorDescription(element_c, layout[2], alignment_c)
  47. new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \
  48. tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor)
  49. manifest.append(new_operation)
  50. operations.append(new_operation)
  51. return operations
  52. ###########################################################################################################
  53. # ConvolutionOperator support variations
  54. # ____________________________________________________________________
  55. # ConvolutionalOperator | Analytic | Optimized
  56. # ____________________________________________________________________
  57. # | Fprop | (strided) | (strided)
  58. # | Dgrad | (strided, unity*) | (unity)
  59. # | Wgrad | (strided) | (strided)
  60. # ____________________________________________________________________
  61. #
  62. # Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
  63. ###########################################################################################################
  64. # Convolution for 2D operations
  65. def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
  66. conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination):
  67. element_a, element_b, element_c, element_epilogue = data_type
  68. # one exceptional case
  69. alignment_c = min(8, alignment)
  70. # iterator algorithm (analytic and optimized)
  71. iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
  72. # by default, only generate the largest tile size
  73. if manifest.args.kernels == '':
  74. tile_descriptions = [tile_descriptions[0],]
  75. operations = []
  76. for tile in tile_descriptions:
  77. for conv_kind in conv_kinds:
  78. for iterator_algorithm in iterator_algorithms:
  79. A = TensorDescription(element_a, layout[0], alignment)
  80. B = TensorDescription(element_b, layout[1], alignment)
  81. C = TensorDescription(element_c, layout[2], alignment_c)
  82. # unity stride only for Optimized Dgrad
  83. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad):
  84. new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
  85. A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor)
  86. manifest.append(new_operation)
  87. operations.append(new_operation)
  88. # strided dgrad is not supported by Optimized Dgrad
  89. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (conv_kind == ConvKind.Dgrad):
  90. continue
  91. # strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
  92. new_operation = Conv2dOperation(conv_kind, iterator_algorithm, tile.minimum_compute_capability, tile,\
  93. A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor)
  94. manifest.append(new_operation)
  95. operations.append(new_operation)
  96. return operations
  97. ###################################################################################################
  98. ###################################################################################################
  99. def GenerateConv2d_Simt(args):
  100. operations = []
  101. layouts = [
  102. (LayoutType.TensorNC4HW4, LayoutType.TensorC4RSK4),
  103. ]
  104. math_instructions = [
  105. MathInstruction( \
  106. [1, 1, 4], \
  107. DataType.s8, DataType.s8, DataType.s32, \
  108. OpcodeClass.Simt, \
  109. MathOperation.multiply_add),
  110. ]
  111. dst_layouts = [
  112. LayoutType.TensorNC4HW4,
  113. LayoutType.TensorNC32HW32,
  114. LayoutType.TensorNHWC,
  115. LayoutType.TensorNHWC,
  116. LayoutType.TensorNCHW
  117. ]
  118. dst_types = [
  119. DataType.s8,
  120. DataType.s8,
  121. DataType.u4,
  122. DataType.s4,
  123. DataType.f32,
  124. ]
  125. max_cc = 1024
  126. for math_inst in math_instructions:
  127. for layout in layouts:
  128. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  129. if dst_type == DataType.s4 or dst_type == DataType.u4:
  130. min_cc = 75
  131. skip_unity_kernel = True
  132. else:
  133. min_cc = 61
  134. skip_unity_kernel = False
  135. tile_descriptions = [
  136. TileDescription([128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  137. TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  138. TileDescription([ 64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc),
  139. TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  140. TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  141. TileDescription([ 32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  142. TileDescription([ 64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  143. TileDescription([ 16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc),
  144. TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  145. ]
  146. operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
  147. dst_layout, dst_type, min_cc, 32, 32, 32,
  148. skip_unity_kernel)
  149. return operations
  150. def GenerateConv2d_TensorOp_8816(args):
  151. operations = []
  152. layouts = [
  153. (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32),
  154. ]
  155. math_instructions = [
  156. MathInstruction( \
  157. [8, 8, 16], \
  158. DataType.s8, DataType.s8, DataType.s32, \
  159. OpcodeClass.TensorOp, \
  160. MathOperation.multiply_add_saturate),
  161. ]
  162. dst_layouts = [
  163. LayoutType.TensorNC32HW32,
  164. LayoutType.TensorNC4HW4,
  165. ]
  166. dst_types = [
  167. DataType.s8,
  168. DataType.s8,
  169. ]
  170. min_cc = 75
  171. max_cc = 1024
  172. for math_inst in math_instructions:
  173. for layout in layouts:
  174. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  175. if dst_layout == LayoutType.TensorNC32HW32:
  176. tile_descriptions = [
  177. TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  178. TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  179. TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  180. TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  181. TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  182. TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc),
  183. TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  184. TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc),
  185. TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc),
  186. ]
  187. operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
  188. dst_layout, dst_type, min_cc, 128, 128, 64,
  189. False, ImplicitGemmMode.GemmTN, True)
  190. else:
  191. assert dst_layout == LayoutType.TensorNC4HW4
  192. tile_descriptions = [
  193. TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  194. TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  195. TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  196. TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  197. TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  198. TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc),
  199. TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  200. TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc),
  201. TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc),
  202. ]
  203. operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
  204. dst_layout, dst_type, min_cc, 128, 128, 64,
  205. False)
  206. return operations
  207. def GenerateConv2d_TensorOp_8832(args):
  208. operations = []
  209. layouts = [
  210. (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64),
  211. ]
  212. math_instructions = [
  213. MathInstruction( \
  214. [8, 8, 32], \
  215. DataType.s4, DataType.s4, DataType.s32, \
  216. OpcodeClass.TensorOp, \
  217. MathOperation.multiply_add_saturate), \
  218. MathInstruction( \
  219. [8, 8, 32], \
  220. DataType.s4, DataType.u4, DataType.s32, \
  221. OpcodeClass.TensorOp, \
  222. MathOperation.multiply_add_saturate)
  223. ]
  224. dst_layouts = [
  225. LayoutType.TensorNC64HW64,
  226. ]
  227. min_cc = 75
  228. max_cc = 1024
  229. for math_inst in math_instructions:
  230. for layout in layouts:
  231. for dst_layout in dst_layouts:
  232. dst_type = math_inst.element_b
  233. tile_descriptions = [
  234. TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  235. TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  236. TileDescription([128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  237. TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  238. ]
  239. operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1],
  240. dst_layout, dst_type, min_cc, 128, 128, 64,
  241. False, ImplicitGemmMode.GemmTN, True)
  242. layouts_nhwc = [
  243. (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32),
  244. (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 64),
  245. (LayoutType.TensorNHWC, LayoutType.TensorNC32HW32, 128),
  246. ]
  247. dst_layouts_nhwc = [
  248. LayoutType.TensorNHWC,
  249. ]
  250. for math_inst in math_instructions:
  251. for layout in layouts_nhwc:
  252. for dst_layout in dst_layouts_nhwc:
  253. dst_type = math_inst.element_b
  254. tile_descriptions = [
  255. TileDescription([128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  256. TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc),
  257. ]
  258. for tile in tile_descriptions:
  259. operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1],
  260. dst_layout, dst_type, min_cc, layout[2], layout[2], 32,
  261. False, ImplicitGemmMode.GemmTN, False)
  262. if tile.threadblock_shape[1] == 32 or tile.threadblock_shape[1] == 64:
  263. dst_align = 32 if tile.threadblock_shape[1] == 32 else 64
  264. operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1],
  265. dst_layout, dst_type, min_cc, layout[2], layout[2], dst_align,
  266. False, ImplicitGemmMode.GemmTN, True)
  267. return operations
  268. def GenerateDeconv_Simt(args):
  269. operations = []
  270. layouts = [
  271. (LayoutType.TensorNC4HW4, LayoutType.TensorK4RSC4),
  272. ]
  273. math_instructions = [
  274. MathInstruction( \
  275. [1, 1, 4], \
  276. DataType.s8, DataType.s8, DataType.s32, \
  277. OpcodeClass.Simt, \
  278. MathOperation.multiply_add),
  279. ]
  280. dst_layouts = [
  281. LayoutType.TensorNC4HW4,
  282. ]
  283. dst_types = [
  284. DataType.s8,
  285. ]
  286. min_cc = 61
  287. max_cc = 1024
  288. for math_inst in math_instructions:
  289. for layout in layouts:
  290. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  291. tile_descriptions = [
  292. TileDescription([64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc),
  293. TileDescription([32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  294. TileDescription([16, 128, 16], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  295. TileDescription([16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc),
  296. TileDescription([16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  297. ]
  298. operations += GenerateConv2d(ConvKind.Dgrad, tile_descriptions, layout[0], layout[1],
  299. dst_layout, dst_type, min_cc, 32, 32, 32,
  300. True)
  301. return operations
  302. ################################################################################
  303. # parameters
  304. # Edge - for tiles, the edges represent the length of one side
  305. # Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles
  306. # MaxEdge - maximum length of each edge
  307. # Min/Max - minimum/maximum of the product of edge lengths
  308. ################################################################################
  309. warpsPerThreadblockEdge = [1, 2, 4, 8, 16]
  310. warpsPerThreadblockRatio = 2
  311. warpsPerThreadblockMax = 16
  312. # NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases
  313. warpShapeEdges = [8, 16, 32, 64, 128, 256]
  314. warpShapeRatio = 4
  315. warpShapeMax = 64*64
  316. warpShapeMin = 8*8
  317. threadblockEdgeMax = 256
  318. # char, type bits/elem, max tile, L0 threadblock tiles
  319. precisions = {
  320. "c" : [ "cutlass::complex<float>", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ],
  321. "d" : [ "double", 64, 64*64, [ [ 64, 64], [ 32, 32] ] ],
  322. "h" : [ "cutlass::half_t", 16, 128*256, [ [256, 128], [ 64, 128], [ 64, 32] ] ],
  323. "i" : [ "int", 32, 128*128, [ [128, 64], [ 16, 32] ] ],
  324. "s" : [ "float", 32, 128*128, [ [128, 256], [128, 128], [ 64, 64] ] ],
  325. "z" : [ "cutlass::complex<double>", 128, 64*64, [ [ 32, 64], [ 16, 32] ] ],
  326. }
  327. # L1 will have a single kernel for every unique shape
  328. # L2 will have everything else
  329. def GenerateGemm_Simt(args):
  330. ################################################################################
  331. # warps per threadblock
  332. ################################################################################
  333. warpsPerThreadblocks = []
  334. for warpsPerThreadblock0 in warpsPerThreadblockEdge:
  335. for warpsPerThreadblock1 in warpsPerThreadblockEdge:
  336. if warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio \
  337. and warpsPerThreadblock1 / warpsPerThreadblock0 <= warpsPerThreadblockRatio \
  338. and warpsPerThreadblock0 * warpsPerThreadblock1 <= warpsPerThreadblockMax:
  339. warpsPerThreadblocks.append([warpsPerThreadblock0,
  340. warpsPerThreadblock1])
  341. ################################################################################
  342. # warp shapes
  343. ################################################################################
  344. warpNumThreads = 32
  345. warpShapes = []
  346. for warp0 in warpShapeEdges:
  347. for warp1 in warpShapeEdges:
  348. if warp0 / warp1 <= warpShapeRatio \
  349. and warp1 / warp0 <= warpShapeRatio \
  350. and warp0 * warp1 <= warpShapeMax \
  351. and warp0*warp1 > warpShapeMin:
  352. warpShapes.append([warp0, warp1])
  353. # sgemm
  354. precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions["s"]
  355. layouts = [
  356. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  357. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  358. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  359. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  360. ]
  361. math_instructions = [
  362. MathInstruction( \
  363. [1, 1, 1], \
  364. DataType.f32, DataType.f32, DataType.f32, \
  365. OpcodeClass.Simt, \
  366. MathOperation.multiply_add),
  367. ]
  368. min_cc = 50
  369. max_cc = 1024
  370. operations = []
  371. for math_inst in math_instructions:
  372. for layout in layouts:
  373. data_type = [
  374. math_inst.element_a,
  375. math_inst.element_b,
  376. math_inst.element_accumulator,
  377. math_inst.element_accumulator,
  378. ]
  379. tile_descriptions = [
  380. TileDescription([64, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  381. TileDescription([256, 64, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  382. TileDescription([ 32, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  383. TileDescription([256, 32, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  384. TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  385. TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  386. TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  387. TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  388. TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  389. TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  390. TileDescription([ 32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  391. TileDescription([ 64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  392. TileDescription([ 32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  393. TileDescription([ 8, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  394. TileDescription([ 16, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  395. TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  396. TileDescription([ 16, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  397. ]
  398. for warpsPerThreadblock in warpsPerThreadblocks:
  399. for warpShape in warpShapes:
  400. warpThreadsM = 0
  401. if warpShape[0] > warpShape[1]:
  402. warpThreadsM = 8
  403. else:
  404. warpThreadsM = 4
  405. warpThreadsN = warpNumThreads / warpThreadsM
  406. # skip shapes with conflicting rectangularity
  407. # they are unlikely to be fastest
  408. blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
  409. blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
  410. warpG = warpShape[0] > warpShape[1]
  411. warpL = warpShape[0] < warpShape[1]
  412. blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1]*2
  413. blockL2 = warpsPerThreadblock[0]*2 < warpsPerThreadblock[1]
  414. warpG2 = warpShape[0] > warpShape[1]*2
  415. warpL2 = warpShape[0]*2 < warpShape[1]
  416. if blockG2 and warpL: continue
  417. if blockL2 and warpG: continue
  418. if warpG2 and blockL: continue
  419. if warpL2 and blockG: continue
  420. # check threadblock ratios and max
  421. threadblockTile = [warpShape[0]*warpsPerThreadblock[0],
  422. warpShape[1]*warpsPerThreadblock[1]]
  423. if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: continue
  424. if threadblockTile[0] > threadblockEdgeMax: continue
  425. if threadblockTile[1] > threadblockEdgeMax: continue
  426. totalThreads = warpNumThreads*warpsPerThreadblock[0]*warpsPerThreadblock[1]
  427. # calculate unroll
  428. # ensure that every iteration at least a full load of A,B are done
  429. unrollMin = 8
  430. unrollMin0 = totalThreads // threadblockTile[0]
  431. unrollMin1 = totalThreads // threadblockTile[1]
  432. unroll = max(unrollMin, unrollMin0, unrollMin1)
  433. threadTileM = warpShape[0] // warpThreadsM
  434. threadTileN = warpShape[1] // warpThreadsN
  435. if threadTileM < 2 or threadTileN < 2: continue
  436. if threadTileM*threadTileN*precisionBits > 8*8*32: continue
  437. # epilogue currently only supports N < WarpNumThreads
  438. if threadblockTile[1] < warpNumThreads: continue
  439. # limit smem
  440. smemBitsA = threadblockTile[0]*unroll*2*precisionBits
  441. smemBitsB = threadblockTile[1]*unroll*2*precisionBits
  442. smemKBytes = (smemBitsA+smemBitsB)/8/1024
  443. if (smemKBytes > 48): continue
  444. tile = TileDescription([threadblockTile[0], threadblockTile[1], unroll], \
  445. 2, \
  446. [threadblockTile[0]//warpShape[0], threadblockTile[1]//warpShape[1], 1], \
  447. math_inst, min_cc, max_cc)
  448. def filter(t: TileDescription) -> bool:
  449. nonlocal tile
  450. return t.threadblock_shape[0] == tile.threadblock_shape[0] and \
  451. t.threadblock_shape[1] == tile.threadblock_shape[1] and \
  452. t.threadblock_shape[2] == tile.threadblock_shape[2] and \
  453. t.warp_count[0] == tile.warp_count[0] and \
  454. t.warp_count[1] == tile.warp_count[1] and \
  455. t.warp_count[2] == tile.warp_count[2] and \
  456. t.stages == tile.stages
  457. if not any(t for t in tile_descriptions if filter(t)): continue
  458. operations += GeneratesGemm(tile, data_type, layout[0], layout[1], layout[2], min_cc)
  459. return operations
  460. #
  461. def GenerateGemv_Simt(args):
  462. threadBlockShape_N = [128, 64, 32]
  463. ldgBits_A = [128, 64, 32]
  464. ldgBits_B = [128, 64, 32]
  465. layouts = [
  466. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
  467. ]
  468. math_instructions = [
  469. MathInstruction( \
  470. [1, 1, 1], \
  471. DataType.f32, DataType.f32, DataType.f32, \
  472. OpcodeClass.Simt, \
  473. MathOperation.multiply_add),
  474. ]
  475. min_cc = 50
  476. operations = []
  477. for math_inst in math_instructions:
  478. for layout in layouts:
  479. data_type = [
  480. math_inst.element_a,
  481. math_inst.element_b,
  482. math_inst.element_accumulator,
  483. math_inst.element_accumulator,
  484. ]
  485. for threadblock_shape_n in threadBlockShape_N:
  486. for align_a in ldgBits_A:
  487. for align_b in ldgBits_B:
  488. ldg_elements_a = align_a // DataTypeSize[math_inst.element_a]
  489. ldg_elements_b = align_b // DataTypeSize[math_inst.element_b]
  490. threadblock_shape_k = (256 * ldg_elements_a) // (threadblock_shape_n // ldg_elements_b)
  491. threadblock_shape = [1, threadblock_shape_n, threadblock_shape_k]
  492. thread_shape = [1, ldg_elements_b, ldg_elements_a]
  493. operations.append(GeneratesGemv(math_inst, \
  494. threadblock_shape, \
  495. thread_shape, \
  496. data_type, \
  497. layout[0], \
  498. layout[1], \
  499. layout[2], \
  500. min_cc, \
  501. align_a, \
  502. align_b))
  503. return operations
  504. #
  505. def GenerateConv2dOperations(args):
  506. if args.type == "simt":
  507. return GenerateConv2d_Simt(args)
  508. elif args.type == "tensorop8816":
  509. return GenerateConv2d_TensorOp_8816(args)
  510. else:
  511. assert args.type == "tensorop8832", "operation conv2d only support" \
  512. "simt, tensorop8816 and tensorop8832. (got:{})".format(args.type)
  513. return GenerateConv2d_TensorOp_8832(args)
  514. def GenerateDeconvOperations(args):
  515. assert args.type == "simt", "operation deconv only support" \
  516. "simt. (got:{})".format(args.type)
  517. return GenerateDeconv_Simt(args)
  518. def GenerateGemmOperations(args):
  519. assert args.type == "simt", "operation gemm only support" \
  520. "simt. (got:{})".format(args.type)
  521. return GenerateGemm_Simt(args)
  522. def GenerateGemvOperations(args):
  523. assert args.type == "simt", "operation gemv only support" \
  524. "simt. (got:{})".format(args.type)
  525. return GenerateGemv_Simt(args)
  526. ###################################################################################################
  527. ###################################################################################################
  528. if __name__ == "__main__":
  529. parser = argparse.ArgumentParser(description="Generates device kernel registration code for CUTLASS Kernels")
  530. parser.add_argument("--operations", type=str, choices=['gemm', 'gemv', 'conv2d', 'deconv'],
  531. required=True, help="Specifies the operation to generate (gemm, gemv, conv2d, deconv)")
  532. parser.add_argument("output", type=str, help="output directory for CUTLASS kernel files")
  533. parser.add_argument("--type", type=str, choices=['simt', 'tensorop8816', 'tensorop8832'],
  534. default='simt', help="kernel type of CUTLASS kernel generator")
  535. gemv_wrapper_path = "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
  536. args = parser.parse_args()
  537. if args.operations == "gemm":
  538. operations = GenerateGemmOperations(args)
  539. elif args.operations == "gemv":
  540. operations = GenerateGemvOperations(args)
  541. elif args.operations == "conv2d":
  542. operations = GenerateConv2dOperations(args)
  543. elif args.operations == "deconv":
  544. operations = GenerateDeconvOperations(args)
  545. if args.operations == "conv2d" or args.operations == "deconv":
  546. for operation in operations:
  547. with EmitConvSingleKernelWrapper(args.output, operation) as emitter:
  548. emitter.emit()
  549. elif args.operations == "gemm":
  550. for operation in operations:
  551. with EmitGemmSingleKernelWrapper(args.output, operation) as emitter:
  552. emitter.emit()
  553. elif args.operations == "gemv":
  554. for operation in operations:
  555. with EmitGemvSingleKernelWrapper(args.output, operation, gemv_wrapper_path) as emitter:
  556. emitter.emit()
  557. if args.operations != "gemv":
  558. GenerateManifest(args, operations, args.output)
  559. #
  560. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台