You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemm_operation.py 52 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import functools
  8. import operator
  9. import os.path
  10. import shutil
  11. from library import *
  12. ###################################################################################################
  13. #
  14. # Data structure modeling a GEMM operation
  15. #
  16. ###################################################################################################
  17. #
  18. class GemmOperation:
  19. #
  20. def __init__(
  21. self,
  22. gemm_kind,
  23. arch,
  24. tile_description,
  25. A,
  26. B,
  27. C,
  28. element_epilogue,
  29. epilogue_functor=EpilogueFunctor.LinearCombination,
  30. swizzling_functor=SwizzlingFunctor.Identity8,
  31. required_cuda_ver_major=9,
  32. required_cuda_ver_minor=2,
  33. ):
  34. self.operation_kind = OperationKind.Gemm
  35. self.arch = arch
  36. self.tile_description = tile_description
  37. self.gemm_kind = gemm_kind
  38. self.A = A
  39. self.B = B
  40. self.C = C
  41. self.element_epilogue = element_epilogue
  42. self.epilogue_functor = epilogue_functor
  43. self.swizzling_functor = swizzling_functor
  44. self.required_cuda_ver_major = required_cuda_ver_major
  45. self.required_cuda_ver_minor = required_cuda_ver_minor
  46. #
  47. def is_complex(self):
  48. complex_operators = [
  49. MathOperation.multiply_add_complex,
  50. MathOperation.multiply_add_complex_gaussian,
  51. ]
  52. return (
  53. self.tile_description.math_instruction.math_operation in complex_operators
  54. )
  55. #
  56. def is_split_k_parallel(self):
  57. return self.gemm_kind == GemmKind.SplitKParallel
  58. #
  59. def is_planar_complex(self):
  60. return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
  61. #
  62. def accumulator_type(self):
  63. accum = self.tile_description.math_instruction.element_accumulator
  64. if self.is_complex():
  65. return get_complex_from_real(accum)
  66. return accum
  67. #
  68. def short_math_name(self):
  69. if (
  70. self.tile_description.math_instruction.math_operation
  71. == MathOperation.multiply_add_complex_gaussian
  72. ):
  73. return "g%s" % ShortDataTypeNames[self.accumulator_type()]
  74. return ShortDataTypeNames[self.accumulator_type()]
  75. #
  76. def core_name(self):
  77. """ The basic operation kind is prefixed with a letter indicating the accumulation type. """
  78. inst_shape = ""
  79. inst_operation = ""
  80. intermediate_type = ""
  81. math_operations_map = {MathOperation.xor_popc: "xor"}
  82. if (
  83. self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp
  84. or self.tile_description.math_instruction.opcode_class
  85. == OpcodeClass.WmmaTensorOp
  86. ):
  87. math_op = self.tile_description.math_instruction.math_operation
  88. math_op_string = (
  89. math_operations_map[math_op]
  90. if math_op in math_operations_map.keys()
  91. else ""
  92. )
  93. inst_shape = "%d%d%d" % tuple(
  94. self.tile_description.math_instruction.instruction_shape
  95. )
  96. inst_shape += math_op_string
  97. if (
  98. self.tile_description.math_instruction.element_a != self.A.element
  99. and self.tile_description.math_instruction.element_a
  100. != self.tile_description.math_instruction.element_accumulator
  101. ):
  102. intermediate_type = DataTypeNames[
  103. self.tile_description.math_instruction.element_a
  104. ]
  105. return "%s%s%s%s" % (
  106. self.short_math_name(),
  107. inst_shape,
  108. intermediate_type,
  109. GemmKindNames[self.gemm_kind],
  110. )
  111. #
  112. def extended_name(self):
  113. """ Append data types if they differ from compute type. """
  114. if self.is_complex():
  115. extended_name = "${core_name}"
  116. else:
  117. if (
  118. self.C.element
  119. != self.tile_description.math_instruction.element_accumulator
  120. and self.A.element
  121. != self.tile_description.math_instruction.element_accumulator
  122. ):
  123. extended_name = "${element_c}_${core_name}_${element_a}"
  124. elif (
  125. self.C.element
  126. == self.tile_description.math_instruction.element_accumulator
  127. and self.A.element
  128. != self.tile_description.math_instruction.element_accumulator
  129. ):
  130. extended_name = "${core_name}_${element_a}"
  131. else:
  132. extended_name = "${core_name}"
  133. extended_name = SubstituteTemplate(
  134. extended_name,
  135. {
  136. "element_a": DataTypeNames[self.A.element],
  137. "element_c": DataTypeNames[self.C.element],
  138. "core_name": self.core_name(),
  139. },
  140. )
  141. return extended_name
  142. #
  143. def layout_name(self):
  144. if self.is_complex() or self.is_planar_complex():
  145. return "%s%s" % (
  146. ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
  147. ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
  148. )
  149. return "%s%s" % (
  150. ShortLayoutTypeNames[self.A.layout],
  151. ShortLayoutTypeNames[self.B.layout],
  152. )
  153. #
  154. def procedural_name(self):
  155. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  156. threadblock = self.tile_description.procedural_name()
  157. opcode_class_name = OpcodeClassNames[
  158. self.tile_description.math_instruction.opcode_class
  159. ]
  160. alignment = max([self.A.alignment, self.B.alignment, self.C.alignment])
  161. return SubstituteTemplate(
  162. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}",
  163. {
  164. "opcode_class": opcode_class_name,
  165. "extended_name": self.extended_name(),
  166. "threadblock": threadblock,
  167. "layout": self.layout_name(),
  168. "alignment": "%d" % self.A.alignment,
  169. },
  170. )
  171. #
  172. def configuration_name(self):
  173. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  174. return self.procedural_name()
  175. ###################################################################################################
  176. #
  177. # Data structure modeling a GEMV Batched Strided operation
  178. #
  179. ###################################################################################################
  180. #
  181. class GemvBatchedStridedOperation:
  182. #
  183. def __init__(
  184. self,
  185. gemm_kind,
  186. arch,
  187. math_inst,
  188. threadblock_shape,
  189. thread_shape,
  190. A,
  191. B,
  192. C,
  193. required_cuda_ver_major=9,
  194. required_cuda_ver_minor=2,
  195. ):
  196. self.operation_kind = OperationKind.Gemm
  197. self.arch = arch
  198. self.gemm_kind = gemm_kind
  199. self.math_instruction = math_inst
  200. self.threadblock_shape = threadblock_shape
  201. self.thread_shape = thread_shape
  202. self.A = A
  203. self.B = B
  204. self.C = C
  205. self.required_cuda_ver_major = required_cuda_ver_major
  206. self.required_cuda_ver_minor = required_cuda_ver_minor
  207. #
  208. def accumulator_type(self):
  209. accum = self.math_instruction.element_accumulator
  210. return accum
  211. #
  212. def short_math_name(self):
  213. return ShortDataTypeNames[self.accumulator_type()]
  214. #
  215. def core_name(self):
  216. """ The basic operation kind is prefixed with a letter indicating the accumulation type. """
  217. return "%s%s" % (self.short_math_name(), GemmKindNames[self.gemm_kind])
  218. #
  219. def extended_name(self):
  220. """ Append data types if they differ from compute type. """
  221. if (
  222. self.C.element != self.math_instruction.element_accumulator
  223. and self.A.element != self.math_instruction.element_accumulator
  224. ):
  225. extended_name = "${element_c}_${core_name}_${element_a}"
  226. elif (
  227. self.C.element == self.math_instruction.element_accumulator
  228. and self.A.element != self.math_instruction.element_accumulator
  229. ):
  230. extended_name = "${core_name}_${element_a}"
  231. else:
  232. extended_name = "${core_name}"
  233. extended_name = SubstituteTemplate(
  234. extended_name,
  235. {
  236. "element_a": DataTypeNames[self.A.element],
  237. "element_c": DataTypeNames[self.C.element],
  238. "core_name": self.core_name(),
  239. },
  240. )
  241. return extended_name
  242. #
  243. def layout_name(self):
  244. return "%s%s" % (
  245. ShortLayoutTypeNames[self.A.layout],
  246. ShortLayoutTypeNames[self.B.layout],
  247. )
  248. #
  249. def procedural_name(self):
  250. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  251. threadblock = "%dx%d_%d" % (
  252. self.threadblock_shape[0],
  253. self.threadblock_shape[1],
  254. self.threadblock_shape[2],
  255. )
  256. opcode_class_name = OpcodeClassNames[self.math_instruction.opcode_class]
  257. alignment_a = self.A.alignment
  258. alignment_b = self.B.alignment
  259. return SubstituteTemplate(
  260. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment_a}x${alignment_b}",
  261. {
  262. "opcode_class": opcode_class_name,
  263. "extended_name": self.extended_name(),
  264. "threadblock": threadblock,
  265. "layout": self.layout_name(),
  266. "alignment_a": "%d" % alignment_a,
  267. "alignment_b": "%d" % alignment_b,
  268. },
  269. )
  270. #
  271. def configuration_name(self):
  272. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  273. return self.procedural_name()
  274. #
  275. def GeneratesGemm(
  276. tile,
  277. data_type,
  278. layout_a,
  279. layout_b,
  280. layout_c,
  281. min_cc,
  282. align_a=32,
  283. align_b=32,
  284. align_c=32,
  285. required_cuda_ver_major=9,
  286. required_cuda_ver_minor=2,
  287. ):
  288. operations = []
  289. swizzling_functor = SwizzlingFunctor.Identity1
  290. element_a, element_b, element_c, element_epilogue = data_type
  291. if tile.math_instruction.element_accumulator == DataType.s32:
  292. epilogues = [EpilogueFunctor.LinearCombinationClamp]
  293. else:
  294. assert (
  295. tile.math_instruction.element_accumulator == DataType.f32
  296. or tile.math_instruction.element_accumulator == DataType.f16
  297. )
  298. epilogues = [EpilogueFunctor.LinearCombination]
  299. for epilogue in epilogues:
  300. A = TensorDescription(
  301. element_a, layout_a, int(align_a // DataTypeSize[element_a])
  302. )
  303. B = TensorDescription(
  304. element_b, layout_b, int(align_b // DataTypeSize[element_b])
  305. )
  306. C = TensorDescription(
  307. element_c, layout_c, int(align_c // DataTypeSize[element_c])
  308. )
  309. operations.append(
  310. GemmOperation(
  311. GemmKind.Gemm,
  312. min_cc,
  313. tile,
  314. A,
  315. B,
  316. C,
  317. element_epilogue,
  318. epilogue,
  319. swizzling_functor,
  320. required_cuda_ver_major,
  321. required_cuda_ver_minor,
  322. )
  323. )
  324. operations.append(
  325. GemmOperation(
  326. GemmKind.SplitKParallel,
  327. min_cc,
  328. tile,
  329. A,
  330. B,
  331. C,
  332. element_epilogue,
  333. epilogue,
  334. swizzling_functor,
  335. required_cuda_ver_major,
  336. required_cuda_ver_minor,
  337. )
  338. )
  339. return operations
  340. def GeneratesGemv(
  341. math_inst,
  342. threadblock_shape,
  343. thread_shape,
  344. data_type,
  345. layout_a,
  346. layout_b,
  347. layout_c,
  348. min_cc,
  349. align_a=32,
  350. align_b=32,
  351. align_c=32,
  352. required_cuda_ver_major=9,
  353. required_cuda_ver_minor=2,
  354. ):
  355. element_a, element_b, element_c, element_epilogue = data_type
  356. A = TensorDescription(element_a, layout_a, int(align_a // DataTypeSize[element_a]))
  357. B = TensorDescription(element_b, layout_b, int(align_b // DataTypeSize[element_b]))
  358. C = TensorDescription(element_c, layout_c, int(align_c // DataTypeSize[element_c]))
  359. return GemvBatchedStridedOperation(
  360. GemmKind.GemvBatchedStrided,
  361. min_cc,
  362. math_inst,
  363. threadblock_shape,
  364. thread_shape,
  365. A,
  366. B,
  367. C,
  368. required_cuda_ver_major,
  369. required_cuda_ver_minor,
  370. )
  371. ###################################################################################################
  372. #
  373. # Emits single instances of a CUTLASS device-wide operator
  374. #
  375. ###################################################################################################
  376. #
  377. class EmitGemmInstance:
  378. """ Responsible for emitting a CUTLASS template definition"""
  379. def __init__(self):
  380. self.gemm_template = """
  381. // Gemm operator ${operation_name}
  382. using Operation_${operation_name} = cutlass::gemm::device::Gemm<
  383. ${element_a}, ${layout_a},
  384. ${element_b}, ${layout_b},
  385. ${element_c}, ${layout_c},
  386. ${element_accumulator},
  387. ${opcode_class},
  388. ${arch},
  389. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  390. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  391. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  392. ${epilogue_functor}<
  393. ${element_c},
  394. ${epilogue_vector_length},
  395. ${element_accumulator},
  396. ${element_epilogue}
  397. >,
  398. ${swizzling_functor},
  399. ${stages},
  400. ${align_a},
  401. ${align_b},
  402. false,
  403. ${math_operation}
  404. ${residual}
  405. >;
  406. """
  407. self.gemm_complex_template = """
  408. // Gemm operator ${operation_name}
  409. using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
  410. ${element_a}, ${layout_a},
  411. ${element_b}, ${layout_b},
  412. ${element_c}, ${layout_c},
  413. ${element_accumulator},
  414. ${opcode_class},
  415. ${arch},
  416. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  417. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  418. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  419. ${epilogue_functor}<
  420. ${element_c},
  421. ${epilogue_vector_length},
  422. ${element_accumulator},
  423. ${element_epilogue}
  424. >,
  425. ${swizzling_functor},
  426. ${stages},
  427. ${transform_a},
  428. ${transform_b},
  429. ${math_operation}
  430. ${residual}
  431. >;
  432. """
  433. def emit(self, operation):
  434. warp_shape = [
  435. operation.tile_description.threadblock_shape[idx]
  436. // operation.tile_description.warp_count[idx]
  437. for idx in range(3)
  438. ]
  439. epilogue_vector_length = int(
  440. min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
  441. / DataTypeSize[operation.C.element]
  442. )
  443. residual = ""
  444. values = {
  445. "operation_name": operation.procedural_name(),
  446. "element_a": DataTypeTag[operation.A.element],
  447. "layout_a": LayoutTag[operation.A.layout],
  448. "element_b": DataTypeTag[operation.B.element],
  449. "layout_b": LayoutTag[operation.B.layout],
  450. "element_c": DataTypeTag[operation.C.element],
  451. "layout_c": LayoutTag[operation.C.layout],
  452. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  453. "opcode_class": OpcodeClassTag[
  454. operation.tile_description.math_instruction.opcode_class
  455. ],
  456. "arch": "cutlass::arch::Sm%d" % operation.arch,
  457. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  458. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  459. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  460. "warp_shape_m": str(warp_shape[0]),
  461. "warp_shape_n": str(warp_shape[1]),
  462. "warp_shape_k": str(warp_shape[2]),
  463. "instruction_shape_m": str(
  464. operation.tile_description.math_instruction.instruction_shape[0]
  465. ),
  466. "instruction_shape_n": str(
  467. operation.tile_description.math_instruction.instruction_shape[1]
  468. ),
  469. "instruction_shape_k": str(
  470. operation.tile_description.math_instruction.instruction_shape[2]
  471. ),
  472. "epilogue_vector_length": str(epilogue_vector_length),
  473. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  474. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  475. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  476. "stages": str(operation.tile_description.stages),
  477. "align_a": str(operation.A.alignment),
  478. "align_b": str(operation.B.alignment),
  479. "transform_a": ComplexTransformTag[operation.A.complex_transform],
  480. "transform_b": ComplexTransformTag[operation.B.complex_transform],
  481. "math_operation": MathOperationTag[
  482. operation.tile_description.math_instruction.math_operation
  483. ],
  484. "residual": residual,
  485. }
  486. template = (
  487. self.gemm_complex_template if operation.is_complex() else self.gemm_template
  488. )
  489. return SubstituteTemplate(template, values)
  490. #
  491. class EmitGemvBatchedStridedInstance:
  492. """ Responsible for emitting a CUTLASS template definition"""
  493. def __init__(self):
  494. self.template = """
  495. // Gemm operator ${operation_name}
  496. using Operation_${operation_name} = cutlass::gemm::kernel::DefaultGemv<
  497. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  498. cutlass::gemm::GemmShape<${thread_shape_m}, ${thread_shape_n}, ${thread_shape_k}>,
  499. ${element_a}, ${layout_a},
  500. ${element_b}, ${layout_b},
  501. ${element_c}, ${layout_c}
  502. >;
  503. """
  504. def emit(self, operation):
  505. values = {
  506. "operation_name": operation.procedural_name(),
  507. "element_a": DataTypeTag[operation.A.element],
  508. "layout_a": LayoutTag[operation.A.layout],
  509. "element_b": DataTypeTag[operation.B.element],
  510. "layout_b": LayoutTag[operation.B.layout],
  511. "element_c": DataTypeTag[operation.C.element],
  512. "layout_c": LayoutTag[operation.C.layout],
  513. "threadblock_shape_m": str(operation.threadblock_shape[0]),
  514. "threadblock_shape_n": str(operation.threadblock_shape[1]),
  515. "threadblock_shape_k": str(operation.threadblock_shape[2]),
  516. "thread_shape_m": str(operation.thread_shape[0]),
  517. "thread_shape_n": str(operation.thread_shape[1]),
  518. "thread_shape_k": str(operation.thread_shape[2]),
  519. }
  520. return SubstituteTemplate(self.template, values)
  521. ###################################################################################################
  522. class EmitSparseGemmInstance:
  523. """ Responsible for emitting a CUTLASS template definition"""
  524. def __init__(self):
  525. self.gemm_template = """
  526. // Gemm operator ${operation_name}
  527. using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
  528. ${element_a}, ${layout_a},
  529. ${element_b}, ${layout_b},
  530. ${element_c}, ${layout_c},
  531. ${element_accumulator},
  532. ${opcode_class},
  533. ${arch},
  534. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  535. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  536. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  537. ${epilogue_functor}<
  538. ${element_c},
  539. ${epilogue_vector_length},
  540. ${element_accumulator},
  541. ${element_epilogue}
  542. >,
  543. ${swizzling_functor},
  544. ${stages},
  545. ${align_a},
  546. ${align_b},
  547. false,
  548. ${math_operation}
  549. ${residual}
  550. >;
  551. """
  552. def emit(self, operation):
  553. warp_shape = [
  554. operation.tile_description.threadblock_shape[idx]
  555. // operation.tile_description.warp_count[idx]
  556. for idx in range(3)
  557. ]
  558. epilogue_vector_length = int(
  559. min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
  560. / DataTypeSize[operation.C.element]
  561. )
  562. residual = ""
  563. values = {
  564. "operation_name": operation.procedural_name(),
  565. "element_a": DataTypeTag[operation.A.element],
  566. "layout_a": LayoutTag[operation.A.layout],
  567. "element_b": DataTypeTag[operation.B.element],
  568. "layout_b": LayoutTag[operation.B.layout],
  569. "element_c": DataTypeTag[operation.C.element],
  570. "layout_c": LayoutTag[operation.C.layout],
  571. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  572. "opcode_class": OpcodeClassTag[
  573. operation.tile_description.math_instruction.opcode_class
  574. ],
  575. "arch": "cutlass::arch::Sm%d" % operation.arch,
  576. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  577. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  578. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  579. "warp_shape_m": str(warp_shape[0]),
  580. "warp_shape_n": str(warp_shape[1]),
  581. "warp_shape_k": str(warp_shape[2]),
  582. "instruction_shape_m": str(
  583. operation.tile_description.math_instruction.instruction_shape[0]
  584. ),
  585. "instruction_shape_n": str(
  586. operation.tile_description.math_instruction.instruction_shape[1]
  587. ),
  588. "instruction_shape_k": str(
  589. operation.tile_description.math_instruction.instruction_shape[2]
  590. ),
  591. "epilogue_vector_length": str(epilogue_vector_length),
  592. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  593. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  594. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  595. "stages": str(operation.tile_description.stages),
  596. "align_a": str(operation.A.alignment),
  597. "align_b": str(operation.B.alignment),
  598. "transform_a": ComplexTransformTag[operation.A.complex_transform],
  599. "transform_b": ComplexTransformTag[operation.B.complex_transform],
  600. "math_operation": MathOperationTag[
  601. operation.tile_description.math_instruction.math_operation
  602. ],
  603. "residual": residual,
  604. }
  605. template = self.gemm_template
  606. return SubstituteTemplate(template, values)
  607. ###################################################################################################
  608. #
  609. class EmitGemmUniversalInstance:
  610. """ Responsible for emitting a CUTLASS template definition"""
  611. def __init__(self):
  612. self.gemm_template = """
  613. // Gemm operator ${operation_name}
  614. using ${operation_name}_base =
  615. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  616. ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
  617. ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
  618. ${element_c}, ${layout_c},
  619. ${element_accumulator},
  620. ${opcode_class},
  621. ${arch},
  622. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  623. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  624. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  625. ${epilogue_functor}<
  626. ${element_c},
  627. ${epilogue_vector_length},
  628. ${element_accumulator},
  629. ${element_epilogue}
  630. >,
  631. ${swizzling_functor},
  632. ${stages},
  633. ${math_operation}
  634. >::GemmKernel;
  635. // Define named type
  636. struct ${operation_name} :
  637. public ${operation_name}_base { };
  638. """
  639. self.gemm_template_interleaved = """
  640. // Gemm operator ${operation_name}
  641. using ${operation_name}_base =
  642. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  643. ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
  644. ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
  645. ${element_c}, ${layout_c},
  646. ${element_accumulator},
  647. ${opcode_class},
  648. ${arch},
  649. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  650. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  651. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  652. ${epilogue_functor}<
  653. ${element_c},
  654. ${epilogue_vector_length},
  655. ${element_accumulator},
  656. ${element_epilogue}
  657. >,
  658. ${swizzling_functor},
  659. ${stages},
  660. ${math_operation}
  661. >::GemmKernel;
  662. // Define named type
  663. struct ${operation_name} :
  664. public ${operation_name}_base { };
  665. """
  666. def emit(self, operation):
  667. threadblock_shape = operation.tile_description.threadblock_shape
  668. warp_count = operation.tile_description.warp_count
  669. warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
  670. epilogue_vector_length = int(
  671. min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
  672. / DataTypeSize[operation.C.element]
  673. )
  674. transpose_layouts = {
  675. LayoutType.ColumnMajor: LayoutType.RowMajor,
  676. LayoutType.RowMajor: LayoutType.ColumnMajor,
  677. }
  678. if (
  679. operation.A.layout in transpose_layouts.keys()
  680. and operation.B.layout in transpose_layouts.keys()
  681. and operation.C.layout in transpose_layouts.keys()
  682. ):
  683. instance_layout_A = transpose_layouts[operation.A.layout]
  684. instance_layout_B = transpose_layouts[operation.B.layout]
  685. instance_layout_C = transpose_layouts[operation.C.layout]
  686. gemm_template = self.gemm_template
  687. else:
  688. instance_layout_A, instance_layout_B, instance_layout_C = (
  689. operation.A.layout,
  690. operation.B.layout,
  691. operation.C.layout,
  692. )
  693. gemm_template = self.gemm_template_interleaved
  694. #
  695. values = {
  696. "operation_name": operation.procedural_name(),
  697. "element_a": DataTypeTag[operation.A.element],
  698. "layout_a": LayoutTag[instance_layout_A],
  699. "element_b": DataTypeTag[operation.B.element],
  700. "layout_b": LayoutTag[instance_layout_B],
  701. "element_c": DataTypeTag[operation.C.element],
  702. "layout_c": LayoutTag[instance_layout_C],
  703. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  704. "opcode_class": OpcodeClassTag[
  705. operation.tile_description.math_instruction.opcode_class
  706. ],
  707. "arch": "cutlass::arch::Sm%d" % operation.arch,
  708. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  709. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  710. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  711. "warp_shape_m": str(warp_shape[0]),
  712. "warp_shape_n": str(warp_shape[1]),
  713. "warp_shape_k": str(warp_shape[2]),
  714. "instruction_shape_m": str(
  715. operation.tile_description.math_instruction.instruction_shape[0]
  716. ),
  717. "instruction_shape_n": str(
  718. operation.tile_description.math_instruction.instruction_shape[1]
  719. ),
  720. "instruction_shape_k": str(
  721. operation.tile_description.math_instruction.instruction_shape[2]
  722. ),
  723. "epilogue_vector_length": str(epilogue_vector_length),
  724. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  725. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  726. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  727. "stages": str(operation.tile_description.stages),
  728. "align_a": str(operation.A.alignment),
  729. "align_b": str(operation.B.alignment),
  730. "transform_a": ComplexTransformTag[operation.A.complex_transform],
  731. "transform_b": ComplexTransformTag[operation.B.complex_transform],
  732. "math_operation": MathOperationTag[
  733. operation.tile_description.math_instruction.math_operation
  734. ],
  735. }
  736. return SubstituteTemplate(gemm_template, values)
  737. ###################################################################################################
  738. #
  739. class EmitGemmPlanarComplexInstance:
  740. """ Responsible for emitting a CUTLASS template definition"""
  741. def __init__(self):
  742. self.template = """
  743. // Gemm operator ${operation_name}
  744. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  745. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  746. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  747. ${element_c}, cutlass::layout::RowMajor,
  748. ${element_accumulator},
  749. ${opcode_class},
  750. ${arch},
  751. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  752. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  753. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  754. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  755. ${element_c},
  756. ${alignment_c},
  757. ${element_accumulator},
  758. ${element_epilogue}
  759. >,
  760. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  761. ${stages},
  762. ${math_operator}
  763. >::GemmKernel;
  764. struct ${operation_name} :
  765. public Operation_${operation_name} { };
  766. """
  767. def emit(self, operation):
  768. warp_shape = [
  769. operation.tile_description.threadblock_shape[idx]
  770. // operation.tile_description.warp_count[idx]
  771. for idx in range(3)
  772. ]
  773. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  774. transposed_layout_A = TransposedLayout[operation.A.layout]
  775. transposed_layout_B = TransposedLayout[operation.B.layout]
  776. values = {
  777. "operation_name": operation.procedural_name(),
  778. "element_a": DataTypeTag[operation.B.element],
  779. "layout_a": LayoutTag[transposed_layout_B],
  780. "transform_a": ComplexTransformTag[operation.B.complex_transform],
  781. "alignment_a": str(operation.B.alignment),
  782. "element_b": DataTypeTag[operation.A.element],
  783. "layout_b": LayoutTag[transposed_layout_A],
  784. "transform_b": ComplexTransformTag[operation.A.complex_transform],
  785. "alignment_b": str(operation.A.alignment),
  786. "element_c": DataTypeTag[operation.C.element],
  787. "layout_c": LayoutTag[operation.C.layout],
  788. "element_accumulator": DataTypeTag[
  789. operation.tile_description.math_instruction.element_accumulator
  790. ],
  791. "opcode_class": OpcodeClassTag[
  792. operation.tile_description.math_instruction.opcode_class
  793. ],
  794. "arch": "cutlass::arch::Sm%d" % operation.arch,
  795. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  796. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  797. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  798. "warp_shape_m": str(warp_shape[0]),
  799. "warp_shape_n": str(warp_shape[1]),
  800. "warp_shape_k": str(warp_shape[2]),
  801. "instruction_shape_m": str(
  802. operation.tile_description.math_instruction.instruction_shape[0]
  803. ),
  804. "instruction_shape_n": str(
  805. operation.tile_description.math_instruction.instruction_shape[1]
  806. ),
  807. "instruction_shape_k": str(
  808. operation.tile_description.math_instruction.instruction_shape[2]
  809. ),
  810. "alignment_c": str(operation.C.alignment),
  811. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  812. "stages": str(operation.tile_description.stages),
  813. "math_operator": "cutlass::arch::OpMultiplyAdd",
  814. }
  815. return SubstituteTemplate(self.template, values)
  816. ###################################################################################################
  817. #
  818. class EmitGemmPlanarComplexArrayInstance:
  819. """ Responsible for emitting a CUTLASS template definition"""
  820. def __init__(self):
  821. self.template = """
  822. // Gemm operator ${operation_name}
  823. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  824. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  825. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  826. ${element_c}, cutlass::layout::RowMajor,
  827. ${element_accumulator},
  828. ${opcode_class},
  829. ${arch},
  830. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  831. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  832. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  833. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  834. ${element_c},
  835. ${alignment_c},
  836. ${element_accumulator},
  837. ${element_epilogue}
  838. >,
  839. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  840. ${stages},
  841. ${math_operator}
  842. >::GemmArrayKernel;
  843. struct ${operation_name} : public Operation_${operation_name} { };
  844. """
  845. def emit(self, operation):
  846. warp_shape = [
  847. operation.tile_description.threadblock_shape[idx]
  848. // operation.tile_description.warp_count[idx]
  849. for idx in range(3)
  850. ]
  851. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  852. transposed_layout_A = TransposedLayout[operation.A.layout]
  853. transposed_layout_B = TransposedLayout[operation.B.layout]
  854. values = {
  855. "operation_name": operation.procedural_name(),
  856. "element_a": DataTypeTag[operation.B.element],
  857. "layout_a": LayoutTag[transposed_layout_B],
  858. "transform_a": ComplexTransformTag[operation.B.complex_transform],
  859. "alignment_a": str(operation.B.alignment),
  860. "element_b": DataTypeTag[operation.A.element],
  861. "layout_b": LayoutTag[transposed_layout_A],
  862. "transform_b": ComplexTransformTag[operation.A.complex_transform],
  863. "alignment_b": str(operation.A.alignment),
  864. "element_c": DataTypeTag[operation.C.element],
  865. "layout_c": LayoutTag[operation.C.layout],
  866. "element_accumulator": DataTypeTag[
  867. operation.tile_description.math_instruction.element_accumulator
  868. ],
  869. "opcode_class": OpcodeClassTag[
  870. operation.tile_description.math_instruction.opcode_class
  871. ],
  872. "arch": "cutlass::arch::Sm%d" % operation.arch,
  873. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  874. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  875. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  876. "warp_shape_m": str(warp_shape[0]),
  877. "warp_shape_n": str(warp_shape[1]),
  878. "warp_shape_k": str(warp_shape[2]),
  879. "instruction_shape_m": str(
  880. operation.tile_description.math_instruction.instruction_shape[0]
  881. ),
  882. "instruction_shape_n": str(
  883. operation.tile_description.math_instruction.instruction_shape[1]
  884. ),
  885. "instruction_shape_k": str(
  886. operation.tile_description.math_instruction.instruction_shape[2]
  887. ),
  888. "alignment_c": str(operation.C.alignment),
  889. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  890. "stages": str(operation.tile_description.stages),
  891. "math_operator": "cutlass::arch::OpMultiplyAdd",
  892. }
  893. return SubstituteTemplate(self.template, values)
  894. #
  895. class EmitGemmSplitKParallelInstance:
  896. """ Responsible for emitting a CUTLASS template definition"""
  897. def __init__(self):
  898. self.template = """
  899. // Gemm operator ${operation_name}
  900. using Operation_${operation_name} = cutlass::gemm::device::GemmSplitKParallel<
  901. ${element_a}, ${layout_a},
  902. ${element_b}, ${layout_b},
  903. ${element_c}, ${layout_c},
  904. ${element_accumulator},
  905. ${opcode_class},
  906. ${arch},
  907. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  908. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  909. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  910. ${epilogue_functor}<
  911. ${element_c},
  912. ${epilogue_vector_length},
  913. ${element_accumulator},
  914. ${element_epilogue}
  915. >,
  916. cutlass::epilogue::thread::Convert<
  917. ${element_accumulator},
  918. ${epilogue_vector_length},
  919. ${element_accumulator}
  920. >,
  921. cutlass::reduction::thread::ReduceAdd<
  922. ${element_accumulator},
  923. ${element_accumulator},
  924. ${epilogue_vector_length}
  925. >,
  926. cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle,
  927. ${stages},
  928. ${align_a},
  929. ${align_b},
  930. ${math_operation}
  931. >;
  932. """
  933. def emit(self, operation):
  934. warp_shape = [
  935. operation.tile_description.threadblock_shape[idx]
  936. // operation.tile_description.warp_count[idx]
  937. for idx in range(3)
  938. ]
  939. epilogue_vector_length = int(
  940. min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
  941. / DataTypeSize[operation.C.element]
  942. )
  943. values = {
  944. "operation_name": operation.procedural_name(),
  945. "element_a": DataTypeTag[operation.A.element],
  946. "layout_a": LayoutTag[operation.A.layout],
  947. "element_b": DataTypeTag[operation.B.element],
  948. "layout_b": LayoutTag[operation.B.layout],
  949. "element_c": DataTypeTag[operation.C.element],
  950. "layout_c": LayoutTag[operation.C.layout],
  951. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  952. "opcode_class": OpcodeClassTag[
  953. operation.tile_description.math_instruction.opcode_class
  954. ],
  955. "arch": "cutlass::arch::Sm%d" % operation.arch,
  956. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  957. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  958. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  959. "warp_shape_m": str(warp_shape[0]),
  960. "warp_shape_n": str(warp_shape[1]),
  961. "warp_shape_k": str(warp_shape[2]),
  962. "instruction_shape_m": str(
  963. operation.tile_description.math_instruction.instruction_shape[0]
  964. ),
  965. "instruction_shape_n": str(
  966. operation.tile_description.math_instruction.instruction_shape[1]
  967. ),
  968. "instruction_shape_k": str(
  969. operation.tile_description.math_instruction.instruction_shape[2]
  970. ),
  971. "epilogue_vector_length": str(epilogue_vector_length),
  972. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  973. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  974. "stages": str(operation.tile_description.stages),
  975. "math_operation": MathOperationTag[
  976. operation.tile_description.math_instruction.math_operation
  977. ],
  978. "align_a": str(operation.A.alignment),
  979. "align_b": str(operation.B.alignment),
  980. }
  981. return SubstituteTemplate(self.template, values)
  982. ###################################################################################################
  983. ###################################################################################################
  984. #
  985. # Emitters functions for all targets
  986. #
  987. ###################################################################################################
  988. class EmitGemmConfigurationLibrary:
  989. def __init__(self, operation_path, configuration_name):
  990. self.configuration_name = configuration_name
  991. self.configuration_path = os.path.join(
  992. operation_path, "%s.cu" % configuration_name
  993. ).replace("\\", "/")
  994. self.instance_emitter = {
  995. GemmKind.Gemm: EmitGemmInstance,
  996. GemmKind.Sparse: EmitSparseGemmInstance,
  997. GemmKind.Universal: EmitGemmUniversalInstance,
  998. GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
  999. GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
  1000. }
  1001. self.gemm_kind_wrappers = {
  1002. GemmKind.Gemm: "GemmOperation",
  1003. GemmKind.Sparse: "GemmSparseOperation",
  1004. GemmKind.Universal: "GemmUniversalOperation",
  1005. GemmKind.PlanarComplex: "GemmPlanarComplexOperation",
  1006. GemmKind.PlanarComplexArray: "GemmPlanarComplexArrayOperation",
  1007. }
  1008. self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
  1009. self.instance_template = {
  1010. GemmKind.Gemm: """
  1011. ${compile_guard_start}
  1012. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  1013. ${compile_guard_end}
  1014. """,
  1015. GemmKind.Sparse: """
  1016. ${compile_guard_start}
  1017. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  1018. ${compile_guard_end}
  1019. """,
  1020. GemmKind.Universal: """
  1021. ${compile_guard_start}
  1022. manifest.append(new ${gemm_kind}<
  1023. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  1024. >("${operation_name}"));
  1025. ${compile_guard_end}
  1026. """,
  1027. GemmKind.PlanarComplex: """
  1028. ${compile_guard_start}
  1029. manifest.append(new ${gemm_kind}<
  1030. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  1031. >("${operation_name}"));
  1032. ${compile_guard_end}
  1033. """,
  1034. GemmKind.PlanarComplexArray: """
  1035. ${compile_guard_start}
  1036. manifest.append(new ${gemm_kind}<
  1037. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  1038. >("${operation_name}"));
  1039. ${compile_guard_end}
  1040. """,
  1041. }
  1042. self.header_template = """
  1043. /*
  1044. Generated by gemm_operation.py - Do not edit.
  1045. */
  1046. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1047. #include "cutlass/arch/wmma.h"
  1048. #include "cutlass/cutlass.h"
  1049. #include "cutlass/library/library.h"
  1050. #include "cutlass/library/manifest.h"
  1051. #include "library_internal.h"
  1052. #include "gemm_operation.h"
  1053. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1054. """
  1055. self.initialize_function_template = """
  1056. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1057. namespace cutlass {
  1058. namespace library {
  1059. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1060. void initialize_${configuration_name}(Manifest &manifest) {
  1061. """
  1062. self.epilogue_template = """
  1063. }
  1064. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1065. } // namespace library
  1066. } // namespace cutlass
  1067. ///////////////////////////////////////////////////////////////////////////////////////////////////
  1068. """
  1069. def __enter__(self):
  1070. self.configuration_file = open(self.configuration_path, "w")
  1071. self.configuration_file.write(self.header_template)
  1072. self.instance_definitions = []
  1073. self.instance_wrappers = []
  1074. self.operations = []
  1075. return self
  1076. def emit(self, operation):
  1077. emitter = self.instance_emitter[operation.gemm_kind]()
  1078. self.operations.append(operation)
  1079. self.instance_definitions.append(emitter.emit(operation))
  1080. self.instance_wrappers.append(
  1081. SubstituteTemplate(
  1082. self.instance_template[operation.gemm_kind],
  1083. {
  1084. "configuration_name": self.configuration_name,
  1085. "operation_name": operation.procedural_name(),
  1086. "gemm_kind": self.gemm_kind_wrappers[operation.gemm_kind],
  1087. "compile_guard_start": SubstituteTemplate(
  1088. self.wmma_guard_start, {"sm_number": str(operation.arch)}
  1089. )
  1090. if operation.tile_description.math_instruction.opcode_class
  1091. == OpcodeClass.WmmaTensorOp
  1092. else "",
  1093. "compile_guard_end": "#endif"
  1094. if operation.tile_description.math_instruction.opcode_class
  1095. == OpcodeClass.WmmaTensorOp
  1096. else "",
  1097. },
  1098. )
  1099. )
  1100. def __exit__(self, exception_type, exception_value, traceback):
  1101. # Write instance definitions in top-level namespace
  1102. for instance_definition in self.instance_definitions:
  1103. self.configuration_file.write(instance_definition)
  1104. # Add wrapper objects within initialize() function
  1105. self.configuration_file.write(
  1106. SubstituteTemplate(
  1107. self.initialize_function_template,
  1108. {"configuration_name": self.configuration_name},
  1109. )
  1110. )
  1111. for instance_wrapper in self.instance_wrappers:
  1112. self.configuration_file.write(instance_wrapper)
  1113. self.configuration_file.write(self.epilogue_template)
  1114. self.configuration_file.close()
  1115. ###################################################################################################
  1116. ###################################################################################################
  1117. class EmitGemmSingleKernelWrapper:
  1118. def __init__(self, kernel_path, gemm_operation, short_path=False):
  1119. self.short_path = short_path
  1120. self.kernel_path = kernel_path
  1121. self.operation = gemm_operation
  1122. instance_emitters = {
  1123. GemmKind.Gemm: EmitGemmInstance(),
  1124. GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(),
  1125. }
  1126. self.instance_emitter = instance_emitters[self.operation.gemm_kind]
  1127. self.header_template = """
  1128. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  1129. // ignore warning of cutlass
  1130. #pragma GCC diagnostic push
  1131. #pragma GCC diagnostic ignored "-Wunused-parameter"
  1132. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  1133. #pragma GCC diagnostic ignored "-Wuninitialized"
  1134. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  1135. #include "cutlass/gemm/device/gemm.h"
  1136. #include "cutlass/gemm/device/gemm_splitk_parallel.h"
  1137. #include "src/cuda/cutlass/manifest.h"
  1138. #include "src/cuda/cutlass/gemm_operation.h"
  1139. """
  1140. self.instance_template = """
  1141. ${operation_instance}
  1142. """
  1143. self.manifest_template = """
  1144. namespace cutlass {
  1145. namespace library {
  1146. void initialize_${operation_name}(Manifest &manifest) {
  1147. manifest.append(new GemmOperation<
  1148. Operation_${operation_name}
  1149. >("${operation_name}"));
  1150. }
  1151. } // namespace library
  1152. } // namespace cutlass
  1153. """
  1154. self.epilogue_template = """
  1155. #pragma GCC diagnostic pop
  1156. #endif
  1157. """
  1158. #
  1159. def __enter__(self):
  1160. if self.short_path:
  1161. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  1162. GlobalCnt.cnt += 1
  1163. else:
  1164. self.kernel_path = os.path.join(
  1165. self.kernel_path, "%s.cu" % self.operation.procedural_name()
  1166. )
  1167. self.kernel_file = open(self.kernel_path, "w")
  1168. return self
  1169. #
  1170. def emit(self):
  1171. self.kernel_file.write(
  1172. SubstituteTemplate(
  1173. self.instance_template,
  1174. {"operation_instance": self.instance_emitter.emit(self.operation)},
  1175. )
  1176. )
  1177. # emit manifest helper
  1178. manifest = SubstituteTemplate(
  1179. self.manifest_template, {"operation_name": self.operation.procedural_name()}
  1180. )
  1181. self.kernel_file.write(manifest)
  1182. #
  1183. def __exit__(self, exception_type, exception_value, traceback):
  1184. self.kernel_file.close()
  1185. ###################################################################################################
  1186. ###################################################################################################
  1187. class EmitGemvSingleKernelWrapper:
  1188. def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False):
  1189. self.kernel_path = kernel_path
  1190. self.wrapper_path = wrapper_path
  1191. self.operation = gemm_operation
  1192. self.short_path = short_path
  1193. self.wrapper_template = """
  1194. template void megdnn::cuda::cutlass_wrapper::
  1195. cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>(
  1196. BatchedGemmCoord const& problem_size,
  1197. const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a,
  1198. const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b,
  1199. typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
  1200. cudaStream_t stream);
  1201. """
  1202. self.instance_emitter = EmitGemvBatchedStridedInstance()
  1203. self.header_template = """
  1204. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  1205. // ignore warning of cutlass
  1206. #pragma GCC diagnostic push
  1207. #pragma GCC diagnostic ignored "-Wunused-parameter"
  1208. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  1209. #pragma GCC diagnostic ignored "-Wuninitialized"
  1210. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  1211. #include "${wrapper_path}"
  1212. """
  1213. self.instance_template = """
  1214. ${operation_instance}
  1215. """
  1216. self.epilogue_template = """
  1217. #pragma GCC diagnostic pop
  1218. #endif
  1219. """
  1220. #
  1221. def __enter__(self):
  1222. if self.short_path:
  1223. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  1224. GlobalCnt.cnt += 1
  1225. else:
  1226. self.kernel_path = os.path.join(
  1227. self.kernel_path, "%s.cu" % self.operation.procedural_name()
  1228. )
  1229. self.kernel_file = open(self.kernel_path, "w")
  1230. return self
  1231. #
  1232. def emit(self):
  1233. self.kernel_file.write(
  1234. SubstituteTemplate(
  1235. self.instance_template,
  1236. {"operation_instance": self.instance_emitter.emit(self.operation)},
  1237. )
  1238. )
  1239. # emit wrapper
  1240. wrapper = SubstituteTemplate(
  1241. self.wrapper_template, {"operation_name": self.operation.procedural_name()}
  1242. )
  1243. self.kernel_file.write(wrapper)
  1244. #
  1245. def __exit__(self, exception_type, exception_value, traceback):
  1246. self.kernel_file.close()
  1247. ###################################################################################################
  1248. ###################################################################################################