From dbce6526d685f58e28049a8373211557053e5138 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 26 Apr 2022 19:31:34 +0800 Subject: [PATCH] fix(mge/functional): fix return dtype of comparison function GitOrigin-RevId: 810e32a829ea2b1d0835b3791a6521549f867de1 --- imperative/python/megengine/functional/elemwise.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 2d33414f..b77feced 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -626,7 +626,7 @@ def logaddexp(x: Tensor, y: Tensor) -> Tensor: def equal(x, y): r"""Element-wise `(x == y)`.""" - return _elwise(x, y, mode=Elemwise.Mode.EQ) + return x == y def not_equal(x, y): @@ -636,22 +636,22 @@ def not_equal(x, y): def less(x, y): r"""Element-wise `(x < y)`.""" - return _elwise(x, y, mode=Elemwise.Mode.LT) + return x < y def less_equal(x, y): r"""Element-wise `(x <= y)`.""" - return _elwise(x, y, mode=Elemwise.Mode.LEQ) + return x <= y def greater(x, y): r"""Element-wise `(x > y)`.""" - return _elwise(y, x, mode=Elemwise.Mode.LT) + return x > y def greater_equal(x, y): r"""Element-wise `(x >= y)`.""" - return _elwise(y, x, mode=Elemwise.Mode.LEQ) + return x >= y # other functions