From ea00b57d93af277ab1ecb754a5a4969f2a774331 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 23 Mar 2020 17:45:07 +0800 Subject: [PATCH] refactor(mge/module): use mge rng to sample from uniform/normal distribution GitOrigin-RevId: 6ec8f99af5cc3f82fe70fe357e274f67eb1d4c36 --- python_module/megengine/module/init.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python_module/megengine/module/init.py b/python_module/megengine/module/init.py index 69cf257b..38945bb4 100644 --- a/python_module/megengine/module/init.py +++ b/python_module/megengine/module/init.py @@ -12,7 +12,8 @@ from typing import Optional, Tuple, Union import numpy as np -from ..core import Tensor +from ..core import Tensor, Graph +from ..random import gaussian, uniform def fill_(tensor: Tensor, val: Union[float, int]) -> None: @@ -48,7 +49,8 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None: :param a: Lower bound of the sampling interval :param b: Upper bound of the sampling interval """ - tensor.set_value(np.random.uniform(a, b, tensor.shape).astype(tensor.dtype)) + with Graph(eager_evaluation=True): + tensor.set_value((b - a) * uniform(tensor.shape) + a) def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: @@ -59,7 +61,8 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: :param mean: The mean of the normal distribution :param std: The standard deviation of the normal distribution """ - tensor.set_value(np.random.normal(mean, std, tensor.shape).astype(np.float32)) + with Graph(eager_evaluation=True): + tensor.set_value(gaussian(tensor.shape, mean=mean, std=std)) def calculate_gain(