GitOrigin-RevId: 57be5cec7b
release-1.1
@@ -72,7 +72,27 @@ __all__ = [ | |||||
] | ] | ||||
class _ElemwiseMode(Elemwise.Mode): | |||||
@classmethod | |||||
def __normalize(cls, val): | |||||
if isinstance(val, str): | |||||
if not hasattr(cls, "__member_upper_dict__"): | |||||
cls.__member_upper_dict__ = { | |||||
k.upper(): v for k, v in cls.__members__.items() | |||||
} | |||||
val = cls.__member_upper_dict__.get(val.upper(), val) | |||||
return val | |||||
@classmethod | |||||
def convert(cls, val): | |||||
val = cls.__normalize(val) | |||||
if isinstance(val, cls): | |||||
return val | |||||
return cls(val) | |||||
def _elwise(*args, mode): | def _elwise(*args, mode): | ||||
mode = _ElemwiseMode.convert(mode) | |||||
op = builtin.Elemwise(mode) | op = builtin.Elemwise(mode) | ||||
tensor_args = list( | tensor_args = list( | ||||
filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | filter(lambda x: isinstance(x, (Tensor, megbrain_graph.VarNode)), args) | ||||
@@ -73,11 +73,9 @@ class Elemwise(Module): | |||||
* "NOT": bool unary: ~x | * "NOT": bool unary: ~x | ||||
""" | """ | ||||
_elemwise_mode_type = P.Elemwise.Mode | |||||
def __init__(self, method): | def __init__(self, method): | ||||
super().__init__() | super().__init__() | ||||
self.method = self._elemwise_mode_type.convert(method) | |||||
self.method = method | |||||
def forward(self, *inps): | def forward(self, *inps): | ||||
return _elwise(*inps, mode=self.method) | return _elwise(*inps, mode=self.method) |
@@ -28,4 +28,4 @@ class Elemwise(Float.Elemwise, QATModule): | |||||
Return a :class:`~.QATModule` instance converted from | Return a :class:`~.QATModule` instance converted from | ||||
a float :class:`~.Module` instance. | a float :class:`~.Module` instance. | ||||
""" | """ | ||||
return cls(float_module.method.name) | |||||
return cls(float_module.method) |
@@ -33,4 +33,4 @@ class Elemwise(QuantizedModule): | |||||
Return a :class:`~.QuantizedModule` instance converted from a | Return a :class:`~.QuantizedModule` instance converted from a | ||||
:class:`~.QATModule` instance. | :class:`~.QATModule` instance. | ||||
""" | """ | ||||
return cls(qat_module.method.name, qat_module.get_activation_dtype()) | |||||
return cls(qat_module.method, qat_module.get_activation_dtype()) |
@@ -10,6 +10,7 @@ import numpy as np | |||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.functional.elemwise import _elwise | |||||
def test_abs(): | def test_abs(): | ||||
@@ -21,6 +22,17 @@ def test_abs(): | |||||
np.testing.assert_allclose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0))) | np.testing.assert_allclose(F.abs(-3.0).numpy(), np.abs(np.float32(-3.0))) | ||||
def test_elemwise_mode_string(): | |||||
np.testing.assert_allclose( | |||||
_elwise(tensor([-3.0, -4.0, -5.0]), mode="ABS").numpy(), | |||||
np.abs(np.array([-3.0, -4.0, -5.0], dtype=np.float32)), | |||||
) | |||||
np.testing.assert_allclose( | |||||
_elwise(-3.0, mode="ABS").numpy(), np.abs(np.float32(-3.0)) | |||||
) | |||||
def test_multiply(): | def test_multiply(): | ||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0)) | F.mul(-3.0, -4.0).numpy(), np.multiply(np.float32(-3.0), np.float32(-4.0)) | ||||
@@ -0,0 +1,30 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
import megengine.functional as F | |||||
from megengine import tensor | |||||
from megengine.module import Elemwise | |||||
def test_module_elemwise(): | |||||
def test_func(method, *inps): | |||||
elemwise = Elemwise(method) | |||||
outputs = elemwise(*inps) | |||||
return outputs.numpy() | |||||
x = np.random.rand(100).astype("float32") | |||||
y = np.random.rand(100).astype("float32") | |||||
x, y = tensor(x), tensor(y) | |||||
np.testing.assert_almost_equal( | |||||
test_func("H_SWISH", x), F.hswish(x).numpy(), decimal=6 | |||||
) | |||||
np.testing.assert_almost_equal( | |||||
test_func("ADD", x, y), F.add(x, y).numpy(), decimal=6 | |||||
) |