Browse Source

test(trace): test subtensor on unknown shape

GitOrigin-RevId: 1b5cfa4e0a
release-1.8
Megvii Engine Team “wenjuan” 3 years ago
parent
commit
1add4517ad
1 changed files with 27 additions and 0 deletions
  1. +27
    -0
      imperative/python/test/unit/core/test_indexing_op.py

+ 27
- 0
imperative/python/test/unit/core/test_indexing_op.py View File

@@ -7,6 +7,8 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
import platform
from tempfile import NamedTemporaryFile


import numpy as np import numpy as np
import pytest import pytest
@@ -16,6 +18,8 @@ import megengine
import megengine.core.tensor.megbrain_graph as G import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F import megengine.functional as F
import megengine.jit as jit import megengine.jit as jit
import megengine.random as rand
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin from megengine.core.ops import builtin
@@ -724,3 +728,26 @@ def test_nd_int_indexing(symbolic):
np.testing.assert_equal(out.numpy(), npy_out) np.testing.assert_equal(out.numpy(), npy_out)


run_test([inp, idx], lambda inp, idx: inp[idx]) run_test([inp, idx], lambda inp, idx: inp[idx])


@pytest.mark.skipif(
platform.system() == "Windows", reason="windows temp file issue, fixme later"
)
def test_subtensor_when_shape_invalid():
@jit.trace(symbolic=True, capture_as_const=True)
def fun(inp):
shape = inp.shape
H = shape[-1]
NH = H * 8 + 4
arr = F.arange(4, NH, 8)
arr_shape = arr.shape
return arr_shape[0]

inp = rand.uniform(size=[1, 3, 224, 224])
fun(inp)

with NamedTemporaryFile() as f:
fun.dump(f.name, arg_names=["data"], optimize_for_inference=True)
inp = rand.uniform(size=[1, 3, 512, 512])
net = cgtools.GraphInference(f.name)
net.run(inp_dict={"data": inp})

Loading…
Cancel
Save