Browse Source

refactor(mge): loose the error bound of fastrun

GitOrigin-RevId: 9bf9b9d4ca
release-1.5
Megvii Engine Team 3 years ago
parent
commit
f141159088
1 changed files with 11 additions and 3 deletions
  1. +11
    -3
      imperative/python/test/unit/utils/test_network_node.py

+ 11
- 3
imperative/python/test/unit/utils/test_network_node.py View File

@@ -105,17 +105,25 @@ def test_matinv():
check_pygraph_dump(fwd, [data], [result]) check_pygraph_dump(fwd, [data], [result])




def test_matmul():
@pytest.mark.parametrize(
"execution_strategy", ["HEURISTIC_REPRODUCIBLE", "PROFILE_REPRODUCIBLE"]
)
def test_matmul(execution_strategy):
@trace(symbolic=True, capture_as_const=True) @trace(symbolic=True, capture_as_const=True)
def fwd(data1, data2): def fwd(data1, data2):
return F.matmul(data1, data2) return F.matmul(data1, data2)


old = get_execution_strategy() old = get_execution_strategy()
set_execution_strategy("HEURISTIC_REPRODUCIBLE")
set_execution_strategy(execution_strategy)

max_err = None
if execution_strategy == "PROFILE_REPRODUCIBLE":
max_err = 1e-5

data1 = Tensor(np.random.random((32, 64))) data1 = Tensor(np.random.random((32, 64)))
data2 = Tensor(np.random.random((64, 16))) data2 = Tensor(np.random.random((64, 16)))
result = fwd(data1, data2) result = fwd(data1, data2)
check_pygraph_dump(fwd, [data1, data2], [result])
check_pygraph_dump(fwd, [data1, data2], [result], max_err=max_err)
set_execution_strategy(old) set_execution_strategy(old)






Loading…
Cancel
Save