@@ -233,9 +233,11 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: | |||||
[0 1 2 3 3] | [0 1 2 3 3] | ||||
""" | """ | ||||
assert lower or upper, "At least one of 'lower' or 'upper' must not be None" | |||||
if lower: | |||||
if upper: | |||||
assert ( | |||||
lower is not None or upper is not None | |||||
), "At least one of 'lower' or 'upper' must not be None" | |||||
if lower is not None: | |||||
if upper is not None: | |||||
assert lower <= upper, "clamp lower bound is bigger that upper bound" | assert lower <= upper, "clamp lower bound is bigger that upper bound" | ||||
return minimum(maximum(inp, lower), upper) | return minimum(maximum(inp, lower), upper) | ||||
else: | else: | ||||
@@ -44,3 +44,12 @@ def test_multiply(): | |||||
np.array([3.0, 4.0], dtype=np.float32), | np.array([3.0, 4.0], dtype=np.float32), | ||||
), | ), | ||||
) | ) | ||||
def test_clamp(): | |||||
"""Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and | |||||
`F.clamp` will fall into wrong conditions unexpectedly. | |||||
""" | |||||
x = np.linspace(-6, 6, dtype="float32") | |||||
assertTensorClose(F.clamp(tensor(x) + 3, 0, 6).numpy(), np.clip(x + 3, 0, 6)) | |||||
assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0)) |