|
|
@@ -103,6 +103,12 @@ def test_training_converge(): |
|
|
|
|
|
|
|
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" |
|
|
|
|
|
|
|
data, _ = next(train_dataset) |
|
|
|
ngrid = 10 |
|
|
|
x = np.linspace(-1.0, 1.0, ngrid) |
|
|
|
xx, yy = np.meshgrid(x, x) |
|
|
|
xx = xx.reshape((ngrid * ngrid, 1)) |
|
|
|
yy = yy.reshape((ngrid * ngrid, 1)) |
|
|
|
data = np.concatenate((xx, yy), axis=1).astype(np.float32) |
|
|
|
|
|
|
|
pred = infer(data).numpy() |
|
|
|
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" |
|
|
|
assert calculate_precision(data, pred) == 1.0, "Test precision must be high enough" |