|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from io import StringIO
- import re
- import argparse
- import subprocess
- import os
- import textwrap
- import inspect
-
- def camel2underscore(
- name, *,
- first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'),
- all_cap_re = re.compile('([a-z])([A-Z]+)')):
- if name.isupper():
- return name.lower()
- s1 = first_cap_re.sub(r'\1_\2', name)
- return all_cap_re.sub(r'\1_\2', s1).lower()
-
-
- def caller_lineno(level=1):
- f = inspect.stack()[level+1]
- return '%s:%d' % (f.filename, f.lineno)
-
-
- class Doc:
- """wrap an identifier and doc"""
- _id = None
-
- def __init__(self, id_, doc, typestr=None, default=None):
- self._id = id_
- self.doc = doc
- self.typestr = typestr
- self.default = default
-
- def __str__(self):
- return self._id
-
-
- class Context:
- fout = None
-
- def __init__(self):
- self.fout = StringIO()
- self.indent = 0
- self.skipped = []
- self.generated_signature = set()
- self.generated_opr = dict()
-
- def write(self, text, *fmt, indent=0):
- text = textwrap.dedent(text)
- text = textwrap.indent(text, ' '*4*(self.indent + indent))
- text = text % fmt
- if not text.endswith('\n'):
- text += '\n'
- self.fout.write(text)
-
- def _gen_signature(self, params, *, have_config=True,
- has_out_dtype=False):
- sig = ['self', '*']
-
- for i, _ in params:
- sig.append('{}=None'.format(i))
-
- if have_config:
- sig.extend(['name=None', 'comp_node=None', 'config=None'])
- if has_out_dtype:
- sig.append('dtype=None')
-
- if params:
- sig.append('**kwargs')
-
- if sig[-1] == '*':
- sig.pop()
- return ', '.join(sig)
-
- def _write_canonize_inputs(self, inputs, convert_inputs,
- convert_inputs_args=None,
- has_out_dtype=False):
- self._write_gen_config(has_out_dtype)
- inputs = list(map(str, inputs))
- if convert_inputs_args is None:
- if inputs[0][0] == '*':
- arg = inputs[0][1:]
- else:
- arg = '[{}]'.format(', '.join(inputs))
- else:
- arg = convert_inputs_args
- self.write('inputs = helper.%s(%s, config=config)',
- convert_inputs, arg)
-
- def _write_gen_config(self, has_out_dtype=False):
- self.write('''\
- config = config or Config()
- if name:
- config.name = name
- if comp_node:
- config.comp_node = comp_node
- ''')
- if has_out_dtype:
- self.write('''\
- if dtype:
- config.dtype = dtype
- ''')
- self.write('self.config = config')
-
- def _write_make_params(self, params):
- for pname, ptype in params:
- self.write('self.%s = helper.make_param(%s, param_defs.%s, kwargs)',
- pname, pname, ptype)
- self.write('assert not kwargs, "extra kwargs: {}".format(kwargs)')
-
- def _write_doc(self, inputs, params, desc):
- self.write('"""')
- if isinstance(desc, Doc):
- assert desc._id is None
- self.write(desc.doc)
- elif desc:
- for i in textwrap.wrap(desc, 75):
- self.write(i)
-
- self.write('')
- for i in inputs:
- name = str(i)
- typestr = ':class:`.Tensor`'
- if name[0] == '*':
- name = name[1:]
- typestr = 'list of ' + typestr
- if isinstance(i, Doc):
- self.write(':param %s: %s', name, i.doc)
- if i.typestr is not None:
- typestr = i.typestr
- if typestr:
- if not isinstance(i, Doc):
- self.write(':param %s: ', name)
- self.write(':type %s: %s', name, typestr)
-
- for pname, ptype in params:
- self.write(':param %s: ', pname)
- self.write(':type %s: :class:`~megbrain.opr_param_defs.%s`',
- pname, ptype)
-
- self.write(':param comp_node: see doc for *config*')
- self.write(':param name: see doc for *config*')
- self.write(
- ':param config: give a :class:`.OperatorNodeConfig` object to set '
- 'operator name and comp node. This can also be achieved by passing '
- '*comp_node* and *name* separately.')
-
- self.write('"""')
-
- def _write_return(self, name, outputs):
- self.write('opdef = helper.PodOpVisitor("%s", config, params)', name)
- self.write('outputs = helper.create_op(opdef, inputs)')
- if outputs:
- self.write('outputs = [outputs[i] for i in %s]',
- list(map(int, outputs)))
- self.write('return helper.convert_outputs(outputs)')
-
- def decl_opr(self, name, *, inputs, params, desc=None, pyname=None,
- canonize_input_vars=None,
- canonize_input_vars_args=None, body=None,
- outputs=None, version=0, has_out_dtype=False):
- """
- :param inputs: name of variable inputs; a name starting with `*' means
- a list of vars
- :type inputs: list of str
- :param params: (param name, param type) pairs; it can be a single
- string representing the param type, and param name defaults to
- 'param'
- :type params: list of pair of str, or str
- :param pyname: python function name
- :param body: extra statements to be placed before calling _create_opr
- :param outputs: the indices of output vars to be selected from raw opr
- result
- """
-
- class OprItem:
- def __init__(self, inputs, desc, params, version, has_out_dtype):
- self.inputs = inputs
- self.desc = desc
- self.params = params
- self.version = version
- self.has_out_dtype = has_out_dtype
-
- if body:
- self.skipped.append(name)
- return
-
- signature = (name, params if isinstance(params, str) else frozenset(params), has_out_dtype, version)
- if signature in self.generated_signature:
- self.skipped.append(name)
- return
- else:
- self.generated_signature.add(signature)
-
- body = body or []
- if isinstance(params, str):
- params = [('param', params)]
- assert params
-
- if name in self.generated_opr:
- org_opr = self.generated_opr[name]
- if version > org_opr.version:
- def compare_doc(a, b):
- if isinstance(a, str):
- return a == b
- else:
- assert isinstance(a, Doc)
- return a.doc == b.doc
-
- assert compare_doc(desc, org_opr.desc)
- assert len(inputs) == len(org_opr.inputs)
- for i, j in zip(inputs, org_opr.inputs):
- assert compare_doc(i, j)
-
- self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
- else:
- self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
-
- def write_generated_oprs(self):
-
- for opr, opr_item in self.generated_opr.items():
-
- name = opr
- params = opr_item.params
- version = opr_item.version
- has_out_dtype = opr_item.has_out_dtype
-
- self.write('# %s', caller_lineno())
- self.write('class %s(PodOpVisitor):', name)
- self.indent += 1
-
- param_names, _ = zip(*params)
- self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names)))
- self.write('name = "%s"', '{}V{}'.format(name, version) if version else name)
- self.write('\n')
-
- self.write('def __init__(%s):',
- self._gen_signature(params,
- has_out_dtype=has_out_dtype))
- self.indent += 1
-
- self._write_gen_config(has_out_dtype=has_out_dtype)
- self.write('\n')
-
- self._write_make_params(params)
-
- self.write('\n')
- self.indent -= 2
-
-
- def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None,
- desc=None, local_defs=[], have_config=True, params=None, has_out_dtype=False):
- self.skipped.append(name)
-
- def get_str(self):
- return self.fout.getvalue()
-
- def all_list(self):
- buf = StringIO()
- print(
- '[',
- *(' "%s",' % i for i in self.generated_opr),
- ']',
- sep='\n',
- file=buf
- )
- return buf.getvalue()
-
-
- def main():
- parser = argparse.ArgumentParser(
- description='generate operator function def code from decl file')
- parser.add_argument('inputs', nargs='+')
- parser.add_argument('--output', '-o')
- args = parser.parse_args()
-
- gen = Context()
- exec_globals = {
- 'decl_opr': gen.decl_opr,
- 'decl_raw_opr': gen.decl_raw_opr,
- 'Doc': Doc,
- 'camel2underscore': camel2underscore,
- }
- for i in args.inputs:
- print('generate ops from {}'.format(i))
- with open(i) as fin:
- exec(compile(fin.read(), i, 'exec'), exec_globals)
-
- gen.write_generated_oprs()
- try:
- git_commit = subprocess.check_output(
- ['git', 'rev-parse', 'HEAD'], universal_newlines=True,
- cwd=os.path.dirname(os.path.realpath(__file__))).strip()
- except:
- git_commit = 'NOT_A_GIT_REPO'
-
- def relpath(*args):
- d = os.path.dirname(__file__)
- return os.path.join(d, *args)
-
- with open(relpath('ops.tpl.py')) as fin:
- with open(args.output, 'w') as fout:
- fout.write(fin.read()
- .replace('{%all%}', gen.all_list())
- .replace('{%body%}', gen.get_str())
- .replace('{%git_commit%}', git_commit))
-
- print('Skipped:')
- print(*gen.skipped, sep='\n')
-
- if __name__ == '__main__':
- main()
|