GitOrigin-RevId: baef3d348c
release-1.2
@@ -8,7 +8,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import numpy as np | import numpy as np | ||||
from .._imperative_rt.core2 import Tensor | |||||
# from .._imperative_rt.core2 import Tensor | |||||
from ..tensor.core import OpBase, TensorBase, apply | from ..tensor.core import OpBase, TensorBase, apply | ||||
@@ -19,5 +19,10 @@ class Const: | |||||
self.device = device | self.device = device | ||||
def __call__(self, *reference): | def __call__(self, *reference): | ||||
Wrapper = type(reference[0]) | |||||
return (Wrapper(self.value, self.dtype, self.device, True),) | |||||
from ...tensor import Tensor | |||||
device = self.device | |||||
if device is None: | |||||
device = reference[0].device | |||||
return (Tensor(self.value, self.dtype, self.device, True),) |
@@ -13,6 +13,12 @@ import numpy as np | |||||
# normal dtype related | # normal dtype related | ||||
from .._imperative_rt import bfloat16, intb1, intb2, intb4 | from .._imperative_rt import bfloat16, intb1, intb2, intb4 | ||||
from .._imperative_rt.common import ( | |||||
get_scale, | |||||
get_zero_point, | |||||
is_dtype_equal, | |||||
is_quantize, | |||||
) | |||||
def is_lowbit(dtype): | def is_lowbit(dtype): | ||||
@@ -42,41 +48,6 @@ _metadata_dict = { | |||||
} | } | ||||
def is_quantize(dtype): | |||||
return ( | |||||
hasattr(dtype, "metadata") | |||||
and dtype.metadata is not None | |||||
and "mgb_dtype" in dtype.metadata | |||||
) | |||||
def get_scale(dtype): | |||||
assert is_quantize(dtype) | |||||
return dtype.metadata["mgb_dtype"]["scale"] | |||||
def get_zero_point(dtype): | |||||
assert is_quantize(dtype) | |||||
metadata = dtype.metadata["mgb_dtype"] | |||||
assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||||
return metadata["zero_point"] | |||||
def is_equal(dt0, dt1): | |||||
def _get_zero_point(dtype): | |||||
assert is_quantize(dtype) | |||||
metadata = dtype.metadata["mgb_dtype"] | |||||
return metadata.get("zero_point") | |||||
if is_quantize(dt0) and is_quantize(dt1): | |||||
return get_scale(dt0) == get_scale(dt1) and _get_zero_point( | |||||
dt0 | |||||
) == _get_zero_point(dt1) | |||||
if not (is_quantize(dt0) or is_quantize(dt1)): | |||||
return dt0 == dt1 | |||||
return False | |||||
def _check_zero_point(zp: int, dtype_str: str): | def _check_zero_point(zp: int, dtype_str: str): | ||||
qmin = _metadata_dict[dtype_str].qmin | qmin = _metadata_dict[dtype_str].qmin | ||||
qmax = _metadata_dict[dtype_str].qmax | qmax = _metadata_dict[dtype_str].qmax | ||||
@@ -151,9 +151,9 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||||
def get_index(i): | def get_index(i): | ||||
if not isinstance(i, (Tensor)): | if not isinstance(i, (Tensor)): | ||||
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | ||||
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||||
(i,) = Const(i, dtype=np.bool_, device=inp.device)() | |||||
else: | else: | ||||
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||||
(i,) = Const(i, dtype=np.int32, device=inp.device)() | |||||
return i | return i | ||||
assert isinstance(i, Tensor) | assert isinstance(i, Tensor) | ||||
if i.dtype != np.bool_: | if i.dtype != np.bool_: | ||||
@@ -197,7 +197,7 @@ def try_condtake(tensor, index): | |||||
): | ): | ||||
return [] | return [] | ||||
if isinstance(index, np.ndarray): | if isinstance(index, np.ndarray): | ||||
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | |||||
(index,) = Const(index, dtype=np.bool_, device=tensor.device)() | |||||
assert isinstance(index, Tensor) | assert isinstance(index, Tensor) | ||||
if not isinstance(tensor, Tensor): | if not isinstance(tensor, Tensor): | ||||
raise TypeError("input must be a tensor") | raise TypeError("input must be a tensor") | ||||
@@ -217,9 +217,7 @@ def getitem(tensor, index): | |||||
if isinstance(v.shape, v.__class__): | if isinstance(v.shape, v.__class__): | ||||
break | break | ||||
if len(v.shape) > 0 and v.shape[0] == 0: | if len(v.shape) > 0 and v.shape[0] == 0: | ||||
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||||
tensor | |||||
) | |||||
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)() | |||||
return empty_tensor | return empty_tensor | ||||
if use_subtensor: | if use_subtensor: | ||||
op = builtin.Subtensor(items=items) | op = builtin.Subtensor(items=items) | ||||
@@ -240,8 +238,7 @@ def setitem(tensor, index, value): | |||||
return tensor | return tensor | ||||
tensor = tensor.reshape(-1) | tensor = tensor.reshape(-1) | ||||
if not isinstance(value, Tensor): | if not isinstance(value, Tensor): | ||||
op = Const(value, dtype=tensor.dtype, device=tensor.device) | |||||
(value,) = op(tensor) | |||||
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() | |||||
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | ||||
for v in tensors: | for v in tensors: | ||||
if len(v.shape) > 0 and v.shape[0] == 0: | if len(v.shape) > 0 and v.shape[0] == 0: | ||||
@@ -11,10 +11,10 @@ from typing import Iterable, Union | |||||
import numpy as np | import numpy as np | ||||
from .._imperative_rt.core2 import Tensor, apply | |||||
from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from ..ops.special import Const | from ..ops.special import Const | ||||
from .dtype import is_equal, is_quantize | |||||
from .dtype import is_dtype_equal, is_quantize | |||||
from .megbrain_graph import VarNode | from .megbrain_graph import VarNode | ||||
_enable_convert_inputs = True | _enable_convert_inputs = True | ||||
@@ -37,94 +37,12 @@ def set_convert_inputs(flag): | |||||
return backup | return backup | ||||
def dtype_promotion(inputs): | |||||
""" | |||||
Returns the dtype that would result from performing an arithmetic | |||||
operation on the provided input tensors and scalars. | |||||
""" | |||||
# map numpy.dtype.kind to priority | |||||
category_priority = { | |||||
"f": 3, # floating-point | |||||
"i": 2, # signed integer | |||||
"u": 2, # unsigned integer | |||||
"b": 1, # boolean | |||||
} | |||||
def scalar2dtype(x): | |||||
""" | |||||
For scalar `x`, returns its corresponding type. A floating point scalar | |||||
has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'. | |||||
A boolean scalar has dtype 'bool'. | |||||
""" | |||||
if isinstance(x, bool): | |||||
return np.bool_ | |||||
if isinstance(x, int): | |||||
return np.int32 | |||||
if isinstance(x, float): | |||||
return np.float32 | |||||
def promote_types(types, cat): | |||||
""" | |||||
Returns the data type with sufficient size to hold all types of | |||||
category `cat` in the list `types`. | |||||
""" | |||||
used_types = [ | |||||
i for i in types if category_priority.get(np.dtype(i).kind, 0) == cat | |||||
] | |||||
assert len(used_types) > 0 | |||||
res = used_types[0] | |||||
for i in used_types: | |||||
res = np.promote_types(res, i) | |||||
return res | |||||
def max_priority(types): | |||||
""" | |||||
Returns the maximum value of the priority of each type in the list | |||||
`types`. | |||||
""" | |||||
if not types: | |||||
return 0 | |||||
else: | |||||
return max([category_priority.get(np.dtype(i).kind, 0) for i in types]) | |||||
scalars = [] | |||||
tensors = [] | |||||
for data in inputs: | |||||
if hasattr(data, "dtype"): | |||||
tensors.append(data.dtype) | |||||
elif isinstance(data, (float, int, bool)): | |||||
scalars.append(scalar2dtype(data)) | |||||
max_pri_scalars = max_priority(scalars) | |||||
max_pri_tensors = max_priority(tensors) | |||||
assert max_pri_scalars > 0 or max_pri_tensors > 0 | |||||
if max_pri_scalars > max_pri_tensors: | |||||
return promote_types(scalars, max_pri_scalars) | |||||
else: | |||||
return promote_types(tensors, max_pri_tensors) | |||||
def get_device(inputs): | |||||
device = None | |||||
for i in inputs: | |||||
if isinstance(i, (Tensor, VarNode)): | |||||
if device is None: | |||||
device = i.device | |||||
elif device != i.device: | |||||
raise ValueError("ambiguous device: {} vs {}".format(device, i.device)) | |||||
assert device is not None | |||||
return device | |||||
def concatenate(inputs, axis=0, *, device=None): | def concatenate(inputs, axis=0, *, device=None): | ||||
dtype = dtype_promotion(inputs) | dtype = dtype_promotion(inputs) | ||||
device = get_device(inputs) | device = get_device(inputs) | ||||
def convert(x): | def convert(x): | ||||
return convert_single_value(x, inputs, dtype=dtype) | |||||
return convert_single_value(x, dtype=dtype, device=device) | |||||
inputs = tuple(map(convert, inputs)) | inputs = tuple(map(convert, inputs)) | ||||
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | (result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | ||||
@@ -133,7 +51,7 @@ def concatenate(inputs, axis=0, *, device=None): | |||||
def astype(x, dtype): | def astype(x, dtype): | ||||
dtype = np.dtype(dtype) | dtype = np.dtype(dtype) | ||||
if not is_equal(x.dtype, dtype): | |||||
if not is_dtype_equal(x.dtype, dtype): | |||||
isscalar = x.isscalar() | isscalar = x.isscalar() | ||||
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) | (x,) = apply(builtin.TypeCvt(dtype=dtype), x) | ||||
if isscalar: | if isscalar: | ||||
@@ -141,13 +59,12 @@ def astype(x, dtype): | |||||
return x | return x | ||||
def convert_single_value(v, inputs, *, dtype=None, device=None): | |||||
tensors = [i for i in inputs if isinstance(i, (Tensor, VarNode))] | |||||
assert len(tensors) > 0 | |||||
def convert_single_value(v, *, dtype=None, device=None): | |||||
if isinstance(v, (Tensor, VarNode)): | if isinstance(v, (Tensor, VarNode)): | ||||
v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) | |||||
if not is_quantize(v.dtype): | |||||
v = astype(v, dtype) | |||||
else: | else: | ||||
(v,) = Const(v, dtype=dtype, device=device)(*tensors) | |||||
(v,) = Const(v, dtype=dtype, device=device)() | |||||
return v | return v | ||||
@@ -161,7 +78,7 @@ def convert_inputs(*args: Tensor): | |||||
def convert(value): | def convert(value): | ||||
if value is None: | if value is None: | ||||
return value | return value | ||||
return convert_single_value(value, args, dtype=dtype, device=device) | |||||
return convert_single_value(value, dtype=dtype, device=device) | |||||
return tuple(map(convert, args)) | return tuple(map(convert, args)) | ||||
@@ -703,7 +703,7 @@ def topk( | |||||
op = builtin.TopK(mode=mode) | op = builtin.TopK(mode=mode) | ||||
if not isinstance(k, Tensor): | if not isinstance(k, Tensor): | ||||
(k,) = Const(k, dtype="int32", device=inp.device)(inp) | |||||
(k,) = Const(k, dtype="int32", device=inp.device)() | |||||
if len(inp.shape) == 1: | if len(inp.shape) == 1: | ||||
inp = inp.reshape(1, -1) | inp = inp.reshape(1, -1) | ||||
@@ -658,7 +658,7 @@ def batch_norm( | |||||
def make_full_if_none(x, value): | def make_full_if_none(x, value): | ||||
if x is None: | if x is None: | ||||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)() | |||||
shape = utils.astensor1d( | shape = utils.astensor1d( | ||||
(1, C, 1, 1), inp, dtype="int32", device=inp.device | (1, C, 1, 1), inp, dtype="int32", device=inp.device | ||||
) | ) | ||||
@@ -1567,7 +1567,7 @@ def indexing_one_hot( | |||||
""" | """ | ||||
assert isinstance(src, Tensor), "src must be of Tensor type" | assert isinstance(src, Tensor), "src must be of Tensor type" | ||||
op = builtin.IndexingOneHot(axis=axis) | op = builtin.IndexingOneHot(axis=axis) | ||||
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) | |||||
index = utils.convert_single_value(index, dtype="int32", device=src.device) | |||||
(result,) = apply(op, src, index) | (result,) = apply(op, src, index) | ||||
if not keepdims: | if not keepdims: | ||||
result = squeeze(result, axis) | result = squeeze(result, axis) | ||||
@@ -107,9 +107,7 @@ def full(shape, value, dtype="float32", device=None): | |||||
shape = (shape,) | shape = (shape,) | ||||
if device is None: | if device is None: | ||||
device = get_default_device() | device = get_default_device() | ||||
(x,) = Const(value, dtype=dtype, device=device)( | |||||
Tensor(value, dtype=dtype, device=device) | |||||
) | |||||
(x,) = Const(value, dtype=dtype, device=device)() | |||||
if len(shape) == 0: # scalar | if len(shape) == 0: # scalar | ||||
return x | return x | ||||
return broadcast_to(x, shape) | return broadcast_to(x, shape) | ||||
@@ -265,7 +263,7 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||||
device = as_device(device) | device = as_device(device) | ||||
def convert(x): | def convert(x): | ||||
return convert_single_value(x, inps, dtype=dtype) | |||||
return convert_single_value(x, dtype=dtype, device=device) | |||||
inps = tuple(map(convert, inps)) | inps = tuple(map(convert, inps)) | ||||
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | (result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | ||||
@@ -37,8 +37,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
else: | else: | ||||
cn = CompNode(device) | cn = CompNode(device) | ||||
else: | else: | ||||
assert isinstance(device, CompNode) | |||||
cn = device | |||||
if isinstance(device, CompNode): | |||||
cn = device | |||||
else: | |||||
cn = device._cn | |||||
# import pdb; pdb.set_trace() | # import pdb; pdb.set_trace() | ||||
if isinstance(data, _Tensor): | if isinstance(data, _Tensor): | ||||
@@ -179,4 +179,5 @@ void init_common(py::module m) { | |||||
init_npy_num_bfloat16(m); | init_npy_num_bfloat16(m); | ||||
init_npy_num_intbx(m); | init_npy_num_intbx(m); | ||||
init_dtypes(m); | |||||
} | } |
@@ -158,7 +158,7 @@ void PyExceptionForward::throw_() { | |||||
/* ============== namespace npy ============== */ | /* ============== namespace npy ============== */ | ||||
namespace { | |||||
namespace npy { | |||||
int to_mgb_supported_dtype_raw(int dtype) { | int to_mgb_supported_dtype_raw(int dtype) { | ||||
if (dtype == NPY_INT64) | if (dtype == NPY_INT64) | ||||
@@ -199,12 +199,6 @@ int dtype_mgb2np_raw(DType dtype) { | |||||
"can not convert dtype %s to numpy dtype", dtype.name())); | "can not convert dtype %s to numpy dtype", dtype.name())); | ||||
} | } | ||||
struct PyArrayDescrDeleter { | |||||
void operator()(PyArray_Descr* obj) { | |||||
Py_XDECREF(obj); | |||||
} | |||||
}; | |||||
//! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new | //! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new | ||||
//! reference to the descriptor. | //! reference to the descriptor. | ||||
std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr( | std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr( | ||||
@@ -585,9 +579,7 @@ void ndarray_shared_from_tensor_py_capsule_dtor(PyObject *cap) { | |||||
HostTensorNDRefHolder::free(static_cast<HostTensorNDRefHolder*>(ptr)); | HostTensorNDRefHolder::free(static_cast<HostTensorNDRefHolder*>(ptr)); | ||||
} | } | ||||
} // anonymous namespace | |||||
PyObject* npy::ndarray_from_tensor( | |||||
PyObject* ndarray_from_tensor( | |||||
const HostTensorND &val, ShareType share_type) { | const HostTensorND &val, ShareType share_type) { | ||||
if (!val.layout().is_contiguous() && !val.shape().is_empty()) { | if (!val.layout().is_contiguous() && !val.shape().is_empty()) { | ||||
mgb_assert(share_type != ShareType::MUST_SHARE); | mgb_assert(share_type != ShareType::MUST_SHARE); | ||||
@@ -634,7 +626,7 @@ PyObject* npy::ndarray_from_tensor( | |||||
return ret; | return ret; | ||||
} | } | ||||
HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||||
HostTensorND np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||||
auto ret_full = np2tensor_try_borrow(obj, meth, dtype); | auto ret_full = np2tensor_try_borrow(obj, meth, dtype); | ||||
if (meth.must_borrow_) { | if (meth.must_borrow_) { | ||||
mgb_assert(ret_full.second, | mgb_assert(ret_full.second, | ||||
@@ -645,7 +637,7 @@ HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||||
return ret_full.first; | return ret_full.first; | ||||
} | } | ||||
PyObject* npy::dtype_mgb2np(mgb::DType dtype) { | |||||
PyObject* dtype_mgb2np(mgb::DType dtype) { | |||||
PYTHON_GIL; | PYTHON_GIL; | ||||
// According to | // According to | ||||
// https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType | // https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType | ||||
@@ -668,7 +660,7 @@ PyObject* npy::dtype_mgb2np(mgb::DType dtype) { | |||||
return typeobj; | return typeobj; | ||||
} | } | ||||
mgb::DType npy::dtype_np2mgb(PyObject *obj) { | |||||
mgb::DType dtype_np2mgb(PyObject *obj) { | |||||
mgb_assert(obj && obj != Py_None, | mgb_assert(obj && obj != Py_None, | ||||
"can not convert null PyObject to numpy dtype"); | "can not convert null PyObject to numpy dtype"); | ||||
// see | // see | ||||
@@ -686,7 +678,7 @@ mgb::DType npy::dtype_np2mgb(PyObject *obj) { | |||||
return result; | return result; | ||||
} | } | ||||
PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) { | |||||
PyObject* to_mgb_supported_dtype(PyObject* dtype) { | |||||
PYTHON_GIL; | PYTHON_GIL; | ||||
PyArray_Descr* descr; | PyArray_Descr* descr; | ||||
@@ -702,7 +694,7 @@ PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) { | |||||
return PyArray_TypeObjectFromType(type_num); | return PyArray_TypeObjectFromType(type_num); | ||||
} | } | ||||
TensorShape npy::vec2shape(const std::vector<size_t> &vec) { | |||||
TensorShape vec2shape(const std::vector<size_t> &vec) { | |||||
TensorShape shape; | TensorShape shape; | ||||
mgb_assert(vec.size() <= TensorShape::MAX_NDIM, | mgb_assert(vec.size() <= TensorShape::MAX_NDIM, | ||||
"dim too large: %zd (max %zd)", | "dim too large: %zd (max %zd)", | ||||
@@ -718,3 +710,5 @@ TensorShape npy::vec2shape(const std::vector<size_t> &vec) { | |||||
mgb_assert(shape.ndim, "shape should not be empty"); | mgb_assert(shape.ndim, "shape should not be empty"); | ||||
return shape; | return shape; | ||||
} | } | ||||
} // namespace npy |
@@ -11,7 +11,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/graph.h" | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/utils/persistent_cache.h" | #include "megbrain/utils/persistent_cache.h" | ||||
#include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
@@ -26,6 +26,8 @@ | |||||
#include <pybind11/numpy.h> | #include <pybind11/numpy.h> | ||||
#include <pybind11/functional.h> | #include <pybind11/functional.h> | ||||
#include "./numpy_dtypes.h" | |||||
pybind11::module submodule(pybind11::module parent, const char* name, const char* doc = nullptr); | pybind11::module submodule(pybind11::module parent, const char* name, const char* doc = nullptr); | ||||
pybind11::module rel_import(pybind11::str name, pybind11::module m, int level); | pybind11::module rel_import(pybind11::str name, pybind11::module m, int level); | ||||
@@ -182,6 +184,18 @@ namespace npy { | |||||
//! convert raw vector to tensor shape | //! convert raw vector to tensor shape | ||||
mgb::TensorShape vec2shape(const std::vector<size_t> &vec); | mgb::TensorShape vec2shape(const std::vector<size_t> &vec); | ||||
struct PyArrayDescrDeleter { | |||||
void operator()(PyArray_Descr* obj) { | |||||
Py_XDECREF(obj); | |||||
} | |||||
}; | |||||
//! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new | |||||
//! reference to the descriptor. | |||||
std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(mgb::DType dtype); | |||||
mgb::DType dtype_np2mgb_descr(PyArray_Descr* descr); | |||||
//! convert megbrain dtype to numpy dtype object; return new reference | //! convert megbrain dtype to numpy dtype object; return new reference | ||||
PyObject* dtype_mgb2np(mgb::DType dtype); | PyObject* dtype_mgb2np(mgb::DType dtype); | ||||
@@ -0,0 +1,179 @@ | |||||
/** | |||||
* \file imperative/python/src/numpy_dtypes.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./numpy_dtypes.h" | |||||
#include "./helper.h" | |||||
#include "./pyext17.h" | |||||
#include "pybind11/pybind11.h" | |||||
#include <cstring> | |||||
namespace py = pybind11; | |||||
namespace mgb { | |||||
namespace { | |||||
inline bool _is_quantize(PyArray_Descr* dtype) { | |||||
static PyObject* PY_MGB_DTYPE_KEY = PyUnicode_FromString("mgb_dtype"); | |||||
return dtype->metadata && | |||||
PyDict_CheckExact(dtype->metadata) && | |||||
PyDict_Contains(dtype->metadata, PY_MGB_DTYPE_KEY) == 1; | |||||
} | |||||
PyObject* _get_mgb_dtype(PyArray_Descr* dtype) { | |||||
// Return value: New reference. | |||||
if (!_is_quantize(dtype)) { | |||||
throw py::type_error("expact quantize dtype"); | |||||
} | |||||
PyObject* ob = PyDict_GetItemString(dtype->metadata, "mgb_dtype"); | |||||
if (!PyDict_CheckExact(ob)) { | |||||
throw py::type_error("mgb_dtype is not dict"); | |||||
} | |||||
Py_INCREF(ob); | |||||
return ob; | |||||
} | |||||
double _get_scale(PyArray_Descr* dtype) { | |||||
PyObject* ob = _get_mgb_dtype(dtype); | |||||
PyObject* scale = PyDict_GetItemString(ob, "scale"); | |||||
if (!scale) { | |||||
Py_DECREF(ob); | |||||
throw py::key_error("scale"); | |||||
} | |||||
if (!PyFloat_Check(scale)) { | |||||
Py_DECREF(ob); | |||||
throw py::type_error("scale is not float"); | |||||
} | |||||
double ret = PyFloat_AsDouble(scale); | |||||
Py_DECREF(ob); | |||||
return ret; | |||||
} | |||||
long _get_zero_point(PyArray_Descr* dtype) { | |||||
PyObject* ob = _get_mgb_dtype(dtype); | |||||
PyObject* name = PyDict_GetItemString(ob, "name"); | |||||
if (!name) { | |||||
Py_DECREF(ob); | |||||
throw py::key_error("name"); | |||||
} | |||||
const char* s = PyUnicode_AsUTF8(name); | |||||
if (strcmp(s, "Quantized8Asymm") != 0 && strcmp(s, "Quantized4Asymm") != 0) { | |||||
Py_DECREF(ob); | |||||
throw py::value_error(ssprintf("expect name to be \"Quantized8Asymm\" or \"Quantized4Asymm\", got %s", s)); | |||||
} | |||||
PyObject* zp = PyDict_GetItemString(ob, "zero_point"); | |||||
if (!zp) { | |||||
Py_DECREF(ob); | |||||
throw py::key_error("zero_point"); | |||||
} | |||||
long ret = PyLong_AsLong(zp); | |||||
Py_DECREF(ob); | |||||
return ret; | |||||
} | |||||
bool _is_dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) { | |||||
bool q1 = _is_quantize(dt1), | |||||
q2 = _is_quantize(dt2); | |||||
if (q1 && q2) { | |||||
if (_get_scale(dt1) != _get_scale(dt2)) { | |||||
return false; | |||||
} | |||||
PyObject* zp1 = PyDict_GetItemString( | |||||
PyDict_GetItemString(dt1->metadata, "mgb_dtype"), "zero_point"); | |||||
PyObject* zp2 = PyDict_GetItemString( | |||||
PyDict_GetItemString(dt2->metadata, "mgb_dtype"), "zero_point"); | |||||
if (!zp1 || !zp2) { | |||||
throw py::key_error("zero_point"); | |||||
} | |||||
return PyLong_AsLong(zp1) == PyLong_AsLong(zp2); | |||||
} | |||||
if (!q1 && !q2) { | |||||
return dt1->type_num == dt2->type_num; | |||||
} | |||||
return false; | |||||
} | |||||
template<auto f> | |||||
struct _wrap { | |||||
static constexpr size_t n_args = []() { | |||||
using F = decltype(f); | |||||
using T = PyArray_Descr*; | |||||
static_assert(std::is_pointer<F>::value); | |||||
if constexpr (std::is_invocable<F, T>::value) { | |||||
return 1; | |||||
} else if constexpr (std::is_invocable<F, T, T>::value) { | |||||
return 2; | |||||
} else { | |||||
static_assert(!std::is_same_v<F, F>, "unreachable"); | |||||
} | |||||
}(); | |||||
static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargs) { | |||||
if (nargs != n_args) { | |||||
PyErr_Format(PyExc_ValueError, "expected %lu arguments", n_args); | |||||
return nullptr; | |||||
} | |||||
for (size_t i=0; i<nargs; ++i) { | |||||
if (args[i] == Py_None) { | |||||
PyErr_SetString(PyExc_ValueError, "can not convert null PyObject to numpy dtype"); | |||||
return nullptr; | |||||
} | |||||
} | |||||
try { | |||||
PyArray_Descr *dt1; | |||||
if(!PyArray_DescrConverter(args[0], &dt1)) { | |||||
throw ConversionError(ssprintf("can not convert to numpy.dtype from %s", | |||||
args[0]->ob_type->tp_name)); | |||||
} | |||||
if constexpr (n_args == 1) { | |||||
auto res = (*f)(dt1); | |||||
Py_DECREF(dt1); | |||||
return py::cast(res).release().ptr(); | |||||
} else { | |||||
PyArray_Descr *dt2; | |||||
if(!PyArray_DescrConverter(args[1], &dt2)) { | |||||
Py_DECREF(dt1); | |||||
throw ConversionError(ssprintf("can not convert to numpy.dtype from %s", | |||||
args[1]->ob_type->tp_name)); | |||||
} | |||||
auto&& res = (*f)(dt1, dt2); | |||||
Py_DECREF(dt1); | |||||
Py_DECREF(dt2); | |||||
return py::cast(res).release().ptr(); | |||||
} | |||||
} catch (std::exception& e) { | |||||
PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
return nullptr; | |||||
} | |||||
} | |||||
}; | |||||
} // anonymous namespace | |||||
void init_dtypes(py::module m) { | |||||
static PyMethodDef method_defs[] = { | |||||
{"is_quantize", (PyCFunction)_wrap<&_is_quantize>::impl, METH_FASTCALL, nullptr}, | |||||
{"get_scale", (PyCFunction)_wrap<&_get_scale>::impl, METH_FASTCALL, nullptr}, | |||||
{"get_zero_point", (PyCFunction)_wrap<&_get_zero_point>::impl, METH_FASTCALL, nullptr}, | |||||
{"is_dtype_equal", (PyCFunction)_wrap<&_is_dtype_equal>::impl, METH_FASTCALL, nullptr}, | |||||
{nullptr, nullptr, 0, nullptr} | |||||
}; | |||||
for (auto&& def: method_defs) { | |||||
if (def.ml_meth != nullptr) { | |||||
auto* func = PyCFunction_NewEx(&def, nullptr, nullptr); | |||||
if (!func) throw py::error_already_set(); | |||||
py::setattr(m, def.ml_name, func); | |||||
} | |||||
} | |||||
} | |||||
} // namespace mgb |
@@ -36,6 +36,7 @@ namespace mgb { | |||||
int npy_num_intb##n(); | int npy_num_intb##n(); | ||||
FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | ||||
#undef DEFINE_NPY_INTBX | #undef DEFINE_NPY_INTBX | ||||
void init_dtypes(pybind11::module m); | |||||
void init_npy_num_intbx(pybind11::module m); | void init_npy_num_intbx(pybind11::module m); | ||||
//! numpy type num for bfloat16 type | //! numpy type num for bfloat16 type | ||||
@@ -9,16 +9,22 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "megbrain/dtype.h" | |||||
#include "megbrain/common.h" | |||||
#include "./tensor.h" | #include "./tensor.h" | ||||
#include "./grad.h" | #include "./grad.h" | ||||
#include "./trace.h" | #include "./trace.h" | ||||
#include "./common.h" | #include "./common.h" | ||||
#include "./numpy_dtypes.h" | #include "./numpy_dtypes.h" | ||||
#include "./graph_rt.h" | #include "./graph_rt.h" | ||||
#include "./helper.h" | |||||
#include <pybind11/numpy.h> | #include <pybind11/numpy.h> | ||||
#include <pybind11/operators.h> | #include <pybind11/operators.h> | ||||
#include "./helper.h" | |||||
#include <unordered_map> | |||||
namespace py = pybind11; | namespace py = pybind11; | ||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
@@ -413,6 +419,198 @@ struct TensorWeakRef { | |||||
} | } | ||||
}; | }; | ||||
/* ============== convert inputs ============== */ | |||||
// map numpy.dtype.kind to priority | |||||
inline uint8_t category_priority(char c) { | |||||
switch (c) { | |||||
case 'f': return 3; // floating-point | |||||
case 'i': return 2; // signed integer | |||||
case 'u': return 2; // unsigned integer | |||||
case 'b': return 1; // boolean | |||||
default: return 0; | |||||
} | |||||
} | |||||
// Returns the maximum value of the priority of each type in the list `types`. | |||||
uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||||
if (types.size() == 0) { | |||||
return 0; | |||||
} else { | |||||
uint8_t max_p = 0; | |||||
for (auto&& desc: types) { | |||||
max_p = std::max(max_p, category_priority(desc->kind)); | |||||
} | |||||
return max_p; | |||||
} | |||||
} | |||||
// Returns the data type with sufficient size to hold all types of | |||||
// category `cat` in the list `types`. | |||||
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||||
// Return value: New reference | |||||
SmallVector<PyArray_Descr*> used_types; | |||||
for (auto&& desc: types) { | |||||
auto&& v = category_priority(desc->kind); | |||||
if (v == cat) { | |||||
used_types.emplace_back(desc); | |||||
} | |||||
} | |||||
mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||||
PyArray_Descr* res = used_types[0]; | |||||
Py_INCREF(res); | |||||
for (size_t i = 1; i < used_types.size(); ++i) { | |||||
PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||||
Py_DECREF(res); | |||||
res = tmp; | |||||
} | |||||
return res; | |||||
} | |||||
PyArray_Descr* scalar2dtype(PyObject* arg) { | |||||
// Return value: New reference | |||||
if (PyBool_Check(arg)) { | |||||
auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||||
return descr; | |||||
} | |||||
if (PyLong_CheckExact(arg)) { | |||||
auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||||
return descr; | |||||
} | |||||
if (PyFloat_CheckExact(arg)) { | |||||
auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||||
return descr; | |||||
} | |||||
return nullptr; | |||||
} | |||||
PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||||
// Return value: New reference | |||||
SmallVector<PyArray_Descr*> tensors; | |||||
SmallVector<PyArray_Descr*> scalars; | |||||
bool is_tuple = false; | |||||
PyObject* tuple; | |||||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
if (PyList_Check(args[0])) { | |||||
tuple = PyList_AsTuple(args[0]); | |||||
} else { | |||||
tuple = args[0]; | |||||
Py_INCREF(tuple); | |||||
} | |||||
nargs = PyTuple_Size(tuple); | |||||
is_tuple = true; | |||||
} | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||||
if (handle == Py_None) continue; | |||||
TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||||
if (tw) { | |||||
mgb::DType type = tw->m_tensor->dtype(); | |||||
auto&& descr = npy::dtype_mgb2np_descr(type); | |||||
Py_INCREF(descr.get()); | |||||
tensors.emplace_back(descr.get()); | |||||
}else{ | |||||
if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||||
auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||||
tensors.emplace_back(descr); | |||||
continue; | |||||
} | |||||
PyArray_Descr* descr = scalar2dtype(handle); | |||||
if (descr) { | |||||
scalars.emplace_back(descr); | |||||
continue; | |||||
} | |||||
} | |||||
} | |||||
auto max_pri_scalars = max_priority(scalars); | |||||
auto max_pri_tensors = max_priority(tensors); | |||||
if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||||
throw py::value_error("invalid input, no dtype avaliable"); | |||||
} | |||||
PyArray_Descr* res; | |||||
if (max_pri_scalars > max_pri_tensors) { | |||||
res = promote_types(scalars, max_pri_scalars); | |||||
}else{ | |||||
res = promote_types(tensors, max_pri_tensors); | |||||
} | |||||
for (auto *p: tensors) { Py_DECREF(p); } | |||||
for (auto *p: scalars) { Py_DECREF(p); } | |||||
Py_DECREF(tuple); | |||||
return res; | |||||
} | |||||
CompNode _get_device(PyObject*const* args, size_t nargs) { | |||||
bool is_tuple = false; | |||||
PyObject* tuple; | |||||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
if (PyList_Check(args[0])) { | |||||
tuple = PyList_AsTuple(args[0]); | |||||
} else { | |||||
tuple = args[0]; | |||||
Py_INCREF(tuple); | |||||
} | |||||
nargs = PyTuple_Size(tuple); | |||||
is_tuple = true; | |||||
} | |||||
bool valid = false; | |||||
CompNode cn; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||||
TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||||
if (tw) { | |||||
if (!valid) { | |||||
cn = tw->m_tensor->comp_node(); | |||||
valid = true; | |||||
} else { | |||||
CompNode cn1 = tw->m_tensor->comp_node(); | |||||
if (cn1 != cn) { | |||||
throw py::value_error(ssprintf("ambiguous device: %s vs %s", | |||||
cn.to_string().c_str(), cn1.to_string().c_str())); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
if (!valid) { | |||||
mgb_assert(0, "expact at least 1 device"); | |||||
} | |||||
Py_DECREF(tuple); | |||||
return cn; | |||||
} | |||||
// Returns the dtype that would result from performing an arithmetic | |||||
// operation on the provided input tensors and scalars. | |||||
PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) { | |||||
if (!nargs) { | |||||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
return nullptr; | |||||
} | |||||
try { | |||||
PyArray_Descr* res = _dtype_promotion(args, nargs); | |||||
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||||
} catch (std::exception& e) { | |||||
PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
return nullptr; | |||||
} | |||||
} | |||||
PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) { | |||||
if (!nargs) { | |||||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
return nullptr; | |||||
} | |||||
try { | |||||
CompNode cn = _get_device(args, nargs); | |||||
return py::cast(cn).release().ptr(); | |||||
} catch (std::exception& e) { | |||||
PyErr_SetString(PyExc_RuntimeError, e.what()); | |||||
return nullptr; | |||||
} | |||||
} | |||||
void init_tensor(py::module m) { | void init_tensor(py::module m) { | ||||
interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | ||||
@@ -444,10 +642,19 @@ void init_tensor(py::module m) { | |||||
.def(py::init<const TensorWrapper&>()) | .def(py::init<const TensorWrapper&>()) | ||||
.def("__call__", &TensorWeakRef::operator()); | .def("__call__", &TensorWeakRef::operator()); | ||||
static PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}; | |||||
auto* apply_func = PyCFunction_NewEx(&apply_def, nullptr, nullptr); | |||||
if (!apply_func) throw py::error_already_set(); | |||||
py::setattr(m, "apply", apply_func); | |||||
static PyMethodDef method_defs[] = { | |||||
{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}, | |||||
{"dtype_promotion", (PyCFunction)dtype_promotion, METH_FASTCALL, nullptr}, | |||||
{"get_device", (PyCFunction)get_device, METH_FASTCALL, nullptr}, | |||||
{nullptr, nullptr, 0, nullptr} | |||||
}; | |||||
for (auto&& def: method_defs) { | |||||
if (def.ml_meth != nullptr) { | |||||
auto* func = PyCFunction_NewEx(&def, nullptr, nullptr); | |||||
if (!func) throw py::error_already_set(); | |||||
py::setattr(m, def.ml_name, func); | |||||
} | |||||
} | |||||
m.def("_set_swap_flag", | m.def("_set_swap_flag", | ||||
[](bool flag) { interpreter_for_py->set_swap_flag(flag); }); | [](bool flag) { interpreter_for_py->set_swap_flag(flag); }); | ||||
@@ -113,7 +113,7 @@ def test_quint8_typecvt(): | |||||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | data = np.random.random(shape).astype(np.float32) * 5 - 1 | ||||
def typecvt(x, dt=None): | def typecvt(x, dt=None): | ||||
(y,) = apply(ops.TypeCvt(dtype=dt), x) | |||||
(y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||||
return y | return y | ||||
# convert to quint8 | # convert to quint8 | ||||
@@ -194,7 +194,7 @@ def test_quint4_typecvt(): | |||||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | data = np.random.random(shape).astype(np.float32) * 5 - 1 | ||||
def typecvt(x, dt=None): | def typecvt(x, dt=None): | ||||
(y,) = apply(ops.TypeCvt(dtype=dt), x) | |||||
(y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||||
return y | return y | ||||
# convert to quint4 | # convert to quint4 | ||||