GitOrigin-RevId: ff8eb003c5
tags/v1.3.0
@@ -96,7 +96,7 @@ class Graph(_imperative_rt.ComputingGraph): | |||
data = data.numpy() | |||
return self._wrap(_imperative_rt.make_const(self, data, device, data.dtype)) | |||
def make_const(self, data, dtype=None, device=None): | |||
def make_const(self, data, dtype=None, device=None, name=None): | |||
if isinstance(data, _imperative_rt.DeviceTensorND): | |||
assert dtype is None and device is None | |||
return self._wrap(_imperative_rt.make_shared(self, data)) | |||
@@ -107,7 +107,9 @@ class Graph(_imperative_rt.ComputingGraph): | |||
elif data.dtype == np.int64: | |||
data = data.astype(np.int32) | |||
device = as_device(device).to_c() | |||
return self._wrap(_imperative_rt.make_const(self, data, device, dtype)) | |||
return self._wrap( | |||
_imperative_rt.make_const(self, data, device, dtype, name) | |||
) | |||
def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None): | |||
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | |||
@@ -305,7 +307,7 @@ def dump_graph( | |||
output_vars: Union[Dict[str, VarNode], List[VarNode]], | |||
*, | |||
keep_var_name: int = 1, | |||
keep_op_name: bool = True, | |||
keep_opr_name: bool = False, | |||
keep_param_name: bool = False, | |||
keep_opr_priority: bool = False, | |||
strip_info_file=None, | |||
@@ -326,7 +328,7 @@ def dump_graph( | |||
* 0: none of the names are kept | |||
* 1: (default)keep names of output vars | |||
* 2: keep names of all (output and internal) vars | |||
:param keep_op_name: whether to keep operator names. | |||
:param keep_opr_name: whether to keep operator names. | |||
:param keep_param_name: whether to keep param names, so param values can be | |||
easily manipulated after loading model | |||
:param keep_opr_priority: whether to keep priority setting for operators | |||
@@ -370,7 +372,7 @@ def dump_graph( | |||
dump_content = _imperative_rt.dump_graph( | |||
ov, | |||
keep_var_name, | |||
keep_op_name, | |||
keep_opr_name, | |||
keep_param_name, | |||
keep_opr_priority, | |||
stat, | |||
@@ -36,6 +36,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef | |||
from ..core.ops.special import Const | |||
from ..core.tensor import megbrain_graph as G | |||
from ..core.tensor.utils import setscalar | |||
from ..utils.naming import auto_naming | |||
from .sublinear_memory_config import SublinearMemoryConfig | |||
@@ -77,6 +78,7 @@ def exclude_from_trace(): | |||
class TensorInfo: | |||
__slots__ = ( | |||
# collected attributes | |||
"name", | |||
"external", | |||
"data_read", | |||
"shape_read", | |||
@@ -96,6 +98,7 @@ class TensorInfo: | |||
) | |||
def __init__(self): | |||
self.name = None | |||
self.exported = None | |||
self.data_read = None | |||
self.shape_read = None | |||
@@ -290,12 +293,16 @@ class trace: | |||
h = getattr(x, "_mixin_handle", -1) | |||
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): | |||
h, info = self._new_handle() | |||
name = auto_naming.get_scope() + "." + x.c_name if x.c_name else x._name | |||
info.name = name | |||
info.external = True | |||
info.device = x.device | |||
info.dtype = x.dtype | |||
info.shape = x.shape | |||
if self._capture_as_const: | |||
info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False) | |||
info.bound_data = RawTensor( | |||
x.numpy(), x.dtype, x.device, False, name | |||
) | |||
ihandles.append(h) | |||
@@ -669,6 +676,12 @@ class trace: | |||
arg_names=None, | |||
output_names=None, | |||
append=False, | |||
keep_var_name: int = 1, | |||
keep_opr_name: bool = False, | |||
keep_param_name: bool = False, | |||
keep_opr_priority: bool = False, | |||
strip_info_file=None, | |||
append_json=False, | |||
optimize_for_inference=True, | |||
**kwargs | |||
): | |||
@@ -681,6 +694,20 @@ class trace: | |||
use the default name if not specified. | |||
:param append: whether output is appended to ``file``. | |||
Only works when ``file`` is str. | |||
:param keep_var_name: level for keeping variable names: | |||
* 0: none of the names are kept | |||
* 1: (default)keep names of output vars | |||
* 2: keep names of all (output and internal) vars | |||
:param keep_opr_name: whether to keep operator names. | |||
:param keep_param_name: whether to keep param names, so param values can be | |||
easily manipulated after loading model | |||
:param keep_opr_priority: whether to keep priority setting for operators | |||
:param strip_info_file: a string for path or a file handler. if is not None, | |||
then the dump information for code strip would be written to ``strip_info_file`` | |||
:param append_json: will be check when `strip_info_file` is not None. if set | |||
true, the information for code strip will be append to strip_info_file. | |||
if set false, will rewrite strip_info_file | |||
:param optimize_for_inference: enbale optmizations, | |||
will skip all optimize options if this is False. Default: True | |||
@@ -785,7 +812,10 @@ class trace: | |||
assert info.external | |||
assert info.bound_data | |||
h2v[h] = graph.make_const( | |||
info.bound_data.numpy(), dtype=info.dtype, device=info.device, | |||
info.bound_data.numpy(), | |||
dtype=info.dtype, | |||
device=info.device, | |||
name=info.name, | |||
) | |||
continue | |||
ivars = [] | |||
@@ -795,13 +825,26 @@ class trace: | |||
assert info.external | |||
assert info.bound_data | |||
h2v[h] = graph.make_const( | |||
info.bound_data.numpy(), dtype=info.dtype, device=dumped_device | |||
info.bound_data.numpy(), | |||
dtype=info.dtype, | |||
device=dumped_device, | |||
name=info.name, | |||
) | |||
ivars.append(h2v[h]) | |||
ovars = G.apply_normal_varnode(op, *ivars) | |||
auto_naming.record_opnode(ovars[0].op) | |||
assert len(ovars) == len(ohandles) | |||
h2v.update(zip(ohandles, ovars)) | |||
for i in ohandles: | |||
name = auto_naming.get_var_name(i) | |||
if name is not None: | |||
h2v[i].name = name | |||
auto_naming.remove_duplicate_names() | |||
dest_vars = [] | |||
for i, h in enumerate(self._output_bindings): | |||
v = h2v[h] | |||
@@ -815,7 +858,15 @@ class trace: | |||
if isinstance(file, str): | |||
permission = "wb" if append == False else "ab" | |||
file = open(file, permission) | |||
dump_content, dump_info = G.dump_graph(dest_vars) | |||
dump_content, dump_info = G.dump_graph( | |||
dest_vars, | |||
keep_var_name=keep_var_name, | |||
keep_opr_name=keep_opr_name, | |||
keep_param_name=keep_param_name, | |||
keep_opr_priority=keep_opr_priority, | |||
strip_info_file=strip_info_file, | |||
append_json=append_json, | |||
) | |||
file.write(dump_content) | |||
return dump_info | |||
@@ -1095,20 +1146,22 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor): | |||
return active_trace._apply_op(op, args) | |||
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache): | |||
def apply_const_compiled_mode(value, dtype, device, is_const, no_cache, name): | |||
if skip_tracing: | |||
args = [ | |||
RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||
for x in args | |||
] | |||
unset_tracing() | |||
ret = RawTensor(value, dtype, device, False) | |||
ret = RawTensor(value, dtype, device, False, name) | |||
set_tracing() | |||
return ret | |||
return active_trace._apply_const(value, dtype, device) | |||
def apply_with_tracing(op: OpDef, *args: RawTensor): | |||
if hasattr(op, "scope"): | |||
op.scope = auto_naming.get_scope() | |||
if active_trace._symbolic: | |||
outputs = apply_symbolic_mode(op, *args) | |||
else: | |||
@@ -1120,12 +1173,12 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): | |||
return list(outputs) | |||
def apply_const_with_tracing(value, dtype, device, is_const, no_cache): | |||
def apply_const_with_tracing(value, dtype, device, is_const, no_cache, name): | |||
if active_trace._symbolic: | |||
outputs = apply_const_symbolic_mode(value, dtype, device) | |||
else: | |||
unset_tracing() | |||
outputs = (RawTensor(value, dtype, device, False),) | |||
outputs = (RawTensor(value, dtype, device, False, name),) | |||
set_tracing() | |||
active_trace._record_const(outputs) | |||
return list(outputs) |
@@ -12,12 +12,12 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
import numpy as np | |||
from ..core._imperative_rt.core2 import pop_scope, push_scope | |||
from ..core.tensor.utils import make_shape_tuple | |||
from ..logger import get_logger | |||
from ..tensor import Parameter, Tensor | |||
from ..utils.deprecation import deprecated | |||
from ..utils.hook import HookHandler | |||
from ..utils.naming import auto_naming | |||
logger = get_logger(__name__) | |||
@@ -69,7 +69,9 @@ class Module(metaclass=ABCMeta): | |||
Base Module class. | |||
""" | |||
def __init__(self): | |||
def __init__(self, name=""): | |||
self.name = name | |||
# runtime attributes | |||
self.training = True | |||
self.quantize_disabled = False | |||
@@ -79,6 +81,8 @@ class Module(metaclass=ABCMeta): | |||
self._forward_hooks = OrderedDict() | |||
self._modules = [] | |||
# used for profiler and automatic naming | |||
self._name = "{anonymous}" | |||
@abstractmethod | |||
@@ -105,7 +109,7 @@ class Module(metaclass=ABCMeta): | |||
return HookHandler(self._forward_hooks, hook) | |||
def __call__(self, *inputs, **kwargs): | |||
push_scope(self._name) | |||
auto_naming.push_scope(self.name if self.name else self._name) | |||
for hook in self._forward_pre_hooks.values(): | |||
modified_inputs = hook(self, inputs) | |||
if modified_inputs is not None: | |||
@@ -119,7 +123,7 @@ class Module(metaclass=ABCMeta): | |||
modified_outputs = hook(self, inputs, outputs) | |||
if modified_outputs is not None: | |||
outputs = modified_outputs | |||
pop_scope(self._name) | |||
auto_naming.pop_scope() | |||
return outputs | |||
def _flatten( | |||
@@ -579,7 +583,7 @@ class Module(metaclass=ABCMeta): | |||
value = super().__getattribute__(name) | |||
if name == "_name": | |||
return value | |||
if _is_module(value): | |||
if isinstance(value, (Tensor, Module)): | |||
value._name = name | |||
return value | |||
@@ -20,6 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin | |||
from .device import _valid_device, get_default_device | |||
from .logger import get_logger | |||
from .utils.deprecation import deprecated | |||
from .utils.naming import auto_naming | |||
class Tensor(_Tensor, ArrayMethodMixin): | |||
@@ -27,7 +28,9 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
dmap_callback = None | |||
_q_dict = None | |||
def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False): | |||
def __new__( | |||
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name="" | |||
): | |||
if device is None: | |||
cn = get_default_device() | |||
elif isinstance(device, str): | |||
@@ -51,8 +54,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
if isinstance(data, np.ndarray): | |||
if 0 in data.strides: | |||
data = data.squeeze().reshape(data.shape) | |||
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache) | |||
obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache, name) | |||
return obj | |||
@property | |||
@@ -91,6 +93,15 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
piece += ", device={}".format(self.device) + ")" | |||
return piece | |||
@property | |||
def name(self): | |||
return self.c_name | |||
@name.setter | |||
def name(self, name): | |||
self.c_name = name | |||
auto_naming.record_var_name(self._mixin_handle, name) | |||
@deprecated(version="1.0", reason="no need to reuse an existing tensor since 1.0") | |||
def set_value(self, value): | |||
if not isinstance(value, _Tensor): | |||
@@ -0,0 +1,63 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..core._imperative_rt.core2 import pop_scope, push_scope | |||
class AutoNaming: | |||
r""" | |||
Name all executed operators automaticlly during tracing and record all tensors | |||
renamed by the user. | |||
""" | |||
def __init__(self): | |||
self.scopes = [] | |||
self.c_ops = [] | |||
self.name2ops = {} | |||
self.handle2names = {} | |||
def clear(self): | |||
for var in vars(self).values(): | |||
var.clear() | |||
def push_scope(self, scope): | |||
push_scope(scope) | |||
self.scopes.append(scope) | |||
def pop_scope(self): | |||
scope = self.scopes.pop() | |||
pop_scope(scope) | |||
def get_scope(self): | |||
return ".".join(self.scopes) | |||
def record_var_name(self, handle, name): | |||
self.handle2names[handle] = name | |||
def get_var_name(self, handle): | |||
return self.handle2names.pop(handle, None) | |||
def record_opnode(self, op): | |||
ops = self.name2ops.get(op.name, []) | |||
ops.append(op) | |||
self.name2ops[op.name] = ops | |||
def remove_duplicate_names(self): | |||
for key, ops in self.name2ops.items(): | |||
if len(ops) == 1: | |||
continue | |||
for i, op in enumerate(ops): | |||
op.name = key + "[%s]" % str(i) | |||
if len(op.outputs) == 1: | |||
continue | |||
for var in op.outputs: | |||
var.name = var.name.replace(key, op.name) | |||
self.name2ops.clear() | |||
auto_naming = AutoNaming() |
@@ -294,7 +294,7 @@ void init_graph_rt(py::module m) { | |||
m.def("dump_graph", []( | |||
const std::vector<VarNode*>& dest_vars, | |||
int keep_var_name, | |||
bool keep_op_name, | |||
bool keep_opr_name, | |||
bool keep_param_name, | |||
bool keep_opr_priority, | |||
py::list& stat, | |||
@@ -307,7 +307,7 @@ void init_graph_rt(py::module m) { | |||
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); | |||
ser::GraphDumper::DumpConfig config{keep_var_name, keep_param_name, | |||
keep_opr_priority, keep_op_name}; | |||
keep_opr_priority, keep_opr_name}; | |||
auto rst = dumper->dump(symvars, config); | |||
for (auto i : rst.inputs) { | |||
@@ -457,13 +457,17 @@ void init_graph_rt(py::module m) { | |||
return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node(); | |||
}); | |||
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { | |||
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype, std::optional<std::string> name) { | |||
if (!cn.valid()) { | |||
cn = CompNode::load(get_default_device()); | |||
} | |||
OperatorNodeConfig config(cn); | |||
if (name) { | |||
config.name(*name); | |||
} | |||
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | |||
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | |||
}); | |||
return opr::ImmutableTensor::make(*graph, hv, config).node(); | |||
}, py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = py::none()); | |||
m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) { | |||
if (!cn.valid()) { | |||
@@ -99,6 +99,14 @@ PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) { | |||
#define py_get_generic(name, attr) \ | |||
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | |||
template<typename T> | |||
PyObject* py_get_scope_impl(PyObject* obj, void* /* closure */) { | |||
// T: PyOpXXX inst(): return XXX in opdef.h.inl | |||
auto& op = reinterpret_cast<T*>(obj)->inst(); | |||
return pyobj_convert_generic<std::string>::to(op.scope()); | |||
} | |||
#define py_get_scope(class) py_get_scope_impl<PyOp(class)> | |||
template<typename T, typename U, U T::Ty::*attr> | |||
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { | |||
if (value == NULL) { | |||
@@ -121,6 +129,27 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) { | |||
#define py_set_generic(name, attr) \ | |||
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr> | |||
template<typename T> | |||
int py_set_scope_impl(PyObject* obj, PyObject* value, void* /* closure */) { | |||
if (value == NULL) { | |||
PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute"); | |||
return -1; | |||
} | |||
auto& op = reinterpret_cast<T*>(obj)->inst(); | |||
try { | |||
op.set_scope(pyobj_convert_generic<std::string>::from(value)); | |||
return 0; | |||
} catch(py::error_already_set& e) { | |||
e.restore(); | |||
} catch(py::builtin_exception& e) { | |||
e.set_error(); | |||
} catch(...) { | |||
PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||
} | |||
return -1; | |||
} | |||
#define py_set_scope(class) py_set_scope_impl<PyOp(class)> | |||
struct PyOpDef { | |||
PyObject_HEAD | |||
std::shared_ptr<OpDef> op; | |||
@@ -24,6 +24,7 @@ | |||
#include <pybind11/numpy.h> | |||
#include <pybind11/operators.h> | |||
#include <range/v3/all.hpp> | |||
#include <string> | |||
#include <unordered_map> | |||
@@ -222,14 +223,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
} | |||
} else { | |||
py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType | |||
if (nargs != 4 && nargs != 5) { | |||
throw py::type_error("expect 4 or 5 arguments"); | |||
if (nargs != 5 && nargs != 6) { | |||
throw py::type_error("expect 5 or 6 arguments"); | |||
} | |||
auto data = tup[0].cast<py::array>(); | |||
DType dtype = tup[1].cast<DType>(); | |||
CompNode cn = tup[2].cast<CompNode>(); | |||
bool is_const = tup[3].cast<bool>(); | |||
bool no_cache = nargs == 5 ? tup[4].cast<bool>() : false; | |||
bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false; | |||
std::string name = tup[nargs - 1].cast<std::string>(); | |||
// const op | |||
if (is_const && is_tracing) { | |||
@@ -259,6 +261,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
} | |||
m_tensor = std::make_shared<Tensor>(handle); | |||
m_tensor->user_custom_name = name; | |||
if (data.ndim() == 0) { | |||
m_tensor->m_flags |= Tensor::Flags::SCALAR; | |||
@@ -313,6 +316,19 @@ REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info) | |||
#undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC | |||
#define SET_GET_NAME(member) \ | |||
PyObject* TensorWrapper::member() { \ | |||
return py::cast(m_tensor->member).release().ptr(); \ | |||
} \ | |||
void TensorWrapper::set_##member(PyObject* dest) { \ | |||
auto py_dest = py::reinterpret_borrow<py::object>(dest); \ | |||
m_tensor->member = py_dest.cast<std::string>(); \ | |||
} | |||
SET_GET_NAME(user_custom_name) | |||
SET_GET_NAME(automatic_name) | |||
#undef SET_GET_NAME | |||
PyObject* TensorWrapper::handle() { | |||
return py::cast(m_tensor->m_handle).release().ptr(); | |||
} | |||
@@ -453,7 +469,11 @@ void TensorWrapper::reset(PyObject* tensor) { | |||
if (!t) { | |||
throw py::type_error("expect Tensor"); | |||
} | |||
std::string user_custom_name = m_tensor->user_custom_name; | |||
std::string automatic_name = m_tensor->automatic_name; | |||
m_tensor = t->m_tensor; | |||
m_tensor->user_custom_name = user_custom_name; | |||
m_tensor->automatic_name = automatic_name; | |||
} | |||
void TensorWrapper::reset_varnode() { | |||
@@ -785,6 +805,8 @@ void init_tensor(py::module m) { | |||
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") | |||
.def_getset<&TensorWrapper::compiled_info, &TensorWrapper::set_compiled_info>("_compiled_info") | |||
.def_getset<&TensorWrapper::trace_mixin_info, &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info") | |||
.def_getset<&TensorWrapper::user_custom_name, &TensorWrapper::set_user_custom_name>("c_name") | |||
.def_getset<&TensorWrapper::automatic_name, &TensorWrapper::set_automatic_name>("_name") | |||
.finalize(); | |||
if (!tensor_type) throw py::error_already_set(); | |||
py::setattr(m, "Tensor", tensor_type); | |||
@@ -15,6 +15,7 @@ | |||
#include "megbrain/imperative/interpreter.h" | |||
#include "pybind11/pybind11.h" | |||
#include <string> | |||
#include "./pyext17.h" | |||
@@ -70,6 +71,8 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||
GradInfo m_grad_info; | |||
TraceInfo m_trace_info; | |||
SharedHandle m_handle; | |||
std::string user_custom_name; | |||
std::string automatic_name; | |||
cg::VarNode* m_var; | |||
using Handle = interpreter::Interpreter::Handle; | |||
@@ -170,6 +173,10 @@ struct TensorWrapper { | |||
void set_compiled_info(PyObject *); | |||
PyObject* trace_mixin_info(); | |||
void set_trace_mixin_info(PyObject *); | |||
PyObject* user_custom_name(); | |||
void set_user_custom_name(PyObject *); | |||
PyObject* automatic_name(); | |||
void set_automatic_name(PyObject *); | |||
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | |||
}; | |||
@@ -0,0 +1,169 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import io | |||
import numpy as np | |||
import pytest | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.utils.comp_graph_tools as cgtools | |||
from megengine import Parameter, Tensor | |||
from megengine.core.tensor import megbrain_graph as G | |||
from megengine.jit.tracing import trace | |||
from megengine.utils.naming import auto_naming | |||
def _dump_and_load(func, symbolic, keep_opr_name=True): | |||
auto_naming.clear() | |||
func = trace(func, symbolic=symbolic, capture_as_const=True) | |||
x = Tensor(np.ones(shape=(2, 3))) | |||
func(x).numpy() | |||
file = io.BytesIO() | |||
func.dump( | |||
file, | |||
optimize_for_inference=False, | |||
arg_names="x", | |||
keep_opr_name=keep_opr_name, | |||
keep_var_name=2, | |||
) | |||
file.seek(0) | |||
*_, outputs = G.load_graph(file) | |||
op = cgtools.get_oprs_seq(outputs)[-1] | |||
return op | |||
@pytest.mark.parametrize("symbolic", [False, True]) | |||
def test_auto_naming(symbolic): | |||
class Simple(M.Module): | |||
def __init__(self, name): | |||
super().__init__() | |||
self.name = name | |||
def forward(self, x): | |||
return x + x | |||
m = Simple("simple") | |||
op = _dump_and_load(m, symbolic) | |||
assert op.name == "simple.ADD" | |||
assert op.outputs[0].name == "simple.ADD" | |||
@pytest.mark.parametrize("symbolic", [False, True]) | |||
def test_user_named_tensor(symbolic): | |||
class Simple(M.Module): | |||
def __init__(self, name): | |||
super().__init__() | |||
self.name = name | |||
self.k = Parameter(1.0, name="k") | |||
def forward(self, x): | |||
x = x + x | |||
x.name = "o_x" | |||
return x | |||
m = Simple("simple") | |||
op = _dump_and_load(m, symbolic) | |||
assert op.name == "simple.ADD" | |||
assert op.outputs[0].name == "o_x" | |||
@pytest.mark.parametrize("symbolic", [False, True]) | |||
def test_user_named_param(symbolic): | |||
class Simple(M.Module): | |||
def __init__(self, name): | |||
super().__init__() | |||
self.name = name | |||
self.k = Parameter(2.0, name="k") | |||
def forward(self, x): | |||
return self.k * x | |||
m = Simple("simple") | |||
op = _dump_and_load(m, symbolic) | |||
assert op.inputs[0].name == "x" | |||
assert op.inputs[1].name == "simple.k" | |||
@pytest.mark.parametrize("symbolic", [False, True]) | |||
def test_without_module(symbolic): | |||
def f(x): | |||
return 2 * x | |||
op = _dump_and_load(f, symbolic) | |||
assert op.name == "MUL" | |||
@pytest.mark.parametrize("symbolic", [False, True]) | |||
def test_with_submodule(symbolic): | |||
class Simple(M.Module): | |||
def __init__(self, name): | |||
super().__init__() | |||
self.name = name | |||
self.linear = M.Linear(3, 3) | |||
def forward(self, x): | |||
x = self.linear(x) | |||
return x | |||
m = Simple("simple") | |||
op = _dump_and_load(m, symbolic) | |||
assert op.name == "simple.linear.ADD" | |||
assert op.inputs[0].owner.name == "simple.linear.MatrixMul" | |||
assert op.outputs[0].name == "simple.linear.ADD" | |||
@pytest.mark.parametrize("symbolic", [False, True]) | |||
def test_named_submodule(symbolic): | |||
class Simple(M.Module): | |||
def __init__(self, name): | |||
super().__init__() | |||
self.name = name | |||
self.linear = M.Linear(3, 3, name="x") | |||
def forward(self, x): | |||
x = self.linear(x) | |||
return x | |||
m = Simple("simple") | |||
op = _dump_and_load(m, symbolic) | |||
assert op.name == "simple.x.ADD" | |||
assert op.inputs[0].owner.name == "simple.x.MatrixMul" | |||
assert op.outputs[0].name == "simple.x.ADD" | |||
@pytest.mark.parametrize("symbolic", [False, True]) | |||
def test_with_same_operators(symbolic): | |||
class Simple(M.Module): | |||
def __init__(self, name): | |||
super().__init__() | |||
self.name = name | |||
def forward(self, x): | |||
x = F.relu(x) | |||
x = F.relu(x) | |||
return x | |||
m = Simple("simple") | |||
op = _dump_and_load(m, symbolic) | |||
assert op.name == "simple.RELU[1]" | |||
assert op.inputs[0].owner.name == "simple.RELU[0]" | |||
def test_not_keep_opr_name(): | |||
def f(x): | |||
return 2 * x | |||
op = _dump_and_load(f, True, False) | |||
assert op.name == "MUL(x,2[2])[4]" |
@@ -148,7 +148,7 @@ def test_dump(): | |||
dump_info = f.dump(file) | |||
assert dump_info.nr_opr == 3 | |||
np.testing.assert_equal(dump_info.inputs, ["arg_0", "arg_1"]) | |||
np.testing.assert_equal(dump_info.outputs, ["ADD(arg_0,arg_1)[4]"]) | |||
np.testing.assert_equal(dump_info.outputs, ["ADD"]) | |||
file.seek(0) | |||
infer_cg = cgtools.GraphInference(file) | |||
result = list((infer_cg.run(a, b)).values())[0] | |||
@@ -75,10 +75,6 @@ std::vector<std::pair<const char*, std::string>> OpDef::props( | |||
return def.trait()->props(def); | |||
} | |||
const char* OpDef::name() const { | |||
return trait()->name; | |||
} | |||
std::string OpDef::to_string() const { | |||
std::string builder = "{"; | |||
for (auto&& [name, value]: props(*this)) { | |||
@@ -107,6 +103,20 @@ const OpTrait* OpDef::trait() const { | |||
return m_trait; | |||
} | |||
const std::string OpDef::scope() const { | |||
return m_scope; | |||
} | |||
void OpDef::set_scope(const std::string& scope) { | |||
m_scope = scope; | |||
} | |||
const std::string OpDef::make_name() const { | |||
if (m_scope.empty()) | |||
return trait()->make_name(*this); | |||
return m_scope + "." + trait()->make_name(*this); | |||
} | |||
} // namespace imperative | |||
} // namespace mgb | |||
@@ -75,6 +75,7 @@ using GradMaker = detail::OpMeth< | |||
using Props = detail::OpMeth<decltype(OpDef::props)>; | |||
using HashFunc = detail::OpMeth<size_t(const OpDef&)>; | |||
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | |||
using MakeNameFunc = detail::OpMeth<std::string(const OpDef&)>; | |||
struct OpTrait { | |||
const char* name; | |||
@@ -88,6 +89,7 @@ struct OpTrait { | |||
Props props; | |||
HashFunc hash; | |||
IsSame is_same_st; | |||
MakeNameFunc make_name; | |||
OpTrait(const char* name); | |||
static OpTrait* find_by_name(const char* name); | |||
static OpTrait* find_by_typeinfo(Typeinfo* type); | |||
@@ -104,7 +106,8 @@ struct OpTrait { | |||
cb(make_backward_graph) \ | |||
cb(props) \ | |||
cb(hash) \ | |||
cb(is_same_st) | |||
cb(is_same_st) \ | |||
cb(make_name) | |||
struct OpTraitRegistry { | |||
OpTrait* trait; | |||
@@ -30,13 +30,14 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
size_t nr_inp = inputs.size(); | |||
mgb_assert(nr_inp == 3 ||nr_inp == 5, | |||
"BatchNorm expects 3 or 5 inputs; got %lu actually", nr_inp); | |||
OperatorNodeConfig config{bn_opr.make_name()}; | |||
if (nr_inp == 3) { | |||
return opr::BatchNorm::make( | |||
inputs[0], inputs[1], inputs[2], bn_opr.param())[0] | |||
inputs[0], inputs[1], inputs[2], bn_opr.param(), config)[0] | |||
.node()->owner_opr(); | |||
} else { | |||
return opr::BatchNorm::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param())[0] | |||
inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], bn_opr.param(), config)[0] | |||
.node()->owner_opr(); | |||
} | |||
} | |||
@@ -27,10 +27,11 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
cg::OperatorNodeBase* apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
def.cast_final_safe<Broadcast>(); | |||
auto&& op = def.cast_final_safe<Broadcast>(); | |||
size_t nr_inp = inputs.size(); | |||
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
return opr::Broadcast::make(inputs[0], inputs[1]).node()->owner_opr(); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr(); | |||
} | |||
bool valid_broadcast(const TensorShape& src_shape, | |||
@@ -96,7 +97,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Reshape&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::Reshape::make(inputs[0], inputs[1], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Reshape::make(inputs[0], inputs[1], op.param(), config); | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
@@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
auto disable = std::make_shared<DTypeScalar>(); | |||
disable->set(0); | |||
cg::OperatorNodeConfig config; | |||
OperatorNodeConfig config{comm.make_name()}; | |||
if (comm.comp_node.size() > 0) { | |||
config.comp_node(CompNode::load(comm.comp_node)); | |||
} | |||
@@ -23,12 +23,12 @@ namespace { | |||
cg::OperatorNodeBase* apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
def.cast_final_safe<CondTake>(); | |||
auto&& op = def.cast_final_safe<CondTake>(); | |||
auto&& graph = inputs[0]->owner_graph(); | |||
opr::CondTake::Param param; | |||
param.val = 1; | |||
cg::OperatorNodeConfig config; | |||
OperatorNodeConfig config{op.make_name()}; | |||
cg::OperatorNodeBase* opr = graph->insert_opr( | |||
std::make_unique<opr::CondTake>( | |||
inputs[0], inputs[1], param, config)); | |||
@@ -31,7 +31,8 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& elemwise_opr = def.cast_final_safe<Elemwise>(); | |||
return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr(); | |||
OperatorNodeConfig config{elemwise_opr.make_name()}; | |||
return opr::Elemwise::make(inputs, elemwise_opr.mode, config).node()->owner_opr(); | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
@@ -23,7 +23,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const CvtColor&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::CvtColor::make(inputs[0], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::CvtColor::make(inputs[0], op.param(), config); | |||
} | |||
OP_TRAIT_REG(CvtColor, CvtColor) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -32,7 +32,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( | |||
ssprintf("%s:%d", send.addr.data(), send.port)); | |||
auto&& graph = inputs[0]->owner_graph(); | |||
cg::OperatorNodeConfig config; | |||
OperatorNodeConfig config{send.make_name()}; | |||
cg::OperatorNodeBase* opr = | |||
graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>( | |||
send.key, inputs[0], group_client, true, config)); | |||
@@ -42,11 +42,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send( | |||
cg::OperatorNodeBase* apply_on_var_node_remote_recv( | |||
const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& recv = def.cast_final_safe<RemoteRecv>(); | |||
OperatorNodeConfig config{recv.cn}; | |||
config.name(recv.make_name()); | |||
auto group_client = std::make_shared<GroupClientProxy>( | |||
ssprintf("%s:%d", recv.addr.data(), recv.port)); | |||
auto&& graph = inputs[0]->owner_graph(); | |||
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>( | |||
recv.key, inputs[0], *graph, group_client, OperatorNodeConfig{recv.cn}, | |||
recv.key, inputs[0], *graph, group_client, config, | |||
recv.shape, recv.dtype)); | |||
} | |||
@@ -21,8 +21,10 @@ namespace { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = def.cast_final_safe<MatrixInverse>(); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::MatrixInverse::make(inputs[0]); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::MatrixInverse::make(inputs[0], {}, config); | |||
} | |||
OP_TRAIT_REG(MatrixInverse, MatrixInverse) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -29,7 +29,9 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
param.iou_thresh = nms_keep.iou_thresh; | |||
param.max_output = nms_keep.max_output; | |||
return NMSKeepOpr::make(inputs[0], param).node()->owner_opr(); | |||
OperatorNodeConfig config{nms_keep.make_name()}; | |||
return NMSKeepOpr::make(inputs[0], param, config).node()->owner_opr(); | |||
} | |||
OP_TRAIT_REG(NMSKeep, NMSKeep, NMSKeepOpr) | |||
@@ -79,11 +79,13 @@ public: | |||
cg::OperatorNodeBase* apply_on_var_node( | |||
const OpDef& def, const VarNodeArray& inputs) { | |||
auto&& attr = def.cast_final_safe<OprAttr>(); | |||
auto config = attr.config; | |||
config.name(attr.make_name()); | |||
mgb_assert(!inputs.empty()); | |||
auto registry = serialization::OprRegistry::find_by_name(attr.type); | |||
mgb_assert(registry, "operator %s not found", attr.type.c_str()); | |||
OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | |||
return registry->loader(ctx, inputs, attr.config); | |||
return registry->loader(ctx, inputs, config); | |||
} | |||
std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | |||
@@ -99,10 +101,15 @@ std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | |||
return {}; | |||
} | |||
std::string make_name(const OpDef& def) { | |||
return "OprAttr"; | |||
} | |||
OP_TRAIT_REG(OprAttr, OprAttr) | |||
.make_from_op_node(make_from_op_node) | |||
.apply_on_var_node(apply_on_var_node) | |||
.props(props) | |||
.make_name(make_name) | |||
.fallback(); | |||
} // anonymous namespace | |||
@@ -24,7 +24,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Resize&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::Resize::make(inputs[0], inputs[1], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Resize::make(inputs[0], inputs[1], op.param(), config); | |||
} | |||
OP_TRAIT_REG(Resize, Resize) | |||
@@ -46,7 +46,8 @@ auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const Convolution&>(def); | |||
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy()); | |||
OperatorNodeConfig config{conv.make_name()}; | |||
return opr::Convolution::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
} | |||
OP_TRAIT_REG(Convolution, Convolution, opr::Convolution) | |||
@@ -60,7 +61,7 @@ auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const ConvolutionBackwardData&>(def); | |||
cg::OperatorNodeConfig config; | |||
OperatorNodeConfig config{conv.make_name()}; | |||
if (inputs.size() == 2) { | |||
return opr::ConvolutionBackwardData::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
} else { | |||
@@ -88,7 +89,8 @@ auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& ds = static_cast<const Dimshuffle&>(def); | |||
return opr::Dimshuffle::make(inputs[0], ds.pattern); | |||
OperatorNodeConfig config{ds.make_name()}; | |||
return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config); | |||
} | |||
OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) | |||
@@ -107,7 +109,8 @@ auto apply_on_var_node( | |||
for (auto&& i : add_axis.axis) { | |||
param.push_back(Desc::make_add(i)); | |||
} | |||
return opr::AxisAddRemove::make(inputs[0], param); | |||
OperatorNodeConfig config{add_axis.make_name()}; | |||
return opr::AxisAddRemove::make(inputs[0], param, config); | |||
} | |||
OP_TRAIT_REG(AddAxis, AddAxis) | |||
@@ -125,7 +128,8 @@ auto apply_on_var_node( | |||
for (auto&& i : remove_axis.axis) { | |||
param.push_back(Desc::make_remove(i)); | |||
} | |||
return opr::AxisAddRemove::make(inputs[0], param); | |||
OperatorNodeConfig config{remove_axis.make_name()}; | |||
return opr::AxisAddRemove::make(inputs[0], param, config); | |||
} | |||
OP_TRAIT_REG(RemoveAxis, RemoveAxis) | |||
@@ -138,7 +142,8 @@ auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& topk = static_cast<const TopK&>(def); | |||
return opr::TopK::make(inputs[0], inputs[1], topk.param())[0] | |||
OperatorNodeConfig config{topk.make_name()}; | |||
return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0] | |||
.node()->owner_opr(); | |||
} | |||
@@ -152,10 +157,12 @@ auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& reduce = static_cast<const Reduce&>(def); | |||
OperatorNodeConfig config{reduce.make_name()}; | |||
if (inputs.size() > 1) { | |||
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1]); | |||
return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); | |||
} else { | |||
return opr::Reduce::make(inputs[0], reduce.param()); | |||
return opr::Reduce::make( | |||
inputs[0], reduce.param(), (cg::VarNode*)nullptr, config); | |||
} | |||
} | |||
@@ -175,7 +182,8 @@ auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& pool = static_cast<const AdaptivePooling&>(def); | |||
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param()); | |||
OperatorNodeConfig config{pool.make_name()}; | |||
return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config); | |||
} | |||
OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) | |||
@@ -189,6 +197,7 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const ConvBias&>(def); | |||
cg::OperatorNodeConfig config{conv.dtype}; | |||
config.name(conv.make_name()); | |||
if (inputs.size() == 2) { | |||
return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
} else if (inputs.size() == 3) { | |||
@@ -210,6 +219,7 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& conv = static_cast<const BatchConvBias&>(def); | |||
cg::OperatorNodeConfig config{conv.dtype}; | |||
config.name(conv.make_name()); | |||
if (inputs.size() == 2) { | |||
return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
} else if (inputs.size() == 3) { | |||
@@ -230,7 +240,8 @@ auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& pool = static_cast<const Pooling&>(def); | |||
return opr::Pooling::make(inputs[0], pool.param()); | |||
OperatorNodeConfig config{pool.make_name()}; | |||
return opr::Pooling::make(inputs[0], pool.param(), config); | |||
} | |||
OP_TRAIT_REG(Pooling, Pooling) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -243,8 +254,9 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& matmul = static_cast<const MatrixMul&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{matmul.make_name()}; | |||
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(), | |||
matmul.policy()); | |||
matmul.policy(), config); | |||
} | |||
OP_TRAIT_REG(MatrixMul, MatrixMul) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -257,8 +269,9 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
OperatorNodeConfig config{matmul.make_name()}; | |||
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(), | |||
matmul.policy()); | |||
matmul.policy(), config); | |||
} | |||
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -267,10 +280,12 @@ OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||
namespace { namespace dot { | |||
auto apply_on_var_node( | |||
const OpDef&, | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = def.cast_final_safe<Dot>(); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::Dot::make(inputs[0], inputs[1]); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Dot::make(inputs[0], inputs[1], config); | |||
} | |||
OP_TRAIT_REG(Dot, Dot) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -282,7 +297,8 @@ auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& argsort = static_cast<const Argsort&>(def); | |||
return opr::Argsort::make(inputs[0], argsort.param()); | |||
OperatorNodeConfig config{argsort.make_name()}; | |||
return opr::Argsort::make(inputs[0], argsort.param(), config); | |||
} | |||
OP_TRAIT_REG(Argsort, Argsort) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -294,7 +310,8 @@ auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& argmax = static_cast<const Argmax&>(def); | |||
return opr::Argmax::make(inputs[0], argmax.param()); | |||
OperatorNodeConfig config{argmax.make_name()}; | |||
return opr::Argmax::make(inputs[0], argmax.param(), config); | |||
} | |||
OP_TRAIT_REG(Argmax, Argmax) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -306,7 +323,8 @@ auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& argmin = static_cast<const Argmin&>(def); | |||
return opr::Argmin::make(inputs[0], argmin.param()); | |||
OperatorNodeConfig config{argmin.make_name()}; | |||
return opr::Argmin::make(inputs[0], argmin.param(), config); | |||
} | |||
OP_TRAIT_REG(Argmin, Argmin) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -318,11 +336,13 @@ auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& warp = static_cast<const WarpPerspective&>(def); | |||
OperatorNodeConfig config{warp.make_name()}; | |||
if (inputs.size() == 3) { | |||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param()); | |||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config); | |||
} else { | |||
mgb_assert(inputs.size() == 4); | |||
return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], inputs[3], warp.param()); | |||
return opr::WarpPerspective::make( | |||
inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config); | |||
} | |||
} | |||
OP_TRAIT_REG(WarpPerspective, WarpPerspective) | |||
@@ -336,7 +356,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& local = static_cast<const GroupLocal&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::GroupLocal::make(inputs[0], inputs[1], local.param()); | |||
OperatorNodeConfig config{local.make_name()}; | |||
return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config); | |||
} | |||
OP_TRAIT_REG(GroupLocal, GroupLocal) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -349,7 +370,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const IndexingOneHot&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); | |||
} | |||
OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -362,7 +384,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const IndexingSetOneHot&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
} | |||
OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -375,7 +398,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const TypeCvt&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::TypeCvt::make(inputs[0], op.dtype); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::TypeCvt::make(inputs[0], op.dtype, config); | |||
} | |||
OP_TRAIT_REG(TypeCvt, TypeCvt) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -388,6 +412,7 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Concat&>(def); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
config.name(op.make_name()); | |||
return opr::Concat::make(inputs, op.axis, config); | |||
} | |||
OP_TRAIT_REG(Concat, Concat) | |||
@@ -402,6 +427,7 @@ auto apply_on_var_node( | |||
auto&& op = static_cast<const Copy&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
config.name(op.make_name()); | |||
return opr::Copy::make(inputs[0], config); | |||
} | |||
OP_TRAIT_REG(Copy, Copy) | |||
@@ -411,10 +437,12 @@ OP_TRAIT_REG(Copy, Copy) | |||
namespace { namespace identity { | |||
auto apply_on_var_node( | |||
const OpDef&, | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = def.cast_final_safe<Identity>(); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::Identity::make(inputs[0]); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Identity::make(inputs[0], config); | |||
} | |||
OP_TRAIT_REG(Identity, Identity) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -427,7 +455,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const AssertEqual&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::AssertEqual::make(inputs[0],inputs[1],op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config); | |||
} | |||
@@ -443,7 +472,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const UniformRNG&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::UniformRNG::make(inputs[0], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::UniformRNG::make(inputs[0], op.param(), config); | |||
} | |||
OP_TRAIT_REG(UniformRNG, UniformRNG) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -456,7 +486,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const GaussianRNG&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::GaussianRNG::make(inputs[0], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::GaussianRNG::make(inputs[0], op.param(), config); | |||
} | |||
OP_TRAIT_REG(GaussianRNG, GaussianRNG) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -469,7 +500,9 @@ VarNodeArray apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ROIAlign&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param()).node()->owner_opr(); | |||
OperatorNodeConfig config{op.make_name()}; | |||
auto* opr = opr::ROIAlign::make( | |||
inputs[0], inputs[1], op.param(), config).node()->owner_opr(); | |||
return {opr->output(0), opr->output(1)}; | |||
} | |||
OP_TRAIT_REG(ROIAlign, ROIAlign) | |||
@@ -484,7 +517,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const NvOf&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::NvOf::make(inputs[0], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::NvOf::make(inputs[0], op.param(), config); | |||
} | |||
OP_TRAIT_REG(NvOf, NvOf) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -499,6 +533,7 @@ auto apply_on_var_node( | |||
auto&& op = static_cast<const Linspace&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
config.name(op.make_name()); | |||
return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
} | |||
OP_TRAIT_REG(Linspace, Linspace) | |||
@@ -513,6 +548,7 @@ auto apply_on_var_node( | |||
auto&& op = static_cast<const Eye&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
cg::OperatorNodeConfig config{op.comp_node}; | |||
config.name(op.make_name()); | |||
opr::Eye::Param param{op.k, op.dtype.enumv()}; | |||
return opr::Eye::make(inputs[0], param, config); | |||
} | |||
@@ -527,7 +563,10 @@ VarNodeArray apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ROIPooling&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()).node()->owner_opr(); | |||
OperatorNodeConfig config{op.make_name()}; | |||
auto* opr = opr::ROIPooling::make( | |||
inputs[0], inputs[1], inputs[2], op.param(), config | |||
).node()->owner_opr(); | |||
return {opr->output(0), opr->output(1)}; | |||
} | |||
OP_TRAIT_REG(ROIPooling, ROIPooling) | |||
@@ -541,7 +580,8 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const Remap&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::Remap::make(inputs[0], inputs[1], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::Remap::make(inputs[0], inputs[1], op.param(), config); | |||
} | |||
OP_TRAIT_REG(Remap, Remap) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -578,7 +618,8 @@ auto apply_on_var_node( \ | |||
const OpDef& def, \ | |||
const VarNodeArray& inputs) { \ | |||
auto&& op = static_cast<const NAME&>(def); \ | |||
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \ | |||
OperatorNodeConfig config{op.make_name()}; \ | |||
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \ | |||
} \ | |||
OP_TRAIT_REG(NAME, NAME) \ | |||
.apply_on_var_node(apply_on_var_node) \ | |||
@@ -609,30 +650,35 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const FakeQuant&>(def); | |||
mgb_assert(inputs.size() == 3); | |||
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
} | |||
OP_TRAIT_REG(FakeQuant, FakeQuant) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // fake_quant | |||
namespace { namespace tqt { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const TQT&>(def); | |||
mgb_assert(inputs.size() == 2); | |||
return opr::TQT::make(inputs[0], inputs[1], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::TQT::make(inputs[0], inputs[1], op.param(), config); | |||
} | |||
OP_TRAIT_REG(TQT, TQT) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // tqt | |||
namespace { namespace elemwise_multi_type { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const ElemwiseMultiType&>(def); | |||
OperatorNodeConfig config{op.dtype}; | |||
config.name(op.make_name()); | |||
return opr::ElemwiseMultiType::make(inputs, op.param(), config); | |||
} | |||
OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) | |||
@@ -646,7 +692,9 @@ auto apply_on_var_node( | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const SVD&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
return opr::SVD::make(inputs[0], op.param())[0].node()->owner_opr()->usable_output(); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::SVD::make(inputs[0], op.param(), config)[0] | |||
.node()->owner_opr()->usable_output(); | |||
} | |||
OP_TRAIT_REG(SVD, SVD) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -21,7 +21,8 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op_def = def.cast_final_safe<GetVarShape>(); | |||
return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); | |||
OperatorNodeConfig config{op_def.make_name()}; | |||
return opr::GetVarShape::make(inputs, op_def.param(), config).node()->owner_opr(); | |||
} | |||
DispatchMode decide_dispatch_mode( | |||
@@ -152,7 +153,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( | |||
auto&& graph = inputs[0]->owner_graph(); | |||
auto&& shapes = get_shapes(param.shapes); | |||
cg::OperatorNodeConfig config; | |||
OperatorNodeConfig config(param.make_name()); | |||
cg::OperatorNodeBase* opr = | |||
graph->insert_opr(std::make_unique<mgb::opr::ParamPackSplit>( | |||
inputs[0], param.offsets, shapes, config)); | |||
@@ -189,7 +190,7 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( | |||
auto&& graph = inputs[0]->owner_graph(); | |||
VarNodeArray inps(inputs.begin(), inputs.end() - 1); | |||
cg::OperatorNodeConfig config; | |||
OperatorNodeConfig config{param.make_name()}; | |||
cg::OperatorNodeBase* opr = | |||
graph->insert_opr(std::make_unique<mgb::opr::ParamPackConcat>( | |||
inps, inputs.back(), param.offsets, config)); | |||
@@ -20,8 +20,9 @@ namespace { namespace tensorrt_runtime { | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const TensorRTRuntime&>(def); | |||
OperatorNodeConfig config{op.make_name()}; | |||
SymbolVarArray sinputs(inputs.begin(), inputs.end()); | |||
return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs); | |||
return opr::TensorRTRuntimeOpr::make(op.buf.c_str(), op.buf_size, sinputs, config); | |||
} | |||
OP_TRAIT_REG(TensorRTRuntime, TensorRTRuntime) | |||
.apply_on_var_node(apply_on_var_node) | |||
@@ -21,7 +21,8 @@ namespace { namespace warp_affine { | |||
const VarNodeArray& inputs) { | |||
mgb_assert(inputs.size() == 3); | |||
auto&& op = static_cast<const WarpAffine&>(def); | |||
return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param()); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::WarpAffine::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
} | |||
OP_TRAIT_REG(WarpAffine, WarpAffine) | |||
@@ -36,6 +36,7 @@ class OpDef : public Hashable, | |||
public NonCopyableObj, | |||
public std::enable_shared_from_this<OpDef> { | |||
mutable const OpTrait* m_trait = nullptr; | |||
std::string m_scope; | |||
public: | |||
virtual ~OpDef() = default; | |||
@@ -86,10 +87,14 @@ public: | |||
const OpTrait* trait() const; | |||
const char* name() const; | |||
std::string to_string() const; | |||
const std::string scope() const; | |||
const std::string make_name() const; | |||
void set_scope(const std::string& scope); | |||
virtual size_t hash() const; | |||
virtual bool is_same_st(const Hashable&) const; | |||
@@ -113,9 +113,10 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) { | |||
"{0}({0}_)", i.name | |||
)); | |||
} | |||
paramList.push_back("std::string scope_ = {}"); | |||
gen_ctor(llvm::join(paramList, ", "), | |||
": " + llvm::join(initList, ", "), | |||
" {}"); | |||
" { set_scope(scope_); }"); | |||
} | |||
auto packedParams = op.getPackedParams(); | |||
@@ -236,11 +237,19 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) { | |||
os << mlir::tblgen::tgfmt(hashable->getPropsFunctionTemplate(), &ctx); | |||
os << "}\n"; | |||
// generate make_name() | |||
os << formatv( | |||
"std::string {0}(const OpDef& def_) {{\n", formatMethImpl("make_name") | |||
); | |||
os << mlir::tblgen::tgfmt(hashable->getNameFunctionTemplate(), &ctx); | |||
os << "}\n"; | |||
os << "} // anonymous namespace\n"; | |||
methods.push_back("hash"); | |||
methods.push_back("is_same_st"); | |||
methods.push_back("props"); | |||
methods.push_back("make_name"); | |||
} | |||
if (!methods.empty()) { | |||
os << formatv( | |||
@@ -327,7 +336,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& | |||
targs.push_back(i.attr.getReturnType()); | |||
} | |||
os << llvm::join(targs, ", "); | |||
os << ">()"; | |||
os << ", std::string>()"; | |||
for (auto &&i : op.getMgbAttributes()) { | |||
os << formatv(", py::arg(\"{0}\")", i.name); | |||
auto defaultValue = i.attr.getDefaultValue(); | |||
@@ -337,7 +346,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext& | |||
hasDefaultCtor = true; | |||
} | |||
} | |||
os << ")"; | |||
os << ", py::arg(\"scope\") = {})"; | |||
} | |||
if (hasDefaultCtor) { | |||
os << "\n .def(py::init<>())"; | |||
@@ -442,6 +451,10 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||
className, i.name)); | |||
} | |||
getsetters.push_back(formatv( | |||
"{{\"scope\", py_get_scope({0}), py_set_scope({0}), \"scope\", NULL},", | |||
className)); | |||
// generate tp_init | |||
std::string initBody; | |||
if (!op.getMgbAttributes().empty()) { | |||
@@ -449,6 +462,7 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
initBody += formatv("\"{0}\", ", attr.name); | |||
}); | |||
initBody += "\"scope\", "; | |||
initBody += "NULL};\n"; | |||
initBody += " PyObject "; | |||
std::vector<std::string> attrs; | |||
@@ -456,12 +470,15 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||
attrs.push_back(formatv("*{0} = NULL", attr.name)); | |||
}); | |||
initBody += llvm::join(attrs, ", ") + ";\n"; | |||
initBody += " PyObject *scope = NULL;\n"; | |||
initBody += " if (!PyArg_ParseTupleAndKeywords(args, kwds, \"|"; | |||
initBody += std::string(op.getMgbAttributes().size(), 'O'); | |||
// an extra slot created for name | |||
initBody += std::string(op.getMgbAttributes().size() + 1, 'O'); | |||
initBody += "\", const_cast<char**>(kwlist)"; | |||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
initBody += formatv(" ,&{0}", attr.name); | |||
initBody += formatv(", &{0}", attr.name); | |||
}); | |||
initBody += ", &scope"; | |||
initBody += "))\n"; | |||
initBody += " return -1;\n"; | |||
llvm::for_each(op.getMgbAttributes(), [&](auto&& attr) { | |||
@@ -483,6 +500,25 @@ EnumWrapper<{0}::{1}>::type2str = {{ | |||
} | |||
)", className, attr.name); | |||
}); | |||
initBody += formatv(R"( | |||
if (scope) {{ | |||
try {{ | |||
reinterpret_cast<PyOp({0})*>(self)->inst().set_scope( | |||
pyobj_convert_generic<std::string>::from(scope)); | |||
} catch(py::error_already_set& e) {{ | |||
e.restore(); | |||
return -1; | |||
} catch(py::builtin_exception& e) {{ | |||
e.set_error(); | |||
return -1; | |||
} catch(...) {{ | |||
PyErr_SetString(PyExc_RuntimeError, "Unknown Error"); | |||
return -1; | |||
} | |||
} | |||
)", className); | |||
} | |||
initBody += "\n return 0;"; | |||
@@ -241,6 +241,30 @@ private: | |||
body += " return props_;\n"; | |||
return body; | |||
} | |||
std::string getModeName() const { | |||
std::string body = formatv( | |||
" auto&& op_ = def_.cast_final_safe<{0}>();\n" | |||
" static_cast<void>(op_);\n", | |||
getCppClassName() | |||
); | |||
for (auto&& it : getMgbAttributes()) { | |||
if (it.name == "mode") { | |||
auto* enumAttr = llvm::dyn_cast<MgbEnumAttrMixin>(&it.attr); | |||
body += " switch (op_.mode){\n"; | |||
for (auto&& enumMember: enumAttr->getEnumMembers()) { | |||
body += formatv( | |||
" case {0}::{1}::{2}:\n", | |||
getCppClassName(), enumAttr->getEnumName(), enumMember | |||
); | |||
body += formatv(" return \"{0}\";\n", enumMember); | |||
} | |||
body += formatv( | |||
" default: return \"{0}::Unknown\";\n", getCppClassName()); | |||
body += " }\n"; | |||
} | |||
} | |||
return body; | |||
} | |||
public: | |||
static bool classof(const Operator* op) { | |||
return op->getDef().isSubClassOf("MgbHashableOpMixin"); | |||
@@ -264,6 +288,12 @@ public: | |||
} | |||
return getDefaultPropsFunction(); | |||
} | |||
std::string getNameFunctionTemplate() const { | |||
if (getDef().getValueAsBit("usingModeName")) { | |||
return getModeName(); | |||
} | |||
return formatv(" return \"{0}\";\n", getCppClassName()); | |||
} | |||
}; | |||
} // namespace tblgen | |||
@@ -476,6 +476,7 @@ def main(): | |||
output_mgbvars = feeds["outputs"] | |||
output_mgbvars = optimize_for_inference(args, output_mgbvars) | |||
output_mgbvars = [var._node for var in output_mgbvars] | |||
inputs = cgtools.get_dep_vars(output_mgbvars, "Host2DeviceCopy") | |||
inputs = sorted((i.name, i.dtype) for i in inputs) | |||
@@ -242,6 +242,7 @@ class MgbPackedParamBase<string className, string accessor>: | |||
class MgbHashableOpMixin { | |||
string hashFunction = ?; | |||
string cmpFunction = ?; | |||
bit usingModeName = 0; | |||
} | |||
class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>: | |||
@@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" | |||
def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> { | |||
let inputs = (ins Variadic<AnyType>:$input); | |||
let results = (outs AnyType); | |||
let usingModeName = 1; | |||
} | |||
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>; | |||
@@ -247,6 +248,7 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara | |||
let extraArguments = (ins | |||
MgbDTypeAttr:$dtype | |||
); | |||
let usingModeName = 1; | |||
} | |||
def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; | |||