Browse Source

refactor(mge/test): make test_parampack more stable

GitOrigin-RevId: d82230ea07
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
888895e971
1 changed files with 24 additions and 6 deletions
  1. +24
    -6
      python_module/test/integration/test_parampack.py

+ 24
- 6
python_module/test/integration/test_parampack.py View File

@@ -105,9 +105,15 @@ def test_static_graph_parampack():


assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"


data, _ = next(train_dataset)
ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid)
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)

pred = infer(data).numpy() pred = infer(data).numpy()
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough"
assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough"




@pytest.mark.slow @pytest.mark.slow
@@ -140,9 +146,15 @@ def test_nopack_parampack():
losses.append(loss.numpy()) losses.append(loss.numpy())
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"


data, _ = next(train_dataset)
ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid)
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)

pred = infer(data).numpy() pred = infer(data).numpy()
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough"
assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough"




@pytest.mark.slow @pytest.mark.slow
@@ -178,9 +190,15 @@ def test_dynamic_graph_parampack():


assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"


data, _ = next(train_dataset)
ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid)
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)

pred = infer(data).numpy() pred = infer(data).numpy()
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough"
assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough"




@pytest.mark.slow @pytest.mark.slow


Loading…
Cancel
Save