@@ -233,7 +233,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule( | |||||
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
CustomBackward& backward) { | CustomBackward& backward) { | ||||
auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); | auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); | ||||
auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items); | |||||
auto&& grad_op = IndexingIncrMultiAxisVec::make(indexingMultiAxisVec.items); | |||||
SmallVector<ValueRef> inputs2; | SmallVector<ValueRef> inputs2; | ||||
if (inputs_require_grad[0]) { | if (inputs_require_grad[0]) { | ||||
inputs2.push_back(get_shape(inputs[0])); | inputs2.push_back(get_shape(inputs[0])); | ||||
@@ -316,7 +316,7 @@ def test_IndexingMultiAxisVec(): | |||||
def f(x): | def f(x): | ||||
x = x * 1 | x = x * 1 | ||||
y = x[[0, 2], [0, 2]] | |||||
y = x[[0, 0, 2, 1], [2, 2, 1, 0]] | |||||
refs["x"] = TensorWeakRef(x) | refs["x"] = TensorWeakRef(x) | ||||
return y | return y | ||||
@@ -326,7 +326,7 @@ def test_IndexingMultiAxisVec(): | |||||
grad(y, F.ones_like(y)) | grad(y, F.ones_like(y)) | ||||
np.testing.assert_equal( | np.testing.assert_equal( | ||||
np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy() | |||||
np.array([[0, 0, 2], [1, 0, 0], [0, 1, 0]], dtype=np.float32), x.grad.numpy() | |||||
) | ) | ||||