You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv2d_operation.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. #
  7. import enum
  8. import os.path
  9. import shutil
  10. from typing import Tuple, List
  11. from library import *
  12. ###################################################################################################
  13. #
  14. class Conv2dOperation:
  15. #
  16. def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \
  17. epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \
  18. special_optimization = SpecialOptimizeDesc.NoneSpecialOpt, implicit_gemm_mode = ImplicitGemmMode.GemmNT, \
  19. without_shared_load = False, required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  20. self.operation_kind = OperationKind.Conv2d
  21. self.conv_kind = conv_kind
  22. self.arch = arch
  23. self.tile_description = tile_description
  24. self.conv_type = conv_type
  25. self.src = src
  26. self.flt = flt
  27. self.bias = bias
  28. self.dst = dst
  29. self.element_epilogue = element_epilogue
  30. self.epilogue_functor = epilogue_functor
  31. self.swizzling_functor = swizzling_functor
  32. self.special_optimization = special_optimization
  33. self.implicit_gemm_mode = implicit_gemm_mode
  34. self.without_shared_load = without_shared_load
  35. self.required_cuda_ver_major = required_cuda_ver_major
  36. self.required_cuda_ver_minor = required_cuda_ver_minor
  37. #
  38. def accumulator_type(self):
  39. accum = self.tile_description.math_instruction.element_accumulator
  40. return accum
  41. #
  42. def core_name(self):
  43. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  44. intermediate_type = ''
  45. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
  46. inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
  47. if self.tile_description.math_instruction.element_a != self.flt.element and \
  48. self.tile_description.math_instruction.element_a != self.accumulator_type():
  49. intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
  50. else:
  51. inst_shape = ''
  52. special_opt = ''
  53. if self.special_optimization == SpecialOptimizeDesc.ConvFilterUnity:
  54. special_opt = '_1x1'
  55. elif self.special_optimization == SpecialOptimizeDesc.DeconvDoubleUpsampling:
  56. special_opt = '_s2'
  57. reorder_k = ''
  58. if self.without_shared_load:
  59. reorder_k = '_roc'
  60. return "%s%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
  61. inst_shape, intermediate_type, ConvKindNames[self.conv_kind], special_opt, \
  62. reorder_k, ShortEpilogueNames[self.epilogue_functor])
  63. #
  64. def extended_name(self):
  65. if self.dst.element != self.tile_description.math_instruction.element_accumulator:
  66. if self.src.element != self.flt.element:
  67. extended_name = "${element_dst}_${core_name}_${element_src}_${element_flt}"
  68. elif self.src.element == self.flt.element:
  69. extended_name = "${element_dst}_${core_name}_${element_src}"
  70. else:
  71. if self.src.element != self.flt.element:
  72. extended_name = "${core_name}_${element_src}_${element_flt}"
  73. elif self.src.element == self.flt.element:
  74. extended_name = "${core_name}_${element_src}"
  75. extended_name = SubstituteTemplate(extended_name, {
  76. 'element_src': DataTypeNames[self.src.element],
  77. 'element_flt': DataTypeNames[self.flt.element],
  78. 'element_dst': DataTypeNames[self.dst.element],
  79. 'core_name': self.core_name()
  80. })
  81. return extended_name
  82. #
  83. def layout_name(self):
  84. if self.src.layout == self.dst.layout:
  85. layout_name = "${src_layout}_${flt_layout}"
  86. else:
  87. layout_name = "${src_layout}_${flt_layout}_${dst_layout}"
  88. layout_name = SubstituteTemplate(layout_name, {
  89. 'src_layout': ShortLayoutTypeNames[self.src.layout],
  90. 'flt_layout': ShortLayoutTypeNames[self.flt.layout],
  91. 'dst_layout': ShortLayoutTypeNames[self.dst.layout],
  92. })
  93. return layout_name
  94. #
  95. def configuration_name(self):
  96. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  97. opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
  98. warp_shape = [int(self.tile_description.threadblock_shape[idx] / self.tile_description.warp_count[idx]) for idx in range(3)]
  99. threadblock = "%dx%dx%d_%dx%dx%d_%d" % (
  100. self.tile_description.threadblock_shape[0],
  101. self.tile_description.threadblock_shape[1],
  102. self.tile_description.threadblock_shape[2],
  103. warp_shape[0],
  104. warp_shape[1],
  105. warp_shape[2],
  106. self.tile_description.stages,
  107. )
  108. configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}"
  109. return SubstituteTemplate(
  110. configuration_name,
  111. {
  112. 'opcode_class': opcode_class_name,
  113. 'extended_name': self.extended_name(),
  114. 'threadblock': threadblock,
  115. 'layout': self.layout_name(),
  116. }
  117. )
  118. #
  119. def procedural_name(self):
  120. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  121. return self.configuration_name()
  122. ###################################################################################################
  123. #
  124. # Emits single instances of a CUTLASS device-wide operator
  125. #
  126. ###################################################################################################
  127. class EmitConv2dInstance:
  128. def __init__(self):
  129. self.template = """
  130. // kernel instance "${operation_name}" generated by cutlass generator
  131. using Convolution =
  132. typename cutlass::conv::device::Convolution<
  133. ${element_src},
  134. ${layout_src},
  135. ${element_flt},
  136. ${layout_flt},
  137. ${element_dst},
  138. ${layout_dst},
  139. ${element_bias},
  140. ${layout_bias},
  141. ${element_accumulator},
  142. ${conv_type},
  143. ${opcode_class},
  144. ${arch},
  145. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  146. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  147. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  148. ${epilogue_functor}<
  149. ${element_dst},
  150. ${epilogue_vector_length},
  151. ${element_accumulator},
  152. ${element_bias},
  153. ${element_epilogue}
  154. >,
  155. ${swizzling_functor},
  156. ${stages},
  157. ${alignment_src},
  158. ${alignment_filter},
  159. ${special_optimization},
  160. ${math_operator},
  161. ${implicit_gemm_mode},
  162. ${without_shared_load}>;
  163. """
  164. def emit(self, operation):
  165. warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
  166. epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element])
  167. values = {
  168. 'operation_name': operation.procedural_name(),
  169. 'conv_type': ConvTypeTag[operation.conv_type],
  170. 'element_src': DataTypeTag[operation.src.element],
  171. 'layout_src': LayoutTag[operation.src.layout],
  172. 'element_flt': DataTypeTag[operation.flt.element],
  173. 'layout_flt': LayoutTag[operation.flt.layout],
  174. 'element_dst': DataTypeTag[operation.dst.element],
  175. 'layout_dst': LayoutTag[operation.dst.layout],
  176. 'element_bias': DataTypeTag[operation.bias.element],
  177. 'layout_bias': LayoutTag[operation.bias.layout],
  178. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  179. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  180. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  181. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  182. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  183. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  184. 'warp_shape_m': str(warp_shape[0]),
  185. 'warp_shape_n': str(warp_shape[1]),
  186. 'warp_shape_k': str(warp_shape[2]),
  187. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  188. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  189. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  190. 'epilogue_vector_length': str(epilogue_vector_length),
  191. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  192. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  193. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  194. 'stages': str(operation.tile_description.stages),
  195. 'alignment_src': str(operation.src.alignment),
  196. 'alignment_filter': str(operation.flt.alignment),
  197. 'special_optimization': SpecialOptimizeDescTag[operation.special_optimization],
  198. 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  199. 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode],
  200. 'without_shared_load': str(operation.without_shared_load).lower()
  201. }
  202. return SubstituteTemplate(self.template, values)
  203. class EmitDeconvInstance:
  204. def __init__(self):
  205. self.template = """
  206. // kernel instance "${operation_name}" generated by cutlass generator
  207. using Deconvolution =
  208. typename cutlass::conv::device::Deconvolution<
  209. ${element_src},
  210. ${layout_src},
  211. ${element_flt},
  212. ${layout_flt},
  213. ${element_dst},
  214. ${layout_dst},
  215. ${element_bias},
  216. ${layout_bias},
  217. ${element_accumulator},
  218. ${conv_type},
  219. ${opcode_class},
  220. ${arch},
  221. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  222. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  223. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  224. ${epilogue_functor}<
  225. ${element_dst},
  226. ${epilogue_vector_length},
  227. ${element_accumulator},
  228. ${element_bias},
  229. ${element_epilogue}
  230. >,
  231. ${swizzling_functor},
  232. ${stages},
  233. ${alignment_src},
  234. ${alignment_filter},
  235. ${special_optimization},
  236. ${math_operator},
  237. ${implicit_gemm_mode}>;
  238. """
  239. def emit(self, operation):
  240. warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
  241. epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element])
  242. values = {
  243. 'operation_name': operation.procedural_name(),
  244. 'conv_type': ConvTypeTag[operation.conv_type],
  245. 'element_src': DataTypeTag[operation.src.element],
  246. 'layout_src': LayoutTag[operation.src.layout],
  247. 'element_flt': DataTypeTag[operation.flt.element],
  248. 'layout_flt': LayoutTag[operation.flt.layout],
  249. 'element_dst': DataTypeTag[operation.dst.element],
  250. 'layout_dst': LayoutTag[operation.dst.layout],
  251. 'element_bias': DataTypeTag[operation.bias.element],
  252. 'layout_bias': LayoutTag[operation.bias.layout],
  253. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  254. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  255. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  256. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  257. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  258. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  259. 'warp_shape_m': str(warp_shape[0]),
  260. 'warp_shape_n': str(warp_shape[1]),
  261. 'warp_shape_k': str(warp_shape[2]),
  262. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  263. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  264. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  265. 'epilogue_vector_length': str(epilogue_vector_length),
  266. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  267. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  268. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  269. 'stages': str(operation.tile_description.stages),
  270. 'alignment_src': str(operation.src.alignment),
  271. 'alignment_filter': str(operation.flt.alignment),
  272. 'special_optimization': SpecialOptimizeDescTag[operation.special_optimization],
  273. 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  274. 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode]
  275. }
  276. return SubstituteTemplate(self.template, values)
  277. ###################################################################################################
  278. #
  279. # Generator functions for all layouts
  280. #
  281. ###################################################################################################
  282. #
  283. def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 32, \
  284. use_special_optimization = SpecialOptimizeDesc.NoneSpecialOpt, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False, \
  285. required_cuda_ver_major = 9, required_cuda_ver_minor = 2):
  286. operations = []
  287. element_epilogue = DataType.f32
  288. if conv_kind == ConvKind.Fprop:
  289. if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
  290. swizzling_functor = SwizzlingFunctor.ConvFpropTrans
  291. else:
  292. swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx
  293. else:
  294. if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
  295. swizzling_functor = SwizzlingFunctor.ConvDgradTrans
  296. else:
  297. swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx
  298. # skip rule
  299. def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool:
  300. return layout == LayoutType.TensorNC32HW32 and \
  301. tile.threadblock_shape[0] % 32 != 0
  302. # rule for bias_type and epilogues
  303. def get_bias_type_and_epilogues(tile: TileDescription, \
  304. out_dtype: DataType) -> Tuple[DataType, List[EpilogueFunctor]]:
  305. if tile.math_instruction.element_accumulator == DataType.s32 and \
  306. out_dtype != DataType.f32:
  307. bias_type = DataType.s32
  308. if tile.math_instruction.element_b == DataType.u4:
  309. epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp]
  310. else:
  311. epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp, \
  312. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp]
  313. elif tile.math_instruction.element_accumulator == DataType.f32 or \
  314. out_dtype == DataType.f32:
  315. bias_type = DataType.f32
  316. epilogues = [EpilogueFunctor.BiasAddLinearCombination, EpilogueFunctor.BiasAddLinearCombinationRelu, \
  317. EpilogueFunctor.BiasAddLinearCombinationHSwish]
  318. return bias_type, epilogues
  319. # rule for filter alignment
  320. def get_flt_align(tile: TileDescription) -> int:
  321. nonlocal flt_align
  322. if tile.math_instruction.opcode_class == OpcodeClass.Simt \
  323. and tile.math_instruction.element_accumulator == DataType.s32:
  324. thread_num = tile.warp_count[0] * tile.warp_count[1] * tile.warp_count[2] * 32
  325. flt_block = tile.threadblock_shape[0] * tile.threadblock_shape[2] \
  326. * DataTypeSize[tile.math_instruction.element_a]
  327. load_per_thread = flt_block//thread_num
  328. if load_per_thread >= 128:
  329. flt_align = 128
  330. elif load_per_thread >= 64:
  331. flt_align = 64
  332. else:
  333. assert load_per_thread >= 32
  334. flt_align = 32
  335. return flt_align
  336. def get_dst_align(tile: TileDescription, out_layout: LayoutType) -> int:
  337. nonlocal dst_align
  338. if tile.math_instruction.opcode_class == OpcodeClass.TensorOp \
  339. and dst_layout == LayoutType.TensorNC4HW4:
  340. dst_align = 32
  341. return dst_align
  342. def filter_epilogue_with_conv_kind(epilogue: EpilogueFunctor, conv_kind: ConvKind) -> bool:
  343. return conv_kind == ConvKind.Dgrad \
  344. and epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp
  345. # loop over all tile descriptions
  346. for tile in tile_descriptions:
  347. if filter_tile_with_layout(tile, dst_layout):
  348. continue
  349. bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type)
  350. flt_align = get_flt_align(tile)
  351. dst_align = get_dst_align(tile, dst_layout)
  352. for epilogue in epilogues:
  353. if filter_epilogue_with_conv_kind(epilogue, conv_kind):
  354. continue
  355. if dst_type == DataType.f32:
  356. bias_type = DataType.f32
  357. #
  358. src = TensorDescription(tile.math_instruction.element_b, src_layout, int(src_align / DataTypeSize[tile.math_instruction.element_b]))
  359. flt = TensorDescription(tile.math_instruction.element_a, flt_layout, int(flt_align / DataTypeSize[tile.math_instruction.element_a]))
  360. bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type])))
  361. dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type]))
  362. new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, SpecialOptimizeDesc.NoneSpecialOpt, implicit_gemm_mode, without_shared_load, required_cuda_ver_major, required_cuda_ver_minor)
  363. operations.append(new_operation)
  364. if use_special_optimization != SpecialOptimizeDesc.NoneSpecialOpt:
  365. new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, use_special_optimization , implicit_gemm_mode, without_shared_load, required_cuda_ver_major, required_cuda_ver_minor)
  366. operations.append(new_operation)
  367. return operations
  368. ###################################################################################################
  369. #
  370. # Emitters functions for all targets
  371. #
  372. ###################################################################################################
  373. class EmitConv2dConfigurationLibrary:
  374. def __init__(self, operation_path, configuration_name):
  375. self.configuration_name = configuration_name
  376. self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
  377. self.instance_emitter = EmitConv2dInstance()
  378. self.instance_template = """
  379. ${operation_instance}
  380. // Derived class
  381. struct ${operation_name} :
  382. public ${operation_name}_base { };
  383. ///////////////////////////////////////////////////////////////////////////////////////////////////
  384. """
  385. self.header_template = """
  386. /*
  387. Generated by conv2d_operation.py - Do not edit.
  388. */
  389. ///////////////////////////////////////////////////////////////////////////////////////////////////
  390. #include "cutlass/cutlass.h"
  391. #include "cutlass/library/library.h"
  392. #include "cutlass/library/manifest.h"
  393. #include "library_internal.h"
  394. #include "conv2d_operation.h"
  395. ///////////////////////////////////////////////////////////////////////////////////////////////////
  396. """
  397. self.configuration_header = """
  398. namespace cutlass {
  399. namespace library {
  400. // Initialize all instances
  401. void initialize_${configuration_name}(Manifest &manifest) {
  402. """
  403. self.configuration_instance = """
  404. using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
  405. ${operation_name}>;
  406. manifest.append(new cutlass::library::Conv2dOperation<
  407. Operation_${operation_name}>(
  408. "${operation_name}"));
  409. """
  410. self.configuration_epilogue = """
  411. }
  412. """
  413. self.epilogue_template = """
  414. ///////////////////////////////////////////////////////////////////////////////////////////////////
  415. } // namespace library
  416. } // namespace cutlass
  417. ///////////////////////////////////////////////////////////////////////////////////////////////////
  418. """
  419. #
  420. def __enter__(self):
  421. self.configuration_file = open(self.configuration_path, "w")
  422. self.configuration_file.write(SubstituteTemplate(self.header_template, {
  423. 'configuration_name': self.configuration_name
  424. }))
  425. self.operations = []
  426. return self
  427. #
  428. def emit(self, operation):
  429. self.operations.append(operation)
  430. self.configuration_file.write(SubstituteTemplate(self.instance_template, {
  431. 'configuration_name': self.configuration_name,
  432. 'operation_name': operation.procedural_name(),
  433. 'operation_instance': self.instance_emitter.emit(operation)
  434. }))
  435. #
  436. def __exit__(self, exception_type, exception_value, traceback):
  437. self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
  438. 'configuration_name': self.configuration_name
  439. }))
  440. for operation in self.operations:
  441. self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
  442. 'configuration_name': self.configuration_name,
  443. 'operation_name': operation.procedural_name()
  444. }))
  445. self.configuration_file.write(self.configuration_epilogue)
  446. self.configuration_file.write(self.epilogue_template)
  447. self.configuration_file.close()
  448. ###################################################################################################
  449. ###################################################################################################
  450. # Emitters for Conv Kernel Wrapper
  451. #
  452. ###################################################################################################
  453. class EmitConvSingleKernelWrapper():
  454. def __init__(self, kernel_path, operation, short_path=False):
  455. self.kernel_path = kernel_path
  456. self.operation = operation
  457. self.short_path = short_path
  458. if self.operation.conv_kind == ConvKind.Fprop:
  459. self.instance_emitter = EmitConv2dInstance()
  460. self.convolution_name = "Convolution"
  461. else:
  462. assert self.operation.conv_kind == ConvKind.Dgrad
  463. self.instance_emitter = EmitDeconvInstance()
  464. self.convolution_name = "Deconvolution"
  465. self.header_template = """
  466. #if __CUDACC_VER_MAJOR__ > ${required_cuda_ver_major} || (__CUDACC_VER_MAJOR__ == ${required_cuda_ver_major} && __CUDACC_VER_MINOR__ >= ${required_cuda_ver_minor})
  467. // ignore warning of cutlass
  468. #pragma GCC diagnostic push
  469. #pragma GCC diagnostic ignored "-Wunused-parameter"
  470. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  471. #pragma GCC diagnostic ignored "-Wuninitialized"
  472. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  473. #include "cutlass/convolution/device/convolution.h"
  474. #include "src/cuda/cutlass/manifest.h"
  475. #include "src/cuda/cutlass/convolution_operation.h"
  476. """
  477. self.instance_template = """
  478. ${operation_instance}
  479. """
  480. self.manifest_template = """
  481. namespace cutlass {
  482. namespace library {
  483. void initialize_${operation_name}(Manifest &manifest) {
  484. manifest.append(new ConvolutionOperation<${convolution_name}>(
  485. "${operation_name}"
  486. ));
  487. }
  488. } // namespace library
  489. } // namespace cutlass
  490. """
  491. self.epilogue_template = """
  492. #pragma GCC diagnostic pop
  493. #endif
  494. """
  495. #
  496. def __enter__(self):
  497. if self.short_path:
  498. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  499. GlobalCnt.cnt += 1
  500. else:
  501. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  502. self.kernel_file = open(self.kernel_path, "w")
  503. self.kernel_file.write(SubstituteTemplate(self.header_template, {
  504. 'required_cuda_ver_major': str(self.operation.required_cuda_ver_major),
  505. 'required_cuda_ver_minor': str(self.operation.required_cuda_ver_minor),
  506. }))
  507. return self
  508. #
  509. def emit(self):
  510. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  511. 'operation_instance': self.instance_emitter.emit(self.operation),
  512. }))
  513. # emit manifest helper
  514. manifest = SubstituteTemplate(self.manifest_template, {
  515. 'operation_name': self.operation.procedural_name(),
  516. 'convolution_name': self.convolution_name
  517. })
  518. self.kernel_file.write(manifest)
  519. #
  520. def __exit__(self, exception_type, exception_value, traceback):
  521. self.kernel_file.write(self.epilogue_template)
  522. self.kernel_file.close()
  523. ###################################################################################################
  524. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台