You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv2d_operation.py 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. #
  7. import enum
  8. import os.path
  9. import shutil
  10. from typing import Tuple, List
  11. from lazy_file import LazyFile
  12. from library import *
  13. ###################################################################################################
  14. #
  15. class Conv2dOperation:
  16. #
  17. def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \
  18. epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \
  19. need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False):
  20. self.operation_kind = OperationKind.Conv2d
  21. self.conv_kind = conv_kind
  22. self.arch = arch
  23. self.tile_description = tile_description
  24. self.conv_type = conv_type
  25. self.src = src
  26. self.flt = flt
  27. self.bias = bias
  28. self.dst = dst
  29. self.element_epilogue = element_epilogue
  30. self.epilogue_functor = epilogue_functor
  31. self.swizzling_functor = swizzling_functor
  32. self.need_load_from_const = need_load_from_const
  33. self.implicit_gemm_mode = implicit_gemm_mode
  34. self.without_shared_load = without_shared_load
  35. #
  36. def accumulator_type(self):
  37. accum = self.tile_description.math_instruction.element_accumulator
  38. return accum
  39. #
  40. def core_name(self):
  41. ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
  42. intermediate_type = ''
  43. if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
  44. inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
  45. if self.tile_description.math_instruction.element_a != self.flt.element and \
  46. self.tile_description.math_instruction.element_a != self.accumulator_type():
  47. intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
  48. else:
  49. inst_shape = ''
  50. unity_kernel = ''
  51. if not self.need_load_from_const:
  52. unity_kernel = '_1x1'
  53. reorder_k = ''
  54. if self.without_shared_load:
  55. reorder_k = '_roc'
  56. return "%s%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
  57. inst_shape, intermediate_type, ConvKindNames[self.conv_kind], unity_kernel, \
  58. reorder_k, ShortEpilogueNames[self.epilogue_functor])
  59. #
  60. def extended_name(self):
  61. if self.dst.element != self.tile_description.math_instruction.element_accumulator:
  62. if self.src.element != self.flt.element:
  63. extended_name = "${element_dst}_${core_name}_${element_src}_${element_flt}"
  64. elif self.src.element == self.flt.element:
  65. extended_name = "${element_dst}_${core_name}_${element_src}"
  66. else:
  67. if self.src.element != self.flt.element:
  68. extended_name = "${core_name}_${element_src}_${element_flt}"
  69. elif self.src.element == self.flt.element:
  70. extended_name = "${core_name}_${element_src}"
  71. extended_name = SubstituteTemplate(extended_name, {
  72. 'element_src': DataTypeNames[self.src.element],
  73. 'element_flt': DataTypeNames[self.flt.element],
  74. 'element_dst': DataTypeNames[self.dst.element],
  75. 'core_name': self.core_name()
  76. })
  77. return extended_name
  78. #
  79. def layout_name(self):
  80. if self.src.layout == self.dst.layout:
  81. layout_name = "${src_layout}_${flt_layout}"
  82. else:
  83. layout_name = "${src_layout}_${flt_layout}_${dst_layout}"
  84. layout_name = SubstituteTemplate(layout_name, {
  85. 'src_layout': ShortLayoutTypeNames[self.src.layout],
  86. 'flt_layout': ShortLayoutTypeNames[self.flt.layout],
  87. 'dst_layout': ShortLayoutTypeNames[self.dst.layout],
  88. })
  89. return layout_name
  90. #
  91. def configuration_name(self):
  92. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  93. opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
  94. warp_shape = [int(self.tile_description.threadblock_shape[idx] / self.tile_description.warp_count[idx]) for idx in range(3)]
  95. threadblock = "%dx%dx%d_%dx%dx%d_%d" % (
  96. self.tile_description.threadblock_shape[0],
  97. self.tile_description.threadblock_shape[1],
  98. self.tile_description.threadblock_shape[2],
  99. warp_shape[0],
  100. warp_shape[1],
  101. warp_shape[2],
  102. self.tile_description.stages,
  103. )
  104. configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}"
  105. return SubstituteTemplate(
  106. configuration_name,
  107. {
  108. 'opcode_class': opcode_class_name,
  109. 'extended_name': self.extended_name(),
  110. 'threadblock': threadblock,
  111. 'layout': self.layout_name(),
  112. }
  113. )
  114. #
  115. def procedural_name(self):
  116. ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
  117. return self.configuration_name()
  118. ###################################################################################################
  119. #
  120. # Emits single instances of a CUTLASS device-wide operator
  121. #
  122. ###################################################################################################
  123. class EmitConv2dInstance:
  124. def __init__(self):
  125. self.template = """
  126. // kernel instance "${operation_name}" generated by cutlass generator
  127. using Convolution =
  128. typename cutlass::conv::device::Convolution<
  129. ${element_src},
  130. ${layout_src},
  131. ${element_flt},
  132. ${layout_flt},
  133. ${element_dst},
  134. ${layout_dst},
  135. ${element_bias},
  136. ${layout_bias},
  137. ${element_accumulator},
  138. ${conv_type},
  139. ${opcode_class},
  140. ${arch},
  141. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  142. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  143. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  144. ${epilogue_functor}<
  145. ${element_dst},
  146. ${epilogue_vector_length},
  147. ${element_accumulator},
  148. ${element_bias},
  149. ${element_epilogue}
  150. >,
  151. ${swizzling_functor},
  152. ${stages},
  153. ${alignment_src},
  154. ${alignment_filter},
  155. ${nonuninity_kernel},
  156. ${math_operator},
  157. ${implicit_gemm_mode},
  158. ${without_shared_load}>;
  159. """
  160. def emit(self, operation):
  161. warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
  162. epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element])
  163. values = {
  164. 'operation_name': operation.procedural_name(),
  165. 'conv_type': ConvTypeTag[operation.conv_type],
  166. 'element_src': DataTypeTag[operation.src.element],
  167. 'layout_src': LayoutTag[operation.src.layout],
  168. 'element_flt': DataTypeTag[operation.flt.element],
  169. 'layout_flt': LayoutTag[operation.flt.layout],
  170. 'element_dst': DataTypeTag[operation.dst.element],
  171. 'layout_dst': LayoutTag[operation.dst.layout],
  172. 'element_bias': DataTypeTag[operation.bias.element],
  173. 'layout_bias': LayoutTag[operation.bias.layout],
  174. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  175. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  176. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  177. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  178. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  179. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  180. 'warp_shape_m': str(warp_shape[0]),
  181. 'warp_shape_n': str(warp_shape[1]),
  182. 'warp_shape_k': str(warp_shape[2]),
  183. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  184. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  185. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  186. 'epilogue_vector_length': str(epilogue_vector_length),
  187. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  188. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  189. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  190. 'stages': str(operation.tile_description.stages),
  191. 'alignment_src': str(operation.src.alignment),
  192. 'alignment_filter': str(operation.flt.alignment),
  193. 'nonuninity_kernel': str(operation.need_load_from_const).lower(),
  194. 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  195. 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode],
  196. 'without_shared_load': str(operation.without_shared_load).lower()
  197. }
  198. return SubstituteTemplate(self.template, values)
  199. class EmitDeconvInstance:
  200. def __init__(self):
  201. self.template = """
  202. // kernel instance "${operation_name}" generated by cutlass generator
  203. using Deconvolution =
  204. typename cutlass::conv::device::Deconvolution<
  205. ${element_src},
  206. ${layout_src},
  207. ${element_flt},
  208. ${layout_flt},
  209. ${element_dst},
  210. ${layout_dst},
  211. ${element_bias},
  212. ${layout_bias},
  213. ${element_accumulator},
  214. ${opcode_class},
  215. ${arch},
  216. cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
  217. cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
  218. cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
  219. ${epilogue_functor}<
  220. ${element_dst},
  221. ${epilogue_vector_length},
  222. ${element_accumulator},
  223. ${element_bias},
  224. ${element_epilogue}
  225. >,
  226. ${swizzling_functor},
  227. ${stages},
  228. ${alignment_src},
  229. ${alignment_filter},
  230. ${nonuninity_kernel},
  231. ${math_operator},
  232. ${implicit_gemm_mode}>;
  233. """
  234. def emit(self, operation):
  235. warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
  236. epilogue_vector_length = int(min(operation.dst.alignment * DataTypeSize[operation.dst.element], 128) / DataTypeSize[operation.dst.element])
  237. values = {
  238. 'operation_name': operation.procedural_name(),
  239. 'element_src': DataTypeTag[operation.src.element],
  240. 'layout_src': LayoutTag[operation.src.layout],
  241. 'element_flt': DataTypeTag[operation.flt.element],
  242. 'layout_flt': LayoutTag[operation.flt.layout],
  243. 'element_dst': DataTypeTag[operation.dst.element],
  244. 'layout_dst': LayoutTag[operation.dst.layout],
  245. 'element_bias': DataTypeTag[operation.bias.element],
  246. 'layout_bias': LayoutTag[operation.bias.layout],
  247. 'element_accumulator': DataTypeTag[operation.accumulator_type()],
  248. 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
  249. 'arch': "cutlass::arch::Sm%d" % operation.arch,
  250. 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
  251. 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
  252. 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
  253. 'warp_shape_m': str(warp_shape[0]),
  254. 'warp_shape_n': str(warp_shape[1]),
  255. 'warp_shape_k': str(warp_shape[2]),
  256. 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
  257. 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
  258. 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
  259. 'epilogue_vector_length': str(epilogue_vector_length),
  260. 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
  261. 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
  262. 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
  263. 'stages': str(operation.tile_description.stages),
  264. 'alignment_src': str(operation.src.alignment),
  265. 'alignment_filter': str(operation.flt.alignment),
  266. 'nonuninity_kernel': str(operation.need_load_from_const).lower(),
  267. 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation],
  268. 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode]
  269. }
  270. return SubstituteTemplate(self.template, values)
  271. ###################################################################################################
  272. #
  273. # Generator functions for all layouts
  274. #
  275. ###################################################################################################
  276. #
  277. def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 128, \
  278. skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False):
  279. operations = []
  280. element_epilogue = DataType.f32
  281. if conv_kind == ConvKind.Fprop:
  282. if implicit_gemm_mode == ImplicitGemmMode.GemmTN:
  283. swizzling_functor = SwizzlingFunctor.ConvFpropTrans
  284. else:
  285. swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx
  286. else:
  287. swizzling_functor = SwizzlingFunctor.ConvDgradNCxHWx
  288. # skip rule
  289. def filter_tile_with_layout(tile: TileDescription, layout: LayoutType) -> bool:
  290. return layout == LayoutType.TensorNC32HW32 and \
  291. tile.threadblock_shape[0] % 32 != 0
  292. # rule for bias_type and epilogues
  293. def get_bias_type_and_epilogues(tile: TileDescription, \
  294. out_dtype: DataType) -> Tuple[DataType, List[EpilogueFunctor]]:
  295. if tile.math_instruction.element_accumulator == DataType.s32 and \
  296. out_dtype != DataType.f32:
  297. bias_type = DataType.s32
  298. if tile.math_instruction.element_b == DataType.u4:
  299. epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp]
  300. else:
  301. epilogues = [EpilogueFunctor.BiasAddLinearCombinationClamp, EpilogueFunctor.BiasAddLinearCombinationReluClamp, \
  302. EpilogueFunctor.BiasAddLinearCombinationHSwishClamp]
  303. elif tile.math_instruction.element_accumulator == DataType.f32 or \
  304. out_dtype == DataType.f32:
  305. bias_type = DataType.f32
  306. epilogues = [EpilogueFunctor.BiasAddLinearCombination, EpilogueFunctor.BiasAddLinearCombinationRelu, \
  307. EpilogueFunctor.BiasAddLinearCombinationHSwish]
  308. return bias_type, epilogues
  309. # rule for filter alignment
  310. def get_flt_align(tile: TileDescription) -> int:
  311. nonlocal flt_align
  312. if tile.math_instruction.opcode_class == OpcodeClass.Simt \
  313. and tile.math_instruction.element_accumulator == DataType.s32:
  314. thread_num = tile.warp_count[0] * tile.warp_count[1] * tile.warp_count[2] * 32
  315. flt_block = tile.threadblock_shape[0] * tile.threadblock_shape[2] \
  316. * DataTypeSize[tile.math_instruction.element_a]
  317. load_per_thread = flt_block//thread_num
  318. if load_per_thread >= 128:
  319. flt_align = 128
  320. elif load_per_thread >= 64:
  321. flt_align = 64
  322. else:
  323. assert load_per_thread >= 32
  324. flt_align = 32
  325. return flt_align
  326. def get_dst_align(tile: TileDescription, out_layout: LayoutType) -> int:
  327. nonlocal dst_align
  328. if tile.math_instruction.opcode_class == OpcodeClass.TensorOp \
  329. and dst_layout == LayoutType.TensorNC4HW4:
  330. dst_align = 32
  331. return dst_align
  332. def filter_epilogue_with_conv_kind(epilogue: EpilogueFunctor, conv_kind: ConvKind) -> bool:
  333. return conv_kind == ConvKind.Dgrad \
  334. and epilogue != EpilogueFunctor.BiasAddLinearCombinationClamp
  335. # loop over all tile descriptions
  336. for tile in tile_descriptions:
  337. if filter_tile_with_layout(tile, dst_layout):
  338. continue
  339. bias_type, epilogues = get_bias_type_and_epilogues(tile, dst_type)
  340. flt_align = get_flt_align(tile)
  341. dst_align = get_dst_align(tile, dst_layout)
  342. for epilogue in epilogues:
  343. if filter_epilogue_with_conv_kind(epilogue, conv_kind):
  344. continue
  345. if dst_type == DataType.f32:
  346. bias_type = DataType.f32
  347. #
  348. src = TensorDescription(tile.math_instruction.element_b, src_layout, int(src_align / DataTypeSize[tile.math_instruction.element_b]))
  349. flt = TensorDescription(tile.math_instruction.element_a, flt_layout, int(flt_align / DataTypeSize[tile.math_instruction.element_a]))
  350. bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type])))
  351. dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type]))
  352. new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode, without_shared_load)
  353. operations.append(new_operation)
  354. if not skip_unity_kernel:
  355. new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode, without_shared_load)
  356. operations.append(new_operation)
  357. return operations
  358. ###################################################################################################
  359. #
  360. # Emitters functions for all targets
  361. #
  362. ###################################################################################################
  363. class EmitConv2dConfigurationLibrary:
  364. def __init__(self, operation_path, configuration_name):
  365. self.configuration_name = configuration_name
  366. self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
  367. self.instance_emitter = EmitConv2dInstance()
  368. self.instance_template = """
  369. ${operation_instance}
  370. // Derived class
  371. struct ${operation_name} :
  372. public ${operation_name}_base { };
  373. ///////////////////////////////////////////////////////////////////////////////////////////////////
  374. """
  375. self.header_template = """
  376. /*
  377. Generated by conv2d_operation.py - Do not edit.
  378. */
  379. ///////////////////////////////////////////////////////////////////////////////////////////////////
  380. #include "cutlass/cutlass.h"
  381. #include "cutlass/library/library.h"
  382. #include "cutlass/library/manifest.h"
  383. #include "library_internal.h"
  384. #include "conv2d_operation.h"
  385. ///////////////////////////////////////////////////////////////////////////////////////////////////
  386. """
  387. self.configuration_header = """
  388. namespace cutlass {
  389. namespace library {
  390. // Initialize all instances
  391. void initialize_${configuration_name}(Manifest &manifest) {
  392. """
  393. self.configuration_instance = """
  394. using Operation_${operation_name} = cutlass::conv::device::ImplicitGemmConvolution<
  395. ${operation_name}>;
  396. manifest.append(new cutlass::library::Conv2dOperation<
  397. Operation_${operation_name}>(
  398. "${operation_name}"));
  399. """
  400. self.configuration_epilogue = """
  401. }
  402. """
  403. self.epilogue_template = """
  404. ///////////////////////////////////////////////////////////////////////////////////////////////////
  405. } // namespace library
  406. } // namespace cutlass
  407. ///////////////////////////////////////////////////////////////////////////////////////////////////
  408. """
  409. #
  410. def __enter__(self):
  411. self.configuration_file = open(self.configuration_path, "w")
  412. self.configuration_file.write(SubstituteTemplate(self.header_template, {
  413. 'configuration_name': self.configuration_name
  414. }))
  415. self.operations = []
  416. return self
  417. #
  418. def emit(self, operation):
  419. self.operations.append(operation)
  420. self.configuration_file.write(SubstituteTemplate(self.instance_template, {
  421. 'configuration_name': self.configuration_name,
  422. 'operation_name': operation.procedural_name(),
  423. 'operation_instance': self.instance_emitter.emit(operation)
  424. }))
  425. #
  426. def __exit__(self, exception_type, exception_value, traceback):
  427. self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
  428. 'configuration_name': self.configuration_name
  429. }))
  430. for operation in self.operations:
  431. self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
  432. 'configuration_name': self.configuration_name,
  433. 'operation_name': operation.procedural_name()
  434. }))
  435. self.configuration_file.write(self.configuration_epilogue)
  436. self.configuration_file.write(self.epilogue_template)
  437. self.configuration_file.close()
  438. ###################################################################################################
  439. ###################################################################################################
  440. # Emitters for Conv Kernel Wrapper
  441. #
  442. ###################################################################################################
  443. class EmitConvSingleKernelWrapper():
  444. def __init__(self, kernel_path, operation, wrapper_path):
  445. self.kernel_path = kernel_path
  446. self.wrapper_path = wrapper_path
  447. self.operation = operation
  448. self.conv_wrappers = { \
  449. ConvKind.Fprop: """
  450. template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
  451. const typename Convolution::ElementSrc* d_src,
  452. const typename Convolution::ElementFilter* d_filter,
  453. const typename Convolution::ElementBias* d_bias,
  454. const typename Convolution::ElementDst* d_z,
  455. typename Convolution::ElementDst* d_dst,
  456. int* workspace,
  457. typename Convolution::ConvolutionParameter const& conv_param,
  458. typename Convolution::EpilogueOutputOp::Params const& epilogue,
  459. cudaStream_t stream,
  460. typename Convolution::ExtraParam extra_param);
  461. """, \
  462. ConvKind.Dgrad: """
  463. template void megdnn::cuda::cutlass_wrapper::cutlass_deconvolution_wrapper<Deconvolution>(
  464. const typename Deconvolution::ElementSrc* d_src,
  465. const typename Deconvolution::ElementFilter* d_filter,
  466. const typename Deconvolution::ElementBias* d_bias,
  467. const typename Deconvolution::ElementDst* d_z,
  468. typename Deconvolution::ElementDst* d_dst,
  469. int* workspace,
  470. typename Deconvolution::ConvolutionParameter const& conv_param,
  471. typename Deconvolution::EpilogueOutputOp::Params const& epilogue,
  472. cudaStream_t stream);
  473. """, \
  474. }
  475. if self.operation.conv_kind == ConvKind.Fprop:
  476. self.instance_emitter = EmitConv2dInstance()
  477. else:
  478. assert self.operation.conv_kind == ConvKind.Dgrad
  479. self.instance_emitter = EmitDeconvInstance()
  480. self.header_template = """
  481. #if !MEGDNN_TEGRA_X1
  482. // ignore warning of cutlass
  483. #pragma GCC diagnostic push
  484. #pragma GCC diagnostic ignored "-Wunused-parameter"
  485. #pragma GCC diagnostic ignored "-Wstrict-aliasing"
  486. #include "${wrapper_path}"
  487. """
  488. self.instance_template = """
  489. ${operation_instance}
  490. """
  491. self.wrapper_template = """
  492. ${wrapper_instance}
  493. """
  494. self.epilogue_template = """
  495. #pragma GCC diagnostic pop
  496. #endif
  497. """
  498. #
  499. def __enter__(self):
  500. self.kernel_path = os.path.join(self.kernel_path, "%s.cu" % self.operation.procedural_name())
  501. self.kernel_file = LazyFile(self.kernel_path)
  502. self.kernel_file.write(SubstituteTemplate(self.header_template, {
  503. 'wrapper_path': self.wrapper_path,
  504. }))
  505. return self
  506. #
  507. def emit(self):
  508. self.kernel_file.write(SubstituteTemplate(self.instance_template, {
  509. 'operation_instance': self.instance_emitter.emit(self.operation),
  510. }))
  511. # emit wrapper
  512. wrapper = SubstituteTemplate(self.wrapper_template, {
  513. 'wrapper_instance': self.conv_wrappers[self.operation.conv_kind],
  514. })
  515. self.kernel_file.write(wrapper)
  516. #
  517. def __exit__(self, exception_type, exception_value, traceback):
  518. self.kernel_file.write(self.epilogue_template)
  519. self.kernel_file.close()
  520. ###################################################################################################
  521. ###################################################################################################

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台