You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

manifest.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. #
  2. # \file generator.py
  3. #
  4. # \brief Generates the CUTLASS Library's instances
  5. #
  6. import enum
  7. import os.path
  8. import shutil
  9. from library import *
  10. from gemm_operation import *
  11. from conv2d_operation import *
  12. ###################################################################################################
  13. class EmitOperationKindLibrary:
  14. def __init__(self, generated_path, kind, args):
  15. self.generated_path = generated_path
  16. self.kind = kind
  17. self.args = args
  18. self.emitters = {
  19. OperationKind.Gemm: EmitGemmConfigurationLibrary,
  20. OperationKind.Conv2d: EmitConv2dConfigurationLibrary,
  21. }
  22. self.configurations = []
  23. self.header_template = """
  24. /*
  25. Generated by manifest.py - Do not edit.
  26. */
  27. #include "cutlass/cutlass.h"
  28. #include "cutlass/library/library.h"
  29. #include "cutlass/library/manifest.h"
  30. namespace cutlass {
  31. namespace library {
  32. ///////////////////////////////////////////////////////////////////////////////////////////////////
  33. """
  34. self.entry_template = """
  35. //
  36. // Entry point to construct operations
  37. //
  38. void initialize_all_${operation_name}_operations(Manifest &manifest) {
  39. """
  40. self.configuration_prototype_template = (
  41. "void initialize_${configuration_name}(Manifest &manifest);\n"
  42. )
  43. self.configuration_template = " initialize_${configuration_name}(manifest);\n"
  44. self.epilogue_template = """
  45. }
  46. ///////////////////////////////////////////////////////////////////////////////////////////////////
  47. } // namespace library
  48. } // namespace cutlass
  49. """
  50. #
  51. def __enter__(self):
  52. self.operation_path = os.path.join(
  53. self.generated_path, OperationKindNames[self.kind]
  54. )
  55. os.mkdir(self.operation_path)
  56. self.top_level_path = os.path.join(
  57. self.operation_path, "all_%s_operations.cu" % OperationKindNames[self.kind]
  58. )
  59. self.top_level_file = open(self.top_level_path, "w")
  60. self.top_level_file.write(self.header_template)
  61. self.source_files = [self.top_level_path]
  62. return self
  63. #
  64. def emit(self, configuration_name, operations):
  65. with self.emitters[self.kind](
  66. self.operation_path, configuration_name
  67. ) as configuration_emitter:
  68. for operation in operations:
  69. configuration_emitter.emit(operation)
  70. self.source_files.append(configuration_emitter.configuration_path)
  71. self.configurations.append(configuration_name)
  72. self.top_level_file.write(
  73. SubstituteTemplate(
  74. self.configuration_prototype_template,
  75. {"configuration_name": configuration_name},
  76. )
  77. )
  78. #
  79. def __exit__(self, exception_type, exception_value, traceback):
  80. self.top_level_file.write(
  81. SubstituteTemplate(
  82. self.entry_template, {"operation_name": OperationKindNames[self.kind]}
  83. )
  84. )
  85. for configuration_name in self.configurations:
  86. self.top_level_file.write(
  87. SubstituteTemplate(
  88. self.configuration_template,
  89. {"configuration_name": configuration_name},
  90. )
  91. )
  92. self.top_level_file.write(self.epilogue_template)
  93. self.top_level_file.close()
  94. ###################################################################################################
  95. ###################################################################################################
  96. class Options:
  97. def __init__(self):
  98. pass
  99. ###################################################################################################
  100. #
  101. class Manifest:
  102. #
  103. def __init__(self, args):
  104. self.operations = {}
  105. self.args = args
  106. architectures = (
  107. args.architectures.split(";") if len(args.architectures) else ["50"]
  108. )
  109. self.compute_capabilities = [int(x) for x in architectures]
  110. self.selected_kernels = []
  111. if args.operations == "all":
  112. self.operations_enabled = []
  113. else:
  114. operations_list = [OperationKind.Gemm, OperationKind.Conv2d]
  115. self.operations_enabled = [
  116. x
  117. for x in operations_list
  118. if OperationKindNames[x] in args.operations.split(",")
  119. ]
  120. if args.kernels == "all":
  121. self.kernel_names = []
  122. else:
  123. self.kernel_names = [x for x in args.kernels.split(",") if x != ""]
  124. self.ignore_kernel_names = [
  125. x for x in args.ignore_kernels.split(",") if x != ""
  126. ]
  127. if args.kernel_filter_file is None:
  128. self.kernel_filter_list = []
  129. else:
  130. self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
  131. self.operation_count = 0
  132. self.operations_by_name = {}
  133. self.top_level_prologue = """
  134. #include "cutlass/library/library.h"
  135. #include "cutlass/library/manifest.h"
  136. namespace cutlass {
  137. namespace library {
  138. ${prototypes}
  139. void initialize_all(Manifest &manifest) {
  140. """
  141. self.top_level_reserve = " manifest.reserve(${operation_count});\n\n"
  142. self.top_level_epilogue = """
  143. }
  144. } // namespace library
  145. } // namespace cutlass
  146. """
  147. def get_kernel_filters(self, kernelListFile):
  148. if os.path.isfile(kernelListFile):
  149. with open(kernelListFile, "r") as fileReader:
  150. lines = [
  151. line.rstrip() for line in fileReader if not line.startswith("#")
  152. ]
  153. lines = [re.compile(line) for line in lines if line]
  154. return lines
  155. else:
  156. return []
  157. def filter_out_kernels(self, kernel_name, kernel_filter_list):
  158. for kernel_filter_re in kernel_filter_list:
  159. if kernel_filter_re.search(kernel_name) is not None:
  160. return True
  161. return False
  162. #
  163. def _filter_string_matches(self, filter_string, haystack):
  164. """ Returns true if all substrings appear in the haystack in order"""
  165. substrings = filter_string.split("*")
  166. for sub in substrings:
  167. idx = haystack.find(sub)
  168. if idx < 0:
  169. return False
  170. haystack = haystack[idx + len(sub) :]
  171. return True
  172. #
  173. def filter(self, operation):
  174. """ Filtering operations based on various criteria"""
  175. # filter based on compute capability
  176. enabled = False
  177. for cc in self.compute_capabilities:
  178. if (
  179. cc >= operation.tile_description.minimum_compute_capability
  180. and cc <= operation.tile_description.maximum_compute_capability
  181. ):
  182. enabled = True
  183. break
  184. if not enabled:
  185. return False
  186. if (
  187. len(self.operations_enabled)
  188. and not operation.operation_kind in self.operations_enabled
  189. ):
  190. return False
  191. # eliminate duplicates
  192. if operation.procedural_name() in self.operations_by_name.keys():
  193. return False
  194. # Filter based on list of valid substrings
  195. if len(self.kernel_names):
  196. name = operation.procedural_name()
  197. enabled = False
  198. # compare against the include list
  199. for name_substr in self.kernel_names:
  200. if self._filter_string_matches(name_substr, name):
  201. enabled = True
  202. break
  203. # compare against the exclude list
  204. for name_substr in self.ignore_kernel_names:
  205. if self._filter_string_matches(name_substr, name):
  206. enabled = False
  207. break
  208. if len(self.kernel_filter_list) > 0:
  209. enabled = False
  210. if self.filter_out_kernels(
  211. operation.procedural_name(), self.kernel_filter_list
  212. ):
  213. enabled = True
  214. # todo: filter based on compute data type
  215. return enabled
  216. #
  217. #
  218. def append(self, operation):
  219. """
  220. Inserts the operation.
  221. operation_kind -> configuration_name -> []
  222. """
  223. if self.filter(operation):
  224. self.selected_kernels.append(operation.procedural_name())
  225. self.operations_by_name[operation.procedural_name()] = operation
  226. # add the configuration
  227. configuration_name = operation.configuration_name()
  228. if operation.operation_kind not in self.operations.keys():
  229. self.operations[operation.operation_kind] = {}
  230. if (
  231. configuration_name
  232. not in self.operations[operation.operation_kind].keys()
  233. ):
  234. self.operations[operation.operation_kind][configuration_name] = []
  235. self.operations[operation.operation_kind][configuration_name].append(
  236. operation
  237. )
  238. self.operation_count += 1
  239. #
  240. #
  241. def emit(self, target=GeneratorTarget.Library):
  242. operation_emitters = {GeneratorTarget.Library: EmitOperationKindLibrary}
  243. generated_path = os.path.join(self.args.curr_build_dir, "generated")
  244. # create generated/
  245. if os.path.exists(generated_path):
  246. shutil.rmtree(generated_path)
  247. os.mkdir(generated_path)
  248. source_files = []
  249. top_level_path = os.path.join(generated_path, "initialize_all.cpp")
  250. with open(top_level_path, "w") as top_level_file:
  251. if target == GeneratorTarget.Library:
  252. source_files.append(top_level_path)
  253. prototypes = []
  254. for operation_kind, configurations in self.operations.items():
  255. prototypes.append(
  256. SubstituteTemplate(
  257. "void initialize_all_${operation_kind}_operations(Manifest &manifest);",
  258. {"operation_kind": OperationKindNames[operation_kind]},
  259. )
  260. )
  261. top_level_file.write(
  262. SubstituteTemplate(
  263. self.top_level_prologue, {"prototypes": "\n".join(prototypes)}
  264. )
  265. )
  266. top_level_file.write(
  267. SubstituteTemplate(
  268. self.top_level_reserve,
  269. {"operation_count": str(self.operation_count)},
  270. )
  271. )
  272. # for each operation kind, emit initializer for all configurations
  273. for operation_kind, configurations in self.operations.items():
  274. with operation_emitters[target](
  275. generated_path, operation_kind, self.args
  276. ) as operation_kind_emitter:
  277. for configuration_name, operations in configurations.items():
  278. operation_kind_emitter.emit(configuration_name, operations)
  279. source_files += operation_kind_emitter.source_files
  280. top_level_file.write(
  281. SubstituteTemplate(
  282. " initialize_all_${operation_kind}_operations(manifest);\n",
  283. {"operation_kind": OperationKindNames[operation_kind]},
  284. )
  285. )
  286. top_level_file.write(self.top_level_epilogue)
  287. # write the manifest.cmake file containing paths from all targets
  288. manifest_path = os.path.join(generated_path, "manifest.cmake")
  289. with open(manifest_path, "w") as manifest_file:
  290. target_name = "cutlass_library_objs"
  291. target_text = SubstituteTemplate(
  292. """cutlass_target_sources(
  293. ${target_name}
  294. BATCH_SOURCES ON
  295. PRIVATE
  296. """,
  297. {"target_name": target_name},
  298. )
  299. manifest_file.write(target_text)
  300. for source_file in source_files:
  301. manifest_file.write(" %s\n" % str(source_file.replace("\\", "/")))
  302. manifest_file.write(")")
  303. #
  304. ###################################################################################################
  305. def GenerateManifest(args, operations, output_dir):
  306. assert isinstance(operations, list)
  307. if len(operations) == 0:
  308. return
  309. op = operations[0]
  310. required_cuda_ver_major = op.required_cuda_ver_major
  311. required_cuda_ver_minor = op.required_cuda_ver_minor
  312. manifest_path = os.path.join(
  313. output_dir, "all_%s_%s_operations.cu" % (args.operations, args.type)
  314. )
  315. f = open(manifest_path, "w")
  316. f.write(
  317. """
  318. /*
  319. Generated by generator.py - Do not edit.
  320. */
  321. #if __CUDACC_VER_MAJOR__ > %s || (__CUDACC_VER_MAJOR__ == %s && __CUDACC_VER_MINOR__ >= %s)
  322. #include "cutlass/cutlass.h"
  323. #include "src/cuda/cutlass/library.h"
  324. #include "src/cuda/cutlass/manifest.h"
  325. namespace cutlass {
  326. namespace library {
  327. """
  328. % (
  329. str(required_cuda_ver_major),
  330. str(required_cuda_ver_major),
  331. str(required_cuda_ver_minor),
  332. )
  333. )
  334. for op in operations:
  335. f.write("void initialize_%s(Manifest &manifest);\n" % op.procedural_name())
  336. f.write(
  337. """
  338. void initialize_all_%s_%s_operations(Manifest &manifest) {
  339. """
  340. % (args.operations, args.type)
  341. )
  342. for op in operations:
  343. f.write(" initialize_%s(manifest);\n" % op.procedural_name())
  344. f.write(
  345. """
  346. }
  347. } // namespace library
  348. } // namespace cutlass
  349. #endif
  350. """
  351. )
  352. f.close()