diff --git a/imperative/python/test/unit/utils/test_network_node.py b/imperative/python/test/unit/utils/test_network_node.py index f6bcddac..15cb7c1d 100644 --- a/imperative/python/test/unit/utils/test_network_node.py +++ b/imperative/python/test/unit/utils/test_network_node.py @@ -105,17 +105,25 @@ def test_matinv(): check_pygraph_dump(fwd, [data], [result]) -def test_matmul(): +@pytest.mark.parametrize( + "execution_strategy", ["HEURISTIC_REPRODUCIBLE", "PROFILE_REPRODUCIBLE"] +) +def test_matmul(execution_strategy): @trace(symbolic=True, capture_as_const=True) def fwd(data1, data2): return F.matmul(data1, data2) old = get_execution_strategy() - set_execution_strategy("HEURISTIC_REPRODUCIBLE") + set_execution_strategy(execution_strategy) + + max_err = None + if execution_strategy == "PROFILE_REPRODUCIBLE": + max_err = 1e-5 + data1 = Tensor(np.random.random((32, 64))) data2 = Tensor(np.random.random((64, 16))) result = fwd(data1, data2) - check_pygraph_dump(fwd, [data1, data2], [result]) + check_pygraph_dump(fwd, [data1, data2], [result], max_err=max_err) set_execution_strategy(old)