@@ -19,6 +19,7 @@ from .._imperative_rt.core2 import ( | |||
SymbolVar, | |||
Tensor, | |||
apply, | |||
astype_cpp, | |||
broadcast_cpp, | |||
dtype_promotion, | |||
) | |||
@@ -27,14 +28,7 @@ from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp | |||
from ..ops import builtin | |||
from . import amp | |||
from .indexing import getitem, setitem | |||
from .utils import ( | |||
_normalize_axis, | |||
astensor1d, | |||
astype, | |||
cast_tensors, | |||
make_shape_tuple, | |||
subgraph, | |||
) | |||
from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph | |||
_ElwMod = builtin.Elemwise.Mode | |||
@@ -605,7 +599,7 @@ class ArrayMethodMixin(abc.ABC): | |||
r"""Returns a :class:`Tensor` with the same data and number of elements | |||
with the specified :attr:`~.Tensor.dtype`. | |||
""" | |||
return astype(self, dtype) | |||
return astype_cpp(self, dtype) | |||
def reshape(self, *args): | |||
r"""See :func:`~.reshape`.""" | |||
@@ -20,6 +20,9 @@ from .._imperative_rt.core2 import ( | |||
_get_convert_inputs, | |||
_set_convert_inputs, | |||
apply, | |||
astype_cpp, | |||
convert_inputs_cpp, | |||
convert_single_value_cpp, | |||
dtype_promotion, | |||
get_device, | |||
make_shape_tuple, | |||
@@ -55,53 +58,14 @@ def concatenate(inputs, axis=0, *, device=None): | |||
return result | |||
def astype(x, dtype): | |||
dtype = np.dtype(dtype) | |||
if not is_dtype_equal(x.dtype, dtype): | |||
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||
return x | |||
def convert_single_value(v, *, dtype=None, device=None): | |||
if isinstance(v, (Tensor, SymbolVar)): | |||
if not is_quantize(v.dtype): | |||
v = astype(v, dtype) | |||
else: | |||
v = Const(v, dtype, device, None) | |||
return v | |||
return convert_single_value_cpp(v, dtype, device) | |||
def convert_inputs(*args, device=None): | |||
if not _get_convert_inputs(): | |||
return args | |||
dtype = dtype_promotion(args) | |||
if device is None: | |||
device = get_device(args) | |||
device = as_device(device) | |||
graph = None | |||
sym_type = None | |||
for a in args: | |||
if isinstance(a, SymbolVar): | |||
if graph is None: | |||
graph = a.var.graph | |||
sym_type = type(a) | |||
else: | |||
assert graph == a.var.graph | |||
args = list(args) | |||
if graph is not None: | |||
for i in range(len(args)): | |||
if not isinstance(args[i], SymbolVar): | |||
rst = make_const(graph, np.array(args[i]), device.to_c(), dtype) | |||
args[i] = sym_type(rst) | |||
def convert(value): | |||
if value is None: | |||
return value | |||
return convert_single_value(value, dtype=dtype, device=device.to_c()) | |||
return tuple(map(convert, args)) | |||
return convert_inputs_cpp(*args, device) | |||
def cast_tensors(*args, promote=False): | |||
@@ -146,7 +110,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
pass | |||
except ValueError: | |||
if dtype is not None and dtype != x.dtype: | |||
x = astype(x, dtype) | |||
x = astype_cpp(x, dtype) | |||
if device is not None: | |||
cn = as_device(device).to_c() | |||
(x,) = apply(builtin.Copy(comp_node=cn), x) | |||
@@ -164,7 +128,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||
if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | |||
x = concatenate(x, device=device) if len(x) > 1 else x[0] | |||
if dtype is not None: | |||
x = astype(x, dtype) | |||
x = astype_cpp(x, dtype) | |||
return x | |||
x = Const(x, dtype, device, reference) | |||
return x | |||
@@ -30,7 +30,6 @@ from ..core.tensor import amp, megbrain_graph | |||
from ..core.tensor.array_method import _elwise_apply | |||
from ..core.tensor.utils import ( | |||
astensor1d, | |||
astype, | |||
cast_tensors, | |||
convert_single_value, | |||
make_shape_tuple, | |||
@@ -170,6 +170,12 @@ struct _wrap { | |||
} // anonymous namespace | |||
namespace imperative::python { | |||
bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) { | |||
return _is_dtype_equal(dt1, dt2); | |||
} | |||
} // namespace imperative::python | |||
#ifdef METH_FASTCALL | |||
#define MGE_PY_INTERFACE(NAME, FUN) \ | |||
{ #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr } | |||
@@ -26,6 +26,11 @@ | |||
cb(BFloat16, npy_num_bfloat16()) | |||
namespace mgb { | |||
namespace imperative::python { | |||
bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2); | |||
} // namespace imperative::python | |||
//! numpy type num for intb1/2/4 type | |||
#define DEFINE_NPY_INTBX(n) int npy_num_intb##n(); | |||
FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | |||
@@ -400,223 +400,6 @@ struct TensorWeakRef { | |||
int _use_cnt() { return wptr.use_count(); } | |||
}; | |||
/* ============== convert inputs ============== */ | |||
// map numpy.dtype.kind to priority | |||
inline uint8_t category_priority(char c) { | |||
switch (c) { | |||
case 'f': | |||
return 3; // floating-point | |||
case 'i': | |||
return 2; // signed integer | |||
case 'u': | |||
return 2; // unsigned integer | |||
case 'b': | |||
return 1; // boolean | |||
default: | |||
return 0; | |||
} | |||
} | |||
// Returns the maximum value of the priority of each type in the list `types`. | |||
uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||
if (types.size() == 0) { | |||
return 0; | |||
} else { | |||
uint8_t max_p = 0; | |||
for (auto&& desc : types) { | |||
max_p = std::max(max_p, category_priority(desc->kind)); | |||
} | |||
return max_p; | |||
} | |||
} | |||
// Returns the data type with sufficient size to hold all types of | |||
// category `cat` in the list `types`. | |||
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||
// Return value: New reference | |||
SmallVector<PyArray_Descr*> used_types; | |||
for (auto&& desc : types) { | |||
auto&& v = category_priority(desc->kind); | |||
if (v == cat) { | |||
used_types.emplace_back(desc); | |||
} | |||
} | |||
mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||
PyArray_Descr* res = used_types[0]; | |||
Py_INCREF(res); | |||
for (size_t i = 1; i < used_types.size(); ++i) { | |||
PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||
Py_DECREF(res); | |||
res = tmp; | |||
} | |||
return res; | |||
} | |||
PyArray_Descr* scalar2dtype(PyObject* arg) { | |||
// Return value: New reference | |||
if (PyBool_Check(arg)) { | |||
auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||
return descr; | |||
} | |||
if (PyLong_CheckExact(arg)) { | |||
auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||
return descr; | |||
} | |||
if (PyFloat_CheckExact(arg)) { | |||
auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||
return descr; | |||
} | |||
return nullptr; | |||
} | |||
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { | |||
// Return value: New reference | |||
SmallVector<PyArray_Descr*> tensors; | |||
SmallVector<PyArray_Descr*> scalars; | |||
bool is_tuple = false; | |||
PyObject* tuple = nullptr; | |||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
if (PyList_Check(args[0])) { | |||
tuple = PyList_AsTuple(args[0]); | |||
} else { | |||
tuple = args[0]; | |||
Py_INCREF(tuple); | |||
} | |||
nargs = PyTuple_Size(tuple); | |||
is_tuple = true; | |||
} | |||
for (size_t i = 0; i < nargs; ++i) { | |||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
if (handle == Py_None) | |||
continue; | |||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
if (tw) { | |||
mgb::DType type = tw->m_tensor->dtype(); | |||
auto&& descr = npy::dtype_mgb2np_descr(type); | |||
Py_INCREF(descr.get()); | |||
tensors.emplace_back(descr.get()); | |||
} else { | |||
if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||
auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||
tensors.emplace_back(descr); | |||
continue; | |||
} | |||
if (py::isinstance<PySymbolVar>(py::handle(handle))) { | |||
auto var = py::handle(handle).cast<PySymbolVar*>(); | |||
mgb::DType type = var->m_node->dtype(); | |||
auto&& descr = npy::dtype_mgb2np_descr(type); | |||
Py_INCREF(descr.get()); | |||
tensors.emplace_back(descr.get()); | |||
continue; | |||
} | |||
PyArray_Descr* descr = scalar2dtype(handle); | |||
if (descr) { | |||
scalars.emplace_back(descr); | |||
continue; | |||
} | |||
} | |||
} | |||
auto max_pri_scalars = max_priority(scalars); | |||
auto max_pri_tensors = max_priority(tensors); | |||
if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||
throw py::value_error("invalid input, no dtype avaliable"); | |||
} | |||
PyArray_Descr* res; | |||
if (max_pri_scalars > max_pri_tensors) { | |||
res = promote_types(scalars, max_pri_scalars); | |||
} else { | |||
res = promote_types(tensors, max_pri_tensors); | |||
} | |||
for (auto* p : tensors) { | |||
Py_DECREF(p); | |||
} | |||
for (auto* p : scalars) { | |||
Py_DECREF(p); | |||
} | |||
Py_XDECREF(tuple); | |||
return res; | |||
} | |||
CompNode _get_device(PyObject* const* args, size_t nargs) { | |||
bool is_tuple = false; | |||
PyObject* tuple = nullptr; | |||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
if (PyList_Check(args[0])) { | |||
tuple = PyList_AsTuple(args[0]); | |||
} else { | |||
tuple = args[0]; | |||
Py_INCREF(tuple); | |||
} | |||
nargs = PyTuple_Size(tuple); | |||
is_tuple = true; | |||
} | |||
bool valid = false; | |||
CompNode cn; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||
if (tw || is_symvar) { | |||
if (!valid) { | |||
cn = tw ? tw->m_tensor->comp_node() | |||
: py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node(); | |||
valid = true; | |||
} else { | |||
CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||
: py::handle(handle) | |||
.cast<PySymbolVar*>() | |||
->m_node->comp_node(); | |||
if (cn1 != cn) { | |||
throw py::value_error(ssprintf( | |||
"ambiguous device: %s (from %s) vs %s (from %s)", | |||
cn.to_string().c_str(), cn.to_string_logical().c_str(), | |||
cn1.to_string().c_str(), cn1.to_string_logical().c_str())); | |||
} | |||
} | |||
} | |||
} | |||
if (!valid) { | |||
return CompNode::load(get_default_device()); | |||
} | |||
Py_XDECREF(tuple); | |||
return cn; | |||
} | |||
// Returns the dtype that would result from performing an arithmetic | |||
// operation on the provided input tensors and scalars. | |||
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) { | |||
if (!nargs) { | |||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
return nullptr; | |||
} | |||
try { | |||
PyArray_Descr* res = _dtype_promotion(args, nargs); | |||
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||
} | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { | |||
if (!nargs) { | |||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
return nullptr; | |||
} | |||
try { | |||
CompNode cn = _get_device(args, nargs); | |||
return py::cast(cn).release().ptr(); | |||
} | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
#ifdef METH_FASTCALL | |||
#define MGE_PY_INTERFACE(NAME, FUNC) \ | |||
{ #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr } | |||
@@ -640,6 +423,9 @@ WRAP_FUNC_PY35(transpose_cpp); | |||
WRAP_FUNC_PY35(broadcast_cpp); | |||
WRAP_FUNC_PY35(reshape_cpp); | |||
WRAP_FUNC_PY35(Const); | |||
WRAP_FUNC_PY35(astype_cpp); | |||
WRAP_FUNC_PY35(convert_single_value_cpp); | |||
WRAP_FUNC_PY35(convert_inputs_cpp); | |||
#undef WRAP_FUNC_PY35 | |||
#define MGE_PY_INTERFACE(NAME, FUNC) \ | |||
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | |||
@@ -779,6 +565,9 @@ void init_tensor(py::module m) { | |||
MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | |||
MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | |||
MGE_PY_INTERFACE(Const, Const), | |||
MGE_PY_INTERFACE(astype_cpp, astype_cpp), | |||
MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), | |||
MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), | |||
{nullptr, nullptr, 0, nullptr}}; | |||
for (auto&& def : method_defs) { | |||
if (def.ml_meth != nullptr) { | |||
@@ -52,6 +52,223 @@ namespace views = ranges::views; | |||
namespace mgb::imperative::python { | |||
/* ============== convert inputs ============== */ | |||
// map numpy.dtype.kind to priority | |||
inline uint8_t category_priority(char c) { | |||
switch (c) { | |||
case 'f': | |||
return 3; // floating-point | |||
case 'i': | |||
return 2; // signed integer | |||
case 'u': | |||
return 2; // unsigned integer | |||
case 'b': | |||
return 1; // boolean | |||
default: | |||
return 0; | |||
} | |||
} | |||
// Returns the maximum value of the priority of each type in the list `types`. | |||
uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||
if (types.size() == 0) { | |||
return 0; | |||
} else { | |||
uint8_t max_p = 0; | |||
for (auto&& desc : types) { | |||
max_p = std::max(max_p, category_priority(desc->kind)); | |||
} | |||
return max_p; | |||
} | |||
} | |||
// Returns the data type with sufficient size to hold all types of | |||
// category `cat` in the list `types`. | |||
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||
// Return value: New reference | |||
SmallVector<PyArray_Descr*> used_types; | |||
for (auto&& desc : types) { | |||
auto&& v = category_priority(desc->kind); | |||
if (v == cat) { | |||
used_types.emplace_back(desc); | |||
} | |||
} | |||
mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||
PyArray_Descr* res = used_types[0]; | |||
Py_INCREF(res); | |||
for (size_t i = 1; i < used_types.size(); ++i) { | |||
PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||
Py_DECREF(res); | |||
res = tmp; | |||
} | |||
return res; | |||
} | |||
PyArray_Descr* scalar2dtype(PyObject* arg) { | |||
// Return value: New reference | |||
if (PyBool_Check(arg)) { | |||
auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||
return descr; | |||
} | |||
if (PyLong_CheckExact(arg)) { | |||
auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||
return descr; | |||
} | |||
if (PyFloat_CheckExact(arg)) { | |||
auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||
return descr; | |||
} | |||
return nullptr; | |||
} | |||
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { | |||
// Return value: New reference | |||
SmallVector<PyArray_Descr*> tensors; | |||
SmallVector<PyArray_Descr*> scalars; | |||
bool is_tuple = false; | |||
PyObject* tuple = nullptr; | |||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
if (PyList_Check(args[0])) { | |||
tuple = PyList_AsTuple(args[0]); | |||
} else { | |||
tuple = args[0]; | |||
Py_INCREF(tuple); | |||
} | |||
nargs = PyTuple_Size(tuple); | |||
is_tuple = true; | |||
} | |||
for (size_t i = 0; i < nargs; ++i) { | |||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
if (handle == Py_None) | |||
continue; | |||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
if (tw) { | |||
mgb::DType type = tw->m_tensor->dtype(); | |||
auto&& descr = npy::dtype_mgb2np_descr(type); | |||
Py_INCREF(descr.get()); | |||
tensors.emplace_back(descr.get()); | |||
} else { | |||
if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||
auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||
tensors.emplace_back(descr); | |||
continue; | |||
} | |||
if (py::isinstance<PySymbolVar>(py::handle(handle))) { | |||
auto var = py::handle(handle).cast<PySymbolVar*>(); | |||
mgb::DType type = var->m_node->dtype(); | |||
auto&& descr = npy::dtype_mgb2np_descr(type); | |||
Py_INCREF(descr.get()); | |||
tensors.emplace_back(descr.get()); | |||
continue; | |||
} | |||
PyArray_Descr* descr = scalar2dtype(handle); | |||
if (descr) { | |||
scalars.emplace_back(descr); | |||
continue; | |||
} | |||
} | |||
} | |||
auto max_pri_scalars = max_priority(scalars); | |||
auto max_pri_tensors = max_priority(tensors); | |||
if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||
throw py::value_error("invalid input, no dtype avaliable"); | |||
} | |||
PyArray_Descr* res; | |||
if (max_pri_scalars > max_pri_tensors) { | |||
res = promote_types(scalars, max_pri_scalars); | |||
} else { | |||
res = promote_types(tensors, max_pri_tensors); | |||
} | |||
for (auto* p : tensors) { | |||
Py_DECREF(p); | |||
} | |||
for (auto* p : scalars) { | |||
Py_DECREF(p); | |||
} | |||
Py_XDECREF(tuple); | |||
return res; | |||
} | |||
CompNode _get_device(PyObject* const* args, size_t nargs) { | |||
bool is_tuple = false; | |||
PyObject* tuple = nullptr; | |||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||
if (PyList_Check(args[0])) { | |||
tuple = PyList_AsTuple(args[0]); | |||
} else { | |||
tuple = args[0]; | |||
Py_INCREF(tuple); | |||
} | |||
nargs = PyTuple_Size(tuple); | |||
is_tuple = true; | |||
} | |||
bool valid = false; | |||
CompNode cn; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||
if (tw || is_symvar) { | |||
if (!valid) { | |||
cn = tw ? tw->m_tensor->comp_node() | |||
: py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node(); | |||
valid = true; | |||
} else { | |||
CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||
: py::handle(handle) | |||
.cast<PySymbolVar*>() | |||
->m_node->comp_node(); | |||
if (cn1 != cn) { | |||
throw py::value_error(ssprintf( | |||
"ambiguous device: %s (from %s) vs %s (from %s)", | |||
cn.to_string().c_str(), cn.to_string_logical().c_str(), | |||
cn1.to_string().c_str(), cn1.to_string_logical().c_str())); | |||
} | |||
} | |||
} | |||
} | |||
if (!valid) { | |||
return CompNode::load(get_default_device()); | |||
} | |||
Py_XDECREF(tuple); | |||
return cn; | |||
} | |||
// Returns the dtype that would result from performing an arithmetic | |||
// operation on the provided input tensors and scalars. | |||
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) { | |||
if (!nargs) { | |||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
return nullptr; | |||
} | |||
try { | |||
PyArray_Descr* res = _dtype_promotion(args, nargs); | |||
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||
} | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { | |||
if (!nargs) { | |||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||
return nullptr; | |||
} | |||
try { | |||
CompNode cn = _get_device(args, nargs); | |||
return py::cast(cn).release().ptr(); | |||
} | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
bool is_scalar(PyObject* tensor) { | |||
if (py::isinstance<PySymbolVar>(py::handle(tensor))) { | |||
auto var = py::handle(tensor).cast<PySymbolVar*>(); | |||
@@ -147,7 +364,6 @@ py::object _Const( | |||
"dmap_callback"); | |||
if (dmap.ptr() != Py_None) { | |||
device_obj = dmap(device); | |||
py::print(device_obj); | |||
} else { | |||
device_obj = py::cast(CompNode::load(device.cast<std::string>())); | |||
} | |||
@@ -1072,6 +1288,92 @@ py::object _reshape_cpp(py::handle inp_hdl, py::handle args) { | |||
return ret[0]; | |||
} | |||
mgb::DType _get_dtype(py::handle tensor) { | |||
if (auto tw = TensorWrapper::try_cast(tensor.ptr())) { | |||
return tw->m_tensor->dtype(); | |||
} else { | |||
auto var = tensor.cast<PySymbolVar*>(); | |||
return var->m_node->dtype(); | |||
} | |||
} | |||
py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { | |||
PyArray_Descr* descr; | |||
if (!PyArray_DescrConverter(dtype_hdl.ptr(), &descr)) { | |||
throw py::value_error(ssprintf( | |||
"can not convert to numpy.dtype from %s", | |||
dtype_hdl.ptr()->ob_type->tp_name)); | |||
} | |||
PyArray_Descr* cur = npy::dtype_mgb2np_descr(_get_dtype(tensor)).get(); | |||
if (!dtype_equal(cur, descr)) { | |||
std::shared_ptr<OpDef> op = TypeCvt::make(npy::dtype_np2mgb_descr(descr)); | |||
py::object Op = py::cast(op); | |||
std::vector<PyObject*> p; | |||
p.resize(2); | |||
p[0] = Op.ptr(); | |||
p[1] = tensor.ptr(); | |||
py::tuple ret = | |||
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||
return ret[0]; | |||
} else { | |||
return py::reinterpret_borrow<py::object>(tensor); | |||
} | |||
} | |||
py::object _convert_single_value_cpp( | |||
py::handle value, py::handle dtype, py::handle device) { | |||
if (is_tensor_or_symbolvar(value)) { | |||
if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) { | |||
return _astype_cpp(value, dtype); | |||
} | |||
} else { | |||
return _Const(value, dtype, device, py::none()); | |||
} | |||
return py::reinterpret_borrow<py::object>(value); | |||
} | |||
py::object _convert_inputs_cpp( | |||
PyObject* const* args, size_t nargs, py::object dtype, py::object device) { | |||
ComputingGraph* graph = nullptr; | |||
py::handle typeobj; | |||
py::list lis; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
py::handle h = py::handle(args[i]); | |||
lis.append(h); | |||
if (py::isinstance<PySymbolVar>(h)) { | |||
auto var = h.cast<PySymbolVar*>(); | |||
auto g = var->m_node->owner_graph(); | |||
if (!graph) { | |||
graph = g; | |||
typeobj = h.get_type(); | |||
} else { | |||
mgb_assert(graph == g); | |||
} | |||
} | |||
} | |||
if (graph) { | |||
CompNode cn = device.cast<CompNode>(); | |||
for (size_t i = 0; i < nargs; ++i) { | |||
OperatorNodeConfig config(cn); | |||
auto hv = npy::np2tensor( | |||
lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | |||
if (py::isinstance<PySymbolVar>(lis[i])) { | |||
lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | |||
} | |||
} | |||
} | |||
auto convert = [&](py::object value) { | |||
if (value.ptr() == Py_None) { | |||
return value; | |||
} | |||
return _convert_single_value_cpp(value, dtype, device); | |||
}; | |||
for (size_t i = 0; i < lis.size(); ++i) { | |||
lis[i] = convert(lis[i]); | |||
} | |||
return py::reinterpret_steal<py::tuple>(PyList_AsTuple(lis.ptr())); | |||
} | |||
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
return _make_shape_tuple(py::handle(args[0])).release().ptr(); | |||
@@ -1152,4 +1454,38 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
return _astype_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||
} | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
PyObject* convert_single_value_cpp( | |||
PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
return _convert_single_value_cpp( | |||
py::handle(args[0]), py::handle(args[1]), py::handle(args[2])) | |||
.release() | |||
.ptr(); | |||
} | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
try { | |||
py::object dtype = py::reinterpret_steal<py::object>( | |||
dtype_promotion(self, args, nargs - 1)); | |||
py::object device; | |||
if (args[nargs - 1] == Py_None) { | |||
device = py::reinterpret_steal<py::object>( | |||
get_device(self, args, nargs - 1)); | |||
} else { | |||
device = py::reinterpret_borrow<py::object>(args[nargs - 1]); | |||
} | |||
return _convert_inputs_cpp(args, nargs - 1, dtype, device).release().ptr(); | |||
} | |||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
} | |||
} // namespace mgb::imperative::python |
@@ -2,6 +2,10 @@ | |||
namespace mgb::imperative::python { | |||
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
@@ -22,4 +26,10 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
} // namespace mgb::imperative::python |