Browse Source

fix(mge/imperative): remove duplicated opr

GitOrigin-RevId: 7d49785fad
release-1.1
Megvii Engine Team 4 years ago
parent
commit
fe5649e456
2 changed files with 71 additions and 21 deletions
  1. +56
    -20
      imperative/python/tools/gen_ops.py
  2. +15
    -1
      src/opr/impl/dnn/dnn.oprdecl

+ 56
- 20
imperative/python/tools/gen_ops.py View File

@@ -14,7 +14,6 @@ import os
import textwrap
import inspect


def camel2underscore(
name, *,
first_cap_re=re.compile('([A-Z])([A-Z][a-z]+)'),
@@ -50,9 +49,9 @@ class Context:
def __init__(self):
self.fout = StringIO()
self.indent = 0
self.generated = []
self.skipped = []
self.generated_signature = set()
self.generated_opr = dict()

def write(self, text, *fmt, indent=0):
text = textwrap.dedent(text)
@@ -181,6 +180,15 @@ class Context:
:param outputs: the indices of output vars to be selected from raw opr
result
"""

class OprItem:
def __init__(self, inputs, desc, params, version, has_out_dtype):
self.inputs = inputs
self.desc = desc
self.params = params
self.version = version
self.has_out_dtype = has_out_dtype

if body:
self.skipped.append(name)
return
@@ -197,29 +205,56 @@ class Context:
params = [('param', params)]
assert params

self.write('# %s', caller_lineno())
self.write('class %s(PodOpVisitor):', name)
self.indent += 1
if name in self.generated_opr:
org_opr = self.generated_opr[name]
if version > org_opr.version:
def compare_doc(a, b):
if isinstance(a, str):
return a == b
else:
assert isinstance(a, Doc)
return a.doc == b.doc

assert compare_doc(desc, org_opr.desc)
assert len(inputs) == len(org_opr.inputs)
for i, j in zip(inputs, org_opr.inputs):
assert compare_doc(i, j)

self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)
else:
self.generated_opr[name] = OprItem(inputs, desc, params, version, has_out_dtype)

def write_generated_oprs(self):

for opr, opr_item in self.generated_opr.items():

name = opr
params = opr_item.params
version = opr_item.version
has_out_dtype = opr_item.has_out_dtype

self.write('# %s', caller_lineno())
self.write('class %s(PodOpVisitor):', name)
self.indent += 1

param_names, _ = zip(*params)
self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names)))
self.write('name = "%s"', '{}V{}'.format(name, version) if version else name)
self.write('\n')
param_names, _ = zip(*params)
self.write('param_names = (%s,)', ', '.join(map('"{}"'.format, param_names)))
self.write('name = "%s"', '{}V{}'.format(name, version) if version else name)
self.write('\n')

self.write('def __init__(%s):',
self._gen_signature(params,
has_out_dtype=has_out_dtype))
self.indent += 1
self.write('def __init__(%s):',
self._gen_signature(params,
has_out_dtype=has_out_dtype))
self.indent += 1

self._write_gen_config(has_out_dtype=has_out_dtype)
self.write('\n')
self._write_gen_config(has_out_dtype=has_out_dtype)
self.write('\n')

self._write_make_params(params)
self._write_make_params(params)

self.write('\n')
self.indent -= 2
self.write('\n')
self.indent -= 2

self.generated.append(name)

def decl_raw_opr(self, name, *, inputs, inputs_cvt=[], body=None,
desc=None, local_defs=[], have_config=True):
@@ -232,7 +267,7 @@ class Context:
buf = StringIO()
print(
'[',
*(' "%s",' % i for i in self.generated),
*(' "%s",' % i for i in self.generated_opr),
']',
sep='\n',
file=buf
@@ -259,6 +294,7 @@ def main():
with open(i) as fin:
exec(compile(fin.read(), i, 'exec'), exec_globals)

gen.write_generated_oprs()
try:
git_commit = subprocess.check_output(
['git', 'rev-parse', 'HEAD'], universal_newlines=True,


+ 15
- 1
src/opr/impl/dnn/dnn.oprdecl View File

@@ -95,6 +95,7 @@ r"""
"""))

decl_opr('Local',
pyname='local',
inputs=[Doc('src',
'input image in (batch, channel, row, col) format'),
Doc('filter',
@@ -105,6 +106,19 @@ decl_opr('Local',
desc='batched convolution on channeled 2D images, but kernels are '
'not shared across different output positions')

decl_opr('Local',
pyname='local_v1',
inputs=[Doc('src',
'input image in (batch, channel, row, col) format'),
Doc('filter',
'convolution kernel in '
'(out row, out col, in channel, '
'kern row, kern col, out channel) format')],
params='Convolution',
desc='batched convolution on channeled 2D images, but kernels are '
'not shared across different output positions',
version=1)

decl_opr('GroupLocal',
inputs=[Doc('src',
'input image in (batch, channel, row, col) format'),
@@ -113,7 +127,7 @@ decl_opr('GroupLocal',
'(group, out row, out col, in channel / group, '
'kern row, kern col, out channel / group) format')],
params=[('param', 'Convolution')],
desc='batched convolution on groupped channeled 2D images, but '
desc='batched convolution on groupped channeled 2D images, but '
'kernels are not shared across different output positions',
version=1)



Loading…
Cancel
Save