@@ -8,6 +8,9 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from .._imperative_rt import make_const | |||
from .._imperative_rt.core2 import SymbolVar, Tensor | |||
class Const: | |||
def __init__(self, value=None, *, dtype=None, device=None): | |||
@@ -19,7 +22,19 @@ class Const: | |||
from ...tensor import Tensor | |||
device = self.device | |||
if device is None: | |||
device = reference[0].device | |||
if len(reference) != 0: | |||
reference = reference[0] | |||
assert isinstance( | |||
reference, (SymbolVar, Tensor) | |||
), "Reference should be Tensor or VarNode" | |||
if device is None: | |||
device = reference.device | |||
if isinstance(reference, SymbolVar): | |||
cls = type(reference) | |||
rst = cls(make_const(reference.graph, self.value, device, self.dtype)) | |||
return (rst,) | |||
return (Tensor(self.value, self.dtype, self.device, True),) |
@@ -13,7 +13,7 @@ from typing import Union | |||
import numpy as np | |||
from .._imperative_rt.common import CompNode | |||
from .._imperative_rt.core2 import Tensor, apply | |||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||
from ..ops import builtin | |||
from ..ops.builtin import Elemwise, GetVarShape | |||
from . import utils | |||
@@ -230,7 +230,9 @@ def _todo(*_): | |||
def _expand_args(args): | |||
if len(args) == 1: | |||
if isinstance(args[0], (collections.abc.Sequence, Tensor, np.ndarray),): | |||
if isinstance( | |||
args[0], (collections.abc.Sequence, Tensor, SymbolVar, np.ndarray), | |||
): | |||
args = args[0] | |||
return args | |||
@@ -10,7 +10,7 @@ from typing import Iterable | |||
import numpy as np | |||
from .._imperative_rt.core2 import Tensor, apply | |||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply | |||
from .._trace_option import use_symbolic_shape | |||
from ..ops import builtin | |||
from ..ops.special import Const | |||
@@ -149,13 +149,13 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
return True | |||
def get_index(i): | |||
if not isinstance(i, (Tensor)): | |||
if not isinstance(i, (Tensor, SymbolVar)): | |||
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | |||
(i,) = Const(i, dtype=np.bool_, device=inp.device)() | |||
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||
else: | |||
(i,) = Const(i, dtype=np.int32, device=inp.device)() | |||
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||
return i | |||
assert isinstance(i, Tensor) | |||
assert isinstance(i, (Tensor, SymbolVar)) | |||
if i.dtype != np.bool_: | |||
return i | |||
_, ind = apply(builtin.CondTake(), i, i) | |||
@@ -197,9 +197,9 @@ def try_condtake(tensor, index): | |||
): | |||
return [] | |||
if isinstance(index, np.ndarray): | |||
(index,) = Const(index, dtype=np.bool_, device=tensor.device)() | |||
assert isinstance(index, Tensor) | |||
if not isinstance(tensor, Tensor): | |||
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | |||
assert isinstance(index, (Tensor, SymbolVar)) | |||
if not isinstance(tensor, (Tensor, SymbolVar)): | |||
raise TypeError("input must be a tensor") | |||
if tensor.device != index.device: | |||
raise ValueError( | |||
@@ -214,11 +214,16 @@ def getitem(tensor, index): | |||
return try_result[0] | |||
tensor, tensors, items, use_subtensor, ret_scalar = unpack_getitem(tensor, index) | |||
for v in tensors: | |||
if v.shape is None: | |||
break | |||
if isinstance(v.shape, v.__class__): | |||
break | |||
if len(v.shape) > 0 and v.shape[0] == 0: | |||
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)() | |||
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||
tensor | |||
) | |||
return empty_tensor | |||
if use_subtensor: | |||
op = builtin.Subtensor(items=items) | |||
else: | |||
@@ -235,8 +240,8 @@ def setitem(tensor, index, value): | |||
if len(try_result) == 2: | |||
index = try_result[1] | |||
tensor = tensor.reshape(-1) | |||
if not isinstance(value, Tensor): | |||
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() | |||
if not isinstance(value, (Tensor, SymbolVar)): | |||
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)(tensor) | |||
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | |||
if use_subtensor: | |||
op = builtin.Subtensor(items=items) | |||
@@ -11,8 +11,9 @@ from typing import Iterable, Union | |||
import numpy as np | |||
from .._imperative_rt import VarNode | |||
from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | |||
from .._imperative_rt import make_const | |||
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion, get_device | |||
from .._wrap import device as as_device | |||
from ..ops import builtin | |||
from ..ops.special import Const | |||
from .dtype import is_dtype_equal, is_quantize | |||
@@ -38,13 +39,9 @@ def set_convert_inputs(flag): | |||
def concatenate(inputs, axis=0, *, device=None): | |||
dtype = dtype_promotion(inputs) | |||
device = get_device(inputs) | |||
def convert(x): | |||
return convert_single_value(x, dtype=dtype, device=device) | |||
inputs = tuple(map(convert, inputs)) | |||
inputs = convert_inputs(*inputs) | |||
if device is None: | |||
device = get_device(inputs) | |||
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | |||
return result | |||
@@ -60,7 +57,7 @@ def astype(x, dtype): | |||
def convert_single_value(v, *, dtype=None, device=None): | |||
if isinstance(v, (Tensor, VarNode)): | |||
if isinstance(v, (Tensor, SymbolVar)): | |||
if not is_quantize(v.dtype): | |||
v = astype(v, dtype) | |||
else: | |||
@@ -68,17 +65,35 @@ def convert_single_value(v, *, dtype=None, device=None): | |||
return v | |||
def convert_inputs(*args: Tensor): | |||
def convert_inputs(*args, device=None): | |||
if not _enable_convert_inputs: | |||
return args | |||
dtype = dtype_promotion(args) | |||
device = get_device(args) | |||
if device is None: | |||
device = get_device(args) | |||
device = as_device(device) | |||
graph = None | |||
sym_type = None | |||
for a in args: | |||
if isinstance(a, SymbolVar): | |||
if graph is None: | |||
graph = a.var.graph | |||
sym_type = type(a) | |||
else: | |||
assert graph == a.var.graph | |||
args = list(args) | |||
if graph is not None: | |||
for i in range(len(args)): | |||
if not isinstance(args[i], SymbolVar): | |||
rst = make_const(graph, np.array(args[i]), device.to_c(), dtype) | |||
args[i] = sym_type(rst) | |||
def convert(value): | |||
if value is None: | |||
return value | |||
return convert_single_value(value, dtype=dtype, device=device) | |||
return convert_single_value(value, dtype=dtype, device=device.to_c()) | |||
return tuple(map(convert, args)) | |||
@@ -98,14 +113,14 @@ def result_type(*args): | |||
def isscalar(x): | |||
if isinstance(x, Tensor): | |||
if isinstance(x, (Tensor, SymbolVar)): | |||
return x._isscalar() | |||
return np.isscalar(x) | |||
def setscalar(x): | |||
if isinstance(x, Tensor): | |||
if isinstance(x, (Tensor, SymbolVar)): | |||
x._setscalar() | |||
else: | |||
raise NotImplementedError("Unsupport type {}".format(type(x))) | |||
@@ -132,7 +147,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
if not isinstance(x, collections.abc.Sequence): | |||
raise TypeError | |||
if any(isinstance(i, Tensor) for i in x): | |||
if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | |||
x = concatenate(x, device=device) | |||
if dtype is not None: | |||
x = astype(x, dtype) | |||
@@ -142,7 +157,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
def _expand_int(s, i): | |||
if isinstance(i, Tensor): | |||
if isinstance(i, (Tensor, SymbolVar)): | |||
i_np = i.numpy() | |||
if i_np.ndim == 0: | |||
s.append(int(i_np)) | |||
@@ -9,8 +9,7 @@ | |||
# pylint: disable=unused-argument,invalid-name,redefined-builtin,arguments-out-of-order | |||
import numpy as np | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core._imperative_rt.graph import VarNode | |||
from ..core._imperative_rt.core2 import SymbolVar, apply | |||
from ..core.ops import builtin | |||
from ..core.ops.builtin import Elemwise | |||
from ..core.tensor import utils | |||
@@ -72,7 +71,7 @@ __all__ = [ | |||
def _elwise(*args, mode): | |||
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) | |||
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, SymbolVar)), args)) | |||
if len(tensor_args) == 0: | |||
dtype = utils.dtype_promotion(args) | |||
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | |||
@@ -12,7 +12,7 @@ from typing import Iterable, Optional, Sequence, Union | |||
import numpy as np | |||
from ..core._imperative_rt import CompNode | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core._imperative_rt.core2 import SymbolVar, apply | |||
from ..core._wrap import device as as_device | |||
from ..core.ops import builtin | |||
from ..core.ops.builtin import Copy, Identity | |||
@@ -101,7 +101,7 @@ def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Ten | |||
return result | |||
def full(shape, value, dtype="float32", device=None): | |||
def full(shape, value, dtype="float32", device=None) -> Tensor: | |||
""" | |||
Returns a tensor with given shape and value. | |||
""" | |||
@@ -115,7 +115,7 @@ def full(shape, value, dtype="float32", device=None): | |||
return broadcast_to(x, shape) | |||
def ones(shape, dtype="float32", device=None): | |||
def ones(shape, dtype="float32", device=None) -> Tensor: | |||
""" | |||
Returns a ones tensor with given shape. | |||
@@ -142,14 +142,14 @@ def ones(shape, dtype="float32", device=None): | |||
return full(shape, 1.0, dtype=dtype, device=device) | |||
def zeros(shape, dtype="float32", device=None): | |||
def zeros(shape, dtype="float32", device=None) -> Tensor: | |||
""" | |||
Returns a zero tensor with given shape. | |||
""" | |||
return full(shape, 0.0, dtype=dtype, device=device) | |||
def zeros_like(inp: Tensor) -> Tensor: | |||
def zeros_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||
""" | |||
Returns a zero tensor with the same shape as input tensor. | |||
@@ -176,21 +176,26 @@ def zeros_like(inp: Tensor) -> Tensor: | |||
[0 0 0]] | |||
""" | |||
return zeros(inp.shape, dtype=inp.dtype, device=inp.device) | |||
return full_like(inp, 0.0) | |||
def ones_like(inp: Tensor) -> Tensor: | |||
def ones_like(inp: Union[Tensor, SymbolVar]) -> Union[Tensor, SymbolVar]: | |||
""" | |||
Returns a ones tensor with the same shape as input tensor. | |||
""" | |||
return ones(inp.shape, dtype=inp.dtype, device=inp.device) | |||
return full_like(inp, 1.0) | |||
def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||
def full_like( | |||
inp: Union[Tensor, SymbolVar], value: Union[int, float] | |||
) -> Union[Tensor, SymbolVar]: | |||
""" | |||
Returns a tensor filled with given value with the same shape as input tensor. | |||
""" | |||
return full(inp.shape, value, dtype=inp.dtype, device=inp.device) | |||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||
if inp.shape is (): | |||
return x | |||
return broadcast_to(x, inp.shape) | |||
def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
@@ -259,15 +264,10 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||
if len(inps) == 1: | |||
return inps[0] | |||
dtype = dtype_promotion(inps) | |||
inps = convert_inputs(*inps, device=device) | |||
if device is None: | |||
device = get_device(inps) | |||
device = as_device(device) | |||
def convert(x): | |||
return convert_single_value(x, dtype=dtype, device=device) | |||
inps = tuple(map(convert, inps)) | |||
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | |||
return result | |||
@@ -379,8 +379,14 @@ def split(inp, nsplits_or_sections, axis=0): | |||
Ntotal, axis, Nsections | |||
) | |||
) | |||
func = ( | |||
floor_div | |||
if isinstance(Nsections, (SymbolVar, Tensor)) | |||
else lambda x, y: x // y | |||
) | |||
div_points = [0] + [ | |||
floor_div(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||
func(Ntotal + Nsections - i - 1, Nsections) for i in range(Nsections) | |||
] | |||
for i in range(2, Nsections + 1): | |||
div_points[i] = div_points[i - 1] + div_points[i] | |||
@@ -925,11 +931,15 @@ def linspace( | |||
if not (cur_device is None or device == cur_device): | |||
raise ("ambiguous device for linspace opr") | |||
if not isinstance(start, Tensor): | |||
is_symbolvar = list(isinstance(x, SymbolVar) for x in [start, stop, num]) | |||
if any(is_symbolvar) and not all(is_symbolvar): | |||
raise TypeError("start, stop and num should all be VarNode or none of them") | |||
if not isinstance(start, (Tensor, SymbolVar)): | |||
start = Tensor(start, device=device) | |||
if not isinstance(stop, Tensor): | |||
if not isinstance(stop, (Tensor, SymbolVar)): | |||
stop = Tensor(stop, device=device) | |||
if not isinstance(num, Tensor): | |||
if not isinstance(num, (Tensor, SymbolVar)): | |||
num = Tensor(num, device=device) | |||
op = builtin.Linspace(comp_node=device) | |||
@@ -983,7 +993,7 @@ def arange( | |||
stop = stop.astype("float32") | |||
if isinstance(step, Tensor): | |||
step = step.astype("float32") | |||
num = ceil(Tensor((stop - start) / step, device=device)) | |||
num = ceil((stop - start) / step) | |||
stop = start + step * (num - 1) | |||
result = linspace(start, stop, num, device=device) | |||
if np.dtype(dtype) == np.int32: | |||
@@ -16,6 +16,7 @@ from typing import Dict, List | |||
import numpy as np | |||
from ..core._imperative_rt import ComputingGraph | |||
from ..core._imperative_rt.core2 import SymbolVar | |||
from ..core.tensor import megbrain_graph as G | |||
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | |||
from .network_node import ( | |||
@@ -60,12 +61,12 @@ class Network: | |||
) | |||
outputs = [new_outputs[i] for i in outspec] | |||
self._orig_outputs = outputs | |||
self.add_dep_oprs(*outputs) | |||
for x in self._orig_outputs: | |||
self.output_vars.append(self._get_var(x)) | |||
self.add_dep_oprs() | |||
for x in self._orig_inputs: | |||
self.input_vars.append(self._get_var(x)) | |||
for x in self._orig_outputs: | |||
self.output_vars.append(self._get_var(x)) | |||
self.graph = self._orig_outputs[0].graph | |||
return self | |||
@@ -197,6 +198,8 @@ class Network: | |||
def add_output(self, *vars: VarNode): | |||
"""Adds vars into the network output node list | |||
""" | |||
if not all([var.owner for var in vars]): | |||
self.add_dep_oprs(*vars) | |||
for var in vars: | |||
if var not in self.output_vars: | |||
self.output_vars.append(var) | |||
@@ -209,21 +212,25 @@ class Network: | |||
self.output_vars.remove(var) | |||
def add_dep_oprs(self, *vars): | |||
"""Adds dependent opnodes and varnodes of vars into network | |||
""" | |||
oprs = get_oprs_seq(vars, False, False) | |||
for mge_opr in oprs: | |||
if len(vars) == 0: | |||
vars = self.output_vars | |||
q = list(vars) | |||
while len(q) > 0: | |||
cur = q.pop(0) | |||
if cur.owner is not None: | |||
continue | |||
if cur.name is None: | |||
cur.name = cur.var.name | |||
self.all_vars_map[cur.var.id] = cur | |||
mge_opr = cur.var.owner | |||
if get_opr_type(mge_opr) == "Host2DeviceCopy": | |||
self._orig_inputs.extend(mge_opr.outputs) | |||
opr = self._add_opr(mge_opr) | |||
if opr is not None: | |||
for x in mge_opr.inputs: | |||
opr.add_inp_var(self._get_var(x)) | |||
# set out var | |||
for x in mge_opr.outputs: | |||
opr.add_out_var(self._get_var(x)) | |||
return [self.all_vars_map[var.id] for var in vars] | |||
cur.owner = self._add_opr(mge_opr) | |||
if cur.owner is None: | |||
cur.owner = self.all_oprs_map[mge_opr.id] | |||
continue | |||
q.extend(cur.owner.inputs) | |||
return list(vars) | |||
def modify_opr_names(self, modifier): | |||
"""Modifies names of operators **inplace**; useful for merging loaded | |||
@@ -275,6 +282,9 @@ class Network: | |||
Replaces vars in the graph. | |||
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | |||
""" | |||
if not all([var.owner for var in repl_dict.values()]): | |||
print(repl_dict.values()) | |||
self.add_dep_oprs(*list(repl_dict.values())) | |||
for var in self.all_vars: | |||
if var in repl_dict: | |||
repl_var = repl_dict[var] | |||
@@ -282,6 +292,7 @@ class Network: | |||
idx = owner.outputs.index(repl_var) | |||
owner.outputs[idx] = var | |||
var.__dict__.update(repl_var.__dict__) | |||
var.var = repl_var.var | |||
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | |||
""" | |||
@@ -297,6 +308,7 @@ class Network: | |||
for ind, var in enumerate(opr.outputs): | |||
var.owner = repl_dict[opr] | |||
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | |||
var.var = repl_dict[opr].outputs[ind].var | |||
def get_opr_by_type(self, oprcls, unique=True): | |||
assert issubclass(oprcls, OpNode) | |||
@@ -381,11 +393,16 @@ class Network: | |||
return self.opr_filter.as_dict() | |||
# used for loading and building graph | |||
def _add_opr(self, x): | |||
def _add_opr(self, opr): | |||
# TODO: use megbrain C++ RTTI to replace type string | |||
if x.id not in self.all_oprs_map: | |||
self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x) | |||
return self.all_oprs_map[x.id] | |||
if opr.id not in self.all_oprs_map: | |||
opnode = str_to_mge_class(get_opr_type(opr)).load(opr) | |||
self.all_oprs_map[opr.id] = opnode | |||
for var in opr.inputs: | |||
opnode.add_inp_var(self._get_var(var)) | |||
for var in opr.outputs: | |||
opnode.add_out_var(self._get_var(var)) | |||
return opnode | |||
else: | |||
return None | |||
@@ -397,7 +414,7 @@ class Network: | |||
def _get_var(self, x): | |||
# auto convert to VarNode of Network | |||
if x.id not in self.all_vars_map: | |||
if x.id not in self.all_vars_map or self.all_vars_map[x.id].var != x: | |||
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | |||
return self.all_vars_map[x.id] | |||
@@ -652,7 +669,7 @@ class NodeFilterHasInput(NodeFilter): | |||
assert isinstance( | |||
i, OpNode | |||
), "has_input() must be used with OpNode; " "got {!r}".format(i) | |||
if self.var in i.inputs: | |||
if any(self.var is _ for _ in i.inputs): | |||
yield i | |||
@@ -6,16 +6,21 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import abc | |||
import json | |||
import sys | |||
from typing import Callable | |||
from typing import Callable, Sequence | |||
import numpy as np | |||
from ..core import _imperative_rt as rt | |||
from ..core._imperative_rt.core2 import SymbolVar | |||
from ..core._wrap import Device | |||
from ..core.ops import builtin | |||
from ..core.tensor.megbrain_graph import InputNode | |||
from ..core.tensor.array_method import ArrayMethodMixin | |||
from ..core.tensor.indexing import getitem as _getitem | |||
from ..core.tensor.indexing import setitem as _setitem | |||
from ..core.tensor.megbrain_graph import InputNode, OutputNode | |||
from ..tensor import Tensor | |||
from .comp_graph_tools import replace_vars | |||
from .module_stats import ( | |||
@@ -29,9 +34,13 @@ class NetworkNode: | |||
pass | |||
class VarNode(NetworkNode): | |||
def __init__(self, owner_opr=None, name=None): | |||
self.var = None | |||
class VarNodeMeta(type(SymbolVar), type(ArrayMethodMixin)): | |||
pass | |||
class VarNode(NetworkNode, SymbolVar, ArrayMethodMixin, metaclass=VarNodeMeta): | |||
def __init__(self, var=None, *, owner_opr=None, name=None): | |||
SymbolVar.__init__(self, var) | |||
self.owner = owner_opr | |||
self.name = name | |||
self.id = id(self) | |||
@@ -58,6 +67,40 @@ class VarNode(NetworkNode): | |||
def dtype(self): | |||
return self.var.dtype if self.var else None | |||
def __bool__(self): | |||
return False | |||
__index__ = None | |||
__int__ = None | |||
__float__ = None | |||
__complex__ = None | |||
def __hash__(self): | |||
return id(self) | |||
@property | |||
def _tuple_shape(self): | |||
return self.var.shape | |||
def numpy(self): | |||
o = OutputNode(self.var) | |||
self.graph.compile(o.outputs).execute() | |||
return o.get_value().numpy() | |||
def __getitem__(self, index): | |||
return _getitem(self, index) | |||
def __setitem__(self, index, value): | |||
if index is not Ellipsis: | |||
value = _setitem(self, index, value) | |||
if self.owner is not None: | |||
idx = self.owner.outputs.index(self) | |||
self.owner.outputs[idx] = VarNode( | |||
self.var, owner_opr=self.owner, name=self.var.name | |||
) | |||
self.var = value.var | |||
self.owner = None | |||
def set_owner_opr(self, owner_opr): | |||
self.owner = owner_opr | |||
@@ -135,7 +178,7 @@ class Host2DeviceCopy(OpNode): | |||
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||
self._opr = outputs.owner | |||
if len(self.outputs) == 0: | |||
self.outputs.append(VarNode(self, self.name)) | |||
self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
self.outputs[0].var = outputs | |||
assert self.outputs[0].owner is self | |||
@@ -173,8 +216,8 @@ class ImmutableTensor(OpNode): | |||
def set_value(self, data, device=None): | |||
assert self.graph is not None | |||
cn = device if device else self.device | |||
assert isinstance(data, (int, float, np.ndarray)) | |||
if isinstance(data, (int, float)): | |||
assert isinstance(data, (int, float, Sequence, np.ndarray)) | |||
if not isinstance(data, np.ndarray): | |||
data = np.array(data) | |||
if data.dtype == np.float64: | |||
data = data.astype(np.float32) | |||
@@ -182,7 +225,7 @@ class ImmutableTensor(OpNode): | |||
data = data.astype(np.int32) | |||
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | |||
if len(self.outputs) == 0: | |||
self.outputs.append(VarNode(self, self.name)) | |||
self.outputs.append(VarNode(owner_opr=self, name=self.name)) | |||
self.outputs[0].var = varnode | |||
self._opr = varnode.owner | |||
@@ -160,16 +160,21 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||
if (ctx.op->same_type<BackwardGraph>()) { | |||
ctx.backward = true; | |||
} | |||
if (py::isinstance<cg::VarNode>(py::handle(args[0]))){ | |||
SmallVector<cg::VarNode*> vinputs(nargs); | |||
for (size_t i = 0; i < nargs; ++i) { | |||
vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>(); | |||
} | |||
auto op = ctx.op.get(); | |||
return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr(); | |||
} | |||
if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | |||
SmallVector<cg::VarNode*> vinputs(nargs); | |||
for (size_t i = 0; i < nargs; ++i) { | |||
vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node; | |||
} | |||
auto op = ctx.op.get(); | |||
auto rst = OpDef::apply_on_var_node(*op, vinputs); | |||
auto ret = pybind11::tuple(rst.size()); | |||
auto typeobj = py::handle(args[0]).get_type(); | |||
for (size_t i = 0; i<rst.size(); ++i) { | |||
ret[i] = typeobj(pybind11::cast(rst[i], pybind11::return_value_policy::automatic)); | |||
} | |||
return ret.release().ptr(); | |||
} | |||
for (size_t i = 0; i < nargs; ++i) { | |||
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||
@@ -686,9 +691,9 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||
continue; | |||
} | |||
if (py::isinstance<cg::VarNode>(py::handle(handle))){ | |||
auto var = py::handle(handle).cast<cg::VarNode *>(); | |||
mgb::DType type = var->dtype(); | |||
if (py::isinstance<PySymbolVar>(py::handle(handle))){ | |||
auto var = py::handle(handle).cast<PySymbolVar*>(); | |||
mgb::DType type = var->m_node->dtype(); | |||
auto && descr = npy::dtype_mgb2np_descr(type); | |||
Py_INCREF(descr.get()); | |||
tensors.emplace_back(descr.get()); | |||
@@ -737,19 +742,26 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { | |||
bool valid = false; | |||
CompNode cn; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
bool is_var = py::isinstance<cg::VarNode>(py::handle(handle)); | |||
if (tw || is_var) { | |||
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||
if (tw || is_symvar) { | |||
if (!valid) { | |||
cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||
cn = tw ? tw->m_tensor->comp_node() | |||
: py::handle(handle) | |||
.cast<PySymbolVar*>() | |||
->m_node->comp_node(); | |||
valid = true; | |||
} else { | |||
CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||
CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||
: py::handle(handle) | |||
.cast<PySymbolVar*>() | |||
->m_node->comp_node(); | |||
if (cn1 != cn) { | |||
throw py::value_error(ssprintf("ambiguous device: %s vs %s", | |||
cn.to_string().c_str(), cn1.to_string().c_str())); | |||
cn.to_string().c_str(), | |||
cn1.to_string().c_str())); | |||
} | |||
} | |||
} | |||
@@ -849,6 +861,32 @@ void init_tensor(py::module m) { | |||
.def("__call__", &TensorWeakRef::operator()) | |||
.def("_use_cnt", &TensorWeakRef::_use_cnt); | |||
py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar") | |||
.def_property_readonly( | |||
"dtype", [](PySymbolVar* v) { return v->m_node->dtype(); }) | |||
.def_property("var", [](PySymbolVar* v) { return v->m_node; }, | |||
[](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; }) | |||
.def_property_readonly( | |||
"device", | |||
[](PySymbolVar* v) { return v->m_node->comp_node(); }) | |||
.def_property_readonly( | |||
"graph", | |||
[](PySymbolVar* v) { return v->m_node->owner_graph(); }) | |||
.def_property_readonly( | |||
"shape", | |||
[](PySymbolVar* v) -> const TensorShape* { | |||
auto&& mgr = v->m_node->owner_graph() | |||
->static_infer_manager(); | |||
return mgr.infer_shape_fallible(v->m_node); | |||
}) | |||
.def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; }) | |||
.def("_setscalar", | |||
[](PySymbolVar* v) { return v->is_scalar = true; }) | |||
.def(py::init([](cg::VarNode* node) { | |||
return std::make_shared<PySymbolVar>(node); | |||
}), | |||
py::arg() = nullptr); | |||
static PyMethodDef method_defs[] = { | |||
MGE_PY_INTERFACE(apply, py_apply), | |||
MGE_PY_INTERFACE(dtype_promotion, dtype_promotion), | |||
@@ -181,6 +181,12 @@ struct TensorWrapper { | |||
PyObject* _use_cnt() { return PyLong_FromSize_t(m_tensor.use_count()); }; | |||
}; | |||
struct PySymbolVar { | |||
cg::VarNode* m_node = nullptr; | |||
bool is_scalar = false; | |||
PySymbolVar() = default; | |||
PySymbolVar(VarNode *m): m_node(m){} | |||
}; | |||
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | |||
@@ -2,9 +2,11 @@ import io | |||
import numpy as np | |||
import megengine.core.tensor.megbrain_graph as G | |||
import megengine.utils.comp_graph_tools as cgtools | |||
from megengine import tensor | |||
from megengine.jit import trace | |||
from megengine.utils.network_node import VarNode | |||
def _default_compare_fn(x, y): | |||
@@ -14,8 +16,23 @@ def _default_compare_fn(x, y): | |||
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6) | |||
def make_tensor(x, network=None, device=None): | |||
if network is not None: | |||
if isinstance(x, VarNode): | |||
return VarNode(x.var) | |||
return network.make_const(x, device=device) | |||
else: | |||
return tensor(x, device=device) | |||
def opr_test( | |||
cases, func, compare_fn=_default_compare_fn, ref_fn=None, test_trace=True, **kwargs | |||
cases, | |||
func, | |||
compare_fn=_default_compare_fn, | |||
ref_fn=None, | |||
test_trace=True, | |||
network=None, | |||
**kwargs | |||
): | |||
""" | |||
:param cases: the list which have dict element, the list length should be 2 for dynamic shape test. | |||
@@ -44,7 +61,7 @@ def opr_test( | |||
if not isinstance(results, (tuple, list)): | |||
results = (results,) | |||
for r, e in zip(results, expected): | |||
if not isinstance(r, tensor): | |||
if not isinstance(r, (tensor, VarNode)): | |||
r = tensor(r) | |||
compare_fn(r, e) | |||
@@ -72,9 +89,9 @@ def opr_test( | |||
raise ValueError("the input func should be callable") | |||
inp, outp = get_param(cases, 0) | |||
inp_tensor = [tensor(inpi) for inpi in inp] | |||
inp_tensor = [make_tensor(inpi, network) for inpi in inp] | |||
if test_trace: | |||
if test_trace and not network: | |||
copied_inp = inp_tensor.copy() | |||
for symbolic in [False, True]: | |||
traced_func = trace(symbolic=symbolic)(func) | |||
@@ -10,12 +10,17 @@ import collections | |||
import numpy as np | |||
import pytest | |||
from utils import make_tensor | |||
import megengine | |||
import megengine.core.tensor.megbrain_graph as G | |||
import megengine.functional as F | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.core.ops import builtin | |||
from megengine.tensor import Tensor | |||
from megengine.utils.network import Network | |||
from megengine.utils.network_node import VarNode | |||
def cvt_to_shape_desc(val, inpvar, config=None): | |||
@@ -387,108 +392,130 @@ def test_batched_mesh_indexing(): | |||
# high level | |||
def get_value(x): | |||
if isinstance(x, VarNode): | |||
var = x.var | |||
o = G.OutputNode(var) | |||
graph = x.graph | |||
graph.compile(o.outputs).execute() | |||
return o.get_value().numpy() | |||
else: | |||
return x.numpy() | |||
@pytest.mark.parametrize("test_varnode", [True, False]) | |||
def test_advance_indexing_high_level(test_varnode): | |||
if test_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
def test_advance_indexing_high_level(): | |||
x = np.arange(25).reshape(5, 5).astype("int32") | |||
d = np.arange(15).reshape(3, 5).astype("int32") | |||
xx = Tensor(x) | |||
xx = make_tensor(x, network) | |||
np.testing.assert_equal(x[1, :], xx[1, :].numpy()) | |||
np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) | |||
np.testing.assert_equal(x[1:3, :], xx[1:3, :].numpy()) | |||
np.testing.assert_equal(x[1, :], get_value(xx[1, :])) | |||
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) | |||
np.testing.assert_equal(x[1:3, :], get_value(xx[1:3, :])) | |||
np.testing.assert_equal(x[:, :], xx[:, :].numpy()) | |||
np.testing.assert_equal(x[1, 1], xx[1, 1].numpy()) | |||
np.testing.assert_equal(x[:, :], get_value(xx[:, :])) | |||
np.testing.assert_equal(x[1, 1], get_value(xx[1, 1])) | |||
yy = xx[(0, 4, 2), :] | |||
np.testing.assert_equal(x[(0, 4, 2), :], yy.numpy()) | |||
np.testing.assert_equal(x[(0, 4, 2), :], get_value(yy)) | |||
x_ = x.copy() | |||
x_[(0, 4, 2), :] = d | |||
xx_ = Tensor(xx) | |||
xx_ = make_tensor(xx, network) | |||
xx_[(0, 4, 2), :] = d | |||
np.testing.assert_equal(x_, xx_.numpy()) | |||
np.testing.assert_equal(x_, get_value(xx_)) | |||
x = np.arange(27).reshape(3, 3, 3).astype("int32") | |||
xx = Tensor(x) | |||
xx = make_tensor(x, network) | |||
np.testing.assert_equal(x[1, :, :], xx[1, :, :].numpy()) | |||
np.testing.assert_equal(x[1, :, 1], xx[1, :, 1].numpy()) | |||
np.testing.assert_equal(x[1, 0:1, :], xx[1, 0:1, :].numpy()) | |||
np.testing.assert_equal(x[0:1, 1, 1], xx[0:1, 1, 1].numpy()) | |||
np.testing.assert_equal(x[:, 1, 1], xx[:, 1, 1].numpy()) | |||
np.testing.assert_equal(x[:, 1], xx[:, 1].numpy()) | |||
np.testing.assert_equal(x[1, 1:2], xx[1, 1:2].numpy()) | |||
np.testing.assert_equal(x[1, :, :], get_value(xx[1, :, :])) | |||
np.testing.assert_equal(x[1, :, 1], get_value(xx[1, :, 1])) | |||
np.testing.assert_equal(x[1, 0:1, :], get_value(xx[1, 0:1, :])) | |||
np.testing.assert_equal(x[0:1, 1, 1], get_value(xx[0:1, 1, 1])) | |||
np.testing.assert_equal(x[:, 1, 1], get_value(xx[:, 1, 1])) | |||
np.testing.assert_equal(x[:, 1], get_value(xx[:, 1])) | |||
np.testing.assert_equal(x[1, 1:2], get_value(xx[1, 1:2])) | |||
x_ = x.copy() | |||
x_[1, 1, 1] = -1 | |||
xx[1, 1, 1] = -1 | |||
np.testing.assert_equal(x_, xx.numpy()) | |||
np.testing.assert_equal(x_, get_value(xx)) | |||
x_[:, 1, 1] = -2 | |||
xx[:, 1, 1] = x_[:, 1, 1] | |||
np.testing.assert_equal(x_, xx.numpy()) | |||
np.testing.assert_equal(x_, get_value(xx)) | |||
x_[0:1, :, 1] = -3 | |||
xx[0:1, :, 1] = x_[0:1, :, 1] | |||
np.testing.assert_equal(x_, xx.numpy()) | |||
np.testing.assert_equal(x_, get_value(xx)) | |||
x_[0:1, :, 1] = -4 | |||
y = Tensor(x_) | |||
y = make_tensor(x_, network) | |||
xx[0:1, :, 1] = y[0:1, :, 1] | |||
np.testing.assert_equal(y.numpy(), xx.numpy()) | |||
np.testing.assert_equal(get_value(y), get_value(xx)) | |||
x[:] = 1 | |||
xx[:] = 1 | |||
np.testing.assert_equal(x, xx.numpy()) | |||
np.testing.assert_equal(x, get_value(xx)) | |||
x = np.arange(9).reshape(3, 3).astype("int32") | |||
xx = Tensor(x) | |||
xx = make_tensor(x, network) | |||
y = np.array([1, 2]) | |||
yy = Tensor(y) | |||
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||
np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||
np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||
yy = make_tensor(y, network) | |||
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) | |||
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) | |||
np.testing.assert_equal(x[:, y], get_value(xx[:, y])) | |||
np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) | |||
x_ = x.copy() | |||
x_[:, y[0]] = -1 | |||
xx_ = Tensor(x_) | |||
xx_ = make_tensor(x_, network) | |||
xx[:, yy[0]] = xx_[:, yy[0]] | |||
np.testing.assert_equal(x_, xx.numpy()) | |||
np.testing.assert_equal(x_, get_value(xx)) | |||
x_[:, y] = -1 | |||
xx_ = Tensor(x_) | |||
xx_ = make_tensor(x_, network) | |||
xx[:, yy] = xx_[:, yy] | |||
np.testing.assert_equal(x_, xx.numpy()) | |||
np.testing.assert_equal(x_, get_value(xx)) | |||
x = np.arange(9).reshape(3, 3).astype("int32") | |||
xx = Tensor(x) | |||
xx = make_tensor(x, network) | |||
y = np.array([1]) | |||
yy = Tensor(y) | |||
np.testing.assert_equal(x[:, y[0]], xx[:, y[0]].numpy()) | |||
np.testing.assert_equal(x[:, y[0]], xx[:, yy[0]].numpy()) | |||
np.testing.assert_equal(x[:, y], xx[:, y].numpy()) | |||
yy = make_tensor(y, network) | |||
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, y[0]])) | |||
np.testing.assert_equal(x[:, y[0]], get_value(xx[:, yy[0]])) | |||
np.testing.assert_equal(x[:, y], get_value(xx[:, y])) | |||
np.testing.assert_equal(x[:, y], xx[:, yy].numpy()) | |||
np.testing.assert_equal(x[:, y], get_value(xx[:, yy])) | |||
x = np.arange(9).reshape(3, 3).astype("int32") | |||
xx = Tensor(x) | |||
np.testing.assert_equal(x[[0, 1], 0], xx[[0, 1], 0].numpy()) | |||
np.testing.assert_equal(x[0:2, 0], xx[0:2, 0].numpy()) | |||
def test_advance_indexing_with_bool(): | |||
xx = make_tensor(x, network) | |||
np.testing.assert_equal(x[[0, 1], 0], get_value(xx[[0, 1], 0])) | |||
np.testing.assert_equal(x[0:2, 0], get_value(xx[0:2, 0])) | |||
@pytest.mark.parametrize( | |||
"test_varnode", [True, False], | |||
) | |||
def test_advance_indexing_with_bool(test_varnode): | |||
if test_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
a = np.arange(9).reshape(3, 3).astype(np.float32) | |||
b = np.array([1, 2, 3]) | |||
c = np.array([1, 2, 3]) | |||
aa = Tensor(a) | |||
bb = Tensor(b) | |||
cc = Tensor(c) | |||
np.testing.assert_equal(a[b == 1, c == 2], aa[bb == 1, cc == 2].numpy()) | |||
aa = make_tensor(a, network) | |||
bb = make_tensor(b, network) | |||
cc = make_tensor(c, network) | |||
np.testing.assert_equal(a[b == 1, c == 2], get_value(aa[bb == 1, cc == 2])) | |||
a[b == 1, c == 2] = -1.0 | |||
aa[bb == 1, cc == 2] = -1.0 | |||
np.testing.assert_equal(a, aa.numpy()) | |||
np.testing.assert_equal(a, get_value(aa)) | |||
a = np.arange(9).reshape(3, 3).astype(np.float32) | |||
b = np.array([False, True, True]) | |||
@@ -11,13 +11,16 @@ import platform | |||
import numpy as np | |||
import pytest | |||
from utils import opr_test | |||
from utils import make_tensor, opr_test | |||
import megengine.functional as F | |||
from megengine import tensor | |||
from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.core.tensor import megbrain_graph as G | |||
from megengine.core.tensor.utils import astensor1d | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.utils.network import Network | |||
from megengine.utils.network_node import VarNode | |||
def test_eye(): | |||
@@ -38,7 +41,13 @@ def test_eye(): | |||
) | |||
def test_concat(): | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_concat(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
def get_data_shape(length: int): | |||
return (length, 2, 3) | |||
@@ -50,18 +59,30 @@ def test_concat(): | |||
return F.concat([data1, data2]) | |||
cases = [{"input": [data1, data2]}, {"input": [data1, data3]}] | |||
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y])) | |||
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]), network=network) | |||
def test_concat_device(): | |||
data1 = tensor(np.random.random((3, 2, 2)).astype("float32"), device="cpu0") | |||
data2 = tensor(np.random.random((2, 2, 2)).astype("float32"), device="cpu1") | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_concat_device(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
data1 = make_tensor(np.random.random((3, 2, 2)).astype("float32"), network, "cpu0") | |||
data2 = make_tensor(np.random.random((2, 2, 2)).astype("float32"), network, "cpu1") | |||
out = F.concat([data1, data2], device="cpu0") | |||
assert str(out.device).split(":")[0] == "cpu0" | |||
def test_stack(): | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_stack(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
data1 = np.random.random((3, 2, 2)).astype("float32") | |||
data2 = np.random.random((3, 2, 2)).astype("float32") | |||
data3 = np.random.random((3, 2, 2)).astype("float32") | |||
@@ -72,12 +93,20 @@ def test_stack(): | |||
def run(data1, data2): | |||
return F.stack([data1, data2], axis=ai) | |||
opr_test(cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai)) | |||
opr_test( | |||
cases, run, ref_fn=lambda x, y: np.stack([x, y], axis=ai), network=network | |||
) | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_split(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
def test_split(): | |||
data = np.random.random((2, 3, 4, 5)).astype(np.float32) | |||
inp = tensor(data) | |||
inp = make_tensor(data, network) | |||
mge_out0 = F.split(inp, 2, axis=3) | |||
mge_out1 = F.split(inp, [3], axis=3) | |||
@@ -106,26 +135,42 @@ def test_split(): | |||
assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]" | |||
def test_reshape(): | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_reshape(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
x = np.arange(6, dtype="float32") | |||
xx = tensor(x) | |||
xx = make_tensor(x, network) | |||
y = x.reshape(1, 2, 3) | |||
for shape in [ | |||
(1, 2, 3), | |||
(1, -1, 3), | |||
(1, tensor(-1), 3), | |||
(1, make_tensor(-1, network), 3), | |||
np.array([1, -1, 3], dtype="int32"), | |||
tensor([1, -1, 3]), | |||
make_tensor([1, -1, 3], network), | |||
]: | |||
yy = F.reshape(xx, shape) | |||
np.testing.assert_equal(yy.numpy(), y) | |||
def test_reshape_shape_inference(): | |||
x_shape_known = tensor([1, 2, 3, 4], dtype="float32") | |||
x_shape_unknown = F.broadcast_to(tensor([1.0]), shape=tensor([1, 1, 1, 1]).sum()) | |||
tshp_unknown = astensor1d((tensor([2]), tensor([2])), x_shape_known) | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_reshape_shape_inference(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
x_shape_known = make_tensor([1, 2, 3, 4], network) | |||
x_shape_unknown = F.broadcast_to( | |||
make_tensor([1.0], network), shape=make_tensor([1, 1, 1, 1], network).sum() | |||
) | |||
tshp_unknown = astensor1d( | |||
(make_tensor([2], network), make_tensor([2], network)), x_shape_known | |||
) | |||
tshp_known = astensor1d((2, 2), x_shape_known) | |||
tshp_known_unspec = astensor1d((2, -1), x_shape_known) | |||
@@ -146,12 +191,18 @@ def test_reshape_shape_inference(): | |||
{"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]}, | |||
{"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]}, | |||
] | |||
opr_test(cases, func, compare_fn=check_shape, test_trace=True) | |||
opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network) | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_squeeze(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
def test_squeeze(): | |||
x = np.arange(6, dtype="float32").reshape(1, 2, 3, 1) | |||
xx = tensor(x) | |||
xx = make_tensor(x, network) | |||
for axis in [None, 3, -4, (3, -4)]: | |||
y = np.squeeze(x, axis) | |||
@@ -159,9 +210,15 @@ def test_squeeze(): | |||
np.testing.assert_equal(y, yy.numpy()) | |||
def test_expand_dims(): | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_expand_dims(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
x = np.arange(6, dtype="float32").reshape(2, 3) | |||
xx = tensor(x) | |||
xx = make_tensor(x, network) | |||
for axis in [2, -3, (3, -4), (1, -4)]: | |||
y = np.expand_dims(x, axis) | |||
@@ -169,11 +226,17 @@ def test_expand_dims(): | |||
np.testing.assert_equal(y, yy.numpy()) | |||
def test_elemwise_dtype_promotion(): | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_elemwise_dtype_promotion(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
x = np.random.rand(2, 3).astype("float32") | |||
y = np.random.rand(1, 3).astype("float16") | |||
xx = tensor(x) | |||
yy = tensor(y) | |||
xx = make_tensor(x, network) | |||
yy = make_tensor(y, network) | |||
z = xx * yy | |||
np.testing.assert_equal(z.numpy(), x * y) | |||
@@ -184,7 +247,13 @@ def test_elemwise_dtype_promotion(): | |||
np.testing.assert_equal(z.numpy(), x - y) | |||
def test_linspace(): | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_linspace(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
cases = [ | |||
{"input": [1, 9, 9]}, | |||
{"input": [3, 10, 8]}, | |||
@@ -193,6 +262,7 @@ def test_linspace(): | |||
cases, | |||
F.linspace, | |||
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||
network=network, | |||
) | |||
cases = [ | |||
@@ -203,20 +273,28 @@ def test_linspace(): | |||
cases, | |||
F.linspace, | |||
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||
network=network, | |||
) | |||
cases = [ | |||
{"input": [1, tensor(9), 9]}, | |||
{"input": [tensor(1), 9, tensor(9)]}, | |||
{"input": [1, make_tensor(9, network), 9]}, | |||
{"input": [make_tensor(1, network), 9, make_tensor(9, network)]}, | |||
] | |||
opr_test( | |||
cases, | |||
F.linspace, | |||
ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32), | |||
network=network, | |||
) | |||
def test_arange(): | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_arange(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
cases = [ | |||
{"input": [1, 9, 1]}, | |||
{"input": [2, 10, 2]}, | |||
@@ -225,6 +303,7 @@ def test_arange(): | |||
cases, | |||
F.arange, | |||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
network=network, | |||
) | |||
cases = [ | |||
@@ -235,6 +314,7 @@ def test_arange(): | |||
cases, | |||
F.arange, | |||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
network=network, | |||
) | |||
cases = [ | |||
@@ -245,20 +325,33 @@ def test_arange(): | |||
cases, | |||
F.arange, | |||
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), | |||
network=network, | |||
) | |||
def test_round(): | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_round(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
data1_shape = (15,) | |||
data2_shape = (25,) | |||
data1 = np.random.random(data1_shape).astype(np.float32) | |||
data2 = np.random.random(data2_shape).astype(np.float32) | |||
cases = [{"input": data1}, {"input": data2}] | |||
opr_test(cases, F.round, ref_fn=np.round) | |||
opr_test(cases, F.round, ref_fn=np.round, network=network) | |||
def test_flatten(): | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_flatten(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
data0_shape = (2, 3, 4, 5) | |||
data1_shape = (4, 5, 6, 7) | |||
data0 = np.random.random(data0_shape).astype(np.float32) | |||
@@ -273,7 +366,7 @@ def test_flatten(): | |||
{"input": data0, "output": output0}, | |||
{"input": data1, "output": output1}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn) | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, network=network) | |||
output0 = (2, 3 * 4 * 5) | |||
output1 = (4, 5 * 6 * 7) | |||
@@ -281,7 +374,7 @@ def test_flatten(): | |||
{"input": data0, "output": output0}, | |||
{"input": data1, "output": output1}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1) | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network) | |||
output0 = (2, 3, 4 * 5) | |||
output1 = (4, 5, 6 * 7) | |||
@@ -289,7 +382,7 @@ def test_flatten(): | |||
{"input": data0, "output": output0}, | |||
{"input": data1, "output": output1}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2) | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network) | |||
output0 = (2, 3 * 4, 5) | |||
output1 = (4, 5 * 6, 7) | |||
@@ -297,10 +390,23 @@ def test_flatten(): | |||
{"input": data0, "output": output0}, | |||
{"input": data1, "output": output1}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) | |||
opr_test( | |||
cases, | |||
F.flatten, | |||
compare_fn=compare_fn, | |||
start_axis=1, | |||
end_axis=2, | |||
network=network, | |||
) | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_broadcast(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
def test_broadcast(): | |||
input1_shape = (20, 30) | |||
output1_shape = (30, 20, 30) | |||
data1 = np.random.random(input1_shape).astype(np.float32) | |||
@@ -321,7 +427,7 @@ def test_broadcast(): | |||
{"input": [data2, output2_shape], "output": output2_shape}, | |||
{"input": [data3, output3_shape], "output": output3_shape}, | |||
] | |||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | |||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network) | |||
x = F.ones((2, 1, 3)) | |||
with pytest.raises(RuntimeError): | |||
@@ -334,35 +440,41 @@ def test_broadcast(): | |||
F.broadcast_to(x, (1, 3)) | |||
def test_utils_astensor1d(): | |||
reference = tensor(0) | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_utils_astensor1d(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
reference = make_tensor(0, network) | |||
# literal | |||
x = [1, 2, 3] | |||
for dtype in [None, "float32"]: | |||
xx = astensor1d(x, reference, dtype=dtype) | |||
assert type(xx) is tensor | |||
assert isinstance(xx, type(reference)) | |||
np.testing.assert_equal(xx.numpy(), x) | |||
# numpy array | |||
x = np.asarray([1, 2, 3], dtype="int32") | |||
for dtype in [None, "float32"]: | |||
xx = astensor1d(x, reference, dtype=dtype) | |||
assert type(xx) is tensor | |||
assert isinstance(xx, type(reference)) | |||
np.testing.assert_equal(xx.numpy(), x.astype(dtype) if dtype else x) | |||
# tensor | |||
x = tensor([1, 2, 3], dtype="int32") | |||
x = make_tensor([1, 2, 3], network) | |||
for dtype in [None, "float32"]: | |||
xx = astensor1d(x, reference, dtype=dtype) | |||
assert type(xx) is tensor | |||
assert isinstance(xx, type(reference)) | |||
np.testing.assert_equal(xx.numpy(), x.numpy()) | |||
# mixed | |||
x = [1, tensor(2), 3] | |||
x = [1, make_tensor(2, network), 3] | |||
for dtype in [None, "float32"]: | |||
xx = astensor1d(x, reference, dtype=dtype) | |||
assert type(xx) is tensor | |||
assert isinstance(xx, type(reference)) | |||
np.testing.assert_equal(xx.numpy(), [1, 2, 3]) | |||
@@ -382,35 +494,60 @@ def test_device(): | |||
np.testing.assert_almost_equal(y5.numpy(), y6.numpy()) | |||
def test_identity(): | |||
x = tensor(np.random.random((5, 10)).astype(np.float32)) | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_identity(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
x = make_tensor(np.random.random((5, 10)).astype(np.float32), network) | |||
y = F.copy(x) | |||
np.testing.assert_equal(y.numpy(), x) | |||
def copy_test(dst, src): | |||
def copy_test(dst, src, network): | |||
data = np.random.random((2, 3)).astype(np.float32) | |||
x = tensor(data, device=src) | |||
x = make_tensor(data, device=src, network=network) | |||
y = F.copy(x, dst) | |||
assert np.allclose(data, y.numpy()) | |||
z = x.to(dst) | |||
assert np.allclose(data, z.numpy()) | |||
if network is None: | |||
z = x.to(dst) | |||
assert np.allclose(data, z.numpy()) | |||
@pytest.mark.require_ngpu(1) | |||
def test_copy_h2d(): | |||
copy_test("cpu0", "gpu0") | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_copy_h2d(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
copy_test("cpu0", "gpu0", network=network) | |||
@pytest.mark.require_ngpu(1) | |||
def test_copy_d2h(): | |||
copy_test("gpu0", "cpu0") | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_copy_d2h(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
copy_test("gpu0", "cpu0", network=network) | |||
@pytest.mark.require_ngpu(2) | |||
def test_copy_d2d(): | |||
copy_test("gpu0", "gpu1") | |||
copy_test("gpu0:0", "gpu0:1") | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_copy_d2d(is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
copy_test("gpu0", "gpu1", network=network) | |||
copy_test("gpu0:0", "gpu0:1", network=network) | |||
@pytest.mark.parametrize( | |||
@@ -425,7 +562,13 @@ def test_copy_d2d(): | |||
((), 10, None), | |||
], | |||
) | |||
def test_repeat(shape, repeats, axis): | |||
@pytest.mark.parametrize("is_varnode", [True, False]) | |||
def test_repeat(shape, repeats, axis, is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
def repeat_func(inp): | |||
return F.repeat(inp=inp, repeats=repeats, axis=axis) | |||
@@ -437,7 +580,10 @@ def test_repeat(shape, repeats, axis): | |||
cases = [{"input": np.array(1.23)}] | |||
opr_test( | |||
cases, repeat_func, ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||
cases, | |||
repeat_func, | |||
ref_fn=lambda inp: np.repeat(inp, repeats, axis), | |||
network=network, | |||
) | |||
@@ -450,14 +596,16 @@ def test_repeat(shape, repeats, axis): | |||
((2, 3, 4, 5), (2, 2, 2, 2, 2, 2, 2)), | |||
], | |||
) | |||
def test_tile(shape, reps): | |||
@pytest.mark.parametrize("is_varnode", [True]) | |||
def test_tile(shape, reps, is_varnode): | |||
if is_varnode: | |||
network = Network() | |||
else: | |||
network = None | |||
def tile_func(inp): | |||
return F.tile(inp=inp, reps=reps) | |||
cases = [ | |||
{"input": np.random.randn(*shape).astype("float32")}, | |||
] | |||
cases = [{"input": np.random.randn(*shape).astype("float32")}] | |||
opr_test( | |||
cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), | |||
) | |||
opr_test(cases, tile_func, ref_fn=lambda inp: np.tile(inp, reps), network=network) |
@@ -34,13 +34,11 @@ def test_replace_var(): | |||
vara = graph.var_filter.name("a").as_unique() | |||
varb = graph.var_filter.name("b").as_unique() | |||
out = F.mul(vara.var, varb.var) | |||
out = F.mul(vara, varb) | |||
out = F.relu(out) | |||
var_list = graph.add_dep_oprs(out) | |||
opnode = list(graph.opr_filter.has_input(vara)) | |||
repl_dict = {opnode[0].outputs[0]: var_list[0]} | |||
repl_dict = {opnode[0].outputs[0]: out} | |||
graph.replace_vars(repl_dict) | |||
modified_model = io.BytesIO() | |||
@@ -72,14 +70,12 @@ def test_replace_opr(): | |||
vara = graph.var_filter.name("a").as_unique() | |||
varb = graph.var_filter.name("b").as_unique() | |||
out1 = F.sub(vara.var, varb.var) | |||
out1 = F.sub(vara, varb) | |||
out1 = F.relu(out1) | |||
var_list = graph.add_dep_oprs(out1) | |||
repl_opr = as_oprnode(var_list) | |||
out1 = graph.add_dep_oprs(out1) | |||
orig_opr = graph.opr_filter.has_input(vara).as_unique() | |||
repl_dict = {orig_opr: repl_opr} | |||
repl_dict = {orig_opr: out1[0].owner} | |||
graph.replace_oprs(repl_dict) | |||
modified_model1 = io.BytesIO() | |||
graph.dump(modified_model1) | |||
@@ -171,8 +167,7 @@ def test_add_input(): | |||
inp_c = graph.make_input_node((2,), np.int32, name="c") | |||
varo = graph.var_filter.name("o").as_unique() | |||
out = F.add(varo.var, inp_c.var) | |||
out = graph.add_dep_oprs(out)[0] | |||
out = F.add(varo, inp_c) | |||
out.name = "o1" | |||
graph.remove_output(varo) | |||
graph.add_output(out) | |||
@@ -206,12 +201,11 @@ def test_add_output(): | |||
var_a = net.var_filter.name("a").as_unique() | |||
var_b = net.var_filter.name("b").as_unique() | |||
y = F.add(var_a.var, var_b.var) | |||
y = F.add(var_a, var_b) | |||
y = F.sigmoid(y) | |||
new_vars = net.add_dep_oprs(y)[0] | |||
new_vars.name = "o1" | |||
net.add_output(new_vars) | |||
y.name = "o1" | |||
net.add_output(y) | |||
modified_model = io.BytesIO() | |||
net.dump(modified_model) | |||