diff --git a/imperative/python/megengine/device.py b/imperative/python/megengine/device.py index f81893c2..2667536f 100644 --- a/imperative/python/megengine/device.py +++ b/imperative/python/megengine/device.py @@ -7,6 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import os +import re from .core._imperative_rt.common import CompNode, DeviceType from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config @@ -22,10 +23,8 @@ __all__ = [ def _valid_device(inp): - if isinstance(inp, str) and len(inp) == 4: - if inp[0] in {"x", "c", "g"} and inp[1:3] == "pu": - if inp[3] == "x" or inp[3].isdigit(): - return True + if isinstance(inp, str) and re.match("^[cxg]pu(\d+|\d+:\d+|x)$", inp): + return True return False diff --git a/imperative/python/megengine/tensor.py b/imperative/python/megengine/tensor.py index a8eae821..0d30a264 100644 --- a/imperative/python/megengine/tensor.py +++ b/imperative/python/megengine/tensor.py @@ -14,7 +14,7 @@ from .core import Tensor as _Tensor from .core.ops.builtin import Copy from .core.tensor.core import apply from .core.tensor.raw_tensor import as_device -from .device import get_default_device +from .device import _valid_device, get_default_device from .utils.deprecation import deprecated @@ -37,6 +37,12 @@ class Tensor(_Tensor): self *= 0 def to(self, device): + if isinstance(device, str) and not _valid_device(device): + raise ValueError( + "invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format( + device + ) + ) cn = as_device(device).to_c() return apply(Copy(comp_node=cn), self)[0]