Browse Source

fix(mge/functional): functional api fixes

GitOrigin-RevId: fa206c4ff6
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
39d4328309
15 changed files with 139 additions and 312 deletions
  1. +1
    -11
      imperative/python/megengine/functional/__init__.py
  2. +0
    -20
      imperative/python/megengine/functional/elemwise.py
  3. +8
    -199
      imperative/python/megengine/functional/loss.py
  4. +30
    -31
      imperative/python/megengine/functional/math.py
  5. +5
    -5
      imperative/python/megengine/functional/nn.py
  6. +35
    -23
      imperative/python/megengine/functional/tensor.py
  7. +8
    -0
      imperative/python/megengine/jit/__init__.py
  8. +0
    -1
      imperative/python/megengine/jit/sublinear_memory_config.py
  9. +8
    -0
      imperative/python/megengine/jit/tracing.py
  10. +3
    -3
      imperative/python/megengine/module/init.py
  11. +2
    -2
      imperative/python/megengine/random/__init__.py
  12. +19
    -11
      imperative/python/megengine/random/distribution.py
  13. +2
    -2
      imperative/python/megengine/random/rng.py
  14. +10
    -4
      imperative/python/test/unit/functional/test_elemwise.py
  15. +8
    -0
      imperative/python/test/unit/functional/test_tensor.py

+ 1
- 11
imperative/python/megengine/functional/__init__.py View File

@@ -9,17 +9,7 @@
# pylint: disable=redefined-builtin
from .elemwise import *
from .graph import add_update
from .loss import (
binary_cross_entropy,
cross_entropy,
cross_entropy_with_softmax,
hinge_loss,
l1_loss,
nll_loss,
smooth_l1_loss,
square_loss,
triplet_margin_loss,
)
from .loss import *
from .math import *
from .nn import *
from .quantized import conv_bias_activation


+ 0
- 20
imperative/python/megengine/functional/elemwise.py View File

@@ -25,10 +25,6 @@ __all__ = [
"asinh",
"acosh",
"atanh",
"bitwise_and", # TODO
"bitwise_not", # TODO
"bitwise_or", # TODO
"bitwise_xor", # TODO
"ceil",
"clamp",
"cos",
@@ -339,22 +335,6 @@ def right_shift(x, y):
return _elwise(x, y, mode="shl")


def bitwise_and(x, y):
raise NotImplementedError


def bitwise_not(x):
raise NotImplementedError


def bitwise_or(x, y):
raise NotImplementedError


def bitwise_xor(x, y):
raise NotImplementedError


# logical functions




+ 8
- 199
imperative/python/megengine/functional/loss.py View File

@@ -15,6 +15,14 @@ from .nn import assert_equal, indexing_one_hot
from .tensor import where
from .utils import zero_grad

__all__ = [
"l1_loss",
"square_loss",
"cross_entropy_with_softmax",
"binary_cross_entropy",
"hinge_loss",
]


def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
r"""
@@ -93,59 +101,6 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor:
return (diff ** 2).mean()


def cross_entropy(
inp: Tensor, target: Tensor, axis: int = 1, ignore_index: int = -1
) -> Tensor:
r"""
Returns the cross entropy loss in a classification problem.

.. math:: \textrm{CrossEntropy}(x, y) = - \sum_{i} y_i\log(x_i)

:param inp: The input tensor representing the predicted probability.
:param label: The input tensor representing the classification label.
:param axis: An axis along which cross_entropy will be applied. Default: 1
:param ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. Default: -1

Examples:

.. testcode::

import numpy as np
from megengine import tensor
import megengine.functional as F

data_shape = (1, 2)
label_shape = (1, )

pred = tensor(np.array([0.5, 0.5], dtype=np.float32).reshape(data_shape))
label = tensor(np.ones(label_shape, dtype=np.int32))
loss = F.cross_entropy(pred, label)
print(loss.numpy())

Outputs:

.. testoutput::

[0.69]

"""
raise NotImplementedError
# n0 = inp.ndim
# n1 = target.ndim
# assert n0 == n1 + 1, (
# "target ndim must be one less than input ndim; input_ndim={} "
# "target_ndim={}".format(n0, n1)
# )

# if ignore_index != -1:
# mask = 1 - equal(target, ignore_index)
# target = target * mask
# loss = -log(indexing_one_hot(inp, target, axis)) * mask
# return loss.sum() / maximum(mask.sum(), 1.0)
# else:
# return -log(indexing_one_hot(inp, target, axis)).mean()


def cross_entropy_with_softmax(
pred: Tensor, label: Tensor, axis: int = 1, label_smooth: float = 0
) -> Tensor:
@@ -189,49 +144,6 @@ def cross_entropy_with_softmax(
return (log(down) - up).mean()


def triplet_margin_loss(
anchor: Tensor, positive: Tensor, negative: Tensor, margin: float = 1.0, p: int = 2
) -> Tensor:
r"""
Creates a criterion that measures the triplet loss given an input tensors.

.. math::

L(a, p, n) = max\left\{d\left(a_{i},p_{i}\right)-d\left(a_{i}, n_{i}\right)+margin, 0\right\},\
d\left(x_{i},y_{i}\right)=\left\|x_{i}-y_{i}\right\|_{p}

:param anchor: The input tensor representing the anchor samples.
:param positive: The input tensor representing the positive samples.
:param negative: The input tensor representing the negative samples.
:param margin: Default: 1.0
:param p: The norm degree for pairwise distance. Default: 2.0
"""
s0 = anchor.shapeof()
s1 = positive.shapeof()
s2 = negative.shapeof()
assert_equal(s0, s1)
assert_equal(s1, s2)

n0 = anchor.ndim
n1 = positive.ndim
n2 = negative.ndim
assert n0 == 2 and n1 == 2 and n2 == 2, (
"anchor ndim, positive ndim, and negative ndim must be 2; "
"anchor_ndim={} positive_ndim={} negative_ndim={}".format(n0, n1, n2)
)
assert p > 0, "a margin with a value greater than 0; p={}".format(p)

diff0 = abs(anchor - positive)
diff1 = abs(anchor - negative)

d1 = power(power(diff0, p).sum(axis=1, keepdims=True), 1 / p)
d2 = power(power(diff1, p).sum(axis=1, keepdims=True), 1 / p)

loss = maximum(d1 - d2 + margin, 0)

return loss.mean()


def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
r"""Function that measures the Binary Cross Entropy between the target and the prediction.

@@ -244,59 +156,6 @@ def binary_cross_entropy(pred: Tensor, label: Tensor) -> Tensor:
return -1.0 * (label * log(pred) + (1.0 - label) * log(1 - pred)).mean()


def nll_loss(
pred: Tensor, label: Tensor, axis: int = 1, ignore_index: int = -1
) -> Tensor:
r"""
The negative log likelihood loss.

:param pred: The predicted result from model.
:param label: The ground truth to compare.

Examples:

.. testcode::

import numpy as np
from megengine import tensor
import megengine.functional as F
data_shape = (2, 2)
label_shape = (2, )

data = tensor(
np.array([[1, 0.5], [0.3, 1.2]], dtype=np.float32).reshape(data_shape),
)
label = tensor(
np.ones(label_shape, dtype=np.int32)
)
pred = F.log(F.softmax(data))
loss1 = F.nll_loss(pred, label)
loss2 = F.cross_entropy_with_softmax(data, label)
print(loss1.numpy(), loss2.numpy())

Outputs:

.. testoutput::

[0.6576154] [0.6576154]

"""
raise NotImplementedError
# n0 = pred.ndim
# n1 = label.ndim
# assert n0 == n1 + 1, (
# "target ndim must be one less than input ndim; input_ndim={} "
# "target_ndim={}".format(n0, n1)
# )

# mask = 1.0 - equal(label, ignore_index)
# label = label * mask

# loss = indexing_one_hot(pred, label, axis) * mask

# return -1.0 * loss.sum() / maximum(mask.sum(), 1.0)


def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
r"""
Caculate the hinge loss which is often used in SVMs.
@@ -337,53 +196,3 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
return loss.sum(axis=1).mean()
else:
return (loss ** 2).sum(axis=1).mean()


def smooth_l1_loss(pred: Tensor, label: Tensor) -> Tensor:
r"""
Caculate the smooth l1 loss proposed in `Fast R-CNN paper by Ross Girshick`.

The smooth l1 loss can be described as:

.. math::
\text{loss}(x, y) = \frac{1}{n} \sum_{i} l_{i}

where :math:`l_{i}` is given by:

.. math::
l_{i} =
\begin{cases}
0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\
|x_i - y_i| - 0.5, & \text{otherwise }
\end{cases}

:param pred: The predicted result from model.
:param label: The ground truth to compare.

Examples:

.. testcode::

from megengine import tensor
import megengine.functional as F

pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
label = tensor([[0.4, 1.5, 1.2], [0., 0.1, 2.2]])

loss = F.smooth_l1_loss(pred, label)

print(loss.numpy())

Outputs:

.. testoutput::

[0.5608334]
"""
raise NotImplementedError
# diff = abs(pred - label)
# l2_loss = 0.5 * (diff ** 2)
# l1_loss = diff - 0.5
# mask = diff < 1
# loss = where(mask, l2_loss, l1_loss)
# return loss.mean()

+ 30
- 31
imperative/python/megengine/functional/math.py View File

@@ -21,47 +21,26 @@ from .elemwise import clamp, exp, log, log1p
from .tensor import remove_axis, reshape

__all__ = [
"all", # TODO
"all_close", # TODO
"any", # TODO
"argmax",
"argmin",
"argsort",
"isinf",
"isnan", # TODO
"isnan",
"max",
"mean",
"median", # TODO
"min",
"norm",
"normalize",
"prod",
"sign", # TODO
"sign",
"sort",
"std",
"sum",
"topk",
"unique", # TODO
"var",
]


def all(inp):
raise NotImplementedError


def all_close(inp):
raise NotImplementedError


def any(inp):
raise NotImplementedError


def unique(inp):
raise NotImplementedError


def isnan(inp: Tensor) -> Tensor:
r"""Returns a new tensor representing if each element is NaN or not.

@@ -77,15 +56,14 @@ def isnan(inp: Tensor) -> Tensor:

x = tensor([1, float("nan"), 0])

print(F.isnan(x))
print(F.isnan(x).numpy())

.. testoutput::

Tensor([0 1 0], dtype=uint8)
[False True False]

"""
raise NotImplementedError
# return (inp != inp).astype("uint8")
return inp != inp


def isinf(inp: Tensor) -> Tensor:
@@ -103,18 +81,39 @@ def isinf(inp: Tensor) -> Tensor:

x = tensor([1, float("inf"), 0])

print(F.isinf(x))
print(F.isinf(x).numpy())

.. testoutput::

Tensor([0 1 0], dtype=uint8)
[False True False]

"""
return (abs(inp).astype("float32") == float("inf")).astype("uint8")
return abs(inp).astype("float32") == float("inf")


def sign(inp: Tensor):
raise NotImplementedError
r"""Returns sign of each element in the input tensor.

:param: inp
:return: a sign tensor.

Examples:

.. testcode::

from megengine import tensor
import megengine.functional as F

x = tensor([1, -1, 0])

print(F.sign(x).numpy())

.. testoutput::

[ 1 -1 0]

"""
return (inp > 0).astype(inp.dtype) - (inp < 0).astype(inp.dtype)


def sum(


+ 5
- 5
imperative/python/megengine/functional/nn.py View File

@@ -623,7 +623,7 @@ def batch_norm2d(
Default: True

"""
from .tensor import expand_dims, squeeze, broadcast
from .tensor import add_axis, remove_axis, broadcast

def full(value):
C = data.shape[1]
@@ -633,7 +633,7 @@ def batch_norm2d(
def expand_or_full(x, value):
if x is None:
return full(value)
return expand_dims(x, [0, 2, 3])
return add_axis(x, [0, 2, 3])

def make_full_if_none(x, value):
if x is None:
@@ -1229,14 +1229,14 @@ def interpolate(
return ret


def dropout(inp: Tensor, drop_prob: float, rescale: bool = True) -> Tensor:
def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
"""
Returns a new tensor where each of the elements are randomly set to zero
with probability P = ``drop_prob``. Optionally rescale the output tensor.

:param inp: The input tensor
:param drop_prob: The probability to drop (set to zero) a single element
:param rescale: The default behavior of ``dropout`` during training is to rescale the output,
:param training: The default behavior of ``dropout`` during training is to rescale the output,
then it can be replaced by an :class:`~.Identity` during inference, default to True.
:return: The output tensor

@@ -1266,7 +1266,7 @@ def dropout(inp: Tensor, drop_prob: float, rescale: bool = True) -> Tensor:
rv = uniform(inp.shape)
mask = rv > drop_prob
inp *= mask.astype(inp.dtype)
if rescale:
if training:
inp *= 1 / (1 - drop_prob)
return inp



+ 35
- 23
imperative/python/megengine/functional/tensor.py View File

@@ -14,6 +14,7 @@ from typing import Iterable, List, Optional, Sequence, Tuple, Union
import numpy as np

from ..core._imperative_rt import CompNode
from ..core._wrap import device as as_device
from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const
@@ -30,31 +31,32 @@ from ..tensor import Tensor
from .elemwise import ceil

__all__ = [
"add_axis", # expand_dims
"add_axis",
"arange",
"broadcast",
"concat",
"cond_take",
"dimshuffle", # transpose, permute
"dimshuffle",
"expand_dims",
"eye",
"full",
"full_like",
"gather",
"eye",
"linspace",
"ones",
"ones_like",
"remove_axis", # squeeze
"param_pack_concat",
"param_pack_split",
"reshape",
"remove_axis",
"split",
"squeeze",
"stack",
"reshape",
"scatter",
"transpose",
"where",
"zeros",
"zeros_like",
"param_pack_split",
"param_pack_concat",
]


@@ -97,6 +99,8 @@ def eye(n: int, *, dtype=None, device: Optional[CompNode] = None) -> Tensor:


def full(shape, value, dtype="float32", device=None):
if isinstance(shape, int):
shape = (shape,)
if device is None:
device = get_default_device()
(x,) = Const(value, dtype=dtype, device=device)(
@@ -196,16 +200,13 @@ def broadcast(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor:
return result


def concat(
inps: Iterable[Tensor], axis: int = 0, device: Optional[CompNode] = None,
) -> Tensor:
def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor:
r"""
Concat some tensors

:param inps: Input tensors to concat
:param axis: the dimension over which the tensors are concatenated. Default: 0
:param device: The comp node output on. Default: None
:param comp_graph: The graph in which output is. Default: None
:return: The output tensor

Examples:
@@ -235,7 +236,9 @@ def concat(
return inps[0]

dtype = dtype_promotion(inps)
device = get_device(inps)
if device is None:
device = get_device(inps)
device = as_device(device)

def convert(x):
return convert_single_value(x, inps, dtype=dtype)
@@ -245,12 +248,13 @@ def concat(
return result


def stack(inps, axis=0):
def stack(inps, axis=0, device=None):
"""Concats a sequence of tensors along a new axis.
The input tensors must have the same shape.

:param inps: The input tensors.
:param axis: Which axis will be concatenated.
:param device: The comp node output on. Default: None
:return: The output concatenated tensor.

Examples:
@@ -283,7 +287,7 @@ def stack(inps, axis=0):
raise ValueError("All input tensors must have the same shape")

inps = [add_axis(inp, axis=axis) for inp in inps]
return concat(inps, axis=axis)
return concat(inps, axis=axis, device=device)


def split(inp, nsplits_or_sections, axis=0):
@@ -609,7 +613,10 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor:

def cond_take(mask: Tensor, x: Tensor) -> Tensor:
r"""
Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened.
Take elements from data if specific condition is satisfied on mask.
This operator has two outputs: the first is the elements taken,
and the second is the indices corresponding to those elements;
they are both 1-dimensional. High-dimension input would first be flattened.

:param mask: condition param; must be the same shape with data
:param x: input tensor from which to take elements
@@ -692,6 +699,9 @@ def dimshuffle(inp: Tensor, pattern: Iterable[int]) -> Tensor:
return result


transpose = dimshuffle


def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
r"""
Reshape a tensor to given target shape; total number of logical elements must
@@ -748,9 +758,6 @@ def reshape(inp: Tensor, target_shape: Iterable[int]) -> Tensor:
return x


transpose = dimshuffle


AxisAddRemove = builtin.AxisAddRemove
AxisDesc = AxisAddRemove.AxisDesc

@@ -803,12 +810,14 @@ def add_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
expand_dims = add_axis


def remove_axis(inp: Tensor, axis: Union[int, Sequence[int]]) -> Tensor:
def remove_axis(
inp: Tensor, axis: Optional[Union[int, Sequence[int]]] = None
) -> Tensor:
r"""
Remove dimension of shape 1.

:param inp: Input tensor
:param axis: Place of axis to be removed
:param axis: Place of axis to be removed, if None, all axis=1 will be removed. Default: None
:return: The output tensor

Examples:
@@ -897,8 +906,8 @@ def linspace(


def arange(
start: Union[int, float, Tensor],
end: Union[int, float, Tensor],
start: Union[int, float, Tensor] = 0,
end: Optional[Union[int, float, Tensor]] = None,
step: Union[int, float, Tensor] = 1,
dtype="float32",
device: Optional[CompNode] = None,
@@ -919,7 +928,7 @@ def arange(
import numpy as np
import megengine.functional as F

a = F.arange(1, 5, 1)
a = F.arange(5)
print(a.numpy())

.. testoutput::
@@ -927,6 +936,9 @@ def arange(
[1. 2. 3. 4.]

"""
if end is None:
start, end = 0, start

if isinstance(start, Tensor):
start = start.astype("float32")
if isinstance(end, Tensor):


+ 8
- 0
imperative/python/megengine/jit/__init__.py View File

@@ -1,2 +1,10 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .sublinear_memory_config import SublinearMemoryConfig
from .tracing import exclude_from_trace, trace

+ 0
- 1
imperative/python/megengine/jit/sublinear_memory_config.py View File

@@ -6,7 +6,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ..device import get_device_count




+ 8
- 0
imperative/python/megengine/jit/tracing.py View File

@@ -1,3 +1,11 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import contextlib
import functools


+ 3
- 3
imperative/python/megengine/module/init.py View File

@@ -13,7 +13,7 @@ from typing import Optional, Tuple, Union
import numpy as np

from ..functional import full
from ..random import gaussian, uniform
from ..random import normal, uniform
from ..tensor import Tensor


@@ -50,7 +50,7 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None:
:param a: Lower bound of the sampling interval
:param b: Upper bound of the sampling interval
"""
tensor._reset(uniform(tensor.shape, low=a, high=b).astype(tensor.dtype))
tensor._reset(uniform(size=tensor.shape, low=a, high=b).astype(tensor.dtype))


def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
@@ -61,7 +61,7 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
:param mean: The mean of the normal distribution
:param std: The standard deviation of the normal distribution
"""
tensor._reset(gaussian(tensor.shape, mean=mean, std=std).astype(tensor.dtype))
tensor._reset(normal(size=tensor.shape, mean=mean, std=std).astype(tensor.dtype))


def calculate_gain(


+ 2
- 2
imperative/python/megengine/random/__init__.py View File

@@ -6,8 +6,8 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .distribution import gaussian, uniform
from .rng import manual_seed
from .distribution import normal, uniform
from .rng import seed

# pylint: disable=undefined-variable
del distribution, rng # type: ignore[name-defined]

+ 19
- 11
imperative/python/megengine/random/distribution.py View File

@@ -15,13 +15,15 @@ from ..core.tensor import utils
from ..core.tensor.core import apply
from .rng import _random_seed_generator

__all__ = ["gaussian", "uniform"]
__all__ = ["normal", "uniform"]


def gaussian(shape: Iterable[int], mean: float = 0, std: float = 1,) -> Tensor:
def normal(
mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
) -> Tensor:
r"""Random variable with Gaussian distribution $N(\mu, \sigma)$

:param shape: Output tensor shape
:param size: Output tensor size
:param mean: The mean or expectation of the distribution
:param std: The standard deviation of the distribution (variance = $\sigma ^ 2$)
:return: The output tensor
@@ -33,7 +35,7 @@ def gaussian(shape: Iterable[int], mean: float = 0, std: float = 1,) -> Tensor:
import megengine as mge
import megengine.random as rand

x = rand.gaussian((2, 2), mean=0, std=1)
x = rand.normal(mean=0, std=1, size=(2, 2))
print(x.numpy())

.. testoutput::
@@ -43,17 +45,21 @@ def gaussian(shape: Iterable[int], mean: float = 0, std: float = 1,) -> Tensor:
[-1.4939808 -1.5824696 ]]

"""
if size is None:
size = (1,)
seed = _random_seed_generator().__next__()
op = GaussianRNG(seed=seed, mean=mean, std=std)
shape = Tensor(shape, dtype="int32")
(output,) = apply(op, shape)
size = Tensor(size, dtype="int32")
(output,) = apply(op, size)
return output


def uniform(shape: Iterable[int], low: float = 0, high: float = 1,) -> Tensor:
def uniform(
low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
) -> Tensor:
r"""Random variable with uniform distribution $U(0, 1)$

:param shape: Output tensor shape
:param size: Output tensor size
:param low: Lower range
:param high: Upper range
:return: The output tensor
@@ -65,7 +71,7 @@ def uniform(shape: Iterable[int], low: float = 0, high: float = 1,) -> Tensor:
import megengine as mge
import megengine.random as rand

x = rand.uniform((2, 2))
x = rand.uniform(size=(2, 2))
print(x.numpy())

.. testoutput::
@@ -77,9 +83,11 @@ def uniform(shape: Iterable[int], low: float = 0, high: float = 1,) -> Tensor:
"""
assert low < high, "Uniform is not defined when low >= high"

if size is None:
size = (1,)
seed = _random_seed_generator().__next__()
op = UniformRNG(seed=seed)
shape = Tensor(shape, dtype="int32")
(output,) = apply(op, shape)
size = Tensor(size, dtype="int32")
(output,) = apply(op, size)

return low + (high - low) * output

+ 2
- 2
imperative/python/megengine/random/rng.py View File

@@ -17,11 +17,11 @@ def _random_seed_generator():
if _rng is None:
from ..distributed.group import get_rank

manual_seed(seed=int(time.time()) + get_rank())
seed(seed=int(time.time()) + get_rank())
while True:
yield _rng.random_raw()


def manual_seed(seed: int):
def seed(seed: int):
global _rng # pylint: disable=global-statement
_rng = MT19937(seed=seed)

+ 10
- 4
imperative/python/test/unit/functional/test_elemwise.py View File

@@ -55,14 +55,20 @@ def test_clamp():
assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0))


# def test_isnan():
# for case in [[1, float("nan"), 0]]:
# assertTensorClose(F.isnan(tensor(case)), np.isnan(case).astype("uint8"))
def test_isnan():
for case in [[1, float("nan"), 0]]:
assertTensorClose(F.isnan(tensor(case)).numpy(), np.isnan(case))


def test_isinf():
for case in [[1, float("inf"), 0]]:
assertTensorClose(F.isinf(tensor(case)).numpy(), np.isinf(case).astype("uint8"))
assertTensorClose(F.isinf(tensor(case)).numpy(), np.isinf(case))


def test_sign():
for case in [[1, -1, 0]]:
x = tensor(case)
assertTensorClose(F.sign(x).numpy(), np.sign(case).astype(x.dtype))


def test_cosh():


+ 8
- 0
imperative/python/test/unit/functional/test_tensor.py View File

@@ -110,6 +110,14 @@ def test_concat():
opr_test(cases, run, ref_fn=lambda x, y: np.concatenate([x, y]))


def test_concat_device():
data1 = tensor(np.random.random((3, 2, 2)).astype("float32"), device="cpu0")
data2 = tensor(np.random.random((2, 2, 2)).astype("float32"), device="cpu1")

out = F.concat([data1, data2], device="cpu0")
assert str(out.device).split(":")[0] == "cpu0"


def test_stack():
data1 = np.random.random((3, 2, 2)).astype("float32")
data2 = np.random.random((3, 2, 2)).astype("float32")


Loading…
Cancel
Save