You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemm_operation.py 45 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. import functools
  10. import operator
  11. from library import *
  12. ###################################################################################################
  13. #
  14. # Data structure modeling a GEMM operation
  15. #
  16. ###################################################################################################
  17. #
  18. class GemmOperation:
  19. #
  20. def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
  21. epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8):
  22. self.operation_kind = OperationKind.Gemm
  23. self.arch = arch
  24. self.tile_description = tile_description
  25. self.gemm_kind = gemm_kind
  26. self.A = A
  27. self.B = B
  28. self.C = C
  29. self.element_epilogue = element_epilogue
  30. self.epilogue_functor = epilogue_functor
  31. self.swizzling_functor = swizzling_functor
  32. #
  33. def is_complex(self):
  34. complex_operators = [
  35. MathOperation.multiply_add_complex,
  36. MathOperation.multiply_add_complex_gaussian
  37. ]
  38. return self.tile_description.math_instruction.math_operation in complex_operators
  39. #
  40. def is_split_k_parallel(self):
  41. return self.gemm_kind == GemmKind.SplitKParallel
  42. #
  43. def is_planar_complex(self):
  44. return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
  45. #
  46. def accumulator_type(self):
  47. accum = self.tile_description.math_instruction.element_accumulator
  48. if self.is_complex():
  49. return get_complex_from_real(accum)
  50. return accum
  51. #
  52. def short_math_name(self):
  53. if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
  54. return "g%s" % ShortDataTypeNames[self.accumulator_type()]
  55. return ShortDataTypeNames[self.accumulator_type()]
  56. #
  57. def core_name(self):
  58. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  59. inst_shape = ''
  60. inst_operation = ''
  61. intermediate_type = ''
  62. math_operations_map = {
  63. MathOperation.xor_popc: 'xor',
  64. }
  65. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
  66. self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
  67. math_op = self.tile_description.math_instruction.math_operation
  68. math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
  69. inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
  70. inst_shape += math_op_string
  71. if self.tile_description.math_instruction.element_a != self.A.element and \
  72. self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
  73. intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
  74. return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
  75. #
  76. def extended_name(self):
  77. ''' Append data types if they differ from compute type. '''
  78. if self.is_complex():
  79. extended_name = "${core_name}"
  80. else:
  81. if self.C.element != self.tile_description.math_instruction.element_accumulator and \
  82. self.A.element != self.tile_description.math_instruction.element_accumulator:
  83. extended_name = "${element_c}_${core_name}_${element_a}"
  84. elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
  85. self.A.element != self.tile_description.math_instruction.element_accumulator:
  86. extended_name = "${core_name}_${element_a}"
  87. else:
  88. extended_name = "${core_name}"
  89. extended_name = SubstituteTemplate(extended_name, {
  90. 'element_a': DataTypeNames[self.A.element],
  91. 'element_c': DataTypeNames[self.C.element],
  92. 'core_name': self.core_name()
  93. })
  94. return extended_name
  95. #
  96. def layout_name(self):
  97. if self.is_complex() or self.is_planar_complex():
  98. return "%s%s" % (
  99. ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
  100. ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
  101. )
  102. return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
  103. #
  104. def procedural_name(self):
  105. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  106. threadblock = self.tile_description.procedural_name()
  107. opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
  108. alignment = max([self.A.alignment, self.B.alignment, self.C.alignment])
  109. return SubstituteTemplate(
  110. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}",
  111. {
  112. 'opcode_class': opcode_class_name,
  113. 'extended_name': self.extended_name(),
  114. 'threadblock': threadblock,
  115. 'layout': self.layout_name(),
  116. 'alignment': "%d" % self.A.alignment,
  117. }
  118. )
  119. #
  120. def configuration_name(self):
  121. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  122. return self.procedural_name()
  123. ###################################################################################################
  124. #
  125. # Data structure modeling a GEMV Batched Strided operation
  126. #
  127. ###################################################################################################
  128. #
  129. class GemvBatchedStridedOperation:
  130. #
  131. def __init__(self, gemm_kind, arch, math_inst, threadblock_shape, thread_shape, A, B, C):
  132. self.operation_kind = OperationKind.Gemm
  133. self.arch = arch
  134. self.gemm_kind = gemm_kind
  135. self.math_instruction = math_inst
  136. self.threadblock_shape = threadblock_shape
  137. self.thread_shape = thread_shape
  138. self.A = A
  139. self.B = B
  140. self.C = C
  141. #
  142. def accumulator_type(self):
  143. accum = self.math_instruction.element_accumulator
  144. return accum
  145. #
  146. def short_math_name(self):
  147. return ShortDataTypeNames[self.accumulator_type()]
  148. #
  149. def core_name(self):
  150. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  151. return "%s%s" % (self.short_math_name(), \
  152. GemmKindNames[self.gemm_kind])
  153. #
  154. def extended_name(self):
  155. ''' Append data types if they differ from compute type. '''
  156. if self.C.element != self.math_instruction.element_accumulator and \
  157. self.A.element != self.math_instruction.element_accumulator:
  158. extended_name = "${element_c}_${core_name}_${element_a}"
  159. elif self.C.element == self.math_instruction.element_accumulator and \
  160. self.A.element != self.math_instruction.element_accumulator:
  161. extended_name = "${core_name}_${element_a}"
  162. else:
  163. extended_name = "${core_name}"
  164. extended_name = SubstituteTemplate(extended_name, {
  165. 'element_a': DataTypeNames[self.A.element],
  166. 'element_c': DataTypeNames[self.C.element],
  167. 'core_name': self.core_name()
  168. })
  169. return extended_name
  170. #
  171. def layout_name(self):
  172. return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
  173. #
  174. def procedural_name(self):
  175. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  176. threadblock = "%dx%d_%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2])
  177. opcode_class_name = OpcodeClassNames[self.math_instruction.opcode_class]
  178. alignment_a = self.A.alignment
  179. alignment_b = self.B.alignment
  180. return SubstituteTemplate(
  181. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment_a}x${alignment_b}",
  182. {
  183. 'opcode_class': opcode_class_name,
  184. 'extended_name': self.extended_name(),
  185. 'threadblock': threadblock,
  186. 'layout': self.layout_name(),
  187. 'alignment_a': "%d" % alignment_a,
  188. 'alignment_b': "%d" % alignment_b,
  189. }
  190. )
  191. #
  192. def configuration_name(self):
  193. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  194. return self.procedural_name()
  195. #
  196. def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a = 32, align_b = 32, align_c = 32):
  197. operations = []
  198. swizzling_functor = SwizzlingFunctor.Identity1
  199. element_a, element_b, element_c, element_epilogue = data_type
  200. if tile.math_instruction.element_accumulator == DataType.s32:
  201. epilogues = [EpilogueFunctor.LinearCombinationClamp]
  202. else:
  203. assert tile.math_instruction.element_accumulator == DataType.f32
  204. epilogues = [EpilogueFunctor.LinearCombination]
  205. for epilogue in epilogues:
  206. A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a]))
  207. B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b]))
  208. C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c]))
  209. operations.append(GemmOperation(GemmKind.Gemm, min_cc, tile, A, B, C, \
  210. element_epilogue, epilogue, swizzling_functor))
  211. operations.append(GemmOperation(GemmKind.SplitKParallel, min_cc, tile, A, B, C, \
  212. element_epilogue, epilogue, swizzling_functor))
  213. return operations
  214. def GeneratesGemv(math_inst, threadblock_shape, thread_shape, data_type, layout_a, layout_b, layout_c, min_cc, \
  215. align_a = 32, align_b = 32, align_c = 32):
  216. element_a, element_b, element_c, element_epilogue = data_type
  217. A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a]))
  218. B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b]))
  219. C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c]))
  220. return GemvBatchedStridedOperation(GemmKind.GemvBatchedStrided, min_cc, math_inst, threadblock_shape, thread_shape, \
  221. A, B, C)
  222. ###################################################################################################
  223. #
  224. # Emits single instances of a CUTLASS device-wide operator
  225. #
  226. ###################################################################################################
  227. #
  228. class EmitGemmInstance:
  229. ''' Responsible for emitting a CUTLASS template definition'''
  230. def __init__(self):
  231. self.gemm_template = """
  232. // Gemm operator ${operation_name}
  233. using Operation_${operation_name} = cutlass::gemm::device::Gemm<
  234. ${element_a}, ${layout_a},
  235. ${element_b}, ${layout_b},
  236. ${element_c}, ${layout_c},
  237. ${element_accumulator},
  238. ${opcode_class},
  239. ${arch},
  240. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  241. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  242. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  243. ${epilogue_functor}<
  244. ${element_c},
  245. ${epilogue_vector_length},
  246. ${element_accumulator},
  247. ${element_epilogue}
  248. >,
  249. ${swizzling_functor},
  250. ${stages},
  251. ${align_a},
  252. ${align_b},
  253. false,
  254. ${math_operation}
  255. ${residual}
  256. >;
  257. """
  258. self.gemm_complex_template = """
  259. // Gemm operator ${operation_name}
  260. using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
  261. ${element_a}, ${layout_a},
  262. ${element_b}, ${layout_b},
  263. ${element_c}, ${layout_c},
  264. ${element_accumulator},
  265. ${opcode_class},
  266. ${arch},
  267. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  268. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  269. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  270. ${epilogue_functor}<
  271. ${element_c},
  272. ${epilogue_vector_length},
  273. ${element_accumulator},
  274. ${element_epilogue}
  275. >,
  276. ${swizzling_functor},
  277. ${stages},
  278. ${transform_a},
  279. ${transform_b},
  280. ${math_operation}
  281. ${residual}
  282. >;
  283. """
  284. def emit(self, operation):
  285. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  286. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  287. residual = ''
  288. values = {
  289. 'operation_name': operation.procedural_name(),
  290. 'element_a': DataTypeTag[operation.A.element],
  291. 'layout_a': LayoutTag[operation.A.layout],
  292. 'element_b': DataTypeTag[operation.B.element],
  293. 'layout_b': LayoutTag[operation.B.layout],
  294. 'element_c': DataTypeTag[operation.C.element],
  295. 'layout_c': LayoutTag[operation.C.layout],
  296. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  297. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  298. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  299. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  300. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  301. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  302. 'warp_shape_m': str(warp_shape[0]),
  303. 'warp_shape_n': str(warp_shape[1]),
  304. 'warp_shape_k': str(warp_shape[2]),
  305. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  306. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  307. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  308. 'epilogue_vector_length': str(epilogue_vector_length),
  309. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  310. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  311. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  312. 'stages': str(operation.tile_description.stages),
  313. 'align_a': str(operation.A.alignment),
  314. 'align_b': str(operation.B.alignment),
  315. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  316. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  317. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  318. 'residual': residual
  319. }
  320. template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
  321. return SubstituteTemplate(template, values)
  322. #
  323. class EmitGemvBatchedStridedInstance:
  324. ''' Responsible for emitting a CUTLASS template definition'''
  325. def __init__(self):
  326. self.template = """
  327. // Gemm operator ${operation_name}
  328. using Operation_${operation_name} = cutlass::gemm::kernel::DefaultGemv<
  329. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  330. cutlass::gemm::GemmShape<${thread_shape_m}, ${thread_shape_n}, ${thread_shape_k}>,
  331. ${element_a}, ${layout_a},
  332. ${element_b}, ${layout_b},
  333. ${element_c}, ${layout_c}
  334. >;
  335. """
  336. def emit(self, operation):
  337. values = {
  338. 'operation_name': operation.procedural_name(),
  339. 'element_a': DataTypeTag[operation.A.element],
  340. 'layout_a': LayoutTag[operation.A.layout],
  341. 'element_b': DataTypeTag[operation.B.element],
  342. 'layout_b': LayoutTag[operation.B.layout],
  343. 'element_c': DataTypeTag[operation.C.element],
  344. 'layout_c': LayoutTag[operation.C.layout],
  345. 'threadblock_shape_m': str(operation.threadblock_shape[0]),
  346. 'threadblock_shape_n': str(operation.threadblock_shape[1]),
  347. 'threadblock_shape_k': str(operation.threadblock_shape[2]),
  348. 'thread_shape_m': str(operation.thread_shape[0]),
  349. 'thread_shape_n': str(operation.thread_shape[1]),
  350. 'thread_shape_k': str(operation.thread_shape[2]),
  351. }
  352. return SubstituteTemplate(self.template, values)
  353. ###################################################################################################
  354. class EmitSparseGemmInstance:
  355. ''' Responsible for emitting a CUTLASS template definition'''
  356. def __init__(self):
  357. self.gemm_template = """
  358. // Gemm operator ${operation_name}
  359. using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
  360. ${element_a}, ${layout_a},
  361. ${element_b}, ${layout_b},
  362. ${element_c}, ${layout_c},
  363. ${element_accumulator},
  364. ${opcode_class},
  365. ${arch},
  366. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  367. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  368. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  369. ${epilogue_functor}<
  370. ${element_c},
  371. ${epilogue_vector_length},
  372. ${element_accumulator},
  373. ${element_epilogue}
  374. >,
  375. ${swizzling_functor},
  376. ${stages},
  377. ${align_a},
  378. ${align_b},
  379. false,
  380. ${math_operation}
  381. ${residual}
  382. >;
  383. """
  384. def emit(self, operation):
  385. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  386. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  387. residual = ''
  388. values = {
  389. 'operation_name': operation.procedural_name(),
  390. 'element_a': DataTypeTag[operation.A.element],
  391. 'layout_a': LayoutTag[operation.A.layout],
  392. 'element_b': DataTypeTag[operation.B.element],
  393. 'layout_b': LayoutTag[operation.B.layout],
  394. 'element_c': DataTypeTag[operation.C.element],
  395. 'layout_c': LayoutTag[operation.C.layout],
  396. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  397. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  398. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  399. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  400. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  401. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  402. 'warp_shape_m': str(warp_shape[0]),
  403. 'warp_shape_n': str(warp_shape[1]),
  404. 'warp_shape_k': str(warp_shape[2]),
  405. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  406. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  407. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  408. 'epilogue_vector_length': str(epilogue_vector_length),
  409. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  410. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  411. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  412. 'stages': str(operation.tile_description.stages),
  413. 'align_a': str(operation.A.alignment),
  414. 'align_b': str(operation.B.alignment),
  415. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  416. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  417. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  418. 'residual': residual
  419. }
  420. template = self.gemm_template
  421. return SubstituteTemplate(template, values)
  422. ###################################################################################################
  423. #
  424. class EmitGemmUniversalInstance:
  425. ''' Responsible for emitting a CUTLASS template definition'''
  426. def __init__(self):
  427. self.gemm_template = """
  428. // Gemm operator ${operation_name}
  429. using ${operation_name}_base =
  430. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  431. ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
  432. ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
  433. ${element_c}, ${layout_c},
  434. ${element_accumulator},
  435. ${opcode_class},
  436. ${arch},
  437. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  438. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  439. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  440. ${epilogue_functor}<
  441. ${element_c},
  442. ${epilogue_vector_length},
  443. ${element_accumulator},
  444. ${element_epilogue}
  445. >,
  446. ${swizzling_functor},
  447. ${stages},
  448. ${math_operation}
  449. >::GemmKernel;
  450. // Define named type
  451. struct ${operation_name} :
  452. public ${operation_name}_base { };
  453. """
  454. self.gemm_template_interleaved = """
  455. // Gemm operator ${operation_name}
  456. using ${operation_name}_base =
  457. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  458. ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
  459. ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
  460. ${element_c}, ${layout_c},
  461. ${element_accumulator},
  462. ${opcode_class},
  463. ${arch},
  464. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  465. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  466. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  467. ${epilogue_functor}<
  468. ${element_c},
  469. ${epilogue_vector_length},
  470. ${element_accumulator},
  471. ${element_epilogue}
  472. >,
  473. ${swizzling_functor},
  474. ${stages},
  475. ${math_operation}
  476. >::GemmKernel;
  477. // Define named type
  478. struct ${operation_name} :
  479. public ${operation_name}_base { };
  480. """
  481. def emit(self, operation):
  482. threadblock_shape = operation.tile_description.threadblock_shape
  483. warp_count = operation.tile_description.warp_count
  484. warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
  485. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  486. transpose_layouts = {
  487. LayoutType.ColumnMajor: LayoutType.RowMajor,
  488. LayoutType.RowMajor: LayoutType.ColumnMajor
  489. }
  490. if operation.A.layout in transpose_layouts.keys() and \
  491. operation.B.layout in transpose_layouts.keys() and \
  492. operation.C.layout in transpose_layouts.keys():
  493. instance_layout_A = transpose_layouts[operation.A.layout]
  494. instance_layout_B = transpose_layouts[operation.B.layout]
  495. instance_layout_C = transpose_layouts[operation.C.layout]
  496. gemm_template = self.gemm_template
  497. else:
  498. instance_layout_A, instance_layout_B, instance_layout_C = \
  499. (operation.A.layout, operation.B.layout, operation.C.layout)
  500. gemm_template = self.gemm_template_interleaved
  501. #
  502. values = {
  503. 'operation_name': operation.procedural_name(),
  504. 'element_a': DataTypeTag[operation.A.element],
  505. 'layout_a': LayoutTag[instance_layout_A],
  506. 'element_b': DataTypeTag[operation.B.element],
  507. 'layout_b': LayoutTag[instance_layout_B],
  508. 'element_c': DataTypeTag[operation.C.element],
  509. 'layout_c': LayoutTag[instance_layout_C],
  510. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  511. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  512. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  513. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  514. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  515. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  516. 'warp_shape_m': str(warp_shape[0]),
  517. 'warp_shape_n': str(warp_shape[1]),
  518. 'warp_shape_k': str(warp_shape[2]),
  519. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  520. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  521. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  522. 'epilogue_vector_length': str(epilogue_vector_length),
  523. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  524. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  525. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  526. 'stages': str(operation.tile_description.stages),
  527. 'align_a': str(operation.A.alignment),
  528. 'align_b': str(operation.B.alignment),
  529. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  530. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  531. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
  532. }
  533. return SubstituteTemplate(gemm_template, values)
  534. ###################################################################################################
  535. #
  536. class EmitGemmPlanarComplexInstance:
  537. ''' Responsible for emitting a CUTLASS template definition'''
  538. def __init__(self):
  539. self.template = """
  540. // Gemm operator ${operation_name}
  541. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  542. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  543. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  544. ${element_c}, cutlass::layout::RowMajor,
  545. ${element_accumulator},
  546. ${opcode_class},
  547. ${arch},
  548. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  549. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  550. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  551. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  552. ${element_c},
  553. ${alignment_c},
  554. ${element_accumulator},
  555. ${element_epilogue}
  556. >,
  557. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  558. ${stages},
  559. ${math_operator}
  560. >::GemmKernel;
  561. struct ${operation_name} :
  562. public Operation_${operation_name} { };
  563. """
  564. def emit(self, operation):
  565. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  566. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  567. transposed_layout_A = TransposedLayout[operation.A.layout]
  568. transposed_layout_B = TransposedLayout[operation.B.layout]
  569. values = {
  570. 'operation_name': operation.procedural_name(),
  571. 'element_a': DataTypeTag[operation.B.element],
  572. 'layout_a': LayoutTag[transposed_layout_B],
  573. 'transform_a': ComplexTransformTag[operation.B.complex_transform],
  574. 'alignment_a': str(operation.B.alignment),
  575. 'element_b': DataTypeTag[operation.A.element],
  576. 'layout_b': LayoutTag[transposed_layout_A],
  577. 'transform_b': ComplexTransformTag[operation.A.complex_transform],
  578. 'alignment_b': str(operation.A.alignment),
  579. 'element_c': DataTypeTag[operation.C.element],
  580. 'layout_c': LayoutTag[operation.C.layout],
  581. 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
  582. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  583. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  584. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  585. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  586. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  587. 'warp_shape_m': str(warp_shape[0]),
  588. 'warp_shape_n': str(warp_shape[1]),
  589. 'warp_shape_k': str(warp_shape[2]),
  590. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  591. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  592. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  593. 'alignment_c': str(operation.C.alignment),
  594. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  595. 'stages': str(operation.tile_description.stages),
  596. 'math_operator': 'cutlass::arch::OpMultiplyAdd'
  597. }
  598. return SubstituteTemplate(self.template, values)
  599. ###################################################################################################
  600. #
  601. class EmitGemmPlanarComplexArrayInstance:
  602. ''' Responsible for emitting a CUTLASS template definition'''
  603. def __init__(self):
  604. self.template = """
  605. // Gemm operator ${operation_name}
  606. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  607. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  608. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  609. ${element_c}, cutlass::layout::RowMajor,
  610. ${element_accumulator},
  611. ${opcode_class},
  612. ${arch},
  613. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  614. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  615. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  616. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  617. ${element_c},
  618. ${alignment_c},
  619. ${element_accumulator},
  620. ${element_epilogue}
  621. >,
  622. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  623. ${stages},
  624. ${math_operator}
  625. >::GemmArrayKernel;
  626. struct ${operation_name} : public Operation_${operation_name} { };
  627. """
  628. def emit(self, operation):
  629. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  630. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  631. transposed_layout_A = TransposedLayout[operation.A.layout]
  632. transposed_layout_B = TransposedLayout[operation.B.layout]
  633. values = {
  634. 'operation_name': operation.procedural_name(),
  635. 'element_a': DataTypeTag[operation.B.element],
  636. 'layout_a': LayoutTag[transposed_layout_B],
  637. 'transform_a': ComplexTransformTag[operation.B.complex_transform],
  638. 'alignment_a': str(operation.B.alignment),
  639. 'element_b': DataTypeTag[operation.A.element],
  640. 'layout_b': LayoutTag[transposed_layout_A],
  641. 'transform_b': ComplexTransformTag[operation.A.complex_transform],
  642. 'alignment_b': str(operation.A.alignment),
  643. 'element_c': DataTypeTag[operation.C.element],
  644. 'layout_c': LayoutTag[operation.C.layout],
  645. 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
  646. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  647. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  648. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  649. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  650. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  651. 'warp_shape_m': str(warp_shape[0]),
  652. 'warp_shape_n': str(warp_shape[1]),
  653. 'warp_shape_k': str(warp_shape[2]),
  654. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  655. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  656. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  657. 'alignment_c': str(operation.C.alignment),
  658. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  659. 'stages': str(operation.tile_description.stages),
  660. 'math_operator': 'cutlass::arch::OpMultiplyAdd'
  661. }
  662. return SubstituteTemplate(self.template, values)
  663. #
  664. class EmitGemmSplitKParallelInstance:
  665. ''' Responsible for emitting a CUTLASS template definition'''
  666. def __init__(self):
  667. self.template = """
  668. // Gemm operator ${operation_name}
  669. using Operation_${operation_name} = cutlass::gemm::device::GemmSplitKParallel<
  670. ${element_a}, ${layout_a},
  671. ${element_b}, ${layout_b},
  672. ${element_c}, ${layout_c},
  673. ${element_accumulator},
  674. ${opcode_class},
  675. ${arch},
  676. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  677. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  678. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  679. ${epilogue_functor}<
  680. ${element_c},
  681. ${epilogue_vector_length},
  682. ${element_accumulator},
  683. ${element_epilogue}
  684. >
  685. >;
  686. """
  687. def emit(self, operation):
  688. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  689. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  690. values = {
  691. 'operation_name': operation.procedural_name(),
  692. 'element_a': DataTypeTag[operation.A.element],
  693. 'layout_a': LayoutTag[operation.A.layout],
  694. 'element_b': DataTypeTag[operation.B.element],
  695. 'layout_b': LayoutTag[operation.B.layout],
  696. 'element_c': DataTypeTag[operation.C.element],
  697. 'layout_c': LayoutTag[operation.C.layout],
  698. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  699. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  700. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  701. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  702. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  703. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  704. 'warp_shape_m': str(warp_shape[0]),
  705. 'warp_shape_n': str(warp_shape[1]),
  706. 'warp_shape_k': str(warp_shape[2]),
  707. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  708. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  709. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  710. 'epilogue_vector_length': str(epilogue_vector_length),
  711. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  712. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  713. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  714. }
  715. return SubstituteTemplate(self.template, values)
  716. ###################################################################################################
  717. ###################################################################################################
  718. #
  719. # Emitters functions for all targets
  720. #
  721. ###################################################################################################
  722. class EmitGemmConfigurationLibrary:
  723. def __init__(self, operation_path, configuration_name):
  724. self.configuration_name = configuration_name
  725. self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
  726. self.instance_emitter = {
  727. GemmKind.Gemm: EmitGemmInstance,
  728. GemmKind.Sparse: EmitSparseGemmInstance,
  729. GemmKind.Universal: EmitGemmUniversalInstance,
  730. GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
  731. GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance
  732. }
  733. self.gemm_kind_wrappers = {
  734. GemmKind.Gemm: 'GemmOperation',
  735. GemmKind.Sparse: 'GemmSparseOperation',
  736. GemmKind.Universal: 'GemmUniversalOperation',
  737. GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
  738. GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation'
  739. }
  740. self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
  741. self.instance_template = {
  742. GemmKind.Gemm: """
  743. ${compile_guard_start}
  744. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  745. ${compile_guard_end}
  746. """,
  747. GemmKind.Sparse: """
  748. ${compile_guard_start}
  749. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  750. ${compile_guard_end}
  751. """,
  752. GemmKind.Universal: """
  753. ${compile_guard_start}
  754. manifest.append(new ${gemm_kind}<
  755. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  756. >("${operation_name}"));
  757. ${compile_guard_end}
  758. """,
  759. GemmKind.PlanarComplex: """
  760. ${compile_guard_start}
  761. manifest.append(new ${gemm_kind}<
  762. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  763. >("${operation_name}"));
  764. ${compile_guard_end}
  765. """,
  766. GemmKind.PlanarComplexArray: """
  767. ${compile_guard_start}
  768. manifest.append(new ${gemm_kind}<
  769. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  770. >("${operation_name}"));
  771. ${compile_guard_end}
  772. """
  773. }
  774. self.header_template = """
  775. /*
  776. Generated by gemm_operation.py - Do not edit.
  777. */
  778. ///////////////////////////////////////////////////////////////////////////////////////////////////
  779. #include "cutlass/arch/wmma.h"
  780. #include "cutlass/cutlass.h"
  781. #include "cutlass/library/library.h"
  782. #include "cutlass/library/manifest.h"
  783. #include "library_internal.h"
  784. #include "gemm_operation.h"
  785. ///////////////////////////////////////////////////////////////////////////////////////////////////
  786. """
  787. self.initialize_function_template = """
  788. ///////////////////////////////////////////////////////////////////////////////////////////////////
  789. namespace cutlass {
  790. namespace library {
  791. ///////////////////////////////////////////////////////////////////////////////////////////////////
  792. void initialize_${configuration_name}(Manifest &manifest) {
  793. """
  794. self.epilogue_template = """
  795. }
  796. ///////////////////////////////////////////////////////////////////////////////////////////////////
  797. } // namespace library
  798. } // namespace cutlass
  799. ///////////////////////////////////////////////////////////////////////////////////////////////////
  800. """
  801. def __enter__(self):
  802. self.configuration_file = open(self.configuration_path, "w")
  803. self.configuration_file.write(self.header_template)
  804. self.instance_definitions = []
  805. self.instance_wrappers = []
  806. self.operations = []
  807. return self
  808. def emit(self, operation):
  809. emitter = self.instance_emitter[operation.gemm_kind]()
  810. self.operations.append(operation)
  811. self.instance_definitions.append(emitter.emit(operation))
  812. self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.gemm_kind], {
  813. 'configuration_name': self.configuration_name,
  814. 'operation_name': operation.procedural_name(),
  815. 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind],
  816. 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
  817. if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
  818. 'compile_guard_end': "#endif" \
  819. if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
  820. }))
  821. def __exit__(self, exception_type, exception_value, traceback):
  822. # Write instance definitions in top-level namespace
  823. for instance_definition in self.instance_definitions:
  824. self.configuration_file.write(instance_definition)
  825. # Add wrapper objects within initialize() function
  826. self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
  827. 'configuration_name': self.configuration_name
  828. }))
  829. for instance_wrapper in self.instance_wrappers:
  830. self.configuration_file.write(instance_wrapper)
  831. self.configuration_file.write(self.epilogue_template)
  832. self.configuration_file.close()
  833. ###################################################################################################
  834. ###################################################################################################
  835. class EmitGemmSingleKernelWrapper:
  836. def __init__(self, kernel_path, gemm_operation, short_path=False):
  837. self.short_path = short_path
  838. self.kernel_path = kernel_path
  839. self.operation = gemm_operation
  840. instance_emitters = {
  841. GemmKind.Gemm: EmitGemmInstance(),
  842. GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(),
  843. }
  844. self.instance_emitter = instance_emitters[self.operation.gemm_kind]
  845. self.header_template = """
  846. #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
  847. // ignore warning of cutlass
  848. #pragma GCC diagnostic push
  849. #pragma GCC diagnostic ignored "-Wunused-parameter"
  850. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  851. #pragma GCC diagnostic ignored "-Wuninitialized"
  852. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  853. #include "cutlass/gemm/device/gemm.h"
  854. #include "cutlass/gemm/device/gemm_splitk_parallel.h"
  855. #include "src/cuda/cutlass/manifest.h"
  856. #include "src/cuda/cutlass/gemm_operation.h"
  857. """
  858. self.instance_template = """
  859. ${operation_instance}
  860. """
  861. self.manifest_template = """
  862. namespace cutlass {
  863. namespace library {
  864. void initialize_${operation_name}(Manifest &manifest) {
  865. manifest.append(new GemmOperation<
  866. Operation_${operation_name}
  867. >("${operation_name}"));
  868. }
  869. } // namespace library
  870. } // namespace cutlass
  871. """
  872. self.epilogue_template = """
  873. #pragma GCC diagnostic pop
  874. #endif
  875. """
  876. #
  877. def __enter__(self):
  878. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  879. self.kernel_file = open(self.kernel_path, "w")
  880. self.kernel_file.write(self.header_template)
  881. return self
  882. #
  883. def emit(self):
  884. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  885. 'operation_instance': self.instance_emitter.emit(self.operation),
  886. }))
  887. # emit manifest helper
  888. manifest = SubstituteTemplate(self.manifest_template, {
  889. 'operation_name': self.operation.procedural_name(),
  890. })
  891. self.kernel_file.write(manifest)
  892. #
  893. def __exit__(self, exception_type, exception_value, traceback):
  894. self.kernel_file.write(self.epilogue_template)
  895. self.kernel_file.close()
  896. ###################################################################################################
  897. ###################################################################################################
  898. class EmitGemvSingleKernelWrapper:
  899. def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False):
  900. self.kernel_path = kernel_path
  901. self.wrapper_path = wrapper_path
  902. self.operation = gemm_operation
  903. self.short_path = short_path
  904. self.wrapper_template = """
  905. template void megdnn::cuda::cutlass_wrapper::
  906. cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>(
  907. BatchedGemmCoord const& problem_size,
  908. const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a,
  909. const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b,
  910. typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
  911. cudaStream_t stream);
  912. """
  913. self.instance_emitter = EmitGemvBatchedStridedInstance()
  914. self.header_template = """
  915. #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
  916. // ignore warning of cutlass
  917. #pragma GCC diagnostic push
  918. #pragma GCC diagnostic ignored "-Wunused-parameter"
  919. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  920. #pragma GCC diagnostic ignored "-Wuninitialized"
  921. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  922. #include "${wrapper_path}"
  923. """
  924. self.instance_template = """
  925. ${operation_instance}
  926. """
  927. self.epilogue_template = """
  928. #pragma GCC diagnostic pop
  929. #endif
  930. """
  931. #
  932. def __enter__(self):
  933. if self.short_path:
  934. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  935. GlobalCnt.cnt += 1
  936. else:
  937. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  938. self.kernel_file = open(self.kernel_path, "w")
  939. self.kernel_file.write(SubstituteTemplate(self.header_template, {
  940. 'wrapper_path': self.wrapper_path,
  941. }))
  942. return self
  943. #
  944. def emit(self):
  945. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  946. 'operation_instance': self.instance_emitter.emit(self.operation),
  947. }))
  948. # emit wrapper
  949. wrapper = SubstituteTemplate(self.wrapper_template, {
  950. 'operation_name': self.operation.procedural_name(),
  951. })
  952. self.kernel_file.write(wrapper)
  953. #
  954. def __exit__(self, exception_type, exception_value, traceback):
  955. self.kernel_file.write(self.epilogue_template)
  956. self.kernel_file.close()
  957. ###################################################################################################
  958. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台