GitOrigin-RevId: ff8eb003c5
tags/v1.3.0
@@ -96,7 +96,7 @@ class Graph(_imperative_rt.ComputingGraph): | |||||
data = data.numpy() | data = data.numpy() | ||||
return self._wrap(_imperative_rt.make_const(self, data, device, data.dtype)) | return self._wrap(_imperative_rt.make_const(self, data, device, data.dtype)) | ||||
def make_const(self, data, dtype=None, device=None): | |||||
def make_const(self, data, dtype=None, device=None, name=None): | |||||
if isinstance(data, _imperative_rt.DeviceTensorND): | if isinstance(data, _imperative_rt.DeviceTensorND): | ||||
assert dtype is None and device is None | assert dtype is None and device is None | ||||
return self._wrap(_imperative_rt.make_shared(self, data)) | return self._wrap(_imperative_rt.make_shared(self, data)) | ||||
@@ -107,7 +107,9 @@ class Graph(_imperative_rt.ComputingGraph): | |||||
elif data.dtype == np.int64: | elif data.dtype == np.int64: | ||||
data = data.astype(np.int32) | data = data.astype(np.int32) | ||||
device = as_device(device).to_c() | device = as_device(device).to_c() | ||||
return self._wrap(_imperative_rt.make_const(self, data, device, dtype)) | |||||
return self._wrap( | |||||
_imperative_rt.make_const(self, data, device, dtype, name) | |||||
) | |||||
def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None): | def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None): | ||||
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | ||||
@@ -305,7 +307,7 @@ def dump_graph( | |||||
output_vars: Union[Dict[str, VarNode], List[VarNode]], | output_vars: Union[Dict[str, VarNode], List[VarNode]], | ||||
*, | *, | ||||
keep_var_name: int = 1, | keep_var_name: int = 1, | ||||
keep_op_name: bool = True, | |||||
keep_opr_name: bool = False, | |||||
keep_param_name: bool = False, | keep_param_name: bool = False, | ||||
keep_opr_priority: bool = False, | keep_opr_priority: bool = False, | ||||
strip_info_file=None, | strip_info_file=None, | ||||
@@ -326,7 +328,7 @@ def dump_graph( | |||||
* 0: none of the names are kept | * 0: none of the names are kept | ||||
* 1: (default)keep names of output vars | * 1: (default)keep names of output vars | ||||
* 2: keep names of all (output and internal) vars | * 2: keep names of all (output and internal) vars | ||||
:param keep_op_name: whether to keep operator names. | |||||
:param keep_opr_name: whether to keep operator names. | |||||
:param keep_param_name: whether to keep param names, so param values can be | :param keep_param_name: whether to keep param names, so param values can be | ||||
easily manipulated after loading model | easily manipulated after loading model | ||||
:param keep_opr_priority: whether to keep priority setting for operators | :param keep_opr_priority: whether to keep priority setting for operators | ||||
@@ -370,7 +372,7 @@ def dump_graph( | |||||
dump_content = _imperative_rt.dump_graph( | dump_content = _imperative_rt.dump_graph( | ||||
ov, | ov, | ||||
keep_var_name, | keep_var_name, | ||||
keep_op_name, | |||||
keep_opr_name, | |||||
keep_param_name, | keep_param_name, | ||||
keep_opr_priority, | keep_opr_priority, | ||||
stat, | stat, | ||||
@@ -36,6 +36,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef | |||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
from ..core.tensor.utils import setscalar | from ..core.tensor.utils import setscalar | ||||
from ..utils.naming import auto_naming | |||||
from .sublinear_memory_config import SublinearMemoryConfig | from .sublinear_memory_config import SublinearMemoryConfig | ||||
@@ -77,6 +78,7 @@ def exclude_from_trace(): | |||||
class TensorInfo: | class TensorInfo: | ||||
__slots__ = ( | __slots__ = ( | ||||
# collected attributes | # collected attributes | ||||
"name", | |||||
"external", | "external", | ||||
"data_read", | "data_read", | ||||
"shape_read", | "shape_read", | ||||
@@ -96,6 +98,7 @@ class TensorInfo: | |||||
) | ) | ||||
def __init__(self): | def __init__(self): | ||||
self.name = None | |||||
self.exported = None | self.exported = None | ||||
self.data_read = None | self.data_read = None | ||||
self.shape_read = None | self.shape_read = None | ||||
@@ -290,12 +293,16 @@ class trace: | |||||
h = getattr(x, "_mixin_handle", -1) | h = getattr(x, "_mixin_handle", -1) | ||||
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): | if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): | ||||
h, info = self._new_handle() | h, info = self._new_handle() | ||||
name = auto_naming.get_scope() + "." + x.c_name if x.c_name else x._name | |||||
info.name = name | |||||
info.external = True | info.external = True | ||||
info.device = x.device | info.device = x.device | ||||
info.dtype = x.dtype | info.dtype = x.dtype | ||||
info.shape = x.shape | info.shape = x.shape | ||||
if self._capture_as_const: | if self._capture_as_const: | ||||
info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False) | |||||
info.bound_data = RawTensor( | |||||
x.numpy(), x.dtype, x.device, False, name | |||||
) | |||||
ihandles.append(h) | ihandles.append(h) | ||||
@@ -669,6 +676,12 @@ class trace: | |||||
arg_names=None, | arg_names=None, | ||||
output_names=None, | output_names=None, | ||||
append=False, | append=False, | ||||
keep_var_name: int = 1, | |||||
keep_opr_name: bool = False, | |||||
keep_param_name: bool = False, | |||||
keep_opr_priority: bool = False, | |||||
strip_info_file=None, | |||||
append_json=False, | |||||
optimize_for_inference=True, | optimize_for_inference=True, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
@@ -681,6 +694,20 @@ class trace: | |||||
use the default name if not specified. | use the default name if not specified. | ||||
:param append: whether output is appended to ``file``. | :param append: whether output is appended to ``file``. | ||||
Only works when ``file`` is str. | Only works when ``file`` is str. | ||||
:param keep_var_name: level for keeping variable names: | |||||
* 0: none of the names are kept | |||||
* 1: (default)keep names of output vars | |||||
* 2: keep names of all (output and internal) vars | |||||
:param keep_opr_name: whether to keep operator names. | |||||
:param keep_param_name: whether to keep param names, so param values can be | |||||
easily manipulated after loading model | |||||
:param keep_opr_priority: whether to keep priority setting for operators | |||||
:param strip_info_file: a string for path or a file handler. if is not None, | |||||
then the dump information for code strip would be written to ``strip_info_file`` | |||||
:param append_json: will be check when `strip_info_file` is not None. if set | |||||
true, the information for code strip will be append to strip_info_file. | |||||
if set false, will rewrite strip_info_file | |||||
:param optimize_for_inference: enbale optmizations, | :param optimize_for_inference: enbale optmizations, | ||||
will skip all optimize options if this is False. Default: True | will skip all optimize options if this is False. Default: True | ||||
@@ -785,7 +812,10 @@ class trace: | |||||
assert info.external | assert info.external | ||||
assert info.bound_data | assert info.bound_data | ||||
h2v[h] = graph.make_const( | h2v[h] = graph.make_const( | ||||
info.bound_data.numpy(), dtype=info.dtype, device=info.device, | |||||
info.bound_data.numpy(), | |||||
dtype=info.dtype, | |||||
device=info.device, | |||||
name=info.name, | |||||
) | ) | ||||
continue | continue | ||||
ivars = [] | ivars = [] | ||||
@@ -795,13 +825,26 @@ class trace: | |||||
assert info.external | assert info.external | ||||
assert info.bound_data | assert info.bound_data | ||||
h2v[h] = graph.make_const( | h2v[h] = graph.make_const( | ||||
info.bound_data.numpy(), dtype=info.dtype, device=dumped_device | |||||
info.bound_data.numpy(), | |||||
dtype=info.dtype, | |||||
device=dumped_device, | |||||
name=info.name, | |||||
) | ) | ||||
ivars.append(h2v[h]) | ivars.append(h2v[h]) | ||||
ovars = G.apply_normal_varnode(op, *ivars) | ovars = G.apply_normal_varnode(op, *ivars) | ||||
auto_naming.record_opnode(ovars[0].op) | |||||
assert len(ovars) == len(ohandles) | assert len(ovars) == len(ohandles) | ||||
h2v.update(zip(ohandles, ovars)) | h2v.update(zip(ohandles, ovars)) | ||||
for i in ohandles: | |||||
name = auto_naming.get_var_name(i) | |||||
if name is not None: | |||||
h2v[i].name = name | |||||
auto_naming.remove_duplicate_names() | |||||
dest_vars = [] | dest_vars = [] | ||||
for i, h in enumerate(self._output_bindings): | for i, h in enumerate(self._output_bindings): | ||||
v = h2v[h] | v = h2v[h] | ||||
@@ -815,7 +858,15 @@ class trace: | |||||
if isinstance(file, str): | if isinstance(file, str): | ||||
permission = "wb" if append == False else "ab" | permission = "wb" if append == False else "ab" | ||||
file = open(file, permission) | file = open(file, permission) | ||||
dump_content, dump_info = G.dump_graph(dest_vars) | |||||
dump_content, dump_info = G.dump_graph( | |||||
dest_vars, | |||||
keep_var_name=keep_var_name, | |||||
keep_opr_name=keep_opr_name, | |||||
keep_param_name=keep_param_name, | |||||
keep_opr_priority=keep_opr_priority, | |||||
strip_info_file=strip_info_file, | |||||
append_json=append_json, | |||||
) | |||||
file.write(dump_content) | file.write(dump_content) | ||||
return dump_info | return dump_info | ||||
@@ -1095,20 +1146,22 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor): | |||||
return active_trace._apply_op(op, args) | return active_trace._apply_op(op, args) | ||||
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache): | |||||
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): | |||||
if skip_tracing: | if skip_tracing: | ||||
args = [ | args = [ | ||||
RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | ||||
for x in args | for x in args | ||||
] | ] | ||||
unset_tracing() | unset_tracing() | ||||
ret = RawTensor(value, dtype, device, False) | |||||
ret = RawTensor(value, dtype, device, False, name) | |||||
set_tracing() | set_tracing() | ||||
return ret | return ret | ||||
return active_trace._apply_const(value, dtype, device) | return active_trace._apply_const(value, dtype, device) | ||||
def apply_with_tracing(op: OpDef, *args: RawTensor): | def apply_with_tracing(op: OpDef, *args: RawTensor): | ||||
if hasattr(op, "scope"): | |||||
op.scope = auto_naming.get_scope() | |||||
if active_trace._symbolic: | if active_trace._symbolic: | ||||
outputs = apply_symbolic_mode(op, *args) | outputs = apply_symbolic_mode(op, *args) | ||||
else: | else: | ||||
@@ -1120,12 +1173,12 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): | |||||
return list(outputs) | return list(outputs) | ||||
def apply_const_with_tracing(value, dtype, device, is_const, no_cache): | |||||
def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name): | |||||
if active_trace._symbolic: | if active_trace._symbolic: | ||||
outputs = apply_const_symbolic_mode(value, dtype, device) | outputs = apply_const_symbolic_mode(value, dtype, device) | ||||
else: | else: | ||||
unset_tracing() | unset_tracing() | ||||
outputs = (RawTensor(value, dtype, device, False),) | |||||
outputs = (RawTensor(value, dtype, device, False, name),) | |||||
set_tracing() | set_tracing() | ||||
active_trace._record_const(outputs) | active_trace._record_const(outputs) | ||||
return list(outputs) | return list(outputs) |
@@ -12,12 +12,12 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||||
import numpy as np | import numpy as np | ||||
from ..core._imperative_rt.core2 import pop_scope, push_scope | |||||
from ..core.tensor.utils import make_shape_tuple | from ..core.tensor.utils import make_shape_tuple | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from ..tensor import Parameter, Tensor | from ..tensor import Parameter, Tensor | ||||
from ..utils.deprecation import deprecated | from ..utils.deprecation import deprecated | ||||
from ..utils.hook import HookHandler | from ..utils.hook import HookHandler | ||||
from ..utils.naming import auto_naming | |||||
logger = get_logger(__name__) | logger = get_logger(__name__) | ||||
@@ -69,7 +69,9 @@ class Module(metaclass=ABCMeta): | |||||
Base Module class. | Base Module class. | ||||
""" | """ | ||||
def __init__(self): | |||||
def __init__(self, name=""): | |||||
self.name = name | |||||
# runtime attributes | # runtime attributes | ||||
self.training = True | self.training = True | ||||
self.quantize_disabled = False | self.quantize_disabled = False | ||||
@@ -79,6 +81,8 @@ class Module(metaclass=ABCMeta): | |||||
self._forward_hooks = OrderedDict() | self._forward_hooks = OrderedDict() | ||||
self._modules = [] | self._modules = [] | ||||
# used for profiler and automatic naming | |||||
self._name = "{anonymous}" | self._name = "{anonymous}" | ||||
@abstractmethod | @abstractmethod | ||||
@@ -105,7 +109,7 @@ class Module(metaclass=ABCMeta): | |||||
return HookHandler(self._forward_hooks, hook) | return HookHandler(self._forward_hooks, hook) | ||||
def __call__(self, *inputs, **kwargs): | def __call__(self, *inputs, **kwargs): | ||||
push_scope(self._name) | |||||
auto_naming.push_scope(self.name if self.name else self._name) | |||||
for hook in self._forward_pre_hooks.values(): | for hook in self._forward_pre_hooks.values(): | ||||
modified_inputs = hook(self, inputs) | modified_inputs = hook(self, inputs) | ||||
if modified_inputs is not None: | if modified_inputs is not None: | ||||
@@ -119,7 +123,7 @@ class Module(metaclass=ABCMeta): | |||||
modified_outputs = hook(self, inputs, outputs) | modified_outputs = hook(self, inputs, outputs) | ||||
if modified_outputs is not None: | if modified_outputs is not None: | ||||
outputs = modified_outputs | outputs = modified_outputs | ||||
pop_scope(self._name) | |||||
auto_naming.pop_scope() | |||||
return outputs | return outputs | ||||
def _flatten( | def _flatten( | ||||
@@ -579,7 +583,7 @@ class Module(metaclass=ABCMeta): | |||||
value = super().__getattribute__(name) | value = super().__getattribute__(name) | ||||
if name == "_name": | if name == "_name": | ||||
return value | return value | ||||
if _is_module(value): | |||||
if isinstance(value, (Tensor, Module)): | |||||
value._name = name | value._name = name | ||||
return value | return value | ||||
@@ -20,6 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin | |||||
from .device import _valid_device, get_default_device | from .device import _valid_device, get_default_device | ||||
from .logger import get_logger | from .logger import get_logger | ||||
from .utils.deprecation import deprecated | from .utils.deprecation import deprecated | ||||
from .utils.naming import auto_naming | |||||
class Tensor(_Tensor, ArrayMethodMixin): | class Tensor(_Tensor, ArrayMethodMixin): | ||||
@@ -27,7 +28,9 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
dmap_callback = None | dmap_callback = None | ||||
_q_dict = None | _q_dict = None | ||||
def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False): | |||||
def __new__( | |||||
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name="" | |||||
): | |||||
if device is None: | if device is None: | ||||
cn = get_default_device() | cn = get_default_device() | ||||
elif isinstance(device, str): | elif isinstance(device, str): | ||||
@@ -51,8 +54,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
if isinstance(data, np.ndarray): | if isinstance(data, np.ndarray): | ||||
if 0 in data.strides: | if 0 in data.strides: | ||||
data = data.squeeze().reshape(data.shape) | data = data.squeeze().reshape(data.shape) | ||||
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache) | |||||
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name) | |||||
return obj | return obj | ||||
@property | @property | ||||
@@ -91,6 +93,15 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
piece += ", device={}".format(self.device) + ")" | piece += ", device={}".format(self.device) + ")" | ||||
return piece | return piece | ||||
@property | |||||
def name(self): | |||||
return self.c_name | |||||
@name.setter | |||||
def name(self, name): | |||||
self.c_name = name | |||||
auto_naming.record_var_name(self._mixin_handle, name) | |||||
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | @deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | ||||
def set_value(self, value): | def set_value(self, value): | ||||
if not isinstance(value, _Tensor): | if not isinstance(value, _Tensor): | ||||
@@ -0,0 +1,63 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ..core._imperative_rt.core2 import pop_scope, push_scope | |||||
class AutoNaming: | |||||
r""" | |||||
Name all executed operators automaticlly during tracing and record all tensors | |||||
renamed by the user. | |||||
""" | |||||
def __init__(self): | |||||
self.scopes = [] | |||||
self.c_ops = [] | |||||
self.name2ops = {} | |||||
self.handle2names = {} | |||||
def clear(self): | |||||
for var in vars(self).values(): | |||||
var.clear() | |||||
def push_scope(self, scope): | |||||
push_scope(scope) | |||||
self.scopes.append(scope) | |||||
def pop_scope(self): | |||||
scope = self.scopes.pop() | |||||
pop_scope(scope) | |||||
def get_scope(self): | |||||
return ".".join(self.scopes) | |||||
def record_var_name(self, handle, name): | |||||
self.handle2names[handle] = name | |||||
def get_var_name(self, handle): | |||||
return self.handle2names.pop(handle, None) | |||||
def record_opnode(self, op): | |||||
ops = self.name2ops.get(op.name, []) | |||||
ops.append(op) | |||||
self.name2ops[op.name] = ops | |||||
def remove_duplicate_names(self): | |||||
for key, ops in self.name2ops.items(): | |||||
if len(ops) == 1: | |||||
continue | |||||
for i, op in enumerate(ops): | |||||
op.name = key + "[%s]" % str(i) | |||||
if len(op.outputs) == 1: | |||||
continue | |||||
for var in op.outputs: | |||||
var.name = var.name.replace(key, op.name) | |||||
self.name2ops.clear() | |||||
auto_naming = AutoNaming() |
@@ -294,7 +294,7 @@ void init_graph_rt(py::module m) { | |||||
m.def("dump_graph", []( | m.def("dump_graph", []( | ||||
const std::vector<VarNode*>& dest_vars, | const std::vector<VarNode*>& dest_vars, | ||||
int keep_var_name, | int keep_var_name, | ||||
bool keep_op_name, | |||||
bool keep_opr_name, | |||||
bool keep_param_name, | bool keep_param_name, | ||||
bool keep_opr_priority, | bool keep_opr_priority, | ||||
py::list& stat, | py::list& stat, | ||||
@@ -307,7 +307,7 @@ void init_graph_rt(py::module m) { | |||||
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | ||||
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | ||||
keep_opr_priority, keep_op_name}; | |||||
keep_opr_priority, keep_opr_name}; | |||||
auto rst = dumper->dump(symvars, config); | auto rst = dumper->dump(symvars, config); | ||||
for (auto i : rst.inputs) { | for (auto i : rst.inputs) { | ||||
@@ -457,13 +457,17 @@ void init_graph_rt(py::module m) { | |||||
return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node(); | return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node(); | ||||
}); | }); | ||||
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { | |||||
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype, std::optional<std::string> name) { | |||||
if (!cn.valid()) { | if (!cn.valid()) { | ||||
cn = CompNode::load(get_default_device()); | cn = CompNode::load(get_default_device()); | ||||
} | } | ||||
OperatorNodeConfig config(cn); | |||||
if (name) { | |||||
config.name(*name); | |||||
} | |||||
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | ||||
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | |||||
}); | |||||
return opr::ImmutableTensor::make(*graph, hv, config).node(); | |||||
}, py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::none()); | |||||
m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) { | m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) { | ||||
if (!cn.valid()) { | if (!cn.valid()) { | ||||
@@ -99,6 +99,14 @@ PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { | |||||
#define py_get_generic(name, attr) \ | #define py_get_generic(name, attr) \ | ||||
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | ||||
template<typename T> | |||||
PyObject* py_get_scope_impl(PyObject* obj, void* /* closure */) { | |||||
// T: PyOpXXX inst(): return XXX in opdef.h.inl | |||||
auto& op = reinterpret_cast<T*>(obj)->inst(); | |||||
return pyobj_convert_generic<std::string>::to(op.scope()); | |||||
} | |||||
#define py_get_scope(class) py_get_scope_impl<PyOp(class)> | |||||
template<typename T, typename U, U T::Ty::*attr> | template<typename T, typename U, U T::Ty::*attr> | ||||
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { | int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { | ||||
if (value == NULL) { | if (value == NULL) { | ||||
@@ -121,6 +129,27 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { | |||||
#define py_set_generic(name, attr) \ | #define py_set_generic(name, attr) \ | ||||
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | ||||
template<typename T> | |||||
int py_set_scope_impl(PyObject* obj, PyObject* value, void* /* closure */) { | |||||
if (value == NULL) { | |||||
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); | |||||
return -1; | |||||
} | |||||
auto& op = reinterpret_cast<T*>(obj)->inst(); | |||||
try { | |||||
op.set_scope(pyobj_convert_generic<std::string>::from(value)); | |||||
return 0; | |||||
} catch(py::error_already_set& e) { | |||||
e.restore(); | |||||
} catch(py::builtin_exception& e) { | |||||
e.set_error(); | |||||
} catch(...) { | |||||
PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||||
} | |||||
return -1; | |||||
} | |||||
#define py_set_scope(class) py_set_scope_impl<PyOp(class)> | |||||
struct PyOpDef { | struct PyOpDef { | ||||
PyObject_HEAD | PyObject_HEAD | ||||
std::shared_ptr<OpDef> op; | std::shared_ptr<OpDef> op; | ||||
@@ -24,6 +24,7 @@ | |||||
#include <pybind11/numpy.h> | #include <pybind11/numpy.h> | ||||
#include <pybind11/operators.h> | #include <pybind11/operators.h> | ||||
#include <range/v3/all.hpp> | #include <range/v3/all.hpp> | ||||
#include <string> | |||||
#include <unordered_map> | #include <unordered_map> | ||||
@@ -222,14 +223,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
} | } | ||||
} else { | } else { | ||||
py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType | py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType | ||||
if (nargs != 4 && nargs != 5) { | |||||
throw py::type_error("expect 4 or 5 arguments"); | |||||
if (nargs != 5 && nargs != 6) { | |||||
throw py::type_error("expect 5 or 6 arguments"); | |||||
} | } | ||||
auto data = tup[0].cast<py::array>(); | auto data = tup[0].cast<py::array>(); | ||||
DType dtype = tup[1].cast<DType>(); | DType dtype = tup[1].cast<DType>(); | ||||
CompNode cn = tup[2].cast<CompNode>(); | CompNode cn = tup[2].cast<CompNode>(); | ||||
bool is_const = tup[3].cast<bool>(); | bool is_const = tup[3].cast<bool>(); | ||||
bool no_cache = nargs == 5 ? tup[4].cast<bool>() : false; | |||||
bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false; | |||||
std::string name = tup[nargs - 1].cast<std::string>(); | |||||
// const op | // const op | ||||
if (is_const && is_tracing) { | if (is_const && is_tracing) { | ||||
@@ -259,6 +261,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||||
} | } | ||||
m_tensor = std::make_shared<Tensor>(handle); | m_tensor = std::make_shared<Tensor>(handle); | ||||
m_tensor->user_custom_name = name; | |||||
if (data.ndim() == 0) { | if (data.ndim() == 0) { | ||||
m_tensor->m_flags |= Tensor::Flags::SCALAR; | m_tensor->m_flags |= Tensor::Flags::SCALAR; | ||||
@@ -313,6 +316,19 @@ REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info) | |||||
#undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC | #undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC | ||||
#define SET_GET_NAME(member) \ | |||||
PyObject* TensorWrapper::member() { \ | |||||
return py::cast(m_tensor->member).release().ptr(); \ | |||||
} \ | |||||
void TensorWrapper::set_##member(PyObject* dest) { \ | |||||
auto py_dest = py::reinterpret_borrow<py::object>(dest); \ | |||||
m_tensor->member = py_dest.cast<std::string>(); \ | |||||
} | |||||
SET_GET_NAME(user_custom_name) | |||||
SET_GET_NAME(automatic_name) | |||||
#undef SET_GET_NAME | |||||
PyObject* TensorWrapper::handle() { | PyObject* TensorWrapper::handle() { | ||||
return py::cast(m_tensor->m_handle).release().ptr(); | return py::cast(m_tensor->m_handle).release().ptr(); | ||||
} | } | ||||
@@ -453,7 +469,11 @@ void TensorWrapper::reset(PyObject* tensor) { | |||||
if (!t) { | if (!t) { | ||||
throw py::type_error("expect Tensor"); | throw py::type_error("expect Tensor"); | ||||
} | } | ||||
std::string user_custom_name = m_tensor->user_custom_name; | |||||
std::string automatic_name = m_tensor->automatic_name; | |||||
m_tensor = t->m_tensor; | m_tensor = t->m_tensor; | ||||
m_tensor->user_custom_name = user_custom_name; | |||||
m_tensor->automatic_name = automatic_name; | |||||
} | } | ||||
void TensorWrapper::reset_varnode() { | void TensorWrapper::reset_varnode() { | ||||
@@ -785,6 +805,8 @@ void init_tensor(py::module m) { | |||||
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") | .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") | ||||
.def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info") | .def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info") | ||||
.def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") | .def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") | ||||
.def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name") | |||||
.def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name") | |||||
.finalize(); | .finalize(); | ||||
if (!tensor_type) throw py::error_already_set(); | if (!tensor_type) throw py::error_already_set(); | ||||
py::setattr(m, "Tensor", tensor_type); | py::setattr(m, "Tensor", tensor_type); | ||||
@@ -15,6 +15,7 @@ | |||||
#include "megbrain/imperative/interpreter.h" | #include "megbrain/imperative/interpreter.h" | ||||
#include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
#include <string> | |||||
#include "./pyext17.h" | #include "./pyext17.h" | ||||
@@ -70,6 +71,8 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||||
GradInfo m_grad_info; | GradInfo m_grad_info; | ||||
TraceInfo m_trace_info; | TraceInfo m_trace_info; | ||||
SharedHandle m_handle; | SharedHandle m_handle; | ||||
std::string user_custom_name; | |||||
std::string automatic_name; | |||||
cg::VarNode* m_var; | cg::VarNode* m_var; | ||||
using Handle = interpreter::Interpreter::Handle; | using Handle = interpreter::Interpreter::Handle; | ||||
@@ -170,6 +173,10 @@ struct TensorWrapper { | |||||
void set_compiled_info(PyObject *); | void set_compiled_info(PyObject *); | ||||
PyObject* trace_mixin_info(); | PyObject* trace_mixin_info(); | ||||
void set_trace_mixin_info(PyObject *); | void set_trace_mixin_info(PyObject *); | ||||
PyObject* user_custom_name(); | |||||
void set_user_custom_name(PyObject *); | |||||
PyObject* automatic_name(); | |||||
void set_automatic_name(PyObject *); | |||||
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | ||||
}; | }; | ||||
@@ -0,0 +1,169 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import io | |||||
import numpy as np | |||||
import pytest | |||||
import megengine.functional as F | |||||
import megengine.module as M | |||||
import megengine.utils.comp_graph_tools as cgtools | |||||
from megengine import Parameter, Tensor | |||||
from megengine.core.tensor import megbrain_graph as G | |||||
from megengine.jit.tracing import trace | |||||
from megengine.utils.naming import auto_naming | |||||
def _dump_and_load(func, symbolic, keep_opr_name=True): | |||||
auto_naming.clear() | |||||
func = trace(func, symbolic=symbolic, capture_as_const=True) | |||||
x = Tensor(np.ones(shape=(2, 3))) | |||||
func(x).numpy() | |||||
file = io.BytesIO() | |||||
func.dump( | |||||
file, | |||||
optimize_for_inference=False, | |||||
arg_names="x", | |||||
keep_opr_name=keep_opr_name, | |||||
keep_var_name=2, | |||||
) | |||||
file.seek(0) | |||||
*_, outputs = G.load_graph(file) | |||||
op = cgtools.get_oprs_seq(outputs)[-1] | |||||
return op | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | |||||
def test_auto_naming(symbolic): | |||||
class Simple(M.Module): | |||||
def __init__(self, name): | |||||
super().__init__() | |||||
self.name = name | |||||
def forward(self, x): | |||||
return x + x | |||||
m = Simple("simple") | |||||
op = _dump_and_load(m, symbolic) | |||||
assert op.name == "simple.ADD" | |||||
assert op.outputs[0].name == "simple.ADD" | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | |||||
def test_user_named_tensor(symbolic): | |||||
class Simple(M.Module): | |||||
def __init__(self, name): | |||||
super().__init__() | |||||
self.name = name | |||||
self.k = Parameter(1.0, name="k") | |||||
def forward(self, x): | |||||
x = x + x | |||||
x.name = "o_x" | |||||
return x | |||||
m = Simple("simple") | |||||
op = _dump_and_load(m, symbolic) | |||||
assert op.name == "simple.ADD" | |||||
assert op.outputs[0].name == "o_x" | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | |||||
def test_user_named_param(symbolic): | |||||
class Simple(M.Module): | |||||
def __init__(self, name): | |||||
super().__init__() | |||||
self.name = name | |||||
self.k = Parameter(2.0, name="k") | |||||
def forward(self, x): | |||||
return self.k * x | |||||
m = Simple("simple") | |||||
op = _dump_and_load(m, symbolic) | |||||
assert op.inputs[0].name == "x" | |||||
assert op.inputs[1].name == "simple.k" | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | |||||
def test_without_module(symbolic): | |||||
def f(x): | |||||
return 2 * x | |||||
op = _dump_and_load(f, symbolic) | |||||
assert op.name == "MUL" | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | |||||
def test_with_submodule(symbolic): | |||||
class Simple(M.Module): | |||||
def __init__(self, name): | |||||
super().__init__() | |||||
self.name = name | |||||
self.linear = M.Linear(3, 3) | |||||
def forward(self, x): | |||||
x = self.linear(x) | |||||
return x | |||||
m = Simple("simple") | |||||
op = _dump_and_load(m, symbolic) | |||||
assert op.name == "simple.linear.ADD" | |||||
assert op.inputs[0].owner.name == "simple.linear.MatrixMul" | |||||
assert op.outputs[0].name == "simple.linear.ADD" | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | |||||
def test_named_submodule(symbolic): | |||||
class Simple(M.Module): | |||||
def __init__(self, name): | |||||
super().__init__() | |||||
self.name = name | |||||
self.linear = M.Linear(3, 3, name="x") | |||||
def forward(self, x): | |||||
x = self.linear(x) | |||||
return x | |||||
m = Simple("simple") | |||||
op = _dump_and_load(m, symbolic) | |||||
assert op.name == "simple.x.ADD" | |||||
assert op.inputs[0].owner.name == "simple.x.MatrixMul" | |||||
assert op.outputs[0].name == "simple.x.ADD" | |||||
@pytest.mark.parametrize("symbolic", [False, True]) | |||||
def test_with_same_operators(symbolic): | |||||
class Simple(M.Module): | |||||
def __init__(self, name): | |||||
super().__init__() | |||||
self.name = name | |||||
def forward(self, x): | |||||
x = F.relu(x) | |||||
x = F.relu(x) | |||||
return x | |||||
m = Simple("simple") | |||||
op = _dump_and_load(m, symbolic) | |||||
assert op.name == "simple.RELU[1]" | |||||
assert op.inputs[0].owner.name == "simple.RELU[0]" | |||||
def test_not_keep_opr_name(): | |||||
def f(x): | |||||
return 2 * x | |||||
op = _dump_and_load(f, True, False) | |||||
assert op.name == "MUL(x,2[2])[4]" |
@@ -148,7 +148,7 @@ def test_dump(): | |||||
dump_info = f.dump(file) | dump_info = f.dump(file) | ||||
assert dump_info.nr_opr == 3 | assert dump_info.nr_opr == 3 | ||||
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | ||||
np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) | |||||
np.testing.assert_equal(dump_info.outputs, ["ADD"]) | |||||
file.seek(0) | file.seek(0) | ||||
infer_cg = cgtools.GraphInference(file) | infer_cg = cgtools.GraphInference(file) | ||||
result = list((infer_cg.run(a, b)).values())[0] | result = list((infer_cg.run(a, b)).values())[0] | ||||
@@ -75,10 +75,6 @@ std::vector<std::pair<const char*, std::string>> OpDef::props( | |||||
return def.trait()->props(def); | return def.trait()->props(def); | ||||
} | } | ||||
const char* OpDef::name() const { | |||||
return trait()->name; | |||||
} | |||||
std::string OpDef::to_string() const { | std::string OpDef::to_string() const { | ||||
std::string builder = "{"; | std::string builder = "{"; | ||||
for (auto&& [name, value]: props(*this)) { | for (auto&& [name, value]: props(*this)) { | ||||
@@ -107,6 +103,20 @@ const OpTrait* OpDef::trait() const { | |||||
return m_trait; | return m_trait; | ||||
} | } | ||||
const std::string OpDef::scope() const { | |||||
return m_scope; | |||||
} | |||||
void OpDef::set_scope(const std::string& scope) { | |||||
m_scope = scope; | |||||
} | |||||
const std::string OpDef::make_name() const { | |||||
if (m_scope.empty()) | |||||
return trait()->make_name(*this); | |||||
return m_scope + "." + trait()->make_name(*this); | |||||
} | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -75,6 +75,7 @@ using GradMaker = detail::OpMeth< | |||||
using Props = detail::OpMeth<decltype(OpDef::props)>; | using Props = detail::OpMeth<decltype(OpDef::props)>; | ||||
using HashFunc = detail::OpMeth<size_t(const OpDef&)>; | using HashFunc = detail::OpMeth<size_t(const OpDef&)>; | ||||
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | ||||
using MakeNameFunc = detail::OpMeth<std::string(const OpDef&)>; | |||||
struct OpTrait { | struct OpTrait { | ||||
const char* name; | const char* name; | ||||
@@ -88,6 +89,7 @@ struct OpTrait { | |||||
Props props; | Props props; | ||||
HashFunc hash; | HashFunc hash; | ||||
IsSame is_same_st; | IsSame is_same_st; | ||||
MakeNameFunc make_name; | |||||
OpTrait(const char* name); | OpTrait(const char* name); | ||||
static OpTrait* find_by_name(const char* name); | static OpTrait* find_by_name(const char* name); | ||||
static OpTrait* find_by_typeinfo(Typeinfo* type); | static OpTrait* find_by_typeinfo(Typeinfo* type); | ||||
@@ -104,7 +106,8 @@ struct OpTrait { | |||||
cb(make_backward_graph) \ | cb(make_backward_graph) \ | ||||
cb(props) \ | cb(props) \ | ||||
cb(hash) \ | cb(hash) \ | ||||
cb(is_same_st) | |||||
cb(is_same_st) \ | |||||
cb(make_name) | |||||
struct OpTraitRegistry { | struct OpTraitRegistry { | ||||
OpTrait* trait; | OpTrait* trait; | ||||
@@ -30,13 +30,14 @@ cg::OperatorNodeBase* apply_on_var_node( | |||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
mgb_assert(nr_inp == 3 ||nr_inp == 5, | mgb_assert(nr_inp == 3 ||nr_inp == 5, | ||||
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | "BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | ||||
OperatorNodeConfig config{bn_opr.make_name()}; | |||||
if (nr_inp == 3) { | if (nr_inp == 3) { | ||||
return opr::BatchNorm::make( | return opr::BatchNorm::make( | ||||
inputs[0], inputs[1], inputs[2], bn_opr.param())[0] | |||||
inputs[0], inputs[1], inputs[2], bn_opr.param(), config)[0] | |||||
.node()->owner_opr(); | .node()->owner_opr(); | ||||
} else { | } else { | ||||
return opr::BatchNorm::make( | return opr::BatchNorm::make( | ||||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param())[0] | |||||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param(), config)[0] | |||||
.node()->owner_opr(); | .node()->owner_opr(); | ||||
} | } | ||||
} | } | ||||
@@ -27,10 +27,11 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
cg::OperatorNodeBase* apply_on_var_node( | cg::OperatorNodeBase* apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
def.cast_final_safe<Broadcast>(); | |||||
auto&& op = def.cast_final_safe<Broadcast>(); | |||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | ||||
return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr(); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr(); | |||||
} | } | ||||
bool valid_broadcast(const TensorShape& src_shape, | bool valid_broadcast(const TensorShape& src_shape, | ||||
@@ -96,7 +97,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const Reshape&>(def); | auto&& op = static_cast<const Reshape&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
return opr::Reshape::make(inputs[0], inputs[1], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::Reshape::make(inputs[0], inputs[1], op.param(), config); | |||||
} | } | ||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
@@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||||
auto disable = std::make_shared<DTypeScalar>(); | auto disable = std::make_shared<DTypeScalar>(); | ||||
disable->set(0); | disable->set(0); | ||||
cg::OperatorNodeConfig config; | |||||
OperatorNodeConfig config{comm.make_name()}; | |||||
if (comm.comp_node.size() > 0) { | if (comm.comp_node.size() > 0) { | ||||
config.comp_node(CompNode::load(comm.comp_node)); | config.comp_node(CompNode::load(comm.comp_node)); | ||||
} | } | ||||
@@ -23,12 +23,12 @@ namespace { | |||||
cg::OperatorNodeBase* apply_on_var_node( | cg::OperatorNodeBase* apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
def.cast_final_safe<CondTake>(); | |||||
auto&& op = def.cast_final_safe<CondTake>(); | |||||
auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
opr::CondTake::Param param; | opr::CondTake::Param param; | ||||
param.val = 1; | param.val = 1; | ||||
cg::OperatorNodeConfig config; | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
cg::OperatorNodeBase* opr = graph->insert_opr( | cg::OperatorNodeBase* opr = graph->insert_opr( | ||||
std::make_unique<opr::CondTake>( | std::make_unique<opr::CondTake>( | ||||
inputs[0], inputs[1], param, config)); | inputs[0], inputs[1], param, config)); | ||||
@@ -31,7 +31,8 @@ cg::OperatorNodeBase* apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& elemwise_opr = def.cast_final_safe<Elemwise>(); | auto&& elemwise_opr = def.cast_final_safe<Elemwise>(); | ||||
return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr(); | |||||
OperatorNodeConfig config{elemwise_opr.make_name()}; | |||||
return opr::Elemwise::make(inputs, elemwise_opr.mode, config).node()->owner_opr(); | |||||
} | } | ||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
@@ -23,7 +23,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const CvtColor&>(def); | auto&& op = static_cast<const CvtColor&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
return opr::CvtColor::make(inputs[0], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::CvtColor::make(inputs[0], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(CvtColor, CvtColor) | OP_TRAIT_REG(CvtColor, CvtColor) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -32,7 +32,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( | |||||
ssprintf("%s:%d", send.addr.data(), send.port)); | ssprintf("%s:%d", send.addr.data(), send.port)); | ||||
auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
cg::OperatorNodeConfig config; | |||||
OperatorNodeConfig config{send.make_name()}; | |||||
cg::OperatorNodeBase* opr = | cg::OperatorNodeBase* opr = | ||||
graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>( | graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>( | ||||
send.key, inputs[0], group_client, true, config)); | send.key, inputs[0], group_client, true, config)); | ||||
@@ -42,11 +42,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( | |||||
cg::OperatorNodeBase* apply_on_var_node_remote_recv( | cg::OperatorNodeBase* apply_on_var_node_remote_recv( | ||||
const OpDef& def, const VarNodeArray& inputs) { | const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& recv = def.cast_final_safe<RemoteRecv>(); | auto&& recv = def.cast_final_safe<RemoteRecv>(); | ||||
OperatorNodeConfig config{recv.cn}; | |||||
config.name(recv.make_name()); | |||||
auto group_client = std::make_shared<GroupClientProxy>( | auto group_client = std::make_shared<GroupClientProxy>( | ||||
ssprintf("%s:%d", recv.addr.data(), recv.port)); | ssprintf("%s:%d", recv.addr.data(), recv.port)); | ||||
auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | ||||
recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn}, | |||||
recv.key, inputs[0], *graph, group_client, config, | |||||
recv.shape, recv.dtype)); | recv.shape, recv.dtype)); | ||||
} | } | ||||
@@ -21,8 +21,10 @@ namespace { | |||||
auto apply_on_var_node( | auto apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = def.cast_final_safe<MatrixInverse>(); | |||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
return opr::MatrixInverse::make(inputs[0]); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::MatrixInverse::make(inputs[0], {}, config); | |||||
} | } | ||||
OP_TRAIT_REG(MatrixInverse, MatrixInverse) | OP_TRAIT_REG(MatrixInverse, MatrixInverse) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -29,7 +29,9 @@ cg::OperatorNodeBase* apply_on_var_node( | |||||
param.iou_thresh = nms_keep.iou_thresh; | param.iou_thresh = nms_keep.iou_thresh; | ||||
param.max_output = nms_keep.max_output; | param.max_output = nms_keep.max_output; | ||||
return NMSKeepOpr::make(inputs[0], param).node()->owner_opr(); | |||||
OperatorNodeConfig config{nms_keep.make_name()}; | |||||
return NMSKeepOpr::make(inputs[0], param, config).node()->owner_opr(); | |||||
} | } | ||||
OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr) | OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr) | ||||
@@ -79,11 +79,13 @@ public: | |||||
cg::OperatorNodeBase* apply_on_var_node( | cg::OperatorNodeBase* apply_on_var_node( | ||||
const OpDef& def, const VarNodeArray& inputs) { | const OpDef& def, const VarNodeArray& inputs) { | ||||
auto&& attr = def.cast_final_safe<OprAttr>(); | auto&& attr = def.cast_final_safe<OprAttr>(); | ||||
auto config = attr.config; | |||||
config.name(attr.make_name()); | |||||
mgb_assert(!inputs.empty()); | mgb_assert(!inputs.empty()); | ||||
auto registry = serialization::OprRegistry::find_by_name(attr.type); | auto registry = serialization::OprRegistry::find_by_name(attr.type); | ||||
mgb_assert(registry, "operator %s not found", attr.type.c_str()); | mgb_assert(registry, "operator %s not found", attr.type.c_str()); | ||||
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | ||||
return registry->loader(ctx, inputs, attr.config); | |||||
return registry->loader(ctx, inputs, config); | |||||
} | } | ||||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | ||||
@@ -99,10 +101,15 @@ std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | |||||
return {}; | return {}; | ||||
} | } | ||||
std::string make_name(const OpDef& def) { | |||||
return "OprAttr"; | |||||
} | |||||
OP_TRAIT_REG(OprAttr, OprAttr) | OP_TRAIT_REG(OprAttr, OprAttr) | ||||
.make_from_op_node(make_from_op_node) | .make_from_op_node(make_from_op_node) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.props(props) | .props(props) | ||||
.make_name(make_name) | |||||
.fallback(); | .fallback(); | ||||
} // anonymous namespace | } // anonymous namespace | ||||
@@ -24,7 +24,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const Resize&>(def); | auto&& op = static_cast<const Resize&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
return opr::Resize::make(inputs[0], inputs[1], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::Resize::make(inputs[0], inputs[1], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(Resize, Resize) | OP_TRAIT_REG(Resize, Resize) | ||||
@@ -46,7 +46,8 @@ auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& conv = static_cast<const Convolution&>(def); | auto&& conv = static_cast<const Convolution&>(def); | ||||
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy()); | |||||
OperatorNodeConfig config{conv.make_name()}; | |||||
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||||
} | } | ||||
OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | ||||
@@ -60,7 +61,7 @@ auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | ||||
cg::OperatorNodeConfig config; | |||||
OperatorNodeConfig config{conv.make_name()}; | |||||
if (inputs.size() == 2) { | if (inputs.size() == 2) { | ||||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | ||||
} else { | } else { | ||||
@@ -88,7 +89,8 @@ auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& ds = static_cast<const Dimshuffle&>(def); | auto&& ds = static_cast<const Dimshuffle&>(def); | ||||
return opr::Dimshuffle::make(inputs[0], ds.pattern); | |||||
OperatorNodeConfig config{ds.make_name()}; | |||||
return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config); | |||||
} | } | ||||
OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) | OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) | ||||
@@ -107,7 +109,8 @@ auto apply_on_var_node( | |||||
for (auto&& i : add_axis.axis) { | for (auto&& i : add_axis.axis) { | ||||
param.push_back(Desc::make_add(i)); | param.push_back(Desc::make_add(i)); | ||||
} | } | ||||
return opr::AxisAddRemove::make(inputs[0], param); | |||||
OperatorNodeConfig config{add_axis.make_name()}; | |||||
return opr::AxisAddRemove::make(inputs[0], param, config); | |||||
} | } | ||||
OP_TRAIT_REG(AddAxis, AddAxis) | OP_TRAIT_REG(AddAxis, AddAxis) | ||||
@@ -125,7 +128,8 @@ auto apply_on_var_node( | |||||
for (auto&& i : remove_axis.axis) { | for (auto&& i : remove_axis.axis) { | ||||
param.push_back(Desc::make_remove(i)); | param.push_back(Desc::make_remove(i)); | ||||
} | } | ||||
return opr::AxisAddRemove::make(inputs[0], param); | |||||
OperatorNodeConfig config{remove_axis.make_name()}; | |||||
return opr::AxisAddRemove::make(inputs[0], param, config); | |||||
} | } | ||||
OP_TRAIT_REG(RemoveAxis, RemoveAxis) | OP_TRAIT_REG(RemoveAxis, RemoveAxis) | ||||
@@ -138,7 +142,8 @@ auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& topk = static_cast<const TopK&>(def); | auto&& topk = static_cast<const TopK&>(def); | ||||
return opr::TopK::make(inputs[0], inputs[1], topk.param())[0] | |||||
OperatorNodeConfig config{topk.make_name()}; | |||||
return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0] | |||||
.node()->owner_opr(); | .node()->owner_opr(); | ||||
} | } | ||||
@@ -152,10 +157,12 @@ auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& reduce = static_cast<const Reduce&>(def); | auto&& reduce = static_cast<const Reduce&>(def); | ||||
OperatorNodeConfig config{reduce.make_name()}; | |||||
if (inputs.size() > 1) { | if (inputs.size() > 1) { | ||||
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]); | |||||
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); | |||||
} else { | } else { | ||||
return opr::Reduce::make(inputs[0], reduce.param()); | |||||
return opr::Reduce::make( | |||||
inputs[0], reduce.param(), (cg::VarNode*)nullptr, config); | |||||
} | } | ||||
} | } | ||||
@@ -175,7 +182,8 @@ auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& pool = static_cast<const AdaptivePooling&>(def); | auto&& pool = static_cast<const AdaptivePooling&>(def); | ||||
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param()); | |||||
OperatorNodeConfig config{pool.make_name()}; | |||||
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) | OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) | ||||
@@ -189,6 +197,7 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& conv = static_cast<const ConvBias&>(def); | auto&& conv = static_cast<const ConvBias&>(def); | ||||
cg::OperatorNodeConfig config{conv.dtype}; | cg::OperatorNodeConfig config{conv.dtype}; | ||||
config.name(conv.make_name()); | |||||
if (inputs.size() == 2) { | if (inputs.size() == 2) { | ||||
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | ||||
} else if (inputs.size() == 3) { | } else if (inputs.size() == 3) { | ||||
@@ -210,6 +219,7 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& conv = static_cast<const BatchConvBias&>(def); | auto&& conv = static_cast<const BatchConvBias&>(def); | ||||
cg::OperatorNodeConfig config{conv.dtype}; | cg::OperatorNodeConfig config{conv.dtype}; | ||||
config.name(conv.make_name()); | |||||
if (inputs.size() == 2) { | if (inputs.size() == 2) { | ||||
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | ||||
} else if (inputs.size() == 3) { | } else if (inputs.size() == 3) { | ||||
@@ -230,7 +240,8 @@ auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& pool = static_cast<const Pooling&>(def); | auto&& pool = static_cast<const Pooling&>(def); | ||||
return opr::Pooling::make(inputs[0], pool.param()); | |||||
OperatorNodeConfig config{pool.make_name()}; | |||||
return opr::Pooling::make(inputs[0], pool.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(Pooling, Pooling) | OP_TRAIT_REG(Pooling, Pooling) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -243,8 +254,9 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& matmul = static_cast<const MatrixMul&>(def); | auto&& matmul = static_cast<const MatrixMul&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
OperatorNodeConfig config{matmul.make_name()}; | |||||
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(), | return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(), | ||||
matmul.policy()); | |||||
matmul.policy(), config); | |||||
} | } | ||||
OP_TRAIT_REG(MatrixMul, MatrixMul) | OP_TRAIT_REG(MatrixMul, MatrixMul) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -257,8 +269,9 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
OperatorNodeConfig config{matmul.make_name()}; | |||||
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(), | return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(), | ||||
matmul.policy()); | |||||
matmul.policy(), config); | |||||
} | } | ||||
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -267,10 +280,12 @@ OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||||
namespace { namespace dot { | namespace { namespace dot { | ||||
auto apply_on_var_node( | auto apply_on_var_node( | ||||
const OpDef&, | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = def.cast_final_safe<Dot>(); | |||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
return opr::Dot::make(inputs[0], inputs[1]); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::Dot::make(inputs[0], inputs[1], config); | |||||
} | } | ||||
OP_TRAIT_REG(Dot, Dot) | OP_TRAIT_REG(Dot, Dot) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -282,7 +297,8 @@ auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& argsort = static_cast<const Argsort&>(def); | auto&& argsort = static_cast<const Argsort&>(def); | ||||
return opr::Argsort::make(inputs[0], argsort.param()); | |||||
OperatorNodeConfig config{argsort.make_name()}; | |||||
return opr::Argsort::make(inputs[0], argsort.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(Argsort, Argsort) | OP_TRAIT_REG(Argsort, Argsort) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -294,7 +310,8 @@ auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& argmax = static_cast<const Argmax&>(def); | auto&& argmax = static_cast<const Argmax&>(def); | ||||
return opr::Argmax::make(inputs[0], argmax.param()); | |||||
OperatorNodeConfig config{argmax.make_name()}; | |||||
return opr::Argmax::make(inputs[0], argmax.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(Argmax, Argmax) | OP_TRAIT_REG(Argmax, Argmax) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -306,7 +323,8 @@ auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& argmin = static_cast<const Argmin&>(def); | auto&& argmin = static_cast<const Argmin&>(def); | ||||
return opr::Argmin::make(inputs[0], argmin.param()); | |||||
OperatorNodeConfig config{argmin.make_name()}; | |||||
return opr::Argmin::make(inputs[0], argmin.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(Argmin, Argmin) | OP_TRAIT_REG(Argmin, Argmin) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -318,11 +336,13 @@ auto apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& warp = static_cast<const WarpPerspective&>(def); | auto&& warp = static_cast<const WarpPerspective&>(def); | ||||
OperatorNodeConfig config{warp.make_name()}; | |||||
if (inputs.size() == 3) { | if (inputs.size() == 3) { | ||||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param()); | |||||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config); | |||||
} else { | } else { | ||||
mgb_assert(inputs.size() == 4); | mgb_assert(inputs.size() == 4); | ||||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param()); | |||||
return opr::WarpPerspective::make( | |||||
inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config); | |||||
} | } | ||||
} | } | ||||
OP_TRAIT_REG(WarpPerspective, WarpPerspective) | OP_TRAIT_REG(WarpPerspective, WarpPerspective) | ||||
@@ -336,7 +356,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& local = static_cast<const GroupLocal&>(def); | auto&& local = static_cast<const GroupLocal&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
return opr::GroupLocal::make(inputs[0], inputs[1], local.param()); | |||||
OperatorNodeConfig config{local.make_name()}; | |||||
return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(GroupLocal, GroupLocal) | OP_TRAIT_REG(GroupLocal, GroupLocal) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -349,7 +370,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const IndexingOneHot&>(def); | auto&& op = static_cast<const IndexingOneHot&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -362,7 +384,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const IndexingSetOneHot&>(def); | auto&& op = static_cast<const IndexingSetOneHot&>(def); | ||||
mgb_assert(inputs.size() == 3); | mgb_assert(inputs.size() == 3); | ||||
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -375,7 +398,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const TypeCvt&>(def); | auto&& op = static_cast<const TypeCvt&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
return opr::TypeCvt::make(inputs[0], op.dtype); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::TypeCvt::make(inputs[0], op.dtype, config); | |||||
} | } | ||||
OP_TRAIT_REG(TypeCvt, TypeCvt) | OP_TRAIT_REG(TypeCvt, TypeCvt) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -388,6 +412,7 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const Concat&>(def); | auto&& op = static_cast<const Concat&>(def); | ||||
cg::OperatorNodeConfig config{op.comp_node}; | cg::OperatorNodeConfig config{op.comp_node}; | ||||
config.name(op.make_name()); | |||||
return opr::Concat::make(inputs, op.axis, config); | return opr::Concat::make(inputs, op.axis, config); | ||||
} | } | ||||
OP_TRAIT_REG(Concat, Concat) | OP_TRAIT_REG(Concat, Concat) | ||||
@@ -402,6 +427,7 @@ auto apply_on_var_node( | |||||
auto&& op = static_cast<const Copy&>(def); | auto&& op = static_cast<const Copy&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
cg::OperatorNodeConfig config{op.comp_node}; | cg::OperatorNodeConfig config{op.comp_node}; | ||||
config.name(op.make_name()); | |||||
return opr::Copy::make(inputs[0], config); | return opr::Copy::make(inputs[0], config); | ||||
} | } | ||||
OP_TRAIT_REG(Copy, Copy) | OP_TRAIT_REG(Copy, Copy) | ||||
@@ -411,10 +437,12 @@ OP_TRAIT_REG(Copy, Copy) | |||||
namespace { namespace identity { | namespace { namespace identity { | ||||
auto apply_on_var_node( | auto apply_on_var_node( | ||||
const OpDef&, | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = def.cast_final_safe<Identity>(); | |||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
return opr::Identity::make(inputs[0]); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::Identity::make(inputs[0], config); | |||||
} | } | ||||
OP_TRAIT_REG(Identity, Identity) | OP_TRAIT_REG(Identity, Identity) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -427,7 +455,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const AssertEqual&>(def); | auto&& op = static_cast<const AssertEqual&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
return opr::AssertEqual::make(inputs[0],inputs[1],op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config); | |||||
} | } | ||||
@@ -443,7 +472,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const UniformRNG&>(def); | auto&& op = static_cast<const UniformRNG&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
return opr::UniformRNG::make(inputs[0], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::UniformRNG::make(inputs[0], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(UniformRNG, UniformRNG) | OP_TRAIT_REG(UniformRNG, UniformRNG) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -456,7 +486,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const GaussianRNG&>(def); | auto&& op = static_cast<const GaussianRNG&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
return opr::GaussianRNG::make(inputs[0], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::GaussianRNG::make(inputs[0], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(GaussianRNG, GaussianRNG) | OP_TRAIT_REG(GaussianRNG, GaussianRNG) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -469,7 +500,9 @@ VarNodeArray apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const ROIAlign&>(def); | auto&& op = static_cast<const ROIAlign&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param()).node()->owner_opr(); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
auto* opr = opr::ROIAlign::make( | |||||
inputs[0], inputs[1], op.param(), config).node()->owner_opr(); | |||||
return {opr->output(0), opr->output(1)}; | return {opr->output(0), opr->output(1)}; | ||||
} | } | ||||
OP_TRAIT_REG(ROIAlign, ROIAlign) | OP_TRAIT_REG(ROIAlign, ROIAlign) | ||||
@@ -484,7 +517,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const NvOf&>(def); | auto&& op = static_cast<const NvOf&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
return opr::NvOf::make(inputs[0], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::NvOf::make(inputs[0], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(NvOf, NvOf) | OP_TRAIT_REG(NvOf, NvOf) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -499,6 +533,7 @@ auto apply_on_var_node( | |||||
auto&& op = static_cast<const Linspace&>(def); | auto&& op = static_cast<const Linspace&>(def); | ||||
mgb_assert(inputs.size() == 3); | mgb_assert(inputs.size() == 3); | ||||
cg::OperatorNodeConfig config{op.comp_node}; | cg::OperatorNodeConfig config{op.comp_node}; | ||||
config.name(op.make_name()); | |||||
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); | return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(Linspace, Linspace) | OP_TRAIT_REG(Linspace, Linspace) | ||||
@@ -513,6 +548,7 @@ auto apply_on_var_node( | |||||
auto&& op = static_cast<const Eye&>(def); | auto&& op = static_cast<const Eye&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
cg::OperatorNodeConfig config{op.comp_node}; | cg::OperatorNodeConfig config{op.comp_node}; | ||||
config.name(op.make_name()); | |||||
opr::Eye::Param param{op.k, op.dtype.enumv()}; | opr::Eye::Param param{op.k, op.dtype.enumv()}; | ||||
return opr::Eye::make(inputs[0], param, config); | return opr::Eye::make(inputs[0], param, config); | ||||
} | } | ||||
@@ -527,7 +563,10 @@ VarNodeArray apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const ROIPooling&>(def); | auto&& op = static_cast<const ROIPooling&>(def); | ||||
mgb_assert(inputs.size() == 3); | mgb_assert(inputs.size() == 3); | ||||
auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()).node()->owner_opr(); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
auto* opr = opr::ROIPooling::make( | |||||
inputs[0], inputs[1], inputs[2], op.param(), config | |||||
).node()->owner_opr(); | |||||
return {opr->output(0), opr->output(1)}; | return {opr->output(0), opr->output(1)}; | ||||
} | } | ||||
OP_TRAIT_REG(ROIPooling, ROIPooling) | OP_TRAIT_REG(ROIPooling, ROIPooling) | ||||
@@ -541,7 +580,8 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const Remap&>(def); | auto&& op = static_cast<const Remap&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
return opr::Remap::make(inputs[0], inputs[1], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::Remap::make(inputs[0], inputs[1], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(Remap, Remap) | OP_TRAIT_REG(Remap, Remap) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -578,7 +618,8 @@ auto apply_on_var_node( \ | |||||
const OpDef& def, \ | const OpDef& def, \ | ||||
const VarNodeArray& inputs) { \ | const VarNodeArray& inputs) { \ | ||||
auto&& op = static_cast<const NAME&>(def); \ | auto&& op = static_cast<const NAME&>(def); \ | ||||
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \ | |||||
OperatorNodeConfig config{op.make_name()}; \ | |||||
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \ | |||||
} \ | } \ | ||||
OP_TRAIT_REG(NAME, NAME) \ | OP_TRAIT_REG(NAME, NAME) \ | ||||
.apply_on_var_node(apply_on_var_node) \ | .apply_on_var_node(apply_on_var_node) \ | ||||
@@ -609,30 +650,35 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const FakeQuant&>(def); | auto&& op = static_cast<const FakeQuant&>(def); | ||||
mgb_assert(inputs.size() == 3); | mgb_assert(inputs.size() == 3); | ||||
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(FakeQuant, FakeQuant) | OP_TRAIT_REG(FakeQuant, FakeQuant) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.fallback(); | .fallback(); | ||||
}} // fake_quant | }} // fake_quant | ||||
namespace { namespace tqt { | namespace { namespace tqt { | ||||
auto apply_on_var_node( | auto apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const TQT&>(def); | auto&& op = static_cast<const TQT&>(def); | ||||
mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
return opr::TQT::make(inputs[0], inputs[1], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::TQT::make(inputs[0], inputs[1], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(TQT, TQT) | OP_TRAIT_REG(TQT, TQT) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
.fallback(); | .fallback(); | ||||
}} // tqt | }} // tqt | ||||
namespace { namespace elemwise_multi_type { | namespace { namespace elemwise_multi_type { | ||||
auto apply_on_var_node( | auto apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const ElemwiseMultiType&>(def); | auto&& op = static_cast<const ElemwiseMultiType&>(def); | ||||
OperatorNodeConfig config{op.dtype}; | OperatorNodeConfig config{op.dtype}; | ||||
config.name(op.make_name()); | |||||
return opr::ElemwiseMultiType::make(inputs, op.param(), config); | return opr::ElemwiseMultiType::make(inputs, op.param(), config); | ||||
} | } | ||||
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) | OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) | ||||
@@ -646,7 +692,9 @@ auto apply_on_var_node( | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const SVD&>(def); | auto&& op = static_cast<const SVD&>(def); | ||||
mgb_assert(inputs.size() == 1); | mgb_assert(inputs.size() == 1); | ||||
return opr::SVD::make(inputs[0], op.param())[0].node()->owner_opr()->usable_output(); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::SVD::make(inputs[0], op.param(), config)[0] | |||||
.node()->owner_opr()->usable_output(); | |||||
} | } | ||||
OP_TRAIT_REG(SVD, SVD) | OP_TRAIT_REG(SVD, SVD) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -21,7 +21,8 @@ cg::OperatorNodeBase* apply_on_var_node( | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op_def = def.cast_final_safe<GetVarShape>(); | auto&& op_def = def.cast_final_safe<GetVarShape>(); | ||||
return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); | |||||
OperatorNodeConfig config{op_def.make_name()}; | |||||
return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr(); | |||||
} | } | ||||
DispatchMode decide_dispatch_mode( | DispatchMode decide_dispatch_mode( | ||||
@@ -152,7 +153,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( | |||||
auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
auto&& shapes = get_shapes(param.shapes); | auto&& shapes = get_shapes(param.shapes); | ||||
cg::OperatorNodeConfig config; | |||||
OperatorNodeConfig config(param.make_name()); | |||||
cg::OperatorNodeBase* opr = | cg::OperatorNodeBase* opr = | ||||
graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>( | graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>( | ||||
inputs[0], param.offsets, shapes, config)); | inputs[0], param.offsets, shapes, config)); | ||||
@@ -189,7 +190,7 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( | |||||
auto&& graph = inputs[0]->owner_graph(); | auto&& graph = inputs[0]->owner_graph(); | ||||
VarNodeArray inps(inputs.begin(), inputs.end() - 1); | VarNodeArray inps(inputs.begin(), inputs.end() - 1); | ||||
cg::OperatorNodeConfig config; | |||||
OperatorNodeConfig config{param.make_name()}; | |||||
cg::OperatorNodeBase* opr = | cg::OperatorNodeBase* opr = | ||||
graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>( | graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>( | ||||
inps, inputs.back(), param.offsets, config)); | inps, inputs.back(), param.offsets, config)); | ||||
@@ -20,8 +20,9 @@ namespace { namespace tensorrt_runtime { | |||||
const OpDef& def, | const OpDef& def, | ||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
auto&& op = static_cast<const TensorRTRuntime&>(def); | auto&& op = static_cast<const TensorRTRuntime&>(def); | ||||
OperatorNodeConfig config{op.make_name()}; | |||||
SymbolVarArray sinputs(inputs.begin(), inputs.end()); | SymbolVarArray sinputs(inputs.begin(), inputs.end()); | ||||
return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs); | |||||
return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs, config); | |||||
} | } | ||||
OP_TRAIT_REG(TensorRTRuntime, TensorRTRuntime) | OP_TRAIT_REG(TensorRTRuntime, TensorRTRuntime) | ||||
.apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
@@ -21,7 +21,8 @@ namespace { namespace warp_affine { | |||||
const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
mgb_assert(inputs.size() == 3); | mgb_assert(inputs.size() == 3); | ||||
auto&& op = static_cast<const WarpAffine&>(def); | auto&& op = static_cast<const WarpAffine&>(def); | ||||
return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param()); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||||
} | } | ||||
OP_TRAIT_REG(WarpAffine, WarpAffine) | OP_TRAIT_REG(WarpAffine, WarpAffine) | ||||
@@ -36,6 +36,7 @@ class OpDef : public Hashable, | |||||
public NonCopyableObj, | public NonCopyableObj, | ||||
public std::enable_shared_from_this<OpDef> { | public std::enable_shared_from_this<OpDef> { | ||||
mutable const OpTrait* m_trait = nullptr; | mutable const OpTrait* m_trait = nullptr; | ||||
std::string m_scope; | |||||
public: | public: | ||||
virtual ~OpDef() = default; | virtual ~OpDef() = default; | ||||
@@ -86,10 +87,14 @@ public: | |||||
const OpTrait* trait() const; | const OpTrait* trait() const; | ||||
const char* name() const; | |||||
std::string to_string() const; | std::string to_string() const; | ||||
const std::string scope() const; | |||||
const std::string make_name() const; | |||||
void set_scope(const std::string& scope); | |||||
virtual size_t hash() const; | virtual size_t hash() const; | ||||
virtual bool is_same_st(const Hashable&) const; | virtual bool is_same_st(const Hashable&) const; | ||||
@@ -113,9 +113,10 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||||
"{0}({0}_)", i.name | "{0}({0}_)", i.name | ||||
)); | )); | ||||
} | } | ||||
paramList.push_back("std::string scope_ = {}"); | |||||
gen_ctor(llvm::join(paramList, ", "), | gen_ctor(llvm::join(paramList, ", "), | ||||
": " + llvm::join(initList, ", "), | ": " + llvm::join(initList, ", "), | ||||
" {}"); | |||||
" { set_scope(scope_); }"); | |||||
} | } | ||||
auto packedParams = op.getPackedParams(); | auto packedParams = op.getPackedParams(); | ||||
@@ -236,11 +237,19 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||||
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | ||||
os << "}\n"; | os << "}\n"; | ||||
// generate make_name() | |||||
os << formatv( | |||||
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||||
); | |||||
os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx); | |||||
os << "}\n"; | |||||
os << "} // anonymous namespace\n"; | os << "} // anonymous namespace\n"; | ||||
methods.push_back("hash"); | methods.push_back("hash"); | ||||
methods.push_back("is_same_st"); | methods.push_back("is_same_st"); | ||||
methods.push_back("props"); | methods.push_back("props"); | ||||
methods.push_back("make_name"); | |||||
} | } | ||||
if (!methods.empty()) { | if (!methods.empty()) { | ||||
os << formatv( | os << formatv( | ||||
@@ -327,7 +336,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& | |||||
targs.push_back(i.attr.getReturnType()); | targs.push_back(i.attr.getReturnType()); | ||||
} | } | ||||
os << llvm::join(targs, ", "); | os << llvm::join(targs, ", "); | ||||
os << ">()"; | |||||
os << ", std::string>()"; | |||||
for (auto &&i : op.getMgbAttributes()) { | for (auto &&i : op.getMgbAttributes()) { | ||||
os << formatv(", py::arg(\"{0}\")", i.name); | os << formatv(", py::arg(\"{0}\")", i.name); | ||||
auto defaultValue = i.attr.getDefaultValue(); | auto defaultValue = i.attr.getDefaultValue(); | ||||
@@ -337,7 +346,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& | |||||
hasDefaultCtor = true; | hasDefaultCtor = true; | ||||
} | } | ||||
} | } | ||||
os << ")"; | |||||
os << ", py::arg(\"scope\") = {})"; | |||||
} | } | ||||
if (hasDefaultCtor) { | if (hasDefaultCtor) { | ||||
os << "\n .def(py::init<>())"; | os << "\n .def(py::init<>())"; | ||||
@@ -442,6 +451,10 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||||
className, i.name)); | className, i.name)); | ||||
} | } | ||||
getsetters.push_back(formatv( | |||||
"{{\"scope\", py_get_scope({0}), py_set_scope({0}), \"scope\", NULL},", | |||||
className)); | |||||
// generate tp_init | // generate tp_init | ||||
std::string initBody; | std::string initBody; | ||||
if (!op.getMgbAttributes().empty()) { | if (!op.getMgbAttributes().empty()) { | ||||
@@ -449,6 +462,7 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | ||||
initBody += formatv("\"{0}\", ", attr.name); | initBody += formatv("\"{0}\", ", attr.name); | ||||
}); | }); | ||||
initBody += "\"scope\", "; | |||||
initBody += "NULL};\n"; | initBody += "NULL};\n"; | ||||
initBody += " PyObject "; | initBody += " PyObject "; | ||||
std::vector<std::string> attrs; | std::vector<std::string> attrs; | ||||
@@ -456,12 +470,15 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||||
attrs.push_back(formatv("*{0} = NULL", attr.name)); | attrs.push_back(formatv("*{0} = NULL", attr.name)); | ||||
}); | }); | ||||
initBody += llvm::join(attrs, ", ") + ";\n"; | initBody += llvm::join(attrs, ", ") + ";\n"; | ||||
initBody += " PyObject *scope = NULL;\n"; | |||||
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | ||||
initBody += std::string(op.getMgbAttributes().size(), 'O'); | |||||
// an extra slot created for name | |||||
initBody += std::string(op.getMgbAttributes().size() + 1, 'O'); | |||||
initBody += "\", const_cast<char**>(kwlist)"; | initBody += "\", const_cast<char**>(kwlist)"; | ||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | ||||
initBody += formatv(" ,&{0}", attr.name); | |||||
initBody += formatv(", &{0}", attr.name); | |||||
}); | }); | ||||
initBody += ", &scope"; | |||||
initBody += "))\n"; | initBody += "))\n"; | ||||
initBody += " return -1;\n"; | initBody += " return -1;\n"; | ||||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | ||||
@@ -483,6 +500,25 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||||
} | } | ||||
)", className, attr.name); | )", className, attr.name); | ||||
}); | }); | ||||
initBody += formatv(R"( | |||||
if (scope) {{ | |||||
try {{ | |||||
reinterpret_cast<PyOp({0})*>(self)->inst().set_scope( | |||||
pyobj_convert_generic<std::string>::from(scope)); | |||||
} catch(py::error_already_set& e) {{ | |||||
e.restore(); | |||||
return -1; | |||||
} catch(py::builtin_exception& e) {{ | |||||
e.set_error(); | |||||
return -1; | |||||
} catch(...) {{ | |||||
PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||||
return -1; | |||||
} | |||||
} | |||||
)", className); | |||||
} | } | ||||
initBody += "\n return 0;"; | initBody += "\n return 0;"; | ||||
@@ -241,6 +241,30 @@ private: | |||||
body += " return props_;\n"; | body += " return props_;\n"; | ||||
return body; | return body; | ||||
} | } | ||||
std::string getModeName() const { | |||||
std::string body = formatv( | |||||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||||
" static_cast<void>(op_);\n", | |||||
getCppClassName() | |||||
); | |||||
for (auto&& it : getMgbAttributes()) { | |||||
if (it.name == "mode") { | |||||
auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr); | |||||
body += " switch (op_.mode){\n"; | |||||
for (auto&& enumMember: enumAttr->getEnumMembers()) { | |||||
body += formatv( | |||||
" case {0}::{1}::{2}:\n", | |||||
getCppClassName(), enumAttr->getEnumName(), enumMember | |||||
); | |||||
body += formatv(" return \"{0}\";\n", enumMember); | |||||
} | |||||
body += formatv( | |||||
" default: return \"{0}::Unknown\";\n", getCppClassName()); | |||||
body += " }\n"; | |||||
} | |||||
} | |||||
return body; | |||||
} | |||||
public: | public: | ||||
static bool classof(const Operator* op) { | static bool classof(const Operator* op) { | ||||
return op->getDef().isSubClassOf("MgbHashableOpMixin"); | return op->getDef().isSubClassOf("MgbHashableOpMixin"); | ||||
@@ -264,6 +288,12 @@ public: | |||||
} | } | ||||
return getDefaultPropsFunction(); | return getDefaultPropsFunction(); | ||||
} | } | ||||
std::string getNameFunctionTemplate() const { | |||||
if (getDef().getValueAsBit("usingModeName")) { | |||||
return getModeName(); | |||||
} | |||||
return formatv(" return \"{0}\";\n", getCppClassName()); | |||||
} | |||||
}; | }; | ||||
} // namespace tblgen | } // namespace tblgen | ||||
@@ -476,6 +476,7 @@ def main(): | |||||
output_mgbvars = feeds["outputs"] | output_mgbvars = feeds["outputs"] | ||||
output_mgbvars = optimize_for_inference(args, output_mgbvars) | output_mgbvars = optimize_for_inference(args, output_mgbvars) | ||||
output_mgbvars = [var._node for var in output_mgbvars] | |||||
inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") | inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") | ||||
inputs = sorted((i.name, i.dtype) for i in inputs) | inputs = sorted((i.name, i.dtype) for i in inputs) | ||||
@@ -242,6 +242,7 @@ class MgbPackedParamBase<string className, string accessor>: | |||||
class MgbHashableOpMixin { | class MgbHashableOpMixin { | ||||
string hashFunction = ?; | string hashFunction = ?; | ||||
string cmpFunction = ?; | string cmpFunction = ?; | ||||
bit usingModeName = 0; | |||||
} | } | ||||
class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | ||||
@@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" | |||||
def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | ||||
let inputs = (ins Variadic<AnyType>:$input); | let inputs = (ins Variadic<AnyType>:$input); | ||||
let results = (outs AnyType); | let results = (outs AnyType); | ||||
let usingModeName = 1; | |||||
} | } | ||||
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | ||||
@@ -247,6 +248,7 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara | |||||
let extraArguments = (ins | let extraArguments = (ins | ||||
MgbDTypeAttr:$dtype | MgbDTypeAttr:$dtype | ||||
); | ); | ||||
let usingModeName = 1; | |||||
} | } | ||||
def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; | def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; | ||||