@@ -7,6 +7,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import os | import os | ||||
import re | |||||
from .core._imperative_rt.common import CompNode, DeviceType | from .core._imperative_rt.common import CompNode, DeviceType | ||||
from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config | from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config | ||||
@@ -22,10 +23,8 @@ __all__ = [ | |||||
def _valid_device(inp): | def _valid_device(inp): | ||||
if isinstance(inp, str) and len(inp) == 4: | |||||
if inp[0] in {"x", "c", "g"} and inp[1:3] == "pu": | |||||
if inp[3] == "x" or inp[3].isdigit(): | |||||
return True | |||||
if isinstance(inp, str) and re.match("^[cxg]pu(\d+|\d+:\d+|x)$", inp): | |||||
return True | |||||
return False | return False | ||||
@@ -14,7 +14,7 @@ from .core import Tensor as _Tensor | |||||
from .core.ops.builtin import Copy | from .core.ops.builtin import Copy | ||||
from .core.tensor.core import apply | from .core.tensor.core import apply | ||||
from .core.tensor.raw_tensor import as_device | from .core.tensor.raw_tensor import as_device | ||||
from .device import get_default_device | |||||
from .device import _valid_device, get_default_device | |||||
from .utils.deprecation import deprecated | from .utils.deprecation import deprecated | ||||
@@ -37,6 +37,12 @@ class Tensor(_Tensor): | |||||
self *= 0 | self *= 0 | ||||
def to(self, device): | def to(self, device): | ||||
if isinstance(device, str) and not _valid_device(device): | |||||
raise ValueError( | |||||
"invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format( | |||||
device | |||||
) | |||||
) | |||||
cn = as_device(device).to_c() | cn = as_device(device).to_c() | ||||
return apply(Copy(comp_node=cn), self)[0] | return apply(Copy(comp_node=cn), self)[0] | ||||