@@ -58,7 +58,7 @@ struct TQTBwdKernOp { | |||||
ctype scaled = input[idx] / t; | ctype scaled = input[idx] / t; | ||||
ctype rounded = round(scaled); | ctype rounded = round(scaled); | ||||
rounded = fmaxf(fminf(rounded, qmax), qmin); | rounded = fmaxf(fminf(rounded, qmax), qmin); | ||||
bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax; | |||||
bool mask_clip = (scaled < -0.5 + qmin) + (scaled > 0.5 + qmax); | |||||
bool mask_quant = !mask_clip; | bool mask_quant = !mask_clip; | ||||
grad_x[idx] = diff[idx] * mask_quant; | grad_x[idx] = diff[idx] * mask_quant; | ||||
@@ -53,7 +53,7 @@ void backward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) { | |||||
T rounded = round(scaled); | T rounded = round(scaled); | ||||
rounded = rounded <= qmin ? qmin : rounded; | rounded = rounded <= qmin ? qmin : rounded; | ||||
rounded = rounded >= qmax ? qmax : rounded; | rounded = rounded >= qmax ? qmax : rounded; | ||||
bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax; | |||||
bool mask_clip = (scaled < -0.5 + qmin) + (scaled > 0.5 + qmax); | |||||
bool mask_quant = !mask_clip; | bool mask_quant = !mask_clip; | ||||
*grad_x = *diff * mask_quant; | *grad_x = *diff * mask_quant; | ||||
@@ -69,8 +69,8 @@ def test_tqt(): | |||||
def cb(grad): | def cb(grad): | ||||
g.append(grad) | g.append(grad) | ||||
x = np.random.normal(size=(1, 2, 3, 4)) | |||||
s = np.random.rand(1) + 1 | |||||
x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") | |||||
s = np.random.rand(1) - 1 | |||||
g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32") | g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32") | ||||
n = TQT_numpy(-127, 127) | n = TQT_numpy(-127, 127) | ||||
@@ -85,9 +85,9 @@ def test_tqt(): | |||||
grad(y, g_y) | grad(y, g_y) | ||||
g_x, g_s = g | g_x, g_s = g | ||||
np.testing.assert_allclose(y.numpy(), y_np, atol=1e-6) | |||||
np.testing.assert_allclose(g_x.numpy(), g_x_np, atol=1e-6) | |||||
np.testing.assert_allclose(g_s.numpy(), g_s_np, atol=1e-6) | |||||
np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5) | |||||
np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-5, atol=1e-5) | |||||
np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-5, atol=5e-5) | |||||