You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

library.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import re
  7. ###################################################################################################
  8. import enum
  9. # The following block implements enum.auto() for Python 3.5 variants that don't include it such
  10. # as the default 3.5.2 on Ubuntu 16.04.
  11. #
  12. # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
  13. try:
  14. from enum import auto as enum_auto
  15. except ImportError:
  16. __cutlass_library_auto_enum = 0
  17. def enum_auto() -> int:
  18. global __cutlass_library_auto_enum
  19. i = __cutlass_library_auto_enum
  20. __cutlass_library_auto_enum += 1
  21. return i
  22. ###################################################################################################
  23. #
  24. class GeneratorTarget(enum.Enum):
  25. Library = enum_auto()
  26. #
  27. GeneratorTargetNames = {
  28. GeneratorTarget.Library: 'library'
  29. }
  30. #
  31. ###################################################################################################
  32. #
  33. class DataType(enum.Enum):
  34. b1 = enum_auto()
  35. u4 = enum_auto()
  36. u8 = enum_auto()
  37. u16 = enum_auto()
  38. u32 = enum_auto()
  39. u64 = enum_auto()
  40. s4 = enum_auto()
  41. s8 = enum_auto()
  42. s16 = enum_auto()
  43. s32 = enum_auto()
  44. s64 = enum_auto()
  45. f16 = enum_auto()
  46. bf16 = enum_auto()
  47. f32 = enum_auto()
  48. tf32 = enum_auto()
  49. f64 = enum_auto()
  50. cf16 = enum_auto()
  51. cbf16 = enum_auto()
  52. cf32 = enum_auto()
  53. ctf32 = enum_auto()
  54. cf64 = enum_auto()
  55. cs4 = enum_auto()
  56. cs8 = enum_auto()
  57. cs16 = enum_auto()
  58. cs32 = enum_auto()
  59. cs64 = enum_auto()
  60. cu4 = enum_auto()
  61. cu8 = enum_auto()
  62. cu16 = enum_auto()
  63. cu32 = enum_auto()
  64. cu64 = enum_auto()
  65. invalid = enum_auto()
  66. #
  67. ShortDataTypeNames = {
  68. DataType.s32: 'i',
  69. DataType.f16: 'h',
  70. DataType.f32: 's',
  71. DataType.f64: 'd',
  72. DataType.cf32: 'c',
  73. DataType.cf64: 'z',
  74. }
  75. #
  76. DataTypeNames = {
  77. DataType.b1: "b1",
  78. DataType.u4: "u4",
  79. DataType.u8: "u8",
  80. DataType.u16: "u16",
  81. DataType.u32: "u32",
  82. DataType.u64: "u64",
  83. DataType.s4: "s4",
  84. DataType.s8: "s8",
  85. DataType.s16: "s16",
  86. DataType.s32: "s32",
  87. DataType.s64: "s64",
  88. DataType.f16: "f16",
  89. DataType.bf16: "bf16",
  90. DataType.f32: "f32",
  91. DataType.tf32: "tf32",
  92. DataType.f64: "f64",
  93. DataType.cf16: "cf16",
  94. DataType.cbf16: "cbf16",
  95. DataType.cf32: "cf32",
  96. DataType.ctf32: "ctf32",
  97. DataType.cf64: "cf64",
  98. DataType.cu4: "cu4",
  99. DataType.cu8: "cu8",
  100. DataType.cu16: "cu16",
  101. DataType.cu32: "cu32",
  102. DataType.cu64: "cu64",
  103. DataType.cs4: "cs4",
  104. DataType.cs8: "cs8",
  105. DataType.cs16: "cs16",
  106. DataType.cs32: "cs32",
  107. DataType.cs64: "cs64",
  108. }
  109. DataTypeTag = {
  110. DataType.b1: "cutlass::uint1b_t",
  111. DataType.u4: "cutlass::uint4b_t",
  112. DataType.u8: "uint8_t",
  113. DataType.u16: "uint16_t",
  114. DataType.u32: "uint32_t",
  115. DataType.u64: "uint64_t",
  116. DataType.s4: "cutlass::int4b_t",
  117. DataType.s8: "int8_t",
  118. DataType.s16: "int16_t",
  119. DataType.s32: "int32_t",
  120. DataType.s64: "int64_t",
  121. DataType.f16: "cutlass::half_t",
  122. DataType.bf16: "cutlass::bfloat16_t",
  123. DataType.f32: "float",
  124. DataType.tf32: "cutlass::tfloat32_t",
  125. DataType.f64: "double",
  126. DataType.cf16: "cutlass::complex<cutlass::half_t>",
  127. DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
  128. DataType.cf32: "cutlass::complex<float>",
  129. DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
  130. DataType.cf64: "cutlass::complex<double>",
  131. DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
  132. DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
  133. DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
  134. DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
  135. DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
  136. DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
  137. DataType.cs8: "cutlass::complex<cutlass::int8_t>",
  138. DataType.cs16: "cutlass::complex<cutlass::int16_t>",
  139. DataType.cs32: "cutlass::complex<cutlass::int32_t>",
  140. DataType.cs64: "cutlass::complex<cutlass::int64_t>",
  141. }
  142. DataTypeSize = {
  143. DataType.b1: 1,
  144. DataType.u4: 4,
  145. DataType.u8: 4,
  146. DataType.u16: 16,
  147. DataType.u32: 32,
  148. DataType.u64: 64,
  149. DataType.s4: 4,
  150. DataType.s8: 8,
  151. DataType.s16: 16,
  152. DataType.s32: 32,
  153. DataType.s64: 64,
  154. DataType.f16: 16,
  155. DataType.bf16: 16,
  156. DataType.f32: 32,
  157. DataType.tf32: 32,
  158. DataType.f64: 64,
  159. DataType.cf16: 32,
  160. DataType.cbf16: 32,
  161. DataType.cf32: 64,
  162. DataType.ctf32: 32,
  163. DataType.cf64: 128,
  164. DataType.cu4: 8,
  165. DataType.cu8: 16,
  166. DataType.cu16: 32,
  167. DataType.cu32: 64,
  168. DataType.cu64: 128,
  169. DataType.cs4: 8,
  170. DataType.cs8: 16,
  171. DataType.cs16: 32,
  172. DataType.cs32: 64,
  173. DataType.cs64: 128,
  174. }
  175. ###################################################################################################
  176. #
  177. class ComplexTransform(enum.Enum):
  178. none = enum_auto()
  179. conj = enum_auto()
  180. #
  181. ComplexTransformTag = {
  182. ComplexTransform.none: 'cutlass::ComplexTransform::kNone',
  183. ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
  184. }
  185. #
  186. RealComplexBijection = [
  187. (DataType.f16, DataType.cf16),
  188. (DataType.f32, DataType.cf32),
  189. (DataType.f64, DataType.cf64),
  190. ]
  191. #
  192. def is_complex(data_type):
  193. for r, c in RealComplexBijection:
  194. if data_type == c:
  195. return True
  196. return False
  197. #
  198. def get_complex_from_real(real_type):
  199. for r, c in RealComplexBijection:
  200. if real_type == r:
  201. return c
  202. return DataType.invalid
  203. #
  204. def get_real_from_complex(complex_type):
  205. for r, c in RealComplexBijection:
  206. if complex_type == c:
  207. return r
  208. return DataType.invalid
  209. #
  210. class ComplexMultiplyOp(enum.Enum):
  211. multiply_add = enum_auto()
  212. gaussian = enum_auto()
  213. ###################################################################################################
  214. #
  215. class MathOperation(enum.Enum):
  216. multiply_add = enum_auto()
  217. multiply_add_saturate = enum_auto()
  218. xor_popc = enum_auto()
  219. multiply_add_fast_bf16 = enum_auto()
  220. multiply_add_fast_f16 = enum_auto()
  221. multiply_add_complex = enum_auto()
  222. multiply_add_complex_gaussian = enum_auto()
  223. #
  224. MathOperationTag = {
  225. MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
  226. MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
  227. MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
  228. MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
  229. MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16',
  230. MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex',
  231. MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex',
  232. }
  233. ###################################################################################################
  234. #
  235. class LayoutType(enum.Enum):
  236. ColumnMajor = enum_auto()
  237. RowMajor = enum_auto()
  238. ColumnMajorInterleaved2 = enum_auto()
  239. RowMajorInterleaved2 = enum_auto()
  240. ColumnMajorInterleaved32 = enum_auto()
  241. RowMajorInterleaved32 = enum_auto()
  242. ColumnMajorInterleaved64 = enum_auto()
  243. RowMajorInterleaved64 = enum_auto()
  244. TensorNHWC = enum_auto()
  245. TensorNDHWC = enum_auto()
  246. TensorNCHW = enum_auto()
  247. TensorNGHWC = enum_auto()
  248. TensorNC4HW4 = enum_auto()
  249. TensorC4RSK4 = enum_auto()
  250. TensorNC8HW8 = enum_auto()
  251. TensorNC16HW16 = enum_auto()
  252. TensorNC32HW32 = enum_auto()
  253. TensorNC64HW64 = enum_auto()
  254. TensorC32RSK32 = enum_auto()
  255. TensorC64RSK64 = enum_auto()
  256. TensorK4RSC4 = enum_auto()
  257. #
  258. LayoutTag = {
  259. LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor',
  260. LayoutType.RowMajor: 'cutlass::layout::RowMajor',
  261. LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>',
  262. LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>',
  263. LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>',
  264. LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>',
  265. LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>',
  266. LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>',
  267. LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC',
  268. LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC',
  269. LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW',
  270. LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC',
  271. LayoutType.TensorNC4HW4: 'cutlass::layout::TensorNCxHWx<4>',
  272. LayoutType.TensorC4RSK4: 'cutlass::layout::TensorCxRSKx<4>',
  273. LayoutType.TensorNC8HW8: 'cutlass::layout::TensorNCxHWx<8>',
  274. LayoutType.TensorNC16HW16: 'cutlass::layout::TensorNCxHWx<16>',
  275. LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>',
  276. LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>',
  277. LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>',
  278. LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>',
  279. LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>',
  280. }
  281. #
  282. TransposedLayout = {
  283. LayoutType.ColumnMajor: LayoutType.RowMajor,
  284. LayoutType.RowMajor: LayoutType.ColumnMajor,
  285. LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
  286. LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
  287. LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
  288. LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
  289. LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
  290. LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
  291. LayoutType.TensorNHWC: LayoutType.TensorNHWC
  292. }
  293. #
  294. ShortLayoutTypeNames = {
  295. LayoutType.ColumnMajor: 'n',
  296. LayoutType.ColumnMajorInterleaved32: 'n2',
  297. LayoutType.ColumnMajorInterleaved32: 'n32',
  298. LayoutType.ColumnMajorInterleaved64: 'n64',
  299. LayoutType.RowMajor: 't',
  300. LayoutType.RowMajorInterleaved2: 't2',
  301. LayoutType.RowMajorInterleaved32: 't32',
  302. LayoutType.RowMajorInterleaved64: 't64',
  303. LayoutType.TensorNHWC: 'nhwc',
  304. LayoutType.TensorNDHWC: 'ndhwc',
  305. LayoutType.TensorNCHW: 'nchw',
  306. LayoutType.TensorNGHWC: 'nghwc',
  307. LayoutType.TensorNC4HW4: 'nc4hw4',
  308. LayoutType.TensorC4RSK4: 'c4rsk4',
  309. LayoutType.TensorNC8HW8: 'nc8hw8',
  310. LayoutType.TensorNC16HW16: 'nc16hw16',
  311. LayoutType.TensorNC32HW32: 'nc32hw32',
  312. LayoutType.TensorNC64HW64: 'nc64hw64',
  313. LayoutType.TensorC32RSK32: 'c32rsk32',
  314. LayoutType.TensorC64RSK64: 'c64rsk64',
  315. LayoutType.TensorK4RSC4: 'k4rsc4',
  316. }
  317. #
  318. ShortComplexLayoutNames = {
  319. (LayoutType.ColumnMajor, ComplexTransform.none): 'n',
  320. (LayoutType.ColumnMajor, ComplexTransform.conj): 'c',
  321. (LayoutType.RowMajor, ComplexTransform.none): 't',
  322. (LayoutType.RowMajor, ComplexTransform.conj): 'h'
  323. }
  324. ###################################################################################################
  325. #
  326. class OpcodeClass(enum.Enum):
  327. Simt = enum_auto()
  328. TensorOp = enum_auto()
  329. WmmaTensorOp = enum_auto()
  330. OpcodeClassNames = {
  331. OpcodeClass.Simt: 'simt',
  332. OpcodeClass.TensorOp: 'tensorop',
  333. OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
  334. }
  335. OpcodeClassTag = {
  336. OpcodeClass.Simt: 'cutlass::arch::OpClassSimt',
  337. OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
  338. OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
  339. }
  340. ###################################################################################################
  341. #
  342. class OperationKind(enum.Enum):
  343. Gemm = enum_auto()
  344. Conv2d = enum_auto()
  345. #
  346. OperationKindNames = {
  347. OperationKind.Gemm: 'gemm'
  348. , OperationKind.Conv2d: 'conv2d'
  349. }
  350. #
  351. class Target(enum.Enum):
  352. library = enum_auto()
  353. ArchitectureNames = {
  354. 50: 'maxwell',
  355. 60: 'pascal',
  356. 61: 'pascal',
  357. 70: 'volta',
  358. 75: 'turing',
  359. 80: 'ampere',
  360. }
  361. ###################################################################################################
  362. #
  363. def SubstituteTemplate(template, values):
  364. text = template
  365. changed = True
  366. while changed:
  367. changed = False
  368. for key, value in values.items():
  369. regex = "\\$\\{%s\\}" % key
  370. newtext = re.sub(regex, value, text)
  371. if newtext != text:
  372. changed = True
  373. text = newtext
  374. return text
  375. ###################################################################################################
  376. #
  377. class GemmKind(enum.Enum):
  378. Gemm = enum_auto()
  379. Sparse = enum_auto()
  380. Universal = enum_auto()
  381. PlanarComplex = enum_auto()
  382. PlanarComplexArray = enum_auto()
  383. SplitKParallel = enum_auto()
  384. GemvBatchedStrided = enum_auto()
  385. #
  386. GemmKindNames = {
  387. GemmKind.Gemm: "gemm",
  388. GemmKind.Sparse: "spgemm",
  389. GemmKind.Universal: "gemm",
  390. GemmKind.PlanarComplex: "gemm_planar_complex",
  391. GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
  392. GemmKind.SplitKParallel: "gemm_split_k_parallel",
  393. GemmKind.GemvBatchedStrided: "gemv_batched_strided",
  394. }
  395. #
  396. class EpilogueFunctor(enum.Enum):
  397. LinearCombination = enum_auto()
  398. LinearCombinationClamp = enum_auto()
  399. BiasAddLinearCombination = enum_auto()
  400. BiasAddLinearCombinationRelu = enum_auto()
  401. BiasAddLinearCombinationHSwish = enum_auto()
  402. BiasAddLinearCombinationClamp = enum_auto()
  403. BiasAddLinearCombinationReluClamp = enum_auto()
  404. BiasAddLinearCombinationHSwishClamp = enum_auto()
  405. #
  406. EpilogueFunctorTag = {
  407. EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
  408. EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
  409. EpilogueFunctor.BiasAddLinearCombination: 'cutlass::epilogue::thread::BiasAddLinearCombination',
  410. EpilogueFunctor.BiasAddLinearCombinationRelu: 'cutlass::epilogue::thread::BiasAddLinearCombinationRelu',
  411. EpilogueFunctor.BiasAddLinearCombinationHSwish: 'cutlass::epilogue::thread::BiasAddLinearCombinationHSwish',
  412. EpilogueFunctor.BiasAddLinearCombinationClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationClamp',
  413. EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp',
  414. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp',
  415. }
  416. #
  417. ShortEpilogueNames = {
  418. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish',
  419. EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu',
  420. EpilogueFunctor.BiasAddLinearCombinationClamp: 'identity',
  421. EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish',
  422. EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu',
  423. EpilogueFunctor.BiasAddLinearCombination: 'identity',
  424. }
  425. #
  426. class SwizzlingFunctor(enum.Enum):
  427. Identity1 = enum_auto()
  428. Identity2 = enum_auto()
  429. Identity4 = enum_auto()
  430. Identity8 = enum_auto()
  431. ConvFpropNCxHWx = enum_auto()
  432. ConvFpropNHWC = enum_auto()
  433. ConvDgradNCxHWx = enum_auto()
  434. #
  435. SwizzlingFunctorTag = {
  436. SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>',
  437. SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>',
  438. SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>',
  439. SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>',
  440. SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle',
  441. SwizzlingFunctor.ConvFpropNHWC: 'cutlass::conv::threadblock::ConvolutionFpropNHWCThreadblockSwizzle',
  442. SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle',
  443. }
  444. ###################################################################################################
  445. class ConvType(enum.Enum):
  446. Convolution = enum_auto()
  447. BatchConvolution = enum_auto()
  448. Local = enum_auto()
  449. LocalShare = enum_auto()
  450. ConvTypeTag = {
  451. ConvType.Convolution: 'cutlass::conv::ConvType::kConvolution',
  452. ConvType.BatchConvolution: 'cutlass::conv::ConvType::kBatchConvolution',
  453. ConvType.Local: 'cutlass::conv::ConvType::kLocal',
  454. ConvType.LocalShare : 'cutlass::conv::ConvType::kLocalShare',
  455. }
  456. #
  457. class ConvKind(enum.Enum):
  458. Fprop = enum_auto()
  459. Dgrad = enum_auto()
  460. Wgrad = enum_auto()
  461. #
  462. ConvKindTag = {
  463. ConvKind.Fprop: 'cutlass::conv::Operator::kFprop',
  464. ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad',
  465. ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad'
  466. }
  467. ConvKindNames = {
  468. ConvKind.Fprop: 'fprop',
  469. ConvKind.Dgrad: 'dgrad',
  470. ConvKind.Wgrad: 'wgrad',
  471. }
  472. #
  473. class IteratorAlgorithm(enum.Enum):
  474. Analytic = enum_auto()
  475. Optimized = enum_auto()
  476. #
  477. IteratorAlgorithmTag = {
  478. IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
  479. IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
  480. }
  481. IteratorAlgorithmNames = {
  482. IteratorAlgorithm.Analytic: 'analytic',
  483. IteratorAlgorithm.Optimized: 'optimized',
  484. }
  485. #
  486. class StrideSupport(enum.Enum):
  487. Strided = enum_auto()
  488. Unity = enum_auto()
  489. #
  490. StrideSupportTag = {
  491. StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
  492. StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
  493. }
  494. StrideSupportNames = {
  495. StrideSupport.Strided: '',
  496. StrideSupport.Unity: 'unity_stride',
  497. }
  498. class ImplicitGemmMode(enum.Enum):
  499. GemmNt = enum_auto()
  500. GemmTn = enum_auto()
  501. ImplicitGemmModeNames = {
  502. ImplicitGemmMode.GemmNt: 'gemm_nt',
  503. ImplicitGemmMode.GemmTn: 'gemm_tn',
  504. }
  505. ImplicitGemmModeTag = {
  506. ImplicitGemmMode.GemmNt: 'cutlass::conv::ImplicitGemmMode::GEMM_NT',
  507. ImplicitGemmMode.GemmTn: 'cutlass::conv::ImplicitGemmMode::GEMM_TN',
  508. }
  509. ###################################################################################################
  510. #
  511. class MathInstruction:
  512. def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add):
  513. self.instruction_shape = instruction_shape
  514. self.element_a = element_a
  515. self.element_b = element_b
  516. self.element_accumulator = element_accumulator
  517. self.opcode_class = opcode_class
  518. self.math_operation = math_operation
  519. #
  520. class TileDescription:
  521. def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute):
  522. self.threadblock_shape = threadblock_shape
  523. self.stages = stages
  524. self.warp_count = warp_count
  525. self.math_instruction = math_instruction
  526. self.minimum_compute_capability = min_compute
  527. self.maximum_compute_capability = max_compute
  528. def procedural_name(self):
  529. return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
  530. #
  531. class TensorDescription:
  532. def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):
  533. self.element = element
  534. self.layout = layout
  535. self.alignment = alignment
  536. self.complex_transform = complex_transform
  537. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台