GitOrigin-RevId: baef3d348c
release-1.2
@@ -8,7 +8,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from .._imperative_rt.core2 import Tensor | |||
# from .._imperative_rt.core2 import Tensor | |||
from ..tensor.core import OpBase, TensorBase, apply | |||
@@ -19,5 +19,10 @@ class Const: | |||
self.device = device | |||
def __call__(self, *reference): | |||
Wrapper = type(reference[0]) | |||
return (Wrapper(self.value, self.dtype, self.device, True),) | |||
from ...tensor import Tensor | |||
device = self.device | |||
if device is None: | |||
device = reference[0].device | |||
return (Tensor(self.value, self.dtype, self.device, True),) |
@@ -13,6 +13,12 @@ import numpy as np | |||
# normal dtype related | |||
from .._imperative_rt import bfloat16, intb1, intb2, intb4 | |||
from .._imperative_rt.common import ( | |||
get_scale, | |||
get_zero_point, | |||
is_dtype_equal, | |||
is_quantize, | |||
) | |||
def is_lowbit(dtype): | |||
@@ -42,41 +48,6 @@ _metadata_dict = { | |||
} | |||
def is_quantize(dtype): | |||
return ( | |||
hasattr(dtype, "metadata") | |||
and dtype.metadata is not None | |||
and "mgb_dtype" in dtype.metadata | |||
) | |||
def get_scale(dtype): | |||
assert is_quantize(dtype) | |||
return dtype.metadata["mgb_dtype"]["scale"] | |||
def get_zero_point(dtype): | |||
assert is_quantize(dtype) | |||
metadata = dtype.metadata["mgb_dtype"] | |||
assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||
return metadata["zero_point"] | |||
def is_equal(dt0, dt1): | |||
def _get_zero_point(dtype): | |||
assert is_quantize(dtype) | |||
metadata = dtype.metadata["mgb_dtype"] | |||
return metadata.get("zero_point") | |||
if is_quantize(dt0) and is_quantize(dt1): | |||
return get_scale(dt0) == get_scale(dt1) and _get_zero_point( | |||
dt0 | |||
) == _get_zero_point(dt1) | |||
if not (is_quantize(dt0) or is_quantize(dt1)): | |||
return dt0 == dt1 | |||
return False | |||
def _check_zero_point(zp: int, dtype_str: str): | |||
qmin = _metadata_dict[dtype_str].qmin | |||
qmax = _metadata_dict[dtype_str].qmax | |||
@@ -151,9 +151,9 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
def get_index(i): | |||
if not isinstance(i, (Tensor)): | |||
if is_bool_list(i) or isinstance(i, np.ndarray) and i.dtype == np.bool_: | |||
(i,) = Const(i, dtype=np.bool_, device=inp.device)(inp) | |||
(i,) = Const(i, dtype=np.bool_, device=inp.device)() | |||
else: | |||
(i,) = Const(i, dtype=np.int32, device=inp.device)(inp) | |||
(i,) = Const(i, dtype=np.int32, device=inp.device)() | |||
return i | |||
assert isinstance(i, Tensor) | |||
if i.dtype != np.bool_: | |||
@@ -197,7 +197,7 @@ def try_condtake(tensor, index): | |||
): | |||
return [] | |||
if isinstance(index, np.ndarray): | |||
(index,) = Const(index, dtype=np.bool_, device=tensor.device)(tensor) | |||
(index,) = Const(index, dtype=np.bool_, device=tensor.device)() | |||
assert isinstance(index, Tensor) | |||
if not isinstance(tensor, Tensor): | |||
raise TypeError("input must be a tensor") | |||
@@ -217,9 +217,7 @@ def getitem(tensor, index): | |||
if isinstance(v.shape, v.__class__): | |||
break | |||
if len(v.shape) > 0 and v.shape[0] == 0: | |||
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)( | |||
tensor | |||
) | |||
(empty_tensor,) = Const([], dtype=tensor.dtype, device=tensor.device)() | |||
return empty_tensor | |||
if use_subtensor: | |||
op = builtin.Subtensor(items=items) | |||
@@ -240,8 +238,7 @@ def setitem(tensor, index, value): | |||
return tensor | |||
tensor = tensor.reshape(-1) | |||
if not isinstance(value, Tensor): | |||
op = Const(value, dtype=tensor.dtype, device=tensor.device) | |||
(value,) = op(tensor) | |||
(value,) = Const(value, dtype=tensor.dtype, device=tensor.device)() | |||
tensor, tensors, items, use_subtensor, _ = unpack_getitem(tensor, index) | |||
for v in tensors: | |||
if len(v.shape) > 0 and v.shape[0] == 0: | |||
@@ -11,10 +11,10 @@ from typing import Iterable, Union | |||
import numpy as np | |||
from .._imperative_rt.core2 import Tensor, apply | |||
from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | |||
from ..ops import builtin | |||
from ..ops.special import Const | |||
from .dtype import is_equal, is_quantize | |||
from .dtype import is_dtype_equal, is_quantize | |||
from .megbrain_graph import VarNode | |||
_enable_convert_inputs = True | |||
@@ -37,94 +37,12 @@ def set_convert_inputs(flag): | |||
return backup | |||
def dtype_promotion(inputs): | |||
""" | |||
Returns the dtype that would result from performing an arithmetic | |||
operation on the provided input tensors and scalars. | |||
""" | |||
# map numpy.dtype.kind to priority | |||
category_priority = { | |||
"f": 3, # floating-point | |||
"i": 2, # signed integer | |||
"u": 2, # unsigned integer | |||
"b": 1, # boolean | |||
} | |||
def scalar2dtype(x): | |||
""" | |||
For scalar `x`, returns its corresponding type. A floating point scalar | |||
has dtype 'float32'. An integral non-boolean scalar has dtype 'int32'. | |||
A boolean scalar has dtype 'bool'. | |||
""" | |||
if isinstance(x, bool): | |||
return np.bool_ | |||
if isinstance(x, int): | |||
return np.int32 | |||
if isinstance(x, float): | |||
return np.float32 | |||
def promote_types(types, cat): | |||
""" | |||
Returns the data type with sufficient size to hold all types of | |||
category `cat` in the list `types`. | |||
""" | |||
used_types = [ | |||
i for i in types if category_priority.get(np.dtype(i).kind, 0) == cat | |||
] | |||
assert len(used_types) > 0 | |||
res = used_types[0] | |||
for i in used_types: | |||
res = np.promote_types(res, i) | |||
return res | |||
def max_priority(types): | |||
""" | |||
Returns the maximum value of the priority of each type in the list | |||
`types`. | |||
""" | |||
if not types: | |||
return 0 | |||
else: | |||
return max([category_priority.get(np.dtype(i).kind, 0) for i in types]) | |||
scalars = [] | |||
tensors = [] | |||
for data in inputs: | |||
if hasattr(data, "dtype"): | |||
tensors.append(data.dtype) | |||
elif isinstance(data, (float, int, bool)): | |||
scalars.append(scalar2dtype(data)) | |||
max_pri_scalars = max_priority(scalars) | |||
max_pri_tensors = max_priority(tensors) | |||
assert max_pri_scalars > 0 or max_pri_tensors > 0 | |||
if max_pri_scalars > max_pri_tensors: | |||
return promote_types(scalars, max_pri_scalars) | |||
else: | |||
return promote_types(tensors, max_pri_tensors) | |||
def get_device(inputs): | |||
device = None | |||
for i in inputs: | |||
if isinstance(i, (Tensor, VarNode)): | |||
if device is None: | |||
device = i.device | |||
elif device != i.device: | |||
raise ValueError("ambiguous device: {} vs {}".format(device, i.device)) | |||
assert device is not None | |||
return device | |||
def concatenate(inputs, axis=0, *, device=None): | |||
dtype = dtype_promotion(inputs) | |||
device = get_device(inputs) | |||
def convert(x): | |||
return convert_single_value(x, inputs, dtype=dtype) | |||
return convert_single_value(x, dtype=dtype, device=device) | |||
inputs = tuple(map(convert, inputs)) | |||
(result,) = apply(builtin.Concat(axis=axis, comp_node=device), *inputs) | |||
@@ -133,7 +51,7 @@ def concatenate(inputs, axis=0, *, device=None): | |||
def astype(x, dtype): | |||
dtype = np.dtype(dtype) | |||
if not is_equal(x.dtype, dtype): | |||
if not is_dtype_equal(x.dtype, dtype): | |||
isscalar = x.isscalar() | |||
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||
if isscalar: | |||
@@ -141,13 +59,12 @@ def astype(x, dtype): | |||
return x | |||
def convert_single_value(v, inputs, *, dtype=None, device=None): | |||
tensors = [i for i in inputs if isinstance(i, (Tensor, VarNode))] | |||
assert len(tensors) > 0 | |||
def convert_single_value(v, *, dtype=None, device=None): | |||
if isinstance(v, (Tensor, VarNode)): | |||
v = astype(v, v.dtype if is_quantize(v.dtype) else dtype) | |||
if not is_quantize(v.dtype): | |||
v = astype(v, dtype) | |||
else: | |||
(v,) = Const(v, dtype=dtype, device=device)(*tensors) | |||
(v,) = Const(v, dtype=dtype, device=device)() | |||
return v | |||
@@ -161,7 +78,7 @@ def convert_inputs(*args: Tensor): | |||
def convert(value): | |||
if value is None: | |||
return value | |||
return convert_single_value(value, args, dtype=dtype, device=device) | |||
return convert_single_value(value, dtype=dtype, device=device) | |||
return tuple(map(convert, args)) | |||
@@ -703,7 +703,7 @@ def topk( | |||
op = builtin.TopK(mode=mode) | |||
if not isinstance(k, Tensor): | |||
(k,) = Const(k, dtype="int32", device=inp.device)(inp) | |||
(k,) = Const(k, dtype="int32", device=inp.device)() | |||
if len(inp.shape) == 1: | |||
inp = inp.reshape(1, -1) | |||
@@ -658,7 +658,7 @@ def batch_norm( | |||
def make_full_if_none(x, value): | |||
if x is None: | |||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)() | |||
shape = utils.astensor1d( | |||
(1, C, 1, 1), inp, dtype="int32", device=inp.device | |||
) | |||
@@ -1567,7 +1567,7 @@ def indexing_one_hot( | |||
""" | |||
assert isinstance(src, Tensor), "src must be of Tensor type" | |||
op = builtin.IndexingOneHot(axis=axis) | |||
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device) | |||
index = utils.convert_single_value(index, dtype="int32", device=src.device) | |||
(result,) = apply(op, src, index) | |||
if not keepdims: | |||
result = squeeze(result, axis) | |||
@@ -107,9 +107,7 @@ def full(shape, value, dtype="float32", device=None): | |||
shape = (shape,) | |||
if device is None: | |||
device = get_default_device() | |||
(x,) = Const(value, dtype=dtype, device=device)( | |||
Tensor(value, dtype=dtype, device=device) | |||
) | |||
(x,) = Const(value, dtype=dtype, device=device)() | |||
if len(shape) == 0: # scalar | |||
return x | |||
return broadcast_to(x, shape) | |||
@@ -265,7 +263,7 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||
device = as_device(device) | |||
def convert(x): | |||
return convert_single_value(x, inps, dtype=dtype) | |||
return convert_single_value(x, dtype=dtype, device=device) | |||
inps = tuple(map(convert, inps)) | |||
(result,) = apply(builtin.Concat(axis=axis, comp_node=device.to_c()), *inps) | |||
@@ -37,8 +37,10 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
else: | |||
cn = CompNode(device) | |||
else: | |||
assert isinstance(device, CompNode) | |||
cn = device | |||
if isinstance(device, CompNode): | |||
cn = device | |||
else: | |||
cn = device._cn | |||
# import pdb; pdb.set_trace() | |||
if isinstance(data, _Tensor): | |||
@@ -179,4 +179,5 @@ void init_common(py::module m) { | |||
init_npy_num_bfloat16(m); | |||
init_npy_num_intbx(m); | |||
init_dtypes(m); | |||
} |
@@ -158,7 +158,7 @@ void PyExceptionForward::throw_() { | |||
/* ============== namespace npy ============== */ | |||
namespace { | |||
namespace npy { | |||
int to_mgb_supported_dtype_raw(int dtype) { | |||
if (dtype == NPY_INT64) | |||
@@ -199,12 +199,6 @@ int dtype_mgb2np_raw(DType dtype) { | |||
"can not convert dtype %s to numpy dtype", dtype.name())); | |||
} | |||
struct PyArrayDescrDeleter { | |||
void operator()(PyArray_Descr* obj) { | |||
Py_XDECREF(obj); | |||
} | |||
}; | |||
//! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new | |||
//! reference to the descriptor. | |||
std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr( | |||
@@ -585,9 +579,7 @@ void ndarray_shared_from_tensor_py_capsule_dtor(PyObject *cap) { | |||
HostTensorNDRefHolder::free(static_cast<HostTensorNDRefHolder*>(ptr)); | |||
} | |||
} // anonymous namespace | |||
PyObject* npy::ndarray_from_tensor( | |||
PyObject* ndarray_from_tensor( | |||
const HostTensorND &val, ShareType share_type) { | |||
if (!val.layout().is_contiguous() && !val.shape().is_empty()) { | |||
mgb_assert(share_type != ShareType::MUST_SHARE); | |||
@@ -634,7 +626,7 @@ PyObject* npy::ndarray_from_tensor( | |||
return ret; | |||
} | |||
HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||
HostTensorND np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||
auto ret_full = np2tensor_try_borrow(obj, meth, dtype); | |||
if (meth.must_borrow_) { | |||
mgb_assert(ret_full.second, | |||
@@ -645,7 +637,7 @@ HostTensorND npy::np2tensor(PyObject* obj, const Meth& meth, DType dtype) { | |||
return ret_full.first; | |||
} | |||
PyObject* npy::dtype_mgb2np(mgb::DType dtype) { | |||
PyObject* dtype_mgb2np(mgb::DType dtype) { | |||
PYTHON_GIL; | |||
// According to | |||
// https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType | |||
@@ -668,7 +660,7 @@ PyObject* npy::dtype_mgb2np(mgb::DType dtype) { | |||
return typeobj; | |||
} | |||
mgb::DType npy::dtype_np2mgb(PyObject *obj) { | |||
mgb::DType dtype_np2mgb(PyObject *obj) { | |||
mgb_assert(obj && obj != Py_None, | |||
"can not convert null PyObject to numpy dtype"); | |||
// see | |||
@@ -686,7 +678,7 @@ mgb::DType npy::dtype_np2mgb(PyObject *obj) { | |||
return result; | |||
} | |||
PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) { | |||
PyObject* to_mgb_supported_dtype(PyObject* dtype) { | |||
PYTHON_GIL; | |||
PyArray_Descr* descr; | |||
@@ -702,7 +694,7 @@ PyObject* npy::to_mgb_supported_dtype(PyObject* dtype) { | |||
return PyArray_TypeObjectFromType(type_num); | |||
} | |||
TensorShape npy::vec2shape(const std::vector<size_t> &vec) { | |||
TensorShape vec2shape(const std::vector<size_t> &vec) { | |||
TensorShape shape; | |||
mgb_assert(vec.size() <= TensorShape::MAX_NDIM, | |||
"dim too large: %zd (max %zd)", | |||
@@ -718,3 +710,5 @@ TensorShape npy::vec2shape(const std::vector<size_t> &vec) { | |||
mgb_assert(shape.ndim, "shape should not be empty"); | |||
return shape; | |||
} | |||
} // namespace npy |
@@ -11,7 +11,7 @@ | |||
#pragma once | |||
#include "megbrain/graph.h" | |||
#include "megbrain/common.h" | |||
#include "megbrain/utils/persistent_cache.h" | |||
#include "megbrain/imperative/op_def.h" | |||
@@ -26,6 +26,8 @@ | |||
#include <pybind11/numpy.h> | |||
#include <pybind11/functional.h> | |||
#include "./numpy_dtypes.h" | |||
pybind11::module submodule(pybind11::module parent, const char* name, const char* doc = nullptr); | |||
pybind11::module rel_import(pybind11::str name, pybind11::module m, int level); | |||
@@ -182,6 +184,18 @@ namespace npy { | |||
//! convert raw vector to tensor shape | |||
mgb::TensorShape vec2shape(const std::vector<size_t> &vec); | |||
struct PyArrayDescrDeleter { | |||
void operator()(PyArray_Descr* obj) { | |||
Py_XDECREF(obj); | |||
} | |||
}; | |||
//! Convert MegBrain DType to NumPy DType descriptor, the caller receives a new | |||
//! reference to the descriptor. | |||
std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(mgb::DType dtype); | |||
mgb::DType dtype_np2mgb_descr(PyArray_Descr* descr); | |||
//! convert megbrain dtype to numpy dtype object; return new reference | |||
PyObject* dtype_mgb2np(mgb::DType dtype); | |||
@@ -0,0 +1,179 @@ | |||
/** | |||
* \file imperative/python/src/numpy_dtypes.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./numpy_dtypes.h" | |||
#include "./helper.h" | |||
#include "./pyext17.h" | |||
#include "pybind11/pybind11.h" | |||
#include <cstring> | |||
namespace py = pybind11; | |||
namespace mgb { | |||
namespace { | |||
inline bool _is_quantize(PyArray_Descr* dtype) { | |||
static PyObject* PY_MGB_DTYPE_KEY = PyUnicode_FromString("mgb_dtype"); | |||
return dtype->metadata && | |||
PyDict_CheckExact(dtype->metadata) && | |||
PyDict_Contains(dtype->metadata, PY_MGB_DTYPE_KEY) == 1; | |||
} | |||
PyObject* _get_mgb_dtype(PyArray_Descr* dtype) { | |||
// Return value: New reference. | |||
if (!_is_quantize(dtype)) { | |||
throw py::type_error("expact quantize dtype"); | |||
} | |||
PyObject* ob = PyDict_GetItemString(dtype->metadata, "mgb_dtype"); | |||
if (!PyDict_CheckExact(ob)) { | |||
throw py::type_error("mgb_dtype is not dict"); | |||
} | |||
Py_INCREF(ob); | |||
return ob; | |||
} | |||
double _get_scale(PyArray_Descr* dtype) { | |||
PyObject* ob = _get_mgb_dtype(dtype); | |||
PyObject* scale = PyDict_GetItemString(ob, "scale"); | |||
if (!scale) { | |||
Py_DECREF(ob); | |||
throw py::key_error("scale"); | |||
} | |||
if (!PyFloat_Check(scale)) { | |||
Py_DECREF(ob); | |||
throw py::type_error("scale is not float"); | |||
} | |||
double ret = PyFloat_AsDouble(scale); | |||
Py_DECREF(ob); | |||
return ret; | |||
} | |||
long _get_zero_point(PyArray_Descr* dtype) { | |||
PyObject* ob = _get_mgb_dtype(dtype); | |||
PyObject* name = PyDict_GetItemString(ob, "name"); | |||
if (!name) { | |||
Py_DECREF(ob); | |||
throw py::key_error("name"); | |||
} | |||
const char* s = PyUnicode_AsUTF8(name); | |||
if (strcmp(s, "Quantized8Asymm") != 0 && strcmp(s, "Quantized4Asymm") != 0) { | |||
Py_DECREF(ob); | |||
throw py::value_error(ssprintf("expect name to be \"Quantized8Asymm\" or \"Quantized4Asymm\", got %s", s)); | |||
} | |||
PyObject* zp = PyDict_GetItemString(ob, "zero_point"); | |||
if (!zp) { | |||
Py_DECREF(ob); | |||
throw py::key_error("zero_point"); | |||
} | |||
long ret = PyLong_AsLong(zp); | |||
Py_DECREF(ob); | |||
return ret; | |||
} | |||
bool _is_dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) { | |||
bool q1 = _is_quantize(dt1), | |||
q2 = _is_quantize(dt2); | |||
if (q1 && q2) { | |||
if (_get_scale(dt1) != _get_scale(dt2)) { | |||
return false; | |||
} | |||
PyObject* zp1 = PyDict_GetItemString( | |||
PyDict_GetItemString(dt1->metadata, "mgb_dtype"), "zero_point"); | |||
PyObject* zp2 = PyDict_GetItemString( | |||
PyDict_GetItemString(dt2->metadata, "mgb_dtype"), "zero_point"); | |||
if (!zp1 || !zp2) { | |||
throw py::key_error("zero_point"); | |||
} | |||
return PyLong_AsLong(zp1) == PyLong_AsLong(zp2); | |||
} | |||
if (!q1 && !q2) { | |||
return dt1->type_num == dt2->type_num; | |||
} | |||
return false; | |||
} | |||
template<auto f> | |||
struct _wrap { | |||
static constexpr size_t n_args = []() { | |||
using F = decltype(f); | |||
using T = PyArray_Descr*; | |||
static_assert(std::is_pointer<F>::value); | |||
if constexpr (std::is_invocable<F, T>::value) { | |||
return 1; | |||
} else if constexpr (std::is_invocable<F, T, T>::value) { | |||
return 2; | |||
} else { | |||
static_assert(!std::is_same_v<F, F>, "unreachable"); | |||
} | |||
}(); | |||
static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargs) { | |||
if (nargs != n_args) { | |||
PyErr_Format(PyExc_ValueError, "expected %lu arguments", n_args); | |||
return nullptr; | |||
} | |||
for (size_t i=0; i<nargs; ++i) { | |||
if (args[i] == Py_None) { | |||
PyErr_SetString(PyExc_ValueError, "can not convert null PyObject to numpy dtype"); | |||
return nullptr; | |||
} | |||
} | |||
try { | |||
PyArray_Descr *dt1; | |||
if(!PyArray_DescrConverter(args[0], &dt1)) { | |||
throw ConversionError(ssprintf("can not convert to numpy.dtype from %s", | |||
args[0]->ob_type->tp_name)); | |||
} | |||
if constexpr (n_args == 1) { | |||
auto res = (*f)(dt1); | |||
Py_DECREF(dt1); | |||
return py::cast(res).release().ptr(); | |||
} else { | |||
PyArray_Descr *dt2; | |||
if(!PyArray_DescrConverter(args[1], &dt2)) { | |||
Py_DECREF(dt1); | |||
throw ConversionError(ssprintf("can not convert to numpy.dtype from %s", | |||
args[1]->ob_type->tp_name)); | |||
} | |||
auto&& res = (*f)(dt1, dt2); | |||
Py_DECREF(dt1); | |||
Py_DECREF(dt2); | |||
return py::cast(res).release().ptr(); | |||
} | |||
} catch (std::exception& e) { | |||
PyErr_SetString(PyExc_RuntimeError, e.what()); | |||
return nullptr; | |||
} | |||
} | |||
}; | |||
} // anonymous namespace | |||
void init_dtypes(py::module m) { | |||
static PyMethodDef method_defs[] = { | |||
{"is_quantize", (PyCFunction)_wrap<&_is_quantize>::impl, METH_FASTCALL, nullptr}, | |||
{"get_scale", (PyCFunction)_wrap<&_get_scale>::impl, METH_FASTCALL, nullptr}, | |||
{"get_zero_point", (PyCFunction)_wrap<&_get_zero_point>::impl, METH_FASTCALL, nullptr}, | |||
{"is_dtype_equal", (PyCFunction)_wrap<&_is_dtype_equal>::impl, METH_FASTCALL, nullptr}, | |||
{nullptr, nullptr, 0, nullptr} | |||
}; | |||
for (auto&& def: method_defs) { | |||
if (def.ml_meth != nullptr) { | |||
auto* func = PyCFunction_NewEx(&def, nullptr, nullptr); | |||
if (!func) throw py::error_already_set(); | |||
py::setattr(m, def.ml_name, func); | |||
} | |||
} | |||
} | |||
} // namespace mgb |
@@ -36,6 +36,7 @@ namespace mgb { | |||
int npy_num_intb##n(); | |||
FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | |||
#undef DEFINE_NPY_INTBX | |||
void init_dtypes(pybind11::module m); | |||
void init_npy_num_intbx(pybind11::module m); | |||
//! numpy type num for bfloat16 type | |||
@@ -9,16 +9,22 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/dtype.h" | |||
#include "megbrain/common.h" | |||
#include "./tensor.h" | |||
#include "./grad.h" | |||
#include "./trace.h" | |||
#include "./common.h" | |||
#include "./numpy_dtypes.h" | |||
#include "./graph_rt.h" | |||
#include "./helper.h" | |||
#include <pybind11/numpy.h> | |||
#include <pybind11/operators.h> | |||
#include "./helper.h" | |||
#include <unordered_map> | |||
namespace py = pybind11; | |||
namespace mgb::imperative::python { | |||
@@ -413,6 +419,198 @@ struct TensorWeakRef { | |||
} | |||
}; | |||
/* ============== convert inputs ============== */ | |||
// map numpy.dtype.kind to priority | |||
inline uint8_t category_priority(char c) { | |||
switch (c) { | |||
case 'f': return 3; // floating-point | |||
case 'i': return 2; // signed integer | |||
case 'u': return 2; // unsigned integer | |||
case 'b': return 1; // boolean | |||
default: return 0; | |||
} | |||
} | |||
// Returns the maximum value of the priority of each type in the list `types`. | |||
uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||
if (types.size() == 0) { | |||
return 0; | |||
} else { | |||
uint8_t max_p = 0; | |||
for (auto&& desc: types) { | |||
max_p = std::max(max_p, category_priority(desc->kind)); | |||
} | |||
return max_p; | |||
} | |||
} | |||
// Returns the data type with sufficient size to hold all types of | |||
// category `cat` in the list `types`. | |||
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||
// Return value: New reference | |||
SmallVector<PyArray_Descr*> used_types; | |||
for (auto&& desc: types) { | |||
auto&& v = category_priority(desc->kind); | |||
if (v == cat) { | |||
used_types.emplace_back(desc); | |||
} | |||
} | |||
mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||
PyArray_Descr* res = used_types[0]; | |||
Py_INCREF(res); | |||
for (size_t i = 1; i < used_types.size(); ++i) { | |||
PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||
Py_DECREF(res); | |||
res = tmp; | |||
} | |||
return res; | |||
} | |||
PyArray_Descr* scalar2dtype(PyObject* arg) { | |||
// Return value: New reference | |||
if (PyBool_Check(arg)) { | |||
auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||
return descr; | |||
} | |||
if (PyLong_CheckExact(arg)) { | |||
auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||
return descr; | |||
} | |||
if (PyFloat_CheckExact(arg)) { | |||
auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||
return descr; | |||
} | |||
return nullptr; | |||
} | |||
PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||
// Return value: New reference | |||
SmallVector<PyArray_Descr*> tensors; | |||
SmallVector<PyArray_Descr*> scalars; | |||
bool is_tuple = false; | |||
PyObject* tuple; | |||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
if (PyList_Check(args[0])) { | |||
tuple = PyList_AsTuple(args[0]); | |||
} else { | |||
tuple = args[0]; | |||
Py_INCREF(tuple); | |||
} | |||
nargs = PyTuple_Size(tuple); | |||
is_tuple = true; | |||
} | |||
for (size_t i = 0; i < nargs; ++i) { | |||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
if (handle == Py_None) continue; | |||
TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||
if (tw) { | |||
mgb::DType type = tw->m_tensor->dtype(); | |||
auto&& descr = npy::dtype_mgb2np_descr(type); | |||
Py_INCREF(descr.get()); | |||
tensors.emplace_back(descr.get()); | |||
}else{ | |||
if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||
auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||
tensors.emplace_back(descr); | |||
continue; | |||
} | |||
PyArray_Descr* descr = scalar2dtype(handle); | |||
if (descr) { | |||
scalars.emplace_back(descr); | |||
continue; | |||
} | |||
} | |||
} | |||
auto max_pri_scalars = max_priority(scalars); | |||
auto max_pri_tensors = max_priority(tensors); | |||
if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||
throw py::value_error("invalid input, no dtype avaliable"); | |||
} | |||
PyArray_Descr* res; | |||
if (max_pri_scalars > max_pri_tensors) { | |||
res = promote_types(scalars, max_pri_scalars); | |||
}else{ | |||
res = promote_types(tensors, max_pri_tensors); | |||
} | |||
for (auto *p: tensors) { Py_DECREF(p); } | |||
for (auto *p: scalars) { Py_DECREF(p); } | |||
Py_DECREF(tuple); | |||
return res; | |||
} | |||
CompNode _get_device(PyObject*const* args, size_t nargs) { | |||
bool is_tuple = false; | |||
PyObject* tuple; | |||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
if (PyList_Check(args[0])) { | |||
tuple = PyList_AsTuple(args[0]); | |||
} else { | |||
tuple = args[0]; | |||
Py_INCREF(tuple); | |||
} | |||
nargs = PyTuple_Size(tuple); | |||
is_tuple = true; | |||
} | |||
bool valid = false; | |||
CompNode cn; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||
if (tw) { | |||
if (!valid) { | |||
cn = tw->m_tensor->comp_node(); | |||
valid = true; | |||
} else { | |||
CompNode cn1 = tw->m_tensor->comp_node(); | |||
if (cn1 != cn) { | |||
throw py::value_error(ssprintf("ambiguous device: %s vs %s", | |||
cn.to_string().c_str(), cn1.to_string().c_str())); | |||
} | |||
} | |||
} | |||
} | |||
if (!valid) { | |||
mgb_assert(0, "expact at least 1 device"); | |||
} | |||
Py_DECREF(tuple); | |||
return cn; | |||
} | |||
// Returns the dtype that would result from performing an arithmetic | |||
// operation on the provided input tensors and scalars. | |||
PyObject* dtype_promotion(PyObject* self, PyObject*const* args, size_t nargs) { | |||
if (!nargs) { | |||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
return nullptr; | |||
} | |||
try { | |||
PyArray_Descr* res = _dtype_promotion(args, nargs); | |||
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||
} catch (std::exception& e) { | |||
PyErr_SetString(PyExc_RuntimeError, e.what()); | |||
return nullptr; | |||
} | |||
} | |||
PyObject* get_device(PyObject* self, PyObject*const* args, size_t nargs) { | |||
if (!nargs) { | |||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
return nullptr; | |||
} | |||
try { | |||
CompNode cn = _get_device(args, nargs); | |||
return py::cast(cn).release().ptr(); | |||
} catch (std::exception& e) { | |||
PyErr_SetString(PyExc_RuntimeError, e.what()); | |||
return nullptr; | |||
} | |||
} | |||
void init_tensor(py::module m) { | |||
interpreter_for_py = interpreter::Interpreter::inst().create_channel(); | |||
@@ -444,10 +642,19 @@ void init_tensor(py::module m) { | |||
.def(py::init<const TensorWrapper&>()) | |||
.def("__call__", &TensorWeakRef::operator()); | |||
static PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}; | |||
auto* apply_func = PyCFunction_NewEx(&apply_def, nullptr, nullptr); | |||
if (!apply_func) throw py::error_already_set(); | |||
py::setattr(m, "apply", apply_func); | |||
static PyMethodDef method_defs[] = { | |||
{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}, | |||
{"dtype_promotion", (PyCFunction)dtype_promotion, METH_FASTCALL, nullptr}, | |||
{"get_device", (PyCFunction)get_device, METH_FASTCALL, nullptr}, | |||
{nullptr, nullptr, 0, nullptr} | |||
}; | |||
for (auto&& def: method_defs) { | |||
if (def.ml_meth != nullptr) { | |||
auto* func = PyCFunction_NewEx(&def, nullptr, nullptr); | |||
if (!func) throw py::error_already_set(); | |||
py::setattr(m, def.ml_name, func); | |||
} | |||
} | |||
m.def("_set_swap_flag", | |||
[](bool flag) { interpreter_for_py->set_swap_flag(flag); }); | |||
@@ -113,7 +113,7 @@ def test_quint8_typecvt(): | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
def typecvt(x, dt=None): | |||
(y,) = apply(ops.TypeCvt(dtype=dt), x) | |||
(y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||
return y | |||
# convert to quint8 | |||
@@ -194,7 +194,7 @@ def test_quint4_typecvt(): | |||
data = np.random.random(shape).astype(np.float32) * 5 - 1 | |||
def typecvt(x, dt=None): | |||
(y,) = apply(ops.TypeCvt(dtype=dt), x) | |||
(y,) = G.apply_normal_op(ops.TypeCvt(dtype=dt), x) | |||
return y | |||
# convert to quint4 | |||