You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemm_operation.py 43 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. import functools
  10. import operator
  11. from lazy_file import LazyFile
  12. from library import *
  13. ###################################################################################################
  14. #
  15. # Data structure modeling a GEMM operation
  16. #
  17. ###################################################################################################
  18. #
  19. class GemmOperation:
  20. #
  21. def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
  22. epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8):
  23. self.operation_kind = OperationKind.Gemm
  24. self.arch = arch
  25. self.tile_description = tile_description
  26. self.gemm_kind = gemm_kind
  27. self.A = A
  28. self.B = B
  29. self.C = C
  30. self.element_epilogue = element_epilogue
  31. self.epilogue_functor = epilogue_functor
  32. self.swizzling_functor = swizzling_functor
  33. #
  34. def is_complex(self):
  35. complex_operators = [
  36. MathOperation.multiply_add_complex,
  37. MathOperation.multiply_add_complex_gaussian
  38. ]
  39. return self.tile_description.math_instruction.math_operation in complex_operators
  40. #
  41. def is_split_k_parallel(self):
  42. return self.gemm_kind == GemmKind.SplitKParallel
  43. #
  44. def is_planar_complex(self):
  45. return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
  46. #
  47. def accumulator_type(self):
  48. accum = self.tile_description.math_instruction.element_accumulator
  49. if self.is_complex():
  50. return get_complex_from_real(accum)
  51. return accum
  52. #
  53. def short_math_name(self):
  54. if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
  55. return "g%s" % ShortDataTypeNames[self.accumulator_type()]
  56. return ShortDataTypeNames[self.accumulator_type()]
  57. #
  58. def core_name(self):
  59. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  60. inst_shape = ''
  61. inst_operation = ''
  62. intermediate_type = ''
  63. math_operations_map = {
  64. MathOperation.xor_popc: 'xor',
  65. }
  66. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
  67. self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
  68. math_op = self.tile_description.math_instruction.math_operation
  69. math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
  70. inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
  71. inst_shape += math_op_string
  72. if self.tile_description.math_instruction.element_a != self.A.element and \
  73. self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
  74. intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
  75. return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
  76. #
  77. def extended_name(self):
  78. ''' Append data types if they differ from compute type. '''
  79. if self.is_complex():
  80. extended_name = "${core_name}"
  81. else:
  82. if self.C.element != self.tile_description.math_instruction.element_accumulator and \
  83. self.A.element != self.tile_description.math_instruction.element_accumulator:
  84. extended_name = "${element_c}_${core_name}_${element_a}"
  85. elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
  86. self.A.element != self.tile_description.math_instruction.element_accumulator:
  87. extended_name = "${core_name}_${element_a}"
  88. else:
  89. extended_name = "${core_name}"
  90. extended_name = SubstituteTemplate(extended_name, {
  91. 'element_a': DataTypeNames[self.A.element],
  92. 'element_c': DataTypeNames[self.C.element],
  93. 'core_name': self.core_name()
  94. })
  95. return extended_name
  96. #
  97. def layout_name(self):
  98. if self.is_complex() or self.is_planar_complex():
  99. return "%s%s" % (
  100. ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
  101. ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
  102. )
  103. return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
  104. #
  105. def procedural_name(self):
  106. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  107. threadblock = self.tile_description.procedural_name()
  108. opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
  109. alignment = max([self.A.alignment, self.B.alignment, self.C.alignment])
  110. return SubstituteTemplate(
  111. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}",
  112. {
  113. 'opcode_class': opcode_class_name,
  114. 'extended_name': self.extended_name(),
  115. 'threadblock': threadblock,
  116. 'layout': self.layout_name(),
  117. 'alignment': "%d" % self.A.alignment,
  118. }
  119. )
  120. #
  121. def configuration_name(self):
  122. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  123. return self.procedural_name()
  124. ###################################################################################################
  125. #
  126. # Data structure modeling a GEMV Batched Strided operation
  127. #
  128. ###################################################################################################
  129. #
  130. class GemvBatchedStridedOperation:
  131. #
  132. def __init__(self, gemm_kind, arch, math_inst, threadblock_shape, thread_shape, A, B, C):
  133. self.operation_kind = OperationKind.Gemm
  134. self.arch = arch
  135. self.gemm_kind = gemm_kind
  136. self.math_instruction = math_inst
  137. self.threadblock_shape = threadblock_shape
  138. self.thread_shape = thread_shape
  139. self.A = A
  140. self.B = B
  141. self.C = C
  142. #
  143. def accumulator_type(self):
  144. accum = self.math_instruction.element_accumulator
  145. return accum
  146. #
  147. def short_math_name(self):
  148. return ShortDataTypeNames[self.accumulator_type()]
  149. #
  150. def core_name(self):
  151. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  152. return "%s%s" % (self.short_math_name(), \
  153. GemmKindNames[self.gemm_kind])
  154. #
  155. def extended_name(self):
  156. ''' Append data types if they differ from compute type. '''
  157. if self.C.element != self.math_instruction.element_accumulator and \
  158. self.A.element != self.math_instruction.element_accumulator:
  159. extended_name = "${element_c}_${core_name}_${element_a}"
  160. elif self.C.element == self.math_instruction.element_accumulator and \
  161. self.A.element != self.math_instruction.element_accumulator:
  162. extended_name = "${core_name}_${element_a}"
  163. else:
  164. extended_name = "${core_name}"
  165. extended_name = SubstituteTemplate(extended_name, {
  166. 'element_a': DataTypeNames[self.A.element],
  167. 'element_c': DataTypeNames[self.C.element],
  168. 'core_name': self.core_name()
  169. })
  170. return extended_name
  171. #
  172. def layout_name(self):
  173. return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
  174. #
  175. def procedural_name(self):
  176. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  177. threadblock = "%dx%d_%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2])
  178. opcode_class_name = OpcodeClassNames[self.math_instruction.opcode_class]
  179. alignment_a = self.A.alignment
  180. alignment_b = self.B.alignment
  181. return SubstituteTemplate(
  182. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment_a}x${alignment_b}",
  183. {
  184. 'opcode_class': opcode_class_name,
  185. 'extended_name': self.extended_name(),
  186. 'threadblock': threadblock,
  187. 'layout': self.layout_name(),
  188. 'alignment_a': "%d" % alignment_a,
  189. 'alignment_b': "%d" % alignment_b,
  190. }
  191. )
  192. #
  193. def configuration_name(self):
  194. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  195. return self.procedural_name()
  196. #
  197. def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a = 32, align_b = 32, align_c = 32):
  198. operations = []
  199. swizzling_functor = SwizzlingFunctor.Identity1
  200. element_a, element_b, element_c, element_epilogue = data_type
  201. if tile.math_instruction.element_accumulator == DataType.s32:
  202. epilogues = [EpilogueFunctor.LinearCombinationClamp]
  203. else:
  204. assert tile.math_instruction.element_accumulator == DataType.f32
  205. epilogues = [EpilogueFunctor.LinearCombination]
  206. for epilogue in epilogues:
  207. A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a]))
  208. B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b]))
  209. C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c]))
  210. operations.append(GemmOperation(GemmKind.Gemm, min_cc, tile, A, B, C, \
  211. element_epilogue, epilogue, swizzling_functor))
  212. operations.append(GemmOperation(GemmKind.SplitKParallel, min_cc, tile, A, B, C, \
  213. element_epilogue, epilogue, swizzling_functor))
  214. return operations
  215. def GeneratesGemv(math_inst, threadblock_shape, thread_shape, data_type, layout_a, layout_b, layout_c, min_cc, \
  216. align_a = 32, align_b = 32, align_c = 32):
  217. element_a, element_b, element_c, element_epilogue = data_type
  218. A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a]))
  219. B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b]))
  220. C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c]))
  221. return GemvBatchedStridedOperation(GemmKind.GemvBatchedStrided, min_cc, math_inst, threadblock_shape, thread_shape, \
  222. A, B, C)
  223. ###################################################################################################
  224. #
  225. # Emits single instances of a CUTLASS device-wide operator
  226. #
  227. ###################################################################################################
  228. #
  229. class EmitGemmInstance:
  230. ''' Responsible for emitting a CUTLASS template definition'''
  231. def __init__(self):
  232. self.gemm_template = """
  233. // Gemm operator ${operation_name}
  234. using Operation_${operation_name} = cutlass::gemm::device::Gemm<
  235. ${element_a}, ${layout_a},
  236. ${element_b}, ${layout_b},
  237. ${element_c}, ${layout_c},
  238. ${element_accumulator},
  239. ${opcode_class},
  240. ${arch},
  241. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  242. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  243. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  244. ${epilogue_functor}<
  245. ${element_c},
  246. ${epilogue_vector_length},
  247. ${element_accumulator},
  248. ${element_epilogue}
  249. >,
  250. ${swizzling_functor},
  251. ${stages},
  252. ${align_a},
  253. ${align_b},
  254. false,
  255. ${math_operation}
  256. ${residual}
  257. >;
  258. """
  259. self.gemm_complex_template = """
  260. // Gemm operator ${operation_name}
  261. using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
  262. ${element_a}, ${layout_a},
  263. ${element_b}, ${layout_b},
  264. ${element_c}, ${layout_c},
  265. ${element_accumulator},
  266. ${opcode_class},
  267. ${arch},
  268. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  269. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  270. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  271. ${epilogue_functor}<
  272. ${element_c},
  273. ${epilogue_vector_length},
  274. ${element_accumulator},
  275. ${element_epilogue}
  276. >,
  277. ${swizzling_functor},
  278. ${stages},
  279. ${transform_a},
  280. ${transform_b},
  281. ${math_operation}
  282. ${residual}
  283. >;
  284. """
  285. def emit(self, operation):
  286. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  287. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  288. residual = ''
  289. values = {
  290. 'operation_name': operation.procedural_name(),
  291. 'element_a': DataTypeTag[operation.A.element],
  292. 'layout_a': LayoutTag[operation.A.layout],
  293. 'element_b': DataTypeTag[operation.B.element],
  294. 'layout_b': LayoutTag[operation.B.layout],
  295. 'element_c': DataTypeTag[operation.C.element],
  296. 'layout_c': LayoutTag[operation.C.layout],
  297. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  298. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  299. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  300. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  301. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  302. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  303. 'warp_shape_m': str(warp_shape[0]),
  304. 'warp_shape_n': str(warp_shape[1]),
  305. 'warp_shape_k': str(warp_shape[2]),
  306. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  307. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  308. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  309. 'epilogue_vector_length': str(epilogue_vector_length),
  310. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  311. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  312. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  313. 'stages': str(operation.tile_description.stages),
  314. 'align_a': str(operation.A.alignment),
  315. 'align_b': str(operation.B.alignment),
  316. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  317. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  318. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  319. 'residual': residual
  320. }
  321. template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
  322. return SubstituteTemplate(template, values)
  323. #
  324. class EmitGemvBatchedStridedInstance:
  325. ''' Responsible for emitting a CUTLASS template definition'''
  326. def __init__(self):
  327. self.template = """
  328. // Gemm operator ${operation_name}
  329. using Operation_${operation_name} = cutlass::gemm::kernel::DefaultGemv<
  330. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  331. cutlass::gemm::GemmShape<${thread_shape_m}, ${thread_shape_n}, ${thread_shape_k}>,
  332. ${element_a}, ${layout_a},
  333. ${element_b}, ${layout_b},
  334. ${element_c}, ${layout_c}
  335. >;
  336. """
  337. def emit(self, operation):
  338. values = {
  339. 'operation_name': operation.procedural_name(),
  340. 'element_a': DataTypeTag[operation.A.element],
  341. 'layout_a': LayoutTag[operation.A.layout],
  342. 'element_b': DataTypeTag[operation.B.element],
  343. 'layout_b': LayoutTag[operation.B.layout],
  344. 'element_c': DataTypeTag[operation.C.element],
  345. 'layout_c': LayoutTag[operation.C.layout],
  346. 'threadblock_shape_m': str(operation.threadblock_shape[0]),
  347. 'threadblock_shape_n': str(operation.threadblock_shape[1]),
  348. 'threadblock_shape_k': str(operation.threadblock_shape[2]),
  349. 'thread_shape_m': str(operation.thread_shape[0]),
  350. 'thread_shape_n': str(operation.thread_shape[1]),
  351. 'thread_shape_k': str(operation.thread_shape[2]),
  352. }
  353. return SubstituteTemplate(self.template, values)
  354. ###################################################################################################
  355. class EmitSparseGemmInstance:
  356. ''' Responsible for emitting a CUTLASS template definition'''
  357. def __init__(self):
  358. self.gemm_template = """
  359. // Gemm operator ${operation_name}
  360. using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
  361. ${element_a}, ${layout_a},
  362. ${element_b}, ${layout_b},
  363. ${element_c}, ${layout_c},
  364. ${element_accumulator},
  365. ${opcode_class},
  366. ${arch},
  367. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  368. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  369. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  370. ${epilogue_functor}<
  371. ${element_c},
  372. ${epilogue_vector_length},
  373. ${element_accumulator},
  374. ${element_epilogue}
  375. >,
  376. ${swizzling_functor},
  377. ${stages},
  378. ${align_a},
  379. ${align_b},
  380. false,
  381. ${math_operation}
  382. ${residual}
  383. >;
  384. """
  385. def emit(self, operation):
  386. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  387. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  388. residual = ''
  389. values = {
  390. 'operation_name': operation.procedural_name(),
  391. 'element_a': DataTypeTag[operation.A.element],
  392. 'layout_a': LayoutTag[operation.A.layout],
  393. 'element_b': DataTypeTag[operation.B.element],
  394. 'layout_b': LayoutTag[operation.B.layout],
  395. 'element_c': DataTypeTag[operation.C.element],
  396. 'layout_c': LayoutTag[operation.C.layout],
  397. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  398. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  399. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  400. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  401. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  402. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  403. 'warp_shape_m': str(warp_shape[0]),
  404. 'warp_shape_n': str(warp_shape[1]),
  405. 'warp_shape_k': str(warp_shape[2]),
  406. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  407. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  408. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  409. 'epilogue_vector_length': str(epilogue_vector_length),
  410. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  411. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  412. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  413. 'stages': str(operation.tile_description.stages),
  414. 'align_a': str(operation.A.alignment),
  415. 'align_b': str(operation.B.alignment),
  416. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  417. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  418. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  419. 'residual': residual
  420. }
  421. template = self.gemm_template
  422. return SubstituteTemplate(template, values)
  423. ###################################################################################################
  424. #
  425. class EmitGemmUniversalInstance:
  426. ''' Responsible for emitting a CUTLASS template definition'''
  427. def __init__(self):
  428. self.gemm_template = """
  429. // Gemm operator ${operation_name}
  430. using ${operation_name}_base =
  431. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  432. ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
  433. ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
  434. ${element_c}, ${layout_c},
  435. ${element_accumulator},
  436. ${opcode_class},
  437. ${arch},
  438. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  439. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  440. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  441. ${epilogue_functor}<
  442. ${element_c},
  443. ${epilogue_vector_length},
  444. ${element_accumulator},
  445. ${element_epilogue}
  446. >,
  447. ${swizzling_functor},
  448. ${stages},
  449. ${math_operation}
  450. >::GemmKernel;
  451. // Define named type
  452. struct ${operation_name} :
  453. public ${operation_name}_base { };
  454. """
  455. self.gemm_template_interleaved = """
  456. // Gemm operator ${operation_name}
  457. using ${operation_name}_base =
  458. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  459. ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
  460. ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
  461. ${element_c}, ${layout_c},
  462. ${element_accumulator},
  463. ${opcode_class},
  464. ${arch},
  465. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  466. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  467. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  468. ${epilogue_functor}<
  469. ${element_c},
  470. ${epilogue_vector_length},
  471. ${element_accumulator},
  472. ${element_epilogue}
  473. >,
  474. ${swizzling_functor},
  475. ${stages},
  476. ${math_operation}
  477. >::GemmKernel;
  478. // Define named type
  479. struct ${operation_name} :
  480. public ${operation_name}_base { };
  481. """
  482. def emit(self, operation):
  483. threadblock_shape = operation.tile_description.threadblock_shape
  484. warp_count = operation.tile_description.warp_count
  485. warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
  486. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  487. transpose_layouts = {
  488. LayoutType.ColumnMajor: LayoutType.RowMajor,
  489. LayoutType.RowMajor: LayoutType.ColumnMajor
  490. }
  491. if operation.A.layout in transpose_layouts.keys() and \
  492. operation.B.layout in transpose_layouts.keys() and \
  493. operation.C.layout in transpose_layouts.keys():
  494. instance_layout_A = transpose_layouts[operation.A.layout]
  495. instance_layout_B = transpose_layouts[operation.B.layout]
  496. instance_layout_C = transpose_layouts[operation.C.layout]
  497. gemm_template = self.gemm_template
  498. else:
  499. instance_layout_A, instance_layout_B, instance_layout_C = \
  500. (operation.A.layout, operation.B.layout, operation.C.layout)
  501. gemm_template = self.gemm_template_interleaved
  502. #
  503. values = {
  504. 'operation_name': operation.procedural_name(),
  505. 'element_a': DataTypeTag[operation.A.element],
  506. 'layout_a': LayoutTag[instance_layout_A],
  507. 'element_b': DataTypeTag[operation.B.element],
  508. 'layout_b': LayoutTag[instance_layout_B],
  509. 'element_c': DataTypeTag[operation.C.element],
  510. 'layout_c': LayoutTag[instance_layout_C],
  511. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  512. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  513. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  514. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  515. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  516. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  517. 'warp_shape_m': str(warp_shape[0]),
  518. 'warp_shape_n': str(warp_shape[1]),
  519. 'warp_shape_k': str(warp_shape[2]),
  520. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  521. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  522. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  523. 'epilogue_vector_length': str(epilogue_vector_length),
  524. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  525. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  526. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  527. 'stages': str(operation.tile_description.stages),
  528. 'align_a': str(operation.A.alignment),
  529. 'align_b': str(operation.B.alignment),
  530. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  531. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  532. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
  533. }
  534. return SubstituteTemplate(gemm_template, values)
  535. ###################################################################################################
  536. #
  537. class EmitGemmPlanarComplexInstance:
  538. ''' Responsible for emitting a CUTLASS template definition'''
  539. def __init__(self):
  540. self.template = """
  541. // Gemm operator ${operation_name}
  542. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  543. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  544. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  545. ${element_c}, cutlass::layout::RowMajor,
  546. ${element_accumulator},
  547. ${opcode_class},
  548. ${arch},
  549. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  550. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  551. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  552. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  553. ${element_c},
  554. ${alignment_c},
  555. ${element_accumulator},
  556. ${element_epilogue}
  557. >,
  558. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  559. ${stages},
  560. ${math_operator}
  561. >::GemmKernel;
  562. struct ${operation_name} :
  563. public Operation_${operation_name} { };
  564. """
  565. def emit(self, operation):
  566. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  567. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  568. transposed_layout_A = TransposedLayout[operation.A.layout]
  569. transposed_layout_B = TransposedLayout[operation.B.layout]
  570. values = {
  571. 'operation_name': operation.procedural_name(),
  572. 'element_a': DataTypeTag[operation.B.element],
  573. 'layout_a': LayoutTag[transposed_layout_B],
  574. 'transform_a': ComplexTransformTag[operation.B.complex_transform],
  575. 'alignment_a': str(operation.B.alignment),
  576. 'element_b': DataTypeTag[operation.A.element],
  577. 'layout_b': LayoutTag[transposed_layout_A],
  578. 'transform_b': ComplexTransformTag[operation.A.complex_transform],
  579. 'alignment_b': str(operation.A.alignment),
  580. 'element_c': DataTypeTag[operation.C.element],
  581. 'layout_c': LayoutTag[operation.C.layout],
  582. 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
  583. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  584. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  585. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  586. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  587. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  588. 'warp_shape_m': str(warp_shape[0]),
  589. 'warp_shape_n': str(warp_shape[1]),
  590. 'warp_shape_k': str(warp_shape[2]),
  591. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  592. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  593. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  594. 'alignment_c': str(operation.C.alignment),
  595. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  596. 'stages': str(operation.tile_description.stages),
  597. 'math_operator': 'cutlass::arch::OpMultiplyAdd'
  598. }
  599. return SubstituteTemplate(self.template, values)
  600. ###################################################################################################
  601. #
  602. class EmitGemmPlanarComplexArrayInstance:
  603. ''' Responsible for emitting a CUTLASS template definition'''
  604. def __init__(self):
  605. self.template = """
  606. // Gemm operator ${operation_name}
  607. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  608. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  609. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  610. ${element_c}, cutlass::layout::RowMajor,
  611. ${element_accumulator},
  612. ${opcode_class},
  613. ${arch},
  614. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  615. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  616. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  617. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  618. ${element_c},
  619. ${alignment_c},
  620. ${element_accumulator},
  621. ${element_epilogue}
  622. >,
  623. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  624. ${stages},
  625. ${math_operator}
  626. >::GemmArrayKernel;
  627. struct ${operation_name} : public Operation_${operation_name} { };
  628. """
  629. def emit(self, operation):
  630. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  631. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  632. transposed_layout_A = TransposedLayout[operation.A.layout]
  633. transposed_layout_B = TransposedLayout[operation.B.layout]
  634. values = {
  635. 'operation_name': operation.procedural_name(),
  636. 'element_a': DataTypeTag[operation.B.element],
  637. 'layout_a': LayoutTag[transposed_layout_B],
  638. 'transform_a': ComplexTransformTag[operation.B.complex_transform],
  639. 'alignment_a': str(operation.B.alignment),
  640. 'element_b': DataTypeTag[operation.A.element],
  641. 'layout_b': LayoutTag[transposed_layout_A],
  642. 'transform_b': ComplexTransformTag[operation.A.complex_transform],
  643. 'alignment_b': str(operation.A.alignment),
  644. 'element_c': DataTypeTag[operation.C.element],
  645. 'layout_c': LayoutTag[operation.C.layout],
  646. 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
  647. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  648. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  649. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  650. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  651. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  652. 'warp_shape_m': str(warp_shape[0]),
  653. 'warp_shape_n': str(warp_shape[1]),
  654. 'warp_shape_k': str(warp_shape[2]),
  655. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  656. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  657. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  658. 'alignment_c': str(operation.C.alignment),
  659. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  660. 'stages': str(operation.tile_description.stages),
  661. 'math_operator': 'cutlass::arch::OpMultiplyAdd'
  662. }
  663. return SubstituteTemplate(self.template, values)
  664. #
  665. class EmitGemmSplitKParallelInstance:
  666. ''' Responsible for emitting a CUTLASS template definition'''
  667. def __init__(self):
  668. self.template = """
  669. // Gemm operator ${operation_name}
  670. using Operation_${operation_name} = cutlass::gemm::device::GemmSplitKParallel<
  671. ${element_a}, ${layout_a},
  672. ${element_b}, ${layout_b},
  673. ${element_c}, ${layout_c},
  674. ${element_accumulator},
  675. ${opcode_class},
  676. ${arch},
  677. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  678. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  679. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  680. ${epilogue_functor}<
  681. ${element_c},
  682. ${epilogue_vector_length},
  683. ${element_accumulator},
  684. ${element_epilogue}
  685. >
  686. >;
  687. """
  688. def emit(self, operation):
  689. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  690. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  691. values = {
  692. 'operation_name': operation.procedural_name(),
  693. 'element_a': DataTypeTag[operation.A.element],
  694. 'layout_a': LayoutTag[operation.A.layout],
  695. 'element_b': DataTypeTag[operation.B.element],
  696. 'layout_b': LayoutTag[operation.B.layout],
  697. 'element_c': DataTypeTag[operation.C.element],
  698. 'layout_c': LayoutTag[operation.C.layout],
  699. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  700. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  701. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  702. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  703. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  704. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  705. 'warp_shape_m': str(warp_shape[0]),
  706. 'warp_shape_n': str(warp_shape[1]),
  707. 'warp_shape_k': str(warp_shape[2]),
  708. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  709. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  710. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  711. 'epilogue_vector_length': str(epilogue_vector_length),
  712. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  713. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  714. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  715. }
  716. return SubstituteTemplate(self.template, values)
  717. ###################################################################################################
  718. ###################################################################################################
  719. #
  720. # Emitters functions for all targets
  721. #
  722. ###################################################################################################
  723. class EmitGemmConfigurationLibrary:
  724. def __init__(self, operation_path, configuration_name):
  725. self.configuration_name = configuration_name
  726. self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
  727. self.instance_emitter = {
  728. GemmKind.Gemm: EmitGemmInstance,
  729. GemmKind.Sparse: EmitSparseGemmInstance,
  730. GemmKind.Universal: EmitGemmUniversalInstance,
  731. GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
  732. GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance
  733. }
  734. self.gemm_kind_wrappers = {
  735. GemmKind.Gemm: 'GemmOperation',
  736. GemmKind.Sparse: 'GemmSparseOperation',
  737. GemmKind.Universal: 'GemmUniversalOperation',
  738. GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
  739. GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation'
  740. }
  741. self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
  742. self.instance_template = {
  743. GemmKind.Gemm: """
  744. ${compile_guard_start}
  745. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  746. ${compile_guard_end}
  747. """,
  748. GemmKind.Sparse: """
  749. ${compile_guard_start}
  750. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  751. ${compile_guard_end}
  752. """,
  753. GemmKind.Universal: """
  754. ${compile_guard_start}
  755. manifest.append(new ${gemm_kind}<
  756. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  757. >("${operation_name}"));
  758. ${compile_guard_end}
  759. """,
  760. GemmKind.PlanarComplex: """
  761. ${compile_guard_start}
  762. manifest.append(new ${gemm_kind}<
  763. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  764. >("${operation_name}"));
  765. ${compile_guard_end}
  766. """,
  767. GemmKind.PlanarComplexArray: """
  768. ${compile_guard_start}
  769. manifest.append(new ${gemm_kind}<
  770. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  771. >("${operation_name}"));
  772. ${compile_guard_end}
  773. """
  774. }
  775. self.header_template = """
  776. /*
  777. Generated by gemm_operation.py - Do not edit.
  778. */
  779. ///////////////////////////////////////////////////////////////////////////////////////////////////
  780. #include "cutlass/arch/wmma.h"
  781. #include "cutlass/cutlass.h"
  782. #include "cutlass/library/library.h"
  783. #include "cutlass/library/manifest.h"
  784. #include "library_internal.h"
  785. #include "gemm_operation.h"
  786. ///////////////////////////////////////////////////////////////////////////////////////////////////
  787. """
  788. self.initialize_function_template = """
  789. ///////////////////////////////////////////////////////////////////////////////////////////////////
  790. namespace cutlass {
  791. namespace library {
  792. ///////////////////////////////////////////////////////////////////////////////////////////////////
  793. void initialize_${configuration_name}(Manifest &manifest) {
  794. """
  795. self.epilogue_template = """
  796. }
  797. ///////////////////////////////////////////////////////////////////////////////////////////////////
  798. } // namespace library
  799. } // namespace cutlass
  800. ///////////////////////////////////////////////////////////////////////////////////////////////////
  801. """
  802. def __enter__(self):
  803. self.configuration_file = open(self.configuration_path, "w")
  804. self.configuration_file.write(self.header_template)
  805. self.instance_definitions = []
  806. self.instance_wrappers = []
  807. self.operations = []
  808. return self
  809. def emit(self, operation):
  810. emitter = self.instance_emitter[operation.gemm_kind]()
  811. self.operations.append(operation)
  812. self.instance_definitions.append(emitter.emit(operation))
  813. self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.gemm_kind], {
  814. 'configuration_name': self.configuration_name,
  815. 'operation_name': operation.procedural_name(),
  816. 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind],
  817. 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
  818. if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
  819. 'compile_guard_end': "#endif" \
  820. if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
  821. }))
  822. def __exit__(self, exception_type, exception_value, traceback):
  823. # Write instance definitions in top-level namespace
  824. for instance_definition in self.instance_definitions:
  825. self.configuration_file.write(instance_definition)
  826. # Add wrapper objects within initialize() function
  827. self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
  828. 'configuration_name': self.configuration_name
  829. }))
  830. for instance_wrapper in self.instance_wrappers:
  831. self.configuration_file.write(instance_wrapper)
  832. self.configuration_file.write(self.epilogue_template)
  833. self.configuration_file.close()
  834. ###################################################################################################
  835. ###################################################################################################
  836. class EmitGemmSingleKernelWrapper:
  837. def __init__(self, kernel_path, gemm_operation, wrapper_path):
  838. self.kernel_path = kernel_path
  839. self.wrapper_path = wrapper_path
  840. self.operation = gemm_operation
  841. gemm_wrapper = """
  842. template void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper<Operation_${operation_name}>(
  843. const typename Operation_${operation_name}::ElementA* d_A, size_t lda,
  844. const typename Operation_${operation_name}::ElementB* d_B, size_t ldb,
  845. typename Operation_${operation_name}::ElementC* d_C, size_t ldc,
  846. int* workspace,
  847. cutlass::gemm::GemmCoord const& problem_size,
  848. typename Operation_${operation_name}::EpilogueOutputOp::Params const& epilogue,
  849. cudaStream_t stream, int split_k_slices);
  850. """
  851. gemv_wrapper = """
  852. template void megdnn::cuda::cutlass_wrapper::
  853. cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>(
  854. BatchedGemmCoord const& problem_size,
  855. const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a,
  856. const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b,
  857. typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
  858. cudaStream_t stream);
  859. """
  860. if self.operation.gemm_kind == GemmKind.SplitKParallel or \
  861. self.operation.gemm_kind == GemmKind.Gemm:
  862. self.wrapper_template = gemm_wrapper
  863. else:
  864. assert self.operation.gemm_kind == GemmKind.GemvBatchedStrided
  865. self.wrapper_template = gemv_wrapper
  866. instance_emitters = {
  867. GemmKind.Gemm: EmitGemmInstance(),
  868. GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(),
  869. GemmKind.GemvBatchedStrided: EmitGemvBatchedStridedInstance(),
  870. }
  871. self.instance_emitter = instance_emitters[self.operation.gemm_kind]
  872. self.header_template = """
  873. #if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
  874. // ignore warning of cutlass
  875. #pragma GCC diagnostic push
  876. #pragma GCC diagnostic ignored "-Wunused-parameter"
  877. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  878. #pragma GCC diagnostic ignored "-Wuninitialized"
  879. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  880. #include "${wrapper_path}"
  881. """
  882. self.instance_template = """
  883. ${operation_instance}
  884. """
  885. self.epilogue_template = """
  886. #pragma GCC diagnostic pop
  887. #endif
  888. """
  889. #
  890. def __enter__(self):
  891. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  892. self.kernel_file = LazyFile(self.kernel_path)
  893. self.kernel_file.write(SubstituteTemplate(self.header_template, {
  894. 'wrapper_path': self.wrapper_path,
  895. }))
  896. return self
  897. #
  898. def emit(self):
  899. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  900. 'operation_instance': self.instance_emitter.emit(self.operation),
  901. }))
  902. # emit wrapper
  903. wrapper = SubstituteTemplate(self.wrapper_template, {
  904. 'operation_name': self.operation.procedural_name(),
  905. })
  906. self.kernel_file.write(wrapper)
  907. #
  908. def __exit__(self, exception_type, exception_value, traceback):
  909. self.kernel_file.write(self.epilogue_template)
  910. self.kernel_file.close()
  911. ###################################################################################################
  912. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台