diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 9f1817ba..810246dd 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -442,3 +442,18 @@ def test_removeAxis(): grad(y, F.ones_like(y)) np.testing.assert_equal(np.ones((3, 3, 1, 1), dtype=np.float32), x.grad.numpy()) + + +def test_dot(): + x = np.random.rand(2, 2).astype("float32") + x = mge.Tensor(x) + u = F.ones((2,)) + v = F.ones((2,)) + grad = Grad().wrt(x, callback=save_to(x)) + + def f(x): + return F.dot(u, F.matmul(x, v)) + + y = f(x) + grad(y, F.ones_like(y)) + np.testing.assert_equal(np.ones((2, 2), dtype=np.float32), x.grad.numpy()) diff --git a/imperative/src/impl/ops/tensor_manip.cpp b/imperative/src/impl/ops/tensor_manip.cpp index 5569bbb7..2bc29bbe 100644 --- a/imperative/src/impl/ops/tensor_manip.cpp +++ b/imperative/src/impl/ops/tensor_manip.cpp @@ -33,7 +33,7 @@ DispatchMode decide_dispatch_mode( const SmallVector& inputs) { bool host_computable = true; for (auto&& inp : inputs) { - // FIXME(czh): remove value chech after proxy graph's + // FIXME(czh): remove value check after proxy graph's // apply_on_device_tensornd is supported and output Tensor // is made before add_task. // then if layout is valid, ptr->layout must be ready @@ -50,9 +50,18 @@ void apply_on_device_tensornd( const SmallVector& inputs, SmallVector* outputs) { auto&& op_def = def.cast_final_safe(); - mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); - auto&& inp = inputs[0]; - auto&& shp = inp.layout(); + + TensorShape shp; + if (inputs.size() == 1) { + shp = inputs[0].layout(); + } else { + TensorShapeArray src(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + src[i] = inputs[i].layout(); + } + megdnn::Elemwise::deduce_shape(src, shp); + } + mgb_assert(shp.ndim != 0, "input shape invalid"); mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(), "GetVarShape's apply_on_device_tensornd should receive default_cpu outputs."); @@ -99,27 +108,36 @@ std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { auto&& op_def = def.cast_final_safe(); - mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); auto&& desc = inputs[0]; - if (!desc.layout.ndim) { + TensorShape shp; + if (inputs.size() == 1) { + shp = desc.layout; + } else { + TensorShapeArray src(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + src[i] = inputs[i].layout; + } + megdnn::Elemwise::deduce_shape(src, shp); + } + if (!shp.ndim) { return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; } DeviceTensorND value; if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { - value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); + value = DeviceTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32()); auto* ptr = value.ptr(); - for (size_t i = 0; i < desc.layout.ndim; ++i) { - ptr[i] = desc.layout[i]; + for (size_t i = 0; i < shp.ndim; ++i) { + ptr[i] = shp[i]; } }else{ int32_t axis = op_def.axis; if (axis < 0) { - axis += desc.layout.ndim; + axis += shp.ndim; } - mgb_assert(axis >= 0 && axis < (int32_t)desc.layout.ndim); + mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); value = DeviceTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); auto* ptr = value.ptr(); - ptr[0] = desc.layout[axis]; + ptr[0] = shp[axis]; } return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; }