Browse Source

perf(mge/utils): move convert_input into C++

GitOrigin-RevId: 0d1cd36251
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
fa62f6c06e
8 changed files with 374 additions and 271 deletions
  1. +3
    -9
      imperative/python/megengine/core/tensor/array_method.py
  2. +7
    -43
      imperative/python/megengine/core/tensor/utils.py
  3. +0
    -1
      imperative/python/megengine/functional/nn.py
  4. +6
    -0
      imperative/python/src/numpy_dtypes.cpp
  5. +5
    -0
      imperative/python/src/numpy_dtypes.h
  6. +6
    -217
      imperative/python/src/tensor.cpp
  7. +337
    -1
      imperative/python/src/tensor_utils.cpp
  8. +10
    -0
      imperative/python/src/tensor_utils.h

+ 3
- 9
imperative/python/megengine/core/tensor/array_method.py View File

@@ -19,6 +19,7 @@ from .._imperative_rt.core2 import (
SymbolVar, SymbolVar,
Tensor, Tensor,
apply, apply,
astype_cpp,
broadcast_cpp, broadcast_cpp,
dtype_promotion, dtype_promotion,
) )
@@ -27,14 +28,7 @@ from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp
from ..ops import builtin from ..ops import builtin
from . import amp from . import amp
from .indexing import getitem, setitem from .indexing import getitem, setitem
from .utils import (
_normalize_axis,
astensor1d,
astype,
cast_tensors,
make_shape_tuple,
subgraph,
)
from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph


_ElwMod = builtin.Elemwise.Mode _ElwMod = builtin.Elemwise.Mode


@@ -605,7 +599,7 @@ class ArrayMethodMixin(abc.ABC):
r"""Returns a :class:`Tensor` with the same data and number of elements r"""Returns a :class:`Tensor` with the same data and number of elements
with the specified :attr:`~.Tensor.dtype`. with the specified :attr:`~.Tensor.dtype`.
""" """
return astype(self, dtype)
return astype_cpp(self, dtype)


def reshape(self, *args): def reshape(self, *args):
r"""See :func:`~.reshape`.""" r"""See :func:`~.reshape`."""


+ 7
- 43
imperative/python/megengine/core/tensor/utils.py View File

@@ -20,6 +20,9 @@ from .._imperative_rt.core2 import (
_get_convert_inputs, _get_convert_inputs,
_set_convert_inputs, _set_convert_inputs,
apply, apply,
astype_cpp,
convert_inputs_cpp,
convert_single_value_cpp,
dtype_promotion, dtype_promotion,
get_device, get_device,
make_shape_tuple, make_shape_tuple,
@@ -55,53 +58,14 @@ def concatenate(inputs, axis=0, *, device=None):
return result return result




def astype(x, dtype):
dtype = np.dtype(dtype)
if not is_dtype_equal(x.dtype, dtype):
(x,) = apply(builtin.TypeCvt(dtype=dtype), x)
return x


def convert_single_value(v, *, dtype=None, device=None): def convert_single_value(v, *, dtype=None, device=None):
if isinstance(v, (Tensor, SymbolVar)):
if not is_quantize(v.dtype):
v = astype(v, dtype)
else:
v = Const(v, dtype, device, None)
return v
return convert_single_value_cpp(v, dtype, device)




def convert_inputs(*args, device=None): def convert_inputs(*args, device=None):
if not _get_convert_inputs(): if not _get_convert_inputs():
return args return args

dtype = dtype_promotion(args)
if device is None:
device = get_device(args)
device = as_device(device)

graph = None
sym_type = None
for a in args:
if isinstance(a, SymbolVar):
if graph is None:
graph = a.var.graph
sym_type = type(a)
else:
assert graph == a.var.graph
args = list(args)
if graph is not None:
for i in range(len(args)):
if not isinstance(args[i], SymbolVar):
rst = make_const(graph, np.array(args[i]), device.to_c(), dtype)
args[i] = sym_type(rst)

def convert(value):
if value is None:
return value
return convert_single_value(value, dtype=dtype, device=device.to_c())

return tuple(map(convert, args))
return convert_inputs_cpp(*args, device)




def cast_tensors(*args, promote=False): def cast_tensors(*args, promote=False):
@@ -146,7 +110,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
pass pass
except ValueError: except ValueError:
if dtype is not None and dtype != x.dtype: if dtype is not None and dtype != x.dtype:
x = astype(x, dtype)
x = astype_cpp(x, dtype)
if device is not None: if device is not None:
cn = as_device(device).to_c() cn = as_device(device).to_c()
(x,) = apply(builtin.Copy(comp_node=cn), x) (x,) = apply(builtin.Copy(comp_node=cn), x)
@@ -164,7 +128,7 @@ def astensor1d(x, *reference, dtype=None, device=None):
if any(isinstance(i, (Tensor, SymbolVar)) for i in x): if any(isinstance(i, (Tensor, SymbolVar)) for i in x):
x = concatenate(x, device=device) if len(x) > 1 else x[0] x = concatenate(x, device=device) if len(x) > 1 else x[0]
if dtype is not None: if dtype is not None:
x = astype(x, dtype)
x = astype_cpp(x, dtype)
return x return x
x = Const(x, dtype, device, reference) x = Const(x, dtype, device, reference)
return x return x


+ 0
- 1
imperative/python/megengine/functional/nn.py View File

@@ -30,7 +30,6 @@ from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import ( from ..core.tensor.utils import (
astensor1d, astensor1d,
astype,
cast_tensors, cast_tensors,
convert_single_value, convert_single_value,
make_shape_tuple, make_shape_tuple,


+ 6
- 0
imperative/python/src/numpy_dtypes.cpp View File

@@ -170,6 +170,12 @@ struct _wrap {


} // anonymous namespace } // anonymous namespace


namespace imperative::python {
bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) {
return _is_dtype_equal(dt1, dt2);
}
} // namespace imperative::python

#ifdef METH_FASTCALL #ifdef METH_FASTCALL
#define MGE_PY_INTERFACE(NAME, FUN) \ #define MGE_PY_INTERFACE(NAME, FUN) \
{ #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr } { #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr }


+ 5
- 0
imperative/python/src/numpy_dtypes.h View File

@@ -26,6 +26,11 @@
cb(BFloat16, npy_num_bfloat16()) cb(BFloat16, npy_num_bfloat16())


namespace mgb { namespace mgb {

namespace imperative::python {
bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2);
} // namespace imperative::python

//! numpy type num for intb1/2/4 type //! numpy type num for intb1/2/4 type
#define DEFINE_NPY_INTBX(n) int npy_num_intb##n(); #define DEFINE_NPY_INTBX(n) int npy_num_intb##n();
FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX)


+ 6
- 217
imperative/python/src/tensor.cpp View File

@@ -400,223 +400,6 @@ struct TensorWeakRef {
int _use_cnt() { return wptr.use_count(); } int _use_cnt() { return wptr.use_count(); }
}; };


/* ============== convert inputs ============== */

// map numpy.dtype.kind to priority
inline uint8_t category_priority(char c) {
switch (c) {
case 'f':
return 3; // floating-point
case 'i':
return 2; // signed integer
case 'u':
return 2; // unsigned integer
case 'b':
return 1; // boolean
default:
return 0;
}
}

// Returns the maximum value of the priority of each type in the list `types`.
uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
if (types.size() == 0) {
return 0;
} else {
uint8_t max_p = 0;
for (auto&& desc : types) {
max_p = std::max(max_p, category_priority(desc->kind));
}
return max_p;
}
}

// Returns the data type with sufficient size to hold all types of
// category `cat` in the list `types`.
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
// Return value: New reference
SmallVector<PyArray_Descr*> used_types;
for (auto&& desc : types) {
auto&& v = category_priority(desc->kind);
if (v == cat) {
used_types.emplace_back(desc);
}
}
mgb_assert(used_types.size() > 0, "size of used_types is 0");
PyArray_Descr* res = used_types[0];
Py_INCREF(res);

for (size_t i = 1; i < used_types.size(); ++i) {
PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
Py_DECREF(res);
res = tmp;
}
return res;
}

PyArray_Descr* scalar2dtype(PyObject* arg) {
// Return value: New reference
if (PyBool_Check(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_BOOL);
return descr;
}
if (PyLong_CheckExact(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_INT32);
return descr;
}
if (PyFloat_CheckExact(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
return descr;
}
return nullptr;
}

PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
// Return value: New reference
SmallVector<PyArray_Descr*> tensors;
SmallVector<PyArray_Descr*> scalars;

bool is_tuple = false;
PyObject* tuple = nullptr;
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
if (PyList_Check(args[0])) {
tuple = PyList_AsTuple(args[0]);
} else {
tuple = args[0];
Py_INCREF(tuple);
}
nargs = PyTuple_Size(tuple);
is_tuple = true;
}

for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
if (handle == Py_None)
continue;
TensorWrapper* tw = TensorWrapper::try_cast(handle);
if (tw) {
mgb::DType type = tw->m_tensor->dtype();
auto&& descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
} else {
if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
auto&& descr = PyArray_DescrFromObject(handle, nullptr);
tensors.emplace_back(descr);
continue;
}

if (py::isinstance<PySymbolVar>(py::handle(handle))) {
auto var = py::handle(handle).cast<PySymbolVar*>();
mgb::DType type = var->m_node->dtype();
auto&& descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
continue;
}

PyArray_Descr* descr = scalar2dtype(handle);
if (descr) {
scalars.emplace_back(descr);
continue;
}
}
}

auto max_pri_scalars = max_priority(scalars);
auto max_pri_tensors = max_priority(tensors);

if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
throw py::value_error("invalid input, no dtype avaliable");
}
PyArray_Descr* res;
if (max_pri_scalars > max_pri_tensors) {
res = promote_types(scalars, max_pri_scalars);
} else {
res = promote_types(tensors, max_pri_tensors);
}
for (auto* p : tensors) {
Py_DECREF(p);
}
for (auto* p : scalars) {
Py_DECREF(p);
}
Py_XDECREF(tuple);
return res;
}

CompNode _get_device(PyObject* const* args, size_t nargs) {
bool is_tuple = false;
PyObject* tuple = nullptr;
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
if (PyList_Check(args[0])) {
tuple = PyList_AsTuple(args[0]);
} else {
tuple = args[0];
Py_INCREF(tuple);
}
nargs = PyTuple_Size(tuple);
is_tuple = true;
}
bool valid = false;
CompNode cn;
for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
TensorWrapper* tw = TensorWrapper::try_cast(handle);

bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
if (tw || is_symvar) {
if (!valid) {
cn = tw ? tw->m_tensor->comp_node()
: py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node();
valid = true;
} else {
CompNode cn1 = tw ? tw->m_tensor->comp_node()
: py::handle(handle)
.cast<PySymbolVar*>()
->m_node->comp_node();
if (cn1 != cn) {
throw py::value_error(ssprintf(
"ambiguous device: %s (from %s) vs %s (from %s)",
cn.to_string().c_str(), cn.to_string_logical().c_str(),
cn1.to_string().c_str(), cn1.to_string_logical().c_str()));
}
}
}
}
if (!valid) {
return CompNode::load(get_default_device());
}
Py_XDECREF(tuple);
return cn;
}

// Returns the dtype that would result from performing an arithmetic
// operation on the provided input tensors and scalars.
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) {
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
return nullptr;
}
try {
PyArray_Descr* res = _dtype_promotion(args, nargs);
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}

PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
return nullptr;
}
try {
CompNode cn = _get_device(args, nargs);
return py::cast(cn).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}

#ifdef METH_FASTCALL #ifdef METH_FASTCALL
#define MGE_PY_INTERFACE(NAME, FUNC) \ #define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr } { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
@@ -640,6 +423,9 @@ WRAP_FUNC_PY35(transpose_cpp);
WRAP_FUNC_PY35(broadcast_cpp); WRAP_FUNC_PY35(broadcast_cpp);
WRAP_FUNC_PY35(reshape_cpp); WRAP_FUNC_PY35(reshape_cpp);
WRAP_FUNC_PY35(Const); WRAP_FUNC_PY35(Const);
WRAP_FUNC_PY35(astype_cpp);
WRAP_FUNC_PY35(convert_single_value_cpp);
WRAP_FUNC_PY35(convert_inputs_cpp);
#undef WRAP_FUNC_PY35 #undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \ #define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
@@ -779,6 +565,9 @@ void init_tensor(py::module m) {
MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
MGE_PY_INTERFACE(Const, Const), MGE_PY_INTERFACE(Const, Const),
MGE_PY_INTERFACE(astype_cpp, astype_cpp),
MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
{nullptr, nullptr, 0, nullptr}}; {nullptr, nullptr, 0, nullptr}};
for (auto&& def : method_defs) { for (auto&& def : method_defs) {
if (def.ml_meth != nullptr) { if (def.ml_meth != nullptr) {


+ 337
- 1
imperative/python/src/tensor_utils.cpp View File

@@ -52,6 +52,223 @@ namespace views = ranges::views;


namespace mgb::imperative::python { namespace mgb::imperative::python {


/* ============== convert inputs ============== */

// map numpy.dtype.kind to priority
inline uint8_t category_priority(char c) {
switch (c) {
case 'f':
return 3; // floating-point
case 'i':
return 2; // signed integer
case 'u':
return 2; // unsigned integer
case 'b':
return 1; // boolean
default:
return 0;
}
}

// Returns the maximum value of the priority of each type in the list `types`.
uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
if (types.size() == 0) {
return 0;
} else {
uint8_t max_p = 0;
for (auto&& desc : types) {
max_p = std::max(max_p, category_priority(desc->kind));
}
return max_p;
}
}

// Returns the data type with sufficient size to hold all types of
// category `cat` in the list `types`.
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
// Return value: New reference
SmallVector<PyArray_Descr*> used_types;
for (auto&& desc : types) {
auto&& v = category_priority(desc->kind);
if (v == cat) {
used_types.emplace_back(desc);
}
}
mgb_assert(used_types.size() > 0, "size of used_types is 0");
PyArray_Descr* res = used_types[0];
Py_INCREF(res);

for (size_t i = 1; i < used_types.size(); ++i) {
PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
Py_DECREF(res);
res = tmp;
}
return res;
}

PyArray_Descr* scalar2dtype(PyObject* arg) {
// Return value: New reference
if (PyBool_Check(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_BOOL);
return descr;
}
if (PyLong_CheckExact(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_INT32);
return descr;
}
if (PyFloat_CheckExact(arg)) {
auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
return descr;
}
return nullptr;
}

PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
// Return value: New reference
SmallVector<PyArray_Descr*> tensors;
SmallVector<PyArray_Descr*> scalars;

bool is_tuple = false;
PyObject* tuple = nullptr;
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
if (PyList_Check(args[0])) {
tuple = PyList_AsTuple(args[0]);
} else {
tuple = args[0];
Py_INCREF(tuple);
}
nargs = PyTuple_Size(tuple);
is_tuple = true;
}

for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
if (handle == Py_None)
continue;
TensorWrapper* tw = TensorWrapper::try_cast(handle);
if (tw) {
mgb::DType type = tw->m_tensor->dtype();
auto&& descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
} else {
if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
auto&& descr = PyArray_DescrFromObject(handle, nullptr);
tensors.emplace_back(descr);
continue;
}

if (py::isinstance<PySymbolVar>(py::handle(handle))) {
auto var = py::handle(handle).cast<PySymbolVar*>();
mgb::DType type = var->m_node->dtype();
auto&& descr = npy::dtype_mgb2np_descr(type);
Py_INCREF(descr.get());
tensors.emplace_back(descr.get());
continue;
}

PyArray_Descr* descr = scalar2dtype(handle);
if (descr) {
scalars.emplace_back(descr);
continue;
}
}
}

auto max_pri_scalars = max_priority(scalars);
auto max_pri_tensors = max_priority(tensors);

if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
throw py::value_error("invalid input, no dtype avaliable");
}
PyArray_Descr* res;
if (max_pri_scalars > max_pri_tensors) {
res = promote_types(scalars, max_pri_scalars);
} else {
res = promote_types(tensors, max_pri_tensors);
}
for (auto* p : tensors) {
Py_DECREF(p);
}
for (auto* p : scalars) {
Py_DECREF(p);
}
Py_XDECREF(tuple);
return res;
}

CompNode _get_device(PyObject* const* args, size_t nargs) {
bool is_tuple = false;
PyObject* tuple = nullptr;
if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
if (PyList_Check(args[0])) {
tuple = PyList_AsTuple(args[0]);
} else {
tuple = args[0];
Py_INCREF(tuple);
}
nargs = PyTuple_Size(tuple);
is_tuple = true;
}
bool valid = false;
CompNode cn;
for (size_t i = 0; i < nargs; ++i) {
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
TensorWrapper* tw = TensorWrapper::try_cast(handle);

bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
if (tw || is_symvar) {
if (!valid) {
cn = tw ? tw->m_tensor->comp_node()
: py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node();
valid = true;
} else {
CompNode cn1 = tw ? tw->m_tensor->comp_node()
: py::handle(handle)
.cast<PySymbolVar*>()
->m_node->comp_node();
if (cn1 != cn) {
throw py::value_error(ssprintf(
"ambiguous device: %s (from %s) vs %s (from %s)",
cn.to_string().c_str(), cn.to_string_logical().c_str(),
cn1.to_string().c_str(), cn1.to_string_logical().c_str()));
}
}
}
}
if (!valid) {
return CompNode::load(get_default_device());
}
Py_XDECREF(tuple);
return cn;
}

// Returns the dtype that would result from performing an arithmetic
// operation on the provided input tensors and scalars.
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) {
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
return nullptr;
}
try {
PyArray_Descr* res = _dtype_promotion(args, nargs);
return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}

PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
if (!nargs) {
PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
return nullptr;
}
try {
CompNode cn = _get_device(args, nargs);
return py::cast(cn).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}

bool is_scalar(PyObject* tensor) { bool is_scalar(PyObject* tensor) {
if (py::isinstance<PySymbolVar>(py::handle(tensor))) { if (py::isinstance<PySymbolVar>(py::handle(tensor))) {
auto var = py::handle(tensor).cast<PySymbolVar*>(); auto var = py::handle(tensor).cast<PySymbolVar*>();
@@ -147,7 +364,6 @@ py::object _Const(
"dmap_callback"); "dmap_callback");
if (dmap.ptr() != Py_None) { if (dmap.ptr() != Py_None) {
device_obj = dmap(device); device_obj = dmap(device);
py::print(device_obj);
} else { } else {
device_obj = py::cast(CompNode::load(device.cast<std::string>())); device_obj = py::cast(CompNode::load(device.cast<std::string>()));
} }
@@ -1072,6 +1288,92 @@ py::object _reshape_cpp(py::handle inp_hdl, py::handle args) {
return ret[0]; return ret[0];
} }


mgb::DType _get_dtype(py::handle tensor) {
if (auto tw = TensorWrapper::try_cast(tensor.ptr())) {
return tw->m_tensor->dtype();
} else {
auto var = tensor.cast<PySymbolVar*>();
return var->m_node->dtype();
}
}

py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
PyArray_Descr* descr;
if (!PyArray_DescrConverter(dtype_hdl.ptr(), &descr)) {
throw py::value_error(ssprintf(
"can not convert to numpy.dtype from %s",
dtype_hdl.ptr()->ob_type->tp_name));
}
PyArray_Descr* cur = npy::dtype_mgb2np_descr(_get_dtype(tensor)).get();
if (!dtype_equal(cur, descr)) {
std::shared_ptr<OpDef> op = TypeCvt::make(npy::dtype_np2mgb_descr(descr));
py::object Op = py::cast(op);
std::vector<PyObject*> p;
p.resize(2);
p[0] = Op.ptr();
p[1] = tensor.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
} else {
return py::reinterpret_borrow<py::object>(tensor);
}
}

py::object _convert_single_value_cpp(
py::handle value, py::handle dtype, py::handle device) {
if (is_tensor_or_symbolvar(value)) {
if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) {
return _astype_cpp(value, dtype);
}
} else {
return _Const(value, dtype, device, py::none());
}
return py::reinterpret_borrow<py::object>(value);
}

py::object _convert_inputs_cpp(
PyObject* const* args, size_t nargs, py::object dtype, py::object device) {
ComputingGraph* graph = nullptr;
py::handle typeobj;
py::list lis;
for (size_t i = 0; i < nargs; ++i) {
py::handle h = py::handle(args[i]);
lis.append(h);
if (py::isinstance<PySymbolVar>(h)) {
auto var = h.cast<PySymbolVar*>();
auto g = var->m_node->owner_graph();
if (!graph) {
graph = g;
typeobj = h.get_type();
} else {
mgb_assert(graph == g);
}
}
}
if (graph) {
CompNode cn = device.cast<CompNode>();
for (size_t i = 0; i < nargs; ++i) {
OperatorNodeConfig config(cn);
auto hv = npy::np2tensor(
lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
if (py::isinstance<PySymbolVar>(lis[i])) {
lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
}
}
}
auto convert = [&](py::object value) {
if (value.ptr() == Py_None) {
return value;
}
return _convert_single_value_cpp(value, dtype, device);
};
for (size_t i = 0; i < lis.size(); ++i) {
lis[i] = convert(lis[i]);
}
return py::reinterpret_steal<py::tuple>(PyList_AsTuple(lis.ptr()));
}

PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try { try {
return _make_shape_tuple(py::handle(args[0])).release().ptr(); return _make_shape_tuple(py::handle(args[0])).release().ptr();
@@ -1152,4 +1454,38 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr) PYEXT17_TRANSLATE_EXC_RET(nullptr)
} }


PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _astype_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}

PyObject* convert_single_value_cpp(
PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _convert_single_value_cpp(
py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
.release()
.ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}

PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
py::object dtype = py::reinterpret_steal<py::object>(
dtype_promotion(self, args, nargs - 1));
py::object device;
if (args[nargs - 1] == Py_None) {
device = py::reinterpret_steal<py::object>(
get_device(self, args, nargs - 1));
} else {
device = py::reinterpret_borrow<py::object>(args[nargs - 1]);
}
return _convert_inputs_cpp(args, nargs - 1, dtype, device).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}

} // namespace mgb::imperative::python } // namespace mgb::imperative::python

+ 10
- 0
imperative/python/src/tensor_utils.h View File

@@ -2,6 +2,10 @@


namespace mgb::imperative::python { namespace mgb::imperative::python {


PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs);

PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs);

PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs); PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs);


PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs);
@@ -22,4 +26,10 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs);


PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs);


PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs);

PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs);

PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs);

} // namespace mgb::imperative::python } // namespace mgb::imperative::python

Loading…
Cancel
Save