You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv2d_operation.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. #
  7. import enum
  8. import os.path
  9. import shutil
  10. from typing import Tuple, List
  11. from library import *
  12. ###################################################################################################
  13. #
  14. class Conv2dOperation:
  15. #
  16. def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \
  17. epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \
  18. special_optimization = SpecialOptimizeDesc.NoneSpecialOpt, implicit_gemm_mode = ImplicitGemmMode.GemmNT, \
  19. without_shared_load = False, required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  20. self.operation_kind = OperationKind.Conv2d
  21. self.conv_kind = conv_kind
  22. self.arch = arch
  23. self.tile_description = tile_description
  24. self.conv_type = conv_type
  25. self.src = src
  26. self.flt = flt
  27. self.bias = bias
  28. self.dst = dst
  29. self.element_epilogue = element_epilogue
  30. self.epilogue_functor = epilogue_functor
  31. self.swizzling_functor = swizzling_functor
  32. self.special_optimization = special_optimization
  33. self.implicit_gemm_mode = implicit_gemm_mode
  34. self.without_shared_load = without_shared_load
  35. self.required_cuda_ver_major = required_cuda_ver_major
  36. self.required_cuda_ver_minor = required_cuda_ver_minor
  37. #
  38. def accumulator_type(self):
  39. accum = self.tile_description.math_instruction.element_accumulator
  40. return accum
  41. #
  42. def core_name(self):
  43. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  44. intermediate_type = ''
  45. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
  46. inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
  47. if self.tile_description.math_instruction.element_a != self.flt.element and \
  48. self.tile_description.math_instruction.element_a != self.accumulator_type():
  49. intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
  50. else:
  51. inst_shape = ''
  52. special_opt = ''
  53. if self.special_optimization == SpecialOptimizeDesc.ConvFilterUnity:
  54. special_opt = '_1x1'
  55. elif self.special_optimization == SpecialOptimizeDesc.DeconvDoubleUpsampling:
  56. special_opt = '_s2'
  57. reorder_k = ''
  58. if self.without_shared_load:
  59. reorder_k = '_roc'
  60. return "%s%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
  61. inst_shape, intermediate_type, ConvKindNames[self.conv_kind], special_opt, \
  62. reorder_k, ShortEpilogueNames[self.epilogue_functor])
  63. #
  64. def extended_name(self):
  65. if self.dst.element != self.tile_description.math_instruction.element_accumulator:
  66. if self.src.element != self.flt.element:
  67. extended_name = "${element_dst}_${core_name}_${element_src}_${element_flt}"
  68. elif self.src.element == self.flt.element:
  69. extended_name = "${element_dst}_${core_name}_${element_src}"
  70. else:
  71. if self.src.element != self.flt.element:
  72. extended_name = "${core_name}_${element_src}_${element_flt}"
  73. elif self.src.element == self.flt.element:
  74. extended_name = "${core_name}_${element_src}"
  75. extended_name = SubstituteTemplate(extended_name, {
  76. 'element_src': DataTypeNames[self.src.element],
  77. 'element_flt': DataTypeNames[self.flt.element],
  78. 'element_dst': DataTypeNames[self.dst.element],
  79. 'core_name': self.core_name()
  80. })
  81. return extended_name
  82. #
  83. def layout_name(self):
  84. if self.src.layout == self.dst.layout:
  85. layout_name = "${src_layout}_${flt_layout}"
  86. else:
  87. layout_name = "${src_layout}_${flt_layout}_${dst_layout}"
  88. layout_name = SubstituteTemplate(layout_name, {
  89. 'src_layout': ShortLayoutTypeNames[self.src.layout],
  90. 'flt_layout': ShortLayoutTypeNames[self.flt.layout],
  91. 'dst_layout': ShortLayoutTypeNames[self.dst.layout],
  92. })
  93. return layout_name
  94. #
  95. def configuration_name(self):
  96. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  97. opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
  98. warp_shape = [int(self.tile_description.threadblock_shape[idx] / self.tile_description.warp_count[idx]) for idx in range(3)]
  99. threadblock = "%dx%dx%d_%dx%dx%d_%d" % (
  100. self.tile_description.threadblock_shape[0],
  101. self.tile_description.threadblock_shape[1],
  102. self.tile_description.threadblock_shape[2],
  103. warp_shape[0],
  104. warp_shape[1],
  105. warp_shape[2],
  106. self.tile_description.stages,
  107. )
  108. configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}"
  109. return SubstituteTemplate(
  110. configuration_name,
  111. {
  112. 'opcode_class': opcode_class_name,
  113. 'extended_name': self.extended_name(),
  114. 'threadblock': threadblock,
  115. 'layout': self.layout_name(),
  116. }
  117. )
  118. #
  119. def procedural_name(self):
  120. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  121. return self.configuration_name()
  122. ###################################################################################################
  123. #
  124. # Emits single instances of a CUTLASS device-wide operator
  125. #
  126. ###################################################################################################
  127. class EmitConv2dInstance:
  128. def __init__(self):
  129. self.template = """
  130. // kernel instance "${operation_name}" generated by cutlass generator
  131. using Convolution =
  132. typename cutlass::conv::device::Convolution<
  133. ${element_src},
  134. ${layout_src},
  135. ${element_flt},
  136. ${layout_flt},
  137. ${element_dst},
  138. ${layout_dst},
  139. ${element_bias},
  140. ${layout_bias},
  141. ${element_accumulator},
  142. ${conv_type},
  143. ${opcode_class},
  144. ${arch},
  145. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  146. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  147. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  148. ${epilogue_functor}<
  149. ${element_dst},
  150. ${epilogue_vector_length},
  151. ${element_accumulator},
  152. ${element_bias},
  153. ${element_epilogue}
  154. >,
  155. ${swizzling_functor},
  156. ${stages},
  157. ${alignment_src},
  158. ${alignment_filter},
  159. ${special_optimization},
  160. ${math_operator},
  161. ${implicit_gemm_mode},
  162. ${without_shared_load}>;
  163. """
  164. def emit(self, operation):
  165. warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
  166. epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element])
  167. values = {
  168. 'operation_name': operation.procedural_name(),
  169. 'conv_type': ConvTypeTag[operation.conv_type],
  170. 'element_src': DataTypeTag[operation.src.element],
  171. 'layout_src': LayoutTag[operation.src.layout],
  172. 'element_flt': DataTypeTag[operation.flt.element],
  173. 'layout_flt': LayoutTag[operation.flt.layout],
  174. 'element_dst': DataTypeTag[operation.dst.element],
  175. 'layout_dst': LayoutTag[operation.dst.layout],
  176. 'element_bias': DataTypeTag[operation.bias.element],
  177. 'layout_bias': LayoutTag[operation.bias.layout],
  178. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  179. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  180. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  181. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  182. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  183. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  184. 'warp_shape_m': str(warp_shape[0]),
  185. 'warp_shape_n': str(warp_shape[1]),
  186. 'warp_shape_k': str(warp_shape[2]),
  187. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  188. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  189. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  190. 'epilogue_vector_length': str(epilogue_vector_length),
  191. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  192. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  193. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  194. 'stages': str(operation.tile_description.stages),
  195. 'alignment_src': str(operation.src.alignment),
  196. 'alignment_filter': str(operation.flt.alignment),
  197. 'special_optimization': SpecialOptimizeDescTag[operation.special_optimization],
  198. 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  199. 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode],
  200. 'without_shared_load': str(operation.without_shared_load).lower()
  201. }
  202. return SubstituteTemplate(self.template, values)
  203. class EmitDeconvInstance:
  204. def __init__(self):
  205. self.template = """
  206. // kernel instance "${operation_name}" generated by cutlass generator
  207. using Deconvolution =
  208. typename cutlass::conv::device::Deconvolution<
  209. ${element_src},
  210. ${layout_src},
  211. ${element_flt},
  212. ${layout_flt},
  213. ${element_dst},
  214. ${layout_dst},
  215. ${element_bias},
  216. ${layout_bias},
  217. ${element_accumulator},
  218. ${conv_type},
  219. ${opcode_class},
  220. ${arch},
  221. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  222. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  223. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  224. ${epilogue_functor}<
  225. ${element_dst},
  226. ${epilogue_vector_length},
  227. ${element_accumulator},
  228. ${element_bias},
  229. ${element_epilogue}
  230. >,
  231. ${swizzling_functor},
  232. ${stages},
  233. ${alignment_src},
  234. ${alignment_filter},
  235. ${special_optimization},
  236. ${math_operator},
  237. ${implicit_gemm_mode}>;
  238. """
  239. def emit(self, operation):
  240. warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
  241. epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element])
  242. values = {
  243. 'operation_name': operation.procedural_name(),
  244. 'conv_type': ConvTypeTag[operation.conv_type],
  245. 'element_src': DataTypeTag[operation.src.element],
  246. 'layout_src': LayoutTag[operation.src.layout],
  247. 'element_flt': DataTypeTag[operation.flt.element],
  248. 'layout_flt': LayoutTag[operation.flt.layout],
  249. 'element_dst': DataTypeTag[operation.dst.element],
  250. 'layout_dst': LayoutTag[operation.dst.layout],
  251. 'element_bias': DataTypeTag[operation.bias.element],
  252. 'layout_bias': LayoutTag[operation.bias.layout],
  253. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  254. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  255. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  256. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  257. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  258. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  259. 'warp_shape_m': str(warp_shape[0]),
  260. 'warp_shape_n': str(warp_shape[1]),
  261. 'warp_shape_k': str(warp_shape[2]),
  262. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  263. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  264. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  265. 'epilogue_vector_length': str(epilogue_vector_length),
  266. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  267. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  268. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  269. 'stages': str(operation.tile_description.stages),
  270. 'alignment_src': str(operation.src.alignment),
  271. 'alignment_filter': str(operation.flt.alignment),
  272. 'special_optimization': SpecialOptimizeDescTag[operation.special_optimization],
  273. 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  274. 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode]
  275. }
  276. return SubstituteTemplate(self.template, values)
  277. ###################################################################################################
  278. #
  279. # Generator functions for all layouts
  280. #
  281. ###################################################################################################
  282. #
  283. def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 32, \
  284. use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False, \
  285. required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  286. operations = []
  287. element_epilogue = DataType.f32
  288. if conv_kind == ConvKind.Fprop:
  289. if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
  290. swizzling_functor = SwizzlingFunctor.ConvFpropTrans
  291. else:
  292. swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx
  293. else:
  294. swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx
  295. # skip rule
  296. def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool:
  297. return layout == LayoutType.TensorNC32HW32 and \
  298. tile.threadblock_shape[0] % 32 != 0
  299. # rule for bias_type and epilogues
  300. def get_bias_type_and_epilogues(tile: TileDescription, \
  301. out_dtype: DataType) -> Tuple[DataType, List[EpilogueFunctor]]:
  302. if tile.math_instruction.element_accumulator == DataType.s32 and \
  303. out_dtype != DataType.f32:
  304. bias_type = DataType.s32
  305. if tile.math_instruction.element_b == DataType.u4:
  306. epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp]
  307. else:
  308. epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp, \
  309. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp]
  310. elif tile.math_instruction.element_accumulator == DataType.f32 or \
  311. out_dtype == DataType.f32:
  312. bias_type = DataType.f32
  313. epilogues = [EpilogueFunctor.BiasAddLinearCombination, EpilogueFunctor.BiasAddLinearCombinationRelu, \
  314. EpilogueFunctor.BiasAddLinearCombinationHSwish]
  315. return bias_type, epilogues
  316. # rule for filter alignment
  317. def get_flt_align(tile: TileDescription) -> int:
  318. nonlocal flt_align
  319. if tile.math_instruction.opcode_class == OpcodeClass.Simt \
  320. and tile.math_instruction.element_accumulator == DataType.s32:
  321. thread_num = tile.warp_count[0] * tile.warp_count[1] * tile.warp_count[2] * 32
  322. flt_block = tile.threadblock_shape[0] * tile.threadblock_shape[2] \
  323. * DataTypeSize[tile.math_instruction.element_a]
  324. load_per_thread = flt_block//thread_num
  325. if load_per_thread >= 128:
  326. flt_align = 128
  327. elif load_per_thread >= 64:
  328. flt_align = 64
  329. else:
  330. assert load_per_thread >= 32
  331. flt_align = 32
  332. return flt_align
  333. def get_dst_align(tile: TileDescription, out_layout: LayoutType) -> int:
  334. nonlocal dst_align
  335. if tile.math_instruction.opcode_class == OpcodeClass.TensorOp \
  336. and dst_layout == LayoutType.TensorNC4HW4:
  337. dst_align = 32
  338. return dst_align
  339. def filter_epilogue_with_conv_kind(epilogue: EpilogueFunctor, conv_kind: ConvKind) -> bool:
  340. return conv_kind == ConvKind.Dgrad \
  341. and epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp
  342. # loop over all tile descriptions
  343. for tile in tile_descriptions:
  344. if filter_tile_with_layout(tile, dst_layout):
  345. continue
  346. bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type)
  347. flt_align = get_flt_align(tile)
  348. dst_align = get_dst_align(tile, dst_layout)
  349. for epilogue in epilogues:
  350. if filter_epilogue_with_conv_kind(epilogue, conv_kind):
  351. continue
  352. if dst_type == DataType.f32:
  353. bias_type = DataType.f32
  354. #
  355. src = TensorDescription(tile.math_instruction.element_b, src_layout, int(src_align / DataTypeSize[tile.math_instruction.element_b]))
  356. flt = TensorDescription(tile.math_instruction.element_a, flt_layout, int(flt_align / DataTypeSize[tile.math_instruction.element_a]))
  357. bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type])))
  358. dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type]))
  359. new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, SpecialOptimizeDesc.NoneSpecialOpt, implicit_gemm_mode, without_shared_load, required_cuda_ver_major, required_cuda_ver_minor)
  360. operations.append(new_operation)
  361. if use_special_optimization != SpecialOptimizeDesc.NoneSpecialOpt:
  362. new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, use_special_optimization , implicit_gemm_mode, without_shared_load, required_cuda_ver_major, required_cuda_ver_minor)
  363. operations.append(new_operation)
  364. return operations
  365. ###################################################################################################
  366. #
  367. # Emitters functions for all targets
  368. #
  369. ###################################################################################################
  370. class EmitConv2dConfigurationLibrary:
  371. def __init__(self, operation_path, configuration_name):
  372. self.configuration_name = configuration_name
  373. self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
  374. self.instance_emitter = EmitConv2dInstance()
  375. self.instance_template = """
  376. ${operation_instance}
  377. // Derived class
  378. struct ${operation_name} :
  379. public ${operation_name}_base { };
  380. ///////////////////////////////////////////////////////////////////////////////////////////////////
  381. """
  382. self.header_template = """
  383. /*
  384. Generated by conv2d_operation.py - Do not edit.
  385. */
  386. ///////////////////////////////////////////////////////////////////////////////////////////////////
  387. #include "cutlass/cutlass.h"
  388. #include "cutlass/library/library.h"
  389. #include "cutlass/library/manifest.h"
  390. #include "library_internal.h"
  391. #include "conv2d_operation.h"
  392. ///////////////////////////////////////////////////////////////////////////////////////////////////
  393. """
  394. self.configuration_header = """
  395. namespace cutlass {
  396. namespace library {
  397. // Initialize all instances
  398. void initialize_${configuration_name}(Manifest &manifest) {
  399. """
  400. self.configuration_instance = """
  401. using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
  402. ${operation_name}>;
  403. manifest.append(new cutlass::library::Conv2dOperation<
  404. Operation_${operation_name}>(
  405. "${operation_name}"));
  406. """
  407. self.configuration_epilogue = """
  408. }
  409. """
  410. self.epilogue_template = """
  411. ///////////////////////////////////////////////////////////////////////////////////////////////////
  412. } // namespace library
  413. } // namespace cutlass
  414. ///////////////////////////////////////////////////////////////////////////////////////////////////
  415. """
  416. #
  417. def __enter__(self):
  418. self.configuration_file = open(self.configuration_path, "w")
  419. self.configuration_file.write(SubstituteTemplate(self.header_template, {
  420. 'configuration_name': self.configuration_name
  421. }))
  422. self.operations = []
  423. return self
  424. #
  425. def emit(self, operation):
  426. self.operations.append(operation)
  427. self.configuration_file.write(SubstituteTemplate(self.instance_template, {
  428. 'configuration_name': self.configuration_name,
  429. 'operation_name': operation.procedural_name(),
  430. 'operation_instance': self.instance_emitter.emit(operation)
  431. }))
  432. #
  433. def __exit__(self, exception_type, exception_value, traceback):
  434. self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
  435. 'configuration_name': self.configuration_name
  436. }))
  437. for operation in self.operations:
  438. self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
  439. 'configuration_name': self.configuration_name,
  440. 'operation_name': operation.procedural_name()
  441. }))
  442. self.configuration_file.write(self.configuration_epilogue)
  443. self.configuration_file.write(self.epilogue_template)
  444. self.configuration_file.close()
  445. ###################################################################################################
  446. ###################################################################################################
  447. # Emitters for Conv Kernel Wrapper
  448. #
  449. ###################################################################################################
  450. class EmitConvSingleKernelWrapper():
  451. def __init__(self, kernel_path, operation, short_path=False):
  452. self.kernel_path = kernel_path
  453. self.operation = operation
  454. self.short_path = short_path
  455. if self.operation.conv_kind == ConvKind.Fprop:
  456. self.instance_emitter = EmitConv2dInstance()
  457. self.convolution_name = "Convolution"
  458. else:
  459. assert self.operation.conv_kind == ConvKind.Dgrad
  460. self.instance_emitter = EmitDeconvInstance()
  461. self.convolution_name = "Deconvolution"
  462. self.header_template = """
  463. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  464. // ignore warning of cutlass
  465. #pragma GCC diagnostic push
  466. #pragma GCC diagnostic ignored "-Wunused-parameter"
  467. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  468. #pragma GCC diagnostic ignored "-Wuninitialized"
  469. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  470. #include "cutlass/convolution/device/convolution.h"
  471. #include "src/cuda/cutlass/manifest.h"
  472. #include "src/cuda/cutlass/convolution_operation.h"
  473. """
  474. self.instance_template = """
  475. ${operation_instance}
  476. """
  477. self.manifest_template = """
  478. namespace cutlass {
  479. namespace library {
  480. void initialize_${operation_name}(Manifest &manifest) {
  481. manifest.append(new ConvolutionOperation<${convolution_name}>(
  482. "${operation_name}"
  483. ));
  484. }
  485. } // namespace library
  486. } // namespace cutlass
  487. """
  488. self.epilogue_template = """
  489. #pragma GCC diagnostic pop
  490. #endif
  491. """
  492. #
  493. def __enter__(self):
  494. if self.short_path:
  495. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  496. GlobalCnt.cnt += 1
  497. else:
  498. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  499. self.kernel_file = open(self.kernel_path, "w")
  500. self.kernel_file.write(SubstituteTemplate(self.header_template, {
  501. 'required_cuda_ver_major': str(self.operation.required_cuda_ver_major),
  502. 'required_cuda_ver_minor': str(self.operation.required_cuda_ver_minor),
  503. }))
  504. return self
  505. #
  506. def emit(self):
  507. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  508. 'operation_instance': self.instance_emitter.emit(self.operation),
  509. }))
  510. # emit manifest helper
  511. manifest = SubstituteTemplate(self.manifest_template, {
  512. 'operation_name': self.operation.procedural_name(),
  513. 'convolution_name': self.convolution_name
  514. })
  515. self.kernel_file.write(manifest)
  516. #
  517. def __exit__(self, exception_type, exception_value, traceback):
  518. self.kernel_file.write(self.epilogue_template)
  519. self.kernel_file.close()
  520. ###################################################################################################
  521. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台