Browse Source

fix(trace): assume result is not scalar when shape is valid

GitOrigin-RevId: beee2d0f28
release-1.8
Megvii Engine Team “wenjuan” 3 years ago
parent
commit
54eef55871
3 changed files with 26 additions and 19 deletions
  1. +9
    -5
      imperative/python/megengine/core/tensor/indexing.py
  2. +3
    -7
      imperative/python/src/tensor.cpp
  3. +14
    -7
      imperative/src/impl/transformations/scalar.cpp

+ 9
- 5
imperative/python/megengine/core/tensor/indexing.py View File

@@ -119,12 +119,16 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True):
else 1 else 1
) )
else: else:
if ndim_indexed > inp.ndim:
raise IndexError(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
inp.ndim, len(tuple_val)
try:
if ndim_indexed > inp.ndim:
raise IndexError(
"too many indices for tensor: tensor is {}-dimensional, but {} were indexed".format(
inp.ndim, len(tuple_val)
)
) )
)
except ValueError:
# ignore
pass


tuple_val = remove_ellipsis(inp, tuple_val) tuple_val = remove_ellipsis(inp, tuple_val)
use_subtensor = True use_subtensor = True


+ 3
- 7
imperative/python/src/tensor.cpp View File

@@ -272,16 +272,12 @@ PyObject* TensorWrapper::device() {


PyObject* TensorWrapper::numpy() { PyObject* TensorWrapper::numpy() {
auto hv = m_tensor->numpy(); auto hv = m_tensor->numpy();
// if (!hv) {
// PyErr_SetString(PyExc_ValueError, "tensor invalid");
// return nullptr;
// }
auto arr = py::reinterpret_steal<py::array>(
npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
if (!arr) {
if (!hv) {
PyErr_SetString(PyExc_ValueError, "tensor invalid"); PyErr_SetString(PyExc_ValueError, "tensor invalid");
return nullptr; return nullptr;
} }
auto arr = py::reinterpret_steal<py::array>(
npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
if (hv->shape().is_scalar()) { if (hv->shape().is_scalar()) {
mgb_assert(PyArray_Check(arr.ptr())); mgb_assert(PyArray_Check(arr.ptr()));
return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr())); return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));


+ 14
- 7
imperative/src/impl/transformations/scalar.cpp View File

@@ -51,6 +51,7 @@ bool is_scalar_shape(ValueRef shape) {
if (shape.is<ScalarValue>()) { if (shape.is<ScalarValue>()) {
return false; return false;
} }
// may have performance issue
auto shape_of_shape = shape.shape(); auto shape_of_shape = shape.shape();
if (!shape_of_shape) { if (!shape_of_shape) {
// assume not scalar // assume not scalar
@@ -211,14 +212,21 @@ std::vector<ValueRef> subtensor_rule(
const Subtensor& subtensor, Span<ValueRef> inputs) { const Subtensor& subtensor, Span<ValueRef> inputs) {
mgb_assert(inputs.size() >= 1); mgb_assert(inputs.size() >= 1);
auto input = inputs[0]; auto input = inputs[0];
size_t ndim = input.is<ScalarValue>() ? 0 : input.shape()->ndim;
for (auto&& [axis, begin, end, step, idx] : subtensor.items) {
if (idx) {
ndim--;
bool is_scalar;
mgb_assert(!input.is<ScalarValue>(), "subtensor shouldn't have scalar input");
if (auto shape = input.shape()) {
size_t ndim = input.shape()->ndim;
for (auto&& [axis, begin, end, step, idx] : subtensor.items) {
if (idx) {
ndim--;
}
} }
is_scalar = ndim == 0;
} else {
is_scalar = false;
} }
auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0]; auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0];
if (!ndim) {
if (is_scalar) {
return {ScalarValue::make(output)}; return {ScalarValue::make(output)};
} else { } else {
return {output}; return {output};
@@ -261,8 +269,7 @@ std::vector<ValueRef> fastpath_copy_rule(


std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) { std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) {
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
bool is_scalar =
(!inputs[1].is<ScalarValue>()) && *inputs[1].shape() == ValueShape{0};
bool is_scalar = is_scalar_shape(inputs[1]);
auto unwrapped_input = inputs[0].is<ScalarValue>() auto unwrapped_input = inputs[0].is<ScalarValue>()
? inputs[0].cast<ScalarValue>().value() ? inputs[0].cast<ScalarValue>().value()
: inputs[0]; : inputs[0];


Loading…
Cancel
Save