You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemm_operation.py 46 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. import functools
  10. import operator
  11. from library import *
  12. ###################################################################################################
  13. #
  14. # Data structure modeling a GEMM operation
  15. #
  16. ###################################################################################################
  17. #
  18. class GemmOperation:
  19. #
  20. def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
  21. epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8):
  22. self.operation_kind = OperationKind.Gemm
  23. self.arch = arch
  24. self.tile_description = tile_description
  25. self.gemm_kind = gemm_kind
  26. self.A = A
  27. self.B = B
  28. self.C = C
  29. self.element_epilogue = element_epilogue
  30. self.epilogue_functor = epilogue_functor
  31. self.swizzling_functor = swizzling_functor
  32. #
  33. def is_complex(self):
  34. complex_operators = [
  35. MathOperation.multiply_add_complex,
  36. MathOperation.multiply_add_complex_gaussian
  37. ]
  38. return self.tile_description.math_instruction.math_operation in complex_operators
  39. #
  40. def is_split_k_parallel(self):
  41. return self.gemm_kind == GemmKind.SplitKParallel
  42. #
  43. def is_planar_complex(self):
  44. return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
  45. #
  46. def accumulator_type(self):
  47. accum = self.tile_description.math_instruction.element_accumulator
  48. if self.is_complex():
  49. return get_complex_from_real(accum)
  50. return accum
  51. #
  52. def short_math_name(self):
  53. if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
  54. return "g%s" % ShortDataTypeNames[self.accumulator_type()]
  55. return ShortDataTypeNames[self.accumulator_type()]
  56. #
  57. def core_name(self):
  58. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  59. inst_shape = ''
  60. inst_operation = ''
  61. intermediate_type = ''
  62. math_operations_map = {
  63. MathOperation.xor_popc: 'xor',
  64. }
  65. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
  66. self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
  67. math_op = self.tile_description.math_instruction.math_operation
  68. math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
  69. inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
  70. inst_shape += math_op_string
  71. if self.tile_description.math_instruction.element_a != self.A.element and \
  72. self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
  73. intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
  74. return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
  75. #
  76. def extended_name(self):
  77. ''' Append data types if they differ from compute type. '''
  78. if self.is_complex():
  79. extended_name = "${core_name}"
  80. else:
  81. if self.C.element != self.tile_description.math_instruction.element_accumulator and \
  82. self.A.element != self.tile_description.math_instruction.element_accumulator:
  83. extended_name = "${element_c}_${core_name}_${element_a}"
  84. elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
  85. self.A.element != self.tile_description.math_instruction.element_accumulator:
  86. extended_name = "${core_name}_${element_a}"
  87. else:
  88. extended_name = "${core_name}"
  89. extended_name = SubstituteTemplate(extended_name, {
  90. 'element_a': DataTypeNames[self.A.element],
  91. 'element_c': DataTypeNames[self.C.element],
  92. 'core_name': self.core_name()
  93. })
  94. return extended_name
  95. #
  96. def layout_name(self):
  97. if self.is_complex() or self.is_planar_complex():
  98. return "%s%s" % (
  99. ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
  100. ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
  101. )
  102. return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
  103. #
  104. def procedural_name(self):
  105. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  106. threadblock = self.tile_description.procedural_name()
  107. opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
  108. alignment = max([self.A.alignment, self.B.alignment, self.C.alignment])
  109. return SubstituteTemplate(
  110. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}",
  111. {
  112. 'opcode_class': opcode_class_name,
  113. 'extended_name': self.extended_name(),
  114. 'threadblock': threadblock,
  115. 'layout': self.layout_name(),
  116. 'alignment': "%d" % self.A.alignment,
  117. }
  118. )
  119. #
  120. def configuration_name(self):
  121. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  122. return self.procedural_name()
  123. ###################################################################################################
  124. #
  125. # Data structure modeling a GEMV Batched Strided operation
  126. #
  127. ###################################################################################################
  128. #
  129. class GemvBatchedStridedOperation:
  130. #
  131. def __init__(self, gemm_kind, arch, math_inst, threadblock_shape, thread_shape, A, B, C):
  132. self.operation_kind = OperationKind.Gemm
  133. self.arch = arch
  134. self.gemm_kind = gemm_kind
  135. self.math_instruction = math_inst
  136. self.threadblock_shape = threadblock_shape
  137. self.thread_shape = thread_shape
  138. self.A = A
  139. self.B = B
  140. self.C = C
  141. #
  142. def accumulator_type(self):
  143. accum = self.math_instruction.element_accumulator
  144. return accum
  145. #
  146. def short_math_name(self):
  147. return ShortDataTypeNames[self.accumulator_type()]
  148. #
  149. def core_name(self):
  150. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  151. return "%s%s" % (self.short_math_name(), \
  152. GemmKindNames[self.gemm_kind])
  153. #
  154. def extended_name(self):
  155. ''' Append data types if they differ from compute type. '''
  156. if self.C.element != self.math_instruction.element_accumulator and \
  157. self.A.element != self.math_instruction.element_accumulator:
  158. extended_name = "${element_c}_${core_name}_${element_a}"
  159. elif self.C.element == self.math_instruction.element_accumulator and \
  160. self.A.element != self.math_instruction.element_accumulator:
  161. extended_name = "${core_name}_${element_a}"
  162. else:
  163. extended_name = "${core_name}"
  164. extended_name = SubstituteTemplate(extended_name, {
  165. 'element_a': DataTypeNames[self.A.element],
  166. 'element_c': DataTypeNames[self.C.element],
  167. 'core_name': self.core_name()
  168. })
  169. return extended_name
  170. #
  171. def layout_name(self):
  172. return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
  173. #
  174. def procedural_name(self):
  175. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  176. threadblock = "%dx%d_%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2])
  177. opcode_class_name = OpcodeClassNames[self.math_instruction.opcode_class]
  178. alignment_a = self.A.alignment
  179. alignment_b = self.B.alignment
  180. return SubstituteTemplate(
  181. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment_a}x${alignment_b}",
  182. {
  183. 'opcode_class': opcode_class_name,
  184. 'extended_name': self.extended_name(),
  185. 'threadblock': threadblock,
  186. 'layout': self.layout_name(),
  187. 'alignment_a': "%d" % alignment_a,
  188. 'alignment_b': "%d" % alignment_b,
  189. }
  190. )
  191. #
  192. def configuration_name(self):
  193. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  194. return self.procedural_name()
  195. #
  196. def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a = 32, align_b = 32, align_c = 32):
  197. operations = []
  198. swizzling_functor = SwizzlingFunctor.Identity1
  199. element_a, element_b, element_c, element_epilogue = data_type
  200. if tile.math_instruction.element_accumulator == DataType.s32:
  201. epilogues = [EpilogueFunctor.LinearCombinationClamp]
  202. else:
  203. assert tile.math_instruction.element_accumulator == DataType.f32 or \
  204. tile.math_instruction.element_accumulator == DataType.f16
  205. epilogues = [EpilogueFunctor.LinearCombination]
  206. for epilogue in epilogues:
  207. A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a]))
  208. B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b]))
  209. C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c]))
  210. operations.append(GemmOperation(GemmKind.Gemm, min_cc, tile, A, B, C, \
  211. element_epilogue, epilogue, swizzling_functor))
  212. operations.append(GemmOperation(GemmKind.SplitKParallel, min_cc, tile, A, B, C, \
  213. element_epilogue, epilogue, swizzling_functor))
  214. return operations
  215. def GeneratesGemv(math_inst, threadblock_shape, thread_shape, data_type, layout_a, layout_b, layout_c, min_cc, \
  216. align_a = 32, align_b = 32, align_c = 32):
  217. element_a, element_b, element_c, element_epilogue = data_type
  218. A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a]))
  219. B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b]))
  220. C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c]))
  221. return GemvBatchedStridedOperation(GemmKind.GemvBatchedStrided, min_cc, math_inst, threadblock_shape, thread_shape, \
  222. A, B, C)
  223. ###################################################################################################
  224. #
  225. # Emits single instances of a CUTLASS device-wide operator
  226. #
  227. ###################################################################################################
  228. #
  229. class EmitGemmInstance:
  230. ''' Responsible for emitting a CUTLASS template definition'''
  231. def __init__(self):
  232. self.gemm_template = """
  233. // Gemm operator ${operation_name}
  234. using Operation_${operation_name} = cutlass::gemm::device::Gemm<
  235. ${element_a}, ${layout_a},
  236. ${element_b}, ${layout_b},
  237. ${element_c}, ${layout_c},
  238. ${element_accumulator},
  239. ${opcode_class},
  240. ${arch},
  241. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  242. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  243. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  244. ${epilogue_functor}<
  245. ${element_c},
  246. ${epilogue_vector_length},
  247. ${element_accumulator},
  248. ${element_epilogue}
  249. >,
  250. ${swizzling_functor},
  251. ${stages},
  252. ${align_a},
  253. ${align_b},
  254. false,
  255. ${math_operation}
  256. ${residual}
  257. >;
  258. """
  259. self.gemm_complex_template = """
  260. // Gemm operator ${operation_name}
  261. using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
  262. ${element_a}, ${layout_a},
  263. ${element_b}, ${layout_b},
  264. ${element_c}, ${layout_c},
  265. ${element_accumulator},
  266. ${opcode_class},
  267. ${arch},
  268. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  269. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  270. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  271. ${epilogue_functor}<
  272. ${element_c},
  273. ${epilogue_vector_length},
  274. ${element_accumulator},
  275. ${element_epilogue}
  276. >,
  277. ${swizzling_functor},
  278. ${stages},
  279. ${transform_a},
  280. ${transform_b},
  281. ${math_operation}
  282. ${residual}
  283. >;
  284. """
  285. def emit(self, operation):
  286. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  287. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  288. residual = ''
  289. values = {
  290. 'operation_name': operation.procedural_name(),
  291. 'element_a': DataTypeTag[operation.A.element],
  292. 'layout_a': LayoutTag[operation.A.layout],
  293. 'element_b': DataTypeTag[operation.B.element],
  294. 'layout_b': LayoutTag[operation.B.layout],
  295. 'element_c': DataTypeTag[operation.C.element],
  296. 'layout_c': LayoutTag[operation.C.layout],
  297. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  298. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  299. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  300. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  301. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  302. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  303. 'warp_shape_m': str(warp_shape[0]),
  304. 'warp_shape_n': str(warp_shape[1]),
  305. 'warp_shape_k': str(warp_shape[2]),
  306. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  307. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  308. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  309. 'epilogue_vector_length': str(epilogue_vector_length),
  310. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  311. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  312. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  313. 'stages': str(operation.tile_description.stages),
  314. 'align_a': str(operation.A.alignment),
  315. 'align_b': str(operation.B.alignment),
  316. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  317. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  318. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  319. 'residual': residual
  320. }
  321. template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
  322. return SubstituteTemplate(template, values)
  323. #
  324. class EmitGemvBatchedStridedInstance:
  325. ''' Responsible for emitting a CUTLASS template definition'''
  326. def __init__(self):
  327. self.template = """
  328. // Gemm operator ${operation_name}
  329. using Operation_${operation_name} = cutlass::gemm::kernel::DefaultGemv<
  330. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  331. cutlass::gemm::GemmShape<${thread_shape_m}, ${thread_shape_n}, ${thread_shape_k}>,
  332. ${element_a}, ${layout_a},
  333. ${element_b}, ${layout_b},
  334. ${element_c}, ${layout_c}
  335. >;
  336. """
  337. def emit(self, operation):
  338. values = {
  339. 'operation_name': operation.procedural_name(),
  340. 'element_a': DataTypeTag[operation.A.element],
  341. 'layout_a': LayoutTag[operation.A.layout],
  342. 'element_b': DataTypeTag[operation.B.element],
  343. 'layout_b': LayoutTag[operation.B.layout],
  344. 'element_c': DataTypeTag[operation.C.element],
  345. 'layout_c': LayoutTag[operation.C.layout],
  346. 'threadblock_shape_m': str(operation.threadblock_shape[0]),
  347. 'threadblock_shape_n': str(operation.threadblock_shape[1]),
  348. 'threadblock_shape_k': str(operation.threadblock_shape[2]),
  349. 'thread_shape_m': str(operation.thread_shape[0]),
  350. 'thread_shape_n': str(operation.thread_shape[1]),
  351. 'thread_shape_k': str(operation.thread_shape[2]),
  352. }
  353. return SubstituteTemplate(self.template, values)
  354. ###################################################################################################
  355. class EmitSparseGemmInstance:
  356. ''' Responsible for emitting a CUTLASS template definition'''
  357. def __init__(self):
  358. self.gemm_template = """
  359. // Gemm operator ${operation_name}
  360. using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
  361. ${element_a}, ${layout_a},
  362. ${element_b}, ${layout_b},
  363. ${element_c}, ${layout_c},
  364. ${element_accumulator},
  365. ${opcode_class},
  366. ${arch},
  367. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  368. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  369. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  370. ${epilogue_functor}<
  371. ${element_c},
  372. ${epilogue_vector_length},
  373. ${element_accumulator},
  374. ${element_epilogue}
  375. >,
  376. ${swizzling_functor},
  377. ${stages},
  378. ${align_a},
  379. ${align_b},
  380. false,
  381. ${math_operation}
  382. ${residual}
  383. >;
  384. """
  385. def emit(self, operation):
  386. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  387. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  388. residual = ''
  389. values = {
  390. 'operation_name': operation.procedural_name(),
  391. 'element_a': DataTypeTag[operation.A.element],
  392. 'layout_a': LayoutTag[operation.A.layout],
  393. 'element_b': DataTypeTag[operation.B.element],
  394. 'layout_b': LayoutTag[operation.B.layout],
  395. 'element_c': DataTypeTag[operation.C.element],
  396. 'layout_c': LayoutTag[operation.C.layout],
  397. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  398. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  399. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  400. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  401. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  402. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  403. 'warp_shape_m': str(warp_shape[0]),
  404. 'warp_shape_n': str(warp_shape[1]),
  405. 'warp_shape_k': str(warp_shape[2]),
  406. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  407. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  408. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  409. 'epilogue_vector_length': str(epilogue_vector_length),
  410. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  411. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  412. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  413. 'stages': str(operation.tile_description.stages),
  414. 'align_a': str(operation.A.alignment),
  415. 'align_b': str(operation.B.alignment),
  416. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  417. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  418. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  419. 'residual': residual
  420. }
  421. template = self.gemm_template
  422. return SubstituteTemplate(template, values)
  423. ###################################################################################################
  424. #
  425. class EmitGemmUniversalInstance:
  426. ''' Responsible for emitting a CUTLASS template definition'''
  427. def __init__(self):
  428. self.gemm_template = """
  429. // Gemm operator ${operation_name}
  430. using ${operation_name}_base =
  431. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  432. ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
  433. ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
  434. ${element_c}, ${layout_c},
  435. ${element_accumulator},
  436. ${opcode_class},
  437. ${arch},
  438. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  439. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  440. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  441. ${epilogue_functor}<
  442. ${element_c},
  443. ${epilogue_vector_length},
  444. ${element_accumulator},
  445. ${element_epilogue}
  446. >,
  447. ${swizzling_functor},
  448. ${stages},
  449. ${math_operation}
  450. >::GemmKernel;
  451. // Define named type
  452. struct ${operation_name} :
  453. public ${operation_name}_base { };
  454. """
  455. self.gemm_template_interleaved = """
  456. // Gemm operator ${operation_name}
  457. using ${operation_name}_base =
  458. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  459. ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
  460. ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
  461. ${element_c}, ${layout_c},
  462. ${element_accumulator},
  463. ${opcode_class},
  464. ${arch},
  465. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  466. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  467. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  468. ${epilogue_functor}<
  469. ${element_c},
  470. ${epilogue_vector_length},
  471. ${element_accumulator},
  472. ${element_epilogue}
  473. >,
  474. ${swizzling_functor},
  475. ${stages},
  476. ${math_operation}
  477. >::GemmKernel;
  478. // Define named type
  479. struct ${operation_name} :
  480. public ${operation_name}_base { };
  481. """
  482. def emit(self, operation):
  483. threadblock_shape = operation.tile_description.threadblock_shape
  484. warp_count = operation.tile_description.warp_count
  485. warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
  486. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  487. transpose_layouts = {
  488. LayoutType.ColumnMajor: LayoutType.RowMajor,
  489. LayoutType.RowMajor: LayoutType.ColumnMajor
  490. }
  491. if operation.A.layout in transpose_layouts.keys() and \
  492. operation.B.layout in transpose_layouts.keys() and \
  493. operation.C.layout in transpose_layouts.keys():
  494. instance_layout_A = transpose_layouts[operation.A.layout]
  495. instance_layout_B = transpose_layouts[operation.B.layout]
  496. instance_layout_C = transpose_layouts[operation.C.layout]
  497. gemm_template = self.gemm_template
  498. else:
  499. instance_layout_A, instance_layout_B, instance_layout_C = \
  500. (operation.A.layout, operation.B.layout, operation.C.layout)
  501. gemm_template = self.gemm_template_interleaved
  502. #
  503. values = {
  504. 'operation_name': operation.procedural_name(),
  505. 'element_a': DataTypeTag[operation.A.element],
  506. 'layout_a': LayoutTag[instance_layout_A],
  507. 'element_b': DataTypeTag[operation.B.element],
  508. 'layout_b': LayoutTag[instance_layout_B],
  509. 'element_c': DataTypeTag[operation.C.element],
  510. 'layout_c': LayoutTag[instance_layout_C],
  511. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  512. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  513. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  514. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  515. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  516. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  517. 'warp_shape_m': str(warp_shape[0]),
  518. 'warp_shape_n': str(warp_shape[1]),
  519. 'warp_shape_k': str(warp_shape[2]),
  520. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  521. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  522. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  523. 'epilogue_vector_length': str(epilogue_vector_length),
  524. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  525. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  526. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  527. 'stages': str(operation.tile_description.stages),
  528. 'align_a': str(operation.A.alignment),
  529. 'align_b': str(operation.B.alignment),
  530. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  531. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  532. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
  533. }
  534. return SubstituteTemplate(gemm_template, values)
  535. ###################################################################################################
  536. #
  537. class EmitGemmPlanarComplexInstance:
  538. ''' Responsible for emitting a CUTLASS template definition'''
  539. def __init__(self):
  540. self.template = """
  541. // Gemm operator ${operation_name}
  542. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  543. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  544. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  545. ${element_c}, cutlass::layout::RowMajor,
  546. ${element_accumulator},
  547. ${opcode_class},
  548. ${arch},
  549. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  550. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  551. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  552. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  553. ${element_c},
  554. ${alignment_c},
  555. ${element_accumulator},
  556. ${element_epilogue}
  557. >,
  558. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  559. ${stages},
  560. ${math_operator}
  561. >::GemmKernel;
  562. struct ${operation_name} :
  563. public Operation_${operation_name} { };
  564. """
  565. def emit(self, operation):
  566. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  567. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  568. transposed_layout_A = TransposedLayout[operation.A.layout]
  569. transposed_layout_B = TransposedLayout[operation.B.layout]
  570. values = {
  571. 'operation_name': operation.procedural_name(),
  572. 'element_a': DataTypeTag[operation.B.element],
  573. 'layout_a': LayoutTag[transposed_layout_B],
  574. 'transform_a': ComplexTransformTag[operation.B.complex_transform],
  575. 'alignment_a': str(operation.B.alignment),
  576. 'element_b': DataTypeTag[operation.A.element],
  577. 'layout_b': LayoutTag[transposed_layout_A],
  578. 'transform_b': ComplexTransformTag[operation.A.complex_transform],
  579. 'alignment_b': str(operation.A.alignment),
  580. 'element_c': DataTypeTag[operation.C.element],
  581. 'layout_c': LayoutTag[operation.C.layout],
  582. 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
  583. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  584. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  585. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  586. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  587. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  588. 'warp_shape_m': str(warp_shape[0]),
  589. 'warp_shape_n': str(warp_shape[1]),
  590. 'warp_shape_k': str(warp_shape[2]),
  591. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  592. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  593. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  594. 'alignment_c': str(operation.C.alignment),
  595. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  596. 'stages': str(operation.tile_description.stages),
  597. 'math_operator': 'cutlass::arch::OpMultiplyAdd'
  598. }
  599. return SubstituteTemplate(self.template, values)
  600. ###################################################################################################
  601. #
  602. class EmitGemmPlanarComplexArrayInstance:
  603. ''' Responsible for emitting a CUTLASS template definition'''
  604. def __init__(self):
  605. self.template = """
  606. // Gemm operator ${operation_name}
  607. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  608. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  609. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  610. ${element_c}, cutlass::layout::RowMajor,
  611. ${element_accumulator},
  612. ${opcode_class},
  613. ${arch},
  614. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  615. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  616. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  617. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  618. ${element_c},
  619. ${alignment_c},
  620. ${element_accumulator},
  621. ${element_epilogue}
  622. >,
  623. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  624. ${stages},
  625. ${math_operator}
  626. >::GemmArrayKernel;
  627. struct ${operation_name} : public Operation_${operation_name} { };
  628. """
  629. def emit(self, operation):
  630. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  631. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  632. transposed_layout_A = TransposedLayout[operation.A.layout]
  633. transposed_layout_B = TransposedLayout[operation.B.layout]
  634. values = {
  635. 'operation_name': operation.procedural_name(),
  636. 'element_a': DataTypeTag[operation.B.element],
  637. 'layout_a': LayoutTag[transposed_layout_B],
  638. 'transform_a': ComplexTransformTag[operation.B.complex_transform],
  639. 'alignment_a': str(operation.B.alignment),
  640. 'element_b': DataTypeTag[operation.A.element],
  641. 'layout_b': LayoutTag[transposed_layout_A],
  642. 'transform_b': ComplexTransformTag[operation.A.complex_transform],
  643. 'alignment_b': str(operation.A.alignment),
  644. 'element_c': DataTypeTag[operation.C.element],
  645. 'layout_c': LayoutTag[operation.C.layout],
  646. 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
  647. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  648. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  649. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  650. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  651. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  652. 'warp_shape_m': str(warp_shape[0]),
  653. 'warp_shape_n': str(warp_shape[1]),
  654. 'warp_shape_k': str(warp_shape[2]),
  655. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  656. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  657. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  658. 'alignment_c': str(operation.C.alignment),
  659. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  660. 'stages': str(operation.tile_description.stages),
  661. 'math_operator': 'cutlass::arch::OpMultiplyAdd'
  662. }
  663. return SubstituteTemplate(self.template, values)
  664. #
  665. class EmitGemmSplitKParallelInstance:
  666. ''' Responsible for emitting a CUTLASS template definition'''
  667. def __init__(self):
  668. self.template = """
  669. // Gemm operator ${operation_name}
  670. using Operation_${operation_name} = cutlass::gemm::device::GemmSplitKParallel<
  671. ${element_a}, ${layout_a},
  672. ${element_b}, ${layout_b},
  673. ${element_c}, ${layout_c},
  674. ${element_accumulator},
  675. ${opcode_class},
  676. ${arch},
  677. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  678. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  679. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  680. ${epilogue_functor}<
  681. ${element_c},
  682. ${epilogue_vector_length},
  683. ${element_accumulator},
  684. ${element_epilogue}
  685. >,
  686. cutlass::epilogue::thread::Convert<
  687. ${element_accumulator},
  688. ${epilogue_vector_length},
  689. ${element_accumulator}
  690. >,
  691. cutlass::reduction::thread::ReduceAdd<
  692. ${element_accumulator},
  693. ${element_accumulator},
  694. ${epilogue_vector_length}
  695. >,
  696. cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle,
  697. ${stages},
  698. ${align_a},
  699. ${align_b},
  700. ${math_operation}
  701. >;
  702. """
  703. def emit(self, operation):
  704. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  705. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  706. values = {
  707. 'operation_name': operation.procedural_name(),
  708. 'element_a': DataTypeTag[operation.A.element],
  709. 'layout_a': LayoutTag[operation.A.layout],
  710. 'element_b': DataTypeTag[operation.B.element],
  711. 'layout_b': LayoutTag[operation.B.layout],
  712. 'element_c': DataTypeTag[operation.C.element],
  713. 'layout_c': LayoutTag[operation.C.layout],
  714. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  715. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  716. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  717. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  718. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  719. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  720. 'warp_shape_m': str(warp_shape[0]),
  721. 'warp_shape_n': str(warp_shape[1]),
  722. 'warp_shape_k': str(warp_shape[2]),
  723. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  724. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  725. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  726. 'epilogue_vector_length': str(epilogue_vector_length),
  727. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  728. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  729. 'stages': str(operation.tile_description.stages),
  730. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  731. 'align_a': str(operation.A.alignment),
  732. 'align_b': str(operation.B.alignment),
  733. }
  734. return SubstituteTemplate(self.template, values)
  735. ###################################################################################################
  736. ###################################################################################################
  737. #
  738. # Emitters functions for all targets
  739. #
  740. ###################################################################################################
  741. class EmitGemmConfigurationLibrary:
  742. def __init__(self, operation_path, configuration_name):
  743. self.configuration_name = configuration_name
  744. self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
  745. self.instance_emitter = {
  746. GemmKind.Gemm: EmitGemmInstance,
  747. GemmKind.Sparse: EmitSparseGemmInstance,
  748. GemmKind.Universal: EmitGemmUniversalInstance,
  749. GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
  750. GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance
  751. }
  752. self.gemm_kind_wrappers = {
  753. GemmKind.Gemm: 'GemmOperation',
  754. GemmKind.Sparse: 'GemmSparseOperation',
  755. GemmKind.Universal: 'GemmUniversalOperation',
  756. GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
  757. GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation'
  758. }
  759. self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
  760. self.instance_template = {
  761. GemmKind.Gemm: """
  762. ${compile_guard_start}
  763. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  764. ${compile_guard_end}
  765. """,
  766. GemmKind.Sparse: """
  767. ${compile_guard_start}
  768. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  769. ${compile_guard_end}
  770. """,
  771. GemmKind.Universal: """
  772. ${compile_guard_start}
  773. manifest.append(new ${gemm_kind}<
  774. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  775. >("${operation_name}"));
  776. ${compile_guard_end}
  777. """,
  778. GemmKind.PlanarComplex: """
  779. ${compile_guard_start}
  780. manifest.append(new ${gemm_kind}<
  781. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  782. >("${operation_name}"));
  783. ${compile_guard_end}
  784. """,
  785. GemmKind.PlanarComplexArray: """
  786. ${compile_guard_start}
  787. manifest.append(new ${gemm_kind}<
  788. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  789. >("${operation_name}"));
  790. ${compile_guard_end}
  791. """
  792. }
  793. self.header_template = """
  794. /*
  795. Generated by gemm_operation.py - Do not edit.
  796. */
  797. ///////////////////////////////////////////////////////////////////////////////////////////////////
  798. #include "cutlass/arch/wmma.h"
  799. #include "cutlass/cutlass.h"
  800. #include "cutlass/library/library.h"
  801. #include "cutlass/library/manifest.h"
  802. #include "library_internal.h"
  803. #include "gemm_operation.h"
  804. ///////////////////////////////////////////////////////////////////////////////////////////////////
  805. """
  806. self.initialize_function_template = """
  807. ///////////////////////////////////////////////////////////////////////////////////////////////////
  808. namespace cutlass {
  809. namespace library {
  810. ///////////////////////////////////////////////////////////////////////////////////////////////////
  811. void initialize_${configuration_name}(Manifest &manifest) {
  812. """
  813. self.epilogue_template = """
  814. }
  815. ///////////////////////////////////////////////////////////////////////////////////////////////////
  816. } // namespace library
  817. } // namespace cutlass
  818. ///////////////////////////////////////////////////////////////////////////////////////////////////
  819. """
  820. def __enter__(self):
  821. self.configuration_file = open(self.configuration_path, "w")
  822. self.configuration_file.write(self.header_template)
  823. self.instance_definitions = []
  824. self.instance_wrappers = []
  825. self.operations = []
  826. return self
  827. def emit(self, operation):
  828. emitter = self.instance_emitter[operation.gemm_kind]()
  829. self.operations.append(operation)
  830. self.instance_definitions.append(emitter.emit(operation))
  831. self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.gemm_kind], {
  832. 'configuration_name': self.configuration_name,
  833. 'operation_name': operation.procedural_name(),
  834. 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind],
  835. 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
  836. if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
  837. 'compile_guard_end': "#endif" \
  838. if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
  839. }))
  840. def __exit__(self, exception_type, exception_value, traceback):
  841. # Write instance definitions in top-level namespace
  842. for instance_definition in self.instance_definitions:
  843. self.configuration_file.write(instance_definition)
  844. # Add wrapper objects within initialize() function
  845. self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
  846. 'configuration_name': self.configuration_name
  847. }))
  848. for instance_wrapper in self.instance_wrappers:
  849. self.configuration_file.write(instance_wrapper)
  850. self.configuration_file.write(self.epilogue_template)
  851. self.configuration_file.close()
  852. ###################################################################################################
  853. ###################################################################################################
  854. class EmitGemmSingleKernelWrapper:
  855. def __init__(self, kernel_path, gemm_operation, short_path=False):
  856. self.short_path = short_path
  857. self.kernel_path = kernel_path
  858. self.operation = gemm_operation
  859. instance_emitters = {
  860. GemmKind.Gemm: EmitGemmInstance(),
  861. GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(),
  862. }
  863. self.instance_emitter = instance_emitters[self.operation.gemm_kind]
  864. self.header_template = """
  865. #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
  866. // ignore warning of cutlass
  867. #pragma GCC diagnostic push
  868. #pragma GCC diagnostic ignored "-Wunused-parameter"
  869. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  870. #pragma GCC diagnostic ignored "-Wuninitialized"
  871. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  872. #include "cutlass/gemm/device/gemm.h"
  873. #include "cutlass/gemm/device/gemm_splitk_parallel.h"
  874. #include "src/cuda/cutlass/manifest.h"
  875. #include "src/cuda/cutlass/gemm_operation.h"
  876. """
  877. self.instance_template = """
  878. ${operation_instance}
  879. """
  880. self.manifest_template = """
  881. namespace cutlass {
  882. namespace library {
  883. void initialize_${operation_name}(Manifest &manifest) {
  884. manifest.append(new GemmOperation<
  885. Operation_${operation_name}
  886. >("${operation_name}"));
  887. }
  888. } // namespace library
  889. } // namespace cutlass
  890. """
  891. self.epilogue_template = """
  892. #pragma GCC diagnostic pop
  893. #endif
  894. """
  895. #
  896. def __enter__(self):
  897. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  898. self.kernel_file = open(self.kernel_path, "w")
  899. self.kernel_file.write(self.header_template)
  900. return self
  901. #
  902. def emit(self):
  903. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  904. 'operation_instance': self.instance_emitter.emit(self.operation),
  905. }))
  906. # emit manifest helper
  907. manifest = SubstituteTemplate(self.manifest_template, {
  908. 'operation_name': self.operation.procedural_name(),
  909. })
  910. self.kernel_file.write(manifest)
  911. #
  912. def __exit__(self, exception_type, exception_value, traceback):
  913. self.kernel_file.write(self.epilogue_template)
  914. self.kernel_file.close()
  915. ###################################################################################################
  916. ###################################################################################################
  917. class EmitGemvSingleKernelWrapper:
  918. def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False):
  919. self.kernel_path = kernel_path
  920. self.wrapper_path = wrapper_path
  921. self.operation = gemm_operation
  922. self.short_path = short_path
  923. self.wrapper_template = """
  924. template void megdnn::cuda::cutlass_wrapper::
  925. cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>(
  926. BatchedGemmCoord const& problem_size,
  927. const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a,
  928. const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b,
  929. typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
  930. cudaStream_t stream);
  931. """
  932. self.instance_emitter = EmitGemvBatchedStridedInstance()
  933. self.header_template = """
  934. #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
  935. // ignore warning of cutlass
  936. #pragma GCC diagnostic push
  937. #pragma GCC diagnostic ignored "-Wunused-parameter"
  938. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  939. #pragma GCC diagnostic ignored "-Wuninitialized"
  940. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  941. #include "${wrapper_path}"
  942. """
  943. self.instance_template = """
  944. ${operation_instance}
  945. """
  946. self.epilogue_template = """
  947. #pragma GCC diagnostic pop
  948. #endif
  949. """
  950. #
  951. def __enter__(self):
  952. if self.short_path:
  953. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  954. GlobalCnt.cnt += 1
  955. else:
  956. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  957. self.kernel_file = open(self.kernel_path, "w")
  958. self.kernel_file.write(SubstituteTemplate(self.header_template, {
  959. 'wrapper_path': self.wrapper_path,
  960. }))
  961. return self
  962. #
  963. def emit(self):
  964. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  965. 'operation_instance': self.instance_emitter.emit(self.operation),
  966. }))
  967. # emit wrapper
  968. wrapper = SubstituteTemplate(self.wrapper_template, {
  969. 'operation_name': self.operation.procedural_name(),
  970. })
  971. self.kernel_file.write(wrapper)
  972. #
  973. def __exit__(self, exception_type, exception_value, traceback):
  974. self.kernel_file.write(self.epilogue_template)
  975. self.kernel_file.close()
  976. ###################################################################################################
  977. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台