Browse Source

fix(mge/functional): support scalar inputs in elemwise functions

GitOrigin-RevId: 7bce561ee1
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
8d2bbf7383
2 changed files with 48 additions and 1 deletions
  1. +8
    -1
      python_module/megengine/functional/elemwise.py
  2. +40
    -0
      python_module/test/unit/functional/test_elemwise.py

+ 8
- 1
python_module/megengine/functional/elemwise.py View File

@@ -11,6 +11,7 @@ import functools

import megengine._internal as mgb

from ..core.graph import _use_default_if_none
from ..core.tensor import Tensor, wrap_io_tensor

__all__ = [
@@ -45,11 +46,17 @@ __all__ = [

def _elemwise(mode): # DONT export
"""Decorator helps to wrap megbrain element-wise oprs"""

def elemwise_decorator(func):
@functools.wraps(func)
@wrap_io_tensor
def elemwise_func(*inputs) -> Tensor:
if all(isinstance(i, (int,float)) for i in inputs):
device, comp_graph = _use_default_if_none(None, None)
ret = mgb.opr.elemwise(*inputs,
mode=mode,
comp_node=device,
comp_graph=comp_graph)
return ret.inferred_value[0]
return mgb.opr.elemwise(*inputs, mode=mode)

return elemwise_func


+ 40
- 0
python_module/test/unit/functional/test_elemwise.py View File

@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

import megengine.functional as F
from megengine import tensor

from megengine.test import assertTensorClose


def test_abs():
assertTensorClose(
F.abs(tensor([-3., -4., -5.])).numpy(),
np.abs(np.array([-3., -4., -5.], dtype=np.float32)))

assertTensorClose(F.abs(-3.), np.abs(np.float32(-3.)))


def test_multiply():
assertTensorClose(F.multiply(-3., -4.),
np.multiply(np.float32(-3.), np.float32(-4.)))

assertTensorClose(
F.multiply(tensor([3., 4.]), 4.).numpy(),
np.multiply(np.array([3., 4.], dtype=np.float32), 4.))

assertTensorClose(
F.multiply(4., tensor([3., 4.])).numpy(),
np.multiply(4., np.array([3., 4.], dtype=np.float32)))

assertTensorClose(
F.multiply(tensor([3., 4.]), tensor([3., 4.])).numpy(),
np.multiply(np.array([3., 4.], dtype=np.float32),
np.array([3., 4.], dtype=np.float32)))

Loading…
Cancel
Save