You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

library.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import re
  7. ###################################################################################################
  8. import enum
  9. # The following block implements enum.auto() for Python 3.5 variants that don't include it such
  10. # as the default 3.5.2 on Ubuntu 16.04.
  11. #
  12. # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
  13. try:
  14. from enum import auto as enum_auto
  15. except ImportError:
  16. __cutlass_library_auto_enum = 0
  17. def enum_auto() -> int:
  18. global __cutlass_library_auto_enum
  19. i = __cutlass_library_auto_enum
  20. __cutlass_library_auto_enum += 1
  21. return i
  22. ###################################################################################################
  23. #
  24. class GeneratorTarget(enum.Enum):
  25. Library = enum_auto()
  26. #
  27. GeneratorTargetNames = {GeneratorTarget.Library: "library"}
  28. #
  29. ###################################################################################################
  30. #
  31. class DataType(enum.Enum):
  32. b1 = enum_auto()
  33. u4 = enum_auto()
  34. u8 = enum_auto()
  35. u16 = enum_auto()
  36. u32 = enum_auto()
  37. u64 = enum_auto()
  38. s4 = enum_auto()
  39. s8 = enum_auto()
  40. s16 = enum_auto()
  41. s32 = enum_auto()
  42. s64 = enum_auto()
  43. f16 = enum_auto()
  44. bf16 = enum_auto()
  45. f32 = enum_auto()
  46. tf32 = enum_auto()
  47. f64 = enum_auto()
  48. cf16 = enum_auto()
  49. cbf16 = enum_auto()
  50. cf32 = enum_auto()
  51. ctf32 = enum_auto()
  52. cf64 = enum_auto()
  53. cs4 = enum_auto()
  54. cs8 = enum_auto()
  55. cs16 = enum_auto()
  56. cs32 = enum_auto()
  57. cs64 = enum_auto()
  58. cu4 = enum_auto()
  59. cu8 = enum_auto()
  60. cu16 = enum_auto()
  61. cu32 = enum_auto()
  62. cu64 = enum_auto()
  63. invalid = enum_auto()
  64. #
  65. ShortDataTypeNames = {
  66. DataType.s32: "i",
  67. DataType.f16: "h",
  68. DataType.f32: "s",
  69. DataType.f64: "d",
  70. DataType.cf32: "c",
  71. DataType.cf64: "z",
  72. }
  73. #
  74. DataTypeNames = {
  75. DataType.b1: "b1",
  76. DataType.u4: "u4",
  77. DataType.u8: "u8",
  78. DataType.u16: "u16",
  79. DataType.u32: "u32",
  80. DataType.u64: "u64",
  81. DataType.s4: "s4",
  82. DataType.s8: "s8",
  83. DataType.s16: "s16",
  84. DataType.s32: "s32",
  85. DataType.s64: "s64",
  86. DataType.f16: "f16",
  87. DataType.bf16: "bf16",
  88. DataType.f32: "f32",
  89. DataType.tf32: "tf32",
  90. DataType.f64: "f64",
  91. DataType.cf16: "cf16",
  92. DataType.cbf16: "cbf16",
  93. DataType.cf32: "cf32",
  94. DataType.ctf32: "ctf32",
  95. DataType.cf64: "cf64",
  96. DataType.cu4: "cu4",
  97. DataType.cu8: "cu8",
  98. DataType.cu16: "cu16",
  99. DataType.cu32: "cu32",
  100. DataType.cu64: "cu64",
  101. DataType.cs4: "cs4",
  102. DataType.cs8: "cs8",
  103. DataType.cs16: "cs16",
  104. DataType.cs32: "cs32",
  105. DataType.cs64: "cs64",
  106. }
  107. DataTypeTag = {
  108. DataType.b1: "cutlass::uint1b_t",
  109. DataType.u4: "cutlass::uint4b_t",
  110. DataType.u8: "uint8_t",
  111. DataType.u16: "uint16_t",
  112. DataType.u32: "uint32_t",
  113. DataType.u64: "uint64_t",
  114. DataType.s4: "cutlass::int4b_t",
  115. DataType.s8: "int8_t",
  116. DataType.s16: "int16_t",
  117. DataType.s32: "int32_t",
  118. DataType.s64: "int64_t",
  119. DataType.f16: "cutlass::half_t",
  120. DataType.bf16: "cutlass::bfloat16_t",
  121. DataType.f32: "float",
  122. DataType.tf32: "cutlass::tfloat32_t",
  123. DataType.f64: "double",
  124. DataType.cf16: "cutlass::complex<cutlass::half_t>",
  125. DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
  126. DataType.cf32: "cutlass::complex<float>",
  127. DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
  128. DataType.cf64: "cutlass::complex<double>",
  129. DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
  130. DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
  131. DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
  132. DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
  133. DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
  134. DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
  135. DataType.cs8: "cutlass::complex<cutlass::int8_t>",
  136. DataType.cs16: "cutlass::complex<cutlass::int16_t>",
  137. DataType.cs32: "cutlass::complex<cutlass::int32_t>",
  138. DataType.cs64: "cutlass::complex<cutlass::int64_t>",
  139. }
  140. DataTypeSize = {
  141. DataType.b1: 1,
  142. DataType.u4: 4,
  143. DataType.u8: 4,
  144. DataType.u16: 16,
  145. DataType.u32: 32,
  146. DataType.u64: 64,
  147. DataType.s4: 4,
  148. DataType.s8: 8,
  149. DataType.s16: 16,
  150. DataType.s32: 32,
  151. DataType.s64: 64,
  152. DataType.f16: 16,
  153. DataType.bf16: 16,
  154. DataType.f32: 32,
  155. DataType.tf32: 32,
  156. DataType.f64: 64,
  157. DataType.cf16: 32,
  158. DataType.cbf16: 32,
  159. DataType.cf32: 64,
  160. DataType.ctf32: 32,
  161. DataType.cf64: 128,
  162. DataType.cu4: 8,
  163. DataType.cu8: 16,
  164. DataType.cu16: 32,
  165. DataType.cu32: 64,
  166. DataType.cu64: 128,
  167. DataType.cs4: 8,
  168. DataType.cs8: 16,
  169. DataType.cs16: 32,
  170. DataType.cs32: 64,
  171. DataType.cs64: 128,
  172. }
  173. ###################################################################################################
  174. #
  175. class ComplexTransform(enum.Enum):
  176. none = enum_auto()
  177. conj = enum_auto()
  178. #
  179. ComplexTransformTag = {
  180. ComplexTransform.none: "cutlass::ComplexTransform::kNone",
  181. ComplexTransform.conj: "cutlass::ComplexTransform::kConjugate",
  182. }
  183. #
  184. RealComplexBijection = [
  185. (DataType.f16, DataType.cf16),
  186. (DataType.f32, DataType.cf32),
  187. (DataType.f64, DataType.cf64),
  188. ]
  189. #
  190. def is_complex(data_type):
  191. for r, c in RealComplexBijection:
  192. if data_type == c:
  193. return True
  194. return False
  195. #
  196. def get_complex_from_real(real_type):
  197. for r, c in RealComplexBijection:
  198. if real_type == r:
  199. return c
  200. return DataType.invalid
  201. #
  202. def get_real_from_complex(complex_type):
  203. for r, c in RealComplexBijection:
  204. if complex_type == c:
  205. return r
  206. return DataType.invalid
  207. #
  208. class ComplexMultiplyOp(enum.Enum):
  209. multiply_add = enum_auto()
  210. gaussian = enum_auto()
  211. ###################################################################################################
  212. #
  213. class MathOperation(enum.Enum):
  214. multiply_add = enum_auto()
  215. multiply_add_saturate = enum_auto()
  216. xor_popc = enum_auto()
  217. multiply_add_fast_bf16 = enum_auto()
  218. multiply_add_fast_f16 = enum_auto()
  219. multiply_add_complex = enum_auto()
  220. multiply_add_complex_gaussian = enum_auto()
  221. #
  222. MathOperationTag = {
  223. MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd",
  224. MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate",
  225. MathOperation.xor_popc: "cutlass::arch::OpXorPopc",
  226. MathOperation.multiply_add_fast_bf16: "cutlass::arch::OpMultiplyAddFastBF16",
  227. MathOperation.multiply_add_fast_f16: "cutlass::arch::OpMultiplyAddFastF16",
  228. MathOperation.multiply_add_complex: "cutlass::arch::OpMultiplyAddComplex",
  229. MathOperation.multiply_add_complex_gaussian: "cutlass::arch::OpMultiplyAddGaussianComplex",
  230. }
  231. ###################################################################################################
  232. #
  233. class LayoutType(enum.Enum):
  234. ColumnMajor = enum_auto()
  235. RowMajor = enum_auto()
  236. ColumnMajorInterleaved2 = enum_auto()
  237. RowMajorInterleaved2 = enum_auto()
  238. ColumnMajorInterleaved32 = enum_auto()
  239. RowMajorInterleaved32 = enum_auto()
  240. ColumnMajorInterleaved64 = enum_auto()
  241. RowMajorInterleaved64 = enum_auto()
  242. TensorNHWC = enum_auto()
  243. TensorNDHWC = enum_auto()
  244. TensorNCHW = enum_auto()
  245. TensorNGHWC = enum_auto()
  246. TensorNC4HW4 = enum_auto()
  247. TensorC4RSK4 = enum_auto()
  248. TensorNC8HW8 = enum_auto()
  249. TensorNC16HW16 = enum_auto()
  250. TensorNC32HW32 = enum_auto()
  251. TensorNC64HW64 = enum_auto()
  252. TensorC32RSK32 = enum_auto()
  253. TensorC64RSK64 = enum_auto()
  254. TensorK4RSC4 = enum_auto()
  255. TensorCK4RS4 = enum_auto()
  256. TensorCK8RS8 = enum_auto()
  257. TensorCK16RS16 = enum_auto()
  258. #
  259. LayoutTag = {
  260. LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor",
  261. LayoutType.RowMajor: "cutlass::layout::RowMajor",
  262. LayoutType.ColumnMajorInterleaved2: "cutlass::layout::ColumnMajorInterleaved<2>",
  263. LayoutType.RowMajorInterleaved2: "cutlass::layout::RowMajorInterleaved<2>",
  264. LayoutType.ColumnMajorInterleaved32: "cutlass::layout::ColumnMajorInterleaved<32>",
  265. LayoutType.RowMajorInterleaved32: "cutlass::layout::RowMajorInterleaved<32>",
  266. LayoutType.ColumnMajorInterleaved64: "cutlass::layout::ColumnMajorInterleaved<64>",
  267. LayoutType.RowMajorInterleaved64: "cutlass::layout::RowMajorInterleaved<64>",
  268. LayoutType.TensorNHWC: "cutlass::layout::TensorNHWC",
  269. LayoutType.TensorNDHWC: "cutlass::layout::TensorNDHWC",
  270. LayoutType.TensorNCHW: "cutlass::layout::TensorNCHW",
  271. LayoutType.TensorNGHWC: "cutlass::layout::TensorNGHWC",
  272. LayoutType.TensorNC4HW4: "cutlass::layout::TensorNCxHWx<4>",
  273. LayoutType.TensorC4RSK4: "cutlass::layout::TensorCxRSKx<4>",
  274. LayoutType.TensorNC8HW8: "cutlass::layout::TensorNCxHWx<8>",
  275. LayoutType.TensorNC16HW16: "cutlass::layout::TensorNCxHWx<16>",
  276. LayoutType.TensorNC32HW32: "cutlass::layout::TensorNCxHWx<32>",
  277. LayoutType.TensorC32RSK32: "cutlass::layout::TensorCxRSKx<32>",
  278. LayoutType.TensorNC64HW64: "cutlass::layout::TensorNCxHWx<64>",
  279. LayoutType.TensorC64RSK64: "cutlass::layout::TensorCxRSKx<64>",
  280. LayoutType.TensorK4RSC4: "cutlass::layout::TensorKxRSCx<4>",
  281. LayoutType.TensorCK4RS4: "cutlass::layout::TensorCKxRSx<4>",
  282. LayoutType.TensorCK8RS8: "cutlass::layout::TensorCKxRSx<8>",
  283. LayoutType.TensorCK16RS16: "cutlass::layout::TensorCKxRSx<16>",
  284. }
  285. #
  286. TransposedLayout = {
  287. LayoutType.ColumnMajor: LayoutType.RowMajor,
  288. LayoutType.RowMajor: LayoutType.ColumnMajor,
  289. LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
  290. LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
  291. LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
  292. LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
  293. LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
  294. LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
  295. LayoutType.TensorNHWC: LayoutType.TensorNHWC,
  296. }
  297. #
  298. ShortLayoutTypeNames = {
  299. LayoutType.ColumnMajor: "n",
  300. LayoutType.ColumnMajorInterleaved32: "n2",
  301. LayoutType.ColumnMajorInterleaved32: "n32",
  302. LayoutType.ColumnMajorInterleaved64: "n64",
  303. LayoutType.RowMajor: "t",
  304. LayoutType.RowMajorInterleaved2: "t2",
  305. LayoutType.RowMajorInterleaved32: "t32",
  306. LayoutType.RowMajorInterleaved64: "t64",
  307. LayoutType.TensorNHWC: "nhwc",
  308. LayoutType.TensorNDHWC: "ndhwc",
  309. LayoutType.TensorNCHW: "nchw",
  310. LayoutType.TensorNGHWC: "nghwc",
  311. LayoutType.TensorNC4HW4: "nc4hw4",
  312. LayoutType.TensorC4RSK4: "c4rsk4",
  313. LayoutType.TensorNC8HW8: "nc8hw8",
  314. LayoutType.TensorNC16HW16: "nc16hw16",
  315. LayoutType.TensorNC32HW32: "nc32hw32",
  316. LayoutType.TensorNC64HW64: "nc64hw64",
  317. LayoutType.TensorC32RSK32: "c32rsk32",
  318. LayoutType.TensorC64RSK64: "c64rsk64",
  319. LayoutType.TensorK4RSC4: "k4rsc4",
  320. LayoutType.TensorCK4RS4: "ck4rs4",
  321. LayoutType.TensorCK8RS8: "ck8rs8",
  322. LayoutType.TensorCK16RS16: "ck16rs16",
  323. }
  324. #
  325. ShortComplexLayoutNames = {
  326. (LayoutType.ColumnMajor, ComplexTransform.none): "n",
  327. (LayoutType.ColumnMajor, ComplexTransform.conj): "c",
  328. (LayoutType.RowMajor, ComplexTransform.none): "t",
  329. (LayoutType.RowMajor, ComplexTransform.conj): "h",
  330. }
  331. ###################################################################################################
  332. #
  333. class OpcodeClass(enum.Enum):
  334. Simt = enum_auto()
  335. TensorOp = enum_auto()
  336. WmmaTensorOp = enum_auto()
  337. OpcodeClassNames = {
  338. OpcodeClass.Simt: "simt",
  339. OpcodeClass.TensorOp: "tensorop",
  340. OpcodeClass.WmmaTensorOp: "wmma_tensorop",
  341. }
  342. OpcodeClassTag = {
  343. OpcodeClass.Simt: "cutlass::arch::OpClassSimt",
  344. OpcodeClass.TensorOp: "cutlass::arch::OpClassTensorOp",
  345. OpcodeClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp",
  346. }
  347. ###################################################################################################
  348. #
  349. class OperationKind(enum.Enum):
  350. Gemm = enum_auto()
  351. Conv2d = enum_auto()
  352. #
  353. OperationKindNames = {OperationKind.Gemm: "gemm", OperationKind.Conv2d: "conv2d"}
  354. #
  355. class Target(enum.Enum):
  356. library = enum_auto()
  357. ArchitectureNames = {
  358. 50: "maxwell",
  359. 60: "pascal",
  360. 61: "pascal",
  361. 70: "volta",
  362. 75: "turing",
  363. 80: "ampere",
  364. }
  365. ###################################################################################################
  366. #
  367. def SubstituteTemplate(template, values):
  368. text = template
  369. changed = True
  370. while changed:
  371. changed = False
  372. for key, value in values.items():
  373. regex = "\\$\\{%s\\}" % key
  374. newtext = re.sub(regex, value, text)
  375. if newtext != text:
  376. changed = True
  377. text = newtext
  378. return text
  379. ###################################################################################################
  380. #
  381. class GemmKind(enum.Enum):
  382. Gemm = enum_auto()
  383. Sparse = enum_auto()
  384. Universal = enum_auto()
  385. PlanarComplex = enum_auto()
  386. PlanarComplexArray = enum_auto()
  387. SplitKParallel = enum_auto()
  388. GemvBatchedStrided = enum_auto()
  389. #
  390. GemmKindNames = {
  391. GemmKind.Gemm: "gemm",
  392. GemmKind.Sparse: "spgemm",
  393. GemmKind.Universal: "gemm",
  394. GemmKind.PlanarComplex: "gemm_planar_complex",
  395. GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
  396. GemmKind.SplitKParallel: "gemm_split_k_parallel",
  397. GemmKind.GemvBatchedStrided: "gemv_batched_strided",
  398. }
  399. #
  400. class EpilogueFunctor(enum.Enum):
  401. LinearCombination = enum_auto()
  402. LinearCombinationClamp = enum_auto()
  403. BiasAddLinearCombination = enum_auto()
  404. BiasAddLinearCombinationRelu = enum_auto()
  405. BiasAddLinearCombinationHSwish = enum_auto()
  406. BiasAddLinearCombinationClamp = enum_auto()
  407. BiasAddLinearCombinationReluClamp = enum_auto()
  408. BiasAddLinearCombinationHSwishClamp = enum_auto()
  409. #
  410. EpilogueFunctorTag = {
  411. EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination",
  412. EpilogueFunctor.LinearCombinationClamp: "cutlass::epilogue::thread::LinearCombinationClamp",
  413. EpilogueFunctor.BiasAddLinearCombination: "cutlass::epilogue::thread::BiasAddLinearCombination",
  414. EpilogueFunctor.BiasAddLinearCombinationRelu: "cutlass::epilogue::thread::BiasAddLinearCombinationRelu",
  415. EpilogueFunctor.BiasAddLinearCombinationHSwish: "cutlass::epilogue::thread::BiasAddLinearCombinationHSwish",
  416. EpilogueFunctor.BiasAddLinearCombinationClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationClamp",
  417. EpilogueFunctor.BiasAddLinearCombinationReluClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp",
  418. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp",
  419. }
  420. #
  421. ShortEpilogueNames = {
  422. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "hswish",
  423. EpilogueFunctor.BiasAddLinearCombinationReluClamp: "relu",
  424. EpilogueFunctor.BiasAddLinearCombinationClamp: "id",
  425. EpilogueFunctor.BiasAddLinearCombinationHSwish: "hswish",
  426. EpilogueFunctor.BiasAddLinearCombinationRelu: "relu",
  427. EpilogueFunctor.BiasAddLinearCombination: "id",
  428. }
  429. #
  430. class SwizzlingFunctor(enum.Enum):
  431. Identity1 = enum_auto()
  432. Identity2 = enum_auto()
  433. Identity4 = enum_auto()
  434. Identity8 = enum_auto()
  435. ConvFpropNCxHWx = enum_auto()
  436. ConvFpropTrans = enum_auto()
  437. ConvDgradNCxHWx = enum_auto()
  438. ConvDgradTrans = enum_auto()
  439. DepthwiseConvolutionFprop = enum_auto()
  440. DepthwiseConvolutionDgrad = enum_auto()
  441. DepthwiseConvolutionWgrad = enum_auto()
  442. #
  443. SwizzlingFunctorTag = {
  444. SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>",
  445. SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>",
  446. SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
  447. SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
  448. SwizzlingFunctor.ConvFpropNCxHWx: "cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle",
  449. SwizzlingFunctor.ConvFpropTrans: "cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle",
  450. SwizzlingFunctor.ConvDgradNCxHWx: "cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle",
  451. SwizzlingFunctor.ConvDgradTrans: "cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle",
  452. SwizzlingFunctor.DepthwiseConvolutionFprop: "cutlass::conv::threadblock::DepthwiseConvolutionFpropThreadblockSwizzle",
  453. SwizzlingFunctor.DepthwiseConvolutionDgrad: "cutlass::conv::threadblock::DepthwiseConvolutionDgradThreadblockSwizzle",
  454. SwizzlingFunctor.DepthwiseConvolutionWgrad: "cutlass::conv::threadblock::DepthwiseConvolutionWgradThreadblockSwizzle",
  455. }
  456. ###################################################################################################
  457. class ConvType(enum.Enum):
  458. Convolution = enum_auto()
  459. BatchConvolution = enum_auto()
  460. Local = enum_auto()
  461. LocalShare = enum_auto()
  462. DepthwiseConvolution = enum_auto()
  463. ConvTypeTag = {
  464. ConvType.Convolution: "cutlass::conv::ConvType::kConvolution",
  465. ConvType.BatchConvolution: "cutlass::conv::ConvType::kBatchConvolution",
  466. ConvType.Local: "cutlass::conv::ConvType::kLocal",
  467. ConvType.LocalShare: "cutlass::conv::ConvType::kLocalShare",
  468. ConvType.DepthwiseConvolution: "cutlass::conv::ConvType::kDepthwiseConvolution",
  469. }
  470. #
  471. class ConvKind(enum.Enum):
  472. Fprop = enum_auto()
  473. Dgrad = enum_auto()
  474. Wgrad = enum_auto()
  475. #
  476. ConvKindTag = {
  477. ConvKind.Fprop: "cutlass::conv::Operator::kFprop",
  478. ConvKind.Dgrad: "cutlass::conv::Operator::kDgrad",
  479. ConvKind.Wgrad: "cutlass::conv::Operator::kWgrad",
  480. }
  481. ConvKindNames = {
  482. ConvKind.Fprop: "fprop",
  483. ConvKind.Dgrad: "dgrad",
  484. ConvKind.Wgrad: "wgrad",
  485. }
  486. #
  487. class IteratorAlgorithm(enum.Enum):
  488. Analytic = enum_auto()
  489. Optimized = enum_auto()
  490. #
  491. IteratorAlgorithmTag = {
  492. IteratorAlgorithm.Analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic",
  493. IteratorAlgorithm.Optimized: "cutlass::conv::IteratorAlgorithm::kOptimized",
  494. }
  495. IteratorAlgorithmNames = {
  496. IteratorAlgorithm.Analytic: "analytic",
  497. IteratorAlgorithm.Optimized: "optimized",
  498. }
  499. #
  500. class StrideSupport(enum.Enum):
  501. Strided = enum_auto()
  502. Unity = enum_auto()
  503. #
  504. StrideSupportTag = {
  505. StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided",
  506. StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity",
  507. }
  508. StrideSupportNames = {StrideSupport.Strided: "", StrideSupport.Unity: "unity_stride"}
  509. class SpecialOptimizeDesc(enum.Enum):
  510. NoneSpecialOpt = enum_auto()
  511. ConvFilterUnity = enum_auto()
  512. DeconvDoubleUpsampling = enum_auto()
  513. SpecialOptimizeDescNames = {
  514. SpecialOptimizeDesc.NoneSpecialOpt: "none",
  515. SpecialOptimizeDesc.ConvFilterUnity: "conv_filter_unity",
  516. SpecialOptimizeDesc.DeconvDoubleUpsampling: "deconv_double_upsampling",
  517. }
  518. SpecialOptimizeDescTag = {
  519. SpecialOptimizeDesc.NoneSpecialOpt: "cutlass::conv::SpecialOptimizeDesc::NONE",
  520. SpecialOptimizeDesc.ConvFilterUnity: "cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY",
  521. SpecialOptimizeDesc.DeconvDoubleUpsampling: "cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING",
  522. }
  523. class ImplicitGemmMode(enum.Enum):
  524. GemmNT = enum_auto()
  525. GemmTN = enum_auto()
  526. ImplicitGemmModeNames = {
  527. ImplicitGemmMode.GemmNT: "gemm_nt",
  528. ImplicitGemmMode.GemmTN: "gemm_tn",
  529. }
  530. ImplicitGemmModeTag = {
  531. ImplicitGemmMode.GemmNT: "cutlass::conv::ImplicitGemmMode::GEMM_NT",
  532. ImplicitGemmMode.GemmTN: "cutlass::conv::ImplicitGemmMode::GEMM_TN",
  533. }
  534. ###################################################################################################
  535. #
  536. class MathInstruction:
  537. def __init__(
  538. self,
  539. instruction_shape,
  540. element_a,
  541. element_b,
  542. element_accumulator,
  543. opcode_class,
  544. math_operation=MathOperation.multiply_add,
  545. ):
  546. self.instruction_shape = instruction_shape
  547. self.element_a = element_a
  548. self.element_b = element_b
  549. self.element_accumulator = element_accumulator
  550. self.opcode_class = opcode_class
  551. self.math_operation = math_operation
  552. #
  553. class TileDescription:
  554. def __init__(
  555. self,
  556. threadblock_shape,
  557. stages,
  558. warp_count,
  559. math_instruction,
  560. min_compute,
  561. max_compute,
  562. ):
  563. self.threadblock_shape = threadblock_shape
  564. self.stages = stages
  565. self.warp_count = warp_count
  566. self.math_instruction = math_instruction
  567. self.minimum_compute_capability = min_compute
  568. self.maximum_compute_capability = max_compute
  569. def procedural_name(self):
  570. return "%dx%d_%dx%d" % (
  571. self.threadblock_shape[0],
  572. self.threadblock_shape[1],
  573. self.threadblock_shape[2],
  574. self.stages,
  575. )
  576. #
  577. class TensorDescription:
  578. def __init__(
  579. self, element, layout, alignment=1, complex_transform=ComplexTransform.none
  580. ):
  581. self.element = element
  582. self.layout = layout
  583. self.alignment = alignment
  584. self.complex_transform = complex_transform
  585. ###################################################################################################
  586. class GlobalCnt:
  587. cnt = 0