From 37e56f4b04174837c922a140c9111fa2828466dc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 13 Oct 2020 16:10:39 +0800 Subject: [PATCH] fix(mge/functional): fix clip under trace(symbolic=True) GitOrigin-RevId: 5b6f5373270bf4699574feacc4b391b08ecdf6e9 --- imperative/python/megengine/functional/elemwise.py | 4 +++- imperative/python/test/unit/test_tracing.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 3b71291c..686ddf4c 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -13,6 +13,7 @@ from ..core.ops import builtin from ..core.tensor import megbrain_graph, utils from ..core.tensor.core import apply from ..device import get_default_device +from ..jit.tracing import is_tracing from ..tensor import Tensor __all__ = [ @@ -580,7 +581,8 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor: ), "At least one of 'lower' or 'upper' must not be None" if lower is not None: if upper is not None: - assert lower <= upper, "clip lower bound is bigger that upper bound" + if not is_tracing(): + assert lower <= upper, "clip lower bound is bigger that upper bound" return minimum(maximum(x, lower), upper) else: return maximum(x, lower) diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index bca796a3..abc2463f 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -394,3 +394,15 @@ def test_trace_valid_broadcast(): f(x1, shape) f(x2, shape) + + +def test_clip(): + x = tensor(np.random.randn(10, 10)) + + @trace(symbolic=True) + def f(x, lower, upper): + y = F.clip(x, lower, upper) + return y + + for i in range(3): + f(x, tensor([0]), tensor([1]))