@@ -19,6 +19,7 @@ from .._imperative_rt.core2 import ( | |||||
SymbolVar, | SymbolVar, | ||||
Tensor, | Tensor, | ||||
apply, | apply, | ||||
astype_cpp, | |||||
broadcast_cpp, | broadcast_cpp, | ||||
dtype_promotion, | dtype_promotion, | ||||
) | ) | ||||
@@ -27,14 +28,7 @@ from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp | |||||
from ..ops import builtin | from ..ops import builtin | ||||
from . import amp | from . import amp | ||||
from .indexing import getitem, setitem | from .indexing import getitem, setitem | ||||
from .utils import ( | |||||
_normalize_axis, | |||||
astensor1d, | |||||
astype, | |||||
cast_tensors, | |||||
make_shape_tuple, | |||||
subgraph, | |||||
) | |||||
from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph | |||||
_ElwMod = builtin.Elemwise.Mode | _ElwMod = builtin.Elemwise.Mode | ||||
@@ -605,7 +599,7 @@ class ArrayMethodMixin(abc.ABC): | |||||
r"""Returns a :class:`Tensor` with the same data and number of elements | r"""Returns a :class:`Tensor` with the same data and number of elements | ||||
with the specified :attr:`~.Tensor.dtype`. | with the specified :attr:`~.Tensor.dtype`. | ||||
""" | """ | ||||
return astype(self, dtype) | |||||
return astype_cpp(self, dtype) | |||||
def reshape(self, *args): | def reshape(self, *args): | ||||
r"""See :func:`~.reshape`.""" | r"""See :func:`~.reshape`.""" | ||||
@@ -20,6 +20,9 @@ from .._imperative_rt.core2 import ( | |||||
_get_convert_inputs, | _get_convert_inputs, | ||||
_set_convert_inputs, | _set_convert_inputs, | ||||
apply, | apply, | ||||
astype_cpp, | |||||
convert_inputs_cpp, | |||||
convert_single_value_cpp, | |||||
dtype_promotion, | dtype_promotion, | ||||
get_device, | get_device, | ||||
make_shape_tuple, | make_shape_tuple, | ||||
@@ -55,53 +58,14 @@ def concatenate(inputs, axis=0, *, device=None): | |||||
return result | return result | ||||
def astype(x, dtype): | |||||
dtype = np.dtype(dtype) | |||||
if not is_dtype_equal(x.dtype, dtype): | |||||
(x,) = apply(builtin.TypeCvt(dtype=dtype), x) | |||||
return x | |||||
def convert_single_value(v, *, dtype=None, device=None): | def convert_single_value(v, *, dtype=None, device=None): | ||||
if isinstance(v, (Tensor, SymbolVar)): | |||||
if not is_quantize(v.dtype): | |||||
v = astype(v, dtype) | |||||
else: | |||||
v = Const(v, dtype, device, None) | |||||
return v | |||||
return convert_single_value_cpp(v, dtype, device) | |||||
def convert_inputs(*args, device=None): | def convert_inputs(*args, device=None): | ||||
if not _get_convert_inputs(): | if not _get_convert_inputs(): | ||||
return args | return args | ||||
dtype = dtype_promotion(args) | |||||
if device is None: | |||||
device = get_device(args) | |||||
device = as_device(device) | |||||
graph = None | |||||
sym_type = None | |||||
for a in args: | |||||
if isinstance(a, SymbolVar): | |||||
if graph is None: | |||||
graph = a.var.graph | |||||
sym_type = type(a) | |||||
else: | |||||
assert graph == a.var.graph | |||||
args = list(args) | |||||
if graph is not None: | |||||
for i in range(len(args)): | |||||
if not isinstance(args[i], SymbolVar): | |||||
rst = make_const(graph, np.array(args[i]), device.to_c(), dtype) | |||||
args[i] = sym_type(rst) | |||||
def convert(value): | |||||
if value is None: | |||||
return value | |||||
return convert_single_value(value, dtype=dtype, device=device.to_c()) | |||||
return tuple(map(convert, args)) | |||||
return convert_inputs_cpp(*args, device) | |||||
def cast_tensors(*args, promote=False): | def cast_tensors(*args, promote=False): | ||||
@@ -146,7 +110,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
pass | pass | ||||
except ValueError: | except ValueError: | ||||
if dtype is not None and dtype != x.dtype: | if dtype is not None and dtype != x.dtype: | ||||
x = astype(x, dtype) | |||||
x = astype_cpp(x, dtype) | |||||
if device is not None: | if device is not None: | ||||
cn = as_device(device).to_c() | cn = as_device(device).to_c() | ||||
(x,) = apply(builtin.Copy(comp_node=cn), x) | (x,) = apply(builtin.Copy(comp_node=cn), x) | ||||
@@ -164,7 +128,7 @@ def astensor1d(x, *reference, dtype=None, device=None): | |||||
if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | if any(isinstance(i, (Tensor, SymbolVar)) for i in x): | ||||
x = concatenate(x, device=device) if len(x) > 1 else x[0] | x = concatenate(x, device=device) if len(x) > 1 else x[0] | ||||
if dtype is not None: | if dtype is not None: | ||||
x = astype(x, dtype) | |||||
x = astype_cpp(x, dtype) | |||||
return x | return x | ||||
x = Const(x, dtype, device, reference) | x = Const(x, dtype, device, reference) | ||||
return x | return x | ||||
@@ -30,7 +30,6 @@ from ..core.tensor import amp, megbrain_graph | |||||
from ..core.tensor.array_method import _elwise_apply | from ..core.tensor.array_method import _elwise_apply | ||||
from ..core.tensor.utils import ( | from ..core.tensor.utils import ( | ||||
astensor1d, | astensor1d, | ||||
astype, | |||||
cast_tensors, | cast_tensors, | ||||
convert_single_value, | convert_single_value, | ||||
make_shape_tuple, | make_shape_tuple, | ||||
@@ -170,6 +170,12 @@ struct _wrap { | |||||
} // anonymous namespace | } // anonymous namespace | ||||
namespace imperative::python { | |||||
bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) { | |||||
return _is_dtype_equal(dt1, dt2); | |||||
} | |||||
} // namespace imperative::python | |||||
#ifdef METH_FASTCALL | #ifdef METH_FASTCALL | ||||
#define MGE_PY_INTERFACE(NAME, FUN) \ | #define MGE_PY_INTERFACE(NAME, FUN) \ | ||||
{ #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr } | { #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr } | ||||
@@ -26,6 +26,11 @@ | |||||
cb(BFloat16, npy_num_bfloat16()) | cb(BFloat16, npy_num_bfloat16()) | ||||
namespace mgb { | namespace mgb { | ||||
namespace imperative::python { | |||||
bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2); | |||||
} // namespace imperative::python | |||||
//! numpy type num for intb1/2/4 type | //! numpy type num for intb1/2/4 type | ||||
#define DEFINE_NPY_INTBX(n) int npy_num_intb##n(); | #define DEFINE_NPY_INTBX(n) int npy_num_intb##n(); | ||||
FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) | ||||
@@ -400,223 +400,6 @@ struct TensorWeakRef { | |||||
int _use_cnt() { return wptr.use_count(); } | int _use_cnt() { return wptr.use_count(); } | ||||
}; | }; | ||||
/* ============== convert inputs ============== */ | |||||
// map numpy.dtype.kind to priority | |||||
inline uint8_t category_priority(char c) { | |||||
switch (c) { | |||||
case 'f': | |||||
return 3; // floating-point | |||||
case 'i': | |||||
return 2; // signed integer | |||||
case 'u': | |||||
return 2; // unsigned integer | |||||
case 'b': | |||||
return 1; // boolean | |||||
default: | |||||
return 0; | |||||
} | |||||
} | |||||
// Returns the maximum value of the priority of each type in the list `types`. | |||||
uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||||
if (types.size() == 0) { | |||||
return 0; | |||||
} else { | |||||
uint8_t max_p = 0; | |||||
for (auto&& desc : types) { | |||||
max_p = std::max(max_p, category_priority(desc->kind)); | |||||
} | |||||
return max_p; | |||||
} | |||||
} | |||||
// Returns the data type with sufficient size to hold all types of | |||||
// category `cat` in the list `types`. | |||||
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||||
// Return value: New reference | |||||
SmallVector<PyArray_Descr*> used_types; | |||||
for (auto&& desc : types) { | |||||
auto&& v = category_priority(desc->kind); | |||||
if (v == cat) { | |||||
used_types.emplace_back(desc); | |||||
} | |||||
} | |||||
mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||||
PyArray_Descr* res = used_types[0]; | |||||
Py_INCREF(res); | |||||
for (size_t i = 1; i < used_types.size(); ++i) { | |||||
PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||||
Py_DECREF(res); | |||||
res = tmp; | |||||
} | |||||
return res; | |||||
} | |||||
PyArray_Descr* scalar2dtype(PyObject* arg) { | |||||
// Return value: New reference | |||||
if (PyBool_Check(arg)) { | |||||
auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||||
return descr; | |||||
} | |||||
if (PyLong_CheckExact(arg)) { | |||||
auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||||
return descr; | |||||
} | |||||
if (PyFloat_CheckExact(arg)) { | |||||
auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||||
return descr; | |||||
} | |||||
return nullptr; | |||||
} | |||||
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { | |||||
// Return value: New reference | |||||
SmallVector<PyArray_Descr*> tensors; | |||||
SmallVector<PyArray_Descr*> scalars; | |||||
bool is_tuple = false; | |||||
PyObject* tuple = nullptr; | |||||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
if (PyList_Check(args[0])) { | |||||
tuple = PyList_AsTuple(args[0]); | |||||
} else { | |||||
tuple = args[0]; | |||||
Py_INCREF(tuple); | |||||
} | |||||
nargs = PyTuple_Size(tuple); | |||||
is_tuple = true; | |||||
} | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||||
if (handle == Py_None) | |||||
continue; | |||||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||||
if (tw) { | |||||
mgb::DType type = tw->m_tensor->dtype(); | |||||
auto&& descr = npy::dtype_mgb2np_descr(type); | |||||
Py_INCREF(descr.get()); | |||||
tensors.emplace_back(descr.get()); | |||||
} else { | |||||
if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||||
auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||||
tensors.emplace_back(descr); | |||||
continue; | |||||
} | |||||
if (py::isinstance<PySymbolVar>(py::handle(handle))) { | |||||
auto var = py::handle(handle).cast<PySymbolVar*>(); | |||||
mgb::DType type = var->m_node->dtype(); | |||||
auto&& descr = npy::dtype_mgb2np_descr(type); | |||||
Py_INCREF(descr.get()); | |||||
tensors.emplace_back(descr.get()); | |||||
continue; | |||||
} | |||||
PyArray_Descr* descr = scalar2dtype(handle); | |||||
if (descr) { | |||||
scalars.emplace_back(descr); | |||||
continue; | |||||
} | |||||
} | |||||
} | |||||
auto max_pri_scalars = max_priority(scalars); | |||||
auto max_pri_tensors = max_priority(tensors); | |||||
if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||||
throw py::value_error("invalid input, no dtype avaliable"); | |||||
} | |||||
PyArray_Descr* res; | |||||
if (max_pri_scalars > max_pri_tensors) { | |||||
res = promote_types(scalars, max_pri_scalars); | |||||
} else { | |||||
res = promote_types(tensors, max_pri_tensors); | |||||
} | |||||
for (auto* p : tensors) { | |||||
Py_DECREF(p); | |||||
} | |||||
for (auto* p : scalars) { | |||||
Py_DECREF(p); | |||||
} | |||||
Py_XDECREF(tuple); | |||||
return res; | |||||
} | |||||
CompNode _get_device(PyObject* const* args, size_t nargs) { | |||||
bool is_tuple = false; | |||||
PyObject* tuple = nullptr; | |||||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
if (PyList_Check(args[0])) { | |||||
tuple = PyList_AsTuple(args[0]); | |||||
} else { | |||||
tuple = args[0]; | |||||
Py_INCREF(tuple); | |||||
} | |||||
nargs = PyTuple_Size(tuple); | |||||
is_tuple = true; | |||||
} | |||||
bool valid = false; | |||||
CompNode cn; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||||
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||||
if (tw || is_symvar) { | |||||
if (!valid) { | |||||
cn = tw ? tw->m_tensor->comp_node() | |||||
: py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node(); | |||||
valid = true; | |||||
} else { | |||||
CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||||
: py::handle(handle) | |||||
.cast<PySymbolVar*>() | |||||
->m_node->comp_node(); | |||||
if (cn1 != cn) { | |||||
throw py::value_error(ssprintf( | |||||
"ambiguous device: %s (from %s) vs %s (from %s)", | |||||
cn.to_string().c_str(), cn.to_string_logical().c_str(), | |||||
cn1.to_string().c_str(), cn1.to_string_logical().c_str())); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
if (!valid) { | |||||
return CompNode::load(get_default_device()); | |||||
} | |||||
Py_XDECREF(tuple); | |||||
return cn; | |||||
} | |||||
// Returns the dtype that would result from performing an arithmetic | |||||
// operation on the provided input tensors and scalars. | |||||
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
if (!nargs) { | |||||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
return nullptr; | |||||
} | |||||
try { | |||||
PyArray_Descr* res = _dtype_promotion(args, nargs); | |||||
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
if (!nargs) { | |||||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
return nullptr; | |||||
} | |||||
try { | |||||
CompNode cn = _get_device(args, nargs); | |||||
return py::cast(cn).release().ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
#ifdef METH_FASTCALL | #ifdef METH_FASTCALL | ||||
#define MGE_PY_INTERFACE(NAME, FUNC) \ | #define MGE_PY_INTERFACE(NAME, FUNC) \ | ||||
{ #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr } | { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr } | ||||
@@ -640,6 +423,9 @@ WRAP_FUNC_PY35(transpose_cpp); | |||||
WRAP_FUNC_PY35(broadcast_cpp); | WRAP_FUNC_PY35(broadcast_cpp); | ||||
WRAP_FUNC_PY35(reshape_cpp); | WRAP_FUNC_PY35(reshape_cpp); | ||||
WRAP_FUNC_PY35(Const); | WRAP_FUNC_PY35(Const); | ||||
WRAP_FUNC_PY35(astype_cpp); | |||||
WRAP_FUNC_PY35(convert_single_value_cpp); | |||||
WRAP_FUNC_PY35(convert_inputs_cpp); | |||||
#undef WRAP_FUNC_PY35 | #undef WRAP_FUNC_PY35 | ||||
#define MGE_PY_INTERFACE(NAME, FUNC) \ | #define MGE_PY_INTERFACE(NAME, FUNC) \ | ||||
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | ||||
@@ -779,6 +565,9 @@ void init_tensor(py::module m) { | |||||
MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | ||||
MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | ||||
MGE_PY_INTERFACE(Const, Const), | MGE_PY_INTERFACE(Const, Const), | ||||
MGE_PY_INTERFACE(astype_cpp, astype_cpp), | |||||
MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), | |||||
MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), | |||||
{nullptr, nullptr, 0, nullptr}}; | {nullptr, nullptr, 0, nullptr}}; | ||||
for (auto&& def : method_defs) { | for (auto&& def : method_defs) { | ||||
if (def.ml_meth != nullptr) { | if (def.ml_meth != nullptr) { | ||||
@@ -52,6 +52,223 @@ namespace views = ranges::views; | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
/* ============== convert inputs ============== */ | |||||
// map numpy.dtype.kind to priority | |||||
inline uint8_t category_priority(char c) { | |||||
switch (c) { | |||||
case 'f': | |||||
return 3; // floating-point | |||||
case 'i': | |||||
return 2; // signed integer | |||||
case 'u': | |||||
return 2; // unsigned integer | |||||
case 'b': | |||||
return 1; // boolean | |||||
default: | |||||
return 0; | |||||
} | |||||
} | |||||
// Returns the maximum value of the priority of each type in the list `types`. | |||||
uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||||
if (types.size() == 0) { | |||||
return 0; | |||||
} else { | |||||
uint8_t max_p = 0; | |||||
for (auto&& desc : types) { | |||||
max_p = std::max(max_p, category_priority(desc->kind)); | |||||
} | |||||
return max_p; | |||||
} | |||||
} | |||||
// Returns the data type with sufficient size to hold all types of | |||||
// category `cat` in the list `types`. | |||||
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||||
// Return value: New reference | |||||
SmallVector<PyArray_Descr*> used_types; | |||||
for (auto&& desc : types) { | |||||
auto&& v = category_priority(desc->kind); | |||||
if (v == cat) { | |||||
used_types.emplace_back(desc); | |||||
} | |||||
} | |||||
mgb_assert(used_types.size() > 0, "size of used_types is 0"); | |||||
PyArray_Descr* res = used_types[0]; | |||||
Py_INCREF(res); | |||||
for (size_t i = 1; i < used_types.size(); ++i) { | |||||
PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); | |||||
Py_DECREF(res); | |||||
res = tmp; | |||||
} | |||||
return res; | |||||
} | |||||
PyArray_Descr* scalar2dtype(PyObject* arg) { | |||||
// Return value: New reference | |||||
if (PyBool_Check(arg)) { | |||||
auto&& descr = PyArray_DescrFromType(NPY_BOOL); | |||||
return descr; | |||||
} | |||||
if (PyLong_CheckExact(arg)) { | |||||
auto&& descr = PyArray_DescrFromType(NPY_INT32); | |||||
return descr; | |||||
} | |||||
if (PyFloat_CheckExact(arg)) { | |||||
auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); | |||||
return descr; | |||||
} | |||||
return nullptr; | |||||
} | |||||
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { | |||||
// Return value: New reference | |||||
SmallVector<PyArray_Descr*> tensors; | |||||
SmallVector<PyArray_Descr*> scalars; | |||||
bool is_tuple = false; | |||||
PyObject* tuple = nullptr; | |||||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
if (PyList_Check(args[0])) { | |||||
tuple = PyList_AsTuple(args[0]); | |||||
} else { | |||||
tuple = args[0]; | |||||
Py_INCREF(tuple); | |||||
} | |||||
nargs = PyTuple_Size(tuple); | |||||
is_tuple = true; | |||||
} | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||||
if (handle == Py_None) | |||||
continue; | |||||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||||
if (tw) { | |||||
mgb::DType type = tw->m_tensor->dtype(); | |||||
auto&& descr = npy::dtype_mgb2np_descr(type); | |||||
Py_INCREF(descr.get()); | |||||
tensors.emplace_back(descr.get()); | |||||
} else { | |||||
if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { | |||||
auto&& descr = PyArray_DescrFromObject(handle, nullptr); | |||||
tensors.emplace_back(descr); | |||||
continue; | |||||
} | |||||
if (py::isinstance<PySymbolVar>(py::handle(handle))) { | |||||
auto var = py::handle(handle).cast<PySymbolVar*>(); | |||||
mgb::DType type = var->m_node->dtype(); | |||||
auto&& descr = npy::dtype_mgb2np_descr(type); | |||||
Py_INCREF(descr.get()); | |||||
tensors.emplace_back(descr.get()); | |||||
continue; | |||||
} | |||||
PyArray_Descr* descr = scalar2dtype(handle); | |||||
if (descr) { | |||||
scalars.emplace_back(descr); | |||||
continue; | |||||
} | |||||
} | |||||
} | |||||
auto max_pri_scalars = max_priority(scalars); | |||||
auto max_pri_tensors = max_priority(tensors); | |||||
if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { | |||||
throw py::value_error("invalid input, no dtype avaliable"); | |||||
} | |||||
PyArray_Descr* res; | |||||
if (max_pri_scalars > max_pri_tensors) { | |||||
res = promote_types(scalars, max_pri_scalars); | |||||
} else { | |||||
res = promote_types(tensors, max_pri_tensors); | |||||
} | |||||
for (auto* p : tensors) { | |||||
Py_DECREF(p); | |||||
} | |||||
for (auto* p : scalars) { | |||||
Py_DECREF(p); | |||||
} | |||||
Py_XDECREF(tuple); | |||||
return res; | |||||
} | |||||
CompNode _get_device(PyObject* const* args, size_t nargs) { | |||||
bool is_tuple = false; | |||||
PyObject* tuple = nullptr; | |||||
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { | |||||
if (PyList_Check(args[0])) { | |||||
tuple = PyList_AsTuple(args[0]); | |||||
} else { | |||||
tuple = args[0]; | |||||
Py_INCREF(tuple); | |||||
} | |||||
nargs = PyTuple_Size(tuple); | |||||
is_tuple = true; | |||||
} | |||||
bool valid = false; | |||||
CompNode cn; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; | |||||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||||
bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle)); | |||||
if (tw || is_symvar) { | |||||
if (!valid) { | |||||
cn = tw ? tw->m_tensor->comp_node() | |||||
: py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node(); | |||||
valid = true; | |||||
} else { | |||||
CompNode cn1 = tw ? tw->m_tensor->comp_node() | |||||
: py::handle(handle) | |||||
.cast<PySymbolVar*>() | |||||
->m_node->comp_node(); | |||||
if (cn1 != cn) { | |||||
throw py::value_error(ssprintf( | |||||
"ambiguous device: %s (from %s) vs %s (from %s)", | |||||
cn.to_string().c_str(), cn.to_string_logical().c_str(), | |||||
cn1.to_string().c_str(), cn1.to_string_logical().c_str())); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
if (!valid) { | |||||
return CompNode::load(get_default_device()); | |||||
} | |||||
Py_XDECREF(tuple); | |||||
return cn; | |||||
} | |||||
// Returns the dtype that would result from performing an arithmetic | |||||
// operation on the provided input tensors and scalars. | |||||
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
if (!nargs) { | |||||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
return nullptr; | |||||
} | |||||
try { | |||||
PyArray_Descr* res = _dtype_promotion(args, nargs); | |||||
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
if (!nargs) { | |||||
PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); | |||||
return nullptr; | |||||
} | |||||
try { | |||||
CompNode cn = _get_device(args, nargs); | |||||
return py::cast(cn).release().ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
bool is_scalar(PyObject* tensor) { | bool is_scalar(PyObject* tensor) { | ||||
if (py::isinstance<PySymbolVar>(py::handle(tensor))) { | if (py::isinstance<PySymbolVar>(py::handle(tensor))) { | ||||
auto var = py::handle(tensor).cast<PySymbolVar*>(); | auto var = py::handle(tensor).cast<PySymbolVar*>(); | ||||
@@ -147,7 +364,6 @@ py::object _Const( | |||||
"dmap_callback"); | "dmap_callback"); | ||||
if (dmap.ptr() != Py_None) { | if (dmap.ptr() != Py_None) { | ||||
device_obj = dmap(device); | device_obj = dmap(device); | ||||
py::print(device_obj); | |||||
} else { | } else { | ||||
device_obj = py::cast(CompNode::load(device.cast<std::string>())); | device_obj = py::cast(CompNode::load(device.cast<std::string>())); | ||||
} | } | ||||
@@ -1072,6 +1288,92 @@ py::object _reshape_cpp(py::handle inp_hdl, py::handle args) { | |||||
return ret[0]; | return ret[0]; | ||||
} | } | ||||
mgb::DType _get_dtype(py::handle tensor) { | |||||
if (auto tw = TensorWrapper::try_cast(tensor.ptr())) { | |||||
return tw->m_tensor->dtype(); | |||||
} else { | |||||
auto var = tensor.cast<PySymbolVar*>(); | |||||
return var->m_node->dtype(); | |||||
} | |||||
} | |||||
py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { | |||||
PyArray_Descr* descr; | |||||
if (!PyArray_DescrConverter(dtype_hdl.ptr(), &descr)) { | |||||
throw py::value_error(ssprintf( | |||||
"can not convert to numpy.dtype from %s", | |||||
dtype_hdl.ptr()->ob_type->tp_name)); | |||||
} | |||||
PyArray_Descr* cur = npy::dtype_mgb2np_descr(_get_dtype(tensor)).get(); | |||||
if (!dtype_equal(cur, descr)) { | |||||
std::shared_ptr<OpDef> op = TypeCvt::make(npy::dtype_np2mgb_descr(descr)); | |||||
py::object Op = py::cast(op); | |||||
std::vector<PyObject*> p; | |||||
p.resize(2); | |||||
p[0] = Op.ptr(); | |||||
p[1] = tensor.ptr(); | |||||
py::tuple ret = | |||||
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||||
return ret[0]; | |||||
} else { | |||||
return py::reinterpret_borrow<py::object>(tensor); | |||||
} | |||||
} | |||||
py::object _convert_single_value_cpp( | |||||
py::handle value, py::handle dtype, py::handle device) { | |||||
if (is_tensor_or_symbolvar(value)) { | |||||
if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) { | |||||
return _astype_cpp(value, dtype); | |||||
} | |||||
} else { | |||||
return _Const(value, dtype, device, py::none()); | |||||
} | |||||
return py::reinterpret_borrow<py::object>(value); | |||||
} | |||||
py::object _convert_inputs_cpp( | |||||
PyObject* const* args, size_t nargs, py::object dtype, py::object device) { | |||||
ComputingGraph* graph = nullptr; | |||||
py::handle typeobj; | |||||
py::list lis; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
py::handle h = py::handle(args[i]); | |||||
lis.append(h); | |||||
if (py::isinstance<PySymbolVar>(h)) { | |||||
auto var = h.cast<PySymbolVar*>(); | |||||
auto g = var->m_node->owner_graph(); | |||||
if (!graph) { | |||||
graph = g; | |||||
typeobj = h.get_type(); | |||||
} else { | |||||
mgb_assert(graph == g); | |||||
} | |||||
} | |||||
} | |||||
if (graph) { | |||||
CompNode cn = device.cast<CompNode>(); | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
OperatorNodeConfig config(cn); | |||||
auto hv = npy::np2tensor( | |||||
lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>()); | |||||
if (py::isinstance<PySymbolVar>(lis[i])) { | |||||
lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); | |||||
} | |||||
} | |||||
} | |||||
auto convert = [&](py::object value) { | |||||
if (value.ptr() == Py_None) { | |||||
return value; | |||||
} | |||||
return _convert_single_value_cpp(value, dtype, device); | |||||
}; | |||||
for (size_t i = 0; i < lis.size(); ++i) { | |||||
lis[i] = convert(lis[i]); | |||||
} | |||||
return py::reinterpret_steal<py::tuple>(PyList_AsTuple(lis.ptr())); | |||||
} | |||||
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | ||||
try { | try { | ||||
return _make_shape_tuple(py::handle(args[0])).release().ptr(); | return _make_shape_tuple(py::handle(args[0])).release().ptr(); | ||||
@@ -1152,4 +1454,38 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
} | } | ||||
PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
try { | |||||
return _astype_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
PyObject* convert_single_value_cpp( | |||||
PyObject* self, PyObject* const* args, size_t nargs) { | |||||
try { | |||||
return _convert_single_value_cpp( | |||||
py::handle(args[0]), py::handle(args[1]), py::handle(args[2])) | |||||
.release() | |||||
.ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
try { | |||||
py::object dtype = py::reinterpret_steal<py::object>( | |||||
dtype_promotion(self, args, nargs - 1)); | |||||
py::object device; | |||||
if (args[nargs - 1] == Py_None) { | |||||
device = py::reinterpret_steal<py::object>( | |||||
get_device(self, args, nargs - 1)); | |||||
} else { | |||||
device = py::reinterpret_borrow<py::object>(args[nargs - 1]); | |||||
} | |||||
return _convert_inputs_cpp(args, nargs - 1, dtype, device).release().ptr(); | |||||
} | |||||
PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
} | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -2,6 +2,10 @@ | |||||
namespace mgb::imperative::python { | namespace mgb::imperative::python { | ||||
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs); | ||||
PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
@@ -22,4 +26,10 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); | ||||
PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |