Browse Source

refactor(mge/test): make test_converge more stable

GitOrigin-RevId: 710e3ede40
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
2174fd12b9
1 changed files with 8 additions and 2 deletions
  1. +8
    -2
      python_module/test/integration/test_converge.py

+ 8
- 2
python_module/test/integration/test_converge.py View File

@@ -103,6 +103,12 @@ def test_training_converge():

assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"

data, _ = next(train_dataset)
ngrid = 10
x = np.linspace(-1.0, 1.0, ngrid)
xx, yy = np.meshgrid(x, x)
xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1))
data = np.concatenate((xx, yy), axis=1).astype(np.float32)

pred = infer(data).numpy()
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough"
assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough"

Loading…
Cancel
Save