Browse Source

feat(mge/functional): add elemwise mode support string input

GitOrigin-RevId: 57be5cec7b
release-1.1
Megvii Engine Team 4 years ago
parent
commit
9cb3c07cd4
6 changed files with 65 additions and 5 deletions
  1. +20
    -0
      imperative/python/megengine/functional/elemwise.py
  2. +1
    -3
      imperative/python/megengine/module/elemwise.py
  3. +1
    -1
      imperative/python/megengine/module/qat/elemwise.py
  4. +1
    -1
      imperative/python/megengine/module/quantized/elemwise.py
  5. +12
    -0
      imperative/python/test/unit/functional/test_elemwise.py
  6. +30
    -0
      imperative/python/test/unit/module/test_elemwise.py

+ 20
- 0
imperative/python/megengine/functional/elemwise.py View File

@@ -72,7 +72,27 @@ __all__ = [
]


class _ElemwiseMode(Elemwise.Mode):
@classmethod
def __normalize(cls, val):
if isinstance(val, str):
if not hasattr(cls, "__member_upper_dict__"):
cls.__member_upper_dict__ = {
k.upper(): v for k, v in cls.__members__.items()
}
val = cls.__member_upper_dict__.get(val.upper(), val)
return val

@classmethod
def convert(cls, val):
val = cls.__normalize(val)
if isinstance(val, cls):
return val
return cls(val)


def _elwise(*args, mode):
mode = _ElemwiseMode.convert(mode)
op = builtin.Elemwise(mode)
tensor_args = list(
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args)


+ 1
- 3
imperative/python/megengine/module/elemwise.py View File

@@ -73,11 +73,9 @@ class Elemwise(Module):
* "NOT": bool unary: ~x
"""

_elemwise_mode_type = P.Elemwise.Mode

def __init__(self, method):
super().__init__()
self.method = self._elemwise_mode_type.convert(method)
self.method = method

def forward(self, *inps):
return _elwise(*inps, mode=self.method)

+ 1
- 1
imperative/python/megengine/module/qat/elemwise.py View File

@@ -28,4 +28,4 @@ class Elemwise(Float.Elemwise, QATModule):
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls(float_module.method.name)
return cls(float_module.method)

+ 1
- 1
imperative/python/megengine/module/quantized/elemwise.py View File

@@ -33,4 +33,4 @@ class Elemwise(QuantizedModule):
Return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.method.name, qat_module.get_activation_dtype())
return cls(qat_module.method, qat_module.get_activation_dtype())

+ 12
- 0
imperative/python/test/unit/functional/test_elemwise.py View File

@@ -10,6 +10,7 @@ import numpy as np

import megengine.functional as F
from megengine import tensor
from megengine.functional.elemwise import _elwise


def test_abs():
@@ -21,6 +22,17 @@ def test_abs():
np.testing.assert_allclose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0)))


def test_elemwise_mode_string():
np.testing.assert_allclose(
_elwise(tensor([-3.0, -4.0, -5.0]), mode="ABS").numpy(),
np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)),
)

np.testing.assert_allclose(
_elwise(-3.0, mode="ABS").numpy(), np.abs(np.float32(-3.0))
)


def test_multiply():
np.testing.assert_allclose(
F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0))


+ 30
- 0
imperative/python/test/unit/module/test_elemwise.py View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

import megengine.functional as F
from megengine import tensor
from megengine.module import Elemwise


def test_module_elemwise():
def test_func(method, *inps):
elemwise = Elemwise(method)
outputs = elemwise(*inps)
return outputs.numpy()

x = np.random.rand(100).astype("float32")
y = np.random.rand(100).astype("float32")
x, y = tensor(x), tensor(y)
np.testing.assert_almost_equal(
test_func("H_SWISH", x), F.hswish(x).numpy(), decimal=6
)
np.testing.assert_almost_equal(
test_func("ADD", x, y), F.add(x, y).numpy(), decimal=6
)

Loading…
Cancel
Save