You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv2d_operation.py 42 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. #
  7. import enum
  8. import os.path
  9. import shutil
  10. from typing import List, Tuple
  11. from library import *
  12. ###################################################################################################
  13. #
  14. class Conv2dOperation:
  15. #
  16. def __init__(
  17. self,
  18. conv_kind,
  19. conv_type,
  20. arch,
  21. tile_description,
  22. src,
  23. flt,
  24. bias,
  25. dst,
  26. element_epilogue,
  27. epilogue_functor=EpilogueFunctor.LinearCombination,
  28. swizzling_functor=SwizzlingFunctor.Identity4,
  29. special_optimization=SpecialOptimizeDesc.NoneSpecialOpt,
  30. implicit_gemm_mode=ImplicitGemmMode.GemmNT,
  31. without_shared_load=False,
  32. required_cuda_ver_major=9,
  33. required_cuda_ver_minor=2,
  34. rin=None,
  35. rout=None,
  36. ):
  37. self.operation_kind = OperationKind.Conv2d
  38. self.conv_kind = conv_kind
  39. self.arch = arch
  40. self.tile_description = tile_description
  41. self.conv_type = conv_type
  42. self.src = src
  43. self.flt = flt
  44. self.bias = bias
  45. self.dst = dst
  46. self.element_epilogue = element_epilogue
  47. self.epilogue_functor = epilogue_functor
  48. self.swizzling_functor = swizzling_functor
  49. self.special_optimization = special_optimization
  50. self.implicit_gemm_mode = implicit_gemm_mode
  51. self.without_shared_load = without_shared_load
  52. self.required_cuda_ver_major = required_cuda_ver_major
  53. self.required_cuda_ver_minor = required_cuda_ver_minor
  54. self.rin = rin
  55. self.rout = rout
  56. #
  57. def accumulator_type(self):
  58. accum = self.tile_description.math_instruction.element_accumulator
  59. return accum
  60. #
  61. def core_name(self):
  62. """ The basic operation kind is prefixed with a letter indicating the accumulation type. """
  63. intermediate_type = ""
  64. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
  65. inst_shape = "%d%d%d" % tuple(
  66. self.tile_description.math_instruction.instruction_shape
  67. )
  68. if (
  69. self.tile_description.math_instruction.element_a != self.flt.element
  70. and self.tile_description.math_instruction.element_a
  71. != self.accumulator_type()
  72. ):
  73. intermediate_type = DataTypeNames[
  74. self.tile_description.math_instruction.element_a
  75. ]
  76. else:
  77. inst_shape = ""
  78. special_opt = ""
  79. if self.special_optimization == SpecialOptimizeDesc.ConvFilterUnity:
  80. special_opt = "_1x1"
  81. elif self.special_optimization == SpecialOptimizeDesc.DeconvDoubleUpsampling:
  82. special_opt = "_s2"
  83. reorder_k = ""
  84. if self.without_shared_load:
  85. reorder_k = "_roc"
  86. conv_type_name = ""
  87. if self.conv_type == ConvType.DepthwiseConvolution:
  88. conv_type_name = "dw"
  89. elif self.conv_type == ConvType.RegionRestrictedConvolution:
  90. conv_type_name = "rr"
  91. return "%s%s%s%s%s%s%s_%s" % (
  92. ShortDataTypeNames[self.accumulator_type()],
  93. inst_shape,
  94. intermediate_type,
  95. conv_type_name,
  96. ConvKindNames[self.conv_kind],
  97. special_opt,
  98. reorder_k,
  99. ShortEpilogueNames[self.epilogue_functor],
  100. )
  101. #
  102. def extended_name(self):
  103. if (
  104. self.dst.element
  105. != self.tile_description.math_instruction.element_accumulator
  106. ):
  107. if self.src.element != self.flt.element:
  108. extended_name = (
  109. "${element_dst}_${core_name}_${element_src}_${element_flt}"
  110. )
  111. elif self.src.element == self.flt.element:
  112. extended_name = "${element_dst}_${core_name}_${element_src}"
  113. else:
  114. if self.src.element != self.flt.element:
  115. extended_name = "${core_name}_${element_src}_${element_flt}"
  116. elif self.src.element == self.flt.element:
  117. extended_name = "${core_name}_${element_src}"
  118. if self.rin != None:
  119. extended_name += "_${element_rin}"
  120. extended_name = SubstituteTemplate(
  121. extended_name,
  122. {
  123. "element_src": DataTypeNames[self.src.element],
  124. "element_flt": DataTypeNames[self.flt.element],
  125. "element_dst": DataTypeNames[self.dst.element],
  126. "core_name": self.core_name(),
  127. "element_rin": DataTypeNames[self.rin.element],
  128. },
  129. )
  130. return extended_name
  131. #
  132. def layout_name(self):
  133. if self.src.layout == self.dst.layout:
  134. layout_name = "${src_layout}_${flt_layout}"
  135. else:
  136. layout_name = "${src_layout}_${flt_layout}_${dst_layout}"
  137. layout_name = SubstituteTemplate(
  138. layout_name,
  139. {
  140. "src_layout": ShortLayoutTypeNames[self.src.layout],
  141. "flt_layout": ShortLayoutTypeNames[self.flt.layout],
  142. "dst_layout": ShortLayoutTypeNames[self.dst.layout],
  143. },
  144. )
  145. return layout_name
  146. #
  147. def configuration_name(self):
  148. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  149. opcode_class_name = OpcodeClassNames[
  150. self.tile_description.math_instruction.opcode_class
  151. ]
  152. warp_shape = [
  153. int(
  154. self.tile_description.threadblock_shape[idx]
  155. / self.tile_description.warp_count[idx]
  156. )
  157. for idx in range(3)
  158. ]
  159. threadblock = "%dx%dx%d_%dx%dx%d_%d" % (
  160. self.tile_description.threadblock_shape[0],
  161. self.tile_description.threadblock_shape[1],
  162. self.tile_description.threadblock_shape[2],
  163. warp_shape[0],
  164. warp_shape[1],
  165. warp_shape[2],
  166. self.tile_description.stages,
  167. )
  168. alignment = "align%dx%d" % (self.src.alignment, self.flt.alignment)
  169. configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${alignment}"
  170. return SubstituteTemplate(
  171. configuration_name,
  172. {
  173. "opcode_class": opcode_class_name,
  174. "extended_name": self.extended_name(),
  175. "threadblock": threadblock,
  176. "layout": self.layout_name(),
  177. "alignment": alignment,
  178. },
  179. )
  180. #
  181. def procedural_name(self):
  182. """ The full procedural name indicates architecture, extended name, tile size, and layout. """
  183. return self.configuration_name()
  184. ###################################################################################################
  185. #
  186. # Emits single instances of a CUTLASS device-wide operator
  187. #
  188. ###################################################################################################
  189. class EmitConv2dInstance:
  190. def __init__(self):
  191. self.template = """
  192. // kernel instance "${operation_name}" generated by cutlass generator
  193. using Convolution_${operation_name} =
  194. typename cutlass::conv::device::Convolution<
  195. ${element_src},
  196. ${layout_src},
  197. ${element_flt},
  198. ${layout_flt},
  199. ${element_dst},
  200. ${layout_dst},
  201. ${element_bias},
  202. ${layout_bias},
  203. ${element_accumulator},
  204. ${conv_type},
  205. ${opcode_class},
  206. ${arch},
  207. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  208. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  209. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  210. ${epilogue_functor}<
  211. ${element_dst},
  212. ${epilogue_vector_length},
  213. ${element_accumulator},
  214. ${element_bias},
  215. ${element_epilogue}
  216. >,
  217. ${swizzling_functor},
  218. ${stages},
  219. ${alignment_src},
  220. ${alignment_filter},
  221. ${special_optimization},
  222. ${math_operator},
  223. ${implicit_gemm_mode},
  224. ${without_shared_load}>;
  225. """
  226. def emit(self, operation):
  227. warp_shape = [
  228. int(
  229. operation.tile_description.threadblock_shape[idx]
  230. / operation.tile_description.warp_count[idx]
  231. )
  232. for idx in range(3)
  233. ]
  234. epilogue_vector_length = int(
  235. min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128)
  236. / DataTypeSize[operation.dst.element]
  237. )
  238. values = {
  239. "operation_name": operation.procedural_name(),
  240. "conv_type": ConvTypeTag[operation.conv_type],
  241. "element_src": DataTypeTag[operation.src.element],
  242. "layout_src": LayoutTag[operation.src.layout],
  243. "element_flt": DataTypeTag[operation.flt.element],
  244. "layout_flt": LayoutTag[operation.flt.layout],
  245. "element_dst": DataTypeTag[operation.dst.element],
  246. "layout_dst": LayoutTag[operation.dst.layout],
  247. "element_bias": DataTypeTag[operation.bias.element],
  248. "layout_bias": LayoutTag[operation.bias.layout],
  249. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  250. "opcode_class": OpcodeClassTag[
  251. operation.tile_description.math_instruction.opcode_class
  252. ],
  253. "arch": "cutlass::arch::Sm%d" % operation.arch,
  254. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  255. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  256. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  257. "warp_shape_m": str(warp_shape[0]),
  258. "warp_shape_n": str(warp_shape[1]),
  259. "warp_shape_k": str(warp_shape[2]),
  260. "instruction_shape_m": str(
  261. operation.tile_description.math_instruction.instruction_shape[0]
  262. ),
  263. "instruction_shape_n": str(
  264. operation.tile_description.math_instruction.instruction_shape[1]
  265. ),
  266. "instruction_shape_k": str(
  267. operation.tile_description.math_instruction.instruction_shape[2]
  268. ),
  269. "epilogue_vector_length": str(epilogue_vector_length),
  270. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  271. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  272. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  273. "stages": str(operation.tile_description.stages),
  274. "alignment_src": str(operation.src.alignment),
  275. "alignment_filter": str(operation.flt.alignment),
  276. "special_optimization": SpecialOptimizeDescTag[
  277. operation.special_optimization
  278. ],
  279. "math_operator": MathOperationTag[
  280. operation.tile_description.math_instruction.math_operation
  281. ],
  282. "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode],
  283. "without_shared_load": str(operation.without_shared_load).lower(),
  284. }
  285. return SubstituteTemplate(self.template, values)
  286. class EmitDeconvInstance:
  287. def __init__(self):
  288. self.template = """
  289. // kernel instance "${operation_name}" generated by cutlass generator
  290. using Convolution_${operation_name} =
  291. typename cutlass::conv::device::Deconvolution<
  292. ${element_src},
  293. ${layout_src},
  294. ${element_flt},
  295. ${layout_flt},
  296. ${element_dst},
  297. ${layout_dst},
  298. ${element_bias},
  299. ${layout_bias},
  300. ${element_accumulator},
  301. ${conv_type},
  302. ${opcode_class},
  303. ${arch},
  304. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  305. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  306. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  307. ${epilogue_functor}<
  308. ${element_dst},
  309. ${epilogue_vector_length},
  310. ${element_accumulator},
  311. ${element_bias},
  312. ${element_epilogue}
  313. >,
  314. ${swizzling_functor},
  315. ${stages},
  316. ${alignment_src},
  317. ${alignment_filter},
  318. ${special_optimization},
  319. ${math_operator},
  320. ${implicit_gemm_mode}>;
  321. """
  322. def emit(self, operation):
  323. warp_shape = [
  324. int(
  325. operation.tile_description.threadblock_shape[idx]
  326. / operation.tile_description.warp_count[idx]
  327. )
  328. for idx in range(3)
  329. ]
  330. epilogue_vector_length = int(
  331. min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128)
  332. / DataTypeSize[operation.dst.element]
  333. )
  334. values = {
  335. "operation_name": operation.procedural_name(),
  336. "conv_type": ConvTypeTag[operation.conv_type],
  337. "element_src": DataTypeTag[operation.src.element],
  338. "layout_src": LayoutTag[operation.src.layout],
  339. "element_flt": DataTypeTag[operation.flt.element],
  340. "layout_flt": LayoutTag[operation.flt.layout],
  341. "element_dst": DataTypeTag[operation.dst.element],
  342. "layout_dst": LayoutTag[operation.dst.layout],
  343. "element_bias": DataTypeTag[operation.bias.element],
  344. "layout_bias": LayoutTag[operation.bias.layout],
  345. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  346. "opcode_class": OpcodeClassTag[
  347. operation.tile_description.math_instruction.opcode_class
  348. ],
  349. "arch": "cutlass::arch::Sm%d" % operation.arch,
  350. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  351. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  352. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  353. "warp_shape_m": str(warp_shape[0]),
  354. "warp_shape_n": str(warp_shape[1]),
  355. "warp_shape_k": str(warp_shape[2]),
  356. "instruction_shape_m": str(
  357. operation.tile_description.math_instruction.instruction_shape[0]
  358. ),
  359. "instruction_shape_n": str(
  360. operation.tile_description.math_instruction.instruction_shape[1]
  361. ),
  362. "instruction_shape_k": str(
  363. operation.tile_description.math_instruction.instruction_shape[2]
  364. ),
  365. "epilogue_vector_length": str(epilogue_vector_length),
  366. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  367. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  368. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  369. "stages": str(operation.tile_description.stages),
  370. "alignment_src": str(operation.src.alignment),
  371. "alignment_filter": str(operation.flt.alignment),
  372. "special_optimization": SpecialOptimizeDescTag[
  373. operation.special_optimization
  374. ],
  375. "math_operator": MathOperationTag[
  376. operation.tile_description.math_instruction.math_operation
  377. ],
  378. "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode],
  379. }
  380. return SubstituteTemplate(self.template, values)
  381. class EmitConvolutionBackwardFilterInstance:
  382. def __init__(self):
  383. self.template = """
  384. // kernel instance "${operation_name}" generated by cutlass generator
  385. using Convolution_${operation_name} =
  386. typename cutlass::conv::device::ConvolutionBackwardFilter<
  387. ${element_src},
  388. ${layout_src},
  389. ${element_diff},
  390. ${layout_diff},
  391. ${element_grad},
  392. ${layout_grad},
  393. ${element_accumulator},
  394. ${conv_type},
  395. ${opcode_class},
  396. ${arch},
  397. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  398. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  399. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  400. ${epilogue_functor}<
  401. ${element_grad},
  402. ${epilogue_vector_length},
  403. ${element_accumulator},
  404. ${element_epilogue}
  405. >,
  406. ${swizzling_functor},
  407. ${stages},
  408. ${alignment_src},
  409. ${alignment_diff},
  410. ${special_optimization},
  411. ${math_operator},
  412. ${implicit_gemm_mode}>;
  413. """
  414. def emit(self, operation):
  415. warp_shape = [
  416. int(
  417. operation.tile_description.threadblock_shape[idx]
  418. / operation.tile_description.warp_count[idx]
  419. )
  420. for idx in range(3)
  421. ]
  422. epilogue_vector_length = int(
  423. min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128)
  424. / DataTypeSize[operation.dst.element]
  425. )
  426. values = {
  427. "operation_name": operation.procedural_name(),
  428. "conv_type": ConvTypeTag[operation.conv_type],
  429. "element_src": DataTypeTag[operation.src.element],
  430. "layout_src": LayoutTag[operation.src.layout],
  431. "element_diff": DataTypeTag[operation.flt.element],
  432. "layout_diff": LayoutTag[operation.flt.layout],
  433. "element_grad": DataTypeTag[operation.dst.element],
  434. "layout_grad": LayoutTag[operation.dst.layout],
  435. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  436. "opcode_class": OpcodeClassTag[
  437. operation.tile_description.math_instruction.opcode_class
  438. ],
  439. "arch": "cutlass::arch::Sm%d" % operation.arch,
  440. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  441. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  442. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  443. "warp_shape_m": str(warp_shape[0]),
  444. "warp_shape_n": str(warp_shape[1]),
  445. "warp_shape_k": str(warp_shape[2]),
  446. "instruction_shape_m": str(
  447. operation.tile_description.math_instruction.instruction_shape[0]
  448. ),
  449. "instruction_shape_n": str(
  450. operation.tile_description.math_instruction.instruction_shape[1]
  451. ),
  452. "instruction_shape_k": str(
  453. operation.tile_description.math_instruction.instruction_shape[2]
  454. ),
  455. "epilogue_vector_length": str(epilogue_vector_length),
  456. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  457. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  458. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  459. "stages": str(operation.tile_description.stages),
  460. "alignment_src": str(operation.src.alignment),
  461. "alignment_diff": str(operation.flt.alignment),
  462. "special_optimization": SpecialOptimizeDescTag[
  463. operation.special_optimization
  464. ],
  465. "math_operator": MathOperationTag[
  466. operation.tile_description.math_instruction.math_operation
  467. ],
  468. "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode],
  469. }
  470. return SubstituteTemplate(self.template, values)
  471. class EmitRegionRestrictedConvolutionBackwardFilterInstance:
  472. def __init__(self):
  473. self.template = """
  474. // kernel instance "${operation_name}" generated by cutlass generator
  475. using Convolution_${operation_name} =
  476. typename cutlass::conv::device::RegionRestrictedConvolutionBackwardFilter<
  477. ${element_src},
  478. ${layout_src},
  479. ${element_diff},
  480. ${layout_diff},
  481. ${element_src_mask},
  482. ${layout_src_mask},
  483. ${element_output_mask},
  484. ${layout_output_mask},
  485. ${element_grad},
  486. ${layout_grad},
  487. ${element_accumulator},
  488. ${conv_type},
  489. ${opcode_class},
  490. ${arch},
  491. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  492. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  493. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  494. ${epilogue_functor}<
  495. ${element_grad},
  496. ${epilogue_vector_length},
  497. ${element_accumulator},
  498. ${element_epilogue}
  499. >,
  500. ${swizzling_functor},
  501. ${stages},
  502. ${alignment_src},
  503. ${alignment_diff},
  504. ${alignment_src_mask},
  505. ${alignment_output_mask},
  506. ${special_optimization},
  507. ${math_operator},
  508. ${implicit_gemm_mode}>;
  509. """
  510. def emit(self, operation):
  511. warp_shape = [
  512. int(
  513. operation.tile_description.threadblock_shape[idx]
  514. / operation.tile_description.warp_count[idx]
  515. )
  516. for idx in range(3)
  517. ]
  518. epilogue_vector_length = int(
  519. min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128)
  520. / DataTypeSize[operation.dst.element]
  521. )
  522. values = {
  523. "operation_name": operation.procedural_name(),
  524. "conv_type": ConvTypeTag[operation.conv_type],
  525. "element_src": DataTypeTag[operation.src.element],
  526. "layout_src": LayoutTag[operation.src.layout],
  527. "element_diff": DataTypeTag[operation.flt.element],
  528. "layout_diff": LayoutTag[operation.flt.layout],
  529. "element_src_mask": DataTypeTag[operation.rin.element],
  530. "layout_src_mask": LayoutTag[operation.rin.layout],
  531. "element_output_mask": DataTypeTag[operation.rout.element],
  532. "layout_output_mask": LayoutTag[operation.rout.layout],
  533. "element_grad": DataTypeTag[operation.dst.element],
  534. "layout_grad": LayoutTag[operation.dst.layout],
  535. "element_accumulator": DataTypeTag[operation.accumulator_type()],
  536. "opcode_class": OpcodeClassTag[
  537. operation.tile_description.math_instruction.opcode_class
  538. ],
  539. "arch": "cutlass::arch::Sm%d" % operation.arch,
  540. "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
  541. "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
  542. "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
  543. "warp_shape_m": str(warp_shape[0]),
  544. "warp_shape_n": str(warp_shape[1]),
  545. "warp_shape_k": str(warp_shape[2]),
  546. "instruction_shape_m": str(
  547. operation.tile_description.math_instruction.instruction_shape[0]
  548. ),
  549. "instruction_shape_n": str(
  550. operation.tile_description.math_instruction.instruction_shape[1]
  551. ),
  552. "instruction_shape_k": str(
  553. operation.tile_description.math_instruction.instruction_shape[2]
  554. ),
  555. "epilogue_vector_length": str(epilogue_vector_length),
  556. "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
  557. "element_epilogue": str(DataTypeTag[operation.element_epilogue]),
  558. "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
  559. "stages": str(operation.tile_description.stages),
  560. "alignment_src": str(operation.src.alignment),
  561. "alignment_diff": str(operation.flt.alignment),
  562. "alignment_src_mask": str(operation.rin.alignment),
  563. "alignment_output_mask": str(operation.rout.alignment),
  564. "special_optimization": SpecialOptimizeDescTag[
  565. operation.special_optimization
  566. ],
  567. "math_operator": MathOperationTag[
  568. operation.tile_description.math_instruction.math_operation
  569. ],
  570. "implicit_gemm_mode": ImplicitGemmModeTag[operation.implicit_gemm_mode],
  571. }
  572. return SubstituteTemplate(self.template, values)
  573. ###################################################################################################
  574. #
  575. # Generator functions for all layouts
  576. #
  577. ###################################################################################################
  578. #
  579. def GenerateConv2d(
  580. conv_type,
  581. conv_kind,
  582. tile_descriptions,
  583. src_layout,
  584. flt_layout,
  585. dst_layout,
  586. dst_type,
  587. min_cc,
  588. src_align=32,
  589. flt_align=32,
  590. dst_align=32,
  591. use_special_optimization=SpecialOptimizeDesc.NoneSpecialOpt,
  592. implicit_gemm_mode=ImplicitGemmMode.GemmNT,
  593. without_shared_load=False,
  594. required_cuda_ver_major=9,
  595. required_cuda_ver_minor=2,
  596. ):
  597. operations = []
  598. element_epilogue = DataType.f32
  599. if (
  600. conv_type == ConvType.DepthwiseConvolution
  601. or conv_type == ConvType.RegionRestrictedConvolution
  602. ):
  603. if conv_kind == ConvKind.Fprop:
  604. swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionFprop
  605. elif conv_kind == ConvKind.Dgrad:
  606. swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionDgrad
  607. else:
  608. assert conv_kind == ConvKind.Wgrad
  609. swizzling_functor = SwizzlingFunctor.DepthwiseConvolutionWgrad
  610. elif conv_type == ConvType.Convolution:
  611. if conv_kind == ConvKind.Fprop:
  612. if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
  613. swizzling_functor = SwizzlingFunctor.ConvFpropTrans
  614. else:
  615. swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx
  616. else:
  617. if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
  618. swizzling_functor = SwizzlingFunctor.ConvDgradTrans
  619. else:
  620. swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx
  621. # skip rule
  622. def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool:
  623. return (
  624. layout == LayoutType.TensorNC32HW32 and tile.threadblock_shape[0] % 32 != 0
  625. )
  626. # rule for bias_type and epilogues
  627. def get_bias_type_and_epilogues(
  628. tile: TileDescription, out_dtype: DataType
  629. ) -> Tuple[DataType, List[EpilogueFunctor]]:
  630. if (
  631. tile.math_instruction.element_accumulator == DataType.s32
  632. and out_dtype != DataType.f32
  633. ):
  634. bias_type = DataType.s32
  635. if tile.math_instruction.element_b == DataType.u4:
  636. epilogues = [
  637. EpilogueFunctor.BiasAddLinearCombinationClamp,
  638. EpilogueFunctor.BiasAddLinearCombinationReluClamp,
  639. ]
  640. else:
  641. epilogues = [
  642. EpilogueFunctor.BiasAddLinearCombinationClamp,
  643. EpilogueFunctor.BiasAddLinearCombinationReluClamp,
  644. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp,
  645. ]
  646. elif (
  647. tile.math_instruction.element_accumulator == DataType.f32
  648. or tile.math_instruction.element_accumulator == DataType.f16
  649. ) or (
  650. tile.math_instruction.element_accumulator == DataType.s32
  651. and out_dtype == DataType.f32
  652. ):
  653. bias_type = out_dtype
  654. epilogues = [
  655. EpilogueFunctor.BiasAddLinearCombination,
  656. EpilogueFunctor.BiasAddLinearCombinationRelu,
  657. EpilogueFunctor.LinearCombination,
  658. ]
  659. if conv_type == ConvType.Convolution:
  660. epilogues.append(EpilogueFunctor.BiasAddLinearCombinationHSwish)
  661. else:
  662. assert False, "invalid path"
  663. return bias_type, epilogues
  664. # rule for filter alignment
  665. def get_flt_align(tile: TileDescription) -> int:
  666. nonlocal flt_align
  667. if (
  668. tile.math_instruction.opcode_class == OpcodeClass.Simt
  669. and tile.math_instruction.element_accumulator == DataType.s32
  670. ):
  671. thread_num = (
  672. tile.warp_count[0] * tile.warp_count[1] * tile.warp_count[2] * 32
  673. )
  674. flt_block = (
  675. tile.threadblock_shape[0]
  676. * tile.threadblock_shape[2]
  677. * DataTypeSize[tile.math_instruction.element_a]
  678. )
  679. load_per_thread = flt_block // thread_num
  680. if load_per_thread >= 128:
  681. flt_align = 128
  682. elif load_per_thread >= 64:
  683. flt_align = 64
  684. else:
  685. assert load_per_thread >= 32
  686. flt_align = 32
  687. return flt_align
  688. def get_dst_align(tile: TileDescription, out_layout: LayoutType) -> int:
  689. nonlocal dst_align
  690. if (
  691. tile.math_instruction.opcode_class == OpcodeClass.TensorOp
  692. and dst_layout == LayoutType.TensorNC4HW4
  693. ):
  694. dst_align = 32
  695. return dst_align
  696. def filter_epilogue_with_conv_kind(
  697. epilogue: EpilogueFunctor, conv_kind: ConvKind
  698. ) -> bool:
  699. if conv_kind == ConvKind.Fprop:
  700. return epilogue == EpilogueFunctor.LinearCombination
  701. elif conv_kind == ConvKind.Dgrad:
  702. return (
  703. epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp
  704. and epilogue != EpilogueFunctor.BiasAddLinearCombination
  705. )
  706. elif conv_kind == ConvKind.Wgrad:
  707. return epilogue != EpilogueFunctor.LinearCombination
  708. # loop over all tile descriptions
  709. for tile in tile_descriptions:
  710. if filter_tile_with_layout(tile, dst_layout):
  711. continue
  712. bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type)
  713. flt_align = flt_align if conv_kind == ConvKind.Wgrad else get_flt_align(tile)
  714. dst_align = get_dst_align(tile, dst_layout)
  715. for epilogue in epilogues:
  716. if filter_epilogue_with_conv_kind(epilogue, conv_kind):
  717. continue
  718. if dst_type == DataType.f32:
  719. bias_type = DataType.f32
  720. #
  721. src = TensorDescription(
  722. tile.math_instruction.element_b,
  723. src_layout,
  724. int(src_align / DataTypeSize[tile.math_instruction.element_b]),
  725. )
  726. flt = TensorDescription(
  727. tile.math_instruction.element_a,
  728. flt_layout,
  729. int(flt_align / DataTypeSize[tile.math_instruction.element_a]),
  730. )
  731. rin = TensorDescription(
  732. tile.math_instruction.element_rin,
  733. src_layout,
  734. int(src_align / DataTypeSize[tile.math_instruction.element_rin]),
  735. )
  736. rout = TensorDescription(
  737. tile.math_instruction.element_rout,
  738. dst_layout,
  739. int(dst_align / DataTypeSize[tile.math_instruction.element_rout]),
  740. )
  741. bias = TensorDescription(
  742. bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))
  743. )
  744. dst = TensorDescription(
  745. dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type])
  746. )
  747. new_operation = Conv2dOperation(
  748. conv_kind,
  749. conv_type,
  750. min_cc,
  751. tile,
  752. src,
  753. flt,
  754. bias,
  755. dst,
  756. element_epilogue,
  757. epilogue,
  758. swizzling_functor,
  759. SpecialOptimizeDesc.NoneSpecialOpt,
  760. implicit_gemm_mode,
  761. without_shared_load,
  762. required_cuda_ver_major,
  763. required_cuda_ver_minor,
  764. rin,
  765. rout,
  766. )
  767. operations.append(new_operation)
  768. if use_special_optimization != SpecialOptimizeDesc.NoneSpecialOpt:
  769. new_operation = Conv2dOperation(
  770. conv_kind,
  771. conv_type,
  772. min_cc,
  773. tile,
  774. src,
  775. flt,
  776. bias,
  777. dst,
  778. element_epilogue,
  779. epilogue,
  780. swizzling_functor,
  781. use_special_optimization,
  782. implicit_gemm_mode,
  783. without_shared_load,
  784. required_cuda_ver_major,
  785. required_cuda_ver_minor,
  786. rin,
  787. rout,
  788. )
  789. operations.append(new_operation)
  790. return operations
  791. ###################################################################################################
  792. #
  793. # Emitters functions for all targets
  794. #
  795. ###################################################################################################
  796. class EmitConv2dConfigurationLibrary:
  797. def __init__(self, operation_path, configuration_name):
  798. self.configuration_name = configuration_name
  799. self.configuration_path = os.path.join(
  800. operation_path, "%s.cu" % configuration_name
  801. )
  802. self.instance_emitter = EmitConv2dInstance()
  803. self.instance_template = """
  804. ${operation_instance}
  805. // Derived class
  806. struct ${operation_name} :
  807. public ${operation_name}_base { };
  808. ///////////////////////////////////////////////////////////////////////////////////////////////////
  809. """
  810. self.header_template = """
  811. /*
  812. Generated by conv2d_operation.py - Do not edit.
  813. */
  814. ///////////////////////////////////////////////////////////////////////////////////////////////////
  815. #include "cutlass/cutlass.h"
  816. #include "cutlass/library/library.h"
  817. #include "cutlass/library/manifest.h"
  818. #include "library_internal.h"
  819. #include "conv2d_operation.h"
  820. ///////////////////////////////////////////////////////////////////////////////////////////////////
  821. """
  822. self.configuration_header = """
  823. namespace cutlass {
  824. namespace library {
  825. // Initialize all instances
  826. void initialize_${configuration_name}(Manifest &manifest) {
  827. """
  828. self.configuration_instance = """
  829. using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
  830. ${operation_name}>;
  831. manifest.append(new cutlass::library::Conv2dOperation<
  832. Operation_${operation_name}>(
  833. "${operation_name}"));
  834. """
  835. self.configuration_epilogue = """
  836. }
  837. """
  838. self.epilogue_template = """
  839. ///////////////////////////////////////////////////////////////////////////////////////////////////
  840. } // namespace library
  841. } // namespace cutlass
  842. ///////////////////////////////////////////////////////////////////////////////////////////////////
  843. """
  844. #
  845. def __enter__(self):
  846. self.configuration_file = open(self.configuration_path, "w")
  847. self.configuration_file.write(
  848. SubstituteTemplate(
  849. self.header_template, {"configuration_name": self.configuration_name}
  850. )
  851. )
  852. self.operations = []
  853. return self
  854. #
  855. def emit(self, operation):
  856. self.operations.append(operation)
  857. self.configuration_file.write(
  858. SubstituteTemplate(
  859. self.instance_template,
  860. {
  861. "configuration_name": self.configuration_name,
  862. "operation_name": operation.procedural_name(),
  863. "operation_instance": self.instance_emitter.emit(operation),
  864. },
  865. )
  866. )
  867. #
  868. def __exit__(self, exception_type, exception_value, traceback):
  869. self.configuration_file.write(
  870. SubstituteTemplate(
  871. self.configuration_header,
  872. {"configuration_name": self.configuration_name},
  873. )
  874. )
  875. for operation in self.operations:
  876. self.configuration_file.write(
  877. SubstituteTemplate(
  878. self.configuration_instance,
  879. {
  880. "configuration_name": self.configuration_name,
  881. "operation_name": operation.procedural_name(),
  882. },
  883. )
  884. )
  885. self.configuration_file.write(self.configuration_epilogue)
  886. self.configuration_file.write(self.epilogue_template)
  887. self.configuration_file.close()
  888. ###################################################################################################
  889. ###################################################################################################
  890. # Emitters for Conv Kernel Wrapper
  891. #
  892. ###################################################################################################
  893. class EmitConvSingleKernelWrapper:
  894. def __init__(self, kernel_path, operation, short_path=False):
  895. self.kernel_path = kernel_path
  896. self.operation = operation
  897. self.short_path = short_path
  898. if self.operation.conv_kind == ConvKind.Fprop:
  899. self.instance_emitter = EmitConv2dInstance()
  900. self.convolution_name = "ConvolutionOperation"
  901. elif self.operation.conv_kind == ConvKind.Dgrad:
  902. self.instance_emitter = EmitDeconvInstance()
  903. self.convolution_name = "ConvolutionOperation"
  904. else:
  905. assert self.operation.conv_kind == ConvKind.Wgrad
  906. self.instance_emitter = EmitConvolutionBackwardFilterInstance()
  907. self.convolution_name = "ConvolutionBackwardFilterOperation"
  908. self.header_template = """
  909. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  910. // ignore warning of cutlass
  911. #pragma GCC diagnostic push
  912. #pragma GCC diagnostic ignored "-Wunused-parameter"
  913. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  914. #pragma GCC diagnostic ignored "-Wuninitialized"
  915. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  916. #include "cutlass/convolution/device/convolution.h"
  917. #include "src/cuda/cutlass/manifest.h"
  918. #include "src/cuda/cutlass/convolution_operation.h"
  919. """
  920. self.instance_template = """
  921. ${operation_instance}
  922. """
  923. self.manifest_template = """
  924. namespace cutlass {
  925. namespace library {
  926. void initialize_${operation_name}(Manifest &manifest) {
  927. manifest.append(new ${convolution_name}<Convolution_${operation_name}>(
  928. "${operation_name}"
  929. ));
  930. }
  931. } // namespace library
  932. } // namespace cutlass
  933. """
  934. self.epilogue_template = """
  935. #pragma GCC diagnostic pop
  936. #endif
  937. """
  938. #
  939. def __enter__(self):
  940. if self.short_path:
  941. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  942. GlobalCnt.cnt += 1
  943. else:
  944. self.kernel_path = os.path.join(
  945. self.kernel_path, "%s.cu" % self.operation.procedural_name()
  946. )
  947. self.kernel_file = open(self.kernel_path, "w")
  948. return self
  949. #
  950. def emit(self):
  951. self.kernel_file.write(
  952. SubstituteTemplate(
  953. self.instance_template,
  954. {"operation_instance": self.instance_emitter.emit(self.operation)},
  955. )
  956. )
  957. # emit manifest helper
  958. manifest = SubstituteTemplate(
  959. self.manifest_template,
  960. {
  961. "operation_name": self.operation.procedural_name(),
  962. "convolution_name": self.convolution_name,
  963. },
  964. )
  965. self.kernel_file.write(manifest)
  966. #
  967. def __exit__(self, exception_type, exception_value, traceback):
  968. self.kernel_file.close()
  969. class EmitRegionRestrictedConvSingleKernelWrapper:
  970. def __init__(self, kernel_path, operation, short_path=False):
  971. self.kernel_path = kernel_path
  972. self.operation = operation
  973. self.short_path = short_path
  974. # Now only support wgrad
  975. assert self.operation.conv_kind == ConvKind.Wgrad
  976. self.instance_emitter = EmitRegionRestrictedConvolutionBackwardFilterInstance()
  977. self.convolution_name = "RegionRestrictedConvolutionBackwardFilterOperation"
  978. self.header_template = """
  979. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  980. // ignore warning of cutlass
  981. #pragma GCC diagnostic push
  982. #pragma GCC diagnostic ignored "-Wunused-parameter"
  983. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  984. #pragma GCC diagnostic ignored "-Wuninitialized"
  985. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  986. #include "cutlass/convolution/device/convolution.h"
  987. #include "src/cuda/cutlass/manifest.h"
  988. #include "src/cuda/cutlass/convolution_operation.h"
  989. """
  990. self.instance_template = """
  991. ${operation_instance}
  992. """
  993. self.manifest_template = """
  994. namespace cutlass {
  995. namespace library {
  996. void initialize_${operation_name}(Manifest &manifest) {
  997. manifest.append(new ${convolution_name}<Convolution_${operation_name}>(
  998. "${operation_name}"
  999. ));
  1000. }
  1001. } // namespace library
  1002. } // namespace cutlass
  1003. """
  1004. self.epilogue_template = """
  1005. #pragma GCC diagnostic pop
  1006. #endif
  1007. """
  1008. #
  1009. def __enter__(self):
  1010. if self.short_path:
  1011. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  1012. GlobalCnt.cnt += 1
  1013. else:
  1014. self.kernel_path = os.path.join(
  1015. self.kernel_path, "%s.cu" % self.operation.procedural_name()
  1016. )
  1017. self.kernel_file = open(self.kernel_path, "w")
  1018. return self
  1019. #
  1020. def emit(self):
  1021. self.kernel_file.write(
  1022. SubstituteTemplate(
  1023. self.instance_template,
  1024. {"operation_instance": self.instance_emitter.emit(self.operation)},
  1025. )
  1026. )
  1027. # emit manifest helper
  1028. manifest = SubstituteTemplate(
  1029. self.manifest_template,
  1030. {
  1031. "operation_name": self.operation.procedural_name(),
  1032. "convolution_name": self.convolution_name,
  1033. },
  1034. )
  1035. self.kernel_file.write(manifest)
  1036. #
  1037. def __exit__(self, exception_type, exception_value, traceback):
  1038. self.kernel_file.close()
  1039. ###################################################################################################
  1040. ###################################################################################################