Browse Source

fix(mge/functional): refine doc and add test for linspace

GitOrigin-RevId: d6c7eea2c2
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
bb92ee26e3
2 changed files with 23 additions and 1 deletions
  1. +1
    -1
      python_module/megengine/functional/tensor.py
  2. +22
    -0
      python_module/test/unit/functional/test_functional.py

+ 1
- 1
python_module/megengine/functional/tensor.py View File

@@ -515,7 +515,7 @@ def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
def linspace( def linspace(
start: Union[int, float, Tensor], start: Union[int, float, Tensor],
stop: Union[int, float, Tensor], stop: Union[int, float, Tensor],
num: int = 100,
num: Union[int, Tensor],
dtype=np.float32, dtype=np.float32,
device: Optional[CompNode] = None, device: Optional[CompNode] = None,
comp_graph: Optional[CompGraph] = None, comp_graph: Optional[CompGraph] = None,


+ 22
- 0
python_module/test/unit/functional/test_functional.py View File

@@ -157,6 +157,28 @@ def test_broadcast_to():
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) opr_test(cases, F.broadcast_to, compare_fn=compare_fn)




def test_linspace():
cases = [
{"input": [1, 9, 9]},
{"input": [3, 10, 8]},
]
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
)

cases = [
{"input": [9, 1, 9]},
{"input": [10, 3, 8]},
]
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
)


def test_arange(): def test_arange():
cases = [ cases = [
{"input": [1, 9, 1]}, {"input": [1, 9, 1]},


Loading…
Cancel
Save