You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generator.py 70 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. import argparse
  10. import platform
  11. import string
  12. from library import *
  13. from manifest import *
  14. ###################################################################################################
  15. #
  16. def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch=0):
  17. # by default, use the latest CUDA Toolkit version
  18. cuda_version = [11, 0, 132]
  19. # Update cuda_version based on parsed string
  20. if semantic_ver_string != "":
  21. for i, x in enumerate([int(x) for x in semantic_ver_string.split(".")]):
  22. if i < len(cuda_version):
  23. cuda_version[i] = x
  24. else:
  25. cuda_version.append(x)
  26. return cuda_version >= [major, minor, patch]
  27. ###################################################################################################
  28. ###################################################################################################
  29. #
  30. def CreateGemmOperator(
  31. manifest,
  32. layouts,
  33. tile_descriptions,
  34. data_type,
  35. alignment_constraints,
  36. complex_transforms=None,
  37. epilogue_functor=EpilogueFunctor.LinearCombination,
  38. swizzling_functor=SwizzlingFunctor.Identity8,
  39. ):
  40. if complex_transforms is None:
  41. complex_transforms = [(ComplexTransform.none, ComplexTransform.none)]
  42. element_a, element_b, element_c, element_epilogue = data_type
  43. operations = []
  44. # by default, only generate the largest tile and largest alignment
  45. if manifest.args.kernels == "":
  46. tile_descriptions = [tile_descriptions[0]]
  47. alignment_constraints = [alignment_constraints[0]]
  48. for layout in layouts:
  49. for tile_description in tile_descriptions:
  50. for alignment in alignment_constraints:
  51. for complex_transform in complex_transforms:
  52. alignment_c = min(8, alignment)
  53. A = TensorDescription(
  54. element_a, layout[0], alignment, complex_transform[0]
  55. )
  56. B = TensorDescription(
  57. element_b, layout[1], alignment, complex_transform[1]
  58. )
  59. C = TensorDescription(element_c, layout[2], alignment_c)
  60. new_operation = GemmOperation(
  61. GemmKind.Universal,
  62. tile_description.minimum_compute_capability,
  63. tile_description,
  64. A,
  65. B,
  66. C,
  67. element_epilogue,
  68. epilogue_functor,
  69. swizzling_functor,
  70. )
  71. manifest.append(new_operation)
  72. operations.append(new_operation)
  73. return operations
  74. ###########################################################################################################
  75. # ConvolutionOperator support variations
  76. # ____________________________________________________________________
  77. # ConvolutionalOperator | Analytic | Optimized
  78. # ____________________________________________________________________
  79. # | Fprop | (strided) | (strided)
  80. # | Dgrad | (strided, unity*) | (unity)
  81. # | Wgrad | (strided) | (strided)
  82. # ____________________________________________________________________
  83. #
  84. # Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
  85. ###########################################################################################################
  86. # Convolution for 2D operations
  87. def CreateConv2dOperator(
  88. manifest,
  89. layout,
  90. tile_descriptions,
  91. data_type,
  92. alignment,
  93. conv_kinds=[ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad],
  94. epilogue_functor=EpilogueFunctor.LinearCombination,
  95. ):
  96. element_a, element_b, element_c, element_epilogue = data_type
  97. # one exceptional case
  98. alignment_c = min(8, alignment)
  99. # iterator algorithm (analytic and optimized)
  100. iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
  101. # by default, only generate the largest tile size
  102. if manifest.args.kernels == "":
  103. tile_descriptions = [tile_descriptions[0]]
  104. operations = []
  105. for tile in tile_descriptions:
  106. for conv_kind in conv_kinds:
  107. for iterator_algorithm in iterator_algorithms:
  108. A = TensorDescription(element_a, layout[0], alignment)
  109. B = TensorDescription(element_b, layout[1], alignment)
  110. C = TensorDescription(element_c, layout[2], alignment_c)
  111. # unity stride only for Optimized Dgrad
  112. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
  113. conv_kind == ConvKind.Dgrad
  114. ):
  115. new_operation = Conv2dOperation(
  116. conv_kind,
  117. iterator_algorithm,
  118. tile.minimum_compute_capability,
  119. tile,
  120. A,
  121. B,
  122. C,
  123. element_epilogue,
  124. StrideSupport.Unity,
  125. epilogue_functor,
  126. )
  127. manifest.append(new_operation)
  128. operations.append(new_operation)
  129. # strided dgrad is not supported by Optimized Dgrad
  130. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
  131. conv_kind == ConvKind.Dgrad
  132. ):
  133. continue
  134. # strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
  135. new_operation = Conv2dOperation(
  136. conv_kind,
  137. iterator_algorithm,
  138. tile.minimum_compute_capability,
  139. tile,
  140. A,
  141. B,
  142. C,
  143. element_epilogue,
  144. StrideSupport.Strided,
  145. epilogue_functor,
  146. )
  147. manifest.append(new_operation)
  148. operations.append(new_operation)
  149. return operations
  150. ###################################################################################################
  151. ###################################################################################################
  152. def GenerateConv2d_Simt(args):
  153. operations = []
  154. layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorC4RSK4)]
  155. math_instructions = [
  156. MathInstruction(
  157. [1, 1, 4],
  158. DataType.s8,
  159. DataType.s8,
  160. DataType.s32,
  161. OpcodeClass.Simt,
  162. MathOperation.multiply_add,
  163. )
  164. ]
  165. dst_layouts = [
  166. LayoutType.TensorNC4HW4,
  167. LayoutType.TensorNC32HW32,
  168. LayoutType.TensorNHWC,
  169. LayoutType.TensorNHWC,
  170. LayoutType.TensorNCHW,
  171. ]
  172. dst_types = [DataType.s8, DataType.s8, DataType.u4, DataType.s4, DataType.f32]
  173. max_cc = 1024
  174. for math_inst in math_instructions:
  175. for layout in layouts:
  176. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  177. if dst_type == DataType.s4 or dst_type == DataType.u4:
  178. min_cc = 75
  179. use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt
  180. else:
  181. min_cc = 61
  182. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  183. tile_descriptions = [
  184. TileDescription(
  185. [128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  186. ),
  187. TileDescription(
  188. [128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  189. ),
  190. TileDescription(
  191. [64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc
  192. ),
  193. TileDescription(
  194. [128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc
  195. ),
  196. TileDescription(
  197. [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
  198. ),
  199. TileDescription(
  200. [32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  201. ),
  202. TileDescription(
  203. [64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  204. ),
  205. TileDescription(
  206. [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
  207. ),
  208. TileDescription(
  209. [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
  210. ),
  211. ]
  212. for tile in tile_descriptions:
  213. if (
  214. dst_layout == LayoutType.TensorNC32HW32
  215. and tile.threadblock_shape[0] > 32
  216. ):
  217. continue
  218. if (
  219. dst_layout == LayoutType.TensorNCHW
  220. or dst_layout == LayoutType.TensorNHWC
  221. ) and tile.threadblock_shape[0] > 16:
  222. continue
  223. operations += GenerateConv2d(
  224. ConvType.Convolution,
  225. ConvKind.Fprop,
  226. [tile],
  227. layout[0],
  228. layout[1],
  229. dst_layout,
  230. dst_type,
  231. min_cc,
  232. 32,
  233. 32,
  234. 32,
  235. use_special_optimization,
  236. )
  237. return operations
  238. def GenerateConv2d_TensorOp_8816(args):
  239. operations = []
  240. layouts = [(LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32)]
  241. math_instructions = [
  242. MathInstruction(
  243. [8, 8, 16],
  244. DataType.s8,
  245. DataType.s8,
  246. DataType.s32,
  247. OpcodeClass.TensorOp,
  248. MathOperation.multiply_add_saturate,
  249. )
  250. ]
  251. dst_layouts = [LayoutType.TensorNC32HW32, LayoutType.TensorNC4HW4]
  252. dst_types = [DataType.s8, DataType.s8]
  253. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  254. min_cc = 75
  255. max_cc = 1024
  256. cuda_major = 10
  257. cuda_minor = 2
  258. for math_inst in math_instructions:
  259. for layout in layouts:
  260. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  261. if dst_layout == LayoutType.TensorNC32HW32:
  262. tile_descriptions = [
  263. TileDescription(
  264. [128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc
  265. ),
  266. TileDescription(
  267. [256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc
  268. ),
  269. TileDescription(
  270. [128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  271. ),
  272. TileDescription(
  273. [128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  274. ),
  275. TileDescription(
  276. [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  277. ),
  278. TileDescription(
  279. [128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc
  280. ),
  281. TileDescription(
  282. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  283. ),
  284. ]
  285. operations += GenerateConv2d(
  286. ConvType.Convolution,
  287. ConvKind.Fprop,
  288. tile_descriptions,
  289. layout[0],
  290. layout[1],
  291. dst_layout,
  292. dst_type,
  293. min_cc,
  294. 128,
  295. 128,
  296. 64,
  297. use_special_optimization,
  298. ImplicitGemmMode.GemmTN,
  299. True,
  300. cuda_major,
  301. cuda_minor,
  302. )
  303. else:
  304. assert dst_layout == LayoutType.TensorNC4HW4
  305. tile_descriptions = [
  306. TileDescription(
  307. [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  308. ),
  309. TileDescription(
  310. [32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc
  311. ),
  312. ]
  313. operations += GenerateConv2d(
  314. ConvType.Convolution,
  315. ConvKind.Fprop,
  316. tile_descriptions,
  317. layout[0],
  318. layout[1],
  319. dst_layout,
  320. dst_type,
  321. min_cc,
  322. 128,
  323. 128,
  324. 64,
  325. use_special_optimization,
  326. ImplicitGemmMode.GemmNT,
  327. False,
  328. cuda_major,
  329. cuda_minor,
  330. )
  331. layouts_nhwc = [
  332. (LayoutType.TensorNHWC, LayoutType.TensorNC4HW4, 32),
  333. (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 64),
  334. (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 128),
  335. ]
  336. dst_layouts_nhwc = [LayoutType.TensorNHWC]
  337. for math_inst in math_instructions:
  338. for layout in layouts_nhwc:
  339. for dst_layout in dst_layouts_nhwc:
  340. dst_type = math_inst.element_b
  341. tile_descriptions = [
  342. TileDescription(
  343. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  344. ),
  345. TileDescription(
  346. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  347. ),
  348. ]
  349. for tile in tile_descriptions:
  350. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  351. operations += GenerateConv2d(
  352. ConvType.Convolution,
  353. ConvKind.Fprop,
  354. [tile],
  355. layout[0],
  356. layout[1],
  357. dst_layout,
  358. dst_type,
  359. min_cc,
  360. layout[2],
  361. layout[2],
  362. dst_align,
  363. use_special_optimization,
  364. ImplicitGemmMode.GemmTN,
  365. False,
  366. cuda_major,
  367. cuda_minor,
  368. )
  369. if (
  370. tile.threadblock_shape[1] == 16
  371. or tile.threadblock_shape[1] == 32
  372. ):
  373. operations += GenerateConv2d(
  374. ConvType.Convolution,
  375. ConvKind.Fprop,
  376. [tile],
  377. layout[0],
  378. layout[1],
  379. dst_layout,
  380. dst_type,
  381. min_cc,
  382. layout[2],
  383. layout[2],
  384. dst_align,
  385. use_special_optimization,
  386. ImplicitGemmMode.GemmTN,
  387. True,
  388. cuda_major,
  389. cuda_minor,
  390. )
  391. out_dtypes = [DataType.s4, DataType.u4, DataType.f32]
  392. # INT8x8x4 and INT8x8x32
  393. for math_inst in math_instructions:
  394. for layout in layouts_nhwc:
  395. for dst_layout in dst_layouts_nhwc:
  396. for out_dtype in out_dtypes:
  397. tile_descriptions = [
  398. TileDescription(
  399. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  400. ),
  401. TileDescription(
  402. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  403. ),
  404. ]
  405. for tile in tile_descriptions:
  406. dst_align = (
  407. 4 * DataTypeSize[out_dtype]
  408. if tile.threadblock_shape[1] == 16
  409. or out_dtype == DataType.f32
  410. else 8 * DataTypeSize[out_dtype]
  411. )
  412. operations += GenerateConv2d(
  413. ConvType.Convolution,
  414. ConvKind.Fprop,
  415. [tile],
  416. layout[0],
  417. layout[1],
  418. dst_layout,
  419. out_dtype,
  420. min_cc,
  421. layout[2],
  422. layout[2],
  423. dst_align,
  424. use_special_optimization,
  425. ImplicitGemmMode.GemmTN,
  426. False,
  427. cuda_major,
  428. cuda_minor,
  429. )
  430. if tile.threadblock_shape[1] == 16 or (
  431. tile.threadblock_shape[1] == 32
  432. and out_dtype != DataType.f32
  433. ):
  434. operations += GenerateConv2d(
  435. ConvType.Convolution,
  436. ConvKind.Fprop,
  437. [tile],
  438. layout[0],
  439. layout[1],
  440. dst_layout,
  441. out_dtype,
  442. min_cc,
  443. layout[2],
  444. layout[2],
  445. dst_align,
  446. use_special_optimization,
  447. ImplicitGemmMode.GemmTN,
  448. True,
  449. cuda_major,
  450. cuda_minor,
  451. )
  452. return operations
  453. def GenerateConv2d_TensorOp_8832(args):
  454. operations = []
  455. layouts = [(LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64)]
  456. math_instructions = [
  457. MathInstruction(
  458. [8, 8, 32],
  459. DataType.s4,
  460. DataType.s4,
  461. DataType.s32,
  462. OpcodeClass.TensorOp,
  463. MathOperation.multiply_add_saturate,
  464. ),
  465. MathInstruction(
  466. [8, 8, 32],
  467. DataType.s4,
  468. DataType.u4,
  469. DataType.s32,
  470. OpcodeClass.TensorOp,
  471. MathOperation.multiply_add_saturate,
  472. ),
  473. ]
  474. dst_layouts = [LayoutType.TensorNC64HW64]
  475. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  476. min_cc = 75
  477. max_cc = 1024
  478. cuda_major = 10
  479. cuda_minor = 2
  480. for math_inst in math_instructions:
  481. for layout in layouts:
  482. for dst_layout in dst_layouts:
  483. dst_type = math_inst.element_b
  484. tile_descriptions = [
  485. TileDescription(
  486. [128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc
  487. ),
  488. TileDescription(
  489. [128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc
  490. ),
  491. TileDescription(
  492. [128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc
  493. ),
  494. TileDescription(
  495. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  496. ),
  497. ]
  498. operations += GenerateConv2d(
  499. ConvType.Convolution,
  500. ConvKind.Fprop,
  501. tile_descriptions,
  502. layout[0],
  503. layout[1],
  504. dst_layout,
  505. dst_type,
  506. min_cc,
  507. 128,
  508. 128,
  509. 64,
  510. use_special_optimization,
  511. ImplicitGemmMode.GemmTN,
  512. True,
  513. cuda_major,
  514. cuda_minor,
  515. )
  516. layouts_nhwc = [
  517. (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32),
  518. (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 64),
  519. (LayoutType.TensorNHWC, LayoutType.TensorNC32HW32, 128),
  520. ]
  521. dst_layouts_nhwc = [LayoutType.TensorNHWC]
  522. for math_inst in math_instructions:
  523. for layout in layouts_nhwc:
  524. for dst_layout in dst_layouts_nhwc:
  525. dst_type = math_inst.element_b
  526. tile_descriptions = [
  527. TileDescription(
  528. [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
  529. ),
  530. TileDescription(
  531. [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  532. ),
  533. TileDescription(
  534. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  535. ),
  536. ]
  537. for tile in tile_descriptions:
  538. dst_align = 16 if tile.threadblock_shape[1] == 16 else 32
  539. operations += GenerateConv2d(
  540. ConvType.Convolution,
  541. ConvKind.Fprop,
  542. [tile],
  543. layout[0],
  544. layout[1],
  545. dst_layout,
  546. dst_type,
  547. min_cc,
  548. layout[2],
  549. layout[2],
  550. dst_align,
  551. use_special_optimization,
  552. ImplicitGemmMode.GemmTN,
  553. False,
  554. cuda_major,
  555. cuda_minor,
  556. )
  557. if (
  558. tile.threadblock_shape[1] == 32
  559. or tile.threadblock_shape[1] == 64
  560. ):
  561. dst_align = 32 if tile.threadblock_shape[1] == 32 else 64
  562. operations += GenerateConv2d(
  563. ConvType.Convolution,
  564. ConvKind.Fprop,
  565. [tile],
  566. layout[0],
  567. layout[1],
  568. dst_layout,
  569. dst_type,
  570. min_cc,
  571. layout[2],
  572. layout[2],
  573. dst_align,
  574. use_special_optimization,
  575. ImplicitGemmMode.GemmTN,
  576. True,
  577. cuda_major,
  578. cuda_minor,
  579. )
  580. # INT4x4x8
  581. for math_inst in math_instructions:
  582. for layout in layouts_nhwc:
  583. for dst_layout in dst_layouts_nhwc:
  584. tile_descriptions = [
  585. TileDescription(
  586. [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
  587. ),
  588. TileDescription(
  589. [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  590. ),
  591. TileDescription(
  592. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  593. ),
  594. ]
  595. for tile in tile_descriptions:
  596. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  597. operations += GenerateConv2d(
  598. ConvType.Convolution,
  599. ConvKind.Fprop,
  600. [tile],
  601. layout[0],
  602. layout[1],
  603. dst_layout,
  604. DataType.s8,
  605. min_cc,
  606. layout[2],
  607. layout[2],
  608. dst_align,
  609. use_special_optimization,
  610. ImplicitGemmMode.GemmTN,
  611. False,
  612. cuda_major,
  613. cuda_minor,
  614. )
  615. if (
  616. tile.threadblock_shape[1] == 32
  617. or tile.threadblock_shape[1] == 64
  618. ):
  619. dst_align = 64 if tile.threadblock_shape[1] == 32 else 128
  620. operations += GenerateConv2d(
  621. ConvType.Convolution,
  622. ConvKind.Fprop,
  623. [tile],
  624. layout[0],
  625. layout[1],
  626. dst_layout,
  627. DataType.s8,
  628. min_cc,
  629. layout[2],
  630. layout[2],
  631. dst_align,
  632. use_special_optimization,
  633. ImplicitGemmMode.GemmTN,
  634. True,
  635. cuda_major,
  636. cuda_minor,
  637. )
  638. return operations
  639. def GenerateDeconv_Simt(args):
  640. operations = []
  641. layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorK4RSC4)]
  642. math_instructions = [
  643. MathInstruction(
  644. [1, 1, 4],
  645. DataType.s8,
  646. DataType.s8,
  647. DataType.s32,
  648. OpcodeClass.Simt,
  649. MathOperation.multiply_add,
  650. )
  651. ]
  652. dst_layouts = [LayoutType.TensorNC4HW4]
  653. dst_types = [DataType.s8]
  654. use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
  655. min_cc = 61
  656. max_cc = 1024
  657. for math_inst in math_instructions:
  658. for layout in layouts:
  659. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  660. tile_descriptions = [
  661. TileDescription(
  662. [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
  663. ),
  664. TileDescription(
  665. [16, 128, 16], 2, [1, 2, 1], math_inst, min_cc, max_cc
  666. ),
  667. TileDescription(
  668. [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
  669. ),
  670. TileDescription(
  671. [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
  672. ),
  673. ]
  674. operations += GenerateConv2d(
  675. ConvType.Convolution,
  676. ConvKind.Dgrad,
  677. tile_descriptions,
  678. layout[0],
  679. layout[1],
  680. dst_layout,
  681. dst_type,
  682. min_cc,
  683. 32,
  684. 32,
  685. 32,
  686. use_special_optimization,
  687. )
  688. return operations
  689. def GenerateDeconv_TensorOp_8816(args):
  690. operations = []
  691. layouts = [
  692. (LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32),
  693. (LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64),
  694. (LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128),
  695. ]
  696. math_instructions = [
  697. MathInstruction(
  698. [8, 8, 16],
  699. DataType.s8,
  700. DataType.s8,
  701. DataType.s32,
  702. OpcodeClass.TensorOp,
  703. MathOperation.multiply_add_saturate,
  704. )
  705. ]
  706. dst_layouts = [LayoutType.TensorNHWC]
  707. dst_types = [DataType.s8]
  708. use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
  709. min_cc = 75
  710. max_cc = 1024
  711. cuda_major = 10
  712. cuda_minor = 2
  713. for math_inst in math_instructions:
  714. for layout in layouts:
  715. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  716. tile_descriptions = [
  717. TileDescription(
  718. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  719. ),
  720. TileDescription(
  721. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  722. ),
  723. ]
  724. for tile in tile_descriptions:
  725. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  726. operations += GenerateConv2d(
  727. ConvType.Convolution,
  728. ConvKind.Dgrad,
  729. [tile],
  730. layout[0],
  731. layout[1],
  732. dst_layout,
  733. dst_type,
  734. min_cc,
  735. layout[2],
  736. layout[2],
  737. dst_align,
  738. use_special_optimization,
  739. ImplicitGemmMode.GemmTN,
  740. False,
  741. cuda_major,
  742. cuda_minor,
  743. )
  744. return operations
  745. ################################################################################
  746. # parameters
  747. # Edge - for tiles, the edges represent the length of one side
  748. # Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles
  749. # MaxEdge - maximum length of each edge
  750. # Min/Max - minimum/maximum of the product of edge lengths
  751. ################################################################################
  752. warpsPerThreadblockEdge = [1, 2, 4, 8, 16]
  753. warpsPerThreadblockRatio = 2
  754. warpsPerThreadblockMax = 16
  755. # NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases
  756. warpShapeEdges = [8, 16, 32, 64, 128, 256]
  757. warpShapeRatio = 4
  758. warpShapeMax = 64 * 64
  759. warpShapeMin = 8 * 8
  760. threadblockEdgeMax = 256
  761. # char, type bits/elem, max tile, L0 threadblock tiles
  762. precisions = {
  763. "c": ["cutlass::complex<float>", 64, 64 * 128, [[64, 128], [64, 32]]],
  764. "d": ["double", 64, 64 * 64, [[64, 64], [32, 32]]],
  765. "h": ["cutlass::half_t", 16, 128 * 256, [[256, 128], [64, 128], [64, 32]]],
  766. "i": ["int", 32, 128 * 128, [[128, 64], [16, 32]]],
  767. "s": ["float", 32, 128 * 128, [[128, 256], [128, 128], [64, 64]]],
  768. "z": ["cutlass::complex<double>", 128, 64 * 64, [[32, 64], [16, 32]]],
  769. }
  770. # L1 will have a single kernel for every unique shape
  771. # L2 will have everything else
  772. def GenerateGemm_Simt(args):
  773. ################################################################################
  774. # warps per threadblock
  775. ################################################################################
  776. warpsPerThreadblocks = []
  777. for warpsPerThreadblock0 in warpsPerThreadblockEdge:
  778. for warpsPerThreadblock1 in warpsPerThreadblockEdge:
  779. if (
  780. warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
  781. and warpsPerThreadblock1 / warpsPerThreadblock0
  782. <= warpsPerThreadblockRatio
  783. and warpsPerThreadblock0 * warpsPerThreadblock1
  784. <= warpsPerThreadblockMax
  785. ):
  786. warpsPerThreadblocks.append(
  787. [warpsPerThreadblock0, warpsPerThreadblock1]
  788. )
  789. ################################################################################
  790. # warp shapes
  791. ################################################################################
  792. warpNumThreads = 32
  793. warpShapes = []
  794. for warp0 in warpShapeEdges:
  795. for warp1 in warpShapeEdges:
  796. if (
  797. warp0 / warp1 <= warpShapeRatio
  798. and warp1 / warp0 <= warpShapeRatio
  799. and warp0 * warp1 <= warpShapeMax
  800. and warp0 * warp1 > warpShapeMin
  801. ):
  802. warpShapes.append([warp0, warp1])
  803. # sgemm
  804. precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[
  805. "s"
  806. ]
  807. layouts = [
  808. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  809. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  810. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  811. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  812. ]
  813. math_instructions = [
  814. MathInstruction(
  815. [1, 1, 1],
  816. DataType.f32,
  817. DataType.f32,
  818. DataType.f32,
  819. OpcodeClass.Simt,
  820. MathOperation.multiply_add,
  821. )
  822. ]
  823. min_cc = 50
  824. max_cc = 1024
  825. operations = []
  826. for math_inst in math_instructions:
  827. for layout in layouts:
  828. data_type = [
  829. math_inst.element_a,
  830. math_inst.element_b,
  831. math_inst.element_accumulator,
  832. math_inst.element_accumulator,
  833. ]
  834. tile_descriptions = [
  835. TileDescription([64, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  836. TileDescription([256, 64, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  837. TileDescription([32, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  838. TileDescription([256, 32, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  839. TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  840. TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  841. TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  842. TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  843. TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  844. TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  845. TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  846. TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  847. TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  848. TileDescription([8, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  849. TileDescription([16, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  850. TileDescription([16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  851. TileDescription([16, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  852. ]
  853. for warpsPerThreadblock in warpsPerThreadblocks:
  854. for warpShape in warpShapes:
  855. warpThreadsM = 0
  856. if warpShape[0] > warpShape[1]:
  857. warpThreadsM = 8
  858. else:
  859. warpThreadsM = 4
  860. warpThreadsN = warpNumThreads / warpThreadsM
  861. # skip shapes with conflicting rectangularity
  862. # they are unlikely to be fastest
  863. blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
  864. blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
  865. warpG = warpShape[0] > warpShape[1]
  866. warpL = warpShape[0] < warpShape[1]
  867. blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
  868. blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
  869. warpG2 = warpShape[0] > warpShape[1] * 2
  870. warpL2 = warpShape[0] * 2 < warpShape[1]
  871. if blockG2 and warpL:
  872. continue
  873. if blockL2 and warpG:
  874. continue
  875. if warpG2 and blockL:
  876. continue
  877. if warpL2 and blockG:
  878. continue
  879. # check threadblock ratios and max
  880. threadblockTile = [
  881. warpShape[0] * warpsPerThreadblock[0],
  882. warpShape[1] * warpsPerThreadblock[1],
  883. ]
  884. if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
  885. continue
  886. if threadblockTile[0] > threadblockEdgeMax:
  887. continue
  888. if threadblockTile[1] > threadblockEdgeMax:
  889. continue
  890. totalThreads = (
  891. warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
  892. )
  893. # calculate unroll
  894. # ensure that every iteration at least a full load of A,B are done
  895. unrollMin = 8
  896. unrollMin0 = totalThreads // threadblockTile[0]
  897. unrollMin1 = totalThreads // threadblockTile[1]
  898. unroll = max(unrollMin, unrollMin0, unrollMin1)
  899. threadTileM = warpShape[0] // warpThreadsM
  900. threadTileN = warpShape[1] // warpThreadsN
  901. if threadTileM < 2 or threadTileN < 2:
  902. continue
  903. if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
  904. continue
  905. # epilogue currently only supports N < WarpNumThreads
  906. if threadblockTile[1] < warpNumThreads:
  907. continue
  908. # limit smem
  909. smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
  910. smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
  911. smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
  912. if smemKBytes > 48:
  913. continue
  914. tile = TileDescription(
  915. [threadblockTile[0], threadblockTile[1], unroll],
  916. 2,
  917. [
  918. threadblockTile[0] // warpShape[0],
  919. threadblockTile[1] // warpShape[1],
  920. 1,
  921. ],
  922. math_inst,
  923. min_cc,
  924. max_cc,
  925. )
  926. def filter(t: TileDescription) -> bool:
  927. nonlocal tile
  928. return (
  929. t.threadblock_shape[0] == tile.threadblock_shape[0]
  930. and t.threadblock_shape[1] == tile.threadblock_shape[1]
  931. and t.threadblock_shape[2] == tile.threadblock_shape[2]
  932. and t.warp_count[0] == tile.warp_count[0]
  933. and t.warp_count[1] == tile.warp_count[1]
  934. and t.warp_count[2] == tile.warp_count[2]
  935. and t.stages == tile.stages
  936. )
  937. if not any(t for t in tile_descriptions if filter(t)):
  938. continue
  939. operations += GeneratesGemm(
  940. tile, data_type, layout[0], layout[1], layout[2], min_cc
  941. )
  942. return operations
  943. #
  944. def GenerateDwconv2d_Simt(args, conv_kind):
  945. ################################################################################
  946. # warps per threadblock
  947. ################################################################################
  948. warpsPerThreadblocks = []
  949. for warpsPerThreadblock0 in warpsPerThreadblockEdge:
  950. for warpsPerThreadblock1 in warpsPerThreadblockEdge:
  951. if (
  952. warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
  953. and warpsPerThreadblock1 / warpsPerThreadblock0
  954. <= warpsPerThreadblockRatio
  955. and warpsPerThreadblock0 * warpsPerThreadblock1
  956. <= warpsPerThreadblockMax
  957. ):
  958. warpsPerThreadblocks.append(
  959. [warpsPerThreadblock0, warpsPerThreadblock1]
  960. )
  961. ################################################################################
  962. # warp shapes
  963. ################################################################################
  964. warpNumThreads = 32
  965. warpShapes = []
  966. for warp0 in warpShapeEdges:
  967. for warp1 in warpShapeEdges:
  968. if (
  969. warp0 / warp1 <= warpShapeRatio
  970. and warp1 / warp0 <= warpShapeRatio
  971. and warp0 * warp1 <= warpShapeMax
  972. and warp0 * warp1 > warpShapeMin
  973. ):
  974. warpShapes.append([warp0, warp1])
  975. # sgemm
  976. precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[
  977. "s"
  978. ]
  979. layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
  980. math_instructions = [
  981. MathInstruction(
  982. [1, 1, 1],
  983. DataType.f32,
  984. DataType.f32,
  985. DataType.f32,
  986. OpcodeClass.Simt,
  987. MathOperation.multiply_add,
  988. )
  989. ]
  990. min_cc = 50
  991. max_cc = 1024
  992. dst_layouts = [LayoutType.TensorNCHW]
  993. dst_types = [DataType.f32]
  994. if conv_kind == ConvKind.Wgrad:
  995. alignment_constraints = [32]
  996. else:
  997. alignment_constraints = [128, 32]
  998. operations = []
  999. for math_inst in math_instructions:
  1000. tile_descriptions = [
  1001. TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  1002. TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1003. TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1004. TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  1005. TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  1006. TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  1007. TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1008. TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1009. TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1010. ]
  1011. for warpsPerThreadblock in warpsPerThreadblocks:
  1012. for warpShape in warpShapes:
  1013. warpThreadsM = 0
  1014. if warpShape[0] > warpShape[1]:
  1015. warpThreadsM = 8
  1016. else:
  1017. warpThreadsM = 4
  1018. warpThreadsN = warpNumThreads / warpThreadsM
  1019. # skip shapes with conflicting rectangularity
  1020. # they are unlikely to be fastest
  1021. blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
  1022. blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
  1023. warpG = warpShape[0] > warpShape[1]
  1024. warpL = warpShape[0] < warpShape[1]
  1025. blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
  1026. blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
  1027. warpG2 = warpShape[0] > warpShape[1] * 2
  1028. warpL2 = warpShape[0] * 2 < warpShape[1]
  1029. if blockG2 and warpL:
  1030. continue
  1031. if blockL2 and warpG:
  1032. continue
  1033. if warpG2 and blockL:
  1034. continue
  1035. if warpL2 and blockG:
  1036. continue
  1037. # check threadblock ratios and max
  1038. threadblockTile = [
  1039. warpShape[0] * warpsPerThreadblock[0],
  1040. warpShape[1] * warpsPerThreadblock[1],
  1041. ]
  1042. if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
  1043. continue
  1044. if threadblockTile[0] > threadblockEdgeMax:
  1045. continue
  1046. if threadblockTile[1] > threadblockEdgeMax:
  1047. continue
  1048. totalThreads = (
  1049. warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
  1050. )
  1051. # calculate unroll
  1052. # ensure that every iteration at least a full load of A,B are done
  1053. unrollMin = 8
  1054. unrollMin0 = totalThreads // threadblockTile[0]
  1055. unrollMin1 = totalThreads // threadblockTile[1]
  1056. unroll = max(unrollMin, unrollMin0, unrollMin1)
  1057. threadTileM = warpShape[0] // warpThreadsM
  1058. threadTileN = warpShape[1] // warpThreadsN
  1059. if threadTileM < 2 or threadTileN < 2:
  1060. continue
  1061. if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
  1062. continue
  1063. # epilogue currently only supports N < WarpNumThreads
  1064. if threadblockTile[1] < warpNumThreads:
  1065. continue
  1066. # limit smem
  1067. smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
  1068. smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
  1069. smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
  1070. if smemKBytes > 48:
  1071. continue
  1072. tile = TileDescription(
  1073. [threadblockTile[0], threadblockTile[1], unroll],
  1074. 2,
  1075. [
  1076. threadblockTile[0] // warpShape[0],
  1077. threadblockTile[1] // warpShape[1],
  1078. 1,
  1079. ],
  1080. math_inst,
  1081. min_cc,
  1082. max_cc,
  1083. )
  1084. def filter(t: TileDescription) -> bool:
  1085. nonlocal tile
  1086. return (
  1087. t.threadblock_shape[0] == tile.threadblock_shape[0]
  1088. and t.threadblock_shape[1] == tile.threadblock_shape[1]
  1089. and t.threadblock_shape[2] == tile.threadblock_shape[2]
  1090. and t.warp_count[0] == tile.warp_count[0]
  1091. and t.warp_count[1] == tile.warp_count[1]
  1092. and t.warp_count[2] == tile.warp_count[2]
  1093. and t.stages == tile.stages
  1094. )
  1095. if not any(t for t in tile_descriptions if filter(t)):
  1096. continue
  1097. for layout in layouts:
  1098. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  1099. for alignment_src in alignment_constraints:
  1100. operations += GenerateConv2d(
  1101. ConvType.DepthwiseConvolution,
  1102. conv_kind,
  1103. [tile],
  1104. layout[0],
  1105. layout[1],
  1106. dst_layout,
  1107. dst_type,
  1108. min_cc,
  1109. alignment_src,
  1110. 32,
  1111. 32,
  1112. SpecialOptimizeDesc.NoneSpecialOpt,
  1113. ImplicitGemmMode.GemmNT
  1114. if conv_kind == ConvKind.Wgrad
  1115. else ImplicitGemmMode.GemmTN,
  1116. )
  1117. return operations
  1118. #
  1119. def GenerateDwconv2d_TensorOp_884(args, conv_kind):
  1120. layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
  1121. math_instructions = [
  1122. MathInstruction(
  1123. [8, 8, 4],
  1124. DataType.f16,
  1125. DataType.f16,
  1126. DataType.f32,
  1127. OpcodeClass.TensorOp,
  1128. MathOperation.multiply_add,
  1129. ),
  1130. MathInstruction(
  1131. [8, 8, 4],
  1132. DataType.f16,
  1133. DataType.f16,
  1134. DataType.f16,
  1135. OpcodeClass.TensorOp,
  1136. MathOperation.multiply_add,
  1137. ),
  1138. ]
  1139. min_cc = 70
  1140. max_cc = 75
  1141. dst_layouts = [LayoutType.TensorNCHW]
  1142. if conv_kind == ConvKind.Wgrad:
  1143. dst_types = [DataType.f32]
  1144. else:
  1145. dst_types = [DataType.f16]
  1146. alignment_constraints = [128, 32, 16]
  1147. cuda_major = 10
  1148. cuda_minor = 1
  1149. operations = []
  1150. for math_inst in math_instructions:
  1151. tile_descriptions = [
  1152. TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  1153. TileDescription([128, 128, 32], 2, [4, 4, 1], math_inst, min_cc, max_cc),
  1154. TileDescription([64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  1155. TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  1156. TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1157. ]
  1158. for layout in layouts:
  1159. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  1160. for alignment_src in alignment_constraints:
  1161. if conv_kind == ConvKind.Wgrad:
  1162. # skip io16xc16
  1163. if math_inst.element_accumulator == DataType.f16:
  1164. continue
  1165. for alignment_diff in alignment_constraints:
  1166. operations += GenerateConv2d(
  1167. ConvType.DepthwiseConvolution,
  1168. conv_kind,
  1169. tile_descriptions,
  1170. layout[0],
  1171. layout[1],
  1172. dst_layout,
  1173. dst_type,
  1174. min_cc,
  1175. alignment_src,
  1176. alignment_diff,
  1177. 32, # always f32 output
  1178. SpecialOptimizeDesc.NoneSpecialOpt,
  1179. ImplicitGemmMode.GemmNT,
  1180. False,
  1181. cuda_major,
  1182. cuda_minor,
  1183. )
  1184. else:
  1185. operations += GenerateConv2d(
  1186. ConvType.DepthwiseConvolution,
  1187. conv_kind,
  1188. tile_descriptions,
  1189. layout[0],
  1190. layout[1],
  1191. dst_layout,
  1192. dst_type,
  1193. min_cc,
  1194. alignment_src,
  1195. 16,
  1196. 16,
  1197. SpecialOptimizeDesc.NoneSpecialOpt,
  1198. ImplicitGemmMode.GemmTN,
  1199. False,
  1200. cuda_major,
  1201. cuda_minor,
  1202. )
  1203. return operations
  1204. #
  1205. def GenerateGemv_Simt(args):
  1206. threadBlockShape_N = [128, 64, 32]
  1207. ldgBits_A = [128, 64, 32]
  1208. ldgBits_B = [128, 64, 32]
  1209. layouts = [(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor)]
  1210. math_instructions = [
  1211. MathInstruction(
  1212. [1, 1, 1],
  1213. DataType.f32,
  1214. DataType.f32,
  1215. DataType.f32,
  1216. OpcodeClass.Simt,
  1217. MathOperation.multiply_add,
  1218. )
  1219. ]
  1220. min_cc = 50
  1221. operations = []
  1222. for math_inst in math_instructions:
  1223. for layout in layouts:
  1224. data_type = [
  1225. math_inst.element_a,
  1226. math_inst.element_b,
  1227. math_inst.element_accumulator,
  1228. math_inst.element_accumulator,
  1229. ]
  1230. for threadblock_shape_n in threadBlockShape_N:
  1231. for align_a in ldgBits_A:
  1232. for align_b in ldgBits_B:
  1233. ldg_elements_a = align_a // DataTypeSize[math_inst.element_a]
  1234. ldg_elements_b = align_b // DataTypeSize[math_inst.element_b]
  1235. threadblock_shape_k = (256 * ldg_elements_a) // (
  1236. threadblock_shape_n // ldg_elements_b
  1237. )
  1238. threadblock_shape = [
  1239. 1,
  1240. threadblock_shape_n,
  1241. threadblock_shape_k,
  1242. ]
  1243. thread_shape = [1, ldg_elements_b, ldg_elements_a]
  1244. operations.append(
  1245. GeneratesGemv(
  1246. math_inst,
  1247. threadblock_shape,
  1248. thread_shape,
  1249. data_type,
  1250. layout[0],
  1251. layout[1],
  1252. layout[2],
  1253. min_cc,
  1254. align_a,
  1255. align_b,
  1256. )
  1257. )
  1258. return operations
  1259. #
  1260. def GeneratesGemm_TensorOp_1688(args):
  1261. layouts = [
  1262. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  1263. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  1264. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  1265. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  1266. ]
  1267. math_instructions = [
  1268. MathInstruction(
  1269. [16, 8, 8],
  1270. DataType.f16,
  1271. DataType.f16,
  1272. DataType.f32,
  1273. OpcodeClass.TensorOp,
  1274. MathOperation.multiply_add,
  1275. ),
  1276. MathInstruction(
  1277. [16, 8, 8],
  1278. DataType.f16,
  1279. DataType.f16,
  1280. DataType.f16,
  1281. OpcodeClass.TensorOp,
  1282. MathOperation.multiply_add,
  1283. ),
  1284. ]
  1285. min_cc = 75
  1286. max_cc = 1024
  1287. alignment_constraints = [
  1288. 8,
  1289. 4,
  1290. 2,
  1291. # 1
  1292. ]
  1293. cuda_major = 10
  1294. cuda_minor = 2
  1295. operations = []
  1296. for math_inst in math_instructions:
  1297. for layout in layouts:
  1298. for align in alignment_constraints:
  1299. tile_descriptions = [
  1300. TileDescription(
  1301. [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
  1302. ),
  1303. TileDescription(
  1304. [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  1305. ),
  1306. TileDescription(
  1307. [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  1308. ),
  1309. ## comment some configuration to reduce compilation time and binary size
  1310. # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1311. # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1312. # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1313. ]
  1314. data_type = [
  1315. math_inst.element_a,
  1316. math_inst.element_b,
  1317. math_inst.element_a,
  1318. math_inst.element_accumulator,
  1319. ]
  1320. for tile in tile_descriptions:
  1321. operations += GeneratesGemm(
  1322. tile,
  1323. data_type,
  1324. layout[0],
  1325. layout[1],
  1326. layout[2],
  1327. min_cc,
  1328. align * 16,
  1329. align * 16,
  1330. align * 16,
  1331. cuda_major,
  1332. cuda_minor,
  1333. )
  1334. return operations
  1335. #
  1336. def GeneratesGemm_TensorOp_884(args):
  1337. layouts = [
  1338. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  1339. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  1340. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  1341. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  1342. ]
  1343. math_instructions = [
  1344. MathInstruction(
  1345. [8, 8, 4],
  1346. DataType.f16,
  1347. DataType.f16,
  1348. DataType.f32,
  1349. OpcodeClass.TensorOp,
  1350. MathOperation.multiply_add,
  1351. ),
  1352. MathInstruction(
  1353. [8, 8, 4],
  1354. DataType.f16,
  1355. DataType.f16,
  1356. DataType.f16,
  1357. OpcodeClass.TensorOp,
  1358. MathOperation.multiply_add,
  1359. ),
  1360. ]
  1361. min_cc = 70
  1362. max_cc = 75
  1363. alignment_constraints = [
  1364. 8,
  1365. 4,
  1366. 2,
  1367. # 1
  1368. ]
  1369. cuda_major = 10
  1370. cuda_minor = 1
  1371. operations = []
  1372. for math_inst in math_instructions:
  1373. for layout in layouts:
  1374. for align in alignment_constraints:
  1375. tile_descriptions = [
  1376. TileDescription(
  1377. [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
  1378. ),
  1379. TileDescription(
  1380. [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  1381. ),
  1382. TileDescription(
  1383. [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  1384. ),
  1385. ## comment some configuration to reduce compilation time and binary size
  1386. # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1387. # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1388. # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1389. ]
  1390. data_type = [
  1391. math_inst.element_a,
  1392. math_inst.element_b,
  1393. math_inst.element_a,
  1394. math_inst.element_accumulator,
  1395. ]
  1396. for tile in tile_descriptions:
  1397. operations += GeneratesGemm(
  1398. tile,
  1399. data_type,
  1400. layout[0],
  1401. layout[1],
  1402. layout[2],
  1403. min_cc,
  1404. align * 16,
  1405. align * 16,
  1406. align * 16,
  1407. cuda_major,
  1408. cuda_minor,
  1409. )
  1410. return operations
  1411. #
  1412. def GenerateConv2dOperations(args):
  1413. if args.type == "simt":
  1414. return GenerateConv2d_Simt(args)
  1415. elif args.type == "tensorop8816":
  1416. return GenerateConv2d_TensorOp_8816(args)
  1417. else:
  1418. assert args.type == "tensorop8832", (
  1419. "operation conv2d only support"
  1420. "simt, tensorop8816 and tensorop8832. (got:{})".format(args.type)
  1421. )
  1422. return GenerateConv2d_TensorOp_8832(args)
  1423. def GenerateDeconvOperations(args):
  1424. if args.type == "simt":
  1425. return GenerateDeconv_Simt(args)
  1426. else:
  1427. assert args.type == "tensorop8816", (
  1428. "operation deconv only support"
  1429. "simt and tensorop8816. (got:{})".format(args.type)
  1430. )
  1431. return GenerateDeconv_TensorOp_8816(args)
  1432. def GenerateDwconv2dFpropOperations(args):
  1433. if args.type == "simt":
  1434. return GenerateDwconv2d_Simt(args, ConvKind.Fprop)
  1435. else:
  1436. assert args.type == "tensorop884", (
  1437. "operation dwconv2d fprop only support"
  1438. "simt, tensorop884. (got:{})".format(args.type)
  1439. )
  1440. return GenerateDwconv2d_TensorOp_884(args, ConvKind.Fprop)
  1441. def GenerateDwconv2dDgradOperations(args):
  1442. if args.type == "simt":
  1443. return GenerateDwconv2d_Simt(args, ConvKind.Dgrad)
  1444. else:
  1445. assert args.type == "tensorop884", (
  1446. "operation dwconv2d fprop only support"
  1447. "simt, tensorop884. (got:{})".format(args.type)
  1448. )
  1449. return GenerateDwconv2d_TensorOp_884(args, ConvKind.Dgrad)
  1450. def GenerateDwconv2dWgradOperations(args):
  1451. if args.type == "simt":
  1452. return GenerateDwconv2d_Simt(args, ConvKind.Wgrad)
  1453. else:
  1454. assert args.type == "tensorop884", (
  1455. "operation dwconv2d fprop only support"
  1456. "simt, tensorop884. (got:{})".format(args.type)
  1457. )
  1458. return GenerateDwconv2d_TensorOp_884(args, ConvKind.Wgrad)
  1459. def GenerateGemmOperations(args):
  1460. if args.type == "tensorop884":
  1461. return GeneratesGemm_TensorOp_884(args)
  1462. elif args.type == "tensorop1688":
  1463. return GeneratesGemm_TensorOp_1688(args)
  1464. else:
  1465. assert (
  1466. args.type == "simt"
  1467. ), "operation gemm only support" "simt. (got:{})".format(args.type)
  1468. return GenerateGemm_Simt(args)
  1469. def GenerateGemvOperations(args):
  1470. assert args.type == "simt", "operation gemv only support" "simt. (got:{})".format(
  1471. args.type
  1472. )
  1473. return GenerateGemv_Simt(args)
  1474. def concat_file(file_path:str,file_name_first:str,file_name_last:str,head:str,required_cuda_ver_major:str, required_cuda_ver_minor:str, epilogue:str, wrapper_path = None):
  1475. import os
  1476. meragefiledir = file_path
  1477. filenames=os.listdir(meragefiledir)
  1478. file1=open(file_path + '/{}_{}_1.cu'.format(file_name_first,file_name_last),'w')
  1479. file2=open(file_path + '/{}_{}_2.cu'.format(file_name_first,file_name_last),'w')
  1480. if wrapper_path is None:
  1481. file1.write(
  1482. SubstituteTemplate(
  1483. head,
  1484. {
  1485. "required_cuda_ver_major": str(
  1486. required_cuda_ver_major
  1487. ),
  1488. "required_cuda_ver_minor": str(
  1489. required_cuda_ver_minor
  1490. ),
  1491. },
  1492. )
  1493. )
  1494. file2.write(
  1495. SubstituteTemplate(
  1496. head,
  1497. {
  1498. "required_cuda_ver_major": str(
  1499. required_cuda_ver_major
  1500. ),
  1501. "required_cuda_ver_minor": str(
  1502. required_cuda_ver_minor
  1503. ),
  1504. },
  1505. )
  1506. )
  1507. else:
  1508. file1.write(
  1509. SubstituteTemplate(
  1510. head,
  1511. {
  1512. "wrapper_path": wrapper_path,
  1513. "required_cuda_ver_major": str(
  1514. required_cuda_ver_major
  1515. ),
  1516. "required_cuda_ver_minor": str(
  1517. required_cuda_ver_minor
  1518. ),
  1519. },
  1520. )
  1521. )
  1522. file2.write(
  1523. SubstituteTemplate(
  1524. head,
  1525. {
  1526. "wrapper_path": wrapper_path,
  1527. "required_cuda_ver_major": str(
  1528. required_cuda_ver_major
  1529. ),
  1530. "required_cuda_ver_minor": str(
  1531. required_cuda_ver_minor
  1532. ),
  1533. },
  1534. )
  1535. )
  1536. flag = 0
  1537. if "tensorop" in file_name_last:
  1538. sub_string_1 = "tensorop"
  1539. sub_string_2 = file_name_last[8:]
  1540. else:
  1541. sub_string_1 = sub_string_2 = "simt"
  1542. if "dwconv2d_" in file_name_first:
  1543. file_name_first = file_name_first[:2]+file_name_first[9:]
  1544. elif ("conv2d" in file_name_first) or ("deconv" in file_name_first):
  1545. file_name_first = "cutlass"
  1546. for filename in filenames:
  1547. if (file_name_first in filename) and (sub_string_1 in filename) and (sub_string_2 in filename) and ("all_" not in filename):
  1548. flag += 1
  1549. filepath=meragefiledir+'/'+filename
  1550. if flag <= len(filenames)/2:
  1551. for line in open(filepath):
  1552. file1.writelines(line)
  1553. else:
  1554. for line in open(filepath):
  1555. file2.writelines(line)
  1556. os.remove(filepath)
  1557. file1.write('\n')
  1558. file2.write('\n')
  1559. elif filename[0].isdigit() and ("all_" not in filename):
  1560. flag += 1
  1561. filepath=meragefiledir+'/'+filename
  1562. if flag <= len(filenames)/2:
  1563. for line in open(filepath):
  1564. file1.writelines(line)
  1565. else:
  1566. for line in open(filepath):
  1567. file2.writelines(line)
  1568. os.remove(filepath)
  1569. file1.write('\n')
  1570. file2.write('\n')
  1571. file1.write(epilogue)
  1572. file2.write(epilogue)
  1573. file1.close()
  1574. file2.close()
  1575. ###################################################################################################
  1576. ###################################################################################################
  1577. if __name__ == "__main__":
  1578. parser = argparse.ArgumentParser(
  1579. description="Generates device kernel registration code for CUTLASS Kernels"
  1580. )
  1581. parser.add_argument(
  1582. "--operations",
  1583. type=str,
  1584. choices=[
  1585. "gemm",
  1586. "gemv",
  1587. "conv2d",
  1588. "deconv",
  1589. "dwconv2d_fprop",
  1590. "dwconv2d_dgrad",
  1591. "dwconv2d_wgrad",
  1592. ],
  1593. required=True,
  1594. help="Specifies the operation to generate (gemm, gemv, conv2d, deconv, dwconv2d_fprop, dwconv2d_dgrad, dwconv2d_wgrad)",
  1595. )
  1596. parser.add_argument(
  1597. "output", type=str, help="output directory for CUTLASS kernel files"
  1598. )
  1599. parser.add_argument(
  1600. "--type",
  1601. type=str,
  1602. choices=["simt", "tensorop8816", "tensorop8832", "tensorop884", "tensorop1688"],
  1603. default="simt",
  1604. help="kernel type of CUTLASS kernel generator",
  1605. )
  1606. gemv_wrapper_path = (
  1607. "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
  1608. )
  1609. short_path = (
  1610. platform.system() == "Windows" or platform.system().find("NT") >= 0
  1611. ) and ("true" != os.getenv("CUTLASS_WITH_LONG_PATH", default="False").lower())
  1612. args = parser.parse_args()
  1613. if args.operations == "gemm":
  1614. operations = GenerateGemmOperations(args)
  1615. elif args.operations == "gemv":
  1616. operations = GenerateGemvOperations(args)
  1617. elif args.operations == "conv2d":
  1618. operations = GenerateConv2dOperations(args)
  1619. elif args.operations == "deconv":
  1620. operations = GenerateDeconvOperations(args)
  1621. elif args.operations == "dwconv2d_fprop":
  1622. operations = GenerateDwconv2dFpropOperations(args)
  1623. elif args.operations == "dwconv2d_dgrad":
  1624. operations = GenerateDwconv2dDgradOperations(args)
  1625. else:
  1626. assert args.operations == "dwconv2d_wgrad", "invalid operation"
  1627. operations = GenerateDwconv2dWgradOperations(args)
  1628. if (
  1629. args.operations == "conv2d"
  1630. or args.operations == "deconv"
  1631. or args.operations == "dwconv2d_fprop"
  1632. or args.operations == "dwconv2d_dgrad"
  1633. or args.operations == "dwconv2d_wgrad"
  1634. ):
  1635. for operation in operations:
  1636. with EmitConvSingleKernelWrapper(
  1637. args.output, operation, short_path
  1638. ) as emitter:
  1639. emitter.emit()
  1640. head = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).header_template
  1641. required_cuda_ver_major = operations[0].required_cuda_ver_major
  1642. required_cuda_ver_minor = operations[0].required_cuda_ver_minor
  1643. epilogue = EmitConvSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template
  1644. concat_file(args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue)
  1645. elif args.operations == "gemm":
  1646. for operation in operations:
  1647. with EmitGemmSingleKernelWrapper(
  1648. args.output, operation, short_path
  1649. ) as emitter:
  1650. emitter.emit()
  1651. head = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).header_template
  1652. required_cuda_ver_major = operations[0].required_cuda_ver_major
  1653. required_cuda_ver_minor = operations[0].required_cuda_ver_minor
  1654. epilogue = EmitGemmSingleKernelWrapper(args.output, operations[0], short_path).epilogue_template
  1655. concat_file(args.output, args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue)
  1656. elif args.operations == "gemv":
  1657. for operation in operations:
  1658. with EmitGemvSingleKernelWrapper(
  1659. args.output, operation, gemv_wrapper_path, short_path
  1660. ) as emitter:
  1661. emitter.emit()
  1662. head = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).header_template
  1663. required_cuda_ver_major = operations[0].required_cuda_ver_major
  1664. required_cuda_ver_minor = operations[0].required_cuda_ver_minor
  1665. epilogue = EmitGemvSingleKernelWrapper(args.output, operations[0], gemv_wrapper_path, short_path).epilogue_template
  1666. concat_file(args.output,args.operations, args.type, head,required_cuda_ver_major, required_cuda_ver_minor, epilogue, wrapper_path = gemv_wrapper_path)
  1667. if args.operations != "gemv":
  1668. GenerateManifest(args, operations, args.output)
  1669. #
  1670. ###################################################################################################