You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generator.py 63 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. import argparse
  10. import platform
  11. from library import *
  12. from manifest import *
  13. ###################################################################################################
  14. #
  15. def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch=0):
  16. # by default, use the latest CUDA Toolkit version
  17. cuda_version = [11, 0, 132]
  18. # Update cuda_version based on parsed string
  19. if semantic_ver_string != "":
  20. for i, x in enumerate([int(x) for x in semantic_ver_string.split(".")]):
  21. if i < len(cuda_version):
  22. cuda_version[i] = x
  23. else:
  24. cuda_version.append(x)
  25. return cuda_version >= [major, minor, patch]
  26. ###################################################################################################
  27. ###################################################################################################
  28. #
  29. def CreateGemmOperator(
  30. manifest,
  31. layouts,
  32. tile_descriptions,
  33. data_type,
  34. alignment_constraints,
  35. complex_transforms=None,
  36. epilogue_functor=EpilogueFunctor.LinearCombination,
  37. swizzling_functor=SwizzlingFunctor.Identity8,
  38. ):
  39. if complex_transforms is None:
  40. complex_transforms = [(ComplexTransform.none, ComplexTransform.none)]
  41. element_a, element_b, element_c, element_epilogue = data_type
  42. operations = []
  43. # by default, only generate the largest tile and largest alignment
  44. if manifest.args.kernels == "":
  45. tile_descriptions = [tile_descriptions[0]]
  46. alignment_constraints = [alignment_constraints[0]]
  47. for layout in layouts:
  48. for tile_description in tile_descriptions:
  49. for alignment in alignment_constraints:
  50. for complex_transform in complex_transforms:
  51. alignment_c = min(8, alignment)
  52. A = TensorDescription(
  53. element_a, layout[0], alignment, complex_transform[0]
  54. )
  55. B = TensorDescription(
  56. element_b, layout[1], alignment, complex_transform[1]
  57. )
  58. C = TensorDescription(element_c, layout[2], alignment_c)
  59. new_operation = GemmOperation(
  60. GemmKind.Universal,
  61. tile_description.minimum_compute_capability,
  62. tile_description,
  63. A,
  64. B,
  65. C,
  66. element_epilogue,
  67. epilogue_functor,
  68. swizzling_functor,
  69. )
  70. manifest.append(new_operation)
  71. operations.append(new_operation)
  72. return operations
  73. ###########################################################################################################
  74. # ConvolutionOperator support variations
  75. # ____________________________________________________________________
  76. # ConvolutionalOperator | Analytic | Optimized
  77. # ____________________________________________________________________
  78. # | Fprop | (strided) | (strided)
  79. # | Dgrad | (strided, unity*) | (unity)
  80. # | Wgrad | (strided) | (strided)
  81. # ____________________________________________________________________
  82. #
  83. # Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
  84. ###########################################################################################################
  85. # Convolution for 2D operations
  86. def CreateConv2dOperator(
  87. manifest,
  88. layout,
  89. tile_descriptions,
  90. data_type,
  91. alignment,
  92. conv_kinds=[ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad],
  93. epilogue_functor=EpilogueFunctor.LinearCombination,
  94. ):
  95. element_a, element_b, element_c, element_epilogue = data_type
  96. # one exceptional case
  97. alignment_c = min(8, alignment)
  98. # iterator algorithm (analytic and optimized)
  99. iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
  100. # by default, only generate the largest tile size
  101. if manifest.args.kernels == "":
  102. tile_descriptions = [tile_descriptions[0]]
  103. operations = []
  104. for tile in tile_descriptions:
  105. for conv_kind in conv_kinds:
  106. for iterator_algorithm in iterator_algorithms:
  107. A = TensorDescription(element_a, layout[0], alignment)
  108. B = TensorDescription(element_b, layout[1], alignment)
  109. C = TensorDescription(element_c, layout[2], alignment_c)
  110. # unity stride only for Optimized Dgrad
  111. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
  112. conv_kind == ConvKind.Dgrad
  113. ):
  114. new_operation = Conv2dOperation(
  115. conv_kind,
  116. iterator_algorithm,
  117. tile.minimum_compute_capability,
  118. tile,
  119. A,
  120. B,
  121. C,
  122. element_epilogue,
  123. StrideSupport.Unity,
  124. epilogue_functor,
  125. )
  126. manifest.append(new_operation)
  127. operations.append(new_operation)
  128. # strided dgrad is not supported by Optimized Dgrad
  129. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
  130. conv_kind == ConvKind.Dgrad
  131. ):
  132. continue
  133. # strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
  134. new_operation = Conv2dOperation(
  135. conv_kind,
  136. iterator_algorithm,
  137. tile.minimum_compute_capability,
  138. tile,
  139. A,
  140. B,
  141. C,
  142. element_epilogue,
  143. StrideSupport.Strided,
  144. epilogue_functor,
  145. )
  146. manifest.append(new_operation)
  147. operations.append(new_operation)
  148. return operations
  149. ###################################################################################################
  150. ###################################################################################################
  151. def GenerateConv2d_Simt(args):
  152. operations = []
  153. layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorC4RSK4)]
  154. math_instructions = [
  155. MathInstruction(
  156. [1, 1, 4],
  157. DataType.s8,
  158. DataType.s8,
  159. DataType.s32,
  160. OpcodeClass.Simt,
  161. MathOperation.multiply_add,
  162. )
  163. ]
  164. dst_layouts = [
  165. LayoutType.TensorNC4HW4,
  166. LayoutType.TensorNC32HW32,
  167. LayoutType.TensorNHWC,
  168. LayoutType.TensorNHWC,
  169. LayoutType.TensorNCHW,
  170. ]
  171. dst_types = [DataType.s8, DataType.s8, DataType.u4, DataType.s4, DataType.f32]
  172. max_cc = 1024
  173. for math_inst in math_instructions:
  174. for layout in layouts:
  175. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  176. if dst_type == DataType.s4 or dst_type == DataType.u4:
  177. min_cc = 75
  178. use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt
  179. else:
  180. min_cc = 61
  181. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  182. tile_descriptions = [
  183. TileDescription(
  184. [128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  185. ),
  186. TileDescription(
  187. [128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  188. ),
  189. TileDescription(
  190. [64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc
  191. ),
  192. TileDescription(
  193. [128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc
  194. ),
  195. TileDescription(
  196. [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
  197. ),
  198. TileDescription(
  199. [32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  200. ),
  201. TileDescription(
  202. [64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  203. ),
  204. TileDescription(
  205. [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
  206. ),
  207. TileDescription(
  208. [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
  209. ),
  210. ]
  211. for tile in tile_descriptions:
  212. if (
  213. dst_layout == LayoutType.TensorNC32HW32
  214. and tile.threadblock_shape[0] > 32
  215. ):
  216. continue
  217. if (
  218. dst_layout == LayoutType.TensorNCHW
  219. or dst_layout == LayoutType.TensorNHWC
  220. ) and tile.threadblock_shape[0] > 16:
  221. continue
  222. operations += GenerateConv2d(
  223. ConvType.Convolution,
  224. ConvKind.Fprop,
  225. [tile],
  226. layout[0],
  227. layout[1],
  228. dst_layout,
  229. dst_type,
  230. min_cc,
  231. 32,
  232. 32,
  233. 32,
  234. use_special_optimization,
  235. )
  236. return operations
  237. def GenerateConv2d_TensorOp_8816(args):
  238. operations = []
  239. layouts = [(LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32)]
  240. math_instructions = [
  241. MathInstruction(
  242. [8, 8, 16],
  243. DataType.s8,
  244. DataType.s8,
  245. DataType.s32,
  246. OpcodeClass.TensorOp,
  247. MathOperation.multiply_add_saturate,
  248. )
  249. ]
  250. dst_layouts = [LayoutType.TensorNC32HW32, LayoutType.TensorNC4HW4]
  251. dst_types = [DataType.s8, DataType.s8]
  252. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  253. min_cc = 75
  254. max_cc = 1024
  255. cuda_major = 10
  256. cuda_minor = 2
  257. for math_inst in math_instructions:
  258. for layout in layouts:
  259. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  260. if dst_layout == LayoutType.TensorNC32HW32:
  261. tile_descriptions = [
  262. TileDescription(
  263. [128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc
  264. ),
  265. TileDescription(
  266. [256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc
  267. ),
  268. TileDescription(
  269. [128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  270. ),
  271. TileDescription(
  272. [128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  273. ),
  274. TileDescription(
  275. [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  276. ),
  277. TileDescription(
  278. [128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc
  279. ),
  280. TileDescription(
  281. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  282. ),
  283. ]
  284. operations += GenerateConv2d(
  285. ConvType.Convolution,
  286. ConvKind.Fprop,
  287. tile_descriptions,
  288. layout[0],
  289. layout[1],
  290. dst_layout,
  291. dst_type,
  292. min_cc,
  293. 128,
  294. 128,
  295. 64,
  296. use_special_optimization,
  297. ImplicitGemmMode.GemmTN,
  298. True,
  299. cuda_major,
  300. cuda_minor,
  301. )
  302. else:
  303. assert dst_layout == LayoutType.TensorNC4HW4
  304. tile_descriptions = [
  305. TileDescription(
  306. [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  307. ),
  308. TileDescription(
  309. [32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc
  310. ),
  311. ]
  312. operations += GenerateConv2d(
  313. ConvType.Convolution,
  314. ConvKind.Fprop,
  315. tile_descriptions,
  316. layout[0],
  317. layout[1],
  318. dst_layout,
  319. dst_type,
  320. min_cc,
  321. 128,
  322. 128,
  323. 64,
  324. use_special_optimization,
  325. ImplicitGemmMode.GemmNT,
  326. False,
  327. cuda_major,
  328. cuda_minor,
  329. )
  330. layouts_nhwc = [
  331. (LayoutType.TensorNHWC, LayoutType.TensorNC4HW4, 32),
  332. (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 64),
  333. (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 128),
  334. ]
  335. dst_layouts_nhwc = [LayoutType.TensorNHWC]
  336. for math_inst in math_instructions:
  337. for layout in layouts_nhwc:
  338. for dst_layout in dst_layouts_nhwc:
  339. dst_type = math_inst.element_b
  340. tile_descriptions = [
  341. TileDescription(
  342. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  343. ),
  344. TileDescription(
  345. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  346. ),
  347. ]
  348. for tile in tile_descriptions:
  349. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  350. operations += GenerateConv2d(
  351. ConvType.Convolution,
  352. ConvKind.Fprop,
  353. [tile],
  354. layout[0],
  355. layout[1],
  356. dst_layout,
  357. dst_type,
  358. min_cc,
  359. layout[2],
  360. layout[2],
  361. dst_align,
  362. use_special_optimization,
  363. ImplicitGemmMode.GemmTN,
  364. False,
  365. cuda_major,
  366. cuda_minor,
  367. )
  368. if (
  369. tile.threadblock_shape[1] == 16
  370. or tile.threadblock_shape[1] == 32
  371. ):
  372. operations += GenerateConv2d(
  373. ConvType.Convolution,
  374. ConvKind.Fprop,
  375. [tile],
  376. layout[0],
  377. layout[1],
  378. dst_layout,
  379. dst_type,
  380. min_cc,
  381. layout[2],
  382. layout[2],
  383. dst_align,
  384. use_special_optimization,
  385. ImplicitGemmMode.GemmTN,
  386. True,
  387. cuda_major,
  388. cuda_minor,
  389. )
  390. out_dtypes = [DataType.s4, DataType.u4, DataType.f32]
  391. # INT8x8x4 and INT8x8x32
  392. for math_inst in math_instructions:
  393. for layout in layouts_nhwc:
  394. for dst_layout in dst_layouts_nhwc:
  395. for out_dtype in out_dtypes:
  396. tile_descriptions = [
  397. TileDescription(
  398. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  399. ),
  400. TileDescription(
  401. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  402. ),
  403. ]
  404. for tile in tile_descriptions:
  405. dst_align = (
  406. 4 * DataTypeSize[out_dtype]
  407. if tile.threadblock_shape[1] == 16
  408. or out_dtype == DataType.f32
  409. else 8 * DataTypeSize[out_dtype]
  410. )
  411. operations += GenerateConv2d(
  412. ConvType.Convolution,
  413. ConvKind.Fprop,
  414. [tile],
  415. layout[0],
  416. layout[1],
  417. dst_layout,
  418. out_dtype,
  419. min_cc,
  420. layout[2],
  421. layout[2],
  422. dst_align,
  423. use_special_optimization,
  424. ImplicitGemmMode.GemmTN,
  425. False,
  426. cuda_major,
  427. cuda_minor,
  428. )
  429. if tile.threadblock_shape[1] == 16 or (
  430. tile.threadblock_shape[1] == 32
  431. and out_dtype != DataType.f32
  432. ):
  433. operations += GenerateConv2d(
  434. ConvType.Convolution,
  435. ConvKind.Fprop,
  436. [tile],
  437. layout[0],
  438. layout[1],
  439. dst_layout,
  440. out_dtype,
  441. min_cc,
  442. layout[2],
  443. layout[2],
  444. dst_align,
  445. use_special_optimization,
  446. ImplicitGemmMode.GemmTN,
  447. True,
  448. cuda_major,
  449. cuda_minor,
  450. )
  451. return operations
  452. def GenerateConv2d_TensorOp_8832(args):
  453. operations = []
  454. layouts = [(LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64)]
  455. math_instructions = [
  456. MathInstruction(
  457. [8, 8, 32],
  458. DataType.s4,
  459. DataType.s4,
  460. DataType.s32,
  461. OpcodeClass.TensorOp,
  462. MathOperation.multiply_add_saturate,
  463. ),
  464. MathInstruction(
  465. [8, 8, 32],
  466. DataType.s4,
  467. DataType.u4,
  468. DataType.s32,
  469. OpcodeClass.TensorOp,
  470. MathOperation.multiply_add_saturate,
  471. ),
  472. ]
  473. dst_layouts = [LayoutType.TensorNC64HW64]
  474. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  475. min_cc = 75
  476. max_cc = 1024
  477. cuda_major = 10
  478. cuda_minor = 2
  479. for math_inst in math_instructions:
  480. for layout in layouts:
  481. for dst_layout in dst_layouts:
  482. dst_type = math_inst.element_b
  483. tile_descriptions = [
  484. TileDescription(
  485. [128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc
  486. ),
  487. TileDescription(
  488. [128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc
  489. ),
  490. TileDescription(
  491. [128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc
  492. ),
  493. TileDescription(
  494. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  495. ),
  496. ]
  497. operations += GenerateConv2d(
  498. ConvType.Convolution,
  499. ConvKind.Fprop,
  500. tile_descriptions,
  501. layout[0],
  502. layout[1],
  503. dst_layout,
  504. dst_type,
  505. min_cc,
  506. 128,
  507. 128,
  508. 64,
  509. use_special_optimization,
  510. ImplicitGemmMode.GemmTN,
  511. True,
  512. cuda_major,
  513. cuda_minor,
  514. )
  515. layouts_nhwc = [
  516. (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32),
  517. (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 64),
  518. (LayoutType.TensorNHWC, LayoutType.TensorNC32HW32, 128),
  519. ]
  520. dst_layouts_nhwc = [LayoutType.TensorNHWC]
  521. for math_inst in math_instructions:
  522. for layout in layouts_nhwc:
  523. for dst_layout in dst_layouts_nhwc:
  524. dst_type = math_inst.element_b
  525. tile_descriptions = [
  526. TileDescription(
  527. [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
  528. ),
  529. TileDescription(
  530. [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  531. ),
  532. TileDescription(
  533. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  534. ),
  535. ]
  536. for tile in tile_descriptions:
  537. dst_align = 16 if tile.threadblock_shape[1] == 16 else 32
  538. operations += GenerateConv2d(
  539. ConvType.Convolution,
  540. ConvKind.Fprop,
  541. [tile],
  542. layout[0],
  543. layout[1],
  544. dst_layout,
  545. dst_type,
  546. min_cc,
  547. layout[2],
  548. layout[2],
  549. dst_align,
  550. use_special_optimization,
  551. ImplicitGemmMode.GemmTN,
  552. False,
  553. cuda_major,
  554. cuda_minor,
  555. )
  556. if (
  557. tile.threadblock_shape[1] == 32
  558. or tile.threadblock_shape[1] == 64
  559. ):
  560. dst_align = 32 if tile.threadblock_shape[1] == 32 else 64
  561. operations += GenerateConv2d(
  562. ConvType.Convolution,
  563. ConvKind.Fprop,
  564. [tile],
  565. layout[0],
  566. layout[1],
  567. dst_layout,
  568. dst_type,
  569. min_cc,
  570. layout[2],
  571. layout[2],
  572. dst_align,
  573. use_special_optimization,
  574. ImplicitGemmMode.GemmTN,
  575. True,
  576. cuda_major,
  577. cuda_minor,
  578. )
  579. # INT4x4x8
  580. for math_inst in math_instructions:
  581. for layout in layouts_nhwc:
  582. for dst_layout in dst_layouts_nhwc:
  583. tile_descriptions = [
  584. TileDescription(
  585. [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
  586. ),
  587. TileDescription(
  588. [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  589. ),
  590. TileDescription(
  591. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  592. ),
  593. ]
  594. for tile in tile_descriptions:
  595. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  596. operations += GenerateConv2d(
  597. ConvType.Convolution,
  598. ConvKind.Fprop,
  599. [tile],
  600. layout[0],
  601. layout[1],
  602. dst_layout,
  603. DataType.s8,
  604. min_cc,
  605. layout[2],
  606. layout[2],
  607. dst_align,
  608. use_special_optimization,
  609. ImplicitGemmMode.GemmTN,
  610. False,
  611. cuda_major,
  612. cuda_minor,
  613. )
  614. if (
  615. tile.threadblock_shape[1] == 32
  616. or tile.threadblock_shape[1] == 64
  617. ):
  618. dst_align = 64 if tile.threadblock_shape[1] == 32 else 128
  619. operations += GenerateConv2d(
  620. ConvType.Convolution,
  621. ConvKind.Fprop,
  622. [tile],
  623. layout[0],
  624. layout[1],
  625. dst_layout,
  626. DataType.s8,
  627. min_cc,
  628. layout[2],
  629. layout[2],
  630. dst_align,
  631. use_special_optimization,
  632. ImplicitGemmMode.GemmTN,
  633. True,
  634. cuda_major,
  635. cuda_minor,
  636. )
  637. return operations
  638. def GenerateDeconv_Simt(args):
  639. operations = []
  640. layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorK4RSC4)]
  641. math_instructions = [
  642. MathInstruction(
  643. [1, 1, 4],
  644. DataType.s8,
  645. DataType.s8,
  646. DataType.s32,
  647. OpcodeClass.Simt,
  648. MathOperation.multiply_add,
  649. )
  650. ]
  651. dst_layouts = [LayoutType.TensorNC4HW4]
  652. dst_types = [DataType.s8]
  653. use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
  654. min_cc = 61
  655. max_cc = 1024
  656. for math_inst in math_instructions:
  657. for layout in layouts:
  658. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  659. tile_descriptions = [
  660. TileDescription(
  661. [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
  662. ),
  663. TileDescription(
  664. [16, 128, 16], 2, [1, 2, 1], math_inst, min_cc, max_cc
  665. ),
  666. TileDescription(
  667. [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
  668. ),
  669. TileDescription(
  670. [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
  671. ),
  672. ]
  673. operations += GenerateConv2d(
  674. ConvType.Convolution,
  675. ConvKind.Dgrad,
  676. tile_descriptions,
  677. layout[0],
  678. layout[1],
  679. dst_layout,
  680. dst_type,
  681. min_cc,
  682. 32,
  683. 32,
  684. 32,
  685. use_special_optimization,
  686. )
  687. return operations
  688. def GenerateDeconv_TensorOp_8816(args):
  689. operations = []
  690. layouts = [
  691. (LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32),
  692. (LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64),
  693. (LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128),
  694. ]
  695. math_instructions = [
  696. MathInstruction(
  697. [8, 8, 16],
  698. DataType.s8,
  699. DataType.s8,
  700. DataType.s32,
  701. OpcodeClass.TensorOp,
  702. MathOperation.multiply_add_saturate,
  703. )
  704. ]
  705. dst_layouts = [LayoutType.TensorNHWC]
  706. dst_types = [DataType.s8]
  707. use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
  708. min_cc = 75
  709. max_cc = 1024
  710. cuda_major = 10
  711. cuda_minor = 2
  712. for math_inst in math_instructions:
  713. for layout in layouts:
  714. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  715. tile_descriptions = [
  716. TileDescription(
  717. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  718. ),
  719. TileDescription(
  720. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  721. ),
  722. ]
  723. for tile in tile_descriptions:
  724. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  725. operations += GenerateConv2d(
  726. ConvType.Convolution,
  727. ConvKind.Dgrad,
  728. [tile],
  729. layout[0],
  730. layout[1],
  731. dst_layout,
  732. dst_type,
  733. min_cc,
  734. layout[2],
  735. layout[2],
  736. dst_align,
  737. use_special_optimization,
  738. ImplicitGemmMode.GemmTN,
  739. False,
  740. cuda_major,
  741. cuda_minor,
  742. )
  743. return operations
  744. ################################################################################
  745. # parameters
  746. # Edge - for tiles, the edges represent the length of one side
  747. # Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles
  748. # MaxEdge - maximum length of each edge
  749. # Min/Max - minimum/maximum of the product of edge lengths
  750. ################################################################################
  751. warpsPerThreadblockEdge = [1, 2, 4, 8, 16]
  752. warpsPerThreadblockRatio = 2
  753. warpsPerThreadblockMax = 16
  754. # NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases
  755. warpShapeEdges = [8, 16, 32, 64, 128, 256]
  756. warpShapeRatio = 4
  757. warpShapeMax = 64 * 64
  758. warpShapeMin = 8 * 8
  759. threadblockEdgeMax = 256
  760. # char, type bits/elem, max tile, L0 threadblock tiles
  761. precisions = {
  762. "c": ["cutlass::complex<float>", 64, 64 * 128, [[64, 128], [64, 32]]],
  763. "d": ["double", 64, 64 * 64, [[64, 64], [32, 32]]],
  764. "h": ["cutlass::half_t", 16, 128 * 256, [[256, 128], [64, 128], [64, 32]]],
  765. "i": ["int", 32, 128 * 128, [[128, 64], [16, 32]]],
  766. "s": ["float", 32, 128 * 128, [[128, 256], [128, 128], [64, 64]]],
  767. "z": ["cutlass::complex<double>", 128, 64 * 64, [[32, 64], [16, 32]]],
  768. }
  769. # L1 will have a single kernel for every unique shape
  770. # L2 will have everything else
  771. def GenerateGemm_Simt(args):
  772. ################################################################################
  773. # warps per threadblock
  774. ################################################################################
  775. warpsPerThreadblocks = []
  776. for warpsPerThreadblock0 in warpsPerThreadblockEdge:
  777. for warpsPerThreadblock1 in warpsPerThreadblockEdge:
  778. if (
  779. warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
  780. and warpsPerThreadblock1 / warpsPerThreadblock0
  781. <= warpsPerThreadblockRatio
  782. and warpsPerThreadblock0 * warpsPerThreadblock1
  783. <= warpsPerThreadblockMax
  784. ):
  785. warpsPerThreadblocks.append(
  786. [warpsPerThreadblock0, warpsPerThreadblock1]
  787. )
  788. ################################################################################
  789. # warp shapes
  790. ################################################################################
  791. warpNumThreads = 32
  792. warpShapes = []
  793. for warp0 in warpShapeEdges:
  794. for warp1 in warpShapeEdges:
  795. if (
  796. warp0 / warp1 <= warpShapeRatio
  797. and warp1 / warp0 <= warpShapeRatio
  798. and warp0 * warp1 <= warpShapeMax
  799. and warp0 * warp1 > warpShapeMin
  800. ):
  801. warpShapes.append([warp0, warp1])
  802. # sgemm
  803. precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[
  804. "s"
  805. ]
  806. layouts = [
  807. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  808. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  809. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  810. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  811. ]
  812. math_instructions = [
  813. MathInstruction(
  814. [1, 1, 1],
  815. DataType.f32,
  816. DataType.f32,
  817. DataType.f32,
  818. OpcodeClass.Simt,
  819. MathOperation.multiply_add,
  820. )
  821. ]
  822. min_cc = 50
  823. max_cc = 1024
  824. operations = []
  825. for math_inst in math_instructions:
  826. for layout in layouts:
  827. data_type = [
  828. math_inst.element_a,
  829. math_inst.element_b,
  830. math_inst.element_accumulator,
  831. math_inst.element_accumulator,
  832. ]
  833. tile_descriptions = [
  834. TileDescription([64, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  835. TileDescription([256, 64, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  836. TileDescription([32, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  837. TileDescription([256, 32, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  838. TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  839. TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  840. TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  841. TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  842. TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  843. TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  844. TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  845. TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  846. TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  847. TileDescription([8, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  848. TileDescription([16, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  849. TileDescription([16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  850. TileDescription([16, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  851. ]
  852. for warpsPerThreadblock in warpsPerThreadblocks:
  853. for warpShape in warpShapes:
  854. warpThreadsM = 0
  855. if warpShape[0] > warpShape[1]:
  856. warpThreadsM = 8
  857. else:
  858. warpThreadsM = 4
  859. warpThreadsN = warpNumThreads / warpThreadsM
  860. # skip shapes with conflicting rectangularity
  861. # they are unlikely to be fastest
  862. blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
  863. blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
  864. warpG = warpShape[0] > warpShape[1]
  865. warpL = warpShape[0] < warpShape[1]
  866. blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
  867. blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
  868. warpG2 = warpShape[0] > warpShape[1] * 2
  869. warpL2 = warpShape[0] * 2 < warpShape[1]
  870. if blockG2 and warpL:
  871. continue
  872. if blockL2 and warpG:
  873. continue
  874. if warpG2 and blockL:
  875. continue
  876. if warpL2 and blockG:
  877. continue
  878. # check threadblock ratios and max
  879. threadblockTile = [
  880. warpShape[0] * warpsPerThreadblock[0],
  881. warpShape[1] * warpsPerThreadblock[1],
  882. ]
  883. if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
  884. continue
  885. if threadblockTile[0] > threadblockEdgeMax:
  886. continue
  887. if threadblockTile[1] > threadblockEdgeMax:
  888. continue
  889. totalThreads = (
  890. warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
  891. )
  892. # calculate unroll
  893. # ensure that every iteration at least a full load of A,B are done
  894. unrollMin = 8
  895. unrollMin0 = totalThreads // threadblockTile[0]
  896. unrollMin1 = totalThreads // threadblockTile[1]
  897. unroll = max(unrollMin, unrollMin0, unrollMin1)
  898. threadTileM = warpShape[0] // warpThreadsM
  899. threadTileN = warpShape[1] // warpThreadsN
  900. if threadTileM < 2 or threadTileN < 2:
  901. continue
  902. if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
  903. continue
  904. # epilogue currently only supports N < WarpNumThreads
  905. if threadblockTile[1] < warpNumThreads:
  906. continue
  907. # limit smem
  908. smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
  909. smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
  910. smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
  911. if smemKBytes > 48:
  912. continue
  913. tile = TileDescription(
  914. [threadblockTile[0], threadblockTile[1], unroll],
  915. 2,
  916. [
  917. threadblockTile[0] // warpShape[0],
  918. threadblockTile[1] // warpShape[1],
  919. 1,
  920. ],
  921. math_inst,
  922. min_cc,
  923. max_cc,
  924. )
  925. def filter(t: TileDescription) -> bool:
  926. nonlocal tile
  927. return (
  928. t.threadblock_shape[0] == tile.threadblock_shape[0]
  929. and t.threadblock_shape[1] == tile.threadblock_shape[1]
  930. and t.threadblock_shape[2] == tile.threadblock_shape[2]
  931. and t.warp_count[0] == tile.warp_count[0]
  932. and t.warp_count[1] == tile.warp_count[1]
  933. and t.warp_count[2] == tile.warp_count[2]
  934. and t.stages == tile.stages
  935. )
  936. if not any(t for t in tile_descriptions if filter(t)):
  937. continue
  938. operations += GeneratesGemm(
  939. tile, data_type, layout[0], layout[1], layout[2], min_cc
  940. )
  941. return operations
  942. def GenerateDwconv2dFprop_Simt(args):
  943. ################################################################################
  944. # warps per threadblock
  945. ################################################################################
  946. warpsPerThreadblocks = []
  947. for warpsPerThreadblock0 in warpsPerThreadblockEdge:
  948. for warpsPerThreadblock1 in warpsPerThreadblockEdge:
  949. if (
  950. warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
  951. and warpsPerThreadblock1 / warpsPerThreadblock0
  952. <= warpsPerThreadblockRatio
  953. and warpsPerThreadblock0 * warpsPerThreadblock1
  954. <= warpsPerThreadblockMax
  955. ):
  956. warpsPerThreadblocks.append(
  957. [warpsPerThreadblock0, warpsPerThreadblock1]
  958. )
  959. ################################################################################
  960. # warp shapes
  961. ################################################################################
  962. warpNumThreads = 32
  963. warpShapes = []
  964. for warp0 in warpShapeEdges:
  965. for warp1 in warpShapeEdges:
  966. if (
  967. warp0 / warp1 <= warpShapeRatio
  968. and warp1 / warp0 <= warpShapeRatio
  969. and warp0 * warp1 <= warpShapeMax
  970. and warp0 * warp1 > warpShapeMin
  971. ):
  972. warpShapes.append([warp0, warp1])
  973. # sgemm
  974. precisionType, precisionBits, threadblockMaxElements, threadblockTilesL0 = precisions[
  975. "s"
  976. ]
  977. layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
  978. math_instructions = [
  979. MathInstruction(
  980. [1, 1, 1],
  981. DataType.f32,
  982. DataType.f32,
  983. DataType.f32,
  984. OpcodeClass.Simt,
  985. MathOperation.multiply_add,
  986. )
  987. ]
  988. min_cc = 50
  989. max_cc = 1024
  990. dst_layouts = [LayoutType.TensorNCHW]
  991. dst_types = [DataType.f32]
  992. alignment_constraints = [128, 32]
  993. operations = []
  994. for math_inst in math_instructions:
  995. tile_descriptions = [
  996. TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  997. TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  998. TileDescription([64, 128, 8], 2, [1, 4, 1], math_inst, min_cc, max_cc),
  999. TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  1000. TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  1001. TileDescription([64, 64, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  1002. TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1003. TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1004. TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1005. ]
  1006. for warpsPerThreadblock in warpsPerThreadblocks:
  1007. for warpShape in warpShapes:
  1008. warpThreadsM = 0
  1009. if warpShape[0] > warpShape[1]:
  1010. warpThreadsM = 8
  1011. else:
  1012. warpThreadsM = 4
  1013. warpThreadsN = warpNumThreads / warpThreadsM
  1014. # skip shapes with conflicting rectangularity
  1015. # they are unlikely to be fastest
  1016. blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
  1017. blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
  1018. warpG = warpShape[0] > warpShape[1]
  1019. warpL = warpShape[0] < warpShape[1]
  1020. blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
  1021. blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
  1022. warpG2 = warpShape[0] > warpShape[1] * 2
  1023. warpL2 = warpShape[0] * 2 < warpShape[1]
  1024. if blockG2 and warpL:
  1025. continue
  1026. if blockL2 and warpG:
  1027. continue
  1028. if warpG2 and blockL:
  1029. continue
  1030. if warpL2 and blockG:
  1031. continue
  1032. # check threadblock ratios and max
  1033. threadblockTile = [
  1034. warpShape[0] * warpsPerThreadblock[0],
  1035. warpShape[1] * warpsPerThreadblock[1],
  1036. ]
  1037. if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
  1038. continue
  1039. if threadblockTile[0] > threadblockEdgeMax:
  1040. continue
  1041. if threadblockTile[1] > threadblockEdgeMax:
  1042. continue
  1043. totalThreads = (
  1044. warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
  1045. )
  1046. # calculate unroll
  1047. # ensure that every iteration at least a full load of A,B are done
  1048. unrollMin = 8
  1049. unrollMin0 = totalThreads // threadblockTile[0]
  1050. unrollMin1 = totalThreads // threadblockTile[1]
  1051. unroll = max(unrollMin, unrollMin0, unrollMin1)
  1052. threadTileM = warpShape[0] // warpThreadsM
  1053. threadTileN = warpShape[1] // warpThreadsN
  1054. if threadTileM < 2 or threadTileN < 2:
  1055. continue
  1056. if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
  1057. continue
  1058. # epilogue currently only supports N < WarpNumThreads
  1059. if threadblockTile[1] < warpNumThreads:
  1060. continue
  1061. # limit smem
  1062. smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
  1063. smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
  1064. smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
  1065. if smemKBytes > 48:
  1066. continue
  1067. tile = TileDescription(
  1068. [threadblockTile[0], threadblockTile[1], unroll],
  1069. 2,
  1070. [
  1071. threadblockTile[0] // warpShape[0],
  1072. threadblockTile[1] // warpShape[1],
  1073. 1,
  1074. ],
  1075. math_inst,
  1076. min_cc,
  1077. max_cc,
  1078. )
  1079. def filter(t: TileDescription) -> bool:
  1080. nonlocal tile
  1081. return (
  1082. t.threadblock_shape[0] == tile.threadblock_shape[0]
  1083. and t.threadblock_shape[1] == tile.threadblock_shape[1]
  1084. and t.threadblock_shape[2] == tile.threadblock_shape[2]
  1085. and t.warp_count[0] == tile.warp_count[0]
  1086. and t.warp_count[1] == tile.warp_count[1]
  1087. and t.warp_count[2] == tile.warp_count[2]
  1088. and t.stages == tile.stages
  1089. )
  1090. if not any(t for t in tile_descriptions if filter(t)):
  1091. continue
  1092. for layout in layouts:
  1093. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  1094. for alignment_src in alignment_constraints:
  1095. operations += GenerateConv2d(
  1096. ConvType.DepthwiseConvolution,
  1097. ConvKind.Fprop,
  1098. [tile],
  1099. layout[0],
  1100. layout[1],
  1101. dst_layout,
  1102. dst_type,
  1103. min_cc,
  1104. alignment_src,
  1105. 32,
  1106. 32,
  1107. SpecialOptimizeDesc.NoneSpecialOpt,
  1108. ImplicitGemmMode.GemmTN,
  1109. )
  1110. return operations
  1111. #
  1112. def GenerateDwconv2dFprop_TensorOp_884(args):
  1113. layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
  1114. math_instructions = [
  1115. MathInstruction(
  1116. [8, 8, 4],
  1117. DataType.f16,
  1118. DataType.f16,
  1119. DataType.f32,
  1120. OpcodeClass.TensorOp,
  1121. MathOperation.multiply_add,
  1122. ),
  1123. MathInstruction(
  1124. [8, 8, 4],
  1125. DataType.f16,
  1126. DataType.f16,
  1127. DataType.f16,
  1128. OpcodeClass.TensorOp,
  1129. MathOperation.multiply_add,
  1130. ),
  1131. ]
  1132. min_cc = 70
  1133. max_cc = 75
  1134. dst_layouts = [LayoutType.TensorNCHW]
  1135. dst_types = [DataType.f16]
  1136. alignment_constraints = [128, 32, 16]
  1137. cuda_major = 10
  1138. cuda_minor = 2
  1139. operations = []
  1140. for math_inst in math_instructions:
  1141. tile_descriptions = [
  1142. TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  1143. TileDescription([128, 128, 32], 2, [4, 4, 1], math_inst, min_cc, max_cc),
  1144. TileDescription([64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  1145. TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  1146. TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1147. ]
  1148. for layout in layouts:
  1149. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  1150. for alignment_src in alignment_constraints:
  1151. operations += GenerateConv2d(
  1152. ConvType.DepthwiseConvolution,
  1153. ConvKind.Fprop,
  1154. tile_descriptions,
  1155. layout[0],
  1156. layout[1],
  1157. dst_layout,
  1158. dst_type,
  1159. min_cc,
  1160. alignment_src,
  1161. 16,
  1162. 16,
  1163. SpecialOptimizeDesc.NoneSpecialOpt,
  1164. ImplicitGemmMode.GemmTN,
  1165. False,
  1166. cuda_major,
  1167. cuda_minor,
  1168. )
  1169. return operations
  1170. #
  1171. def GenerateGemv_Simt(args):
  1172. threadBlockShape_N = [128, 64, 32]
  1173. ldgBits_A = [128, 64, 32]
  1174. ldgBits_B = [128, 64, 32]
  1175. layouts = [(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor)]
  1176. math_instructions = [
  1177. MathInstruction(
  1178. [1, 1, 1],
  1179. DataType.f32,
  1180. DataType.f32,
  1181. DataType.f32,
  1182. OpcodeClass.Simt,
  1183. MathOperation.multiply_add,
  1184. )
  1185. ]
  1186. min_cc = 50
  1187. operations = []
  1188. for math_inst in math_instructions:
  1189. for layout in layouts:
  1190. data_type = [
  1191. math_inst.element_a,
  1192. math_inst.element_b,
  1193. math_inst.element_accumulator,
  1194. math_inst.element_accumulator,
  1195. ]
  1196. for threadblock_shape_n in threadBlockShape_N:
  1197. for align_a in ldgBits_A:
  1198. for align_b in ldgBits_B:
  1199. ldg_elements_a = align_a // DataTypeSize[math_inst.element_a]
  1200. ldg_elements_b = align_b // DataTypeSize[math_inst.element_b]
  1201. threadblock_shape_k = (256 * ldg_elements_a) // (
  1202. threadblock_shape_n // ldg_elements_b
  1203. )
  1204. threadblock_shape = [
  1205. 1,
  1206. threadblock_shape_n,
  1207. threadblock_shape_k,
  1208. ]
  1209. thread_shape = [1, ldg_elements_b, ldg_elements_a]
  1210. operations.append(
  1211. GeneratesGemv(
  1212. math_inst,
  1213. threadblock_shape,
  1214. thread_shape,
  1215. data_type,
  1216. layout[0],
  1217. layout[1],
  1218. layout[2],
  1219. min_cc,
  1220. align_a,
  1221. align_b,
  1222. )
  1223. )
  1224. return operations
  1225. #
  1226. def GeneratesGemm_TensorOp_1688(args):
  1227. layouts = [
  1228. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  1229. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  1230. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  1231. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  1232. ]
  1233. math_instructions = [
  1234. MathInstruction(
  1235. [16, 8, 8],
  1236. DataType.f16,
  1237. DataType.f16,
  1238. DataType.f32,
  1239. OpcodeClass.TensorOp,
  1240. MathOperation.multiply_add,
  1241. ),
  1242. MathInstruction(
  1243. [16, 8, 8],
  1244. DataType.f16,
  1245. DataType.f16,
  1246. DataType.f16,
  1247. OpcodeClass.TensorOp,
  1248. MathOperation.multiply_add,
  1249. ),
  1250. ]
  1251. min_cc = 75
  1252. max_cc = 1024
  1253. alignment_constraints = [
  1254. 8,
  1255. 4,
  1256. 2,
  1257. # 1
  1258. ]
  1259. cuda_major = 10
  1260. cuda_minor = 2
  1261. operations = []
  1262. for math_inst in math_instructions:
  1263. for layout in layouts:
  1264. for align in alignment_constraints:
  1265. tile_descriptions = [
  1266. TileDescription(
  1267. [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
  1268. ),
  1269. TileDescription(
  1270. [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  1271. ),
  1272. TileDescription(
  1273. [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  1274. ),
  1275. ## comment some configuration to reduce compilation time and binary size
  1276. # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1277. # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1278. # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1279. ]
  1280. data_type = [
  1281. math_inst.element_a,
  1282. math_inst.element_b,
  1283. math_inst.element_a,
  1284. math_inst.element_accumulator,
  1285. ]
  1286. for tile in tile_descriptions:
  1287. operations += GeneratesGemm(
  1288. tile,
  1289. data_type,
  1290. layout[0],
  1291. layout[1],
  1292. layout[2],
  1293. min_cc,
  1294. align * 16,
  1295. align * 16,
  1296. align * 16,
  1297. cuda_major,
  1298. cuda_minor,
  1299. )
  1300. return operations
  1301. #
  1302. def GeneratesGemm_TensorOp_884(args):
  1303. layouts = [
  1304. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  1305. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  1306. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  1307. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  1308. ]
  1309. math_instructions = [
  1310. MathInstruction(
  1311. [8, 8, 4],
  1312. DataType.f16,
  1313. DataType.f16,
  1314. DataType.f32,
  1315. OpcodeClass.TensorOp,
  1316. MathOperation.multiply_add,
  1317. ),
  1318. MathInstruction(
  1319. [8, 8, 4],
  1320. DataType.f16,
  1321. DataType.f16,
  1322. DataType.f16,
  1323. OpcodeClass.TensorOp,
  1324. MathOperation.multiply_add,
  1325. ),
  1326. ]
  1327. min_cc = 70
  1328. max_cc = 75
  1329. alignment_constraints = [
  1330. 8,
  1331. 4,
  1332. 2,
  1333. # 1
  1334. ]
  1335. cuda_major = 10
  1336. cuda_minor = 2
  1337. operations = []
  1338. for math_inst in math_instructions:
  1339. for layout in layouts:
  1340. for align in alignment_constraints:
  1341. tile_descriptions = [
  1342. TileDescription(
  1343. [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
  1344. ),
  1345. TileDescription(
  1346. [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  1347. ),
  1348. TileDescription(
  1349. [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  1350. ),
  1351. ## comment some configuration to reduce compilation time and binary size
  1352. # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1353. # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1354. # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1355. ]
  1356. data_type = [
  1357. math_inst.element_a,
  1358. math_inst.element_b,
  1359. math_inst.element_a,
  1360. math_inst.element_accumulator,
  1361. ]
  1362. for tile in tile_descriptions:
  1363. operations += GeneratesGemm(
  1364. tile,
  1365. data_type,
  1366. layout[0],
  1367. layout[1],
  1368. layout[2],
  1369. min_cc,
  1370. align * 16,
  1371. align * 16,
  1372. align * 16,
  1373. cuda_major,
  1374. cuda_minor,
  1375. )
  1376. return operations
  1377. #
  1378. def GenerateConv2dOperations(args):
  1379. if args.type == "simt":
  1380. return GenerateConv2d_Simt(args)
  1381. elif args.type == "tensorop8816":
  1382. return GenerateConv2d_TensorOp_8816(args)
  1383. else:
  1384. assert args.type == "tensorop8832", (
  1385. "operation conv2d only support"
  1386. "simt, tensorop8816 and tensorop8832. (got:{})".format(args.type)
  1387. )
  1388. return GenerateConv2d_TensorOp_8832(args)
  1389. def GenerateDeconvOperations(args):
  1390. if args.type == "simt":
  1391. return GenerateDeconv_Simt(args)
  1392. else:
  1393. assert args.type == "tensorop8816", (
  1394. "operation deconv only support"
  1395. "simt and tensorop8816. (got:{})".format(args.type)
  1396. )
  1397. return GenerateDeconv_TensorOp_8816(args)
  1398. def GenerateDwconv2dFpropOperations(args):
  1399. if args.type == "simt":
  1400. return GenerateDwconv2dFprop_Simt(args)
  1401. else:
  1402. assert args.type == "tensorop884", (
  1403. "operation dwconv2d fprop only support"
  1404. "simt, tensorop884. (got:{})".format(args.type)
  1405. )
  1406. return GenerateDwconv2dFprop_TensorOp_884(args)
  1407. def GenerateGemmOperations(args):
  1408. if args.type == "tensorop884":
  1409. return GeneratesGemm_TensorOp_884(args)
  1410. elif args.type == "tensorop1688":
  1411. return GeneratesGemm_TensorOp_1688(args)
  1412. else:
  1413. assert (
  1414. args.type == "simt"
  1415. ), "operation gemm only support" "simt. (got:{})".format(args.type)
  1416. return GenerateGemm_Simt(args)
  1417. def GenerateGemvOperations(args):
  1418. assert args.type == "simt", "operation gemv only support" "simt. (got:{})".format(
  1419. args.type
  1420. )
  1421. return GenerateGemv_Simt(args)
  1422. ###################################################################################################
  1423. ###################################################################################################
  1424. if __name__ == "__main__":
  1425. parser = argparse.ArgumentParser(
  1426. description="Generates device kernel registration code for CUTLASS Kernels"
  1427. )
  1428. parser.add_argument(
  1429. "--operations",
  1430. type=str,
  1431. choices=[
  1432. "gemm",
  1433. "gemv",
  1434. "conv2d",
  1435. "deconv",
  1436. "dwconv2d_fprop",
  1437. "dwconv2d_dgrad",
  1438. "dwconv2d_wgrad",
  1439. ],
  1440. required=True,
  1441. help="Specifies the operation to generate (gemm, gemv, conv2d, deconv, dwconv2d_fprop, dwconv2d_dgrad, dwconv2d_wgrad)",
  1442. )
  1443. parser.add_argument(
  1444. "output", type=str, help="output directory for CUTLASS kernel files"
  1445. )
  1446. parser.add_argument(
  1447. "--type",
  1448. type=str,
  1449. choices=["simt", "tensorop8816", "tensorop8832", "tensorop884", "tensorop1688"],
  1450. default="simt",
  1451. help="kernel type of CUTLASS kernel generator",
  1452. )
  1453. gemv_wrapper_path = (
  1454. "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
  1455. )
  1456. short_path = (
  1457. platform.system() == "Windows" or platform.system().find("NT") >= 0
  1458. ) and ("true" != os.getenv("CUTLASS_WITH_LONG_PATH", default="False").lower())
  1459. args = parser.parse_args()
  1460. if args.operations == "gemm":
  1461. operations = GenerateGemmOperations(args)
  1462. elif args.operations == "gemv":
  1463. operations = GenerateGemvOperations(args)
  1464. elif args.operations == "conv2d":
  1465. operations = GenerateConv2dOperations(args)
  1466. elif args.operations == "deconv":
  1467. operations = GenerateDeconvOperations(args)
  1468. elif args.operations == "dwconv2d_fprop":
  1469. operations = GenerateDwconv2dFpropOperations(args)
  1470. elif args.operations == "dwconv2d_dgrad":
  1471. pass
  1472. elif args.operations == "dwconv2d_wgrad":
  1473. pass
  1474. if (
  1475. args.operations == "conv2d"
  1476. or args.operations == "deconv"
  1477. or args.operations == "dwconv2d_fprop"
  1478. or args.operations == "dwconv2d_dgrad"
  1479. or args.operations == "dwconv2d_wgrad"
  1480. ):
  1481. for operation in operations:
  1482. with EmitConvSingleKernelWrapper(
  1483. args.output, operation, short_path
  1484. ) as emitter:
  1485. emitter.emit()
  1486. elif args.operations == "gemm":
  1487. for operation in operations:
  1488. with EmitGemmSingleKernelWrapper(
  1489. args.output, operation, short_path
  1490. ) as emitter:
  1491. emitter.emit()
  1492. elif args.operations == "gemv":
  1493. for operation in operations:
  1494. with EmitGemvSingleKernelWrapper(
  1495. args.output, operation, gemv_wrapper_path, short_path
  1496. ) as emitter:
  1497. emitter.emit()
  1498. if args.operations != "gemv":
  1499. GenerateManifest(args, operations, args.output)
  1500. #
  1501. ###################################################################################################