Browse Source

feat(imperative/opr): rebase rng refactoring to dev & add python module

GitOrigin-RevId: ee5984c52d
release-1.4
Megvii Engine Team 4 years ago
parent
commit
0d332cf005
13 changed files with 371 additions and 323 deletions
  1. +2
    -1
      imperative/python/megengine/core/tensor/array_method.py
  2. +3
    -0
      imperative/python/megengine/distributed/group.py
  3. +1
    -1
      imperative/python/megengine/random/__init__.py
  4. +18
    -24
      imperative/python/megengine/random/distribution.py
  5. +85
    -4
      imperative/python/megengine/random/rng.py
  6. +14
    -17
      imperative/python/src/ops.cpp
  7. +121
    -0
      imperative/python/test/unit/random/test_rng.py
  8. +0
    -76
      imperative/python/test/unit/test_rng.py
  9. +84
    -84
      imperative/src/impl/ops/rng.cpp
  10. +0
    -28
      imperative/src/impl/ops/specializations.cpp
  11. +9
    -78
      imperative/src/include/megbrain/imperative/ops/rng.h
  12. +14
    -6
      imperative/src/test/rng.cpp
  13. +20
    -4
      src/core/include/megbrain/ir/ops.td

+ 2
- 1
imperative/python/megengine/core/tensor/array_method.py View File

@@ -156,7 +156,8 @@ def _logical_binary_elwise(mode, rev=False):
def _remove_axis(inp: Tensor, axis) -> Tensor: def _remove_axis(inp: Tensor, axis) -> Tensor:
def get_axes(): def get_axes():
if axis is None: if axis is None:
return [i for i, s in enumerate(inp.shape) if s == 1]
shp = inp.shape
return [i for i, s in enumerate(shp) if s == 1]
try: try:
return [int(axis)] return [int(axis)]
except (TypeError, ValueError): except (TypeError, ValueError):


+ 3
- 0
imperative/python/megengine/distributed/group.py View File

@@ -6,9 +6,11 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import time
from typing import List, Optional, Tuple from typing import List, Optional, Tuple


from ..device import set_default_device, what_is_xpu from ..device import set_default_device, what_is_xpu
from ..random import seed
from .server import Client, Server from .server import Client, Server




@@ -156,6 +158,7 @@ def init_process_group(
WORLD.reset(list(range(world_size))) WORLD.reset(list(range(world_size)))


set_default_device("{}{}".format(device_type, device)) set_default_device("{}{}".format(device_type, device))
seed(int(time.time()) + rank)




def is_distributed() -> bool: def is_distributed() -> bool:


+ 1
- 1
imperative/python/megengine/random/__init__.py View File

@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .distribution import normal, uniform from .distribution import normal, uniform
from .rng import seed
from .rng import RNG, seed


# pylint: disable=undefined-variable # pylint: disable=undefined-variable
del distribution, rng # type: ignore[name-defined] del distribution, rng # type: ignore[name-defined]

+ 18
- 24
imperative/python/megengine/random/distribution.py View File

@@ -9,11 +9,8 @@
from typing import Iterable, Optional from typing import Iterable, Optional


from .. import Tensor from .. import Tensor
from ..core._imperative_rt import invoke_op
from ..core._imperative_rt.core2 import apply
from ..core.ops.builtin import GaussianRNG, UniformRNG
from ..core.tensor import utils
from .rng import _random_seed_generator
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from .rng import _normal, _uniform


__all__ = ["normal", "uniform"] __all__ = ["normal", "uniform"]


@@ -48,14 +45,14 @@ def normal(
[-1.4939808 -1.5824696 ]] [-1.4939808 -1.5824696 ]]


""" """
if size is None:
size = (1,)
op = GaussianRNG(mean, std)
_ref = Tensor([], dtype="int32")
shape = utils.astensor1d(size, _ref, dtype="int32")
shape = Tensor(shape, dtype="int32")
(output,) = apply(op, shape)
return output
return _normal(
mean=mean,
std=std,
size=size,
seed=_get_global_rng_seed(),
device=None,
handle=0,
)




def uniform( def uniform(
@@ -88,14 +85,11 @@ def uniform(
[0.09365904 0.62957656]] [0.09365904 0.62957656]]


""" """
assert low < high, "Uniform is not defined when low >= high"

if size is None:
size = (1,)
op = UniformRNG()
_ref = Tensor([], dtype="int32")
shape = utils.astensor1d(size, _ref, dtype="int32")
shape = Tensor(shape, dtype="int32")
(output,) = apply(op, shape)

return low + (high - low) * output
return _uniform(
low=low,
high=high,
size=size,
seed=_get_global_rng_seed(),
device=None,
handle=0,
)

+ 85
- 4
imperative/python/megengine/random/rng.py View File

@@ -7,17 +7,94 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import time import time
from typing import Iterable, Optional


from numpy.random import MT19937 from numpy.random import MT19937


from .. import Tensor
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle
from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed
from ..core.ops.builtin import GaussianRNG, UniformRNG
from ..core.tensor import utils
from ..device import get_default_device

_rng = None _rng = None




def _random_seed_generator():
if _rng is None:
from ..distributed.group import get_rank
def _normal(
mean: float,
std: float,
size: Optional[Iterable[int]],
seed: int,
device: str,
handle: int,
) -> Tensor:
if size is None:
size = (1,)
op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle)
_ref = Tensor([], dtype="int32", device=device)
shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
(output,) = apply(op, shape)
return output


def _uniform(
low: float,
high: float,
size: Optional[Iterable[int]],
seed: int,
device: str,
handle: int,
) -> Tensor:
assert low < high, "Uniform is not defined when low >= high"
if size is None:
size = (1,)
op = UniformRNG(seed=seed, handle=handle)
_ref = Tensor([], dtype="int32", device=device)
shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
(output,) = apply(op, shape)
return low + (high - low) * output


class RNG:
def __init__(self, seed=0, device=None):
self.seed = seed
self.device = device if device else get_default_device()
self.handle = _new_rng_handle(self.device, self.seed)

def uniform(
self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
):
return _uniform(
low=low,
high=high,
size=size,
seed=self.seed,
device=self.device,
handle=self.handle,
)


seed(seed=int(time.time()) + get_rank())
def normal(
self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
):
return _normal(
mean=mean,
std=std,
size=size,
seed=self.seed,
device=self.device,
handle=self.handle,
)

def __del__(self):
_delete_rng_handle(self.handle)


def _random_seed_generator():
assert _rng
while True: while True:
yield _rng.random_raw() yield _rng.random_raw()


@@ -25,3 +102,7 @@ def _random_seed_generator():
def seed(seed: int): def seed(seed: int):
global _rng # pylint: disable=global-statement global _rng # pylint: disable=global-statement
_rng = MT19937(seed=seed) _rng = MT19937(seed=seed)
_set_global_rng_seed(seed)


seed(int(time.time()))

+ 14
- 17
imperative/python/src/ops.cpp View File

@@ -10,7 +10,10 @@
*/ */


#include "./ops.h" #include "./ops.h"
#include "./helper.h"
#include "./tensor.h"


#include "megbrain/common.h"
#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
@@ -491,21 +494,15 @@ void init_ops(py::module m) {
_init_py_op_base(m); _init_py_op_base(m);
INIT_ALL_OP(m) INIT_ALL_OP(m)


m.def("new_rng_handle", &RNGMixin::new_handle);
// FIXME: RNG op might execute after handle released due to async dispatch,
// which would cause memory leak or use-after-free
m.def("delete_rng_handle", &RNGMixin::delete_handle);
m.def("set_rng_seed", &set_rng_seed);

py::class_<UniformRNG, std::shared_ptr<UniformRNG>, OpDef>(m, "UniformRNG")
.def(py::init<>())
.def(py::init<mgb::CompNode>())
.def(py::init<RNGMixin::Handle>());

py::class_<GaussianRNG, std::shared_ptr<GaussianRNG>, OpDef>(m, "GaussianRNG")
.def(py::init<>())
.def(py::init<mgb::CompNode>())
.def(py::init<float ,float>())
.def(py::init<float ,float, mgb::CompNode>())
.def(py::init<float ,float, RNGMixin::Handle>());
m.def("new_rng_handle", &rng::new_handle);
m.def("delete_rng_handle", [](size_t handle){
// RNG op might execute after handle released due to async dispatch, so
// we need sync before delete a handle to avoid memory leak or use-after-free
python::interpreter_for_py->sync();
mgb::CompNode::sync_all();
py_task_q.wait_all_task_finish();
rng::delete_handle(handle);
}, py::call_guard<py::gil_scoped_release>());
m.def("set_global_rng_seed", &rng::set_global_rng_seed);
m.def("get_global_rng_seed", &rng::get_global_rng_seed);
} }

+ 121
- 0
imperative/python/test/unit/random/test_rng.py View File

@@ -0,0 +1,121 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

import megengine
from megengine import tensor
from megengine.core._imperative_rt import CompNode
from megengine.core._imperative_rt.core2 import apply
from megengine.core._imperative_rt.ops import (
delete_rng_handle,
get_global_rng_seed,
new_rng_handle,
)
from megengine.core.ops.builtin import GaussianRNG, UniformRNG
from megengine.random import RNG
from megengine.random.rng import _normal, _uniform


def test_gaussian_op():
shape = (
8,
9,
11,
12,
)
shape = tensor(shape, dtype="int32")
op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1
assert str(output.device) == str(CompNode("xpux"))

cn = CompNode("xpu2")
seed = 233333
h = new_rng_handle(cn, seed)
op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h)
(output,) = apply(op, shape)
delete_rng_handle(h)
assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1
assert str(output.device) == str(cn)


def test_uniform_op():
shape = (
8,
9,
11,
12,
)
shape = tensor(shape, dtype="int32")
op = UniformRNG(seed=get_global_rng_seed())
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(CompNode("xpux"))

cn = CompNode("xpu2")
seed = 233333
h = new_rng_handle(cn, seed)
op = UniformRNG(seed=seed, handle=h)
(output,) = apply(op, shape)
delete_rng_handle(h)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(cn)


def test_UniformRNG():
m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1")
m3 = RNG(seed=222, device="xpu0")
out1 = m1.uniform(size=(100,))
out1_ = m1.uniform(size=(100,))
out2 = m2.uniform(size=(100,))
out3 = m3.uniform(size=(100,))

np.testing.assert_equal(out1.numpy(), out2.numpy())
assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all()

low = -234
high = 123
out = m1.uniform(low=low, high=high, size=(20, 30, 40))
out_shp = out.shape
if isinstance(out_shp, tuple):
assert out_shp == (20, 30, 40)
else:
assert all(out.shape.numpy() == np.array([20, 30, 40]))
assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1


def test_NormalRNG():
m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1")
m3 = RNG(seed=222, device="xpu0")
out1 = m1.normal(size=(100,))
out1_ = m1.uniform(size=(100,))
out2 = m2.normal(size=(100,))
out3 = m3.normal(size=(100,))

np.testing.assert_equal(out1.numpy(), out2.numpy())
assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all()

mean = -1
std = 2
out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
out_shp = out.shape
if isinstance(out_shp, tuple):
assert out_shp == (20, 30, 40)
else:
assert all(out.shape.numpy() == np.array([20, 30, 40]))
assert np.abs(out.mean().numpy() - mean) / std < 0.1
assert np.abs(np.std(out.numpy()) - std) < 0.1

+ 0
- 76
imperative/python/test/unit/test_rng.py View File

@@ -1,76 +0,0 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

from megengine import tensor
from megengine.core._imperative_rt import CompNode
from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle
from megengine.core.ops.builtin import GaussianRNG, UniformRNG
from megengine.core.tensor.core import apply


def test_gaussian_rng():
shape = (
8,
9,
11,
12,
)
shape = tensor(shape, dtype="int32")
op = GaussianRNG(1.0, 3.0)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1
assert str(output.device) == str(CompNode("xpux"))

cn = CompNode("xpu1")
op = GaussianRNG(-1.0, 2.0, cn)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1
assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1
assert str(output.device) == str(cn)

cn = CompNode("xpu2")
seed = 233333
h = new_rng_handle(cn, seed)
op = GaussianRNG(3.0, 1.0, h)
(output,) = apply(op, shape)
delete_rng_handle(h)
assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1
assert str(output.device) == str(cn)


def test_uniform_rng():
shape = (
8,
9,
11,
12,
)
shape = tensor(shape, dtype="int32")
op = UniformRNG()
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(CompNode("xpux"))

cn = CompNode("xpu1")
op = UniformRNG(cn)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(cn)

cn = CompNode("xpu2")
seed = 233333
h = new_rng_handle(cn, seed)
op = UniformRNG(h)
(output,) = apply(op, shape)
delete_rng_handle(h)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(cn)

+ 84
- 84
imperative/src/impl/ops/rng.cpp View File

@@ -2,7 +2,7 @@
* \file imperative/src/impl/ops/rng.cpp * \file imperative/src/impl/ops/rng.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
@@ -10,23 +10,23 @@
*/ */


#include "megbrain/imperative/ops/rng.h" #include "megbrain/imperative/ops/rng.h"
#include <bits/stdint-uintn.h>
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/graph/helper.h" #include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h" #include "megbrain/opr/rand.h"
//#include "megbrain/common.h"


#include "../op_trait.h" #include "../op_trait.h"
#include "../dnn_op_helper.h"


namespace mgb {
namespace imperative {
namespace mgb::imperative::rng {


namespace { namespace {


template <typename HandleFactory, typename THandle> template <typename HandleFactory, typename THandle>
class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj { class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj {
public: public:
using DT = CompNode::DeviceType;
using Handle = THandle; using Handle = THandle;
using OpTypeInfo = size_t;


template <typename... Args> template <typename... Args>
Handle new_handle(Args&&... args) { Handle new_handle(Args&&... args) {
@@ -38,27 +38,26 @@ public:
size_t removed = 0; size_t removed = 0;
if (!is_finalized()) { if (!is_finalized()) {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_mtx);
removed = m_handle2op.erase(handle);
removed = m_handle2ops.erase(handle);
} }
static_cast<HandleFactory*>(this)->do_delete_handle(handle); static_cast<HandleFactory*>(this)->do_delete_handle(handle);
return removed; return removed;
} }


template <typename DnnOp> template <typename DnnOp>
auto get_dnn_op(Handle handle, CompNode cn) {
auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) {
mgb_assert(!is_finalized()); mgb_assert(!is_finalized());
DnnOpWithMutex* dnn_op_with_mtx; DnnOpWithMutex* dnn_op_with_mtx;
{ {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_mtx);
dnn_op_with_mtx = &m_handle2op[handle];
dnn_op_with_mtx = &m_handle2ops[handle][tpinfo];
} }
auto dnn_handle = auto dnn_handle =
MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
DnnOp* dnn_op;
std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx); std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx);
bool initialized = false; bool initialized = false;
if ((dnn_op = dynamic_cast<DnnOp*>(dnn_op_with_mtx->op.get())) !=
nullptr) {
DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get());
if (dnn_op != nullptr) {
mgb_assert(dnn_op->handle() == dnn_handle); mgb_assert(dnn_op->handle() == dnn_handle);
initialized = true; initialized = true;
} else { } else {
@@ -77,35 +76,30 @@ private:
struct DnnOpWithMutex { struct DnnOpWithMutex {
std::mutex mtx; std::mutex mtx;
std::unique_ptr<megdnn::OperatorBase> op; std::unique_ptr<megdnn::OperatorBase> op;
DnnOpWithMutex(): op{nullptr} {}
}; };


std::shared_ptr<void> on_comp_node_finalize() override { std::shared_ptr<void> on_comp_node_finalize() override {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_mtx);
m_handle2op.clear();
m_handle2ops.clear();
return {}; return {};
} }


std::unordered_map<Handle, DnnOpWithMutex> m_handle2op;
std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex> > m_handle2ops;
std::mutex m_mtx; std::mutex m_mtx;
}; };


class RNGDnnOpManager final class RNGDnnOpManager final
: public DnnOpManagerT<RNGDnnOpManager, RNGMixin::Handle> {
: public DnnOpManagerT<RNGDnnOpManager, Handle> {
public: public:
Handle new_handle(CompNode comp_node, uint64_t seed) {
MGB_LOCK_GUARD(sm_mtx);
return DnnOpManagerBase::new_handle(comp_node, seed);
}

size_t delete_handle(Handle handle) { size_t delete_handle(Handle handle) {
size_t ret = 0;
{
MGB_LOCK_GUARD(sm_mtx);
auto iter = sm_partial2full.find(handle);
if (iter != sm_partial2full.end()) {
for (auto&& h : iter->second) {
ret += DnnOpManagerBase::delete_handle(h.second);
}
sm_partial2full.erase(iter);
}
}
ret += DnnOpManagerBase::delete_handle(handle);
return ret;
MGB_LOCK_GUARD(sm_mtx);
return DnnOpManagerBase::delete_handle(handle);
} }


Handle do_new_handle(CompNode comp_node, uint64_t seed) { Handle do_new_handle(CompNode comp_node, uint64_t seed) {
@@ -118,32 +112,26 @@ public:
} }


static uint64_t get_seed(Handle handle) { static uint64_t get_seed(Handle handle) {
if (!handle) { return glob_default_seed; }
return reinterpret_cast<HandleData*>(handle)->seed; return reinterpret_cast<HandleData*>(handle)->seed;
} }


static CompNode get_comp_node(Handle handle) { static CompNode get_comp_node(Handle handle) {
mgb_assert(handle, "invalid handle");
return reinterpret_cast<HandleData*>(handle)->comp_node; return reinterpret_cast<HandleData*>(handle)->comp_node;
} }


static Handle get_full_handle(Handle handle, CompNode comp_node) {
if (get_comp_node(handle).valid()) {
return handle;
}
MGB_LOCK_GUARD(sm_mtx);
auto&& full = sm_partial2full[handle][comp_node];
if (!full) {
full = inst().new_handle(comp_node, get_seed(handle));
}
return full;
}

static Handle get_default_handle(CompNode comp_node) { static Handle get_default_handle(CompNode comp_node) {
static Handle glob_partial_handle =
inst().new_handle(CompNode{}, glob_default_seed);
if (!comp_node.valid()) {
return glob_partial_handle;
mgb_assert(comp_node.valid());
MGB_LOCK_GUARD(sm_mtx);
auto&& glob_handle = glob_default_handles[comp_node];
if (!glob_handle) {
glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
} else if (get_seed(glob_handle) != glob_default_seed) {
inst().DnnOpManagerBase::delete_handle(glob_handle);
glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
} }
return get_full_handle(glob_partial_handle, comp_node);
return glob_handle;
} }


static RNGDnnOpManager& inst() { static RNGDnnOpManager& inst() {
@@ -152,9 +140,15 @@ public:
} }


static void set_glob_default_seed(uint64_t seed) { static void set_glob_default_seed(uint64_t seed) {
MGB_LOCK_GUARD(sm_mtx);
glob_default_seed = seed; glob_default_seed = seed;
} }


static uint64_t get_glob_default_seed() {
MGB_LOCK_GUARD(sm_mtx);
return glob_default_seed;
}

private: private:
struct HandleData { struct HandleData {
CompNode comp_node; CompNode comp_node;
@@ -165,16 +159,13 @@ private:
MemPool<HandleData> m_handle_pool; MemPool<HandleData> m_handle_pool;


static std::mutex sm_mtx; static std::mutex sm_mtx;
static std::unordered_map<Handle, CompNode::UnorderedMap<Handle>>
sm_partial2full;
static CompNode::UnorderedMap<Handle> glob_default_handles;
static uint64_t glob_default_seed; static uint64_t glob_default_seed;
}; };


uint64_t RNGDnnOpManager::glob_default_seed = 0; uint64_t RNGDnnOpManager::glob_default_seed = 0;
std::mutex RNGDnnOpManager::sm_mtx; std::mutex RNGDnnOpManager::sm_mtx;
std::unordered_map<RNGDnnOpManager::Handle,
CompNode::UnorderedMap<RNGDnnOpManager::Handle>>
RNGDnnOpManager::sm_partial2full;
CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles;


template <typename Op> template <typename Op>
struct OpMeth; struct OpMeth;
@@ -185,7 +176,11 @@ struct OpMeth<UniformRNG> {
using Param = DnnOp::Param; using Param = DnnOp::Param;
using OpNode = mgb::opr::UniformRNG; using OpNode = mgb::opr::UniformRNG;
static Param make_param(const UniformRNG& rng) { static Param make_param(const UniformRNG& rng) {
return {RNGDnnOpManager::get_seed(rng.handle())};
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu",
handle_seed, rng.seed);
return {handle_seed};
} }
}; };


@@ -195,7 +190,11 @@ struct OpMeth<GaussianRNG> {
using Param = DnnOp::Param; using Param = DnnOp::Param;
using OpNode = mgb::opr::GaussianRNG; using OpNode = mgb::opr::GaussianRNG;
static Param make_param(const GaussianRNG& rng) { static Param make_param(const GaussianRNG& rng) {
return {RNGDnnOpManager::get_seed(rng.handle()), rng.mean, rng.std};
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu",
handle_seed, rng.seed);
return {handle_seed, rng.mean, rng.std};
} }
}; };


@@ -206,23 +205,22 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
auto dest = outputs[0]; auto dest = outputs[0];


auto cn = dest->comp_node(); auto cn = dest->comp_node();
auto handle = RNGDnnOpManager::get_full_handle(rng.handle(), cn);
{
auto handle_cn = RNGDnnOpManager::get_comp_node(handle);
mgb_assert(cn == handle_cn,
"inconsistent comp_node: handle: %s, output: %s",
cn.to_string().c_str(), handle_cn.to_string().c_str());
auto handle = rng.handle;
if (!handle) {
handle = RNGDnnOpManager::get_default_handle(cn);
} }


// retrieve dnn_op from glob cache // retrieve dnn_op from glob cache
auto dnn_op_thread_safe = RNGDnnOpManager::inst() auto dnn_op_thread_safe = RNGDnnOpManager::inst()
.get_dnn_op<typename OpMeth<Op>::DnnOp>(handle, cn);
.get_dnn_op<typename OpMeth<Op>::DnnOp>(
handle, reinterpret_cast<size_t>(op.dyn_typeinfo()),
cn);
auto initialized = std::get<0>(dnn_op_thread_safe); auto initialized = std::get<0>(dnn_op_thread_safe);
auto dnn_op = std::get<1>(dnn_op_thread_safe); auto dnn_op = std::get<1>(dnn_op_thread_safe);
if (initialized) { if (initialized) {
auto handle_seed = RNGDnnOpManager::get_seed(handle); auto handle_seed = RNGDnnOpManager::get_seed(handle);
mgb_assert(dnn_op->param().seed == handle_seed, mgb_assert(dnn_op->param().seed == handle_seed,
"inconsistent rng seed: handle: %zu, dnn_op: %zu",
"inconsistent rng seed: handle: %lu, dnn_op: %lu",
handle_seed, dnn_op->param().seed); handle_seed, dnn_op->param().seed);
} }
dnn_op->param() = OpMeth<Op>::make_param(rng); dnn_op->param() = OpMeth<Op>::make_param(rng);
@@ -239,9 +237,12 @@ template <typename Op>
SmallVector<LogicalTensorDesc> infer_output_attrs( SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& op, const SmallVector<TensorPtr>& inputs) { const OpDef& op, const SmallVector<TensorPtr>& inputs) {
LogicalTensorDesc dest; LogicalTensorDesc dest;
dest.comp_node = op.cast_final_safe<Op>().comp_node();
if (!dest.comp_node.valid())
auto handle = op.cast_final_safe<Op>().handle;
if (handle) {
dest.comp_node = RNGDnnOpManager::get_comp_node(handle);
} else {
dest.comp_node = inputs[0]->comp_node(); dest.comp_node = inputs[0]->comp_node();
}


auto hv = inputs[0]->get_value().proxy_to_default_cpu(); auto hv = inputs[0]->get_value().proxy_to_default_cpu();
TensorShape tshape; TensorShape tshape;
@@ -263,15 +264,22 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
} }


template<typename Op> template<typename Op>
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) {
SymbolVar apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
size_t nr_inp = inputs.size(); size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 1, "UniformRNG expects 1 inputs; got %lu actually",
nr_inp);
auto&& rng = def.cast_final_safe<Op>(); auto&& rng = def.cast_final_safe<Op>();
mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually",
rng.dyn_typeinfo()->name,
nr_inp);
auto param = OpMeth<Op>::make_param(rng); auto param = OpMeth<Op>::make_param(rng);
return OpMeth<Op>::OpNode::make(
inputs[0], param, {rng.comp_node()}).node()->owner_opr();
OperatorNodeConfig config;
if (rng.handle) {
config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)};
} else {
config = {rng.make_name()};
}
return OpMeth<Op>::OpNode::make(inputs[0], param, config);
} }


template<typename T> template<typename T>
@@ -309,28 +317,22 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(


} // anonymous namespace } // anonymous namespace


RNGMixin::RNGMixin(CompNode cn):
m_handle(RNGDnnOpManager::get_default_handle(cn)) {}

uint64_t RNGMixin::seed() const {
return RNGDnnOpManager::get_seed(m_handle);
}

CompNode RNGMixin::comp_node() const {
return RNGDnnOpManager::get_comp_node(m_handle);
}

RNGMixin::Handle RNGMixin::new_handle(CompNode comp_node, uint64_t seed) {
Handle new_handle(CompNode comp_node, uint64_t seed) {
return RNGDnnOpManager::inst().new_handle(comp_node, seed); return RNGDnnOpManager::inst().new_handle(comp_node, seed);
} }


size_t RNGMixin::delete_handle(Handle handle) {
size_t delete_handle(Handle handle) {
return RNGDnnOpManager::inst().delete_handle(handle); return RNGDnnOpManager::inst().delete_handle(handle);
} }


void set_rng_seed(uint64_t seed) {
void set_global_rng_seed(uint64_t seed) {
RNGDnnOpManager::set_glob_default_seed(seed); RNGDnnOpManager::set_glob_default_seed(seed);
} }

uint64_t get_global_rng_seed() {
return RNGDnnOpManager::get_glob_default_seed();
}

#define REG_RNG_OP(NAME)\ #define REG_RNG_OP(NAME)\
namespace { \ namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
@@ -339,12 +341,10 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.fallback(); \ .fallback(); \
} \ } \
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME);


REG_RNG_OP(UniformRNG) REG_RNG_OP(UniformRNG)
REG_RNG_OP(GaussianRNG) REG_RNG_OP(GaussianRNG)


} // namespace imperative
} // namespace mgb
} // namespace mgb::imperative::rng


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 0
- 28
imperative/src/impl/ops/specializations.cpp View File

@@ -429,34 +429,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual)
.fallback(); .fallback();
}} // assert_equal }} // assert_equal


namespace { namespace uniform_rng {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const UniformRNG&>(def);
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::UniformRNG::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(UniformRNG, UniformRNG)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // uniform_rng

namespace { namespace gaussian_rng {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const GaussianRNG&>(def);
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::GaussianRNG::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(GaussianRNG, GaussianRNG)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // gaussian_rng

namespace { namespace roi_align { namespace { namespace roi_align {
VarNodeArray apply_on_var_node( VarNodeArray apply_on_var_node(
const OpDef& def, const OpDef& def,


+ 9
- 78
imperative/src/include/megbrain/imperative/ops/rng.h View File

@@ -2,7 +2,7 @@
* \file imperative/src/include/megbrain/imperative/ops/rng.h * \file imperative/src/include/megbrain/imperative/ops/rng.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
@@ -12,84 +12,15 @@
#pragma once #pragma once


#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/ops/autogen.h"


namespace mgb::imperative {
namespace mgb::imperative::rng {


class RNGMixin {
public:
using Handle = size_t;
using Handle = size_t;


static Handle new_handle(
CompNode comp_node={}, uint64_t seed=0);
Handle new_handle(CompNode comp_node, uint64_t seed);
size_t delete_handle(Handle handle);
void set_global_rng_seed(uint64_t seed);
uint64_t get_global_rng_seed();


static size_t delete_handle(Handle handle);

Handle handle() const {
return m_handle;
}

uint64_t seed() const;

CompNode comp_node() const;
protected:
RNGMixin(Handle handle): m_handle(handle) {}
RNGMixin(CompNode comp_node);
private:
Handle m_handle;
};

class GaussianRNG : public OpDefImplBase<GaussianRNG>,
public RNGMixin {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float mean = 1.0f, std = 0.0;
GaussianRNG(CompNode comp_node_): RNGMixin(comp_node_) {}
GaussianRNG(float mean_=1.0, float std_=0.0, CompNode comp_node_={}):
GaussianRNG(comp_node_) { mean = mean_; std = std_; }
GaussianRNG(float mean_, float std_, Handle handle):
RNGMixin(handle), mean(mean_), std(std_) {}
size_t hash() const override {
XXHash xxhash{};
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(dyn_typeinfo());
append(seed());
append(mean);
append(std);
return xxhash.digest();
}


bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const GaussianRNG&>(rhs_);
return rhs.seed() == seed()
&& rhs.mean == mean
&& rhs.std == std;
}
};

class UniformRNG : public OpDefImplBase<UniformRNG>,
public RNGMixin {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
UniformRNG(CompNode comp_node_={}): RNGMixin(comp_node_) {}
UniformRNG(Handle handle): RNGMixin(handle) {}

size_t hash() const override {
return hash_pair_combine(
mgb::hash(seed()),
reinterpret_cast<std::uintptr_t>(dyn_typeinfo()));
}

bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const UniformRNG&>(rhs_);
return rhs.dyn_typeinfo() == dyn_typeinfo()
&& rhs.seed() == seed();
}

};

void set_rng_seed(uint64_t seed);
} // namespace mgb::imperative
} // namespace mgb::imperative::rng

+ 14
- 6
imperative/src/test/rng.cpp View File

@@ -14,6 +14,7 @@


using namespace mgb; using namespace mgb;
using namespace imperative; using namespace imperative;
using namespace imperative::rng;


template<typename Op, typename ...Args> template<typename Op, typename ...Args>
void check_rng_basic(Args&& ...args) { void check_rng_basic(Args&& ...args) {
@@ -22,24 +23,31 @@ void check_rng_basic(Args&& ...args) {
{3, 4, 5, 6}, {3, 4, 5, 6},
{2333}}) {2333}})
for (auto&& cn: { for (auto&& cn: {
CompNode::load("cpu0"),
CompNode::load("xpu0")})
CompNode::load("xpu0"),
CompNode::load("xpu1")})
{ {
auto op = Op::make(std::forward<Args>(args)..., cn);
Handle h = new_handle(cn, 123);
auto op = Op::make(std::forward<Args>(args)..., h);
DeviceTensorND tshape_dev; DeviceTensorND tshape_dev;
cg::copy_shape_to_tensor_value(tshape_dev, tshape); cg::copy_shape_to_tensor_value(tshape_dev, tshape);
auto outputs = OpDef::apply_on_physical_tensor(*op, {Tensor::make(tshape_dev)});
SmallVector<TensorPtr> inputs = {Tensor::make(tshape_dev)};
auto outputs = OpDef::apply_on_physical_tensor(*op, inputs);
ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape)); ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape));
ASSERT_TRUE(cn == outputs[0]->comp_node()); ASSERT_TRUE(cn == outputs[0]->comp_node());
// sync before delete handle
for (auto&& p: outputs) {
p->get_value();
}
delete_handle(h);
} }
} }


TEST(TestImperative, UniformRNGBasic) { TEST(TestImperative, UniformRNGBasic) {
check_rng_basic<UniformRNG>();
check_rng_basic<UniformRNG>(123);
} }


TEST(TestImperative, GaussianRNGBasic) { TEST(TestImperative, GaussianRNGBasic) {
check_rng_basic<GaussianRNG>(2.f, 3.f);
check_rng_basic<GaussianRNG>(123, 2.f, 3.f);
} }


// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 20
- 4
src/core/include/megbrain/ir/ops.td View File

@@ -114,17 +114,33 @@ def TopK: MgbHashableOp<"TopK", [TopKParam]>;
def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>; def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;


def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> {
let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}];
let cmpFunction = [{return true;}];
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash($_self.handle));
}];
let cmpFunction = [{return $0.handle == $1.handle;}];
} }


def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> {
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{ let hashFunction = [{
return mgb::hash_pair_combine( return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()), mgb::hash($_self.dyn_typeinfo()),
mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std)));
mgb::hash_pair_combine(
mgb::hash($_self.handle),
mgb::hash_pair_combine(
mgb::hash($_self.mean),
mgb::hash($_self.std))
)
);
}]; }];
let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}];
let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std;}];
} }


def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {


Loading…
Cancel
Save