Browse Source

fix(mgb/dnn): fix backward computation of tqt

GitOrigin-RevId: 850d11a5ce
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
652ec9f251
3 changed files with 7 additions and 7 deletions
  1. +1
    -1
      dnn/src/cuda/tqt/kern.cuh
  2. +1
    -1
      dnn/src/naive/tqt/opr_impl.cpp
  3. +5
    -5
      imperative/python/test/unit/quantization/test_fake_quant.py

+ 1
- 1
dnn/src/cuda/tqt/kern.cuh View File

@@ -58,7 +58,7 @@ struct TQTBwdKernOp {
ctype scaled = input[idx] / t;
ctype rounded = round(scaled);
rounded = fmaxf(fminf(rounded, qmax), qmin);
bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax;
bool mask_clip = (scaled < -0.5 + qmin) + (scaled > 0.5 + qmax);
bool mask_quant = !mask_clip;

grad_x[idx] = diff[idx] * mask_quant;


+ 1
- 1
dnn/src/naive/tqt/opr_impl.cpp View File

@@ -53,7 +53,7 @@ void backward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) {
T rounded = round(scaled);
rounded = rounded <= qmin ? qmin : rounded;
rounded = rounded >= qmax ? qmax : rounded;
bool mask_clip = scaled < -0.5 + qmin && scaled > 0.5 + qmax;
bool mask_clip = (scaled < -0.5 + qmin) + (scaled > 0.5 + qmax);
bool mask_quant = !mask_clip;

*grad_x = *diff * mask_quant;


+ 5
- 5
imperative/python/test/unit/quantization/test_fake_quant.py View File

@@ -69,8 +69,8 @@ def test_tqt():
def cb(grad):
g.append(grad)

x = np.random.normal(size=(1, 2, 3, 4))
s = np.random.rand(1) + 1
x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32")
s = np.random.rand(1) - 1
g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32")

n = TQT_numpy(-127, 127)
@@ -85,9 +85,9 @@ def test_tqt():
grad(y, g_y)
g_x, g_s = g

np.testing.assert_allclose(y.numpy(), y_np, atol=1e-6)
np.testing.assert_allclose(g_x.numpy(), g_x_np, atol=1e-6)
np.testing.assert_allclose(g_s.numpy(), g_s_np, atol=1e-6)
np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-5, atol=5e-5)





Loading…
Cancel
Save