You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

library.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import re
  7. ###################################################################################################
  8. import enum
  9. # The following block implements enum.auto() for Python 3.5 variants that don't include it such
  10. # as the default 3.5.2 on Ubuntu 16.04.
  11. #
  12. # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
  13. try:
  14. from enum import auto as enum_auto
  15. except ImportError:
  16. __cutlass_library_auto_enum = 0
  17. def enum_auto() -> int:
  18. global __cutlass_library_auto_enum
  19. i = __cutlass_library_auto_enum
  20. __cutlass_library_auto_enum += 1
  21. return i
  22. ###################################################################################################
  23. #
  24. class GeneratorTarget(enum.Enum):
  25. Library = enum_auto()
  26. #
  27. GeneratorTargetNames = {
  28. GeneratorTarget.Library: 'library'
  29. }
  30. #
  31. ###################################################################################################
  32. #
  33. class DataType(enum.Enum):
  34. b1 = enum_auto()
  35. u4 = enum_auto()
  36. u8 = enum_auto()
  37. u16 = enum_auto()
  38. u32 = enum_auto()
  39. u64 = enum_auto()
  40. s4 = enum_auto()
  41. s8 = enum_auto()
  42. s16 = enum_auto()
  43. s32 = enum_auto()
  44. s64 = enum_auto()
  45. f16 = enum_auto()
  46. bf16 = enum_auto()
  47. f32 = enum_auto()
  48. tf32 = enum_auto()
  49. f64 = enum_auto()
  50. cf16 = enum_auto()
  51. cbf16 = enum_auto()
  52. cf32 = enum_auto()
  53. ctf32 = enum_auto()
  54. cf64 = enum_auto()
  55. cs4 = enum_auto()
  56. cs8 = enum_auto()
  57. cs16 = enum_auto()
  58. cs32 = enum_auto()
  59. cs64 = enum_auto()
  60. cu4 = enum_auto()
  61. cu8 = enum_auto()
  62. cu16 = enum_auto()
  63. cu32 = enum_auto()
  64. cu64 = enum_auto()
  65. invalid = enum_auto()
  66. #
  67. ShortDataTypeNames = {
  68. DataType.s32: 'i',
  69. DataType.f16: 'h',
  70. DataType.f32: 's',
  71. DataType.f64: 'd',
  72. DataType.cf32: 'c',
  73. DataType.cf64: 'z',
  74. }
  75. #
  76. DataTypeNames = {
  77. DataType.b1: "b1",
  78. DataType.u4: "u4",
  79. DataType.u8: "u8",
  80. DataType.u16: "u16",
  81. DataType.u32: "u32",
  82. DataType.u64: "u64",
  83. DataType.s4: "s4",
  84. DataType.s8: "s8",
  85. DataType.s16: "s16",
  86. DataType.s32: "s32",
  87. DataType.s64: "s64",
  88. DataType.f16: "f16",
  89. DataType.bf16: "bf16",
  90. DataType.f32: "f32",
  91. DataType.tf32: "tf32",
  92. DataType.f64: "f64",
  93. DataType.cf16: "cf16",
  94. DataType.cbf16: "cbf16",
  95. DataType.cf32: "cf32",
  96. DataType.ctf32: "ctf32",
  97. DataType.cf64: "cf64",
  98. DataType.cu4: "cu4",
  99. DataType.cu8: "cu8",
  100. DataType.cu16: "cu16",
  101. DataType.cu32: "cu32",
  102. DataType.cu64: "cu64",
  103. DataType.cs4: "cs4",
  104. DataType.cs8: "cs8",
  105. DataType.cs16: "cs16",
  106. DataType.cs32: "cs32",
  107. DataType.cs64: "cs64",
  108. }
  109. DataTypeTag = {
  110. DataType.b1: "cutlass::uint1b_t",
  111. DataType.u4: "cutlass::uint4b_t",
  112. DataType.u8: "uint8_t",
  113. DataType.u16: "uint16_t",
  114. DataType.u32: "uint32_t",
  115. DataType.u64: "uint64_t",
  116. DataType.s4: "cutlass::int4b_t",
  117. DataType.s8: "int8_t",
  118. DataType.s16: "int16_t",
  119. DataType.s32: "int32_t",
  120. DataType.s64: "int64_t",
  121. DataType.f16: "cutlass::half_t",
  122. DataType.bf16: "cutlass::bfloat16_t",
  123. DataType.f32: "float",
  124. DataType.tf32: "cutlass::tfloat32_t",
  125. DataType.f64: "double",
  126. DataType.cf16: "cutlass::complex<cutlass::half_t>",
  127. DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
  128. DataType.cf32: "cutlass::complex<float>",
  129. DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
  130. DataType.cf64: "cutlass::complex<double>",
  131. DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
  132. DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
  133. DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
  134. DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
  135. DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
  136. DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
  137. DataType.cs8: "cutlass::complex<cutlass::int8_t>",
  138. DataType.cs16: "cutlass::complex<cutlass::int16_t>",
  139. DataType.cs32: "cutlass::complex<cutlass::int32_t>",
  140. DataType.cs64: "cutlass::complex<cutlass::int64_t>",
  141. }
  142. DataTypeSize = {
  143. DataType.b1: 1,
  144. DataType.u4: 4,
  145. DataType.u8: 4,
  146. DataType.u16: 16,
  147. DataType.u32: 32,
  148. DataType.u64: 64,
  149. DataType.s4: 4,
  150. DataType.s8: 8,
  151. DataType.s16: 16,
  152. DataType.s32: 32,
  153. DataType.s64: 64,
  154. DataType.f16: 16,
  155. DataType.bf16: 16,
  156. DataType.f32: 32,
  157. DataType.tf32: 32,
  158. DataType.f64: 64,
  159. DataType.cf16: 32,
  160. DataType.cbf16: 32,
  161. DataType.cf32: 64,
  162. DataType.ctf32: 32,
  163. DataType.cf64: 128,
  164. DataType.cu4: 8,
  165. DataType.cu8: 16,
  166. DataType.cu16: 32,
  167. DataType.cu32: 64,
  168. DataType.cu64: 128,
  169. DataType.cs4: 8,
  170. DataType.cs8: 16,
  171. DataType.cs16: 32,
  172. DataType.cs32: 64,
  173. DataType.cs64: 128,
  174. }
  175. ###################################################################################################
  176. #
  177. class ComplexTransform(enum.Enum):
  178. none = enum_auto()
  179. conj = enum_auto()
  180. #
  181. ComplexTransformTag = {
  182. ComplexTransform.none: 'cutlass::ComplexTransform::kNone',
  183. ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
  184. }
  185. #
  186. RealComplexBijection = [
  187. (DataType.f16, DataType.cf16),
  188. (DataType.f32, DataType.cf32),
  189. (DataType.f64, DataType.cf64),
  190. ]
  191. #
  192. def is_complex(data_type):
  193. for r, c in RealComplexBijection:
  194. if data_type == c:
  195. return True
  196. return False
  197. #
  198. def get_complex_from_real(real_type):
  199. for r, c in RealComplexBijection:
  200. if real_type == r:
  201. return c
  202. return DataType.invalid
  203. #
  204. def get_real_from_complex(complex_type):
  205. for r, c in RealComplexBijection:
  206. if complex_type == c:
  207. return r
  208. return DataType.invalid
  209. #
  210. class ComplexMultiplyOp(enum.Enum):
  211. multiply_add = enum_auto()
  212. gaussian = enum_auto()
  213. ###################################################################################################
  214. #
  215. class MathOperation(enum.Enum):
  216. multiply_add = enum_auto()
  217. multiply_add_saturate = enum_auto()
  218. xor_popc = enum_auto()
  219. multiply_add_fast_bf16 = enum_auto()
  220. multiply_add_fast_f16 = enum_auto()
  221. multiply_add_complex = enum_auto()
  222. multiply_add_complex_gaussian = enum_auto()
  223. #
  224. MathOperationTag = {
  225. MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
  226. MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
  227. MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
  228. MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
  229. MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16',
  230. MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex',
  231. MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex',
  232. }
  233. ###################################################################################################
  234. #
  235. class LayoutType(enum.Enum):
  236. ColumnMajor = enum_auto()
  237. RowMajor = enum_auto()
  238. ColumnMajorInterleaved2 = enum_auto()
  239. RowMajorInterleaved2 = enum_auto()
  240. ColumnMajorInterleaved32 = enum_auto()
  241. RowMajorInterleaved32 = enum_auto()
  242. ColumnMajorInterleaved64 = enum_auto()
  243. RowMajorInterleaved64 = enum_auto()
  244. TensorNHWC = enum_auto()
  245. TensorNDHWC = enum_auto()
  246. TensorNCHW = enum_auto()
  247. TensorNGHWC = enum_auto()
  248. TensorNC4HW4 = enum_auto()
  249. TensorC4RSK4 = enum_auto()
  250. TensorNC8HW8 = enum_auto()
  251. TensorNC16HW16 = enum_auto()
  252. TensorNC32HW32 = enum_auto()
  253. TensorNC64HW64 = enum_auto()
  254. TensorC32RSK32 = enum_auto()
  255. TensorC64RSK64 = enum_auto()
  256. TensorK4RSC4 = enum_auto()
  257. TensorCK4RS4 = enum_auto()
  258. TensorCK8RS8 = enum_auto()
  259. TensorCK16RS16 = enum_auto()
  260. #
  261. LayoutTag = {
  262. LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor',
  263. LayoutType.RowMajor: 'cutlass::layout::RowMajor',
  264. LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>',
  265. LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>',
  266. LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>',
  267. LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>',
  268. LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>',
  269. LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>',
  270. LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC',
  271. LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC',
  272. LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW',
  273. LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC',
  274. LayoutType.TensorNC4HW4: 'cutlass::layout::TensorNCxHWx<4>',
  275. LayoutType.TensorC4RSK4: 'cutlass::layout::TensorCxRSKx<4>',
  276. LayoutType.TensorNC8HW8: 'cutlass::layout::TensorNCxHWx<8>',
  277. LayoutType.TensorNC16HW16: 'cutlass::layout::TensorNCxHWx<16>',
  278. LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>',
  279. LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>',
  280. LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>',
  281. LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>',
  282. LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>',
  283. LayoutType.TensorCK4RS4: 'cutlass::layout::TensorCKxRSx<4>',
  284. LayoutType.TensorCK8RS8: 'cutlass::layout::TensorCKxRSx<8>',
  285. LayoutType.TensorCK16RS16: 'cutlass::layout::TensorCKxRSx<16>',
  286. }
  287. #
  288. TransposedLayout = {
  289. LayoutType.ColumnMajor: LayoutType.RowMajor,
  290. LayoutType.RowMajor: LayoutType.ColumnMajor,
  291. LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
  292. LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
  293. LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
  294. LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
  295. LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
  296. LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
  297. LayoutType.TensorNHWC: LayoutType.TensorNHWC
  298. }
  299. #
  300. ShortLayoutTypeNames = {
  301. LayoutType.ColumnMajor: 'n',
  302. LayoutType.ColumnMajorInterleaved32: 'n2',
  303. LayoutType.ColumnMajorInterleaved32: 'n32',
  304. LayoutType.ColumnMajorInterleaved64: 'n64',
  305. LayoutType.RowMajor: 't',
  306. LayoutType.RowMajorInterleaved2: 't2',
  307. LayoutType.RowMajorInterleaved32: 't32',
  308. LayoutType.RowMajorInterleaved64: 't64',
  309. LayoutType.TensorNHWC: 'nhwc',
  310. LayoutType.TensorNDHWC: 'ndhwc',
  311. LayoutType.TensorNCHW: 'nchw',
  312. LayoutType.TensorNGHWC: 'nghwc',
  313. LayoutType.TensorNC4HW4: 'nc4hw4',
  314. LayoutType.TensorC4RSK4: 'c4rsk4',
  315. LayoutType.TensorNC8HW8: 'nc8hw8',
  316. LayoutType.TensorNC16HW16: 'nc16hw16',
  317. LayoutType.TensorNC32HW32: 'nc32hw32',
  318. LayoutType.TensorNC64HW64: 'nc64hw64',
  319. LayoutType.TensorC32RSK32: 'c32rsk32',
  320. LayoutType.TensorC64RSK64: 'c64rsk64',
  321. LayoutType.TensorK4RSC4: 'k4rsc4',
  322. LayoutType.TensorCK4RS4: 'ck4rs4',
  323. LayoutType.TensorCK8RS8: 'ck8rs8',
  324. LayoutType.TensorCK16RS16: 'ck16rs16',
  325. }
  326. #
  327. ShortComplexLayoutNames = {
  328. (LayoutType.ColumnMajor, ComplexTransform.none): 'n',
  329. (LayoutType.ColumnMajor, ComplexTransform.conj): 'c',
  330. (LayoutType.RowMajor, ComplexTransform.none): 't',
  331. (LayoutType.RowMajor, ComplexTransform.conj): 'h'
  332. }
  333. ###################################################################################################
  334. #
  335. class OpcodeClass(enum.Enum):
  336. Simt = enum_auto()
  337. TensorOp = enum_auto()
  338. WmmaTensorOp = enum_auto()
  339. OpcodeClassNames = {
  340. OpcodeClass.Simt: 'simt',
  341. OpcodeClass.TensorOp: 'tensorop',
  342. OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
  343. }
  344. OpcodeClassTag = {
  345. OpcodeClass.Simt: 'cutlass::arch::OpClassSimt',
  346. OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
  347. OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
  348. }
  349. ###################################################################################################
  350. #
  351. class OperationKind(enum.Enum):
  352. Gemm = enum_auto()
  353. Conv2d = enum_auto()
  354. #
  355. OperationKindNames = {
  356. OperationKind.Gemm: 'gemm'
  357. , OperationKind.Conv2d: 'conv2d'
  358. }
  359. #
  360. class Target(enum.Enum):
  361. library = enum_auto()
  362. ArchitectureNames = {
  363. 50: 'maxwell',
  364. 60: 'pascal',
  365. 61: 'pascal',
  366. 70: 'volta',
  367. 75: 'turing',
  368. 80: 'ampere',
  369. }
  370. ###################################################################################################
  371. #
  372. def SubstituteTemplate(template, values):
  373. text = template
  374. changed = True
  375. while changed:
  376. changed = False
  377. for key, value in values.items():
  378. regex = "\\$\\{%s\\}" % key
  379. newtext = re.sub(regex, value, text)
  380. if newtext != text:
  381. changed = True
  382. text = newtext
  383. return text
  384. ###################################################################################################
  385. #
  386. class GemmKind(enum.Enum):
  387. Gemm = enum_auto()
  388. Sparse = enum_auto()
  389. Universal = enum_auto()
  390. PlanarComplex = enum_auto()
  391. PlanarComplexArray = enum_auto()
  392. SplitKParallel = enum_auto()
  393. GemvBatchedStrided = enum_auto()
  394. #
  395. GemmKindNames = {
  396. GemmKind.Gemm: "gemm",
  397. GemmKind.Sparse: "spgemm",
  398. GemmKind.Universal: "gemm",
  399. GemmKind.PlanarComplex: "gemm_planar_complex",
  400. GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
  401. GemmKind.SplitKParallel: "gemm_split_k_parallel",
  402. GemmKind.GemvBatchedStrided: "gemv_batched_strided",
  403. }
  404. #
  405. class EpilogueFunctor(enum.Enum):
  406. LinearCombination = enum_auto()
  407. LinearCombinationClamp = enum_auto()
  408. BiasAddLinearCombination = enum_auto()
  409. BiasAddLinearCombinationRelu = enum_auto()
  410. BiasAddLinearCombinationHSwish = enum_auto()
  411. BiasAddLinearCombinationClamp = enum_auto()
  412. BiasAddLinearCombinationReluClamp = enum_auto()
  413. BiasAddLinearCombinationHSwishClamp = enum_auto()
  414. #
  415. EpilogueFunctorTag = {
  416. EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
  417. EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
  418. EpilogueFunctor.BiasAddLinearCombination: 'cutlass::epilogue::thread::BiasAddLinearCombination',
  419. EpilogueFunctor.BiasAddLinearCombinationRelu: 'cutlass::epilogue::thread::BiasAddLinearCombinationRelu',
  420. EpilogueFunctor.BiasAddLinearCombinationHSwish: 'cutlass::epilogue::thread::BiasAddLinearCombinationHSwish',
  421. EpilogueFunctor.BiasAddLinearCombinationClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationClamp',
  422. EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp',
  423. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp',
  424. }
  425. #
  426. ShortEpilogueNames = {
  427. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish',
  428. EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu',
  429. EpilogueFunctor.BiasAddLinearCombinationClamp: 'id',
  430. EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish',
  431. EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu',
  432. EpilogueFunctor.BiasAddLinearCombination: 'id',
  433. }
  434. #
  435. class SwizzlingFunctor(enum.Enum):
  436. Identity1 = enum_auto()
  437. Identity2 = enum_auto()
  438. Identity4 = enum_auto()
  439. Identity8 = enum_auto()
  440. ConvFpropNCxHWx = enum_auto()
  441. ConvFpropTrans = enum_auto()
  442. ConvDgradNCxHWx = enum_auto()
  443. ConvDgradTrans = enum_auto()
  444. #
  445. SwizzlingFunctorTag = {
  446. SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>',
  447. SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>',
  448. SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>',
  449. SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>',
  450. SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle',
  451. SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle',
  452. SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle',
  453. SwizzlingFunctor.ConvDgradTrans: 'cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle',
  454. }
  455. ###################################################################################################
  456. class ConvType(enum.Enum):
  457. Convolution = enum_auto()
  458. BatchConvolution = enum_auto()
  459. Local = enum_auto()
  460. LocalShare = enum_auto()
  461. ConvTypeTag = {
  462. ConvType.Convolution: 'cutlass::conv::ConvType::kConvolution',
  463. ConvType.BatchConvolution: 'cutlass::conv::ConvType::kBatchConvolution',
  464. ConvType.Local: 'cutlass::conv::ConvType::kLocal',
  465. ConvType.LocalShare : 'cutlass::conv::ConvType::kLocalShare',
  466. }
  467. #
  468. class ConvKind(enum.Enum):
  469. Fprop = enum_auto()
  470. Dgrad = enum_auto()
  471. Wgrad = enum_auto()
  472. #
  473. ConvKindTag = {
  474. ConvKind.Fprop: 'cutlass::conv::Operator::kFprop',
  475. ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad',
  476. ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad'
  477. }
  478. ConvKindNames = {
  479. ConvKind.Fprop: 'fprop',
  480. ConvKind.Dgrad: 'dgrad',
  481. ConvKind.Wgrad: 'wgrad',
  482. }
  483. #
  484. class IteratorAlgorithm(enum.Enum):
  485. Analytic = enum_auto()
  486. Optimized = enum_auto()
  487. #
  488. IteratorAlgorithmTag = {
  489. IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
  490. IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
  491. }
  492. IteratorAlgorithmNames = {
  493. IteratorAlgorithm.Analytic: 'analytic',
  494. IteratorAlgorithm.Optimized: 'optimized',
  495. }
  496. #
  497. class StrideSupport(enum.Enum):
  498. Strided = enum_auto()
  499. Unity = enum_auto()
  500. #
  501. StrideSupportTag = {
  502. StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
  503. StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
  504. }
  505. StrideSupportNames = {
  506. StrideSupport.Strided: '',
  507. StrideSupport.Unity: 'unity_stride',
  508. }
  509. class SpecialOptimizeDesc(enum.Enum):
  510. NoneSpecialOpt = enum_auto()
  511. ConvFilterUnity = enum_auto()
  512. DeconvDoubleUpsampling = enum_auto()
  513. SpecialOptimizeDescNames = {
  514. SpecialOptimizeDesc.NoneSpecialOpt: 'none',
  515. SpecialOptimizeDesc.ConvFilterUnity: 'conv_filter_unity',
  516. SpecialOptimizeDesc.DeconvDoubleUpsampling: 'deconv_double_upsampling',
  517. }
  518. SpecialOptimizeDescTag = {
  519. SpecialOptimizeDesc.NoneSpecialOpt: 'cutlass::conv::SpecialOptimizeDesc::NONE',
  520. SpecialOptimizeDesc.ConvFilterUnity: 'cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY',
  521. SpecialOptimizeDesc.DeconvDoubleUpsampling: 'cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING',
  522. }
  523. class ImplicitGemmMode(enum.Enum):
  524. GemmNT = enum_auto()
  525. GemmTN = enum_auto()
  526. ImplicitGemmModeNames = {
  527. ImplicitGemmMode.GemmNT: 'gemm_nt',
  528. ImplicitGemmMode.GemmTN: 'gemm_tn',
  529. }
  530. ImplicitGemmModeTag = {
  531. ImplicitGemmMode.GemmNT: 'cutlass::conv::ImplicitGemmMode::GEMM_NT',
  532. ImplicitGemmMode.GemmTN: 'cutlass::conv::ImplicitGemmMode::GEMM_TN',
  533. }
  534. ###################################################################################################
  535. #
  536. class MathInstruction:
  537. def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add):
  538. self.instruction_shape = instruction_shape
  539. self.element_a = element_a
  540. self.element_b = element_b
  541. self.element_accumulator = element_accumulator
  542. self.opcode_class = opcode_class
  543. self.math_operation = math_operation
  544. #
  545. class TileDescription:
  546. def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute):
  547. self.threadblock_shape = threadblock_shape
  548. self.stages = stages
  549. self.warp_count = warp_count
  550. self.math_instruction = math_instruction
  551. self.minimum_compute_capability = min_compute
  552. self.maximum_compute_capability = max_compute
  553. def procedural_name(self):
  554. return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
  555. #
  556. class TensorDescription:
  557. def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):
  558. self.element = element
  559. self.layout = layout
  560. self.alignment = alignment
  561. self.complex_transform = complex_transform
  562. ###################################################################################################
  563. class GlobalCnt:
  564. cnt = 0