Browse Source

feat(mge/pytest): add more tests for specialized grad rules

GitOrigin-RevId: 509ef5a220
release-1.2
Megvii Engine Team 4 years ago
parent
commit
697f70c011
1 changed files with 72 additions and 6 deletions
  1. +72
    -6
      imperative/python/test/unit/core/test_autodiff.py

+ 72
- 6
imperative/python/test/unit/core/test_autodiff.py View File

@@ -259,7 +259,18 @@ def test_reshape():
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = x.reshape(5, 2)

refs = {}

def f(x):
x = x * 1
y = x.reshape(5, 2)
refs["x"] = TensorWeakRef(x)
return y

y = f(x)
for _, r in refs.items():
assert r() is None


grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy()) np.testing.assert_equal(np.ones((2, 5), dtype=np.float32), x.grad.numpy())
@@ -270,7 +281,18 @@ def test_subtensor():
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = x[1:-1, :2]

refs = {}

def f(x):
x = x * 1
y = x[1:-1, :2]
refs["x"] = TensorWeakRef(x)
return y

y = f(x)
for _, r in refs.items():
assert r() is None


grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal( np.testing.assert_equal(
@@ -283,7 +305,18 @@ def test_IndexingMultiAxisVec():
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = x[[0, 2], [0, 2]]

refs = {}

def f(x):
x = x * 1
y = x[[0, 2], [0, 2]]
refs["x"] = TensorWeakRef(x)
return y

y = f(x)
for _, r in refs.items():
assert r() is None


grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal( np.testing.assert_equal(
@@ -296,7 +329,18 @@ def test_AxisAddRemove():
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = F.squeeze(F.expand_dims(x, 2), 0)

refs = {}

def f(x):
x = x * 1
y = F.squeeze(F.expand_dims(x, 2), 0)
refs["x"] = TensorWeakRef(x)
return y

y = f(x)
for _, r in refs.items():
assert r() is None


grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal( np.testing.assert_equal(
@@ -342,7 +386,18 @@ def test_addAxis():
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = F.expand_dims(x, [2, 3])

refs = {}

def f(x):
x = x * 1
y = F.expand_dims(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y

y = f(x)
for _, r in refs.items():
assert r() is None


grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy()) np.testing.assert_equal(np.ones((3, 3), dtype=np.float32), x.grad.numpy())
@@ -353,7 +408,18 @@ def test_removeAxis():
x = mge.Tensor(x_np) x = mge.Tensor(x_np)


grad = Grad().wrt(x, callback=save_to(x)) grad = Grad().wrt(x, callback=save_to(x))
y = F.squeeze(x, [2, 3])

refs = {}

def f(x):
x = x * 1
y = F.squeeze(x, [2, 3])
refs["x"] = TensorWeakRef(x)
return y

y = f(x)
for _, r in refs.items():
assert r() is None


grad(y, F.ones_like(y)) grad(y, F.ones_like(y))
np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy()) np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy())

Loading…
Cancel
Save