Browse Source

feat(mge/utils): add array method for varnode

GitOrigin-RevId: 6e4d05b475
release-1.3
Megvii Engine Team 4 years ago
parent
commit
24b91b98c7
14 changed files with 582 additions and 246 deletions
  1. +17
    -2
      imperative/python/megengine/core/ops/special.py
  2. +4
    -2
      imperative/python/megengine/core/tensor/array_method.py
  3. +16
    -11
      imperative/python/megengine/core/tensor/indexing.py
  4. +32
    -17
      imperative/python/megengine/core/tensor/utils.py
  5. +2
    -3
      imperative/python/megengine/functional/elemwise.py
  6. +31
    -21
      imperative/python/megengine/functional/tensor.py
  7. +39
    -22
      imperative/python/megengine/utils/network.py
  8. +52
    -9
      imperative/python/megengine/utils/network_node.py
  9. +57
    -19
      imperative/python/src/tensor.cpp
  10. +6
    -0
      imperative/python/src/tensor.h
  11. +21
    -4
      imperative/python/test/helpers/utils.py
  12. +79
    -52
      imperative/python/test/unit/core/test_indexing_op.py
  13. +217
    -69
      imperative/python/test/unit/functional/test_tensor.py
  14. +9
    -15
      imperative/python/test/unit/utils/test_network.py

+ 17
- 2
imperative/python/megengine/core/ops/special.py View File

@@ -8,6 +8,9 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np


from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor



class Const: class Const:
def __init__(self, value=None, *, dtype=None, device=None): def __init__(self, value=None, *, dtype=None, device=None):
@@ -19,7 +22,19 @@ class Const:
from ...tensor import Tensor from ...tensor import Tensor


device = self.device device = self.device
if device is None:
device = reference[0].device

if len(reference) != 0:
reference = reference[0]
assert isinstance(
reference, (SymbolVar, Tensor)
), "Reference should be Tensor or VarNode"

if device is None:
device = reference.device

if isinstance(reference, SymbolVar):
cls = type(reference)
rst = cls(make_const(reference.graph, self.value, device, self.dtype))
return (rst,)


return (Tensor(self.value, self.dtype, self.device, True),) return (Tensor(self.value, self.dtype, self.device, True),)

+ 4
- 2
imperative/python/megengine/core/tensor/array_method.py View File

@@ -13,7 +13,7 @@ from typing import Union
import numpy as np import numpy as np


from .._imperative_rt.common import CompNode from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import Tensor, apply
from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from ..ops import builtin from ..ops import builtin
from ..ops.builtin import Elemwise, GetVarShape from ..ops.builtin import Elemwise, GetVarShape
from . import utils from . import utils
@@ -230,7 +230,9 @@ def _todo(*_):


def _expand_args(args): def _expand_args(args):
if len(args) == 1: if len(args) == 1:
if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),):
if isinstance(
args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray),
):
args = args[0] args = args[0]
return args return args




+ 16
- 11
imperative/python/megengine/core/tensor/indexing.py View File

@@ -10,7 +10,7 @@ from typing import Iterable


import numpy as np import numpy as np


from .._imperative_rt.core2 import Tensor, apply
from .._imperative_rt.core2 import SymbolVar, Tensor, apply
from .._trace_option import use_symbolic_shape from .._trace_option import use_symbolic_shape
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
@@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
return True return True


def get_index(i): def get_index(i):
if not isinstance(i, (Tensor)):
if not isinstance(i, (Tensor, SymbolVar)):
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_:
(i,) = Const(i, dtype=np.bool_, device=inp.device)()
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp)
else: else:
(i,) = Const(i, dtype=np.int32, device=inp.device)()
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp)
return i return i
assert isinstance(i, Tensor)
assert isinstance(i, (Tensor, SymbolVar))
if i.dtype != np.bool_: if i.dtype != np.bool_:
return i return i
_, ind = apply(builtin.CondTake(), i, i) _, ind = apply(builtin.CondTake(), i, i)
@@ -197,9 +197,9 @@ def try_condtake(tensor, index):
): ):
return [] return []
if isinstance(index, np.ndarray): if isinstance(index, np.ndarray):
(index,) = Const(index, dtype=np.bool_, device=tensor.device)()
assert isinstance(index, Tensor)
if not isinstance(tensor, Tensor):
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor)
assert isinstance(index, (Tensor, SymbolVar))
if not isinstance(tensor, (Tensor, SymbolVar)):
raise TypeError("input must be a tensor") raise TypeError("input must be a tensor")
if tensor.device != index.device: if tensor.device != index.device:
raise ValueError( raise ValueError(
@@ -214,11 +214,16 @@ def getitem(tensor, index):
return try_result[0] return try_result[0]
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index)
for v in tensors: for v in tensors:
if v.shape is None:
break
if isinstance(v.shape, v.__class__): if isinstance(v.shape, v.__class__):
break break
if len(v.shape) > 0 and v.shape[0] == 0: if len(v.shape) > 0 and v.shape[0] == 0:
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)()
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)(
tensor
)
return empty_tensor return empty_tensor

if use_subtensor: if use_subtensor:
op = builtin.Subtensor(items=items) op = builtin.Subtensor(items=items)
else: else:
@@ -235,8 +240,8 @@ def setitem(tensor, index, value):
if len(try_result) == 2: if len(try_result) == 2:
index = try_result[1] index = try_result[1]
tensor = tensor.reshape(-1) tensor = tensor.reshape(-1)
if not isinstance(value, Tensor):
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)()
if not isinstance(value, (Tensor, SymbolVar)):
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor)
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index)
if use_subtensor: if use_subtensor:
op = builtin.Subtensor(items=items) op = builtin.Subtensor(items=items)


+ 32
- 17
imperative/python/megengine/core/tensor/utils.py View File

@@ -11,8 +11,9 @@ from typing import Iterable, Union


import numpy as np import numpy as np


from .._imperative_rt import VarNode
from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device
from .._imperative_rt import make_const
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device
from .._wrap import device as as_device
from ..ops import builtin from ..ops import builtin
from ..ops.special import Const from ..ops.special import Const
from .dtype import is_dtype_equal, is_quantize from .dtype import is_dtype_equal, is_quantize
@@ -38,13 +39,9 @@ def set_convert_inputs(flag):




def concatenate(inputs, axis=0, *, device=None): def concatenate(inputs, axis=0, *, device=None):
dtype = dtype_promotion(inputs)
device = get_device(inputs)

def convert(x):
return convert_single_value(x, dtype=dtype, device=device)

inputs = tuple(map(convert, inputs))
inputs = convert_inputs(*inputs)
if device is None:
device = get_device(inputs)
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs)
return result return result


@@ -60,7 +57,7 @@ def astype(x, dtype):




def convert_single_value(v, *, dtype=None, device=None): def convert_single_value(v, *, dtype=None, device=None):
if isinstance(v, (Tensor, VarNode)):
if isinstance(v, (Tensor, SymbolVar)):
if not is_quantize(v.dtype): if not is_quantize(v.dtype):
v = astype(v, dtype) v = astype(v, dtype)
else: else:
@@ -68,17 +65,35 @@ def convert_single_value(v, *, dtype=None, device=None):
return v return v




def convert_inputs(*args: Tensor):
def convert_inputs(*args, device=None):
if not _enable_convert_inputs: if not _enable_convert_inputs:
return args return args


dtype = dtype_promotion(args) dtype = dtype_promotion(args)
device = get_device(args)
if device is None:
device = get_device(args)
device = as_device(device)

graph = None
sym_type = None
for a in args:
if isinstance(a, SymbolVar):
if graph is None:
graph = a.var.graph
sym_type = type(a)
else:
assert graph == a.var.graph
args = list(args)
if graph is not None:
for i in range(len(args)):
if not isinstance(args[i], SymbolVar):
rst = make_const(graph, np.array(args[i]), device.to_c(), dtype)
args[i] = sym_type(rst)


def convert(value): def convert(value):
if value is None: if value is None:
return value return value
return convert_single_value(value, dtype=dtype, device=device)
return convert_single_value(value, dtype=dtype, device=device.to_c())


return tuple(map(convert, args)) return tuple(map(convert, args))


@@ -98,14 +113,14 @@ def result_type(*args):


def isscalar(x): def isscalar(x):


if isinstance(x, Tensor):
if isinstance(x, (Tensor, SymbolVar)):
return x._isscalar() return x._isscalar()


return np.isscalar(x) return np.isscalar(x)




def setscalar(x): def setscalar(x):
if isinstance(x, Tensor):
if isinstance(x, (Tensor, SymbolVar)):
x._setscalar() x._setscalar()
else: else:
raise NotImplementedError("Unsupport type {}".format(type(x))) raise NotImplementedError("Unsupport type {}".format(type(x)))
@@ -132,7 +147,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if not isinstance(x, collections.abc.Sequence): if not isinstance(x, collections.abc.Sequence):
raise TypeError raise TypeError


if any(isinstance(i, Tensor) for i in x):
if any(isinstance(i, (Tensor, SymbolVar)) for i in x):
x = concatenate(x, device=device) x = concatenate(x, device=device)
if dtype is not None: if dtype is not None:
x = astype(x, dtype) x = astype(x, dtype)
@@ -142,7 +157,7 @@ def astensor1d(x, *reference, dtype=None, device=None):




def _expand_int(s, i): def _expand_int(s, i):
if isinstance(i, Tensor):
if isinstance(i, (Tensor, SymbolVar)):
i_np = i.numpy() i_np = i.numpy()
if i_np.ndim == 0: if i_np.ndim == 0:
s.append(int(i_np)) s.append(int(i_np))


+ 2
- 3
imperative/python/megengine/functional/elemwise.py View File

@@ -9,8 +9,7 @@
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order # pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order
import numpy as np import numpy as np


from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.graph import VarNode
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Elemwise from ..core.ops.builtin import Elemwise
from ..core.tensor import utils from ..core.tensor import utils
@@ -72,7 +71,7 @@ __all__ = [




def _elwise(*args, mode): def _elwise(*args, mode):
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args))
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args))
if len(tensor_args) == 0: if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args) dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())


+ 31
- 21
imperative/python/megengine/functional/tensor.py View File

@@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Union
import numpy as np import numpy as np


from ..core._imperative_rt import CompNode from ..core._imperative_rt import CompNode
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import SymbolVar, apply
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import Copy, Identity from ..core.ops.builtin import Copy, Identity
@@ -101,7 +101,7 @@ def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Ten
return result return result




def full(shape, value, dtype="float32", device=None):
def full(shape, value, dtype="float32", device=None) -> Tensor:
""" """
Returns a tensor with given shape and value. Returns a tensor with given shape and value.
""" """
@@ -115,7 +115,7 @@ def full(shape, value, dtype="float32", device=None):
return broadcast_to(x, shape) return broadcast_to(x, shape)




def ones(shape, dtype="float32", device=None):
def ones(shape, dtype="float32", device=None) -> Tensor:
""" """
Returns a ones tensor with given shape. Returns a ones tensor with given shape.


@@ -142,14 +142,14 @@ def ones(shape, dtype="float32", device=None):
return full(shape, 1.0, dtype=dtype, device=device) return full(shape, 1.0, dtype=dtype, device=device)




def zeros(shape, dtype="float32", device=None):
def zeros(shape, dtype="float32", device=None) -> Tensor:
""" """
Returns a zero tensor with given shape. Returns a zero tensor with given shape.
""" """
return full(shape, 0.0, dtype=dtype, device=device) return full(shape, 0.0, dtype=dtype, device=device)




def zeros_like(inp: Tensor) -> Tensor:
def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
""" """
Returns a zero tensor with the same shape as input tensor. Returns a zero tensor with the same shape as input tensor.


@@ -176,21 +176,26 @@ def zeros_like(inp: Tensor) -> Tensor:
[0 0 0]] [0 0 0]]


""" """
return zeros(inp.shape, dtype=inp.dtype, device=inp.device)
return full_like(inp, 0.0)




def ones_like(inp: Tensor) -> Tensor:
def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]:
""" """
Returns a ones tensor with the same shape as input tensor. Returns a ones tensor with the same shape as input tensor.
""" """
return ones(inp.shape, dtype=inp.dtype, device=inp.device)
return full_like(inp, 1.0)




def full_like(inp: Tensor, value: Union[int, float]) -> Tensor:
def full_like(
inp: Union[Tensor, SymbolVar], value: Union[int, float]
) -> Union[Tensor, SymbolVar]:
""" """
Returns a tensor filled with given value with the same shape as input tensor. Returns a tensor filled with given value with the same shape as input tensor.
""" """
return full(inp.shape, value, dtype=inp.dtype, device=inp.device)
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp)
if inp.shape is ():
return x
return broadcast_to(x, inp.shape)




def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
@@ -259,15 +264,10 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
if len(inps) == 1: if len(inps) == 1:
return inps[0] return inps[0]


dtype = dtype_promotion(inps)
inps = convert_inputs(*inps, device=device)
if device is None: if device is None:
device = get_device(inps) device = get_device(inps)
device = as_device(device) device = as_device(device)

def convert(x):
return convert_single_value(x, dtype=dtype, device=device)

inps = tuple(map(convert, inps))
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps)
return result return result


@@ -379,8 +379,14 @@ def split(inp, nsplits_or_sections, axis=0):
Ntotal, axis, Nsections Ntotal, axis, Nsections
) )
) )

func = (
floor_div
if isinstance(Nsections, (SymbolVar, Tensor))
else lambda x, y: x // y
)
div_points = [0] + [ div_points = [0] + [
floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections)
func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections)
] ]
for i in range(2, Nsections + 1): for i in range(2, Nsections + 1):
div_points[i] = div_points[i - 1] + div_points[i] div_points[i] = div_points[i - 1] + div_points[i]
@@ -925,11 +931,15 @@ def linspace(
if not (cur_device is None or device == cur_device): if not (cur_device is None or device == cur_device):
raise ("ambiguous device for linspace opr") raise ("ambiguous device for linspace opr")


if not isinstance(start, Tensor):
is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num])
if any(is_symbolvar) and not all(is_symbolvar):
raise TypeError("start, stop and num should all be VarNode or none of them")

if not isinstance(start, (Tensor, SymbolVar)):
start = Tensor(start, device=device) start = Tensor(start, device=device)
if not isinstance(stop, Tensor):
if not isinstance(stop, (Tensor, SymbolVar)):
stop = Tensor(stop, device=device) stop = Tensor(stop, device=device)
if not isinstance(num, Tensor):
if not isinstance(num, (Tensor, SymbolVar)):
num = Tensor(num, device=device) num = Tensor(num, device=device)


op = builtin.Linspace(comp_node=device) op = builtin.Linspace(comp_node=device)
@@ -983,7 +993,7 @@ def arange(
stop = stop.astype("float32") stop = stop.astype("float32")
if isinstance(step, Tensor): if isinstance(step, Tensor):
step = step.astype("float32") step = step.astype("float32")
num = ceil(Tensor((stop - start) / step, device=device))
num = ceil((stop - start) / step)
stop = start + step * (num - 1) stop = start + step * (num - 1)
result = linspace(start, stop, num, device=device) result = linspace(start, stop, num, device=device)
if np.dtype(dtype) == np.int32: if np.dtype(dtype) == np.int32:


+ 39
- 22
imperative/python/megengine/utils/network.py View File

@@ -16,6 +16,7 @@ from typing import Dict, List
import numpy as np import numpy as np


from ..core._imperative_rt import ComputingGraph from ..core._imperative_rt import ComputingGraph
from ..core._imperative_rt.core2 import SymbolVar
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
from .network_node import ( from .network_node import (
@@ -60,12 +61,12 @@ class Network:
) )
outputs = [new_outputs[i] for i in outspec] outputs = [new_outputs[i] for i in outspec]
self._orig_outputs = outputs self._orig_outputs = outputs
self.add_dep_oprs(*outputs)
for x in self._orig_outputs:
self.output_vars.append(self._get_var(x))
self.add_dep_oprs()
for x in self._orig_inputs: for x in self._orig_inputs:
self.input_vars.append(self._get_var(x)) self.input_vars.append(self._get_var(x))


for x in self._orig_outputs:
self.output_vars.append(self._get_var(x))
self.graph = self._orig_outputs[0].graph self.graph = self._orig_outputs[0].graph
return self return self


@@ -197,6 +198,8 @@ class Network:
def add_output(self, *vars: VarNode): def add_output(self, *vars: VarNode):
"""Adds vars into the network output node list """Adds vars into the network output node list
""" """
if not all([var.owner for var in vars]):
self.add_dep_oprs(*vars)
for var in vars: for var in vars:
if var not in self.output_vars: if var not in self.output_vars:
self.output_vars.append(var) self.output_vars.append(var)
@@ -209,21 +212,25 @@ class Network:
self.output_vars.remove(var) self.output_vars.remove(var)


def add_dep_oprs(self, *vars): def add_dep_oprs(self, *vars):
"""Adds dependent opnodes and varnodes of vars into network
"""
oprs = get_oprs_seq(vars, False, False)
for mge_opr in oprs:
if len(vars) == 0:
vars = self.output_vars
q = list(vars)
while len(q) > 0:
cur = q.pop(0)
if cur.owner is not None:
continue
if cur.name is None:
cur.name = cur.var.name
self.all_vars_map[cur.var.id] = cur
mge_opr = cur.var.owner
if get_opr_type(mge_opr) == "Host2DeviceCopy": if get_opr_type(mge_opr) == "Host2DeviceCopy":
self._orig_inputs.extend(mge_opr.outputs) self._orig_inputs.extend(mge_opr.outputs)
opr = self._add_opr(mge_opr)
if opr is not None:
for x in mge_opr.inputs:
opr.add_inp_var(self._get_var(x))
# set out var
for x in mge_opr.outputs:
opr.add_out_var(self._get_var(x))

return [self.all_vars_map[var.id] for var in vars]
cur.owner = self._add_opr(mge_opr)
if cur.owner is None:
cur.owner = self.all_oprs_map[mge_opr.id]
continue
q.extend(cur.owner.inputs)
return list(vars)


def modify_opr_names(self, modifier): def modify_opr_names(self, modifier):
"""Modifies names of operators **inplace**; useful for merging loaded """Modifies names of operators **inplace**; useful for merging loaded
@@ -275,6 +282,9 @@ class Network:
Replaces vars in the graph. Replaces vars in the graph.
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. :param repl_dict: the map {old_var: new_var} that specifies how to replace the vars.
""" """
if not all([var.owner for var in repl_dict.values()]):
print(repl_dict.values())
self.add_dep_oprs(*list(repl_dict.values()))
for var in self.all_vars: for var in self.all_vars:
if var in repl_dict: if var in repl_dict:
repl_var = repl_dict[var] repl_var = repl_dict[var]
@@ -282,6 +292,7 @@ class Network:
idx = owner.outputs.index(repl_var) idx = owner.outputs.index(repl_var)
owner.outputs[idx] = var owner.outputs[idx] = var
var.__dict__.update(repl_var.__dict__) var.__dict__.update(repl_var.__dict__)
var.var = repl_var.var


def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]):
""" """
@@ -297,6 +308,7 @@ class Network:
for ind, var in enumerate(opr.outputs): for ind, var in enumerate(opr.outputs):
var.owner = repl_dict[opr] var.owner = repl_dict[opr]
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)
var.var = repl_dict[opr].outputs[ind].var


def get_opr_by_type(self, oprcls, unique=True): def get_opr_by_type(self, oprcls, unique=True):
assert issubclass(oprcls, OpNode) assert issubclass(oprcls, OpNode)
@@ -381,11 +393,16 @@ class Network:
return self.opr_filter.as_dict() return self.opr_filter.as_dict()


# used for loading and building graph # used for loading and building graph
def _add_opr(self, x):
def _add_opr(self, opr):
# TODO: use megbrain C++ RTTI to replace type string # TODO: use megbrain C++ RTTI to replace type string
if x.id not in self.all_oprs_map:
self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x)
return self.all_oprs_map[x.id]
if opr.id not in self.all_oprs_map:
opnode = str_to_mge_class(get_opr_type(opr)).load(opr)
self.all_oprs_map[opr.id] = opnode
for var in opr.inputs:
opnode.add_inp_var(self._get_var(var))
for var in opr.outputs:
opnode.add_out_var(self._get_var(var))
return opnode
else: else:
return None return None


@@ -397,7 +414,7 @@ class Network:


def _get_var(self, x): def _get_var(self, x):
# auto convert to VarNode of Network # auto convert to VarNode of Network
if x.id not in self.all_vars_map:
if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x:
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner))
return self.all_vars_map[x.id] return self.all_vars_map[x.id]


@@ -652,7 +669,7 @@ class NodeFilterHasInput(NodeFilter):
assert isinstance( assert isinstance(
i, OpNode i, OpNode
), "has_input() must be used with OpNode; " "got {!r}".format(i) ), "has_input() must be used with OpNode; " "got {!r}".format(i)
if self.var in i.inputs:
if any(self.var is _ for _ in i.inputs):
yield i yield i






+ 52
- 9
imperative/python/megengine/utils/network_node.py View File

@@ -6,16 +6,21 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
import json import json
import sys import sys
from typing import Callable
from typing import Callable, Sequence


import numpy as np import numpy as np


from ..core import _imperative_rt as rt from ..core import _imperative_rt as rt
from ..core._imperative_rt.core2 import SymbolVar
from ..core._wrap import Device from ..core._wrap import Device
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor.megbrain_graph import InputNode
from ..core.tensor.array_method import ArrayMethodMixin
from ..core.tensor.indexing import getitem as _getitem
from ..core.tensor.indexing import setitem as _setitem
from ..core.tensor.megbrain_graph import InputNode, OutputNode
from ..tensor import Tensor from ..tensor import Tensor
from .comp_graph_tools import replace_vars from .comp_graph_tools import replace_vars
from .module_stats import ( from .module_stats import (
@@ -29,9 +34,13 @@ class NetworkNode:
pass pass




class VarNode(NetworkNode):
def __init__(self, owner_opr=None, name=None):
self.var = None
class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)):
pass


class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta):
def __init__(self, var=None, *, owner_opr=None, name=None):
SymbolVar.__init__(self, var)
self.owner = owner_opr self.owner = owner_opr
self.name = name self.name = name
self.id = id(self) self.id = id(self)
@@ -58,6 +67,40 @@ class VarNode(NetworkNode):
def dtype(self): def dtype(self):
return self.var.dtype if self.var else None return self.var.dtype if self.var else None


def __bool__(self):
return False

__index__ = None
__int__ = None
__float__ = None
__complex__ = None

def __hash__(self):
return id(self)

@property
def _tuple_shape(self):
return self.var.shape

def numpy(self):
o = OutputNode(self.var)
self.graph.compile(o.outputs).execute()
return o.get_value().numpy()

def __getitem__(self, index):
return _getitem(self, index)

def __setitem__(self, index, value):
if index is not Ellipsis:
value = _setitem(self, index, value)
if self.owner is not None:
idx = self.owner.outputs.index(self)
self.owner.outputs[idx] = VarNode(
self.var, owner_opr=self.owner, name=self.var.name
)
self.var = value.var
self.owner = None

def set_owner_opr(self, owner_opr): def set_owner_opr(self, owner_opr):
self.owner = owner_opr self.owner = owner_opr


@@ -135,7 +178,7 @@ class Host2DeviceCopy(OpNode):
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name)
self._opr = outputs.owner self._opr = outputs.owner
if len(self.outputs) == 0: if len(self.outputs) == 0:
self.outputs.append(VarNode(self, self.name))
self.outputs.append(VarNode(owner_opr=self, name=self.name))
self.outputs[0].var = outputs self.outputs[0].var = outputs
assert self.outputs[0].owner is self assert self.outputs[0].owner is self


@@ -173,8 +216,8 @@ class ImmutableTensor(OpNode):
def set_value(self, data, device=None): def set_value(self, data, device=None):
assert self.graph is not None assert self.graph is not None
cn = device if device else self.device cn = device if device else self.device
assert isinstance(data, (int, float, np.ndarray))
if isinstance(data, (int, float)):
assert isinstance(data, (int, float, Sequence, np.ndarray))
if not isinstance(data, np.ndarray):
data = np.array(data) data = np.array(data)
if data.dtype == np.float64: if data.dtype == np.float64:
data = data.astype(np.float32) data = data.astype(np.float32)
@@ -182,7 +225,7 @@ class ImmutableTensor(OpNode):
data = data.astype(np.int32) data = data.astype(np.int32)
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name)
if len(self.outputs) == 0: if len(self.outputs) == 0:
self.outputs.append(VarNode(self, self.name))
self.outputs.append(VarNode(owner_opr=self, name=self.name))
self.outputs[0].var = varnode self.outputs[0].var = varnode
self._opr = varnode.owner self._opr = varnode.owner




+ 57
- 19
imperative/python/src/tensor.cpp View File

@@ -160,16 +160,21 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
if (ctx.op->same_type<BackwardGraph>()) { if (ctx.op->same_type<BackwardGraph>()) {
ctx.backward = true; ctx.backward = true;
} }
if (py::isinstance<cg::VarNode>(py::handle(args[0]))){
SmallVector<cg::VarNode*> vinputs(nargs);
for (size_t i = 0; i < nargs; ++i) {
vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>();
}
auto op = ctx.op.get();
return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr();
}

if (py::isinstance<PySymbolVar>(py::handle(args[0]))){
SmallVector<cg::VarNode*> vinputs(nargs);
for (size_t i = 0; i < nargs; ++i) {
vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node;
}
auto op = ctx.op.get();
auto rst = OpDef::apply_on_var_node(*op, vinputs);
auto ret = pybind11::tuple(rst.size());
auto typeobj = py::handle(args[0]).get_type();
for (size_t i = 0; i<rst.size(); ++i) {
ret[i] = typeobj(pybind11::cast(rst[i], pybind11::return_value_policy::automatic));
}
return ret.release().ptr();
}


for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
@@ -686,9 +691,9 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) {
continue; continue;
} }


if (py::isinstance<cg::VarNode>(py::handle(handle))){
auto var = py::handle(handle).cast<cg::VarNode *>();
mgb::DType type = var->dtype();
if (py::isinstance<PySymbolVar>(py::handle(handle))){
auto var = py::handle(handle).cast<PySymbolVar*>();
mgb::DType type = var->m_node->dtype();
auto && descr = npy::dtype_mgb2np_descr(type); auto && descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get()); Py_INCREF(descr.get());
tensors.emplace_back(descr.get()); tensors.emplace_back(descr.get());
@@ -737,19 +742,26 @@ CompNode _get_device(PyObject*const* args, size_t nargs) {
bool valid = false; bool valid = false;
CompNode cn; CompNode cn;
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i];
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
TensorWrapper* tw = TensorWrapper::try_cast(handle); TensorWrapper* tw = TensorWrapper::try_cast(handle);


bool is_var = py::isinstance<cg::VarNode>(py::handle(handle));
if (tw || is_var) {
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
if (tw || is_symvar) {
if (!valid) { if (!valid) {
cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node();
cn = tw ? tw->m_tensor->comp_node()
: py::handle(handle)
.cast<PySymbolVar*>()
->m_node->comp_node();
valid = true; valid = true;
} else { } else {
CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node();
CompNode cn1 = tw ? tw->m_tensor->comp_node()
: py::handle(handle)
.cast<PySymbolVar*>()
->m_node->comp_node();
if (cn1 != cn) { if (cn1 != cn) {
throw py::value_error(ssprintf("ambiguous device: %s vs %s", throw py::value_error(ssprintf("ambiguous device: %s vs %s",
cn.to_string().c_str(), cn1.to_string().c_str()));
cn.to_string().c_str(),
cn1.to_string().c_str()));
} }
} }
} }
@@ -849,6 +861,32 @@ void init_tensor(py::module m) {
.def("__call__", &TensorWeakRef::operator()) .def("__call__", &TensorWeakRef::operator())
.def("_use_cnt", &TensorWeakRef::_use_cnt); .def("_use_cnt", &TensorWeakRef::_use_cnt);


py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
.def_property_readonly(
"dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
.def_property("var", [](PySymbolVar* v) { return v->m_node; },
[](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
.def_property_readonly(
"device",
[](PySymbolVar* v) { return v->m_node->comp_node(); })
.def_property_readonly(
"graph",
[](PySymbolVar* v) { return v->m_node->owner_graph(); })
.def_property_readonly(
"shape",
[](PySymbolVar* v) -> const TensorShape* {
auto&& mgr = v->m_node->owner_graph()
->static_infer_manager();
return mgr.infer_shape_fallible(v->m_node);
})
.def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
.def("_setscalar",
[](PySymbolVar* v) { return v->is_scalar = true; })
.def(py::init([](cg::VarNode* node) {
return std::make_shared<PySymbolVar>(node);
}),
py::arg() = nullptr);

static PyMethodDef method_defs[] = { static PyMethodDef method_defs[] = {
MGE_PY_INTERFACE(apply, py_apply), MGE_PY_INTERFACE(apply, py_apply),
MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),


+ 6
- 0
imperative/python/src/tensor.h View File

@@ -181,6 +181,12 @@ struct TensorWrapper {
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); };
}; };


struct PySymbolVar {
cg::VarNode* m_node = nullptr;
bool is_scalar = false;
PySymbolVar() = default;
PySymbolVar(VarNode *m): m_node(m){}
};


PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */);




+ 21
- 4
imperative/python/test/helpers/utils.py View File

@@ -2,9 +2,11 @@ import io


import numpy as np import numpy as np


import megengine.core.tensor.megbrain_graph as G
import megengine.utils.comp_graph_tools as cgtools import megengine.utils.comp_graph_tools as cgtools
from megengine import tensor from megengine import tensor
from megengine.jit import trace from megengine.jit import trace
from megengine.utils.network_node import VarNode




def _default_compare_fn(x, y): def _default_compare_fn(x, y):
@@ -14,8 +16,23 @@ def _default_compare_fn(x, y):
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)




def make_tensor(x, network=None, device=None):
if network is not None:
if isinstance(x, VarNode):
return VarNode(x.var)
return network.make_const(x, device=device)
else:
return tensor(x, device=device)


def opr_test( def opr_test(
cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs
cases,
func,
compare_fn=_default_compare_fn,
ref_fn=None,
test_trace=True,
network=None,
**kwargs
): ):
""" """
:param cases: the list which have dict element, the list length should be 2 for dynamic shape test. :param cases: the list which have dict element, the list length should be 2 for dynamic shape test.
@@ -44,7 +61,7 @@ def opr_test(
if not isinstance(results, (tuple, list)): if not isinstance(results, (tuple, list)):
results = (results,) results = (results,)
for r, e in zip(results, expected): for r, e in zip(results, expected):
if not isinstance(r, tensor):
if not isinstance(r, (tensor, VarNode)):
r = tensor(r) r = tensor(r)
compare_fn(r, e) compare_fn(r, e)


@@ -72,9 +89,9 @@ def opr_test(
raise ValueError("the input func should be callable") raise ValueError("the input func should be callable")


inp, outp = get_param(cases, 0) inp, outp = get_param(cases, 0)
inp_tensor = [tensor(inpi) for inpi in inp]
inp_tensor = [make_tensor(inpi, network) for inpi in inp]


if test_trace:
if test_trace and not network:
copied_inp = inp_tensor.copy() copied_inp = inp_tensor.copy()
for symbolic in [False, True]: for symbolic in [False, True]:
traced_func = trace(symbolic=symbolic)(func) traced_func = trace(symbolic=symbolic)(func)


+ 79
- 52
imperative/python/test/unit/core/test_indexing_op.py View File

@@ -10,12 +10,17 @@ import collections


import numpy as np import numpy as np
import pytest import pytest
from utils import make_tensor


import megengine import megengine
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin from megengine.core.ops import builtin
from megengine.tensor import Tensor from megengine.tensor import Tensor
from megengine.utils.network import Network
from megengine.utils.network_node import VarNode




def cvt_to_shape_desc(val, inpvar, config=None): def cvt_to_shape_desc(val, inpvar, config=None):
@@ -387,108 +392,130 @@ def test_batched_mesh_indexing():




# high level # high level
def get_value(x):
if isinstance(x, VarNode):
var = x.var
o = G.OutputNode(var)
graph = x.graph
graph.compile(o.outputs).execute()
return o.get_value().numpy()
else:
return x.numpy()


@pytest.mark.parametrize("test_varnode", [True, False])
def test_advance_indexing_high_level(test_varnode):
if test_varnode:
network = Network()
else:
network = None



def test_advance_indexing_high_level():
x = np.arange(25).reshape(5, 5).astype("int32") x = np.arange(25).reshape(5, 5).astype("int32")
d = np.arange(15).reshape(3, 5).astype("int32") d = np.arange(15).reshape(3, 5).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)


np.testing.assert_equal(x[1, :], xx[1, :].numpy())
np.testing.assert_equal(x[:, 1], xx[:, 1].numpy())
np.testing.assert_equal(x[1:3, :], xx[1:3, :].numpy())
np.testing.assert_equal(x[1, :], get_value(xx[1, :]))
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1]))
np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :]))


np.testing.assert_equal(x[:, :], xx[:, :].numpy())
np.testing.assert_equal(x[1, 1], xx[1, 1].numpy())
np.testing.assert_equal(x[:, :], get_value(xx[:, :]))
np.testing.assert_equal(x[1, 1], get_value(xx[1, 1]))
yy = xx[(0, 4, 2), :] yy = xx[(0, 4, 2), :]
np.testing.assert_equal(x[(0, 4, 2), :], yy.numpy())
np.testing.assert_equal(x[(0, 4, 2), :], get_value(yy))


x_ = x.copy() x_ = x.copy()
x_[(0, 4, 2), :] = d x_[(0, 4, 2), :] = d
xx_ = Tensor(xx)
xx_ = make_tensor(xx, network)
xx_[(0, 4, 2), :] = d xx_[(0, 4, 2), :] = d
np.testing.assert_equal(x_, xx_.numpy())
np.testing.assert_equal(x_, get_value(xx_))


x = np.arange(27).reshape(3, 3, 3).astype("int32") x = np.arange(27).reshape(3, 3, 3).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)


np.testing.assert_equal(x[1, :, :], xx[1, :, :].numpy())
np.testing.assert_equal(x[1, :, 1], xx[1, :, 1].numpy())
np.testing.assert_equal(x[1, 0:1, :], xx[1, 0:1, :].numpy())
np.testing.assert_equal(x[0:1, 1, 1], xx[0:1, 1, 1].numpy())
np.testing.assert_equal(x[:, 1, 1], xx[:, 1, 1].numpy())
np.testing.assert_equal(x[:, 1], xx[:, 1].numpy())
np.testing.assert_equal(x[1, 1:2], xx[1, 1:2].numpy())
np.testing.assert_equal(x[1, :, :], get_value(xx[1, :, :]))
np.testing.assert_equal(x[1, :, 1], get_value(xx[1, :, 1]))
np.testing.assert_equal(x[1, 0:1, :], get_value(xx[1, 0:1, :]))
np.testing.assert_equal(x[0:1, 1, 1], get_value(xx[0:1, 1, 1]))
np.testing.assert_equal(x[:, 1, 1], get_value(xx[:, 1, 1]))
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1]))
np.testing.assert_equal(x[1, 1:2], get_value(xx[1, 1:2]))


x_ = x.copy() x_ = x.copy()
x_[1, 1, 1] = -1 x_[1, 1, 1] = -1
xx[1, 1, 1] = -1 xx[1, 1, 1] = -1
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))


x_[:, 1, 1] = -2 x_[:, 1, 1] = -2
xx[:, 1, 1] = x_[:, 1, 1] xx[:, 1, 1] = x_[:, 1, 1]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))


x_[0:1, :, 1] = -3 x_[0:1, :, 1] = -3
xx[0:1, :, 1] = x_[0:1, :, 1] xx[0:1, :, 1] = x_[0:1, :, 1]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))


x_[0:1, :, 1] = -4 x_[0:1, :, 1] = -4
y = Tensor(x_)
y = make_tensor(x_, network)
xx[0:1, :, 1] = y[0:1, :, 1] xx[0:1, :, 1] = y[0:1, :, 1]
np.testing.assert_equal(y.numpy(), xx.numpy())
np.testing.assert_equal(get_value(y), get_value(xx))


x[:] = 1 x[:] = 1
xx[:] = 1 xx[:] = 1
np.testing.assert_equal(x, xx.numpy())
np.testing.assert_equal(x, get_value(xx))


x = np.arange(9).reshape(3, 3).astype("int32") x = np.arange(9).reshape(3, 3).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)
y = np.array([1, 2]) y = np.array([1, 2])
yy = Tensor(y)
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy())
np.testing.assert_equal(x[:, y], xx[:, y].numpy())
np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
yy = make_tensor(y, network)
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]]))
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]]))
np.testing.assert_equal(x[:, y], get_value(xx[:, y]))
np.testing.assert_equal(x[:, y], get_value(xx[:, yy]))


x_ = x.copy() x_ = x.copy()
x_[:, y[0]] = -1 x_[:, y[0]] = -1
xx_ = Tensor(x_)
xx_ = make_tensor(x_, network)
xx[:, yy[0]] = xx_[:, yy[0]] xx[:, yy[0]] = xx_[:, yy[0]]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))


x_[:, y] = -1 x_[:, y] = -1
xx_ = Tensor(x_)
xx_ = make_tensor(x_, network)
xx[:, yy] = xx_[:, yy] xx[:, yy] = xx_[:, yy]
np.testing.assert_equal(x_, xx.numpy())
np.testing.assert_equal(x_, get_value(xx))


x = np.arange(9).reshape(3, 3).astype("int32") x = np.arange(9).reshape(3, 3).astype("int32")
xx = Tensor(x)
xx = make_tensor(x, network)
y = np.array([1]) y = np.array([1])
yy = Tensor(y)
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy())
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy())
np.testing.assert_equal(x[:, y], xx[:, y].numpy())
yy = make_tensor(y, network)
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]]))
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]]))
np.testing.assert_equal(x[:, y], get_value(xx[:, y]))


np.testing.assert_equal(x[:, y], xx[:, yy].numpy())
np.testing.assert_equal(x[:, y], get_value(xx[:, yy]))


x = np.arange(9).reshape(3, 3).astype("int32") x = np.arange(9).reshape(3, 3).astype("int32")
xx = Tensor(x)
np.testing.assert_equal(x[[0, 1], 0], xx[[0, 1], 0].numpy())
np.testing.assert_equal(x[0:2, 0], xx[0:2, 0].numpy())


def test_advance_indexing_with_bool():
xx = make_tensor(x, network)
np.testing.assert_equal(x[[0, 1], 0], get_value(xx[[0, 1], 0]))
np.testing.assert_equal(x[0:2, 0], get_value(xx[0:2, 0]))


@pytest.mark.parametrize(
"test_varnode", [True, False],
)
def test_advance_indexing_with_bool(test_varnode):
if test_varnode:
network = Network()
else:
network = None
a = np.arange(9).reshape(3, 3).astype(np.float32) a = np.arange(9).reshape(3, 3).astype(np.float32)
b = np.array([1, 2, 3]) b = np.array([1, 2, 3])
c = np.array([1, 2, 3]) c = np.array([1, 2, 3])
aa = Tensor(a)
bb = Tensor(b)
cc = Tensor(c)
np.testing.assert_equal(a[b == 1, c == 2], aa[bb == 1, cc == 2].numpy())
aa = make_tensor(a, network)
bb = make_tensor(b, network)
cc = make_tensor(c, network)
np.testing.assert_equal(a[b == 1, c == 2], get_value(aa[bb == 1, cc == 2]))
a[b == 1, c == 2] = -1.0 a[b == 1, c == 2] = -1.0
aa[bb == 1, cc == 2] = -1.0 aa[bb == 1, cc == 2] = -1.0
np.testing.assert_equal(a, aa.numpy())
np.testing.assert_equal(a, get_value(aa))


a = np.arange(9).reshape(3, 3).astype(np.float32) a = np.arange(9).reshape(3, 3).astype(np.float32)
b = np.array([False, True, True]) b = np.array([False, True, True])


+ 217
- 69
imperative/python/test/unit/functional/test_tensor.py View File

@@ -11,13 +11,16 @@ import platform


import numpy as np import numpy as np
import pytest import pytest
from utils import opr_test
from utils import make_tensor, opr_test


import megengine.functional as F import megengine.functional as F
from megengine import tensor from megengine import tensor
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.tensor import megbrain_graph as G
from megengine.core.tensor.utils import astensor1d from megengine.core.tensor.utils import astensor1d
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.utils.network import Network
from megengine.utils.network_node import VarNode




def test_eye(): def test_eye():
@@ -38,7 +41,13 @@ def test_eye():
) )




def test_concat():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_concat(is_varnode):
if is_varnode:
network = Network()
else:
network = None

def get_data_shape(length: int): def get_data_shape(length: int):
return (length, 2, 3) return (length, 2, 3)


@@ -50,18 +59,30 @@ def test_concat():
return F.concat([data1, data2]) return F.concat([data1, data2])


cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] cases = [{"input": [data1, data2]}, {"input": [data1, data3]}]
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]))
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network)




def test_concat_device():
data1 = tensor(np.random.random((3, 2, 2)).astype("float32"), device="cpu0")
data2 = tensor(np.random.random((2, 2, 2)).astype("float32"), device="cpu1")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_concat_device(is_varnode):
if is_varnode:
network = Network()
else:
network = None

data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0")
data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1")


out = F.concat([data1, data2], device="cpu0") out = F.concat([data1, data2], device="cpu0")
assert str(out.device).split(":")[0] == "cpu0" assert str(out.device).split(":")[0] == "cpu0"




def test_stack():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_stack(is_varnode):
if is_varnode:
network = Network()
else:
network = None

data1 = np.random.random((3, 2, 2)).astype("float32") data1 = np.random.random((3, 2, 2)).astype("float32")
data2 = np.random.random((3, 2, 2)).astype("float32") data2 = np.random.random((3, 2, 2)).astype("float32")
data3 = np.random.random((3, 2, 2)).astype("float32") data3 = np.random.random((3, 2, 2)).astype("float32")
@@ -72,12 +93,20 @@ def test_stack():
def run(data1, data2): def run(data1, data2):
return F.stack([data1, data2], axis=ai) return F.stack([data1, data2], axis=ai)


opr_test(cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai))
opr_test(
cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network
)



@pytest.mark.parametrize("is_varnode", [True, False])
def test_split(is_varnode):
if is_varnode:
network = Network()
else:
network = None


def test_split():
data = np.random.random((2, 3, 4, 5)).astype(np.float32) data = np.random.random((2, 3, 4, 5)).astype(np.float32)
inp = tensor(data)
inp = make_tensor(data, network)


mge_out0 = F.split(inp, 2, axis=3) mge_out0 = F.split(inp, 2, axis=3)
mge_out1 = F.split(inp, [3], axis=3) mge_out1 = F.split(inp, [3], axis=3)
@@ -106,26 +135,42 @@ def test_split():
assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]"




def test_reshape():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = np.arange(6, dtype="float32") x = np.arange(6, dtype="float32")
xx = tensor(x)
xx = make_tensor(x, network)
y = x.reshape(1, 2, 3) y = x.reshape(1, 2, 3)


for shape in [ for shape in [
(1, 2, 3), (1, 2, 3),
(1, -1, 3), (1, -1, 3),
(1, tensor(-1), 3),
(1, make_tensor(-1, network), 3),
np.array([1, -1, 3], dtype="int32"), np.array([1, -1, 3], dtype="int32"),
tensor([1, -1, 3]),
make_tensor([1, -1, 3], network),
]: ]:
yy = F.reshape(xx, shape) yy = F.reshape(xx, shape)
np.testing.assert_equal(yy.numpy(), y) np.testing.assert_equal(yy.numpy(), y)




def test_reshape_shape_inference():
x_shape_known = tensor([1, 2, 3, 4], dtype="float32")
x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum())
tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_reshape_shape_inference(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x_shape_known = make_tensor([1, 2, 3, 4], network)
x_shape_unknown = F.broadcast_to(
make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum()
)
tshp_unknown = astensor1d(
(make_tensor([2], network), make_tensor([2], network)), x_shape_known
)
tshp_known = astensor1d((2, 2), x_shape_known) tshp_known = astensor1d((2, 2), x_shape_known)
tshp_known_unspec = astensor1d((2, -1), x_shape_known) tshp_known_unspec = astensor1d((2, -1), x_shape_known)


@@ -146,12 +191,18 @@ def test_reshape_shape_inference():
{"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, {"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, {"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
] ]
opr_test(cases, func, compare_fn=check_shape, test_trace=True)
opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)



@pytest.mark.parametrize("is_varnode", [True, False])
def test_squeeze(is_varnode):
if is_varnode:
network = Network()
else:
network = None


def test_squeeze():
x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1)
xx = tensor(x)
xx = make_tensor(x, network)


for axis in [None, 3, -4, (3, -4)]: for axis in [None, 3, -4, (3, -4)]:
y = np.squeeze(x, axis) y = np.squeeze(x, axis)
@@ -159,9 +210,15 @@ def test_squeeze():
np.testing.assert_equal(y, yy.numpy()) np.testing.assert_equal(y, yy.numpy())




def test_expand_dims():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_expand_dims(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = np.arange(6, dtype="float32").reshape(2, 3) x = np.arange(6, dtype="float32").reshape(2, 3)
xx = tensor(x)
xx = make_tensor(x, network)


for axis in [2, -3, (3, -4), (1, -4)]: for axis in [2, -3, (3, -4), (1, -4)]:
y = np.expand_dims(x, axis) y = np.expand_dims(x, axis)
@@ -169,11 +226,17 @@ def test_expand_dims():
np.testing.assert_equal(y, yy.numpy()) np.testing.assert_equal(y, yy.numpy())




def test_elemwise_dtype_promotion():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_elemwise_dtype_promotion(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = np.random.rand(2, 3).astype("float32") x = np.random.rand(2, 3).astype("float32")
y = np.random.rand(1, 3).astype("float16") y = np.random.rand(1, 3).astype("float16")
xx = tensor(x)
yy = tensor(y)
xx = make_tensor(x, network)
yy = make_tensor(y, network)
z = xx * yy z = xx * yy
np.testing.assert_equal(z.numpy(), x * y) np.testing.assert_equal(z.numpy(), x * y)


@@ -184,7 +247,13 @@ def test_elemwise_dtype_promotion():
np.testing.assert_equal(z.numpy(), x - y) np.testing.assert_equal(z.numpy(), x - y)




def test_linspace():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_linspace(is_varnode):
if is_varnode:
network = Network()
else:
network = None

cases = [ cases = [
{"input": [1, 9, 9]}, {"input": [1, 9, 9]},
{"input": [3, 10, 8]}, {"input": [3, 10, 8]},
@@ -193,6 +262,7 @@ def test_linspace():
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
network=network,
) )


cases = [ cases = [
@@ -203,20 +273,28 @@ def test_linspace():
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
network=network,
) )


cases = [ cases = [
{"input": [1, tensor(9), 9]},
{"input": [tensor(1), 9, tensor(9)]},
{"input": [1, make_tensor(9, network), 9]},
{"input": [make_tensor(1, network), 9, make_tensor(9, network)]},
] ]
opr_test( opr_test(
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
network=network,
) )




def test_arange():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_arange(is_varnode):
if is_varnode:
network = Network()
else:
network = None

cases = [ cases = [
{"input": [1, 9, 1]}, {"input": [1, 9, 1]},
{"input": [2, 10, 2]}, {"input": [2, 10, 2]},
@@ -225,6 +303,7 @@ def test_arange():
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
network=network,
) )


cases = [ cases = [
@@ -235,6 +314,7 @@ def test_arange():
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
network=network,
) )


cases = [ cases = [
@@ -245,20 +325,33 @@ def test_arange():
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
network=network,
) )




def test_round():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_round(is_varnode):
if is_varnode:
network = Network()
else:
network = None

data1_shape = (15,) data1_shape = (15,)
data2_shape = (25,) data2_shape = (25,)
data1 = np.random.random(data1_shape).astype(np.float32) data1 = np.random.random(data1_shape).astype(np.float32)
data2 = np.random.random(data2_shape).astype(np.float32) data2 = np.random.random(data2_shape).astype(np.float32)


cases = [{"input": data1}, {"input": data2}] cases = [{"input": data1}, {"input": data2}]
opr_test(cases, F.round, ref_fn=np.round)
opr_test(cases, F.round, ref_fn=np.round, network=network)




def test_flatten():
@pytest.mark.parametrize("is_varnode", [True, False])
def test_flatten(is_varnode):
if is_varnode:
network = Network()
else:
network = None

data0_shape = (2, 3, 4, 5) data0_shape = (2, 3, 4, 5)
data1_shape = (4, 5, 6, 7) data1_shape = (4, 5, 6, 7)
data0 = np.random.random(data0_shape).astype(np.float32) data0 = np.random.random(data0_shape).astype(np.float32)
@@ -273,7 +366,7 @@ def test_flatten():
{"input": data0, "output": output0}, {"input": data0, "output": output0},
{"input": data1, "output": output1}, {"input": data1, "output": output1},
] ]
opr_test(cases, F.flatten, compare_fn=compare_fn)
opr_test(cases, F.flatten, compare_fn=compare_fn, network=network)


output0 = (2, 3 * 4 * 5) output0 = (2, 3 * 4 * 5)
output1 = (4, 5 * 6 * 7) output1 = (4, 5 * 6 * 7)
@@ -281,7 +374,7 @@ def test_flatten():
{"input": data0, "output": output0}, {"input": data0, "output": output0},
{"input": data1, "output": output1}, {"input": data1, "output": output1},
] ]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1)
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network)


output0 = (2, 3, 4 * 5) output0 = (2, 3, 4 * 5)
output1 = (4, 5, 6 * 7) output1 = (4, 5, 6 * 7)
@@ -289,7 +382,7 @@ def test_flatten():
{"input": data0, "output": output0}, {"input": data0, "output": output0},
{"input": data1, "output": output1}, {"input": data1, "output": output1},
] ]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2)
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network)


output0 = (2, 3 * 4, 5) output0 = (2, 3 * 4, 5)
output1 = (4, 5 * 6, 7) output1 = (4, 5 * 6, 7)
@@ -297,10 +390,23 @@ def test_flatten():
{"input": data0, "output": output0}, {"input": data0, "output": output0},
{"input": data1, "output": output1}, {"input": data1, "output": output1},
] ]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2)
opr_test(
cases,
F.flatten,
compare_fn=compare_fn,
start_axis=1,
end_axis=2,
network=network,
)



@pytest.mark.parametrize("is_varnode", [True, False])
def test_broadcast(is_varnode):
if is_varnode:
network = Network()
else:
network = None


def test_broadcast():
input1_shape = (20, 30) input1_shape = (20, 30)
output1_shape = (30, 20, 30) output1_shape = (30, 20, 30)
data1 = np.random.random(input1_shape).astype(np.float32) data1 = np.random.random(input1_shape).astype(np.float32)
@@ -321,7 +427,7 @@ def test_broadcast():
{"input": [data2, output2_shape], "output": output2_shape}, {"input": [data2, output2_shape], "output": output2_shape},
{"input": [data3, output3_shape], "output": output3_shape}, {"input": [data3, output3_shape], "output": output3_shape},
] ]
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)


x = F.ones((2, 1, 3)) x = F.ones((2, 1, 3))
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -334,35 +440,41 @@ def test_broadcast():
F.broadcast_to(x, (1, 3)) F.broadcast_to(x, (1, 3))




def test_utils_astensor1d():
reference = tensor(0)
@pytest.mark.parametrize("is_varnode", [True, False])
def test_utils_astensor1d(is_varnode):
if is_varnode:
network = Network()
else:
network = None

reference = make_tensor(0, network)


# literal # literal
x = [1, 2, 3] x = [1, 2, 3]
for dtype in [None, "float32"]: for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype) xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), x) np.testing.assert_equal(xx.numpy(), x)


# numpy array # numpy array
x = np.asarray([1, 2, 3], dtype="int32") x = np.asarray([1, 2, 3], dtype="int32")
for dtype in [None, "float32"]: for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype) xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x) np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x)


# tensor # tensor
x = tensor([1, 2, 3], dtype="int32")
x = make_tensor([1, 2, 3], network)
for dtype in [None, "float32"]: for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype) xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), x.numpy()) np.testing.assert_equal(xx.numpy(), x.numpy())


# mixed # mixed
x = [1, tensor(2), 3]
x = [1, make_tensor(2, network), 3]
for dtype in [None, "float32"]: for dtype in [None, "float32"]:
xx = astensor1d(x, reference, dtype=dtype) xx = astensor1d(x, reference, dtype=dtype)
assert type(xx) is tensor
assert isinstance(xx, type(reference))
np.testing.assert_equal(xx.numpy(), [1, 2, 3]) np.testing.assert_equal(xx.numpy(), [1, 2, 3])




@@ -382,35 +494,60 @@ def test_device():
np.testing.assert_almost_equal(y5.numpy(), y6.numpy()) np.testing.assert_almost_equal(y5.numpy(), y6.numpy())




def test_identity():
x = tensor(np.random.random((5, 10)).astype(np.float32))
@pytest.mark.parametrize("is_varnode", [True, False])
def test_identity(is_varnode):
if is_varnode:
network = Network()
else:
network = None

x = make_tensor(np.random.random((5, 10)).astype(np.float32), network)
y = F.copy(x) y = F.copy(x)
np.testing.assert_equal(y.numpy(), x) np.testing.assert_equal(y.numpy(), x)




def copy_test(dst, src):
def copy_test(dst, src, network):
data = np.random.random((2, 3)).astype(np.float32) data = np.random.random((2, 3)).astype(np.float32)
x = tensor(data, device=src)
x = make_tensor(data, device=src, network=network)
y = F.copy(x, dst) y = F.copy(x, dst)
assert np.allclose(data, y.numpy()) assert np.allclose(data, y.numpy())
z = x.to(dst)
assert np.allclose(data, z.numpy())
if network is None:
z = x.to(dst)
assert np.allclose(data, z.numpy())




@pytest.mark.require_ngpu(1) @pytest.mark.require_ngpu(1)
def test_copy_h2d():
copy_test("cpu0", "gpu0")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_h2d(is_varnode):
if is_varnode:
network = Network()
else:
network = None

copy_test("cpu0", "gpu0", network=network)




@pytest.mark.require_ngpu(1) @pytest.mark.require_ngpu(1)
def test_copy_d2h():
copy_test("gpu0", "cpu0")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_d2h(is_varnode):
if is_varnode:
network = Network()
else:
network = None

copy_test("gpu0", "cpu0", network=network)




@pytest.mark.require_ngpu(2) @pytest.mark.require_ngpu(2)
def test_copy_d2d():
copy_test("gpu0", "gpu1")
copy_test("gpu0:0", "gpu0:1")
@pytest.mark.parametrize("is_varnode", [True, False])
def test_copy_d2d(is_varnode):
if is_varnode:
network = Network()
else:
network = None

copy_test("gpu0", "gpu1", network=network)
copy_test("gpu0:0", "gpu0:1", network=network)




@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -425,7 +562,13 @@ def test_copy_d2d():
((), 10, None), ((), 10, None),
], ],
) )
def test_repeat(shape, repeats, axis):
@pytest.mark.parametrize("is_varnode", [True, False])
def test_repeat(shape, repeats, axis, is_varnode):
if is_varnode:
network = Network()
else:
network = None

def repeat_func(inp): def repeat_func(inp):
return F.repeat(inp=inp, repeats=repeats, axis=axis) return F.repeat(inp=inp, repeats=repeats, axis=axis)


@@ -437,7 +580,10 @@ def test_repeat(shape, repeats, axis):
cases = [{"input": np.array(1.23)}] cases = [{"input": np.array(1.23)}]


opr_test( opr_test(
cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis),
cases,
repeat_func,
ref_fn=lambda inp: np.repeat(inp, repeats, axis),
network=network,
) )




@@ -450,14 +596,16 @@ def test_repeat(shape, repeats, axis):
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), ((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)),
], ],
) )
def test_tile(shape, reps):
@pytest.mark.parametrize("is_varnode", [True])
def test_tile(shape, reps, is_varnode):
if is_varnode:
network = Network()
else:
network = None

def tile_func(inp): def tile_func(inp):
return F.tile(inp=inp, reps=reps) return F.tile(inp=inp, reps=reps)


cases = [
{"input": np.random.randn(*shape).astype("float32")},
]
cases = [{"input": np.random.randn(*shape).astype("float32")}]


opr_test(
cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps),
)
opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network)

+ 9
- 15
imperative/python/test/unit/utils/test_network.py View File

@@ -34,13 +34,11 @@ def test_replace_var():
vara = graph.var_filter.name("a").as_unique() vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique() varb = graph.var_filter.name("b").as_unique()


out = F.mul(vara.var, varb.var)
out = F.mul(vara, varb)
out = F.relu(out) out = F.relu(out)


var_list = graph.add_dep_oprs(out)

opnode = list(graph.opr_filter.has_input(vara)) opnode = list(graph.opr_filter.has_input(vara))
repl_dict = {opnode[0].outputs[0]: var_list[0]}
repl_dict = {opnode[0].outputs[0]: out}
graph.replace_vars(repl_dict) graph.replace_vars(repl_dict)


modified_model = io.BytesIO() modified_model = io.BytesIO()
@@ -72,14 +70,12 @@ def test_replace_opr():
vara = graph.var_filter.name("a").as_unique() vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique() varb = graph.var_filter.name("b").as_unique()


out1 = F.sub(vara.var, varb.var)
out1 = F.sub(vara, varb)
out1 = F.relu(out1) out1 = F.relu(out1)

var_list = graph.add_dep_oprs(out1)
repl_opr = as_oprnode(var_list)
out1 = graph.add_dep_oprs(out1)
orig_opr = graph.opr_filter.has_input(vara).as_unique() orig_opr = graph.opr_filter.has_input(vara).as_unique()


repl_dict = {orig_opr: repl_opr}
repl_dict = {orig_opr: out1[0].owner}
graph.replace_oprs(repl_dict) graph.replace_oprs(repl_dict)
modified_model1 = io.BytesIO() modified_model1 = io.BytesIO()
graph.dump(modified_model1) graph.dump(modified_model1)
@@ -171,8 +167,7 @@ def test_add_input():
inp_c = graph.make_input_node((2,), np.int32, name="c") inp_c = graph.make_input_node((2,), np.int32, name="c")
varo = graph.var_filter.name("o").as_unique() varo = graph.var_filter.name("o").as_unique()


out = F.add(varo.var, inp_c.var)
out = graph.add_dep_oprs(out)[0]
out = F.add(varo, inp_c)
out.name = "o1" out.name = "o1"
graph.remove_output(varo) graph.remove_output(varo)
graph.add_output(out) graph.add_output(out)
@@ -206,12 +201,11 @@ def test_add_output():
var_a = net.var_filter.name("a").as_unique() var_a = net.var_filter.name("a").as_unique()
var_b = net.var_filter.name("b").as_unique() var_b = net.var_filter.name("b").as_unique()


y = F.add(var_a.var, var_b.var)
y = F.add(var_a, var_b)
y = F.sigmoid(y) y = F.sigmoid(y)


new_vars = net.add_dep_oprs(y)[0]
new_vars.name = "o1"
net.add_output(new_vars)
y.name = "o1"
net.add_output(y)


modified_model = io.BytesIO() modified_model = io.BytesIO()
net.dump(modified_model) net.dump(modified_model)


Loading…
Cancel
Save