|
|
@@ -11,6 +11,7 @@ import itertools |
|
|
|
import numpy as np |
|
|
|
import pytest |
|
|
|
|
|
|
|
import megengine.core.ops.builtin as builtin |
|
|
|
import megengine.core.tensor.dtype as dtype |
|
|
|
import megengine.functional as F |
|
|
|
from megengine import Buffer, Parameter, is_cuda_available, tensor |
|
|
@@ -631,3 +632,20 @@ def test_condtake(): |
|
|
|
val, idx = F.cond_take(yy, xx) |
|
|
|
np.testing.assert_equal(val.numpy(), x[y]) |
|
|
|
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) |
|
|
|
|
|
|
|
|
|
|
|
def test_condtake_is_same(): |
|
|
|
op1 = builtin.CondTake() |
|
|
|
op2 = builtin.CondTake() |
|
|
|
assert op1 == op2 |
|
|
|
|
|
|
|
|
|
|
|
def test_nms_is_same(): |
|
|
|
op1 = builtin.NMSKeep(0.7, 100) |
|
|
|
op2 = builtin.NMSKeep(0.7, 100) |
|
|
|
op3 = builtin.NMSKeep(0.8, 100) |
|
|
|
op4 = builtin.NMSKeep(0.7, 200) |
|
|
|
assert op1 == op2 |
|
|
|
assert op1 != op3 |
|
|
|
assert op1 != op4 |
|
|
|
assert op3 != op4 |