Browse Source

fix(mge/functional): fix F.topk(kth_only=True)

GitOrigin-RevId: ddecd1d14b
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
1040b77843
2 changed files with 51 additions and 11 deletions
  1. +15
    -11
      imperative/python/megengine/functional/math.py
  2. +36
    -0
      imperative/python/test/unit/functional/test_math.py

+ 15
- 11
imperative/python/megengine/functional/math.py View File

@@ -673,7 +673,7 @@ def topk(
:param descending: if True, return the largest elements instead. Default: False :param descending: if True, return the largest elements instead. Default: False
:param kth_only: if True, only the k-th element will be returned. Default: False :param kth_only: if True, only the k-th element will be returned. Default: False
:param no_sort: if True, the returned elements can be unordered. Default: False :param no_sort: if True, the returned elements can be unordered. Default: False
:return: tuple of two tensors `(topk_tensor, indices_of_int32)`.
:return: tuple of two tensors ``(topk_tensor, indices_of_int32)``


Examples: Examples:


@@ -695,7 +695,7 @@ def topk(


""" """
if descending: if descending:
inp = -inp
k = -k


if kth_only: if kth_only:
mode = "kth_only" mode = "kth_only"
@@ -709,21 +709,25 @@ def topk(
(k,) = Const(k, dtype="int32", device=inp.device)() (k,) = Const(k, dtype="int32", device=inp.device)()


if len(inp.shape) == 1: if len(inp.shape) == 1:
inp = inp.reshape(1, -1)
res = apply(op, inp, k)
if kth_only: if kth_only:
tns = res[0]
(tns,) = apply(op, expand_dims(inp, 0), k)
# FIXME:
# could use a dedicated kernel
# gradient may be routed to other indices if k-th value is not unique
ind = argmax((tns == inp).astype("int8"))
tns = squeeze(tns, 0)
else: else:
tns, ind = res[0][0], res[1][0]
tns, ind = apply(op, expand_dims(inp, 0), k)
tns = squeeze(tns, 0)
ind = squeeze(ind, 0)
else: else:
res = apply(op, inp, k)
if kth_only: if kth_only:
tns = res
(tns,) = apply(op, inp, k)
# FIXME: same as above
ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1)
else: else:
tns, ind = res[0], res[1]
tns, ind = apply(op, inp, k)


if descending:
tns = -tns
return tns, ind return tns, ind






+ 36
- 0
imperative/python/test/unit/functional/test_math.py View File

@@ -168,3 +168,39 @@ def test_has_inf():
data[0][0][0][0] = float("inf") data[0][0][0][0] = float("inf")
rst = F.math._has_inf(tensor(data)) rst = F.math._has_inf(tensor(data))
np.testing.assert_equal(rst.numpy(), [1]) np.testing.assert_equal(rst.numpy(), [1])


@pytest.mark.parametrize("descending", [True, False])
@pytest.mark.parametrize("sorted", [True, False])
@pytest.mark.parametrize("inp1d", [True, False])
@pytest.mark.parametrize("kth_only", [True, False])
def test_topk(descending, sorted, inp1d, kth_only):
k = 3
if inp1d:
data = np.random.permutation(7)
else:
data = np.random.permutation(5 * 7).reshape(5, 7)
data = data.astype(np.int32)

def np_sort(x):
if descending:
return np.sort(x)[..., ::-1]
return np.sort(x)

res = F.topk(
tensor(data), k, descending=descending, no_sort=(not sorted), kth_only=kth_only
)

values, indices = res
values = values.numpy()
indices = indices.numpy()
if kth_only:
np.testing.assert_equal(
values, np.take_along_axis(data, indices[..., None], -1).squeeze(-1)
)
np.testing.assert_equal(values, np_sort(data)[..., k - 1])
else:
np.testing.assert_equal(values, np.take_along_axis(data, indices, -1))
if not sorted:
values = np_sort(values)
np.testing.assert_equal(values, np_sort(data)[..., :k])

Loading…
Cancel
Save