Browse Source

fix(mge/functional): fix clip under trace(symbolic=True)

GitOrigin-RevId: 5b6f537327
release-1.1
Megvii Engine Team 4 years ago
parent
commit
37e56f4b04
2 changed files with 15 additions and 1 deletions
  1. +3
    -1
      imperative/python/megengine/functional/elemwise.py
  2. +12
    -0
      imperative/python/test/unit/test_tracing.py

+ 3
- 1
imperative/python/megengine/functional/elemwise.py View File

@@ -13,6 +13,7 @@ from ..core.ops import builtin
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import apply
from ..device import get_default_device
from ..jit.tracing import is_tracing
from ..tensor import Tensor

__all__ = [
@@ -580,7 +581,8 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor:
), "At least one of 'lower' or 'upper' must not be None"
if lower is not None:
if upper is not None:
assert lower <= upper, "clip lower bound is bigger that upper bound"
if not is_tracing():
assert lower <= upper, "clip lower bound is bigger that upper bound"
return minimum(maximum(x, lower), upper)
else:
return maximum(x, lower)


+ 12
- 0
imperative/python/test/unit/test_tracing.py View File

@@ -394,3 +394,15 @@ def test_trace_valid_broadcast():

f(x1, shape)
f(x2, shape)


def test_clip():
x = tensor(np.random.randn(10, 10))

@trace(symbolic=True)
def f(x, lower, upper):
y = F.clip(x, lower, upper)
return y

for i in range(3):
f(x, tensor([0]), tensor([1]))

Loading…
Cancel
Save