You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv2d_operation.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. #
  7. import enum
  8. import os.path
  9. import shutil
  10. from typing import Tuple, List
  11. from library import *
  12. ###################################################################################################
  13. #
  14. class Conv2dOperation:
  15. #
  16. def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \
  17. epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \
  18. need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False):
  19. self.operation_kind = OperationKind.Conv2d
  20. self.conv_kind = conv_kind
  21. self.arch = arch
  22. self.tile_description = tile_description
  23. self.conv_type = conv_type
  24. self.src = src
  25. self.flt = flt
  26. self.bias = bias
  27. self.dst = dst
  28. self.element_epilogue = element_epilogue
  29. self.epilogue_functor = epilogue_functor
  30. self.swizzling_functor = swizzling_functor
  31. self.need_load_from_const = need_load_from_const
  32. self.implicit_gemm_mode = implicit_gemm_mode
  33. self.without_shared_load = without_shared_load
  34. #
  35. def accumulator_type(self):
  36. accum = self.tile_description.math_instruction.element_accumulator
  37. return accum
  38. #
  39. def core_name(self):
  40. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  41. intermediate_type = ''
  42. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
  43. inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
  44. if self.tile_description.math_instruction.element_a != self.flt.element and \
  45. self.tile_description.math_instruction.element_a != self.accumulator_type():
  46. intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
  47. else:
  48. inst_shape = ''
  49. unity_kernel = ''
  50. if not self.need_load_from_const:
  51. unity_kernel = '_1x1'
  52. reorder_k = ''
  53. if self.without_shared_load:
  54. reorder_k = '_roc'
  55. return "%s%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
  56. inst_shape, intermediate_type, ConvKindNames[self.conv_kind], unity_kernel, \
  57. reorder_k, ShortEpilogueNames[self.epilogue_functor])
  58. #
  59. def extended_name(self):
  60. if self.dst.element != self.tile_description.math_instruction.element_accumulator:
  61. if self.src.element != self.flt.element:
  62. extended_name = "${element_dst}_${core_name}_${element_src}_${element_flt}"
  63. elif self.src.element == self.flt.element:
  64. extended_name = "${element_dst}_${core_name}_${element_src}"
  65. else:
  66. if self.src.element != self.flt.element:
  67. extended_name = "${core_name}_${element_src}_${element_flt}"
  68. elif self.src.element == self.flt.element:
  69. extended_name = "${core_name}_${element_src}"
  70. extended_name = SubstituteTemplate(extended_name, {
  71. 'element_src': DataTypeNames[self.src.element],
  72. 'element_flt': DataTypeNames[self.flt.element],
  73. 'element_dst': DataTypeNames[self.dst.element],
  74. 'core_name': self.core_name()
  75. })
  76. return extended_name
  77. #
  78. def layout_name(self):
  79. if self.src.layout == self.dst.layout:
  80. layout_name = "${src_layout}_${flt_layout}"
  81. else:
  82. layout_name = "${src_layout}_${flt_layout}_${dst_layout}"
  83. layout_name = SubstituteTemplate(layout_name, {
  84. 'src_layout': ShortLayoutTypeNames[self.src.layout],
  85. 'flt_layout': ShortLayoutTypeNames[self.flt.layout],
  86. 'dst_layout': ShortLayoutTypeNames[self.dst.layout],
  87. })
  88. return layout_name
  89. #
  90. def configuration_name(self):
  91. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  92. opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
  93. warp_shape = [int(self.tile_description.threadblock_shape[idx] / self.tile_description.warp_count[idx]) for idx in range(3)]
  94. threadblock = "%dx%dx%d_%dx%dx%d_%d" % (
  95. self.tile_description.threadblock_shape[0],
  96. self.tile_description.threadblock_shape[1],
  97. self.tile_description.threadblock_shape[2],
  98. warp_shape[0],
  99. warp_shape[1],
  100. warp_shape[2],
  101. self.tile_description.stages,
  102. )
  103. configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}"
  104. return SubstituteTemplate(
  105. configuration_name,
  106. {
  107. 'opcode_class': opcode_class_name,
  108. 'extended_name': self.extended_name(),
  109. 'threadblock': threadblock,
  110. 'layout': self.layout_name(),
  111. }
  112. )
  113. #
  114. def procedural_name(self):
  115. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  116. return self.configuration_name()
  117. ###################################################################################################
  118. #
  119. # Emits single instances of a CUTLASS device-wide operator
  120. #
  121. ###################################################################################################
  122. class EmitConv2dInstance:
  123. def __init__(self):
  124. self.template = """
  125. // kernel instance "${operation_name}" generated by cutlass generator
  126. using Convolution =
  127. typename cutlass::conv::device::Convolution<
  128. ${element_src},
  129. ${layout_src},
  130. ${element_flt},
  131. ${layout_flt},
  132. ${element_dst},
  133. ${layout_dst},
  134. ${element_bias},
  135. ${layout_bias},
  136. ${element_accumulator},
  137. ${conv_type},
  138. ${opcode_class},
  139. ${arch},
  140. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  141. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  142. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  143. ${epilogue_functor}<
  144. ${element_dst},
  145. ${epilogue_vector_length},
  146. ${element_accumulator},
  147. ${element_bias},
  148. ${element_epilogue}
  149. >,
  150. ${swizzling_functor},
  151. ${stages},
  152. ${alignment_src},
  153. ${alignment_filter},
  154. ${nonuninity_kernel},
  155. ${math_operator},
  156. ${implicit_gemm_mode},
  157. ${without_shared_load}>;
  158. """
  159. def emit(self, operation):
  160. warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
  161. epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element])
  162. values = {
  163. 'operation_name': operation.procedural_name(),
  164. 'conv_type': ConvTypeTag[operation.conv_type],
  165. 'element_src': DataTypeTag[operation.src.element],
  166. 'layout_src': LayoutTag[operation.src.layout],
  167. 'element_flt': DataTypeTag[operation.flt.element],
  168. 'layout_flt': LayoutTag[operation.flt.layout],
  169. 'element_dst': DataTypeTag[operation.dst.element],
  170. 'layout_dst': LayoutTag[operation.dst.layout],
  171. 'element_bias': DataTypeTag[operation.bias.element],
  172. 'layout_bias': LayoutTag[operation.bias.layout],
  173. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  174. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  175. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  176. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  177. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  178. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  179. 'warp_shape_m': str(warp_shape[0]),
  180. 'warp_shape_n': str(warp_shape[1]),
  181. 'warp_shape_k': str(warp_shape[2]),
  182. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  183. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  184. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  185. 'epilogue_vector_length': str(epilogue_vector_length),
  186. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  187. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  188. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  189. 'stages': str(operation.tile_description.stages),
  190. 'alignment_src': str(operation.src.alignment),
  191. 'alignment_filter': str(operation.flt.alignment),
  192. 'nonuninity_kernel': str(operation.need_load_from_const).lower(),
  193. 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  194. 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode],
  195. 'without_shared_load': str(operation.without_shared_load).lower()
  196. }
  197. return SubstituteTemplate(self.template, values)
  198. class EmitDeconvInstance:
  199. def __init__(self):
  200. self.template = """
  201. // kernel instance "${operation_name}" generated by cutlass generator
  202. using Deconvolution =
  203. typename cutlass::conv::device::Deconvolution<
  204. ${element_src},
  205. ${layout_src},
  206. ${element_flt},
  207. ${layout_flt},
  208. ${element_dst},
  209. ${layout_dst},
  210. ${element_bias},
  211. ${layout_bias},
  212. ${element_accumulator},
  213. ${conv_type},
  214. ${opcode_class},
  215. ${arch},
  216. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  217. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  218. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  219. ${epilogue_functor}<
  220. ${element_dst},
  221. ${epilogue_vector_length},
  222. ${element_accumulator},
  223. ${element_bias},
  224. ${element_epilogue}
  225. >,
  226. ${swizzling_functor},
  227. ${stages},
  228. ${alignment_src},
  229. ${alignment_filter},
  230. ${nonuninity_kernel},
  231. ${math_operator},
  232. ${implicit_gemm_mode}>;
  233. """
  234. def emit(self, operation):
  235. warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
  236. epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element])
  237. values = {
  238. 'operation_name': operation.procedural_name(),
  239. 'conv_type': ConvTypeTag[operation.conv_type],
  240. 'element_src': DataTypeTag[operation.src.element],
  241. 'layout_src': LayoutTag[operation.src.layout],
  242. 'element_flt': DataTypeTag[operation.flt.element],
  243. 'layout_flt': LayoutTag[operation.flt.layout],
  244. 'element_dst': DataTypeTag[operation.dst.element],
  245. 'layout_dst': LayoutTag[operation.dst.layout],
  246. 'element_bias': DataTypeTag[operation.bias.element],
  247. 'layout_bias': LayoutTag[operation.bias.layout],
  248. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  249. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  250. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  251. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  252. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  253. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  254. 'warp_shape_m': str(warp_shape[0]),
  255. 'warp_shape_n': str(warp_shape[1]),
  256. 'warp_shape_k': str(warp_shape[2]),
  257. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  258. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  259. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  260. 'epilogue_vector_length': str(epilogue_vector_length),
  261. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  262. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  263. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  264. 'stages': str(operation.tile_description.stages),
  265. 'alignment_src': str(operation.src.alignment),
  266. 'alignment_filter': str(operation.flt.alignment),
  267. 'nonuninity_kernel': str(operation.need_load_from_const).lower(),
  268. 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  269. 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode]
  270. }
  271. return SubstituteTemplate(self.template, values)
  272. ###################################################################################################
  273. #
  274. # Generator functions for all layouts
  275. #
  276. ###################################################################################################
  277. #
  278. def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 128, \
  279. skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False):
  280. operations = []
  281. element_epilogue = DataType.f32
  282. if conv_kind == ConvKind.Fprop:
  283. if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
  284. swizzling_functor = SwizzlingFunctor.ConvFpropTrans
  285. else:
  286. swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx
  287. else:
  288. swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx
  289. # skip rule
  290. def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool:
  291. return layout == LayoutType.TensorNC32HW32 and \
  292. tile.threadblock_shape[0] % 32 != 0
  293. # rule for bias_type and epilogues
  294. def get_bias_type_and_epilogues(tile: TileDescription, \
  295. out_dtype: DataType) -> Tuple[DataType, List[EpilogueFunctor]]:
  296. if tile.math_instruction.element_accumulator == DataType.s32 and \
  297. out_dtype != DataType.f32:
  298. bias_type = DataType.s32
  299. if tile.math_instruction.element_b == DataType.u4:
  300. epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp]
  301. else:
  302. epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp, \
  303. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp]
  304. elif tile.math_instruction.element_accumulator == DataType.f32 or \
  305. out_dtype == DataType.f32:
  306. bias_type = DataType.f32
  307. epilogues = [EpilogueFunctor.BiasAddLinearCombination, EpilogueFunctor.BiasAddLinearCombinationRelu, \
  308. EpilogueFunctor.BiasAddLinearCombinationHSwish]
  309. return bias_type, epilogues
  310. # rule for filter alignment
  311. def get_flt_align(tile: TileDescription) -> int:
  312. nonlocal flt_align
  313. if tile.math_instruction.opcode_class == OpcodeClass.Simt \
  314. and tile.math_instruction.element_accumulator == DataType.s32:
  315. thread_num = tile.warp_count[0] * tile.warp_count[1] * tile.warp_count[2] * 32
  316. flt_block = tile.threadblock_shape[0] * tile.threadblock_shape[2] \
  317. * DataTypeSize[tile.math_instruction.element_a]
  318. load_per_thread = flt_block//thread_num
  319. if load_per_thread >= 128:
  320. flt_align = 128
  321. elif load_per_thread >= 64:
  322. flt_align = 64
  323. else:
  324. assert load_per_thread >= 32
  325. flt_align = 32
  326. return flt_align
  327. def get_dst_align(tile: TileDescription, out_layout: LayoutType) -> int:
  328. nonlocal dst_align
  329. if tile.math_instruction.opcode_class == OpcodeClass.TensorOp \
  330. and dst_layout == LayoutType.TensorNC4HW4:
  331. dst_align = 32
  332. return dst_align
  333. def filter_epilogue_with_conv_kind(epilogue: EpilogueFunctor, conv_kind: ConvKind) -> bool:
  334. return conv_kind == ConvKind.Dgrad \
  335. and epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp
  336. # loop over all tile descriptions
  337. for tile in tile_descriptions:
  338. if filter_tile_with_layout(tile, dst_layout):
  339. continue
  340. bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type)
  341. flt_align = get_flt_align(tile)
  342. dst_align = get_dst_align(tile, dst_layout)
  343. for epilogue in epilogues:
  344. if filter_epilogue_with_conv_kind(epilogue, conv_kind):
  345. continue
  346. if dst_type == DataType.f32:
  347. bias_type = DataType.f32
  348. #
  349. src = TensorDescription(tile.math_instruction.element_b, src_layout, int(src_align / DataTypeSize[tile.math_instruction.element_b]))
  350. flt = TensorDescription(tile.math_instruction.element_a, flt_layout, int(flt_align / DataTypeSize[tile.math_instruction.element_a]))
  351. bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type])))
  352. dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type]))
  353. new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode, without_shared_load)
  354. operations.append(new_operation)
  355. if not skip_unity_kernel:
  356. new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode, without_shared_load)
  357. operations.append(new_operation)
  358. return operations
  359. ###################################################################################################
  360. #
  361. # Emitters functions for all targets
  362. #
  363. ###################################################################################################
  364. class EmitConv2dConfigurationLibrary:
  365. def __init__(self, operation_path, configuration_name):
  366. self.configuration_name = configuration_name
  367. self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
  368. self.instance_emitter = EmitConv2dInstance()
  369. self.instance_template = """
  370. ${operation_instance}
  371. // Derived class
  372. struct ${operation_name} :
  373. public ${operation_name}_base { };
  374. ///////////////////////////////////////////////////////////////////////////////////////////////////
  375. """
  376. self.header_template = """
  377. /*
  378. Generated by conv2d_operation.py - Do not edit.
  379. */
  380. ///////////////////////////////////////////////////////////////////////////////////////////////////
  381. #include "cutlass/cutlass.h"
  382. #include "cutlass/library/library.h"
  383. #include "cutlass/library/manifest.h"
  384. #include "library_internal.h"
  385. #include "conv2d_operation.h"
  386. ///////////////////////////////////////////////////////////////////////////////////////////////////
  387. """
  388. self.configuration_header = """
  389. namespace cutlass {
  390. namespace library {
  391. // Initialize all instances
  392. void initialize_${configuration_name}(Manifest &manifest) {
  393. """
  394. self.configuration_instance = """
  395. using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
  396. ${operation_name}>;
  397. manifest.append(new cutlass::library::Conv2dOperation<
  398. Operation_${operation_name}>(
  399. "${operation_name}"));
  400. """
  401. self.configuration_epilogue = """
  402. }
  403. """
  404. self.epilogue_template = """
  405. ///////////////////////////////////////////////////////////////////////////////////////////////////
  406. } // namespace library
  407. } // namespace cutlass
  408. ///////////////////////////////////////////////////////////////////////////////////////////////////
  409. """
  410. #
  411. def __enter__(self):
  412. self.configuration_file = open(self.configuration_path, "w")
  413. self.configuration_file.write(SubstituteTemplate(self.header_template, {
  414. 'configuration_name': self.configuration_name
  415. }))
  416. self.operations = []
  417. return self
  418. #
  419. def emit(self, operation):
  420. self.operations.append(operation)
  421. self.configuration_file.write(SubstituteTemplate(self.instance_template, {
  422. 'configuration_name': self.configuration_name,
  423. 'operation_name': operation.procedural_name(),
  424. 'operation_instance': self.instance_emitter.emit(operation)
  425. }))
  426. #
  427. def __exit__(self, exception_type, exception_value, traceback):
  428. self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
  429. 'configuration_name': self.configuration_name
  430. }))
  431. for operation in self.operations:
  432. self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
  433. 'configuration_name': self.configuration_name,
  434. 'operation_name': operation.procedural_name()
  435. }))
  436. self.configuration_file.write(self.configuration_epilogue)
  437. self.configuration_file.write(self.epilogue_template)
  438. self.configuration_file.close()
  439. ###################################################################################################
  440. ###################################################################################################
  441. # Emitters for Conv Kernel Wrapper
  442. #
  443. ###################################################################################################
  444. class EmitConvSingleKernelWrapper():
  445. def __init__(self, kernel_path, operation, short_path=False):
  446. self.kernel_path = kernel_path
  447. self.operation = operation
  448. self.short_path = short_path
  449. if self.operation.conv_kind == ConvKind.Fprop:
  450. self.instance_emitter = EmitConv2dInstance()
  451. self.convolution_name = "Convolution"
  452. else:
  453. assert self.operation.conv_kind == ConvKind.Dgrad
  454. self.instance_emitter = EmitDeconvInstance()
  455. self.convolution_name = "Deconvolution"
  456. self.header_template = """
  457. #if !MEGDNN_TEGRA_X1
  458. // ignore warning of cutlass
  459. #pragma GCC diagnostic push
  460. #pragma GCC diagnostic ignored "-Wunused-parameter"
  461. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  462. #pragma GCC diagnostic ignored "-Wuninitialized"
  463. #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
  464. #include "cutlass/convolution/device/convolution.h"
  465. #include "src/cuda/cutlass/manifest.h"
  466. #include "src/cuda/cutlass/convolution_operation.h"
  467. """
  468. self.instance_template = """
  469. ${operation_instance}
  470. """
  471. self.manifest_template = """
  472. namespace cutlass {
  473. namespace library {
  474. void initialize_${operation_name}(Manifest &manifest) {
  475. manifest.append(new ConvolutionOperation<${convolution_name}>(
  476. "${operation_name}"
  477. ));
  478. }
  479. } // namespace library
  480. } // namespace cutlass
  481. """
  482. self.epilogue_template = """
  483. #pragma GCC diagnostic pop
  484. #endif
  485. """
  486. #
  487. def __enter__(self):
  488. if self.short_path:
  489. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % GlobalCnt.cnt)
  490. GlobalCnt.cnt += 1
  491. else:
  492. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  493. self.kernel_file = open(self.kernel_path, "w")
  494. self.kernel_file.write(self.header_template)
  495. return self
  496. #
  497. def emit(self):
  498. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  499. 'operation_instance': self.instance_emitter.emit(self.operation),
  500. }))
  501. # emit manifest helper
  502. manifest = SubstituteTemplate(self.manifest_template, {
  503. 'operation_name': self.operation.procedural_name(),
  504. 'convolution_name': self.convolution_name
  505. })
  506. self.kernel_file.write(manifest)
  507. #
  508. def __exit__(self, exception_type, exception_value, traceback):
  509. self.kernel_file.write(self.epilogue_template)
  510. self.kernel_file.close()
  511. ###################################################################################################
  512. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台