Browse Source

feat(mge): add device name check

GitOrigin-RevId: d9910b6275
release-1.1
Megvii Engine Team 4 years ago
parent
commit
d225cbcdbe
2 changed files with 10 additions and 5 deletions
  1. +3
    -4
      imperative/python/megengine/device.py
  2. +7
    -1
      imperative/python/megengine/tensor.py

+ 3
- 4
imperative/python/megengine/device.py View File

@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import re

from .core._imperative_rt.common import CompNode, DeviceType
from .core._imperative_rt.common import set_prealloc_config as _set_prealloc_config
@@ -22,10 +23,8 @@ __all__ = [


def _valid_device(inp):
if isinstance(inp, str) and len(inp) == 4:
if inp[0] in {"x", "c", "g"} and inp[1:3] == "pu":
if inp[3] == "x" or inp[3].isdigit():
return True
if isinstance(inp, str) and re.match("^[cxg]pu(\d+|\d+:\d+|x)$", inp):
return True
return False




+ 7
- 1
imperative/python/megengine/tensor.py View File

@@ -14,7 +14,7 @@ from .core import Tensor as _Tensor
from .core.ops.builtin import Copy
from .core.tensor.core import apply
from .core.tensor.raw_tensor import as_device
from .device import get_default_device
from .device import _valid_device, get_default_device
from .utils.deprecation import deprecated


@@ -37,6 +37,12 @@ class Tensor(_Tensor):
self *= 0

def to(self, device):
if isinstance(device, str) and not _valid_device(device):
raise ValueError(
"invalid device name {}. For the correct format of the device name, please refer to the instruction of megengine.device.set_default_device()".format(
device
)
)
cn = as_device(device).to_c()
return apply(Copy(comp_node=cn), self)[0]



Loading…
Cancel
Save