You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

generator.py 76 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import argparse
  7. import enum
  8. import os.path
  9. import platform
  10. import string
  11. from library import *
  12. from manifest import *
  13. ###################################################################################################
  14. #
  15. def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch=0):
  16. # by default, use the latest CUDA Toolkit version
  17. cuda_version = [11, 0, 132]
  18. # Update cuda_version based on parsed string
  19. if semantic_ver_string != "":
  20. for i, x in enumerate([int(x) for x in semantic_ver_string.split(".")]):
  21. if i < len(cuda_version):
  22. cuda_version[i] = x
  23. else:
  24. cuda_version.append(x)
  25. return cuda_version >= [major, minor, patch]
  26. ###################################################################################################
  27. ###################################################################################################
  28. #
  29. def CreateGemmOperator(
  30. manifest,
  31. layouts,
  32. tile_descriptions,
  33. data_type,
  34. alignment_constraints,
  35. complex_transforms=None,
  36. epilogue_functor=EpilogueFunctor.LinearCombination,
  37. swizzling_functor=SwizzlingFunctor.Identity8,
  38. ):
  39. if complex_transforms is None:
  40. complex_transforms = [(ComplexTransform.none, ComplexTransform.none)]
  41. element_a, element_b, element_c, element_epilogue = data_type
  42. operations = []
  43. # by default, only generate the largest tile and largest alignment
  44. if manifest.args.kernels == "":
  45. tile_descriptions = [tile_descriptions[0]]
  46. alignment_constraints = [alignment_constraints[0]]
  47. for layout in layouts:
  48. for tile_description in tile_descriptions:
  49. for alignment in alignment_constraints:
  50. for complex_transform in complex_transforms:
  51. alignment_c = min(8, alignment)
  52. A = TensorDescription(
  53. element_a, layout[0], alignment, complex_transform[0]
  54. )
  55. B = TensorDescription(
  56. element_b, layout[1], alignment, complex_transform[1]
  57. )
  58. C = TensorDescription(element_c, layout[2], alignment_c)
  59. new_operation = GemmOperation(
  60. GemmKind.Universal,
  61. tile_description.minimum_compute_capability,
  62. tile_description,
  63. A,
  64. B,
  65. C,
  66. element_epilogue,
  67. epilogue_functor,
  68. swizzling_functor,
  69. )
  70. manifest.append(new_operation)
  71. operations.append(new_operation)
  72. return operations
  73. ###########################################################################################################
  74. # ConvolutionOperator support variations
  75. # ____________________________________________________________________
  76. # ConvolutionalOperator | Analytic | Optimized
  77. # ____________________________________________________________________
  78. # | Fprop | (strided) | (strided)
  79. # | Dgrad | (strided, unity*) | (unity)
  80. # | Wgrad | (strided) | (strided)
  81. # ____________________________________________________________________
  82. #
  83. # Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low
  84. ###########################################################################################################
  85. # Convolution for 2D operations
  86. def CreateConv2dOperator(
  87. manifest,
  88. layout,
  89. tile_descriptions,
  90. data_type,
  91. alignment,
  92. conv_kinds=[ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad],
  93. epilogue_functor=EpilogueFunctor.LinearCombination,
  94. ):
  95. element_a, element_b, element_c, element_epilogue = data_type
  96. # one exceptional case
  97. alignment_c = min(8, alignment)
  98. # iterator algorithm (analytic and optimized)
  99. iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized]
  100. # by default, only generate the largest tile size
  101. if manifest.args.kernels == "":
  102. tile_descriptions = [tile_descriptions[0]]
  103. operations = []
  104. for tile in tile_descriptions:
  105. for conv_kind in conv_kinds:
  106. for iterator_algorithm in iterator_algorithms:
  107. A = TensorDescription(element_a, layout[0], alignment)
  108. B = TensorDescription(element_b, layout[1], alignment)
  109. C = TensorDescription(element_c, layout[2], alignment_c)
  110. # unity stride only for Optimized Dgrad
  111. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
  112. conv_kind == ConvKind.Dgrad
  113. ):
  114. new_operation = Conv2dOperation(
  115. conv_kind,
  116. iterator_algorithm,
  117. tile.minimum_compute_capability,
  118. tile,
  119. A,
  120. B,
  121. C,
  122. element_epilogue,
  123. StrideSupport.Unity,
  124. epilogue_functor,
  125. )
  126. manifest.append(new_operation)
  127. operations.append(new_operation)
  128. # strided dgrad is not supported by Optimized Dgrad
  129. if (iterator_algorithm == IteratorAlgorithm.Optimized) and (
  130. conv_kind == ConvKind.Dgrad
  131. ):
  132. continue
  133. # strided support for Fprop (Analytic/Optimized), Dgrad (Analytic), and Wgrad (Analytic)
  134. new_operation = Conv2dOperation(
  135. conv_kind,
  136. iterator_algorithm,
  137. tile.minimum_compute_capability,
  138. tile,
  139. A,
  140. B,
  141. C,
  142. element_epilogue,
  143. StrideSupport.Strided,
  144. epilogue_functor,
  145. )
  146. manifest.append(new_operation)
  147. operations.append(new_operation)
  148. return operations
  149. ###################################################################################################
  150. ###################################################################################################
  151. def GenerateConv2d_Simt(args):
  152. operations = []
  153. layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorC4RSK4)]
  154. math_instructions = [
  155. MathInstruction(
  156. [1, 1, 4],
  157. DataType.s8,
  158. DataType.s8,
  159. DataType.s32,
  160. OpcodeClass.Simt,
  161. MathOperation.multiply_add,
  162. )
  163. ]
  164. dst_layouts = [
  165. LayoutType.TensorNC4HW4,
  166. LayoutType.TensorNC32HW32,
  167. LayoutType.TensorNHWC,
  168. LayoutType.TensorNHWC,
  169. LayoutType.TensorNCHW,
  170. ]
  171. dst_types = [DataType.s8, DataType.s8, DataType.u4, DataType.s4, DataType.f32]
  172. max_cc = 1024
  173. for math_inst in math_instructions:
  174. for layout in layouts:
  175. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  176. if dst_type == DataType.s4 or dst_type == DataType.u4:
  177. min_cc = 75
  178. use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt
  179. else:
  180. min_cc = 61
  181. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  182. tile_descriptions = [
  183. TileDescription(
  184. [128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  185. ),
  186. TileDescription(
  187. [128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  188. ),
  189. TileDescription(
  190. [64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc
  191. ),
  192. TileDescription(
  193. [128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc
  194. ),
  195. TileDescription(
  196. [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
  197. ),
  198. TileDescription(
  199. [32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  200. ),
  201. TileDescription(
  202. [64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  203. ),
  204. TileDescription(
  205. [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
  206. ),
  207. TileDescription(
  208. [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
  209. ),
  210. ]
  211. for tile in tile_descriptions:
  212. if (
  213. dst_layout == LayoutType.TensorNC32HW32
  214. and tile.threadblock_shape[0] > 32
  215. ):
  216. continue
  217. if (
  218. dst_layout == LayoutType.TensorNCHW
  219. or dst_layout == LayoutType.TensorNHWC
  220. ) and tile.threadblock_shape[0] > 16:
  221. continue
  222. operations += GenerateConv2d(
  223. ConvType.Convolution,
  224. ConvKind.Fprop,
  225. [tile],
  226. layout[0],
  227. layout[1],
  228. dst_layout,
  229. dst_type,
  230. min_cc,
  231. 32,
  232. 32,
  233. 32,
  234. use_special_optimization,
  235. )
  236. return operations
  237. def GenerateConv2d_TensorOp_8816(args):
  238. operations = []
  239. layouts = [(LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32)]
  240. math_instructions = [
  241. MathInstruction(
  242. [8, 8, 16],
  243. DataType.s8,
  244. DataType.s8,
  245. DataType.s32,
  246. OpcodeClass.TensorOp,
  247. MathOperation.multiply_add_saturate,
  248. )
  249. ]
  250. dst_layouts = [LayoutType.TensorNC32HW32, LayoutType.TensorNC4HW4]
  251. dst_types = [DataType.s8, DataType.s8]
  252. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  253. min_cc = 75
  254. max_cc = 1024
  255. cuda_major = 10
  256. cuda_minor = 2
  257. for math_inst in math_instructions:
  258. for layout in layouts:
  259. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  260. if dst_layout == LayoutType.TensorNC32HW32:
  261. tile_descriptions = [
  262. TileDescription(
  263. [128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc
  264. ),
  265. TileDescription(
  266. [256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc
  267. ),
  268. TileDescription(
  269. [128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  270. ),
  271. TileDescription(
  272. [128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  273. ),
  274. TileDescription(
  275. [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  276. ),
  277. TileDescription(
  278. [128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc
  279. ),
  280. TileDescription(
  281. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  282. ),
  283. ]
  284. operations += GenerateConv2d(
  285. ConvType.Convolution,
  286. ConvKind.Fprop,
  287. tile_descriptions,
  288. layout[0],
  289. layout[1],
  290. dst_layout,
  291. dst_type,
  292. min_cc,
  293. 128,
  294. 128,
  295. 64,
  296. use_special_optimization,
  297. ImplicitGemmMode.GemmTN,
  298. True,
  299. cuda_major,
  300. cuda_minor,
  301. )
  302. else:
  303. assert dst_layout == LayoutType.TensorNC4HW4
  304. tile_descriptions = [
  305. TileDescription(
  306. [64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc
  307. ),
  308. TileDescription(
  309. [32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc
  310. ),
  311. ]
  312. operations += GenerateConv2d(
  313. ConvType.Convolution,
  314. ConvKind.Fprop,
  315. tile_descriptions,
  316. layout[0],
  317. layout[1],
  318. dst_layout,
  319. dst_type,
  320. min_cc,
  321. 128,
  322. 128,
  323. 64,
  324. use_special_optimization,
  325. ImplicitGemmMode.GemmNT,
  326. False,
  327. cuda_major,
  328. cuda_minor,
  329. )
  330. layouts_nhwc = [
  331. (LayoutType.TensorNHWC, LayoutType.TensorNC4HW4, 32),
  332. (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 64),
  333. (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 128),
  334. ]
  335. dst_layouts_nhwc = [LayoutType.TensorNHWC]
  336. for math_inst in math_instructions:
  337. for layout in layouts_nhwc:
  338. for dst_layout in dst_layouts_nhwc:
  339. dst_type = math_inst.element_b
  340. tile_descriptions = [
  341. TileDescription(
  342. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  343. ),
  344. TileDescription(
  345. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  346. ),
  347. ]
  348. for tile in tile_descriptions:
  349. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  350. operations += GenerateConv2d(
  351. ConvType.Convolution,
  352. ConvKind.Fprop,
  353. [tile],
  354. layout[0],
  355. layout[1],
  356. dst_layout,
  357. dst_type,
  358. min_cc,
  359. layout[2],
  360. layout[2],
  361. dst_align,
  362. use_special_optimization,
  363. ImplicitGemmMode.GemmTN,
  364. False,
  365. cuda_major,
  366. cuda_minor,
  367. )
  368. if (
  369. tile.threadblock_shape[1] == 16
  370. or tile.threadblock_shape[1] == 32
  371. ):
  372. operations += GenerateConv2d(
  373. ConvType.Convolution,
  374. ConvKind.Fprop,
  375. [tile],
  376. layout[0],
  377. layout[1],
  378. dst_layout,
  379. dst_type,
  380. min_cc,
  381. layout[2],
  382. layout[2],
  383. dst_align,
  384. use_special_optimization,
  385. ImplicitGemmMode.GemmTN,
  386. True,
  387. cuda_major,
  388. cuda_minor,
  389. )
  390. out_dtypes = [DataType.s4, DataType.u4, DataType.f32]
  391. # INT8x8x4 and INT8x8x32
  392. for math_inst in math_instructions:
  393. for layout in layouts_nhwc:
  394. for dst_layout in dst_layouts_nhwc:
  395. for out_dtype in out_dtypes:
  396. tile_descriptions = [
  397. TileDescription(
  398. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  399. ),
  400. TileDescription(
  401. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  402. ),
  403. ]
  404. for tile in tile_descriptions:
  405. dst_align = (
  406. 4 * DataTypeSize[out_dtype]
  407. if tile.threadblock_shape[1] == 16
  408. or out_dtype == DataType.f32
  409. else 8 * DataTypeSize[out_dtype]
  410. )
  411. operations += GenerateConv2d(
  412. ConvType.Convolution,
  413. ConvKind.Fprop,
  414. [tile],
  415. layout[0],
  416. layout[1],
  417. dst_layout,
  418. out_dtype,
  419. min_cc,
  420. layout[2],
  421. layout[2],
  422. dst_align,
  423. use_special_optimization,
  424. ImplicitGemmMode.GemmTN,
  425. False,
  426. cuda_major,
  427. cuda_minor,
  428. )
  429. if tile.threadblock_shape[1] == 16 or (
  430. tile.threadblock_shape[1] == 32
  431. and out_dtype != DataType.f32
  432. ):
  433. operations += GenerateConv2d(
  434. ConvType.Convolution,
  435. ConvKind.Fprop,
  436. [tile],
  437. layout[0],
  438. layout[1],
  439. dst_layout,
  440. out_dtype,
  441. min_cc,
  442. layout[2],
  443. layout[2],
  444. dst_align,
  445. use_special_optimization,
  446. ImplicitGemmMode.GemmTN,
  447. True,
  448. cuda_major,
  449. cuda_minor,
  450. )
  451. return operations
  452. def GenerateConv2d_TensorOp_8832(args):
  453. operations = []
  454. layouts = [(LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64)]
  455. math_instructions = [
  456. MathInstruction(
  457. [8, 8, 32],
  458. DataType.s4,
  459. DataType.s4,
  460. DataType.s32,
  461. OpcodeClass.TensorOp,
  462. MathOperation.multiply_add_saturate,
  463. ),
  464. MathInstruction(
  465. [8, 8, 32],
  466. DataType.s4,
  467. DataType.u4,
  468. DataType.s32,
  469. OpcodeClass.TensorOp,
  470. MathOperation.multiply_add_saturate,
  471. ),
  472. ]
  473. dst_layouts = [LayoutType.TensorNC64HW64]
  474. use_special_optimization = SpecialOptimizeDesc.ConvFilterUnity
  475. min_cc = 75
  476. max_cc = 1024
  477. cuda_major = 10
  478. cuda_minor = 2
  479. for math_inst in math_instructions:
  480. for layout in layouts:
  481. for dst_layout in dst_layouts:
  482. dst_type = math_inst.element_b
  483. tile_descriptions = [
  484. TileDescription(
  485. [128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc
  486. ),
  487. TileDescription(
  488. [128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc
  489. ),
  490. TileDescription(
  491. [128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc
  492. ),
  493. TileDescription(
  494. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  495. ),
  496. ]
  497. operations += GenerateConv2d(
  498. ConvType.Convolution,
  499. ConvKind.Fprop,
  500. tile_descriptions,
  501. layout[0],
  502. layout[1],
  503. dst_layout,
  504. dst_type,
  505. min_cc,
  506. 128,
  507. 128,
  508. 64,
  509. use_special_optimization,
  510. ImplicitGemmMode.GemmTN,
  511. True,
  512. cuda_major,
  513. cuda_minor,
  514. )
  515. layouts_nhwc = [
  516. (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32),
  517. (LayoutType.TensorNHWC, LayoutType.TensorNC16HW16, 64),
  518. (LayoutType.TensorNHWC, LayoutType.TensorNC32HW32, 128),
  519. ]
  520. dst_layouts_nhwc = [LayoutType.TensorNHWC]
  521. for math_inst in math_instructions:
  522. for layout in layouts_nhwc:
  523. for dst_layout in dst_layouts_nhwc:
  524. dst_type = math_inst.element_b
  525. tile_descriptions = [
  526. TileDescription(
  527. [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
  528. ),
  529. TileDescription(
  530. [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  531. ),
  532. TileDescription(
  533. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  534. ),
  535. ]
  536. for tile in tile_descriptions:
  537. dst_align = 16 if tile.threadblock_shape[1] == 16 else 32
  538. operations += GenerateConv2d(
  539. ConvType.Convolution,
  540. ConvKind.Fprop,
  541. [tile],
  542. layout[0],
  543. layout[1],
  544. dst_layout,
  545. dst_type,
  546. min_cc,
  547. layout[2],
  548. layout[2],
  549. dst_align,
  550. use_special_optimization,
  551. ImplicitGemmMode.GemmTN,
  552. False,
  553. cuda_major,
  554. cuda_minor,
  555. )
  556. if (
  557. tile.threadblock_shape[1] == 32
  558. or tile.threadblock_shape[1] == 64
  559. ):
  560. dst_align = 32 if tile.threadblock_shape[1] == 32 else 64
  561. operations += GenerateConv2d(
  562. ConvType.Convolution,
  563. ConvKind.Fprop,
  564. [tile],
  565. layout[0],
  566. layout[1],
  567. dst_layout,
  568. dst_type,
  569. min_cc,
  570. layout[2],
  571. layout[2],
  572. dst_align,
  573. use_special_optimization,
  574. ImplicitGemmMode.GemmTN,
  575. True,
  576. cuda_major,
  577. cuda_minor,
  578. )
  579. # INT4x4x8
  580. for math_inst in math_instructions:
  581. for layout in layouts_nhwc:
  582. for dst_layout in dst_layouts_nhwc:
  583. tile_descriptions = [
  584. TileDescription(
  585. [128, 16, 64], 2, [1, 1, 1], math_inst, min_cc, max_cc
  586. ),
  587. TileDescription(
  588. [128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  589. ),
  590. TileDescription(
  591. [128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc
  592. ),
  593. ]
  594. for tile in tile_descriptions:
  595. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  596. operations += GenerateConv2d(
  597. ConvType.Convolution,
  598. ConvKind.Fprop,
  599. [tile],
  600. layout[0],
  601. layout[1],
  602. dst_layout,
  603. DataType.s8,
  604. min_cc,
  605. layout[2],
  606. layout[2],
  607. dst_align,
  608. use_special_optimization,
  609. ImplicitGemmMode.GemmTN,
  610. False,
  611. cuda_major,
  612. cuda_minor,
  613. )
  614. if (
  615. tile.threadblock_shape[1] == 32
  616. or tile.threadblock_shape[1] == 64
  617. ):
  618. dst_align = 64 if tile.threadblock_shape[1] == 32 else 128
  619. operations += GenerateConv2d(
  620. ConvType.Convolution,
  621. ConvKind.Fprop,
  622. [tile],
  623. layout[0],
  624. layout[1],
  625. dst_layout,
  626. DataType.s8,
  627. min_cc,
  628. layout[2],
  629. layout[2],
  630. dst_align,
  631. use_special_optimization,
  632. ImplicitGemmMode.GemmTN,
  633. True,
  634. cuda_major,
  635. cuda_minor,
  636. )
  637. return operations
  638. def GenerateDeconv_Simt(args):
  639. operations = []
  640. layouts = [(LayoutType.TensorNC4HW4, LayoutType.TensorK4RSC4)]
  641. math_instructions = [
  642. MathInstruction(
  643. [1, 1, 4],
  644. DataType.s8,
  645. DataType.s8,
  646. DataType.s32,
  647. OpcodeClass.Simt,
  648. MathOperation.multiply_add,
  649. )
  650. ]
  651. dst_layouts = [LayoutType.TensorNC4HW4]
  652. dst_types = [DataType.s8]
  653. use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
  654. min_cc = 61
  655. max_cc = 1024
  656. for math_inst in math_instructions:
  657. for layout in layouts:
  658. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  659. tile_descriptions = [
  660. TileDescription(
  661. [32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc
  662. ),
  663. TileDescription(
  664. [16, 128, 16], 2, [1, 2, 1], math_inst, min_cc, max_cc
  665. ),
  666. TileDescription(
  667. [16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc
  668. ),
  669. TileDescription(
  670. [16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc
  671. ),
  672. ]
  673. operations += GenerateConv2d(
  674. ConvType.Convolution,
  675. ConvKind.Dgrad,
  676. tile_descriptions,
  677. layout[0],
  678. layout[1],
  679. dst_layout,
  680. dst_type,
  681. min_cc,
  682. 32,
  683. 32,
  684. 32,
  685. use_special_optimization,
  686. )
  687. return operations
  688. def GenerateDeconv_TensorOp_8816(args):
  689. operations = []
  690. layouts = [
  691. (LayoutType.TensorNHWC, LayoutType.TensorCK4RS4, 32),
  692. (LayoutType.TensorNHWC, LayoutType.TensorCK8RS8, 64),
  693. (LayoutType.TensorNHWC, LayoutType.TensorCK16RS16, 128),
  694. ]
  695. math_instructions = [
  696. MathInstruction(
  697. [8, 8, 16],
  698. DataType.s8,
  699. DataType.s8,
  700. DataType.s32,
  701. OpcodeClass.TensorOp,
  702. MathOperation.multiply_add_saturate,
  703. )
  704. ]
  705. dst_layouts = [LayoutType.TensorNHWC]
  706. dst_types = [DataType.s8]
  707. use_special_optimization = SpecialOptimizeDesc.DeconvDoubleUpsampling
  708. min_cc = 75
  709. max_cc = 1024
  710. cuda_major = 10
  711. cuda_minor = 2
  712. for math_inst in math_instructions:
  713. for layout in layouts:
  714. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  715. tile_descriptions = [
  716. TileDescription(
  717. [128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc
  718. ),
  719. TileDescription(
  720. [64, 16, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc
  721. ),
  722. ]
  723. for tile in tile_descriptions:
  724. dst_align = 32 if tile.threadblock_shape[1] == 16 else 64
  725. operations += GenerateConv2d(
  726. ConvType.Convolution,
  727. ConvKind.Dgrad,
  728. [tile],
  729. layout[0],
  730. layout[1],
  731. dst_layout,
  732. dst_type,
  733. min_cc,
  734. layout[2],
  735. layout[2],
  736. dst_align,
  737. use_special_optimization,
  738. ImplicitGemmMode.GemmTN,
  739. False,
  740. cuda_major,
  741. cuda_minor,
  742. )
  743. return operations
  744. ################################################################################
  745. # parameters
  746. # Edge - for tiles, the edges represent the length of one side
  747. # Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles
  748. # MaxEdge - maximum length of each edge
  749. # Min/Max - minimum/maximum of the product of edge lengths
  750. ################################################################################
  751. warpsPerThreadblockEdge = [1, 2, 4, 8, 16]
  752. warpsPerThreadblockRatio = 2
  753. warpsPerThreadblockMax = 16
  754. # NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases
  755. warpShapeEdges = [8, 16, 32, 64, 128, 256]
  756. warpShapeRatio = 4
  757. warpShapeMax = 64 * 64
  758. warpShapeMin = 8 * 8
  759. threadblockEdgeMax = 256
  760. # char, type bits/elem, max tile, L0 threadblock tiles
  761. precisions = {
  762. "c": ["cutlass::complex<float>", 64, 64 * 128, [[64, 128], [64, 32]]],
  763. "d": ["double", 64, 64 * 64, [[64, 64], [32, 32]]],
  764. "h": ["cutlass::half_t", 16, 128 * 256, [[256, 128], [64, 128], [64, 32]]],
  765. "i": ["int", 32, 128 * 128, [[128, 64], [16, 32]]],
  766. "s": ["float", 32, 128 * 128, [[128, 256], [128, 128], [64, 64]]],
  767. "z": ["cutlass::complex<double>", 128, 64 * 64, [[32, 64], [16, 32]]],
  768. }
  769. # L1 will have a single kernel for every unique shape
  770. # L2 will have everything else
  771. def GenerateGemm_Simt(args):
  772. ################################################################################
  773. # warps per threadblock
  774. ################################################################################
  775. warpsPerThreadblocks = []
  776. for warpsPerThreadblock0 in warpsPerThreadblockEdge:
  777. for warpsPerThreadblock1 in warpsPerThreadblockEdge:
  778. if (
  779. warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
  780. and warpsPerThreadblock1 / warpsPerThreadblock0
  781. <= warpsPerThreadblockRatio
  782. and warpsPerThreadblock0 * warpsPerThreadblock1
  783. <= warpsPerThreadblockMax
  784. ):
  785. warpsPerThreadblocks.append(
  786. [warpsPerThreadblock0, warpsPerThreadblock1]
  787. )
  788. ################################################################################
  789. # warp shapes
  790. ################################################################################
  791. warpNumThreads = 32
  792. warpShapes = []
  793. for warp0 in warpShapeEdges:
  794. for warp1 in warpShapeEdges:
  795. if (
  796. warp0 / warp1 <= warpShapeRatio
  797. and warp1 / warp0 <= warpShapeRatio
  798. and warp0 * warp1 <= warpShapeMax
  799. and warp0 * warp1 > warpShapeMin
  800. ):
  801. warpShapes.append([warp0, warp1])
  802. # sgemm
  803. (
  804. precisionType,
  805. precisionBits,
  806. threadblockMaxElements,
  807. threadblockTilesL0,
  808. ) = precisions["s"]
  809. layouts = [
  810. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  811. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  812. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  813. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  814. ]
  815. math_instructions = [
  816. MathInstruction(
  817. [1, 1, 1],
  818. DataType.f32,
  819. DataType.f32,
  820. DataType.f32,
  821. OpcodeClass.Simt,
  822. MathOperation.multiply_add,
  823. )
  824. ]
  825. min_cc = 50
  826. max_cc = 1024
  827. operations = []
  828. for math_inst in math_instructions:
  829. for layout in layouts:
  830. data_type = [
  831. math_inst.element_a,
  832. math_inst.element_b,
  833. math_inst.element_accumulator,
  834. math_inst.element_accumulator,
  835. ]
  836. tile_descriptions = [
  837. TileDescription([64, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  838. TileDescription([256, 64, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  839. TileDescription([32, 256, 8], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  840. TileDescription([256, 32, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  841. TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  842. TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  843. TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  844. TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  845. TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  846. TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  847. TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  848. TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  849. TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  850. TileDescription([8, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  851. TileDescription([16, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  852. TileDescription([16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  853. TileDescription([16, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  854. ]
  855. for warpsPerThreadblock in warpsPerThreadblocks:
  856. for warpShape in warpShapes:
  857. warpThreadsM = 0
  858. if warpShape[0] > warpShape[1]:
  859. warpThreadsM = 8
  860. else:
  861. warpThreadsM = 4
  862. warpThreadsN = warpNumThreads / warpThreadsM
  863. # skip shapes with conflicting rectangularity
  864. # they are unlikely to be fastest
  865. blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
  866. blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
  867. warpG = warpShape[0] > warpShape[1]
  868. warpL = warpShape[0] < warpShape[1]
  869. blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
  870. blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
  871. warpG2 = warpShape[0] > warpShape[1] * 2
  872. warpL2 = warpShape[0] * 2 < warpShape[1]
  873. if blockG2 and warpL:
  874. continue
  875. if blockL2 and warpG:
  876. continue
  877. if warpG2 and blockL:
  878. continue
  879. if warpL2 and blockG:
  880. continue
  881. # check threadblock ratios and max
  882. threadblockTile = [
  883. warpShape[0] * warpsPerThreadblock[0],
  884. warpShape[1] * warpsPerThreadblock[1],
  885. ]
  886. if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
  887. continue
  888. if threadblockTile[0] > threadblockEdgeMax:
  889. continue
  890. if threadblockTile[1] > threadblockEdgeMax:
  891. continue
  892. totalThreads = (
  893. warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
  894. )
  895. # calculate unroll
  896. # ensure that every iteration at least a full load of A,B are done
  897. unrollMin = 8
  898. unrollMin0 = totalThreads // threadblockTile[0]
  899. unrollMin1 = totalThreads // threadblockTile[1]
  900. unroll = max(unrollMin, unrollMin0, unrollMin1)
  901. threadTileM = warpShape[0] // warpThreadsM
  902. threadTileN = warpShape[1] // warpThreadsN
  903. if threadTileM < 2 or threadTileN < 2:
  904. continue
  905. if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
  906. continue
  907. # epilogue currently only supports N < WarpNumThreads
  908. if threadblockTile[1] < warpNumThreads:
  909. continue
  910. # limit smem
  911. smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
  912. smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
  913. smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
  914. if smemKBytes > 48:
  915. continue
  916. tile = TileDescription(
  917. [threadblockTile[0], threadblockTile[1], unroll],
  918. 2,
  919. [
  920. threadblockTile[0] // warpShape[0],
  921. threadblockTile[1] // warpShape[1],
  922. 1,
  923. ],
  924. math_inst,
  925. min_cc,
  926. max_cc,
  927. )
  928. def filter(t: TileDescription) -> bool:
  929. nonlocal tile
  930. return (
  931. t.threadblock_shape[0] == tile.threadblock_shape[0]
  932. and t.threadblock_shape[1] == tile.threadblock_shape[1]
  933. and t.threadblock_shape[2] == tile.threadblock_shape[2]
  934. and t.warp_count[0] == tile.warp_count[0]
  935. and t.warp_count[1] == tile.warp_count[1]
  936. and t.warp_count[2] == tile.warp_count[2]
  937. and t.stages == tile.stages
  938. )
  939. if not any(t for t in tile_descriptions if filter(t)):
  940. continue
  941. operations += GeneratesGemm(
  942. tile, data_type, layout[0], layout[1], layout[2], min_cc
  943. )
  944. return operations
  945. #
  946. def GenerateDwconv2d_Simt(args, conv_kind):
  947. ################################################################################
  948. # warps per threadblock
  949. ################################################################################
  950. warpsPerThreadblocks = []
  951. for warpsPerThreadblock0 in warpsPerThreadblockEdge:
  952. for warpsPerThreadblock1 in warpsPerThreadblockEdge:
  953. if (
  954. warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio
  955. and warpsPerThreadblock1 / warpsPerThreadblock0
  956. <= warpsPerThreadblockRatio
  957. and warpsPerThreadblock0 * warpsPerThreadblock1
  958. <= warpsPerThreadblockMax
  959. ):
  960. warpsPerThreadblocks.append(
  961. [warpsPerThreadblock0, warpsPerThreadblock1]
  962. )
  963. ################################################################################
  964. # warp shapes
  965. ################################################################################
  966. warpNumThreads = 32
  967. warpShapes = []
  968. for warp0 in warpShapeEdges:
  969. for warp1 in warpShapeEdges:
  970. if (
  971. warp0 / warp1 <= warpShapeRatio
  972. and warp1 / warp0 <= warpShapeRatio
  973. and warp0 * warp1 <= warpShapeMax
  974. and warp0 * warp1 > warpShapeMin
  975. ):
  976. warpShapes.append([warp0, warp1])
  977. # sgemm
  978. (
  979. precisionType,
  980. precisionBits,
  981. threadblockMaxElements,
  982. threadblockTilesL0,
  983. ) = precisions["s"]
  984. layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
  985. math_instructions = [
  986. MathInstruction(
  987. [1, 1, 1],
  988. DataType.f32,
  989. DataType.f32,
  990. DataType.f32,
  991. OpcodeClass.Simt,
  992. MathOperation.multiply_add,
  993. )
  994. ]
  995. min_cc = 50
  996. max_cc = 1024
  997. dst_layouts = [LayoutType.TensorNCHW]
  998. dst_types = [DataType.f32]
  999. if conv_kind == ConvKind.Wgrad:
  1000. alignment_constraints = [32]
  1001. else:
  1002. alignment_constraints = [128, 32]
  1003. operations = []
  1004. for math_inst in math_instructions:
  1005. tile_descriptions = [
  1006. TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  1007. TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1008. TileDescription([64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1009. TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  1010. TileDescription([32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc),
  1011. TileDescription([64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc),
  1012. TileDescription([32, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1013. TileDescription([64, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1014. TileDescription([32, 32, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc),
  1015. ]
  1016. for warpsPerThreadblock in warpsPerThreadblocks:
  1017. for warpShape in warpShapes:
  1018. warpThreadsM = 0
  1019. if warpShape[0] > warpShape[1]:
  1020. warpThreadsM = 8
  1021. else:
  1022. warpThreadsM = 4
  1023. warpThreadsN = warpNumThreads / warpThreadsM
  1024. # skip shapes with conflicting rectangularity
  1025. # they are unlikely to be fastest
  1026. blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1]
  1027. blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1]
  1028. warpG = warpShape[0] > warpShape[1]
  1029. warpL = warpShape[0] < warpShape[1]
  1030. blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1] * 2
  1031. blockL2 = warpsPerThreadblock[0] * 2 < warpsPerThreadblock[1]
  1032. warpG2 = warpShape[0] > warpShape[1] * 2
  1033. warpL2 = warpShape[0] * 2 < warpShape[1]
  1034. if blockG2 and warpL:
  1035. continue
  1036. if blockL2 and warpG:
  1037. continue
  1038. if warpG2 and blockL:
  1039. continue
  1040. if warpL2 and blockG:
  1041. continue
  1042. # check threadblock ratios and max
  1043. threadblockTile = [
  1044. warpShape[0] * warpsPerThreadblock[0],
  1045. warpShape[1] * warpsPerThreadblock[1],
  1046. ]
  1047. if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements:
  1048. continue
  1049. if threadblockTile[0] > threadblockEdgeMax:
  1050. continue
  1051. if threadblockTile[1] > threadblockEdgeMax:
  1052. continue
  1053. totalThreads = (
  1054. warpNumThreads * warpsPerThreadblock[0] * warpsPerThreadblock[1]
  1055. )
  1056. # calculate unroll
  1057. # ensure that every iteration at least a full load of A,B are done
  1058. unrollMin = 8
  1059. unrollMin0 = totalThreads // threadblockTile[0]
  1060. unrollMin1 = totalThreads // threadblockTile[1]
  1061. unroll = max(unrollMin, unrollMin0, unrollMin1)
  1062. threadTileM = warpShape[0] // warpThreadsM
  1063. threadTileN = warpShape[1] // warpThreadsN
  1064. if threadTileM < 2 or threadTileN < 2:
  1065. continue
  1066. if threadTileM * threadTileN * precisionBits > 8 * 8 * 32:
  1067. continue
  1068. # epilogue currently only supports N < WarpNumThreads
  1069. if threadblockTile[1] < warpNumThreads:
  1070. continue
  1071. # limit smem
  1072. smemBitsA = threadblockTile[0] * unroll * 2 * precisionBits
  1073. smemBitsB = threadblockTile[1] * unroll * 2 * precisionBits
  1074. smemKBytes = (smemBitsA + smemBitsB) / 8 / 1024
  1075. if smemKBytes > 48:
  1076. continue
  1077. tile = TileDescription(
  1078. [threadblockTile[0], threadblockTile[1], unroll],
  1079. 2,
  1080. [
  1081. threadblockTile[0] // warpShape[0],
  1082. threadblockTile[1] // warpShape[1],
  1083. 1,
  1084. ],
  1085. math_inst,
  1086. min_cc,
  1087. max_cc,
  1088. )
  1089. def filter(t: TileDescription) -> bool:
  1090. nonlocal tile
  1091. return (
  1092. t.threadblock_shape[0] == tile.threadblock_shape[0]
  1093. and t.threadblock_shape[1] == tile.threadblock_shape[1]
  1094. and t.threadblock_shape[2] == tile.threadblock_shape[2]
  1095. and t.warp_count[0] == tile.warp_count[0]
  1096. and t.warp_count[1] == tile.warp_count[1]
  1097. and t.warp_count[2] == tile.warp_count[2]
  1098. and t.stages == tile.stages
  1099. )
  1100. if not any(t for t in tile_descriptions if filter(t)):
  1101. continue
  1102. for layout in layouts:
  1103. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  1104. for alignment_src in alignment_constraints:
  1105. operations += GenerateConv2d(
  1106. ConvType.DepthwiseConvolution,
  1107. conv_kind,
  1108. [tile],
  1109. layout[0],
  1110. layout[1],
  1111. dst_layout,
  1112. dst_type,
  1113. min_cc,
  1114. alignment_src,
  1115. 32,
  1116. 32,
  1117. SpecialOptimizeDesc.NoneSpecialOpt,
  1118. ImplicitGemmMode.GemmNT
  1119. if conv_kind == ConvKind.Wgrad
  1120. else ImplicitGemmMode.GemmTN,
  1121. )
  1122. return operations
  1123. #
  1124. def GenerateDwconv2d_TensorOp_884(args, conv_kind):
  1125. layouts = [(LayoutType.TensorNCHW, LayoutType.TensorNCHW)]
  1126. math_instructions = [
  1127. MathInstruction(
  1128. [8, 8, 4],
  1129. DataType.f16,
  1130. DataType.f16,
  1131. DataType.f32,
  1132. OpcodeClass.TensorOp,
  1133. MathOperation.multiply_add,
  1134. ),
  1135. MathInstruction(
  1136. [8, 8, 4],
  1137. DataType.f16,
  1138. DataType.f16,
  1139. DataType.f16,
  1140. OpcodeClass.TensorOp,
  1141. MathOperation.multiply_add,
  1142. ),
  1143. ]
  1144. min_cc = 70
  1145. max_cc = 75
  1146. dst_layouts = [LayoutType.TensorNCHW]
  1147. if conv_kind == ConvKind.Wgrad:
  1148. dst_types = [DataType.f32]
  1149. else:
  1150. dst_types = [DataType.f16]
  1151. alignment_constraints = [128, 32, 16]
  1152. cuda_major = 10
  1153. cuda_minor = 1
  1154. operations = []
  1155. for math_inst in math_instructions:
  1156. tile_descriptions = [
  1157. TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  1158. TileDescription([128, 128, 32], 2, [4, 4, 1], math_inst, min_cc, max_cc),
  1159. TileDescription([64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc),
  1160. TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc),
  1161. TileDescription([64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1162. ]
  1163. for layout in layouts:
  1164. for dst_type, dst_layout in zip(dst_types, dst_layouts):
  1165. for alignment_src in alignment_constraints:
  1166. if conv_kind == ConvKind.Wgrad:
  1167. # skip io16xc16
  1168. if math_inst.element_accumulator == DataType.f16:
  1169. continue
  1170. for alignment_diff in alignment_constraints:
  1171. operations += GenerateConv2d(
  1172. ConvType.DepthwiseConvolution,
  1173. conv_kind,
  1174. tile_descriptions,
  1175. layout[0],
  1176. layout[1],
  1177. dst_layout,
  1178. dst_type,
  1179. min_cc,
  1180. alignment_src,
  1181. alignment_diff,
  1182. 32, # always f32 output
  1183. SpecialOptimizeDesc.NoneSpecialOpt,
  1184. ImplicitGemmMode.GemmNT,
  1185. False,
  1186. cuda_major,
  1187. cuda_minor,
  1188. )
  1189. else:
  1190. operations += GenerateConv2d(
  1191. ConvType.DepthwiseConvolution,
  1192. conv_kind,
  1193. tile_descriptions,
  1194. layout[0],
  1195. layout[1],
  1196. dst_layout,
  1197. dst_type,
  1198. min_cc,
  1199. alignment_src,
  1200. 16,
  1201. 16,
  1202. SpecialOptimizeDesc.NoneSpecialOpt,
  1203. ImplicitGemmMode.GemmTN,
  1204. False,
  1205. cuda_major,
  1206. cuda_minor,
  1207. )
  1208. return operations
  1209. #
  1210. def GenerateGemv_Simt(args):
  1211. threadBlockShape_N = [128, 64, 32]
  1212. ldgBits_A = [128, 64, 32]
  1213. ldgBits_B = [128, 64, 32]
  1214. layouts = [(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor)]
  1215. math_instructions = [
  1216. MathInstruction(
  1217. [1, 1, 1],
  1218. DataType.f32,
  1219. DataType.f32,
  1220. DataType.f32,
  1221. OpcodeClass.Simt,
  1222. MathOperation.multiply_add,
  1223. )
  1224. ]
  1225. min_cc = 50
  1226. operations = []
  1227. for math_inst in math_instructions:
  1228. for layout in layouts:
  1229. data_type = [
  1230. math_inst.element_a,
  1231. math_inst.element_b,
  1232. math_inst.element_accumulator,
  1233. math_inst.element_accumulator,
  1234. ]
  1235. for threadblock_shape_n in threadBlockShape_N:
  1236. for align_a in ldgBits_A:
  1237. for align_b in ldgBits_B:
  1238. ldg_elements_a = align_a // DataTypeSize[math_inst.element_a]
  1239. ldg_elements_b = align_b // DataTypeSize[math_inst.element_b]
  1240. threadblock_shape_k = (256 * ldg_elements_a) // (
  1241. threadblock_shape_n // ldg_elements_b
  1242. )
  1243. threadblock_shape = [
  1244. 1,
  1245. threadblock_shape_n,
  1246. threadblock_shape_k,
  1247. ]
  1248. thread_shape = [1, ldg_elements_b, ldg_elements_a]
  1249. operations.append(
  1250. GeneratesGemv(
  1251. math_inst,
  1252. threadblock_shape,
  1253. thread_shape,
  1254. data_type,
  1255. layout[0],
  1256. layout[1],
  1257. layout[2],
  1258. min_cc,
  1259. align_a,
  1260. align_b,
  1261. )
  1262. )
  1263. return operations
  1264. #
  1265. def GeneratesGemm_TensorOp_1688(args):
  1266. layouts = [
  1267. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  1268. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  1269. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  1270. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  1271. ]
  1272. math_instructions = [
  1273. MathInstruction(
  1274. [16, 8, 8],
  1275. DataType.f16,
  1276. DataType.f16,
  1277. DataType.f32,
  1278. OpcodeClass.TensorOp,
  1279. MathOperation.multiply_add,
  1280. ),
  1281. MathInstruction(
  1282. [16, 8, 8],
  1283. DataType.f16,
  1284. DataType.f16,
  1285. DataType.f16,
  1286. OpcodeClass.TensorOp,
  1287. MathOperation.multiply_add,
  1288. ),
  1289. ]
  1290. min_cc = 75
  1291. max_cc = 1024
  1292. alignment_constraints = [
  1293. 8,
  1294. 4,
  1295. 2,
  1296. # 1
  1297. ]
  1298. cuda_major = 10
  1299. cuda_minor = 2
  1300. operations = []
  1301. for math_inst in math_instructions:
  1302. for layout in layouts:
  1303. for align in alignment_constraints:
  1304. tile_descriptions = [
  1305. TileDescription(
  1306. [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
  1307. ),
  1308. TileDescription(
  1309. [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  1310. ),
  1311. TileDescription(
  1312. [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  1313. ),
  1314. ## comment some configuration to reduce compilation time and binary size
  1315. # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1316. # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1317. # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1318. ]
  1319. data_type = [
  1320. math_inst.element_a,
  1321. math_inst.element_b,
  1322. math_inst.element_a,
  1323. math_inst.element_accumulator,
  1324. ]
  1325. for tile in tile_descriptions:
  1326. operations += GeneratesGemm(
  1327. tile,
  1328. data_type,
  1329. layout[0],
  1330. layout[1],
  1331. layout[2],
  1332. min_cc,
  1333. align * 16,
  1334. align * 16,
  1335. align * 16,
  1336. cuda_major,
  1337. cuda_minor,
  1338. )
  1339. return operations
  1340. #
  1341. def GeneratesGemm_TensorOp_884(args):
  1342. layouts = [
  1343. (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # nn
  1344. (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), # nt
  1345. (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), # tn
  1346. (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), # tt
  1347. ]
  1348. math_instructions = [
  1349. MathInstruction(
  1350. [8, 8, 4],
  1351. DataType.f16,
  1352. DataType.f16,
  1353. DataType.f32,
  1354. OpcodeClass.TensorOp,
  1355. MathOperation.multiply_add,
  1356. ),
  1357. MathInstruction(
  1358. [8, 8, 4],
  1359. DataType.f16,
  1360. DataType.f16,
  1361. DataType.f16,
  1362. OpcodeClass.TensorOp,
  1363. MathOperation.multiply_add,
  1364. ),
  1365. ]
  1366. min_cc = 70
  1367. max_cc = 75
  1368. alignment_constraints = [
  1369. 8,
  1370. 4,
  1371. 2,
  1372. # 1
  1373. ]
  1374. cuda_major = 10
  1375. cuda_minor = 1
  1376. operations = []
  1377. for math_inst in math_instructions:
  1378. for layout in layouts:
  1379. for align in alignment_constraints:
  1380. tile_descriptions = [
  1381. TileDescription(
  1382. [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
  1383. ),
  1384. TileDescription(
  1385. [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
  1386. ),
  1387. TileDescription(
  1388. [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
  1389. ),
  1390. ## comment some configuration to reduce compilation time and binary size
  1391. # TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1392. # TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1393. # TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc),
  1394. ]
  1395. data_type = [
  1396. math_inst.element_a,
  1397. math_inst.element_b,
  1398. math_inst.element_a,
  1399. math_inst.element_accumulator,
  1400. ]
  1401. for tile in tile_descriptions:
  1402. operations += GeneratesGemm(
  1403. tile,
  1404. data_type,
  1405. layout[0],
  1406. layout[1],
  1407. layout[2],
  1408. min_cc,
  1409. align * 16,
  1410. align * 16,
  1411. align * 16,
  1412. cuda_major,
  1413. cuda_minor,
  1414. )
  1415. return operations
  1416. #
  1417. def GenerateConv2dOperations(args):
  1418. if args.type == "simt":
  1419. return GenerateConv2d_Simt(args)
  1420. elif args.type == "tensorop8816":
  1421. return GenerateConv2d_TensorOp_8816(args)
  1422. else:
  1423. assert args.type == "tensorop8832", (
  1424. "operation conv2d only support"
  1425. "simt, tensorop8816 and tensorop8832. (got:{})".format(args.type)
  1426. )
  1427. return GenerateConv2d_TensorOp_8832(args)
  1428. def GenerateDeconvOperations(args):
  1429. if args.type == "simt":
  1430. return GenerateDeconv_Simt(args)
  1431. else:
  1432. assert args.type == "tensorop8816", (
  1433. "operation deconv only support"
  1434. "simt and tensorop8816. (got:{})".format(args.type)
  1435. )
  1436. return GenerateDeconv_TensorOp_8816(args)
  1437. def GenerateDwconv2dFpropOperations(args):
  1438. if args.type == "simt":
  1439. return GenerateDwconv2d_Simt(args, ConvKind.Fprop)
  1440. else:
  1441. assert args.type == "tensorop884", (
  1442. "operation dwconv2d fprop only support"
  1443. "simt, tensorop884. (got:{})".format(args.type)
  1444. )
  1445. return GenerateDwconv2d_TensorOp_884(args, ConvKind.Fprop)
  1446. def GenerateDwconv2dDgradOperations(args):
  1447. if args.type == "simt":
  1448. return GenerateDwconv2d_Simt(args, ConvKind.Dgrad)
  1449. else:
  1450. assert args.type == "tensorop884", (
  1451. "operation dwconv2d fprop only support"
  1452. "simt, tensorop884. (got:{})".format(args.type)
  1453. )
  1454. return GenerateDwconv2d_TensorOp_884(args, ConvKind.Dgrad)
  1455. def GenerateDwconv2dWgradOperations(args):
  1456. if args.type == "simt":
  1457. return GenerateDwconv2d_Simt(args, ConvKind.Wgrad)
  1458. else:
  1459. assert args.type == "tensorop884", (
  1460. "operation dwconv2d fprop only support"
  1461. "simt, tensorop884. (got:{})".format(args.type)
  1462. )
  1463. return GenerateDwconv2d_TensorOp_884(args, ConvKind.Wgrad)
  1464. def GenerateGemmOperations(args):
  1465. if args.type == "tensorop884":
  1466. return GeneratesGemm_TensorOp_884(args)
  1467. elif args.type == "tensorop1688":
  1468. return GeneratesGemm_TensorOp_1688(args)
  1469. else:
  1470. assert (
  1471. args.type == "simt"
  1472. ), "operation gemm only support" "simt. (got:{})".format(args.type)
  1473. return GenerateGemm_Simt(args)
  1474. def GenerateGemvOperations(args):
  1475. assert args.type == "simt", "operation gemv only support" "simt. (got:{})".format(
  1476. args.type
  1477. )
  1478. return GenerateGemv_Simt(args)
  1479. ################################################################################
  1480. # parameters
  1481. # split_number - the concated file will be divided into split_number parts
  1482. # file_path - the path of file, which is need to be concated
  1483. # operations - args.operations
  1484. # type - args.type
  1485. # head - the head in the file
  1486. # required_cuda_ver_major - required cuda major
  1487. # required_cuda_ver_minor - required cuda minjor
  1488. # epilogue - the epilogue in the file
  1489. # wrapper_path - wrapper path
  1490. ################################################################################
  1491. def ConcatFile(
  1492. split_number: int,
  1493. file_path: str,
  1494. operations: str,
  1495. type: str,
  1496. head: str,
  1497. required_cuda_ver_major: str,
  1498. required_cuda_ver_minor: str,
  1499. epilogue: str,
  1500. wrapper_path=None,
  1501. ):
  1502. import os
  1503. meragefiledir = file_path
  1504. filenames = os.listdir(meragefiledir)
  1505. # filter file
  1506. if "tensorop" in type:
  1507. sub_string_1 = "tensorop"
  1508. sub_string_2 = type[8:]
  1509. else:
  1510. sub_string_1 = sub_string_2 = "simt"
  1511. if "dwconv2d_" in operations:
  1512. filtered_operations = operations[:2] + operations[9:]
  1513. elif ("conv2d" in operations) or ("deconv" in operations):
  1514. filtered_operations = "cutlass"
  1515. else:
  1516. filtered_operations = operations
  1517. # get the file list number
  1518. file_list = {}
  1519. file_list[operations + type] = 0
  1520. for filename in filenames:
  1521. if (
  1522. (filtered_operations in filename)
  1523. and (sub_string_1 in filename)
  1524. and (sub_string_2 in filename)
  1525. and ("all_" not in filename)
  1526. ):
  1527. file_list[operations + type] += 1
  1528. # concat file for linux
  1529. flag_1 = 0
  1530. flag_2 = 0
  1531. for filename in filenames:
  1532. if (
  1533. (filtered_operations in filename)
  1534. and (sub_string_1 in filename)
  1535. and (sub_string_2 in filename)
  1536. and ("all_" not in filename)
  1537. ):
  1538. flag_1 += 1
  1539. filepath = meragefiledir + "/" + filename
  1540. if (flag_1 >= flag_2 * (file_list[operations + type] / split_number)) and (
  1541. flag_1 <= (flag_2 + 1) * (file_list[operations + type] / split_number)
  1542. ):
  1543. file = open(
  1544. file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a"
  1545. )
  1546. # write Template at the head
  1547. if wrapper_path is None:
  1548. file.write(
  1549. SubstituteTemplate(
  1550. head,
  1551. {
  1552. "required_cuda_ver_major": str(required_cuda_ver_major),
  1553. "required_cuda_ver_minor": str(required_cuda_ver_minor),
  1554. },
  1555. )
  1556. )
  1557. else:
  1558. file.write(
  1559. SubstituteTemplate(
  1560. head,
  1561. {
  1562. "wrapper_path": wrapper_path,
  1563. "required_cuda_ver_major": str(required_cuda_ver_major),
  1564. "required_cuda_ver_minor": str(required_cuda_ver_minor),
  1565. },
  1566. )
  1567. )
  1568. # concat all the remaining files
  1569. if flag_2 == (split_number - 1):
  1570. for line in open(filepath):
  1571. file.writelines(line)
  1572. os.remove(filepath)
  1573. file.write("\n")
  1574. file.write(epilogue)
  1575. continue
  1576. for line in open(filepath):
  1577. file.writelines(line)
  1578. os.remove(filepath)
  1579. file.write("\n")
  1580. file.write(epilogue)
  1581. else:
  1582. # write Template at the head
  1583. if wrapper_path is None:
  1584. file.write(
  1585. SubstituteTemplate(
  1586. head,
  1587. {
  1588. "required_cuda_ver_major": str(required_cuda_ver_major),
  1589. "required_cuda_ver_minor": str(required_cuda_ver_minor),
  1590. },
  1591. )
  1592. )
  1593. else:
  1594. file.write(
  1595. SubstituteTemplate(
  1596. head,
  1597. {
  1598. "wrapper_path": wrapper_path,
  1599. "required_cuda_ver_major": str(required_cuda_ver_major),
  1600. "required_cuda_ver_minor": str(required_cuda_ver_minor),
  1601. },
  1602. )
  1603. )
  1604. for line in open(filepath):
  1605. file.writelines(line)
  1606. os.remove(filepath)
  1607. file.write("\n")
  1608. file.write(epilogue)
  1609. file.close()
  1610. flag_2 += 1
  1611. # concat file for windows
  1612. elif filename[0].isdigit() and ("all_" not in filename):
  1613. flag_1 += 1
  1614. filepath = meragefiledir + "/" + filename
  1615. if (flag_1 >= flag_2 * (len(filenames) / split_number)) and (
  1616. flag_1 <= (flag_2 + 1) * (len(filenames) / split_number)
  1617. ):
  1618. file = open(
  1619. file_path + "/{}_{}_{}.cu".format(operations, type, flag_2), "a"
  1620. )
  1621. # write Template at the head
  1622. if wrapper_path is None:
  1623. file.write(
  1624. SubstituteTemplate(
  1625. head,
  1626. {
  1627. "required_cuda_ver_major": str(required_cuda_ver_major),
  1628. "required_cuda_ver_minor": str(required_cuda_ver_minor),
  1629. },
  1630. )
  1631. )
  1632. else:
  1633. file.write(
  1634. SubstituteTemplate(
  1635. head,
  1636. {
  1637. "wrapper_path": wrapper_path,
  1638. "required_cuda_ver_major": str(required_cuda_ver_major),
  1639. "required_cuda_ver_minor": str(required_cuda_ver_minor),
  1640. },
  1641. )
  1642. )
  1643. # concat all the remaining files
  1644. if flag_2 == (split_number - 1):
  1645. for line in open(filepath):
  1646. file.writelines(line)
  1647. os.remove(filepath)
  1648. file.write("\n")
  1649. file.write(epilogue)
  1650. continue
  1651. for line in open(filepath):
  1652. file.writelines(line)
  1653. os.remove(filepath)
  1654. file.write("\n")
  1655. file.write(epilogue)
  1656. else:
  1657. # write Template at the head
  1658. if wrapper_path is None:
  1659. file.write(
  1660. SubstituteTemplate(
  1661. head,
  1662. {
  1663. "required_cuda_ver_major": str(required_cuda_ver_major),
  1664. "required_cuda_ver_minor": str(required_cuda_ver_minor),
  1665. },
  1666. )
  1667. )
  1668. else:
  1669. file.write(
  1670. SubstituteTemplate(
  1671. head,
  1672. {
  1673. "wrapper_path": wrapper_path,
  1674. "required_cuda_ver_major": str(required_cuda_ver_major),
  1675. "required_cuda_ver_minor": str(required_cuda_ver_minor),
  1676. },
  1677. )
  1678. )
  1679. for line in open(filepath):
  1680. file.writelines(line)
  1681. os.remove(filepath)
  1682. file.write("\n")
  1683. file.write(epilogue)
  1684. file.close()
  1685. flag_2 += 1
  1686. ###################################################################################################
  1687. ###################################################################################################
  1688. if __name__ == "__main__":
  1689. parser = argparse.ArgumentParser(
  1690. description="Generates device kernel registration code for CUTLASS Kernels"
  1691. )
  1692. parser.add_argument(
  1693. "--operations",
  1694. type=str,
  1695. choices=[
  1696. "gemm",
  1697. "gemv",
  1698. "conv2d",
  1699. "deconv",
  1700. "dwconv2d_fprop",
  1701. "dwconv2d_dgrad",
  1702. "dwconv2d_wgrad",
  1703. ],
  1704. required=True,
  1705. help="Specifies the operation to generate (gemm, gemv, conv2d, deconv, dwconv2d_fprop, dwconv2d_dgrad, dwconv2d_wgrad)",
  1706. )
  1707. parser.add_argument(
  1708. "output", type=str, help="output directory for CUTLASS kernel files"
  1709. )
  1710. parser.add_argument(
  1711. "--type",
  1712. type=str,
  1713. choices=["simt", "tensorop8816", "tensorop8832", "tensorop884", "tensorop1688"],
  1714. default="simt",
  1715. help="kernel type of CUTLASS kernel generator",
  1716. )
  1717. gemv_wrapper_path = (
  1718. "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
  1719. )
  1720. short_path = (
  1721. platform.system() == "Windows" or platform.system().find("NT") >= 0
  1722. ) and ("true" != os.getenv("CUTLASS_WITH_LONG_PATH", default="False").lower())
  1723. args = parser.parse_args()
  1724. if args.operations == "gemm":
  1725. operations = GenerateGemmOperations(args)
  1726. elif args.operations == "gemv":
  1727. operations = GenerateGemvOperations(args)
  1728. elif args.operations == "conv2d":
  1729. operations = GenerateConv2dOperations(args)
  1730. elif args.operations == "deconv":
  1731. operations = GenerateDeconvOperations(args)
  1732. elif args.operations == "dwconv2d_fprop":
  1733. operations = GenerateDwconv2dFpropOperations(args)
  1734. elif args.operations == "dwconv2d_dgrad":
  1735. operations = GenerateDwconv2dDgradOperations(args)
  1736. else:
  1737. assert args.operations == "dwconv2d_wgrad", "invalid operation"
  1738. operations = GenerateDwconv2dWgradOperations(args)
  1739. if (
  1740. args.operations == "conv2d"
  1741. or args.operations == "deconv"
  1742. or args.operations == "dwconv2d_fprop"
  1743. or args.operations == "dwconv2d_dgrad"
  1744. or args.operations == "dwconv2d_wgrad"
  1745. ):
  1746. for operation in operations:
  1747. with EmitConvSingleKernelWrapper(
  1748. args.output, operation, short_path
  1749. ) as emitter:
  1750. emitter.emit()
  1751. head = EmitConvSingleKernelWrapper(
  1752. args.output, operations[0], short_path
  1753. ).header_template
  1754. required_cuda_ver_major = operations[0].required_cuda_ver_major
  1755. required_cuda_ver_minor = operations[0].required_cuda_ver_minor
  1756. epilogue = EmitConvSingleKernelWrapper(
  1757. args.output, operations[0], short_path
  1758. ).epilogue_template
  1759. if "tensorop" in args.type:
  1760. ConcatFile(
  1761. 4,
  1762. args.output,
  1763. args.operations,
  1764. args.type,
  1765. head,
  1766. required_cuda_ver_major,
  1767. required_cuda_ver_minor,
  1768. epilogue,
  1769. )
  1770. else:
  1771. ConcatFile(
  1772. 2,
  1773. args.output,
  1774. args.operations,
  1775. args.type,
  1776. head,
  1777. required_cuda_ver_major,
  1778. required_cuda_ver_minor,
  1779. epilogue,
  1780. )
  1781. elif args.operations == "gemm":
  1782. for operation in operations:
  1783. with EmitGemmSingleKernelWrapper(
  1784. args.output, operation, short_path
  1785. ) as emitter:
  1786. emitter.emit()
  1787. head = EmitGemmSingleKernelWrapper(
  1788. args.output, operations[0], short_path
  1789. ).header_template
  1790. required_cuda_ver_major = operations[0].required_cuda_ver_major
  1791. required_cuda_ver_minor = operations[0].required_cuda_ver_minor
  1792. epilogue = EmitGemmSingleKernelWrapper(
  1793. args.output, operations[0], short_path
  1794. ).epilogue_template
  1795. if args.type == "tensorop884":
  1796. ConcatFile(
  1797. 30,
  1798. args.output,
  1799. args.operations,
  1800. args.type,
  1801. head,
  1802. required_cuda_ver_major,
  1803. required_cuda_ver_minor,
  1804. epilogue,
  1805. )
  1806. else:
  1807. ConcatFile(
  1808. 2,
  1809. args.output,
  1810. args.operations,
  1811. args.type,
  1812. head,
  1813. required_cuda_ver_major,
  1814. required_cuda_ver_minor,
  1815. epilogue,
  1816. )
  1817. elif args.operations == "gemv":
  1818. for operation in operations:
  1819. with EmitGemvSingleKernelWrapper(
  1820. args.output, operation, gemv_wrapper_path, short_path
  1821. ) as emitter:
  1822. emitter.emit()
  1823. head = EmitGemvSingleKernelWrapper(
  1824. args.output, operations[0], gemv_wrapper_path, short_path
  1825. ).header_template
  1826. required_cuda_ver_major = operations[0].required_cuda_ver_major
  1827. required_cuda_ver_minor = operations[0].required_cuda_ver_minor
  1828. epilogue = EmitGemvSingleKernelWrapper(
  1829. args.output, operations[0], gemv_wrapper_path, short_path
  1830. ).epilogue_template
  1831. ConcatFile(
  1832. 2,
  1833. args.output,
  1834. args.operations,
  1835. args.type,
  1836. head,
  1837. required_cuda_ver_major,
  1838. required_cuda_ver_minor,
  1839. epilogue,
  1840. wrapper_path=gemv_wrapper_path,
  1841. )
  1842. if args.operations != "gemv":
  1843. GenerateManifest(args, operations, args.output)
  1844. #
  1845. ###################################################################################################