You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv2d_operation.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. #
  7. import enum
  8. import os.path
  9. import shutil
  10. from typing import Tuple, List
  11. from lazy_file import LazyFile
  12. from library import *
  13. ###################################################################################################
  14. #
  15. class Conv2dOperation:
  16. #
  17. def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \
  18. epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \
  19. need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNt):
  20. self.operation_kind = OperationKind.Conv2d
  21. self.conv_kind = conv_kind
  22. self.arch = arch
  23. self.tile_description = tile_description
  24. self.conv_type = conv_type
  25. self.src = src
  26. self.flt = flt
  27. self.bias = bias
  28. self.dst = dst
  29. self.element_epilogue = element_epilogue
  30. self.epilogue_functor = epilogue_functor
  31. self.swizzling_functor = swizzling_functor
  32. self.need_load_from_const = need_load_from_const
  33. self.implicit_gemm_mode = implicit_gemm_mode
  34. #
  35. def accumulator_type(self):
  36. accum = self.tile_description.math_instruction.element_accumulator
  37. return accum
  38. #
  39. def core_name(self):
  40. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  41. intermediate_type = ''
  42. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
  43. inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
  44. if self.tile_description.math_instruction.element_a != self.flt.element and \
  45. self.tile_description.math_instruction.element_a != self.accumulator_type():
  46. intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
  47. else:
  48. inst_shape = ''
  49. unity_kernel = ''
  50. if not self.need_load_from_const:
  51. unity_kernel = '_1x1'
  52. return "%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
  53. inst_shape, intermediate_type, ConvKindNames[self.conv_kind], unity_kernel, \
  54. ShortEpilogueNames[self.epilogue_functor])
  55. #
  56. def extended_name(self):
  57. if self.dst.element != self.tile_description.math_instruction.element_accumulator:
  58. if self.src.element != self.flt.element:
  59. extended_name = "${element_dst}_${core_name}_${element_src}_${element_flt}"
  60. elif self.src.element == self.flt.element:
  61. extended_name = "${element_dst}_${core_name}_${element_src}"
  62. else:
  63. if self.src.element != self.flt.element:
  64. extended_name = "${core_name}_${element_src}_${element_flt}"
  65. elif self.src.element == self.flt.element:
  66. extended_name = "${core_name}_${element_src}"
  67. extended_name = SubstituteTemplate(extended_name, {
  68. 'element_src': DataTypeNames[self.src.element],
  69. 'element_flt': DataTypeNames[self.flt.element],
  70. 'element_dst': DataTypeNames[self.dst.element],
  71. 'core_name': self.core_name()
  72. })
  73. return extended_name
  74. #
  75. def layout_name(self):
  76. if self.src.layout == self.dst.layout:
  77. layout_name = "${src_layout}_${flt_layout}"
  78. else:
  79. layout_name = "${src_layout}_${flt_layout}_${dst_layout}"
  80. layout_name = SubstituteTemplate(layout_name, {
  81. 'src_layout': ShortLayoutTypeNames[self.src.layout],
  82. 'flt_layout': ShortLayoutTypeNames[self.flt.layout],
  83. 'dst_layout': ShortLayoutTypeNames[self.dst.layout],
  84. })
  85. return layout_name
  86. #
  87. def configuration_name(self):
  88. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  89. opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
  90. warp_shape = [int(self.tile_description.threadblock_shape[idx] / self.tile_description.warp_count[idx]) for idx in range(3)]
  91. threadblock = "%dx%dx%d_%dx%dx%d_%d" % (
  92. self.tile_description.threadblock_shape[0],
  93. self.tile_description.threadblock_shape[1],
  94. self.tile_description.threadblock_shape[2],
  95. warp_shape[0],
  96. warp_shape[1],
  97. warp_shape[2],
  98. self.tile_description.stages,
  99. )
  100. configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}"
  101. return SubstituteTemplate(
  102. configuration_name,
  103. {
  104. 'opcode_class': opcode_class_name,
  105. 'extended_name': self.extended_name(),
  106. 'threadblock': threadblock,
  107. 'layout': self.layout_name(),
  108. }
  109. )
  110. #
  111. def procedural_name(self):
  112. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  113. return self.configuration_name()
  114. ###################################################################################################
  115. #
  116. # Emits single instances of a CUTLASS device-wide operator
  117. #
  118. ###################################################################################################
  119. class EmitConv2dInstance:
  120. def __init__(self):
  121. self.template = """
  122. // kernel instance "${operation_name}" generated by cutlass generator
  123. using Convolution =
  124. typename cutlass::conv::device::Convolution<
  125. ${element_src},
  126. ${layout_src},
  127. ${element_flt},
  128. ${layout_flt},
  129. ${element_dst},
  130. ${layout_dst},
  131. ${element_bias},
  132. ${layout_bias},
  133. ${element_accumulator},
  134. ${conv_type},
  135. ${opcode_class},
  136. ${arch},
  137. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  138. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  139. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  140. ${epilogue_functor}<
  141. ${element_dst},
  142. ${epilogue_vector_length},
  143. ${element_accumulator},
  144. ${element_bias},
  145. ${element_epilogue}
  146. >,
  147. ${swizzling_functor},
  148. ${stages},
  149. ${alignment_src},
  150. ${alignment_filter},
  151. ${nonuninity_kernel},
  152. ${math_operator},
  153. ${implicit_gemm_mode}>;
  154. """
  155. def emit(self, operation):
  156. warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
  157. epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element])
  158. values = {
  159. 'operation_name': operation.procedural_name(),
  160. 'conv_type': ConvTypeTag[operation.conv_type],
  161. 'element_src': DataTypeTag[operation.src.element],
  162. 'layout_src': LayoutTag[operation.src.layout],
  163. 'element_flt': DataTypeTag[operation.flt.element],
  164. 'layout_flt': LayoutTag[operation.flt.layout],
  165. 'element_dst': DataTypeTag[operation.dst.element],
  166. 'layout_dst': LayoutTag[operation.dst.layout],
  167. 'element_bias': DataTypeTag[operation.bias.element],
  168. 'layout_bias': LayoutTag[operation.bias.layout],
  169. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  170. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  171. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  172. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  173. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  174. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  175. 'warp_shape_m': str(warp_shape[0]),
  176. 'warp_shape_n': str(warp_shape[1]),
  177. 'warp_shape_k': str(warp_shape[2]),
  178. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  179. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  180. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  181. 'epilogue_vector_length': str(epilogue_vector_length),
  182. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  183. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  184. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  185. 'stages': str(operation.tile_description.stages),
  186. 'alignment_src': str(operation.src.alignment),
  187. 'alignment_filter': str(operation.flt.alignment),
  188. 'nonuninity_kernel': str(operation.need_load_from_const).lower(),
  189. 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  190. 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode]
  191. }
  192. return SubstituteTemplate(self.template, values)
  193. class EmitDeconvInstance:
  194. def __init__(self):
  195. self.template = """
  196. // kernel instance "${operation_name}" generated by cutlass generator
  197. using Deconvolution =
  198. typename cutlass::conv::device::Deconvolution<
  199. ${element_src},
  200. ${layout_src},
  201. ${element_flt},
  202. ${layout_flt},
  203. ${element_dst},
  204. ${layout_dst},
  205. ${element_bias},
  206. ${layout_bias},
  207. ${element_accumulator},
  208. ${opcode_class},
  209. ${arch},
  210. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  211. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  212. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  213. ${epilogue_functor}<
  214. ${element_dst},
  215. ${epilogue_vector_length},
  216. ${element_accumulator},
  217. ${element_bias},
  218. ${element_epilogue}
  219. >,
  220. ${swizzling_functor},
  221. ${stages},
  222. ${alignment_src},
  223. ${alignment_filter},
  224. ${nonuninity_kernel},
  225. ${math_operator},
  226. ${implicit_gemm_mode}>;
  227. """
  228. def emit(self, operation):
  229. warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
  230. epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element])
  231. values = {
  232. 'operation_name': operation.procedural_name(),
  233. 'element_src': DataTypeTag[operation.src.element],
  234. 'layout_src': LayoutTag[operation.src.layout],
  235. 'element_flt': DataTypeTag[operation.flt.element],
  236. 'layout_flt': LayoutTag[operation.flt.layout],
  237. 'element_dst': DataTypeTag[operation.dst.element],
  238. 'layout_dst': LayoutTag[operation.dst.layout],
  239. 'element_bias': DataTypeTag[operation.bias.element],
  240. 'layout_bias': LayoutTag[operation.bias.layout],
  241. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  242. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  243. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  244. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  245. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  246. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  247. 'warp_shape_m': str(warp_shape[0]),
  248. 'warp_shape_n': str(warp_shape[1]),
  249. 'warp_shape_k': str(warp_shape[2]),
  250. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  251. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  252. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  253. 'epilogue_vector_length': str(epilogue_vector_length),
  254. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  255. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  256. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  257. 'stages': str(operation.tile_description.stages),
  258. 'alignment_src': str(operation.src.alignment),
  259. 'alignment_filter': str(operation.flt.alignment),
  260. 'nonuninity_kernel': str(operation.need_load_from_const).lower(),
  261. 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  262. 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode]
  263. }
  264. return SubstituteTemplate(self.template, values)
  265. ###################################################################################################
  266. #
  267. # Generator functions for all layouts
  268. #
  269. ###################################################################################################
  270. #
  271. def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 128, \
  272. skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNt):
  273. operations = []
  274. element_epilogue = DataType.f32
  275. if conv_kind == ConvKind.Fprop:
  276. if src_layout == LayoutType.TensorNHWC:
  277. swizzling_functor = SwizzlingFunctor.ConvFpropNHWC
  278. else:
  279. swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx
  280. else:
  281. swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx
  282. # skip rule
  283. def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool:
  284. return layout == LayoutType.TensorNC32HW32 and \
  285. tile.threadblock_shape[0] % 32 != 0
  286. # rule for bias_type and epilogues
  287. def get_bias_type_and_epilogues(tile: TileDescription, \
  288. out_dtype: DataType) -> Tuple[DataType, List[EpilogueFunctor]]:
  289. if tile.math_instruction.element_accumulator == DataType.s32 and \
  290. out_dtype != DataType.f32:
  291. bias_type = DataType.s32
  292. if tile.math_instruction.element_b == DataType.u4:
  293. epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp]
  294. else:
  295. epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp, \
  296. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp]
  297. elif tile.math_instruction.element_accumulator == DataType.f32 or \
  298. out_dtype == DataType.f32:
  299. bias_type = DataType.f32
  300. epilogues = [EpilogueFunctor.BiasAddLinearCombination, EpilogueFunctor.BiasAddLinearCombinationRelu, \
  301. EpilogueFunctor.BiasAddLinearCombinationHSwish]
  302. return bias_type, epilogues
  303. # rule for filter alignment
  304. def get_flt_align(tile: TileDescription) -> int:
  305. nonlocal flt_align
  306. if tile.math_instruction.opcode_class == OpcodeClass.Simt \
  307. and tile.math_instruction.element_accumulator == DataType.s32:
  308. thread_num = tile.warp_count[0] * tile.warp_count[1] * tile.warp_count[2] * 32
  309. flt_block = tile.threadblock_shape[0] * tile.threadblock_shape[2] \
  310. * DataTypeSize[tile.math_instruction.element_a]
  311. load_per_thread = flt_block//thread_num
  312. if load_per_thread >= 128:
  313. flt_align = 128
  314. elif load_per_thread >= 64:
  315. flt_align = 64
  316. else:
  317. assert load_per_thread >= 32
  318. flt_align = 32
  319. return flt_align
  320. def get_dst_align(tile: TileDescription, out_layout: LayoutType) -> int:
  321. nonlocal dst_align
  322. if tile.math_instruction.opcode_class == OpcodeClass.TensorOp \
  323. and dst_layout == LayoutType.TensorNC4HW4:
  324. dst_align = 32
  325. return dst_align
  326. def filter_epilogue_with_conv_kind(epilogue: EpilogueFunctor, conv_kind: ConvKind) -> bool:
  327. return conv_kind == ConvKind.Dgrad \
  328. and epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp
  329. # loop over all tile descriptions
  330. for tile in tile_descriptions:
  331. if filter_tile_with_layout(tile, dst_layout):
  332. continue
  333. bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type)
  334. flt_align = get_flt_align(tile)
  335. dst_align = get_dst_align(tile, dst_layout)
  336. for epilogue in epilogues:
  337. if filter_epilogue_with_conv_kind(epilogue, conv_kind):
  338. continue
  339. if dst_type == DataType.f32:
  340. bias_type = DataType.f32
  341. #
  342. src = TensorDescription(tile.math_instruction.element_b, src_layout, int(src_align / DataTypeSize[tile.math_instruction.element_b]))
  343. flt = TensorDescription(tile.math_instruction.element_a, flt_layout, int(flt_align / DataTypeSize[tile.math_instruction.element_a]))
  344. bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type])))
  345. dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type]))
  346. new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode)
  347. operations.append(new_operation)
  348. if not skip_unity_kernel:
  349. new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode)
  350. operations.append(new_operation)
  351. return operations
  352. ###################################################################################################
  353. #
  354. # Emitters functions for all targets
  355. #
  356. ###################################################################################################
  357. class EmitConv2dConfigurationLibrary:
  358. def __init__(self, operation_path, configuration_name):
  359. self.configuration_name = configuration_name
  360. self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
  361. self.instance_emitter = EmitConv2dInstance()
  362. self.instance_template = """
  363. ${operation_instance}
  364. // Derived class
  365. struct ${operation_name} :
  366. public ${operation_name}_base { };
  367. ///////////////////////////////////////////////////////////////////////////////////////////////////
  368. """
  369. self.header_template = """
  370. /*
  371. Generated by conv2d_operation.py - Do not edit.
  372. */
  373. ///////////////////////////////////////////////////////////////////////////////////////////////////
  374. #include "cutlass/cutlass.h"
  375. #include "cutlass/library/library.h"
  376. #include "cutlass/library/manifest.h"
  377. #include "library_internal.h"
  378. #include "conv2d_operation.h"
  379. ///////////////////////////////////////////////////////////////////////////////////////////////////
  380. """
  381. self.configuration_header = """
  382. namespace cutlass {
  383. namespace library {
  384. // Initialize all instances
  385. void initialize_${configuration_name}(Manifest &manifest) {
  386. """
  387. self.configuration_instance = """
  388. using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
  389. ${operation_name}>;
  390. manifest.append(new cutlass::library::Conv2dOperation<
  391. Operation_${operation_name}>(
  392. "${operation_name}"));
  393. """
  394. self.configuration_epilogue = """
  395. }
  396. """
  397. self.epilogue_template = """
  398. ///////////////////////////////////////////////////////////////////////////////////////////////////
  399. } // namespace library
  400. } // namespace cutlass
  401. ///////////////////////////////////////////////////////////////////////////////////////////////////
  402. """
  403. #
  404. def __enter__(self):
  405. self.configuration_file = open(self.configuration_path, "w")
  406. self.configuration_file.write(SubstituteTemplate(self.header_template, {
  407. 'configuration_name': self.configuration_name
  408. }))
  409. self.operations = []
  410. return self
  411. #
  412. def emit(self, operation):
  413. self.operations.append(operation)
  414. self.configuration_file.write(SubstituteTemplate(self.instance_template, {
  415. 'configuration_name': self.configuration_name,
  416. 'operation_name': operation.procedural_name(),
  417. 'operation_instance': self.instance_emitter.emit(operation)
  418. }))
  419. #
  420. def __exit__(self, exception_type, exception_value, traceback):
  421. self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
  422. 'configuration_name': self.configuration_name
  423. }))
  424. for operation in self.operations:
  425. self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
  426. 'configuration_name': self.configuration_name,
  427. 'operation_name': operation.procedural_name()
  428. }))
  429. self.configuration_file.write(self.configuration_epilogue)
  430. self.configuration_file.write(self.epilogue_template)
  431. self.configuration_file.close()
  432. ###################################################################################################
  433. ###################################################################################################
  434. # Emitters for Conv Kernel Wrapper
  435. #
  436. ###################################################################################################
  437. class EmitConvSingleKernelWrapper():
  438. def __init__(self, kernel_path, operation, wrapper_path):
  439. self.kernel_path = kernel_path
  440. self.wrapper_path = wrapper_path
  441. self.operation = operation
  442. self.conv_wrappers = { \
  443. ConvKind.Fprop: """
  444. template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
  445. const typename Convolution::ElementSrc* d_src,
  446. const typename Convolution::ElementFilter* d_filter,
  447. const typename Convolution::ElementBias* d_bias,
  448. const typename Convolution::ElementDst* d_z,
  449. typename Convolution::ElementDst* d_dst,
  450. int* workspace,
  451. typename Convolution::ConvolutionParameter const& conv_param,
  452. typename Convolution::EpilogueOutputOp::Params const& epilogue,
  453. cudaStream_t stream,
  454. typename Convolution::ExtraParam extra_param);
  455. """, \
  456. ConvKind.Dgrad: """
  457. template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
  458. const typename Deconvolution::ElementSrc* d_src,
  459. const typename Deconvolution::ElementFilter* d_filter,
  460. const typename Deconvolution::ElementBias* d_bias,
  461. const typename Deconvolution::ElementDst* d_z,
  462. typename Deconvolution::ElementDst* d_dst,
  463. int* workspace,
  464. typename Deconvolution::ConvolutionParameter const& conv_param,
  465. typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
  466. cudaStream_t stream);
  467. """, \
  468. }
  469. if self.operation.conv_kind == ConvKind.Fprop:
  470. self.instance_emitter = EmitConv2dInstance()
  471. else:
  472. assert self.operation.conv_kind == ConvKind.Dgrad
  473. self.instance_emitter = EmitDeconvInstance()
  474. self.header_template = """
  475. #if !MEGDNN_TEGRA_X1
  476. // ignore warning of cutlass
  477. #pragma GCC diagnostic push
  478. #pragma GCC diagnostic ignored "-Wunused-parameter"
  479. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  480. #include "${wrapper_path}"
  481. """
  482. self.instance_template = """
  483. ${operation_instance}
  484. """
  485. self.wrapper_template = """
  486. ${wrapper_instance}
  487. """
  488. self.epilogue_template = """
  489. #pragma GCC diagnostic pop
  490. #endif
  491. """
  492. #
  493. def __enter__(self):
  494. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  495. self.kernel_file = LazyFile(self.kernel_path)
  496. self.kernel_file.write(SubstituteTemplate(self.header_template, {
  497. 'wrapper_path': self.wrapper_path,
  498. }))
  499. return self
  500. #
  501. def emit(self):
  502. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  503. 'operation_instance': self.instance_emitter.emit(self.operation),
  504. }))
  505. # emit wrapper
  506. wrapper = SubstituteTemplate(self.wrapper_template, {
  507. 'wrapper_instance': self.conv_wrappers[self.operation.conv_kind],
  508. })
  509. self.kernel_file.write(wrapper)
  510. #
  511. def __exit__(self, exception_type, exception_value, traceback):
  512. self.kernel_file.write(self.epilogue_template)
  513. self.kernel_file.close()
  514. ###################################################################################################
  515. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台