GitOrigin-RevId: ee5984c52d
release-1.4
@@ -156,7 +156,8 @@ def _logical_binary_elwise(mode, rev=False): | |||
def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
def get_axes(): | |||
if axis is None: | |||
return [i for i, s in enumerate(inp.shape) if s == 1] | |||
shp = inp.shape | |||
return [i for i, s in enumerate(shp) if s == 1] | |||
try: | |||
return [int(axis)] | |||
except (TypeError, ValueError): | |||
@@ -6,9 +6,11 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import time | |||
from typing import List, Optional, Tuple | |||
from ..device import set_default_device, what_is_xpu | |||
from ..random import seed | |||
from .server import Client, Server | |||
@@ -156,6 +158,7 @@ def init_process_group( | |||
WORLD.reset(list(range(world_size))) | |||
set_default_device("{}{}".format(device_type, device)) | |||
seed(int(time.time()) + rank) | |||
def is_distributed() -> bool: | |||
@@ -7,7 +7,7 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .distribution import normal, uniform | |||
from .rng import seed | |||
from .rng import RNG, seed | |||
# pylint: disable=undefined-variable | |||
del distribution, rng # type: ignore[name-defined] |
@@ -9,11 +9,8 @@ | |||
from typing import Iterable, Optional | |||
from .. import Tensor | |||
from ..core._imperative_rt import invoke_op | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core.ops.builtin import GaussianRNG, UniformRNG | |||
from ..core.tensor import utils | |||
from .rng import _random_seed_generator | |||
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||
from .rng import _normal, _uniform | |||
__all__ = ["normal", "uniform"] | |||
@@ -48,14 +45,14 @@ def normal( | |||
[-1.4939808 -1.5824696 ]] | |||
""" | |||
if size is None: | |||
size = (1,) | |||
op = GaussianRNG(mean, std) | |||
_ref = Tensor([], dtype="int32") | |||
shape = utils.astensor1d(size, _ref, dtype="int32") | |||
shape = Tensor(shape, dtype="int32") | |||
(output,) = apply(op, shape) | |||
return output | |||
return _normal( | |||
mean=mean, | |||
std=std, | |||
size=size, | |||
seed=_get_global_rng_seed(), | |||
device=None, | |||
handle=0, | |||
) | |||
def uniform( | |||
@@ -88,14 +85,11 @@ def uniform( | |||
[0.09365904 0.62957656]] | |||
""" | |||
assert low < high, "Uniform is not defined when low >= high" | |||
if size is None: | |||
size = (1,) | |||
op = UniformRNG() | |||
_ref = Tensor([], dtype="int32") | |||
shape = utils.astensor1d(size, _ref, dtype="int32") | |||
shape = Tensor(shape, dtype="int32") | |||
(output,) = apply(op, shape) | |||
return low + (high - low) * output | |||
return _uniform( | |||
low=low, | |||
high=high, | |||
size=size, | |||
seed=_get_global_rng_seed(), | |||
device=None, | |||
handle=0, | |||
) |
@@ -7,17 +7,94 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import time | |||
from typing import Iterable, Optional | |||
from numpy.random import MT19937 | |||
from .. import Tensor | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle | |||
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||
from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle | |||
from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed | |||
from ..core.ops.builtin import GaussianRNG, UniformRNG | |||
from ..core.tensor import utils | |||
from ..device import get_default_device | |||
_rng = None | |||
def _random_seed_generator(): | |||
if _rng is None: | |||
from ..distributed.group import get_rank | |||
def _normal( | |||
mean: float, | |||
std: float, | |||
size: Optional[Iterable[int]], | |||
seed: int, | |||
device: str, | |||
handle: int, | |||
) -> Tensor: | |||
if size is None: | |||
size = (1,) | |||
op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle) | |||
_ref = Tensor([], dtype="int32", device=device) | |||
shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | |||
(output,) = apply(op, shape) | |||
return output | |||
def _uniform( | |||
low: float, | |||
high: float, | |||
size: Optional[Iterable[int]], | |||
seed: int, | |||
device: str, | |||
handle: int, | |||
) -> Tensor: | |||
assert low < high, "Uniform is not defined when low >= high" | |||
if size is None: | |||
size = (1,) | |||
op = UniformRNG(seed=seed, handle=handle) | |||
_ref = Tensor([], dtype="int32", device=device) | |||
shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | |||
(output,) = apply(op, shape) | |||
return low + (high - low) * output | |||
class RNG: | |||
def __init__(self, seed=0, device=None): | |||
self.seed = seed | |||
self.device = device if device else get_default_device() | |||
self.handle = _new_rng_handle(self.device, self.seed) | |||
def uniform( | |||
self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None | |||
): | |||
return _uniform( | |||
low=low, | |||
high=high, | |||
size=size, | |||
seed=self.seed, | |||
device=self.device, | |||
handle=self.handle, | |||
) | |||
seed(seed=int(time.time()) + get_rank()) | |||
def normal( | |||
self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None | |||
): | |||
return _normal( | |||
mean=mean, | |||
std=std, | |||
size=size, | |||
seed=self.seed, | |||
device=self.device, | |||
handle=self.handle, | |||
) | |||
def __del__(self): | |||
_delete_rng_handle(self.handle) | |||
def _random_seed_generator(): | |||
assert _rng | |||
while True: | |||
yield _rng.random_raw() | |||
@@ -25,3 +102,7 @@ def _random_seed_generator(): | |||
def seed(seed: int): | |||
global _rng # pylint: disable=global-statement | |||
_rng = MT19937(seed=seed) | |||
_set_global_rng_seed(seed) | |||
seed(int(time.time())) |
@@ -10,7 +10,10 @@ | |||
*/ | |||
#include "./ops.h" | |||
#include "./helper.h" | |||
#include "./tensor.h" | |||
#include "megbrain/common.h" | |||
#include "megbrain/imperative.h" | |||
#include "megbrain/imperative/ops/backward_graph.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
@@ -491,21 +494,15 @@ void init_ops(py::module m) { | |||
_init_py_op_base(m); | |||
INIT_ALL_OP(m) | |||
m.def("new_rng_handle", &RNGMixin::new_handle); | |||
// FIXME: RNG op might execute after handle released due to async dispatch, | |||
// which would cause memory leak or use-after-free | |||
m.def("delete_rng_handle", &RNGMixin::delete_handle); | |||
m.def("set_rng_seed", &set_rng_seed); | |||
py::class_<UniformRNG, std::shared_ptr<UniformRNG>, OpDef>(m, "UniformRNG") | |||
.def(py::init<>()) | |||
.def(py::init<mgb::CompNode>()) | |||
.def(py::init<RNGMixin::Handle>()); | |||
py::class_<GaussianRNG, std::shared_ptr<GaussianRNG>, OpDef>(m, "GaussianRNG") | |||
.def(py::init<>()) | |||
.def(py::init<mgb::CompNode>()) | |||
.def(py::init<float ,float>()) | |||
.def(py::init<float ,float, mgb::CompNode>()) | |||
.def(py::init<float ,float, RNGMixin::Handle>()); | |||
m.def("new_rng_handle", &rng::new_handle); | |||
m.def("delete_rng_handle", [](size_t handle){ | |||
// RNG op might execute after handle released due to async dispatch, so | |||
// we need sync before delete a handle to avoid memory leak or use-after-free | |||
python::interpreter_for_py->sync(); | |||
mgb::CompNode::sync_all(); | |||
py_task_q.wait_all_task_finish(); | |||
rng::delete_handle(handle); | |||
}, py::call_guard<py::gil_scoped_release>()); | |||
m.def("set_global_rng_seed", &rng::set_global_rng_seed); | |||
m.def("get_global_rng_seed", &rng::get_global_rng_seed); | |||
} |
@@ -0,0 +1,121 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import megengine | |||
from megengine import tensor | |||
from megengine.core._imperative_rt import CompNode | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.core._imperative_rt.ops import ( | |||
delete_rng_handle, | |||
get_global_rng_seed, | |||
new_rng_handle, | |||
) | |||
from megengine.core.ops.builtin import GaussianRNG, UniformRNG | |||
from megengine.random import RNG | |||
from megengine.random.rng import _normal, _uniform | |||
def test_gaussian_op(): | |||
shape = ( | |||
8, | |||
9, | |||
11, | |||
12, | |||
) | |||
shape = tensor(shape, dtype="int32") | |||
op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0) | |||
(output,) = apply(op, shape) | |||
assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 | |||
assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 | |||
assert str(output.device) == str(CompNode("xpux")) | |||
cn = CompNode("xpu2") | |||
seed = 233333 | |||
h = new_rng_handle(cn, seed) | |||
op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h) | |||
(output,) = apply(op, shape) | |||
delete_rng_handle(h) | |||
assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 | |||
assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 | |||
assert str(output.device) == str(cn) | |||
def test_uniform_op(): | |||
shape = ( | |||
8, | |||
9, | |||
11, | |||
12, | |||
) | |||
shape = tensor(shape, dtype="int32") | |||
op = UniformRNG(seed=get_global_rng_seed()) | |||
(output,) = apply(op, shape) | |||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
assert str(output.device) == str(CompNode("xpux")) | |||
cn = CompNode("xpu2") | |||
seed = 233333 | |||
h = new_rng_handle(cn, seed) | |||
op = UniformRNG(seed=seed, handle=h) | |||
(output,) = apply(op, shape) | |||
delete_rng_handle(h) | |||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
assert str(output.device) == str(cn) | |||
def test_UniformRNG(): | |||
m1 = RNG(seed=111, device="xpu0") | |||
m2 = RNG(seed=111, device="xpu1") | |||
m3 = RNG(seed=222, device="xpu0") | |||
out1 = m1.uniform(size=(100,)) | |||
out1_ = m1.uniform(size=(100,)) | |||
out2 = m2.uniform(size=(100,)) | |||
out3 = m3.uniform(size=(100,)) | |||
np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||
assert out1.device == "xpu0" and out2.device == "xpu1" | |||
assert not (out1.numpy() == out3.numpy()).all() | |||
assert not (out1.numpy() == out1_.numpy()).all() | |||
low = -234 | |||
high = 123 | |||
out = m1.uniform(low=low, high=high, size=(20, 30, 40)) | |||
out_shp = out.shape | |||
if isinstance(out_shp, tuple): | |||
assert out_shp == (20, 30, 40) | |||
else: | |||
assert all(out.shape.numpy() == np.array([20, 30, 40])) | |||
assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1 | |||
def test_NormalRNG(): | |||
m1 = RNG(seed=111, device="xpu0") | |||
m2 = RNG(seed=111, device="xpu1") | |||
m3 = RNG(seed=222, device="xpu0") | |||
out1 = m1.normal(size=(100,)) | |||
out1_ = m1.uniform(size=(100,)) | |||
out2 = m2.normal(size=(100,)) | |||
out3 = m3.normal(size=(100,)) | |||
np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||
assert out1.device == "xpu0" and out2.device == "xpu1" | |||
assert not (out1.numpy() == out3.numpy()).all() | |||
assert not (out1.numpy() == out1_.numpy()).all() | |||
mean = -1 | |||
std = 2 | |||
out = m1.normal(mean=mean, std=std, size=(20, 30, 40)) | |||
out_shp = out.shape | |||
if isinstance(out_shp, tuple): | |||
assert out_shp == (20, 30, 40) | |||
else: | |||
assert all(out.shape.numpy() == np.array([20, 30, 40])) | |||
assert np.abs(out.mean().numpy() - mean) / std < 0.1 | |||
assert np.abs(np.std(out.numpy()) - std) < 0.1 |
@@ -1,76 +0,0 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
from megengine import tensor | |||
from megengine.core._imperative_rt import CompNode | |||
from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle | |||
from megengine.core.ops.builtin import GaussianRNG, UniformRNG | |||
from megengine.core.tensor.core import apply | |||
def test_gaussian_rng(): | |||
shape = ( | |||
8, | |||
9, | |||
11, | |||
12, | |||
) | |||
shape = tensor(shape, dtype="int32") | |||
op = GaussianRNG(1.0, 3.0) | |||
(output,) = apply(op, shape) | |||
assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 | |||
assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 | |||
assert str(output.device) == str(CompNode("xpux")) | |||
cn = CompNode("xpu1") | |||
op = GaussianRNG(-1.0, 2.0, cn) | |||
(output,) = apply(op, shape) | |||
assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1 | |||
assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1 | |||
assert str(output.device) == str(cn) | |||
cn = CompNode("xpu2") | |||
seed = 233333 | |||
h = new_rng_handle(cn, seed) | |||
op = GaussianRNG(3.0, 1.0, h) | |||
(output,) = apply(op, shape) | |||
delete_rng_handle(h) | |||
assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 | |||
assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 | |||
assert str(output.device) == str(cn) | |||
def test_uniform_rng(): | |||
shape = ( | |||
8, | |||
9, | |||
11, | |||
12, | |||
) | |||
shape = tensor(shape, dtype="int32") | |||
op = UniformRNG() | |||
(output,) = apply(op, shape) | |||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
assert str(output.device) == str(CompNode("xpux")) | |||
cn = CompNode("xpu1") | |||
op = UniformRNG(cn) | |||
(output,) = apply(op, shape) | |||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
assert str(output.device) == str(cn) | |||
cn = CompNode("xpu2") | |||
seed = 233333 | |||
h = new_rng_handle(cn, seed) | |||
op = UniformRNG(h) | |||
(output,) = apply(op, shape) | |||
delete_rng_handle(h) | |||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
assert str(output.device) == str(cn) |
@@ -2,7 +2,7 @@ | |||
* \file imperative/src/impl/ops/rng.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
@@ -10,23 +10,23 @@ | |||
*/ | |||
#include "megbrain/imperative/ops/rng.h" | |||
#include <bits/stdint-uintn.h> | |||
#include "megbrain/comp_node_env.h" | |||
#include "megbrain/graph/helper.h" | |||
#include "megbrain/opr/rand.h" | |||
//#include "megbrain/common.h" | |||
#include "../op_trait.h" | |||
#include "../dnn_op_helper.h" | |||
namespace mgb { | |||
namespace imperative { | |||
namespace mgb::imperative::rng { | |||
namespace { | |||
template <typename HandleFactory, typename THandle> | |||
class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj { | |||
public: | |||
using DT = CompNode::DeviceType; | |||
using Handle = THandle; | |||
using OpTypeInfo = size_t; | |||
template <typename... Args> | |||
Handle new_handle(Args&&... args) { | |||
@@ -38,27 +38,26 @@ public: | |||
size_t removed = 0; | |||
if (!is_finalized()) { | |||
MGB_LOCK_GUARD(m_mtx); | |||
removed = m_handle2op.erase(handle); | |||
removed = m_handle2ops.erase(handle); | |||
} | |||
static_cast<HandleFactory*>(this)->do_delete_handle(handle); | |||
return removed; | |||
} | |||
template <typename DnnOp> | |||
auto get_dnn_op(Handle handle, CompNode cn) { | |||
auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) { | |||
mgb_assert(!is_finalized()); | |||
DnnOpWithMutex* dnn_op_with_mtx; | |||
{ | |||
MGB_LOCK_GUARD(m_mtx); | |||
dnn_op_with_mtx = &m_handle2op[handle]; | |||
dnn_op_with_mtx = &m_handle2ops[handle][tpinfo]; | |||
} | |||
auto dnn_handle = | |||
MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); | |||
DnnOp* dnn_op; | |||
std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx); | |||
bool initialized = false; | |||
if ((dnn_op = dynamic_cast<DnnOp*>(dnn_op_with_mtx->op.get())) != | |||
nullptr) { | |||
DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get()); | |||
if (dnn_op != nullptr) { | |||
mgb_assert(dnn_op->handle() == dnn_handle); | |||
initialized = true; | |||
} else { | |||
@@ -77,35 +76,30 @@ private: | |||
struct DnnOpWithMutex { | |||
std::mutex mtx; | |||
std::unique_ptr<megdnn::OperatorBase> op; | |||
DnnOpWithMutex(): op{nullptr} {} | |||
}; | |||
std::shared_ptr<void> on_comp_node_finalize() override { | |||
MGB_LOCK_GUARD(m_mtx); | |||
m_handle2op.clear(); | |||
m_handle2ops.clear(); | |||
return {}; | |||
} | |||
std::unordered_map<Handle, DnnOpWithMutex> m_handle2op; | |||
std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex> > m_handle2ops; | |||
std::mutex m_mtx; | |||
}; | |||
class RNGDnnOpManager final | |||
: public DnnOpManagerT<RNGDnnOpManager, RNGMixin::Handle> { | |||
: public DnnOpManagerT<RNGDnnOpManager, Handle> { | |||
public: | |||
Handle new_handle(CompNode comp_node, uint64_t seed) { | |||
MGB_LOCK_GUARD(sm_mtx); | |||
return DnnOpManagerBase::new_handle(comp_node, seed); | |||
} | |||
size_t delete_handle(Handle handle) { | |||
size_t ret = 0; | |||
{ | |||
MGB_LOCK_GUARD(sm_mtx); | |||
auto iter = sm_partial2full.find(handle); | |||
if (iter != sm_partial2full.end()) { | |||
for (auto&& h : iter->second) { | |||
ret += DnnOpManagerBase::delete_handle(h.second); | |||
} | |||
sm_partial2full.erase(iter); | |||
} | |||
} | |||
ret += DnnOpManagerBase::delete_handle(handle); | |||
return ret; | |||
MGB_LOCK_GUARD(sm_mtx); | |||
return DnnOpManagerBase::delete_handle(handle); | |||
} | |||
Handle do_new_handle(CompNode comp_node, uint64_t seed) { | |||
@@ -118,32 +112,26 @@ public: | |||
} | |||
static uint64_t get_seed(Handle handle) { | |||
if (!handle) { return glob_default_seed; } | |||
return reinterpret_cast<HandleData*>(handle)->seed; | |||
} | |||
static CompNode get_comp_node(Handle handle) { | |||
mgb_assert(handle, "invalid handle"); | |||
return reinterpret_cast<HandleData*>(handle)->comp_node; | |||
} | |||
static Handle get_full_handle(Handle handle, CompNode comp_node) { | |||
if (get_comp_node(handle).valid()) { | |||
return handle; | |||
} | |||
MGB_LOCK_GUARD(sm_mtx); | |||
auto&& full = sm_partial2full[handle][comp_node]; | |||
if (!full) { | |||
full = inst().new_handle(comp_node, get_seed(handle)); | |||
} | |||
return full; | |||
} | |||
static Handle get_default_handle(CompNode comp_node) { | |||
static Handle glob_partial_handle = | |||
inst().new_handle(CompNode{}, glob_default_seed); | |||
if (!comp_node.valid()) { | |||
return glob_partial_handle; | |||
mgb_assert(comp_node.valid()); | |||
MGB_LOCK_GUARD(sm_mtx); | |||
auto&& glob_handle = glob_default_handles[comp_node]; | |||
if (!glob_handle) { | |||
glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | |||
} else if (get_seed(glob_handle) != glob_default_seed) { | |||
inst().DnnOpManagerBase::delete_handle(glob_handle); | |||
glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | |||
} | |||
return get_full_handle(glob_partial_handle, comp_node); | |||
return glob_handle; | |||
} | |||
static RNGDnnOpManager& inst() { | |||
@@ -152,9 +140,15 @@ public: | |||
} | |||
static void set_glob_default_seed(uint64_t seed) { | |||
MGB_LOCK_GUARD(sm_mtx); | |||
glob_default_seed = seed; | |||
} | |||
static uint64_t get_glob_default_seed() { | |||
MGB_LOCK_GUARD(sm_mtx); | |||
return glob_default_seed; | |||
} | |||
private: | |||
struct HandleData { | |||
CompNode comp_node; | |||
@@ -165,16 +159,13 @@ private: | |||
MemPool<HandleData> m_handle_pool; | |||
static std::mutex sm_mtx; | |||
static std::unordered_map<Handle, CompNode::UnorderedMap<Handle>> | |||
sm_partial2full; | |||
static CompNode::UnorderedMap<Handle> glob_default_handles; | |||
static uint64_t glob_default_seed; | |||
}; | |||
uint64_t RNGDnnOpManager::glob_default_seed = 0; | |||
std::mutex RNGDnnOpManager::sm_mtx; | |||
std::unordered_map<RNGDnnOpManager::Handle, | |||
CompNode::UnorderedMap<RNGDnnOpManager::Handle>> | |||
RNGDnnOpManager::sm_partial2full; | |||
CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles; | |||
template <typename Op> | |||
struct OpMeth; | |||
@@ -185,7 +176,11 @@ struct OpMeth<UniformRNG> { | |||
using Param = DnnOp::Param; | |||
using OpNode = mgb::opr::UniformRNG; | |||
static Param make_param(const UniformRNG& rng) { | |||
return {RNGDnnOpManager::get_seed(rng.handle())}; | |||
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||
mgb_assert(handle_seed == rng.seed, | |||
"inconsistent rng seed: rng op: %lu handle: %lu", | |||
handle_seed, rng.seed); | |||
return {handle_seed}; | |||
} | |||
}; | |||
@@ -195,7 +190,11 @@ struct OpMeth<GaussianRNG> { | |||
using Param = DnnOp::Param; | |||
using OpNode = mgb::opr::GaussianRNG; | |||
static Param make_param(const GaussianRNG& rng) { | |||
return {RNGDnnOpManager::get_seed(rng.handle()), rng.mean, rng.std}; | |||
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||
mgb_assert(handle_seed == rng.seed, | |||
"inconsistent rng seed: rng op: %lu handle: %lu", | |||
handle_seed, rng.seed); | |||
return {handle_seed, rng.mean, rng.std}; | |||
} | |||
}; | |||
@@ -206,23 +205,22 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, | |||
auto dest = outputs[0]; | |||
auto cn = dest->comp_node(); | |||
auto handle = RNGDnnOpManager::get_full_handle(rng.handle(), cn); | |||
{ | |||
auto handle_cn = RNGDnnOpManager::get_comp_node(handle); | |||
mgb_assert(cn == handle_cn, | |||
"inconsistent comp_node: handle: %s, output: %s", | |||
cn.to_string().c_str(), handle_cn.to_string().c_str()); | |||
auto handle = rng.handle; | |||
if (!handle) { | |||
handle = RNGDnnOpManager::get_default_handle(cn); | |||
} | |||
// retrieve dnn_op from glob cache | |||
auto dnn_op_thread_safe = RNGDnnOpManager::inst() | |||
.get_dnn_op<typename OpMeth<Op>::DnnOp>(handle, cn); | |||
.get_dnn_op<typename OpMeth<Op>::DnnOp>( | |||
handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), | |||
cn); | |||
auto initialized = std::get<0>(dnn_op_thread_safe); | |||
auto dnn_op = std::get<1>(dnn_op_thread_safe); | |||
if (initialized) { | |||
auto handle_seed = RNGDnnOpManager::get_seed(handle); | |||
mgb_assert(dnn_op->param().seed == handle_seed, | |||
"inconsistent rng seed: handle: %zu, dnn_op: %zu", | |||
"inconsistent rng seed: handle: %lu, dnn_op: %lu", | |||
handle_seed, dnn_op->param().seed); | |||
} | |||
dnn_op->param() = OpMeth<Op>::make_param(rng); | |||
@@ -239,9 +237,12 @@ template <typename Op> | |||
SmallVector<LogicalTensorDesc> infer_output_attrs( | |||
const OpDef& op, const SmallVector<TensorPtr>& inputs) { | |||
LogicalTensorDesc dest; | |||
dest.comp_node = op.cast_final_safe<Op>().comp_node(); | |||
if (!dest.comp_node.valid()) | |||
auto handle = op.cast_final_safe<Op>().handle; | |||
if (handle) { | |||
dest.comp_node = RNGDnnOpManager::get_comp_node(handle); | |||
} else { | |||
dest.comp_node = inputs[0]->comp_node(); | |||
} | |||
auto hv = inputs[0]->get_value().proxy_to_default_cpu(); | |||
TensorShape tshape; | |||
@@ -263,15 +264,22 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
} | |||
template<typename Op> | |||
cg::OperatorNodeBase* apply_on_var_node( | |||
const OpDef& def, const VarNodeArray& inputs) { | |||
SymbolVar apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
size_t nr_inp = inputs.size(); | |||
mgb_assert(nr_inp == 1, "UniformRNG expects 1 inputs; got %lu actually", | |||
nr_inp); | |||
auto&& rng = def.cast_final_safe<Op>(); | |||
mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", | |||
rng.dyn_typeinfo()->name, | |||
nr_inp); | |||
auto param = OpMeth<Op>::make_param(rng); | |||
return OpMeth<Op>::OpNode::make( | |||
inputs[0], param, {rng.comp_node()}).node()->owner_opr(); | |||
OperatorNodeConfig config; | |||
if (rng.handle) { | |||
config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; | |||
} else { | |||
config = {rng.make_name()}; | |||
} | |||
return OpMeth<Op>::OpNode::make(inputs[0], param, config); | |||
} | |||
template<typename T> | |||
@@ -309,28 +317,22 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
} // anonymous namespace | |||
RNGMixin::RNGMixin(CompNode cn): | |||
m_handle(RNGDnnOpManager::get_default_handle(cn)) {} | |||
uint64_t RNGMixin::seed() const { | |||
return RNGDnnOpManager::get_seed(m_handle); | |||
} | |||
CompNode RNGMixin::comp_node() const { | |||
return RNGDnnOpManager::get_comp_node(m_handle); | |||
} | |||
RNGMixin::Handle RNGMixin::new_handle(CompNode comp_node, uint64_t seed) { | |||
Handle new_handle(CompNode comp_node, uint64_t seed) { | |||
return RNGDnnOpManager::inst().new_handle(comp_node, seed); | |||
} | |||
size_t RNGMixin::delete_handle(Handle handle) { | |||
size_t delete_handle(Handle handle) { | |||
return RNGDnnOpManager::inst().delete_handle(handle); | |||
} | |||
void set_rng_seed(uint64_t seed) { | |||
void set_global_rng_seed(uint64_t seed) { | |||
RNGDnnOpManager::set_glob_default_seed(seed); | |||
} | |||
uint64_t get_global_rng_seed() { | |||
return RNGDnnOpManager::get_glob_default_seed(); | |||
} | |||
#define REG_RNG_OP(NAME)\ | |||
namespace { \ | |||
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||
@@ -339,12 +341,10 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ | |||
.fallback(); \ | |||
} \ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME); | |||
REG_RNG_OP(UniformRNG) | |||
REG_RNG_OP(GaussianRNG) | |||
} // namespace imperative | |||
} // namespace mgb | |||
} // namespace mgb::imperative::rng | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -429,34 +429,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual) | |||
.fallback(); | |||
}} // assert_equal | |||
namespace { namespace uniform_rng { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const UniformRNG&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::UniformRNG::make(inputs[0], op.param(), config); | |||
} | |||
OP_TRAIT_REG(UniformRNG, UniformRNG) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // uniform_rng | |||
namespace { namespace gaussian_rng { | |||
auto apply_on_var_node( | |||
const OpDef& def, | |||
const VarNodeArray& inputs) { | |||
auto&& op = static_cast<const GaussianRNG&>(def); | |||
mgb_assert(inputs.size() == 1); | |||
OperatorNodeConfig config{op.make_name()}; | |||
return opr::GaussianRNG::make(inputs[0], op.param(), config); | |||
} | |||
OP_TRAIT_REG(GaussianRNG, GaussianRNG) | |||
.apply_on_var_node(apply_on_var_node) | |||
.fallback(); | |||
}} // gaussian_rng | |||
namespace { namespace roi_align { | |||
VarNodeArray apply_on_var_node( | |||
const OpDef& def, | |||
@@ -2,7 +2,7 @@ | |||
* \file imperative/src/include/megbrain/imperative/ops/rng.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
@@ -12,84 +12,15 @@ | |||
#pragma once | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
namespace mgb::imperative { | |||
namespace mgb::imperative::rng { | |||
class RNGMixin { | |||
public: | |||
using Handle = size_t; | |||
using Handle = size_t; | |||
static Handle new_handle( | |||
CompNode comp_node={}, uint64_t seed=0); | |||
Handle new_handle(CompNode comp_node, uint64_t seed); | |||
size_t delete_handle(Handle handle); | |||
void set_global_rng_seed(uint64_t seed); | |||
uint64_t get_global_rng_seed(); | |||
static size_t delete_handle(Handle handle); | |||
Handle handle() const { | |||
return m_handle; | |||
} | |||
uint64_t seed() const; | |||
CompNode comp_node() const; | |||
protected: | |||
RNGMixin(Handle handle): m_handle(handle) {} | |||
RNGMixin(CompNode comp_node); | |||
private: | |||
Handle m_handle; | |||
}; | |||
class GaussianRNG : public OpDefImplBase<GaussianRNG>, | |||
public RNGMixin { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
float mean = 1.0f, std = 0.0; | |||
GaussianRNG(CompNode comp_node_): RNGMixin(comp_node_) {} | |||
GaussianRNG(float mean_=1.0, float std_=0.0, CompNode comp_node_={}): | |||
GaussianRNG(comp_node_) { mean = mean_; std = std_; } | |||
GaussianRNG(float mean_, float std_, Handle handle): | |||
RNGMixin(handle), mean(mean_), std(std_) {} | |||
size_t hash() const override { | |||
XXHash xxhash{}; | |||
auto append = [&xxhash](auto field){ | |||
auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
}; | |||
append(dyn_typeinfo()); | |||
append(seed()); | |||
append(mean); | |||
append(std); | |||
return xxhash.digest(); | |||
} | |||
bool is_same_st(const Hashable& rhs_) const override { | |||
auto&& rhs = static_cast<const GaussianRNG&>(rhs_); | |||
return rhs.seed() == seed() | |||
&& rhs.mean == mean | |||
&& rhs.std == std; | |||
} | |||
}; | |||
class UniformRNG : public OpDefImplBase<UniformRNG>, | |||
public RNGMixin { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
UniformRNG(CompNode comp_node_={}): RNGMixin(comp_node_) {} | |||
UniformRNG(Handle handle): RNGMixin(handle) {} | |||
size_t hash() const override { | |||
return hash_pair_combine( | |||
mgb::hash(seed()), | |||
reinterpret_cast<std::uintptr_t>(dyn_typeinfo())); | |||
} | |||
bool is_same_st(const Hashable& rhs_) const override { | |||
auto&& rhs = static_cast<const UniformRNG&>(rhs_); | |||
return rhs.dyn_typeinfo() == dyn_typeinfo() | |||
&& rhs.seed() == seed(); | |||
} | |||
}; | |||
void set_rng_seed(uint64_t seed); | |||
} // namespace mgb::imperative | |||
} // namespace mgb::imperative::rng |
@@ -14,6 +14,7 @@ | |||
using namespace mgb; | |||
using namespace imperative; | |||
using namespace imperative::rng; | |||
template<typename Op, typename ...Args> | |||
void check_rng_basic(Args&& ...args) { | |||
@@ -22,24 +23,31 @@ void check_rng_basic(Args&& ...args) { | |||
{3, 4, 5, 6}, | |||
{2333}}) | |||
for (auto&& cn: { | |||
CompNode::load("cpu0"), | |||
CompNode::load("xpu0")}) | |||
CompNode::load("xpu0"), | |||
CompNode::load("xpu1")}) | |||
{ | |||
auto op = Op::make(std::forward<Args>(args)..., cn); | |||
Handle h = new_handle(cn, 123); | |||
auto op = Op::make(std::forward<Args>(args)..., h); | |||
DeviceTensorND tshape_dev; | |||
cg::copy_shape_to_tensor_value(tshape_dev, tshape); | |||
auto outputs = OpDef::apply_on_physical_tensor(*op, {Tensor::make(tshape_dev)}); | |||
SmallVector<TensorPtr> inputs = {Tensor::make(tshape_dev)}; | |||
auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); | |||
ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape)); | |||
ASSERT_TRUE(cn == outputs[0]->comp_node()); | |||
// sync before delete handle | |||
for (auto&& p: outputs) { | |||
p->get_value(); | |||
} | |||
delete_handle(h); | |||
} | |||
} | |||
TEST(TestImperative, UniformRNGBasic) { | |||
check_rng_basic<UniformRNG>(); | |||
check_rng_basic<UniformRNG>(123); | |||
} | |||
TEST(TestImperative, GaussianRNGBasic) { | |||
check_rng_basic<GaussianRNG>(2.f, 3.f); | |||
check_rng_basic<GaussianRNG>(123, 2.f, 3.f); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -114,17 +114,33 @@ def TopK: MgbHashableOp<"TopK", [TopKParam]>; | |||
def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>; | |||
def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { | |||
let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}]; | |||
let cmpFunction = [{return true;}]; | |||
let extraArguments = (ins | |||
MgbSizeTAddr:$handle | |||
); | |||
let hashFunction = [{ | |||
return mgb::hash_pair_combine( | |||
mgb::hash($_self.dyn_typeinfo()), | |||
mgb::hash($_self.handle)); | |||
}]; | |||
let cmpFunction = [{return $0.handle == $1.handle;}]; | |||
} | |||
def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { | |||
let extraArguments = (ins | |||
MgbSizeTAddr:$handle | |||
); | |||
let hashFunction = [{ | |||
return mgb::hash_pair_combine( | |||
mgb::hash($_self.dyn_typeinfo()), | |||
mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std))); | |||
mgb::hash_pair_combine( | |||
mgb::hash($_self.handle), | |||
mgb::hash_pair_combine( | |||
mgb::hash($_self.mean), | |||
mgb::hash($_self.std)) | |||
) | |||
); | |||
}]; | |||
let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}]; | |||
let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std;}]; | |||
} | |||
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | |||