You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv2d_operation.py 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. #
  7. import enum
  8. import os.path
  9. import shutil
  10. from typing import List, Tuple
  11. from library import *
  12. ###################################################################################################
  13. #
  14. class Conv2dOperation:
  15. #
  16. def __init__(
  17. self,
  18. conv_kind,
  19. conv_type,
  20. arch,
  21. tile_description,
  22. src,
  23. flt,
  24. bias,
  25. dst,
  26. element_epilogue,
  27. epilogue_functor=EpilogueFunctor.LinearCombination,
  28. swizzling_functor=SwizzlingFunctor.Identity4,
  29. special_optimization=SpecialOptimizeDesc.NoneSpecialOpt,
  30. implicit_gemm_mode=ImplicitGemmMode.GemmNT,
  31. without_shared_load=False,
  32. required_cuda_ver_major=9,
  33. required_cuda_ver_minor=2,
  34. ):
  35. self.operation_kind = OperationKind.Conv2d
  36. self.conv_kind = conv_kind
  37. self.arch = arch
  38. self.tile_description = tile_description
  39. self.conv_type = conv_type
  40. self.src = src
  41. self.flt = flt
  42. self.bias = bias
  43. self.dst = dst
  44. self.element_epilogue = element_epilogue
  45. self.epilogue_functor = epilogue_functor
  46. self.swizzling_functor = swizzling_functor
  47. self.special_optimization = special_optimization
  48. self.implicit_gemm_mode = implicit_gemm_mode
  49. self.without_shared_load = without_shared_load
  50. self.required_cuda_ver_major = required_cuda_ver_major
  51. self.required_cuda_ver_minor = required_cuda_ver_minor
  52. #
  53. def accumulator_type(self):
  54. accum = self.tile_description.math_instruction.element_accumulator
  55. return accum
  56. #
  57. def core_name(self):
  58. """ The basic operation kind is prefixed with a letter indicating the accumulation type. """
  59. intermediate_type = ""
  60. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
  61. inst_shape = "%d%d%d" % tuple(
  62. self.tile_description.math_instruction.instruction_shape
  63. )
  64. if (
  65. self.tile_description.math_instruction.element_a != self.flt.element
  66. and self.tile_description.math_instruction.element_a
  67. != self.accumulator_type()
  68. ):
  69. intermediate_type = DataTypeNames[
  70. self.tile_description.math_instruction.element_a
  71. ]
  72. else:
  73. inst_shape = ""
  74. special_opt = ""
  75. if self.special_optimization == SpecialOptimizeDesc.ConvFilterUnity:
  76. special_opt = "_1x1"
  77. elif self.special_optimization == SpecialOptimizeDesc.DeconvDoubleUpsampling:
  78. special_opt = "_s2"
  79. reorder_k = ""
  80. if self.without_shared_load:
  81. reorder_k = "_roc"
  82. conv_type_name = ""
  83. if self.conv_type == ConvType.DepthwiseConvolution:
  84. conv_type_name = "dw"
  85. return "%s%s%s%s%s%s%s_%s" % (
  86. ShortDataTypeNames[self.accumulator_type()],
  87. inst_shape,
  88. intermediate_type,
  89. conv_type_name,
  90. ConvKindNames[self.conv_kind],
  91. special_opt,
  92. reorder_k,
  93. ShortEpilogueNames[self.epilogue_functor],
  94. )
  95. #
  96. def extended_name(self):
  97. if (
  98. self.dst.element
  99. != self.tile_description.math_instruction.element_accumulator
  100. ):
  101. if self.src.element != self.flt.element:
  102. extended_name = (
  103. "${element_dst}_${core_name}_${element_src}_${element_flt}"
  104. )
  105. elif self.src.element == self.flt.element:
  106. extended_name = "${element_dst}_${core_name}_${element_src}"
  107. else:
  108. if self.src.element != self.flt.element:
  109. extended_name = "${core_name}_${element_src}_${element_flt}"
  110. elif self.src.element == self.flt.element:
  111. extended_name = "${core_name}_${element_src}"
  112. extended_name = SubstituteTemplate(
  113. extended_name,
  114. {
  115. "element_src": DataTypeNames[self.src.element],
  116. "element_flt": DataTypeNames[self.flt.element],
  117. "element_dst": DataTypeNames[self.dst.element],
  118. "core_name": self.core_name(),
  119. },
  120. )
  121. return extended_name
  122. #
  123. def layout_name(self):
  124. if self.src.layout == self.dst.layout:
  125. layout_name = "${src_layout}_${flt_layout}"
  126. else:
  127. layout_name = "${src_layout}_${flt_layout}_${dst_layout}"
  128. layout_name = SubstituteTemplate(
  129. layout_name,
  130. {
  131. "src_layout": ShortLayoutTypeNames[self.src.layout],
  132. "flt_layout": ShortLayoutTypeNames[self.flt.layout],
  133. "dst_layout": ShortLayoutTypeNames[self.dst.layout],
  134. },
  135. )
  136. return layout_name
  137. #
  138. def configuration_name(self):
  139. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  140. opcode_class_name = OpcodeClassNames[
  141. self.tile_description.math_instruction.opcode_class
  142. ]
  143. warp_shape = [
  144. int(
  145. self.tile_description.threadblock_shape[idx]
  146. / self.tile_description.warp_count[idx]
  147. )
  148. for idx in range(3)
  149. ]
  150. threadblock = "%dx%dx%d_%dx%dx%d_%d" % (
  151. self.tile_description.threadblock_shape[0],
  152. self.tile_description.threadblock_shape[1],
  153. self.tile_description.threadblock_shape[2],
  154. warp_shape[0],
  155. warp_shape[1],
  156. warp_shape[2],
  157. self.tile_description.stages,
  158. )
  159. alignment = "align%dx%d" % (self.src.alignment, self.flt.alignment)
  160. configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${alignment}"
  161. return SubstituteTemplate(
  162. configuration_name,
  163. {
  164. "opcode_class": opcode_class_name,
  165. "extended_name": self.extended_name(),
  166. "threadblock": threadblock,
  167. "layout": self.layout_name(),
  168. "alignment": alignment,
  169. },
  170. )
  171. #
  172. def procedural_name(self):
  173. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  174. return self.configuration_name()
  175. ###################################################################################################
  176. #
  177. # Emits single instances of a CUTLASS device-wide operator
  178. #
  179. ###################################################################################################
  180. class EmitConv2dInstance:
  181. def __init__(self):
  182. self.template = """
  183. // kernel instance "${operation_name}" generated by cutlass generator
  184. using Convolution_${operation_name} =
  185. typename cutlass::conv::device::Convolution<
  186. ${element_src},
  187. ${layout_src},
  188. ${element_flt},
  189. ${layout_flt},
  190. ${element_dst},
  191. ${layout_dst},
  192. ${element_bias},
  193. ${layout_bias},
  194. ${element_accumulator},
  195. ${conv_type},
  196. ${opcode_class},
  197. ${arch},
  198. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  199. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  200. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  201. ${epilogue_functor}<
  202. ${element_dst},
  203. ${epilogue_vector_length},
  204. ${element_accumulator},
  205. ${element_bias},
  206. ${element_epilogue}
  207. >,
  208. ${swizzling_functor},
  209. ${stages},
  210. ${alignment_src},
  211. ${alignment_filter},
  212. ${special_optimization},
  213. ${math_operator},
  214. ${implicit_gemm_mode},
  215. ${without_shared_load}>;
  216. """
  217. def emit(self, operation):
  218. warp_shape = [
  219. int(
  220. operation.tile_description.threadblock_shape[idx]
  221. / operation.tile_description.warp_count[idx]
  222. )
  223. for idx in range(3)
  224. ]
  225. epilogue_vector_length = int(
  226. min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128)
  227. / DataTypeSize[operation.dst.element]
  228. )
  229. values = {
  230. "operation_name": operation.procedural_name(),
  231. "conv_type": ConvTypeTag[operation.conv_type],
  232. "element_src": DataTypeTag[operation.src.element],
  233. "layout_src": LayoutTag[operation.src.layout],
  234. "element_flt": DataTypeTag[operation.flt.element],
  235. "layout_flt": LayoutTag[operation.flt.layout],
  236. "element_dst": DataTypeTag[operation.dst.element],
  237. "layout_dst": LayoutTag[operation.dst.layout],
  238. "element_bias": DataTypeTag[operation.bias.element],
  239. "layout_bias": LayoutTag[operation.bias.layout],
  240. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  241. "opcode_class": OpcodeClassTag[
  242. operation.tile_description.math_instruction.opcode_class
  243. ],
  244. "arch": "cutlass::arch::Sm%d" % operation.arch,
  245. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  246. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  247. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  248. "warp_shape_m": str(warp_shape[0]),
  249. "warp_shape_n": str(warp_shape[1]),
  250. "warp_shape_k": str(warp_shape[2]),
  251. "instruction_shape_m": str(
  252. operation.tile_description.math_instruction.instruction_shape[0]
  253. ),
  254. "instruction_shape_n": str(
  255. operation.tile_description.math_instruction.instruction_shape[1]
  256. ),
  257. "instruction_shape_k": str(
  258. operation.tile_description.math_instruction.instruction_shape[2]
  259. ),
  260. "epilogue_vector_length": str(epilogue_vector_length),
  261. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  262. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  263. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  264. "stages": str(operation.tile_description.stages),
  265. "alignment_src": str(operation.src.alignment),
  266. "alignment_filter": str(operation.flt.alignment),
  267. "special_optimization": SpecialOptimizeDescTag[
  268. operation.special_optimization
  269. ],
  270. "math_operator": MathOperationTag[
  271. operation.tile_description.math_instruction.math_operation
  272. ],
  273. "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode],
  274. "without_shared_load": str(operation.without_shared_load).lower(),
  275. }
  276. return SubstituteTemplate(self.template, values)
  277. class EmitDeconvInstance:
  278. def __init__(self):
  279. self.template = """
  280. // kernel instance "${operation_name}" generated by cutlass generator
  281. using Convolution_${operation_name} =
  282. typename cutlass::conv::device::Deconvolution<
  283. ${element_src},
  284. ${layout_src},
  285. ${element_flt},
  286. ${layout_flt},
  287. ${element_dst},
  288. ${layout_dst},
  289. ${element_bias},
  290. ${layout_bias},
  291. ${element_accumulator},
  292. ${conv_type},
  293. ${opcode_class},
  294. ${arch},
  295. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  296. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  297. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  298. ${epilogue_functor}<
  299. ${element_dst},
  300. ${epilogue_vector_length},
  301. ${element_accumulator},
  302. ${element_bias},
  303. ${element_epilogue}
  304. >,
  305. ${swizzling_functor},
  306. ${stages},
  307. ${alignment_src},
  308. ${alignment_filter},
  309. ${special_optimization},
  310. ${math_operator},
  311. ${implicit_gemm_mode}>;
  312. """
  313. def emit(self, operation):
  314. warp_shape = [
  315. int(
  316. operation.tile_description.threadblock_shape[idx]
  317. / operation.tile_description.warp_count[idx]
  318. )
  319. for idx in range(3)
  320. ]
  321. epilogue_vector_length = int(
  322. min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128)
  323. / DataTypeSize[operation.dst.element]
  324. )
  325. values = {
  326. "operation_name": operation.procedural_name(),
  327. "conv_type": ConvTypeTag[operation.conv_type],
  328. "element_src": DataTypeTag[operation.src.element],
  329. "layout_src": LayoutTag[operation.src.layout],
  330. "element_flt": DataTypeTag[operation.flt.element],
  331. "layout_flt": LayoutTag[operation.flt.layout],
  332. "element_dst": DataTypeTag[operation.dst.element],
  333. "layout_dst": LayoutTag[operation.dst.layout],
  334. "element_bias": DataTypeTag[operation.bias.element],
  335. "layout_bias": LayoutTag[operation.bias.layout],
  336. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  337. "opcode_class": OpcodeClassTag[
  338. operation.tile_description.math_instruction.opcode_class
  339. ],
  340. "arch": "cutlass::arch::Sm%d" % operation.arch,
  341. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  342. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  343. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  344. "warp_shape_m": str(warp_shape[0]),
  345. "warp_shape_n": str(warp_shape[1]),
  346. "warp_shape_k": str(warp_shape[2]),
  347. "instruction_shape_m": str(
  348. operation.tile_description.math_instruction.instruction_shape[0]
  349. ),
  350. "instruction_shape_n": str(
  351. operation.tile_description.math_instruction.instruction_shape[1]
  352. ),
  353. "instruction_shape_k": str(
  354. operation.tile_description.math_instruction.instruction_shape[2]
  355. ),
  356. "epilogue_vector_length": str(epilogue_vector_length),
  357. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  358. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  359. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  360. "stages": str(operation.tile_description.stages),
  361. "alignment_src": str(operation.src.alignment),
  362. "alignment_filter": str(operation.flt.alignment),
  363. "special_optimization": SpecialOptimizeDescTag[
  364. operation.special_optimization
  365. ],
  366. "math_operator": MathOperationTag[
  367. operation.tile_description.math_instruction.math_operation
  368. ],
  369. "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode],
  370. }
  371. return SubstituteTemplate(self.template, values)
  372. class EmitConvolutionBackwardFilterInstance:
  373. def __init__(self):
  374. self.template = """
  375. // kernel instance "${operation_name}" generated by cutlass generator
  376. using Convolution_${operation_name} =
  377. typename cutlass::conv::device::ConvolutionBackwardFilter<
  378. ${element_src},
  379. ${layout_src},
  380. ${element_diff},
  381. ${layout_diff},
  382. ${element_grad},
  383. ${layout_grad},
  384. ${element_accumulator},
  385. ${conv_type},
  386. ${opcode_class},
  387. ${arch},
  388. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  389. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  390. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  391. ${epilogue_functor}<
  392. ${element_grad},
  393. ${epilogue_vector_length},
  394. ${element_accumulator},
  395. ${element_epilogue}
  396. >,
  397. ${swizzling_functor},
  398. ${stages},
  399. ${alignment_src},
  400. ${alignment_diff},
  401. ${special_optimization},
  402. ${math_operator},
  403. ${implicit_gemm_mode}>;
  404. """
  405. def emit(self, operation):
  406. warp_shape = [
  407. int(
  408. operation.tile_description.threadblock_shape[idx]
  409. / operation.tile_description.warp_count[idx]
  410. )
  411. for idx in range(3)
  412. ]
  413. epilogue_vector_length = int(
  414. min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128)
  415. / DataTypeSize[operation.dst.element]
  416. )
  417. values = {
  418. "operation_name": operation.procedural_name(),
  419. "conv_type": ConvTypeTag[operation.conv_type],
  420. "element_src": DataTypeTag[operation.src.element],
  421. "layout_src": LayoutTag[operation.src.layout],
  422. "element_diff": DataTypeTag[operation.flt.element],
  423. "layout_diff": LayoutTag[operation.flt.layout],
  424. "element_grad": DataTypeTag[operation.dst.element],
  425. "layout_grad": LayoutTag[operation.dst.layout],
  426. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  427. "opcode_class": OpcodeClassTag[
  428. operation.tile_description.math_instruction.opcode_class
  429. ],
  430. "arch": "cutlass::arch::Sm%d" % operation.arch,
  431. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  432. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  433. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  434. "warp_shape_m": str(warp_shape[0]),
  435. "warp_shape_n": str(warp_shape[1]),
  436. "warp_shape_k": str(warp_shape[2]),
  437. "instruction_shape_m": str(
  438. operation.tile_description.math_instruction.instruction_shape[0]
  439. ),
  440. "instruction_shape_n": str(
  441. operation.tile_description.math_instruction.instruction_shape[1]
  442. ),
  443. "instruction_shape_k": str(
  444. operation.tile_description.math_instruction.instruction_shape[2]
  445. ),
  446. "epilogue_vector_length": str(epilogue_vector_length),
  447. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  448. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  449. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  450. "stages": str(operation.tile_description.stages),
  451. "alignment_src": str(operation.src.alignment),
  452. "alignment_diff": str(operation.flt.alignment),
  453. "special_optimization": SpecialOptimizeDescTag[
  454. operation.special_optimization
  455. ],
  456. "math_operator": MathOperationTag[
  457. operation.tile_description.math_instruction.math_operation
  458. ],
  459. "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode],
  460. }
  461. return SubstituteTemplate(self.template, values)
  462. ###################################################################################################
  463. #
  464. # Generator functions for all layouts
  465. #
  466. ###################################################################################################
  467. #
  468. def GenerateConv2d(
  469. conv_type,
  470. conv_kind,
  471. tile_descriptions,
  472. src_layout,
  473. flt_layout,
  474. dst_layout,
  475. dst_type,
  476. min_cc,
  477. src_align=32,
  478. flt_align=32,
  479. dst_align=32,
  480. use_special_optimization=SpecialOptimizeDesc.NoneSpecialOpt,
  481. implicit_gemm_mode=ImplicitGemmMode.GemmNT,
  482. without_shared_load=False,
  483. required_cuda_ver_major=9,
  484. required_cuda_ver_minor=2,
  485. ):
  486. operations = []
  487. element_epilogue = DataType.f32
  488. if conv_type == ConvType.DepthwiseConvolution:
  489. if conv_kind == ConvKind.Fprop:
  490. swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionFprop
  491. elif conv_kind == ConvKind.Dgrad:
  492. swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionDgrad
  493. else:
  494. assert conv_kind == ConvKind.Wgrad
  495. swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionWgrad
  496. elif conv_type == ConvType.Convolution:
  497. if conv_kind == ConvKind.Fprop:
  498. if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
  499. swizzling_functor = SwizzlingFunctor.ConvFpropTrans
  500. else:
  501. swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx
  502. else:
  503. if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
  504. swizzling_functor = SwizzlingFunctor.ConvDgradTrans
  505. else:
  506. swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx
  507. # skip rule
  508. def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool:
  509. return (
  510. layout == LayoutType.TensorNC32HW32 and tile.threadblock_shape[0] % 32 != 0
  511. )
  512. # rule for bias_type and epilogues
  513. def get_bias_type_and_epilogues(
  514. tile: TileDescription, out_dtype: DataType
  515. ) -> Tuple[DataType, List[EpilogueFunctor]]:
  516. if (
  517. tile.math_instruction.element_accumulator == DataType.s32
  518. and out_dtype != DataType.f32
  519. ):
  520. bias_type = DataType.s32
  521. if tile.math_instruction.element_b == DataType.u4:
  522. epilogues = [
  523. EpilogueFunctor.BiasAddLinearCombinationClamp,
  524. EpilogueFunctor.BiasAddLinearCombinationReluClamp,
  525. ]
  526. else:
  527. epilogues = [
  528. EpilogueFunctor.BiasAddLinearCombinationClamp,
  529. EpilogueFunctor.BiasAddLinearCombinationReluClamp,
  530. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp,
  531. ]
  532. elif (
  533. tile.math_instruction.element_accumulator == DataType.f32
  534. or tile.math_instruction.element_accumulator == DataType.f16
  535. ) or (
  536. tile.math_instruction.element_accumulator == DataType.s32
  537. and out_dtype == DataType.f32
  538. ):
  539. bias_type = out_dtype
  540. epilogues = [
  541. EpilogueFunctor.BiasAddLinearCombination,
  542. EpilogueFunctor.BiasAddLinearCombinationRelu,
  543. EpilogueFunctor.LinearCombination,
  544. ]
  545. if conv_type == ConvType.Convolution:
  546. epilogues.append(EpilogueFunctor.BiasAddLinearCombinationHSwish)
  547. else:
  548. assert False, "invalid path"
  549. return bias_type, epilogues
  550. # rule for filter alignment
  551. def get_flt_align(tile: TileDescription) -> int:
  552. nonlocal flt_align
  553. if (
  554. tile.math_instruction.opcode_class == OpcodeClass.Simt
  555. and tile.math_instruction.element_accumulator == DataType.s32
  556. ):
  557. thread_num = (
  558. tile.warp_count[0] * tile.warp_count[1] * tile.warp_count[2] * 32
  559. )
  560. flt_block = (
  561. tile.threadblock_shape[0]
  562. * tile.threadblock_shape[2]
  563. * DataTypeSize[tile.math_instruction.element_a]
  564. )
  565. load_per_thread = flt_block // thread_num
  566. if load_per_thread >= 128:
  567. flt_align = 128
  568. elif load_per_thread >= 64:
  569. flt_align = 64
  570. else:
  571. assert load_per_thread >= 32
  572. flt_align = 32
  573. return flt_align
  574. def get_dst_align(tile: TileDescription, out_layout: LayoutType) -> int:
  575. nonlocal dst_align
  576. if (
  577. tile.math_instruction.opcode_class == OpcodeClass.TensorOp
  578. and dst_layout == LayoutType.TensorNC4HW4
  579. ):
  580. dst_align = 32
  581. return dst_align
  582. def filter_epilogue_with_conv_kind(
  583. epilogue: EpilogueFunctor, conv_kind: ConvKind
  584. ) -> bool:
  585. if conv_kind == ConvKind.Fprop:
  586. return epilogue == EpilogueFunctor.LinearCombination
  587. elif conv_kind == ConvKind.Dgrad:
  588. return (
  589. epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp
  590. and epilogue != EpilogueFunctor.BiasAddLinearCombination
  591. )
  592. elif conv_kind == ConvKind.Wgrad:
  593. return epilogue != EpilogueFunctor.LinearCombination
  594. # loop over all tile descriptions
  595. for tile in tile_descriptions:
  596. if filter_tile_with_layout(tile, dst_layout):
  597. continue
  598. bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type)
  599. flt_align = flt_align if conv_kind == ConvKind.Wgrad else get_flt_align(tile)
  600. dst_align = get_dst_align(tile, dst_layout)
  601. for epilogue in epilogues:
  602. if filter_epilogue_with_conv_kind(epilogue, conv_kind):
  603. continue
  604. if dst_type == DataType.f32:
  605. bias_type = DataType.f32
  606. #
  607. src = TensorDescription(
  608. tile.math_instruction.element_b,
  609. src_layout,
  610. int(src_align / DataTypeSize[tile.math_instruction.element_b]),
  611. )
  612. flt = TensorDescription(
  613. tile.math_instruction.element_a,
  614. flt_layout,
  615. int(flt_align / DataTypeSize[tile.math_instruction.element_a]),
  616. )
  617. bias = TensorDescription(
  618. bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))
  619. )
  620. dst = TensorDescription(
  621. dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type])
  622. )
  623. new_operation = Conv2dOperation(
  624. conv_kind,
  625. conv_type,
  626. min_cc,
  627. tile,
  628. src,
  629. flt,
  630. bias,
  631. dst,
  632. element_epilogue,
  633. epilogue,
  634. swizzling_functor,
  635. SpecialOptimizeDesc.NoneSpecialOpt,
  636. implicit_gemm_mode,
  637. without_shared_load,
  638. required_cuda_ver_major,
  639. required_cuda_ver_minor,
  640. )
  641. operations.append(new_operation)
  642. if use_special_optimization != SpecialOptimizeDesc.NoneSpecialOpt:
  643. new_operation = Conv2dOperation(
  644. conv_kind,
  645. conv_type,
  646. min_cc,
  647. tile,
  648. src,
  649. flt,
  650. bias,
  651. dst,
  652. element_epilogue,
  653. epilogue,
  654. swizzling_functor,
  655. use_special_optimization,
  656. implicit_gemm_mode,
  657. without_shared_load,
  658. required_cuda_ver_major,
  659. required_cuda_ver_minor,
  660. )
  661. operations.append(new_operation)
  662. return operations
  663. ###################################################################################################
  664. #
  665. # Emitters functions for all targets
  666. #
  667. ###################################################################################################
  668. class EmitConv2dConfigurationLibrary:
  669. def __init__(self, operation_path, configuration_name):
  670. self.configuration_name = configuration_name
  671. self.configuration_path = os.path.join(
  672. operation_path, "%s.cu" % configuration_name
  673. )
  674. self.instance_emitter = EmitConv2dInstance()
  675. self.instance_template = """
  676. ${operation_instance}
  677. // Derived class
  678. struct ${operation_name} :
  679. public ${operation_name}_base { };
  680. ///////////////////////////////////////////////////////////////////////////////////////////////////
  681. """
  682. self.header_template = """
  683. /*
  684. Generated by conv2d_operation.py - Do not edit.
  685. */
  686. ///////////////////////////////////////////////////////////////////////////////////////////////////
  687. #include "cutlass/cutlass.h"
  688. #include "cutlass/library/library.h"
  689. #include "cutlass/library/manifest.h"
  690. #include "library_internal.h"
  691. #include "conv2d_operation.h"
  692. ///////////////////////////////////////////////////////////////////////////////////////////////////
  693. """
  694. self.configuration_header = """
  695. namespace cutlass {
  696. namespace library {
  697. // Initialize all instances
  698. void initialize_${configuration_name}(Manifest &manifest) {
  699. """
  700. self.configuration_instance = """
  701. using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
  702. ${operation_name}>;
  703. manifest.append(new cutlass::library::Conv2dOperation<
  704. Operation_${operation_name}>(
  705. "${operation_name}"));
  706. """
  707. self.configuration_epilogue = """
  708. }
  709. """
  710. self.epilogue_template = """
  711. ///////////////////////////////////////////////////////////////////////////////////////////////////
  712. } // namespace library
  713. } // namespace cutlass
  714. ///////////////////////////////////////////////////////////////////////////////////////////////////
  715. """
  716. #
  717. def __enter__(self):
  718. self.configuration_file = open(self.configuration_path, "w")
  719. self.configuration_file.write(
  720. SubstituteTemplate(
  721. self.header_template, {"configuration_name": self.configuration_name}
  722. )
  723. )
  724. self.operations = []
  725. return self
  726. #
  727. def emit(self, operation):
  728. self.operations.append(operation)
  729. self.configuration_file.write(
  730. SubstituteTemplate(
  731. self.instance_template,
  732. {
  733. "configuration_name": self.configuration_name,
  734. "operation_name": operation.procedural_name(),
  735. "operation_instance": self.instance_emitter.emit(operation),
  736. },
  737. )
  738. )
  739. #
  740. def __exit__(self, exception_type, exception_value, traceback):
  741. self.configuration_file.write(
  742. SubstituteTemplate(
  743. self.configuration_header,
  744. {"configuration_name": self.configuration_name},
  745. )
  746. )
  747. for operation in self.operations:
  748. self.configuration_file.write(
  749. SubstituteTemplate(
  750. self.configuration_instance,
  751. {
  752. "configuration_name": self.configuration_name,
  753. "operation_name": operation.procedural_name(),
  754. },
  755. )
  756. )
  757. self.configuration_file.write(self.configuration_epilogue)
  758. self.configuration_file.write(self.epilogue_template)
  759. self.configuration_file.close()
  760. ###################################################################################################
  761. ###################################################################################################
  762. # Emitters for Conv Kernel Wrapper
  763. #
  764. ###################################################################################################
  765. class EmitConvSingleKernelWrapper:
  766. def __init__(self, kernel_path, operation, short_path=False):
  767. self.kernel_path = kernel_path
  768. self.operation = operation
  769. self.short_path = short_path
  770. if self.operation.conv_kind == ConvKind.Fprop:
  771. self.instance_emitter = EmitConv2dInstance()
  772. self.convolution_name = "ConvolutionOperation"
  773. elif self.operation.conv_kind == ConvKind.Dgrad:
  774. self.instance_emitter = EmitDeconvInstance()
  775. self.convolution_name = "ConvolutionOperation"
  776. else:
  777. assert self.operation.conv_kind == ConvKind.Wgrad
  778. self.instance_emitter = EmitConvolutionBackwardFilterInstance()
  779. self.convolution_name = "ConvolutionBackwardFilterOperation"
  780. self.header_template = """
  781. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  782. // ignore warning of cutlass
  783. #pragma GCC diagnostic push
  784. #pragma GCC diagnostic ignored "-Wunused-parameter"
  785. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  786. #pragma GCC diagnostic ignored "-Wuninitialized"
  787. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  788. #include "cutlass/convolution/device/convolution.h"
  789. #include "src/cuda/cutlass/manifest.h"
  790. #include "src/cuda/cutlass/convolution_operation.h"
  791. """
  792. self.instance_template = """
  793. ${operation_instance}
  794. """
  795. self.manifest_template = """
  796. namespace cutlass {
  797. namespace library {
  798. void initialize_${operation_name}(Manifest &manifest) {
  799. manifest.append(new ${convolution_name}<Convolution_${operation_name}>(
  800. "${operation_name}"
  801. ));
  802. }
  803. } // namespace library
  804. } // namespace cutlass
  805. """
  806. self.epilogue_template = """
  807. #pragma GCC diagnostic pop
  808. #endif
  809. """
  810. #
  811. def __enter__(self):
  812. if self.short_path:
  813. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  814. GlobalCnt.cnt += 1
  815. else:
  816. self.kernel_path = os.path.join(
  817. self.kernel_path, "%s.cu" % self.operation.procedural_name()
  818. )
  819. self.kernel_file = open(self.kernel_path, "w")
  820. return self
  821. #
  822. def emit(self):
  823. self.kernel_file.write(
  824. SubstituteTemplate(
  825. self.instance_template,
  826. {"operation_instance": self.instance_emitter.emit(self.operation)},
  827. )
  828. )
  829. # emit manifest helper
  830. manifest = SubstituteTemplate(
  831. self.manifest_template,
  832. {
  833. "operation_name": self.operation.procedural_name(),
  834. "convolution_name": self.convolution_name,
  835. },
  836. )
  837. self.kernel_file.write(manifest)
  838. #
  839. def __exit__(self, exception_type, exception_value, traceback):
  840. self.kernel_file.close()
  841. ###################################################################################################
  842. ###################################################################################################