|
|
@@ -14,8 +14,9 @@ from typing import Optional, Sequence, Tuple, Union |
|
|
|
|
|
|
|
from ..core.ops import builtin |
|
|
|
from ..core.ops._internal import param_defs as P |
|
|
|
from ..core.ops.special import Const |
|
|
|
from ..core.tensor import utils |
|
|
|
from ..core.tensor.core import apply |
|
|
|
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply |
|
|
|
from ..tensor import Tensor |
|
|
|
from .elemwise import clamp, exp, log, log1p |
|
|
|
from .tensor import add_axis, remove_axis, reshape |
|
|
@@ -665,15 +666,18 @@ def topk( |
|
|
|
mode = Mode.VALUE_IDX_SORTED |
|
|
|
op = builtin.TopK(mode=mode) |
|
|
|
|
|
|
|
if not isinstance(k, (TensorBase, TensorWrapperBase)): |
|
|
|
(k,) = Const(k, dtype="int32", device=inp.device)(inp) |
|
|
|
|
|
|
|
if len(inp.shape) == 1: |
|
|
|
inp = inp.reshape(1, -1) |
|
|
|
res = apply(op, inp, Tensor(k, dtype="int32")) |
|
|
|
res = apply(op, inp, k) |
|
|
|
if kth_only: |
|
|
|
tns = res[0] |
|
|
|
else: |
|
|
|
tns, ind = res[0][0], res[1][0] |
|
|
|
else: |
|
|
|
res = apply(op, inp, Tensor(k, dtype="int32")) |
|
|
|
res = apply(op, inp, k) |
|
|
|
if kth_only: |
|
|
|
tns = res |
|
|
|
else: |
|
|
|