diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index a8710731..102af7f2 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -13,6 +13,7 @@ from .._imperative_rt.core2 import ( astype_cpp, batched_matmul_cpp, broadcast_cpp, + expand_dims_cpp, getitem_cpp, matmul_cpp, reshape_cpp, @@ -62,7 +63,6 @@ def _matmul( assert dim1 > 0 and dim2 > 0 maxdim = dim1 if dim1 > dim2 else dim2 compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) - if dim1 == 1 and dim2 == 1: # dispatch to Dot (result,) = apply(builtin.Dot(), inp1, inp2) return result @@ -72,34 +72,44 @@ def _matmul( # 2x2 # nx1(transpose_a=False), n>=3 # nx2(transpose_a=False), n>=3 - return matmul_cpp( - inp1, - inp2, - dim1, - dim2, + ret = matmul_cpp( + inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0), + inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1), + max(dim1, 2), + max(dim2, 2), transpose_a, transpose_b, compute_mode, _config._benchmark_kernel, _config._deterministic_kernel, ) + if dim1 == 1: + ret = squeeze_cpp(ret, -2) + elif dim2 == 1: + ret = squeeze_cpp(ret, -1) + return ret else: # dispath to BatchedMatrixMul # nx1(transpose_a=True), n>=3 # nx2(transpose_a=True), n>=3 # nxm,n>=3,m>=3 # 1xm,m>=3 # 2xm,m>=3 - return batched_matmul_cpp( - inp1, - inp2, - dim1, - dim2, + ret = batched_matmul_cpp( + inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0), + inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1), + max(dim1, 2), + max(dim2, 2), transpose_a, transpose_b, compute_mode, _config._benchmark_kernel, _config._deterministic_kernel, ) + if dim1 == 1: + ret = squeeze_cpp(ret, -2) + elif dim2 == 1: + ret = squeeze_cpp(ret, -1) + return ret def _unary_elwise(mode): diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 62228ac6..3da9f618 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -87,6 +87,136 @@ ValueRef make_empty_tensor( return res; } +std::optional matrix_mul_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto&& matmul = op.cast_final_safe(); + size_t dimA = matmul.dimA; + size_t dimB = matmul.dimB; + auto&& param = matmul.param(); + auto&& policy = matmul.policy(); + mgb_assert(inputs.size() == 2); + std::array inps, input_shapes; + for (size_t i = 0; i < 2; ++i) { + if (inputs_require_grad[i ^ 1]) { + inps[i] = inputs[i]; + input_shapes[i] = get_shape(inputs[i]); + } + } + auto maker = CustomGradMaker(backward, inputs.size()); + maker.output_size(1).output_captured(0, false); + maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes), + param, policy, dimA, dimB](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + SmallVector ret(2); + if (!grad) { + return ret; + } + size_t dimG = std::max(dimA, dimB); + if (inps_[1]) { + if (param.transposeA) { + // A^T(2) @ B(2) = G(2), A'(2) = B'(2) @ G'^T(2) -> MatrixMul + auto&& grad_op = MatrixMul::make( + param.transposeB, true, param.compute_mode, param.format, + policy.strategy, policy.workspace_limit, dimB, dimG); + ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0]; + } else { + // A(>=2) @ B(2) = G(>=2), A'(>=2) = G'(>=2) @ B(2) -> MatrixMul + auto&& grad_op = MatrixMul::make( + false, !param.transposeB, param.compute_mode, param.format, + policy.strategy, policy.workspace_limit, dimG, dimB); + ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0]; + } + } + if (inps_[0]) { + if (param.transposeB) { + // A(>=2) @ B^T(2) = G(>=2), B'(2) = G'^T(>=2) @ A(>=2) -> MatrixMul + // (specialized) + auto&& grad_op = MatrixMul::make( + true, param.transposeA, param.compute_mode, param.format, + policy.strategy, policy.workspace_limit, dimG, dimA); + ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0]; + } else { + // A(>=2) @ B(2) = G(>=2), B'(2) = G'(>=2) @ A(>=2) -> MatrixMul + // (specialized) + auto&& grad_op = MatrixMul::make( + !param.transposeA, false, param.compute_mode, param.format, + policy.strategy, policy.workspace_limit, dimA, dimG); + ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0]; + } + } + return ret; + }); + maker.finalize(); + return imperative::apply(ApplyOp(op), inputs); +} + +std::optional batched_matrix_mul_grad_rule( + const OpDef& op, Span inputs, Span inputs_require_grad, + CustomBackward& backward) { + auto&& bmm = op.cast_final_safe(); + size_t dimA = bmm.dimA; + size_t dimB = bmm.dimB; + auto&& param = bmm.param(); + auto&& policy = bmm.policy(); + mgb_assert(inputs.size() == 2); + std::array inps, input_shapes; + for (size_t i = 0; i < 2; ++i) { + if (inputs_require_grad[i ^ 1]) { + inps[i] = inputs[i]; + input_shapes[i] = get_shape(inputs[i]); + } + } + auto maker = CustomGradMaker(backward, inputs.size()); + maker.output_size(1).output_captured(0, false); + maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes), + param, policy, dimA, dimB](Span grads) { + mgb_assert(grads.size() == 1); + ValueRef grad = grads[0]; + SmallVector ret(2); + if (!grad) { + return ret; + } + size_t dimG = std::max(dimA, dimB); + if (inps_[1]) { + if (param.transposeA) { + auto&& grad_op = BatchedMatrixMul::make( + param.transposeB, true, param.compute_mode, param.format, + policy.strategy, policy.workspace_limit, dimB, dimG); + ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0]; + } else { + auto&& grad_op = BatchedMatrixMul::make( + false, !param.transposeB, param.compute_mode, param.format, + policy.strategy, policy.workspace_limit, dimG, dimB); + ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0]; + } + if (dimG != dimA) { + ret[0] = reduce_to(ret[0], input_shapes_[0]); + } + } + if (inps_[0]) { + if (param.transposeB) { + auto&& grad_op = BatchedMatrixMul::make( + true, param.transposeA, param.compute_mode, param.format, + policy.strategy, policy.workspace_limit, dimG, dimA); + ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0]; + } else { + auto&& grad_op = BatchedMatrixMul::make( + !param.transposeA, false, param.compute_mode, param.format, + policy.strategy, policy.workspace_limit, dimA, dimG); + ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0]; + } + if (dimG != dimB) { + ret[1] = reduce_to(ret[1], input_shapes_[1]); + } + } + return ret; + }); + maker.finalize(); + return imperative::apply(ApplyOp(op), inputs); +} + std::optional elemwise_grad_rule( const OpDef& op, Span inputs, Span inputs_require_grad, CustomBackward& backward) { @@ -395,6 +525,9 @@ struct Init { FastpathCopy::typeinfo(), fastpathcopy_grad_rule); CustomBackward::register_grad_rule( PixelShuffle::typeinfo(), pixelShuffle_grad_rule); + CustomBackward::register_grad_rule(MatrixMul::typeinfo(), matrix_mul_grad_rule); + CustomBackward::register_grad_rule( + BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule); } } _; diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 07cf66b9..358aa3a5 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -511,3 +511,45 @@ def test_pixel_shuffle(): y = f(x) grad(y, F.ones_like(y)) np.testing.assert_equal(2 * x.numpy(), x.grad.numpy()) + + +def test_matmul(): + def test_one(xdim, ydim, transposeA, transposeB): + xshape = (1, 4) if xdim == 1 else (2,) * (xdim - 2) + (3, 4) + yshape = (4, 1) if ydim == 1 else (2,) * (ydim - 2) + (4, 5) + x = np.random.rand(*xshape).astype("float32") + y = np.random.rand(*yshape).astype("float32") + gshape = (x @ y).shape + g = np.random.rand(*gshape).astype("float32") + dx = g @ np.swapaxes(y, -1, -2) + dy = np.swapaxes(x, -1, -2) @ g + while dx.shape != x.shape: + dx = dx.sum(0) + while dy.shape != y.shape: + dy = dy.sum(0) + if transposeA: + x = np.swapaxes(x, -1, -2) + dx = np.swapaxes(dx, -1, -2) + if transposeB: + y = np.swapaxes(y, -1, -2) + dy = np.swapaxes(dy, -1, -2) + x = mge.Tensor(x.squeeze()) + y = mge.Tensor(y.squeeze()) + g = mge.Tensor(g.squeeze()) + with Grad() as grad: + grad.wrt(x, callback=save_to(x)) + grad.wrt(y, callback=save_to(y)) + z = F.matmul(x, y, transpose_a=transposeA, transpose_b=transposeB) + grad(z, g) + np.testing.assert_almost_equal(dx.squeeze(), x.grad.numpy(), decimal=5) + np.testing.assert_almost_equal(dy.squeeze(), y.grad.numpy(), decimal=5) + + for xdim in [1, 2, 3, 4]: + for ydim in [1, 2, 3, 4]: + for transposeA in [False, True]: + if xdim == 1 and transposeA == True: + continue + for transposeB in [False, True]: + if ydim == 1 and transposeB == True: + continue + test_one(xdim, ydim, transposeA, transposeB) diff --git a/imperative/src/impl/ops/matmul.cpp b/imperative/src/impl/ops/matmul.cpp index 145dabfd..d30d4786 100644 --- a/imperative/src/impl/ops/matmul.cpp +++ b/imperative/src/impl/ops/matmul.cpp @@ -31,18 +31,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { DTypeScalar vi{-1}; auto graph = inputs[0]->owner_graph(); - bool remove_row = false, remove_col = false; - if (dim1 == 1) { - dim1 = 2; - remove_row = true; - inp1 = inp1.add_axis(0); - } - if (dim2 == 1) { - dim2 = 2; - remove_col = true; - inp2 = inp2.add_axis(1); - } - SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; if (dim1 > 2) { auto idx = opr::ImmutableTensor::make(*graph, vi, config); @@ -91,17 +79,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { result = result.reshape(tshp); } - auto maxdim = dim1 > dim2 ? dim1 : dim2; - if (remove_row) { - std::vector remove_param; - remove_param.push_back(Desc::make_remove(maxdim - 2)); - result = opr::AxisAddRemove::make(result, remove_param); - } - if (remove_col) { - std::vector remove_param; - remove_param.push_back(Desc::make_remove(maxdim - 1)); - result = opr::AxisAddRemove::make(result, remove_param); - } return result; } @@ -113,6 +90,19 @@ std::tuple, bool> infer_output_attrs_fallible( size_t dim1 = layout1.ndim, dim2 = layout2.ndim; DType dst_dtype; + if (dim1 == dim2 && dim2 >= 3) { // only happens in backward + for (size_t i = 1; i + 1 < layout1.ndim; ++i) { + layout1[0] *= layout1[i]; + layout2[0] *= layout2[i]; + } + layout1[1] = layout1[layout1.ndim - 1]; + layout1.ndim = 2; + layout1.init_contiguous_stride(); + layout2[1] = layout2[layout2.ndim - 1]; + layout2.ndim = 2; + layout2.init_contiguous_stride(); + dim1 = dim2 = 2; + } DnnOprCaller dnn_opr(inputs[0].comp_node); dnn_opr.op->param() = matmul.param(); @@ -156,6 +146,19 @@ SmallVector apply_on_physical_tensor( DnnOprCaller dnn_opr(cn); dnn_opr.op->param() = matmul.param(); + if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward + for (size_t i = 1; i + 1 < layout1.ndim; ++i) { + layout1[0] *= layout1[i]; + layout2[0] *= layout2[i]; + } + layout1[1] = layout1[layout1.ndim - 1]; + layout1.ndim = 2; + layout1.init_contiguous_stride(); + layout2[1] = layout2[layout2.ndim - 1]; + layout2.ndim = 2; + layout2.init_contiguous_stride(); + } + DType dst_dtype; dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); @@ -191,11 +194,7 @@ SmallVector apply_on_physical_tensor( } TensorLayout layout_a = layout1, layout_b = layout2; - if (dim1 == 1) { - layout_a.add_axis_cont_inplace(0); - inp_tensornds[0] = inputs[0]->dnn_tensor(); - inp_tensornds[0].layout = layout_a; - } else if (dim1 > 2) { + if (dim1 > 2) { size_t batch = std::accumulate( layout1.shape, layout1.shape + dim1 - 1, (size_t)1, std::multiplies()); @@ -216,13 +215,7 @@ SmallVector apply_on_physical_tensor( inp_tensornds[0] = inputs[0]->dnn_tensor(); } - if (dim2 == 1) { - layout_b.add_axis_inplace(1, 1, 1); - inp_tensornds[1] = inputs[1]->dnn_tensor(); - inp_tensornds[1].layout = layout_b; - } else { - inp_tensornds[1] = inputs[1]->dnn_tensor(); - } + inp_tensornds[1] = inputs[1]->dnn_tensor(); TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); dst_layout.init_contiguous_stride(); @@ -232,6 +225,11 @@ SmallVector apply_on_physical_tensor( if (matmul.transposeB) std::swap(layout_b.shape[0], layout_b.shape[1]); + if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward + inp_tensornds[0].layout = layout_a; + inp_tensornds[1].layout = layout_b; + } + DeviceTensorND out = BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); size_t sz = setup_algo( @@ -279,18 +277,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto graph = inputs[0]->owner_graph(); auto idx = opr::ImmutableTensor::make(*graph, vi, config); - bool remove_row = false, remove_col = false; - if (dim1 == 1) { - dim1 = 2; - remove_row = true; - inp1 = inp1.add_axis(0); - } - if (dim2 == 1) { - dim2 = 2; - remove_col = true; - inp2 = inp2.add_axis(1); - } - auto shp1 = inp1.symshape(); auto shp2 = inp2.symshape(); SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; @@ -349,16 +335,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn); result = result.reshape(result_shp); } - if (remove_row) { - std::vector remove_param; - remove_param.push_back(Desc::make_remove(maxdim - 2)); - result = opr::AxisAddRemove::make(result, remove_param); - } - if (remove_col) { - std::vector remove_param; - remove_param.push_back(Desc::make_remove(maxdim - 1)); - result = opr::AxisAddRemove::make(result, remove_param); - } return result; } @@ -418,21 +394,6 @@ SmallVector apply_on_physical_tensor( DType dst_dtype; dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); - bool remove_row = false, remove_col = false; - if (dim1 == 1) { - dim1 = 2; - remove_row = true; - } - if (dim2 == 1) { - dim2 = 2; - remove_col = true; - } - - if (remove_row) - layout1.add_axis_cont_inplace(0); - if (remove_col) - layout2.add_axis_inplace(1, 1, 1); - TensorShape tshp, batch_shp; size_t j = 0; auto inp1 = inputs[0], inp2 = inputs[1]; @@ -530,12 +491,6 @@ SmallVector apply_on_physical_tensor( if (maxdim > 3) { dst_layout = dst_layout.reshape(shp1); } - if (remove_row) { - dst_layout = dst_layout.remove_axis(maxdim - 2); - } - if (remove_col) { - dst_layout = dst_layout.remove_axis(maxdim - 1); - } return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))}; }