From a926878c01d50c836dcf5ebc802c8aef19d4ce82 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 15 May 2022 20:52:59 +0800 Subject: [PATCH] feat(imperative): remove symbolvar of imperative GitOrigin-RevId: 16da6d1491526b707ea6851fb68e330c02cc788a --- .../python/megengine/core/tensor/array_method.py | 12 +- imperative/python/megengine/core/tensor/utils.py | 5 +- imperative/python/megengine/functional/elemwise.py | 2 +- imperative/python/megengine/functional/math.py | 2 +- imperative/python/megengine/functional/nn.py | 4 +- imperative/python/megengine/functional/tensor.py | 36 ++-- .../python/megengine/functional/tensor_cache.py | 4 +- imperative/python/megengine/tensor.py | 2 + imperative/python/megengine/traced_module/expr.py | 2 +- imperative/python/megengine/utils/network.py | 2 +- imperative/python/megengine/utils/network_node.py | 112 +++++++----- imperative/python/src/graph_rt.cpp | 74 ++++---- imperative/python/src/graph_rt.h | 3 + imperative/python/src/tensor.cpp | 200 +++++---------------- imperative/python/src/tensor.h | 12 +- imperative/python/src/tensor_utils.cpp | 172 +++++++----------- imperative/python/src/transformation.h | 3 +- imperative/python/test/helpers/utils.py | 2 +- .../python/test/unit/functional/test_tensor.py | 12 ++ imperative/python/test/unit/utils/test_network.py | 6 +- imperative/src/impl/basic_operators.cpp | 8 + .../include/megbrain/imperative/basic_operators.h | 17 ++ .../src/include/megbrain/imperative/basic_values.h | 19 ++ .../megbrain/imperative/transformations/symbol.h | 42 ++++- imperative/src/include/megbrain/imperative/value.h | 1 + 25 files changed, 356 insertions(+), 398 deletions(-) diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 08e83bd1..a8710731 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -7,9 +7,7 @@ from typing import Union import numpy as np from .. import _config -from .._imperative_rt.common import CompNode from .._imperative_rt.core2 import ( - SymbolVar, Tensor, apply, astype_cpp, @@ -17,9 +15,11 @@ from .._imperative_rt.core2 import ( broadcast_cpp, getitem_cpp, matmul_cpp, + reshape_cpp, + setitem_cpp, + squeeze_cpp, + transpose_cpp, ) -from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar -from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp from ..ops import builtin from . import amp from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph @@ -189,9 +189,7 @@ def _todo(*_): def _expand_args(args): if len(args) == 1: - if isinstance( - args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray), - ): + if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),): args = args[0] return args diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 781161af..afa1f08f 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -8,7 +8,6 @@ import numpy as np from .._imperative_rt import make_const from .._imperative_rt.core2 import ( Const, - SymbolVar, Tensor, _get_convert_inputs, _set_convert_inputs, @@ -77,7 +76,7 @@ def result_type(*args): def isscalar(x): - if isinstance(x, (Tensor, SymbolVar)): + if isinstance(x, Tensor): return x._isscalar() return np.isscalar(x) @@ -283,7 +282,7 @@ def interpret_subgraph(func, dtype, device): return results def apply_const(value, dtype=dtype, device=device): - return Const(value, dtype, device, None) + return Const(value, dtype, device) outputs, outputs_has_grad = func(args, apply_expr, apply_const) outputs = [ diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index c47ab9c8..e92cb2d9 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -2,7 +2,7 @@ # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order import numpy as np -from ..core._imperative_rt.core2 import SymbolVar, apply +from ..core._imperative_rt.core2 import apply from ..core.ops import builtin from ..core.ops.builtin import Elemwise from ..core.tensor.array_method import _elwise diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index becd3fdf..c414b9b0 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -538,7 +538,7 @@ def topk( op = builtin.TopK(mode=mode) if not isinstance(k, Tensor): - k = Const(k, "int32", inp.device, None) + k = Const(k, "int32", inp.device) if len(inp.shape) == 1: if kth_only: diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 0e97954e..52900b21 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1222,7 +1222,7 @@ def batch_norm( raise ValueError("Invalid param_dim {}".format(param_dim)) if x is None: - x = Const(value, inp.dtype, inp.device, None) + x = Const(value, inp.dtype, inp.device) shape = astensor1d(pshape, inp, dtype="int32", device=inp.device) (result,) = apply(builtin.Broadcast(), x, shape) return result @@ -1446,7 +1446,7 @@ def sync_batch_norm( def _make_full_if_none(x, value): if x is None: - x = Const(value, inp.dtype, _device, None) + x = Const(value, inp.dtype, _device) (result,) = apply(builtin.Broadcast(), x, reduce_shape) return result elif x.ndim == 1: diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index 3a37e5b4..ada8a3b0 100755 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -7,7 +7,6 @@ import numpy as np from ..core._imperative_rt import CompNode from ..core._imperative_rt.core2 import ( Const, - SymbolVar, apply, broadcast_cpp, dtype_promotion, @@ -151,7 +150,7 @@ def full( shape = (shape,) if device is None: device = get_default_device() - x = Const(value, dtype, device, None) + x = Const(value, dtype, device) if type(shape) in (list, tuple) and len(shape) == 0: return x return broadcast_to(x, shape) @@ -216,7 +215,7 @@ def zeros( return full(shape, 0.0, dtype=dtype, device=device) -def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: +def zeros_like(inp: Tensor) -> Tensor: r"""Returns a tensor filled with zeros with the same shape and data type as input tensor. Args: @@ -235,7 +234,7 @@ def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: return full_like(inp, 0.0) -def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: +def ones_like(inp: Tensor) -> Tensor: r"""Returns a tensor filled with ones with the same shape and data type as input tensor. Args: @@ -253,9 +252,7 @@ def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: return full_like(inp, 1.0) -def full_like( - inp: Union[Tensor, SymbolVar], value: Union[int, float] -) -> Union[Tensor, SymbolVar]: +def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: r"""Returns a tensor filled with given value with the same shape as input tensor. Args: @@ -272,7 +269,7 @@ def full_like( Tensor([[2 2 2] [2 2 2]], dtype=int32, device=xpux:0) """ - x = Const(value, inp.dtype, inp.device, inp) + x = Const(value, inp.dtype, inp.device) if inp.ndim == 0: return x return broadcast_to(x, inp.shape) @@ -668,9 +665,9 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor: >>> print(v.numpy(), index.numpy()) [1. 4.] [0 3] """ - if not isinstance(x, (Tensor, SymbolVar)): + if not isinstance(x, Tensor): raise TypeError("input must be a tensor") - if not isinstance(mask, (Tensor, SymbolVar)): + if not isinstance(mask, Tensor): raise TypeError("mask must be a tensor") if mask.dtype != np.bool_: raise ValueError("mask must be bool") @@ -843,15 +840,11 @@ def linspace( if not (cur_device is None or device == cur_device): raise ("ambiguous device for linspace opr") - is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num]) - if any(is_symbolvar) and not all(is_symbolvar): - raise TypeError("start, stop and num should all be VarNode or none of them") - - if not isinstance(start, (Tensor, SymbolVar)): + if not isinstance(start, Tensor): start = Tensor(start, device=device) - if not isinstance(stop, (Tensor, SymbolVar)): + if not isinstance(stop, Tensor): stop = Tensor(stop, device=device) - if not isinstance(num, (Tensor, SymbolVar)): + if not isinstance(num, Tensor): num = Tensor(num, device=device) op = builtin.Linspace(comp_node=device) @@ -901,9 +894,12 @@ def arange( if stop is None: start, stop = 0, start - start = Tensor(start, dtype="float32") - stop = Tensor(stop, dtype="float32") - step = Tensor(step, dtype="float32") + if not isinstance(start, Tensor): + start = Tensor(start, dtype="float32") + if not isinstance(stop, Tensor): + stop = Tensor(stop, dtype="float32") + if not isinstance(step, Tensor): + step = Tensor(step, dtype="float32") num = ceil((stop - start) / step) stop = start + step * (num - 1) diff --git a/imperative/python/megengine/functional/tensor_cache.py b/imperative/python/megengine/functional/tensor_cache.py index 50846839..415fae56 100644 --- a/imperative/python/megengine/functional/tensor_cache.py +++ b/imperative/python/megengine/functional/tensor_cache.py @@ -7,11 +7,11 @@ small_tensor_cache = {} def _get_scalar_tensor_with_value(value, dtype=None, device=None): global small_tensor_cache if is_tracing(): - ret = Const(value, dtype, device, None) + ret = Const(value, dtype, device) else: cache_key = (value, dtype, device) if cache_key not in small_tensor_cache: - ret = Const(value, dtype, device, None) + ret = Const(value, dtype, device) small_tensor_cache[cache_key] = ret else: ret = small_tensor_cache[cache_key] diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index dc610f98..3f3fa41a 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -154,6 +154,8 @@ class Tensor(_Tensor, ArrayMethodMixin): @name.setter def name(self, name): self._custom_name = name + if name == None: + name = "" self._name = self._prefix + "." + name if self._prefix else name self._set_name(self._name) diff --git a/imperative/python/megengine/traced_module/expr.py b/imperative/python/megengine/traced_module/expr.py index 408ee8a6..e35abdf5 100644 --- a/imperative/python/megengine/traced_module/expr.py +++ b/imperative/python/megengine/traced_module/expr.py @@ -756,7 +756,7 @@ class Constant(Expr): def interpret(self, *inputs): if isinstance(self.value, RawTensor): - return (Const(self.value.numpy(), None, None, None),) + return (Const(self.value.numpy(), None, None),) return (self.value,) def __repr__(self): diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index aed19b5b..398d925b 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -395,7 +395,7 @@ class Network: for ind, var in enumerate(opr.outputs): var.owner = repl_dict[opr] var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) - var.var = repl_dict[opr].outputs[ind].var + var._reset_var(repl_dict[opr].outputs[ind].var) repl_dict[opr].outputs = opr.outputs self._compile() diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index d5778abe..d357da9c 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -6,11 +6,11 @@ from typing import Sequence import numpy as np from ..core import _imperative_rt as rt -from ..core._imperative_rt.core2 import SymbolVar, apply +from ..core._imperative_rt.core2 import apply, set_py_varnode_type from ..core._trace_option import use_symbolic_shape from ..core._wrap import Device from ..core.ops import builtin -from ..core.tensor.array_method import ArrayMethodMixin +from ..tensor import Tensor from .comp_graph_tools import replace_vars from .module_stats import ( preprocess_receptive_field, @@ -23,26 +23,72 @@ class NetworkNode: pass -class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): - pass +class VarNode(NetworkNode, Tensor): + _users = None + _owner = None + _name = None + _id = None + def __new__(cls, var, *, owner_opr=None, name=None): + obj = Tensor.__new__(cls, var) + return obj -class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): - def __init__(self, var=None, *, owner_opr=None, name=None): - SymbolVar.__init__(self, var) - self.users = [] # List[OpNode] - self.owner = owner_opr + def __init__(self, var, *, owner_opr=None, name=None): + self._owner = owner_opr self.name = name - self.id = id(self) @classmethod def load(cls, sym_var, owner_opr): - obj = cls() + obj = cls(sym_var) obj.var = sym_var # mgb varnode obj.name = sym_var.name obj.owner = owner_opr return obj + @property + def users(self): + if self._users is None: + self._users = [] + return self._users + + @property + def owner(self): + return self._owner + + @owner.setter + def owner(self, owner): + self._owner = owner + + @property + def id(self): + if self._id is None: + self._id = id(self) + return self._id + + @property + def var(self): + return super().var() + + @var.setter + def var(self, var): + self._reset(var) + + def _reset(self, other): + if not isinstance(other, Tensor): + other = VarNode(other) + super()._reset(other) + self.owner = None + + def _reset_var(self, var): + origin_owner = self.owner + self.var = var + self.var.name = self.name + self.owner = origin_owner + + @property + def graph(self): + return super().graph() + def _get_var_shape(self, axis=None): opdef = ( builtin.GetVarShape() if axis is None else builtin.GetVarShape(axis=axis) @@ -77,14 +123,6 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): return rst return self._get_var_shape() if self.var else None - @property - def dtype(self): - return self.var.dtype if self.var else None - - @property - def ndim(self): - return super().ndim - def __bool__(self): return False @@ -92,27 +130,11 @@ class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): __int__ = None __float__ = None __complex__ = None + __repr__ = lambda self: "VarNode:" + self.name def __hash__(self): return id(self) - def numpy(self): - return super().numpy() - - def _reset(self, other): - if not isinstance(other, VarNode): - assert self.graph, "VarNode _reset must have graph" - node = ImmutableTensor(other, graph=self.graph) - node.compile(self.graph) - other = node.outputs[0] - if self.owner is not None: - idx = self.owner.outputs.index(self) - self.owner.outputs[idx] = VarNode( - self.var, owner_opr=self.owner, name=self.var.name - ) - self.var = other.var - self.owner = None - def set_owner_opr(self, owner_opr): self.owner = owner_opr @@ -158,8 +180,7 @@ class OpNode(NetworkNode): assert len(outputs) == len(self.outputs) self._opr = outputs[0].owner for i in range(len(self.outputs)): - self.outputs[i].var = outputs[i] - self.outputs[i].var.name = self.outputs[i].name + self.outputs[i]._reset_var(outputs[i]) assert self.outputs[i].owner is self def add_inp_var(self, x): @@ -214,8 +235,9 @@ class Host2DeviceCopy(OpNode): outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) self._opr = outputs.owner if len(self.outputs) == 0: - self.outputs.append(VarNode(owner_opr=self, name=self.name)) - self.outputs[0].var = outputs + self.outputs.append(VarNode(outputs, owner_opr=self, name=self.name)) + else: + self.outputs[0]._reset_var(outputs) assert self.outputs[0].owner is self @@ -262,8 +284,9 @@ class ConstOpBase(OpNode): data = data.astype(np.int32) varnode = type(self).rt_fun(self.graph, data, cn, data.dtype, self.name) if len(self.outputs) == 0: - self.outputs.append(VarNode(owner_opr=self, name=self.name)) - self.outputs[0].var = varnode + self.outputs.append(VarNode(varnode, owner_opr=self, name=self.name)) + else: + self.outputs[0]._reset_var(varnode) self._opr = varnode.owner @classmethod @@ -313,7 +336,7 @@ class ReadOnlyOpNode(OpNode): if bool(repl_dict): out_vars = replace_vars(self._opr.outputs, repl_dict) for ind, o in enumerate(self.outputs): - o.var = out_vars[ind] + o._reset_var(out_vars[ind]) class Elemwise(OpNode): @@ -785,3 +808,6 @@ class AssertEqual(OpNode): class CvtColorForward(OpNode): type = "CvtColor" opdef = builtin.CvtColor + + +set_py_varnode_type(VarNode) diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 5b968e76..beabf78e 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -114,6 +114,8 @@ void _set_priority_to_id(const std::vector& dest_vars) { } } +py::object Py_Varnode = py::none(); + void init_graph_rt(py::module m) { static const std::unique_ptr _imperative_sm_opr_footprint_ptr{ std::make_unique()}; @@ -124,40 +126,44 @@ void init_graph_rt(py::module m) { def_rendezvous(m, "TensorAttrRendezvous"); - py::class_>(m, "VarNode") - .def_property_readonly( - "owner", [](cg::VarNode* v) { return v->owner_opr(); }) - .def_property_readonly( - "graph", [](cg::VarNode* v) { return v->owner_graph(); }) - .def_property( - "name", py::overload_cast<>(&VarNode::name, py::const_), - py::overload_cast(&VarNode::name)) - .def_property_readonly("dtype", [](cg::VarNode* v) { return v->dtype(); }) - .def_property_readonly( - "comp_node", [](cg::VarNode* v) { return v->comp_node(); }) - .def_property_readonly( - "shape", - [](cg::VarNode* v) -> const TensorShape* { - auto&& mgr = v->owner_graph()->static_infer_manager(); - return mgr.infer_shape_fallible(v); - }) - .def_property_readonly( - "value", - [](cg::VarNode* v) -> py::object { - auto&& mgr = v->owner_graph()->static_infer_manager(); - auto&& type = mgr.get_infer_type(v); - using InferType = cg::static_infer::InferType; - if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { - return py::none(); - } - auto* val = mgr.infer_value_fallible(v); - if (!val) { - return py::none(); - } - return py::cast(*val).attr("numpy")(); - }) - .def_property_readonly("id", [](cg::VarNode* v) { return (v->id()); }) - .def("__repr__", [](cg::VarNode* v) { return "Var:" + v->name(); }); + Py_Varnode = + py::class_>(m, "VarNode") + .def_property_readonly( + "owner", [](cg::VarNode* v) { return v->owner_opr(); }) + .def_property_readonly( + "graph", [](cg::VarNode* v) { return v->owner_graph(); }) + .def_property( + "name", py::overload_cast<>(&VarNode::name, py::const_), + py::overload_cast(&VarNode::name)) + .def_property_readonly( + "dtype", [](cg::VarNode* v) { return v->dtype(); }) + .def_property_readonly( + "comp_node", [](cg::VarNode* v) { return v->comp_node(); }) + .def_property_readonly( + "shape", + [](cg::VarNode* v) -> const TensorShape* { + auto&& mgr = v->owner_graph()->static_infer_manager(); + return mgr.infer_shape_fallible(v); + }) + .def_property_readonly( + "value", + [](cg::VarNode* v) -> py::object { + auto&& mgr = v->owner_graph()->static_infer_manager(); + auto&& type = mgr.get_infer_type(v); + using InferType = cg::static_infer::InferType; + if (!(type.value & + (InferType::CONST | InferType::RT_STATIC))) { + return py::none(); + } + auto* val = mgr.infer_value_fallible(v); + if (!val) { + return py::none(); + } + return py::cast(*val).attr("numpy")(); + }) + .def_property_readonly( + "id", [](cg::VarNode* v) { return (v->id()); }) + .def("__repr__", [](cg::VarNode* v) { return "Var:" + v->name(); }); py::class_>( m, "OperatorNode") diff --git a/imperative/python/src/graph_rt.h b/imperative/python/src/graph_rt.h index 3fa90fd8..8dd004e2 100644 --- a/imperative/python/src/graph_rt.h +++ b/imperative/python/src/graph_rt.h @@ -8,6 +8,9 @@ #include "megbrain/graph.h" #include "megbrain/plugin/opr_footprint.h" +namespace py = pybind11; +extern py::object Py_Varnode; + template class GraphNodePtr { std::shared_ptr m_graph; diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index ab7af91a..4cd4a27f 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -48,58 +48,11 @@ namespace mgb::imperative::python { namespace { WeakKeyMap module_trace_info_map; - -struct SymbolVarContext { - TransformationContext context; - std::shared_ptr symbol_tsf; - std::shared_ptr scalar_tsf; - std::shared_ptr dtype_promote_tsf; - std::shared_ptr dim_expansion_tsf; - - SymbolVarContext(cg::ComputingGraph* graph) { - symbol_tsf = std::make_shared(graph); - scalar_tsf = std::make_shared(); - dtype_promote_tsf = std::make_shared(); - dim_expansion_tsf = std::make_shared(); - Transformation::swap_context(context); - } - - void init() { - symbol_tsf->register_at(Transformation::top()); - scalar_tsf->register_at(Transformation::top()); - dtype_promote_tsf->register_at(Transformation::top()); - dim_expansion_tsf->register_at(Transformation::top()); - } - - ValueRef symvar2val(py::handle py_symbol_var) { - auto* symbol_var = py_symbol_var.cast(); - ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node); - if (symbol_var->is_scalar) { - value = scalar_tsf->value_type().make(value); - } - return value; - } - - py::object val2symvar(py::handle typeobj, ValueRef value) { - bool is_scalar = false; - if (auto* scalar_value = value.as(scalar_tsf->value_type())) { - value = scalar_value->value(); - is_scalar = true; - } - auto* node = value.cast(symbol_tsf->value_type()).node(); - auto py_symbol_var = - typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic)); - py_symbol_var.cast()->is_scalar = is_scalar; - return py_symbol_var; - } - - ~SymbolVarContext() { Transformation::swap_context(context); } -}; - } // namespace interpreter::Interpreter::Channel* interpreter_for_py = nullptr; PyTypeObject* py_tensor_type = nullptr; +PyTypeObject* py_varnode_type = nullptr; pybind11::handle py_device_type = nullptr; PyObject* cpp_use_symbolic_shape; @@ -136,22 +89,6 @@ PyObject* py_apply( auto op = py::handle(py_op).cast>(); SmallVector tensors(nargs); - SmallVector is_symbol_var(nargs, false); - ComputingGraph* cg = nullptr; - for (size_t i = 0; i < nargs; ++i) { - if ((!TensorWrapper::try_cast(args[i])) && - py::isinstance(py::handle(args[i]))) { - is_symbol_var[i] = true; - ComputingGraph* cur_cg = - py::handle(args[i]).cast()->m_node->owner_graph(); - if (cg == nullptr) { - cg = cur_cg; - } else { - mgb_assert(cg == cur_cg); - } - } - } - mgb::CompNode target_cn; mgb::DType target_dtype; @@ -174,35 +111,11 @@ PyObject* py_apply( } }; - if (cg != nullptr) { - // swap to a special context to reuse scalar handle - size_t symbol_var_idx = 8; - SymbolVarContext context(cg); - context.init(); - for (size_t i = 0; i < nargs; ++i) { - if (is_symbol_var[i]) { - symbol_var_idx = i; - tensors[i] = context.symvar2val(args[i]); - } else if ( - DTypePromoteCfg::convert_input_enabled && - op->same_type()) { - tensors[i] = convert_pyinput_to_tensor(i); - } else { - PyErr_SetString( - PyExc_TypeError, "py_apply expects tensor as inputs"); - return nullptr; - } - } - auto outputs = imperative::apply(*op, tensors); - auto ret = pybind11::tuple(outputs.size()); - auto typeobj = py::handle(args[symbol_var_idx]).get_type(); - for (size_t i = 0; i < outputs.size(); ++i) { - ret[i] = context.val2symvar(typeobj, outputs[i]); - } - return ret.release().ptr(); - } - + bool is_varnode_apply = false; for (size_t i = 0; i < nargs; ++i) { + if (PyObject_TypeCheck(args[i], py_varnode_type)) { + is_varnode_apply = true; + } if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { tensors[i] = tw->m_tensor->data(); } else if ( @@ -218,8 +131,9 @@ PyObject* py_apply( auto outputs = [&] { return imperative::apply(*op, tensors); }(); size_t nout = outputs.size(); auto ret = py::tuple(nout); + PyTypeObject* py_type = is_varnode_apply ? py_varnode_type : py_tensor_type; for (size_t i = 0; i < nout; ++i) { - ret[i] = TensorWrapper::make(py_tensor_type, std::move(outputs[i])); + ret[i] = TensorWrapper::make(py_type, std::move(outputs[i])); } return ret.release().ptr(); } @@ -622,9 +536,17 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { CreateTensor::Kind kind = is_const ? CreateTensor::Const : no_cache ? CreateTensor::Unique : CreateTensor::Common; - auto&& hval = pyobj2hval(data, cn, dtype); - auto val = imperative::apply( - CreateTensor(kind, cn, hval.dtype, hval.shape), hval.storage)[0]; + ValueRef val; + if (py::isinstance(data, Py_Varnode)) { + cg::VarNode* m_node = py::handle(data).cast(); + val = imperative::apply( + CreateNode(m_node), Span(nullptr, nullptr))[0]; + } else { + auto&& hval = pyobj2hval(data, cn, dtype); + val = imperative::apply( + CreateTensor(kind, cn, hval.dtype, hval.shape), + hval.storage)[0]; + } m_tensor.emplace(val); } @@ -734,6 +656,20 @@ PyObject* TensorWrapper::isscalar() { } } +PyObject* TensorWrapper::_var() { + TypedValueRef value = + imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref(); + auto* node = value->node(); + return py::cast(node).release().ptr(); +} + +PyObject* TensorWrapper::_graph() { + TypedValueRef value = + imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref(); + auto* graph = value->graph(); + return py::cast(graph).release().ptr(); +} + struct TensorWeakRef { ValueWeakRef data; @@ -808,6 +744,10 @@ void init_tensor(py::module m) { std::make_shared()) .release()); MGB_MARK_USED_VAR(transformations + .register_at( + std::make_shared()) + .release()); + MGB_MARK_USED_VAR(transformations .register_at( std::make_shared()) .release()); @@ -863,6 +803,8 @@ void init_tensor(py::module m) { .def<&TensorWrapper::_detail>("_detail") .def<&TensorWrapper::_set_name>("_set_name") .def<&TensorWrapper::_watch>("_watch") + .def<&TensorWrapper::_var>("var") + .def<&TensorWrapper::_graph>("graph") .def_getset< &TensorWrapper::module_trace_info, &TensorWrapper::set_module_trace_info>("_NodeMixin__node") @@ -875,43 +817,6 @@ void init_tensor(py::module m) { .def(py::init()) .def("__call__", &TensorWeakRef::operator()); - py::class_>(m, "SymbolVar") - .def_property_readonly( - "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); }) - .def_property( - "var", [](PySymbolVar* v) { return v->m_node; }, - [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; }) - .def_property_readonly( - "device", [](PySymbolVar* v) { return v->m_node->comp_node(); }) - .def_property_readonly( - "graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); }) - .def_property_readonly( - "shape", - [](PySymbolVar* v) -> const TensorShape* { - auto&& mgr = v->m_node->owner_graph()->static_infer_manager(); - return mgr.infer_shape_fallible(v->m_node); - }) - .def("numpy", - [](PySymbolVar* v) { - auto&& mgr = v->m_node->owner_graph()->static_infer_manager(); - auto&& type = mgr.get_infer_type(v->m_node); - using InferType = cg::static_infer::InferType; - if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { - throw py::value_error("value invalid!"); - } - auto* val = mgr.infer_value_fallible(v->m_node); - if (!val) { - throw py::value_error("value invalid!"); - } - auto np_val = py::cast(*val).attr("numpy")(); - return np_val; - }) - .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) - .def(py::init([](cg::VarNode* node) { - return std::make_shared(node); - }), - py::arg() = nullptr); - static PyMethodDef method_defs[] = { MGE_PY_INTERFACE(apply, py_apply), MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), @@ -1027,6 +932,10 @@ void init_tensor(py::module m) { py_tensor_type = reinterpret_cast(type_obj.inc_ref().ptr()); }); + m.def("set_py_varnode_type", [](py::object type_obj) { + py_varnode_type = reinterpret_cast(type_obj.inc_ref().ptr()); + }); + m.def("set_py_device_type", [](py::object type_obj) { py_device_type = type_obj.inc_ref(); }); @@ -1217,31 +1126,6 @@ void init_tensor(py::module m) { } }); - m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object { - auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) { - auto make_scalar_shape = [&](CompNode device) { - return imperative::apply( - CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), - HostStorage::make(device))[0]; - }; - return imperative::apply(op, input, make_scalar_shape(*input.device()))[0]; - }; - if (py::isinstance(tensor)) { - auto* graph = tensor.cast()->m_node->owner_graph(); - SymbolVarContext context(graph); - context.init(); - auto output = reduce_to_scalar( - *op.cast>(), context.symvar2val(tensor)); - auto typeobj = tensor.get_type(); - return context.val2symvar(typeobj, output); - } else { - auto* tw = TensorWrapper::try_cast(tensor.ptr()); - auto output = reduce_to_scalar( - *op.cast>(), tw->m_tensor->data()); - return TensorWrapper::make(py_tensor_type, output); - } - }); - m.def("name_tensor", [](std::string name, py::object tensor) { auto* tw = TensorWrapper::try_cast(tensor.ptr()); auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index df9c8ffe..86394cba 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -10,6 +10,8 @@ #include "./pyext17.h" #include "megbrain/imperative/dispatch.h" +#include "megbrain/imperative/transformations/scalar.h" +#include "megbrain/imperative/transformations/symbol.h" #include "megbrain/imperative/utils/span.h" namespace mgb::imperative::python { @@ -27,6 +29,7 @@ namespace mgb::imperative::python { extern interpreter::Interpreter::Channel* interpreter_for_py; extern PyTypeObject* py_tensor_type; +extern PyTypeObject* py_varnode_type; extern pybind11::handle py_device_type; extern PyObject* cpp_use_symbolic_shape; extern PyObject* cpp_astensor1d; @@ -126,16 +129,11 @@ public: void set_module_trace_info(PyObject*); void _set_name(PyObject*); PyObject* _detail(); + PyObject* _var(); + PyObject* _graph(); void _watch(); }; -struct PySymbolVar { - cg::VarNode* m_node = nullptr; - bool is_scalar = false; - PySymbolVar() = default; - PySymbolVar(VarNode* m) : m_node(m) {} -}; - PyObject* py_apply( PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */); diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 688c69d6..c2506f6f 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -146,15 +146,6 @@ PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { continue; } - if (py::isinstance(py::handle(handle))) { - auto var = py::handle(handle).cast(); - mgb::DType type = var->m_node->dtype(); - auto&& descr = npy::dtype_mgb2np_descr(type); - Py_INCREF(descr.get()); - tensors.emplace_back(descr.get()); - continue; - } - PyArray_Descr* descr = scalar2dtype(handle); if (descr) { scalars.emplace_back(descr); @@ -204,17 +195,12 @@ CompNode _get_device(PyObject* const* args, size_t nargs) { PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; TensorWrapper* tw = TensorWrapper::try_cast(handle); - bool is_symvar = py::isinstance(py::handle(handle)); - if (tw || is_symvar) { + if (tw) { if (!valid) { - cn = tw ? tw->m_tensor->comp_node() - : py::handle(handle).cast()->m_node->comp_node(); + cn = tw->m_tensor->comp_node(); valid = true; } else { - CompNode cn1 = tw ? tw->m_tensor->comp_node() - : py::handle(handle) - .cast() - ->m_node->comp_node(); + CompNode cn1 = tw->m_tensor->comp_node(); if (cn1 != cn) { throw py::value_error(ssprintf( "ambiguous device: %s (from %s) vs %s (from %s)", @@ -258,10 +244,6 @@ PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { } bool is_scalar(PyObject* tensor) { - if (py::isinstance(py::handle(tensor))) { - auto var = py::handle(tensor).cast(); - return var->is_scalar; - } auto* tw = TensorWrapper::try_cast(tensor); if (tw) { return tw->m_tensor->is_scalar(); @@ -319,8 +301,7 @@ py::object device2obj(py::handle device, bool mapping = false) { } } -py::object _Const( - py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) { +py::object _Const(py::handle value, py::handle dtype, py::handle device) { py::object val = py::reinterpret_borrow(value); if (PyArray_Check(value.ptr())) { py::tuple strides = @@ -338,32 +319,6 @@ py::object _Const( val = val.attr("reshape")(orig_shp); } } - py::object ref; - if (py::isinstance(ref_hdl)) { - py::tuple tup = py::reinterpret_borrow(ref_hdl); - if (tup.size()) { - ref = tup[0]; - } else { - ref = py::none(); - } - } else { - ref = py::reinterpret_borrow(ref_hdl); - } - if (py::isinstance(ref)) { - auto ref_var = ref.cast(); - auto* graph = ref_var->m_node->owner_graph(); - CompNode cn; - if (device.ptr() == Py_None) { - cn = ref_var->m_node->comp_node(); - } else { - cn = device2obj(device).cast(); - } - OperatorNodeConfig config(cn); - auto hv = npy::np2tensor( - val.ptr(), npy::Meth::borrow(cn), dtype.cast()); - auto typeobj = ref.get_type(); - return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); - } py::object device_obj = device2obj(device, true); py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); @@ -373,7 +328,7 @@ py::tuple _make_shape_tuple(py::handle shape) { py::list orig; py::list ret(0); auto solve_one = [&](py::handle val) { - if (TensorWrapper::try_cast(val.ptr()) || py::isinstance(val)) { + if (TensorWrapper::try_cast(val.ptr())) { py::object np = getattr(val, "numpy")(); PyArrayObject* arr = (PyArrayObject*)np.ptr(); PyObject* maybe_list = PyArray_ToList(arr); @@ -415,25 +370,53 @@ py::tuple _make_shape_tuple(py::handle shape) { return py::reinterpret_steal(PyList_AsTuple(ret.ptr())); } -bool is_tensor_or_symbolvar(py::handle arg) { - return bool(TensorWrapper::try_cast(arg.ptr())) || py::isinstance(arg); +bool is_tensor(py::handle arg) { + return bool(TensorWrapper::try_cast(arg.ptr())); } bool is_py_sequence(py::handle arg) { - if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr()) || - py::isinstance(arg)) { + if (PyArray_Check(arg.ptr()) || TensorWrapper::try_cast(arg.ptr())) { return false; } return PySequence_Check(arg.ptr()); } -mgb::DType _get_dtype(py::handle tensor) { - if (auto tw = TensorWrapper::try_cast(tensor.ptr())) { - return tw->m_tensor->dtype(); +py::object get_res_by_refhdl( + py::handle value, py::handle dtype, py::handle device, py::handle ref_hdl) { + py::object res = _Const(value, dtype, device); + py::object ref; + if (py::isinstance(ref_hdl)) { + py::tuple tup = py::reinterpret_borrow(ref_hdl); + if (tup.size()) { + ref = tup[0]; + } else { + ref = py::none(); + } } else { - auto var = tensor.cast(); - return var->m_node->dtype(); + ref = py::reinterpret_borrow(ref_hdl); + } + if (PyObject_TypeCheck(ref.ptr(), py_varnode_type)) { + auto temp = dtype.cast(); + ComputingGraph* graph = getattr(ref, "graph").cast(); + cg::VarNode* node = getattr(ref, "var").cast(); + CompNode cn; + if (device.ptr() == Py_None) { + cn = node->comp_node(); + } else { + cn = device2obj(device).cast(); + } + OperatorNodeConfig config(cn); + auto hv = npy::np2tensor( + value.ptr(), npy::Meth::borrow(cn), dtype.cast()); + auto typeobj = ref.get_type(); + return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); } + return res; +} + +mgb::DType _get_dtype(py::handle tensor) { + auto tw = TensorWrapper::try_cast(tensor.ptr()); + return tw->m_tensor->dtype(); } py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { @@ -457,12 +440,12 @@ py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { py::object _convert_single_value_cpp( py::handle value, py::handle dtype, py::handle device) { - if (is_tensor_or_symbolvar(value)) { + if (is_tensor(value)) { if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) { return _astype_cpp(value, dtype); } } else { - return _Const(value, dtype, device, py::none()); + return _Const(value, dtype, device); } return py::reinterpret_borrow(value); } @@ -475,28 +458,8 @@ py::object _convert_inputs_cpp( for (size_t i = 0; i < nargs; ++i) { py::handle h = py::handle(args[i]); lis.append(h); - if (py::isinstance(h)) { - auto var = h.cast(); - auto g = var->m_node->owner_graph(); - if (!graph) { - graph = g; - typeobj = h.get_type(); - } else { - mgb_assert(graph == g); - } - } - } - if (graph) { - CompNode cn = device2obj(device).cast(); - for (size_t i = 0; i < nargs; ++i) { - OperatorNodeConfig config(cn); - auto hv = npy::np2tensor( - lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast()); - if (!py::isinstance(lis[i])) { - lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); - } - } } + auto convert = [&](py::object value) { if (value.is_none()) { return value; @@ -517,7 +480,8 @@ py::object _astensor1d_cpp( if (device.ptr() != Py_None) { device_obj = device2obj(device); } - if (py::isinstance(value)) { + + if (PyObject_TypeCheck(value.ptr(), py_varnode_type)) { try { getattr(value, "ndim"); } catch (py::error_already_set& err) { @@ -537,14 +501,15 @@ py::object _astensor1d_cpp( return ret; } } + size_t ndim = 999; if (hasattr(value, "ndim")) { ndim = getattr(value, "ndim").cast(); if (ndim != 0 && ndim != 1) { throw py::value_error("ndim != 1 or 0, get : " + std::to_string(ndim)); } - if (!is_tensor_or_symbolvar(value)) { - return _Const(value, dtype, device, ref); + if (!is_tensor(value)) { + return get_res_by_refhdl(value, dtype, device, ref); } else { return py::reinterpret_borrow(value); } @@ -555,13 +520,13 @@ py::object _astensor1d_cpp( py::list lis = py::reinterpret_steal(PySequence_List(value.ptr())); bool need_concat = false; for (size_t i = 0; i < lis.size(); ++i) { - if (is_tensor_or_symbolvar(lis[i])) { + if (is_tensor(lis[i])) { need_concat = true; break; } } if (!need_concat) { - return _Const(value, dtype, device, ref); + return get_res_by_refhdl(value, dtype, device, ref); } if (lis.size() > 1) { std::vector c_args(lis.size() + 1); @@ -600,10 +565,9 @@ py::object _astensor1d_cpp( } py::object _get_index(py::object tensor, py::object src) { - if (!TensorWrapper::try_cast(tensor.ptr()) && - !py::isinstance(tensor)) { + if (!TensorWrapper::try_cast(tensor.ptr())) { auto get_const = [&](mgb::DType dtype) -> py::object { - return _Const(tensor, py::cast(dtype), src.attr("device"), src); + return _Const(tensor, py::cast(dtype), src.attr("device")); }; if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) { tensor = get_const(dtype::Bool()); @@ -636,9 +600,8 @@ py::tuple _try_cond_take(py::handle tensor, py::handle index) { } py::object iobj; if (PyArray_Check(index.ptr())) { - iobj = - _Const(index, py::cast((mgb::DType)dtype::Bool()), - getattr(tensor, "device"), tensor); + iobj = _Const( + index, py::cast((mgb::DType)dtype::Bool()), getattr(tensor, "device")); } else { iobj = py::reinterpret_borrow(index); } @@ -920,8 +883,8 @@ py::object _expand_args(py::handle args) { return py::reinterpret_borrow(args); } py::tuple args_tup = py::reinterpret_borrow(args.ptr()); - if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || - is_tensor_or_symbolvar(args_tup[0].ptr()))) { + if (args_tup.size() == 1 && + (PySequence_Check(args_tup[0].ptr()) || is_tensor(args_tup[0].ptr()))) { return py::reinterpret_borrow(args_tup[0]); } else { return py::reinterpret_steal(PySequence_List(args_tup.ptr())); @@ -948,7 +911,8 @@ std::tuple, bool> tuple2vector(py::object shape) { bool enable_fastpath(py::handle inp) { auto&& tm_tr = TransformationManager::get_instance() .segments[TransformationManager::Segment::ModuleTrace]; - if (!TensorWrapper::try_cast(inp.ptr()) || + bool is_varnode = PyObject_TypeCheck(inp.ptr(), py_varnode_type); + if (is_varnode || TransformationManager::get_instance() .segments[TransformationManager::Segment::Trace] .size() > 0 || @@ -1181,10 +1145,8 @@ py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) { py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) { py::object org_shape = getattr(inp_hdl, "shape"); py::object val = py::reinterpret_borrow(val_hdl); - if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance(val)) { - val = - _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"), - inp_hdl); + if (!TensorWrapper::try_cast(val.ptr())) { + val = _Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device")); } py::tuple up = _unpack_indexes(inp_hdl, idx_hdl); @@ -1308,12 +1270,12 @@ py::object _split_cpp( repr(nsplits_or_sections_hdl).cast()); } py::object pos = div_points[i] - div_points[i - 1]; - if (is_tensor_or_symbolvar(pos)) { + if (is_tensor(pos)) { partitions.append(pos); } else { partitions.append( _Const(pos, py::cast((mgb::DType)dtype::Int32()), - getattr(inp_hdl, "device"), inp_hdl)); + getattr(inp_hdl, "device"))); } } op = Split::make(axis, 0); @@ -1438,7 +1400,7 @@ py::object _squeeze_cpp(py::handle inp_hdl, py::handle axis_hdl) { py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { py::object obj = _expand_args(args); py::list lis; - if (!is_tensor_or_symbolvar(obj.ptr()) && PySequence_Check(obj.ptr())) { + if (!is_tensor(obj.ptr()) && PySequence_Check(obj.ptr())) { lis = py::reinterpret_steal(PySequence_List(obj.ptr())); } else { py::object np = getattr(obj, "numpy")(); @@ -1631,7 +1593,7 @@ PyObject* pixel_shuffle_cpp(PyObject* self, PyObject* const* args, size_t nargs) PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { try { - return _Const(args[0], args[1], args[2], args[3]).release().ptr(); + return _Const(args[0], args[1], args[2]).release().ptr(); } PYEXT17_TRANSLATE_EXC_RET(nullptr) } @@ -1696,4 +1658,4 @@ PyObject* astensor1d_cpp(PyObject* self, PyObject* const* args, size_t nargs) { PYEXT17_TRANSLATE_EXC_RET(nullptr) } -} // namespace mgb::imperative::python +} // namespace mgb::imperative::python \ No newline at end of file diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h index c59f0eb7..ae5cf59a 100644 --- a/imperative/python/src/transformation.h +++ b/imperative/python/src/transformation.h @@ -20,11 +20,12 @@ public: DimExpansion, Grad, Scalar, + Symbol, Trace, Eval, }; - std::array>, 7> segments; + std::array>, 8> segments; private: template diff --git a/imperative/python/test/helpers/utils.py b/imperative/python/test/helpers/utils.py index ef19085f..5d73ffb6 100644 --- a/imperative/python/test/helpers/utils.py +++ b/imperative/python/test/helpers/utils.py @@ -11,7 +11,7 @@ from megengine.utils.network_node import VarNode def _default_compare_fn(x, y): - if isinstance(x, tensor): + if isinstance(x, tensor) and not isinstance(x, VarNode): x = x.numpy() elif not isinstance(x, np.ndarray): x = get_var_value(x) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 7a5c7ccc..05d2abd5 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -679,6 +679,18 @@ def test_utils_astensor1d(is_varnode): assert isinstance(xx, type(reference)) np.testing.assert_equal(xx.numpy(), [1, 2, 3]) + # varnode + if is_varnode: + a = np.array([[1, 2, 3], [4, 5, 6]]).astype("float32") + b = np.array([[True, False, True], [False, True, True]]) + aa = make_tensor(a, network) + bb = make_tensor(b, network) + x, y = F.cond_take(bb, aa) + for dtype in [None, "float32"]: + xx = astensor1d(x, reference, dtype=dtype) + assert isinstance(xx, type(reference)) + np.testing.assert_equal(get_var_value(xx), get_var_value(x)) + def test_device(): x = tensor([1, 2, 3], dtype="float32") diff --git a/imperative/python/test/unit/utils/test_network.py b/imperative/python/test/unit/utils/test_network.py index ca5abe68..243b7164 100644 --- a/imperative/python/test/unit/utils/test_network.py +++ b/imperative/python/test/unit/utils/test_network.py @@ -114,8 +114,10 @@ def test_replace_opr(): vara = graph.var_filter.name("a").as_unique() varb = graph.var_filter.name("b").as_unique() - out1 = F.sub(vara, varb) + out1 = F.mul(vara, varb) out1 = F.relu(out1) + out1 += 2 + out1 *= 3 out1 = graph.add_dep_oprs(out1) orig_opr = graph.opr_filter.has_input(vara).as_unique() @@ -135,7 +137,7 @@ def test_replace_opr(): load_graph = GraphInference(modified_model1) out = load_graph.run(a, b) - np.testing.assert_equal(out["o"], [0, 0]) + np.testing.assert_equal(out["o"], [30, 60]) def test_splice_network(): diff --git a/imperative/src/impl/basic_operators.cpp b/imperative/src/impl/basic_operators.cpp index d7b57e55..bb4c265a 100644 --- a/imperative/src/impl/basic_operators.cpp +++ b/imperative/src/impl/basic_operators.cpp @@ -82,6 +82,10 @@ std::string DTRCommand::to_string() const { return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind); } +std::string CreateNode::to_string() const { + return "CreateNode"; +} + std::string GetName::to_string() const { return "GetName{}"; } @@ -94,5 +98,9 @@ std::string IsScalar::to_string() const { return "IsScalar"; } +std::string GetVarVal::to_string() const { + return "GetVarVal"; +} + } // namespace imperative } // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/basic_operators.h b/imperative/src/include/megbrain/imperative/basic_operators.h index c93913e7..963f24d1 100644 --- a/imperative/src/include/megbrain/imperative/basic_operators.h +++ b/imperative/src/include/megbrain/imperative/basic_operators.h @@ -157,5 +157,22 @@ public: std::string to_string() const override; }; +class GetVarVal final : public OperatorImpl { +public: + std::string to_string() const override; +}; + +class CreateNode final : public OperatorImpl { +private: + cg::VarNode* m_node; + +public: + CreateNode(cg::VarNode* node) : m_node(node) {} + + cg::VarNode* node() const { return m_node; } + + std::string to_string() const override; +}; + } // namespace imperative } // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/basic_values.h b/imperative/src/include/megbrain/imperative/basic_values.h index 1b634d2d..3aa0d3d1 100644 --- a/imperative/src/include/megbrain/imperative/basic_values.h +++ b/imperative/src/include/megbrain/imperative/basic_values.h @@ -173,5 +173,24 @@ public: std::string to_string() const override; }; +class NodeStorage { +private: + cg::VarNode* m_node; + +public: + NodeStorage() = default; + NodeStorage(VarNode* node) : m_node(node) {} + VarNode* node() const { return m_node; } + ComputingGraph* graph() const { return m_node->owner_graph(); } + std::string to_string() const { return m_node->name(); } +}; + +class NodeValue final : public PrimitiveValue { +public: + using PrimitiveValue::PrimitiveValue; + + std::string to_string() const override { return NodeStorage::to_string(); } +}; + } // namespace imperative } // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/transformations/symbol.h b/imperative/src/include/megbrain/imperative/transformations/symbol.h index 2cabe6dc..9c546278 100644 --- a/imperative/src/include/megbrain/imperative/transformations/symbol.h +++ b/imperative/src/include/megbrain/imperative/transformations/symbol.h @@ -39,13 +39,39 @@ private: ObjectType m_value_type{"SymbolValue"}; public: - SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} + SymbolTransformation() {} ValueRefList apply_transformation( const Operator& op, Span inputs) override { + ComputingGraph* cg = nullptr; + if (auto* node_value = op.as()) { + return {m_value_type.make(node_value->node())}; + } + for (auto&& input : inputs) { + if (auto* val = input.as(m_value_type)) { + auto* node = val->node(); + ComputingGraph* cur_cg = node->owner_graph(); + if (cg == nullptr) { + cg = cur_cg; + } else { + mgb_assert(cg == cur_cg, "input varnode gragh should be the same"); + } + } + } + if (!cg) { + return imperative::apply(op, inputs); + } + if (auto* apply_op = op.as()) { SmallVector input_nodes; for (auto&& input : inputs) { - input_nodes.push_back(input.cast(m_value_type).node()); + if (!input.is(m_value_type)) { + auto* node = opr::ImmutableTensor::make( + *cg, input.numpy()->as_nd(true), {}) + .node(); + input_nodes.push_back(node); + } else { + input_nodes.push_back(input.cast(m_value_type).node()); + } } auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); ValueRefList outputs(output_nodes.size()); @@ -53,15 +79,9 @@ public: outputs[i] = m_value_type.make(output_nodes[i]); } return outputs; - } else if (auto* create_tensor = op.as()) { - auto&& args = create_tensor->parse(inputs); - mgb_assert( - args.kind == CreateTensor::Const, - "only const value is allowed here"); - auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node(); - return {m_value_type.make(node)}; } else if (auto* get_attr = op.as()) { auto* node = inputs.item().cast(m_value_type).node(); + auto* m_graph = node->owner_graph(); switch (get_attr->attr()) { case GetAttr::DType: return {DTypeValue::make(node->dtype())}; @@ -105,6 +125,10 @@ public: MegBrainError, "Symbol: malformed GetAttr: %s", op.to_string().c_str()); } + } else if (auto* get_attr = op.as()) { + cg::VarNode* node = inputs.item().cast(m_value_type).node(); + NodeStorage inp_var = NodeStorage(node); + return {NodeValue::make(inp_var)}; } else { return op.fallback(inputs); } diff --git a/imperative/src/include/megbrain/imperative/value.h b/imperative/src/include/megbrain/imperative/value.h index 4c78c9f8..44970e86 100644 --- a/imperative/src/include/megbrain/imperative/value.h +++ b/imperative/src/include/megbrain/imperative/value.h @@ -33,6 +33,7 @@ class ShapeValue; class DTypeValue; class CompNodeValue; class StringValue; +class NodeValue; class Operator;