Browse Source

fix(mge/module): tensor shape will not work when constructing numpy array

GitOrigin-RevId: 5a0d705970
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
2beb65b19d
2 changed files with 17 additions and 4 deletions
  1. +5
    -3
      imperative/python/megengine/module/init.py
  2. +12
    -1
      imperative/python/test/unit/module/test_init.py

+ 5
- 3
imperative/python/megengine/module/init.py View File

@@ -12,6 +12,8 @@ from typing import Optional, Tuple, Union


import numpy as np import numpy as np


from ..functional import full
from ..random import gaussian, uniform
from ..tensor import Tensor from ..tensor import Tensor




@@ -21,7 +23,7 @@ def fill_(tensor: Tensor, val: Union[float, int]) -> None:
:param tensor: An n-dimentional tensor to be initialized :param tensor: An n-dimentional tensor to be initialized
:param val: The value to be filled throughout the tensor :param val: The value to be filled throughout the tensor
""" """
tensor.set_value(np.full(tensor.shape, val, tensor.dtype))
tensor.set_value(full(shape=tensor.shape, value=val, dtype=tensor.dtype))




def zeros_(tensor: Tensor) -> None: def zeros_(tensor: Tensor) -> None:
@@ -48,7 +50,7 @@ def uniform_(tensor: Tensor, a: float = 0.0, b: float = 1.0) -> None:
:param a: Lower bound of the sampling interval :param a: Lower bound of the sampling interval
:param b: Upper bound of the sampling interval :param b: Upper bound of the sampling interval
""" """
tensor.set_value(np.random.uniform(a, b, tensor.shape).astype(tensor.dtype))
tensor.set_value(uniform(tensor.shape, low=a, high=b).astype(tensor.dtype))




def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None: def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
@@ -59,7 +61,7 @@ def normal_(tensor: Tensor, mean: float = 0.0, std: float = 1.0) -> None:
:param mean: The mean of the normal distribution :param mean: The mean of the normal distribution
:param std: The standard deviation of the normal distribution :param std: The standard deviation of the normal distribution
""" """
tensor.set_value(np.random.normal(mean, std, tensor.shape).astype(np.float32))
tensor.set_value(gaussian(tensor.shape, mean=mean, std=std).astype(tensor.dtype))




def calculate_gain( def calculate_gain(


+ 12
- 1
imperative/python/test/unit/module/test_init.py View File

@@ -6,10 +6,21 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest import pytest


from megengine import tensor
from megengine.module import Conv2d, Linear from megengine.module import Conv2d, Linear
from megengine.module.init import calculate_fan_in_and_fan_out
from megengine.module.init import calculate_fan_in_and_fan_out, fill_


def test_fill_():
x = tensor(np.zeros((2, 3, 4)), dtype=np.float32)
fill_(x, 5.0)

np.testing.assert_array_equal(
x.numpy(), np.full(shape=(2, 3, 4), fill_value=5.0, dtype=np.float32)
)




def test_calculate_fan_in_and_fan_out(): def test_calculate_fan_in_and_fan_out():


Loading…
Cancel
Save