remove core.tensor, raw_tensor,TensorWrapper
avoid create tensor with zero-stride numpy ndarray
GitOrigin-RevId: 4fe5c4c5ba
release-1.2
@@ -9,5 +9,4 @@ | |||
import os | |||
import sys | |||
from .tensor import Tensor | |||
from .tensor.megbrain_graph import Graph |
@@ -27,8 +27,6 @@ from ..ops.builtin import ( | |||
from ..ops.special import Const | |||
from ..tensor.core import apply | |||
from ..tensor.function import Function | |||
from ..tensor.tensor import Tensor | |||
from ..tensor.tensor_wrapper import TensorWrapper | |||
@functools.singledispatch | |||
@@ -21,7 +21,6 @@ from ..ops.builtin import Elemwise, OpDef, RemoteSend | |||
from ..ops.special import Const | |||
from ..tensor.core import TensorBase, TensorWrapperBase, apply | |||
from ..tensor.function import Function | |||
from ..tensor.tensor import Tensor, get_context | |||
from . import builtin_op_utils | |||
""" Some notes: | |||
@@ -65,238 +64,6 @@ def get_tensor(x): | |||
return get_tensor(x) | |||
class Grad: | |||
def __init__(self, name=None): | |||
if name is None: | |||
global _grad_count | |||
self._name = "grad_" + str(_grad_count) | |||
_grad_count += 1 | |||
else: | |||
self._name = name | |||
assert self._name not in _grad_manager_dict, "grad manager name duplicated" | |||
_grad_manager_dict[self._name] = self | |||
# list of all x in partial(y) / partial(x) | |||
self.xs = [] | |||
# constains weak reference of all OpNode during forward | |||
# OpNode contains inputs, outputs and its backward | |||
# ops forms the computational graph | |||
self.ops = [] | |||
# save remote_send output for backward | |||
self.remote_send_cache = [] | |||
self._attached_tensors = weakref.WeakSet() | |||
self._enabled = True | |||
@property | |||
def name(self): | |||
return self._name | |||
def wrt(self, *args: Tensor, callback=None): | |||
""" Indicates the loss is a function of the input tensors (usually the net trainable parameters), | |||
i.e., d (loss) / d (Tensor) != 0 | |||
callback is used to perform additional operations after gradient is obtained in backward. | |||
e.g., copy the grad to a particular place | |||
A VariableNode will be created and saved in the tensor/s _extra_data slot. | |||
""" | |||
for x in map(get_tensor, args): | |||
v = self._new_variable(x, callback=callback) | |||
assert self not in x._extra_data | |||
x._extra_data[self] = Tracer(v) | |||
self.xs.append(v) | |||
return self | |||
def _new_variable(self, owner, opnode=None, callback=None): | |||
self._attached_tensors.add(owner) | |||
return VariableNode(self, owner, opnode=opnode, callback=callback) | |||
def _new_opnode(self, inputs, outputs): | |||
inputs = tuple(inputs) | |||
for i in inputs: | |||
assert i is None or isinstance(i, VariableNode) | |||
o = OpNode() | |||
o.inputs = inputs | |||
o.outputs = [] | |||
tracers = [] | |||
for i in outputs: | |||
assert isinstance(i, Tensor) | |||
v = self._new_variable(i, o) | |||
o.outputs.append(weakref.ref(v)) | |||
tracers.append(Tracer(v)) | |||
self.ops.append(weakref.ref(o)) | |||
return o, tracers | |||
def copy(self): | |||
raise NotImplementedError | |||
def __enter__(self): | |||
return self | |||
def _exit(self): | |||
"""clear all resources""" | |||
self._enabled = False | |||
for o in self.ops: | |||
o = o() | |||
if o: | |||
o.clear() | |||
for i in self._attached_tensors: | |||
i._extra_data.pop(self, None) | |||
self.remote_send_cache = [] | |||
def __exit__(self, *_): | |||
self._exit() | |||
def __call__(self, ys, dys): | |||
""" Defines Grad(). | |||
:param ys: outputs of forward operators, e.g., the loss tensor | |||
:type ys: list of Tensor or TensorWrapperBase | |||
:param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss, | |||
e.g., one for the loss itself | |||
:type dys: list of Tensor or TensorWrapperBase | |||
""" | |||
assert self._enabled | |||
self._enabled = False | |||
def check_wrapper(): | |||
if isinstance(dys, TensorWrapperBase): | |||
return type(dys) | |||
if isinstance(dys, TensorBase): | |||
return | |||
assert isinstance(dys, (tuple, list)) | |||
for i in dys: | |||
if isinstance(i, TensorWrapperBase): | |||
return type(i) | |||
# use Tensor as defualt wrapper | |||
return mge.Tensor | |||
Wrapper = check_wrapper() | |||
def aslist(x): | |||
if isinstance(x, (Tensor, TensorWrapperBase)): | |||
x = [x] | |||
else: | |||
x = list(x) | |||
x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x] | |||
for i in x: | |||
assert isinstance(i, Tensor) | |||
return x | |||
ys = aslist(ys) | |||
dys = aslist(dys) | |||
assert len(ys) == len(dys) | |||
ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()] | |||
ys = [y for i, y in enumerate(ys) if i in ids] | |||
dys = [dy for i, dy in enumerate(dys) if i in ids] | |||
# ys is changed to a list of VariableNode which contains more information | |||
# such as OpNode, callback, etc. | |||
ys = [i._extra_data[self].node for i in ys] | |||
# NOTE: callback is called only if grad is not None | |||
# the OpNode sequence in backward | |||
op_seq = [] | |||
# VariableNode -> (i, j), where i is time stamp in backward, j means jth input | |||
last_written_to = {} | |||
def schedule(): | |||
reached = set(ys) | |||
# i is the time stamp in backward | |||
i = 0 | |||
for o in self.ops[::-1]: | |||
o = o() | |||
if o is None: | |||
continue | |||
if not o.has_grad_fn(o, reached): | |||
continue | |||
op_seq.append(o) | |||
for j, v in enumerate(o.inputs): | |||
reached.add(v) | |||
last_written_to[v] = i, j | |||
i += 1 | |||
schedule() | |||
# VariableNode -> Tensor | |||
cache = {} | |||
def initialize(): | |||
for y, dy in zip(ys, dys): | |||
cache[y] = dy | |||
if y not in last_written_to and y.callback: | |||
y.callback(y.owner(), dy) | |||
initialize() | |||
# NOTE: None is used to mark a node has been consumed | |||
for seqno, opnode in enumerate(op_seq): | |||
input_nodes = opnode.inputs | |||
output_nodes = [i() for i in opnode.outputs] | |||
backward = opnode.backward | |||
backward_allow_noinput = opnode.backward_allow_noinput | |||
opnode.clear() | |||
output_grads = [] | |||
for i in output_nodes: | |||
if i is not None: | |||
if i in cache: | |||
assert cache[i] is not None | |||
output_grads.append(cache[i]) | |||
else: | |||
output_grads.append(None) | |||
# read by backward, mark consumed | |||
cache[i] = None | |||
else: | |||
output_grads.append(None) | |||
if ( | |||
any([grad is not None for grad in output_grads]) | |||
or backward_allow_noinput | |||
): | |||
input_grads = backward(*output_grads) | |||
else: | |||
input_grads = [None] * len(input_nodes) | |||
assert len(input_nodes) == len(input_grads) | |||
for i, (v, g) in enumerate(zip(input_nodes, input_grads)): | |||
if v is None: | |||
continue | |||
if v in cache: | |||
assert cache[v] | |||
if g is not None: | |||
cache[v] = add(cache[v], g) | |||
elif g is not None: | |||
cache[v] = g | |||
if last_written_to[v] == (seqno, i): | |||
if v.callback: | |||
v.callback( | |||
v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] | |||
) | |||
if v.opnode is None: | |||
# won't read by backward, mark consumed | |||
cache[v] = None | |||
for v in cache.values(): | |||
assert v is None | |||
self._exit() | |||
def __del__(self): | |||
self._exit() | |||
class clearable: | |||
__cleared = False | |||
@@ -10,11 +10,6 @@ import warnings | |||
from typing import Union | |||
from ..._imperative_rt import OpDef, ops | |||
from ...tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
# register OpDef as a "virtual subclass" of OpBase, so any of registered | |||
# apply(OpBase, ...) rules could work well on OpDef | |||
OpBase.register(OpDef) | |||
__all__ = ["OpDef"] | |||
@@ -6,4 +6,3 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .tensor_wrapper import TensorWrapper as Tensor |
@@ -13,17 +13,9 @@ import sys | |||
import typing | |||
from abc import ABC | |||
from .._imperative_rt.core2 import apply as apply2 | |||
from .multipledispatch import Dispatcher | |||
def apply_op(op, *args): | |||
Wrapper = type(args[0]) | |||
args = [arg._tensor for arg in args] | |||
results = apply2(op, *args) | |||
return tuple(map(Wrapper, results)) | |||
class OpBase(ABC): | |||
def __call__(self, *args): | |||
return apply(self, *args) | |||
@@ -7,9 +7,6 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..ops.builtin import OpDef | |||
from .core import TensorBase, TensorWrapperBase, apply | |||
from .raw_tensor import RawTensor | |||
from .tensor import Tensor, push_context | |||
from .tensor_wrapper import TensorWrapper | |||
class Function: | |||
@@ -155,13 +152,3 @@ def _(op: Function, *args: TensorWrapperBase): | |||
t._extra_data[k] = i | |||
return tuple(map(Wrapper, outputs)) | |||
@apply.register() | |||
def _(op: Function, *args: Tensor): | |||
raise NotImplementedError | |||
@apply.register() | |||
def _(op: Function, *args: RawTensor): | |||
raise NotImplementedError |
@@ -1,117 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import contextlib | |||
import copy | |||
from .core import Dispatcher, OpBase, TensorBase, apply | |||
class Tensor(TensorBase): | |||
def __init__(self, data: TensorBase): | |||
self._data = data | |||
# _extra_data is set up in Grad.wrt | |||
self._extra_data = {} | |||
self._user_data = {} | |||
def __getattr__(self, name): | |||
if name in self._user_data: | |||
return self._user_data[name] | |||
raise AttributeError(name) | |||
def reset(self, other): | |||
assert isinstance(other, __class__) | |||
self.__dict__.clear() | |||
self._data = other.data | |||
self._extra_data = other._extra_data.copy() | |||
self._user_data = other._user_data.copy() | |||
def copy(self): | |||
other = object.__new__(type(self)) | |||
other.reset(self) | |||
return other | |||
# tensor interface | |||
@property | |||
def shape(self): | |||
return self._data.shape | |||
@property | |||
def dtype(self): | |||
return self._data.dtype | |||
@property | |||
def device(self): | |||
return self._data.device | |||
def numpy(self): | |||
return self._data.numpy() | |||
def _drop(self): | |||
self._data._drop() | |||
def _swap_in(self): | |||
self._data._swap_in() | |||
def _swap_out(self): | |||
self._data._swap_out() | |||
class ApplyContext: | |||
__slots__ = ("inputs", "outputs", "key") | |||
def __init__(self): | |||
self.inputs = None | |||
self.outputs = None | |||
self.key = None | |||
_context = None | |||
@contextlib.contextmanager | |||
def push_context(): | |||
global _context | |||
backup = _context | |||
try: | |||
_context = ApplyContext() | |||
yield _context | |||
finally: | |||
_context = backup | |||
def get_context(): | |||
return _context | |||
@apply.register() | |||
def tensor_apply(op: OpBase, *args: Tensor): | |||
data = tuple(i._data for i in args) | |||
# type(Tensor._data) is RawTensor | |||
# dispached to apply.add@RawTensor.py if passed Tensor args | |||
outputs = apply(op, *data) | |||
ret = tuple(map(Tensor, outputs)) | |||
with push_context() as ctx: | |||
ctx.inputs = args | |||
ctx.outputs = ret | |||
for k in set().union(*(i._extra_data for i in args)): | |||
ctx.key = k | |||
data = tuple( | |||
i._extra_data.get(k) if isinstance(i, Tensor) else i for i in args | |||
) | |||
# data are instances of Tracer | |||
# dispatched to apply.add@grad.py | |||
outputs = apply(op, *data) | |||
if outputs is not None: | |||
assert len(outputs) == len(ret) | |||
for t, i in zip(ret, outputs): | |||
t._extra_data[k] = i | |||
return ret |
@@ -19,7 +19,6 @@ from ..ops import builtin | |||
from ..ops.builtin import Elemwise, GetVarShape | |||
from ..ops.special import Const | |||
from . import utils | |||
from .core import OpBase, TensorBase, TensorWrapperBase | |||
from .indexing import getitem as _getitem | |||
from .indexing import setitem as _setitem | |||
from .utils import isscalar | |||
@@ -439,98 +438,3 @@ class ArrayMethodMixin(abc.ABC): | |||
min = _reduce("MIN") | |||
max = _reduce("MAX") | |||
mean = _reduce("MEAN") | |||
class GenericTensorWrapper(ArrayMethodMixin, TensorWrapperBase): | |||
def __init__(self, data): | |||
self.__wrapped__ = data | |||
def _reset(self, other): | |||
if not isinstance(other, __class__): | |||
raise TypeError(type(other)) | |||
self.__wrapped__ = other.__wrapped__ | |||
return self | |||
@property | |||
def dtype(self): | |||
return self.__wrapped__.dtype | |||
@property | |||
def shape(self): | |||
shape = self.__wrapped__.shape | |||
if shape == () or not use_symbolic_shape(): | |||
return shape | |||
return apply(GetVarShape(), self)[0] | |||
@property | |||
def device(self): | |||
return self.__wrapped__.device | |||
def numpy(self): | |||
return self.__wrapped__.numpy() | |||
def _drop(self): | |||
self.__wrapped__._drop() | |||
def _swap_in(self): | |||
self.__wrapped__._swap_in() | |||
def _swap_out(self): | |||
self.__wrapped__._swap_out() | |||
class TensorWrapper(ArrayMethodMixin, TensorBase): | |||
def __init__(self, data, dtype=None, device=None, isscalar=False): | |||
self._isscalar = isscalar | |||
if isinstance(data, Tensor): | |||
self._tensor = data | |||
else: | |||
if device is None: | |||
device = CompNode._get_default_device() | |||
self._tensor = Tensor(data, dtype, device) | |||
def _reset(self, other): | |||
if not isinstance(other, __class__): | |||
raise TypeError(type(other)) | |||
self._tensor = other._tensor | |||
return self | |||
@property | |||
def dtype(self): | |||
return self._tensor.dtype | |||
@property | |||
def shape(self): | |||
if self._isscalar: | |||
return () | |||
shape = self._tensor.shape | |||
if shape == () or not use_symbolic_shape(): | |||
return shape | |||
return apply(GetVarShape(), self)[0] | |||
@property | |||
def device(self): | |||
return self._tensor.device | |||
def numpy(self): | |||
if self._isscalar: | |||
return self._tensor.numpy().squeeze() | |||
return self._tensor.numpy() | |||
def _drop(self): | |||
self._tensor._drop() | |||
def _swap_in(self): | |||
self._tensor._swap_in() | |||
def _swap_out(self): | |||
self._tensor._swap_out() | |||
def __repr__(self): | |||
piece = "Tensor(" | |||
with np.printoptions(precision=4, suppress=True): | |||
piece += "{}".format(str(self.numpy())) | |||
if self.dtype != np.float32: | |||
piece += ", dtype={}".format(np.dtype(self.dtype).name) | |||
piece += ", device={}".format(self.device) + ")" | |||
return piece |
@@ -18,9 +18,8 @@ from ..core.autodiff.grad import ( | |||
tracer_apply, | |||
) | |||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
from ..core.tensor.tensor import Tensor, tensor_apply | |||
from ..device import get_default_device | |||
from ..tensor import tensor | |||
from ..tensor import Tensor | |||
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||
__all__ = [ | |||
@@ -16,7 +16,6 @@ from ..core._imperative_rt.core2 import apply | |||
from ..core.ops import builtin | |||
from ..core.ops.special import Const | |||
from ..core.tensor import utils | |||
from ..core.tensor.core import TensorBase, TensorWrapperBase | |||
from ..tensor import Tensor | |||
from .elemwise import clip, exp, log, log1p | |||
from .tensor import reshape, squeeze | |||
@@ -703,7 +702,7 @@ def topk( | |||
mode = "VALUE_IDX_SORTED" | |||
op = builtin.TopK(mode=mode) | |||
if not isinstance(k, (TensorBase, TensorWrapperBase)): | |||
if not isinstance(k, Tensor): | |||
(k,) = Const(k, dtype="int32", device=inp.device)(inp) | |||
if len(inp.shape) == 1: | |||
@@ -14,7 +14,7 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union | |||
import numpy as np | |||
from ..core._imperative_rt import CompNode | |||
from ..core._imperative_rt.core2 import Tensor, apply | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core._wrap import device as as_device | |||
from ..core.ops import builtin | |||
from ..core.ops.special import Const | |||
@@ -19,6 +19,7 @@ import weakref | |||
import numpy as np | |||
from ..core._imperative_rt import GraphProfiler | |||
from ..core._imperative_rt.core2 import Tensor | |||
from ..core._imperative_rt.ops import ( | |||
CollectiveComm, | |||
GaussianRNG, | |||
@@ -32,7 +33,6 @@ from ..core.ops.special import Const | |||
from ..core.tensor import megbrain_graph as G | |||
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor | |||
from ..core.tensor.tensor import Tensor | |||
from .sublinear_memory_config import SublinearMemoryConfig | |||
@@ -10,7 +10,6 @@ from typing import Iterable, Union | |||
import numpy as np | |||
from ..core.tensor.tensor import Tensor | |||
from ..tensor import Parameter, tensor | |||
from .optimizer import Optimizer | |||
@@ -10,7 +10,6 @@ from typing import Iterable, Union | |||
import numpy as np | |||
from ..core.tensor.tensor import Tensor | |||
from ..tensor import Parameter, tensor | |||
from .optimizer import Optimizer | |||
@@ -8,7 +8,6 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable, Tuple, Union | |||
from ..core.tensor.tensor import Tensor | |||
from ..tensor import Parameter, tensor | |||
from .optimizer import Optimizer | |||
@@ -8,7 +8,6 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable, Union | |||
from ..core.tensor.tensor import Tensor | |||
from ..tensor import Parameter, tensor | |||
from .optimizer import Optimizer | |||
@@ -16,8 +16,8 @@ from .core._imperative_rt import CompNode | |||
from .core._imperative_rt.core2 import Tensor as _Tensor | |||
from .core._imperative_rt.core2 import apply | |||
from .core._trace_option import use_symbolic_shape | |||
from .core._wrap import device as as_device | |||
from .core.ops.builtin import Copy, GetVarShape | |||
from .core.tensor.raw_tensor import as_device | |||
from .core.tensor.tensor_wrapper import ArrayMethodMixin | |||
from .device import _valid_device, get_default_device | |||
from .utils.deprecation import deprecated | |||
@@ -43,6 +43,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
if isinstance(data, _Tensor): | |||
obj = _Tensor.__new__(cls, data) | |||
else: | |||
if isinstance(data, np.ndarray): | |||
if 0 in data.strides: | |||
data = data.squeeze().reshape(data.shape) | |||
obj = _Tensor.__new__(cls, data, dtype, cn) | |||
return obj | |||
@@ -13,7 +13,7 @@ import numpy | |||
from ..core import _imperative_rt | |||
from ..core._imperative_rt import OperatorNode, VarNode | |||
from ..core.tensor import megbrain_graph as G | |||
from ..core.tensor.raw_tensor import as_raw_tensor | |||
from ..tensor import Tensor | |||
__all__ = [ | |||
"get_dep_vars", | |||
@@ -309,7 +309,7 @@ def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.n | |||
cg = new_out_list[0].graph | |||
func = cg.compile(new_out_list) | |||
for node, value in zip(inp_node_list, inp_data_list): | |||
node.set_value(as_raw_tensor(value)._dev_tensor()) | |||
node.set_value(Tensor(value)._dev_tensor()) | |||
func.execute() | |||
out_data_list = [o.get_value().numpy() for o in out_node_list] | |||
return out_data_list |
@@ -13,7 +13,7 @@ import megengine.functional as F | |||
from megengine import Parameter, optimizer | |||
from megengine.jit import trace | |||
from megengine.module import Linear, Module | |||
from megengine.tensor import tensor | |||
from megengine.tensor import Tensor | |||
class MLP(Module): | |||
@@ -54,7 +54,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
for group in opt.param_groups: | |||
group["lr"] += 0.01 | |||
check_func.lr += 0.01 | |||
data = tensor(np.random.random(data_shape).astype(np.float32)) | |||
data = Tensor(np.random.random(data_shape).astype(np.float32)) | |||
opt.clear_grad() | |||
with gm: | |||
@@ -98,7 +98,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): | |||
ori_params[param] = np.copy(param.numpy()) | |||
train_func( | |||
tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm | |||
Tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm | |||
) | |||
step += 1 | |||
check_func(ori_params, net.parameters(), step) | |||
@@ -11,7 +11,7 @@ import pickle | |||
import numpy as np | |||
from megengine.core.tensor.dtype import bfloat16 | |||
from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
from megengine.tensor import Tensor | |||
def test_define(): | |||
@@ -42,14 +42,14 @@ def test_cast(): | |||
def test_shared_nd(): | |||
data = np.array([-3.4, 1.394683, 2.323497, -7.439948, -5.2397], dtype=bfloat16) | |||
snd = as_raw_tensor(data, dtype=bfloat16, device="xpux") | |||
snd = Tensor(data, dtype=bfloat16, device="xpux") | |||
assert snd.numpy().dtype == bfloat16 | |||
np.testing.assert_allclose( | |||
snd.numpy(), [-3.40625, 1.398438, 2.328125, -7.4375, -5.25], atol=1e-6 | |||
) | |||
data = np.array([-9.34964, -8.342, 9.4385, 0.18746, 1.48], dtype=bfloat16) | |||
snd = as_raw_tensor(data, dtype=bfloat16, device="xpux") | |||
snd = Tensor(data, dtype=bfloat16, device="xpux") | |||
np.testing.assert_allclose( | |||
snd.numpy(), [-9.375, -8.3125, 9.4375, 0.1875, 1.476562], atol=1e-6 | |||
) | |||
@@ -12,7 +12,7 @@ import numpy as np | |||
import pytest | |||
from megengine.core.tensor.dtype import intb1, intb2, intb4 | |||
from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
from megengine.tensor import Tensor | |||
def bit_define_test(bit, low_bit_type): | |||
@@ -78,11 +78,11 @@ def _shared_nd_test(bit, low_bit_type): | |||
min_value = 1 - (1 << bit) | |||
data = np.arange(min_value, max_value + 2, 2, dtype=low_bit_type) | |||
snd = as_raw_tensor(data, dtype=low_bit_type, device="xpux") | |||
snd = Tensor(data, dtype=low_bit_type, device="xpux") | |||
np.testing.assert_allclose(snd.numpy(), range(min_value, max_value + 2, 2)) | |||
data = np.arange(min_value, max_value + 2, 4, dtype=low_bit_type) | |||
snd = as_raw_tensor(data, dtype=low_bit_type, device="xpux") | |||
snd = Tensor(data, dtype=low_bit_type, device="xpux") | |||
np.testing.assert_allclose(snd.numpy(), range(min_value, max_value + 2, 4)) | |||
@@ -32,8 +32,8 @@ from megengine.core.tensor.dtype import ( | |||
quint4, | |||
quint8, | |||
) | |||
from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.tensor import Tensor | |||
def test_dtype_quint8(): | |||
@@ -71,7 +71,7 @@ def _get_compiled_result(inp, dtype, shape, device, calc_func=None): | |||
temp_rst = calc_func(inp_node.outputs[0]) | |||
oup_node = G.OutputNode(temp_rst) | |||
func = graph.compile(oup_node.outputs[0]) | |||
inp_node.set_value(as_raw_tensor(inp, dtype=dtype, device=device)._dev_tensor()) | |||
inp_node.set_value(Tensor(inp, dtype=dtype, device=device)._dev_tensor()) | |||
func.execute() | |||
return oup_node.get_value().numpy() | |||
@@ -9,15 +9,15 @@ | |||
import numpy as np | |||
import pytest | |||
import megengine.core.tensor.raw_tensor | |||
from megengine.core.tensor.core import apply | |||
import megengine | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.tensor import Tensor | |||
def elemwise(*args, mode): | |||
from megengine.core._imperative_rt.imperative import apply_op | |||
from megengine.core.ops.builtin import Elemwise | |||
return apply_op(Elemwise(mode), args) | |||
return apply(Elemwise(mode), *args) | |||
def test_basic_interface(): | |||
@@ -44,11 +44,11 @@ def test_simple_arith(): | |||
from megengine.core.ops.builtin import Elemwise | |||
x = np.random.rand(10).astype("float32") | |||
xx = megengine.core._imperative_rt.put(x) | |||
xx = Tensor(x) | |||
(yy,) = elemwise(xx, xx, mode=Elemwise.Mode.MUL) | |||
np.testing.assert_allclose(x * x, megengine.core._imperative_rt.get_value(yy)) | |||
megengine.core._imperative_rt.delete(xx) | |||
megengine.core._imperative_rt.delete(yy) | |||
np.testing.assert_allclose(x * x, yy.numpy()) | |||
del xx | |||
del yy | |||
def test_tensor_on_device(): | |||
@@ -62,10 +62,9 @@ def test_tensor_on_device(): | |||
def test_raw_tensor(): | |||
from megengine.core.ops.builtin import Elemwise | |||
from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
x = np.random.rand(10).astype("float32") | |||
xx = as_raw_tensor(x) | |||
xx = Tensor(x) | |||
(yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) | |||
np.testing.assert_allclose(x * x, yy.numpy()) | |||
(yy,) = apply(Elemwise(Elemwise.Mode.MUL), xx, xx) | |||
@@ -12,10 +12,10 @@ import numpy as np | |||
import pytest | |||
import megengine | |||
import megengine.tensor as Tensor | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.core._trace_option import use_symbolic_shape | |||
from megengine.core.ops import builtin | |||
from megengine.tensor import Tensor | |||
def cvt_to_shape_desc(val, inpvar, config=None): | |||
@@ -8,8 +8,6 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import pytest | |||
from megengine.core import Tensor | |||
# from megengine.core.interpreter.hints import function | |||
@@ -11,8 +11,8 @@ from concurrent.futures import Future | |||
import numpy as np | |||
import megengine.functional as F | |||
import megengine.tensor as Tensor | |||
from megengine.core.tensor import megbrain_graph as mgb_graph | |||
from megengine.tensor import Tensor | |||
def test_io(): | |||
@@ -9,12 +9,12 @@ | |||
import numpy as np | |||
import megengine.functional as F | |||
from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
from megengine.tensor import Tensor | |||
def test_as_raw_tensor(): | |||
x = np.arange(6, dtype="float32").reshape(2, 3) | |||
xx = as_raw_tensor(x, device="xpux") | |||
xx = Tensor(x, device="xpux") | |||
yy = F.add(xx, 1).numpy() | |||
assert xx.dtype == np.float32 | |||
assert xx.device == "xpux" | |||
@@ -23,7 +23,7 @@ def test_as_raw_tensor(): | |||
def test_as_raw_tensor_from_int64(): | |||
x = np.arange(6, dtype="int64").reshape(2, 3) | |||
xx = as_raw_tensor(x, dtype="float32", device="xpux") | |||
xx = Tensor(x, dtype="float32", device="xpux") | |||
yy = F.add(xx, 1).numpy() | |||
assert xx.dtype == np.float32 | |||
assert xx.device == "xpux" | |||