From 99a85c40796175bb06cdfd9be52ee2a5e08b70e6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 18 Apr 2022 14:05:55 +0800 Subject: [PATCH] fix(mge): fix advanced indexing grad GitOrigin-RevId: 8033c9322dd79db2b72a8cc7df66cdaf270b0b60 --- imperative/python/src/grad_override.cpp | 2 +- imperative/python/test/unit/core/test_autodiff.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 51982c1a..6ddd7ca5 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -233,7 +233,7 @@ std::optional indexingMultiAxisVec_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { auto&& indexingMultiAxisVec = op.cast_final_safe(); - auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items); + auto&& grad_op = IndexingIncrMultiAxisVec::make(indexingMultiAxisVec.items); SmallVector inputs2; if (inputs_require_grad[0]) { inputs2.push_back(get_shape(inputs[0])); diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 63a8304a..cd07f654 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -316,7 +316,7 @@ def test_IndexingMultiAxisVec(): def f(x): x = x * 1 - y = x[[0, 2], [0, 2]] + y = x[[0, 0, 2, 1], [2, 2, 1, 0]] refs["x"] = TensorWeakRef(x) return y @@ -326,7 +326,7 @@ def test_IndexingMultiAxisVec(): grad(y, F.ones_like(y)) np.testing.assert_equal( - np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy() + np.array([[0, 0, 2], [1, 0, 0], [0, 1, 0]], dtype=np.float32), x.grad.numpy() )