GitOrigin-RevId: d6c7eea2c2
tags/v0.4.0
@@ -515,7 +515,7 @@ def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor: | |||||
def linspace( | def linspace( | ||||
start: Union[int, float, Tensor], | start: Union[int, float, Tensor], | ||||
stop: Union[int, float, Tensor], | stop: Union[int, float, Tensor], | ||||
num: int = 100, | |||||
num: Union[int, Tensor], | |||||
dtype=np.float32, | dtype=np.float32, | ||||
device: Optional[CompNode] = None, | device: Optional[CompNode] = None, | ||||
comp_graph: Optional[CompGraph] = None, | comp_graph: Optional[CompGraph] = None, | ||||
@@ -157,6 +157,28 @@ def test_broadcast_to(): | |||||
opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | opr_test(cases, F.broadcast_to, compare_fn=compare_fn) | ||||
def test_linspace(): | |||||
cases = [ | |||||
{"input": [1, 9, 9]}, | |||||
{"input": [3, 10, 8]}, | |||||
] | |||||
opr_test( | |||||
cases, | |||||
F.linspace, | |||||
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||||
) | |||||
cases = [ | |||||
{"input": [9, 1, 9]}, | |||||
{"input": [10, 3, 8]}, | |||||
] | |||||
opr_test( | |||||
cases, | |||||
F.linspace, | |||||
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), | |||||
) | |||||
def test_arange(): | def test_arange(): | ||||
cases = [ | cases = [ | ||||
{"input": [1, 9, 1]}, | {"input": [1, 9, 1]}, | ||||