You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_stubapi.py 22 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. #!/usr/bin/python3
  2. # -*- coding: UTF-8 -*-
  3. #-------------------------------------------------------------------
  4. # Purpose:
  5. # Copyright 2020 Huawei Technologies Co., Ltd. All rights reserved.
  6. #-------------------------------------------------------------------
  7. import os
  8. import re
  9. import sys
  10. import logging
  11. logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s',
  12. level=logging.INFO)
  13. """
  14. this attr is used for symbol table visible
  15. """
  16. GE_ATTR = 'GE_FUNC_VISIBILITY'
  17. """
  18. generate stub func body by return type
  19. """
  20. RETURN_STATEMENTS = {
  21. 'graphStatus': ' std::cout << "[ERROR]: stub library libgraph or libge_compiler cannot be used for execution, please check your "\n '
  22. ' << "environment variables and compilation options to make sure you use the correct library."\n'
  23. ' << std::endl;\n'
  24. ' return ACL_ERROR_COMPILING_STUB_MODE;',
  25. 'Status': ' return SUCCESS;',
  26. 'Graph': ' return Graph();',
  27. 'Graph&': ' return *this;',
  28. 'Format': ' return Format();',
  29. 'Format&': ' return *this;',
  30. 'Shape': ' return Shape();',
  31. 'Shape&': ' return *this;',
  32. 'TensorDesc': ' return TensorDesc();',
  33. 'TensorDesc&': ' return *this;',
  34. 'Tensor': ' return Tensor();',
  35. 'Tensor&': ' return *this;',
  36. 'Operator': ' return Operator();',
  37. 'Operator&': ' return *this;',
  38. 'Ptr': ' return nullptr;',
  39. 'std::string': ' return "";',
  40. 'std::string&': ' return "";',
  41. 'string': ' return "";',
  42. 'int': ' return 0;',
  43. 'DataType': ' return DT_FLOAT;',
  44. 'InferenceContextPtr': ' return nullptr;',
  45. 'SubgraphBuilder': ' return nullptr;',
  46. 'OperatorImplPtr': ' return nullptr;',
  47. 'OutHandler': ' return nullptr;',
  48. 'std::vector<std::string>': ' return {};',
  49. 'std::vector<int64_t>': ' return {};',
  50. 'std::map': ' return {};',
  51. 'uint32_t': ' return 0;',
  52. 'int64_t': ' return 0;',
  53. 'uint64_t': ' return 0;',
  54. 'size_t': ' return 0;',
  55. 'float': ' return 0.0f;',
  56. 'bool': ' return false;',
  57. }
  58. """
  59. max code len per line in hua_wei software programming specifications
  60. """
  61. max_code_len_per_line = 100
  62. """
  63. white_list_for_debug, include_dir_key_words is to
  64. determines which header files to generate cc files from
  65. when DEBUG on
  66. """
  67. white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h",
  68. "ge_ir_build.h", "ge_api.h", "tensorflow_parser.h", "caffe_parser.h"]
  69. include_dir_key_words = ["ge", "graph", "parser"]
  70. DEBUG = True
  71. def need_generate_func(func_line):
  72. """
  73. :param func_line:
  74. :return:
  75. """
  76. if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \
  77. or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"):
  78. return False
  79. return True
  80. def file_endswith_white_list_suffix(file):
  81. """
  82. :param file:
  83. :return:
  84. """
  85. if DEBUG:
  86. for suffix in white_list_for_debug:
  87. if file.endswith(suffix):
  88. return True
  89. return False
  90. else:
  91. return True
  92. """
  93. belows are patterns used for analyse .h file
  94. """
  95. # pattern function
  96. pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after
  97. ([a-zA-Z~_] # void int likely
  98. .*
  99. [)] #we find )
  100. (?!.*{) # we do not want the case int abc() const
  101. .*)
  102. (;.*) #we want to find ; and after for we will replace these later
  103. \n$
  104. """, re.VERBOSE | re.MULTILINE | re.DOTALL)
  105. # pattern comment
  106. pattern_comment = re.compile(r'^\s*//')
  107. pattern_comment_2_start = re.compile(r'^\s*/[*]')
  108. pattern_comment_2_end = re.compile(r'[*]/\s*$')
  109. # pattern define
  110. pattern_define = re.compile(r'^\s*#define')
  111. pattern_define_return = re.compile(r'\\\s*$')
  112. # blank line
  113. pattern_blank_line = re.compile(r'^\s*$')
  114. # virtual,explicit,friend,static
  115. pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)')
  116. # lead space
  117. pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]')
  118. # functions will have patterns such as func ( or func(
  119. # but operator is an exception; the class name is preceded by an operator, and the above mode does not exist
  120. # format like :"operator = ()"
  121. pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]')
  122. # template
  123. pattern_template = re.compile(r'^\s*template')
  124. pattern_template_end = re.compile(r'>\s*$')
  125. # namespace
  126. pattern_namespace = re.compile(r'namespace.*{')
  127. # class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with
  128. pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR)
  129. # {}
  130. pattern_start = re.compile('{')
  131. pattern_end = re.compile('}')
  132. line_index = 0
  133. class H2CC(object):
  134. def __init__(self, input_file, output_file, shared_includes_content):
  135. """
  136. :param input_file:
  137. :param output_file:
  138. :param shared_includes_content:
  139. """
  140. self.input_file = input_file
  141. self.output_file = output_file
  142. self.shared_includes_content = shared_includes_content
  143. self.line_index = 0
  144. self.input_fd = open(self.input_file, 'r')
  145. self.input_content = self.input_fd.readlines()
  146. self.output_fd = open(self.output_file, 'w')
  147. # The state may be normal_now(in the middle of {}),class_now,namespace_now
  148. self.stack = []
  149. self.stack_class = []
  150. self.stack_template = []
  151. # record funcs generated by h2cc func
  152. self.func_list_exist = []
  153. def __del__(self):
  154. self.input_fd.close()
  155. self.output_fd.close()
  156. del self.stack
  157. del self.stack_class
  158. del self.stack_template
  159. del self.func_list_exist
  160. def just_skip(self):
  161. # skip blank line or comment
  162. if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search(
  163. self.input_content[self.line_index]): # /n or comment using //
  164. self.line_index += 1
  165. if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /*
  166. while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */
  167. self.line_index += 1
  168. self.line_index += 1
  169. # skip define
  170. if pattern_define.search(self.input_content[self.line_index]):
  171. while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search(
  172. self.input_content[self.line_index]):
  173. self.line_index += 1
  174. self.line_index += 1
  175. def write_inc_content(self):
  176. for shared_include_content in self.shared_includes_content:
  177. self.output_fd.write(shared_include_content)
  178. def h2cc(self):
  179. """
  180. :return:
  181. """
  182. logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file)
  183. global pattern_comment
  184. global pattern_comment_2_start
  185. global pattern_comment_2_end
  186. global pattern_blank_line
  187. global pattern_func
  188. global pattern_keyword
  189. global pattern_leading_space
  190. global pattern_func_name
  191. global pattern_template
  192. global pattern_template_end
  193. global pattern_namespace
  194. global pattern_class
  195. global pattern_start
  196. global pattern_end
  197. global line_index
  198. # write inc content
  199. self.write_inc_content()
  200. # core processing cycle, process the input .h file by line
  201. while self.line_index < len(self.input_content):
  202. # handle comment and blank line
  203. self.just_skip()
  204. # match namespace
  205. self.handle_namespace()
  206. # match template
  207. template_string = self.handle_template()
  208. # match class
  209. line = self.input_content[self.line_index]
  210. match_class = pattern_class.search(line)
  211. match_start = pattern_start.search(line)
  212. handle_class_result = self.handle_class(template_string, line, match_start, match_class)
  213. if handle_class_result == "continue":
  214. continue
  215. # match "}"
  216. handle_stack_result = self.handle_stack(match_start)
  217. if handle_stack_result == "continue":
  218. continue
  219. # handle func
  220. handle_func1_result, line, start_i = self.handle_func1(line)
  221. if handle_func1_result == "continue":
  222. continue
  223. # here means func is found
  224. # delete key word
  225. line = pattern_keyword.sub('', line)
  226. logging.info("line[%s]", line)
  227. # Class member function
  228. # if friend we will not add class name
  229. friend_match = re.search('friend ', line)
  230. if len(self.stack_class) > 0 and not friend_match:
  231. line, func_name = self.handle_class_member_func(line, template_string)
  232. # Normal functions
  233. else:
  234. line, func_name = self.handle_normal_func(line, template_string)
  235. need_generate = need_generate_func(line)
  236. # func body
  237. line += self.implement_function(line)
  238. # comment
  239. line = self.gen_comment(start_i) + line
  240. # write to out file
  241. self.write_func_content(line, func_name, need_generate)
  242. # next loop
  243. self.line_index += 1
  244. logging.info('Added %s functions', len(self.func_list_exist))
  245. logging.info('Successfully converted,please see ' + self.output_file)
  246. def handle_func1(self, line):
  247. """
  248. :param line:
  249. :return:
  250. """
  251. find1 = re.search('[(]', line)
  252. if not find1:
  253. self.line_index += 1
  254. return "continue", line, None
  255. find2 = re.search('[)]', line)
  256. start_i = self.line_index
  257. space_match = pattern_leading_space.search(line)
  258. # deal with
  259. # int abc(int a,
  260. # int b)
  261. if find1 and (not find2):
  262. self.line_index += 1
  263. line2 = self.input_content[self.line_index]
  264. if space_match:
  265. line2 = re.sub('^' + space_match.group(1), '', line2)
  266. line += line2
  267. while self.line_index < len(self.input_content) and (not re.search('[)]', line2)):
  268. self.line_index += 1
  269. line2 = self.input_content[self.line_index]
  270. line2 = re.sub('^' + space_match.group(1), '', line2)
  271. line += line2
  272. match_start = pattern_start.search(self.input_content[self.line_index])
  273. match_end = pattern_end.search(self.input_content[self.line_index])
  274. if match_start: # like ) { or ) {} int the last line
  275. if not match_end:
  276. self.stack.append('normal_now')
  277. ii = start_i
  278. while ii <= self.line_index:
  279. ii += 1
  280. self.line_index += 1
  281. return "continue", line, start_i
  282. logging.info("line[%s]", line)
  283. # ' int abc();'->'int abc()'
  284. (line, match) = pattern_func.subn(r'\2\n', line)
  285. logging.info("line[%s]", line)
  286. # deal with case:
  287. # 'int \n abc(int a, int b)'
  288. if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]):
  289. line = self.input_content[start_i - 1] + line
  290. line = line.lstrip()
  291. if not match:
  292. self.line_index += 1
  293. return "continue", line, start_i
  294. return "pass", line, start_i
  295. def handle_stack(self, match_start):
  296. """
  297. :param match_start:
  298. :return:
  299. """
  300. line = self.input_content[self.line_index]
  301. match_end = pattern_end.search(line)
  302. if match_start:
  303. self.stack.append('normal_now')
  304. if match_end:
  305. top_status = self.stack.pop()
  306. if top_status == 'namespace_now':
  307. self.output_fd.write(line + '\n')
  308. elif top_status == 'class_now':
  309. self.stack_class.pop()
  310. self.stack_template.pop()
  311. if match_start or match_end:
  312. self.line_index += 1
  313. return "continue"
  314. if len(self.stack) > 0 and self.stack[-1] == 'normal_now':
  315. self.line_index += 1
  316. return "continue"
  317. return "pass"
  318. def handle_class(self, template_string, line, match_start, match_class):
  319. """
  320. :param template_string:
  321. :param line:
  322. :param match_start:
  323. :param match_class:
  324. :return:
  325. """
  326. if match_class: # we face a class
  327. self.stack_template.append(template_string)
  328. self.stack.append('class_now')
  329. class_name = match_class.group(3)
  330. # class template specializations: class A<u,Node<u> >
  331. if '<' in class_name:
  332. k = line.index('<')
  333. fit = 1
  334. for ii in range(k + 1, len(line)):
  335. if line[ii] == '<':
  336. fit += 1
  337. if line[ii] == '>':
  338. fit -= 1
  339. if fit == 0:
  340. break
  341. class_name += line[k + 1:ii + 1]
  342. logging.info('class_name[%s]', class_name)
  343. self.stack_class.append(class_name)
  344. while not match_start:
  345. self.line_index += 1
  346. line = self.input_content[self.line_index]
  347. match_start = pattern_start.search(line)
  348. self.line_index += 1
  349. return "continue"
  350. return "pass"
  351. def handle_template(self):
  352. line = self.input_content[self.line_index]
  353. match_template = pattern_template.search(line)
  354. template_string = ''
  355. if match_template:
  356. match_template_end = pattern_template_end.search(line)
  357. template_string = line
  358. while not match_template_end:
  359. self.line_index += 1
  360. line = self.input_content[self.line_index]
  361. template_string += line
  362. match_template_end = pattern_template_end.search(line)
  363. self.line_index += 1
  364. return template_string
  365. def handle_namespace(self):
  366. line = self.input_content[self.line_index]
  367. match_namespace = pattern_namespace.search(line)
  368. if match_namespace: # we face namespace
  369. self.output_fd.write(line + '\n')
  370. self.stack.append('namespace_now')
  371. self.line_index += 1
  372. def handle_normal_func(self, line, template_string):
  373. template_line = ''
  374. self.stack_template.append(template_string)
  375. if self.stack_template[-1] != '':
  376. template_line = re.sub(r'\s*template', 'template', self.stack_template[-1])
  377. # change '< class T = a, class U = A(3)>' to '<class T, class U>'
  378. template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
  379. template_line = re.sub(r'\s*=.*,', ',', template_line)
  380. template_line = re.sub(r'\s*=.*', '', template_line)
  381. line = re.sub(r'\s*=.*,', ',', line)
  382. line = re.sub(r'\s*=.*\)', ')', line)
  383. line = template_line + line
  384. self.stack_template.pop()
  385. func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
  386. logging.info("line[%s]", line)
  387. logging.info("func_name[%s]", func_name)
  388. return line, func_name
  389. def handle_class_member_func(self, line, template_string):
  390. template_line = ''
  391. x = ''
  392. if template_string != '':
  393. template_string = re.sub(r'\s*template', 'template', template_string)
  394. template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string)
  395. template_string = re.sub(r'\s*=.*,', ',', template_string)
  396. template_string = re.sub(r'\s*=.*', '', template_string)
  397. if self.stack_template[-1] != '':
  398. if not (re.search(r'<\s*>', stack_template[-1])):
  399. template_line = re.sub(r'^\s*template', 'template', stack_template[-1])
  400. if not (re.search(r'<.*>', self.stack_class[-1])):
  401. # for x we get like template<class T, typename U> -> <T,U>
  402. x = re.sub(r'template\s*<', '<', template_line) # remove template -> <class T, typename U>
  403. x = re.sub(r'\n', '', x)
  404. x = re.sub(r'\s*=.*,', ',', x)
  405. x = re.sub(r'\s*=.*\>', '>', x)
  406. x = x.rstrip() # remove \n
  407. x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '',
  408. x) # remove class,typename -> <T, U>
  409. x = re.sub(r'<\s+', '<', x)
  410. x = re.sub(r'\s+>', '>', x)
  411. x = re.sub(r'\s+,', ',', x)
  412. x = re.sub(r',\s+', ', ', x)
  413. line = re.sub(r'\s*=\s+0', '', line)
  414. line = re.sub(r'\s*=\s+.*,', ',', line)
  415. line = re.sub(r'\s*=\s+.*\)', ')', line)
  416. logging.info("x[%s]\nline[%s]", x, line)
  417. # if the function is long, void ABC::foo()
  418. # breaks into two lines void ABC::\n foo()
  419. temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1)
  420. if len(temp_line) > max_code_len_per_line:
  421. line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1)
  422. else:
  423. line = temp_line
  424. logging.info("line[%s]", line)
  425. # add template as the above if there is one
  426. template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
  427. template_line = re.sub(r'\s*=.*,', ',', template_line)
  428. template_line = re.sub(r'\s*=.*', '', template_line)
  429. line = template_line + template_string + line
  430. func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
  431. logging.info("line[%s]", line)
  432. logging.info("func_name[%s]", func_name)
  433. return line, func_name
  434. def write_func_content(self, content, func_name, need_generate):
  435. if not (func_name in self.func_list_exist) and need_generate:
  436. self.output_fd.write(content)
  437. self.func_list_exist.append(func_name)
  438. logging.info('add func:[%s]', func_name)
  439. def gen_comment(self, start_i):
  440. comment_line = ''
  441. # Function comments are on top of function declarations, copy them over
  442. k = start_i - 1 # one line before this func start
  443. if pattern_template.search(self.input_content[k]):
  444. k -= 1
  445. if pattern_comment_2_end.search(self.input_content[k]):
  446. comment_line = self.input_content[k].lstrip()
  447. while not pattern_comment_2_start.search(self.input_content[k]):
  448. k -= 1
  449. comment_line = self.input_content[k].lstrip() + comment_line
  450. else:
  451. for j in range(k, 0, -1):
  452. c_line = self.input_content[j]
  453. if pattern_comment.search(c_line):
  454. c_line = re.sub(r'\s*//', '//', c_line)
  455. comment_line = c_line + comment_line
  456. else:
  457. break
  458. return comment_line
  459. @staticmethod
  460. def implement_function(func):
  461. function_def = ''
  462. function_def += '{\n'
  463. all_items = func.split()
  464. start = 0
  465. return_type = all_items[start]
  466. if return_type == "const":
  467. start += 1
  468. return_type = all_items[start]
  469. if return_type.startswith(('std::map', 'std::set', 'std::vector')):
  470. return_type = "std::map"
  471. if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')):
  472. return_type = "Ptr"
  473. if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
  474. return_type += "&"
  475. if RETURN_STATEMENTS.__contains__(return_type):
  476. function_def += RETURN_STATEMENTS[return_type]
  477. else:
  478. logging.warning("Unhandled return type[%s]", return_type)
  479. function_def += '\n'
  480. function_def += '}\n'
  481. function_def += '\n'
  482. return function_def
  483. def collect_header_files(path):
  484. """
  485. :param path:
  486. :return:
  487. """
  488. header_files = []
  489. shared_includes_content = []
  490. for root, dirs, files in os.walk(path):
  491. files.sort()
  492. for file in files:
  493. if file.find("git") >= 0:
  494. continue
  495. if not file.endswith('.h'):
  496. continue
  497. file_path = os.path.join(root, file)
  498. file_path = file_path.replace('\\', '/')
  499. header_files.append(file_path)
  500. include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:])
  501. shared_includes_content.append(include_str)
  502. # for acl error code
  503. shared_includes_content.append('#include <iostream>\n')
  504. shared_includes_content.append('const int ACL_ERROR_COMPILING_STUB_MODE = 100039;\n')
  505. return header_files, shared_includes_content
  506. def generate_stub_file(inc_dir, out_cc_dir):
  507. """
  508. :param inc_dir:
  509. :param out_cc_dir:
  510. :return:
  511. """
  512. target_header_files, shared_includes_content = collect_header_files(inc_dir)
  513. for header_file in target_header_files:
  514. if not file_endswith_white_list_suffix(header_file):
  515. continue
  516. cc_file = re.sub('.h*$', '.cc', header_file)
  517. h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content)
  518. h_2_cc.h2cc()
  519. def gen_code(inc_dir, out_cc_dir):
  520. """
  521. :param inc_dir:
  522. :param out_cc_dir:
  523. :return:
  524. """
  525. if not inc_dir.endswith('/'):
  526. inc_dir += '/'
  527. if not out_cc_dir.endswith('/'):
  528. out_cc_dir += '/'
  529. for include_dir_key_word in include_dir_key_words:
  530. generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir)
  531. if __name__ == '__main__':
  532. inc_dir = sys.argv[1]
  533. out_cc_dir = sys.argv[2]
  534. gen_code(inc_dir, out_cc_dir)

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示