|
|
@@ -14,7 +14,6 @@ import os |
|
|
|
import textwrap |
|
|
|
import inspect |
|
|
|
|
|
|
|
|
|
|
|
def camel2underscore( |
|
|
|
name, *, |
|
|
|
first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'), |
|
|
@@ -50,9 +49,9 @@ class Context: |
|
|
|
def __init__(self): |
|
|
|
self.fout = StringIO() |
|
|
|
self.indent = 0 |
|
|
|
self.generated = [] |
|
|
|
self.skipped = [] |
|
|
|
self.generated_signature = set() |
|
|
|
self.generated_opr = dict() |
|
|
|
|
|
|
|
def write(self, text, *fmt, indent=0): |
|
|
|
text = textwrap.dedent(text) |
|
|
@@ -181,6 +180,15 @@ class Context: |
|
|
|
:param outputs: the indices of output vars to be selected from raw opr |
|
|
|
result |
|
|
|
""" |
|
|
|
|
|
|
|
class OprItem: |
|
|
|
def __init__(self, inputs, desc, params, version, has_out_dtype): |
|
|
|
self.inputs = inputs |
|
|
|
self.desc = desc |
|
|
|
self.params = params |
|
|
|
self.version = version |
|
|
|
self.has_out_dtype = has_out_dtype |
|
|
|
|
|
|
|
if body: |
|
|
|
self.skipped.append(name) |
|
|
|
return |
|
|
@@ -197,29 +205,56 @@ class Context: |
|
|
|
params = [('param', params)] |
|
|
|
assert params |
|
|
|
|
|
|
|
self.write('# %s', caller_lineno()) |
|
|
|
self.write('class %s(PodOpVisitor):', name) |
|
|
|
self.indent += 1 |
|
|
|
if name in self.generated_opr: |
|
|
|
org_opr = self.generated_opr[name] |
|
|
|
if version > org_opr.version: |
|
|
|
def compare_doc(a, b): |
|
|
|
if isinstance(a, str): |
|
|
|
return a == b |
|
|
|
else: |
|
|
|
assert isinstance(a, Doc) |
|
|
|
return a.doc == b.doc |
|
|
|
|
|
|
|
assert compare_doc(desc, org_opr.desc) |
|
|
|
assert len(inputs) == len(org_opr.inputs) |
|
|
|
for i, j in zip(inputs, org_opr.inputs): |
|
|
|
assert compare_doc(i, j) |
|
|
|
|
|
|
|
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype) |
|
|
|
else: |
|
|
|
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype) |
|
|
|
|
|
|
|
def write_generated_oprs(self): |
|
|
|
|
|
|
|
for opr, opr_item in self.generated_opr.items(): |
|
|
|
|
|
|
|
name = opr |
|
|
|
params = opr_item.params |
|
|
|
version = opr_item.version |
|
|
|
has_out_dtype = opr_item.has_out_dtype |
|
|
|
|
|
|
|
self.write('# %s', caller_lineno()) |
|
|
|
self.write('class %s(PodOpVisitor):', name) |
|
|
|
self.indent += 1 |
|
|
|
|
|
|
|
param_names, _ = zip(*params) |
|
|
|
self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names))) |
|
|
|
self.write('name = "%s"', '{}V{}'.format(name, version) if version else name) |
|
|
|
self.write('\n') |
|
|
|
param_names, _ = zip(*params) |
|
|
|
self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names))) |
|
|
|
self.write('name = "%s"', '{}V{}'.format(name, version) if version else name) |
|
|
|
self.write('\n') |
|
|
|
|
|
|
|
self.write('def __init__(%s):', |
|
|
|
self._gen_signature(params, |
|
|
|
has_out_dtype=has_out_dtype)) |
|
|
|
self.indent += 1 |
|
|
|
self.write('def __init__(%s):', |
|
|
|
self._gen_signature(params, |
|
|
|
has_out_dtype=has_out_dtype)) |
|
|
|
self.indent += 1 |
|
|
|
|
|
|
|
self._write_gen_config(has_out_dtype=has_out_dtype) |
|
|
|
self.write('\n') |
|
|
|
self._write_gen_config(has_out_dtype=has_out_dtype) |
|
|
|
self.write('\n') |
|
|
|
|
|
|
|
self._write_make_params(params) |
|
|
|
self._write_make_params(params) |
|
|
|
|
|
|
|
self.write('\n') |
|
|
|
self.indent -= 2 |
|
|
|
self.write('\n') |
|
|
|
self.indent -= 2 |
|
|
|
|
|
|
|
self.generated.append(name) |
|
|
|
|
|
|
|
def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None, |
|
|
|
desc=None, local_defs=[], have_config=True): |
|
|
@@ -232,7 +267,7 @@ class Context: |
|
|
|
buf = StringIO() |
|
|
|
print( |
|
|
|
'[', |
|
|
|
*(' "%s",' % i for i in self.generated), |
|
|
|
*(' "%s",' % i for i in self.generated_opr), |
|
|
|
']', |
|
|
|
sep='\n', |
|
|
|
file=buf |
|
|
@@ -259,6 +294,7 @@ def main(): |
|
|
|
with open(i) as fin: |
|
|
|
exec(compile(fin.read(), i, 'exec'), exec_globals) |
|
|
|
|
|
|
|
gen.write_generated_oprs() |
|
|
|
try: |
|
|
|
git_commit = subprocess.check_output( |
|
|
|
['git', 'rev-parse', 'HEAD'], universal_newlines=True, |
|
|
|