You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_ops.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from io import StringIO
  10. import re
  11. import argparse
  12. import subprocess
  13. import os
  14. import textwrap
  15. import inspect
  16. def camel2underscore(
  17. name, *,
  18. first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'),
  19. all_cap_re = re.compile('([a-z])([A-Z]+)')):
  20. if name.isupper():
  21. return name.lower()
  22. s1 = first_cap_re.sub(r'\1_\2', name)
  23. return all_cap_re.sub(r'\1_\2', s1).lower()
  24. def caller_lineno(level=1):
  25. f = inspect.stack()[level+1]
  26. return '%s:%d' % (f.filename, f.lineno)
  27. class Doc:
  28. """wrap an identifier and doc"""
  29. _id = None
  30. def __init__(self, id_, doc, typestr=None, default=None):
  31. self._id = id_
  32. self.doc = doc
  33. self.typestr = typestr
  34. self.default = default
  35. def __str__(self):
  36. return self._id
  37. class Context:
  38. fout = None
  39. def __init__(self):
  40. self.fout = StringIO()
  41. self.indent = 0
  42. self.skipped = []
  43. self.generated_signature = set()
  44. self.generated_opr = dict()
  45. def write(self, text, *fmt, indent=0):
  46. text = textwrap.dedent(text)
  47. text = textwrap.indent(text, ' '*4*(self.indent + indent))
  48. text = text % fmt
  49. if not text.endswith('\n'):
  50. text += '\n'
  51. self.fout.write(text)
  52. def _gen_signature(self, params, *, have_config=True,
  53. has_out_dtype=False):
  54. sig = ['self', '*']
  55. for i, _ in params:
  56. sig.append('{}=None'.format(i))
  57. if have_config:
  58. sig.extend(['name=None', 'comp_node=None', 'config=None'])
  59. if has_out_dtype:
  60. sig.append('dtype=None')
  61. if params:
  62. sig.append('**kwargs')
  63. if sig[-1] == '*':
  64. sig.pop()
  65. return ', '.join(sig)
  66. def _write_canonize_inputs(self, inputs, convert_inputs,
  67. convert_inputs_args=None,
  68. has_out_dtype=False):
  69. self._write_gen_config(has_out_dtype)
  70. inputs = list(map(str, inputs))
  71. if convert_inputs_args is None:
  72. if inputs[0][0] == '*':
  73. arg = inputs[0][1:]
  74. else:
  75. arg = '[{}]'.format(', '.join(inputs))
  76. else:
  77. arg = convert_inputs_args
  78. self.write('inputs = helper.%s(%s, config=config)',
  79. convert_inputs, arg)
  80. def _write_gen_config(self, has_out_dtype=False):
  81. self.write('''\
  82. config = config or Config()
  83. if name:
  84. config.name = name
  85. if comp_node:
  86. config.comp_node = comp_node
  87. ''')
  88. if has_out_dtype:
  89. self.write('''\
  90. if dtype:
  91. config.dtype = dtype
  92. ''')
  93. self.write('self.config = config')
  94. def _write_make_params(self, params):
  95. for pname, ptype in params:
  96. self.write('self.%s = helper.make_param(%s, param_defs.%s, kwargs)',
  97. pname, pname, ptype)
  98. self.write('assert not kwargs, "extra kwargs: {}".format(kwargs)')
  99. def _write_doc(self, inputs, params, desc):
  100. self.write('"""')
  101. if isinstance(desc, Doc):
  102. assert desc._id is None
  103. self.write(desc.doc)
  104. elif desc:
  105. for i in textwrap.wrap(desc, 75):
  106. self.write(i)
  107. self.write('')
  108. for i in inputs:
  109. name = str(i)
  110. typestr = ':class:`.Tensor`'
  111. if name[0] == '*':
  112. name = name[1:]
  113. typestr = 'list of ' + typestr
  114. if isinstance(i, Doc):
  115. self.write(':param %s: %s', name, i.doc)
  116. if i.typestr is not None:
  117. typestr = i.typestr
  118. if typestr:
  119. if not isinstance(i, Doc):
  120. self.write(':param %s: ', name)
  121. self.write(':type %s: %s', name, typestr)
  122. for pname, ptype in params:
  123. self.write(':param %s: ', pname)
  124. self.write(':type %s: :class:`~megbrain.opr_param_defs.%s`',
  125. pname, ptype)
  126. self.write(':param comp_node: see doc for *config*')
  127. self.write(':param name: see doc for *config*')
  128. self.write(
  129. ':param config: give a :class:`.OperatorNodeConfig` object to set '
  130. 'operator name and comp node. This can also be achieved by passing '
  131. '*comp_node* and *name* separately.')
  132. self.write('"""')
  133. def _write_return(self, name, outputs):
  134. self.write('opdef = helper.PodOpVisitor("%s", config, params)', name)
  135. self.write('outputs = helper.create_op(opdef, inputs)')
  136. if outputs:
  137. self.write('outputs = [outputs[i] for i in %s]',
  138. list(map(int, outputs)))
  139. self.write('return helper.convert_outputs(outputs)')
  140. def decl_opr(self, name, *, inputs, params, desc=None, pyname=None,
  141. canonize_input_vars=None,
  142. canonize_input_vars_args=None, body=None,
  143. outputs=None, version=0, has_out_dtype=False):
  144. """
  145. :param inputs: name of variable inputs; a name starting with `*' means
  146. a list of vars
  147. :type inputs: list of str
  148. :param params: (param name, param type) pairs; it can be a single
  149. string representing the param type, and param name defaults to
  150. 'param'
  151. :type params: list of pair of str, or str
  152. :param pyname: python function name
  153. :param body: extra statements to be placed before calling _create_opr
  154. :param outputs: the indices of output vars to be selected from raw opr
  155. result
  156. """
  157. class OprItem:
  158. def __init__(self, inputs, desc, params, version, has_out_dtype):
  159. self.inputs = inputs
  160. self.desc = desc
  161. self.params = params
  162. self.version = version
  163. self.has_out_dtype = has_out_dtype
  164. if body:
  165. self.skipped.append(name)
  166. return
  167. signature = (name, params if isinstance(params, str) else frozenset(params), has_out_dtype, version)
  168. if signature in self.generated_signature:
  169. self.skipped.append(name)
  170. return
  171. else:
  172. self.generated_signature.add(signature)
  173. body = body or []
  174. if isinstance(params, str):
  175. params = [('param', params)]
  176. assert params
  177. if name in self.generated_opr:
  178. org_opr = self.generated_opr[name]
  179. if version > org_opr.version:
  180. def compare_doc(a, b):
  181. if isinstance(a, str):
  182. return a == b
  183. else:
  184. assert isinstance(a, Doc)
  185. return a.doc == b.doc
  186. assert compare_doc(desc, org_opr.desc)
  187. assert len(inputs) == len(org_opr.inputs)
  188. for i, j in zip(inputs, org_opr.inputs):
  189. assert compare_doc(i, j)
  190. self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
  191. else:
  192. self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
  193. def write_generated_oprs(self):
  194. for opr, opr_item in self.generated_opr.items():
  195. name = opr
  196. params = opr_item.params
  197. version = opr_item.version
  198. has_out_dtype = opr_item.has_out_dtype
  199. self.write('# %s', caller_lineno())
  200. self.write('class %s(PodOpVisitor):', name)
  201. self.indent += 1
  202. param_names, _ = zip(*params)
  203. self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names)))
  204. self.write('name = "%s"', '{}V{}'.format(name, version) if version else name)
  205. self.write('\n')
  206. self.write('def __init__(%s):',
  207. self._gen_signature(params,
  208. has_out_dtype=has_out_dtype))
  209. self.indent += 1
  210. self._write_gen_config(has_out_dtype=has_out_dtype)
  211. self.write('\n')
  212. self._write_make_params(params)
  213. self.write('\n')
  214. self.indent -= 2
  215. def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None,
  216. desc=None, local_defs=[], have_config=True, params=None, has_out_dtype=False):
  217. self.skipped.append(name)
  218. def get_str(self):
  219. return self.fout.getvalue()
  220. def all_list(self):
  221. buf = StringIO()
  222. print(
  223. '[',
  224. *(' "%s",' % i for i in self.generated_opr),
  225. ']',
  226. sep='\n',
  227. file=buf
  228. )
  229. return buf.getvalue()
  230. def main():
  231. parser = argparse.ArgumentParser(
  232. description='generate operator function def code from decl file')
  233. parser.add_argument('inputs', nargs='+')
  234. parser.add_argument('--output', '-o')
  235. args = parser.parse_args()
  236. gen = Context()
  237. exec_globals = {
  238. 'decl_opr': gen.decl_opr,
  239. 'decl_raw_opr': gen.decl_raw_opr,
  240. 'Doc': Doc,
  241. 'camel2underscore': camel2underscore,
  242. }
  243. for i in args.inputs:
  244. print('generate ops from {}'.format(i))
  245. with open(i) as fin:
  246. exec(compile(fin.read(), i, 'exec'), exec_globals)
  247. gen.write_generated_oprs()
  248. try:
  249. git_commit = subprocess.check_output(
  250. ['git', 'rev-parse', 'HEAD'], universal_newlines=True,
  251. cwd=os.path.dirname(os.path.realpath(__file__))).strip()
  252. except:
  253. git_commit = 'NOT_A_GIT_REPO'
  254. def relpath(*args):
  255. d = os.path.dirname(__file__)
  256. return os.path.join(d, *args)
  257. with open(relpath('ops.tpl.py')) as fin:
  258. with open(args.output, 'w') as fout:
  259. fout.write(fin.read()
  260. .replace('{%all%}', gen.all_list())
  261. .replace('{%body%}', gen.get_str())
  262. .replace('{%git_commit%}', git_commit))
  263. print('Skipped:')
  264. print(*gen.skipped, sep='\n')
  265. if __name__ == '__main__':
  266. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台