GitOrigin-RevId: ee5984c52d
release-1.4
@@ -156,7 +156,8 @@ def _logical_binary_elwise(mode, rev=False): | |||||
def _remove_axis(inp: Tensor, axis) -> Tensor: | def _remove_axis(inp: Tensor, axis) -> Tensor: | ||||
def get_axes(): | def get_axes(): | ||||
if axis is None: | if axis is None: | ||||
return [i for i, s in enumerate(inp.shape) if s == 1] | |||||
shp = inp.shape | |||||
return [i for i, s in enumerate(shp) if s == 1] | |||||
try: | try: | ||||
return [int(axis)] | return [int(axis)] | ||||
except (TypeError, ValueError): | except (TypeError, ValueError): | ||||
@@ -6,9 +6,11 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import time | |||||
from typing import List, Optional, Tuple | from typing import List, Optional, Tuple | ||||
from ..device import set_default_device, what_is_xpu | from ..device import set_default_device, what_is_xpu | ||||
from ..random import seed | |||||
from .server import Client, Server | from .server import Client, Server | ||||
@@ -156,6 +158,7 @@ def init_process_group( | |||||
WORLD.reset(list(range(world_size))) | WORLD.reset(list(range(world_size))) | ||||
set_default_device("{}{}".format(device_type, device)) | set_default_device("{}{}".format(device_type, device)) | ||||
seed(int(time.time()) + rank) | |||||
def is_distributed() -> bool: | def is_distributed() -> bool: | ||||
@@ -7,7 +7,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .distribution import normal, uniform | from .distribution import normal, uniform | ||||
from .rng import seed | |||||
from .rng import RNG, seed | |||||
# pylint: disable=undefined-variable | # pylint: disable=undefined-variable | ||||
del distribution, rng # type: ignore[name-defined] | del distribution, rng # type: ignore[name-defined] |
@@ -9,11 +9,8 @@ | |||||
from typing import Iterable, Optional | from typing import Iterable, Optional | ||||
from .. import Tensor | from .. import Tensor | ||||
from ..core._imperative_rt import invoke_op | |||||
from ..core._imperative_rt.core2 import apply | |||||
from ..core.ops.builtin import GaussianRNG, UniformRNG | |||||
from ..core.tensor import utils | |||||
from .rng import _random_seed_generator | |||||
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||||
from .rng import _normal, _uniform | |||||
__all__ = ["normal", "uniform"] | __all__ = ["normal", "uniform"] | ||||
@@ -48,14 +45,14 @@ def normal( | |||||
[-1.4939808 -1.5824696 ]] | [-1.4939808 -1.5824696 ]] | ||||
""" | """ | ||||
if size is None: | |||||
size = (1,) | |||||
op = GaussianRNG(mean, std) | |||||
_ref = Tensor([], dtype="int32") | |||||
shape = utils.astensor1d(size, _ref, dtype="int32") | |||||
shape = Tensor(shape, dtype="int32") | |||||
(output,) = apply(op, shape) | |||||
return output | |||||
return _normal( | |||||
mean=mean, | |||||
std=std, | |||||
size=size, | |||||
seed=_get_global_rng_seed(), | |||||
device=None, | |||||
handle=0, | |||||
) | |||||
def uniform( | def uniform( | ||||
@@ -88,14 +85,11 @@ def uniform( | |||||
[0.09365904 0.62957656]] | [0.09365904 0.62957656]] | ||||
""" | """ | ||||
assert low < high, "Uniform is not defined when low >= high" | |||||
if size is None: | |||||
size = (1,) | |||||
op = UniformRNG() | |||||
_ref = Tensor([], dtype="int32") | |||||
shape = utils.astensor1d(size, _ref, dtype="int32") | |||||
shape = Tensor(shape, dtype="int32") | |||||
(output,) = apply(op, shape) | |||||
return low + (high - low) * output | |||||
return _uniform( | |||||
low=low, | |||||
high=high, | |||||
size=size, | |||||
seed=_get_global_rng_seed(), | |||||
device=None, | |||||
handle=0, | |||||
) |
@@ -7,17 +7,94 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import time | import time | ||||
from typing import Iterable, Optional | |||||
from numpy.random import MT19937 | from numpy.random import MT19937 | ||||
from .. import Tensor | |||||
from ..core._imperative_rt.core2 import apply | |||||
from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle | |||||
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||||
from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle | |||||
from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed | |||||
from ..core.ops.builtin import GaussianRNG, UniformRNG | |||||
from ..core.tensor import utils | |||||
from ..device import get_default_device | |||||
_rng = None | _rng = None | ||||
def _random_seed_generator(): | |||||
if _rng is None: | |||||
from ..distributed.group import get_rank | |||||
def _normal( | |||||
mean: float, | |||||
std: float, | |||||
size: Optional[Iterable[int]], | |||||
seed: int, | |||||
device: str, | |||||
handle: int, | |||||
) -> Tensor: | |||||
if size is None: | |||||
size = (1,) | |||||
op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle) | |||||
_ref = Tensor([], dtype="int32", device=device) | |||||
shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | |||||
(output,) = apply(op, shape) | |||||
return output | |||||
def _uniform( | |||||
low: float, | |||||
high: float, | |||||
size: Optional[Iterable[int]], | |||||
seed: int, | |||||
device: str, | |||||
handle: int, | |||||
) -> Tensor: | |||||
assert low < high, "Uniform is not defined when low >= high" | |||||
if size is None: | |||||
size = (1,) | |||||
op = UniformRNG(seed=seed, handle=handle) | |||||
_ref = Tensor([], dtype="int32", device=device) | |||||
shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | |||||
(output,) = apply(op, shape) | |||||
return low + (high - low) * output | |||||
class RNG: | |||||
def __init__(self, seed=0, device=None): | |||||
self.seed = seed | |||||
self.device = device if device else get_default_device() | |||||
self.handle = _new_rng_handle(self.device, self.seed) | |||||
def uniform( | |||||
self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None | |||||
): | |||||
return _uniform( | |||||
low=low, | |||||
high=high, | |||||
size=size, | |||||
seed=self.seed, | |||||
device=self.device, | |||||
handle=self.handle, | |||||
) | |||||
seed(seed=int(time.time()) + get_rank()) | |||||
def normal( | |||||
self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None | |||||
): | |||||
return _normal( | |||||
mean=mean, | |||||
std=std, | |||||
size=size, | |||||
seed=self.seed, | |||||
device=self.device, | |||||
handle=self.handle, | |||||
) | |||||
def __del__(self): | |||||
_delete_rng_handle(self.handle) | |||||
def _random_seed_generator(): | |||||
assert _rng | |||||
while True: | while True: | ||||
yield _rng.random_raw() | yield _rng.random_raw() | ||||
@@ -25,3 +102,7 @@ def _random_seed_generator(): | |||||
def seed(seed: int): | def seed(seed: int): | ||||
global _rng # pylint: disable=global-statement | global _rng # pylint: disable=global-statement | ||||
_rng = MT19937(seed=seed) | _rng = MT19937(seed=seed) | ||||
_set_global_rng_seed(seed) | |||||
seed(int(time.time())) |
@@ -10,7 +10,10 @@ | |||||
*/ | */ | ||||
#include "./ops.h" | #include "./ops.h" | ||||
#include "./helper.h" | |||||
#include "./tensor.h" | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/imperative.h" | #include "megbrain/imperative.h" | ||||
#include "megbrain/imperative/ops/backward_graph.h" | #include "megbrain/imperative/ops/backward_graph.h" | ||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
@@ -491,21 +494,15 @@ void init_ops(py::module m) { | |||||
_init_py_op_base(m); | _init_py_op_base(m); | ||||
INIT_ALL_OP(m) | INIT_ALL_OP(m) | ||||
m.def("new_rng_handle", &RNGMixin::new_handle); | |||||
// FIXME: RNG op might execute after handle released due to async dispatch, | |||||
// which would cause memory leak or use-after-free | |||||
m.def("delete_rng_handle", &RNGMixin::delete_handle); | |||||
m.def("set_rng_seed", &set_rng_seed); | |||||
py::class_<UniformRNG, std::shared_ptr<UniformRNG>, OpDef>(m, "UniformRNG") | |||||
.def(py::init<>()) | |||||
.def(py::init<mgb::CompNode>()) | |||||
.def(py::init<RNGMixin::Handle>()); | |||||
py::class_<GaussianRNG, std::shared_ptr<GaussianRNG>, OpDef>(m, "GaussianRNG") | |||||
.def(py::init<>()) | |||||
.def(py::init<mgb::CompNode>()) | |||||
.def(py::init<float ,float>()) | |||||
.def(py::init<float ,float, mgb::CompNode>()) | |||||
.def(py::init<float ,float, RNGMixin::Handle>()); | |||||
m.def("new_rng_handle", &rng::new_handle); | |||||
m.def("delete_rng_handle", [](size_t handle){ | |||||
// RNG op might execute after handle released due to async dispatch, so | |||||
// we need sync before delete a handle to avoid memory leak or use-after-free | |||||
python::interpreter_for_py->sync(); | |||||
mgb::CompNode::sync_all(); | |||||
py_task_q.wait_all_task_finish(); | |||||
rng::delete_handle(handle); | |||||
}, py::call_guard<py::gil_scoped_release>()); | |||||
m.def("set_global_rng_seed", &rng::set_global_rng_seed); | |||||
m.def("get_global_rng_seed", &rng::get_global_rng_seed); | |||||
} | } |
@@ -0,0 +1,121 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
import megengine | |||||
from megengine import tensor | |||||
from megengine.core._imperative_rt import CompNode | |||||
from megengine.core._imperative_rt.core2 import apply | |||||
from megengine.core._imperative_rt.ops import ( | |||||
delete_rng_handle, | |||||
get_global_rng_seed, | |||||
new_rng_handle, | |||||
) | |||||
from megengine.core.ops.builtin import GaussianRNG, UniformRNG | |||||
from megengine.random import RNG | |||||
from megengine.random.rng import _normal, _uniform | |||||
def test_gaussian_op(): | |||||
shape = ( | |||||
8, | |||||
9, | |||||
11, | |||||
12, | |||||
) | |||||
shape = tensor(shape, dtype="int32") | |||||
op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0) | |||||
(output,) = apply(op, shape) | |||||
assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 | |||||
assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 | |||||
assert str(output.device) == str(CompNode("xpux")) | |||||
cn = CompNode("xpu2") | |||||
seed = 233333 | |||||
h = new_rng_handle(cn, seed) | |||||
op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h) | |||||
(output,) = apply(op, shape) | |||||
delete_rng_handle(h) | |||||
assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 | |||||
assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 | |||||
assert str(output.device) == str(cn) | |||||
def test_uniform_op(): | |||||
shape = ( | |||||
8, | |||||
9, | |||||
11, | |||||
12, | |||||
) | |||||
shape = tensor(shape, dtype="int32") | |||||
op = UniformRNG(seed=get_global_rng_seed()) | |||||
(output,) = apply(op, shape) | |||||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||||
assert str(output.device) == str(CompNode("xpux")) | |||||
cn = CompNode("xpu2") | |||||
seed = 233333 | |||||
h = new_rng_handle(cn, seed) | |||||
op = UniformRNG(seed=seed, handle=h) | |||||
(output,) = apply(op, shape) | |||||
delete_rng_handle(h) | |||||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||||
assert str(output.device) == str(cn) | |||||
def test_UniformRNG(): | |||||
m1 = RNG(seed=111, device="xpu0") | |||||
m2 = RNG(seed=111, device="xpu1") | |||||
m3 = RNG(seed=222, device="xpu0") | |||||
out1 = m1.uniform(size=(100,)) | |||||
out1_ = m1.uniform(size=(100,)) | |||||
out2 = m2.uniform(size=(100,)) | |||||
out3 = m3.uniform(size=(100,)) | |||||
np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||||
assert out1.device == "xpu0" and out2.device == "xpu1" | |||||
assert not (out1.numpy() == out3.numpy()).all() | |||||
assert not (out1.numpy() == out1_.numpy()).all() | |||||
low = -234 | |||||
high = 123 | |||||
out = m1.uniform(low=low, high=high, size=(20, 30, 40)) | |||||
out_shp = out.shape | |||||
if isinstance(out_shp, tuple): | |||||
assert out_shp == (20, 30, 40) | |||||
else: | |||||
assert all(out.shape.numpy() == np.array([20, 30, 40])) | |||||
assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1 | |||||
def test_NormalRNG(): | |||||
m1 = RNG(seed=111, device="xpu0") | |||||
m2 = RNG(seed=111, device="xpu1") | |||||
m3 = RNG(seed=222, device="xpu0") | |||||
out1 = m1.normal(size=(100,)) | |||||
out1_ = m1.uniform(size=(100,)) | |||||
out2 = m2.normal(size=(100,)) | |||||
out3 = m3.normal(size=(100,)) | |||||
np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||||
assert out1.device == "xpu0" and out2.device == "xpu1" | |||||
assert not (out1.numpy() == out3.numpy()).all() | |||||
assert not (out1.numpy() == out1_.numpy()).all() | |||||
mean = -1 | |||||
std = 2 | |||||
out = m1.normal(mean=mean, std=std, size=(20, 30, 40)) | |||||
out_shp = out.shape | |||||
if isinstance(out_shp, tuple): | |||||
assert out_shp == (20, 30, 40) | |||||
else: | |||||
assert all(out.shape.numpy() == np.array([20, 30, 40])) | |||||
assert np.abs(out.mean().numpy() - mean) / std < 0.1 | |||||
assert np.abs(np.std(out.numpy()) - std) < 0.1 |
@@ -1,76 +0,0 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
from megengine import tensor | |||||
from megengine.core._imperative_rt import CompNode | |||||
from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle | |||||
from megengine.core.ops.builtin import GaussianRNG, UniformRNG | |||||
from megengine.core.tensor.core import apply | |||||
def test_gaussian_rng(): | |||||
shape = ( | |||||
8, | |||||
9, | |||||
11, | |||||
12, | |||||
) | |||||
shape = tensor(shape, dtype="int32") | |||||
op = GaussianRNG(1.0, 3.0) | |||||
(output,) = apply(op, shape) | |||||
assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 | |||||
assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 | |||||
assert str(output.device) == str(CompNode("xpux")) | |||||
cn = CompNode("xpu1") | |||||
op = GaussianRNG(-1.0, 2.0, cn) | |||||
(output,) = apply(op, shape) | |||||
assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1 | |||||
assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1 | |||||
assert str(output.device) == str(cn) | |||||
cn = CompNode("xpu2") | |||||
seed = 233333 | |||||
h = new_rng_handle(cn, seed) | |||||
op = GaussianRNG(3.0, 1.0, h) | |||||
(output,) = apply(op, shape) | |||||
delete_rng_handle(h) | |||||
assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 | |||||
assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 | |||||
assert str(output.device) == str(cn) | |||||
def test_uniform_rng(): | |||||
shape = ( | |||||
8, | |||||
9, | |||||
11, | |||||
12, | |||||
) | |||||
shape = tensor(shape, dtype="int32") | |||||
op = UniformRNG() | |||||
(output,) = apply(op, shape) | |||||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||||
assert str(output.device) == str(CompNode("xpux")) | |||||
cn = CompNode("xpu1") | |||||
op = UniformRNG(cn) | |||||
(output,) = apply(op, shape) | |||||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||||
assert str(output.device) == str(cn) | |||||
cn = CompNode("xpu2") | |||||
seed = 233333 | |||||
h = new_rng_handle(cn, seed) | |||||
op = UniformRNG(h) | |||||
(output,) = apply(op, shape) | |||||
delete_rng_handle(h) | |||||
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||||
assert str(output.device) == str(cn) |
@@ -2,7 +2,7 @@ | |||||
* \file imperative/src/impl/ops/rng.cpp | * \file imperative/src/impl/ops/rng.cpp | ||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
@@ -10,23 +10,23 @@ | |||||
*/ | */ | ||||
#include "megbrain/imperative/ops/rng.h" | #include "megbrain/imperative/ops/rng.h" | ||||
#include <bits/stdint-uintn.h> | |||||
#include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
#include "megbrain/graph/helper.h" | #include "megbrain/graph/helper.h" | ||||
#include "megbrain/opr/rand.h" | #include "megbrain/opr/rand.h" | ||||
//#include "megbrain/common.h" | |||||
#include "../op_trait.h" | #include "../op_trait.h" | ||||
#include "../dnn_op_helper.h" | |||||
namespace mgb { | |||||
namespace imperative { | |||||
namespace mgb::imperative::rng { | |||||
namespace { | namespace { | ||||
template <typename HandleFactory, typename THandle> | template <typename HandleFactory, typename THandle> | ||||
class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj { | class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj { | ||||
public: | public: | ||||
using DT = CompNode::DeviceType; | |||||
using Handle = THandle; | using Handle = THandle; | ||||
using OpTypeInfo = size_t; | |||||
template <typename... Args> | template <typename... Args> | ||||
Handle new_handle(Args&&... args) { | Handle new_handle(Args&&... args) { | ||||
@@ -38,27 +38,26 @@ public: | |||||
size_t removed = 0; | size_t removed = 0; | ||||
if (!is_finalized()) { | if (!is_finalized()) { | ||||
MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
removed = m_handle2op.erase(handle); | |||||
removed = m_handle2ops.erase(handle); | |||||
} | } | ||||
static_cast<HandleFactory*>(this)->do_delete_handle(handle); | static_cast<HandleFactory*>(this)->do_delete_handle(handle); | ||||
return removed; | return removed; | ||||
} | } | ||||
template <typename DnnOp> | template <typename DnnOp> | ||||
auto get_dnn_op(Handle handle, CompNode cn) { | |||||
auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) { | |||||
mgb_assert(!is_finalized()); | mgb_assert(!is_finalized()); | ||||
DnnOpWithMutex* dnn_op_with_mtx; | DnnOpWithMutex* dnn_op_with_mtx; | ||||
{ | { | ||||
MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
dnn_op_with_mtx = &m_handle2op[handle]; | |||||
dnn_op_with_mtx = &m_handle2ops[handle][tpinfo]; | |||||
} | } | ||||
auto dnn_handle = | auto dnn_handle = | ||||
MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); | MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); | ||||
DnnOp* dnn_op; | |||||
std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx); | std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx); | ||||
bool initialized = false; | bool initialized = false; | ||||
if ((dnn_op = dynamic_cast<DnnOp*>(dnn_op_with_mtx->op.get())) != | |||||
nullptr) { | |||||
DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get()); | |||||
if (dnn_op != nullptr) { | |||||
mgb_assert(dnn_op->handle() == dnn_handle); | mgb_assert(dnn_op->handle() == dnn_handle); | ||||
initialized = true; | initialized = true; | ||||
} else { | } else { | ||||
@@ -77,35 +76,30 @@ private: | |||||
struct DnnOpWithMutex { | struct DnnOpWithMutex { | ||||
std::mutex mtx; | std::mutex mtx; | ||||
std::unique_ptr<megdnn::OperatorBase> op; | std::unique_ptr<megdnn::OperatorBase> op; | ||||
DnnOpWithMutex(): op{nullptr} {} | |||||
}; | }; | ||||
std::shared_ptr<void> on_comp_node_finalize() override { | std::shared_ptr<void> on_comp_node_finalize() override { | ||||
MGB_LOCK_GUARD(m_mtx); | MGB_LOCK_GUARD(m_mtx); | ||||
m_handle2op.clear(); | |||||
m_handle2ops.clear(); | |||||
return {}; | return {}; | ||||
} | } | ||||
std::unordered_map<Handle, DnnOpWithMutex> m_handle2op; | |||||
std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex> > m_handle2ops; | |||||
std::mutex m_mtx; | std::mutex m_mtx; | ||||
}; | }; | ||||
class RNGDnnOpManager final | class RNGDnnOpManager final | ||||
: public DnnOpManagerT<RNGDnnOpManager, RNGMixin::Handle> { | |||||
: public DnnOpManagerT<RNGDnnOpManager, Handle> { | |||||
public: | public: | ||||
Handle new_handle(CompNode comp_node, uint64_t seed) { | |||||
MGB_LOCK_GUARD(sm_mtx); | |||||
return DnnOpManagerBase::new_handle(comp_node, seed); | |||||
} | |||||
size_t delete_handle(Handle handle) { | size_t delete_handle(Handle handle) { | ||||
size_t ret = 0; | |||||
{ | |||||
MGB_LOCK_GUARD(sm_mtx); | |||||
auto iter = sm_partial2full.find(handle); | |||||
if (iter != sm_partial2full.end()) { | |||||
for (auto&& h : iter->second) { | |||||
ret += DnnOpManagerBase::delete_handle(h.second); | |||||
} | |||||
sm_partial2full.erase(iter); | |||||
} | |||||
} | |||||
ret += DnnOpManagerBase::delete_handle(handle); | |||||
return ret; | |||||
MGB_LOCK_GUARD(sm_mtx); | |||||
return DnnOpManagerBase::delete_handle(handle); | |||||
} | } | ||||
Handle do_new_handle(CompNode comp_node, uint64_t seed) { | Handle do_new_handle(CompNode comp_node, uint64_t seed) { | ||||
@@ -118,32 +112,26 @@ public: | |||||
} | } | ||||
static uint64_t get_seed(Handle handle) { | static uint64_t get_seed(Handle handle) { | ||||
if (!handle) { return glob_default_seed; } | |||||
return reinterpret_cast<HandleData*>(handle)->seed; | return reinterpret_cast<HandleData*>(handle)->seed; | ||||
} | } | ||||
static CompNode get_comp_node(Handle handle) { | static CompNode get_comp_node(Handle handle) { | ||||
mgb_assert(handle, "invalid handle"); | |||||
return reinterpret_cast<HandleData*>(handle)->comp_node; | return reinterpret_cast<HandleData*>(handle)->comp_node; | ||||
} | } | ||||
static Handle get_full_handle(Handle handle, CompNode comp_node) { | |||||
if (get_comp_node(handle).valid()) { | |||||
return handle; | |||||
} | |||||
MGB_LOCK_GUARD(sm_mtx); | |||||
auto&& full = sm_partial2full[handle][comp_node]; | |||||
if (!full) { | |||||
full = inst().new_handle(comp_node, get_seed(handle)); | |||||
} | |||||
return full; | |||||
} | |||||
static Handle get_default_handle(CompNode comp_node) { | static Handle get_default_handle(CompNode comp_node) { | ||||
static Handle glob_partial_handle = | |||||
inst().new_handle(CompNode{}, glob_default_seed); | |||||
if (!comp_node.valid()) { | |||||
return glob_partial_handle; | |||||
mgb_assert(comp_node.valid()); | |||||
MGB_LOCK_GUARD(sm_mtx); | |||||
auto&& glob_handle = glob_default_handles[comp_node]; | |||||
if (!glob_handle) { | |||||
glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | |||||
} else if (get_seed(glob_handle) != glob_default_seed) { | |||||
inst().DnnOpManagerBase::delete_handle(glob_handle); | |||||
glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | |||||
} | } | ||||
return get_full_handle(glob_partial_handle, comp_node); | |||||
return glob_handle; | |||||
} | } | ||||
static RNGDnnOpManager& inst() { | static RNGDnnOpManager& inst() { | ||||
@@ -152,9 +140,15 @@ public: | |||||
} | } | ||||
static void set_glob_default_seed(uint64_t seed) { | static void set_glob_default_seed(uint64_t seed) { | ||||
MGB_LOCK_GUARD(sm_mtx); | |||||
glob_default_seed = seed; | glob_default_seed = seed; | ||||
} | } | ||||
static uint64_t get_glob_default_seed() { | |||||
MGB_LOCK_GUARD(sm_mtx); | |||||
return glob_default_seed; | |||||
} | |||||
private: | private: | ||||
struct HandleData { | struct HandleData { | ||||
CompNode comp_node; | CompNode comp_node; | ||||
@@ -165,16 +159,13 @@ private: | |||||
MemPool<HandleData> m_handle_pool; | MemPool<HandleData> m_handle_pool; | ||||
static std::mutex sm_mtx; | static std::mutex sm_mtx; | ||||
static std::unordered_map<Handle, CompNode::UnorderedMap<Handle>> | |||||
sm_partial2full; | |||||
static CompNode::UnorderedMap<Handle> glob_default_handles; | |||||
static uint64_t glob_default_seed; | static uint64_t glob_default_seed; | ||||
}; | }; | ||||
uint64_t RNGDnnOpManager::glob_default_seed = 0; | uint64_t RNGDnnOpManager::glob_default_seed = 0; | ||||
std::mutex RNGDnnOpManager::sm_mtx; | std::mutex RNGDnnOpManager::sm_mtx; | ||||
std::unordered_map<RNGDnnOpManager::Handle, | |||||
CompNode::UnorderedMap<RNGDnnOpManager::Handle>> | |||||
RNGDnnOpManager::sm_partial2full; | |||||
CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles; | |||||
template <typename Op> | template <typename Op> | ||||
struct OpMeth; | struct OpMeth; | ||||
@@ -185,7 +176,11 @@ struct OpMeth<UniformRNG> { | |||||
using Param = DnnOp::Param; | using Param = DnnOp::Param; | ||||
using OpNode = mgb::opr::UniformRNG; | using OpNode = mgb::opr::UniformRNG; | ||||
static Param make_param(const UniformRNG& rng) { | static Param make_param(const UniformRNG& rng) { | ||||
return {RNGDnnOpManager::get_seed(rng.handle())}; | |||||
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||||
mgb_assert(handle_seed == rng.seed, | |||||
"inconsistent rng seed: rng op: %lu handle: %lu", | |||||
handle_seed, rng.seed); | |||||
return {handle_seed}; | |||||
} | } | ||||
}; | }; | ||||
@@ -195,7 +190,11 @@ struct OpMeth<GaussianRNG> { | |||||
using Param = DnnOp::Param; | using Param = DnnOp::Param; | ||||
using OpNode = mgb::opr::GaussianRNG; | using OpNode = mgb::opr::GaussianRNG; | ||||
static Param make_param(const GaussianRNG& rng) { | static Param make_param(const GaussianRNG& rng) { | ||||
return {RNGDnnOpManager::get_seed(rng.handle()), rng.mean, rng.std}; | |||||
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||||
mgb_assert(handle_seed == rng.seed, | |||||
"inconsistent rng seed: rng op: %lu handle: %lu", | |||||
handle_seed, rng.seed); | |||||
return {handle_seed, rng.mean, rng.std}; | |||||
} | } | ||||
}; | }; | ||||
@@ -206,23 +205,22 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, | |||||
auto dest = outputs[0]; | auto dest = outputs[0]; | ||||
auto cn = dest->comp_node(); | auto cn = dest->comp_node(); | ||||
auto handle = RNGDnnOpManager::get_full_handle(rng.handle(), cn); | |||||
{ | |||||
auto handle_cn = RNGDnnOpManager::get_comp_node(handle); | |||||
mgb_assert(cn == handle_cn, | |||||
"inconsistent comp_node: handle: %s, output: %s", | |||||
cn.to_string().c_str(), handle_cn.to_string().c_str()); | |||||
auto handle = rng.handle; | |||||
if (!handle) { | |||||
handle = RNGDnnOpManager::get_default_handle(cn); | |||||
} | } | ||||
// retrieve dnn_op from glob cache | // retrieve dnn_op from glob cache | ||||
auto dnn_op_thread_safe = RNGDnnOpManager::inst() | auto dnn_op_thread_safe = RNGDnnOpManager::inst() | ||||
.get_dnn_op<typename OpMeth<Op>::DnnOp>(handle, cn); | |||||
.get_dnn_op<typename OpMeth<Op>::DnnOp>( | |||||
handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), | |||||
cn); | |||||
auto initialized = std::get<0>(dnn_op_thread_safe); | auto initialized = std::get<0>(dnn_op_thread_safe); | ||||
auto dnn_op = std::get<1>(dnn_op_thread_safe); | auto dnn_op = std::get<1>(dnn_op_thread_safe); | ||||
if (initialized) { | if (initialized) { | ||||
auto handle_seed = RNGDnnOpManager::get_seed(handle); | auto handle_seed = RNGDnnOpManager::get_seed(handle); | ||||
mgb_assert(dnn_op->param().seed == handle_seed, | mgb_assert(dnn_op->param().seed == handle_seed, | ||||
"inconsistent rng seed: handle: %zu, dnn_op: %zu", | |||||
"inconsistent rng seed: handle: %lu, dnn_op: %lu", | |||||
handle_seed, dnn_op->param().seed); | handle_seed, dnn_op->param().seed); | ||||
} | } | ||||
dnn_op->param() = OpMeth<Op>::make_param(rng); | dnn_op->param() = OpMeth<Op>::make_param(rng); | ||||
@@ -239,9 +237,12 @@ template <typename Op> | |||||
SmallVector<LogicalTensorDesc> infer_output_attrs( | SmallVector<LogicalTensorDesc> infer_output_attrs( | ||||
const OpDef& op, const SmallVector<TensorPtr>& inputs) { | const OpDef& op, const SmallVector<TensorPtr>& inputs) { | ||||
LogicalTensorDesc dest; | LogicalTensorDesc dest; | ||||
dest.comp_node = op.cast_final_safe<Op>().comp_node(); | |||||
if (!dest.comp_node.valid()) | |||||
auto handle = op.cast_final_safe<Op>().handle; | |||||
if (handle) { | |||||
dest.comp_node = RNGDnnOpManager::get_comp_node(handle); | |||||
} else { | |||||
dest.comp_node = inputs[0]->comp_node(); | dest.comp_node = inputs[0]->comp_node(); | ||||
} | |||||
auto hv = inputs[0]->get_value().proxy_to_default_cpu(); | auto hv = inputs[0]->get_value().proxy_to_default_cpu(); | ||||
TensorShape tshape; | TensorShape tshape; | ||||
@@ -263,15 +264,22 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
} | } | ||||
template<typename Op> | template<typename Op> | ||||
cg::OperatorNodeBase* apply_on_var_node( | |||||
const OpDef& def, const VarNodeArray& inputs) { | |||||
SymbolVar apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
size_t nr_inp = inputs.size(); | size_t nr_inp = inputs.size(); | ||||
mgb_assert(nr_inp == 1, "UniformRNG expects 1 inputs; got %lu actually", | |||||
nr_inp); | |||||
auto&& rng = def.cast_final_safe<Op>(); | auto&& rng = def.cast_final_safe<Op>(); | ||||
mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", | |||||
rng.dyn_typeinfo()->name, | |||||
nr_inp); | |||||
auto param = OpMeth<Op>::make_param(rng); | auto param = OpMeth<Op>::make_param(rng); | ||||
return OpMeth<Op>::OpNode::make( | |||||
inputs[0], param, {rng.comp_node()}).node()->owner_opr(); | |||||
OperatorNodeConfig config; | |||||
if (rng.handle) { | |||||
config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; | |||||
} else { | |||||
config = {rng.make_name()}; | |||||
} | |||||
return OpMeth<Op>::OpNode::make(inputs[0], param, config); | |||||
} | } | ||||
template<typename T> | template<typename T> | ||||
@@ -309,28 +317,22 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
} // anonymous namespace | } // anonymous namespace | ||||
RNGMixin::RNGMixin(CompNode cn): | |||||
m_handle(RNGDnnOpManager::get_default_handle(cn)) {} | |||||
uint64_t RNGMixin::seed() const { | |||||
return RNGDnnOpManager::get_seed(m_handle); | |||||
} | |||||
CompNode RNGMixin::comp_node() const { | |||||
return RNGDnnOpManager::get_comp_node(m_handle); | |||||
} | |||||
RNGMixin::Handle RNGMixin::new_handle(CompNode comp_node, uint64_t seed) { | |||||
Handle new_handle(CompNode comp_node, uint64_t seed) { | |||||
return RNGDnnOpManager::inst().new_handle(comp_node, seed); | return RNGDnnOpManager::inst().new_handle(comp_node, seed); | ||||
} | } | ||||
size_t RNGMixin::delete_handle(Handle handle) { | |||||
size_t delete_handle(Handle handle) { | |||||
return RNGDnnOpManager::inst().delete_handle(handle); | return RNGDnnOpManager::inst().delete_handle(handle); | ||||
} | } | ||||
void set_rng_seed(uint64_t seed) { | |||||
void set_global_rng_seed(uint64_t seed) { | |||||
RNGDnnOpManager::set_glob_default_seed(seed); | RNGDnnOpManager::set_glob_default_seed(seed); | ||||
} | } | ||||
uint64_t get_global_rng_seed() { | |||||
return RNGDnnOpManager::get_glob_default_seed(); | |||||
} | |||||
#define REG_RNG_OP(NAME)\ | #define REG_RNG_OP(NAME)\ | ||||
namespace { \ | namespace { \ | ||||
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | ||||
@@ -339,12 +341,10 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||||
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ | .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ | ||||
.fallback(); \ | .fallback(); \ | ||||
} \ | } \ | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME); | |||||
REG_RNG_OP(UniformRNG) | REG_RNG_OP(UniformRNG) | ||||
REG_RNG_OP(GaussianRNG) | REG_RNG_OP(GaussianRNG) | ||||
} // namespace imperative | |||||
} // namespace mgb | |||||
} // namespace mgb::imperative::rng | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -429,34 +429,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual) | |||||
.fallback(); | .fallback(); | ||||
}} // assert_equal | }} // assert_equal | ||||
namespace { namespace uniform_rng { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const UniformRNG&>(def); | |||||
mgb_assert(inputs.size() == 1); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::UniformRNG::make(inputs[0], op.param(), config); | |||||
} | |||||
OP_TRAIT_REG(UniformRNG, UniformRNG) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // uniform_rng | |||||
namespace { namespace gaussian_rng { | |||||
auto apply_on_var_node( | |||||
const OpDef& def, | |||||
const VarNodeArray& inputs) { | |||||
auto&& op = static_cast<const GaussianRNG&>(def); | |||||
mgb_assert(inputs.size() == 1); | |||||
OperatorNodeConfig config{op.make_name()}; | |||||
return opr::GaussianRNG::make(inputs[0], op.param(), config); | |||||
} | |||||
OP_TRAIT_REG(GaussianRNG, GaussianRNG) | |||||
.apply_on_var_node(apply_on_var_node) | |||||
.fallback(); | |||||
}} // gaussian_rng | |||||
namespace { namespace roi_align { | namespace { namespace roi_align { | ||||
VarNodeArray apply_on_var_node( | VarNodeArray apply_on_var_node( | ||||
const OpDef& def, | const OpDef& def, | ||||
@@ -2,7 +2,7 @@ | |||||
* \file imperative/src/include/megbrain/imperative/ops/rng.h | * \file imperative/src/include/megbrain/imperative/ops/rng.h | ||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
* | * | ||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
@@ -12,84 +12,15 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
#include "megbrain/imperative/ops/autogen.h" | |||||
namespace mgb::imperative { | |||||
namespace mgb::imperative::rng { | |||||
class RNGMixin { | |||||
public: | |||||
using Handle = size_t; | |||||
using Handle = size_t; | |||||
static Handle new_handle( | |||||
CompNode comp_node={}, uint64_t seed=0); | |||||
Handle new_handle(CompNode comp_node, uint64_t seed); | |||||
size_t delete_handle(Handle handle); | |||||
void set_global_rng_seed(uint64_t seed); | |||||
uint64_t get_global_rng_seed(); | |||||
static size_t delete_handle(Handle handle); | |||||
Handle handle() const { | |||||
return m_handle; | |||||
} | |||||
uint64_t seed() const; | |||||
CompNode comp_node() const; | |||||
protected: | |||||
RNGMixin(Handle handle): m_handle(handle) {} | |||||
RNGMixin(CompNode comp_node); | |||||
private: | |||||
Handle m_handle; | |||||
}; | |||||
class GaussianRNG : public OpDefImplBase<GaussianRNG>, | |||||
public RNGMixin { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
public: | |||||
float mean = 1.0f, std = 0.0; | |||||
GaussianRNG(CompNode comp_node_): RNGMixin(comp_node_) {} | |||||
GaussianRNG(float mean_=1.0, float std_=0.0, CompNode comp_node_={}): | |||||
GaussianRNG(comp_node_) { mean = mean_; std = std_; } | |||||
GaussianRNG(float mean_, float std_, Handle handle): | |||||
RNGMixin(handle), mean(mean_), std(std_) {} | |||||
size_t hash() const override { | |||||
XXHash xxhash{}; | |||||
auto append = [&xxhash](auto field){ | |||||
auto hash_val = HashTrait<decltype(field)>::eval(field); | |||||
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||||
}; | |||||
append(dyn_typeinfo()); | |||||
append(seed()); | |||||
append(mean); | |||||
append(std); | |||||
return xxhash.digest(); | |||||
} | |||||
bool is_same_st(const Hashable& rhs_) const override { | |||||
auto&& rhs = static_cast<const GaussianRNG&>(rhs_); | |||||
return rhs.seed() == seed() | |||||
&& rhs.mean == mean | |||||
&& rhs.std == std; | |||||
} | |||||
}; | |||||
class UniformRNG : public OpDefImplBase<UniformRNG>, | |||||
public RNGMixin { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
public: | |||||
UniformRNG(CompNode comp_node_={}): RNGMixin(comp_node_) {} | |||||
UniformRNG(Handle handle): RNGMixin(handle) {} | |||||
size_t hash() const override { | |||||
return hash_pair_combine( | |||||
mgb::hash(seed()), | |||||
reinterpret_cast<std::uintptr_t>(dyn_typeinfo())); | |||||
} | |||||
bool is_same_st(const Hashable& rhs_) const override { | |||||
auto&& rhs = static_cast<const UniformRNG&>(rhs_); | |||||
return rhs.dyn_typeinfo() == dyn_typeinfo() | |||||
&& rhs.seed() == seed(); | |||||
} | |||||
}; | |||||
void set_rng_seed(uint64_t seed); | |||||
} // namespace mgb::imperative | |||||
} // namespace mgb::imperative::rng |
@@ -14,6 +14,7 @@ | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace imperative; | using namespace imperative; | ||||
using namespace imperative::rng; | |||||
template<typename Op, typename ...Args> | template<typename Op, typename ...Args> | ||||
void check_rng_basic(Args&& ...args) { | void check_rng_basic(Args&& ...args) { | ||||
@@ -22,24 +23,31 @@ void check_rng_basic(Args&& ...args) { | |||||
{3, 4, 5, 6}, | {3, 4, 5, 6}, | ||||
{2333}}) | {2333}}) | ||||
for (auto&& cn: { | for (auto&& cn: { | ||||
CompNode::load("cpu0"), | |||||
CompNode::load("xpu0")}) | |||||
CompNode::load("xpu0"), | |||||
CompNode::load("xpu1")}) | |||||
{ | { | ||||
auto op = Op::make(std::forward<Args>(args)..., cn); | |||||
Handle h = new_handle(cn, 123); | |||||
auto op = Op::make(std::forward<Args>(args)..., h); | |||||
DeviceTensorND tshape_dev; | DeviceTensorND tshape_dev; | ||||
cg::copy_shape_to_tensor_value(tshape_dev, tshape); | cg::copy_shape_to_tensor_value(tshape_dev, tshape); | ||||
auto outputs = OpDef::apply_on_physical_tensor(*op, {Tensor::make(tshape_dev)}); | |||||
SmallVector<TensorPtr> inputs = {Tensor::make(tshape_dev)}; | |||||
auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); | |||||
ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape)); | ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape)); | ||||
ASSERT_TRUE(cn == outputs[0]->comp_node()); | ASSERT_TRUE(cn == outputs[0]->comp_node()); | ||||
// sync before delete handle | |||||
for (auto&& p: outputs) { | |||||
p->get_value(); | |||||
} | |||||
delete_handle(h); | |||||
} | } | ||||
} | } | ||||
TEST(TestImperative, UniformRNGBasic) { | TEST(TestImperative, UniformRNGBasic) { | ||||
check_rng_basic<UniformRNG>(); | |||||
check_rng_basic<UniformRNG>(123); | |||||
} | } | ||||
TEST(TestImperative, GaussianRNGBasic) { | TEST(TestImperative, GaussianRNGBasic) { | ||||
check_rng_basic<GaussianRNG>(2.f, 3.f); | |||||
check_rng_basic<GaussianRNG>(123, 2.f, 3.f); | |||||
} | } | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -114,17 +114,33 @@ def TopK: MgbHashableOp<"TopK", [TopKParam]>; | |||||
def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>; | def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>; | ||||
def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { | def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { | ||||
let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}]; | |||||
let cmpFunction = [{return true;}]; | |||||
let extraArguments = (ins | |||||
MgbSizeTAddr:$handle | |||||
); | |||||
let hashFunction = [{ | |||||
return mgb::hash_pair_combine( | |||||
mgb::hash($_self.dyn_typeinfo()), | |||||
mgb::hash($_self.handle)); | |||||
}]; | |||||
let cmpFunction = [{return $0.handle == $1.handle;}]; | |||||
} | } | ||||
def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { | def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { | ||||
let extraArguments = (ins | |||||
MgbSizeTAddr:$handle | |||||
); | |||||
let hashFunction = [{ | let hashFunction = [{ | ||||
return mgb::hash_pair_combine( | return mgb::hash_pair_combine( | ||||
mgb::hash($_self.dyn_typeinfo()), | mgb::hash($_self.dyn_typeinfo()), | ||||
mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std))); | |||||
mgb::hash_pair_combine( | |||||
mgb::hash($_self.handle), | |||||
mgb::hash_pair_combine( | |||||
mgb::hash($_self.mean), | |||||
mgb::hash($_self.std)) | |||||
) | |||||
); | |||||
}]; | }]; | ||||
let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}]; | |||||
let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std;}]; | |||||
} | } | ||||
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | ||||