GitOrigin-RevId: 7ed0447bfe
release-1.7
@@ -225,7 +225,7 @@ def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor: | |||||
assert inp.size > 0, "size needs to be greater than 0" | assert inp.size > 0, "size needs to be greater than 0" | ||||
op = ShuffleRNG(seed=seed, handle=handle) | op = ShuffleRNG(seed=seed, handle=handle) | ||||
output, _ = apply(op, inp) | output, _ = apply(op, inp) | ||||
inp._reset(output) | |||||
return output | |||||
class RNG: | class RNG: | ||||
@@ -554,12 +554,15 @@ class RNG: | |||||
_seed = self._seed() if callable(self._seed) else self._seed | _seed = self._seed() if callable(self._seed) else self._seed | ||||
return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle) | return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle) | ||||
def permutation(self, n: int, *, dtype: str = "int32"): | |||||
r"""Generates a random permutation of integers from :math:`0` to :math:`n - 1`. | |||||
def permutation(self, n: Union[int, Tensor], *, dtype: str = "int32"): | |||||
r"""Randomly permute a sequence, or return a permuted range. | |||||
If ``n`` is a multi-dimensional tensor, it is only shuffled along its first index. | |||||
Args: | Args: | ||||
n: the upper bound. Must be larger than 0. | |||||
dtype: the output data type. int32, int16 and float32 are supported. Default: int32 | |||||
n: If ``n`` is an integer, random permutation of integers from :math:`0` to :math:`n - 1`. | |||||
If ``n`` is an tensor, make a copy and shuffle the elements randomly. | |||||
dtype: the output data type when ``n`` is an integer. | |||||
int32, int16 and float32 are supported. Default: int32 | |||||
Returns: | Returns: | ||||
the output tensor. | the output tensor. | ||||
@@ -568,13 +571,18 @@ class RNG: | |||||
.. testcode:: | .. testcode:: | ||||
import numpy as np | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.random as rand | import megengine.random as rand | ||||
x = rand.permutation(n=10, dtype="int32") | |||||
x = rand.permutation(10, dtype="int32") | |||||
print(x.numpy()) | |||||
x = rand.permutation(10, dtype="float32") | |||||
print(x.numpy()) | print(x.numpy()) | ||||
x = rand.permutation(n=10, dtype="float32") | |||||
x = mge.tensor(np.arange(18)).reshape(6,3) | |||||
x = rand.permutation(x) | |||||
print(x.numpy()) | print(x.numpy()) | ||||
Outputs: | Outputs: | ||||
@@ -584,11 +592,20 @@ class RNG: | |||||
[4 5 0 7 3 8 6 1 9 2] | [4 5 0 7 3 8 6 1 9 2] | ||||
[3. 4. 9. 0. 6. 8. 7. 1. 5. 2.] | [3. 4. 9. 0. 6. 8. 7. 1. 5. 2.] | ||||
[[12 13 14] | |||||
[ 3 4 5] | |||||
[15 16 17] | |||||
[ 0 1 2] | |||||
[ 9 10 11] | |||||
[ 6 7 8]] | |||||
""" | """ | ||||
_seed = self._seed() if callable(self._seed) else self._seed | _seed = self._seed() if callable(self._seed) else self._seed | ||||
return _permutation( | |||||
n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype | |||||
) | |||||
if isinstance(n, int): | |||||
return _permutation( | |||||
n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype | |||||
) | |||||
assert isinstance(n, Tensor) | |||||
return _shuffle(inp=n, seed=_seed, handle=self._handle) | |||||
def shuffle(self, inp: Tensor): | def shuffle(self, inp: Tensor): | ||||
r"""Modify a sequence in-place by shuffling its contents. | r"""Modify a sequence in-place by shuffling its contents. | ||||
@@ -627,7 +644,7 @@ class RNG: | |||||
[ 6. 7. 8.]] | [ 6. 7. 8.]] | ||||
""" | """ | ||||
_seed = self._seed() if callable(self._seed) else self._seed | _seed = self._seed() if callable(self._seed) else self._seed | ||||
_shuffle(inp=inp, seed=_seed, handle=self._handle) | |||||
inp._reset(_shuffle(inp=inp, seed=_seed, handle=self._handle)) | |||||
def __del__(self): | def __del__(self): | ||||
if self._handle != 0: | if self._handle != 0: | ||||
@@ -28,6 +28,7 @@ from megengine.core.ops.builtin import ( | |||||
UniformRNG, | UniformRNG, | ||||
) | ) | ||||
from megengine.device import get_device_count | from megengine.device import get_device_count | ||||
from megengine.jit import trace | |||||
from megengine.random import RNG | from megengine.random import RNG | ||||
from megengine.random import seed as set_global_seed | from megengine.random import seed as set_global_seed | ||||
from megengine.random import uniform | from megengine.random import uniform | ||||
@@ -370,21 +371,22 @@ def test_PoissonRNG(): | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
get_device_count("xpu") <= 1, reason="xpu counts need > 1", | get_device_count("xpu") <= 1, reason="xpu counts need > 1", | ||||
) | ) | ||||
def test_PermutationRNG(): | |||||
@pytest.mark.parametrize("symbolic", [True, False]) | |||||
def test_PermutationRNG(symbolic): | |||||
m1 = RNG(seed=111, device="xpu0") | m1 = RNG(seed=111, device="xpu0") | ||||
m2 = RNG(seed=111, device="xpu1") | m2 = RNG(seed=111, device="xpu1") | ||||
m3 = RNG(seed=222, device="xpu0") | m3 = RNG(seed=222, device="xpu0") | ||||
out1 = m1.permutation(n=1000) | |||||
out1 = m1.permutation(1000) | |||||
out1_ = m1.uniform(size=(1000,)) | out1_ = m1.uniform(size=(1000,)) | ||||
out2 = m2.permutation(n=1000) | |||||
out3 = m3.permutation(n=1000) | |||||
out2 = m2.permutation(1000) | |||||
out3 = m3.permutation(1000) | |||||
np.testing.assert_equal(out1.numpy(), out2.numpy()) | np.testing.assert_equal(out1.numpy(), out2.numpy()) | ||||
assert out1.device == "xpu0" and out2.device == "xpu1" | assert out1.device == "xpu0" and out2.device == "xpu1" | ||||
assert not (out1.numpy() == out3.numpy()).all() | assert not (out1.numpy() == out3.numpy()).all() | ||||
assert not (out1.numpy() == out1_.numpy()).all() | assert not (out1.numpy() == out1_.numpy()).all() | ||||
out = m1.permutation(n=1000) | |||||
out = m1.permutation(1000) | |||||
out_shp = out.shape | out_shp = out.shape | ||||
if isinstance(out_shp, tuple): | if isinstance(out_shp, tuple): | ||||
assert out_shp == (1000,) | assert out_shp == (1000,) | ||||
@@ -397,6 +399,24 @@ def test_PermutationRNG(): | |||||
assert sum_result(out, lambda x: x) < 500 | assert sum_result(out, lambda x: x) < 500 | ||||
assert sum_result(out, np.sort) == 1000 | assert sum_result(out, np.sort) == 1000 | ||||
def func(): | |||||
out = m1.permutation(Tensor(7)) | |||||
out_shp = out.shape | |||||
if isinstance(out_shp, tuple): | |||||
assert out_shp == (1,) | |||||
else: | |||||
assert all(out.shape.numpy() == np.array([1])) | |||||
n, m = 6, 3 | |||||
out = m1.permutation(Tensor(np.arange(n * m), dtype="float32").reshape(n, m)) | |||||
out_shp = out.shape | |||||
if isinstance(out_shp, tuple): | |||||
assert out_shp == (n, m) | |||||
else: | |||||
assert all(out.shape.numpy() == np.array([n, m])) | |||||
func = trace(symbolic=symbolic)(func) | |||||
func() | |||||
@pytest.mark.skipif( | @pytest.mark.skipif( | ||||
get_device_count("xpu") <= 1, reason="xpu counts need > 1", | get_device_count("xpu") <= 1, reason="xpu counts need > 1", | ||||
@@ -214,8 +214,12 @@ ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param, | |||||
const OperatorNodeConfig& config) | const OperatorNodeConfig& config) | ||||
: Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) { | : Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) { | ||||
add_input({data}); | add_input({data}); | ||||
add_output(None)->dtype(data->dtype()); | |||||
add_output(None)->dtype(dtype::Int32{}); | |||||
add_output(None) | |||||
->dtype(data->dtype()) | |||||
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
add_output(None) | |||||
->dtype(dtype::Int32{}) | |||||
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||||
cg::add_workspace_output(this); | cg::add_workspace_output(this); | ||||
add_equivalence_component<ScalarHash<void*>>(this); | add_equivalence_component<ScalarHash<void*>>(this); | ||||
} | } | ||||
@@ -266,12 +270,27 @@ void ShuffleRNGForward::add_input_layout_constraint() { | |||||
}; | }; | ||||
void ShuffleRNGForward::scn_do_execute() { | void ShuffleRNGForward::scn_do_execute() { | ||||
auto&& ret = output(0); | |||||
if (ret->layout().is_empty()) { | |||||
mgb_assert(ret->dev_tensor().empty()); | |||||
return; | |||||
} | |||||
m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(), | m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(), | ||||
output(0)->dev_tensor().as_megdnn(), | output(0)->dev_tensor().as_megdnn(), | ||||
output(1)->dev_tensor().as_megdnn(), | output(1)->dev_tensor().as_megdnn(), | ||||
get_megdnn_workspace_from_var(output(2))); | get_megdnn_workspace_from_var(output(2))); | ||||
} | } | ||||
cg::OperatorNodeBase::NodeProp* ShuffleRNGForward::do_make_node_prop() const { | |||||
auto prop = Super::do_make_node_prop(); | |||||
prop->add_flag(NodeProp::Flag::IMPURE_FUNC); | |||||
for (auto i : input()) { | |||||
prop->add_dep_type_existing_var(i, | |||||
NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
} | |||||
return prop; | |||||
} | |||||
#if MGB_ENABLE_GRAD | #if MGB_ENABLE_GRAD | ||||
MGB_IMPL_OPR_GRAD(ShuffleRNGForward) { | MGB_IMPL_OPR_GRAD(ShuffleRNGForward) { | ||||
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); | mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); | ||||