Browse Source

refactor(mge/module): use mge rng to sample from uniform/normal distribution

GitOrigin-RevId: 6ec8f99af5
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
ea00b57d93
1 changed files with 6 additions and 3 deletions
  1. +6
    -3
      python_module/megengine/module/init.py

+ 6
- 3
python_module/megengine/module/init.py View File

@@ -12,7 +12,8 @@ from typing import Optional, Tuple, Union

import numpy as np

from ..core import Tensor
from ..core import Tensor, Graph
from ..random import gaussian, uniform


def fill_(tensor: Tensor, val: Union[float, int]) -> None:
@@ -48,7 +49,8 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None:
:param a: Lower bound of the sampling interval
:param b: Upper bound of the sampling interval
"""
tensor.set_value(np.random.uniform(a, b, tensor.shape).astype(tensor.dtype))
with Graph(eager_evaluation=True):
tensor.set_value((b - a) * uniform(tensor.shape) + a)


def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
@@ -59,7 +61,8 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
:param mean: The mean of the normal distribution
:param std: The standard deviation of the normal distribution
"""
tensor.set_value(np.random.normal(mean, std, tensor.shape).astype(np.float32))
with Graph(eager_evaluation=True):
tensor.set_value(gaussian(tensor.shape, mean=mean, std=std))


def calculate_gain(


Loading…
Cancel
Save