Browse Source

test(trace): test subtensor on unknown shape

GitOrigin-RevId: 1b5cfa4e0a
release-1.8
Megvii Engine Team “wenjuan” 3 years ago
parent
commit
1add4517ad
1 changed files with 27 additions and 0 deletions
  1. +27
    -0
      imperative/python/test/unit/core/test_indexing_op.py

+ 27
- 0
imperative/python/test/unit/core/test_indexing_op.py View File

@@ -7,6 +7,8 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import platform
from tempfile import NamedTemporaryFile

import numpy as np
import pytest
@@ -16,6 +18,8 @@ import megengine
import megengine.core.tensor.megbrain_graph as G
import megengine.functional as F
import megengine.jit as jit
import megengine.random as rand
import megengine.utils.comp_graph_tools as cgtools
from megengine.core._imperative_rt.core2 import apply
from megengine.core._trace_option import use_symbolic_shape
from megengine.core.ops import builtin
@@ -724,3 +728,26 @@ def test_nd_int_indexing(symbolic):
np.testing.assert_equal(out.numpy(), npy_out)

run_test([inp, idx], lambda inp, idx: inp[idx])


@pytest.mark.skipif(
platform.system() == "Windows", reason="windows temp file issue, fixme later"
)
def test_subtensor_when_shape_invalid():
@jit.trace(symbolic=True, capture_as_const=True)
def fun(inp):
shape = inp.shape
H = shape[-1]
NH = H * 8 + 4
arr = F.arange(4, NH, 8)
arr_shape = arr.shape
return arr_shape[0]

inp = rand.uniform(size=[1, 3, 224, 224])
fun(inp)

with NamedTemporaryFile() as f:
fun.dump(f.name, arg_names=["data"], optimize_for_inference=True)
inp = rand.uniform(size=[1, 3, 512, 512])
net = cgtools.GraphInference(f.name)
net.run(inp_dict={"data": inp})

Loading…
Cancel
Save