|
|
@@ -14,7 +14,7 @@ from .core import Tensor as _Tensor |
|
|
|
from .core.ops.builtin import Copy |
|
|
|
from .core.tensor.core import apply |
|
|
|
from .core.tensor.raw_tensor import as_device |
|
|
|
from .device import get_default_device |
|
|
|
from .device import _valid_device, get_default_device |
|
|
|
from .utils.deprecation import deprecated |
|
|
|
|
|
|
|
|
|
|
@@ -37,6 +37,12 @@ class Tensor(_Tensor): |
|
|
|
self *= 0 |
|
|
|
|
|
|
|
def to(self, device): |
|
|
|
if isinstance(device, str) and not _valid_device(device): |
|
|
|
raise ValueError( |
|
|
|
"invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format( |
|
|
|
device |
|
|
|
) |
|
|
|
) |
|
|
|
cn = as_device(device).to_c() |
|
|
|
return apply(Copy(comp_node=cn), self)[0] |
|
|
|
|
|
|
|