From 94960ecf426ceb64ce47f72bf0a7b55e159a8ba4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 30 Mar 2022 15:39:46 +0800 Subject: [PATCH] fix(imperative): restrict using convert_inputs in py_apply GitOrigin-RevId: b021aac8a6f35dfe3b87dbc98a8007fd1a5b54b2 --- imperative/python/src/tensor.cpp | 9 +++++++-- imperative/python/test/unit/functional/test_elemwise.py | 5 +++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index dac67b00..73512a0e 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -169,7 +169,7 @@ PyObject* py_apply( } HostTensorND ht(target_cn); ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); - if (PyArray_Check(args[i])) { // non scaler + if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler return imperative::apply( CreateTensor(CreateTensor::Const, target_cn, ht.layout()), HostStorage::make(ht.storage()))[0]; @@ -205,8 +205,13 @@ PyObject* py_apply( for (size_t i = 0; i < nargs; ++i) { if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { tensors[i] = tw->m_tensor->data(); - } else { + } else if ( + DTypePromoteCfg::convert_input_enabled && + op->same_type()) { tensors[i] = convert_pyinput_to_tensor(i); + } else { + PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs"); + return nullptr; } } diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index d2fbaf08..3f18de80 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -77,6 +77,11 @@ def test_div(): np.floor_divide(np.array([-5, -7], dtype=np.int32), 2), ) + np.testing.assert_allclose( + (tensor([[5, 4, 3], [4, 2, 6]]) // [1, 2, 1]).numpy(), + np.floor_divide(np.array([[5, 4, 3], [4, 2, 6]], dtype=np.int32), [1, 2, 1]), + ) + def test_clamp(): """Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and