Browse Source

feat(imperative/src): name operators automatically when tracing

GitOrigin-RevId: ff8eb003c5
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
6fb19b664b
34 changed files with 621 additions and 99 deletions
  1. +7
    -5
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +61
    -8
      imperative/python/megengine/jit/tracing.py
  3. +9
    -5
      imperative/python/megengine/module/module.py
  4. +14
    -3
      imperative/python/megengine/tensor.py
  5. +63
    -0
      imperative/python/megengine/utils/naming.py
  6. +9
    -5
      imperative/python/src/graph_rt.cpp
  7. +29
    -0
      imperative/python/src/ops.cpp
  8. +25
    -3
      imperative/python/src/tensor.cpp
  9. +7
    -0
      imperative/python/src/tensor.h
  10. +169
    -0
      imperative/python/test/unit/test_dump_naming.py
  11. +1
    -1
      imperative/python/test/unit/test_tracing.py
  12. +14
    -4
      imperative/src/impl/op_def.cpp
  13. +4
    -1
      imperative/src/impl/op_trait.h
  14. +3
    -2
      imperative/src/impl/ops/batch_norm.cpp
  15. +5
    -3
      imperative/src/impl/ops/broadcast.cpp
  16. +1
    -1
      imperative/src/impl/ops/collective_comm.cpp
  17. +2
    -2
      imperative/src/impl/ops/cond_take.cpp
  18. +2
    -1
      imperative/src/impl/ops/elemwise.cpp
  19. +2
    -1
      imperative/src/impl/ops/img_proc.cpp
  20. +4
    -2
      imperative/src/impl/ops/io_remote.cpp
  21. +3
    -1
      imperative/src/impl/ops/matrix_inverse.cpp
  22. +3
    -1
      imperative/src/impl/ops/nms.cpp
  23. +8
    -1
      imperative/src/impl/ops/opr_attr.cpp
  24. +2
    -1
      imperative/src/impl/ops/resize.cpp
  25. +84
    -36
      imperative/src/impl/ops/specializations.cpp
  26. +4
    -3
      imperative/src/impl/ops/tensor_manip.cpp
  27. +2
    -1
      imperative/src/impl/ops/tensorrt_runtime.cpp
  28. +2
    -1
      imperative/src/impl/ops/warp_affine.cpp
  29. +7
    -2
      imperative/src/include/megbrain/imperative/op_def.h
  30. +41
    -5
      imperative/tablegen/autogen.cpp
  31. +30
    -0
      imperative/tablegen/helper.h
  32. +1
    -0
      sdk/load-and-run/dump_with_testcase_mge.py
  33. +1
    -0
      src/core/include/megbrain/ir/base.td
  34. +2
    -0
      src/core/include/megbrain/ir/ops.td

+ 7
- 5
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -96,7 +96,7 @@ class Graph(_imperative_rt.ComputingGraph):
data = data.numpy() data = data.numpy()
return self._wrap(_imperative_rt.make_const(self, data, device, data.dtype)) return self._wrap(_imperative_rt.make_const(self, data, device, data.dtype))


def make_const(self, data, dtype=None, device=None):
def make_const(self, data, dtype=None, device=None, name=None):
if isinstance(data, _imperative_rt.DeviceTensorND): if isinstance(data, _imperative_rt.DeviceTensorND):
assert dtype is None and device is None assert dtype is None and device is None
return self._wrap(_imperative_rt.make_shared(self, data)) return self._wrap(_imperative_rt.make_shared(self, data))
@@ -107,7 +107,9 @@ class Graph(_imperative_rt.ComputingGraph):
elif data.dtype == np.int64: elif data.dtype == np.int64:
data = data.astype(np.int32) data = data.astype(np.int32)
device = as_device(device).to_c() device = as_device(device).to_c()
return self._wrap(_imperative_rt.make_const(self, data, device, dtype))
return self._wrap(
_imperative_rt.make_const(self, data, device, dtype, name)
)


def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None): def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None):
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self)
@@ -305,7 +307,7 @@ def dump_graph(
output_vars: Union[Dict[str, VarNode], List[VarNode]], output_vars: Union[Dict[str, VarNode], List[VarNode]],
*, *,
keep_var_name: int = 1, keep_var_name: int = 1,
keep_op_name: bool = True,
keep_opr_name: bool = False,
keep_param_name: bool = False, keep_param_name: bool = False,
keep_opr_priority: bool = False, keep_opr_priority: bool = False,
strip_info_file=None, strip_info_file=None,
@@ -326,7 +328,7 @@ def dump_graph(
* 0: none of the names are kept * 0: none of the names are kept
* 1: (default)keep names of output vars * 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars * 2: keep names of all (output and internal) vars
:param keep_op_name: whether to keep operator names.
:param keep_opr_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be :param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators :param keep_opr_priority: whether to keep priority setting for operators
@@ -370,7 +372,7 @@ def dump_graph(
dump_content = _imperative_rt.dump_graph( dump_content = _imperative_rt.dump_graph(
ov, ov,
keep_var_name, keep_var_name,
keep_op_name,
keep_opr_name,
keep_param_name, keep_param_name,
keep_opr_priority, keep_opr_priority,
stat, stat,


+ 61
- 8
imperative/python/megengine/jit/tracing.py View File

@@ -36,6 +36,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar from ..core.tensor.utils import setscalar
from ..utils.naming import auto_naming
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig




@@ -77,6 +78,7 @@ def exclude_from_trace():
class TensorInfo: class TensorInfo:
__slots__ = ( __slots__ = (
# collected attributes # collected attributes
"name",
"external", "external",
"data_read", "data_read",
"shape_read", "shape_read",
@@ -96,6 +98,7 @@ class TensorInfo:
) )


def __init__(self): def __init__(self):
self.name = None
self.exported = None self.exported = None
self.data_read = None self.data_read = None
self.shape_read = None self.shape_read = None
@@ -290,12 +293,16 @@ class trace:
h = getattr(x, "_mixin_handle", -1) h = getattr(x, "_mixin_handle", -1)
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
h, info = self._new_handle() h, info = self._new_handle()
name = auto_naming.get_scope() + "." + x.c_name if x.c_name else x._name
info.name = name
info.external = True info.external = True
info.device = x.device info.device = x.device
info.dtype = x.dtype info.dtype = x.dtype
info.shape = x.shape info.shape = x.shape
if self._capture_as_const: if self._capture_as_const:
info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False)
info.bound_data = RawTensor(
x.numpy(), x.dtype, x.device, False, name
)


ihandles.append(h) ihandles.append(h)


@@ -669,6 +676,12 @@ class trace:
arg_names=None, arg_names=None,
output_names=None, output_names=None,
append=False, append=False,
keep_var_name: int = 1,
keep_opr_name: bool = False,
keep_param_name: bool = False,
keep_opr_priority: bool = False,
strip_info_file=None,
append_json=False,
optimize_for_inference=True, optimize_for_inference=True,
**kwargs **kwargs
): ):
@@ -681,6 +694,20 @@ class trace:
use the default name if not specified. use the default name if not specified.
:param append: whether output is appended to ``file``. :param append: whether output is appended to ``file``.
Only works when ``file`` is str. Only works when ``file`` is str.
:param keep_var_name: level for keeping variable names:

* 0: none of the names are kept
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
:param keep_opr_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators
:param strip_info_file: a string for path or a file handler. if is not None,
then the dump information for code strip would be written to ``strip_info_file``
:param append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations, :param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True will skip all optimize options if this is False. Default: True


@@ -785,7 +812,10 @@ class trace:
assert info.external assert info.external
assert info.bound_data assert info.bound_data
h2v[h] = graph.make_const( h2v[h] = graph.make_const(
info.bound_data.numpy(), dtype=info.dtype, device=info.device,
info.bound_data.numpy(),
dtype=info.dtype,
device=info.device,
name=info.name,
) )
continue continue
ivars = [] ivars = []
@@ -795,13 +825,26 @@ class trace:
assert info.external assert info.external
assert info.bound_data assert info.bound_data
h2v[h] = graph.make_const( h2v[h] = graph.make_const(
info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
info.bound_data.numpy(),
dtype=info.dtype,
device=dumped_device,
name=info.name,
) )
ivars.append(h2v[h]) ivars.append(h2v[h])
ovars = G.apply_normal_varnode(op, *ivars) ovars = G.apply_normal_varnode(op, *ivars)

auto_naming.record_opnode(ovars[0].op)

assert len(ovars) == len(ohandles) assert len(ovars) == len(ohandles)
h2v.update(zip(ohandles, ovars)) h2v.update(zip(ohandles, ovars))


for i in ohandles:
name = auto_naming.get_var_name(i)
if name is not None:
h2v[i].name = name

auto_naming.remove_duplicate_names()

dest_vars = [] dest_vars = []
for i, h in enumerate(self._output_bindings): for i, h in enumerate(self._output_bindings):
v = h2v[h] v = h2v[h]
@@ -815,7 +858,15 @@ class trace:
if isinstance(file, str): if isinstance(file, str):
permission = "wb" if append == False else "ab" permission = "wb" if append == False else "ab"
file = open(file, permission) file = open(file, permission)
dump_content, dump_info = G.dump_graph(dest_vars)
dump_content, dump_info = G.dump_graph(
dest_vars,
keep_var_name=keep_var_name,
keep_opr_name=keep_opr_name,
keep_param_name=keep_param_name,
keep_opr_priority=keep_opr_priority,
strip_info_file=strip_info_file,
append_json=append_json,
)
file.write(dump_content) file.write(dump_content)
return dump_info return dump_info


@@ -1095,20 +1146,22 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor):
return active_trace._apply_op(op, args) return active_trace._apply_op(op, args)




def apply_const_compiled_mode(value, dtype, device, is_const, no_cache):
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name):
if skip_tracing: if skip_tracing:
args = [ args = [
RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
for x in args for x in args
] ]
unset_tracing() unset_tracing()
ret = RawTensor(value, dtype, device, False)
ret = RawTensor(value, dtype, device, False, name)
set_tracing() set_tracing()
return ret return ret
return active_trace._apply_const(value, dtype, device) return active_trace._apply_const(value, dtype, device)




def apply_with_tracing(op: OpDef, *args: RawTensor): def apply_with_tracing(op: OpDef, *args: RawTensor):
if hasattr(op, "scope"):
op.scope = auto_naming.get_scope()
if active_trace._symbolic: if active_trace._symbolic:
outputs = apply_symbolic_mode(op, *args) outputs = apply_symbolic_mode(op, *args)
else: else:
@@ -1120,12 +1173,12 @@ def apply_with_tracing(op: OpDef, *args: RawTensor):
return list(outputs) return list(outputs)




def apply_const_with_tracing(value, dtype, device, is_const, no_cache):
def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name):
if active_trace._symbolic: if active_trace._symbolic:
outputs = apply_const_symbolic_mode(value, dtype, device) outputs = apply_const_symbolic_mode(value, dtype, device)
else: else:
unset_tracing() unset_tracing()
outputs = (RawTensor(value, dtype, device, False),)
outputs = (RawTensor(value, dtype, device, False, name),)
set_tracing() set_tracing()
active_trace._record_const(outputs) active_trace._record_const(outputs)
return list(outputs) return list(outputs)

+ 9
- 5
imperative/python/megengine/module/module.py View File

@@ -12,12 +12,12 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union


import numpy as np import numpy as np


from ..core._imperative_rt.core2 import pop_scope, push_scope
from ..core.tensor.utils import make_shape_tuple from ..core.tensor.utils import make_shape_tuple
from ..logger import get_logger from ..logger import get_logger
from ..tensor import Parameter, Tensor from ..tensor import Parameter, Tensor
from ..utils.deprecation import deprecated from ..utils.deprecation import deprecated
from ..utils.hook import HookHandler from ..utils.hook import HookHandler
from ..utils.naming import auto_naming


logger = get_logger(__name__) logger = get_logger(__name__)


@@ -69,7 +69,9 @@ class Module(metaclass=ABCMeta):
Base Module class. Base Module class.
""" """


def __init__(self):
def __init__(self, name=""):
self.name = name

# runtime attributes # runtime attributes
self.training = True self.training = True
self.quantize_disabled = False self.quantize_disabled = False
@@ -79,6 +81,8 @@ class Module(metaclass=ABCMeta):
self._forward_hooks = OrderedDict() self._forward_hooks = OrderedDict()


self._modules = [] self._modules = []

# used for profiler and automatic naming
self._name = "{anonymous}" self._name = "{anonymous}"


@abstractmethod @abstractmethod
@@ -105,7 +109,7 @@ class Module(metaclass=ABCMeta):
return HookHandler(self._forward_hooks, hook) return HookHandler(self._forward_hooks, hook)


def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
push_scope(self._name)
auto_naming.push_scope(self.name if self.name else self._name)
for hook in self._forward_pre_hooks.values(): for hook in self._forward_pre_hooks.values():
modified_inputs = hook(self, inputs) modified_inputs = hook(self, inputs)
if modified_inputs is not None: if modified_inputs is not None:
@@ -119,7 +123,7 @@ class Module(metaclass=ABCMeta):
modified_outputs = hook(self, inputs, outputs) modified_outputs = hook(self, inputs, outputs)
if modified_outputs is not None: if modified_outputs is not None:
outputs = modified_outputs outputs = modified_outputs
pop_scope(self._name)
auto_naming.pop_scope()
return outputs return outputs


def _flatten( def _flatten(
@@ -579,7 +583,7 @@ class Module(metaclass=ABCMeta):
value = super().__getattribute__(name) value = super().__getattribute__(name)
if name == "_name": if name == "_name":
return value return value
if _is_module(value):
if isinstance(value, (Tensor, Module)):
value._name = name value._name = name
return value return value




+ 14
- 3
imperative/python/megengine/tensor.py View File

@@ -20,6 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin
from .device import _valid_device, get_default_device from .device import _valid_device, get_default_device
from .logger import get_logger from .logger import get_logger
from .utils.deprecation import deprecated from .utils.deprecation import deprecated
from .utils.naming import auto_naming




class Tensor(_Tensor, ArrayMethodMixin): class Tensor(_Tensor, ArrayMethodMixin):
@@ -27,7 +28,9 @@ class Tensor(_Tensor, ArrayMethodMixin):
dmap_callback = None dmap_callback = None
_q_dict = None _q_dict = None


def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False):
def __new__(
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=""
):
if device is None: if device is None:
cn = get_default_device() cn = get_default_device()
elif isinstance(device, str): elif isinstance(device, str):
@@ -51,8 +54,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
if 0 in data.strides: if 0 in data.strides:
data = data.squeeze().reshape(data.shape) data = data.squeeze().reshape(data.shape)

obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache)
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name)
return obj return obj


@property @property
@@ -91,6 +93,15 @@ class Tensor(_Tensor, ArrayMethodMixin):
piece += ", device={}".format(self.device) + ")" piece += ", device={}".format(self.device) + ")"
return piece return piece


@property
def name(self):
return self.c_name

@name.setter
def name(self, name):
self.c_name = name
auto_naming.record_var_name(self._mixin_handle, name)

@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0")
def set_value(self, value): def set_value(self, value):
if not isinstance(value, _Tensor): if not isinstance(value, _Tensor):


+ 63
- 0
imperative/python/megengine/utils/naming.py View File

@@ -0,0 +1,63 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ..core._imperative_rt.core2 import pop_scope, push_scope


class AutoNaming:
r"""
Name all executed operators automaticlly during tracing and record all tensors
renamed by the user.
"""

def __init__(self):
self.scopes = []
self.c_ops = []
self.name2ops = {}
self.handle2names = {}

def clear(self):
for var in vars(self).values():
var.clear()

def push_scope(self, scope):
push_scope(scope)
self.scopes.append(scope)

def pop_scope(self):
scope = self.scopes.pop()
pop_scope(scope)

def get_scope(self):
return ".".join(self.scopes)

def record_var_name(self, handle, name):
self.handle2names[handle] = name

def get_var_name(self, handle):
return self.handle2names.pop(handle, None)

def record_opnode(self, op):
ops = self.name2ops.get(op.name, [])
ops.append(op)
self.name2ops[op.name] = ops

def remove_duplicate_names(self):
for key, ops in self.name2ops.items():
if len(ops) == 1:
continue
for i, op in enumerate(ops):
op.name = key + "[%s]" % str(i)
if len(op.outputs) == 1:
continue
for var in op.outputs:
var.name = var.name.replace(key, op.name)
self.name2ops.clear()


auto_naming = AutoNaming()

+ 9
- 5
imperative/python/src/graph_rt.cpp View File

@@ -294,7 +294,7 @@ void init_graph_rt(py::module m) {
m.def("dump_graph", []( m.def("dump_graph", [](
const std::vector<VarNode*>& dest_vars, const std::vector<VarNode*>& dest_vars,
int keep_var_name, int keep_var_name,
bool keep_op_name,
bool keep_opr_name,
bool keep_param_name, bool keep_param_name,
bool keep_opr_priority, bool keep_opr_priority,
py::list& stat, py::list& stat,
@@ -307,7 +307,7 @@ void init_graph_rt(py::module m) {
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());


ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name,
keep_opr_priority, keep_op_name};
keep_opr_priority, keep_opr_name};


auto rst = dumper->dump(symvars, config); auto rst = dumper->dump(symvars, config);
for (auto i : rst.inputs) { for (auto i : rst.inputs) {
@@ -457,13 +457,17 @@ void init_graph_rt(py::module m) {
return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node(); return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node();
}); });


m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) {
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype, std::optional<std::string> name) {
if (!cn.valid()) { if (!cn.valid()) {
cn = CompNode::load(get_default_device()); cn = CompNode::load(get_default_device());
} }
OperatorNodeConfig config(cn);
if (name) {
config.name(*name);
}
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype);
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
});
return opr::ImmutableTensor::make(*graph, hv, config).node();
}, py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::none());


m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) { m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) {
if (!cn.valid()) { if (!cn.valid()) {


+ 29
- 0
imperative/python/src/ops.cpp View File

@@ -99,6 +99,14 @@ PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
#define py_get_generic(name, attr) \ #define py_get_generic(name, attr) \
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>


template<typename T>
PyObject* py_get_scope_impl(PyObject* obj, void* /* closure */) {
// T: PyOpXXX inst(): return XXX in opdef.h.inl
auto& op = reinterpret_cast<T*>(obj)->inst();
return pyobj_convert_generic<std::string>::to(op.scope());
}
#define py_get_scope(class) py_get_scope_impl<PyOp(class)>

template<typename T, typename U, U T::Ty::*attr> template<typename T, typename U, U T::Ty::*attr>
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) { if (value == NULL) {
@@ -121,6 +129,27 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
#define py_set_generic(name, attr) \ #define py_set_generic(name, attr) \
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>


template<typename T>
int py_set_scope_impl(PyObject* obj, PyObject* value, void* /* closure */) {
if (value == NULL) {
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
return -1;
}
auto& op = reinterpret_cast<T*>(obj)->inst();
try {
op.set_scope(pyobj_convert_generic<std::string>::from(value));
return 0;
} catch(py::error_already_set& e) {
e.restore();
} catch(py::builtin_exception& e) {
e.set_error();
} catch(...) {
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
}
return -1;
}
#define py_set_scope(class) py_set_scope_impl<PyOp(class)>

struct PyOpDef { struct PyOpDef {
PyObject_HEAD PyObject_HEAD
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;


+ 25
- 3
imperative/python/src/tensor.cpp View File

@@ -24,6 +24,7 @@
#include <pybind11/numpy.h> #include <pybind11/numpy.h>
#include <pybind11/operators.h> #include <pybind11/operators.h>
#include <range/v3/all.hpp> #include <range/v3/all.hpp>
#include <string>


#include <unordered_map> #include <unordered_map>


@@ -222,14 +223,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
} }
} else { } else {
py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType
if (nargs != 4 && nargs != 5) {
throw py::type_error("expect 4 or 5 arguments");
if (nargs != 5 && nargs != 6) {
throw py::type_error("expect 5 or 6 arguments");
} }
auto data = tup[0].cast<py::array>(); auto data = tup[0].cast<py::array>();
DType dtype = tup[1].cast<DType>(); DType dtype = tup[1].cast<DType>();
CompNode cn = tup[2].cast<CompNode>(); CompNode cn = tup[2].cast<CompNode>();
bool is_const = tup[3].cast<bool>(); bool is_const = tup[3].cast<bool>();
bool no_cache = nargs == 5 ? tup[4].cast<bool>() : false;
bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false;
std::string name = tup[nargs - 1].cast<std::string>();


// const op // const op
if (is_const && is_tracing) { if (is_const && is_tracing) {
@@ -259,6 +261,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
} }


m_tensor = std::make_shared<Tensor>(handle); m_tensor = std::make_shared<Tensor>(handle);
m_tensor->user_custom_name = name;


if (data.ndim() == 0) { if (data.ndim() == 0) {
m_tensor->m_flags |= Tensor::Flags::SCALAR; m_tensor->m_flags |= Tensor::Flags::SCALAR;
@@ -313,6 +316,19 @@ REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info)
#undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC #undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC




#define SET_GET_NAME(member) \
PyObject* TensorWrapper::member() { \
return py::cast(m_tensor->member).release().ptr(); \
} \
void TensorWrapper::set_##member(PyObject* dest) { \
auto py_dest = py::reinterpret_borrow<py::object>(dest); \
m_tensor->member = py_dest.cast<std::string>(); \
}
SET_GET_NAME(user_custom_name)
SET_GET_NAME(automatic_name)
#undef SET_GET_NAME


PyObject* TensorWrapper::handle() { PyObject* TensorWrapper::handle() {
return py::cast(m_tensor->m_handle).release().ptr(); return py::cast(m_tensor->m_handle).release().ptr();
} }
@@ -453,7 +469,11 @@ void TensorWrapper::reset(PyObject* tensor) {
if (!t) { if (!t) {
throw py::type_error("expect Tensor"); throw py::type_error("expect Tensor");
} }
std::string user_custom_name = m_tensor->user_custom_name;
std::string automatic_name = m_tensor->automatic_name;
m_tensor = t->m_tensor; m_tensor = t->m_tensor;
m_tensor->user_custom_name = user_custom_name;
m_tensor->automatic_name = automatic_name;
} }


void TensorWrapper::reset_varnode() { void TensorWrapper::reset_varnode() {
@@ -785,6 +805,8 @@ void init_tensor(py::module m) {
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle")
.def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info") .def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info")
.def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info")
.def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name")
.def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name")
.finalize(); .finalize();
if (!tensor_type) throw py::error_already_set(); if (!tensor_type) throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type); py::setattr(m, "Tensor", tensor_type);


+ 7
- 0
imperative/python/src/tensor.h View File

@@ -15,6 +15,7 @@


#include "megbrain/imperative/interpreter.h" #include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include <string>


#include "./pyext17.h" #include "./pyext17.h"


@@ -70,6 +71,8 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
GradInfo m_grad_info; GradInfo m_grad_info;
TraceInfo m_trace_info; TraceInfo m_trace_info;
SharedHandle m_handle; SharedHandle m_handle;
std::string user_custom_name;
std::string automatic_name;
cg::VarNode* m_var; cg::VarNode* m_var;


using Handle = interpreter::Interpreter::Handle; using Handle = interpreter::Interpreter::Handle;
@@ -170,6 +173,10 @@ struct TensorWrapper {
void set_compiled_info(PyObject *); void set_compiled_info(PyObject *);
PyObject* trace_mixin_info(); PyObject* trace_mixin_info();
void set_trace_mixin_info(PyObject *); void set_trace_mixin_info(PyObject *);
PyObject* user_custom_name();
void set_user_custom_name(PyObject *);
PyObject* automatic_name();
void set_automatic_name(PyObject *);
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
}; };




+ 169
- 0
imperative/python/test/unit/test_dump_naming.py View File

@@ -0,0 +1,169 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io

import numpy as np
import pytest

import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine import Parameter, Tensor
from megengine.core.tensor import megbrain_graph as G
from megengine.jit.tracing import trace
from megengine.utils.naming import auto_naming


def _dump_and_load(func, symbolic, keep_opr_name=True):
auto_naming.clear()
func = trace(func, symbolic=symbolic, capture_as_const=True)
x = Tensor(np.ones(shape=(2, 3)))
func(x).numpy()
file = io.BytesIO()
func.dump(
file,
optimize_for_inference=False,
arg_names="x",
keep_opr_name=keep_opr_name,
keep_var_name=2,
)
file.seek(0)
*_, outputs = G.load_graph(file)
op = cgtools.get_oprs_seq(outputs)[-1]
return op


@pytest.mark.parametrize("symbolic", [False, True])
def test_auto_naming(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name

def forward(self, x):
return x + x

m = Simple("simple")
op = _dump_and_load(m, symbolic)
assert op.name == "simple.ADD"
assert op.outputs[0].name == "simple.ADD"


@pytest.mark.parametrize("symbolic", [False, True])
def test_user_named_tensor(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.k = Parameter(1.0, name="k")

def forward(self, x):
x = x + x
x.name = "o_x"
return x

m = Simple("simple")

op = _dump_and_load(m, symbolic)
assert op.name == "simple.ADD"
assert op.outputs[0].name == "o_x"


@pytest.mark.parametrize("symbolic", [False, True])
def test_user_named_param(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.k = Parameter(2.0, name="k")

def forward(self, x):
return self.k * x

m = Simple("simple")

op = _dump_and_load(m, symbolic)
assert op.inputs[0].name == "x"
assert op.inputs[1].name == "simple.k"


@pytest.mark.parametrize("symbolic", [False, True])
def test_without_module(symbolic):
def f(x):
return 2 * x

op = _dump_and_load(f, symbolic)
assert op.name == "MUL"


@pytest.mark.parametrize("symbolic", [False, True])
def test_with_submodule(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.linear = M.Linear(3, 3)

def forward(self, x):
x = self.linear(x)
return x

m = Simple("simple")

op = _dump_and_load(m, symbolic)
assert op.name == "simple.linear.ADD"
assert op.inputs[0].owner.name == "simple.linear.MatrixMul"
assert op.outputs[0].name == "simple.linear.ADD"


@pytest.mark.parametrize("symbolic", [False, True])
def test_named_submodule(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.linear = M.Linear(3, 3, name="x")

def forward(self, x):
x = self.linear(x)
return x

m = Simple("simple")

op = _dump_and_load(m, symbolic)
assert op.name == "simple.x.ADD"
assert op.inputs[0].owner.name == "simple.x.MatrixMul"
assert op.outputs[0].name == "simple.x.ADD"


@pytest.mark.parametrize("symbolic", [False, True])
def test_with_same_operators(symbolic):
class Simple(M.Module):
def __init__(self, name):
super().__init__()
self.name = name

def forward(self, x):
x = F.relu(x)
x = F.relu(x)
return x

m = Simple("simple")

op = _dump_and_load(m, symbolic)
assert op.name == "simple.RELU[1]"
assert op.inputs[0].owner.name == "simple.RELU[0]"


def test_not_keep_opr_name():
def f(x):
return 2 * x

op = _dump_and_load(f, True, False)
assert op.name == "MUL(x,2[2])[4]"

+ 1
- 1
imperative/python/test/unit/test_tracing.py View File

@@ -148,7 +148,7 @@ def test_dump():
dump_info = f.dump(file) dump_info = f.dump(file)
assert dump_info.nr_opr == 3 assert dump_info.nr_opr == 3
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"])
np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"])
np.testing.assert_equal(dump_info.outputs, ["ADD"])
file.seek(0) file.seek(0)
infer_cg = cgtools.GraphInference(file) infer_cg = cgtools.GraphInference(file)
result = list((infer_cg.run(a, b)).values())[0] result = list((infer_cg.run(a, b)).values())[0]


+ 14
- 4
imperative/src/impl/op_def.cpp View File

@@ -75,10 +75,6 @@ std::vector<std::pair<const char*, std::string>> OpDef::props(
return def.trait()->props(def); return def.trait()->props(def);
} }


const char* OpDef::name() const {
return trait()->name;
}

std::string OpDef::to_string() const { std::string OpDef::to_string() const {
std::string builder = "{"; std::string builder = "{";
for (auto&& [name, value]: props(*this)) { for (auto&& [name, value]: props(*this)) {
@@ -107,6 +103,20 @@ const OpTrait* OpDef::trait() const {
return m_trait; return m_trait;
} }


const std::string OpDef::scope() const {
return m_scope;
}

void OpDef::set_scope(const std::string& scope) {
m_scope = scope;
}

const std::string OpDef::make_name() const {
if (m_scope.empty())
return trait()->make_name(*this);
return m_scope + "." + trait()->make_name(*this);
}

} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb




+ 4
- 1
imperative/src/impl/op_trait.h View File

@@ -75,6 +75,7 @@ using GradMaker = detail::OpMeth<
using Props = detail::OpMeth<decltype(OpDef::props)>; using Props = detail::OpMeth<decltype(OpDef::props)>;
using HashFunc = detail::OpMeth<size_t(const OpDef&)>; using HashFunc = detail::OpMeth<size_t(const OpDef&)>;
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>;
using MakeNameFunc = detail::OpMeth<std::string(const OpDef&)>;


struct OpTrait { struct OpTrait {
const char* name; const char* name;
@@ -88,6 +89,7 @@ struct OpTrait {
Props props; Props props;
HashFunc hash; HashFunc hash;
IsSame is_same_st; IsSame is_same_st;
MakeNameFunc make_name;
OpTrait(const char* name); OpTrait(const char* name);
static OpTrait* find_by_name(const char* name); static OpTrait* find_by_name(const char* name);
static OpTrait* find_by_typeinfo(Typeinfo* type); static OpTrait* find_by_typeinfo(Typeinfo* type);
@@ -104,7 +106,8 @@ struct OpTrait {
cb(make_backward_graph) \ cb(make_backward_graph) \
cb(props) \ cb(props) \
cb(hash) \ cb(hash) \
cb(is_same_st)
cb(is_same_st) \
cb(make_name)


struct OpTraitRegistry { struct OpTraitRegistry {
OpTrait* trait; OpTrait* trait;


+ 3
- 2
imperative/src/impl/ops/batch_norm.cpp View File

@@ -30,13 +30,14 @@ cg::OperatorNodeBase* apply_on_var_node(
size_t nr_inp = inputs.size(); size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 3 ||nr_inp == 5, mgb_assert(nr_inp == 3 ||nr_inp == 5,
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp);
OperatorNodeConfig config{bn_opr.make_name()};
if (nr_inp == 3) { if (nr_inp == 3) {
return opr::BatchNorm::make( return opr::BatchNorm::make(
inputs[0], inputs[1], inputs[2], bn_opr.param())[0]
inputs[0], inputs[1], inputs[2], bn_opr.param(), config)[0]
.node()->owner_opr(); .node()->owner_opr();
} else { } else {
return opr::BatchNorm::make( return opr::BatchNorm::make(
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param())[0]
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param(), config)[0]
.node()->owner_opr(); .node()->owner_opr();
} }
} }


+ 5
- 3
imperative/src/impl/ops/broadcast.cpp View File

@@ -27,10 +27,11 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
cg::OperatorNodeBase* apply_on_var_node( cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
def.cast_final_safe<Broadcast>();
auto&& op = def.cast_final_safe<Broadcast>();
size_t nr_inp = inputs.size(); size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr();
OperatorNodeConfig config{op.make_name()};
return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr();
} }


bool valid_broadcast(const TensorShape& src_shape, bool valid_broadcast(const TensorShape& src_shape,
@@ -96,7 +97,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const Reshape&>(def); auto&& op = static_cast<const Reshape&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
return opr::Reshape::make(inputs[0], inputs[1], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::Reshape::make(inputs[0], inputs[1], op.param(), config);
} }


std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(


+ 1
- 1
imperative/src/impl/ops/collective_comm.cpp View File

@@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node(
auto disable = std::make_shared<DTypeScalar>(); auto disable = std::make_shared<DTypeScalar>();
disable->set(0); disable->set(0);


cg::OperatorNodeConfig config;
OperatorNodeConfig config{comm.make_name()};
if (comm.comp_node.size() > 0) { if (comm.comp_node.size() > 0) {
config.comp_node(CompNode::load(comm.comp_node)); config.comp_node(CompNode::load(comm.comp_node));
} }


+ 2
- 2
imperative/src/impl/ops/cond_take.cpp View File

@@ -23,12 +23,12 @@ namespace {
cg::OperatorNodeBase* apply_on_var_node( cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
def.cast_final_safe<CondTake>();
auto&& op = def.cast_final_safe<CondTake>();
auto&& graph = inputs[0]->owner_graph(); auto&& graph = inputs[0]->owner_graph();


opr::CondTake::Param param; opr::CondTake::Param param;
param.val = 1; param.val = 1;
cg::OperatorNodeConfig config;
OperatorNodeConfig config{op.make_name()};
cg::OperatorNodeBase* opr = graph->insert_opr( cg::OperatorNodeBase* opr = graph->insert_opr(
std::make_unique<opr::CondTake>( std::make_unique<opr::CondTake>(
inputs[0], inputs[1], param, config)); inputs[0], inputs[1], param, config));


+ 2
- 1
imperative/src/impl/ops/elemwise.cpp View File

@@ -31,7 +31,8 @@ cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& elemwise_opr = def.cast_final_safe<Elemwise>(); auto&& elemwise_opr = def.cast_final_safe<Elemwise>();
return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr();
OperatorNodeConfig config{elemwise_opr.make_name()};
return opr::Elemwise::make(inputs, elemwise_opr.mode, config).node()->owner_opr();
} }


std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(


+ 2
- 1
imperative/src/impl/ops/img_proc.cpp View File

@@ -23,7 +23,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const CvtColor&>(def); auto&& op = static_cast<const CvtColor&>(def);
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
return opr::CvtColor::make(inputs[0], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::CvtColor::make(inputs[0], op.param(), config);
} }
OP_TRAIT_REG(CvtColor, CvtColor) OP_TRAIT_REG(CvtColor, CvtColor)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)


+ 4
- 2
imperative/src/impl/ops/io_remote.cpp View File

@@ -32,7 +32,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send(
ssprintf("%s:%d", send.addr.data(), send.port)); ssprintf("%s:%d", send.addr.data(), send.port));
auto&& graph = inputs[0]->owner_graph(); auto&& graph = inputs[0]->owner_graph();


cg::OperatorNodeConfig config;
OperatorNodeConfig config{send.make_name()};
cg::OperatorNodeBase* opr = cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>( graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>(
send.key, inputs[0], group_client, true, config)); send.key, inputs[0], group_client, true, config));
@@ -42,11 +42,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send(
cg::OperatorNodeBase* apply_on_var_node_remote_recv( cg::OperatorNodeBase* apply_on_var_node_remote_recv(
const OpDef& def, const VarNodeArray& inputs) { const OpDef& def, const VarNodeArray& inputs) {
auto&& recv = def.cast_final_safe<RemoteRecv>(); auto&& recv = def.cast_final_safe<RemoteRecv>();
OperatorNodeConfig config{recv.cn};
config.name(recv.make_name());
auto group_client = std::make_shared<GroupClientProxy>( auto group_client = std::make_shared<GroupClientProxy>(
ssprintf("%s:%d", recv.addr.data(), recv.port)); ssprintf("%s:%d", recv.addr.data(), recv.port));
auto&& graph = inputs[0]->owner_graph(); auto&& graph = inputs[0]->owner_graph();
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>(
recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn},
recv.key, inputs[0], *graph, group_client, config,
recv.shape, recv.dtype)); recv.shape, recv.dtype));
} }




+ 3
- 1
imperative/src/impl/ops/matrix_inverse.cpp View File

@@ -21,8 +21,10 @@ namespace {
auto apply_on_var_node( auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<MatrixInverse>();
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
return opr::MatrixInverse::make(inputs[0]);
OperatorNodeConfig config{op.make_name()};
return opr::MatrixInverse::make(inputs[0], {}, config);
} }
OP_TRAIT_REG(MatrixInverse, MatrixInverse) OP_TRAIT_REG(MatrixInverse, MatrixInverse)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)


+ 3
- 1
imperative/src/impl/ops/nms.cpp View File

@@ -29,7 +29,9 @@ cg::OperatorNodeBase* apply_on_var_node(
param.iou_thresh = nms_keep.iou_thresh; param.iou_thresh = nms_keep.iou_thresh;
param.max_output = nms_keep.max_output; param.max_output = nms_keep.max_output;


return NMSKeepOpr::make(inputs[0], param).node()->owner_opr();
OperatorNodeConfig config{nms_keep.make_name()};

return NMSKeepOpr::make(inputs[0], param, config).node()->owner_opr();
} }


OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr) OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr)


+ 8
- 1
imperative/src/impl/ops/opr_attr.cpp View File

@@ -79,11 +79,13 @@ public:
cg::OperatorNodeBase* apply_on_var_node( cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) { const OpDef& def, const VarNodeArray& inputs) {
auto&& attr = def.cast_final_safe<OprAttr>(); auto&& attr = def.cast_final_safe<OprAttr>();
auto config = attr.config;
config.name(attr.make_name());
mgb_assert(!inputs.empty()); mgb_assert(!inputs.empty());
auto registry = serialization::OprRegistry::find_by_name(attr.type); auto registry = serialization::OprRegistry::find_by_name(attr.type);
mgb_assert(registry, "operator %s not found", attr.type.c_str()); mgb_assert(registry, "operator %s not found", attr.type.c_str());
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()};
return registry->loader(ctx, inputs, attr.config);
return registry->loader(ctx, inputs, config);
} }


std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
@@ -99,10 +101,15 @@ std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
return {}; return {};
} }


std::string make_name(const OpDef& def) {
return "OprAttr";
}

OP_TRAIT_REG(OprAttr, OprAttr) OP_TRAIT_REG(OprAttr, OprAttr)
.make_from_op_node(make_from_op_node) .make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.props(props) .props(props)
.make_name(make_name)
.fallback(); .fallback();


} // anonymous namespace } // anonymous namespace


+ 2
- 1
imperative/src/impl/ops/resize.cpp View File

@@ -24,7 +24,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const Resize&>(def); auto&& op = static_cast<const Resize&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
return opr::Resize::make(inputs[0], inputs[1], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::Resize::make(inputs[0], inputs[1], op.param(), config);
} }


OP_TRAIT_REG(Resize, Resize) OP_TRAIT_REG(Resize, Resize)


+ 84
- 36
imperative/src/impl/ops/specializations.cpp View File

@@ -46,7 +46,8 @@ auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& conv = static_cast<const Convolution&>(def); auto&& conv = static_cast<const Convolution&>(def);
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy());
OperatorNodeConfig config{conv.make_name()};
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} }


OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) OP_TRAIT_REG(Convolution, Convolution, opr::Convolution)
@@ -60,7 +61,7 @@ auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& conv = static_cast<const ConvolutionBackwardData&>(def); auto&& conv = static_cast<const ConvolutionBackwardData&>(def);
cg::OperatorNodeConfig config;
OperatorNodeConfig config{conv.make_name()};
if (inputs.size() == 2) { if (inputs.size() == 2) {
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else { } else {
@@ -88,7 +89,8 @@ auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& ds = static_cast<const Dimshuffle&>(def); auto&& ds = static_cast<const Dimshuffle&>(def);
return opr::Dimshuffle::make(inputs[0], ds.pattern);
OperatorNodeConfig config{ds.make_name()};
return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
} }


OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
@@ -107,7 +109,8 @@ auto apply_on_var_node(
for (auto&& i : add_axis.axis) { for (auto&& i : add_axis.axis) {
param.push_back(Desc::make_add(i)); param.push_back(Desc::make_add(i));
} }
return opr::AxisAddRemove::make(inputs[0], param);
OperatorNodeConfig config{add_axis.make_name()};
return opr::AxisAddRemove::make(inputs[0], param, config);
} }


OP_TRAIT_REG(AddAxis, AddAxis) OP_TRAIT_REG(AddAxis, AddAxis)
@@ -125,7 +128,8 @@ auto apply_on_var_node(
for (auto&& i : remove_axis.axis) { for (auto&& i : remove_axis.axis) {
param.push_back(Desc::make_remove(i)); param.push_back(Desc::make_remove(i));
} }
return opr::AxisAddRemove::make(inputs[0], param);
OperatorNodeConfig config{remove_axis.make_name()};
return opr::AxisAddRemove::make(inputs[0], param, config);
} }


OP_TRAIT_REG(RemoveAxis, RemoveAxis) OP_TRAIT_REG(RemoveAxis, RemoveAxis)
@@ -138,7 +142,8 @@ auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& topk = static_cast<const TopK&>(def); auto&& topk = static_cast<const TopK&>(def);
return opr::TopK::make(inputs[0], inputs[1], topk.param())[0]
OperatorNodeConfig config{topk.make_name()};
return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0]
.node()->owner_opr(); .node()->owner_opr();
} }


@@ -152,10 +157,12 @@ auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& reduce = static_cast<const Reduce&>(def); auto&& reduce = static_cast<const Reduce&>(def);
OperatorNodeConfig config{reduce.make_name()};
if (inputs.size() > 1) { if (inputs.size() > 1) {
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]);
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config);
} else { } else {
return opr::Reduce::make(inputs[0], reduce.param());
return opr::Reduce::make(
inputs[0], reduce.param(), (cg::VarNode*)nullptr, config);
} }
} }


@@ -175,7 +182,8 @@ auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& pool = static_cast<const AdaptivePooling&>(def); auto&& pool = static_cast<const AdaptivePooling&>(def);
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param());
OperatorNodeConfig config{pool.make_name()};
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config);
} }


OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) OP_TRAIT_REG(AdaptivePooling, AdaptivePooling)
@@ -189,6 +197,7 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& conv = static_cast<const ConvBias&>(def); auto&& conv = static_cast<const ConvBias&>(def);
cg::OperatorNodeConfig config{conv.dtype}; cg::OperatorNodeConfig config{conv.dtype};
config.name(conv.make_name());
if (inputs.size() == 2) { if (inputs.size() == 2) {
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else if (inputs.size() == 3) { } else if (inputs.size() == 3) {
@@ -210,6 +219,7 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& conv = static_cast<const BatchConvBias&>(def); auto&& conv = static_cast<const BatchConvBias&>(def);
cg::OperatorNodeConfig config{conv.dtype}; cg::OperatorNodeConfig config{conv.dtype};
config.name(conv.make_name());
if (inputs.size() == 2) { if (inputs.size() == 2) {
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config);
} else if (inputs.size() == 3) { } else if (inputs.size() == 3) {
@@ -230,7 +240,8 @@ auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& pool = static_cast<const Pooling&>(def); auto&& pool = static_cast<const Pooling&>(def);
return opr::Pooling::make(inputs[0], pool.param());
OperatorNodeConfig config{pool.make_name()};
return opr::Pooling::make(inputs[0], pool.param(), config);
} }
OP_TRAIT_REG(Pooling, Pooling) OP_TRAIT_REG(Pooling, Pooling)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -243,8 +254,9 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& matmul = static_cast<const MatrixMul&>(def); auto&& matmul = static_cast<const MatrixMul&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(), return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(),
matmul.policy());
matmul.policy(), config);
} }
OP_TRAIT_REG(MatrixMul, MatrixMul) OP_TRAIT_REG(MatrixMul, MatrixMul)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -257,8 +269,9 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& matmul = static_cast<const BatchedMatrixMul&>(def); auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
OperatorNodeConfig config{matmul.make_name()};
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(), return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(),
matmul.policy());
matmul.policy(), config);
} }
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -267,10 +280,12 @@ OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)


namespace { namespace dot { namespace { namespace dot {
auto apply_on_var_node( auto apply_on_var_node(
const OpDef&,
const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Dot>();
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
return opr::Dot::make(inputs[0], inputs[1]);
OperatorNodeConfig config{op.make_name()};
return opr::Dot::make(inputs[0], inputs[1], config);
} }
OP_TRAIT_REG(Dot, Dot) OP_TRAIT_REG(Dot, Dot)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -282,7 +297,8 @@ auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& argsort = static_cast<const Argsort&>(def); auto&& argsort = static_cast<const Argsort&>(def);
return opr::Argsort::make(inputs[0], argsort.param());
OperatorNodeConfig config{argsort.make_name()};
return opr::Argsort::make(inputs[0], argsort.param(), config);
} }
OP_TRAIT_REG(Argsort, Argsort) OP_TRAIT_REG(Argsort, Argsort)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -294,7 +310,8 @@ auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& argmax = static_cast<const Argmax&>(def); auto&& argmax = static_cast<const Argmax&>(def);
return opr::Argmax::make(inputs[0], argmax.param());
OperatorNodeConfig config{argmax.make_name()};
return opr::Argmax::make(inputs[0], argmax.param(), config);
} }
OP_TRAIT_REG(Argmax, Argmax) OP_TRAIT_REG(Argmax, Argmax)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -306,7 +323,8 @@ auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& argmin = static_cast<const Argmin&>(def); auto&& argmin = static_cast<const Argmin&>(def);
return opr::Argmin::make(inputs[0], argmin.param());
OperatorNodeConfig config{argmin.make_name()};
return opr::Argmin::make(inputs[0], argmin.param(), config);
} }
OP_TRAIT_REG(Argmin, Argmin) OP_TRAIT_REG(Argmin, Argmin)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -318,11 +336,13 @@ auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& warp = static_cast<const WarpPerspective&>(def); auto&& warp = static_cast<const WarpPerspective&>(def);
OperatorNodeConfig config{warp.make_name()};
if (inputs.size() == 3) { if (inputs.size() == 3) {
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param());
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config);
} else { } else {
mgb_assert(inputs.size() == 4); mgb_assert(inputs.size() == 4);
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param());
return opr::WarpPerspective::make(
inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config);
} }
} }
OP_TRAIT_REG(WarpPerspective, WarpPerspective) OP_TRAIT_REG(WarpPerspective, WarpPerspective)
@@ -336,7 +356,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& local = static_cast<const GroupLocal&>(def); auto&& local = static_cast<const GroupLocal&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
return opr::GroupLocal::make(inputs[0], inputs[1], local.param());
OperatorNodeConfig config{local.make_name()};
return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config);
} }
OP_TRAIT_REG(GroupLocal, GroupLocal) OP_TRAIT_REG(GroupLocal, GroupLocal)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -349,7 +370,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingOneHot&>(def); auto&& op = static_cast<const IndexingOneHot&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config);
} }
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) OP_TRAIT_REG(IndexingOneHot, IndexingOneHot)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -362,7 +384,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const IndexingSetOneHot&>(def); auto&& op = static_cast<const IndexingSetOneHot&>(def);
mgb_assert(inputs.size() == 3); mgb_assert(inputs.size() == 3);
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config);
} }
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -375,7 +398,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const TypeCvt&>(def); auto&& op = static_cast<const TypeCvt&>(def);
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
return opr::TypeCvt::make(inputs[0], op.dtype);
OperatorNodeConfig config{op.make_name()};
return opr::TypeCvt::make(inputs[0], op.dtype, config);
} }
OP_TRAIT_REG(TypeCvt, TypeCvt) OP_TRAIT_REG(TypeCvt, TypeCvt)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -388,6 +412,7 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const Concat&>(def); auto&& op = static_cast<const Concat&>(def);
cg::OperatorNodeConfig config{op.comp_node}; cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
return opr::Concat::make(inputs, op.axis, config); return opr::Concat::make(inputs, op.axis, config);
} }
OP_TRAIT_REG(Concat, Concat) OP_TRAIT_REG(Concat, Concat)
@@ -402,6 +427,7 @@ auto apply_on_var_node(
auto&& op = static_cast<const Copy&>(def); auto&& op = static_cast<const Copy&>(def);
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
cg::OperatorNodeConfig config{op.comp_node}; cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
return opr::Copy::make(inputs[0], config); return opr::Copy::make(inputs[0], config);
} }
OP_TRAIT_REG(Copy, Copy) OP_TRAIT_REG(Copy, Copy)
@@ -411,10 +437,12 @@ OP_TRAIT_REG(Copy, Copy)


namespace { namespace identity { namespace { namespace identity {
auto apply_on_var_node( auto apply_on_var_node(
const OpDef&,
const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Identity>();
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
return opr::Identity::make(inputs[0]);
OperatorNodeConfig config{op.make_name()};
return opr::Identity::make(inputs[0], config);
} }
OP_TRAIT_REG(Identity, Identity) OP_TRAIT_REG(Identity, Identity)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -427,7 +455,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const AssertEqual&>(def); auto&& op = static_cast<const AssertEqual&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
return opr::AssertEqual::make(inputs[0],inputs[1],op.param());
OperatorNodeConfig config{op.make_name()};
return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config);


} }


@@ -443,7 +472,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const UniformRNG&>(def); auto&& op = static_cast<const UniformRNG&>(def);
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
return opr::UniformRNG::make(inputs[0], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::UniformRNG::make(inputs[0], op.param(), config);
} }
OP_TRAIT_REG(UniformRNG, UniformRNG) OP_TRAIT_REG(UniformRNG, UniformRNG)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -456,7 +486,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const GaussianRNG&>(def); auto&& op = static_cast<const GaussianRNG&>(def);
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
return opr::GaussianRNG::make(inputs[0], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::GaussianRNG::make(inputs[0], op.param(), config);
} }
OP_TRAIT_REG(GaussianRNG, GaussianRNG) OP_TRAIT_REG(GaussianRNG, GaussianRNG)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -469,7 +500,9 @@ VarNodeArray apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIAlign&>(def); auto&& op = static_cast<const ROIAlign&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param()).node()->owner_opr();
OperatorNodeConfig config{op.make_name()};
auto* opr = opr::ROIAlign::make(
inputs[0], inputs[1], op.param(), config).node()->owner_opr();
return {opr->output(0), opr->output(1)}; return {opr->output(0), opr->output(1)};
} }
OP_TRAIT_REG(ROIAlign, ROIAlign) OP_TRAIT_REG(ROIAlign, ROIAlign)
@@ -484,7 +517,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const NvOf&>(def); auto&& op = static_cast<const NvOf&>(def);
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
return opr::NvOf::make(inputs[0], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::NvOf::make(inputs[0], op.param(), config);
} }
OP_TRAIT_REG(NvOf, NvOf) OP_TRAIT_REG(NvOf, NvOf)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -499,6 +533,7 @@ auto apply_on_var_node(
auto&& op = static_cast<const Linspace&>(def); auto&& op = static_cast<const Linspace&>(def);
mgb_assert(inputs.size() == 3); mgb_assert(inputs.size() == 3);
cg::OperatorNodeConfig config{op.comp_node}; cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config);
} }
OP_TRAIT_REG(Linspace, Linspace) OP_TRAIT_REG(Linspace, Linspace)
@@ -513,6 +548,7 @@ auto apply_on_var_node(
auto&& op = static_cast<const Eye&>(def); auto&& op = static_cast<const Eye&>(def);
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
cg::OperatorNodeConfig config{op.comp_node}; cg::OperatorNodeConfig config{op.comp_node};
config.name(op.make_name());
opr::Eye::Param param{op.k, op.dtype.enumv()}; opr::Eye::Param param{op.k, op.dtype.enumv()};
return opr::Eye::make(inputs[0], param, config); return opr::Eye::make(inputs[0], param, config);
} }
@@ -527,7 +563,10 @@ VarNodeArray apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const ROIPooling&>(def); auto&& op = static_cast<const ROIPooling&>(def);
mgb_assert(inputs.size() == 3); mgb_assert(inputs.size() == 3);
auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()).node()->owner_opr();
OperatorNodeConfig config{op.make_name()};
auto* opr = opr::ROIPooling::make(
inputs[0], inputs[1], inputs[2], op.param(), config
).node()->owner_opr();
return {opr->output(0), opr->output(1)}; return {opr->output(0), opr->output(1)};
} }
OP_TRAIT_REG(ROIPooling, ROIPooling) OP_TRAIT_REG(ROIPooling, ROIPooling)
@@ -541,7 +580,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const Remap&>(def); auto&& op = static_cast<const Remap&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
return opr::Remap::make(inputs[0], inputs[1], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::Remap::make(inputs[0], inputs[1], op.param(), config);
} }
OP_TRAIT_REG(Remap, Remap) OP_TRAIT_REG(Remap, Remap)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
@@ -578,7 +618,8 @@ auto apply_on_var_node( \
const OpDef& def, \ const OpDef& def, \
const VarNodeArray& inputs) { \ const VarNodeArray& inputs) { \
auto&& op = static_cast<const NAME&>(def); \ auto&& op = static_cast<const NAME&>(def); \
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \
OperatorNodeConfig config{op.make_name()}; \
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \
} \ } \
OP_TRAIT_REG(NAME, NAME) \ OP_TRAIT_REG(NAME, NAME) \
.apply_on_var_node(apply_on_var_node) \ .apply_on_var_node(apply_on_var_node) \
@@ -609,30 +650,35 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const FakeQuant&>(def); auto&& op = static_cast<const FakeQuant&>(def);
mgb_assert(inputs.size() == 3); mgb_assert(inputs.size() == 3);
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config);
} }
OP_TRAIT_REG(FakeQuant, FakeQuant) OP_TRAIT_REG(FakeQuant, FakeQuant)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.fallback(); .fallback();
}} // fake_quant }} // fake_quant

namespace { namespace tqt { namespace { namespace tqt {
auto apply_on_var_node( auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const TQT&>(def); auto&& op = static_cast<const TQT&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
return opr::TQT::make(inputs[0], inputs[1], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::TQT::make(inputs[0], inputs[1], op.param(), config);
} }
OP_TRAIT_REG(TQT, TQT) OP_TRAIT_REG(TQT, TQT)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.fallback(); .fallback();
}} // tqt }} // tqt

namespace { namespace elemwise_multi_type { namespace { namespace elemwise_multi_type {
auto apply_on_var_node( auto apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const ElemwiseMultiType&>(def); auto&& op = static_cast<const ElemwiseMultiType&>(def);
OperatorNodeConfig config{op.dtype}; OperatorNodeConfig config{op.dtype};
config.name(op.make_name());
return opr::ElemwiseMultiType::make(inputs, op.param(), config); return opr::ElemwiseMultiType::make(inputs, op.param(), config);
} }
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType)
@@ -646,7 +692,9 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const SVD&>(def); auto&& op = static_cast<const SVD&>(def);
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
return opr::SVD::make(inputs[0], op.param())[0].node()->owner_opr()->usable_output();
OperatorNodeConfig config{op.make_name()};
return opr::SVD::make(inputs[0], op.param(), config)[0]
.node()->owner_opr()->usable_output();
} }
OP_TRAIT_REG(SVD, SVD) OP_TRAIT_REG(SVD, SVD)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)


+ 4
- 3
imperative/src/impl/ops/tensor_manip.cpp View File

@@ -21,7 +21,8 @@ cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op_def = def.cast_final_safe<GetVarShape>(); auto&& op_def = def.cast_final_safe<GetVarShape>();
return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr();
OperatorNodeConfig config{op_def.make_name()};
return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr();
} }


DispatchMode decide_dispatch_mode( DispatchMode decide_dispatch_mode(
@@ -152,7 +153,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
auto&& graph = inputs[0]->owner_graph(); auto&& graph = inputs[0]->owner_graph();


auto&& shapes = get_shapes(param.shapes); auto&& shapes = get_shapes(param.shapes);
cg::OperatorNodeConfig config;
OperatorNodeConfig config(param.make_name());
cg::OperatorNodeBase* opr = cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>( graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>(
inputs[0], param.offsets, shapes, config)); inputs[0], param.offsets, shapes, config));
@@ -189,7 +190,7 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
auto&& graph = inputs[0]->owner_graph(); auto&& graph = inputs[0]->owner_graph();


VarNodeArray inps(inputs.begin(), inputs.end() - 1); VarNodeArray inps(inputs.begin(), inputs.end() - 1);
cg::OperatorNodeConfig config;
OperatorNodeConfig config{param.make_name()};
cg::OperatorNodeBase* opr = cg::OperatorNodeBase* opr =
graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>( graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>(
inps, inputs.back(), param.offsets, config)); inps, inputs.back(), param.offsets, config));


+ 2
- 1
imperative/src/impl/ops/tensorrt_runtime.cpp View File

@@ -20,8 +20,9 @@ namespace { namespace tensorrt_runtime {
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& op = static_cast<const TensorRTRuntime&>(def); auto&& op = static_cast<const TensorRTRuntime&>(def);
OperatorNodeConfig config{op.make_name()};
SymbolVarArray sinputs(inputs.begin(), inputs.end()); SymbolVarArray sinputs(inputs.begin(), inputs.end());
return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs);
return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs, config);
} }
OP_TRAIT_REG(TensorRTRuntime, TensorRTRuntime) OP_TRAIT_REG(TensorRTRuntime, TensorRTRuntime)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)


+ 2
- 1
imperative/src/impl/ops/warp_affine.cpp View File

@@ -21,7 +21,8 @@ namespace { namespace warp_affine {
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 3); mgb_assert(inputs.size() == 3);
auto&& op = static_cast<const WarpAffine&>(def); auto&& op = static_cast<const WarpAffine&>(def);
return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param());
OperatorNodeConfig config{op.make_name()};
return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param(), config);
} }


OP_TRAIT_REG(WarpAffine, WarpAffine) OP_TRAIT_REG(WarpAffine, WarpAffine)


+ 7
- 2
imperative/src/include/megbrain/imperative/op_def.h View File

@@ -36,6 +36,7 @@ class OpDef : public Hashable,
public NonCopyableObj, public NonCopyableObj,
public std::enable_shared_from_this<OpDef> { public std::enable_shared_from_this<OpDef> {
mutable const OpTrait* m_trait = nullptr; mutable const OpTrait* m_trait = nullptr;
std::string m_scope;
public: public:
virtual ~OpDef() = default; virtual ~OpDef() = default;


@@ -86,10 +87,14 @@ public:


const OpTrait* trait() const; const OpTrait* trait() const;


const char* name() const;

std::string to_string() const; std::string to_string() const;


const std::string scope() const;

const std::string make_name() const;

void set_scope(const std::string& scope);

virtual size_t hash() const; virtual size_t hash() const;


virtual bool is_same_st(const Hashable&) const; virtual bool is_same_st(const Hashable&) const;


+ 41
- 5
imperative/tablegen/autogen.cpp View File

@@ -113,9 +113,10 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
"{0}({0}_)", i.name "{0}({0}_)", i.name
)); ));
} }
paramList.push_back("std::string scope_ = {}");
gen_ctor(llvm::join(paramList, ", "), gen_ctor(llvm::join(paramList, ", "),
": " + llvm::join(initList, ", "), ": " + llvm::join(initList, ", "),
" {}");
" { set_scope(scope_); }");
} }


auto packedParams = op.getPackedParams(); auto packedParams = op.getPackedParams();
@@ -236,11 +237,19 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx);
os << "}\n"; os << "}\n";


// generate make_name()
os << formatv(
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name")
);
os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx);
os << "}\n";

os << "} // anonymous namespace\n"; os << "} // anonymous namespace\n";


methods.push_back("hash"); methods.push_back("hash");
methods.push_back("is_same_st"); methods.push_back("is_same_st");
methods.push_back("props"); methods.push_back("props");
methods.push_back("make_name");
} }
if (!methods.empty()) { if (!methods.empty()) {
os << formatv( os << formatv(
@@ -327,7 +336,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
targs.push_back(i.attr.getReturnType()); targs.push_back(i.attr.getReturnType());
} }
os << llvm::join(targs, ", "); os << llvm::join(targs, ", ");
os << ">()";
os << ", std::string>()";
for (auto &&i : op.getMgbAttributes()) { for (auto &&i : op.getMgbAttributes()) {
os << formatv(", py::arg(\"{0}\")", i.name); os << formatv(", py::arg(\"{0}\")", i.name);
auto defaultValue = i.attr.getDefaultValue(); auto defaultValue = i.attr.getDefaultValue();
@@ -337,7 +346,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
hasDefaultCtor = true; hasDefaultCtor = true;
} }
} }
os << ")";
os << ", py::arg(\"scope\") = {})";
} }
if (hasDefaultCtor) { if (hasDefaultCtor) {
os << "\n .def(py::init<>())"; os << "\n .def(py::init<>())";
@@ -442,6 +451,10 @@ EnumWrapper<{0}::{1}>::type2str = {{
className, i.name)); className, i.name));
} }


getsetters.push_back(formatv(
"{{\"scope\", py_get_scope({0}), py_set_scope({0}), \"scope\", NULL},",
className));

// generate tp_init // generate tp_init
std::string initBody; std::string initBody;
if (!op.getMgbAttributes().empty()) { if (!op.getMgbAttributes().empty()) {
@@ -449,6 +462,7 @@ EnumWrapper<{0}::{1}>::type2str = {{
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv("\"{0}\", ", attr.name); initBody += formatv("\"{0}\", ", attr.name);
}); });
initBody += "\"scope\", ";
initBody += "NULL};\n"; initBody += "NULL};\n";
initBody += " PyObject "; initBody += " PyObject ";
std::vector<std::string> attrs; std::vector<std::string> attrs;
@@ -456,12 +470,15 @@ EnumWrapper<{0}::{1}>::type2str = {{
attrs.push_back(formatv("*{0} = NULL", attr.name)); attrs.push_back(formatv("*{0} = NULL", attr.name));
}); });
initBody += llvm::join(attrs, ", ") + ";\n"; initBody += llvm::join(attrs, ", ") + ";\n";
initBody += " PyObject *scope = NULL;\n";
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|";
initBody += std::string(op.getMgbAttributes().size(), 'O');
// an extra slot created for name
initBody += std::string(op.getMgbAttributes().size() + 1, 'O');
initBody += "\", const_cast<char**>(kwlist)"; initBody += "\", const_cast<char**>(kwlist)";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
initBody += formatv(" ,&{0}", attr.name);
initBody += formatv(", &{0}", attr.name);
}); });
initBody += ", &scope";
initBody += "))\n"; initBody += "))\n";
initBody += " return -1;\n"; initBody += " return -1;\n";
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) {
@@ -483,6 +500,25 @@ EnumWrapper<{0}::{1}>::type2str = {{
} }
)", className, attr.name); )", className, attr.name);
}); });

initBody += formatv(R"(
if (scope) {{
try {{
reinterpret_cast<PyOp({0})*>(self)->inst().set_scope(
pyobj_convert_generic<std::string>::from(scope));
} catch(py::error_already_set& e) {{
e.restore();
return -1;
} catch(py::builtin_exception& e) {{
e.set_error();
return -1;
} catch(...) {{
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
return -1;
}
}
)", className);

} }
initBody += "\n return 0;"; initBody += "\n return 0;";




+ 30
- 0
imperative/tablegen/helper.h View File

@@ -241,6 +241,30 @@ private:
body += " return props_;\n"; body += " return props_;\n";
return body; return body;
} }
std::string getModeName() const {
std::string body = formatv(
" auto&& op_ = def_.cast_final_safe<{0}>();\n"
" static_cast<void>(op_);\n",
getCppClassName()
);
for (auto&& it : getMgbAttributes()) {
if (it.name == "mode") {
auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr);
body += " switch (op_.mode){\n";
for (auto&& enumMember: enumAttr->getEnumMembers()) {
body += formatv(
" case {0}::{1}::{2}:\n",
getCppClassName(), enumAttr->getEnumName(), enumMember
);
body += formatv(" return \"{0}\";\n", enumMember);
}
body += formatv(
" default: return \"{0}::Unknown\";\n", getCppClassName());
body += " }\n";
}
}
return body;
}
public: public:
static bool classof(const Operator* op) { static bool classof(const Operator* op) {
return op->getDef().isSubClassOf("MgbHashableOpMixin"); return op->getDef().isSubClassOf("MgbHashableOpMixin");
@@ -264,6 +288,12 @@ public:
} }
return getDefaultPropsFunction(); return getDefaultPropsFunction();
} }
std::string getNameFunctionTemplate() const {
if (getDef().getValueAsBit("usingModeName")) {
return getModeName();
}
return formatv(" return \"{0}\";\n", getCppClassName());
}
}; };


} // namespace tblgen } // namespace tblgen


+ 1
- 0
sdk/load-and-run/dump_with_testcase_mge.py View File

@@ -476,6 +476,7 @@ def main():


output_mgbvars = feeds["outputs"] output_mgbvars = feeds["outputs"]
output_mgbvars = optimize_for_inference(args, output_mgbvars) output_mgbvars = optimize_for_inference(args, output_mgbvars)
output_mgbvars = [var._node for var in output_mgbvars]


inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy")
inputs = sorted((i.name, i.dtype) for i in inputs) inputs = sorted((i.name, i.dtype) for i in inputs)


+ 1
- 0
src/core/include/megbrain/ir/base.td View File

@@ -242,6 +242,7 @@ class MgbPackedParamBase<string className, string accessor>:
class MgbHashableOpMixin { class MgbHashableOpMixin {
string hashFunction = ?; string hashFunction = ?;
string cmpFunction = ?; string cmpFunction = ?;
bit usingModeName = 0;
} }


class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>:


+ 2
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
let inputs = (ins Variadic<AnyType>:$input); let inputs = (ins Variadic<AnyType>:$input);
let results = (outs AnyType); let results = (outs AnyType);
let usingModeName = 1;
} }


def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>;
@@ -247,6 +248,7 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara
let extraArguments = (ins let extraArguments = (ins
MgbDTypeAttr:$dtype MgbDTypeAttr:$dtype
); );
let usingModeName = 1;
} }


def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;


Loading…
Cancel
Save