Browse Source

feat(mge/functional): add matinv

GitOrigin-RevId: d4fa8a8277
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
df976782fa
4 changed files with 90 additions and 0 deletions
  1. +33
    -0
      imperative/python/megengine/functional/nn.py
  2. +19
    -0
      imperative/python/test/unit/functional/test_functional.py
  3. +36
    -0
      imperative/src/impl/ops/matrix_inverse.cpp
  4. +2
    -0
      src/core/include/megbrain/ir/ops.td

+ 33
- 0
imperative/python/megengine/functional/nn.py View File

@@ -52,6 +52,7 @@ __all__ = [
"logsigmoid",
"logsumexp",
"logsoftmax",
"matinv",
"matmul",
"max_pool2d",
"one_hot",
@@ -1002,6 +1003,38 @@ def remap(
return result


def matinv(inp: Tensor) -> Tensor:
"""
Computes the inverse of a batch of matrices; input must has shape [..., n, n].

:param inp: input tensor.
:return: output tensor.

Examples:

.. testcode::

import numpy as np
from megengine import tensor
import megengine.functional as F

data = tensor([[1.0, 0.0], [1.0, 1.0]])
out = F.matinv(data)
print(out.numpy())

Outputs:

.. testoutput::

[[ 1. 0.]
[-1. 1.]]

"""

(result,) = apply(builtin.MatrixInverse(), inp)
return result


def matmul(
inp1: Tensor,
inp2: Tensor,


+ 19
- 0
imperative/python/test/unit/functional/test_functional.py View File

@@ -60,6 +60,25 @@ def test_dropout():
assert out.numpy().sum() >= 0.0


def test_matinv():
shape1 = (5, 5)
shape2 = (3, 9, 9)
data1 = np.random.random(shape1).astype("float32")
data2 = np.random.random(shape2).astype("float32")

cases = [
{"input": data1},
{"input": data2},
]

opr_test(
cases,
F.matinv,
compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-5),
ref_fn=np.linalg.inv,
)


def test_matmul():
shape1 = 3
shape2 = 3


+ 36
- 0
imperative/src/impl/ops/matrix_inverse.cpp View File

@@ -0,0 +1,36 @@
/**
* \file imperative/src/impl/ops/matrix_inverse.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/

#include "../op_trait.h"

#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/blas.h"

namespace mgb{
namespace imperative {

namespace {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
mgb_assert(inputs.size() == 1);
return opr::MatrixInverse::make(inputs[0]);
}
OP_TRAIT_REG(MatrixInverse, MatrixInverse)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // anonymous namespace

} // namespace imperative
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}


+ 2
- 0
src/core/include/megbrain/ir/ops.td View File

@@ -34,6 +34,8 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
let results = (outs AnyType);
}

def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>;

def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;

def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;


Loading…
Cancel
Save