diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 00148234..8133221f 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -19,6 +19,7 @@ from .._imperative_rt.core2 import ( SymbolVar, Tensor, apply, + astype_cpp, broadcast_cpp, dtype_promotion, ) @@ -27,14 +28,7 @@ from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp from ..ops import builtin from . import amp from .indexing import getitem, setitem -from .utils import ( - _normalize_axis, - astensor1d, - astype, - cast_tensors, - make_shape_tuple, - subgraph, -) +from .utils import _normalize_axis, astensor1d, cast_tensors, make_shape_tuple, subgraph _ElwMod = builtin.Elemwise.Mode @@ -605,7 +599,7 @@ class ArrayMethodMixin(abc.ABC): r"""Returns a :class:`Tensor` with the same data and number of elements with the specified :attr:`~.Tensor.dtype`. """ - return astype(self, dtype) + return astype_cpp(self, dtype) def reshape(self, *args): r"""See :func:`~.reshape`.""" diff --git a/imperative/python/megengine/core/tensor/utils.py b/imperative/python/megengine/core/tensor/utils.py index 9b14d227..2396ff89 100644 --- a/imperative/python/megengine/core/tensor/utils.py +++ b/imperative/python/megengine/core/tensor/utils.py @@ -20,6 +20,9 @@ from .._imperative_rt.core2 import ( _get_convert_inputs, _set_convert_inputs, apply, + astype_cpp, + convert_inputs_cpp, + convert_single_value_cpp, dtype_promotion, get_device, make_shape_tuple, @@ -55,53 +58,14 @@ def concatenate(inputs, axis=0, *, device=None): return result -def astype(x, dtype): - dtype = np.dtype(dtype) - if not is_dtype_equal(x.dtype, dtype): - (x,) = apply(builtin.TypeCvt(dtype=dtype), x) - return x - - def convert_single_value(v, *, dtype=None, device=None): - if isinstance(v, (Tensor, SymbolVar)): - if not is_quantize(v.dtype): - v = astype(v, dtype) - else: - v = Const(v, dtype, device, None) - return v + return convert_single_value_cpp(v, dtype, device) def convert_inputs(*args, device=None): if not _get_convert_inputs(): return args - - dtype = dtype_promotion(args) - if device is None: - device = get_device(args) - device = as_device(device) - - graph = None - sym_type = None - for a in args: - if isinstance(a, SymbolVar): - if graph is None: - graph = a.var.graph - sym_type = type(a) - else: - assert graph == a.var.graph - args = list(args) - if graph is not None: - for i in range(len(args)): - if not isinstance(args[i], SymbolVar): - rst = make_const(graph, np.array(args[i]), device.to_c(), dtype) - args[i] = sym_type(rst) - - def convert(value): - if value is None: - return value - return convert_single_value(value, dtype=dtype, device=device.to_c()) - - return tuple(map(convert, args)) + return convert_inputs_cpp(*args, device) def cast_tensors(*args, promote=False): @@ -146,7 +110,7 @@ def astensor1d(x, *reference, dtype=None, device=None): pass except ValueError: if dtype is not None and dtype != x.dtype: - x = astype(x, dtype) + x = astype_cpp(x, dtype) if device is not None: cn = as_device(device).to_c() (x,) = apply(builtin.Copy(comp_node=cn), x) @@ -164,7 +128,7 @@ def astensor1d(x, *reference, dtype=None, device=None): if any(isinstance(i, (Tensor, SymbolVar)) for i in x): x = concatenate(x, device=device) if len(x) > 1 else x[0] if dtype is not None: - x = astype(x, dtype) + x = astype_cpp(x, dtype) return x x = Const(x, dtype, device, reference) return x diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 3110aa08..a0fb5708 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -30,7 +30,6 @@ from ..core.tensor import amp, megbrain_graph from ..core.tensor.array_method import _elwise_apply from ..core.tensor.utils import ( astensor1d, - astype, cast_tensors, convert_single_value, make_shape_tuple, diff --git a/imperative/python/src/numpy_dtypes.cpp b/imperative/python/src/numpy_dtypes.cpp index 2a048b74..87a215cf 100644 --- a/imperative/python/src/numpy_dtypes.cpp +++ b/imperative/python/src/numpy_dtypes.cpp @@ -170,6 +170,12 @@ struct _wrap { } // anonymous namespace +namespace imperative::python { +bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2) { + return _is_dtype_equal(dt1, dt2); +} +} // namespace imperative::python + #ifdef METH_FASTCALL #define MGE_PY_INTERFACE(NAME, FUN) \ { #NAME, (PyCFunction)_wrap < &(FUN)> ::impl, METH_FASTCALL, nullptr } diff --git a/imperative/python/src/numpy_dtypes.h b/imperative/python/src/numpy_dtypes.h index 1b061060..bed646c5 100644 --- a/imperative/python/src/numpy_dtypes.h +++ b/imperative/python/src/numpy_dtypes.h @@ -26,6 +26,11 @@ cb(BFloat16, npy_num_bfloat16()) namespace mgb { + +namespace imperative::python { +bool dtype_equal(PyArray_Descr* dt1, PyArray_Descr* dt2); +} // namespace imperative::python + //! numpy type num for intb1/2/4 type #define DEFINE_NPY_INTBX(n) int npy_num_intb##n(); FOREACH_MGB_LOW_BIT(DEFINE_NPY_INTBX) diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index e231dbbe..d466e790 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -400,223 +400,6 @@ struct TensorWeakRef { int _use_cnt() { return wptr.use_count(); } }; -/* ============== convert inputs ============== */ - -// map numpy.dtype.kind to priority -inline uint8_t category_priority(char c) { - switch (c) { - case 'f': - return 3; // floating-point - case 'i': - return 2; // signed integer - case 'u': - return 2; // unsigned integer - case 'b': - return 1; // boolean - default: - return 0; - } -} - -// Returns the maximum value of the priority of each type in the list `types`. -uint8_t max_priority(SmallVector types) { - if (types.size() == 0) { - return 0; - } else { - uint8_t max_p = 0; - for (auto&& desc : types) { - max_p = std::max(max_p, category_priority(desc->kind)); - } - return max_p; - } -} - -// Returns the data type with sufficient size to hold all types of -// category `cat` in the list `types`. -PyArray_Descr* promote_types(SmallVector types, uint8_t cat) { - // Return value: New reference - SmallVector used_types; - for (auto&& desc : types) { - auto&& v = category_priority(desc->kind); - if (v == cat) { - used_types.emplace_back(desc); - } - } - mgb_assert(used_types.size() > 0, "size of used_types is 0"); - PyArray_Descr* res = used_types[0]; - Py_INCREF(res); - - for (size_t i = 1; i < used_types.size(); ++i) { - PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); - Py_DECREF(res); - res = tmp; - } - return res; -} - -PyArray_Descr* scalar2dtype(PyObject* arg) { - // Return value: New reference - if (PyBool_Check(arg)) { - auto&& descr = PyArray_DescrFromType(NPY_BOOL); - return descr; - } - if (PyLong_CheckExact(arg)) { - auto&& descr = PyArray_DescrFromType(NPY_INT32); - return descr; - } - if (PyFloat_CheckExact(arg)) { - auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); - return descr; - } - return nullptr; -} - -PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { - // Return value: New reference - SmallVector tensors; - SmallVector scalars; - - bool is_tuple = false; - PyObject* tuple = nullptr; - if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { - if (PyList_Check(args[0])) { - tuple = PyList_AsTuple(args[0]); - } else { - tuple = args[0]; - Py_INCREF(tuple); - } - nargs = PyTuple_Size(tuple); - is_tuple = true; - } - - for (size_t i = 0; i < nargs; ++i) { - PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; - if (handle == Py_None) - continue; - TensorWrapper* tw = TensorWrapper::try_cast(handle); - if (tw) { - mgb::DType type = tw->m_tensor->dtype(); - auto&& descr = npy::dtype_mgb2np_descr(type); - Py_INCREF(descr.get()); - tensors.emplace_back(descr.get()); - } else { - if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { - auto&& descr = PyArray_DescrFromObject(handle, nullptr); - tensors.emplace_back(descr); - continue; - } - - if (py::isinstance(py::handle(handle))) { - auto var = py::handle(handle).cast(); - mgb::DType type = var->m_node->dtype(); - auto&& descr = npy::dtype_mgb2np_descr(type); - Py_INCREF(descr.get()); - tensors.emplace_back(descr.get()); - continue; - } - - PyArray_Descr* descr = scalar2dtype(handle); - if (descr) { - scalars.emplace_back(descr); - continue; - } - } - } - - auto max_pri_scalars = max_priority(scalars); - auto max_pri_tensors = max_priority(tensors); - - if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { - throw py::value_error("invalid input, no dtype avaliable"); - } - PyArray_Descr* res; - if (max_pri_scalars > max_pri_tensors) { - res = promote_types(scalars, max_pri_scalars); - } else { - res = promote_types(tensors, max_pri_tensors); - } - for (auto* p : tensors) { - Py_DECREF(p); - } - for (auto* p : scalars) { - Py_DECREF(p); - } - Py_XDECREF(tuple); - return res; -} - -CompNode _get_device(PyObject* const* args, size_t nargs) { - bool is_tuple = false; - PyObject* tuple = nullptr; - if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { - if (PyList_Check(args[0])) { - tuple = PyList_AsTuple(args[0]); - } else { - tuple = args[0]; - Py_INCREF(tuple); - } - nargs = PyTuple_Size(tuple); - is_tuple = true; - } - bool valid = false; - CompNode cn; - for (size_t i = 0; i < nargs; ++i) { - PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; - TensorWrapper* tw = TensorWrapper::try_cast(handle); - - bool is_symvar = py::isinstance(py::handle(handle)); - if (tw || is_symvar) { - if (!valid) { - cn = tw ? tw->m_tensor->comp_node() - : py::handle(handle).cast()->m_node->comp_node(); - valid = true; - } else { - CompNode cn1 = tw ? tw->m_tensor->comp_node() - : py::handle(handle) - .cast() - ->m_node->comp_node(); - if (cn1 != cn) { - throw py::value_error(ssprintf( - "ambiguous device: %s (from %s) vs %s (from %s)", - cn.to_string().c_str(), cn.to_string_logical().c_str(), - cn1.to_string().c_str(), cn1.to_string_logical().c_str())); - } - } - } - } - if (!valid) { - return CompNode::load(get_default_device()); - } - Py_XDECREF(tuple); - return cn; -} - -// Returns the dtype that would result from performing an arithmetic -// operation on the provided input tensors and scalars. -PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) { - if (!nargs) { - PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); - return nullptr; - } - try { - PyArray_Descr* res = _dtype_promotion(args, nargs); - return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); - } - PYEXT17_TRANSLATE_EXC_RET(nullptr) -} - -PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { - if (!nargs) { - PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); - return nullptr; - } - try { - CompNode cn = _get_device(args, nargs); - return py::cast(cn).release().ptr(); - } - PYEXT17_TRANSLATE_EXC_RET(nullptr) -} - #ifdef METH_FASTCALL #define MGE_PY_INTERFACE(NAME, FUNC) \ { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr } @@ -640,6 +423,9 @@ WRAP_FUNC_PY35(transpose_cpp); WRAP_FUNC_PY35(broadcast_cpp); WRAP_FUNC_PY35(reshape_cpp); WRAP_FUNC_PY35(Const); +WRAP_FUNC_PY35(astype_cpp); +WRAP_FUNC_PY35(convert_single_value_cpp); +WRAP_FUNC_PY35(convert_inputs_cpp); #undef WRAP_FUNC_PY35 #define MGE_PY_INTERFACE(NAME, FUNC) \ { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } @@ -779,6 +565,9 @@ void init_tensor(py::module m) { MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), MGE_PY_INTERFACE(Const, Const), + MGE_PY_INTERFACE(astype_cpp, astype_cpp), + MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), + MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), {nullptr, nullptr, 0, nullptr}}; for (auto&& def : method_defs) { if (def.ml_meth != nullptr) { diff --git a/imperative/python/src/tensor_utils.cpp b/imperative/python/src/tensor_utils.cpp index 17f2f391..78ba7481 100644 --- a/imperative/python/src/tensor_utils.cpp +++ b/imperative/python/src/tensor_utils.cpp @@ -52,6 +52,223 @@ namespace views = ranges::views; namespace mgb::imperative::python { +/* ============== convert inputs ============== */ + +// map numpy.dtype.kind to priority +inline uint8_t category_priority(char c) { + switch (c) { + case 'f': + return 3; // floating-point + case 'i': + return 2; // signed integer + case 'u': + return 2; // unsigned integer + case 'b': + return 1; // boolean + default: + return 0; + } +} + +// Returns the maximum value of the priority of each type in the list `types`. +uint8_t max_priority(SmallVector types) { + if (types.size() == 0) { + return 0; + } else { + uint8_t max_p = 0; + for (auto&& desc : types) { + max_p = std::max(max_p, category_priority(desc->kind)); + } + return max_p; + } +} + +// Returns the data type with sufficient size to hold all types of +// category `cat` in the list `types`. +PyArray_Descr* promote_types(SmallVector types, uint8_t cat) { + // Return value: New reference + SmallVector used_types; + for (auto&& desc : types) { + auto&& v = category_priority(desc->kind); + if (v == cat) { + used_types.emplace_back(desc); + } + } + mgb_assert(used_types.size() > 0, "size of used_types is 0"); + PyArray_Descr* res = used_types[0]; + Py_INCREF(res); + + for (size_t i = 1; i < used_types.size(); ++i) { + PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res); + Py_DECREF(res); + res = tmp; + } + return res; +} + +PyArray_Descr* scalar2dtype(PyObject* arg) { + // Return value: New reference + if (PyBool_Check(arg)) { + auto&& descr = PyArray_DescrFromType(NPY_BOOL); + return descr; + } + if (PyLong_CheckExact(arg)) { + auto&& descr = PyArray_DescrFromType(NPY_INT32); + return descr; + } + if (PyFloat_CheckExact(arg)) { + auto&& descr = PyArray_DescrFromType(NPY_FLOAT32); + return descr; + } + return nullptr; +} + +PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) { + // Return value: New reference + SmallVector tensors; + SmallVector scalars; + + bool is_tuple = false; + PyObject* tuple = nullptr; + if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { + if (PyList_Check(args[0])) { + tuple = PyList_AsTuple(args[0]); + } else { + tuple = args[0]; + Py_INCREF(tuple); + } + nargs = PyTuple_Size(tuple); + is_tuple = true; + } + + for (size_t i = 0; i < nargs; ++i) { + PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; + if (handle == Py_None) + continue; + TensorWrapper* tw = TensorWrapper::try_cast(handle); + if (tw) { + mgb::DType type = tw->m_tensor->dtype(); + auto&& descr = npy::dtype_mgb2np_descr(type); + Py_INCREF(descr.get()); + tensors.emplace_back(descr.get()); + } else { + if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) { + auto&& descr = PyArray_DescrFromObject(handle, nullptr); + tensors.emplace_back(descr); + continue; + } + + if (py::isinstance(py::handle(handle))) { + auto var = py::handle(handle).cast(); + mgb::DType type = var->m_node->dtype(); + auto&& descr = npy::dtype_mgb2np_descr(type); + Py_INCREF(descr.get()); + tensors.emplace_back(descr.get()); + continue; + } + + PyArray_Descr* descr = scalar2dtype(handle); + if (descr) { + scalars.emplace_back(descr); + continue; + } + } + } + + auto max_pri_scalars = max_priority(scalars); + auto max_pri_tensors = max_priority(tensors); + + if (max_pri_scalars <= 0 && max_pri_tensors <= 0) { + throw py::value_error("invalid input, no dtype avaliable"); + } + PyArray_Descr* res; + if (max_pri_scalars > max_pri_tensors) { + res = promote_types(scalars, max_pri_scalars); + } else { + res = promote_types(tensors, max_pri_tensors); + } + for (auto* p : tensors) { + Py_DECREF(p); + } + for (auto* p : scalars) { + Py_DECREF(p); + } + Py_XDECREF(tuple); + return res; +} + +CompNode _get_device(PyObject* const* args, size_t nargs) { + bool is_tuple = false; + PyObject* tuple = nullptr; + if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) { + if (PyList_Check(args[0])) { + tuple = PyList_AsTuple(args[0]); + } else { + tuple = args[0]; + Py_INCREF(tuple); + } + nargs = PyTuple_Size(tuple); + is_tuple = true; + } + bool valid = false; + CompNode cn; + for (size_t i = 0; i < nargs; ++i) { + PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i]; + TensorWrapper* tw = TensorWrapper::try_cast(handle); + + bool is_symvar = py::isinstance(py::handle(handle)); + if (tw || is_symvar) { + if (!valid) { + cn = tw ? tw->m_tensor->comp_node() + : py::handle(handle).cast()->m_node->comp_node(); + valid = true; + } else { + CompNode cn1 = tw ? tw->m_tensor->comp_node() + : py::handle(handle) + .cast() + ->m_node->comp_node(); + if (cn1 != cn) { + throw py::value_error(ssprintf( + "ambiguous device: %s (from %s) vs %s (from %s)", + cn.to_string().c_str(), cn.to_string_logical().c_str(), + cn1.to_string().c_str(), cn1.to_string_logical().c_str())); + } + } + } + } + if (!valid) { + return CompNode::load(get_default_device()); + } + Py_XDECREF(tuple); + return cn; +} + +// Returns the dtype that would result from performing an arithmetic +// operation on the provided input tensors and scalars. +PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) { + if (!nargs) { + PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); + return nullptr; + } + try { + PyArray_Descr* res = _dtype_promotion(args, nargs); + return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + +PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) { + if (!nargs) { + PyErr_SetString(PyExc_TypeError, "empty input is not allowed"); + return nullptr; + } + try { + CompNode cn = _get_device(args, nargs); + return py::cast(cn).release().ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + bool is_scalar(PyObject* tensor) { if (py::isinstance(py::handle(tensor))) { auto var = py::handle(tensor).cast(); @@ -147,7 +364,6 @@ py::object _Const( "dmap_callback"); if (dmap.ptr() != Py_None) { device_obj = dmap(device); - py::print(device_obj); } else { device_obj = py::cast(CompNode::load(device.cast())); } @@ -1072,6 +1288,92 @@ py::object _reshape_cpp(py::handle inp_hdl, py::handle args) { return ret[0]; } +mgb::DType _get_dtype(py::handle tensor) { + if (auto tw = TensorWrapper::try_cast(tensor.ptr())) { + return tw->m_tensor->dtype(); + } else { + auto var = tensor.cast(); + return var->m_node->dtype(); + } +} + +py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) { + PyArray_Descr* descr; + if (!PyArray_DescrConverter(dtype_hdl.ptr(), &descr)) { + throw py::value_error(ssprintf( + "can not convert to numpy.dtype from %s", + dtype_hdl.ptr()->ob_type->tp_name)); + } + PyArray_Descr* cur = npy::dtype_mgb2np_descr(_get_dtype(tensor)).get(); + if (!dtype_equal(cur, descr)) { + std::shared_ptr op = TypeCvt::make(npy::dtype_np2mgb_descr(descr)); + py::object Op = py::cast(op); + std::vector p; + p.resize(2); + p[0] = Op.ptr(); + p[1] = tensor.ptr(); + py::tuple ret = + py::reinterpret_steal(py_apply(NULL, p.data(), p.size())); + return ret[0]; + } else { + return py::reinterpret_borrow(tensor); + } +} + +py::object _convert_single_value_cpp( + py::handle value, py::handle dtype, py::handle device) { + if (is_tensor_or_symbolvar(value)) { + if (_get_dtype(value).category() != DTypeCategory::QUANTIZED) { + return _astype_cpp(value, dtype); + } + } else { + return _Const(value, dtype, device, py::none()); + } + return py::reinterpret_borrow(value); +} + +py::object _convert_inputs_cpp( + PyObject* const* args, size_t nargs, py::object dtype, py::object device) { + ComputingGraph* graph = nullptr; + py::handle typeobj; + py::list lis; + for (size_t i = 0; i < nargs; ++i) { + py::handle h = py::handle(args[i]); + lis.append(h); + if (py::isinstance(h)) { + auto var = h.cast(); + auto g = var->m_node->owner_graph(); + if (!graph) { + graph = g; + typeobj = h.get_type(); + } else { + mgb_assert(graph == g); + } + } + } + if (graph) { + CompNode cn = device.cast(); + for (size_t i = 0; i < nargs; ++i) { + OperatorNodeConfig config(cn); + auto hv = npy::np2tensor( + lis[i].ptr(), npy::Meth::borrow(cn), dtype.cast()); + if (py::isinstance(lis[i])) { + lis[i] = typeobj(opr::ImmutableTensor::make(*graph, hv, config).node()); + } + } + } + auto convert = [&](py::object value) { + if (value.ptr() == Py_None) { + return value; + } + return _convert_single_value_cpp(value, dtype, device); + }; + for (size_t i = 0; i < lis.size(); ++i) { + lis[i] = convert(lis[i]); + } + return py::reinterpret_steal(PyList_AsTuple(lis.ptr())); +} + PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { try { return _make_shape_tuple(py::handle(args[0])).release().ptr(); @@ -1152,4 +1454,38 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs) { PYEXT17_TRANSLATE_EXC_RET(nullptr) } +PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _astype_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + +PyObject* convert_single_value_cpp( + PyObject* self, PyObject* const* args, size_t nargs) { + try { + return _convert_single_value_cpp( + py::handle(args[0]), py::handle(args[1]), py::handle(args[2])) + .release() + .ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + +PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs) { + try { + py::object dtype = py::reinterpret_steal( + dtype_promotion(self, args, nargs - 1)); + py::object device; + if (args[nargs - 1] == Py_None) { + device = py::reinterpret_steal( + get_device(self, args, nargs - 1)); + } else { + device = py::reinterpret_borrow(args[nargs - 1]); + } + return _convert_inputs_cpp(args, nargs - 1, dtype, device).release().ptr(); + } + PYEXT17_TRANSLATE_EXC_RET(nullptr) +} + } // namespace mgb::imperative::python diff --git a/imperative/python/src/tensor_utils.h b/imperative/python/src/tensor_utils.h index 90608004..43472ee4 100644 --- a/imperative/python/src/tensor_utils.h +++ b/imperative/python/src/tensor_utils.h @@ -2,6 +2,10 @@ namespace mgb::imperative::python { +PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs); + +PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs); + PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs); PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs); @@ -22,4 +26,10 @@ PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); +PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs); + +PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs); + +PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs); + } // namespace mgb::imperative::python \ No newline at end of file