You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv2d_operation.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. #
  7. import enum
  8. import os.path
  9. import shutil
  10. from typing import Tuple, List
  11. from library import *
  12. ###################################################################################################
  13. #
  14. class Conv2dOperation:
  15. #
  16. def __init__(
  17. self,
  18. conv_kind,
  19. conv_type,
  20. arch,
  21. tile_description,
  22. src,
  23. flt,
  24. bias,
  25. dst,
  26. element_epilogue,
  27. epilogue_functor=EpilogueFunctor.LinearCombination,
  28. swizzling_functor=SwizzlingFunctor.Identity4,
  29. special_optimization=SpecialOptimizeDesc.NoneSpecialOpt,
  30. implicit_gemm_mode=ImplicitGemmMode.GemmNT,
  31. without_shared_load=False,
  32. required_cuda_ver_major=9,
  33. required_cuda_ver_minor=2,
  34. ):
  35. self.operation_kind = OperationKind.Conv2d
  36. self.conv_kind = conv_kind
  37. self.arch = arch
  38. self.tile_description = tile_description
  39. self.conv_type = conv_type
  40. self.src = src
  41. self.flt = flt
  42. self.bias = bias
  43. self.dst = dst
  44. self.element_epilogue = element_epilogue
  45. self.epilogue_functor = epilogue_functor
  46. self.swizzling_functor = swizzling_functor
  47. self.special_optimization = special_optimization
  48. self.implicit_gemm_mode = implicit_gemm_mode
  49. self.without_shared_load = without_shared_load
  50. self.required_cuda_ver_major = required_cuda_ver_major
  51. self.required_cuda_ver_minor = required_cuda_ver_minor
  52. #
  53. def accumulator_type(self):
  54. accum = self.tile_description.math_instruction.element_accumulator
  55. return accum
  56. #
  57. def core_name(self):
  58. """ The basic operation kind is prefixed with a letter indicating the accumulation type. """
  59. intermediate_type = ""
  60. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
  61. inst_shape = "%d%d%d" % tuple(
  62. self.tile_description.math_instruction.instruction_shape
  63. )
  64. if (
  65. self.tile_description.math_instruction.element_a != self.flt.element
  66. and self.tile_description.math_instruction.element_a
  67. != self.accumulator_type()
  68. ):
  69. intermediate_type = DataTypeNames[
  70. self.tile_description.math_instruction.element_a
  71. ]
  72. else:
  73. inst_shape = ""
  74. special_opt = ""
  75. if self.special_optimization == SpecialOptimizeDesc.ConvFilterUnity:
  76. special_opt = "_1x1"
  77. elif self.special_optimization == SpecialOptimizeDesc.DeconvDoubleUpsampling:
  78. special_opt = "_s2"
  79. reorder_k = ""
  80. if self.without_shared_load:
  81. reorder_k = "_roc"
  82. conv_type_name = ""
  83. if self.conv_type == ConvType.DepthwiseConvolution:
  84. conv_type_name = "dw"
  85. return "%s%s%s%s%s%s%s_%s" % (
  86. ShortDataTypeNames[self.accumulator_type()],
  87. inst_shape,
  88. intermediate_type,
  89. conv_type_name,
  90. ConvKindNames[self.conv_kind],
  91. special_opt,
  92. reorder_k,
  93. ShortEpilogueNames[self.epilogue_functor],
  94. )
  95. #
  96. def extended_name(self):
  97. if (
  98. self.dst.element
  99. != self.tile_description.math_instruction.element_accumulator
  100. ):
  101. if self.src.element != self.flt.element:
  102. extended_name = (
  103. "${element_dst}_${core_name}_${element_src}_${element_flt}"
  104. )
  105. elif self.src.element == self.flt.element:
  106. extended_name = "${element_dst}_${core_name}_${element_src}"
  107. else:
  108. if self.src.element != self.flt.element:
  109. extended_name = "${core_name}_${element_src}_${element_flt}"
  110. elif self.src.element == self.flt.element:
  111. extended_name = "${core_name}_${element_src}"
  112. extended_name = SubstituteTemplate(
  113. extended_name,
  114. {
  115. "element_src": DataTypeNames[self.src.element],
  116. "element_flt": DataTypeNames[self.flt.element],
  117. "element_dst": DataTypeNames[self.dst.element],
  118. "core_name": self.core_name(),
  119. },
  120. )
  121. return extended_name
  122. #
  123. def layout_name(self):
  124. if self.src.layout == self.dst.layout:
  125. layout_name = "${src_layout}_${flt_layout}"
  126. else:
  127. layout_name = "${src_layout}_${flt_layout}_${dst_layout}"
  128. layout_name = SubstituteTemplate(
  129. layout_name,
  130. {
  131. "src_layout": ShortLayoutTypeNames[self.src.layout],
  132. "flt_layout": ShortLayoutTypeNames[self.flt.layout],
  133. "dst_layout": ShortLayoutTypeNames[self.dst.layout],
  134. },
  135. )
  136. return layout_name
  137. #
  138. def configuration_name(self):
  139. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  140. opcode_class_name = OpcodeClassNames[
  141. self.tile_description.math_instruction.opcode_class
  142. ]
  143. warp_shape = [
  144. int(
  145. self.tile_description.threadblock_shape[idx]
  146. / self.tile_description.warp_count[idx]
  147. )
  148. for idx in range(3)
  149. ]
  150. threadblock = "%dx%dx%d_%dx%dx%d_%d" % (
  151. self.tile_description.threadblock_shape[0],
  152. self.tile_description.threadblock_shape[1],
  153. self.tile_description.threadblock_shape[2],
  154. warp_shape[0],
  155. warp_shape[1],
  156. warp_shape[2],
  157. self.tile_description.stages,
  158. )
  159. alignment = "align%dx%d" % (self.src.alignment, self.flt.alignment)
  160. configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${alignment}"
  161. return SubstituteTemplate(
  162. configuration_name,
  163. {
  164. "opcode_class": opcode_class_name,
  165. "extended_name": self.extended_name(),
  166. "threadblock": threadblock,
  167. "layout": self.layout_name(),
  168. "alignment": alignment,
  169. },
  170. )
  171. #
  172. def procedural_name(self):
  173. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  174. return self.configuration_name()
  175. ###################################################################################################
  176. #
  177. # Emits single instances of a CUTLASS device-wide operator
  178. #
  179. ###################################################################################################
  180. class EmitConv2dInstance:
  181. def __init__(self):
  182. self.template = """
  183. // kernel instance "${operation_name}" generated by cutlass generator
  184. using Convolution =
  185. typename cutlass::conv::device::Convolution<
  186. ${element_src},
  187. ${layout_src},
  188. ${element_flt},
  189. ${layout_flt},
  190. ${element_dst},
  191. ${layout_dst},
  192. ${element_bias},
  193. ${layout_bias},
  194. ${element_accumulator},
  195. ${conv_type},
  196. ${opcode_class},
  197. ${arch},
  198. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  199. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  200. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  201. ${epilogue_functor}<
  202. ${element_dst},
  203. ${epilogue_vector_length},
  204. ${element_accumulator},
  205. ${element_bias},
  206. ${element_epilogue}
  207. >,
  208. ${swizzling_functor},
  209. ${stages},
  210. ${alignment_src},
  211. ${alignment_filter},
  212. ${special_optimization},
  213. ${math_operator},
  214. ${implicit_gemm_mode},
  215. ${without_shared_load}>;
  216. """
  217. def emit(self, operation):
  218. warp_shape = [
  219. int(
  220. operation.tile_description.threadblock_shape[idx]
  221. / operation.tile_description.warp_count[idx]
  222. )
  223. for idx in range(3)
  224. ]
  225. epilogue_vector_length = int(
  226. min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128)
  227. / DataTypeSize[operation.dst.element]
  228. )
  229. values = {
  230. "operation_name": operation.procedural_name(),
  231. "conv_type": ConvTypeTag[operation.conv_type],
  232. "element_src": DataTypeTag[operation.src.element],
  233. "layout_src": LayoutTag[operation.src.layout],
  234. "element_flt": DataTypeTag[operation.flt.element],
  235. "layout_flt": LayoutTag[operation.flt.layout],
  236. "element_dst": DataTypeTag[operation.dst.element],
  237. "layout_dst": LayoutTag[operation.dst.layout],
  238. "element_bias": DataTypeTag[operation.bias.element],
  239. "layout_bias": LayoutTag[operation.bias.layout],
  240. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  241. "opcode_class": OpcodeClassTag[
  242. operation.tile_description.math_instruction.opcode_class
  243. ],
  244. "arch": "cutlass::arch::Sm%d" % operation.arch,
  245. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  246. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  247. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  248. "warp_shape_m": str(warp_shape[0]),
  249. "warp_shape_n": str(warp_shape[1]),
  250. "warp_shape_k": str(warp_shape[2]),
  251. "instruction_shape_m": str(
  252. operation.tile_description.math_instruction.instruction_shape[0]
  253. ),
  254. "instruction_shape_n": str(
  255. operation.tile_description.math_instruction.instruction_shape[1]
  256. ),
  257. "instruction_shape_k": str(
  258. operation.tile_description.math_instruction.instruction_shape[2]
  259. ),
  260. "epilogue_vector_length": str(epilogue_vector_length),
  261. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  262. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  263. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  264. "stages": str(operation.tile_description.stages),
  265. "alignment_src": str(operation.src.alignment),
  266. "alignment_filter": str(operation.flt.alignment),
  267. "special_optimization": SpecialOptimizeDescTag[
  268. operation.special_optimization
  269. ],
  270. "math_operator": MathOperationTag[
  271. operation.tile_description.math_instruction.math_operation
  272. ],
  273. "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode],
  274. "without_shared_load": str(operation.without_shared_load).lower(),
  275. }
  276. return SubstituteTemplate(self.template, values)
  277. class EmitDeconvInstance:
  278. def __init__(self):
  279. self.template = """
  280. // kernel instance "${operation_name}" generated by cutlass generator
  281. using Deconvolution =
  282. typename cutlass::conv::device::Deconvolution<
  283. ${element_src},
  284. ${layout_src},
  285. ${element_flt},
  286. ${layout_flt},
  287. ${element_dst},
  288. ${layout_dst},
  289. ${element_bias},
  290. ${layout_bias},
  291. ${element_accumulator},
  292. ${conv_type},
  293. ${opcode_class},
  294. ${arch},
  295. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  296. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  297. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  298. ${epilogue_functor}<
  299. ${element_dst},
  300. ${epilogue_vector_length},
  301. ${element_accumulator},
  302. ${element_bias},
  303. ${element_epilogue}
  304. >,
  305. ${swizzling_functor},
  306. ${stages},
  307. ${alignment_src},
  308. ${alignment_filter},
  309. ${special_optimization},
  310. ${math_operator},
  311. ${implicit_gemm_mode}>;
  312. """
  313. def emit(self, operation):
  314. warp_shape = [
  315. int(
  316. operation.tile_description.threadblock_shape[idx]
  317. / operation.tile_description.warp_count[idx]
  318. )
  319. for idx in range(3)
  320. ]
  321. epilogue_vector_length = int(
  322. min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128)
  323. / DataTypeSize[operation.dst.element]
  324. )
  325. values = {
  326. "operation_name": operation.procedural_name(),
  327. "conv_type": ConvTypeTag[operation.conv_type],
  328. "element_src": DataTypeTag[operation.src.element],
  329. "layout_src": LayoutTag[operation.src.layout],
  330. "element_flt": DataTypeTag[operation.flt.element],
  331. "layout_flt": LayoutTag[operation.flt.layout],
  332. "element_dst": DataTypeTag[operation.dst.element],
  333. "layout_dst": LayoutTag[operation.dst.layout],
  334. "element_bias": DataTypeTag[operation.bias.element],
  335. "layout_bias": LayoutTag[operation.bias.layout],
  336. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  337. "opcode_class": OpcodeClassTag[
  338. operation.tile_description.math_instruction.opcode_class
  339. ],
  340. "arch": "cutlass::arch::Sm%d" % operation.arch,
  341. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  342. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  343. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  344. "warp_shape_m": str(warp_shape[0]),
  345. "warp_shape_n": str(warp_shape[1]),
  346. "warp_shape_k": str(warp_shape[2]),
  347. "instruction_shape_m": str(
  348. operation.tile_description.math_instruction.instruction_shape[0]
  349. ),
  350. "instruction_shape_n": str(
  351. operation.tile_description.math_instruction.instruction_shape[1]
  352. ),
  353. "instruction_shape_k": str(
  354. operation.tile_description.math_instruction.instruction_shape[2]
  355. ),
  356. "epilogue_vector_length": str(epilogue_vector_length),
  357. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  358. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  359. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  360. "stages": str(operation.tile_description.stages),
  361. "alignment_src": str(operation.src.alignment),
  362. "alignment_filter": str(operation.flt.alignment),
  363. "special_optimization": SpecialOptimizeDescTag[
  364. operation.special_optimization
  365. ],
  366. "math_operator": MathOperationTag[
  367. operation.tile_description.math_instruction.math_operation
  368. ],
  369. "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode],
  370. }
  371. return SubstituteTemplate(self.template, values)
  372. ###################################################################################################
  373. #
  374. # Generator functions for all layouts
  375. #
  376. ###################################################################################################
  377. #
  378. def GenerateConv2d(
  379. conv_type,
  380. conv_kind,
  381. tile_descriptions,
  382. src_layout,
  383. flt_layout,
  384. dst_layout,
  385. dst_type,
  386. min_cc,
  387. src_align=32,
  388. flt_align=32,
  389. dst_align=32,
  390. use_special_optimization=SpecialOptimizeDesc.NoneSpecialOpt,
  391. implicit_gemm_mode=ImplicitGemmMode.GemmNT,
  392. without_shared_load=False,
  393. required_cuda_ver_major=9,
  394. required_cuda_ver_minor=2,
  395. ):
  396. operations = []
  397. element_epilogue = DataType.f32
  398. if conv_type == ConvType.DepthwiseConvolution:
  399. if conv_kind == ConvKind.Fprop:
  400. swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionFprop
  401. elif conv_kind == ConvKind.Dgrad:
  402. swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionDgrad
  403. else:
  404. assert conv_kind == ConvKind.Wgrad
  405. swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionWgrad
  406. elif conv_type == ConvType.Convolution:
  407. if conv_kind == ConvKind.Fprop:
  408. if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
  409. swizzling_functor = SwizzlingFunctor.ConvFpropTrans
  410. else:
  411. swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx
  412. else:
  413. if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
  414. swizzling_functor = SwizzlingFunctor.ConvDgradTrans
  415. else:
  416. swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx
  417. # skip rule
  418. def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool:
  419. return (
  420. layout == LayoutType.TensorNC32HW32 and tile.threadblock_shape[0] % 32 != 0
  421. )
  422. # rule for bias_type and epilogues
  423. def get_bias_type_and_epilogues(
  424. tile: TileDescription, out_dtype: DataType
  425. ) -> Tuple[DataType, List[EpilogueFunctor]]:
  426. if (
  427. tile.math_instruction.element_accumulator == DataType.s32
  428. and out_dtype != DataType.f32
  429. ):
  430. bias_type = DataType.s32
  431. if tile.math_instruction.element_b == DataType.u4:
  432. epilogues = [
  433. EpilogueFunctor.BiasAddLinearCombinationClamp,
  434. EpilogueFunctor.BiasAddLinearCombinationReluClamp,
  435. ]
  436. else:
  437. epilogues = [
  438. EpilogueFunctor.BiasAddLinearCombinationClamp,
  439. EpilogueFunctor.BiasAddLinearCombinationReluClamp,
  440. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp,
  441. ]
  442. elif (
  443. tile.math_instruction.element_accumulator == DataType.f32
  444. or tile.math_instruction.element_accumulator == DataType.f16
  445. ) or (
  446. tile.math_instruction.element_accumulator == DataType.s32
  447. and out_dtype == DataType.f32
  448. ):
  449. bias_type = out_dtype
  450. epilogues = [
  451. EpilogueFunctor.BiasAddLinearCombination,
  452. EpilogueFunctor.BiasAddLinearCombinationRelu,
  453. ]
  454. if conv_type == ConvType.Convolution:
  455. epilogues.append(EpilogueFunctor.BiasAddLinearCombinationHSwish)
  456. else:
  457. assert False, "invalid path"
  458. return bias_type, epilogues
  459. # rule for filter alignment
  460. def get_flt_align(tile: TileDescription) -> int:
  461. nonlocal flt_align
  462. if (
  463. tile.math_instruction.opcode_class == OpcodeClass.Simt
  464. and tile.math_instruction.element_accumulator == DataType.s32
  465. ):
  466. thread_num = (
  467. tile.warp_count[0] * tile.warp_count[1] * tile.warp_count[2] * 32
  468. )
  469. flt_block = (
  470. tile.threadblock_shape[0]
  471. * tile.threadblock_shape[2]
  472. * DataTypeSize[tile.math_instruction.element_a]
  473. )
  474. load_per_thread = flt_block // thread_num
  475. if load_per_thread >= 128:
  476. flt_align = 128
  477. elif load_per_thread >= 64:
  478. flt_align = 64
  479. else:
  480. assert load_per_thread >= 32
  481. flt_align = 32
  482. return flt_align
  483. def get_dst_align(tile: TileDescription, out_layout: LayoutType) -> int:
  484. nonlocal dst_align
  485. if (
  486. tile.math_instruction.opcode_class == OpcodeClass.TensorOp
  487. and dst_layout == LayoutType.TensorNC4HW4
  488. ):
  489. dst_align = 32
  490. return dst_align
  491. def filter_epilogue_with_conv_kind(
  492. epilogue: EpilogueFunctor, conv_kind: ConvKind
  493. ) -> bool:
  494. return (
  495. conv_kind == ConvKind.Dgrad
  496. and epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp
  497. )
  498. # loop over all tile descriptions
  499. for tile in tile_descriptions:
  500. if filter_tile_with_layout(tile, dst_layout):
  501. continue
  502. bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type)
  503. flt_align = get_flt_align(tile)
  504. dst_align = get_dst_align(tile, dst_layout)
  505. for epilogue in epilogues:
  506. if filter_epilogue_with_conv_kind(epilogue, conv_kind):
  507. continue
  508. if dst_type == DataType.f32:
  509. bias_type = DataType.f32
  510. #
  511. src = TensorDescription(
  512. tile.math_instruction.element_b,
  513. src_layout,
  514. int(src_align / DataTypeSize[tile.math_instruction.element_b]),
  515. )
  516. flt = TensorDescription(
  517. tile.math_instruction.element_a,
  518. flt_layout,
  519. int(flt_align / DataTypeSize[tile.math_instruction.element_a]),
  520. )
  521. bias = TensorDescription(
  522. bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))
  523. )
  524. dst = TensorDescription(
  525. dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type])
  526. )
  527. new_operation = Conv2dOperation(
  528. conv_kind,
  529. conv_type,
  530. min_cc,
  531. tile,
  532. src,
  533. flt,
  534. bias,
  535. dst,
  536. element_epilogue,
  537. epilogue,
  538. swizzling_functor,
  539. SpecialOptimizeDesc.NoneSpecialOpt,
  540. implicit_gemm_mode,
  541. without_shared_load,
  542. required_cuda_ver_major,
  543. required_cuda_ver_minor,
  544. )
  545. operations.append(new_operation)
  546. if use_special_optimization != SpecialOptimizeDesc.NoneSpecialOpt:
  547. new_operation = Conv2dOperation(
  548. conv_kind,
  549. conv_type,
  550. min_cc,
  551. tile,
  552. src,
  553. flt,
  554. bias,
  555. dst,
  556. element_epilogue,
  557. epilogue,
  558. swizzling_functor,
  559. use_special_optimization,
  560. implicit_gemm_mode,
  561. without_shared_load,
  562. required_cuda_ver_major,
  563. required_cuda_ver_minor,
  564. )
  565. operations.append(new_operation)
  566. return operations
  567. ###################################################################################################
  568. #
  569. # Emitters functions for all targets
  570. #
  571. ###################################################################################################
  572. class EmitConv2dConfigurationLibrary:
  573. def __init__(self, operation_path, configuration_name):
  574. self.configuration_name = configuration_name
  575. self.configuration_path = os.path.join(
  576. operation_path, "%s.cu" % configuration_name
  577. )
  578. self.instance_emitter = EmitConv2dInstance()
  579. self.instance_template = """
  580. ${operation_instance}
  581. // Derived class
  582. struct ${operation_name} :
  583. public ${operation_name}_base { };
  584. ///////////////////////////////////////////////////////////////////////////////////////////////////
  585. """
  586. self.header_template = """
  587. /*
  588. Generated by conv2d_operation.py - Do not edit.
  589. */
  590. ///////////////////////////////////////////////////////////////////////////////////////////////////
  591. #include "cutlass/cutlass.h"
  592. #include "cutlass/library/library.h"
  593. #include "cutlass/library/manifest.h"
  594. #include "library_internal.h"
  595. #include "conv2d_operation.h"
  596. ///////////////////////////////////////////////////////////////////////////////////////////////////
  597. """
  598. self.configuration_header = """
  599. namespace cutlass {
  600. namespace library {
  601. // Initialize all instances
  602. void initialize_${configuration_name}(Manifest &manifest) {
  603. """
  604. self.configuration_instance = """
  605. using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
  606. ${operation_name}>;
  607. manifest.append(new cutlass::library::Conv2dOperation<
  608. Operation_${operation_name}>(
  609. "${operation_name}"));
  610. """
  611. self.configuration_epilogue = """
  612. }
  613. """
  614. self.epilogue_template = """
  615. ///////////////////////////////////////////////////////////////////////////////////////////////////
  616. } // namespace library
  617. } // namespace cutlass
  618. ///////////////////////////////////////////////////////////////////////////////////////////////////
  619. """
  620. #
  621. def __enter__(self):
  622. self.configuration_file = open(self.configuration_path, "w")
  623. self.configuration_file.write(
  624. SubstituteTemplate(
  625. self.header_template, {"configuration_name": self.configuration_name}
  626. )
  627. )
  628. self.operations = []
  629. return self
  630. #
  631. def emit(self, operation):
  632. self.operations.append(operation)
  633. self.configuration_file.write(
  634. SubstituteTemplate(
  635. self.instance_template,
  636. {
  637. "configuration_name": self.configuration_name,
  638. "operation_name": operation.procedural_name(),
  639. "operation_instance": self.instance_emitter.emit(operation),
  640. },
  641. )
  642. )
  643. #
  644. def __exit__(self, exception_type, exception_value, traceback):
  645. self.configuration_file.write(
  646. SubstituteTemplate(
  647. self.configuration_header,
  648. {"configuration_name": self.configuration_name},
  649. )
  650. )
  651. for operation in self.operations:
  652. self.configuration_file.write(
  653. SubstituteTemplate(
  654. self.configuration_instance,
  655. {
  656. "configuration_name": self.configuration_name,
  657. "operation_name": operation.procedural_name(),
  658. },
  659. )
  660. )
  661. self.configuration_file.write(self.configuration_epilogue)
  662. self.configuration_file.write(self.epilogue_template)
  663. self.configuration_file.close()
  664. ###################################################################################################
  665. ###################################################################################################
  666. # Emitters for Conv Kernel Wrapper
  667. #
  668. ###################################################################################################
  669. class EmitConvSingleKernelWrapper:
  670. def __init__(self, kernel_path, operation, short_path=False):
  671. self.kernel_path = kernel_path
  672. self.operation = operation
  673. self.short_path = short_path
  674. if self.operation.conv_kind == ConvKind.Fprop:
  675. self.instance_emitter = EmitConv2dInstance()
  676. self.convolution_name = "Convolution"
  677. else:
  678. assert self.operation.conv_kind == ConvKind.Dgrad
  679. self.instance_emitter = EmitDeconvInstance()
  680. self.convolution_name = "Deconvolution"
  681. self.header_template = """
  682. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  683. // ignore warning of cutlass
  684. #pragma GCC diagnostic push
  685. #pragma GCC diagnostic ignored "-Wunused-parameter"
  686. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  687. #pragma GCC diagnostic ignored "-Wuninitialized"
  688. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  689. #include "cutlass/convolution/device/convolution.h"
  690. #include "src/cuda/cutlass/manifest.h"
  691. #include "src/cuda/cutlass/convolution_operation.h"
  692. """
  693. self.instance_template = """
  694. ${operation_instance}
  695. """
  696. self.manifest_template = """
  697. namespace cutlass {
  698. namespace library {
  699. void initialize_${operation_name}(Manifest &manifest) {
  700. manifest.append(new ConvolutionOperation<${convolution_name}>(
  701. "${operation_name}"
  702. ));
  703. }
  704. } // namespace library
  705. } // namespace cutlass
  706. """
  707. self.epilogue_template = """
  708. #pragma GCC diagnostic pop
  709. #endif
  710. """
  711. #
  712. def __enter__(self):
  713. if self.short_path:
  714. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  715. GlobalCnt.cnt += 1
  716. else:
  717. self.kernel_path = os.path.join(
  718. self.kernel_path, "%s.cu" % self.operation.procedural_name()
  719. )
  720. self.kernel_file = open(self.kernel_path, "w")
  721. self.kernel_file.write(
  722. SubstituteTemplate(
  723. self.header_template,
  724. {
  725. "required_cuda_ver_major": str(
  726. self.operation.required_cuda_ver_major
  727. ),
  728. "required_cuda_ver_minor": str(
  729. self.operation.required_cuda_ver_minor
  730. ),
  731. },
  732. )
  733. )
  734. return self
  735. #
  736. def emit(self):
  737. self.kernel_file.write(
  738. SubstituteTemplate(
  739. self.instance_template,
  740. {"operation_instance": self.instance_emitter.emit(self.operation)},
  741. )
  742. )
  743. # emit manifest helper
  744. manifest = SubstituteTemplate(
  745. self.manifest_template,
  746. {
  747. "operation_name": self.operation.procedural_name(),
  748. "convolution_name": self.convolution_name,
  749. },
  750. )
  751. self.kernel_file.write(manifest)
  752. #
  753. def __exit__(self, exception_type, exception_value, traceback):
  754. self.kernel_file.write(self.epilogue_template)
  755. self.kernel_file.close()
  756. ###################################################################################################
  757. ###################################################################################################