Browse Source

fix(mge): fix advanced indexing grad

GitOrigin-RevId: 8033c9322d
release-1.10
Megvii Engine Team 3 years ago
parent
commit
99a85c4079
2 changed files with 3 additions and 3 deletions
  1. +1
    -1
      imperative/python/src/grad_override.cpp
  2. +2
    -2
      imperative/python/test/unit/core/test_autodiff.py

+ 1
- 1
imperative/python/src/grad_override.cpp View File

@@ -233,7 +233,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>();
auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items);
auto&& grad_op = IndexingIncrMultiAxisVec::make(indexingMultiAxisVec.items);
SmallVector<ValueRef> inputs2;
if (inputs_require_grad[0]) {
inputs2.push_back(get_shape(inputs[0]));


+ 2
- 2
imperative/python/test/unit/core/test_autodiff.py View File

@@ -316,7 +316,7 @@ def test_IndexingMultiAxisVec():

def f(x):
x = x * 1
y = x[[0, 2], [0, 2]]
y = x[[0, 0, 2, 1], [2, 2, 1, 0]]
refs["x"] = TensorWeakRef(x)
return y

@@ -326,7 +326,7 @@ def test_IndexingMultiAxisVec():
grad(y, F.ones_like(y))

np.testing.assert_equal(
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy()
np.array([[0, 0, 2], [1, 0, 0], [0, 1, 0]], dtype=np.float32), x.grad.numpy()
)




Loading…
Cancel
Save