Browse Source

fix(mge/clamp): fix `F.clamp`

GitOrigin-RevId: 1efac8add6
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
207527d108
2 changed files with 14 additions and 3 deletions
  1. +5
    -3
      python_module/megengine/functional/elemwise.py
  2. +9
    -0
      python_module/test/unit/functional/test_elemwise.py

+ 5
- 3
python_module/megengine/functional/elemwise.py View File

@@ -233,9 +233,11 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor:
[0 1 2 3 3]

"""
assert lower or upper, "At least one of 'lower' or 'upper' must not be None"
if lower:
if upper:
assert (
lower is not None or upper is not None
), "At least one of 'lower' or 'upper' must not be None"
if lower is not None:
if upper is not None:
assert lower <= upper, "clamp lower bound is bigger that upper bound"
return minimum(maximum(inp, lower), upper)
else:


+ 9
- 0
python_module/test/unit/functional/test_elemwise.py View File

@@ -44,3 +44,12 @@ def test_multiply():
np.array([3.0, 4.0], dtype=np.float32),
),
)


def test_clamp():
"""Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and
`F.clamp` will fall into wrong conditions unexpectedly.
"""
x = np.linspace(-6, 6, dtype="float32")
assertTensorClose(F.clamp(tensor(x) + 3, 0, 6).numpy(), np.clip(x + 3, 0, 6))
assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0))

Loading…
Cancel
Save