You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gen_ops.py 8.8 kB


  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from io import StringIO
  10. import re
  11. import argparse
  12. import subprocess
  13. import os
  14. import textwrap
  15. import inspect
  16. def camel2underscore(
  17. name, *,
  18. first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'),
  19. all_cap_re = re.compile('([a-z])([A-Z]+)')):
  20. if name.isupper():
  21. return name.lower()
  22. s1 = first_cap_re.sub(r'\1_\2', name)
  23. return all_cap_re.sub(r'\1_\2', s1).lower()
  24. def caller_lineno(level=1):
  25. f = inspect.stack()[level+1]
  26. return '%s:%d' % (f.filename, f.lineno)
  27. class Doc:
  28. """wrap an identifier and doc"""
  29. _id = None
  30. def __init__(self, id_, doc, typestr=None, default=None):
  31. self._id = id_
  32. self.doc = doc
  33. self.typestr = typestr
  34. self.default = default
  35. def __str__(self):
  36. return self._id
  37. class Context:
  38. fout = None
  39. def __init__(self):
  40. self.fout = StringIO()
  41. self.indent = 0
  42. self.generated = []
  43. self.skipped = []
  44. def write(self, text, *fmt, indent=0):
  45. text = textwrap.dedent(text)
  46. text = textwrap.indent(text, ' '*4*(self.indent + indent))
  47. text = text % fmt
  48. if not text.endswith('\n'):
  49. text += '\n'
  50. self.fout.write(text)
  51. def _gen_signature(self, params, *, have_config=True,
  52. has_out_dtype=False):
  53. sig = ['self', '*']
  54. for i, _ in params:
  55. sig.append('{}=None'.format(i))
  56. if have_config:
  57. sig.extend(['name=None', 'comp_node=None', 'config=None'])
  58. if has_out_dtype:
  59. sig.append('dtype=None')
  60. if params:
  61. sig.append('**kwargs')
  62. if sig[-1] == '*':
  63. sig.pop()
  64. return ', '.join(sig)
  65. def _write_canonize_inputs(self, inputs, convert_inputs,
  66. convert_inputs_args=None,
  67. has_out_dtype=False):
  68. self._write_gen_config(has_out_dtype)
  69. inputs = list(map(str, inputs))
  70. if convert_inputs_args is None:
  71. if inputs[0][0] == '*':
  72. arg = inputs[0][1:]
  73. else:
  74. arg = '[{}]'.format(', '.join(inputs))
  75. else:
  76. arg = convert_inputs_args
  77. self.write('inputs = helper.%s(%s, config=config)',
  78. convert_inputs, arg)
  79. def _write_gen_config(self, has_out_dtype=False):
  80. self.write('''\
  81. config = config or Config()
  82. if name:
  83. config.name = name
  84. if comp_node:
  85. config.comp_node = comp_node
  86. ''')
  87. if has_out_dtype:
  88. self.write('''\
  89. if dtype:
  90. config.dtype = dtype
  91. ''')
  92. self.write('self.config = config')
  93. def _write_make_params(self, params):
  94. for pname, ptype in params:
  95. self.write('self.%s = helper.make_param(%s, param_defs.%s, kwargs)',
  96. pname, pname, ptype)
  97. self.write('assert not kwargs, "extra kwargs: {}".format(kwargs)')
  98. def _write_doc(self, inputs, params, desc):
  99. self.write('"""')
  100. if isinstance(desc, Doc):
  101. assert desc._id is None
  102. self.write(desc.doc)
  103. elif desc:
  104. for i in textwrap.wrap(desc, 75):
  105. self.write(i)
  106. self.write('')
  107. for i in inputs:
  108. name = str(i)
  109. typestr = ':class:`.Tensor`'
  110. if name[0] == '*':
  111. name = name[1:]
  112. typestr = 'list of ' + typestr
  113. if isinstance(i, Doc):
  114. self.write(':param %s: %s', name, i.doc)
  115. if i.typestr is not None:
  116. typestr = i.typestr
  117. if typestr:
  118. if not isinstance(i, Doc):
  119. self.write(':param %s: ', name)
  120. self.write(':type %s: %s', name, typestr)
  121. for pname, ptype in params:
  122. self.write(':param %s: ', pname)
  123. self.write(':type %s: :class:`~megbrain.opr_param_defs.%s`',
  124. pname, ptype)
  125. self.write(':param comp_node: see doc for *config*')
  126. self.write(':param name: see doc for *config*')
  127. self.write(
  128. ':param config: give a :class:`.OperatorNodeConfig` object to set '
  129. 'operator name and comp node. This can also be achieved by passing '
  130. '*comp_node* and *name* separately.')
  131. self.write('"""')
  132. def _write_return(self, name, outputs):
  133. self.write('opdef = helper.PodOpVisitor("%s", config, params)', name)
  134. self.write('outputs = helper.create_op(opdef, inputs)')
  135. if outputs:
  136. self.write('outputs = [outputs[i] for i in %s]',
  137. list(map(int, outputs)))
  138. self.write('return helper.convert_outputs(outputs)')
  139. def decl_opr(self, name, *, inputs, params, desc=None, pyname=None,
  140. canonize_input_vars=None,
  141. canonize_input_vars_args=None, body=None,
  142. outputs=None, version=0, has_out_dtype=False):
  143. """
  144. :param inputs: name of variable inputs; a name starting with `*' means
  145. a list of vars
  146. :type inputs: list of str
  147. :param params: (param name, param type) pairs; it can be a single
  148. string representing the param type, and param name defaults to
  149. 'param'
  150. :type params: list of pair of str, or str
  151. :param pyname: python function name
  152. :param body: extra statements to be placed before calling _create_opr
  153. :param outputs: the indices of output vars to be selected from raw opr
  154. result
  155. """
  156. if body:
  157. self.skipped.append(name)
  158. return
  159. body = body or []
  160. if isinstance(params, str):
  161. params = [('param', params)]
  162. assert params
  163. self.write('# %s', caller_lineno())
  164. self.write('class %s(PodOpVisitor):', name)
  165. self.indent += 1
  166. param_names, _ = zip(*params)
  167. self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names)))
  168. self.write('name = "%s"', '{}V{}'.format(name, version) if version else name)
  169. self.write('\n')
  170. self.write('def __init__(%s):',
  171. self._gen_signature(params,
  172. has_out_dtype=has_out_dtype))
  173. self.indent += 1
  174. self._write_gen_config(has_out_dtype=has_out_dtype)
  175. self.write('\n')
  176. self._write_make_params(params)
  177. self.write('\n')
  178. self.indent -= 2
  179. self.generated.append(name)
  180. def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None,
  181. desc=None, local_defs=[], have_config=True):
  182. self.skipped.append(name)
  183. def get_str(self):
  184. return self.fout.getvalue()
  185. def all_list(self):
  186. buf = StringIO()
  187. print(
  188. '[',
  189. *(' "%s",' % i for i in self.generated),
  190. ']',
  191. sep='\n',
  192. file=buf
  193. )
  194. return buf.getvalue()
  195. def main():
  196. parser = argparse.ArgumentParser(
  197. description='generate operator function def code from decl file')
  198. parser.add_argument('inputs', nargs='+')
  199. parser.add_argument('--output', '-o')
  200. args = parser.parse_args()
  201. gen = Context()
  202. exec_globals = {
  203. 'decl_opr': gen.decl_opr,
  204. 'decl_raw_opr': gen.decl_raw_opr,
  205. 'Doc': Doc,
  206. 'camel2underscore': camel2underscore,
  207. }
  208. for i in args.inputs:
  209. print('generate ops from {}'.format(i))
  210. with open(i) as fin:
  211. exec(compile(fin.read(), i, 'exec'), exec_globals)
  212. try:
  213. git_commit = subprocess.check_output(
  214. ['git', 'rev-parse', 'HEAD'], universal_newlines=True,
  215. cwd=os.path.dirname(os.path.realpath(__file__))).strip()
  216. except:
  217. git_commit = 'NOT_A_GIT_REPO'
  218. def relpath(*args):
  219. d = os.path.dirname(__file__)
  220. return os.path.join(d, *args)
  221. with open(relpath('ops.tpl.py')) as fin:
  222. with open(args.output, 'w') as fout:
  223. fout.write(fin.read()
  224. .replace('{%all%}', gen.all_list())
  225. .replace('{%body%}', gen.get_str())
  226. .replace('{%git_commit%}', git_commit))
  227. print('Skipped:')
  228. print(*gen.skipped, sep='\n')
  229. if __name__ == '__main__':
  230. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台