From bb92ee26e302c09e1dcdf899e15f2cdbe4af9634 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 20 Apr 2020 13:20:12 +0800 Subject: [PATCH] fix(mge/functional): refine doc and add test for linspace GitOrigin-RevId: d6c7eea2c23ea37fa3cda668132aa6360fb24e51 --- python_module/megengine/functional/tensor.py | 2 +- .../test/unit/functional/test_functional.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/python_module/megengine/functional/tensor.py b/python_module/megengine/functional/tensor.py index 31aedc27..2cc58a57 100644 --- a/python_module/megengine/functional/tensor.py +++ b/python_module/megengine/functional/tensor.py @@ -515,7 +515,7 @@ def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor: def linspace( start: Union[int, float, Tensor], stop: Union[int, float, Tensor], - num: int = 100, + num: Union[int, Tensor], dtype=np.float32, device: Optional[CompNode] = None, comp_graph: Optional[CompGraph] = None, diff --git a/python_module/test/unit/functional/test_functional.py b/python_module/test/unit/functional/test_functional.py index 3b88cd8c..b9f0cebf 100644 --- a/python_module/test/unit/functional/test_functional.py +++ b/python_module/test/unit/functional/test_functional.py @@ -157,6 +157,28 @@ def test_broadcast_to(): opr_test(cases, F.broadcast_to, compare_fn=compare_fn) +def test_linspace(): + cases = [ + {"input": [1, 9, 9]}, + {"input": [3, 10, 8]}, + ] + opr_test( + cases, + F.linspace, + ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), + ) + + cases = [ + {"input": [9, 1, 9]}, + {"input": [10, 3, 8]}, + ] + opr_test( + cases, + F.linspace, + ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), + ) + + def test_arange(): cases = [ {"input": [1, 9, 1]},