GitOrigin-RevId: 15eb08bacb
tags/v1.3.0
@@ -13,19 +13,23 @@ import numbers | |||
from typing import Optional, Sequence, Tuple, Union | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core._trace_option import use_symbolic_shape | |||
from ..core.ops import builtin | |||
from ..core.ops.special import Const | |||
from ..core.tensor import utils | |||
from ..tensor import Tensor | |||
from .debug_param import get_conv_execution_strategy | |||
from .elemwise import clip, exp, log, log1p | |||
from .tensor import reshape, squeeze | |||
from .tensor import broadcast_to, concat, expand_dims, reshape, squeeze | |||
__all__ = [ | |||
"argmax", | |||
"argmin", | |||
"argsort", | |||
"dot", | |||
"isinf", | |||
"isnan", | |||
"matmul", | |||
"max", | |||
"mean", | |||
"min", | |||
@@ -36,6 +40,7 @@ __all__ = [ | |||
"sort", | |||
"std", | |||
"sum", | |||
"svd", | |||
"topk", | |||
"var", | |||
] | |||
@@ -663,7 +668,7 @@ def topk( | |||
no_sort: bool = False, | |||
) -> Tuple[Tensor, Tensor]: | |||
r""" | |||
Selects the ``Top-K``(by default) smallest elements of 2d matrix by row. | |||
Selects the ``Top-K`` (by default) smallest elements of 2d matrix by row. | |||
:param inp: input tensor. If input tensor is 2d, each row will be sorted. | |||
:param k: number of elements needed. | |||
@@ -722,3 +727,204 @@ def topk( | |||
if descending: | |||
tns = -tns | |||
return tns, ind | |||
def matmul( | |||
inp1: Tensor, | |||
inp2: Tensor, | |||
transpose_a=False, | |||
transpose_b=False, | |||
compute_mode="DEFAULT", | |||
format="DEFAULT", | |||
) -> Tensor: | |||
""" | |||
Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. | |||
With different inputs dim, this function behaves differently: | |||
- Both 1-D tensor, simply forward to ``dot``. | |||
- Both 2-D tensor, normal matrix multiplication. | |||
- If one input tensor is 1-D, matrix vector multiplication. | |||
- If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will be broadcasted. For example: | |||
- inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` | |||
- inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` | |||
- inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)` | |||
:param inp1: first matrix to be multiplied. | |||
:param inp2: second matrix to be multiplied. | |||
:return: output tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2)) | |||
out = F.matmul(data1, data2) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[10. 13.] | |||
[28. 40.]] | |||
""" | |||
remove_row, remove_col = False, False | |||
inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
dim1, dim2 = inp1.ndim, inp2.ndim | |||
# handle dim=1 cases, dot and matrix-vector multiplication | |||
if dim1 == 1 and dim2 == 1: | |||
return dot(inp1, inp2) | |||
# the underlying matmul op requires input dims to be at least 2 | |||
if dim1 == 1: | |||
inp1 = expand_dims(inp1, 0) | |||
dim1 = 2 | |||
remove_row = True | |||
if dim2 == 1: | |||
inp2 = expand_dims(inp2, 1) | |||
dim2 = 2 | |||
remove_col = True | |||
batch_shape = None | |||
shape1 = inp1.shape | |||
shape2 = inp2.shape | |||
maxdim = dim1 if dim1 > dim2 else dim2 | |||
if dim1 >= 3 or dim2 >= 3: | |||
if use_symbolic_shape(): | |||
if dim1 > dim2: | |||
shape2 = concat([shape1[:-2], shape2[-2:]]) | |||
inp2 = broadcast_to(inp2, shape2) | |||
if dim1 < dim2: | |||
shape1 = concat([shape2[:-2], shape1[-2:]]) | |||
inp1 = broadcast_to(inp1, shape1) | |||
if maxdim > 3: | |||
batch_shape = shape1[:-2] | |||
# compress inputs to 3d | |||
(inp1,) = apply( | |||
builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) | |||
) | |||
(inp2,) = apply( | |||
builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) | |||
) | |||
else: | |||
if dim1 > dim2: | |||
shape2 = shape1[:-2] + shape2[-2:] | |||
inp2 = broadcast_to(inp2, shape2) | |||
if dim1 < dim2: | |||
shape1 = shape2[:-2] + shape1[-2:] | |||
inp1 = broadcast_to(inp1, shape1) | |||
if maxdim > 3: | |||
batch_shape = shape1[:-2] | |||
# compress inputs to 3d | |||
inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) | |||
inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) | |||
op = builtin.BatchedMatrixMul( | |||
transposeA=transpose_a, | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=get_conv_execution_strategy(), | |||
) | |||
else: | |||
op = builtin.MatrixMul( | |||
transposeA=transpose_a, | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=get_conv_execution_strategy(), | |||
) | |||
(result,) = apply(op, inp1, inp2) | |||
if maxdim > 3: | |||
if use_symbolic_shape(): | |||
(result,) = apply( | |||
builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) | |||
) | |||
else: | |||
result = result.reshape(batch_shape + result.shape[-2:]) | |||
if remove_row: | |||
result = squeeze(result, axis=-2) | |||
if remove_col: | |||
result = squeeze(result, axis=-1) | |||
return result | |||
def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||
""" | |||
Computes dot-product of two vectors ``inp1`` and ``inp2``. | |||
inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted. | |||
Refer to :func:`~.matmul` for more general usage. | |||
:param inp1: first vector. | |||
:param inp2: second vector. | |||
:return: output value. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data1 = tensor(np.arange(0, 6, dtype=np.float32)) | |||
data2 = tensor(np.arange(0, 6, dtype=np.float32)) | |||
out = F.dot(data1, data2) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
55. | |||
""" | |||
op = builtin.Dot() | |||
inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
assert ( | |||
inp1.ndim <= 1 and inp2.ndim <= 1 | |||
), "Input tensors for dot must be 1-dimensional or scalar" | |||
(result,) = apply(op, inp1, inp2) | |||
utils.setscalar(result) | |||
return result | |||
def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: | |||
""" | |||
Computes the singular value decompositions of input matrix. | |||
:param inp: input matrix, must has shape `[..., M, N]`. | |||
:return: output matrices, `(U, sigma, V)`. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3)) | |||
_, y, _ = F.svd(x) | |||
print(y.numpy().round(decimals=3)) | |||
Outputs: | |||
.. testoutput:: | |||
[7.348 1. ] | |||
""" | |||
op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) | |||
U, sigma, V = apply(op, inp) | |||
return U, sigma, V |
@@ -25,7 +25,7 @@ from ..utils.tuple_function import _pair, _pair_nonzero | |||
from .debug_param import get_conv_execution_strategy | |||
from .distributed import all_reduce_sum | |||
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | |||
from .math import argsort, max, prod, sum | |||
from .math import argsort, matmul, max, prod, sum | |||
from .tensor import ( | |||
broadcast_to, | |||
concat, | |||
@@ -46,7 +46,6 @@ __all__ = [ | |||
"conv_transpose2d", | |||
"deformable_conv2d", | |||
"deformable_psroi_pooling", | |||
"dot", | |||
"dropout", | |||
"indexing_one_hot", | |||
"leaky_relu", | |||
@@ -55,7 +54,6 @@ __all__ = [ | |||
"logsumexp", | |||
"logsoftmax", | |||
"matinv", | |||
"matmul", | |||
"max_pool2d", | |||
"one_hot", | |||
"prelu", | |||
@@ -63,7 +61,6 @@ __all__ = [ | |||
"resize", | |||
"softmax", | |||
"softplus", | |||
"svd", | |||
"warp_affine", | |||
"warp_perspective", | |||
"conv1d", | |||
@@ -1221,207 +1218,6 @@ def matinv(inp: Tensor) -> Tensor: | |||
return result | |||
def matmul( | |||
inp1: Tensor, | |||
inp2: Tensor, | |||
transpose_a=False, | |||
transpose_b=False, | |||
compute_mode="DEFAULT", | |||
format="DEFAULT", | |||
) -> Tensor: | |||
""" | |||
Performs a matrix multiplication of the matrices ``inp1`` and ``inp2``. | |||
With different inputs dim, this function behaves differently: | |||
- Both 1-D tensor, simply forward to ``dot``. | |||
- Both 2-D tensor, normal matrix multiplication. | |||
- If one input tensor is 1-D, matrix vector multiplication. | |||
- If at least one tensor are 3-dimensional or >3-dimensional, the other tensor should have dim >= 2, the batched matrix-matrix is returned, and the tensor with smaller dimension will | |||
be broadcasted. For example: | |||
- inp1: `(n, k, m)`, inp2: `(n, m, p)`, return: `(n, k, p)` | |||
- inp1: `(n, k, m)`, inp2: `(m, p)`, return: `(n, k, p)` | |||
- inp1: `(n, j, k, m)`, inp2: `(n, j, m, p)`, return: `(n, j, k, p)` | |||
:param inp1: first matrix to be multiplied. | |||
:param inp2: second matrix to be multiplied. | |||
:return: output tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data1 = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
data2 = tensor(np.arange(0, 6, dtype=np.float32).reshape(3, 2)) | |||
out = F.matmul(data1, data2) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[10. 13.] | |||
[28. 40.]] | |||
""" | |||
remove_row, remove_col = False, False | |||
inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
dim1, dim2 = inp1.ndim, inp2.ndim | |||
# handle dim=1 cases, dot and matrix-vector multiplication | |||
if dim1 == 1 and dim2 == 1: | |||
return dot(inp1, inp2) | |||
# the underlying matmul op requires input dims to be at least 2 | |||
if dim1 == 1: | |||
inp1 = expand_dims(inp1, 0) | |||
dim1 = 2 | |||
remove_row = True | |||
if dim2 == 1: | |||
inp2 = expand_dims(inp2, 1) | |||
dim2 = 2 | |||
remove_col = True | |||
batch_shape = None | |||
shape1 = inp1.shape | |||
shape2 = inp2.shape | |||
maxdim = dim1 if dim1 > dim2 else dim2 | |||
if dim1 >= 3 or dim2 >= 3: | |||
if use_symbolic_shape(): | |||
if dim1 > dim2: | |||
shape2 = concat([shape1[:-2], shape2[-2:]]) | |||
inp2 = broadcast_to(inp2, shape2) | |||
if dim1 < dim2: | |||
shape1 = concat([shape2[:-2], shape1[-2:]]) | |||
inp1 = broadcast_to(inp1, shape1) | |||
if maxdim > 3: | |||
batch_shape = shape1[:-2] | |||
# compress inputs to 3d | |||
(inp1,) = apply( | |||
builtin.Reshape(), inp1, concat([prod(shape1[:-2]), shape1[-2:]]) | |||
) | |||
(inp2,) = apply( | |||
builtin.Reshape(), inp2, concat([prod(shape2[:-2]), shape2[-2:]]) | |||
) | |||
else: | |||
if dim1 > dim2: | |||
shape2 = shape1[:-2] + shape2[-2:] | |||
inp2 = broadcast_to(inp2, shape2) | |||
if dim1 < dim2: | |||
shape1 = shape2[:-2] + shape1[-2:] | |||
inp1 = broadcast_to(inp1, shape1) | |||
if maxdim > 3: | |||
batch_shape = shape1[:-2] | |||
# compress inputs to 3d | |||
inp1 = inp1.reshape((-1, shape1[-2], shape1[-1])) | |||
inp2 = inp2.reshape((-1, shape2[-2], shape2[-1])) | |||
op = builtin.BatchedMatrixMul( | |||
transposeA=transpose_a, | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=get_conv_execution_strategy(), | |||
) | |||
else: | |||
op = builtin.MatrixMul( | |||
transposeA=transpose_a, | |||
transposeB=transpose_b, | |||
compute_mode=compute_mode, | |||
format=format, | |||
strategy=get_conv_execution_strategy(), | |||
) | |||
(result,) = apply(op, inp1, inp2) | |||
if maxdim > 3: | |||
if use_symbolic_shape(): | |||
(result,) = apply( | |||
builtin.Reshape(), result, concat([batch_shape, result.shape[-2:]]) | |||
) | |||
else: | |||
result = result.reshape(batch_shape + result.shape[-2:]) | |||
if remove_row: | |||
result = squeeze(result, axis=-2) | |||
if remove_col: | |||
result = squeeze(result, axis=-1) | |||
return result | |||
def dot(inp1: Tensor, inp2: Tensor) -> Tensor: | |||
""" | |||
Computes dot-product of two vectors ``inp1`` and ``inp2``. | |||
inputs must be 1-dimensional or scalar. A scalar input is automatically broadcasted. | |||
Refer to :func:`~.matmul` for more general usage. | |||
:param inp1: first vector. | |||
:param inp2: second vector. | |||
:return: output value. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data1 = tensor(np.arange(0, 6, dtype=np.float32)) | |||
data2 = tensor(np.arange(0, 6, dtype=np.float32)) | |||
out = F.dot(data1, data2) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
55. | |||
""" | |||
op = builtin.Dot() | |||
inp1, inp2 = utils.convert_inputs(inp1, inp2) | |||
assert ( | |||
inp1.ndim <= 1 and inp2.ndim <= 1 | |||
), "Input tensors for dot must be 1-dimensional or scalar" | |||
(result,) = apply(op, inp1, inp2) | |||
setscalar(result) | |||
return result | |||
def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: | |||
""" | |||
Computes the singular value decompositions of input matrix. | |||
:param inp: input matrix, must has shape `[..., M, N]`. | |||
:return: output matrices, `(U, sigma, V)`. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2,3)) | |||
_, y, _ = F.svd(x) | |||
print(y.numpy().round(decimals=3)) | |||
Outputs: | |||
.. testoutput:: | |||
[7.348 1. ] | |||
""" | |||
op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) | |||
U, sigma, V = apply(op, inp) | |||
return U, sigma, V | |||
def interpolate( | |||
inp: Tensor, | |||
size: Optional[Union[int, Tuple[int, int]]] = None, | |||
@@ -707,7 +707,7 @@ def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
:param inp: input tensor. | |||
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1, | |||
and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | |||
and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | |||
* (``'x'``) -> make a 0d (scalar) into a 1d vector | |||
* (0, 1) -> identity for 2d vectors | |||