You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_stubapi.py 22 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578
  1. import os
  2. import re
  3. import sys
  4. import logging
  5. logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s',
  6. level=logging.INFO)
  7. """
  8. this attr is used for symbol table visible
  9. """
  10. GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY'
  11. """
  12. generate stub func body by return type
  13. """
  14. RETURN_STATEMENTS = {
  15. 'graphStatus': ' std::cout << "[ERROR]: stub library libgraph or libge_compiler cannot be used for execution, please check your "\n '
  16. ' << "environment variables and compilation options to make sure you use the correct library."\n'
  17. ' << std::endl;\n'
  18. ' return ACL_ERROR_COMPILING_STUB_MODE;',
  19. 'Status': ' return SUCCESS;',
  20. 'Graph': ' return Graph();',
  21. 'Graph&': ' return *this;',
  22. 'Format': ' return Format();',
  23. 'Format&': ' return *this;',
  24. 'Shape': ' return Shape();',
  25. 'Shape&': ' return *this;',
  26. 'TensorDesc': ' return TensorDesc();',
  27. 'TensorDesc&': ' return *this;',
  28. 'Tensor': ' return Tensor();',
  29. 'Tensor&': ' return *this;',
  30. 'Operator': ' return Operator();',
  31. 'Operator&': ' return *this;',
  32. 'Ptr': ' return nullptr;',
  33. 'std::string': ' return "";',
  34. 'std::string&': ' return "";',
  35. 'string': ' return "";',
  36. 'int': ' return 0;',
  37. 'DataType': ' return DT_FLOAT;',
  38. 'InferenceContextPtr': ' return nullptr;',
  39. 'SubgraphBuilder': ' return nullptr;',
  40. 'OperatorImplPtr': ' return nullptr;',
  41. 'OutHandler': ' return nullptr;',
  42. 'std::vector<std::string>': ' return {};',
  43. 'std::vector<int64_t>': ' return {};',
  44. 'std::map': ' return {};',
  45. 'uint32_t': ' return 0;',
  46. 'int64_t': ' return 0;',
  47. 'uint64_t': ' return 0;',
  48. 'size_t': ' return 0;',
  49. 'float': ' return 0.0f;',
  50. 'bool': ' return false;',
  51. }
  52. """
  53. max code len per line in hua_wei software programming specifications
  54. """
  55. max_code_len_per_line = 100
  56. """
  57. white_list_for_debug, include_dir_key_words is to
  58. determines which header files to generate cc files from
  59. when DEBUG on
  60. """
  61. white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h",
  62. "ge_ir_build.h", "ge_api.h", "ge_prof.h", "tensorflow_parser.h", "caffe_parser.h"]
  63. include_dir_key_words = ["ge", "graph", "parser"]
  64. DEBUG = True
  65. def need_generate_func(func_line):
  66. """
  67. :param func_line:
  68. :return:
  69. """
  70. if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \
  71. or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"):
  72. return False
  73. return True
  74. def file_endswith_white_list_suffix(file):
  75. """
  76. :param file:
  77. :return:
  78. """
  79. if DEBUG:
  80. for suffix in white_list_for_debug:
  81. if file.endswith(suffix):
  82. return True
  83. return False
  84. else:
  85. return True
  86. """
  87. belows are patterns used for analyse .h file
  88. """
  89. # pattern function
  90. pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after
  91. ([a-zA-Z~_] # void int likely
  92. .*
  93. [)] #we find )
  94. (?!.*{) # we do not want the case int abc() const
  95. .*)
  96. (;.*) #we want to find ; and after for we will replace these later
  97. \n$
  98. """, re.VERBOSE | re.MULTILINE | re.DOTALL)
  99. # pattern comment
  100. pattern_comment = re.compile(r'^\s*//')
  101. pattern_comment_2_start = re.compile(r'^\s*/[*]')
  102. pattern_comment_2_end = re.compile(r'[*]/\s*$')
  103. # pattern define
  104. pattern_define = re.compile(r'^\s*#define')
  105. pattern_define_return = re.compile(r'\\\s*$')
  106. # blank line
  107. pattern_blank_line = re.compile(r'^\s*$')
  108. # virtual,explicit,friend,static
  109. pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)')
  110. # lead space
  111. pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]')
  112. # functions will have patterns such as func ( or func(
  113. # but operator is an exception; the class name is preceded by an operator, and the above mode does not exist
  114. # format like :"operator = ()"
  115. pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]')
  116. # template
  117. pattern_template = re.compile(r'^\s*template')
  118. pattern_template_end = re.compile(r'>\s*$')
  119. # namespace
  120. pattern_namespace = re.compile(r'namespace.*{')
  121. # class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with
  122. pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR)
  123. # {}
  124. pattern_start = re.compile('{')
  125. pattern_end = re.compile('}')
  126. line_index = 0
  127. class H2CC(object):
  128. def __init__(self, input_file, output_file, shared_includes_content):
  129. """
  130. :param input_file:
  131. :param output_file:
  132. :param shared_includes_content:
  133. """
  134. self.input_file = input_file
  135. self.output_file = output_file
  136. self.shared_includes_content = shared_includes_content
  137. self.line_index = 0
  138. self.input_fd = open(self.input_file, 'r')
  139. self.input_content = self.input_fd.readlines()
  140. self.output_fd = open(self.output_file, 'w')
  141. # The state may be normal_now(in the middle of {}),class_now,namespace_now
  142. self.stack = []
  143. self.stack_class = []
  144. self.stack_template = []
  145. # record funcs generated by h2cc func
  146. self.func_list_exist = []
  147. def __del__(self):
  148. self.input_fd.close()
  149. self.output_fd.close()
  150. del self.stack
  151. del self.stack_class
  152. del self.stack_template
  153. del self.func_list_exist
  154. def just_skip(self):
  155. # skip blank line or comment
  156. if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search(
  157. self.input_content[self.line_index]): # /n or comment using //
  158. self.line_index += 1
  159. if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /*
  160. while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */
  161. self.line_index += 1
  162. self.line_index += 1
  163. # skip define
  164. if pattern_define.search(self.input_content[self.line_index]):
  165. while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search(
  166. self.input_content[self.line_index]):
  167. self.line_index += 1
  168. self.line_index += 1
  169. def write_inc_content(self):
  170. for shared_include_content in self.shared_includes_content:
  171. self.output_fd.write(shared_include_content)
  172. def h2cc(self):
  173. """
  174. :return:
  175. """
  176. logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file)
  177. global pattern_comment
  178. global pattern_comment_2_start
  179. global pattern_comment_2_end
  180. global pattern_blank_line
  181. global pattern_func
  182. global pattern_keyword
  183. global pattern_leading_space
  184. global pattern_func_name
  185. global pattern_template
  186. global pattern_template_end
  187. global pattern_namespace
  188. global pattern_class
  189. global pattern_start
  190. global pattern_end
  191. global line_index
  192. # write inc content
  193. self.write_inc_content()
  194. # core processing cycle, process the input .h file by line
  195. while self.line_index < len(self.input_content):
  196. # handle comment and blank line
  197. self.just_skip()
  198. # match namespace
  199. self.handle_namespace()
  200. # match template
  201. template_string = self.handle_template()
  202. # match class
  203. line = self.input_content[self.line_index]
  204. match_class = pattern_class.search(line)
  205. match_start = pattern_start.search(line)
  206. handle_class_result = self.handle_class(template_string, line, match_start, match_class)
  207. if handle_class_result == "continue":
  208. continue
  209. # match "}"
  210. handle_stack_result = self.handle_stack(match_start)
  211. if handle_stack_result == "continue":
  212. continue
  213. # handle func
  214. handle_func1_result, line, start_i = self.handle_func1(line)
  215. if handle_func1_result == "continue":
  216. continue
  217. # here means func is found
  218. # delete key word
  219. line = pattern_keyword.sub('', line)
  220. logging.info("line[%s]", line)
  221. # Class member function
  222. # if friend we will not add class name
  223. friend_match = re.search('friend ', line)
  224. if len(self.stack_class) > 0 and not friend_match:
  225. line, func_name = self.handle_class_member_func(line, template_string)
  226. # Normal functions
  227. else:
  228. line, func_name = self.handle_normal_func(line, template_string)
  229. need_generate = need_generate_func(line)
  230. # func body
  231. line += self.implement_function(line)
  232. # comment
  233. line = self.gen_comment(start_i) + line
  234. # write to out file
  235. self.write_func_content(line, func_name, need_generate)
  236. # next loop
  237. self.line_index += 1
  238. logging.info('Added %s functions', len(self.func_list_exist))
  239. logging.info('Successfully converted,please see ' + self.output_file)
  240. def handle_func1(self, line):
  241. """
  242. :param line:
  243. :return:
  244. """
  245. find1 = re.search('[(]', line)
  246. if not find1:
  247. self.line_index += 1
  248. return "continue", line, None
  249. find2 = re.search('[)]', line)
  250. start_i = self.line_index
  251. space_match = pattern_leading_space.search(line)
  252. # deal with
  253. # int abc(int a,
  254. # int b)
  255. if find1 and (not find2):
  256. self.line_index += 1
  257. line2 = self.input_content[self.line_index]
  258. if space_match:
  259. line2 = re.sub('^' + space_match.group(1), '', line2)
  260. line += line2
  261. while self.line_index < len(self.input_content) and (not re.search('[)]', line2)):
  262. self.line_index += 1
  263. line2 = self.input_content[self.line_index]
  264. line2 = re.sub('^' + space_match.group(1), '', line2)
  265. line += line2
  266. match_start = pattern_start.search(self.input_content[self.line_index])
  267. match_end = pattern_end.search(self.input_content[self.line_index])
  268. if match_start: # like ) { or ) {} int the last line
  269. if not match_end:
  270. self.stack.append('normal_now')
  271. ii = start_i
  272. while ii <= self.line_index:
  273. ii += 1
  274. self.line_index += 1
  275. return "continue", line, start_i
  276. logging.info("line[%s]", line)
  277. # ' int abc();'->'int abc()'
  278. (line, match) = pattern_func.subn(r'\2\n', line)
  279. logging.info("line[%s]", line)
  280. # deal with case:
  281. # 'int \n abc(int a, int b)'
  282. if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]):
  283. line = self.input_content[start_i - 1] + line
  284. line = line.lstrip()
  285. if not match:
  286. self.line_index += 1
  287. return "continue", line, start_i
  288. return "pass", line, start_i
  289. def handle_stack(self, match_start):
  290. """
  291. :param match_start:
  292. :return:
  293. """
  294. line = self.input_content[self.line_index]
  295. match_end = pattern_end.search(line)
  296. if match_start:
  297. self.stack.append('normal_now')
  298. if match_end:
  299. top_status = self.stack.pop()
  300. if top_status == 'namespace_now':
  301. self.output_fd.write(line + '\n')
  302. elif top_status == 'class_now':
  303. self.stack_class.pop()
  304. self.stack_template.pop()
  305. if match_start or match_end:
  306. self.line_index += 1
  307. return "continue"
  308. if len(self.stack) > 0 and self.stack[-1] == 'normal_now':
  309. self.line_index += 1
  310. return "continue"
  311. return "pass"
  312. def handle_class(self, template_string, line, match_start, match_class):
  313. """
  314. :param template_string:
  315. :param line:
  316. :param match_start:
  317. :param match_class:
  318. :return:
  319. """
  320. if match_class: # we face a class
  321. self.stack_template.append(template_string)
  322. self.stack.append('class_now')
  323. class_name = match_class.group(3)
  324. # class template specializations: class A<u,Node<u> >
  325. if '<' in class_name:
  326. k = line.index('<')
  327. fit = 1
  328. for ii in range(k + 1, len(line)):
  329. if line[ii] == '<':
  330. fit += 1
  331. if line[ii] == '>':
  332. fit -= 1
  333. if fit == 0:
  334. break
  335. class_name += line[k + 1:ii + 1]
  336. logging.info('class_name[%s]', class_name)
  337. self.stack_class.append(class_name)
  338. while not match_start:
  339. self.line_index += 1
  340. line = self.input_content[self.line_index]
  341. match_start = pattern_start.search(line)
  342. self.line_index += 1
  343. return "continue"
  344. return "pass"
  345. def handle_template(self):
  346. line = self.input_content[self.line_index]
  347. match_template = pattern_template.search(line)
  348. template_string = ''
  349. if match_template:
  350. match_template_end = pattern_template_end.search(line)
  351. template_string = line
  352. while not match_template_end:
  353. self.line_index += 1
  354. line = self.input_content[self.line_index]
  355. template_string += line
  356. match_template_end = pattern_template_end.search(line)
  357. self.line_index += 1
  358. return template_string
  359. def handle_namespace(self):
  360. line = self.input_content[self.line_index]
  361. match_namespace = pattern_namespace.search(line)
  362. if match_namespace: # we face namespace
  363. self.output_fd.write(line + '\n')
  364. self.stack.append('namespace_now')
  365. self.line_index += 1
  366. def handle_normal_func(self, line, template_string):
  367. template_line = ''
  368. self.stack_template.append(template_string)
  369. if self.stack_template[-1] != '':
  370. template_line = re.sub(r'\s*template', 'template', self.stack_template[-1])
  371. # change '< class T = a, class U = A(3)>' to '<class T, class U>'
  372. template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
  373. template_line = re.sub(r'\s*=.*,', ',', template_line)
  374. template_line = re.sub(r'\s*=.*', '', template_line)
  375. line = re.sub(r'\s*=.*,', ',', line)
  376. line = re.sub(r'\s*=.*\)', ')', line)
  377. line = template_line + line
  378. self.stack_template.pop()
  379. func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
  380. logging.info("line[%s]", line)
  381. logging.info("func_name[%s]", func_name)
  382. return line, func_name
  383. def handle_class_member_func(self, line, template_string):
  384. template_line = ''
  385. x = ''
  386. if template_string != '':
  387. template_string = re.sub(r'\s*template', 'template', template_string)
  388. template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string)
  389. template_string = re.sub(r'\s*=.*,', ',', template_string)
  390. template_string = re.sub(r'\s*=.*', '', template_string)
  391. if self.stack_template[-1] != '':
  392. if not (re.search(r'<\s*>', stack_template[-1])):
  393. template_line = re.sub(r'^\s*template', 'template', stack_template[-1])
  394. if not (re.search(r'<.*>', self.stack_class[-1])):
  395. # for x we get like template<class T, typename U> -> <T,U>
  396. x = re.sub(r'template\s*<', '<', template_line) # remove template -> <class T, typename U>
  397. x = re.sub(r'\n', '', x)
  398. x = re.sub(r'\s*=.*,', ',', x)
  399. x = re.sub(r'\s*=.*\>', '>', x)
  400. x = x.rstrip() # remove \n
  401. x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '',
  402. x) # remove class,typename -> <T, U>
  403. x = re.sub(r'<\s+', '<', x)
  404. x = re.sub(r'\s+>', '>', x)
  405. x = re.sub(r'\s+,', ',', x)
  406. x = re.sub(r',\s+', ', ', x)
  407. line = re.sub(r'\s*=\s+0', '', line)
  408. line = re.sub(r'\s*=\s+.*,', ',', line)
  409. line = re.sub(r'\s*=\s+.*\)', ')', line)
  410. logging.info("x[%s]\nline[%s]", x, line)
  411. # if the function is long, void ABC::foo()
  412. # breaks into two lines void ABC::\n foo()
  413. temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1)
  414. if len(temp_line) > max_code_len_per_line:
  415. line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1)
  416. else:
  417. line = temp_line
  418. logging.info("line[%s]", line)
  419. # add template as the above if there is one
  420. template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
  421. template_line = re.sub(r'\s*=.*,', ',', template_line)
  422. template_line = re.sub(r'\s*=.*', '', template_line)
  423. line = template_line + template_string + line
  424. func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
  425. logging.info("line[%s]", line)
  426. logging.info("func_name[%s]", func_name)
  427. return line, func_name
  428. def write_func_content(self, content, func_name, need_generate):
  429. if not (func_name in self.func_list_exist) and need_generate:
  430. self.output_fd.write(content)
  431. self.func_list_exist.append(func_name)
  432. logging.info('add func:[%s]', func_name)
  433. def gen_comment(self, start_i):
  434. comment_line = ''
  435. # Function comments are on top of function declarations, copy them over
  436. k = start_i - 1 # one line before this func start
  437. if pattern_template.search(self.input_content[k]):
  438. k -= 1
  439. if pattern_comment_2_end.search(self.input_content[k]):
  440. comment_line = self.input_content[k].lstrip()
  441. while not pattern_comment_2_start.search(self.input_content[k]):
  442. k -= 1
  443. comment_line = self.input_content[k].lstrip() + comment_line
  444. else:
  445. for j in range(k, 0, -1):
  446. c_line = self.input_content[j]
  447. if pattern_comment.search(c_line):
  448. c_line = re.sub(r'\s*//', '//', c_line)
  449. comment_line = c_line + comment_line
  450. else:
  451. break
  452. return comment_line
  453. @staticmethod
  454. def implement_function(func):
  455. function_def = ''
  456. function_def += '{\n'
  457. all_items = func.split()
  458. start = 0
  459. return_type = all_items[start]
  460. if return_type == "const":
  461. start += 1
  462. return_type = all_items[start]
  463. if return_type.startswith(('std::map', 'std::set', 'std::vector')):
  464. return_type = "std::map"
  465. if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')):
  466. return_type = "Ptr"
  467. if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
  468. return_type += "&"
  469. if RETURN_STATEMENTS.__contains__(return_type):
  470. function_def += RETURN_STATEMENTS[return_type]
  471. else:
  472. logging.warning("Unhandled return type[%s]", return_type)
  473. function_def += '\n'
  474. function_def += '}\n'
  475. function_def += '\n'
  476. return function_def
  477. def collect_header_files(path):
  478. """
  479. :param path:
  480. :return:
  481. """
  482. header_files = []
  483. shared_includes_content = []
  484. for root, dirs, files in os.walk(path):
  485. files.sort()
  486. for file in files:
  487. if file.find("git") >= 0:
  488. continue
  489. if not file.endswith('.h'):
  490. continue
  491. file_path = os.path.join(root, file)
  492. file_path = file_path.replace('\\', '/')
  493. header_files.append(file_path)
  494. include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:])
  495. shared_includes_content.append(include_str)
  496. # for acl error code
  497. shared_includes_content.append('#include <iostream>\n')
  498. shared_includes_content.append('const int ACL_ERROR_COMPILING_STUB_MODE = 100039;\n')
  499. return header_files, shared_includes_content
  500. def generate_stub_file(inc_dir, out_cc_dir):
  501. """
  502. :param inc_dir:
  503. :param out_cc_dir:
  504. :return:
  505. """
  506. target_header_files, shared_includes_content = collect_header_files(inc_dir)
  507. for header_file in target_header_files:
  508. if not file_endswith_white_list_suffix(header_file):
  509. continue
  510. cc_file = re.sub('.h*$', '.cc', header_file)
  511. h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content)
  512. h_2_cc.h2cc()
  513. def gen_code(inc_dir, out_cc_dir):
  514. """
  515. :param inc_dir:
  516. :param out_cc_dir:
  517. :return:
  518. """
  519. if not inc_dir.endswith('/'):
  520. inc_dir += '/'
  521. if not out_cc_dir.endswith('/'):
  522. out_cc_dir += '/'
  523. for include_dir_key_word in include_dir_key_words:
  524. generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir)
  525. if __name__ == '__main__':
  526. inc_dir = sys.argv[1]
  527. out_cc_dir = sys.argv[2]
  528. gen_code(inc_dir, out_cc_dir)

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示