You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generator.py 65 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. import argparse
  10. import platform
  11. from library import *
  12. from manifest import *
  13. ###################################################################################################
  14. #
  15. def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch=0):
  16. # by default, use the latest CUDA Toolkit version
  17. cuda_version = [11, 0, 132]
  18. # Update cuda_version based on parsed string
  19. if semantic_ver_string != "":
  20. for i, x in enumerate([int(x) for x in semantic_ver_string.split(".")]):
  21. if i < len(cuda_version):
  22. cuda_version[i] = x
  23. else:
  24. cuda_version.append(x)
  25. return cuda_version >= [major, minor, patch]
  26. ###################################################################################################
  27. ###################################################################################################
  28. #
  29. def CreateGemmOperator(
  30. manifest,
  31. layouts,
  32. tile_descriptions,
  33. data_type,
  34. alignment_constraints,
  35. complex_transforms=None,
  36. epilogue_functor=EpilogueFunctor.LinearCombination,
  37. swizzling_functor=SwizzlingFunctor.Identity8,
  38. ):
  39. if complex_transforms is None:
  40. complex_transforms = [(ComplexTransform.none, ComplexTransform.none)]
  41. element_a, element_b, element_c, element_epilogue = data_type
  42. operations = []
  43. # by default, only generate the largest tile and largest alignment
  44. if manifest.args.kernels == "":
  45. tile_descriptions = [tile_descriptions[0]]
  46. alignment_constraints = [alignment_constraints[0]]
  47. for layout in layouts:
  48. for tile_description in tile_descriptions:
  49. for alignment in alignment_constraints:
  50. for complex_transform in complex_transforms:
  51. alignment_c = min(8, alignment)
  52. A = TensorDescription(
  53. element_a, layout[0], alignment, complex_transform[0]
  54. )
  55. B = TensorDescription(
  56. element_b, layout[1], alignment, complex_transform[1]
  57. )
  58. C = TensorDescription(element_c, layout[2], alignment_c)
  59. new_operation = GemmOperation(
  60. GemmKind.Universal,
  61. tile_description.minimum_compute_capability,
  62. tile_description,
  63. A,
  64. B,
  65. C,
  66. element_epilogue,
  67. epilogue_functor,
  68. swizzling_functor,
  69. )
  70. manifest.append(new_operation)
  71. operations.append(new_operation)
  72. return operations
  73. ###########################################################################################################
  74. # ConvolutionOperator support variations
  75. # ____________________________________________________________________
  76. # ConvolutionalOperator | Analytic | Optimized
  77. # ____________________________________________________________________
  78. # | Fprop | (strided) | (strided)
  79. # | Dgrad | (strided, unity*) | (unity)
  80. # | Wgrad | (strided) | (strided)
  81. # ____________________________________________________________________
  82. #
  83. # Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
  84. ###########################################################################################################
  85. # Convolution for 2D operations
  86. def CreateConv2dOperator(
  87. manifest,
  88. layout,
  89. tile_descriptions,
  90. data_type,
  91. alignment,
  92. conv_kinds=[ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad],
  93. epilogue_functor=EpilogueFunctor.LinearCombination,
  94. ):
  95. element_a, element_b, element_c, element_epilogue = data_type
  96. # one exceptional case
  97. alignment_c = min(8, alignment)
  98. # iterator algorithm (analytic and optimized)
  99. iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
  100. # by default, only generate the largest tile size
  101. if manifest.args.kernels == "":
  102. tile_descriptions = [tile_descriptions[0]]
  103. operations = []
  104. for tile in tile_descriptions:
  105. for conv_kind in conv_kinds:
  106. for iterator_algorithm in iterator_algorithms:
  107. A = TensorDescription(element_a, layout[0], alignment)
  108. B = TensorDescription(element_b, layout[1], alignment)
  109. C = TensorDescription(element_c, layout[2], alignment_c)
  110. # unity stride only for Optimized Dgrad
  111. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
  112. conv_kind == ConvKind.Dgrad
  113. ):
  114. new_operation = Conv2dOperation(
  115. conv_kind,
  116. iterator_algorithm,
  117. tile.minimum_compute_capability,
  118. tile,
  119. A,
  120. B,
  121. C,
  122. element_epilogue,
  123. StrideSupport.Unity,
  124. epilogue_functor,
  125. )
  126. manifest.append(new_operation)
  127. operations.append(new_operation)
  128. # strided dgrad is not supported by Optimized Dgrad
  129. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
  130. conv_kind == ConvKind.Dgrad
  131. ):
  132. continue
  133. # strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
  134. new_operation = Conv2dOperation(
  135. conv_kind,
  136. iterator_algorithm,
  137. tile.minimum_compute_capability,
  138. tile,
  139. A,
  140. B,
  141. C,
  142. element_epilogue,
  143. StrideSupport.Strided,
  144. epilogue_functor,
  145. )
  146. manifest.append(new_operation)
  147. operations.append(new_operation)
  148. return operations
  149. ###################################################################################################
  150. ###################################################################################################
  151. def GenerateConv2d_Simt(args):
  152. operations = []
  153. layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorC4RSK4)]
  154. math_instructions = [
  155. MathInstruction(
  156. [1, 1, 4],
  157. DataType.s8,
  158. DataType.s8,
  159. DataType.s32,
  160. OpcodeClass.Simt,
  161. MathOperation.multiply_add,
  162. )
  163. ]
  164. dst_layouts = [
  165. LayoutType.TensorNC4HW4,
  166. LayoutType.TensorNC32HW32,
  167. LayoutType.TensorNHWC,
  168. LayoutType.TensorNHWC,
  169. LayoutType.TensorNCHW,
  170. ]
  171. dst_types = [DataType.s8, DataType.s8, DataType.u4, DataType.s4, DataType.f32]
  172. max_cc = 1024
  173. for math_inst in math_instructions:
  174. for layout in layouts:
  175. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  176. if dst_type == DataType.s4 or dst_type == DataType.u4:
  177. min_cc = 75
  178. use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt
  179. else:
  180. min_cc = 61
  181. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  182. tile_descriptions = [
  183. TileDescription(
  184. [128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  185. ),
  186. TileDescription(
  187. [128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  188. ),
  189. TileDescription(
  190. [64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc
  191. ),
  192. TileDescription(
  193. [128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc
  194. ),
  195. TileDescription(
  196. [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
  197. ),
  198. TileDescription(
  199. [32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  200. ),
  201. TileDescription(
  202. [64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  203. ),
  204. TileDescription(
  205. [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
  206. ),
  207. TileDescription(
  208. [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
  209. ),
  210. ]
  211. for tile in tile_descriptions:
  212. if (
  213. dst_layout == LayoutType.TensorNC32HW32
  214. and tile.threadblock_shape[0] > 32
  215. ):
  216. continue
  217. if (
  218. dst_layout == LayoutType.TensorNCHW
  219. or dst_layout == LayoutType.TensorNHWC
  220. ) and tile.threadblock_shape[0] > 16:
  221. continue
  222. operations += GenerateConv2d(
  223. ConvType.Convolution,
  224. ConvKind.Fprop,
  225. [tile],
  226. layout[0],
  227. layout[1],
  228. dst_layout,
  229. dst_type,
  230. min_cc,
  231. 32,
  232. 32,
  233. 32,
  234. use_special_optimization,
  235. )
  236. return operations
  237. def GenerateConv2d_TensorOp_8816(args):
  238. operations = []
  239. layouts = [(LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32)]
  240. math_instructions = [
  241. MathInstruction(
  242. [8, 8, 16],
  243. DataType.s8,
  244. DataType.s8,
  245. DataType.s32,
  246. OpcodeClass.TensorOp,
  247. MathOperation.multiply_add_saturate,
  248. )
  249. ]
  250. dst_layouts = [LayoutType.TensorNC32HW32, LayoutType.TensorNC4HW4]
  251. dst_types = [DataType.s8, DataType.s8]
  252. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  253. min_cc = 75
  254. max_cc = 1024
  255. cuda_major = 10
  256. cuda_minor = 2
  257. for math_inst in math_instructions:
  258. for layout in layouts:
  259. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  260. if dst_layout == LayoutType.TensorNC32HW32:
  261. tile_descriptions = [
  262. TileDescription(
  263. [128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc
  264. ),
  265. TileDescription(
  266. [256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc
  267. ),
  268. TileDescription(
  269. [128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  270. ),
  271. TileDescription(
  272. [128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  273. ),
  274. TileDescription(
  275. [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  276. ),
  277. TileDescription(
  278. [128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc
  279. ),
  280. TileDescription(
  281. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  282. ),
  283. ]
  284. operations += GenerateConv2d(
  285. ConvType.Convolution,
  286. ConvKind.Fprop,
  287. tile_descriptions,
  288. layout[0],
  289. layout[1],
  290. dst_layout,
  291. dst_type,
  292. min_cc,
  293. 128,
  294. 128,
  295. 64,
  296. use_special_optimization,
  297. ImplicitGemmMode.GemmTN,
  298. True,
  299. cuda_major,
  300. cuda_minor,
  301. )
  302. else:
  303. assert dst_layout == LayoutType.TensorNC4HW4
  304. tile_descriptions = [
  305. TileDescription(
  306. [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  307. ),
  308. TileDescription(
  309. [32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc
  310. ),
  311. ]
  312. operations += GenerateConv2d(
  313. ConvType.Convolution,
  314. ConvKind.Fprop,
  315. tile_descriptions,
  316. layout[0],
  317. layout[1],
  318. dst_layout,
  319. dst_type,
  320. min_cc,
  321. 128,
  322. 128,
  323. 64,
  324. use_special_optimization,
  325. ImplicitGemmMode.GemmNT,
  326. False,
  327. cuda_major,
  328. cuda_minor,
  329. )
  330. layouts_nhwc = [
  331. (LayoutType.TensorNHWC, LayoutType.TensorNC4HW4, 32),
  332. (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 64),
  333. (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 128),
  334. ]
  335. dst_layouts_nhwc = [LayoutType.TensorNHWC]
  336. for math_inst in math_instructions:
  337. for layout in layouts_nhwc:
  338. for dst_layout in dst_layouts_nhwc:
  339. dst_type = math_inst.element_b
  340. tile_descriptions = [
  341. TileDescription(
  342. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  343. ),
  344. TileDescription(
  345. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  346. ),
  347. ]
  348. for tile in tile_descriptions:
  349. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  350. operations += GenerateConv2d(
  351. ConvType.Convolution,
  352. ConvKind.Fprop,
  353. [tile],
  354. layout[0],
  355. layout[1],
  356. dst_layout,
  357. dst_type,
  358. min_cc,
  359. layout[2],
  360. layout[2],
  361. dst_align,
  362. use_special_optimization,
  363. ImplicitGemmMode.GemmTN,
  364. False,
  365. cuda_major,
  366. cuda_minor,
  367. )
  368. if (
  369. tile.threadblock_shape[1] == 16
  370. or tile.threadblock_shape[1] == 32
  371. ):
  372. operations += GenerateConv2d(
  373. ConvType.Convolution,
  374. ConvKind.Fprop,
  375. [tile],
  376. layout[0],
  377. layout[1],
  378. dst_layout,
  379. dst_type,
  380. min_cc,
  381. layout[2],
  382. layout[2],
  383. dst_align,
  384. use_special_optimization,
  385. ImplicitGemmMode.GemmTN,
  386. True,
  387. cuda_major,
  388. cuda_minor,
  389. )
  390. out_dtypes = [DataType.s4, DataType.u4, DataType.f32]
  391. # INT8x8x4 and INT8x8x32
  392. for math_inst in math_instructions:
  393. for layout in layouts_nhwc:
  394. for dst_layout in dst_layouts_nhwc:
  395. for out_dtype in out_dtypes:
  396. tile_descriptions = [
  397. TileDescription(
  398. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  399. ),
  400. TileDescription(
  401. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  402. ),
  403. ]
  404. for tile in tile_descriptions:
  405. dst_align = (
  406. 4 * DataTypeSize[out_dtype]
  407. if tile.threadblock_shape[1] == 16
  408. or out_dtype == DataType.f32
  409. else 8 * DataTypeSize[out_dtype]
  410. )
  411. operations += GenerateConv2d(
  412. ConvType.Convolution,
  413. ConvKind.Fprop,
  414. [tile],
  415. layout[0],
  416. layout[1],
  417. dst_layout,
  418. out_dtype,
  419. min_cc,
  420. layout[2],
  421. layout[2],
  422. dst_align,
  423. use_special_optimization,
  424. ImplicitGemmMode.GemmTN,
  425. False,
  426. cuda_major,
  427. cuda_minor,
  428. )
  429. if tile.threadblock_shape[1] == 16 or (
  430. tile.threadblock_shape[1] == 32
  431. and out_dtype != DataType.f32
  432. ):
  433. operations += GenerateConv2d(
  434. ConvType.Convolution,
  435. ConvKind.Fprop,
  436. [tile],
  437. layout[0],
  438. layout[1],
  439. dst_layout,
  440. out_dtype,
  441. min_cc,
  442. layout[2],
  443. layout[2],
  444. dst_align,
  445. use_special_optimization,
  446. ImplicitGemmMode.GemmTN,
  447. True,
  448. cuda_major,
  449. cuda_minor,
  450. )
  451. return operations
  452. def GenerateConv2d_TensorOp_8832(args):
  453. operations = []
  454. layouts = [(LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64)]
  455. math_instructions = [
  456. MathInstruction(
  457. [8, 8, 32],
  458. DataType.s4,
  459. DataType.s4,
  460. DataType.s32,
  461. OpcodeClass.TensorOp,
  462. MathOperation.multiply_add_saturate,
  463. ),
  464. MathInstruction(
  465. [8, 8, 32],
  466. DataType.s4,
  467. DataType.u4,
  468. DataType.s32,
  469. OpcodeClass.TensorOp,
  470. MathOperation.multiply_add_saturate,
  471. ),
  472. ]
  473. dst_layouts = [LayoutType.TensorNC64HW64]
  474. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  475. min_cc = 75
  476. max_cc = 1024
  477. cuda_major = 10
  478. cuda_minor = 2
  479. for math_inst in math_instructions:
  480. for layout in layouts:
  481. for dst_layout in dst_layouts:
  482. dst_type = math_inst.element_b
  483. tile_descriptions = [
  484. TileDescription(
  485. [128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc
  486. ),
  487. TileDescription(
  488. [128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc
  489. ),
  490. TileDescription(
  491. [128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc
  492. ),
  493. TileDescription(
  494. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  495. ),
  496. ]
  497. operations += GenerateConv2d(
  498. ConvType.Convolution,
  499. ConvKind.Fprop,
  500. tile_descriptions,
  501. layout[0],
  502. layout[1],
  503. dst_layout,
  504. dst_type,
  505. min_cc,
  506. 128,
  507. 128,
  508. 64,
  509. use_special_optimization,
  510. ImplicitGemmMode.GemmTN,
  511. True,
  512. cuda_major,
  513. cuda_minor,
  514. )
  515. layouts_nhwc = [
  516. (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32),
  517. (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 64),
  518. (LayoutType.TensorNHWC, LayoutType.TensorNC32HW32, 128),
  519. ]
  520. dst_layouts_nhwc = [LayoutType.TensorNHWC]
  521. for math_inst in math_instructions:
  522. for layout in layouts_nhwc:
  523. for dst_layout in dst_layouts_nhwc:
  524. dst_type = math_inst.element_b
  525. tile_descriptions = [
  526. TileDescription(
  527. [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
  528. ),
  529. TileDescription(
  530. [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  531. ),
  532. TileDescription(
  533. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  534. ),
  535. ]
  536. for tile in tile_descriptions:
  537. dst_align = 16 if tile.threadblock_shape[1] == 16 else 32
  538. operations += GenerateConv2d(
  539. ConvType.Convolution,
  540. ConvKind.Fprop,
  541. [tile],
  542. layout[0],
  543. layout[1],
  544. dst_layout,
  545. dst_type,
  546. min_cc,
  547. layout[2],
  548. layout[2],
  549. dst_align,
  550. use_special_optimization,
  551. ImplicitGemmMode.GemmTN,
  552. False,
  553. cuda_major,
  554. cuda_minor,
  555. )
  556. if (
  557. tile.threadblock_shape[1] == 32
  558. or tile.threadblock_shape[1] == 64
  559. ):
  560. dst_align = 32 if tile.threadblock_shape[1] == 32 else 64
  561. operations += GenerateConv2d(
  562. ConvType.Convolution,
  563. ConvKind.Fprop,
  564. [tile],
  565. layout[0],
  566. layout[1],
  567. dst_layout,
  568. dst_type,
  569. min_cc,
  570. layout[2],
  571. layout[2],
  572. dst_align,
  573. use_special_optimization,
  574. ImplicitGemmMode.GemmTN,
  575. True,
  576. cuda_major,
  577. cuda_minor,
  578. )
  579. # INT4x4x8
  580. for math_inst in math_instructions:
  581. for layout in layouts_nhwc:
  582. for dst_layout in dst_layouts_nhwc:
  583. tile_descriptions = [
  584. TileDescription(
  585. [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
  586. ),
  587. TileDescription(
  588. [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  589. ),
  590. TileDescription(
  591. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  592. ),
  593. ]
  594. for tile in tile_descriptions:
  595. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  596. operations += GenerateConv2d(
  597. ConvType.Convolution,
  598. ConvKind.Fprop,
  599. [tile],
  600. layout[0],
  601. layout[1],
  602. dst_layout,
  603. DataType.s8,
  604. min_cc,
  605. layout[2],
  606. layout[2],
  607. dst_align,
  608. use_special_optimization,
  609. ImplicitGemmMode.GemmTN,
  610. False,
  611. cuda_major,
  612. cuda_minor,
  613. )
  614. if (
  615. tile.threadblock_shape[1] == 32
  616. or tile.threadblock_shape[1] == 64
  617. ):
  618. dst_align = 64 if tile.threadblock_shape[1] == 32 else 128
  619. operations += GenerateConv2d(
  620. ConvType.Convolution,
  621. ConvKind.Fprop,
  622. [tile],
  623. layout[0],
  624. layout[1],
  625. dst_layout,
  626. DataType.s8,
  627. min_cc,
  628. layout[2],
  629. layout[2],
  630. dst_align,
  631. use_special_optimization,
  632. ImplicitGemmMode.GemmTN,
  633. True,
  634. cuda_major,
  635. cuda_minor,
  636. )
  637. return operations
  638. def GenerateDeconv_Simt(args):
  639. operations = []
  640. layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorK4RSC4)]
  641. math_instructions = [
  642. MathInstruction(
  643. [1, 1, 4],
  644. DataType.s8,
  645. DataType.s8,
  646. DataType.s32,
  647. OpcodeClass.Simt,
  648. MathOperation.multiply_add,
  649. )
  650. ]
  651. dst_layouts = [LayoutType.TensorNC4HW4]
  652. dst_types = [DataType.s8]
  653. use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
  654. min_cc = 61
  655. max_cc = 1024
  656. for math_inst in math_instructions:
  657. for layout in layouts:
  658. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  659. tile_descriptions = [
  660. TileDescription(
  661. [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
  662. ),
  663. TileDescription(
  664. [16, 128, 16], 2, [1, 2, 1], math_inst, min_cc, max_cc
  665. ),
  666. TileDescription(
  667. [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
  668. ),
  669. TileDescription(
  670. [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
  671. ),
  672. ]
  673. operations += GenerateConv2d(
  674. ConvType.Convolution,
  675. ConvKind.Dgrad,
  676. tile_descriptions,
  677. layout[0],
  678. layout[1],
  679. dst_layout,
  680. dst_type,
  681. min_cc,
  682. 32,
  683. 32,
  684. 32,
  685. use_special_optimization,
  686. )
  687. return operations
  688. def GenerateDeconv_TensorOp_8816(args):
  689. operations = []
  690. layouts = [
  691. (LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32),
  692. (LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64),
  693. (LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128),
  694. ]
  695. math_instructions = [
  696. MathInstruction(
  697. [8, 8, 16],
  698. DataType.s8,
  699. DataType.s8,
  700. DataType.s32,
  701. OpcodeClass.TensorOp,
  702. MathOperation.multiply_add_saturate,
  703. )
  704. ]
  705. dst_layouts = [LayoutType.TensorNHWC]
  706. dst_types = [DataType.s8]
  707. use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
  708. min_cc = 75
  709. max_cc = 1024
  710. cuda_major = 10
  711. cuda_minor = 2
  712. for math_inst in math_instructions:
  713. for layout in layouts:
  714. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  715. tile_descriptions = [
  716. TileDescription(
  717. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  718. ),
  719. TileDescription(
  720. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  721. ),
  722. ]
  723. for tile in tile_descriptions:
  724. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  725. operations += GenerateConv2d(
  726. ConvType.Convolution,
  727. ConvKind.Dgrad,
  728. [tile],
  729. layout[0],
  730. layout[1],
  731. dst_layout,
  732. dst_type,
  733. min_cc,
  734. layout[2],
  735. layout[2],
  736. dst_align,
  737. use_special_optimization,
  738. ImplicitGemmMode.GemmTN,
  739. False,
  740. cuda_major,
  741. cuda_minor,
  742. )
  743. return operations
  744. ################################################################################
  745. # parameters
  746. # Edge - for tiles, the edges represent the length of one side
  747. # Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles
  748. # MaxEdge - maximum length of each edge
  749. # Min/Max - minimum/maximum of the product of edge lengths
  750. ################################################################################
  751. warpsPerThreadblockEdge = [1, 2, 4, 8, 16]
  752. warpsPerThreadblockRatio = 2
  753. warpsPerThreadblockMax = 16
  754. # NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases
  755. warpShapeEdges = [8, 16, 32, 64, 128, 256]
  756. warpShapeRatio = 4
  757. warpShapeMax = 64 * 64
  758. warpShapeMin = 8 * 8
  759. threadblockEdgeMax = 256
  760. # char, type bits/elem, max tile, L0 threadblock tiles
  761. precisions = {
  762. "c": ["cutlass::complex<float>", 64, 64 * 128, [[64, 128], [64, 32]]],
  763. "d": ["double", 64, 64 * 64, [[64, 64], [32, 32]]],
  764. "h": ["cutlass::half_t", 16, 128 * 256, [[256, 128], [64, 128], [64, 32]]],
  765. "i": ["int", 32, 128 * 128, [[128, 64], [16, 32]]],
  766. "s": ["float", 32, 128 * 128, [[128, 256], [128, 128], [64, 64]]],
  767. "z": ["cutlass::complex<double>", 128, 64 * 64, [[32, 64], [16, 32]]],
  768. }
  769. # L1 will have a single kernel for every unique shape
  770. # L2 will have everything else
  771. def GenerateGemm_Simt(args):
  772. ################################################################################
  773. # warps per threadblock
  774. ################################################################################
  775. warpsPerThreadblocks = []
  776. for warpsPerThreadblock0 in warpsPerThreadblockEdge:
  777. for warpsPerThreadblock1 in warpsPerThreadblockEdge:
  778. if (
  779. warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
  780. and warpsPerThreadblock1 / warpsPerThreadblock0
  781. <= warpsPerThreadblockRatio
  782. and warpsPerThreadblock0 * warpsPerThreadblock1
  783. <= warpsPerThreadblockMax
  784. ):
  785. warpsPerThreadblocks.append(
  786. [warpsPerThreadblock0, warpsPerThreadblock1]
  787. )
  788. ################################################################################
  789. # warp shapes
  790. ################################################################################
  791. warpNumThreads = 32
  792. warpShapes = []
  793. for warp0 in warpShapeEdges:
  794. for warp1 in warpShapeEdges:
  795. if (
  796. warp0 / warp1 <= warpShapeRatio
  797. and warp1 / warp0 <= warpShapeRatio
  798. and warp0 * warp1 <= warpShapeMax
  799. and warp0 * warp1 > warpShapeMin
  800. ):
  801. warpShapes.append([warp0, warp1])
  802. # sgemm
  803. precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[
  804. "s"
  805. ]
  806. layouts = [
  807. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  808. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  809. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  810. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  811. ]
  812. math_instructions = [
  813. MathInstruction(
  814. [1, 1, 1],
  815. DataType.f32,
  816. DataType.f32,
  817. DataType.f32,
  818. OpcodeClass.Simt,
  819. MathOperation.multiply_add,
  820. )
  821. ]
  822. min_cc = 50
  823. max_cc = 1024
  824. operations = []
  825. for math_inst in math_instructions:
  826. for layout in layouts:
  827. data_type = [
  828. math_inst.element_a,
  829. math_inst.element_b,
  830. math_inst.element_accumulator,
  831. math_inst.element_accumulator,
  832. ]
  833. tile_descriptions = [
  834. TileDescription([64, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  835. TileDescription([256, 64, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  836. TileDescription([32, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  837. TileDescription([256, 32, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  838. TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  839. TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  840. TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  841. TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  842. TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  843. TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  844. TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  845. TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  846. TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  847. TileDescription([8, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  848. TileDescription([16, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  849. TileDescription([16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  850. TileDescription([16, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  851. ]
  852. for warpsPerThreadblock in warpsPerThreadblocks:
  853. for warpShape in warpShapes:
  854. warpThreadsM = 0
  855. if warpShape[0] > warpShape[1]:
  856. warpThreadsM = 8
  857. else:
  858. warpThreadsM = 4
  859. warpThreadsN = warpNumThreads / warpThreadsM
  860. # skip shapes with conflicting rectangularity
  861. # they are unlikely to be fastest
  862. blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
  863. blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
  864. warpG = warpShape[0] > warpShape[1]
  865. warpL = warpShape[0] < warpShape[1]
  866. blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
  867. blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
  868. warpG2 = warpShape[0] > warpShape[1] * 2
  869. warpL2 = warpShape[0] * 2 < warpShape[1]
  870. if blockG2 and warpL:
  871. continue
  872. if blockL2 and warpG:
  873. continue
  874. if warpG2 and blockL:
  875. continue
  876. if warpL2 and blockG:
  877. continue
  878. # check threadblock ratios and max
  879. threadblockTile = [
  880. warpShape[0] * warpsPerThreadblock[0],
  881. warpShape[1] * warpsPerThreadblock[1],
  882. ]
  883. if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
  884. continue
  885. if threadblockTile[0] > threadblockEdgeMax:
  886. continue
  887. if threadblockTile[1] > threadblockEdgeMax:
  888. continue
  889. totalThreads = (
  890. warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
  891. )
  892. # calculate unroll
  893. # ensure that every iteration at least a full load of A,B are done
  894. unrollMin = 8
  895. unrollMin0 = totalThreads // threadblockTile[0]
  896. unrollMin1 = totalThreads // threadblockTile[1]
  897. unroll = max(unrollMin, unrollMin0, unrollMin1)
  898. threadTileM = warpShape[0] // warpThreadsM
  899. threadTileN = warpShape[1] // warpThreadsN
  900. if threadTileM < 2 or threadTileN < 2:
  901. continue
  902. if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
  903. continue
  904. # epilogue currently only supports N < WarpNumThreads
  905. if threadblockTile[1] < warpNumThreads:
  906. continue
  907. # limit smem
  908. smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
  909. smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
  910. smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
  911. if smemKBytes > 48:
  912. continue
  913. tile = TileDescription(
  914. [threadblockTile[0], threadblockTile[1], unroll],
  915. 2,
  916. [
  917. threadblockTile[0] // warpShape[0],
  918. threadblockTile[1] // warpShape[1],
  919. 1,
  920. ],
  921. math_inst,
  922. min_cc,
  923. max_cc,
  924. )
  925. def filter(t: TileDescription) -> bool:
  926. nonlocal tile
  927. return (
  928. t.threadblock_shape[0] == tile.threadblock_shape[0]
  929. and t.threadblock_shape[1] == tile.threadblock_shape[1]
  930. and t.threadblock_shape[2] == tile.threadblock_shape[2]
  931. and t.warp_count[0] == tile.warp_count[0]
  932. and t.warp_count[1] == tile.warp_count[1]
  933. and t.warp_count[2] == tile.warp_count[2]
  934. and t.stages == tile.stages
  935. )
  936. if not any(t for t in tile_descriptions if filter(t)):
  937. continue
  938. operations += GeneratesGemm(
  939. tile, data_type, layout[0], layout[1], layout[2], min_cc
  940. )
  941. return operations
  942. #
  943. def GenerateDwconv2d_Simt(args, conv_kind):
  944. ################################################################################
  945. # warps per threadblock
  946. ################################################################################
  947. warpsPerThreadblocks = []
  948. for warpsPerThreadblock0 in warpsPerThreadblockEdge:
  949. for warpsPerThreadblock1 in warpsPerThreadblockEdge:
  950. if (
  951. warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
  952. and warpsPerThreadblock1 / warpsPerThreadblock0
  953. <= warpsPerThreadblockRatio
  954. and warpsPerThreadblock0 * warpsPerThreadblock1
  955. <= warpsPerThreadblockMax
  956. ):
  957. warpsPerThreadblocks.append(
  958. [warpsPerThreadblock0, warpsPerThreadblock1]
  959. )
  960. ################################################################################
  961. # warp shapes
  962. ################################################################################
  963. warpNumThreads = 32
  964. warpShapes = []
  965. for warp0 in warpShapeEdges:
  966. for warp1 in warpShapeEdges:
  967. if (
  968. warp0 / warp1 <= warpShapeRatio
  969. and warp1 / warp0 <= warpShapeRatio
  970. and warp0 * warp1 <= warpShapeMax
  971. and warp0 * warp1 > warpShapeMin
  972. ):
  973. warpShapes.append([warp0, warp1])
  974. # sgemm
  975. precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[
  976. "s"
  977. ]
  978. layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
  979. math_instructions = [
  980. MathInstruction(
  981. [1, 1, 1],
  982. DataType.f32,
  983. DataType.f32,
  984. DataType.f32,
  985. OpcodeClass.Simt,
  986. MathOperation.multiply_add,
  987. )
  988. ]
  989. min_cc = 50
  990. max_cc = 1024
  991. dst_layouts = [LayoutType.TensorNCHW]
  992. dst_types = [DataType.f32]
  993. if conv_kind == ConvKind.Wgrad:
  994. alignment_constraints = [32]
  995. else:
  996. alignment_constraints = [128, 32]
  997. operations = []
  998. for math_inst in math_instructions:
  999. tile_descriptions = [
  1000. TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  1001. TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1002. TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1003. TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  1004. TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  1005. TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  1006. TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1007. TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1008. TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1009. ]
  1010. for warpsPerThreadblock in warpsPerThreadblocks:
  1011. for warpShape in warpShapes:
  1012. warpThreadsM = 0
  1013. if warpShape[0] > warpShape[1]:
  1014. warpThreadsM = 8
  1015. else:
  1016. warpThreadsM = 4
  1017. warpThreadsN = warpNumThreads / warpThreadsM
  1018. # skip shapes with conflicting rectangularity
  1019. # they are unlikely to be fastest
  1020. blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
  1021. blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
  1022. warpG = warpShape[0] > warpShape[1]
  1023. warpL = warpShape[0] < warpShape[1]
  1024. blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
  1025. blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
  1026. warpG2 = warpShape[0] > warpShape[1] * 2
  1027. warpL2 = warpShape[0] * 2 < warpShape[1]
  1028. if blockG2 and warpL:
  1029. continue
  1030. if blockL2 and warpG:
  1031. continue
  1032. if warpG2 and blockL:
  1033. continue
  1034. if warpL2 and blockG:
  1035. continue
  1036. # check threadblock ratios and max
  1037. threadblockTile = [
  1038. warpShape[0] * warpsPerThreadblock[0],
  1039. warpShape[1] * warpsPerThreadblock[1],
  1040. ]
  1041. if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
  1042. continue
  1043. if threadblockTile[0] > threadblockEdgeMax:
  1044. continue
  1045. if threadblockTile[1] > threadblockEdgeMax:
  1046. continue
  1047. totalThreads = (
  1048. warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
  1049. )
  1050. # calculate unroll
  1051. # ensure that every iteration at least a full load of A,B are done
  1052. unrollMin = 8
  1053. unrollMin0 = totalThreads // threadblockTile[0]
  1054. unrollMin1 = totalThreads // threadblockTile[1]
  1055. unroll = max(unrollMin, unrollMin0, unrollMin1)
  1056. threadTileM = warpShape[0] // warpThreadsM
  1057. threadTileN = warpShape[1] // warpThreadsN
  1058. if threadTileM < 2 or threadTileN < 2:
  1059. continue
  1060. if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
  1061. continue
  1062. # epilogue currently only supports N < WarpNumThreads
  1063. if threadblockTile[1] < warpNumThreads:
  1064. continue
  1065. # limit smem
  1066. smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
  1067. smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
  1068. smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
  1069. if smemKBytes > 48:
  1070. continue
  1071. tile = TileDescription(
  1072. [threadblockTile[0], threadblockTile[1], unroll],
  1073. 2,
  1074. [
  1075. threadblockTile[0] // warpShape[0],
  1076. threadblockTile[1] // warpShape[1],
  1077. 1,
  1078. ],
  1079. math_inst,
  1080. min_cc,
  1081. max_cc,
  1082. )
  1083. def filter(t: TileDescription) -> bool:
  1084. nonlocal tile
  1085. return (
  1086. t.threadblock_shape[0] == tile.threadblock_shape[0]
  1087. and t.threadblock_shape[1] == tile.threadblock_shape[1]
  1088. and t.threadblock_shape[2] == tile.threadblock_shape[2]
  1089. and t.warp_count[0] == tile.warp_count[0]
  1090. and t.warp_count[1] == tile.warp_count[1]
  1091. and t.warp_count[2] == tile.warp_count[2]
  1092. and t.stages == tile.stages
  1093. )
  1094. if not any(t for t in tile_descriptions if filter(t)):
  1095. continue
  1096. for layout in layouts:
  1097. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  1098. for alignment_src in alignment_constraints:
  1099. operations += GenerateConv2d(
  1100. ConvType.DepthwiseConvolution,
  1101. conv_kind,
  1102. [tile],
  1103. layout[0],
  1104. layout[1],
  1105. dst_layout,
  1106. dst_type,
  1107. min_cc,
  1108. alignment_src,
  1109. 32,
  1110. 32,
  1111. SpecialOptimizeDesc.NoneSpecialOpt,
  1112. ImplicitGemmMode.GemmNT
  1113. if conv_kind == ConvKind.Wgrad
  1114. else ImplicitGemmMode.GemmTN,
  1115. )
  1116. return operations
  1117. #
  1118. def GenerateDwconv2d_TensorOp_884(args, conv_kind):
  1119. layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
  1120. math_instructions = [
  1121. MathInstruction(
  1122. [8, 8, 4],
  1123. DataType.f16,
  1124. DataType.f16,
  1125. DataType.f32,
  1126. OpcodeClass.TensorOp,
  1127. MathOperation.multiply_add,
  1128. ),
  1129. MathInstruction(
  1130. [8, 8, 4],
  1131. DataType.f16,
  1132. DataType.f16,
  1133. DataType.f16,
  1134. OpcodeClass.TensorOp,
  1135. MathOperation.multiply_add,
  1136. ),
  1137. ]
  1138. min_cc = 70
  1139. max_cc = 75
  1140. dst_layouts = [LayoutType.TensorNCHW]
  1141. if conv_kind == ConvKind.Wgrad:
  1142. dst_types = [DataType.f32]
  1143. else:
  1144. dst_types = [DataType.f16]
  1145. alignment_constraints = [128, 32, 16]
  1146. cuda_major = 10
  1147. cuda_minor = 1
  1148. operations = []
  1149. for math_inst in math_instructions:
  1150. tile_descriptions = [
  1151. TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  1152. TileDescription([128, 128, 32], 2, [4, 4, 1], math_inst, min_cc, max_cc),
  1153. TileDescription([64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  1154. TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  1155. TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1156. ]
  1157. for layout in layouts:
  1158. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  1159. for alignment_src in alignment_constraints:
  1160. if conv_kind == ConvKind.Wgrad:
  1161. # skip io16xc16
  1162. if math_inst.element_accumulator == DataType.f16:
  1163. continue
  1164. for alignment_diff in alignment_constraints:
  1165. operations += GenerateConv2d(
  1166. ConvType.DepthwiseConvolution,
  1167. conv_kind,
  1168. tile_descriptions,
  1169. layout[0],
  1170. layout[1],
  1171. dst_layout,
  1172. dst_type,
  1173. min_cc,
  1174. alignment_src,
  1175. alignment_diff,
  1176. 32, # always f32 output
  1177. SpecialOptimizeDesc.NoneSpecialOpt,
  1178. ImplicitGemmMode.GemmNT,
  1179. False,
  1180. cuda_major,
  1181. cuda_minor,
  1182. )
  1183. else:
  1184. operations += GenerateConv2d(
  1185. ConvType.DepthwiseConvolution,
  1186. conv_kind,
  1187. tile_descriptions,
  1188. layout[0],
  1189. layout[1],
  1190. dst_layout,
  1191. dst_type,
  1192. min_cc,
  1193. alignment_src,
  1194. 16,
  1195. 16,
  1196. SpecialOptimizeDesc.NoneSpecialOpt,
  1197. ImplicitGemmMode.GemmTN,
  1198. False,
  1199. cuda_major,
  1200. cuda_minor,
  1201. )
  1202. return operations
  1203. #
  1204. def GenerateGemv_Simt(args):
  1205. threadBlockShape_N = [128, 64, 32]
  1206. ldgBits_A = [128, 64, 32]
  1207. ldgBits_B = [128, 64, 32]
  1208. layouts = [(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor)]
  1209. math_instructions = [
  1210. MathInstruction(
  1211. [1, 1, 1],
  1212. DataType.f32,
  1213. DataType.f32,
  1214. DataType.f32,
  1215. OpcodeClass.Simt,
  1216. MathOperation.multiply_add,
  1217. )
  1218. ]
  1219. min_cc = 50
  1220. operations = []
  1221. for math_inst in math_instructions:
  1222. for layout in layouts:
  1223. data_type = [
  1224. math_inst.element_a,
  1225. math_inst.element_b,
  1226. math_inst.element_accumulator,
  1227. math_inst.element_accumulator,
  1228. ]
  1229. for threadblock_shape_n in threadBlockShape_N:
  1230. for align_a in ldgBits_A:
  1231. for align_b in ldgBits_B:
  1232. ldg_elements_a = align_a // DataTypeSize[math_inst.element_a]
  1233. ldg_elements_b = align_b // DataTypeSize[math_inst.element_b]
  1234. threadblock_shape_k = (256 * ldg_elements_a) // (
  1235. threadblock_shape_n // ldg_elements_b
  1236. )
  1237. threadblock_shape = [
  1238. 1,
  1239. threadblock_shape_n,
  1240. threadblock_shape_k,
  1241. ]
  1242. thread_shape = [1, ldg_elements_b, ldg_elements_a]
  1243. operations.append(
  1244. GeneratesGemv(
  1245. math_inst,
  1246. threadblock_shape,
  1247. thread_shape,
  1248. data_type,
  1249. layout[0],
  1250. layout[1],
  1251. layout[2],
  1252. min_cc,
  1253. align_a,
  1254. align_b,
  1255. )
  1256. )
  1257. return operations
  1258. #
  1259. def GeneratesGemm_TensorOp_1688(args):
  1260. layouts = [
  1261. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  1262. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  1263. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  1264. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  1265. ]
  1266. math_instructions = [
  1267. MathInstruction(
  1268. [16, 8, 8],
  1269. DataType.f16,
  1270. DataType.f16,
  1271. DataType.f32,
  1272. OpcodeClass.TensorOp,
  1273. MathOperation.multiply_add,
  1274. ),
  1275. MathInstruction(
  1276. [16, 8, 8],
  1277. DataType.f16,
  1278. DataType.f16,
  1279. DataType.f16,
  1280. OpcodeClass.TensorOp,
  1281. MathOperation.multiply_add,
  1282. ),
  1283. ]
  1284. min_cc = 75
  1285. max_cc = 1024
  1286. alignment_constraints = [
  1287. 8,
  1288. 4,
  1289. 2,
  1290. # 1
  1291. ]
  1292. cuda_major = 10
  1293. cuda_minor = 2
  1294. operations = []
  1295. for math_inst in math_instructions:
  1296. for layout in layouts:
  1297. for align in alignment_constraints:
  1298. tile_descriptions = [
  1299. TileDescription(
  1300. [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
  1301. ),
  1302. TileDescription(
  1303. [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  1304. ),
  1305. TileDescription(
  1306. [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  1307. ),
  1308. ## comment some configuration to reduce compilation time and binary size
  1309. # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1310. # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1311. # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1312. ]
  1313. data_type = [
  1314. math_inst.element_a,
  1315. math_inst.element_b,
  1316. math_inst.element_a,
  1317. math_inst.element_accumulator,
  1318. ]
  1319. for tile in tile_descriptions:
  1320. operations += GeneratesGemm(
  1321. tile,
  1322. data_type,
  1323. layout[0],
  1324. layout[1],
  1325. layout[2],
  1326. min_cc,
  1327. align * 16,
  1328. align * 16,
  1329. align * 16,
  1330. cuda_major,
  1331. cuda_minor,
  1332. )
  1333. return operations
  1334. #
  1335. def GeneratesGemm_TensorOp_884(args):
  1336. layouts = [
  1337. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  1338. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  1339. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  1340. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  1341. ]
  1342. math_instructions = [
  1343. MathInstruction(
  1344. [8, 8, 4],
  1345. DataType.f16,
  1346. DataType.f16,
  1347. DataType.f32,
  1348. OpcodeClass.TensorOp,
  1349. MathOperation.multiply_add,
  1350. ),
  1351. MathInstruction(
  1352. [8, 8, 4],
  1353. DataType.f16,
  1354. DataType.f16,
  1355. DataType.f16,
  1356. OpcodeClass.TensorOp,
  1357. MathOperation.multiply_add,
  1358. ),
  1359. ]
  1360. min_cc = 70
  1361. max_cc = 75
  1362. alignment_constraints = [
  1363. 8,
  1364. 4,
  1365. 2,
  1366. # 1
  1367. ]
  1368. cuda_major = 10
  1369. cuda_minor = 1
  1370. operations = []
  1371. for math_inst in math_instructions:
  1372. for layout in layouts:
  1373. for align in alignment_constraints:
  1374. tile_descriptions = [
  1375. TileDescription(
  1376. [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
  1377. ),
  1378. TileDescription(
  1379. [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  1380. ),
  1381. TileDescription(
  1382. [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  1383. ),
  1384. ## comment some configuration to reduce compilation time and binary size
  1385. # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1386. # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1387. # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1388. ]
  1389. data_type = [
  1390. math_inst.element_a,
  1391. math_inst.element_b,
  1392. math_inst.element_a,
  1393. math_inst.element_accumulator,
  1394. ]
  1395. for tile in tile_descriptions:
  1396. operations += GeneratesGemm(
  1397. tile,
  1398. data_type,
  1399. layout[0],
  1400. layout[1],
  1401. layout[2],
  1402. min_cc,
  1403. align * 16,
  1404. align * 16,
  1405. align * 16,
  1406. cuda_major,
  1407. cuda_minor,
  1408. )
  1409. return operations
  1410. #
  1411. def GenerateConv2dOperations(args):
  1412. if args.type == "simt":
  1413. return GenerateConv2d_Simt(args)
  1414. elif args.type == "tensorop8816":
  1415. return GenerateConv2d_TensorOp_8816(args)
  1416. else:
  1417. assert args.type == "tensorop8832", (
  1418. "operation conv2d only support"
  1419. "simt, tensorop8816 and tensorop8832. (got:{})".format(args.type)
  1420. )
  1421. return GenerateConv2d_TensorOp_8832(args)
  1422. def GenerateDeconvOperations(args):
  1423. if args.type == "simt":
  1424. return GenerateDeconv_Simt(args)
  1425. else:
  1426. assert args.type == "tensorop8816", (
  1427. "operation deconv only support"
  1428. "simt and tensorop8816. (got:{})".format(args.type)
  1429. )
  1430. return GenerateDeconv_TensorOp_8816(args)
  1431. def GenerateDwconv2dFpropOperations(args):
  1432. if args.type == "simt":
  1433. return GenerateDwconv2d_Simt(args, ConvKind.Fprop)
  1434. else:
  1435. assert args.type == "tensorop884", (
  1436. "operation dwconv2d fprop only support"
  1437. "simt, tensorop884. (got:{})".format(args.type)
  1438. )
  1439. return GenerateDwconv2d_TensorOp_884(args, ConvKind.Fprop)
  1440. def GenerateDwconv2dDgradOperations(args):
  1441. if args.type == "simt":
  1442. return GenerateDwconv2d_Simt(args, ConvKind.Dgrad)
  1443. else:
  1444. assert args.type == "tensorop884", (
  1445. "operation dwconv2d fprop only support"
  1446. "simt, tensorop884. (got:{})".format(args.type)
  1447. )
  1448. return GenerateDwconv2d_TensorOp_884(args, ConvKind.Dgrad)
  1449. def GenerateDwconv2dWgradOperations(args):
  1450. if args.type == "simt":
  1451. return GenerateDwconv2d_Simt(args, ConvKind.Wgrad)
  1452. else:
  1453. assert args.type == "tensorop884", (
  1454. "operation dwconv2d fprop only support"
  1455. "simt, tensorop884. (got:{})".format(args.type)
  1456. )
  1457. return GenerateDwconv2d_TensorOp_884(args, ConvKind.Wgrad)
  1458. def GenerateGemmOperations(args):
  1459. if args.type == "tensorop884":
  1460. return GeneratesGemm_TensorOp_884(args)
  1461. elif args.type == "tensorop1688":
  1462. return GeneratesGemm_TensorOp_1688(args)
  1463. else:
  1464. assert (
  1465. args.type == "simt"
  1466. ), "operation gemm only support" "simt. (got:{})".format(args.type)
  1467. return GenerateGemm_Simt(args)
  1468. def GenerateGemvOperations(args):
  1469. assert args.type == "simt", "operation gemv only support" "simt. (got:{})".format(
  1470. args.type
  1471. )
  1472. return GenerateGemv_Simt(args)
  1473. ###################################################################################################
  1474. ###################################################################################################
  1475. if __name__ == "__main__":
  1476. parser = argparse.ArgumentParser(
  1477. description="Generates device kernel registration code for CUTLASS Kernels"
  1478. )
  1479. parser.add_argument(
  1480. "--operations",
  1481. type=str,
  1482. choices=[
  1483. "gemm",
  1484. "gemv",
  1485. "conv2d",
  1486. "deconv",
  1487. "dwconv2d_fprop",
  1488. "dwconv2d_dgrad",
  1489. "dwconv2d_wgrad",
  1490. ],
  1491. required=True,
  1492. help="Specifies the operation to generate (gemm, gemv, conv2d, deconv, dwconv2d_fprop, dwconv2d_dgrad, dwconv2d_wgrad)",
  1493. )
  1494. parser.add_argument(
  1495. "output", type=str, help="output directory for CUTLASS kernel files"
  1496. )
  1497. parser.add_argument(
  1498. "--type",
  1499. type=str,
  1500. choices=["simt", "tensorop8816", "tensorop8832", "tensorop884", "tensorop1688"],
  1501. default="simt",
  1502. help="kernel type of CUTLASS kernel generator",
  1503. )
  1504. gemv_wrapper_path = (
  1505. "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
  1506. )
  1507. short_path = (
  1508. platform.system() == "Windows" or platform.system().find("NT") >= 0
  1509. ) and ("true" != os.getenv("CUTLASS_WITH_LONG_PATH", default="False").lower())
  1510. args = parser.parse_args()
  1511. if args.operations == "gemm":
  1512. operations = GenerateGemmOperations(args)
  1513. elif args.operations == "gemv":
  1514. operations = GenerateGemvOperations(args)
  1515. elif args.operations == "conv2d":
  1516. operations = GenerateConv2dOperations(args)
  1517. elif args.operations == "deconv":
  1518. operations = GenerateDeconvOperations(args)
  1519. elif args.operations == "dwconv2d_fprop":
  1520. operations = GenerateDwconv2dFpropOperations(args)
  1521. elif args.operations == "dwconv2d_dgrad":
  1522. operations = GenerateDwconv2dDgradOperations(args)
  1523. else:
  1524. assert args.operations == "dwconv2d_wgrad", "invalid operation"
  1525. operations = GenerateDwconv2dWgradOperations(args)
  1526. if (
  1527. args.operations == "conv2d"
  1528. or args.operations == "deconv"
  1529. or args.operations == "dwconv2d_fprop"
  1530. or args.operations == "dwconv2d_dgrad"
  1531. or args.operations == "dwconv2d_wgrad"
  1532. ):
  1533. for operation in operations:
  1534. with EmitConvSingleKernelWrapper(
  1535. args.output, operation, short_path
  1536. ) as emitter:
  1537. emitter.emit()
  1538. elif args.operations == "gemm":
  1539. for operation in operations:
  1540. with EmitGemmSingleKernelWrapper(
  1541. args.output, operation, short_path
  1542. ) as emitter:
  1543. emitter.emit()
  1544. elif args.operations == "gemv":
  1545. for operation in operations:
  1546. with EmitGemvSingleKernelWrapper(
  1547. args.output, operation, gemv_wrapper_path, short_path
  1548. ) as emitter:
  1549. emitter.emit()
  1550. if args.operations != "gemv":
  1551. GenerateManifest(args, operations, args.output)
  1552. #
  1553. ###################################################################################################