|
|
@@ -673,7 +673,7 @@ def topk( |
|
|
|
:param descending: if True, return the largest elements instead. Default: False |
|
|
|
:param kth_only: if True, only the k-th element will be returned. Default: False |
|
|
|
:param no_sort: if True, the returned elements can be unordered. Default: False |
|
|
|
:return: tuple of two tensors `(topk_tensor, indices_of_int32)`. |
|
|
|
:return: tuple of two tensors ``(topk_tensor, indices_of_int32)`` |
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
|
@@ -695,7 +695,7 @@ def topk( |
|
|
|
|
|
|
|
""" |
|
|
|
if descending: |
|
|
|
inp = -inp |
|
|
|
k = -k |
|
|
|
|
|
|
|
if kth_only: |
|
|
|
mode = "kth_only" |
|
|
@@ -709,21 +709,25 @@ def topk( |
|
|
|
(k,) = Const(k, dtype="int32", device=inp.device)() |
|
|
|
|
|
|
|
if len(inp.shape) == 1: |
|
|
|
inp = inp.reshape(1, -1) |
|
|
|
res = apply(op, inp, k) |
|
|
|
if kth_only: |
|
|
|
tns = res[0] |
|
|
|
(tns,) = apply(op, expand_dims(inp, 0), k) |
|
|
|
# FIXME: |
|
|
|
# could use a dedicated kernel |
|
|
|
# gradient may be routed to other indices if k-th value is not unique |
|
|
|
ind = argmax((tns == inp).astype("int8")) |
|
|
|
tns = squeeze(tns, 0) |
|
|
|
else: |
|
|
|
tns, ind = res[0][0], res[1][0] |
|
|
|
tns, ind = apply(op, expand_dims(inp, 0), k) |
|
|
|
tns = squeeze(tns, 0) |
|
|
|
ind = squeeze(ind, 0) |
|
|
|
else: |
|
|
|
res = apply(op, inp, k) |
|
|
|
if kth_only: |
|
|
|
tns = res |
|
|
|
(tns,) = apply(op, inp, k) |
|
|
|
# FIXME: same as above |
|
|
|
ind = argmax((expand_dims(tns, 1) == inp).astype("int8"), 1) |
|
|
|
else: |
|
|
|
tns, ind = res[0], res[1] |
|
|
|
tns, ind = apply(op, inp, k) |
|
|
|
|
|
|
|
if descending: |
|
|
|
tns = -tns |
|
|
|
return tns, ind |
|
|
|
|
|
|
|
|
|
|
|