Browse Source

fix(mge/functional): fix op mismatch when tracing NMSKeep

GitOrigin-RevId: e8f2cbb755
release-1.1
Megvii Engine Team 4 years ago
parent
commit
af349d6102
3 changed files with 44 additions and 3 deletions
  1. +15
    -3
      imperative/python/megengine/functional/nn.py
  2. +7
    -0
      imperative/python/megengine/jit/tracing.py
  3. +22
    -0
      imperative/python/test/unit/test_tracing.py

+ 15
- 3
imperative/python/megengine/functional/nn.py View File

@@ -17,6 +17,7 @@ from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.utils import astensor1d
from ..distributed import WORLD, is_distributed
from ..jit.tracing import is_tracing
from ..random import uniform
from ..tensor import Tensor
from .debug_param import get_conv_execution_strategy
@@ -1470,13 +1471,17 @@ def indexing_one_hot(
return result


def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor:
def nms(
boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None
) -> Tensor:
r"""
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU).

:param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format.
:param iou_thresh: IoU threshold for overlapping.
:param scores: tensor of shape `(N,)`, the score of boxes.
:param max_output: the maximum number of boxes to keep; it is optional if this operator is not traced
otherwise it required to be specified; if it is not specified, all boxes are kept.
:return: indices of the elements that have been kept by NMS.

Examples:
@@ -1515,12 +1520,19 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor:
scores = scores.detach()
sorted_idx = argsort(scores, descending=True)
boxes = boxes[sorted_idx]
max_output = boxes.shape[0]

if is_tracing():
assert (
max_output is not None and max_output > 0
), "max_output should be specified under tracing"

if max_output is None:
max_output = boxes.shape[0]

op = builtin.NMSKeep(iou_thresh, max_output)
inp = utils.convert_inputs(boxes.reshape(1, -1, 4))
indices, count = apply(op, *inp)
indices = indices[0][: count.item()]
indices = indices[0][: count[0]]
keep_inds = sorted_idx[indices]
return keep_inds



+ 7
- 0
imperative/python/megengine/jit/tracing.py View File

@@ -36,6 +36,13 @@ active_trace = None
skip_tracing = False


def is_tracing():
if active_trace is None:
return False
else:
return not skip_tracing


@contextlib.contextmanager
def exclude_from_trace():
global skip_tracing


+ 22
- 0
imperative/python/test/unit/test_tracing.py View File

@@ -357,3 +357,25 @@ def test_trace_broadcast():
f(x1)
f(x2)
f(x3)


def test_trace_nms():
def make_inputs(n):
boxes = np.zeros((n, 4))
boxes[:, :2] = np.random.rand(n, 2) * 100
boxes[:, 2:] = np.random.rand(n, 2) * 100 + 100

scores = np.random.rand(n)

return tensor(boxes), tensor(scores)

@trace(symbolic=False)
def f(boxes, scores):
results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20)
with exclude_from_trace():
_ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5)
return results

f(*make_inputs(10))
f(*make_inputs(20))
f(*make_inputs(30))

Loading…
Cancel
Save