|
|
@@ -44,3 +44,12 @@ def test_multiply(): |
|
|
|
np.array([3.0, 4.0], dtype=np.float32), |
|
|
|
), |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def test_clamp(): |
|
|
|
"""Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and |
|
|
|
`F.clamp` will fall into wrong conditions unexpectedly. |
|
|
|
""" |
|
|
|
x = np.linspace(-6, 6, dtype="float32") |
|
|
|
assertTensorClose(F.clamp(tensor(x) + 3, 0, 6).numpy(), np.clip(x + 3, 0, 6)) |
|
|
|
assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0)) |