GitOrigin-RevId: acb287d48b
tags/v1.0.0-rc1
@@ -14,7 +14,7 @@ from .math import * | |||
from .nn import * | |||
from .quantized import conv_bias_activation | |||
from .tensor import * | |||
from .utils import accuracy, copy, zero_grad | |||
from .utils import accuracy, copy | |||
from . import distributed # isort:skip | |||
@@ -33,6 +33,7 @@ __all__ = [ | |||
"eq", | |||
"exp", | |||
"expm1", | |||
"fast_tanh", | |||
"floor", | |||
"floor_div", | |||
"gt", | |||
@@ -67,7 +68,6 @@ __all__ = [ | |||
"sub", | |||
"tan", | |||
"tanh", | |||
"fast_tanh", | |||
] | |||
@@ -108,13 +108,37 @@ def _elemwise_multi_type(*args, mode, **kwargs): | |||
def add(x, y): | |||
"""Element-wise addition. | |||
At least one operand should be tensor. | |||
same for sub/mul/div/floor_div/pow/mod/atan2/eq/ne/lt/le/gt/ge/maximum/minmium. | |||
Same for sub/mul/div/floor_div/pow/mod/atan2/eq/ne/lt/le/gt/ge/maximum/minmium. | |||
:param x: input tensor. | |||
:return: computed tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
out = F.add(x, y) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[ 0. 2. 4.] | |||
[ 6. 8. 10.]] | |||
""" | |||
return _elwise(x, y, mode="add") | |||
def sub(x, y): | |||
"""Element-wise subtract.""" | |||
"""Element-wise subtraction.""" | |||
return _elwise(x, y, mode="sub") | |||
@@ -173,24 +197,23 @@ def log1p(x): | |||
return _elwise(x, mode="log1p") | |||
def sqrt(inp: Tensor) -> Tensor: | |||
""" | |||
Return a new tensor with the square-root of the elements of ``inp``. | |||
For negative value, return nan. | |||
def sqrt(x: Tensor) -> Tensor: | |||
"""Element-wise sqrt. | |||
For negative input value, return ``NaN``. | |||
:param inp: The input tensor | |||
:return: The computed tensor | |||
:param x: input tensor. | |||
:return: computed tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = mge.tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
out = F.sqrt(data) | |||
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
out = F.sqrt(x) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -201,12 +224,12 @@ def sqrt(inp: Tensor) -> Tensor: | |||
[1.7321 2. 2.2361]] | |||
""" | |||
return inp ** 0.5 | |||
return x ** 0.5 | |||
def square(inp: Tensor) -> Tensor: | |||
def square(x: Tensor) -> Tensor: | |||
""" | |||
Return a new tensor with the square of the elements of ``inp`` | |||
Return a new tensor with the square of the elements of input tensor. | |||
:param inp: The input tensor | |||
:return: The computed tensor | |||
@@ -231,92 +254,129 @@ def square(inp: Tensor) -> Tensor: | |||
[ 9. 16. 25.]] | |||
""" | |||
return inp ** 2 | |||
return x ** 2 | |||
def round(x): | |||
"""Round tensor to int element-wise.""" | |||
"""Element-wise rounding to int.""" | |||
return _elwise(x, mode="round") | |||
def ceil(x): | |||
"""Return the ceil of the input, element-wise.""" | |||
"""Element-wise ceiling.""" | |||
return _elwise(x, mode="ceil") | |||
def floor(x): | |||
"""Calculate the floor element-wise""" | |||
"""Element-wise floor.""" | |||
return _elwise(x, mode="floor") | |||
def maximum(x, y): | |||
"""Element-wise maximum of array elements.""" | |||
return _elwise(x, y, mode="max") | |||
def minimum(x, y): | |||
"""Element-wise minimum of array elements.""" | |||
return _elwise(x, y, mode="min") | |||
# trigonometric functions | |||
def cos(x): | |||
"""Cosine, element-wise.""" | |||
"""Element-wise cosine. | |||
:param x: input tensor. | |||
:return: computed tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
out = F.cos(x) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[ 1. 0.5403 -0.4161] | |||
[-0.99 -0.6536 0.2837]] | |||
""" | |||
return _elwise(x, mode="cos") | |||
def sin(x): | |||
"""Sine, element-wise.""" | |||
"""Element-wise sine.""" | |||
return _elwise(x, mode="sin") | |||
def tan(x): | |||
"""Element-wise tangent.""" | |||
return sin(x) / cos(x) | |||
def acos(x): | |||
"""Inverse cosine, element-wise.""" | |||
"""Element-wise inverse cosine.""" | |||
return _elwise(x, mode="acos") | |||
def asin(x): | |||
"""Inverse sine, element-wise.""" | |||
"""Element-wise inverse sine.""" | |||
return _elwise(x, mode="asin") | |||
def atan(x): | |||
"""Element-wise inverse tangent.""" | |||
return _elwise(x, 1, mode="atan2") | |||
def atan2(y, x): | |||
"""Element-wise 2-argument arctangent.""" | |||
return _elwise(y, x, mode="atan2") | |||
def cosh(x): | |||
r"""Compute element-wise hyperbolic cosine.""" | |||
r"""Element-wise hyperbolic cosine.""" | |||
return 0.5 * (exp(x) + exp(-x)) | |||
def sinh(x): | |||
r"""Compute element-wise hyperbolic sine.""" | |||
r"""Element-wise hyperbolic sine.""" | |||
u = expm1(x) | |||
return 0.5 * u / (u + 1) * (u + 2) | |||
def tanh(x): | |||
r"""Compute element-wise hyperbolic tangent.""" | |||
r"""Element-wise hyperbolic tangent.""" | |||
return _elwise(x, mode="tanh") | |||
def asinh(x): | |||
r"""Compute element-wise inverse hyperbolic sine.""" | |||
r"""Element-wise inverse hyperbolic sine.""" | |||
return log(x + (x ** 2 + 1) ** 0.5) | |||
def acosh(x): | |||
r"""Compute element-wise inverse hyperbolic cosine.""" | |||
r"""Element-wise inverse hyperbolic cosine.""" | |||
return log(x + (x ** 2 - 1) ** 0.5) | |||
def atanh(x): | |||
r"""Compute element-wise inverse hyperbolic tangent.""" | |||
r"""Element-wise inverse hyperbolic tangent.""" | |||
return log1p(2 * x / (1 - x)) / 2 | |||
def fast_tanh(x): | |||
r"""Compute element-wise fast tanh; this is an approximation: | |||
r"""Element-wise fast tanh; this is an approximation: | |||
.. math:: | |||
\text{fast_tanh}(x) = x * (27. + x * x) / (27. + 9. * x * x) | |||
@@ -328,29 +388,60 @@ def fast_tanh(x): | |||
def left_shift(x, y): | |||
"""Element-wise bitwise binary: x << y. | |||
:param x: input tensor, should be int. | |||
:param y: how many bits to be left-shifted. | |||
:return: computed tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(0, 6, dtype=np.int32).reshape(2, 3)) | |||
out = F.left_shift(x, 2) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[ 0 4 8] | |||
[12 16 20]] | |||
""" | |||
return _elwise(x, y, mode="shl") | |||
def right_shift(x, y): | |||
return _elwise(x, y, mode="shl") | |||
"""Element-wise bitwise binary: x >> y.""" | |||
return _elwise(x, y, mode="shr") | |||
# logical functions | |||
def logical_and(x, y): | |||
"""Element-wise logical and: x && y.""" | |||
return _elwise(x, y, mode="AND") | |||
def logical_not(x): | |||
"""Element-wise logical not: ~x.""" | |||
return _elwise(x, mode="NOT") | |||
def logical_or(x, y): | |||
"""Element-wise logical or: x || y.""" | |||
return _elwise(x, y, mode="OR") | |||
def logical_xor(x, y): | |||
"""Element-wise logical xor: x ^ y.""" | |||
return _elwise(x, y, mode="XOR") | |||
@@ -358,72 +449,112 @@ def logical_xor(x, y): | |||
def eq(x, y): | |||
"""Return (x == y) element-wise.""" | |||
"""Element-wise (x == y). | |||
:param x: input tensor 1. | |||
:param y: input tensor 2. | |||
:return: computed tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
y = tensor(np.arange(0, 6, dtype=np.float32).reshape(2, 3)) | |||
out = F.eq(x, y) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1. 1. 1.] | |||
[1. 1. 1.]] | |||
""" | |||
return _elwise(x, y, mode="eq") | |||
def ne(x, y): | |||
"""Element-wise (x != y).""" | |||
return x != y | |||
def lt(x, y): | |||
"""Return (x < y) element-wise.""" | |||
"""Element-wise (x < y).""" | |||
return _elwise(x, y, mode="lt") | |||
def le(x, y): | |||
"""Return (x =< y) element-wise.""" | |||
"""Element-wise (x <= y).""" | |||
return _elwise(x, y, mode="leq") | |||
def gt(x, y): | |||
"""Return (x > y) element-wise.""" | |||
"""Element-wise (x > y).""" | |||
return _elwise(y, x, mode="lt") | |||
def ge(x, y): | |||
"""Return (x >= y) element-wise""" | |||
"""Element-wise (x >= y).""" | |||
return _elwise(y, x, mode="leq") | |||
# other functions | |||
def hswish(x): | |||
"""Return x * relu6(x + 3) / 6 element-wise""" | |||
"""Element-wise x * relu6(x + 3) / 6. | |||
:param x: input tensor. | |||
:return: computed tensor. | |||
Example: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.arange(5).astype(np.float32)) | |||
out = F.hswish(x) | |||
print(out.numpy()) | |||
.. testoutput:: | |||
[0. 0.6667 1.6667 3. 4. ] | |||
""" | |||
return _elwise(x, mode="h_swish") | |||
def hsigmoid(x): | |||
"""Return relu6(x + 3) / 6 element-wise""" | |||
"""Element-wise relu6(x + 3) / 6.""" | |||
return relu6(x + 3) / 6 | |||
def relu(x): | |||
"""Return `max(x, 0)` element-wise.""" | |||
"""Element-wise `max(x, 0)`.""" | |||
return _elwise(x, mode="relu") | |||
def relu6(x): | |||
"""Return min(max(x, 0), 6) element-wise.""" | |||
"""Element-wise min(max(x, 0), 6).""" | |||
return minimum(maximum(x, 0), 6) | |||
def sigmoid(x): | |||
"""Return 1 / ( 1 + exp( -x ) ) element-wise.""" | |||
"""Element-wise 1 / ( 1 + exp( -x ) ).""" | |||
return _elwise(x, mode="sigmoid") | |||
def maximum(x, y): | |||
"""Element-wise maximum of array elements.""" | |||
return _elwise(x, y, mode="max") | |||
def minimum(x, y): | |||
"""Element-wise minimum of array elements.""" | |||
return _elwise(x, y, mode="min") | |||
def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: | |||
r""" | |||
Clamp all elements in :attr:`inp` into the range `[` :attr:`lower`, :attr:`upper` `]` and return | |||
def clamp(x: Tensor, lower=None, upper=None) -> Tensor: | |||
r"""Clamps all elements in input tensor into the range `[` :attr:`lower`, :attr:`upper` `]` and returns | |||
a resulting tensor: | |||
.. math:: | |||
@@ -433,9 +564,10 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: | |||
\text{upper} & \text{if } x_i > \text{upper} | |||
\end{cases} | |||
:param inp: the input tensor. | |||
:param lower: lower-bound of the range to be clamped to | |||
:param upper: upper-bound of the range to be clamped to | |||
:param x: input tensor. | |||
:param lower: lower-bound of the range to be clamped to. | |||
:param upper: upper-bound of the range to be clamped to. | |||
:return: output clamped tensor. | |||
Examples: | |||
@@ -444,12 +576,10 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
a = tensor(np.arange(5).astype(np.int32)) | |||
a = tensor(np.arange(5).astype(np.int32)) | |||
print(F.clamp(a, 2, 4).numpy()) | |||
print(F.clamp(a, lower=3).numpy()) | |||
print(F.clamp(a, upper=3).numpy()) | |||
Outputs: | |||
@@ -467,8 +597,8 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: | |||
if lower is not None: | |||
if upper is not None: | |||
assert lower <= upper, "clamp lower bound is bigger that upper bound" | |||
return minimum(maximum(inp, lower), upper) | |||
return minimum(maximum(x, lower), upper) | |||
else: | |||
return maximum(inp, lower) | |||
return maximum(x, lower) | |||
else: | |||
return minimum(inp, upper) | |||
return minimum(x, upper) |
@@ -9,22 +9,22 @@ | |||
# pylint: disable=too-many-lines | |||
from typing import List | |||
from ..core import Tensor | |||
from ..tensor import Tensor | |||
def cambricon_subgraph( | |||
inputs: List[Tensor], data: bytes, symbol: str, tensor_dim_mutable: bool, | |||
) -> List[Tensor]: | |||
"""Load a serialized Cambricon subgraph (i.e. cnrtModel_t) and | |||
"""Loads a serialized Cambricon subgraph (i.e. cnrtModel_t) and | |||
execute the operations defined in the subgraph. | |||
:param inputs: List of input tensors of the subgraph. | |||
:param data: The serialized subgraph. | |||
:param symbol: The name of the function in the subgraph. | |||
:param inputs: list of input tensors of the subgraph. | |||
:param data: the serialized subgraph. | |||
:param symbol: the name of the function in the subgraph. | |||
The function is corresponding to a cnmlFusionOp | |||
which is added to the cnmlModel_t/cnrtModel_t. | |||
:param tensor_dim_mutable: Whether the input tensors' shapes are mutalbe | |||
in cnrtModel_t | |||
:param tensor_dim_mutable: whether the input tensors' shapes are mutalbe | |||
in cnrtModel_t. | |||
""" | |||
raise NotImplementedError | |||
@@ -32,13 +32,13 @@ def cambricon_subgraph( | |||
def extern_opr_subgraph( | |||
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, | |||
) -> List[Tensor]: | |||
"""Load a serialized extern opr subgraph and fake execute the operator | |||
"""Loads a serialized extern opr subgraph and fake execute the operator. | |||
:param inputs: Tensor or list of input tensors. | |||
:param output_shapes: The output shapes. | |||
:param dump_name: The serialized subgraph name. | |||
:param dump_data: The serialized subgraph. | |||
:param inputs: tensor or list of input tensors. | |||
:param output_shapes: the output shapes. | |||
:param dump_name: the serialized subgraph name. | |||
:param dump_data: the serialized subgraph. | |||
:return: List of tensors | |||
:return: list of tensors. | |||
""" | |||
raise NotImplementedError |
@@ -9,7 +9,7 @@ | |||
import collections | |||
from typing import Iterable, Optional, Union | |||
from ..core.tensor import Tensor | |||
from ..tensor import Tensor | |||
def add_update( | |||
@@ -20,7 +20,7 @@ def add_update( | |||
beta: Union[Tensor, float, int] = 1.0, | |||
bias: Union[Tensor, float, int] = 0.0 | |||
): | |||
r"""Inplace modify ``dest`` as follows: | |||
r"""Modify ``dest`` inplace as follows: | |||
.. math:: | |||
dest = alpha * dest + beta * delta + bias | |||
@@ -11,9 +11,8 @@ import numpy as np | |||
from ..core.tensor.utils import make_shape_tuple | |||
from ..tensor import Tensor | |||
from .elemwise import abs, eq, exp, log, maximum, pow, relu | |||
from .nn import assert_equal, indexing_one_hot | |||
from .nn import indexing_one_hot | |||
from .tensor import where | |||
from .utils import zero_grad | |||
__all__ = [ | |||
"l1_loss", | |||
@@ -25,8 +24,7 @@ __all__ = [ | |||
def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
r""" | |||
Calculates the mean absolute error (MAE) between | |||
r"""Calculates the mean absolute error (MAE) between | |||
each element in the pred :math:`x` and label :math:`y`. | |||
The mean absolute error can be described as: | |||
@@ -43,8 +41,9 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||
of :math:`N` elements each. :math:`N` is the batch size. | |||
:param pred: The predicted result from model. | |||
:param label: The ground truth to compare. | |||
:param pred: predicted result from model. | |||
:param label: ground truth to compare. | |||
:return: loss value. | |||
Examples: | |||
@@ -53,9 +52,10 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.functional as F | |||
ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32)) | |||
tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32)) | |||
loss = F.l1_loss(ipt,tgt) | |||
loss = F.l1_loss(ipt, tgt) | |||
print(loss.numpy()) | |||
Outputs: | |||
@@ -70,8 +70,7 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
def square_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
r""" | |||
Calculates the mean squared error (squared L2 norm) between | |||
r"""Calculates the mean squared error (squared L2 norm) between | |||
each element in the pred :math:`x` and label :math:`y`. | |||
The mean squared error can be described as: | |||
@@ -88,13 +87,33 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
:math:`x` and :math:`y` are tensors of arbitrary shapes with a total | |||
of :math:`N` elements each. :math:`N` is the batch size. | |||
:param pred: The predicted result from model. | |||
:param label: The ground truth to compare. | |||
:param pred: predicted result from model. | |||
:param label: ground truth to compare. | |||
:return: loss value. | |||
Shape: | |||
- pred: :math:`(N, *)` where :math:`*` means any number of additional | |||
dimensions | |||
- label: :math:`(N, *)`. Same shape as ``pred`` | |||
dimensions. | |||
- label: :math:`(N, *)`. Same shape as ``pred``. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.functional as F | |||
ipt = mge.tensor(np.array([3, 3, 3, 3]).astype(np.float32)) | |||
tgt = mge.tensor(np.array([2, 8, 6, 1]).astype(np.float32)) | |||
loss = F.square_loss(ipt, tgt) | |||
print(loss.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[9.75] | |||
""" | |||
diff = pred - label | |||
@@ -104,8 +123,7 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor: | |||
def cross_entropy_with_softmax( | |||
pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0 | |||
) -> Tensor: | |||
r""" | |||
Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`. | |||
r"""Returns loss after applying :func:`~.softmax` + :func:`~.cross_entropy`. | |||
It has better numerical stability compared with sequential calls to :func:`~.softmax` and :func:`~.cross_entropy`. | |||
@@ -116,10 +134,33 @@ def cross_entropy_with_softmax( | |||
where :math:`y^{LS}` and :math:`y` are new label distribution and origin label distribution respectively. | |||
k is the index of label distribution. :math:`\alpha` is label_smooth and :math:`K` is the number of classes. | |||
:param pred: The input tensor representing the predicted probability. | |||
:param label: The input tensor representing the classification label. | |||
:param axis: An axis along which softmax will be applied. Default: 1. | |||
:param label_smooth: A label smoothing of parameter that can re-distribute target distribution. Default: 0. | |||
:param pred: input tensor representing the predicted probability. | |||
:param label: input tensor representing the classification label. | |||
:param axis: an axis along which softmax will be applied. Default: 1 | |||
:param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0 | |||
:return: loss value. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data_shape = (1, 2) | |||
label_shape = (1, ) | |||
pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape)) | |||
label = tensor(np.ones(label_shape, dtype=np.int32)) | |||
loss = F.cross_entropy_with_softmax(pred, label) | |||
print(loss.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.6931] | |||
""" | |||
n0 = pred.ndim | |||
n1 = label.ndim | |||
@@ -147,26 +188,44 @@ def cross_entropy_with_softmax( | |||
def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor: | |||
r"""Function that measures the Binary Cross Entropy between the target and the prediction. | |||
:param pred: (N,*) where * means, any number of additional dimensions. | |||
:param label: (N,*), same shape as the input. | |||
:param pred: `(N, *)` where `*` means any number of additional dimensions. | |||
:param label: `(N, *)`, same shape as the input. | |||
:return: loss value. | |||
""" | |||
assert make_shape_tuple(pred.shape) == make_shape_tuple(label.shape) | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(1, 2)) | |||
label = tensor(np.ones((1, 2), dtype=np.float32)) | |||
loss = F.binary_cross_entropy(pred, label) | |||
print(loss.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[0.6931] | |||
""" | |||
return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean() | |||
def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | |||
r""" | |||
Caculate the hinge loss which is often used in SVMs. | |||
r"""Caculate the hinge loss which is often used in SVMs. | |||
The hinge loss can be described as: | |||
.. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_i_j*y_i_j)) | |||
.. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_{ij}*y_{ij})) | |||
:param pred: The input tensor representing the predicted probability, shape is (N, C). | |||
:param label: The input tensor representing the binary classification label, shape is (N, C). | |||
:param norm: Specify the norm to caculate the loss, should be "L1" or "L2". | |||
:param pred: input tensor representing the predicted probability, shape is `(N, C)`. | |||
:param label: input tensor representing the binary classification label, shape is `(N, C)`. | |||
:param norm: specify the norm to caculate the loss, should be "L1" or "L2". | |||
:return: loss value. | |||
Examples: | |||
@@ -177,9 +236,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor: | |||
pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]], dtype="float32") | |||
label = tensor([[1, -1, -1], [-1, 1, 1]], dtype="float32") | |||
loss = F.hinge_loss(pred, label) | |||
print(loss.numpy()) | |||
Outputs: | |||
@@ -18,7 +18,7 @@ from ..core.tensor import utils | |||
from ..core.tensor.core import apply | |||
from ..tensor import Tensor | |||
from .elemwise import clamp, exp, log, log1p | |||
from .tensor import remove_axis, reshape | |||
from .tensor import add_axis, remove_axis, reshape | |||
__all__ = [ | |||
"argmax", | |||
@@ -42,10 +42,10 @@ __all__ = [ | |||
def isnan(inp: Tensor) -> Tensor: | |||
r"""Returns a new tensor representing if each element is NaN or not. | |||
r"""Returns a new tensor representing if each element is ``NaN`` or not. | |||
:param: inp | |||
:return: a new tensor representing if each element in :attr:`inp` is NaN or not. | |||
:param inp: input tensor. | |||
:return: a new tensor representing if each element in inp is NaN or not. | |||
Examples: | |||
@@ -55,7 +55,6 @@ def isnan(inp: Tensor) -> Tensor: | |||
import megengine.functional as F | |||
x = tensor([1, float("nan"), 0]) | |||
print(F.isnan(x).numpy()) | |||
Outputs: | |||
@@ -69,10 +68,10 @@ def isnan(inp: Tensor) -> Tensor: | |||
def isinf(inp: Tensor) -> Tensor: | |||
r"""Returns a new tensor representing if each element is Inf or not. | |||
r"""Returns a new tensor representing if each element is ``Inf`` or not. | |||
:param: inp | |||
:return: a new tensor representing if each element in :attr:`inp` is Inf or not. | |||
:param inp: input tensor. | |||
:return: a new tensor representing if each element in inp is Inf or not. | |||
Examples: | |||
@@ -82,7 +81,6 @@ def isinf(inp: Tensor) -> Tensor: | |||
import megengine.functional as F | |||
x = tensor([1, float("inf"), 0]) | |||
print(F.isinf(x).numpy()) | |||
Outputs: | |||
@@ -96,10 +94,10 @@ def isinf(inp: Tensor) -> Tensor: | |||
def sign(inp: Tensor): | |||
r"""Returns sign of each element in the input tensor. | |||
r"""Returns a new tensor representing the sign of each element in input tensor. | |||
:param: inp | |||
:return: a sign tensor. | |||
:param: input tensor. | |||
:return: the sign of input tensor. | |||
Examples: | |||
@@ -109,8 +107,9 @@ def sign(inp: Tensor): | |||
import megengine.functional as F | |||
x = tensor([1, -1, 0]) | |||
print(F.sign(x).numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
@@ -125,14 +124,15 @@ def sum( | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
r"""Returns the sum of each row of the ``inp`` tensor in the given ``axis``. | |||
r"""Returns the sum of input tensor along given axis. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor. | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. | |||
:param inp: input tensor. | |||
:param axis: dimension to reduce. If None, all the dimensions will be reduced. | |||
Default: None | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. | |||
:param keepdims: whether the output tensor has axis retained or not. | |||
Default: False | |||
:return: The output tensor | |||
:return: output tensor. | |||
Examples: | |||
@@ -142,12 +142,12 @@ def sum( | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.sum(data) | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.sum(x) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[21] | |||
@@ -159,13 +159,13 @@ def sum( | |||
def prod( | |||
inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None, keepdims=False | |||
) -> Tensor: | |||
r""" | |||
Returns the element product of input tensor along given *axis*. | |||
r"""Returns the product of input tensor along given axis. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None`` | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: ``False`` | |||
:return: The output tensor | |||
:param inp: input tensor. | |||
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: whether the output tensor has axis retained or not. Default: False | |||
:return: output tensor. | |||
Examples: | |||
@@ -175,8 +175,8 @@ def prod( | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.prod(data) | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.prod(x) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -194,13 +194,14 @@ def mean( | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
"""Returns the mean value of each row of the ``inp`` tensor in | |||
the given ``axis``. If axis is a list of dimensions, | |||
"""Returns the mean value of input tensor along | |||
given axis. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||
:param inp: input tensor. | |||
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: whether the output tensor has axis retained or not. Default: False | |||
:return: output tensor. | |||
Examples: | |||
@@ -210,8 +211,8 @@ def mean( | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.mean(data) | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2, 3)) | |||
out = F.mean(x) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -224,27 +225,19 @@ def mean( | |||
return inp.astype("float32").mean(axis=axis, keepdims=keepdims) | |||
def median( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
raise NotImplementedError | |||
def var( | |||
inp: Tensor, | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
"""Returns the variance value of input tensor along | |||
given ``axis``. If axis is a list of dimensions, | |||
given axis. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor. | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``. | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``. | |||
:return: The output tensor. | |||
:param inp: input tensor. | |||
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: whether the output tensor has axis retained or not. Default: False | |||
:return: output tensor. | |||
Examples: | |||
@@ -278,13 +271,13 @@ def std( | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
"""Returns the standard deviation of input tensor along | |||
given ``axis``. If axis is a list of dimensions, | |||
given axis. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor. | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: ``None``. | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. Default: ``False``. | |||
:return: The output tensor. | |||
:param inp: input tensor. | |||
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: whether the output tensor has axis retained or not. Default: False | |||
:return: output tensor. | |||
Examples: | |||
@@ -312,13 +305,14 @@ def min( | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
r""" | |||
Returns the min value of input tensor along given *axis*. | |||
r"""Returns the min value of input tensor along | |||
given axis. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
:param inp: input tensor. | |||
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: whether the output tensor has axis retained or not. Default: False | |||
:return: output tensor. | |||
Examples: | |||
@@ -329,8 +323,8 @@ def min( | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.min(x) | |||
print(y.numpy()) | |||
out = F.min(x) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -347,12 +341,14 @@ def max( | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
r"""Returns the max value of the input tensor along given *axis*. | |||
r"""Returns the max value of the input tensor along | |||
given axis. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
:param inp: input tensor. | |||
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: whether the output tensor has axis retained or not. Default: False | |||
:return: output tensor. | |||
Examples: | |||
@@ -363,8 +359,8 @@ def max( | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.max(x) | |||
print(y.numpy()) | |||
out = F.max(x) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -382,13 +378,15 @@ def norm( | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims=False, | |||
): | |||
"""Calculate ``p``-norm of input tensor along certain axis. | |||
"""Calculates ``p``-norm of input tensor along | |||
given axis. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor | |||
:param p: power of value ``p`` applied to ``inp``. Default: 2 | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has ``axis`` retained or not. Default: False | |||
:return: The output tensor | |||
:param inp: input tensor. | |||
:param p: power of value applied to inp. Default: 2 | |||
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: whether the output tensor has axis retained or not. Default: False | |||
:return: output tensor. | |||
Examples: | |||
@@ -399,8 +397,8 @@ def norm( | |||
import megengine.functional as F | |||
x = tensor(np.arange(-3, 3, dtype=np.float32).reshape(2,3)) | |||
y = F.norm(x) | |||
print(y.numpy()) | |||
out = F.norm(x) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -423,12 +421,14 @@ def argmin( | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
r"""Returns the indices of the minimum values along an axis | |||
r"""Returns the indices of the minimum values along | |||
given axis. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
:param inp: input tensor. | |||
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: whether the output tensor has axis retained or not. Default: False | |||
:return: output tensor. | |||
Examples: | |||
@@ -439,8 +439,8 @@ def argmin( | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.argmin(x) | |||
print(y.numpy()) | |||
out = F.argmin(x) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -479,12 +479,14 @@ def argmax( | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
keepdims: bool = False, | |||
) -> Tensor: | |||
r"""Returns the indices of the maximum values along an axis | |||
r"""Returns the indices of the maximum values along | |||
given axis. If axis is a list of dimensions, | |||
reduce over all of them. | |||
:param inp: The input tensor | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: Whether the output tensor has *axis* retained or not. Default: False | |||
:return: The output tensor | |||
:param inp: input tensor. | |||
:param axis: dimension to reduce. If None, all the dimensions will be reduced. Default: None | |||
:param keepdims: whether the output tensor has axis retained or not. Default: False | |||
:return: output tensor. | |||
Examples: | |||
@@ -495,9 +497,9 @@ def argmax( | |||
import megengine.functional as F | |||
x = tensor(np.arange(1, 7, dtype=np.int32).reshape(2,3)) | |||
y = F.argmax(x) | |||
print(y.numpy()) | |||
out = F.argmax(x) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
@@ -536,21 +538,22 @@ def normalize( | |||
axis: Optional[Union[int, Sequence[int]]] = None, | |||
eps: float = 1e-12, | |||
) -> Tensor: | |||
r"""Perform :math:`L_p` normalization of input tensor along certain axis. | |||
r"""Performs :math:`L_p` normalization of input tensor along | |||
given axis. If axis is a list of dimensions, | |||
reduce over all of them. | |||
For a tensor :attr:`inp` of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each | |||
For a tensor inp of shape :math:`(n_0, ..., n_{dim}, ..., n_k)`, each | |||
:math:`n_{dim}` -element vector :math:`v` along dimension :attr:`axis` is transformed as: | |||
.. math:: | |||
v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}. | |||
:param inp: the input tensor | |||
:param p: power of value ``p`` applied to ``inp``. Default: 2 | |||
:param axis: The dimension to reduce. If None, all the dimensions will be reduced | |||
:param inp: input tensor. | |||
:param p: power of value applied to inp. Default: 2 | |||
:param axis: dimension to reduce. If None, all the dimensions will be reduced | |||
to calculate the norm. Default: None | |||
:param eps: a small value to avoid division by zero. Default: 1e-12 | |||
:return: the normalized output tensor | |||
:return: normalized output tensor. | |||
""" | |||
if axis is None: | |||
return inp / clamp(norm(inp, p, axis), lower=eps) | |||
@@ -559,12 +562,11 @@ def normalize( | |||
def argsort(inp: Tensor, descending: bool = False) -> Tensor: | |||
r""" | |||
Sort the target 2d matrix by row, return both the sorted tensor and indices. | |||
r"""Sorts the target 2d matrix by row, return both the sorted tensor and indices. | |||
:param inp: The input tensor, if 2d, each row will be sorted | |||
:param descending: Sort in descending order, where the largest comes first. Default: ``False`` | |||
:return: Tuple of two tensors (sorted_tensor, indices_of_int32) | |||
:param inp: input tensor, if 2d, each row will be sorted. | |||
:param descending: Sort in descending order, where the largest comes first. Default: False | |||
:return: Tuple of two tensors `(sorted_tensor, indices_of_int32)`. | |||
Examples: | |||
@@ -573,8 +575,9 @@ def argsort(inp: Tensor, descending: bool = False) -> Tensor: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.array([1,2], dtype=np.float32)) | |||
indices = F.argsort(data) | |||
x = tensor(np.array([1,2], dtype=np.float32)) | |||
indices = F.argsort(x) | |||
print(indices.numpy()) | |||
Outputs: | |||
@@ -622,15 +625,14 @@ def topk( | |||
kth_only: bool = False, | |||
no_sort: bool = False, | |||
) -> Tuple[Tensor, Tensor]: | |||
r""" | |||
Selected the Top-K (by default) smallest elements of 2d matrix by row. | |||
r"""Selects the ``Top-K(by default)`` smallest elements of 2d matrix by row. | |||
:param inp: The input tensor, if 2d, each row will be sorted | |||
:param k: The number of elements needed | |||
:param descending: If true, return the largest elements instead. Default: ``False`` | |||
:param kth_only: If true, only the k-th element will be returned. Default: ``False`` | |||
:param no_sort: If true, the returned elements can be unordered. Default: ``False`` | |||
:return: Tuple of two tensors (topk_tensor, indices_of_int32) | |||
:param inp: input tensor, if 2d, each row will be sorted. | |||
:param k: number of elements needed. | |||
:param descending: if true, return the largest elements instead. Default: False | |||
:param kth_only: if true, only the k-th element will be returned. Default: False | |||
:param no_sort: if true, the returned elements can be unordered. Default: False | |||
:return: tuple of two tensors `(topk_tensor, indices_of_int32)`. | |||
Examples: | |||
@@ -639,8 +641,9 @@ def topk( | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | |||
top, indices = F.topk(data, 5) | |||
x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | |||
top, indices = F.topk(x, 5) | |||
print(top.numpy(), indices.numpy()) | |||
Outputs: | |||
@@ -22,41 +22,38 @@ from .debug_param import get_conv_execution_strategy | |||
from .distributed import all_reduce_sum | |||
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu | |||
from .math import argsort, max, sum | |||
from .tensor import add_axis, broadcast, concat, remove_axis, reshape | |||
from .tensor import add_axis, broadcast, concat, full, ones, remove_axis, reshape, zeros | |||
from .types import _pair, _pair_nonzero | |||
__all__ = [ | |||
"linear", | |||
"avg_pool2d", | |||
"batched_nms", | |||
"batch_norm2d", | |||
"conv2d", | |||
"conv_transpose2d", | |||
"local_conv2d", | |||
"max_pool2d", | |||
"avg_pool2d", | |||
"prelu", | |||
"dot", | |||
"dropout", | |||
"embedding", | |||
"indexing_one_hot", | |||
"interpolate", | |||
"leaky_relu", | |||
"softplus", | |||
"log_softmax", | |||
"linear", | |||
"local_conv2d", | |||
"logsigmoid", | |||
"logsumexp", | |||
"flatten", | |||
"softmax", | |||
"batch_norm2d", | |||
"sync_batch_norm", | |||
"one_hot", | |||
"warp_perspective", | |||
"log_softmax", | |||
"matmul", | |||
"interpolate", | |||
"dropout", | |||
"identity", | |||
"embedding", | |||
"roi_pooling", | |||
"max_pool2d", | |||
"nms", | |||
"one_hot", | |||
"prelu", | |||
"roi_align", | |||
"assert_equal", | |||
"indexing_one_hot", | |||
"dot", | |||
"roi_pooling", | |||
"softmax", | |||
"softplus", | |||
"svd", | |||
"nms", | |||
"batched_nms", | |||
"sync_batch_norm", | |||
"warp_perspective", | |||
] | |||
@@ -72,14 +69,14 @@ def expand_hw(x): | |||
def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: | |||
"""Applies a linear transformation to the input. | |||
"""Applies a linear transformation to the input tensor. | |||
Refer to :class:`~.module.linear.Linear` for more information. | |||
:param inp: the input tensor with shape `(N, in_features)`. | |||
:param weight: the weight with shape `(out_features, in_features)`. | |||
:param bias: the bias with shape `(out_features,)`. | |||
Default: ``None`` | |||
:param inp: input tensor with shape `(N, in_features)`. | |||
:param weight: weight with shape `(out_features, in_features)`. | |||
:param bias: bias with shape `(out_features,)`. | |||
Default: None | |||
""" | |||
ret = matmul(inp, weight, transpose_b=True) | |||
if bias is not None: | |||
@@ -102,28 +99,28 @@ def conv2d( | |||
Refer to :class:`~.Conv2d` for more information. | |||
:param inp: The feature map of the convolution operation | |||
:param weight: The convolution kernel | |||
:param bias: The bias added to the result of convolution (if given) | |||
:param stride: Stride of the 2D convolution operation. Default: 1 | |||
:param padding: Size of the paddings added to the input on both sides of its | |||
:param inp: feature map of the convolution operation. | |||
:param weight: convolution kernel. | |||
:param bias: bias added to the result of convolution (if given). | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: Dilation of the 2D convolution operation. Default: 1 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be ``(groups, out_channel // groups, | |||
in_channels // groups, height, width)``. | |||
:type conv_mode: string or :class:`P.Convolution.Mode` | |||
:param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
'CROSS_CORRELATION'. | |||
so as to perform a ``grouped convolution``. When groups is not 1, | |||
in_channels and out_channels must be divisible by groups, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, height, width)`. | |||
:type conv_mode: string or :class:`P.Convolution.Mode`. | |||
:param conv_mode: supports "CROSS_CORRELATION" or "CONVOLUTION". Default: | |||
"CROSS_CORRELATION" | |||
:type compute_mode: string or | |||
:class:`P.Convolution.ComputeMode` | |||
:param compute_mode: When set to 'DEFAULT', no special requirements will be | |||
placed on the precision of intermediate results. When set to 'FLOAT32', | |||
:class:`P.Convolution.ComputeMode`. | |||
:param compute_mode: when set to "DEFAULT", no special requirements will be | |||
placed on the precision of intermediate results. When set to "FLOAT32", | |||
Float32 would be used for accumulator and intermediate result, but only | |||
effective when input and output are of Float16 dtype. | |||
:return: output tensor. | |||
""" | |||
assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | |||
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | |||
@@ -168,28 +165,28 @@ def conv_transpose2d( | |||
Refer to :class:`~.ConvTranspose2d` for more information. | |||
:param inp: The feature map of the convolution operation | |||
:param weight: The convolution kernel | |||
:param bias: The bias added to the result of convolution (if given) | |||
:param stride: Stride of the 2D convolution operation. Default: 1 | |||
:param padding: Size of the paddings added to the input on both sides of its | |||
:param inp: feature map of the convolution operation. | |||
:param weight: convolution kernel. | |||
:param bias: bias added to the result of convolution (if given) | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: Dilation of the 2D convolution operation. Default: 1 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be ``(groups, out_channel // groups, | |||
in_channels // groups, height, width)``. Default: 1 | |||
:type conv_mode: string or :class:`P.Convolution.Mode` | |||
:param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
'CROSS_CORRELATION'. | |||
so as to perform a ``grouped convolution``. When groups is not 1, | |||
in_channels and out_channels must be divisible by groups, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, height, width)`. Default: 1 | |||
:type conv_mode: string or :class:`P.Convolution.Mode`. | |||
:param conv_mode: supports "CROSS_CORRELATION" or "CONVOLUTION". Default: | |||
"CROSS_CORRELATION" | |||
:type compute_mode: string or | |||
:class:`P.Convolution.ComputeMode` | |||
:param compute_mode: When set to 'DEFAULT', no special requirements will be | |||
placed on the precision of intermediate results. When set to 'FLOAT32', | |||
:class:`P.Convolution.ComputeMode`. | |||
:param compute_mode: when set to "DEFAULT", no special requirements will be | |||
placed on the precision of intermediate results. When set to "FLOAT32", | |||
Float32 would be used for accumulator and intermediate result, but only | |||
effective when input and output are of Float16 dtype. | |||
:return: output tensor. | |||
""" | |||
assert conv_mode == "CROSS_CORRELATION" or conv_mode.name == "CROSS_CORRELATION" | |||
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | |||
@@ -258,16 +255,16 @@ def max_pool2d( | |||
stride: Optional[Union[int, Tuple[int, int]]] = None, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
) -> Tensor: | |||
"""Applies a 2D max pooling over an input. | |||
"""Applies a 2D max pooling over an input tensor. | |||
Refer to :class:`~.MaxPool2d` for more information. | |||
:param inp: The input tensor. | |||
:param kernel_size: The size of the window. | |||
:param stride: The stride of the window. If not provided, its value is set to ``kernel_size``. | |||
:param inp: input tensor. | |||
:param kernel_size: size of the window. | |||
:param stride: stride of the window. If not provided, its value is set to kernel_size. | |||
Default: None | |||
:param padding: Implicit zero padding to be added on both sides. Default: 0 | |||
:param padding: implicit zero padding to be added on both sides. Default: 0 | |||
:return: output tensor. | |||
""" | |||
if stride is None: | |||
stride = kernel_size | |||
@@ -295,17 +292,17 @@ def avg_pool2d( | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
mode: str = "AVERAGE_COUNT_EXCLUDE_PADDING", | |||
) -> Tensor: | |||
""" Applies a 2D average pooling over an input. | |||
"""Applies a 2D average pooling over an input tensor. | |||
Refer to :class:`~.AvgPool2d` for more information. | |||
:param inp: The input tensor. | |||
:param kernel_size: The size of the window. | |||
:param stride: The stride of the window. If not provided, its value is set to ``kernel_size``. | |||
:param inp: input tensor. | |||
:param kernel_size: size of the window. | |||
:param stride: stride of the window. If not provided, its value is set to kernel_size. | |||
Default: None | |||
:param padding: Implicit zero padding to be added on both sides. Default: 0 | |||
:param mode: Whether to count padding values. Default: "AVERAGE_COUNT_EXCLUDE_PADDING" | |||
:param padding: implicit zero padding to be added on both sides. Default: 0 | |||
:param mode: whether to count padding values. Default: "AVERAGE_COUNT_EXCLUDE_PADDING" | |||
:return: output tensor. | |||
""" | |||
if stride is None: | |||
stride = kernel_size | |||
@@ -513,44 +510,6 @@ def logsumexp( | |||
) | |||
def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: | |||
r""" | |||
Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``. | |||
:param inp: The input tensor. | |||
:param start_axis: The start dimension that the sub-tensor to be flattened. Default: 0 | |||
:param end_axis: The end dimension that the sub-tensor to be flattened. Default: -1 | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
inp_shape = (2, 2, 3, 3) | |||
inp = tensor( | |||
np.arange(36, dtype=np.int32).reshape(inp_shape), | |||
) | |||
oup = F.flatten(inp, 2) | |||
print(inp.numpy().shape) | |||
print(oup.numpy().shape) | |||
Outputs: | |||
.. testoutput:: | |||
(2, 2, 3, 3) | |||
(2, 2, 9) | |||
""" | |||
target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,) | |||
if end_axis != -1: | |||
target_shape += (*inp.shape[end_axis + 1 :],) | |||
return inp.reshape(*target_shape) | |||
def _get_softmax_axis(ndim: int) -> int: | |||
if ndim in (0, 1, 3): | |||
return 0 | |||
@@ -602,7 +561,7 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||
def batch_norm2d( | |||
data: Tensor, | |||
inp: Tensor, | |||
running_mean: Tensor = None, | |||
running_var: Tensor = None, | |||
weight: Optional[Tensor] = None, | |||
@@ -621,35 +580,34 @@ def batch_norm2d( | |||
:param running_mean: tensor to store running mean. | |||
:param running_var: tensor to store running variance. | |||
:param weight: scaling tensor in the learnable affine parameters. | |||
See :math:`\gamma` in :class:`~.BatchNorm2d` | |||
See :math:`\gamma` in :class:`~.BatchNorm2d`. | |||
:param bias: bias tensor in the learnable affine parameters. | |||
See :math:`\beta` in :class:`~.BatchNorm2d` | |||
See :math:`\beta` in :class:`~.BatchNorm2d`. | |||
:param training: a boolean value to indicate whether batch norm is performed | |||
in traning mode. Default: ``False`` | |||
:param momentum: the value used for the ``running_mean`` and ``running_var`` | |||
in traning mode. Default: False | |||
:param momentum: value used for the ``running_mean`` and ``running_var`` | |||
computation. | |||
Default: 0.9 | |||
:param eps: a value added to the denominator for numerical stability. | |||
Default: 1e-5. | |||
Default: 1e-5 | |||
:param inplace: whether to update running_mean and running_var inplace or return new tensors | |||
Default: True | |||
:return: output tensor. | |||
""" | |||
from .tensor import add_axis, remove_axis, broadcast | |||
def full(value): | |||
C = data.shape[1] | |||
(x,) = Const(value, dtype=data.dtype, device=data.device)(data) | |||
def full_value(value): | |||
C = inp.shape[1] | |||
(x,) = Const(value, dtype=inp.dtype, device=inp.device)(inp) | |||
return broadcast(x, [1, C, 1, 1]) | |||
def expand_or_full(x, value): | |||
if x is None: | |||
return full(value) | |||
return full_value(value) | |||
return add_axis(x, [0, 2, 3]) | |||
def make_full_if_none(x, value): | |||
if x is None: | |||
return full(value) | |||
return full(shape=(1, inp.shape[1], 1, 1), value=value) | |||
return x | |||
has_mean = running_mean is not None | |||
@@ -664,8 +622,8 @@ def batch_norm2d( | |||
if has_var and running_var.ndim != 4: | |||
raise ValueError | |||
data, weight, bias, running_mean, running_var = utils.convert_inputs( | |||
data, weight, bias, running_mean, running_var | |||
inp, weight, bias, running_mean, running_var = utils.convert_inputs( | |||
inp, weight, bias, running_mean, running_var | |||
) | |||
weight = expand_or_full(weight, 1) | |||
@@ -673,7 +631,7 @@ def batch_norm2d( | |||
if not training: | |||
op = builtin.BatchNorm(fwd_mode="INFERENCE", epsilon=eps, param_dim="DIM_1C11") | |||
ret = apply(op, data, weight, bias, running_mean, running_var)[-1] | |||
ret = apply(op, inp, weight, bias, running_mean, running_var)[-1] | |||
return ret | |||
else: | |||
@@ -684,8 +642,8 @@ def batch_norm2d( | |||
if has_mean or has_var: | |||
running_mean = make_full_if_none(running_mean, 0) | |||
running_var = make_full_if_none(running_var, 1) | |||
new_mean, new_var, _, _, data = apply( | |||
op, data, weight, bias, running_mean, running_var | |||
new_mean, new_var, _, _, inp = apply( | |||
op, inp, weight, bias, running_mean, running_var | |||
) | |||
if not has_mean: | |||
new_mean = None | |||
@@ -698,12 +656,12 @@ def batch_norm2d( | |||
if has_var: | |||
running_var[...] = new_var | |||
return data | |||
return inp | |||
else: | |||
return data, new_mean, new_var | |||
return inp, new_mean, new_var | |||
else: | |||
_, _, data, = apply(op, data, weight, bias) | |||
return data | |||
_, _, inp, = apply(op, inp, weight, bias) | |||
return inp | |||
def sync_batch_norm( | |||
@@ -718,7 +676,7 @@ def sync_batch_norm( | |||
eps_mode="ADDITIVE", | |||
group=WORLD, | |||
) -> Tensor: | |||
""" Applies synchronized batch normalization to the input. | |||
"""Applies synchronized batch normalization to the input. | |||
Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information. | |||
@@ -726,16 +684,17 @@ def sync_batch_norm( | |||
:param running_mean: tensor to store running mean. | |||
:param running_var: tensor to store running variance. | |||
:param weight: scaling tensor in the learnable affine parameters. | |||
See :math:`\gamma` in :class:`~.BatchNorm2d` | |||
See :math:`\gamma` in :class:`~.BatchNorm2d`. | |||
:param bias: bias tensor in the learnable affine parameters. | |||
See :math:`\beta` in :class:`~.BatchNorm2d` | |||
See :math:`\beta` in :class:`~.BatchNorm2d`. | |||
:param training: a boolean value to indicate whether batch norm is performed | |||
in traning mode. Default: ``False`` | |||
:param momentum: the value used for the ``running_mean`` and ``running_var`` | |||
in traning mode. Default: False | |||
:param momentum: value used for the ``running_mean`` and ``running_var`` | |||
computation. | |||
Default: 0.9 | |||
:param eps: a value added to the denominator for numerical stability. | |||
Default: 1e-5. | |||
Default: 1e-5 | |||
:return: output tensor. | |||
""" | |||
assert eps_mode in {"MAX", "ADDITIVE"}, "unknown eps_mode: {}".format(eps_mode) | |||
_channels = inp.shape[1] | |||
@@ -786,7 +745,7 @@ def sync_batch_norm( | |||
bias = bias.reshape(*_param_shape) | |||
# outvar = output * weight + bias | |||
# where output = input * invsqrt_channel_variance + ( | |||
# where output = inp * invsqrt_channel_variance + ( | |||
# -channel_mean * invsqrt_channel_variance | |||
# ) | |||
# Manually expand output for gopt | |||
@@ -818,11 +777,11 @@ def sync_batch_norm( | |||
def one_hot(inp: Tensor, num_classes: int) -> Tensor: | |||
r""" | |||
Perform one-hot encoding for the input tensor. | |||
r"""Performs one-hot encoding for the input tensor. | |||
:param inp: input tensor | |||
:param num_classes: number of classes denotes the last dimension of the output tensor | |||
:param inp: input tensor. | |||
:param num_classes: number of classes denotes the last dimension of the output tensor. | |||
:return: output tensor. | |||
Examples: | |||
@@ -832,8 +791,8 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
inp = tensor(np.arange(1, 4, dtype=np.int32)) | |||
out = F.one_hot(inp, num_classes=4) | |||
x = tensor(np.arange(1, 4, dtype=np.int32)) | |||
out = F.one_hot(x, num_classes=4) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -845,20 +804,12 @@ def one_hot(inp: Tensor, num_classes: int) -> Tensor: | |||
[0 0 0 1]] | |||
""" | |||
raise NotImplementedError | |||
# comp_node, comp_graph = _decide_comp_node_and_comp_graph(inp) | |||
# zeros = mgb.make_immutable(value=0, comp_node=comp_node, comp_graph=comp_graph) | |||
# zeros_symvar = zeros.broadcast(inp.shapeof(), num_classes) | |||
# ones = mgb.make_immutable(value=1, comp_node=comp_node, comp_graph=comp_graph) | |||
# ones_symvar = ones.broadcast(inp.shapeof(), 1) | |||
zeros_tensor = zeros(list(inp.shape) + [num_classes], inp.dtype, inp.device) | |||
ones_tensor = ones(list(inp.shape) + [1], inp.dtype, inp.device) | |||
# return Tensor( | |||
# mgb.opr.indexing_set_one_hot( | |||
# zeros_symvar, axis=len(inp.shapeof()), index=inp, value=ones_symvar | |||
# ) | |||
# ) | |||
op = builtin.IndexingSetOneHot(axis=inp.ndim) | |||
(result,) = apply(op, zeros_tensor, inp, ones_tensor) | |||
return result | |||
def warp_perspective( | |||
@@ -869,8 +820,7 @@ def warp_perspective( | |||
border_val: float = 0.0, | |||
interp_mode: str = "LINEAR", | |||
): | |||
r""" | |||
Applies perspective transformation to batched 2D images. | |||
r"""Applies perspective transformation to batched 2D images. | |||
The input images are transformed to the output images by the transformation matrix: | |||
@@ -880,12 +830,13 @@ def warp_perspective( | |||
\frac{M_{10}h + M_{11}w + M_{12}}{M_{20}h + M_{21}w + M_{22}} | |||
\right) | |||
:param inp: input image | |||
:param M: (batch, 3, 3) transformation matrix | |||
:param dsize: (h, w) size of the output image | |||
:param border_mode: pixel extrapolation method. Default: ``"REPLICATE"`` | |||
:param border_val: value used in case of a constant border. Default: ``0`` | |||
:param interp_mode: interpolation methods. Default: ``"LINEAR"`` | |||
:param inp: input image. | |||
:param M: `(batch, 3, 3)` transformation matrix. | |||
:param dsize: `(h, w)` size of the output image. | |||
:param border_mode: pixel extrapolation method. Default: "REPLICATE" | |||
:param border_val: value used in case of a constant border. Default: 0 | |||
:param interp_mode: interpolation methods. Default: "LINEAR" | |||
:return: output tensor. | |||
Examples: | |||
@@ -894,14 +845,15 @@ def warp_perspective( | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
inp_shape = (1, 1, 4, 4) | |||
inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||
x = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||
M_shape = (1, 3, 3) | |||
# M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1) | |||
M = tensor(np.array([[1., 0., 1.], | |||
[0., 1., 1.], | |||
[0., 0., 1.]], dtype=np.float32).reshape(M_shape)) | |||
out = F.warp_perspective(inp, M, (2, 2)) | |||
out = F.warp_perspective(x, M, (2, 2)) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -1100,15 +1052,15 @@ def interpolate( | |||
mode: str = "BILINEAR", | |||
align_corners: bool = None, | |||
) -> Tensor: | |||
r""" | |||
Down/up samples the input tensor to either the given :attr:`size` or the given | |||
:attr:`scale_factor` | |||
r"""Down/up samples the input tensor to either the given size or the given | |||
scale_factor. | |||
:param inp: input tensor | |||
:param size: size of the output tensor. Default: ``None`` | |||
:param scale_factor: scaling factor of the output tensor. Default: ``None`` | |||
:param inp: input tensor. | |||
:param size: size of the output tensor. Default: None | |||
:param scale_factor: scaling factor of the output tensor. Default: None | |||
:param mode: interpolation methods, acceptable values are: | |||
'BILINEAR', 'LINEAR'. Default: ``BILINEAR`` | |||
"BILINEAR", "LINEAR". Default: "BILINEAR" | |||
:return: output tensor. | |||
Examples: | |||
@@ -1119,11 +1071,10 @@ def interpolate( | |||
import megengine.functional as F | |||
from megengine.test import assertTensorClose | |||
inp = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) | |||
out = F.interpolate(inp, [4, 4], align_corners=False) | |||
x = tensor(np.arange(1, 5, dtype=np.float32).reshape(1, 1, 2, 2)) | |||
out = F.interpolate(x, [4, 4], align_corners=False) | |||
print(out.numpy()) | |||
out2 = F.interpolate(inp, scale_factor=2.) | |||
out2 = F.interpolate(x, scale_factor=2.) | |||
assertTensorClose(out.numpy(), out2.numpy()) | |||
Outputs: | |||
@@ -1245,28 +1196,25 @@ def interpolate( | |||
def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: | |||
""" | |||
Returns a new tensor where each of the elements are randomly set to zero | |||
"""Returns a new tensor where each of the elements are randomly set to zero | |||
with probability P = ``drop_prob``. Optionally rescale the output tensor. | |||
:param inp: The input tensor | |||
:param drop_prob: The probability to drop (set to zero) a single element | |||
:param training: The default behavior of ``dropout`` during training is to rescale the output, | |||
:param inp: input tensor. | |||
:param drop_prob: probability to drop (set to zero) a single element. | |||
:param training: the default behavior of ``dropout`` during training is to rescale the output, | |||
then it can be replaced by an :class:`~.Identity` during inference, default to True. | |||
:return: The output tensor | |||
:return: the output tensor | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.functional as F | |||
from megengine import tensor | |||
import megengine.functional as F | |||
data = tensor(np.ones(10, dtype=np.float32)) | |||
out = F.dropout(data, 1./3.) | |||
x = tensor(np.ones(10, dtype=np.float32)) | |||
out = F.dropout(x, 1./3.) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -1286,33 +1234,21 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor: | |||
return inp | |||
def identity(inp: Tensor) -> Tensor: | |||
"""applies an identity transform to the input tensor. | |||
:param inp: The input tensor | |||
""" | |||
op = builtin.Identity() | |||
(data,) = utils.convert_inputs(inp) | |||
(output,) = apply(op, data) | |||
return output | |||
def embedding( | |||
input: Tensor, | |||
inp: Tensor, | |||
weight: Tensor, | |||
padding_idx: Optional[int] = None, | |||
max_norm: Optional[float] = None, | |||
norm_type: Optional[float] = None, | |||
): | |||
""" | |||
Applies lookup table for embedding. | |||
"""Applies lookup table for embedding. | |||
:param input: the tensor with indices. | |||
:param weight: the learnable weights which embedding from. | |||
:param inp: tensor with indices. | |||
:param weight: learnable weights which embedding from. | |||
:param padding_idx: should be set to None, not support now. | |||
:param max_norm: should be set to None, not support now. | |||
:param norm_type: should be set to None, not support now. | |||
:return: output tensor. | |||
Refer to :class:`~.Embedding` for more information. | |||
""" | |||
@@ -1321,8 +1257,8 @@ def embedding( | |||
if max_norm is not None or norm_type is not None: | |||
raise ValueError("Not support weight normlization Now!") | |||
dest_shp = list(input.shape) + [weight.shape[-1]] | |||
return weight[input.reshape(-1)].reshape(dest_shp) | |||
dest_shp = list(inp.shape) + [weight.shape[-1]] | |||
return weight[inp.reshape(-1)].reshape(dest_shp) | |||
def roi_pooling( | |||
@@ -1332,15 +1268,37 @@ def roi_pooling( | |||
mode: str = "max", | |||
scale: float = 1.0, | |||
) -> Tensor: | |||
""" | |||
Apply roi pooling on input feature | |||
"""Applies roi pooling on input feature. | |||
:param inp: tensor that represents the input feature, (N, C, H, W) images | |||
:param rois: (K, 5) boxes. First column is the index into N. The other 4 columns are xyxy | |||
:param output_shape: (height, width) of output rois feature | |||
:param mode: "max" or "average", use max/average align just like max/average pooling. Default: ``"max"`` | |||
:param inp: tensor that represents the input feature, `(N, C, H, W)` images. | |||
:param rois: `(K, 5)` boxes. First column is the index into N. The other 4 columns are xyxy. | |||
:param output_shape: `(height, width)` of output rois feature. | |||
:param mode: "max" or "average", use max/average align just like max/average pooling. Default: "max" | |||
:param scale: scale the input boxes by this number. Default: 1.0 | |||
:return: (K, C, output_shape[0], output_shape[1]) feature of rois | |||
:return: `(K, C, output_shape[0], output_shape[1])` feature of rois. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
np.random.seed(42) | |||
inp = tensor(np.random.randn(1, 1, 128, 128)) | |||
rois = tensor(np.random.random((4, 5))) | |||
y = F.roi_pooling(inp, rois, (2, 2)) | |||
print(y.numpy()[0]) | |||
Outputs: | |||
.. testoutput:: | |||
[[[-0.1383 -0.1383] | |||
[-0.5035 -0.5035]]] | |||
""" | |||
assert mode in ["max", "average"], "only max/average mode is supported" | |||
if isinstance(output_shape, int): | |||
@@ -1355,7 +1313,7 @@ def roi_pooling( | |||
def roi_align( | |||
input: Tensor, | |||
inp: Tensor, | |||
rois: Tensor, | |||
output_shape: Union[int, tuple, list], | |||
mode: str = "average", | |||
@@ -1363,18 +1321,40 @@ def roi_align( | |||
sample_points: Union[int, tuple, list] = 2, | |||
aligned: bool = True, | |||
) -> Tensor: | |||
""" | |||
Apply roi align on input feature | |||
"""Applies roi align on input feature. | |||
:param input: tensor that represents the input feature, (N, C, H, W) images | |||
:param rois: (N, 5) boxes. First column is the index into N. The other 4 columns are xyxy | |||
:param output_shape: (height, width) shape of output rois feature. | |||
:param mode: "max" or "average", use max/average align just like max/average pooling. Default: ``"average"`` | |||
:param inp: tensor that represents the input feature, `(N, C, H, W)` images. | |||
:param rois: `(N, 5)` boxes. First column is the index into N. The other 4 columns are xyxy. | |||
:param output_shape: `(height, width)` shape of output rois feature. | |||
:param mode: "max" or "average", use max/average align just like max/average pooling. Default: "average" | |||
:param spatial_scale: scale the input boxes by this number. Default: 1.0 | |||
:param sample_points: number of inputs samples to take for each output sample. | |||
0 to take samples densely. Default: 2 | |||
:param aligned: wheather align the input feature, with `aligned=True`, | |||
we first appropriately scale the ROI and then shift it by -0.5. Default: True | |||
:return: output tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
np.random.seed(42) | |||
inp = tensor(np.random.randn(1, 1, 128, 128)) | |||
rois = tensor(np.random.random((4, 5))) | |||
y = F.roi_align(inp, rois, (2, 2)) | |||
print(y.numpy()[0]) | |||
Outputs: | |||
.. testoutput:: | |||
[[[0.175 0.175 ] | |||
[0.1359 0.1359]]] | |||
""" | |||
assert mode in ["max", "average"], "only max/average mode is supported" | |||
if isinstance(output_shape, int): | |||
@@ -1395,58 +1375,21 @@ def roi_align( | |||
sample_height=sample_height, | |||
sample_width=sample_width, | |||
) | |||
input, rois = utils.convert_inputs(input, rois) | |||
result, *_ = apply(op, input, rois) | |||
inp, rois = utils.convert_inputs(inp, rois) | |||
result, *_ = apply(op, inp, rois) | |||
return result | |||
def assert_equal( | |||
get: Tensor, expect: Tensor, max_err: float = 1e-4, verbose: bool = False | |||
) -> Tensor: | |||
r""" | |||
Asserts that ``get`` equals to ``expect``, and returns value of ``expect``. | |||
:param get: tensor to be checked. | |||
:param expect: tensor with expected values. | |||
:param max_err: tolerance that two float values are asserted equal. Default: 1e-4 | |||
:param verbose: whether to print details if two tensors are not equal. Default: False | |||
Examples: | |||
.. testcode:: | |||
import megengine.functional as F | |||
from megengine import tensor | |||
get = tensor([1.0, 2.0]) | |||
max_err = 0.1 | |||
expect = get + max_err / 2.0 | |||
val = F.assert_equal(expect, get, max_err=max_err) | |||
print(val.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[1.05 2.05] | |||
""" | |||
raise NotImplementedError | |||
# op = builtin.AssertEqual(maxerr=max_err, verbose=verbose) | |||
# result, = apply(op, get, expect) | |||
# return result | |||
def indexing_one_hot( | |||
src: Tensor, index: Tensor, axis: int = 1, keepdims=False | |||
) -> Tensor: | |||
r""" | |||
One-hot indexing for some axis. | |||
r"""One-hot indexing for some axis. | |||
:param src: input data tensor. | |||
:param src: input tensor. | |||
:param index: index tensor. | |||
:param axis: the axis on src for which values in index index. Default: 1 | |||
:param keepdims: whether not to remove the axis in result. Default: ``False`` | |||
:param axis: axis on src for which values in index index. Default: 1 | |||
:param keepdims: whether not to remove the axis in result. Default: False | |||
:return: output tensor. | |||
Examples: | |||
@@ -1461,7 +1404,7 @@ def indexing_one_hot( | |||
print(val.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[1.] | |||
@@ -1480,11 +1423,11 @@ def indexing_one_hot( | |||
def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor: | |||
r""" | |||
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU). | |||
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU). | |||
:param boxes: tensor of shape ``(N, 4)``; the boxes to perform nms on; each box is expected to be in (x1, y1, x2, y2) format. | |||
:param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format. | |||
:param iou_thresh: iou threshold for overlapping. | |||
:param scores: tensor of shape ``(N,)``, the score of boxes. | |||
:param scores: tensor of shape `(N,)`, the score of boxes. | |||
:return: indices of the elements that have been kept by NMS. | |||
Examples: | |||
@@ -1539,10 +1482,10 @@ def batched_nms( | |||
r""" | |||
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU). | |||
:param boxes: tensor of shape ``(N, 4)``; the boxes to perform nms on; each box is expected to be in (x1, y1, x2, y2) format | |||
:param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format | |||
:param iou_thresh: iou threshold for overlapping | |||
:param idxs: tensor of shape ``(N,)``, the class indexs of boxes in the batch. | |||
:param scores: tensor of shape ``(N,)``, the score of boxes. | |||
:param idxs: tensor of shape `(N,)`, the class indexs of boxes in the batch. | |||
:param scores: tensor of shape `(N,)`, the score of boxes. | |||
:return: indices and the number of the elements that have been kept by NMS | |||
Examples: | |||
@@ -29,32 +29,29 @@ def conv_bias_activation( | |||
conv_mode="CROSS_CORRELATION", | |||
compute_mode="DEFAULT", | |||
) -> Tensor: | |||
""" convolution bias with activation operation, only for inference. | |||
"""Convolution bias with activation operation, only for inference. | |||
:param inp: The feature map of the convolution operation | |||
:param weight: The convolution kernel | |||
:param bias: The bias added to the result of convolution | |||
:param stride: Stride of the 2D convolution operation. Default: 1 | |||
:param padding: Size of the paddings added to the input on both sides of its | |||
:param inp: feature map of the convolution operation. | |||
:param weight: convolution kernel. | |||
:param bias: bias added to the result of convolution | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: Dilation of the 2D convolution operation. Default: 1 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be ``(groups, out_channel // groups, | |||
in_channels // groups, height, width)``. | |||
:type conv_mode: string or :class:`P.Convolution.Mode` | |||
:param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
'CROSS_CORRELATION'. | |||
:param dtype: Support for np.dtype, Default: | |||
np.int8. | |||
:param scale: scale if use quantization, Default: | |||
0.0. | |||
:param zero_point: scale if use quantization quint8, Default: | |||
0.0. | |||
so as to perform a "grouped convolution". When groups is not 1, | |||
in_channels and out_channels must be divisible by groups, | |||
and the shape of weight should be `(groups, out_channel // groups, | |||
in_channels // groups, height, width)`. | |||
:type conv_mode: string or :class:`P.Convolution.Mode`. | |||
:param conv_mode: supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
'CROSS_CORRELATION' | |||
:param dtype: support for np.dtype, Default: np.int8 | |||
:param scale: scale if use quantization, Default: 0.0 | |||
:param zero_point: scale if use quantization quint8, Default: 0.0 | |||
:type compute_mode: string or | |||
:class:`P.Convolution.ComputeMode` | |||
:param compute_mode: When set to 'DEFAULT', no special requirements will be | |||
:class:`P.Convolution.ComputeMode`. | |||
:param compute_mode: when set to 'DEFAULT', no special requirements will be | |||
placed on the precision of intermediate results. When set to 'FLOAT32', | |||
Float32 would be used for accumulator and intermediate result, but only | |||
effective when input and output are of Float16 dtype. | |||
@@ -36,12 +36,14 @@ __all__ = [ | |||
"broadcast", | |||
"concat", | |||
"cond_take", | |||
"dimshuffle", | |||
"expand_dims", | |||
"transpose", | |||
"add_axis", | |||
"eye", | |||
"flatten", | |||
"full", | |||
"full_like", | |||
"gather", | |||
"identity", | |||
"linspace", | |||
"ones", | |||
"ones_like", | |||
@@ -50,7 +52,6 @@ __all__ = [ | |||
"reshape", | |||
"remove_axis", | |||
"split", | |||
"squeeze", | |||
"stack", | |||
"scatter", | |||
"transpose", | |||
@@ -60,16 +61,14 @@ __all__ = [ | |||
] | |||
def eye(n: int, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: | |||
""" | |||
Returns a 2D tensor with ones on the diagonal and zeros elsewhere. | |||
def eye(shape, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: | |||
"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere. | |||
:param n: The number of rows | |||
:param m: The number of columns. Default: None | |||
:param dtype: The data type. Default: None | |||
:param device: Compute node of the matrix. Default: None | |||
:param comp_graph: Compute graph of the matrix. Default: None | |||
:return: The eye matrix | |||
:param shape: expected shape of otuput tensor. | |||
:param m: number of columns. Default: None | |||
:param dtype: data type. Default: None | |||
:param device: compute node of the matrix. Default: None | |||
:return: eye matrix. | |||
Examples: | |||
@@ -79,8 +78,7 @@ def eye(n: int, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor | |||
import megengine.functional as F | |||
data_shape = (4, 6) | |||
n, m = data_shape | |||
out = F.eye([n, m], dtype=np.float32) | |||
out = F.eye(data_shape, dtype=np.float32) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -94,11 +92,13 @@ def eye(n: int, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor | |||
""" | |||
op = builtin.Eye(k=0, dtype=dtype, comp_node=device) | |||
(result,) = apply(op, Tensor(n, dtype="int32", device=device)) | |||
(result,) = apply(op, Tensor(shape, dtype="int32", device=device)) | |||
return result | |||
def full(shape, value, dtype="float32", device=None): | |||
"""Returns a tensor with given shape and value. | |||
""" | |||
if isinstance(shape, int): | |||
shape = (shape,) | |||
if device is None: | |||
@@ -110,18 +110,42 @@ def full(shape, value, dtype="float32", device=None): | |||
def ones(shape, dtype="float32", device=None): | |||
"""Returns a ones tensor with given shape. | |||
:param inp: input tensor. | |||
:return: output zero tensor. | |||
Examples: | |||
.. testcode:: | |||
import megengine.functional as F | |||
out = F.ones((2, 1)) | |||
print(out.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[1.] | |||
[1.]] | |||
""" | |||
return full(shape, 1.0, dtype=dtype, device=device) | |||
def zeros(shape, dtype="float32", device=None): | |||
"""Returns a zero tensor with given shape. | |||
""" | |||
return full(shape, 0.0, dtype=dtype, device=device) | |||
def zeros_like(inp: Tensor) -> Tensor: | |||
r""" | |||
Returns a zero tensor with the same shape as input tensor | |||
"""Returns a zero tensor with the same shape as input tensor. | |||
:param inp: input tensor | |||
:param inp: input tensor. | |||
:return: output zero tensor. | |||
Examples: | |||
@@ -147,26 +171,36 @@ def zeros_like(inp: Tensor) -> Tensor: | |||
def ones_like(inp: Tensor) -> Tensor: | |||
r""" | |||
Returns a identity tensor with the same shape as input tensor | |||
"""Returns a identity tensor with the same shape as input tensor. | |||
""" | |||
return ones(inp.shape, dtype=inp.dtype, device=inp.device) | |||
def full_like(inp: Tensor, value: Union[int, float]) -> Tensor: | |||
r""" | |||
Returns a tensor filled with value val with the same shape as input tensor | |||
"""Returns a tensor filled with given value with the same shape as input tensor. | |||
""" | |||
return full(inp.shape, value, dtype=inp.dtype, device=inp.device) | |||
def identity(inp: Tensor) -> Tensor: | |||
"""Applies an identity transform to the input tensor. | |||
:param inp: input tensor. | |||
:return: output tensor. | |||
""" | |||
op = builtin.Identity() | |||
(data,) = utils.convert_inputs(inp) | |||
(output,) = apply(op, data) | |||
return output | |||
def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
""" | |||
Broadcast a tensor to ``shape`` | |||
Broadcasts a tensor to given shape. | |||
:param inp: The input tensor | |||
:param shape: The target shape | |||
:return: The output tensor | |||
:param inp: input tensor. | |||
:param shape: target shape. | |||
:return: output tensor. | |||
Examples: | |||
@@ -206,10 +240,10 @@ def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||
r""" | |||
Concat some tensors | |||
:param inps: Input tensors to concat | |||
:param axis: the dimension over which the tensors are concatenated. Default: 0 | |||
:param device: The comp node output on. Default: None | |||
:return: The output tensor | |||
:param inps: input tensors to concat. | |||
:param axis: dimension over which the tensors are concatenated. Default: 0 | |||
:param device: comp node output on. Default: None | |||
:return: output tensor. | |||
Examples: | |||
@@ -254,10 +288,10 @@ def stack(inps, axis=0, device=None): | |||
"""Concats a sequence of tensors along a new axis. | |||
The input tensors must have the same shape. | |||
:param inps: The input tensors. | |||
:param axis: Which axis will be concatenated. | |||
:param inps: input tensors. | |||
:param axis: which axis will be concatenated. | |||
:param device: The comp node output on. Default: None | |||
:return: The output concatenated tensor. | |||
:return: output concatenated tensor. | |||
Examples: | |||
@@ -296,10 +330,10 @@ def split(inp, nsplits_or_sections, axis=0): | |||
"""Splits the input tensor into several smaller tensors. | |||
When nsplits_or_sections is int, the last tensor may be smaller than others. | |||
:param inp: The input tensor. | |||
:param nsplits_or_sections: Number of sub tensors or section information list. | |||
:param axis: Which axis will be splited. | |||
:return: The output tensor list. | |||
:param inp: input tensor. | |||
:param nsplits_or_sections: number of sub tensors or section information list. | |||
:param axis: which axis will be splited. | |||
:return: output tensor list. | |||
Examples: | |||
@@ -377,8 +411,7 @@ def _get_idx(index, axis): | |||
def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: | |||
r""" | |||
Gather data from :attr:`inp` on :attr:`axis` using :attr:`index`. | |||
r"""Gathers data from inp on axis using index. | |||
For a 3-D tensor, the output is specified by:: | |||
@@ -386,16 +419,16 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: | |||
out[i][j][k] = inp[i][index[i][j][k]][k] # if axis == 1 | |||
out[i][j][k] = inp[i][j][index[i][j][k]] # if axis == 2 | |||
if :attr:`inp` is an n-dimensional tensor with size | |||
if inp is an n-dimensional tensor with size | |||
:math:`(x_0,x_1,...,x_{i-1},x_i,x_{i+1},...,x_{n-1})` and axis=i, | |||
then :attr:`index` must be an n-dimensional tensor with size | |||
then index must be an n-dimensional tensor with size | |||
:math:`(x_0,x_1,...,x_{i-1},y,x_{i+1},...,x_{n-1})` where :math:`y\ge 1` and | |||
output will have the same size as :attr:`index`. | |||
output will have the same size as index. | |||
:param inp: the source tensor | |||
:param axis: the axis along which to index | |||
:param index: the indices of elements to gather | |||
:param inp: input tensor. | |||
:param axis: axis along which to index. | |||
:param index: indices of elements to gather. | |||
:return: output tensor. | |||
Examples: | |||
@@ -449,25 +482,25 @@ def gather(inp: Tensor, axis: int, index: Tensor) -> Tensor: | |||
def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | |||
r""" | |||
Writes all values from the tensor :attr:`source` into :attr:`inp` at the indices specified in the :attr:`index` tensor. | |||
r"""Writes all values from the tensor source into inp | |||
at the indices specified in the index tensor. | |||
For each value in :attr:`source`, its output index is specified by its index | |||
in :attr:`source` for ``axis != dimension`` and by the corresponding value in | |||
:attr:`index` for ``axis = dimension``. | |||
For each value in source, its output index is specified by its index | |||
in source for ``axis != dimension`` and by the corresponding value in | |||
index for ``axis = dimension``. | |||
For a 3-D tensor, :attr:`inp` is updated as:: | |||
For a 3-D tensor, inp is updated as:: | |||
inp[index[i][j][k]][j][k] = source[i][j][k] # if axis == 0 | |||
inp[i][index[i][j][k]][k] = source[i][j][k] # if axis == 1 | |||
inp[i][j][index[i][j][k]] = source[i][j][k] # if axis == 2 | |||
:attr:`inp`, :attr:`index` and :attr:`source` should have same number of dimensions. | |||
inp, index and source should have same number of dimensions. | |||
It is also required that ``source.shape(d) <= inp.shape(d)`` and ``index.shape(d) == source.shape(d)`` | |||
for all dimensions ``d``. | |||
Moreover, the values of :attr:`index` must be between ``0`` and ``inp.shape(axis) - 1`` inclusive. | |||
Moreover, the values of index must be between ``0`` and ``inp.shape(axis) - 1`` inclusive. | |||
.. note:: | |||
Please notice that, due to performance issues, the result is uncertain on the GPU device | |||
@@ -478,10 +511,11 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | |||
from source[0][2] which value is 0.2256 or source[1][2] which value is 0.5339 | |||
if set the index[1][2] from 1 to 0. | |||
:param inp: the inp tensor which to be scattered | |||
:param axis: the axis along which to index | |||
:param index: the indices of elements to scatter | |||
:param source: the source element(s) to scatter | |||
:param inp: inp tensor which to be scattered. | |||
:param axis: axis along which to index. | |||
:param index: indices of elements to scatter. | |||
:param source: source element(s) to scatter. | |||
:return: output tensor. | |||
Examples: | |||
@@ -553,16 +587,16 @@ def scatter(inp: Tensor, axis: int, index: Tensor, source: Tensor) -> Tensor: | |||
def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||
r""" | |||
Select elements either from Tensor x or Tensor y, according to mask. | |||
r"""Selects elements either from Tensor x or Tensor y, according to mask. | |||
.. math:: | |||
\textrm{out}_i = x_i \textrm{ if } \textrm{mask}_i \textrm{ is True else } y_i | |||
:param mask: a mask used for choosing x or y | |||
:param x: the first choice | |||
:param y: the second choice | |||
:param mask: a mask used for choosing x or y. | |||
:param x: first choice. | |||
:param y: second choice. | |||
:return: output tensor. | |||
Examples: | |||
@@ -620,8 +654,8 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor: | |||
and the second is the indices corresponding to those elements; | |||
they are both 1-dimensional. High-dimension input would first be flattened. | |||
:param mask: condition param; must be the same shape with data | |||
:param x: input tensor from which to take elements | |||
:param mask: condition param; must be the same shape with data. | |||
:param x: input tensor from which to take elements. | |||
Examples: | |||
@@ -657,12 +691,13 @@ def cond_take(mask: Tensor, x: Tensor) -> Tensor: | |||
return v, index | |||
def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
def transpose(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
r""" | |||
Swap shapes and strides according to given pattern | |||
Swaps shapes and strides according to given pattern. | |||
:param inp: Input tensor | |||
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1, and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | |||
:param inp: input tensor. | |||
:param pattern: a list of integers including 0, 1, ... , ``ndim``-1, | |||
and any number of ``'x'`` char in dimensions where this tensor should be broadcasted. For examples: | |||
* (``'x'``) -> make a 0d (scalar) into a 1d vector | |||
* (0, 1) -> identity for 2d vectors | |||
@@ -674,7 +709,7 @@ def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
* (1, ``'x'``, 0) -> AxB to Bx1xA | |||
* (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) | |||
:return: The output tensor | |||
:return: output tensor. | |||
Examples: | |||
@@ -684,7 +719,7 @@ def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.array([[1, 1], [0, 0]], dtype=np.int32)) | |||
out = F.dimshuffle(x, (1, 0)) | |||
out = F.transpose(x, (1, 0)) | |||
print(out.numpy()) | |||
Outputs: | |||
@@ -701,15 +736,15 @@ def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor: | |||
return result | |||
transpose = dimshuffle | |||
dimshuffle = transpose | |||
def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor: | |||
r""" | |||
Reshape a tensor to given target shape; total number of logical elements must | |||
Reshapes a tensor to given target shape; total number of logical elements must | |||
remain unchanged | |||
:param inp: Input tensor | |||
:param inp: input tensor. | |||
:param target_shape: target shape, the components would be concatenated to form the | |||
target shape, and it can contain an element of -1 representing unspec_axis. | |||
@@ -764,13 +799,51 @@ AxisAddRemove = builtin.AxisAddRemove | |||
AxisDesc = AxisAddRemove.AxisDesc | |||
def flatten(inp: Tensor, start_axis: int = 0, end_axis: int = -1) -> Tensor: | |||
r"""Reshapes the tensor by flattening the sub-tensor from dimension ``start_axis`` to dimension ``end_axis``. | |||
:param inp: input tensor. | |||
:param start_axis: start dimension that the sub-tensor to be flattened. Default: 0 | |||
:param end_axis: end dimension that the sub-tensor to be flattened. Default: -1 | |||
:return: output tensor. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
inp_shape = (2, 2, 3, 3) | |||
x = tensor( | |||
np.arange(36, dtype=np.int32).reshape(inp_shape), | |||
) | |||
out = F.flatten(x, 2) | |||
print(x.numpy().shape) | |||
print(out.numpy().shape) | |||
Outputs: | |||
.. testoutput:: | |||
(2, 2, 3, 3) | |||
(2, 2, 9) | |||
""" | |||
target_shape = tuple(inp.shape[i] for i in range(start_axis)) + (-1,) | |||
if end_axis != -1: | |||
target_shape += (*inp.shape[end_axis + 1 :],) | |||
return inp.reshape(*target_shape) | |||
def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
r""" | |||
Add dimension before given axis. | |||
Adds dimension before given axis. | |||
:param inp: Input tensor | |||
:param axis: Place of new axes | |||
:return: The output tensor | |||
:param inp: input tensor. | |||
:param axis: place of new axes. | |||
:return: output tensor. | |||
Examples: | |||
@@ -779,6 +852,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor([1, 2]) | |||
out = F.add_axis(x, 0) | |||
print(out.shape) | |||
@@ -790,7 +864,7 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
(1, 2) | |||
""" | |||
Param = AxisAddRemove.Param | |||
Param = builtin.AxisAddRemove.Param | |||
def get_axes(): | |||
try: | |||
@@ -803,24 +877,24 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor: | |||
ndim = inp.ndim + len(axis) | |||
axis = sorted(i + ndim if i < 0 else i for i in axis) | |||
param = Param(*map(AxisDesc.make_add, axis)) | |||
op = AxisAddRemove(param=param) | |||
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_add, axis)) | |||
op = builtin.AxisAddRemove(param=param) | |||
(result,) = apply(op, inp) | |||
return result | |||
expand_dims = add_axis | |||
add_axis = add_axis | |||
def remove_axis( | |||
inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None | |||
) -> Tensor: | |||
r""" | |||
Remove dimension of shape 1. | |||
Removes dimension of shape 1. | |||
:param inp: Input tensor | |||
:param axis: Place of axis to be removed, if None, all axis=1 will be removed. Default: None | |||
:return: The output tensor | |||
:param inp: input tensor. | |||
:param axis: place of axis to be removed. | |||
:return: output tensor. | |||
Examples: | |||
@@ -829,6 +903,7 @@ def remove_axis( | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine.functional as F | |||
x = tensor(np.array([1, 2], dtype=np.int32).reshape(1, 1, 2, 1)) | |||
out = F.remove_axis(x, 3) | |||
print(out.shape) | |||
@@ -840,7 +915,7 @@ def remove_axis( | |||
(1, 1, 2) | |||
""" | |||
Param = AxisAddRemove.Param | |||
Param = builtin.AxisAddRemove.Param | |||
def get_axes(): | |||
if axis is None: | |||
@@ -855,15 +930,12 @@ def remove_axis( | |||
axis = sorted(i + inp.ndim if i < 0 else i for i in axis) | |||
axis = [a - i for i, a in enumerate(axis)] | |||
param = Param(*map(AxisDesc.make_remove, axis)) | |||
op = AxisAddRemove(param=param) | |||
param = Param(*map(builtin.AxisAddRemove.AxisDesc.make_remove, axis)) | |||
op = builtin.AxisAddRemove(param=param) | |||
(result,) = apply(op, inp) | |||
return result | |||
squeeze = remove_axis | |||
def linspace( | |||
start: Union[int, float, Tensor], | |||
stop: Union[int, float, Tensor], | |||
@@ -871,14 +943,13 @@ def linspace( | |||
dtype="float32", | |||
device: Optional[CompNode] = None, | |||
) -> Tensor: | |||
r""" | |||
Return equally spaced numbers over a specified interval | |||
r"""Returns equally spaced numbers over a specified interval. | |||
:param start: Starting value of the squence, shoule be scalar | |||
:param stop: The last value of the squence, shoule be scalar | |||
:param num: number of values to generate | |||
:param dtype: result data type | |||
:return: The generated tensor | |||
:param start: starting value of the squence, shoule be scalar. | |||
:param stop: last value of the squence, shoule be scalar. | |||
:param num: number of values to generate. | |||
:param dtype: result data type. | |||
:return: generated tensor. | |||
Examples: | |||
@@ -916,14 +987,13 @@ def arange( | |||
dtype="float32", | |||
device: Optional[CompNode] = None, | |||
) -> Tensor: | |||
r""" | |||
Returns a Tensor with values from `start` to `end` with adjacent interval `step` | |||
r"""Returns a Tensor with values from start to end with adjacent interval step. | |||
:param start: starting value of the squence, shoule be scalar | |||
:param end: ending value of the squence, shoule be scalar | |||
:param step: the gap between each pair of adjacent values. Default 1 | |||
:param dtype: result data type | |||
:return: The generated tensor | |||
:param start: starting value of the squence, shoule be scalar. | |||
:param end: ending value of the squence, shoule be scalar. | |||
:param step: gap between each pair of adjacent values. Default: 1 | |||
:param dtype: result data type. | |||
:return: generated tensor. | |||
Examples: | |||
@@ -937,9 +1007,11 @@ def arange( | |||
Outputs: | |||
Outputs: | |||
.. testoutput:: | |||
[1. 2. 3. 4.] | |||
[0. 1. 2. 3. 4.] | |||
""" | |||
if end is None: | |||
@@ -964,12 +1036,12 @@ def param_pack_split(inp: Tensor, offsets: List, shapes: List) -> Tensor: | |||
Returns split Tensor to Tensor list as offsets and shapes described, | |||
only used for parampack. | |||
:param inp: Input tensor | |||
:param inp: input tensor. | |||
:param offsets: offsets of outputs, length of 2 * n, | |||
while n is tensor nums you want to split, | |||
format [begin0, end0, begin1, end1]. | |||
:param shapes: tensor shapes of outputs | |||
:return: split tensors | |||
format `[begin0, end0, begin1, end1]`. | |||
:param shapes: tensor shapes of outputs. | |||
:return: split tensors. | |||
Examples: | |||
@@ -1004,8 +1076,8 @@ def param_pack_concat(inps: List, offsets: Tensor, offsets_val: List) -> Tensor: | |||
r""" | |||
Returns concat Tensor, only used for parampack. | |||
:param inps: Input tensors | |||
:param offsets: device value of offsets | |||
:param inps: input tensors. | |||
:param offsets: device value of offsets. | |||
:param offsets_val: offsets of inputs, length of 2 * n, | |||
format [begin0, end0, begin1, end1]. | |||
:return: concat tensors | |||
@@ -10,12 +10,15 @@ import collections | |||
import functools | |||
def get_ndtuple(value, *, n, allow_zero=True): | |||
r"""Converts possibly 1D tuple to nd tuple | |||
def get_ndtuple(value, *, n, allow_zero: bool = True): | |||
r"""Converts possibly 1D tuple to n-dim tuple. | |||
:type allow_zero: bool | |||
:param allow_zero: whether to allow zero tuple value""" | |||
if not isinstance(value, collections.abc.Iterable): | |||
:param value: value will be filled in generated tuple. | |||
:param n: how many elements will the tuple have. | |||
:param allow_zero: whether to allow zero tuple value. | |||
:return: a tuple. | |||
""" | |||
if not isinstance(value, collections.Iterable): | |||
value = int(value) | |||
value = tuple([value for i in range(n)]) | |||
else: | |||
@@ -15,7 +15,7 @@ from ..core.ops.builtin import Copy | |||
from ..core.tensor import Tensor | |||
from ..core.tensor.core import apply | |||
from .math import topk as _topk | |||
from .tensor import dimshuffle as _dimshuffle | |||
from .tensor import transpose as _transpose | |||
def accuracy( | |||
@@ -24,11 +24,11 @@ def accuracy( | |||
r""" | |||
Calculate the classification accuracy given predicted logits and ground-truth labels. | |||
:param logits: Model predictions of shape [batch_size, num_classes], | |||
:param logits: model predictions of shape `[batch_size, num_classes]`, | |||
representing the probability (likelyhood) of each class. | |||
:param target: Ground-truth labels, 1d tensor of int32 | |||
:param topk: Specifies the topk values, could be an int or tuple of ints. Default: 1 | |||
:return: Tensor(s) of classification accuracy between 0.0 and 1.0 | |||
:param target: ground-truth labels, 1d tensor of int32. | |||
:param topk: specifies the topk values, could be an int or tuple of ints. Default: 1 | |||
:return: tensor(s) of classification accuracy between 0.0 and 1.0. | |||
Examples: | |||
@@ -54,7 +54,7 @@ def accuracy( | |||
_, pred = _topk(logits, k=max(topk), descending=True) | |||
accs = [] | |||
for k in topk: | |||
correct = pred[:, :k].detach() == _dimshuffle(target, (0, "x")).broadcast( | |||
correct = pred[:, :k].detach() == _transpose(target, (0, "x")).broadcast( | |||
target.shape[0], k | |||
) | |||
accs.append(correct.astype(np.float32).sum() / target.shape[0]) | |||
@@ -63,12 +63,25 @@ def accuracy( | |||
return accs | |||
def zero_grad(inp: Tensor) -> Tensor: | |||
r""" | |||
Returns a tensor which is treated as constant during backward gradient calcuation, | |||
i.e. its gradient is zero. | |||
:param inp: Input tensor. | |||
See implementation of :func:`~.softmax` for example. | |||
""" | |||
print("zero_grad is obsoleted, please use detach instead") | |||
raise NotImplementedError | |||
def copy(inp, cn): | |||
r""" | |||
Copy tensor to another device. | |||
:param inp: Input tensor. | |||
:param cn: device that you copy to | |||
:param inp: input tensor. | |||
:param cn: device that you copy to. | |||
Examples: | |||
@@ -234,32 +234,33 @@ class BatchNorm2d(_BatchNorm): | |||
less than 4D. | |||
:type eps: float | |||
:param eps: a value added to the denominator for numerical stability. | |||
Default: 1e-5. | |||
Default: 1e-5 | |||
:type momentum: float | |||
:param momentum: the value used for the `running_mean` and `running_var` | |||
computation. | |||
Default: 0.9 | |||
:type affine: bool | |||
:param affine: a boolean value that when set to ``True``, this module has | |||
learnable affine parameters. Default: ``True`` | |||
:param affine: a boolean value that when set to True, this module has | |||
learnable affine parameters. Default: True | |||
:type track_running_stats: bool | |||
:param track_running_stats: when set to ``True``, this module tracks the | |||
running mean and variance. When set to ``False``, this module does not | |||
:param track_running_stats: when set to True, this module tracks the | |||
running mean and variance. When set to False, this module does not | |||
track such statistics and always uses batch statistics in both training | |||
and eval modes. Default: ``True``. | |||
and eval modes. Default: True | |||
:type freeze: bool | |||
:param freeze: when set to ``True``, this module does not update the | |||
:param freeze: when set to True, this module does not update the | |||
running mean and variance, and uses the running mean and variance instead of | |||
the batch mean and batch variance to normalize the input. The parameter takes effect | |||
only when the module is initilized with ``track_running_stats`` as ``True`` and | |||
only when the module is initilized with track_running_stats as True and | |||
the module is in training mode. | |||
Default: ``False``. | |||
Default: False | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
@@ -268,13 +269,13 @@ class BatchNorm2d(_BatchNorm): | |||
inp = mge.tensor(np.random.rand(1, 4, 3, 3).astype("float32")) | |||
oup = m(inp) | |||
print(m.weight.numpy(), m.bias.numpy()) | |||
# Without Learnable Parameters | |||
# Without L`e`arnable Parameters | |||
m = M.BatchNorm2d(4, affine=False) | |||
oup = m(inp) | |||
print(m.weight, m.bias) | |||
Outputs: | |||
.. testoutput:: | |||
[1. 1. 1. 1.] [0. 0. 0. 0.] | |||
@@ -88,8 +88,8 @@ class Conv2d(_ConvNd): | |||
:math:`H` is a height of input planes in pixels, and :math:`W` is | |||
width in pixels. | |||
When ``groups == in_channels`` and ``out_channels == K * in_channels``, | |||
where `K` is a positive integer, this operation is also known as depthwise | |||
When `groups == in_channels` and `out_channels == K * in_channels`, | |||
where K is a positive integer, this operation is also known as depthwise | |||
convolution. | |||
In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`, | |||
@@ -98,27 +98,47 @@ class Conv2d(_ConvNd): | |||
:param in_channels: number of input channels. | |||
:param out_channels: number of output channels. | |||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
:param kernel_size: size of weight on spatial dimensions. If kernel_size is | |||
an :class:`int`, the actual kernel size would be | |||
``(kernel_size, kernel_size)``. Default: 1 | |||
`(kernel_size, kernel_size)`. Default: 1 | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
so as to perform a "grouped convolution". When groups is not 1, | |||
in_channels and out_channels must be divisible by groups, | |||
and there would be an extra dimension at the beginning of the weight's | |||
shape. Specifically, the shape of weight would be ``(groups, | |||
out_channel // groups, in_channels // groups, *kernel_size)``. | |||
shape. Specifically, the shape of weight would be `(groups, | |||
out_channel // groups, in_channels // groups, *kernel_size)`. | |||
:param bias: whether to add a bias onto the result of convolution. Default: | |||
True | |||
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
`CROSS_CORRELATION`. | |||
`CROSS_CORRELATION` | |||
:param compute_mode: When set to `DEFAULT`, no special requirements will be | |||
placed on the precision of intermediate results. When set to `FLOAT32`, | |||
float32 would be used for accumulator and intermediate result, but only | |||
effective when input and output are of float16 dtype. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
m = M.Conv2d(in_channels=3, out_channels=1, kernel_size=3) | |||
inp = mge.tensor(np.arange(0, 96).astype("float32").reshape(2, 3, 4, 4)) | |||
oup = m(inp) | |||
print(oup.shape) | |||
Outputs: | |||
.. testoutput:: | |||
(2, 1, 2, 2) | |||
""" | |||
_conv_mode_type = P.Convolution.Mode | |||
@@ -226,7 +246,7 @@ class ConvTranspose2d(_ConvNd): | |||
:param bias: wether to add a bias onto the result of convolution. Default: | |||
True | |||
:param conv_mode: Supports `CROSS_CORRELATION` or `CONVOLUTION`. Default: | |||
`CROSS_CORRELATION`. | |||
`CROSS_CORRELATION` | |||
:param compute_mode: When set to `DEFAULT`, no special requirements will be | |||
placed on the precision of intermediate results. When set to `FLOAT32`, | |||
float32 would be used for accumulator and intermediate result, but only | |||
@@ -314,17 +334,17 @@ class LocalConv2d(Conv2d): | |||
:param out_channels: number of output channels. | |||
:param input_height: the height of the input images. | |||
:param input_width: the width of the input images. | |||
:param kernel_size: size of weight on spatial dimensions. If ``kernel_size`` is | |||
:param kernel_size: size of weight on spatial dimensions. If kernel_size is | |||
an :class:`int`, the actual kernel size would be | |||
``(kernel_size, kernel_size)``. Default: 1 | |||
`(kernel_size, kernel_size)`. Default: 1 | |||
:param stride: stride of the 2D convolution operation. Default: 1 | |||
:param padding: size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||
The shape of weight is ``(groups, output_height, output_width, | |||
in_channels // groups, *kernel_size, out_channels // groups)``. | |||
so as to perform a "grouped convolution". When groups is not 1, | |||
in_channels and out_channels must be divisible by groups. | |||
The shape of weight is `(groups, output_height, output_width, | |||
in_channels // groups, *kernel_size, out_channels // groups)`. | |||
""" | |||
_conv_mode_type = P.Convolution.Mode | |||
@@ -11,7 +11,8 @@ from .module import Module | |||
class Dropout(Module): | |||
r"""Randomly set input elements to zeros with the probability :math:`drop\_prob` during training. Commonly used in large networks to prevent overfitting. | |||
r"""Randomly set input elements to zeros with the probability :math:`drop\_prob` during training. | |||
Commonly used in large networks to prevent overfitting. | |||
Note that we perform dropout only during training, we also rescale(multiply) the output tensor | |||
by :math:`\frac{1}{1 - drop\_prob}`. During inference :class:`~.Dropout` is equal to :class:`~.Identity`. | |||
@@ -67,6 +67,10 @@ class Elemwise(Module): | |||
* "H_SWISH": h_swish | |||
* "FUSE_ADD_H_SWISH": h_swish(x+y) | |||
* "H_SWISH_GRAD": h_swish_grad | |||
* "AND": bool binary: x && y | |||
* "OR": bool binary: x || y | |||
* "XOR": bool binary: x ^ y | |||
* "NOT": bool unary: ~x | |||
""" | |||
_elemwise_mode_type = P.Elemwise.Mode | |||
@@ -78,7 +78,7 @@ def calculate_gain( | |||
Sigmoid :math:`1` | |||
Tanh :math:`\frac{5}{3}` | |||
ReLU :math:`\sqrt{2}` | |||
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative_{slope}}^2}}` | |||
Leaky Relu :math:`\sqrt{\frac{2}{1 + {\text{negative}_\text{slope}}^2}}` | |||
================= ==================================================== | |||
:param nonlinearity: Name of the non-linear function | |||
@@ -28,6 +28,25 @@ class Linear(Module): | |||
:param bias: If set to ``False``, the layer will not learn an additive bias. | |||
Default: ``True`` | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
m = M.Linear(in_features=3, out_features=1) | |||
inp = mge.tensor(np.arange(0, 6).astype("float32").reshape(2, 3)) | |||
oup = m(inp) | |||
print(oup.shape) | |||
Outputs: | |||
.. testoutput:: | |||
(2, 1) | |||
""" | |||
def __init__( | |||
@@ -48,8 +48,29 @@ class MaxPool2d(_PoolNd): | |||
both sides for :attr:`padding` number of points. | |||
:param kernel_size: the size of the window to take a max over. | |||
:param stride: the stride of the window. Default value is ``kernel_size``. | |||
:param stride: the stride of the window. Default value is kernel_size. | |||
:param padding: implicit zero padding to be added on both sides. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
m = M.MaxPool2d(kernel_size=3, stride=1, padding=0) | |||
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4)) | |||
oup = m(inp) | |||
print(oup.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[[10. 11.] | |||
[14. 15.]]]] | |||
""" | |||
def forward(self, inp): | |||
@@ -72,8 +93,29 @@ class AvgPool2d(_PoolNd): | |||
both sides for :attr:`padding` number of points. | |||
:param kernel_size: the size of the window. | |||
:param stride: the stride of the window. Default value is ``kernel_size``. | |||
:param stride: the stride of the window. Default value is kernel_size。 | |||
:param padding: implicit zero padding to be added on both sides. | |||
Examples: | |||
.. testcode:: | |||
import numpy as np | |||
import megengine as mge | |||
import megengine.module as M | |||
m = M.AvgPool2d(kernel_size=3, stride=1, padding=0) | |||
inp = mge.tensor(np.arange(0, 16).astype("float32").reshape(1, 1, 4, 4)) | |||
oup = m(inp) | |||
print(oup.numpy()) | |||
Outputs: | |||
.. testoutput:: | |||
[[[[ 5. 6.] | |||
[ 9. 10.]]]] | |||
""" | |||
def forward(self, inp): | |||
@@ -23,12 +23,13 @@ class Sequential(Module): | |||
.. testcode:: | |||
import numpy as np | |||
from megengine import tensor | |||
import megengine as mge | |||
import megengine.module as M | |||
import megengine.functional as F | |||
batch_size = 64 | |||
data = tensor(np.zeros((batch_size, 1, 28, 28)), dtype=np.float32) | |||
label = tensor(np.zeros(batch_size,), dtype=np.int32) | |||
data = mge.tensor(np.zeros((batch_size, 1, 28, 28)), dtype=np.float32) | |||
label = mge.tensor(np.zeros(batch_size,), dtype=np.int32) | |||
data = data.reshape(batch_size, -1) | |||
net = M.Sequential( | |||
@@ -192,7 +192,7 @@ def unpack_getitem(inp, tuple_val, *, allow_newaxis=True): | |||
return inp, tensors, items | |||
def dimshuffle(*args, **kwargs): | |||
def transpose(*args, **kwargs): | |||
op = all_ops.Dimshuffle(**kwargs).to_c() | |||
return invoke_op(op, args) | |||
@@ -274,10 +274,10 @@ def batched_incr_mesh_indexing(input, value, tuple_val): | |||
return invoke_op(op, (input, value, *tensors)) | |||
def test_dimshuffle(): | |||
def test_transpose(): | |||
x = np.arange(10).reshape(2, 5).astype("int32") | |||
xx = as_raw_tensor(x) | |||
(yy,) = dimshuffle(xx, pattern="1x0") | |||
(yy,) = transpose(xx, pattern="1x0") | |||
np.testing.assert_equal(np.expand_dims(x.transpose(), axis=1), yy.numpy()) | |||
@@ -0,0 +1,29 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import pytest | |||
from megengine.core import Tensor | |||
# from megengine.core.interpreter.hints import function | |||
@pytest.mark.skip(reason="under rewrite") | |||
def test_1(): | |||
@function | |||
def f(x, p): | |||
x = x + 1 | |||
if p: | |||
return x * x | |||
return x * 2 | |||
x = Tensor(0) | |||
for _ in range(5): | |||
assert f(x, 0).numpy() == 2 | |||
assert f(x, 1).numpy() == 1 |
@@ -0,0 +1,43 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import os | |||
import sys | |||
import numpy as np | |||
import pytest | |||
from megengine.data.dataset import ArrayDataset, Dataset, MapDataset, StreamDataset | |||
def test_abstract_cls(): | |||
with pytest.raises(TypeError): | |||
Dataset() | |||
with pytest.raises(TypeError): | |||
MapDataset() | |||
with pytest.raises(TypeError): | |||
StreamDataset() | |||
def test_array_dataset(): | |||
size = (10,) | |||
data_shape = (3, 256, 256) | |||
label_shape = (1,) | |||
data = np.random.randint(0, 255, size + data_shape) | |||
label = np.random.randint(0, 9, size + label_shape) | |||
dataset = ArrayDataset(data, label) | |||
assert dataset[0][0].shape == data_shape | |||
assert dataset[0][1].shape == label_shape | |||
assert len(dataset) == size[0] | |||
def test_array_dataset_dim_error(): | |||
data = np.random.randint(0, 255, (10, 3, 256, 256)) | |||
label = np.random.randint(0, 9, (1,)) | |||
with pytest.raises(ValueError): | |||
ArrayDataset(data, label) |
@@ -0,0 +1,81 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import copy | |||
import os | |||
import sys | |||
import numpy as np | |||
import pytest | |||
from megengine.data.dataset import ArrayDataset | |||
from megengine.data.sampler import RandomSampler, ReplacementSampler, SequentialSampler | |||
def test_sequential_sampler(): | |||
indices = list(range(100)) | |||
sampler = SequentialSampler(ArrayDataset(indices)) | |||
assert indices == list(each[0] for each in sampler) | |||
def test_RandomSampler(): | |||
indices = list(range(20)) | |||
indices_copy = copy.deepcopy(indices) | |||
sampler = RandomSampler(ArrayDataset(indices_copy)) | |||
sample_indices = sampler | |||
assert indices != list(each[0] for each in sample_indices) | |||
assert indices == sorted(list(each[0] for each in sample_indices)) | |||
def test_random_sampler_seed(): | |||
seed = [0, 1] | |||
indices = list(range(20)) | |||
indices_copy1 = copy.deepcopy(indices) | |||
indices_copy2 = copy.deepcopy(indices) | |||
indices_copy3 = copy.deepcopy(indices) | |||
sampler1 = RandomSampler(ArrayDataset(indices_copy1), seed=seed[0]) | |||
sampler2 = RandomSampler(ArrayDataset(indices_copy2), seed=seed[0]) | |||
sampler3 = RandomSampler(ArrayDataset(indices_copy3), seed=seed[1]) | |||
assert indices != list(each[0] for each in sampler1) | |||
assert indices != list(each[0] for each in sampler2) | |||
assert indices != list(each[0] for each in sampler3) | |||
assert indices == sorted(list(each[0] for each in sampler1)) | |||
assert indices == sorted(list(each[0] for each in sampler2)) | |||
assert indices == sorted(list(each[0] for each in sampler3)) | |||
assert list(each[0] for each in sampler1) == list(each[0] for each in sampler2) | |||
assert list(each[0] for each in sampler1) != list(each[0] for each in sampler3) | |||
def test_ReplacementSampler(): | |||
num_samples = 30 | |||
indices = list(range(20)) | |||
weights = list(range(20)) | |||
sampler = ReplacementSampler( | |||
ArrayDataset(indices), num_samples=num_samples, weights=weights | |||
) | |||
assert len(list(each[0] for each in sampler)) == num_samples | |||
def test_sampler_drop_last_false(): | |||
batch_size = 5 | |||
drop_last = False | |||
indices = list(range(24)) | |||
sampler = SequentialSampler( | |||
ArrayDataset(indices), batch_size=batch_size, drop_last=drop_last | |||
) | |||
assert len([each for each in sampler]) == len(sampler) | |||
def test_sampler_drop_last_true(): | |||
batch_size = 5 | |||
drop_last = True | |||
indices = list(range(24)) | |||
sampler = SequentialSampler( | |||
ArrayDataset(indices), batch_size=batch_size, drop_last=drop_last | |||
) | |||
assert len([each for each in sampler]) == len(sampler) |
@@ -0,0 +1,108 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from megengine.data.transform import * | |||
data_shape = (100, 100, 3) | |||
label_shape = (4,) | |||
ToMode_target_shape = (3, 100, 100) | |||
CenterCrop_size = (90, 70) | |||
CenterCrop_target_shape = CenterCrop_size + (3,) | |||
RandomResizedCrop_size = (50, 50) | |||
RandomResizedCrop_target_shape = RandomResizedCrop_size + (3,) | |||
def generate_data(): | |||
return [ | |||
( | |||
(np.random.rand(*data_shape) * 255).astype(np.uint8), | |||
np.random.randint(10, size=label_shape), | |||
) | |||
for _ in range(*label_shape) | |||
] | |||
def test_ToMode(): | |||
t = ToMode(mode="CHW") | |||
aug_data = t.apply_batch(generate_data()) | |||
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data] | |||
target_shape = [(ToMode_target_shape, label_shape)] * 4 | |||
assert aug_data_shape == target_shape | |||
def test_CenterCrop(): | |||
t = CenterCrop(output_size=CenterCrop_size) | |||
aug_data = t.apply_batch(generate_data()) | |||
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data] | |||
target_shape = [(CenterCrop_target_shape, label_shape)] * 4 | |||
assert aug_data_shape == target_shape | |||
def test_ColorJitter(): | |||
t = ColorJitter() | |||
aug_data = t.apply_batch(generate_data()) | |||
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data] | |||
target_shape = [(data_shape, label_shape)] * 4 | |||
assert aug_data_shape == target_shape | |||
def test_RandomHorizontalFlip(): | |||
t = RandomHorizontalFlip(prob=1) | |||
aug_data = t.apply_batch(generate_data()) | |||
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data] | |||
target_shape = [(data_shape, label_shape)] * 4 | |||
assert aug_data_shape == target_shape | |||
def test_RandomVerticalFlip(): | |||
t = RandomVerticalFlip(prob=1) | |||
aug_data = t.apply_batch(generate_data()) | |||
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data] | |||
target_shape = [(data_shape, label_shape)] * 4 | |||
assert aug_data_shape == target_shape | |||
def test_RandomResizedCrop(): | |||
t = RandomResizedCrop(output_size=RandomResizedCrop_size) | |||
aug_data = t.apply_batch(generate_data()) | |||
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data] | |||
target_shape = [(RandomResizedCrop_target_shape, label_shape)] * 4 | |||
assert aug_data_shape == target_shape | |||
def test_Normalize(): | |||
t = Normalize() | |||
aug_data = t.apply_batch(generate_data()) | |||
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data] | |||
target_shape = [(data_shape, label_shape)] * 4 | |||
assert aug_data_shape == target_shape | |||
def test_RandomCrop(): | |||
t = RandomCrop((150, 120), padding_size=10, padding_value=[1, 2, 3]) | |||
aug_data = t.apply_batch(generate_data()) | |||
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data] | |||
target_shape = [((150, 120, 3), label_shape)] * 4 | |||
assert aug_data_shape == target_shape | |||
def test_Compose(): | |||
t = Compose( | |||
[ | |||
CenterCrop(output_size=CenterCrop_size), | |||
RandomHorizontalFlip(prob=1), | |||
ToMode(mode="CHW"), | |||
] | |||
) | |||
aug_data = t.apply_batch(generate_data()) | |||
aug_data_shape = [(a.shape, b.shape) for a, b in aug_data] | |||
print(aug_data_shape) | |||
target_shape = [((3, 90, 70), label_shape)] * 4 | |||
assert aug_data_shape == target_shape |
@@ -83,48 +83,6 @@ def opr_test(cases, func, compare_fn=_default_compare_fn, ref_fn=None, **kwargs) | |||
check_results(results, outp) | |||
def test_flatten(): | |||
data0_shape = (2, 3, 4, 5) | |||
data1_shape = (4, 5, 6, 7) | |||
data0 = np.random.random(data0_shape).astype(np.float32) | |||
data1 = np.random.random(data1_shape).astype(np.float32) | |||
def compare_fn(x, y): | |||
assert x.numpy().shape == y | |||
output0 = (2 * 3 * 4 * 5,) | |||
output1 = (4 * 5 * 6 * 7,) | |||
cases = [ | |||
{"input": data0, "output": (output0,)}, | |||
{"input": data1, "output": (output1,)}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn) | |||
output0 = (2, 3 * 4 * 5) | |||
output1 = (4, 5 * 6 * 7) | |||
cases = [ | |||
{"input": data0, "output": (output0,)}, | |||
{"input": data1, "output": (output1,)}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1) | |||
output0 = (2, 3, 4 * 5) | |||
output1 = (4, 5, 6 * 7) | |||
cases = [ | |||
{"input": data0, "output": (output0,)}, | |||
{"input": data1, "output": (output1,)}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2) | |||
output0 = (2, 3 * 4, 5) | |||
output1 = (4, 5 * 6, 7) | |||
cases = [ | |||
{"input": data0, "output": (output0,)}, | |||
{"input": data1, "output": (output1,)}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) | |||
def test_where(): | |||
maskv0 = np.array([[1, 0], [0, 1]], dtype=np.bool_) | |||
xv0 = np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32) | |||
@@ -155,45 +113,6 @@ def test_where(): | |||
opr_test(cases, F.where, ref_fn=np.where) | |||
def test_matmul(): | |||
shape1 = 3 | |||
shape2 = 3 | |||
shape3 = (3, 5) | |||
shape4 = (5, 6) | |||
data1 = np.random.random(shape1).astype("float32") | |||
data2 = np.random.random(shape2).astype("float32") | |||
data3 = np.random.random(shape3).astype("float32") | |||
data4 = np.random.random(shape4).astype("float32") | |||
cases = [ | |||
{"input": [data1, data2]}, | |||
{"input": [data2, data3]}, | |||
{"input": [data3, data4]}, | |||
] | |||
opr_test(cases, F.matmul, ref_fn=np.matmul) | |||
batch_size = 10 | |||
shape1 = (batch_size, 2, 3) | |||
shape2 = (batch_size, 3, 4) | |||
shape3 = (batch_size, 10, 4, 5) | |||
data1 = np.random.random(shape1).astype("float32") | |||
data2 = np.random.random(shape2).astype("float32") | |||
data3 = np.random.random(shape3).astype("float32") | |||
cases = [{"input": [data1, data2]}, {"input": [data2, data3]}] | |||
for i in range(0, batch_size): | |||
def compare_fn(x, y): | |||
x.numpy()[i, ...] == y | |||
opr_test( | |||
cases, | |||
F.matmul, | |||
compare_fn=compare_fn, | |||
ref_fn=lambda x, y: np.matmul(x[i, ...], y[i, ...]), | |||
) | |||
def test_interpolate(): | |||
def linear_interpolate(): | |||
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) | |||
@@ -303,28 +222,28 @@ def test_roi_pooling(): | |||
assert make_shape_tuple(inp_feat.grad.shape) == make_shape_tuple(inp_feat.shape) | |||
# def test_one_hot(): | |||
# def onehot_low_dimension(): | |||
# inp = tensor(np.arange(1, 4, dtype=np.int32)) | |||
# out = F.one_hot(inp, num_classes=4) | |||
# assertTensorClose( | |||
# out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)] | |||
# ) | |||
def test_one_hot(): | |||
def onehot_low_dimension(): | |||
inp = tensor(np.arange(1, 4, dtype=np.int32)) | |||
out = F.one_hot(inp, num_classes=4) | |||
assertTensorClose( | |||
out.numpy(), np.eye(4, dtype=np.int32)[np.arange(1, 4, dtype=np.int32)] | |||
) | |||
# def onehot_high_dimension(): | |||
# arr = np.array( | |||
# [[3, 2, 4, 4, 2, 4, 0, 4, 4, 1], [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]], dtype=np.int32 | |||
# ) | |||
def onehot_high_dimension(): | |||
arr = np.array( | |||
[[3, 2, 4, 4, 2, 4, 0, 4, 4, 1], [4, 1, 1, 3, 2, 2, 4, 2, 4, 3]], | |||
dtype=np.int32, | |||
) | |||
# inp = tensor(arr) | |||
# out = F.one_hot(inp, 10) | |||
inp = tensor(arr) | |||
out = F.one_hot(inp, 10) | |||
# assertTensorClose(out.numpy(), np.eye(10, dtype=np.int32)[arr]) | |||
assertTensorClose(out.numpy(), np.eye(10, dtype=np.int32)[arr]) | |||
# onehot_low_dimension() | |||
# onehot_high_dimension() | |||
onehot_low_dimension() | |||
onehot_high_dimension() | |||
def test_add_update(): | |||
@@ -554,7 +473,7 @@ def test_conv_bias(): | |||
var = F.reshape( | |||
var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3]) | |||
) | |||
var = F.dimshuffle(var, (0, 1, 3, 4, 2)) | |||
var = F.transpose(var, (0, 1, 3, 4, 2)) | |||
return var | |||
def run_conv2d(inp, w, b): | |||
@@ -591,7 +510,7 @@ def test_conv_bias(): | |||
"float32" | |||
) | |||
if format == "NCHW4": | |||
result = F.dimshuffle(result, (0, 1, 4, 2, 3)) | |||
result = F.transpose(result, (0, 1, 4, 2, 3)) | |||
expected = F.flatten(expected) | |||
result = F.flatten(result) | |||
assertTensorClose(result.numpy(), expected.numpy(), max_err=outp_scale) | |||
@@ -608,22 +527,6 @@ def test_conv_bias(): | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") | |||
# def test_softplus(): | |||
# x = np.arange(1000).astype(np.float32) | |||
# out = F.softplus(tensor(x)) | |||
# mask = x <= 20 | |||
# with np.errstate(over="ignore"): | |||
# expected = np.where(mask, np.log(1 + np.exp(x)), x) | |||
# assertTensorClose(out, expected) | |||
# beta = 2 | |||
# out = F.softplus(tensor(x), beta=beta, threshold=30) | |||
# mask = beta * x <= 30 | |||
# # ignore overflow | |||
# with np.errstate(over="ignore"): | |||
# expected = np.where(mask, np.log(1 + np.exp(x * beta)) / beta, x) | |||
# assertTensorClose(out, expected) | |||
def test_condtake(): | |||
x = np.array([[1, 2, 3], [4, 5, 6]]) | |||
y = np.array([[True, False, True], [False, True, True]]) | |||
@@ -12,7 +12,6 @@ import megengine.functional as F | |||
from megengine import tensor | |||
# XXX need to test label_smooth | |||
def test_cross_entropy_with_softmax(): | |||
data = tensor([1, 100]).astype(np.float32).reshape((1, 2)) | |||
label = tensor([1]).astype(np.int32) |
@@ -14,8 +14,6 @@ import megengine.functional as F | |||
from megengine import tensor | |||
from megengine.test import assertTensorClose | |||
# from helpers import opr_test | |||
def _default_compare_fn(x, y): | |||
assertTensorClose(x.numpy(), y) | |||
@@ -207,6 +205,45 @@ def test_normalize(): | |||
opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3)) | |||
def test_matmul(): | |||
shape1 = 3 | |||
shape2 = 3 | |||
shape3 = (3, 5) | |||
shape4 = (5, 6) | |||
data1 = np.random.random(shape1).astype("float32") | |||
data2 = np.random.random(shape2).astype("float32") | |||
data3 = np.random.random(shape3).astype("float32") | |||
data4 = np.random.random(shape4).astype("float32") | |||
cases = [ | |||
{"input": [data1, data2]}, | |||
{"input": [data2, data3]}, | |||
{"input": [data3, data4]}, | |||
] | |||
opr_test(cases, F.matmul, ref_fn=np.matmul) | |||
batch_size = 10 | |||
shape1 = (batch_size, 2, 3) | |||
shape2 = (batch_size, 3, 4) | |||
shape3 = (batch_size, 10, 4, 5) | |||
data1 = np.random.random(shape1).astype("float32") | |||
data2 = np.random.random(shape2).astype("float32") | |||
data3 = np.random.random(shape3).astype("float32") | |||
cases = [{"input": [data1, data2]}, {"input": [data2, data3]}] | |||
for i in range(0, batch_size): | |||
def compare_fn(x, y): | |||
x.numpy()[i, ...] == y | |||
opr_test( | |||
cases, | |||
F.matmul, | |||
compare_fn=compare_fn, | |||
ref_fn=lambda x, y: np.matmul(x[i, ...], y[i, ...]), | |||
) | |||
# def test_logsumexp(): | |||
# x = np.arange(10).astype(np.float32) | |||
# expected = np.log(np.sum(np.exp(x))) | |||
@@ -165,7 +165,7 @@ def test_squeeze(): | |||
for axis in [None, 3, -4, (3, -4)]: | |||
y = np.squeeze(x, axis) | |||
yy = F.squeeze(xx, axis) | |||
yy = F.remove_axis(xx, axis) | |||
np.testing.assert_equal(y, yy.numpy()) | |||
@@ -175,7 +175,7 @@ def test_expand_dims(): | |||
for axis in [2, -3, (3, -4), (1, -4)]: | |||
y = np.expand_dims(x, axis) | |||
yy = F.expand_dims(xx, axis) | |||
yy = F.add_axis(xx, axis) | |||
np.testing.assert_equal(y, yy.numpy()) | |||
@@ -258,6 +258,48 @@ def test_round(): | |||
opr_test(cases, F.round, ref_fn=np.round) | |||
def test_flatten(): | |||
data0_shape = (2, 3, 4, 5) | |||
data1_shape = (4, 5, 6, 7) | |||
data0 = np.random.random(data0_shape).astype(np.float32) | |||
data1 = np.random.random(data1_shape).astype(np.float32) | |||
def compare_fn(x, y): | |||
assert x.numpy().shape == y[0] | |||
output0 = (2 * 3 * 4 * 5,) | |||
output1 = (4 * 5 * 6 * 7,) | |||
cases = [ | |||
{"input": data0, "output": (output0,)}, | |||
{"input": data1, "output": (output1,)}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn) | |||
output0 = (2, 3 * 4 * 5) | |||
output1 = (4, 5 * 6 * 7) | |||
cases = [ | |||
{"input": data0, "output": (output0,)}, | |||
{"input": data1, "output": (output1,)}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1) | |||
output0 = (2, 3, 4 * 5) | |||
output1 = (4, 5, 6 * 7) | |||
cases = [ | |||
{"input": data0, "output": (output0,)}, | |||
{"input": data1, "output": (output1,)}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2) | |||
output0 = (2, 3 * 4, 5) | |||
output1 = (4, 5 * 6, 7) | |||
cases = [ | |||
{"input": data0, "output": (output0,)}, | |||
{"input": data1, "output": (output1,)}, | |||
] | |||
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, end_axis=2) | |||
def test_broadcast(): | |||
input1_shape = (20, 30) | |||
output1_shape = (30, 20, 30) | |||