You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gemm_operation.py 47 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. import functools
  10. import operator
  11. from library import *
  12. ###################################################################################################
  13. #
  14. # Data structure modeling a GEMM operation
  15. #
  16. ###################################################################################################
  17. #
  18. class GemmOperation:
  19. #
  20. def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
  21. epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
  22. required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  23. self.operation_kind = OperationKind.Gemm
  24. self.arch = arch
  25. self.tile_description = tile_description
  26. self.gemm_kind = gemm_kind
  27. self.A = A
  28. self.B = B
  29. self.C = C
  30. self.element_epilogue = element_epilogue
  31. self.epilogue_functor = epilogue_functor
  32. self.swizzling_functor = swizzling_functor
  33. self.required_cuda_ver_major = required_cuda_ver_major
  34. self.required_cuda_ver_minor = required_cuda_ver_minor
  35. #
  36. def is_complex(self):
  37. complex_operators = [
  38. MathOperation.multiply_add_complex,
  39. MathOperation.multiply_add_complex_gaussian
  40. ]
  41. return self.tile_description.math_instruction.math_operation in complex_operators
  42. #
  43. def is_split_k_parallel(self):
  44. return self.gemm_kind == GemmKind.SplitKParallel
  45. #
  46. def is_planar_complex(self):
  47. return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
  48. #
  49. def accumulator_type(self):
  50. accum = self.tile_description.math_instruction.element_accumulator
  51. if self.is_complex():
  52. return get_complex_from_real(accum)
  53. return accum
  54. #
  55. def short_math_name(self):
  56. if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
  57. return "g%s" % ShortDataTypeNames[self.accumulator_type()]
  58. return ShortDataTypeNames[self.accumulator_type()]
  59. #
  60. def core_name(self):
  61. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  62. inst_shape = ''
  63. inst_operation = ''
  64. intermediate_type = ''
  65. math_operations_map = {
  66. MathOperation.xor_popc: 'xor',
  67. }
  68. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
  69. self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
  70. math_op = self.tile_description.math_instruction.math_operation
  71. math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
  72. inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
  73. inst_shape += math_op_string
  74. if self.tile_description.math_instruction.element_a != self.A.element and \
  75. self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
  76. intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
  77. return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
  78. #
  79. def extended_name(self):
  80. ''' Append data types if they differ from compute type. '''
  81. if self.is_complex():
  82. extended_name = "${core_name}"
  83. else:
  84. if self.C.element != self.tile_description.math_instruction.element_accumulator and \
  85. self.A.element != self.tile_description.math_instruction.element_accumulator:
  86. extended_name = "${element_c}_${core_name}_${element_a}"
  87. elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
  88. self.A.element != self.tile_description.math_instruction.element_accumulator:
  89. extended_name = "${core_name}_${element_a}"
  90. else:
  91. extended_name = "${core_name}"
  92. extended_name = SubstituteTemplate(extended_name, {
  93. 'element_a': DataTypeNames[self.A.element],
  94. 'element_c': DataTypeNames[self.C.element],
  95. 'core_name': self.core_name()
  96. })
  97. return extended_name
  98. #
  99. def layout_name(self):
  100. if self.is_complex() or self.is_planar_complex():
  101. return "%s%s" % (
  102. ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
  103. ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
  104. )
  105. return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
  106. #
  107. def procedural_name(self):
  108. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  109. threadblock = self.tile_description.procedural_name()
  110. opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
  111. alignment = max([self.A.alignment, self.B.alignment, self.C.alignment])
  112. return SubstituteTemplate(
  113. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}",
  114. {
  115. 'opcode_class': opcode_class_name,
  116. 'extended_name': self.extended_name(),
  117. 'threadblock': threadblock,
  118. 'layout': self.layout_name(),
  119. 'alignment': "%d" % self.A.alignment,
  120. }
  121. )
  122. #
  123. def configuration_name(self):
  124. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  125. return self.procedural_name()
  126. ###################################################################################################
  127. #
  128. # Data structure modeling a GEMV Batched Strided operation
  129. #
  130. ###################################################################################################
  131. #
  132. class GemvBatchedStridedOperation:
  133. #
  134. def __init__(self, gemm_kind, arch, math_inst, threadblock_shape, thread_shape, A, B, C, \
  135. required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  136. self.operation_kind = OperationKind.Gemm
  137. self.arch = arch
  138. self.gemm_kind = gemm_kind
  139. self.math_instruction = math_inst
  140. self.threadblock_shape = threadblock_shape
  141. self.thread_shape = thread_shape
  142. self.A = A
  143. self.B = B
  144. self.C = C
  145. self.required_cuda_ver_major = required_cuda_ver_major
  146. self.required_cuda_ver_minor = required_cuda_ver_minor
  147. #
  148. def accumulator_type(self):
  149. accum = self.math_instruction.element_accumulator
  150. return accum
  151. #
  152. def short_math_name(self):
  153. return ShortDataTypeNames[self.accumulator_type()]
  154. #
  155. def core_name(self):
  156. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  157. return "%s%s" % (self.short_math_name(), \
  158. GemmKindNames[self.gemm_kind])
  159. #
  160. def extended_name(self):
  161. ''' Append data types if they differ from compute type. '''
  162. if self.C.element != self.math_instruction.element_accumulator and \
  163. self.A.element != self.math_instruction.element_accumulator:
  164. extended_name = "${element_c}_${core_name}_${element_a}"
  165. elif self.C.element == self.math_instruction.element_accumulator and \
  166. self.A.element != self.math_instruction.element_accumulator:
  167. extended_name = "${core_name}_${element_a}"
  168. else:
  169. extended_name = "${core_name}"
  170. extended_name = SubstituteTemplate(extended_name, {
  171. 'element_a': DataTypeNames[self.A.element],
  172. 'element_c': DataTypeNames[self.C.element],
  173. 'core_name': self.core_name()
  174. })
  175. return extended_name
  176. #
  177. def layout_name(self):
  178. return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
  179. #
  180. def procedural_name(self):
  181. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  182. threadblock = "%dx%d_%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2])
  183. opcode_class_name = OpcodeClassNames[self.math_instruction.opcode_class]
  184. alignment_a = self.A.alignment
  185. alignment_b = self.B.alignment
  186. return SubstituteTemplate(
  187. "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment_a}x${alignment_b}",
  188. {
  189. 'opcode_class': opcode_class_name,
  190. 'extended_name': self.extended_name(),
  191. 'threadblock': threadblock,
  192. 'layout': self.layout_name(),
  193. 'alignment_a': "%d" % alignment_a,
  194. 'alignment_b': "%d" % alignment_b,
  195. }
  196. )
  197. #
  198. def configuration_name(self):
  199. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  200. return self.procedural_name()
  201. #
  202. def GeneratesGemm(tile, data_type, layout_a, layout_b, layout_c, min_cc, align_a = 32, align_b = 32, align_c = 32, required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  203. operations = []
  204. swizzling_functor = SwizzlingFunctor.Identity1
  205. element_a, element_b, element_c, element_epilogue = data_type
  206. if tile.math_instruction.element_accumulator == DataType.s32:
  207. epilogues = [EpilogueFunctor.LinearCombinationClamp]
  208. else:
  209. assert tile.math_instruction.element_accumulator == DataType.f32 or \
  210. tile.math_instruction.element_accumulator == DataType.f16
  211. epilogues = [EpilogueFunctor.LinearCombination]
  212. for epilogue in epilogues:
  213. A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a]))
  214. B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b]))
  215. C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c]))
  216. operations.append(GemmOperation(GemmKind.Gemm, min_cc, tile, A, B, C, \
  217. element_epilogue, epilogue, swizzling_functor, \
  218. required_cuda_ver_major, required_cuda_ver_minor))
  219. operations.append(GemmOperation(GemmKind.SplitKParallel, min_cc, tile, A, B, C, \
  220. element_epilogue, epilogue, swizzling_functor, \
  221. required_cuda_ver_major, required_cuda_ver_minor))
  222. return operations
  223. def GeneratesGemv(math_inst, threadblock_shape, thread_shape, data_type, layout_a, layout_b, layout_c, min_cc, \
  224. align_a = 32, align_b = 32, align_c = 32, \
  225. required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  226. element_a, element_b, element_c, element_epilogue = data_type
  227. A = TensorDescription(element_a, layout_a, int(align_a//DataTypeSize[element_a]))
  228. B = TensorDescription(element_b, layout_b, int(align_b//DataTypeSize[element_b]))
  229. C = TensorDescription(element_c, layout_c, int(align_c//DataTypeSize[element_c]))
  230. return GemvBatchedStridedOperation(GemmKind.GemvBatchedStrided, min_cc, math_inst, threadblock_shape, thread_shape, \
  231. A, B, C, required_cuda_ver_major, required_cuda_ver_minor)
  232. ###################################################################################################
  233. #
  234. # Emits single instances of a CUTLASS device-wide operator
  235. #
  236. ###################################################################################################
  237. #
  238. class EmitGemmInstance:
  239. ''' Responsible for emitting a CUTLASS template definition'''
  240. def __init__(self):
  241. self.gemm_template = """
  242. // Gemm operator ${operation_name}
  243. using Operation_${operation_name} = cutlass::gemm::device::Gemm<
  244. ${element_a}, ${layout_a},
  245. ${element_b}, ${layout_b},
  246. ${element_c}, ${layout_c},
  247. ${element_accumulator},
  248. ${opcode_class},
  249. ${arch},
  250. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  251. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  252. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  253. ${epilogue_functor}<
  254. ${element_c},
  255. ${epilogue_vector_length},
  256. ${element_accumulator},
  257. ${element_epilogue}
  258. >,
  259. ${swizzling_functor},
  260. ${stages},
  261. ${align_a},
  262. ${align_b},
  263. false,
  264. ${math_operation}
  265. ${residual}
  266. >;
  267. """
  268. self.gemm_complex_template = """
  269. // Gemm operator ${operation_name}
  270. using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
  271. ${element_a}, ${layout_a},
  272. ${element_b}, ${layout_b},
  273. ${element_c}, ${layout_c},
  274. ${element_accumulator},
  275. ${opcode_class},
  276. ${arch},
  277. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  278. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  279. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  280. ${epilogue_functor}<
  281. ${element_c},
  282. ${epilogue_vector_length},
  283. ${element_accumulator},
  284. ${element_epilogue}
  285. >,
  286. ${swizzling_functor},
  287. ${stages},
  288. ${transform_a},
  289. ${transform_b},
  290. ${math_operation}
  291. ${residual}
  292. >;
  293. """
  294. def emit(self, operation):
  295. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  296. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  297. residual = ''
  298. values = {
  299. 'operation_name': operation.procedural_name(),
  300. 'element_a': DataTypeTag[operation.A.element],
  301. 'layout_a': LayoutTag[operation.A.layout],
  302. 'element_b': DataTypeTag[operation.B.element],
  303. 'layout_b': LayoutTag[operation.B.layout],
  304. 'element_c': DataTypeTag[operation.C.element],
  305. 'layout_c': LayoutTag[operation.C.layout],
  306. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  307. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  308. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  309. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  310. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  311. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  312. 'warp_shape_m': str(warp_shape[0]),
  313. 'warp_shape_n': str(warp_shape[1]),
  314. 'warp_shape_k': str(warp_shape[2]),
  315. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  316. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  317. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  318. 'epilogue_vector_length': str(epilogue_vector_length),
  319. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  320. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  321. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  322. 'stages': str(operation.tile_description.stages),
  323. 'align_a': str(operation.A.alignment),
  324. 'align_b': str(operation.B.alignment),
  325. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  326. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  327. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  328. 'residual': residual
  329. }
  330. template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
  331. return SubstituteTemplate(template, values)
  332. #
  333. class EmitGemvBatchedStridedInstance:
  334. ''' Responsible for emitting a CUTLASS template definition'''
  335. def __init__(self):
  336. self.template = """
  337. // Gemm operator ${operation_name}
  338. using Operation_${operation_name} = cutlass::gemm::kernel::DefaultGemv<
  339. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  340. cutlass::gemm::GemmShape<${thread_shape_m}, ${thread_shape_n}, ${thread_shape_k}>,
  341. ${element_a}, ${layout_a},
  342. ${element_b}, ${layout_b},
  343. ${element_c}, ${layout_c}
  344. >;
  345. """
  346. def emit(self, operation):
  347. values = {
  348. 'operation_name': operation.procedural_name(),
  349. 'element_a': DataTypeTag[operation.A.element],
  350. 'layout_a': LayoutTag[operation.A.layout],
  351. 'element_b': DataTypeTag[operation.B.element],
  352. 'layout_b': LayoutTag[operation.B.layout],
  353. 'element_c': DataTypeTag[operation.C.element],
  354. 'layout_c': LayoutTag[operation.C.layout],
  355. 'threadblock_shape_m': str(operation.threadblock_shape[0]),
  356. 'threadblock_shape_n': str(operation.threadblock_shape[1]),
  357. 'threadblock_shape_k': str(operation.threadblock_shape[2]),
  358. 'thread_shape_m': str(operation.thread_shape[0]),
  359. 'thread_shape_n': str(operation.thread_shape[1]),
  360. 'thread_shape_k': str(operation.thread_shape[2]),
  361. }
  362. return SubstituteTemplate(self.template, values)
  363. ###################################################################################################
  364. class EmitSparseGemmInstance:
  365. ''' Responsible for emitting a CUTLASS template definition'''
  366. def __init__(self):
  367. self.gemm_template = """
  368. // Gemm operator ${operation_name}
  369. using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
  370. ${element_a}, ${layout_a},
  371. ${element_b}, ${layout_b},
  372. ${element_c}, ${layout_c},
  373. ${element_accumulator},
  374. ${opcode_class},
  375. ${arch},
  376. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  377. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  378. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  379. ${epilogue_functor}<
  380. ${element_c},
  381. ${epilogue_vector_length},
  382. ${element_accumulator},
  383. ${element_epilogue}
  384. >,
  385. ${swizzling_functor},
  386. ${stages},
  387. ${align_a},
  388. ${align_b},
  389. false,
  390. ${math_operation}
  391. ${residual}
  392. >;
  393. """
  394. def emit(self, operation):
  395. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  396. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  397. residual = ''
  398. values = {
  399. 'operation_name': operation.procedural_name(),
  400. 'element_a': DataTypeTag[operation.A.element],
  401. 'layout_a': LayoutTag[operation.A.layout],
  402. 'element_b': DataTypeTag[operation.B.element],
  403. 'layout_b': LayoutTag[operation.B.layout],
  404. 'element_c': DataTypeTag[operation.C.element],
  405. 'layout_c': LayoutTag[operation.C.layout],
  406. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  407. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  408. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  409. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  410. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  411. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  412. 'warp_shape_m': str(warp_shape[0]),
  413. 'warp_shape_n': str(warp_shape[1]),
  414. 'warp_shape_k': str(warp_shape[2]),
  415. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  416. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  417. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  418. 'epilogue_vector_length': str(epilogue_vector_length),
  419. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  420. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  421. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  422. 'stages': str(operation.tile_description.stages),
  423. 'align_a': str(operation.A.alignment),
  424. 'align_b': str(operation.B.alignment),
  425. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  426. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  427. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  428. 'residual': residual
  429. }
  430. template = self.gemm_template
  431. return SubstituteTemplate(template, values)
  432. ###################################################################################################
  433. #
  434. class EmitGemmUniversalInstance:
  435. ''' Responsible for emitting a CUTLASS template definition'''
  436. def __init__(self):
  437. self.gemm_template = """
  438. // Gemm operator ${operation_name}
  439. using ${operation_name}_base =
  440. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  441. ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
  442. ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
  443. ${element_c}, ${layout_c},
  444. ${element_accumulator},
  445. ${opcode_class},
  446. ${arch},
  447. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  448. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  449. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  450. ${epilogue_functor}<
  451. ${element_c},
  452. ${epilogue_vector_length},
  453. ${element_accumulator},
  454. ${element_epilogue}
  455. >,
  456. ${swizzling_functor},
  457. ${stages},
  458. ${math_operation}
  459. >::GemmKernel;
  460. // Define named type
  461. struct ${operation_name} :
  462. public ${operation_name}_base { };
  463. """
  464. self.gemm_template_interleaved = """
  465. // Gemm operator ${operation_name}
  466. using ${operation_name}_base =
  467. typename cutlass::gemm::kernel::DefaultGemmUniversal<
  468. ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
  469. ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
  470. ${element_c}, ${layout_c},
  471. ${element_accumulator},
  472. ${opcode_class},
  473. ${arch},
  474. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  475. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  476. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  477. ${epilogue_functor}<
  478. ${element_c},
  479. ${epilogue_vector_length},
  480. ${element_accumulator},
  481. ${element_epilogue}
  482. >,
  483. ${swizzling_functor},
  484. ${stages},
  485. ${math_operation}
  486. >::GemmKernel;
  487. // Define named type
  488. struct ${operation_name} :
  489. public ${operation_name}_base { };
  490. """
  491. def emit(self, operation):
  492. threadblock_shape = operation.tile_description.threadblock_shape
  493. warp_count = operation.tile_description.warp_count
  494. warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
  495. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  496. transpose_layouts = {
  497. LayoutType.ColumnMajor: LayoutType.RowMajor,
  498. LayoutType.RowMajor: LayoutType.ColumnMajor
  499. }
  500. if operation.A.layout in transpose_layouts.keys() and \
  501. operation.B.layout in transpose_layouts.keys() and \
  502. operation.C.layout in transpose_layouts.keys():
  503. instance_layout_A = transpose_layouts[operation.A.layout]
  504. instance_layout_B = transpose_layouts[operation.B.layout]
  505. instance_layout_C = transpose_layouts[operation.C.layout]
  506. gemm_template = self.gemm_template
  507. else:
  508. instance_layout_A, instance_layout_B, instance_layout_C = \
  509. (operation.A.layout, operation.B.layout, operation.C.layout)
  510. gemm_template = self.gemm_template_interleaved
  511. #
  512. values = {
  513. 'operation_name': operation.procedural_name(),
  514. 'element_a': DataTypeTag[operation.A.element],
  515. 'layout_a': LayoutTag[instance_layout_A],
  516. 'element_b': DataTypeTag[operation.B.element],
  517. 'layout_b': LayoutTag[instance_layout_B],
  518. 'element_c': DataTypeTag[operation.C.element],
  519. 'layout_c': LayoutTag[instance_layout_C],
  520. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  521. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  522. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  523. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  524. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  525. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  526. 'warp_shape_m': str(warp_shape[0]),
  527. 'warp_shape_n': str(warp_shape[1]),
  528. 'warp_shape_k': str(warp_shape[2]),
  529. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  530. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  531. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  532. 'epilogue_vector_length': str(epilogue_vector_length),
  533. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  534. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  535. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  536. 'stages': str(operation.tile_description.stages),
  537. 'align_a': str(operation.A.alignment),
  538. 'align_b': str(operation.B.alignment),
  539. 'transform_a': ComplexTransformTag[operation.A.complex_transform],
  540. 'transform_b': ComplexTransformTag[operation.B.complex_transform],
  541. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
  542. }
  543. return SubstituteTemplate(gemm_template, values)
  544. ###################################################################################################
  545. #
  546. class EmitGemmPlanarComplexInstance:
  547. ''' Responsible for emitting a CUTLASS template definition'''
  548. def __init__(self):
  549. self.template = """
  550. // Gemm operator ${operation_name}
  551. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  552. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  553. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  554. ${element_c}, cutlass::layout::RowMajor,
  555. ${element_accumulator},
  556. ${opcode_class},
  557. ${arch},
  558. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  559. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  560. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  561. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  562. ${element_c},
  563. ${alignment_c},
  564. ${element_accumulator},
  565. ${element_epilogue}
  566. >,
  567. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  568. ${stages},
  569. ${math_operator}
  570. >::GemmKernel;
  571. struct ${operation_name} :
  572. public Operation_${operation_name} { };
  573. """
  574. def emit(self, operation):
  575. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  576. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  577. transposed_layout_A = TransposedLayout[operation.A.layout]
  578. transposed_layout_B = TransposedLayout[operation.B.layout]
  579. values = {
  580. 'operation_name': operation.procedural_name(),
  581. 'element_a': DataTypeTag[operation.B.element],
  582. 'layout_a': LayoutTag[transposed_layout_B],
  583. 'transform_a': ComplexTransformTag[operation.B.complex_transform],
  584. 'alignment_a': str(operation.B.alignment),
  585. 'element_b': DataTypeTag[operation.A.element],
  586. 'layout_b': LayoutTag[transposed_layout_A],
  587. 'transform_b': ComplexTransformTag[operation.A.complex_transform],
  588. 'alignment_b': str(operation.A.alignment),
  589. 'element_c': DataTypeTag[operation.C.element],
  590. 'layout_c': LayoutTag[operation.C.layout],
  591. 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
  592. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  593. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  594. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  595. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  596. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  597. 'warp_shape_m': str(warp_shape[0]),
  598. 'warp_shape_n': str(warp_shape[1]),
  599. 'warp_shape_k': str(warp_shape[2]),
  600. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  601. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  602. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  603. 'alignment_c': str(operation.C.alignment),
  604. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  605. 'stages': str(operation.tile_description.stages),
  606. 'math_operator': 'cutlass::arch::OpMultiplyAdd'
  607. }
  608. return SubstituteTemplate(self.template, values)
  609. ###################################################################################################
  610. #
  611. class EmitGemmPlanarComplexArrayInstance:
  612. ''' Responsible for emitting a CUTLASS template definition'''
  613. def __init__(self):
  614. self.template = """
  615. // Gemm operator ${operation_name}
  616. using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
  617. ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
  618. ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
  619. ${element_c}, cutlass::layout::RowMajor,
  620. ${element_accumulator},
  621. ${opcode_class},
  622. ${arch},
  623. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  624. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  625. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  626. cutlass::epilogue::thread::LinearCombinationPlanarComplex<
  627. ${element_c},
  628. ${alignment_c},
  629. ${element_accumulator},
  630. ${element_epilogue}
  631. >,
  632. cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  633. ${stages},
  634. ${math_operator}
  635. >::GemmArrayKernel;
  636. struct ${operation_name} : public Operation_${operation_name} { };
  637. """
  638. def emit(self, operation):
  639. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  640. # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
  641. transposed_layout_A = TransposedLayout[operation.A.layout]
  642. transposed_layout_B = TransposedLayout[operation.B.layout]
  643. values = {
  644. 'operation_name': operation.procedural_name(),
  645. 'element_a': DataTypeTag[operation.B.element],
  646. 'layout_a': LayoutTag[transposed_layout_B],
  647. 'transform_a': ComplexTransformTag[operation.B.complex_transform],
  648. 'alignment_a': str(operation.B.alignment),
  649. 'element_b': DataTypeTag[operation.A.element],
  650. 'layout_b': LayoutTag[transposed_layout_A],
  651. 'transform_b': ComplexTransformTag[operation.A.complex_transform],
  652. 'alignment_b': str(operation.A.alignment),
  653. 'element_c': DataTypeTag[operation.C.element],
  654. 'layout_c': LayoutTag[operation.C.layout],
  655. 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
  656. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  657. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  658. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  659. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  660. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  661. 'warp_shape_m': str(warp_shape[0]),
  662. 'warp_shape_n': str(warp_shape[1]),
  663. 'warp_shape_k': str(warp_shape[2]),
  664. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  665. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  666. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  667. 'alignment_c': str(operation.C.alignment),
  668. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  669. 'stages': str(operation.tile_description.stages),
  670. 'math_operator': 'cutlass::arch::OpMultiplyAdd'
  671. }
  672. return SubstituteTemplate(self.template, values)
  673. #
  674. class EmitGemmSplitKParallelInstance:
  675. ''' Responsible for emitting a CUTLASS template definition'''
  676. def __init__(self):
  677. self.template = """
  678. // Gemm operator ${operation_name}
  679. using Operation_${operation_name} = cutlass::gemm::device::GemmSplitKParallel<
  680. ${element_a}, ${layout_a},
  681. ${element_b}, ${layout_b},
  682. ${element_c}, ${layout_c},
  683. ${element_accumulator},
  684. ${opcode_class},
  685. ${arch},
  686. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  687. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  688. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  689. ${epilogue_functor}<
  690. ${element_c},
  691. ${epilogue_vector_length},
  692. ${element_accumulator},
  693. ${element_epilogue}
  694. >,
  695. cutlass::epilogue::thread::Convert<
  696. ${element_accumulator},
  697. ${epilogue_vector_length},
  698. ${element_accumulator}
  699. >,
  700. cutlass::reduction::thread::ReduceAdd<
  701. ${element_accumulator},
  702. ${element_accumulator},
  703. ${epilogue_vector_length}
  704. >,
  705. cutlass::gemm::threadblock::GemmSplitKHorizontalThreadblockSwizzle,
  706. ${stages},
  707. ${align_a},
  708. ${align_b},
  709. ${math_operation}
  710. >;
  711. """
  712. def emit(self, operation):
  713. warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
  714. epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
  715. values = {
  716. 'operation_name': operation.procedural_name(),
  717. 'element_a': DataTypeTag[operation.A.element],
  718. 'layout_a': LayoutTag[operation.A.layout],
  719. 'element_b': DataTypeTag[operation.B.element],
  720. 'layout_b': LayoutTag[operation.B.layout],
  721. 'element_c': DataTypeTag[operation.C.element],
  722. 'layout_c': LayoutTag[operation.C.layout],
  723. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  724. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  725. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  726. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  727. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  728. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  729. 'warp_shape_m': str(warp_shape[0]),
  730. 'warp_shape_n': str(warp_shape[1]),
  731. 'warp_shape_k': str(warp_shape[2]),
  732. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  733. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  734. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  735. 'epilogue_vector_length': str(epilogue_vector_length),
  736. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  737. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  738. 'stages': str(operation.tile_description.stages),
  739. 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  740. 'align_a': str(operation.A.alignment),
  741. 'align_b': str(operation.B.alignment),
  742. }
  743. return SubstituteTemplate(self.template, values)
  744. ###################################################################################################
  745. ###################################################################################################
  746. #
  747. # Emitters functions for all targets
  748. #
  749. ###################################################################################################
  750. class EmitGemmConfigurationLibrary:
  751. def __init__(self, operation_path, configuration_name):
  752. self.configuration_name = configuration_name
  753. self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
  754. self.instance_emitter = {
  755. GemmKind.Gemm: EmitGemmInstance,
  756. GemmKind.Sparse: EmitSparseGemmInstance,
  757. GemmKind.Universal: EmitGemmUniversalInstance,
  758. GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
  759. GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance
  760. }
  761. self.gemm_kind_wrappers = {
  762. GemmKind.Gemm: 'GemmOperation',
  763. GemmKind.Sparse: 'GemmSparseOperation',
  764. GemmKind.Universal: 'GemmUniversalOperation',
  765. GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
  766. GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation'
  767. }
  768. self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
  769. self.instance_template = {
  770. GemmKind.Gemm: """
  771. ${compile_guard_start}
  772. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  773. ${compile_guard_end}
  774. """,
  775. GemmKind.Sparse: """
  776. ${compile_guard_start}
  777. manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
  778. ${compile_guard_end}
  779. """,
  780. GemmKind.Universal: """
  781. ${compile_guard_start}
  782. manifest.append(new ${gemm_kind}<
  783. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  784. >("${operation_name}"));
  785. ${compile_guard_end}
  786. """,
  787. GemmKind.PlanarComplex: """
  788. ${compile_guard_start}
  789. manifest.append(new ${gemm_kind}<
  790. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  791. >("${operation_name}"));
  792. ${compile_guard_end}
  793. """,
  794. GemmKind.PlanarComplexArray: """
  795. ${compile_guard_start}
  796. manifest.append(new ${gemm_kind}<
  797. cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
  798. >("${operation_name}"));
  799. ${compile_guard_end}
  800. """
  801. }
  802. self.header_template = """
  803. /*
  804. Generated by gemm_operation.py - Do not edit.
  805. */
  806. ///////////////////////////////////////////////////////////////////////////////////////////////////
  807. #include "cutlass/arch/wmma.h"
  808. #include "cutlass/cutlass.h"
  809. #include "cutlass/library/library.h"
  810. #include "cutlass/library/manifest.h"
  811. #include "library_internal.h"
  812. #include "gemm_operation.h"
  813. ///////////////////////////////////////////////////////////////////////////////////////////////////
  814. """
  815. self.initialize_function_template = """
  816. ///////////////////////////////////////////////////////////////////////////////////////////////////
  817. namespace cutlass {
  818. namespace library {
  819. ///////////////////////////////////////////////////////////////////////////////////////////////////
  820. void initialize_${configuration_name}(Manifest &manifest) {
  821. """
  822. self.epilogue_template = """
  823. }
  824. ///////////////////////////////////////////////////////////////////////////////////////////////////
  825. } // namespace library
  826. } // namespace cutlass
  827. ///////////////////////////////////////////////////////////////////////////////////////////////////
  828. """
  829. def __enter__(self):
  830. self.configuration_file = open(self.configuration_path, "w")
  831. self.configuration_file.write(self.header_template)
  832. self.instance_definitions = []
  833. self.instance_wrappers = []
  834. self.operations = []
  835. return self
  836. def emit(self, operation):
  837. emitter = self.instance_emitter[operation.gemm_kind]()
  838. self.operations.append(operation)
  839. self.instance_definitions.append(emitter.emit(operation))
  840. self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.gemm_kind], {
  841. 'configuration_name': self.configuration_name,
  842. 'operation_name': operation.procedural_name(),
  843. 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind],
  844. 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
  845. if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
  846. 'compile_guard_end': "#endif" \
  847. if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
  848. }))
  849. def __exit__(self, exception_type, exception_value, traceback):
  850. # Write instance definitions in top-level namespace
  851. for instance_definition in self.instance_definitions:
  852. self.configuration_file.write(instance_definition)
  853. # Add wrapper objects within initialize() function
  854. self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
  855. 'configuration_name': self.configuration_name
  856. }))
  857. for instance_wrapper in self.instance_wrappers:
  858. self.configuration_file.write(instance_wrapper)
  859. self.configuration_file.write(self.epilogue_template)
  860. self.configuration_file.close()
  861. ###################################################################################################
  862. ###################################################################################################
  863. class EmitGemmSingleKernelWrapper:
  864. def __init__(self, kernel_path, gemm_operation, short_path=False):
  865. self.short_path = short_path
  866. self.kernel_path = kernel_path
  867. self.operation = gemm_operation
  868. instance_emitters = {
  869. GemmKind.Gemm: EmitGemmInstance(),
  870. GemmKind.SplitKParallel: EmitGemmSplitKParallelInstance(),
  871. }
  872. self.instance_emitter = instance_emitters[self.operation.gemm_kind]
  873. self.header_template = """
  874. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  875. // ignore warning of cutlass
  876. #pragma GCC diagnostic push
  877. #pragma GCC diagnostic ignored "-Wunused-parameter"
  878. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  879. #pragma GCC diagnostic ignored "-Wuninitialized"
  880. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  881. #include "cutlass/gemm/device/gemm.h"
  882. #include "cutlass/gemm/device/gemm_splitk_parallel.h"
  883. #include "src/cuda/cutlass/manifest.h"
  884. #include "src/cuda/cutlass/gemm_operation.h"
  885. """
  886. self.instance_template = """
  887. ${operation_instance}
  888. """
  889. self.manifest_template = """
  890. namespace cutlass {
  891. namespace library {
  892. void initialize_${operation_name}(Manifest &manifest) {
  893. manifest.append(new GemmOperation<
  894. Operation_${operation_name}
  895. >("${operation_name}"));
  896. }
  897. } // namespace library
  898. } // namespace cutlass
  899. """
  900. self.epilogue_template = """
  901. #pragma GCC diagnostic pop
  902. #endif
  903. """
  904. #
  905. def __enter__(self):
  906. if self.short_path:
  907. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  908. GlobalCnt.cnt += 1
  909. else:
  910. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  911. self.kernel_file = open(self.kernel_path, "w")
  912. self.kernel_file.write(SubstituteTemplate(self.header_template, {
  913. 'required_cuda_ver_major': str(self.operation.required_cuda_ver_major),
  914. 'required_cuda_ver_minor': str(self.operation.required_cuda_ver_minor),
  915. }))
  916. return self
  917. #
  918. def emit(self):
  919. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  920. 'operation_instance': self.instance_emitter.emit(self.operation),
  921. }))
  922. # emit manifest helper
  923. manifest = SubstituteTemplate(self.manifest_template, {
  924. 'operation_name': self.operation.procedural_name(),
  925. })
  926. self.kernel_file.write(manifest)
  927. #
  928. def __exit__(self, exception_type, exception_value, traceback):
  929. self.kernel_file.write(self.epilogue_template)
  930. self.kernel_file.close()
  931. ###################################################################################################
  932. ###################################################################################################
  933. class EmitGemvSingleKernelWrapper:
  934. def __init__(self, kernel_path, gemm_operation, wrapper_path, short_path=False):
  935. self.kernel_path = kernel_path
  936. self.wrapper_path = wrapper_path
  937. self.operation = gemm_operation
  938. self.short_path = short_path
  939. self.wrapper_template = """
  940. template void megdnn::cuda::cutlass_wrapper::
  941. cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_${operation_name}>(
  942. BatchedGemmCoord const& problem_size,
  943. const typename Operation_${operation_name}::ElementA* d_A, size_t lda, size_t batch_stride_a,
  944. const typename Operation_${operation_name}::ElementB* d_B, size_t ldb, size_t batch_stride_b,
  945. typename Operation_${operation_name}::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
  946. cudaStream_t stream);
  947. """
  948. self.instance_emitter = EmitGemvBatchedStridedInstance()
  949. self.header_template = """
  950. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  951. // ignore warning of cutlass
  952. #pragma GCC diagnostic push
  953. #pragma GCC diagnostic ignored "-Wunused-parameter"
  954. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  955. #pragma GCC diagnostic ignored "-Wuninitialized"
  956. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  957. #include "${wrapper_path}"
  958. """
  959. self.instance_template = """
  960. ${operation_instance}
  961. """
  962. self.epilogue_template = """
  963. #pragma GCC diagnostic pop
  964. #endif
  965. """
  966. #
  967. def __enter__(self):
  968. if self.short_path:
  969. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  970. GlobalCnt.cnt += 1
  971. else:
  972. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  973. self.kernel_file = open(self.kernel_path, "w")
  974. self.kernel_file.write(SubstituteTemplate(self.header_template, {
  975. 'wrapper_path': self.wrapper_path,
  976. 'required_cuda_ver_major': str(self.operation.required_cuda_ver_major),
  977. 'required_cuda_ver_minor': str(self.operation.required_cuda_ver_minor),
  978. }))
  979. return self
  980. #
  981. def emit(self):
  982. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  983. 'operation_instance': self.instance_emitter.emit(self.operation),
  984. }))
  985. # emit wrapper
  986. wrapper = SubstituteTemplate(self.wrapper_template, {
  987. 'operation_name': self.operation.procedural_name(),
  988. })
  989. self.kernel_file.write(wrapper)
  990. #
  991. def __exit__(self, exception_type, exception_value, traceback):
  992. self.kernel_file.write(self.epilogue_template)
  993. self.kernel_file.close()
  994. ###################################################################################################
  995. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台