|
|
@@ -6,6 +6,7 @@ |
|
|
|
# Unless required by applicable law or agreed to in writing, |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import inspect |
|
|
|
import io |
|
|
|
import itertools |
|
|
|
from tempfile import mkstemp |
|
|
@@ -492,3 +493,44 @@ def test_random(shape_mode): |
|
|
|
|
|
|
|
run_test(uniform) |
|
|
|
run_test(normal) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shape_mode", [False, True]) |
|
|
|
def test_trace_advance_indexing(shape_mode): |
|
|
|
funcs = [ |
|
|
|
lambda x, i: x[i], |
|
|
|
# lambda x, i, j: x[i, j], # FIXME |
|
|
|
lambda x, i, j: x[i, :, j, ...], |
|
|
|
# lambda x, start, end: x[start:end], # FIXME |
|
|
|
lambda x, start, end: x[:, 0, start:end, ..., 1], |
|
|
|
lambda x, vec: x[vec], |
|
|
|
lambda x, vec: x[vec, ..., 0, 1:3], |
|
|
|
lambda x, vec: x[vec, vec[0], vec[1]], |
|
|
|
# lambda x, i, start, end, vec: x[i, ..., :, vec, start:end], # FIXME |
|
|
|
lambda x, mask: x[mask], |
|
|
|
] |
|
|
|
|
|
|
|
inputs = { |
|
|
|
"x": np.random.randn(5, 5, 5, 5, 5).astype("float32"), |
|
|
|
"i": 0, |
|
|
|
"j": 2, |
|
|
|
"start": 1, |
|
|
|
"end": 3, |
|
|
|
"vec": [1, 2, 3], |
|
|
|
"mask": np.random.randn(5, 5, 5, 5, 5) >= 0, |
|
|
|
} |
|
|
|
for f in funcs: |
|
|
|
sig = inspect.signature(f) |
|
|
|
param_names = list(sig._parameters.keys()) |
|
|
|
params = {} |
|
|
|
params_np = {} |
|
|
|
f_traced = trace(f, symbolic=False, symbolic_shape=shape_mode) |
|
|
|
for name in param_names: |
|
|
|
params[name] = tensor(inputs[name]) |
|
|
|
params_np[name] = inputs[name] |
|
|
|
expected = f(**params_np) |
|
|
|
result_imperative = f(**params) |
|
|
|
np.testing.assert_equal(expected, result_imperative.numpy()) |
|
|
|
for _ in range(3): |
|
|
|
result_trace = f_traced(**params) |
|
|
|
np.testing.assert_equal(expected, result_trace.numpy()) |