diff --git a/python_module/megengine/functional/elemwise.py b/python_module/megengine/functional/elemwise.py index 4e23287e..80d9550a 100644 --- a/python_module/megengine/functional/elemwise.py +++ b/python_module/megengine/functional/elemwise.py @@ -11,6 +11,7 @@ import functools import megengine._internal as mgb +from ..core.graph import _use_default_if_none from ..core.tensor import Tensor, wrap_io_tensor __all__ = [ @@ -45,11 +46,17 @@ __all__ = [ def _elemwise(mode): # DONT export """Decorator helps to wrap megbrain element-wise oprs""" - def elemwise_decorator(func): @functools.wraps(func) @wrap_io_tensor def elemwise_func(*inputs) -> Tensor: + if all(isinstance(i, (int,float)) for i in inputs): + device, comp_graph = _use_default_if_none(None, None) + ret = mgb.opr.elemwise(*inputs, + mode=mode, + comp_node=device, + comp_graph=comp_graph) + return ret.inferred_value[0] return mgb.opr.elemwise(*inputs, mode=mode) return elemwise_func diff --git a/python_module/test/unit/functional/test_elemwise.py b/python_module/test/unit/functional/test_elemwise.py new file mode 100644 index 00000000..67dc84d8 --- /dev/null +++ b/python_module/test/unit/functional/test_elemwise.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np + +import megengine.functional as F +from megengine import tensor + +from megengine.test import assertTensorClose + + +def test_abs(): + assertTensorClose( + F.abs(tensor([-3., -4., -5.])).numpy(), + np.abs(np.array([-3., -4., -5.], dtype=np.float32))) + + assertTensorClose(F.abs(-3.), np.abs(np.float32(-3.))) + + +def test_multiply(): + assertTensorClose(F.multiply(-3., -4.), + np.multiply(np.float32(-3.), np.float32(-4.))) + + assertTensorClose( + F.multiply(tensor([3., 4.]), 4.).numpy(), + np.multiply(np.array([3., 4.], dtype=np.float32), 4.)) + + assertTensorClose( + F.multiply(4., tensor([3., 4.])).numpy(), + np.multiply(4., np.array([3., 4.], dtype=np.float32))) + + assertTensorClose( + F.multiply(tensor([3., 4.]), tensor([3., 4.])).numpy(), + np.multiply(np.array([3., 4.], dtype=np.float32), + np.array([3., 4.], dtype=np.float32)))