diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index 9f93a175..612e03ab 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -191,19 +191,34 @@ class LSQ_numpy: def test_lsq(): - def preprocess(scale, eps): - scale = np.array([0]) if scale < eps else scale - eps - return np.abs(scale) + eps - g = [] def cb(grad): g.append(grad) - x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") - s = np.random.rand(1) + # FIXME: use random number when LSQ is fixed + # x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") + # s = np.random.rand(1) + x = np.array( + [ + [ + [ + [4.0, 38.0, -121.0, 38.0], + [15.0, -115.0, -112.0, 24.0], + [23.0, -65.0, 109.0, -115.0], + ], + [ + [-66.0, -90.0, -45.0, -101.0], + [68.0, -98.0, 108.0, -79.0], + [54.0, 63.0, -10.0, -50.0], + ], + ] + ], + dtype="float32", + ) + s = np.array([0.02918224], dtype="float32") eps = np.array([1e-5], dtype="float32") - s = preprocess(s, eps) + s = np.abs(s) if np.abs(s) > eps else eps zero_point = np.array([1.0], dtype="float32") grad_s = np.array([2.0], dtype="float32")