@@ -52,6 +52,7 @@ __all__ = [ | |||||
"logsigmoid", | "logsigmoid", | ||||
"logsumexp", | "logsumexp", | ||||
"logsoftmax", | "logsoftmax", | ||||
"matinv", | |||||
"matmul", | "matmul", | ||||
"max_pool2d", | "max_pool2d", | ||||
"one_hot", | "one_hot", | ||||
@@ -1002,6 +1003,38 @@ def remap( | |||||
return result | return result | ||||
def matinv(inp: Tensor) -> Tensor: | |||||
""" | |||||
Computes the inverse of a batch of matrices; input must has shape [..., n, n]. | |||||
:param inp: input tensor. | |||||
:return: output tensor. | |||||
Examples: | |||||
.. testcode:: | |||||
import numpy as np | |||||
from megengine import tensor | |||||
import megengine.functional as F | |||||
data = tensor([[1.0, 0.0], [1.0, 1.0]]) | |||||
out = F.matinv(data) | |||||
print(out.numpy()) | |||||
Outputs: | |||||
.. testoutput:: | |||||
[[ 1. 0.] | |||||
[-1. 1.]] | |||||
""" | |||||
(result,) = apply(builtin.MatrixInverse(), inp) | |||||
return result | |||||
def matmul( | def matmul( | ||||
inp1: Tensor, | inp1: Tensor, | ||||
inp2: Tensor, | inp2: Tensor, | ||||
@@ -60,6 +60,25 @@ def test_dropout(): | |||||
assert out.numpy().sum() >= 0.0 | assert out.numpy().sum() >= 0.0 | ||||
def test_matinv(): | |||||
shape1 = (5, 5) | |||||
shape2 = (3, 9, 9) | |||||
data1 = np.random.random(shape1).astype("float32") | |||||
data2 = np.random.random(shape2).astype("float32") | |||||
cases = [ | |||||
{"input": data1}, | |||||
{"input": data2}, | |||||
] | |||||
opr_test( | |||||
cases, | |||||
F.matinv, | |||||
compare_fn=lambda x, y: np.testing.assert_allclose(x.numpy(), y, rtol=1e-5), | |||||
ref_fn=np.linalg.inv, | |||||
) | |||||
def test_matmul(): | def test_matmul(): | ||||
shape1 = 3 | shape1 = 3 | ||||
shape2 = 3 | shape2 = 3 | ||||
@@ -0,0 +1,36 @@ | |||||
/** | |||||
* \file imperative/src/impl/ops/matrix_inverse.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "../op_trait.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
#include "megbrain/opr/blas.h" | |||||
namespace mgb{ | |||||
namespace imperative { | |||||
namespace { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
mgb_assert(inputs.size() == 1); | |||||
return opr::MatrixInverse::make(inputs[0]); | |||||
} | |||||
OP_TRAIT_REG(MatrixInverse, MatrixInverse) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
} // anonymous namespace | |||||
} // namespace imperative | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
@@ -34,6 +34,8 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { | |||||
let results = (outs AnyType); | let results = (outs AnyType); | ||||
} | } | ||||
def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>; | |||||
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; | def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; | ||||
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; | def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; | ||||