diff --git a/python_module/test/integration/test_converge.py b/python_module/test/integration/test_converge.py index ec0efa87..e7e4f6c4 100644 --- a/python_module/test/integration/test_converge.py +++ b/python_module/test/integration/test_converge.py @@ -103,6 +103,12 @@ def test_training_converge(): assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" - data, _ = next(train_dataset) + ngrid = 10 + x = np.linspace(-1.0, 1.0, ngrid) + xx, yy = np.meshgrid(x, x) + xx = xx.reshape((ngrid * ngrid, 1)) + yy = yy.reshape((ngrid * ngrid, 1)) + data = np.concatenate((xx, yy), axis=1).astype(np.float32) + pred = infer(data).numpy() - assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough"