|
|
@@ -191,19 +191,34 @@ class LSQ_numpy: |
|
|
|
|
|
|
|
|
|
|
|
def test_lsq(): |
|
|
|
def preprocess(scale, eps): |
|
|
|
scale = np.array([0]) if scale < eps else scale - eps |
|
|
|
return np.abs(scale) + eps |
|
|
|
|
|
|
|
g = [] |
|
|
|
|
|
|
|
def cb(grad): |
|
|
|
g.append(grad) |
|
|
|
|
|
|
|
x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") |
|
|
|
s = np.random.rand(1) |
|
|
|
# FIXME: use random number when LSQ is fixed |
|
|
|
# x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") |
|
|
|
# s = np.random.rand(1) |
|
|
|
x = np.array( |
|
|
|
[ |
|
|
|
[ |
|
|
|
[ |
|
|
|
[4.0, 38.0, -121.0, 38.0], |
|
|
|
[15.0, -115.0, -112.0, 24.0], |
|
|
|
[23.0, -65.0, 109.0, -115.0], |
|
|
|
], |
|
|
|
[ |
|
|
|
[-66.0, -90.0, -45.0, -101.0], |
|
|
|
[68.0, -98.0, 108.0, -79.0], |
|
|
|
[54.0, 63.0, -10.0, -50.0], |
|
|
|
], |
|
|
|
] |
|
|
|
], |
|
|
|
dtype="float32", |
|
|
|
) |
|
|
|
s = np.array([0.02918224], dtype="float32") |
|
|
|
eps = np.array([1e-5], dtype="float32") |
|
|
|
s = preprocess(s, eps) |
|
|
|
s = np.abs(s) if np.abs(s) > eps else eps |
|
|
|
zero_point = np.array([1.0], dtype="float32") |
|
|
|
grad_s = np.array([2.0], dtype="float32") |
|
|
|
|
|
|
|