You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemm_operation.py 52 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. import functools
  10. import operator
  11. from library import *
  12. ###################################################################################################
  13. #
  14. # Data structure modeling a GEMM operation
  15. #
  16. ###################################################################################################
  17. #
  18. class GemmOperation:
  19. #
  20. def __init__(
  21. self,
  22. gemm_kind,
  23. arch,
  24. tile_description,
  25. A,
  26. B,
  27. C,
  28. element_epilogue,
  29. epilogue_functor=EpilogueFunctor.LinearCombination,
  30. swizzling_functor=SwizzlingFunctor.Identity8,
  31. required_cuda_ver_major=9,
  32. required_cuda_ver_minor=2,
  33. ):
  34. self.operation_kind = OperationKind.Gemm
  35. self.arch = arch
  36. self.tile_description = tile_description
  37. self.gemm_kind = gemm_kind
  38. self.A = A
  39. self.B = B
  40. self.C = C
  41. self.element_epilogue = element_epilogue
  42. self.epilogue_functor = epilogue_functor
  43. self.swizzling_functor = swizzling_functor
  44. self.required_cuda_ver_major = required_cuda_ver_major
  45. self.required_cuda_ver_minor = required_cuda_ver_minor
  46. #
  47. def is_complex(self):
  48. complex_operators = [
  49. MathOperation.multiply_add_complex,
  50. MathOperation.multiply_add_complex_gaussian,
  51. ]
  52. return (
  53. self.tile_description.math_instruction.math_operation in complex_operators
  54. )
  55. #
  56. def is_split_k_parallel(self):
  57. return self.gemm_kind == GemmKind.SplitKParallel
  58. #
  59. def is_planar_complex(self):
  60. return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
  61. #
  62. def accumulator_type(self):
  63. accum = self.tile_description.math_instruction.element_accumulator
  64. if self.is_complex():
  65. return get_complex_from_real(accum)
  66. return accum
  67. #
  68. def short_math_name(self):
  69. if (
  70. self.tile_description.math_instruction.math_operation
  71. == MathOperation.multiply_add_complex_gaussian
  72. ):
  73. return "g%s" % ShortDataTypeNames[self.accumulator_type()]
  74. return ShortDataTypeNames[self.accumulator_type()]
  75. #
  76. def core_name(self):
  77. """ The basic operation kind is prefixed with a letter indicating the accumulation type. """
  78. inst_shape = ""
  79. inst_operation = ""
  80. intermediate_type = ""
  81. math_operations_map = {MathOperation.xor_popc: "xor"}
  82. if (
  83. self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp
  84. or self.tile_description.math_instruction.opcode_class
  85. == OpcodeClass.WmmaTensorOp
  86. ):
  87. math_op = self.tile_description.math_instruction.math_operation
  88. math_op_string = (
  89. math_operations_map[math_op]
  90. if math_op in math_operations_map.keys()
  91. else ""
  92. )
  93. inst_shape = "%d%d%d" % tuple(
  94. self.tile_description.math_instruction.instruction_shape
  95. )
  96. inst_shape += math_op_string
  97. if (
  98. self.tile_description.math_instruction.element_a != self.A.element
  99. and self.tile_description.math_instruction.element_a
  100. != self.tile_description.math_instruction.element_accumulator
  101. ):
  102. intermediate_type = DataTypeNames[
  103. self.tile_description.math_instruction.element_a
  104. ]
  105. return "%s%s%s%s" % (
  106. self.short_math_name(),
  107. inst_shape,
  108. intermediate_type,
  109. GemmKindNames[self.gemm_kind],
  110. )
  111. #
  112. def extended_name(self):
  113. """ Append data types if they differ from compute type. """
  114. if self.is_complex():
  115. extended_name = "${core_name}"
  116. else:
  117. if (
  118. self.C.element
  119. != self.tile_description.math_instruction.element_accumulator
  120. and self.A.element
  121. != self.tile_description.math_instruction.element_accumulator
  122. ):
  123. extended_name = "${element_c}_${core_name}_${element_a}"
  124. elif (
  125. self.C.element
  126. == self.tile_description.math_instruction.element_accumulator
  127. and self.A.element
  128. != self.tile_description.math_instruction.element_accumulator
  129. ):
  130. extended_name = "${core_name}_${element_a}"
  131. else:
  132. extended_name = "${core_name}"
  133. extended_name = SubstituteTemplate(
  134. extended_name,
  135. {
  136. "element_a": DataTypeNames[self.A.element],
  137. "element_c": DataTypeNames[self.C.element],
  138. "core_name": self.core_name(),
  139. },
  140. )
  141. return extended_name
  142. #
  143. def layout_name(self):
  144. if self.is_complex() or self.is_planar_complex():
  145. return "%s%s" % (
  146. ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
  147. ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
  148. )
  149. return "%s%s" % (
  150. ShortLayoutTypeNames[self.A.layout],
  151. ShortLayoutTypeNames[self.B.layout],
  152. )
  153. #
  154. def procedural_name(self):
  155. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  156. threadblock = self.tile_description.procedural_name()
  157. opcode_class_name = OpcodeClassNames[
  158. self.tile_description.math_instruction.opcode_class
  159. ]
  160. alignment = max([self.A.alignment, self.B.alignment, self.C.alignment])
  161. return SubstituteTemplate(
  162. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}",
  163. {
  164. "opcode_class": opcode_class_name,
  165. "extended_name": self.extended_name(),
  166. "threadblock": threadblock,
  167. "layout": self.layout_name(),
  168. "alignment": "%d" % self.A.alignment,
  169. },
  170. )
  171. #
  172. def configuration_name(self):
  173. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  174. return self.procedural_name()
  175. ###################################################################################################
  176. #
  177. # Data structure modeling a GEMV Batched Strided operation
  178. #
  179. ###################################################################################################
  180. #
  181. class GemvBatchedStridedOperation:
  182. #
  183. def __init__(
  184. self,
  185. gemm_kind,
  186. arch,
  187. math_inst,
  188. threadblock_shape,
  189. thread_shape,
  190. A,
  191. B,
  192. C,
  193. required_cuda_ver_major=9,
  194. required_cuda_ver_minor=2,
  195. ):
  196. self.operation_kind = OperationKind.Gemm
  197. self.arch = arch
  198. self.gemm_kind = gemm_kind
  199. self.math_instruction = math_inst
  200. self.threadblock_shape = threadblock_shape
  201. self.thread_shape = thread_shape
  202. self.A = A
  203. self.B = B
  204. self.C = C
  205. self.required_cuda_ver_major = required_cuda_ver_major
  206. self.required_cuda_ver_minor = required_cuda_ver_minor
  207. #
  208. def accumulator_type(self):
  209. accum = self.math_instruction.element_accumulator
  210. return accum
  211. #
  212. def short_math_name(self):
  213. return ShortDataTypeNames[self.accumulator_type()]
  214. #
  215. def core_name(self):
  216. """ The basic operation kind is prefixed with a letter indicating the accumulation type. """
  217. return "%s%s" % (self.short_math_name(), GemmKindNames[self.gemm_kind])
  218. #
  219. def extended_name(self):
  220. """ Append data types if they differ from compute type. """
  221. if (
  222. self.C.element != self.math_instruction.element_accumulator
  223. and self.A.element != self.math_instruction.element_accumulator
  224. ):
  225. extended_name = "${element_c}_${core_name}_${element_a}"
  226. elif (
  227. self.C.element == self.math_instruction.element_accumulator
  228. and self.A.element != self.math_instruction.element_accumulator
  229. ):
  230. extended_name = "${core_name}_${element_a}"
  231. else:
  232. extended_name = "${core_name}"
  233. extended_name = SubstituteTemplate(
  234. extended_name,
  235. {
  236. "element_a": DataTypeNames[self.A.element],
  237. "element_c": DataTypeNames[self.C.element],
  238. "core_name": self.core_name(),
  239. },
  240. )
  241. return extended_name
  242. #
  243. def layout_name(self):
  244. return "%s%s" % (
  245. ShortLayoutTypeNames[self.A.layout],
  246. ShortLayoutTypeNames[self.B.layout],
  247. )
  248. #
  249. def procedural_name(self):
  250. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  251. threadblock = "%dx%d_%d" % (
  252. self.threadblock_shape[0],
  253. self.threadblock_shape[1],
  254. self.threadblock_shape[2],
  255. )
  256. opcode_class_name = OpcodeClassNames[self.math_instruction.opcode_class]
  257. alignment_a = self.A.alignment
  258. alignment_b = self.B.alignment
  259. return SubstituteTemplate(
  260. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment_a}x${alignment_b}",
  261. {
  262. "opcode_class": opcode_class_name,
  263. "extended_name": self.extended_name(),
  264. "threadblock": threadblock,
  265. "layout": self.layout_name(),
  266. "alignment_a": "%d" % alignment_a,
  267. "alignment_b": "%d" % alignment_b,
  268. },
  269. )
  270. #
  271. def configuration_name(self):
  272. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  273. return self.procedural_name()
  274. #
  275. def GeneratesGemm(
  276. tile,
  277. data_type,
  278. layout_a,
  279. layout_b,
  280. layout_c,
  281. min_cc,
  282. align_a=32,
  283. align_b=32,
  284. align_c=32,
  285. required_cuda_ver_major=9,
  286. required_cuda_ver_minor=2,
  287. ):
  288. operations = []
  289. swizzling_functor = SwizzlingFunctor.Identity1
  290. element_a, element_b, element_c, element_epilogue = data_type
  291. if tile.math_instruction.element_accumulator == DataType.s32:
  292. epilogues = [EpilogueFunctor.LinearCombinationClamp]
  293. else:
  294. assert (
  295. tile.math_instruction.element_accumulator == DataType.f32
  296. or tile.math_instruction.element_accumulator == DataType.f16
  297. )
  298. epilogues = [EpilogueFunctor.LinearCombination]
  299. for epilogue in epilogues:
  300. A = TensorDescription(
  301. element_a, layout_a, int(align_a // DataTypeSize[element_a])
  302. )
  303. B = TensorDescription(
  304. element_b, layout_b, int(align_b // DataTypeSize[element_b])
  305. )
  306. C = TensorDescription(
  307. element_c, layout_c, int(align_c // DataTypeSize[element_c])
  308. )
  309. operations.append(
  310. GemmOperation(
  311. GemmKind.Gemm,
  312. min_cc,
  313. tile,
  314. A,
  315. B,
  316. C,
  317. element_epilogue,
  318. epilogue,
  319. swizzling_functor,
  320. required_cuda_ver_major,
  321. required_cuda_ver_minor,
  322. )
  323. )
  324. operations.append(
  325. GemmOperation(
  326. GemmKind.SplitKParallel,
  327. min_cc,
  328. tile,
  329. A,
  330. B,
  331. C,
  332. element_epilogue,
  333. epilogue,
  334. swizzling_functor,
  335. required_cuda_ver_major,
  336. required_cuda_ver_minor,
  337. )
  338. )
  339. return operations
  340. def GeneratesGemv(
  341. math_inst,
  342. threadblock_shape,
  343. thread_shape,
  344. data_type,
  345. layout_a,
  346. layout_b,
  347. layout_c,
  348. min_cc,
  349. align_a=32,
  350. align_b=32,
  351. align_c=32,
  352. required_cuda_ver_major=9,
  353. required_cuda_ver_minor=2,
  354. ):
  355. element_a, element_b, element_c, element_epilogue = data_type
  356. A = TensorDescription(element_a, layout_a, int(align_a // DataTypeSize[element_a]))
  357. B = TensorDescription(element_b, layout_b, int(align_b // DataTypeSize[element_b]))
  358. C = TensorDescription(element_c, layout_c, int(align_c // DataTypeSize[element_c]))
  359. return GemvBatchedStridedOperation(
  360. GemmKind.GemvBatchedStrided,
  361. min_cc,
  362. math_inst,
  363. threadblock_shape,
  364. thread_shape,
  365. A,
  366. B,
  367. C,
  368. required_cuda_ver_major,
  369. required_cuda_ver_minor,
  370. )
  371. ###################################################################################################
  372. #
  373. # Emits single instances of a CUTLASS device-wide operator
  374. #
  375. ###################################################################################################
  376. #
  377. class EmitGemmInstance:
  378. """ Responsible for emitting a CUTLASS template definition"""
  379. def __init__(self):
  380. self.gemm_template = """
  381. // Gemm operator ${operation_name}
  382. using Operation_${operation_name} = cutlass::gemm::device::Gemm<
  383. ${element_a}, ${layout_a},
  384. ${element_b}, ${layout_b},
  385. ${element_c}, ${layout_c},
  386. ${element_accumulator},
  387. ${opcode_class},
  388. ${arch},
  389. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  390. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  391. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  392. ${epilogue_functor}<
  393. ${element_c},
  394. ${epilogue_vector_length},
  395. ${element_accumulator},
  396. ${element_epilogue}
  397. >,
  398. ${swizzling_functor},
  399. ${stages},
  400. ${align_a},
  401. ${align_b},
  402. false,
  403. ${math_operation}
  404. ${residual}
  405. >;
  406. """
  407. self.gemm_complex_template = """
  408. // Gemm operator ${operation_name}
  409. using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
  410. ${element_a}, ${layout_a},
  411. ${element_b}, ${layout_b},
  412. ${element_c}, ${layout_c},
  413. ${element_accumulator},
  414. ${opcode_class},
  415. ${arch},
  416. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  417. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  418. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  419. ${epilogue_functor}<
  420. ${element_c},
  421. ${epilogue_vector_length},
  422. ${element_accumulator},
  423. ${element_epilogue}
  424. >,
  425. ${swizzling_functor},
  426. ${stages},
  427. ${transform_a},
  428. ${transform_b},
  429. ${math_operation}
  430. ${residual}
  431. >;
  432. """
  433. def emit(self, operation):
  434. warp_shape = [
  435. operation.tile_description.threadblock_shape[idx]
  436. // operation.tile_description.warp_count[idx]
  437. for idx in range(3)
  438. ]
  439. epilogue_vector_length = int(
  440. min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
  441. / DataTypeSize[operation.C.element]
  442. )
  443. residual = ""
  444. values = {
  445. "operation_name": operation.procedural_name(),
  446. "element_a": DataTypeTag[operation.A.element],
  447. "layout_a": LayoutTag[operation.A.layout],
  448. "element_b": DataTypeTag[operation.B.element],
  449. "layout_b": LayoutTag[operation.B.layout],
  450. "element_c": DataTypeTag[operation.C.element],
  451. "layout_c": LayoutTag[operation.C.layout],
  452. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  453. "opcode_class": OpcodeClassTag[
  454. operation.tile_description.math_instruction.opcode_class
  455. ],
  456. "arch": "cutlass::arch::Sm%d" % operation.arch,
  457. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  458. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  459. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  460. "warp_shape_m": str(warp_shape[0]),
  461. "warp_shape_n": str(warp_shape[1]),
  462. "warp_shape_k": str(warp_shape[2]),
  463. "instruction_shape_m": str(
  464. operation.tile_description.math_instruction.instruction_shape[0]
  465. ),
  466. "instruction_shape_n": str(
  467. operation.tile_description.math_instruction.instruction_shape[1]
  468. ),
  469. "instruction_shape_k": str(
  470. operation.tile_description.math_instruction.instruction_shape[2]
  471. ),
  472. "epilogue_vector_length": str(epilogue_vector_length),
  473. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  474. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  475. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  476. "stages": str(operation.tile_description.stages),
  477. "align_a": str(operation.A.alignment),
  478. "align_b": str(operation.B.alignment),
  479. "transform_a": ComplexTransformTag[operation.A.complex_transform],
  480. "transform_b": ComplexTransformTag[operation.B.complex_transform],
  481. "math_operation": MathOperationTag[
  482. operation.tile_description.math_instruction.math_operation
  483. ],
  484. "residual": residual,
  485. }
  486. template = (
  487. self.gemm_complex_template if operation.is_complex() else self.gemm_template
  488. )
  489. return SubstituteTemplate(template, values)
  490. #
  491. class EmitGemvBatchedStridedInstance:
  492. """ Responsible for emitting a CUTLASS template definition"""
  493. def __init__(self):
  494. self.template = """
  495. // Gemm operator ${operation_name}
  496. using Operation_${operation_name} = cutlass::gemm::kernel::DefaultGemv<
  497. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  498. cutlass::gemm::GemmShape<${thread_shape_m}, ${thread_shape_n}, ${thread_shape_k}>,
  499. ${element_a}, ${layout_a},
  500. ${element_b}, ${layout_b},
  501. ${element_c}, ${layout_c}
  502. >;
  503. """
  504. def emit(self, operation):
  505. values = {
  506. "operation_name": operation.procedural_name(),
  507. "element_a": DataTypeTag[operation.A.element],
  508. "layout_a": LayoutTag[operation.A.layout],
  509. "element_b": DataTypeTag[operation.B.element],
  510. "layout_b": LayoutTag[operation.B.layout],
  511. "element_c": DataTypeTag[operation.C.element],
  512. "layout_c": LayoutTag[operation.C.layout],
  513. "threadblock_shape_m": str(operation.threadblock_shape[0]),
  514. "threadblock_shape_n": str(operation.threadblock_shape[1]),
  515. "threadblock_shape_k": str(operation.threadblock_shape[2]),
  516. "thread_shape_m": str(operation.thread_shape[0]),
  517. "thread_shape_n": str(operation.thread_shape[1]),
  518. "thread_shape_k": str(operation.thread_shape[2]),
  519. }
  520. return SubstituteTemplate(self.template, values)
  521. ###################################################################################################
  522. class EmitSparseGemmInstance:
  523. """ Responsible for emitting a CUTLASS template definition"""
  524. def __init__(self):
  525. self.gemm_template = """
  526. // Gemm operator ${operation_name}
  527. using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
  528. ${element_a}, ${layout_a},
  529. ${element_b}, ${layout_b},
  530. ${element_c}, ${layout_c},
  531. ${element_accumulator},
  532. ${opcode_class},
  533. ${arch},
  534. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  535. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  536. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  537. ${epilogue_functor}<
  538. ${element_c},
  539. ${epilogue_vector_length},
  540. ${element_accumulator},
  541. ${element_epilogue}
  542. >,
  543. ${swizzling_functor},
  544. ${stages},
  545. ${align_a},
  546. ${align_b},
  547. false,
  548. ${math_operation}
  549. ${residual}
  550. >;
  551. """
  552. def emit(self, operation):
  553. warp_shape = [
  554. operation.tile_description.threadblock_shape[idx]
  555. // operation.tile_description.warp_count[idx]
  556. for idx in range(3)
  557. ]
  558. epilogue_vector_length = int(
  559. min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
  560. / DataTypeSize[operation.C.element]
  561. )
  562. residual = ""
  563. values = {
  564. "operation_name": operation.procedural_name(),
  565. "element_a": DataTypeTag[operation.A.element],
  566. "layout_a": LayoutTag[operation.A.layout],
  567. "element_b": DataTypeTag[operation.B.element],
  568. "layout_b": LayoutTag[operation.B.layout],
  569. "element_c": DataTypeTag[operation.C.element],
  570. "layout_c": LayoutTag[operation.C.layout],
  571. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  572. "opcode_class": OpcodeClassTag[
  573. operation.tile_description.math_instruction.opcode_class
  574. ],
  575. "arch": "cutlass::arch::Sm%d" % operation.arch,
  576. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  577. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  578. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  579. "warp_shape_m": str(warp_shape[0]),
  580. "warp_shape_n": str(warp_shape[1]),
  581. "warp_shape_k": str(warp_shape[2]),
  582. "instruction_shape_m": str(
  583. operation.tile_description.math_instruction.instruction_shape[0]
  584. ),
  585. "instruction_shape_n": str(
  586. operation.tile_description.math_instruction.instruction_shape[1]
  587. ),
  588. "instruction_shape_k": str(
  589. operation.tile_description.math_instruction.instruction_shape[2]
  590. ),
  591. "epilogue_vector_length": str(epilogue_vector_length),
  592. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  593. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  594. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  595. "stages": str(operation.tile_description.stages),
  596. "align_a": str(operation.A.alignment),
  597. "align_b": str(operation.B.alignment),
  598. "transform_a": ComplexTransformTag[operation.A.complex_transform],
  599. "transform_b": ComplexTransformTag[operation.B.complex_transform],
  600. "math_operation": MathOperationTag[
  601. operation.tile_description.math_instruction.math_operation
  602. ],
  603. "residual": residual,
  604. }
  605. template = self.gemm_template
  606. return SubstituteTemplate(template, values)
  607. ###################################################################################################
  608. #
  609. class EmitGemmUniversalInstance:
  610. """ Responsible for emitting a CUTLASS template definition"""
  611. def __init__(self):
  612. self.gemm_template = """
  613. // Gemm operator ${operation_name}
  614. using ${operation_name}_base =
  615. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  616. ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
  617. ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
  618. ${element_c}, ${layout_c},
  619. ${element_accumulator},
  620. ${opcode_class},
  621. ${arch},
  622. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  623. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  624. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  625. ${epilogue_functor}<
  626. ${element_c},
  627. ${epilogue_vector_length},
  628. ${element_accumulator},
  629. ${element_epilogue}
  630. >,
  631. ${swizzling_functor},
  632. ${stages},
  633. ${math_operation}
  634. >::GemmKernel;
  635. // Define named type
  636. struct ${operation_name} :
  637. public ${operation_name}_base { };
  638. """
  639. self.gemm_template_interleaved = """
  640. // Gemm operator ${operation_name}
  641. using ${operation_name}_base =
  642. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  643. ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
  644. ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
  645. ${element_c}, ${layout_c},
  646. ${element_accumulator},
  647. ${opcode_class},
  648. ${arch},
  649. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  650. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  651. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  652. ${epilogue_functor}<
  653. ${element_c},
  654. ${epilogue_vector_length},
  655. ${element_accumulator},
  656. ${element_epilogue}
  657. >,
  658. ${swizzling_functor},
  659. ${stages},
  660. ${math_operation}
  661. >::GemmKernel;
  662. // Define named type
  663. struct ${operation_name} :
  664. public ${operation_name}_base { };
  665. """
  666. def emit(self, operation):
  667. threadblock_shape = operation.tile_description.threadblock_shape
  668. warp_count = operation.tile_description.warp_count
  669. warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
  670. epilogue_vector_length = int(
  671. min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
  672. / DataTypeSize[operation.C.element]
  673. )
  674. transpose_layouts = {
  675. LayoutType.ColumnMajor: LayoutType.RowMajor,
  676. LayoutType.RowMajor: LayoutType.ColumnMajor,
  677. }
  678. if (
  679. operation.A.layout in transpose_layouts.keys()
  680. and operation.B.layout in transpose_layouts.keys()
  681. and operation.C.layout in transpose_layouts.keys()
  682. ):
  683. instance_layout_A = transpose_layouts[operation.A.layout]
  684. instance_layout_B = transpose_layouts[operation.B.layout]
  685. instance_layout_C = transpose_layouts[operation.C.layout]
  686. gemm_template = self.gemm_template
  687. else:
  688. instance_layout_A, instance_layout_B, instance_layout_C = (
  689. operation.A.layout,
  690. operation.B.layout,
  691. operation.C.layout,
  692. )
  693. gemm_template = self.gemm_template_interleaved
  694. #
  695. values = {
  696. "operation_name": operation.procedural_name(),
  697. "element_a": DataTypeTag[operation.A.element],
  698. "layout_a": LayoutTag[instance_layout_A],
  699. "element_b": DataTypeTag[operation.B.element],
  700. "layout_b": LayoutTag[instance_layout_B],
  701. "element_c": DataTypeTag[operation.C.element],
  702. "layout_c": LayoutTag[instance_layout_C],
  703. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  704. "opcode_class": OpcodeClassTag[
  705. operation.tile_description.math_instruction.opcode_class
  706. ],
  707. "arch": "cutlass::arch::Sm%d" % operation.arch,
  708. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  709. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  710. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  711. "warp_shape_m": str(warp_shape[0]),
  712. "warp_shape_n": str(warp_shape[1]),
  713. "warp_shape_k": str(warp_shape[2]),
  714. "instruction_shape_m": str(
  715. operation.tile_description.math_instruction.instruction_shape[0]
  716. ),
  717. "instruction_shape_n": str(
  718. operation.tile_description.math_instruction.instruction_shape[1]
  719. ),
  720. "instruction_shape_k": str(
  721. operation.tile_description.math_instruction.instruction_shape[2]
  722. ),
  723. "epilogue_vector_length": str(epilogue_vector_length),
  724. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  725. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  726. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  727. "stages": str(operation.tile_description.stages),
  728. "align_a": str(operation.A.alignment),
  729. "align_b": str(operation.B.alignment),
  730. "transform_a": ComplexTransformTag[operation.A.complex_transform],
  731. "transform_b": ComplexTransformTag[operation.B.complex_transform],
  732. "math_operation": MathOperationTag[
  733. operation.tile_description.math_instruction.math_operation
  734. ],
  735. }
  736. return SubstituteTemplate(gemm_template, values)
  737. ###################################################################################################
  738. #
  739. class EmitGemmPlanarComplexInstance:
  740. """ Responsible for emitting a CUTLASS template definition"""
  741. def __init__(self):
  742. self.template = """
  743. // Gemm operator ${operation_name}
  744. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  745. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  746. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  747. ${element_c}, cutlass::layout::RowMajor,
  748. ${element_accumulator},
  749. ${opcode_class},
  750. ${arch},
  751. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  752. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  753. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  754. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  755. ${element_c},
  756. ${alignment_c},
  757. ${element_accumulator},
  758. ${element_epilogue}
  759. >,
  760. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  761. ${stages},
  762. ${math_operator}
  763. >::GemmKernel;
  764. struct ${operation_name} :
  765. public Operation_${operation_name} { };
  766. """
  767. def emit(self, operation):
  768. warp_shape = [
  769. operation.tile_description.threadblock_shape[idx]
  770. // operation.tile_description.warp_count[idx]
  771. for idx in range(3)
  772. ]
  773. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  774. transposed_layout_A = TransposedLayout[operation.A.layout]
  775. transposed_layout_B = TransposedLayout[operation.B.layout]
  776. values = {
  777. "operation_name": operation.procedural_name(),
  778. "element_a": DataTypeTag[operation.B.element],
  779. "layout_a": LayoutTag[transposed_layout_B],
  780. "transform_a": ComplexTransformTag[operation.B.complex_transform],
  781. "alignment_a": str(operation.B.alignment),
  782. "element_b": DataTypeTag[operation.A.element],
  783. "layout_b": LayoutTag[transposed_layout_A],
  784. "transform_b": ComplexTransformTag[operation.A.complex_transform],
  785. "alignment_b": str(operation.A.alignment),
  786. "element_c": DataTypeTag[operation.C.element],
  787. "layout_c": LayoutTag[operation.C.layout],
  788. "element_accumulator": DataTypeTag[
  789. operation.tile_description.math_instruction.element_accumulator
  790. ],
  791. "opcode_class": OpcodeClassTag[
  792. operation.tile_description.math_instruction.opcode_class
  793. ],
  794. "arch": "cutlass::arch::Sm%d" % operation.arch,
  795. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  796. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  797. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  798. "warp_shape_m": str(warp_shape[0]),
  799. "warp_shape_n": str(warp_shape[1]),
  800. "warp_shape_k": str(warp_shape[2]),
  801. "instruction_shape_m": str(
  802. operation.tile_description.math_instruction.instruction_shape[0]
  803. ),
  804. "instruction_shape_n": str(
  805. operation.tile_description.math_instruction.instruction_shape[1]
  806. ),
  807. "instruction_shape_k": str(
  808. operation.tile_description.math_instruction.instruction_shape[2]
  809. ),
  810. "alignment_c": str(operation.C.alignment),
  811. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  812. "stages": str(operation.tile_description.stages),
  813. "math_operator": "cutlass::arch::OpMultiplyAdd",
  814. }
  815. return SubstituteTemplate(self.template, values)
  816. ###################################################################################################
  817. #
  818. class EmitGemmPlanarComplexArrayInstance:
  819. """ Responsible for emitting a CUTLASS template definition"""
  820. def __init__(self):
  821. self.template = """
  822. // Gemm operator ${operation_name}
  823. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  824. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  825. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  826. ${element_c}, cutlass::layout::RowMajor,
  827. ${element_accumulator},
  828. ${opcode_class},
  829. ${arch},
  830. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  831. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  832. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  833. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  834. ${element_c},
  835. ${alignment_c},
  836. ${element_accumulator},
  837. ${element_epilogue}
  838. >,
  839. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  840. ${stages},
  841. ${math_operator}
  842. >::GemmArrayKernel;
  843. struct ${operation_name} : public Operation_${operation_name} { };
  844. """
  845. def emit(self, operation):
  846. warp_shape = [
  847. operation.tile_description.threadblock_shape[idx]
  848. // operation.tile_description.warp_count[idx]
  849. for idx in range(3)
  850. ]
  851. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  852. transposed_layout_A = TransposedLayout[operation.A.layout]
  853. transposed_layout_B = TransposedLayout[operation.B.layout]
  854. values = {
  855. "operation_name": operation.procedural_name(),
  856. "element_a": DataTypeTag[operation.B.element],
  857. "layout_a": LayoutTag[transposed_layout_B],
  858. "transform_a": ComplexTransformTag[operation.B.complex_transform],
  859. "alignment_a": str(operation.B.alignment),
  860. "element_b": DataTypeTag[operation.A.element],
  861. "layout_b": LayoutTag[transposed_layout_A],
  862. "transform_b": ComplexTransformTag[operation.A.complex_transform],
  863. "alignment_b": str(operation.A.alignment),
  864. "element_c": DataTypeTag[operation.C.element],
  865. "layout_c": LayoutTag[operation.C.layout],
  866. "element_accumulator": DataTypeTag[
  867. operation.tile_description.math_instruction.element_accumulator
  868. ],
  869. "opcode_class": OpcodeClassTag[
  870. operation.tile_description.math_instruction.opcode_class
  871. ],
  872. "arch": "cutlass::arch::Sm%d" % operation.arch,
  873. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  874. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  875. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  876. "warp_shape_m": str(warp_shape[0]),
  877. "warp_shape_n": str(warp_shape[1]),
  878. "warp_shape_k": str(warp_shape[2]),
  879. "instruction_shape_m": str(
  880. operation.tile_description.math_instruction.instruction_shape[0]
  881. ),
  882. "instruction_shape_n": str(
  883. operation.tile_description.math_instruction.instruction_shape[1]
  884. ),
  885. "instruction_shape_k": str(
  886. operation.tile_description.math_instruction.instruction_shape[2]
  887. ),
  888. "alignment_c": str(operation.C.alignment),
  889. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  890. "stages": str(operation.tile_description.stages),
  891. "math_operator": "cutlass::arch::OpMultiplyAdd",
  892. }
  893. return SubstituteTemplate(self.template, values)
  894. #
  895. class EmitGemmSplitKParallelInstance:
  896. """ Responsible for emitting a CUTLASS template definition"""
  897. def __init__(self):
  898. self.template = """
  899. // Gemm operator ${operation_name}
  900. using Operation_${operation_name} = cutlass::gemm::device::GemmSplitKParallel<
  901. ${element_a}, ${layout_a},
  902. ${element_b}, ${layout_b},
  903. ${element_c}, ${layout_c},
  904. ${element_accumulator},
  905. ${opcode_class},
  906. ${arch},
  907. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  908. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  909. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  910. ${epilogue_functor}<
  911. ${element_c},
  912. ${epilogue_vector_length},
  913. ${element_accumulator},
  914. ${element_epilogue}
  915. >,
  916. cutlass::epilogue::thread::Convert<
  917. ${element_accumulator},
  918. ${epilogue_vector_length},
  919. ${element_accumulator}
  920. >,
  921. cutlass::reduction::thread::ReduceAdd<
  922. ${element_accumulator},
  923. ${element_accumulator},
  924. ${epilogue_vector_length}
  925. >,
  926. cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle,
  927. ${stages},
  928. ${align_a},
  929. ${align_b},
  930. ${math_operation}
  931. >;
  932. """
  933. def emit(self, operation):
  934. warp_shape = [
  935. operation.tile_description.threadblock_shape[idx]
  936. // operation.tile_description.warp_count[idx]
  937. for idx in range(3)
  938. ]
  939. epilogue_vector_length = int(
  940. min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
  941. / DataTypeSize[operation.C.element]
  942. )
  943. values = {
  944. "operation_name": operation.procedural_name(),
  945. "element_a": DataTypeTag[operation.A.element],
  946. "layout_a": LayoutTag[operation.A.layout],
  947. "element_b": DataTypeTag[operation.B.element],
  948. "layout_b": LayoutTag[operation.B.layout],
  949. "element_c": DataTypeTag[operation.C.element],
  950. "layout_c": LayoutTag[operation.C.layout],
  951. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  952. "opcode_class": OpcodeClassTag[
  953. operation.tile_description.math_instruction.opcode_class
  954. ],
  955. "arch": "cutlass::arch::Sm%d" % operation.arch,
  956. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  957. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  958. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  959. "warp_shape_m": str(warp_shape[0]),
  960. "warp_shape_n": str(warp_shape[1]),
  961. "warp_shape_k": str(warp_shape[2]),
  962. "instruction_shape_m": str(
  963. operation.tile_description.math_instruction.instruction_shape[0]
  964. ),
  965. "instruction_shape_n": str(
  966. operation.tile_description.math_instruction.instruction_shape[1]
  967. ),
  968. "instruction_shape_k": str(
  969. operation.tile_description.math_instruction.instruction_shape[2]
  970. ),
  971. "epilogue_vector_length": str(epilogue_vector_length),
  972. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  973. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  974. "stages": str(operation.tile_description.stages),
  975. "math_operation": MathOperationTag[
  976. operation.tile_description.math_instruction.math_operation
  977. ],
  978. "align_a": str(operation.A.alignment),
  979. "align_b": str(operation.B.alignment),
  980. }
  981. return SubstituteTemplate(self.template, values)
  982. ###################################################################################################
  983. ###################################################################################################
  984. #
  985. # Emitters functions for all targets
  986. #
  987. ###################################################################################################
  988. class EmitGemmConfigurationLibrary:
  989. def __init__(self, operation_path, configuration_name):
  990. self.configuration_name = configuration_name
  991. self.configuration_path = os.path.join(
  992. operation_path, "%s.cu" % configuration_name
  993. ).replace("\\", "/")
  994. self.instance_emitter = {
  995. GemmKind.Gemm: EmitGemmInstance,
  996. GemmKind.Sparse: EmitSparseGemmInstance,
  997. GemmKind.Universal: EmitGemmUniversalInstance,
  998. GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
  999. GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
  1000. }
  1001. self.gemm_kind_wrappers = {
  1002. GemmKind.Gemm: "GemmOperation",
  1003. GemmKind.Sparse: "GemmSparseOperation",
  1004. GemmKind.Universal: "GemmUniversalOperation",
  1005. GemmKind.PlanarComplex: "GemmPlanarComplexOperation",
  1006. GemmKind.PlanarComplexArray: "GemmPlanarComplexArrayOperation",
  1007. }
  1008. self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
  1009. self.instance_template = {
  1010. GemmKind.Gemm: """
  1011. ${compile_guard_start}
  1012. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  1013. ${compile_guard_end}
  1014. """,
  1015. GemmKind.Sparse: """
  1016. ${compile_guard_start}
  1017. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  1018. ${compile_guard_end}
  1019. """,
  1020. GemmKind.Universal: """
  1021. ${compile_guard_start}
  1022. manifest.append(new ${gemm_kind}<
  1023. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  1024. >("${operation_name}"));
  1025. ${compile_guard_end}
  1026. """,
  1027. GemmKind.PlanarComplex: """
  1028. ${compile_guard_start}
  1029. manifest.append(new ${gemm_kind}<
  1030. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  1031. >("${operation_name}"));
  1032. ${compile_guard_end}
  1033. """,
  1034. GemmKind.PlanarComplexArray: """
  1035. ${compile_guard_start}
  1036. manifest.append(new ${gemm_kind}<
  1037. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  1038. >("${operation_name}"));
  1039. ${compile_guard_end}
  1040. """,
  1041. }
  1042. self.header_template = """
  1043. /*
  1044. Generated by gemm_operation.py - Do not edit.
  1045. */
  1046. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1047. #include "cutlass/arch/wmma.h"
  1048. #include "cutlass/cutlass.h"
  1049. #include "cutlass/library/library.h"
  1050. #include "cutlass/library/manifest.h"
  1051. #include "library_internal.h"
  1052. #include "gemm_operation.h"
  1053. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1054. """
  1055. self.initialize_function_template = """
  1056. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1057. namespace cutlass {
  1058. namespace library {
  1059. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1060. void initialize_${configuration_name}(Manifest &manifest) {
  1061. """
  1062. self.epilogue_template = """
  1063. }
  1064. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1065. } // namespace library
  1066. } // namespace cutlass
  1067. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1068. """
  1069. def __enter__(self):
  1070. self.configuration_file = open(self.configuration_path, "w")
  1071. self.configuration_file.write(self.header_template)
  1072. self.instance_definitions = []
  1073. self.instance_wrappers = []
  1074. self.operations = []
  1075. return self
  1076. def emit(self, operation):
  1077. emitter = self.instance_emitter[operation.gemm_kind]()
  1078. self.operations.append(operation)
  1079. self.instance_definitions.append(emitter.emit(operation))
  1080. self.instance_wrappers.append(
  1081. SubstituteTemplate(
  1082. self.instance_template[operation.gemm_kind],
  1083. {
  1084. "configuration_name": self.configuration_name,
  1085. "operation_name": operation.procedural_name(),
  1086. "gemm_kind": self.gemm_kind_wrappers[operation.gemm_kind],
  1087. "compile_guard_start": SubstituteTemplate(
  1088. self.wmma_guard_start, {"sm_number": str(operation.arch)}
  1089. )
  1090. if operation.tile_description.math_instruction.opcode_class
  1091. == OpcodeClass.WmmaTensorOp
  1092. else "",
  1093. "compile_guard_end": "#endif"
  1094. if operation.tile_description.math_instruction.opcode_class
  1095. == OpcodeClass.WmmaTensorOp
  1096. else "",
  1097. },
  1098. )
  1099. )
  1100. def __exit__(self, exception_type, exception_value, traceback):
  1101. # Write instance definitions in top-level namespace
  1102. for instance_definition in self.instance_definitions:
  1103. self.configuration_file.write(instance_definition)
  1104. # Add wrapper objects within initialize() function
  1105. self.configuration_file.write(
  1106. SubstituteTemplate(
  1107. self.initialize_function_template,
  1108. {"configuration_name": self.configuration_name},
  1109. )
  1110. )
  1111. for instance_wrapper in self.instance_wrappers:
  1112. self.configuration_file.write(instance_wrapper)
  1113. self.configuration_file.write(self.epilogue_template)
  1114. self.configuration_file.close()
  1115. ###################################################################################################
  1116. ###################################################################################################
  1117. class EmitGemmSingleKernelWrapper:
  1118. def __init__(self, kernel_path, gemm_operation, short_path=False):
  1119. self.short_path = short_path
  1120. self.kernel_path = kernel_path
  1121. self.operation = gemm_operation
  1122. instance_emitters = {
  1123. GemmKind.Gemm: EmitGemmInstance(),
  1124. GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(),
  1125. }
  1126. self.instance_emitter = instance_emitters[self.operation.gemm_kind]
  1127. self.header_template = """
  1128. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  1129. // ignore warning of cutlass
  1130. #pragma GCC diagnostic push
  1131. #pragma GCC diagnostic ignored "-Wunused-parameter"
  1132. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  1133. #pragma GCC diagnostic ignored "-Wuninitialized"
  1134. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  1135. #include "cutlass/gemm/device/gemm.h"
  1136. #include "cutlass/gemm/device/gemm_splitk_parallel.h"
  1137. #include "src/cuda/cutlass/manifest.h"
  1138. #include "src/cuda/cutlass/gemm_operation.h"
  1139. """
  1140. self.instance_template = """
  1141. ${operation_instance}
  1142. """
  1143. self.manifest_template = """
  1144. namespace cutlass {
  1145. namespace library {
  1146. void initialize_${operation_name}(Manifest &manifest) {
  1147. manifest.append(new GemmOperation<
  1148. Operation_${operation_name}
  1149. >("${operation_name}"));
  1150. }
  1151. } // namespace library
  1152. } // namespace cutlass
  1153. """
  1154. self.epilogue_template = """
  1155. #pragma GCC diagnostic pop
  1156. #endif
  1157. """
  1158. #
  1159. def __enter__(self):
  1160. if self.short_path:
  1161. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  1162. GlobalCnt.cnt += 1
  1163. else:
  1164. self.kernel_path = os.path.join(
  1165. self.kernel_path, "%s.cu" % self.operation.procedural_name()
  1166. )
  1167. self.kernel_file = open(self.kernel_path, "w")
  1168. return self
  1169. #
  1170. def emit(self):
  1171. self.kernel_file.write(
  1172. SubstituteTemplate(
  1173. self.instance_template,
  1174. {"operation_instance": self.instance_emitter.emit(self.operation)},
  1175. )
  1176. )
  1177. # emit manifest helper
  1178. manifest = SubstituteTemplate(
  1179. self.manifest_template, {"operation_name": self.operation.procedural_name()}
  1180. )
  1181. self.kernel_file.write(manifest)
  1182. #
  1183. def __exit__(self, exception_type, exception_value, traceback):
  1184. self.kernel_file.close()
  1185. ###################################################################################################
  1186. ###################################################################################################
  1187. class EmitGemvSingleKernelWrapper:
  1188. def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False):
  1189. self.kernel_path = kernel_path
  1190. self.wrapper_path = wrapper_path
  1191. self.operation = gemm_operation
  1192. self.short_path = short_path
  1193. self.wrapper_template = """
  1194. template void megdnn::cuda::cutlass_wrapper::
  1195. cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>(
  1196. BatchedGemmCoord const& problem_size,
  1197. const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a,
  1198. const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b,
  1199. typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
  1200. cudaStream_t stream);
  1201. """
  1202. self.instance_emitter = EmitGemvBatchedStridedInstance()
  1203. self.header_template = """
  1204. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  1205. // ignore warning of cutlass
  1206. #pragma GCC diagnostic push
  1207. #pragma GCC diagnostic ignored "-Wunused-parameter"
  1208. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  1209. #pragma GCC diagnostic ignored "-Wuninitialized"
  1210. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  1211. #include "${wrapper_path}"
  1212. """
  1213. self.instance_template = """
  1214. ${operation_instance}
  1215. """
  1216. self.epilogue_template = """
  1217. #pragma GCC diagnostic pop
  1218. #endif
  1219. """
  1220. #
  1221. def __enter__(self):
  1222. if self.short_path:
  1223. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  1224. GlobalCnt.cnt += 1
  1225. else:
  1226. self.kernel_path = os.path.join(
  1227. self.kernel_path, "%s.cu" % self.operation.procedural_name()
  1228. )
  1229. self.kernel_file = open(self.kernel_path, "w")
  1230. return self
  1231. #
  1232. def emit(self):
  1233. self.kernel_file.write(
  1234. SubstituteTemplate(
  1235. self.instance_template,
  1236. {"operation_instance": self.instance_emitter.emit(self.operation)},
  1237. )
  1238. )
  1239. # emit wrapper
  1240. wrapper = SubstituteTemplate(
  1241. self.wrapper_template, {"operation_name": self.operation.procedural_name()}
  1242. )
  1243. self.kernel_file.write(wrapper)
  1244. #
  1245. def __exit__(self, exception_type, exception_value, traceback):
  1246. self.kernel_file.close()
  1247. ###################################################################################################
  1248. ###################################################################################################