From af349d6102bd17be3881ae9ccdee057bf4ab546b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 29 Sep 2020 14:48:57 +0800 Subject: [PATCH] fix(mge/functional): fix op mismatch when tracing NMSKeep GitOrigin-RevId: e8f2cbb7557b7482df936faca80f4fcc15eef22b --- imperative/python/megengine/functional/nn.py | 18 +++++++++++++++--- imperative/python/megengine/jit/tracing.py | 7 +++++++ imperative/python/test/unit/test_tracing.py | 22 ++++++++++++++++++++++ 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 6ba8ee0e..7d120b8e 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -17,6 +17,7 @@ from ..core.tensor import megbrain_graph, utils from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.utils import astensor1d from ..distributed import WORLD, is_distributed +from ..jit.tracing import is_tracing from ..random import uniform from ..tensor import Tensor from .debug_param import get_conv_execution_strategy @@ -1470,13 +1471,17 @@ def indexing_one_hot( return result -def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor: +def nms( + boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None +) -> Tensor: r""" Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU). :param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format. :param iou_thresh: IoU threshold for overlapping. :param scores: tensor of shape `(N,)`, the score of boxes. + :param max_output: the maximum number of boxes to keep; it is optional if this operator is not traced + otherwise it required to be specified; if it is not specified, all boxes are kept. :return: indices of the elements that have been kept by NMS. Examples: @@ -1515,12 +1520,19 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor: scores = scores.detach() sorted_idx = argsort(scores, descending=True) boxes = boxes[sorted_idx] - max_output = boxes.shape[0] + + if is_tracing(): + assert ( + max_output is not None and max_output > 0 + ), "max_output should be specified under tracing" + + if max_output is None: + max_output = boxes.shape[0] op = builtin.NMSKeep(iou_thresh, max_output) inp = utils.convert_inputs(boxes.reshape(1, -1, 4)) indices, count = apply(op, *inp) - indices = indices[0][: count.item()] + indices = indices[0][: count[0]] keep_inds = sorted_idx[indices] return keep_inds diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 1fe93c71..6fe59b46 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -36,6 +36,13 @@ active_trace = None skip_tracing = False +def is_tracing(): + if active_trace is None: + return False + else: + return not skip_tracing + + @contextlib.contextmanager def exclude_from_trace(): global skip_tracing diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index d54206d9..805d0121 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -357,3 +357,25 @@ def test_trace_broadcast(): f(x1) f(x2) f(x3) + + +def test_trace_nms(): + def make_inputs(n): + boxes = np.zeros((n, 4)) + boxes[:, :2] = np.random.rand(n, 2) * 100 + boxes[:, 2:] = np.random.rand(n, 2) * 100 + 100 + + scores = np.random.rand(n) + + return tensor(boxes), tensor(scores) + + @trace(symbolic=False) + def f(boxes, scores): + results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20) + with exclude_from_trace(): + _ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5) + return results + + f(*make_inputs(10)) + f(*make_inputs(20)) + f(*make_inputs(30))