Browse Source

refactor(mge/functional): remove dependence to trace in functional implementations

GitOrigin-RevId: 0b18479fcc
release-1.4
Megvii Engine Team 4 years ago
parent
commit
a29bf679a8
2 changed files with 6 additions and 9 deletions
  1. +2
    -3
      imperative/python/megengine/functional/elemwise.py
  2. +4
    -6
      imperative/python/megengine/functional/vision.py

+ 2
- 3
imperative/python/megengine/functional/elemwise.py View File

@@ -16,7 +16,6 @@ from ..core.tensor import utils
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import astype
from ..device import get_default_device
from ..jit.tracing import is_tracing
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func

@@ -560,8 +559,8 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor:
), "At least one of 'lower' or 'upper' must not be None"
if lower is not None:
if upper is not None:
if not is_tracing():
assert lower <= upper, "clip lower bound is bigger that upper bound"
# FIXME: following assertion won't work during trace if upper and lower are Tensors
# assert lower <= upper, "clip lower bound is bigger that upper bound"
return minimum(maximum(x, lower), upper)
else:
return maximum(x, lower)


+ 4
- 6
imperative/python/megengine/functional/vision.py View File

@@ -12,7 +12,6 @@ from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.utils import astensor1d
from ..jit.tracing import is_tracing
from ..tensor import Tensor
from .elemwise import floor
from .math import argsort
@@ -226,6 +225,10 @@ def nms(
otherwise it required to be specified; if it is not specified, all boxes are kept.
:return: indices of the elements that have been kept by NMS, sorted by scores.

.. note::

max_output should be specified and should have valid positive value under tracing

Examples:

.. testcode::
@@ -263,11 +266,6 @@ def nms(
sorted_idx = argsort(scores, descending=True)
boxes = boxes[sorted_idx]

if is_tracing():
assert (
max_output is not None and max_output > 0
), "max_output should be specified under tracing"

if max_output is None:
max_output = boxes.shape[0]



Loading…
Cancel
Save