Browse Source

fix(mge/functional): fix trace topk

GitOrigin-RevId: c88ca8219b
release-1.1
Megvii Engine Team 4 years ago
parent
commit
f4860b9345
1 changed files with 7 additions and 3 deletions
  1. +7
    -3
      imperative/python/megengine/functional/math.py

+ 7
- 3
imperative/python/megengine/functional/math.py View File

@@ -14,8 +14,9 @@ from typing import Optional, Sequence, Tuple, Union

from ..core.ops import builtin
from ..core.ops._internal import param_defs as P
from ..core.ops.special import Const
from ..core.tensor import utils
from ..core.tensor.core import apply
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor import Tensor
from .elemwise import clamp, exp, log, log1p
from .tensor import add_axis, remove_axis, reshape
@@ -665,15 +666,18 @@ def topk(
mode = Mode.VALUE_IDX_SORTED
op = builtin.TopK(mode=mode)

if not isinstance(k, (TensorBase, TensorWrapperBase)):
(k,) = Const(k, dtype="int32", device=inp.device)(inp)

if len(inp.shape) == 1:
inp = inp.reshape(1, -1)
res = apply(op, inp, Tensor(k, dtype="int32"))
res = apply(op, inp, k)
if kth_only:
tns = res[0]
else:
tns, ind = res[0][0], res[1][0]
else:
res = apply(op, inp, Tensor(k, dtype="int32"))
res = apply(op, inp, k)
if kth_only:
tns = res
else:


Loading…
Cancel
Save