Browse Source

fix(imperative/opr): close SCALAR_IDX warning of IndexingMultiAxisVec for proxy_graph

GitOrigin-RevId: 6e0d888f85
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
798ae5e58c
2 changed files with 44 additions and 1 deletions
  1. +42
    -0
      imperative/python/test/unit/test_tracing.py
  2. +2
    -1
      src/opr/impl/indexing.cpp

+ 42
- 0
imperative/python/test/unit/test_tracing.py View File

@@ -6,6 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import inspect
import io
import itertools
from tempfile import mkstemp
@@ -492,3 +493,44 @@ def test_random(shape_mode):

run_test(uniform)
run_test(normal)


@pytest.mark.parametrize("shape_mode", [False, True])
def test_trace_advance_indexing(shape_mode):
funcs = [
lambda x, i: x[i],
# lambda x, i, j: x[i, j], # FIXME
lambda x, i, j: x[i, :, j, ...],
# lambda x, start, end: x[start:end], # FIXME
lambda x, start, end: x[:, 0, start:end, ..., 1],
lambda x, vec: x[vec],
lambda x, vec: x[vec, ..., 0, 1:3],
lambda x, vec: x[vec, vec[0], vec[1]],
# lambda x, i, start, end, vec: x[i, ..., :, vec, start:end], # FIXME
lambda x, mask: x[mask],
]

inputs = {
"x": np.random.randn(5, 5, 5, 5, 5).astype("float32"),
"i": 0,
"j": 2,
"start": 1,
"end": 3,
"vec": [1, 2, 3],
"mask": np.random.randn(5, 5, 5, 5, 5) >= 0,
}
for f in funcs:
sig = inspect.signature(f)
param_names = list(sig._parameters.keys())
params = {}
params_np = {}
f_traced = trace(f, symbolic=False, symbolic_shape=shape_mode)
for name in param_names:
params[name] = tensor(inputs[name])
params_np[name] = inputs[name]
expected = f(**params_np)
result_imperative = f(**params)
np.testing.assert_equal(expected, result_imperative.numpy())
for _ in range(3):
result_trace = f_traced(**params)
np.testing.assert_equal(expected, result_trace.numpy())

+ 2
- 1
src/opr/impl/indexing.cpp View File

@@ -248,7 +248,8 @@ intl::MultiAxisVecFancyIndexingHelper::make_megdnn_index_desc(
}
}

if (!m_scalar_idx_warn_printed && warn_all_scalar) {
if (!m_scalar_idx_warn_printed && warn_all_scalar &&
!this->owner_graph()->options().imperative_proxy_graph) {
bool all_scalar = true;
for (auto &&i: index) {
if (!i.vec.layout.is_scalar()) {


Loading…
Cancel
Save