GitOrigin-RevId: f9caf17d24
tags/v1.0.0-rc1
@@ -70,7 +70,7 @@ def set_default_device(device: str = "xpux"): | |||||
multi-threading parallelism at the operator level. For example, | multi-threading parallelism at the operator level. For example, | ||||
'multithread4' will compute with 4 threads. which implements | 'multithread4' will compute with 4 threads. which implements | ||||
The default value is 'xpux' to specify any device available. | |||||
The default value is 'xpux' to specify any device available. The priority of using gpu is higher when both gpu and cpu are available. | |||||
It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. | ||||
""" | """ | ||||
@@ -11,6 +11,8 @@ | |||||
import collections | import collections | ||||
from .core import Tensor as _Tensor | from .core import Tensor as _Tensor | ||||
from .core.ops.builtin import Copy | |||||
from .core.tensor.core import apply | |||||
from .device import get_default_device | from .device import get_default_device | ||||
@@ -30,6 +32,9 @@ class Tensor(_Tensor): | |||||
def reset_zero(self): | def reset_zero(self): | ||||
self *= 0 | self *= 0 | ||||
def to(self, cn): | |||||
return apply(Copy(comp_node=cn), self)[0] | |||||
def __getstate__(self): | def __getstate__(self): | ||||
r""" __getstate__ will be called for pickle serialization or deep copy | r""" __getstate__ will be called for pickle serialization or deep copy | ||||
""" | """ | ||||
@@ -322,6 +322,8 @@ def copy_test(dst, src): | |||||
x = tensor(data, device=src) | x = tensor(data, device=src) | ||||
y = F.copy(x, dst) | y = F.copy(x, dst) | ||||
assert np.allclose(data, y.numpy()) | assert np.allclose(data, y.numpy()) | ||||
z = x.to(dst) | |||||
assert np.allclose(data, z.numpy()) | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||