Browse Source

feat(imperative): enable to() to copy to device

GitOrigin-RevId: f9caf17d24
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
44d0b5daf5
3 changed files with 8 additions and 1 deletions
  1. +1
    -1
      imperative/python/megengine/device.py
  2. +5
    -0
      imperative/python/megengine/tensor.py
  3. +2
    -0
      imperative/python/test/unit/functional/test_tensor.py

+ 1
- 1
imperative/python/megengine/device.py View File

@@ -70,7 +70,7 @@ def set_default_device(device: str = "xpux"):
multi-threading parallelism at the operator level. For example, multi-threading parallelism at the operator level. For example,
'multithread4' will compute with 4 threads. which implements 'multithread4' will compute with 4 threads. which implements


The default value is 'xpux' to specify any device available.
The default value is 'xpux' to specify any device available. The priority of using gpu is higher when both gpu and cpu are available.


It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. It can also be set by environmental variable `MGE_DEFAULT_DEVICE`.
""" """


+ 5
- 0
imperative/python/megengine/tensor.py View File

@@ -11,6 +11,8 @@
import collections import collections


from .core import Tensor as _Tensor from .core import Tensor as _Tensor
from .core.ops.builtin import Copy
from .core.tensor.core import apply
from .device import get_default_device from .device import get_default_device




@@ -30,6 +32,9 @@ class Tensor(_Tensor):
def reset_zero(self): def reset_zero(self):
self *= 0 self *= 0


def to(self, cn):
return apply(Copy(comp_node=cn), self)[0]

def __getstate__(self): def __getstate__(self):
r""" __getstate__ will be called for pickle serialization or deep copy r""" __getstate__ will be called for pickle serialization or deep copy
""" """


+ 2
- 0
imperative/python/test/unit/functional/test_tensor.py View File

@@ -322,6 +322,8 @@ def copy_test(dst, src):
x = tensor(data, device=src) x = tensor(data, device=src)
y = F.copy(x, dst) y = F.copy(x, dst)
assert np.allclose(data, y.numpy()) assert np.allclose(data, y.numpy())
z = x.to(dst)
assert np.allclose(data, z.numpy())




@pytest.mark.skipif( @pytest.mark.skipif(


Loading…
Cancel
Save