diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 7d1ca44e..e98135e5 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -52,6 +52,7 @@ __all__ = [ "logsigmoid", "logsumexp", "logsoftmax", + "matinv", "matmul", "max_pool2d", "one_hot", @@ -1002,6 +1003,38 @@ def remap( return result +def matinv(inp: Tensor) -> Tensor: + """ + Computes the inverse of a batch of matrices; input must has shape [..., n, n]. + + :param inp: input tensor. + :return: output tensor. + + Examples: + + .. testcode:: + + import numpy as np + from megengine import tensor + import megengine.functional as F + + data = tensor([[1.0, 0.0], [1.0, 1.0]]) + out = F.matinv(data) + print(out.numpy()) + + Outputs: + + .. testoutput:: + + [[ 1. 0.] + [-1. 1.]] + + """ + + (result,) = apply(builtin.MatrixInverse(), inp) + return result + + def matmul( inp1: Tensor, inp2: Tensor, diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 49e3e817..9bcfe276 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -60,6 +60,25 @@ def test_dropout(): assert out.numpy().sum() >= 0.0 +def test_matinv(): + shape1 = (5, 5) + shape2 = (3, 9, 9) + data1 = np.random.random(shape1).astype("float32") + data2 = np.random.random(shape2).astype("float32") + + cases = [ + {"input": data1}, + {"input": data2}, + ] + + opr_test( + cases, + F.matinv, + compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-5), + ref_fn=np.linalg.inv, + ) + + def test_matmul(): shape1 = 3 shape2 = 3 diff --git a/imperative/src/impl/ops/matrix_inverse.cpp b/imperative/src/impl/ops/matrix_inverse.cpp new file mode 100644 index 00000000..b20794e0 --- /dev/null +++ b/imperative/src/impl/ops/matrix_inverse.cpp @@ -0,0 +1,36 @@ +/** + * \file imperative/src/impl/ops/matrix_inverse.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "../op_trait.h" + +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/blas.h" + +namespace mgb{ +namespace imperative { + +namespace { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + mgb_assert(inputs.size() == 1); + return opr::MatrixInverse::make(inputs[0]); +} +OP_TRAIT_REG(MatrixInverse, MatrixInverse) + .apply_on_var_node(apply_on_var_node) + .fallback(); +} // anonymous namespace + +} // namespace imperative +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} + diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 7ac9f049..6e71da49 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -34,6 +34,8 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { let results = (outs AnyType); } +def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>; + def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;