You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemm_operation.py 47 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. import functools
  10. import operator
  11. from library import *
  12. ###################################################################################################
  13. #
  14. # Data structure modeling a GEMM operation
  15. #
  16. ###################################################################################################
  17. #
  18. class GemmOperation:
  19. #
  20. def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
  21. epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
  22. required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  23. self.operation_kind = OperationKind.Gemm
  24. self.arch = arch
  25. self.tile_description = tile_description
  26. self.gemm_kind = gemm_kind
  27. self.A = A
  28. self.B = B
  29. self.C = C
  30. self.element_epilogue = element_epilogue
  31. self.epilogue_functor = epilogue_functor
  32. self.swizzling_functor = swizzling_functor
  33. self.required_cuda_ver_major = required_cuda_ver_major
  34. self.required_cuda_ver_minor = required_cuda_ver_minor
  35. #
  36. def is_complex(self):
  37. complex_operators = [
  38. MathOperation.multiply_add_complex,
  39. MathOperation.multiply_add_complex_gaussian
  40. ]
  41. return self.tile_description.math_instruction.math_operation in complex_operators
  42. #
  43. def is_split_k_parallel(self):
  44. return self.gemm_kind == GemmKind.SplitKParallel
  45. #
  46. def is_planar_complex(self):
  47. return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
  48. #
  49. def accumulator_type(self):
  50. accum = self.tile_description.math_instruction.element_accumulator
  51. if self.is_complex():
  52. return get_complex_from_real(accum)
  53. return accum
  54. #
  55. def short_math_name(self):
  56. if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
  57. return "g%s" % ShortDataTypeNames[self.accumulator_type()]
  58. return ShortDataTypeNames[self.accumulator_type()]
  59. #
  60. def core_name(self):
  61. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  62. inst_shape = ''
  63. inst_operation = ''
  64. intermediate_type = ''
  65. math_operations_map = {
  66. MathOperation.xor_popc: 'xor',
  67. }
  68. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
  69. self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
  70. math_op = self.tile_description.math_instruction.math_operation
  71. math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
  72. inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
  73. inst_shape += math_op_string
  74. if self.tile_description.math_instruction.element_a != self.A.element and \
  75. self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
  76. intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
  77. return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
  78. #
  79. def extended_name(self):
  80. ''' Append data types if they differ from compute type. '''
  81. if self.is_complex():
  82. extended_name = "${core_name}"
  83. else:
  84. if self.C.element != self.tile_description.math_instruction.element_accumulator and \
  85. self.A.element != self.tile_description.math_instruction.element_accumulator:
  86. extended_name = "${element_c}_${core_name}_${element_a}"
  87. elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
  88. self.A.element != self.tile_description.math_instruction.element_accumulator:
  89. extended_name = "${core_name}_${element_a}"
  90. else:
  91. extended_name = "${core_name}"
  92. extended_name = SubstituteTemplate(extended_name, {
  93. 'element_a': DataTypeNames[self.A.element],
  94. 'element_c': DataTypeNames[self.C.element],
  95. 'core_name': self.core_name()
  96. })
  97. return extended_name
  98. #
  99. def layout_name(self):
  100. if self.is_complex() or self.is_planar_complex():
  101. return "%s%s" % (
  102. ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
  103. ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
  104. )
  105. return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
  106. #
  107. def procedural_name(self):
  108. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  109. threadblock = self.tile_description.procedural_name()
  110. opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
  111. alignment = max([self.A.alignment, self.B.alignment, self.C.alignment])
  112. return SubstituteTemplate(
  113. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}",
  114. {
  115. 'opcode_class': opcode_class_name,
  116. 'extended_name': self.extended_name(),
  117. 'threadblock': threadblock,
  118. 'layout': self.layout_name(),
  119. 'alignment': "%d" % self.A.alignment,
  120. }
  121. )
  122. #
  123. def configuration_name(self):
  124. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  125. return self.procedural_name()
  126. ###################################################################################################
  127. #
  128. # Data structure modeling a GEMV Batched Strided operation
  129. #
  130. ###################################################################################################
  131. #
  132. class GemvBatchedStridedOperation:
  133. #
  134. def __init__(self, gemm_kind, arch, math_inst, threadblock_shape, thread_shape, A, B, C, \
  135. required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  136. self.operation_kind = OperationKind.Gemm
  137. self.arch = arch
  138. self.gemm_kind = gemm_kind
  139. self.math_instruction = math_inst
  140. self.threadblock_shape = threadblock_shape
  141. self.thread_shape = thread_shape
  142. self.A = A
  143. self.B = B
  144. self.C = C
  145. self.required_cuda_ver_major = required_cuda_ver_major
  146. self.required_cuda_ver_minor = required_cuda_ver_minor
  147. #
  148. def accumulator_type(self):
  149. accum = self.math_instruction.element_accumulator
  150. return accum
  151. #
  152. def short_math_name(self):
  153. return ShortDataTypeNames[self.accumulator_type()]
  154. #
  155. def core_name(self):
  156. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  157. return "%s%s" % (self.short_math_name(), \
  158. GemmKindNames[self.gemm_kind])
  159. #
  160. def extended_name(self):
  161. ''' Append data types if they differ from compute type. '''
  162. if self.C.element != self.math_instruction.element_accumulator and \
  163. self.A.element != self.math_instruction.element_accumulator:
  164. extended_name = "${element_c}_${core_name}_${element_a}"
  165. elif self.C.element == self.math_instruction.element_accumulator and \
  166. self.A.element != self.math_instruction.element_accumulator:
  167. extended_name = "${core_name}_${element_a}"
  168. else:
  169. extended_name = "${core_name}"
  170. extended_name = SubstituteTemplate(extended_name, {
  171. 'element_a': DataTypeNames[self.A.element],
  172. 'element_c': DataTypeNames[self.C.element],
  173. 'core_name': self.core_name()
  174. })
  175. return extended_name
  176. #
  177. def layout_name(self):
  178. return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
  179. #
  180. def procedural_name(self):
  181. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  182. threadblock = "%dx%d_%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2])
  183. opcode_class_name = OpcodeClassNames[self.math_instruction.opcode_class]
  184. alignment_a = self.A.alignment
  185. alignment_b = self.B.alignment
  186. return SubstituteTemplate(
  187. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment_a}x${alignment_b}",
  188. {
  189. 'opcode_class': opcode_class_name,
  190. 'extended_name': self.extended_name(),
  191. 'threadblock': threadblock,
  192. 'layout': self.layout_name(),
  193. 'alignment_a': "%d" % alignment_a,
  194. 'alignment_b': "%d" % alignment_b,
  195. }
  196. )
  197. #
  198. def configuration_name(self):
  199. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  200. return self.procedural_name()
  201. #
  202. def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a = 32, align_b = 32, align_c = 32, required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  203. operations = []
  204. swizzling_functor = SwizzlingFunctor.Identity1
  205. element_a, element_b, element_c, element_epilogue = data_type
  206. if tile.math_instruction.element_accumulator == DataType.s32:
  207. epilogues = [EpilogueFunctor.LinearCombinationClamp]
  208. else:
  209. assert tile.math_instruction.element_accumulator == DataType.f32 or \
  210. tile.math_instruction.element_accumulator == DataType.f16
  211. epilogues = [EpilogueFunctor.LinearCombination]
  212. for epilogue in epilogues:
  213. A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a]))
  214. B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b]))
  215. C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c]))
  216. operations.append(GemmOperation(GemmKind.Gemm, min_cc, tile, A, B, C, \
  217. element_epilogue, epilogue, swizzling_functor, \
  218. required_cuda_ver_major, required_cuda_ver_minor))
  219. operations.append(GemmOperation(GemmKind.SplitKParallel, min_cc, tile, A, B, C, \
  220. element_epilogue, epilogue, swizzling_functor, \
  221. required_cuda_ver_major, required_cuda_ver_minor))
  222. return operations
  223. def GeneratesGemv(math_inst, threadblock_shape, thread_shape, data_type, layout_a, layout_b, layout_c, min_cc, \
  224. align_a = 32, align_b = 32, align_c = 32, \
  225. required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  226. element_a, element_b, element_c, element_epilogue = data_type
  227. A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a]))
  228. B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b]))
  229. C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c]))
  230. return GemvBatchedStridedOperation(GemmKind.GemvBatchedStrided, min_cc, math_inst, threadblock_shape, thread_shape, \
  231. A, B, C, required_cuda_ver_major, required_cuda_ver_minor)
  232. ###################################################################################################
  233. #
  234. # Emits single instances of a CUTLASS device-wide operator
  235. #
  236. ###################################################################################################
  237. #
  238. class EmitGemmInstance:
  239. ''' Responsible for emitting a CUTLASS template definition'''
  240. def __init__(self):
  241. self.gemm_template = """
  242. // Gemm operator ${operation_name}
  243. using Operation_${operation_name} = cutlass::gemm::device::Gemm<
  244. ${element_a}, ${layout_a},
  245. ${element_b}, ${layout_b},
  246. ${element_c}, ${layout_c},
  247. ${element_accumulator},
  248. ${opcode_class},
  249. ${arch},
  250. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  251. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  252. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  253. ${epilogue_functor}<
  254. ${element_c},
  255. ${epilogue_vector_length},
  256. ${element_accumulator},
  257. ${element_epilogue}
  258. >,
  259. ${swizzling_functor},
  260. ${stages},
  261. ${align_a},
  262. ${align_b},
  263. false,
  264. ${math_operation}
  265. ${residual}
  266. >;
  267. """
  268. self.gemm_complex_template = """
  269. // Gemm operator ${operation_name}
  270. using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
  271. ${element_a}, ${layout_a},
  272. ${element_b}, ${layout_b},
  273. ${element_c}, ${layout_c},
  274. ${element_accumulator},
  275. ${opcode_class},
  276. ${arch},
  277. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  278. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  279. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  280. ${epilogue_functor}<
  281. ${element_c},
  282. ${epilogue_vector_length},
  283. ${element_accumulator},
  284. ${element_epilogue}
  285. >,
  286. ${swizzling_functor},
  287. ${stages},
  288. ${transform_a},
  289. ${transform_b},
  290. ${math_operation}
  291. ${residual}
  292. >;
  293. """
  294. def emit(self, operation):
  295. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  296. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  297. residual = ''
  298. values = {
  299. 'operation_name': operation.procedural_name(),
  300. 'element_a': DataTypeTag[operation.A.element],
  301. 'layout_a': LayoutTag[operation.A.layout],
  302. 'element_b': DataTypeTag[operation.B.element],
  303. 'layout_b': LayoutTag[operation.B.layout],
  304. 'element_c': DataTypeTag[operation.C.element],
  305. 'layout_c': LayoutTag[operation.C.layout],
  306. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  307. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  308. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  309. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  310. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  311. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  312. 'warp_shape_m': str(warp_shape[0]),
  313. 'warp_shape_n': str(warp_shape[1]),
  314. 'warp_shape_k': str(warp_shape[2]),
  315. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  316. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  317. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  318. 'epilogue_vector_length': str(epilogue_vector_length),
  319. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  320. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  321. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  322. 'stages': str(operation.tile_description.stages),
  323. 'align_a': str(operation.A.alignment),
  324. 'align_b': str(operation.B.alignment),
  325. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  326. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  327. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  328. 'residual': residual
  329. }
  330. template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
  331. return SubstituteTemplate(template, values)
  332. #
  333. class EmitGemvBatchedStridedInstance:
  334. ''' Responsible for emitting a CUTLASS template definition'''
  335. def __init__(self):
  336. self.template = """
  337. // Gemm operator ${operation_name}
  338. using Operation_${operation_name} = cutlass::gemm::kernel::DefaultGemv<
  339. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  340. cutlass::gemm::GemmShape<${thread_shape_m}, ${thread_shape_n}, ${thread_shape_k}>,
  341. ${element_a}, ${layout_a},
  342. ${element_b}, ${layout_b},
  343. ${element_c}, ${layout_c}
  344. >;
  345. """
  346. def emit(self, operation):
  347. values = {
  348. 'operation_name': operation.procedural_name(),
  349. 'element_a': DataTypeTag[operation.A.element],
  350. 'layout_a': LayoutTag[operation.A.layout],
  351. 'element_b': DataTypeTag[operation.B.element],
  352. 'layout_b': LayoutTag[operation.B.layout],
  353. 'element_c': DataTypeTag[operation.C.element],
  354. 'layout_c': LayoutTag[operation.C.layout],
  355. 'threadblock_shape_m': str(operation.threadblock_shape[0]),
  356. 'threadblock_shape_n': str(operation.threadblock_shape[1]),
  357. 'threadblock_shape_k': str(operation.threadblock_shape[2]),
  358. 'thread_shape_m': str(operation.thread_shape[0]),
  359. 'thread_shape_n': str(operation.thread_shape[1]),
  360. 'thread_shape_k': str(operation.thread_shape[2]),
  361. }
  362. return SubstituteTemplate(self.template, values)
  363. ###################################################################################################
  364. class EmitSparseGemmInstance:
  365. ''' Responsible for emitting a CUTLASS template definition'''
  366. def __init__(self):
  367. self.gemm_template = """
  368. // Gemm operator ${operation_name}
  369. using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
  370. ${element_a}, ${layout_a},
  371. ${element_b}, ${layout_b},
  372. ${element_c}, ${layout_c},
  373. ${element_accumulator},
  374. ${opcode_class},
  375. ${arch},
  376. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  377. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  378. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  379. ${epilogue_functor}<
  380. ${element_c},
  381. ${epilogue_vector_length},
  382. ${element_accumulator},
  383. ${element_epilogue}
  384. >,
  385. ${swizzling_functor},
  386. ${stages},
  387. ${align_a},
  388. ${align_b},
  389. false,
  390. ${math_operation}
  391. ${residual}
  392. >;
  393. """
  394. def emit(self, operation):
  395. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  396. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  397. residual = ''
  398. values = {
  399. 'operation_name': operation.procedural_name(),
  400. 'element_a': DataTypeTag[operation.A.element],
  401. 'layout_a': LayoutTag[operation.A.layout],
  402. 'element_b': DataTypeTag[operation.B.element],
  403. 'layout_b': LayoutTag[operation.B.layout],
  404. 'element_c': DataTypeTag[operation.C.element],
  405. 'layout_c': LayoutTag[operation.C.layout],
  406. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  407. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  408. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  409. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  410. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  411. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  412. 'warp_shape_m': str(warp_shape[0]),
  413. 'warp_shape_n': str(warp_shape[1]),
  414. 'warp_shape_k': str(warp_shape[2]),
  415. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  416. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  417. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  418. 'epilogue_vector_length': str(epilogue_vector_length),
  419. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  420. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  421. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  422. 'stages': str(operation.tile_description.stages),
  423. 'align_a': str(operation.A.alignment),
  424. 'align_b': str(operation.B.alignment),
  425. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  426. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  427. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  428. 'residual': residual
  429. }
  430. template = self.gemm_template
  431. return SubstituteTemplate(template, values)
  432. ###################################################################################################
  433. #
  434. class EmitGemmUniversalInstance:
  435. ''' Responsible for emitting a CUTLASS template definition'''
  436. def __init__(self):
  437. self.gemm_template = """
  438. // Gemm operator ${operation_name}
  439. using ${operation_name}_base =
  440. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  441. ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
  442. ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
  443. ${element_c}, ${layout_c},
  444. ${element_accumulator},
  445. ${opcode_class},
  446. ${arch},
  447. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  448. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  449. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  450. ${epilogue_functor}<
  451. ${element_c},
  452. ${epilogue_vector_length},
  453. ${element_accumulator},
  454. ${element_epilogue}
  455. >,
  456. ${swizzling_functor},
  457. ${stages},
  458. ${math_operation}
  459. >::GemmKernel;
  460. // Define named type
  461. struct ${operation_name} :
  462. public ${operation_name}_base { };
  463. """
  464. self.gemm_template_interleaved = """
  465. // Gemm operator ${operation_name}
  466. using ${operation_name}_base =
  467. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  468. ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
  469. ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
  470. ${element_c}, ${layout_c},
  471. ${element_accumulator},
  472. ${opcode_class},
  473. ${arch},
  474. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  475. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  476. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  477. ${epilogue_functor}<
  478. ${element_c},
  479. ${epilogue_vector_length},
  480. ${element_accumulator},
  481. ${element_epilogue}
  482. >,
  483. ${swizzling_functor},
  484. ${stages},
  485. ${math_operation}
  486. >::GemmKernel;
  487. // Define named type
  488. struct ${operation_name} :
  489. public ${operation_name}_base { };
  490. """
  491. def emit(self, operation):
  492. threadblock_shape = operation.tile_description.threadblock_shape
  493. warp_count = operation.tile_description.warp_count
  494. warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
  495. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  496. transpose_layouts = {
  497. LayoutType.ColumnMajor: LayoutType.RowMajor,
  498. LayoutType.RowMajor: LayoutType.ColumnMajor
  499. }
  500. if operation.A.layout in transpose_layouts.keys() and \
  501. operation.B.layout in transpose_layouts.keys() and \
  502. operation.C.layout in transpose_layouts.keys():
  503. instance_layout_A = transpose_layouts[operation.A.layout]
  504. instance_layout_B = transpose_layouts[operation.B.layout]
  505. instance_layout_C = transpose_layouts[operation.C.layout]
  506. gemm_template = self.gemm_template
  507. else:
  508. instance_layout_A, instance_layout_B, instance_layout_C = \
  509. (operation.A.layout, operation.B.layout, operation.C.layout)
  510. gemm_template = self.gemm_template_interleaved
  511. #
  512. values = {
  513. 'operation_name': operation.procedural_name(),
  514. 'element_a': DataTypeTag[operation.A.element],
  515. 'layout_a': LayoutTag[instance_layout_A],
  516. 'element_b': DataTypeTag[operation.B.element],
  517. 'layout_b': LayoutTag[instance_layout_B],
  518. 'element_c': DataTypeTag[operation.C.element],
  519. 'layout_c': LayoutTag[instance_layout_C],
  520. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  521. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  522. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  523. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  524. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  525. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  526. 'warp_shape_m': str(warp_shape[0]),
  527. 'warp_shape_n': str(warp_shape[1]),
  528. 'warp_shape_k': str(warp_shape[2]),
  529. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  530. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  531. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  532. 'epilogue_vector_length': str(epilogue_vector_length),
  533. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  534. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  535. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  536. 'stages': str(operation.tile_description.stages),
  537. 'align_a': str(operation.A.alignment),
  538. 'align_b': str(operation.B.alignment),
  539. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  540. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  541. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
  542. }
  543. return SubstituteTemplate(gemm_template, values)
  544. ###################################################################################################
  545. #
  546. class EmitGemmPlanarComplexInstance:
  547. ''' Responsible for emitting a CUTLASS template definition'''
  548. def __init__(self):
  549. self.template = """
  550. // Gemm operator ${operation_name}
  551. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  552. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  553. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  554. ${element_c}, cutlass::layout::RowMajor,
  555. ${element_accumulator},
  556. ${opcode_class},
  557. ${arch},
  558. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  559. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  560. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  561. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  562. ${element_c},
  563. ${alignment_c},
  564. ${element_accumulator},
  565. ${element_epilogue}
  566. >,
  567. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  568. ${stages},
  569. ${math_operator}
  570. >::GemmKernel;
  571. struct ${operation_name} :
  572. public Operation_${operation_name} { };
  573. """
  574. def emit(self, operation):
  575. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  576. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  577. transposed_layout_A = TransposedLayout[operation.A.layout]
  578. transposed_layout_B = TransposedLayout[operation.B.layout]
  579. values = {
  580. 'operation_name': operation.procedural_name(),
  581. 'element_a': DataTypeTag[operation.B.element],
  582. 'layout_a': LayoutTag[transposed_layout_B],
  583. 'transform_a': ComplexTransformTag[operation.B.complex_transform],
  584. 'alignment_a': str(operation.B.alignment),
  585. 'element_b': DataTypeTag[operation.A.element],
  586. 'layout_b': LayoutTag[transposed_layout_A],
  587. 'transform_b': ComplexTransformTag[operation.A.complex_transform],
  588. 'alignment_b': str(operation.A.alignment),
  589. 'element_c': DataTypeTag[operation.C.element],
  590. 'layout_c': LayoutTag[operation.C.layout],
  591. 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
  592. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  593. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  594. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  595. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  596. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  597. 'warp_shape_m': str(warp_shape[0]),
  598. 'warp_shape_n': str(warp_shape[1]),
  599. 'warp_shape_k': str(warp_shape[2]),
  600. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  601. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  602. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  603. 'alignment_c': str(operation.C.alignment),
  604. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  605. 'stages': str(operation.tile_description.stages),
  606. 'math_operator': 'cutlass::arch::OpMultiplyAdd'
  607. }
  608. return SubstituteTemplate(self.template, values)
  609. ###################################################################################################
  610. #
  611. class EmitGemmPlanarComplexArrayInstance:
  612. ''' Responsible for emitting a CUTLASS template definition'''
  613. def __init__(self):
  614. self.template = """
  615. // Gemm operator ${operation_name}
  616. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  617. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  618. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  619. ${element_c}, cutlass::layout::RowMajor,
  620. ${element_accumulator},
  621. ${opcode_class},
  622. ${arch},
  623. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  624. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  625. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  626. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  627. ${element_c},
  628. ${alignment_c},
  629. ${element_accumulator},
  630. ${element_epilogue}
  631. >,
  632. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  633. ${stages},
  634. ${math_operator}
  635. >::GemmArrayKernel;
  636. struct ${operation_name} : public Operation_${operation_name} { };
  637. """
  638. def emit(self, operation):
  639. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  640. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  641. transposed_layout_A = TransposedLayout[operation.A.layout]
  642. transposed_layout_B = TransposedLayout[operation.B.layout]
  643. values = {
  644. 'operation_name': operation.procedural_name(),
  645. 'element_a': DataTypeTag[operation.B.element],
  646. 'layout_a': LayoutTag[transposed_layout_B],
  647. 'transform_a': ComplexTransformTag[operation.B.complex_transform],
  648. 'alignment_a': str(operation.B.alignment),
  649. 'element_b': DataTypeTag[operation.A.element],
  650. 'layout_b': LayoutTag[transposed_layout_A],
  651. 'transform_b': ComplexTransformTag[operation.A.complex_transform],
  652. 'alignment_b': str(operation.A.alignment),
  653. 'element_c': DataTypeTag[operation.C.element],
  654. 'layout_c': LayoutTag[operation.C.layout],
  655. 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
  656. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  657. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  658. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  659. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  660. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  661. 'warp_shape_m': str(warp_shape[0]),
  662. 'warp_shape_n': str(warp_shape[1]),
  663. 'warp_shape_k': str(warp_shape[2]),
  664. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  665. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  666. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  667. 'alignment_c': str(operation.C.alignment),
  668. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  669. 'stages': str(operation.tile_description.stages),
  670. 'math_operator': 'cutlass::arch::OpMultiplyAdd'
  671. }
  672. return SubstituteTemplate(self.template, values)
  673. #
  674. class EmitGemmSplitKParallelInstance:
  675. ''' Responsible for emitting a CUTLASS template definition'''
  676. def __init__(self):
  677. self.template = """
  678. // Gemm operator ${operation_name}
  679. using Operation_${operation_name} = cutlass::gemm::device::GemmSplitKParallel<
  680. ${element_a}, ${layout_a},
  681. ${element_b}, ${layout_b},
  682. ${element_c}, ${layout_c},
  683. ${element_accumulator},
  684. ${opcode_class},
  685. ${arch},
  686. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  687. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  688. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  689. ${epilogue_functor}<
  690. ${element_c},
  691. ${epilogue_vector_length},
  692. ${element_accumulator},
  693. ${element_epilogue}
  694. >,
  695. cutlass::epilogue::thread::Convert<
  696. ${element_accumulator},
  697. ${epilogue_vector_length},
  698. ${element_accumulator}
  699. >,
  700. cutlass::reduction::thread::ReduceAdd<
  701. ${element_accumulator},
  702. ${element_accumulator},
  703. ${epilogue_vector_length}
  704. >,
  705. cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle,
  706. ${stages},
  707. ${align_a},
  708. ${align_b},
  709. ${math_operation}
  710. >;
  711. """
  712. def emit(self, operation):
  713. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  714. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  715. values = {
  716. 'operation_name': operation.procedural_name(),
  717. 'element_a': DataTypeTag[operation.A.element],
  718. 'layout_a': LayoutTag[operation.A.layout],
  719. 'element_b': DataTypeTag[operation.B.element],
  720. 'layout_b': LayoutTag[operation.B.layout],
  721. 'element_c': DataTypeTag[operation.C.element],
  722. 'layout_c': LayoutTag[operation.C.layout],
  723. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  724. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  725. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  726. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  727. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  728. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  729. 'warp_shape_m': str(warp_shape[0]),
  730. 'warp_shape_n': str(warp_shape[1]),
  731. 'warp_shape_k': str(warp_shape[2]),
  732. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  733. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  734. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  735. 'epilogue_vector_length': str(epilogue_vector_length),
  736. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  737. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  738. 'stages': str(operation.tile_description.stages),
  739. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  740. 'align_a': str(operation.A.alignment),
  741. 'align_b': str(operation.B.alignment),
  742. }
  743. return SubstituteTemplate(self.template, values)
  744. ###################################################################################################
  745. ###################################################################################################
  746. #
  747. # Emitters functions for all targets
  748. #
  749. ###################################################################################################
  750. class EmitGemmConfigurationLibrary:
  751. def __init__(self, operation_path, configuration_name):
  752. self.configuration_name = configuration_name
  753. self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
  754. self.instance_emitter = {
  755. GemmKind.Gemm: EmitGemmInstance,
  756. GemmKind.Sparse: EmitSparseGemmInstance,
  757. GemmKind.Universal: EmitGemmUniversalInstance,
  758. GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
  759. GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance
  760. }
  761. self.gemm_kind_wrappers = {
  762. GemmKind.Gemm: 'GemmOperation',
  763. GemmKind.Sparse: 'GemmSparseOperation',
  764. GemmKind.Universal: 'GemmUniversalOperation',
  765. GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
  766. GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation'
  767. }
  768. self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
  769. self.instance_template = {
  770. GemmKind.Gemm: """
  771. ${compile_guard_start}
  772. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  773. ${compile_guard_end}
  774. """,
  775. GemmKind.Sparse: """
  776. ${compile_guard_start}
  777. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  778. ${compile_guard_end}
  779. """,
  780. GemmKind.Universal: """
  781. ${compile_guard_start}
  782. manifest.append(new ${gemm_kind}<
  783. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  784. >("${operation_name}"));
  785. ${compile_guard_end}
  786. """,
  787. GemmKind.PlanarComplex: """
  788. ${compile_guard_start}
  789. manifest.append(new ${gemm_kind}<
  790. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  791. >("${operation_name}"));
  792. ${compile_guard_end}
  793. """,
  794. GemmKind.PlanarComplexArray: """
  795. ${compile_guard_start}
  796. manifest.append(new ${gemm_kind}<
  797. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  798. >("${operation_name}"));
  799. ${compile_guard_end}
  800. """
  801. }
  802. self.header_template = """
  803. /*
  804. Generated by gemm_operation.py - Do not edit.
  805. */
  806. ///////////////////////////////////////////////////////////////////////////////////////////////////
  807. #include "cutlass/arch/wmma.h"
  808. #include "cutlass/cutlass.h"
  809. #include "cutlass/library/library.h"
  810. #include "cutlass/library/manifest.h"
  811. #include "library_internal.h"
  812. #include "gemm_operation.h"
  813. ///////////////////////////////////////////////////////////////////////////////////////////////////
  814. """
  815. self.initialize_function_template = """
  816. ///////////////////////////////////////////////////////////////////////////////////////////////////
  817. namespace cutlass {
  818. namespace library {
  819. ///////////////////////////////////////////////////////////////////////////////////////////////////
  820. void initialize_${configuration_name}(Manifest &manifest) {
  821. """
  822. self.epilogue_template = """
  823. }
  824. ///////////////////////////////////////////////////////////////////////////////////////////////////
  825. } // namespace library
  826. } // namespace cutlass
  827. ///////////////////////////////////////////////////////////////////////////////////////////////////
  828. """
  829. def __enter__(self):
  830. self.configuration_file = open(self.configuration_path, "w")
  831. self.configuration_file.write(self.header_template)
  832. self.instance_definitions = []
  833. self.instance_wrappers = []
  834. self.operations = []
  835. return self
  836. def emit(self, operation):
  837. emitter = self.instance_emitter[operation.gemm_kind]()
  838. self.operations.append(operation)
  839. self.instance_definitions.append(emitter.emit(operation))
  840. self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.gemm_kind], {
  841. 'configuration_name': self.configuration_name,
  842. 'operation_name': operation.procedural_name(),
  843. 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind],
  844. 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
  845. if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
  846. 'compile_guard_end': "#endif" \
  847. if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
  848. }))
  849. def __exit__(self, exception_type, exception_value, traceback):
  850. # Write instance definitions in top-level namespace
  851. for instance_definition in self.instance_definitions:
  852. self.configuration_file.write(instance_definition)
  853. # Add wrapper objects within initialize() function
  854. self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
  855. 'configuration_name': self.configuration_name
  856. }))
  857. for instance_wrapper in self.instance_wrappers:
  858. self.configuration_file.write(instance_wrapper)
  859. self.configuration_file.write(self.epilogue_template)
  860. self.configuration_file.close()
  861. ###################################################################################################
  862. ###################################################################################################
  863. class EmitGemmSingleKernelWrapper:
  864. def __init__(self, kernel_path, gemm_operation, short_path=False):
  865. self.short_path = short_path
  866. self.kernel_path = kernel_path
  867. self.operation = gemm_operation
  868. instance_emitters = {
  869. GemmKind.Gemm: EmitGemmInstance(),
  870. GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(),
  871. }
  872. self.instance_emitter = instance_emitters[self.operation.gemm_kind]
  873. self.header_template = """
  874. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  875. // ignore warning of cutlass
  876. #pragma GCC diagnostic push
  877. #pragma GCC diagnostic ignored "-Wunused-parameter"
  878. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  879. #pragma GCC diagnostic ignored "-Wuninitialized"
  880. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  881. #include "cutlass/gemm/device/gemm.h"
  882. #include "cutlass/gemm/device/gemm_splitk_parallel.h"
  883. #include "src/cuda/cutlass/manifest.h"
  884. #include "src/cuda/cutlass/gemm_operation.h"
  885. """
  886. self.instance_template = """
  887. ${operation_instance}
  888. """
  889. self.manifest_template = """
  890. namespace cutlass {
  891. namespace library {
  892. void initialize_${operation_name}(Manifest &manifest) {
  893. manifest.append(new GemmOperation<
  894. Operation_${operation_name}
  895. >("${operation_name}"));
  896. }
  897. } // namespace library
  898. } // namespace cutlass
  899. """
  900. self.epilogue_template = """
  901. #pragma GCC diagnostic pop
  902. #endif
  903. """
  904. #
  905. def __enter__(self):
  906. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  907. self.kernel_file = open(self.kernel_path, "w")
  908. self.kernel_file.write(SubstituteTemplate(self.header_template, {
  909. 'required_cuda_ver_major': str(self.operation.required_cuda_ver_major),
  910. 'required_cuda_ver_minor': str(self.operation.required_cuda_ver_minor),
  911. }))
  912. return self
  913. #
  914. def emit(self):
  915. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  916. 'operation_instance': self.instance_emitter.emit(self.operation),
  917. }))
  918. # emit manifest helper
  919. manifest = SubstituteTemplate(self.manifest_template, {
  920. 'operation_name': self.operation.procedural_name(),
  921. })
  922. self.kernel_file.write(manifest)
  923. #
  924. def __exit__(self, exception_type, exception_value, traceback):
  925. self.kernel_file.write(self.epilogue_template)
  926. self.kernel_file.close()
  927. ###################################################################################################
  928. ###################################################################################################
  929. class EmitGemvSingleKernelWrapper:
  930. def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False):
  931. self.kernel_path = kernel_path
  932. self.wrapper_path = wrapper_path
  933. self.operation = gemm_operation
  934. self.short_path = short_path
  935. self.wrapper_template = """
  936. template void megdnn::cuda::cutlass_wrapper::
  937. cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>(
  938. BatchedGemmCoord const& problem_size,
  939. const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a,
  940. const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b,
  941. typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
  942. cudaStream_t stream);
  943. """
  944. self.instance_emitter = EmitGemvBatchedStridedInstance()
  945. self.header_template = """
  946. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  947. // ignore warning of cutlass
  948. #pragma GCC diagnostic push
  949. #pragma GCC diagnostic ignored "-Wunused-parameter"
  950. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  951. #pragma GCC diagnostic ignored "-Wuninitialized"
  952. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  953. #include "${wrapper_path}"
  954. """
  955. self.instance_template = """
  956. ${operation_instance}
  957. """
  958. self.epilogue_template = """
  959. #pragma GCC diagnostic pop
  960. #endif
  961. """
  962. #
  963. def __enter__(self):
  964. if self.short_path:
  965. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  966. GlobalCnt.cnt += 1
  967. else:
  968. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  969. self.kernel_file = open(self.kernel_path, "w")
  970. self.kernel_file.write(SubstituteTemplate(self.header_template, {
  971. 'wrapper_path': self.wrapper_path,
  972. 'required_cuda_ver_major': str(self.operation.required_cuda_ver_major),
  973. 'required_cuda_ver_minor': str(self.operation.required_cuda_ver_minor),
  974. }))
  975. return self
  976. #
  977. def emit(self):
  978. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  979. 'operation_instance': self.instance_emitter.emit(self.operation),
  980. }))
  981. # emit wrapper
  982. wrapper = SubstituteTemplate(self.wrapper_template, {
  983. 'operation_name': self.operation.procedural_name(),
  984. })
  985. self.kernel_file.write(wrapper)
  986. #
  987. def __exit__(self, exception_type, exception_value, traceback):
  988. self.kernel_file.write(self.epilogue_template)
  989. self.kernel_file.close()
  990. ###################################################################################################
  991. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台