You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

library.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import re
  8. ###################################################################################################
  9. # The following block implements enum.auto() for Python 3.5 variants that don't include it such
  10. # as the default 3.5.2 on Ubuntu 16.04.
  11. #
  12. # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
  13. try:
  14. from enum import auto as enum_auto
  15. except ImportError:
  16. __cutlass_library_auto_enum = 0
  17. def enum_auto() -> int:
  18. global __cutlass_library_auto_enum
  19. i = __cutlass_library_auto_enum
  20. __cutlass_library_auto_enum += 1
  21. return i
  22. ###################################################################################################
  23. #
  24. class GeneratorTarget(enum.Enum):
  25. Library = enum_auto()
  26. #
  27. GeneratorTargetNames = {GeneratorTarget.Library: "library"}
  28. #
  29. ###################################################################################################
  30. #
  31. class DataType(enum.Enum):
  32. b1 = enum_auto()
  33. u4 = enum_auto()
  34. u8 = enum_auto()
  35. u16 = enum_auto()
  36. u32 = enum_auto()
  37. u64 = enum_auto()
  38. s4 = enum_auto()
  39. s8 = enum_auto()
  40. s16 = enum_auto()
  41. s32 = enum_auto()
  42. s64 = enum_auto()
  43. f16 = enum_auto()
  44. bf16 = enum_auto()
  45. f32 = enum_auto()
  46. tf32 = enum_auto()
  47. f64 = enum_auto()
  48. cf16 = enum_auto()
  49. cbf16 = enum_auto()
  50. cf32 = enum_auto()
  51. ctf32 = enum_auto()
  52. cf64 = enum_auto()
  53. cs4 = enum_auto()
  54. cs8 = enum_auto()
  55. cs16 = enum_auto()
  56. cs32 = enum_auto()
  57. cs64 = enum_auto()
  58. cu4 = enum_auto()
  59. cu8 = enum_auto()
  60. cu16 = enum_auto()
  61. cu32 = enum_auto()
  62. cu64 = enum_auto()
  63. invalid = enum_auto()
  64. #
  65. ShortDataTypeNames = {
  66. DataType.s32: "i",
  67. DataType.f16: "h",
  68. DataType.f32: "s",
  69. DataType.f64: "d",
  70. DataType.cf32: "c",
  71. DataType.cf64: "z",
  72. }
  73. #
  74. DataTypeNames = {
  75. DataType.b1: "b1",
  76. DataType.u4: "u4",
  77. DataType.u8: "u8",
  78. DataType.u16: "u16",
  79. DataType.u32: "u32",
  80. DataType.u64: "u64",
  81. DataType.s4: "s4",
  82. DataType.s8: "s8",
  83. DataType.s16: "s16",
  84. DataType.s32: "s32",
  85. DataType.s64: "s64",
  86. DataType.f16: "f16",
  87. DataType.bf16: "bf16",
  88. DataType.f32: "f32",
  89. DataType.tf32: "tf32",
  90. DataType.f64: "f64",
  91. DataType.cf16: "cf16",
  92. DataType.cbf16: "cbf16",
  93. DataType.cf32: "cf32",
  94. DataType.ctf32: "ctf32",
  95. DataType.cf64: "cf64",
  96. DataType.cu4: "cu4",
  97. DataType.cu8: "cu8",
  98. DataType.cu16: "cu16",
  99. DataType.cu32: "cu32",
  100. DataType.cu64: "cu64",
  101. DataType.cs4: "cs4",
  102. DataType.cs8: "cs8",
  103. DataType.cs16: "cs16",
  104. DataType.cs32: "cs32",
  105. DataType.cs64: "cs64",
  106. }
  107. DataTypeTag = {
  108. DataType.b1: "cutlass::uint1b_t",
  109. DataType.u4: "cutlass::uint4b_t",
  110. DataType.u8: "uint8_t",
  111. DataType.u16: "uint16_t",
  112. DataType.u32: "uint32_t",
  113. DataType.u64: "uint64_t",
  114. DataType.s4: "cutlass::int4b_t",
  115. DataType.s8: "int8_t",
  116. DataType.s16: "int16_t",
  117. DataType.s32: "int32_t",
  118. DataType.s64: "int64_t",
  119. DataType.f16: "cutlass::half_t",
  120. DataType.bf16: "cutlass::bfloat16_t",
  121. DataType.f32: "float",
  122. DataType.tf32: "cutlass::tfloat32_t",
  123. DataType.f64: "double",
  124. DataType.cf16: "cutlass::complex<cutlass::half_t>",
  125. DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
  126. DataType.cf32: "cutlass::complex<float>",
  127. DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
  128. DataType.cf64: "cutlass::complex<double>",
  129. DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
  130. DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
  131. DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
  132. DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
  133. DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
  134. DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
  135. DataType.cs8: "cutlass::complex<cutlass::int8_t>",
  136. DataType.cs16: "cutlass::complex<cutlass::int16_t>",
  137. DataType.cs32: "cutlass::complex<cutlass::int32_t>",
  138. DataType.cs64: "cutlass::complex<cutlass::int64_t>",
  139. }
  140. DataTypeSize = {
  141. DataType.b1: 1,
  142. DataType.u4: 4,
  143. DataType.u8: 4,
  144. DataType.u16: 16,
  145. DataType.u32: 32,
  146. DataType.u64: 64,
  147. DataType.s4: 4,
  148. DataType.s8: 8,
  149. DataType.s16: 16,
  150. DataType.s32: 32,
  151. DataType.s64: 64,
  152. DataType.f16: 16,
  153. DataType.bf16: 16,
  154. DataType.f32: 32,
  155. DataType.tf32: 32,
  156. DataType.f64: 64,
  157. DataType.cf16: 32,
  158. DataType.cbf16: 32,
  159. DataType.cf32: 64,
  160. DataType.ctf32: 32,
  161. DataType.cf64: 128,
  162. DataType.cu4: 8,
  163. DataType.cu8: 16,
  164. DataType.cu16: 32,
  165. DataType.cu32: 64,
  166. DataType.cu64: 128,
  167. DataType.cs4: 8,
  168. DataType.cs8: 16,
  169. DataType.cs16: 32,
  170. DataType.cs32: 64,
  171. DataType.cs64: 128,
  172. }
  173. ###################################################################################################
  174. #
  175. class ComplexTransform(enum.Enum):
  176. none = enum_auto()
  177. conj = enum_auto()
  178. #
  179. ComplexTransformTag = {
  180. ComplexTransform.none: "cutlass::ComplexTransform::kNone",
  181. ComplexTransform.conj: "cutlass::ComplexTransform::kConjugate",
  182. }
  183. #
  184. RealComplexBijection = [
  185. (DataType.f16, DataType.cf16),
  186. (DataType.f32, DataType.cf32),
  187. (DataType.f64, DataType.cf64),
  188. ]
  189. #
  190. def is_complex(data_type):
  191. for r, c in RealComplexBijection:
  192. if data_type == c:
  193. return True
  194. return False
  195. #
  196. def get_complex_from_real(real_type):
  197. for r, c in RealComplexBijection:
  198. if real_type == r:
  199. return c
  200. return DataType.invalid
  201. #
  202. def get_real_from_complex(complex_type):
  203. for r, c in RealComplexBijection:
  204. if complex_type == c:
  205. return r
  206. return DataType.invalid
  207. #
  208. class ComplexMultiplyOp(enum.Enum):
  209. multiply_add = enum_auto()
  210. gaussian = enum_auto()
  211. ###################################################################################################
  212. #
  213. class MathOperation(enum.Enum):
  214. multiply_add = enum_auto()
  215. multiply_add_saturate = enum_auto()
  216. xor_popc = enum_auto()
  217. multiply_add_fast_bf16 = enum_auto()
  218. multiply_add_fast_f16 = enum_auto()
  219. multiply_add_complex = enum_auto()
  220. multiply_add_complex_gaussian = enum_auto()
  221. #
  222. MathOperationTag = {
  223. MathOperation.multiply_add: "cutlass::arch::OpMultiplyAdd",
  224. MathOperation.multiply_add_saturate: "cutlass::arch::OpMultiplyAddSaturate",
  225. MathOperation.xor_popc: "cutlass::arch::OpXorPopc",
  226. MathOperation.multiply_add_fast_bf16: "cutlass::arch::OpMultiplyAddFastBF16",
  227. MathOperation.multiply_add_fast_f16: "cutlass::arch::OpMultiplyAddFastF16",
  228. MathOperation.multiply_add_complex: "cutlass::arch::OpMultiplyAddComplex",
  229. MathOperation.multiply_add_complex_gaussian: "cutlass::arch::OpMultiplyAddGaussianComplex",
  230. }
  231. ###################################################################################################
  232. #
  233. class LayoutType(enum.Enum):
  234. ColumnMajor = enum_auto()
  235. RowMajor = enum_auto()
  236. ColumnMajorInterleaved2 = enum_auto()
  237. RowMajorInterleaved2 = enum_auto()
  238. ColumnMajorInterleaved32 = enum_auto()
  239. RowMajorInterleaved32 = enum_auto()
  240. ColumnMajorInterleaved64 = enum_auto()
  241. RowMajorInterleaved64 = enum_auto()
  242. TensorNHWC = enum_auto()
  243. TensorNDHWC = enum_auto()
  244. TensorNCHW = enum_auto()
  245. TensorNGHWC = enum_auto()
  246. TensorNC4HW4 = enum_auto()
  247. TensorC4RSK4 = enum_auto()
  248. TensorNC8HW8 = enum_auto()
  249. TensorNC16HW16 = enum_auto()
  250. TensorNC32HW32 = enum_auto()
  251. TensorNC64HW64 = enum_auto()
  252. TensorC32RSK32 = enum_auto()
  253. TensorC64RSK64 = enum_auto()
  254. TensorK4RSC4 = enum_auto()
  255. TensorCK4RS4 = enum_auto()
  256. TensorCK8RS8 = enum_auto()
  257. TensorCK16RS16 = enum_auto()
  258. #
  259. LayoutTag = {
  260. LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor",
  261. LayoutType.RowMajor: "cutlass::layout::RowMajor",
  262. LayoutType.ColumnMajorInterleaved2: "cutlass::layout::ColumnMajorInterleaved<2>",
  263. LayoutType.RowMajorInterleaved2: "cutlass::layout::RowMajorInterleaved<2>",
  264. LayoutType.ColumnMajorInterleaved32: "cutlass::layout::ColumnMajorInterleaved<32>",
  265. LayoutType.RowMajorInterleaved32: "cutlass::layout::RowMajorInterleaved<32>",
  266. LayoutType.ColumnMajorInterleaved64: "cutlass::layout::ColumnMajorInterleaved<64>",
  267. LayoutType.RowMajorInterleaved64: "cutlass::layout::RowMajorInterleaved<64>",
  268. LayoutType.TensorNHWC: "cutlass::layout::TensorNHWC",
  269. LayoutType.TensorNDHWC: "cutlass::layout::TensorNDHWC",
  270. LayoutType.TensorNCHW: "cutlass::layout::TensorNCHW",
  271. LayoutType.TensorNGHWC: "cutlass::layout::TensorNGHWC",
  272. LayoutType.TensorNC4HW4: "cutlass::layout::TensorNCxHWx<4>",
  273. LayoutType.TensorC4RSK4: "cutlass::layout::TensorCxRSKx<4>",
  274. LayoutType.TensorNC8HW8: "cutlass::layout::TensorNCxHWx<8>",
  275. LayoutType.TensorNC16HW16: "cutlass::layout::TensorNCxHWx<16>",
  276. LayoutType.TensorNC32HW32: "cutlass::layout::TensorNCxHWx<32>",
  277. LayoutType.TensorC32RSK32: "cutlass::layout::TensorCxRSKx<32>",
  278. LayoutType.TensorNC64HW64: "cutlass::layout::TensorNCxHWx<64>",
  279. LayoutType.TensorC64RSK64: "cutlass::layout::TensorCxRSKx<64>",
  280. LayoutType.TensorK4RSC4: "cutlass::layout::TensorKxRSCx<4>",
  281. LayoutType.TensorCK4RS4: "cutlass::layout::TensorCKxRSx<4>",
  282. LayoutType.TensorCK8RS8: "cutlass::layout::TensorCKxRSx<8>",
  283. LayoutType.TensorCK16RS16: "cutlass::layout::TensorCKxRSx<16>",
  284. }
  285. #
  286. TransposedLayout = {
  287. LayoutType.ColumnMajor: LayoutType.RowMajor,
  288. LayoutType.RowMajor: LayoutType.ColumnMajor,
  289. LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
  290. LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
  291. LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
  292. LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
  293. LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
  294. LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
  295. LayoutType.TensorNHWC: LayoutType.TensorNHWC,
  296. }
  297. #
  298. ShortLayoutTypeNames = {
  299. LayoutType.ColumnMajor: "n",
  300. LayoutType.ColumnMajorInterleaved32: "n2",
  301. LayoutType.ColumnMajorInterleaved32: "n32",
  302. LayoutType.ColumnMajorInterleaved64: "n64",
  303. LayoutType.RowMajor: "t",
  304. LayoutType.RowMajorInterleaved2: "t2",
  305. LayoutType.RowMajorInterleaved32: "t32",
  306. LayoutType.RowMajorInterleaved64: "t64",
  307. LayoutType.TensorNHWC: "nhwc",
  308. LayoutType.TensorNDHWC: "ndhwc",
  309. LayoutType.TensorNCHW: "nchw",
  310. LayoutType.TensorNGHWC: "nghwc",
  311. LayoutType.TensorNC4HW4: "nc4hw4",
  312. LayoutType.TensorC4RSK4: "c4rsk4",
  313. LayoutType.TensorNC8HW8: "nc8hw8",
  314. LayoutType.TensorNC16HW16: "nc16hw16",
  315. LayoutType.TensorNC32HW32: "nc32hw32",
  316. LayoutType.TensorNC64HW64: "nc64hw64",
  317. LayoutType.TensorC32RSK32: "c32rsk32",
  318. LayoutType.TensorC64RSK64: "c64rsk64",
  319. LayoutType.TensorK4RSC4: "k4rsc4",
  320. LayoutType.TensorCK4RS4: "ck4rs4",
  321. LayoutType.TensorCK8RS8: "ck8rs8",
  322. LayoutType.TensorCK16RS16: "ck16rs16",
  323. }
  324. #
  325. ShortComplexLayoutNames = {
  326. (LayoutType.ColumnMajor, ComplexTransform.none): "n",
  327. (LayoutType.ColumnMajor, ComplexTransform.conj): "c",
  328. (LayoutType.RowMajor, ComplexTransform.none): "t",
  329. (LayoutType.RowMajor, ComplexTransform.conj): "h",
  330. }
  331. ###################################################################################################
  332. #
  333. class OpcodeClass(enum.Enum):
  334. Simt = enum_auto()
  335. TensorOp = enum_auto()
  336. WmmaTensorOp = enum_auto()
  337. OpcodeClassNames = {
  338. OpcodeClass.Simt: "simt",
  339. OpcodeClass.TensorOp: "tensorop",
  340. OpcodeClass.WmmaTensorOp: "wmma_tensorop",
  341. }
  342. OpcodeClassTag = {
  343. OpcodeClass.Simt: "cutlass::arch::OpClassSimt",
  344. OpcodeClass.TensorOp: "cutlass::arch::OpClassTensorOp",
  345. OpcodeClass.WmmaTensorOp: "cutlass::arch::OpClassWmmaTensorOp",
  346. }
  347. ###################################################################################################
  348. #
  349. class OperationKind(enum.Enum):
  350. Gemm = enum_auto()
  351. Conv2d = enum_auto()
  352. #
  353. OperationKindNames = {OperationKind.Gemm: "gemm", OperationKind.Conv2d: "conv2d"}
  354. #
  355. class Target(enum.Enum):
  356. library = enum_auto()
  357. ArchitectureNames = {
  358. 50: "maxwell",
  359. 60: "pascal",
  360. 61: "pascal",
  361. 70: "volta",
  362. 75: "turing",
  363. 80: "ampere",
  364. }
  365. ###################################################################################################
  366. #
  367. def SubstituteTemplate(template, values):
  368. text = template
  369. changed = True
  370. while changed:
  371. changed = False
  372. for key, value in values.items():
  373. regex = "\\$\\{%s\\}" % key
  374. newtext = re.sub(regex, value, text)
  375. if newtext != text:
  376. changed = True
  377. text = newtext
  378. return text
  379. ###################################################################################################
  380. #
  381. class GemmKind(enum.Enum):
  382. Gemm = enum_auto()
  383. Sparse = enum_auto()
  384. Universal = enum_auto()
  385. PlanarComplex = enum_auto()
  386. PlanarComplexArray = enum_auto()
  387. SplitKParallel = enum_auto()
  388. GemvBatchedStrided = enum_auto()
  389. #
  390. GemmKindNames = {
  391. GemmKind.Gemm: "gemm",
  392. GemmKind.Sparse: "spgemm",
  393. GemmKind.Universal: "gemm",
  394. GemmKind.PlanarComplex: "gemm_planar_complex",
  395. GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
  396. GemmKind.SplitKParallel: "gemm_split_k_parallel",
  397. GemmKind.GemvBatchedStrided: "gemv_batched_strided",
  398. }
  399. #
  400. class EpilogueFunctor(enum.Enum):
  401. LinearCombination = enum_auto()
  402. LinearCombinationClamp = enum_auto()
  403. BiasAddLinearCombination = enum_auto()
  404. BiasAddLinearCombinationRelu = enum_auto()
  405. BiasAddLinearCombinationHSwish = enum_auto()
  406. BiasAddLinearCombinationClamp = enum_auto()
  407. BiasAddLinearCombinationReluClamp = enum_auto()
  408. BiasAddLinearCombinationHSwishClamp = enum_auto()
  409. #
  410. EpilogueFunctorTag = {
  411. EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination",
  412. EpilogueFunctor.LinearCombinationClamp: "cutlass::epilogue::thread::LinearCombinationClamp",
  413. EpilogueFunctor.BiasAddLinearCombination: "cutlass::epilogue::thread::BiasAddLinearCombination",
  414. EpilogueFunctor.BiasAddLinearCombinationRelu: "cutlass::epilogue::thread::BiasAddLinearCombinationRelu",
  415. EpilogueFunctor.BiasAddLinearCombinationHSwish: "cutlass::epilogue::thread::BiasAddLinearCombinationHSwish",
  416. EpilogueFunctor.BiasAddLinearCombinationClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationClamp",
  417. EpilogueFunctor.BiasAddLinearCombinationReluClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp",
  418. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp",
  419. }
  420. #
  421. ShortEpilogueNames = {
  422. EpilogueFunctor.LinearCombination: "id",
  423. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: "hswish",
  424. EpilogueFunctor.BiasAddLinearCombinationReluClamp: "relu",
  425. EpilogueFunctor.BiasAddLinearCombinationClamp: "id",
  426. EpilogueFunctor.BiasAddLinearCombinationHSwish: "hswish",
  427. EpilogueFunctor.BiasAddLinearCombinationRelu: "relu",
  428. EpilogueFunctor.BiasAddLinearCombination: "id",
  429. }
  430. #
  431. class SwizzlingFunctor(enum.Enum):
  432. Identity1 = enum_auto()
  433. Identity2 = enum_auto()
  434. Identity4 = enum_auto()
  435. Identity8 = enum_auto()
  436. ConvFpropNCxHWx = enum_auto()
  437. ConvFpropTrans = enum_auto()
  438. ConvDgradNCxHWx = enum_auto()
  439. ConvDgradTrans = enum_auto()
  440. DepthwiseConvolutionFprop = enum_auto()
  441. DepthwiseConvolutionDgrad = enum_auto()
  442. DepthwiseConvolutionWgrad = enum_auto()
  443. #
  444. SwizzlingFunctorTag = {
  445. SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>",
  446. SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>",
  447. SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
  448. SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
  449. SwizzlingFunctor.ConvFpropNCxHWx: "cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle",
  450. SwizzlingFunctor.ConvFpropTrans: "cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle",
  451. SwizzlingFunctor.ConvDgradNCxHWx: "cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle",
  452. SwizzlingFunctor.ConvDgradTrans: "cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle",
  453. SwizzlingFunctor.DepthwiseConvolutionFprop: "cutlass::conv::threadblock::DepthwiseConvolutionFpropThreadblockSwizzle",
  454. SwizzlingFunctor.DepthwiseConvolutionDgrad: "cutlass::conv::threadblock::DepthwiseConvolutionDgradThreadblockSwizzle",
  455. SwizzlingFunctor.DepthwiseConvolutionWgrad: "cutlass::conv::threadblock::DepthwiseConvolutionWgradThreadblockSwizzle",
  456. }
  457. ###################################################################################################
  458. class ConvType(enum.Enum):
  459. Convolution = enum_auto()
  460. BatchConvolution = enum_auto()
  461. Local = enum_auto()
  462. LocalShare = enum_auto()
  463. DepthwiseConvolution = enum_auto()
  464. ConvTypeTag = {
  465. ConvType.Convolution: "cutlass::conv::ConvType::kConvolution",
  466. ConvType.BatchConvolution: "cutlass::conv::ConvType::kBatchConvolution",
  467. ConvType.Local: "cutlass::conv::ConvType::kLocal",
  468. ConvType.LocalShare: "cutlass::conv::ConvType::kLocalShare",
  469. ConvType.DepthwiseConvolution: "cutlass::conv::ConvType::kDepthwiseConvolution",
  470. }
  471. #
  472. class ConvKind(enum.Enum):
  473. Fprop = enum_auto()
  474. Dgrad = enum_auto()
  475. Wgrad = enum_auto()
  476. #
  477. ConvKindTag = {
  478. ConvKind.Fprop: "cutlass::conv::Operator::kFprop",
  479. ConvKind.Dgrad: "cutlass::conv::Operator::kDgrad",
  480. ConvKind.Wgrad: "cutlass::conv::Operator::kWgrad",
  481. }
  482. ConvKindNames = {
  483. ConvKind.Fprop: "fprop",
  484. ConvKind.Dgrad: "dgrad",
  485. ConvKind.Wgrad: "wgrad",
  486. }
  487. #
  488. class IteratorAlgorithm(enum.Enum):
  489. Analytic = enum_auto()
  490. Optimized = enum_auto()
  491. #
  492. IteratorAlgorithmTag = {
  493. IteratorAlgorithm.Analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic",
  494. IteratorAlgorithm.Optimized: "cutlass::conv::IteratorAlgorithm::kOptimized",
  495. }
  496. IteratorAlgorithmNames = {
  497. IteratorAlgorithm.Analytic: "analytic",
  498. IteratorAlgorithm.Optimized: "optimized",
  499. }
  500. #
  501. class StrideSupport(enum.Enum):
  502. Strided = enum_auto()
  503. Unity = enum_auto()
  504. #
  505. StrideSupportTag = {
  506. StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided",
  507. StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity",
  508. }
  509. StrideSupportNames = {StrideSupport.Strided: "", StrideSupport.Unity: "unity_stride"}
  510. class SpecialOptimizeDesc(enum.Enum):
  511. NoneSpecialOpt = enum_auto()
  512. ConvFilterUnity = enum_auto()
  513. DeconvDoubleUpsampling = enum_auto()
  514. SpecialOptimizeDescNames = {
  515. SpecialOptimizeDesc.NoneSpecialOpt: "none",
  516. SpecialOptimizeDesc.ConvFilterUnity: "conv_filter_unity",
  517. SpecialOptimizeDesc.DeconvDoubleUpsampling: "deconv_double_upsampling",
  518. }
  519. SpecialOptimizeDescTag = {
  520. SpecialOptimizeDesc.NoneSpecialOpt: "cutlass::conv::SpecialOptimizeDesc::NONE",
  521. SpecialOptimizeDesc.ConvFilterUnity: "cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY",
  522. SpecialOptimizeDesc.DeconvDoubleUpsampling: "cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING",
  523. }
  524. class ImplicitGemmMode(enum.Enum):
  525. GemmNT = enum_auto()
  526. GemmTN = enum_auto()
  527. ImplicitGemmModeNames = {
  528. ImplicitGemmMode.GemmNT: "gemm_nt",
  529. ImplicitGemmMode.GemmTN: "gemm_tn",
  530. }
  531. ImplicitGemmModeTag = {
  532. ImplicitGemmMode.GemmNT: "cutlass::conv::ImplicitGemmMode::GEMM_NT",
  533. ImplicitGemmMode.GemmTN: "cutlass::conv::ImplicitGemmMode::GEMM_TN",
  534. }
  535. ###################################################################################################
  536. #
  537. class MathInstruction:
  538. def __init__(
  539. self,
  540. instruction_shape,
  541. element_a,
  542. element_b,
  543. element_accumulator,
  544. opcode_class,
  545. math_operation=MathOperation.multiply_add,
  546. ):
  547. self.instruction_shape = instruction_shape
  548. self.element_a = element_a
  549. self.element_b = element_b
  550. self.element_accumulator = element_accumulator
  551. self.opcode_class = opcode_class
  552. self.math_operation = math_operation
  553. #
  554. class TileDescription:
  555. def __init__(
  556. self,
  557. threadblock_shape,
  558. stages,
  559. warp_count,
  560. math_instruction,
  561. min_compute,
  562. max_compute,
  563. ):
  564. self.threadblock_shape = threadblock_shape
  565. self.stages = stages
  566. self.warp_count = warp_count
  567. self.math_instruction = math_instruction
  568. self.minimum_compute_capability = min_compute
  569. self.maximum_compute_capability = max_compute
  570. def procedural_name(self):
  571. return "%dx%d_%dx%d" % (
  572. self.threadblock_shape[0],
  573. self.threadblock_shape[1],
  574. self.threadblock_shape[2],
  575. self.stages,
  576. )
  577. #
  578. class TensorDescription:
  579. def __init__(
  580. self, element, layout, alignment=1, complex_transform=ComplexTransform.none
  581. ):
  582. self.element = element
  583. self.layout = layout
  584. self.alignment = alignment
  585. self.complex_transform = complex_transform
  586. ###################################################################################################
  587. class GlobalCnt:
  588. cnt = 0