Browse Source

fix(mge/functional): fix linspace device and open other trace tests

GitOrigin-RevId: 4667c4adec
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
4485e780ae
3 changed files with 27 additions and 23 deletions
  1. +15
    -4
      imperative/python/megengine/functional/tensor.py
  2. +1
    -7
      imperative/python/test/unit/functional/test_functional.py
  3. +11
    -12
      imperative/python/test/unit/functional/test_tensor.py

+ 15
- 4
imperative/python/megengine/functional/tensor.py View File

@@ -910,7 +910,7 @@ def linspace(
import numpy as np import numpy as np
import megengine.functional as F import megengine.functional as F


a = F.linspace(3,10,5)
a = F.linspace(3, 10, 5)
print(a.numpy()) print(a.numpy())


Outputs: Outputs:
@@ -920,9 +920,20 @@ def linspace(
[ 3. 4.75 6.5 8.25 10. ] [ 3. 4.75 6.5 8.25 10. ]


""" """
start = Tensor(start, device=device)
stop = Tensor(stop, device=device)
num = Tensor(num, device=device)
for item in (start, stop, num):
cur_device = getattr(item, "device", None)
if device is None:
device = cur_device
else:
if not (cur_device is None or device == cur_device):
raise ("ambiguous device for linspace opr")

if not isinstance(start, Tensor):
start = Tensor(start, device=device)
if not isinstance(stop, Tensor):
stop = Tensor(stop, device=device)
if not isinstance(num, Tensor):
num = Tensor(num, device=device)


op = builtin.Linspace(comp_node=device) op = builtin.Linspace(comp_node=device)
(result,) = apply(op, start, stop, num) (result,) = apply(op, start, stop, num)


+ 1
- 7
imperative/python/test/unit/functional/test_functional.py View File

@@ -114,18 +114,12 @@ def test_matmul():
{"input": [data3, data4]}, {"input": [data3, data4]},
{"input": [data4, data5]}, {"input": [data4, data5]},
] ]
for _ in range(0, batch_size):
# FIXME: remove test_trace=False in the future
opr_test(
cases, F.matmul, test_trace=False, ref_fn=np.matmul,
)
opr_test(cases, F.matmul, ref_fn=np.matmul)


# FIXME: remove test_trace=False in the future
opr_test( opr_test(
[{"input": [data1, data4]}], [{"input": [data1, data4]}],
F.matmul, F.matmul,
ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
test_trace=False,
transpose_b=True, transpose_b=True,
) )




+ 11
- 12
imperative/python/test/unit/functional/test_tensor.py View File

@@ -162,24 +162,30 @@ def test_linspace():
{"input": [1, 9, 9]}, {"input": [1, 9, 9]},
{"input": [3, 10, 8]}, {"input": [3, 10, 8]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
test_trace=False,
) )


cases = [ cases = [
{"input": [9, 1, 9]}, {"input": [9, 1, 9]},
{"input": [10, 3, 8]}, {"input": [10, 3, 8]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
test_trace=False,
)

cases = [
{"input": [1, tensor(9), 9]},
{"input": [tensor(1), 9, tensor(9)]},
]
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
) )




@@ -188,36 +194,30 @@ def test_arange():
{"input": [1, 9, 1]}, {"input": [1, 9, 1]},
{"input": [2, 10, 2]}, {"input": [2, 10, 2]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
) )


cases = [ cases = [
{"input": [9, 1, -1]}, {"input": [9, 1, -1]},
{"input": [10, 2, -2]}, {"input": [10, 2, -2]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
) )


cases = [ cases = [
{"input": [9.3, 1.2, -0.5]}, {"input": [9.3, 1.2, -0.5]},
{"input": [10.3, 2.1, -1.7]}, {"input": [10.3, 2.1, -1.7]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
) )




@@ -289,8 +289,7 @@ def test_broadcast():
{"input": [data1, output1_shape], "output": output1_shape}, {"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape}, {"input": [data2, output2_shape], "output": output2_shape},
] ]
# FIXME: remove test_trace=False in the future
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, test_trace=False)
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)


x = F.ones((2, 1, 3)) x = F.ones((2, 1, 3))
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):


Loading…
Cancel
Save