|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import numpy as np
- import torch
-
- import megengine._internal as mgb
-
- _TORCH_NUMPY_MAPPING = {
- torch.float16: np.float16,
- torch.float32: np.float32,
- torch.float64: np.float64,
- torch.int8: np.int8,
- torch.int16: np.int16,
- torch.int32: np.int32,
- }
-
-
- def torch_dtype_to_numpy_dtype(torch_dtype: torch.dtype):
- """map torch dtype to numpy dtype
-
- :param torch_dtype: torch dtype
- :return: numpy dtype
- """
- if not isinstance(torch_dtype, torch.dtype):
- raise TypeError("Argument `torch_dtype` should be an instance of torch.dtype")
- if torch_dtype not in _TORCH_NUMPY_MAPPING:
- raise ValueError("Unknown PyTorch dtype: {}".format(torch_dtype))
- return _TORCH_NUMPY_MAPPING[torch_dtype]
-
-
- def torch_device_to_device(device: torch.device):
- """map torch device to device
-
- :param device: torch device
- :return: device
- """
- if not isinstance(device, torch.device):
- raise TypeError("Argument `device` should be an instance of torch.device")
- index = device.index
- if index is None:
- index = "x"
- if device.type == "cpu":
- return "cpu{}".format(index)
- elif device.type == "cuda":
- return "gpu{}".format(index)
- raise ValueError("Unknown PyTorch device: {}".format(device))
-
-
- def device_to_torch_device(device: mgb.CompNode):
- """map device to torch device
-
- :param device: megbrain compute node
- :return: corresponding torch device
- """
- t, d, _ = device.locator_physical
- if t == "CUDA":
- return torch.device("cuda", d)
- elif t == "CPU":
- return torch.device("cpu", d)
- else:
- raise Exception("Unsupported device type: {}".format(t))
|