Browse Source

chore(mge/functional): add compatible code for functional api

GitOrigin-RevId: 3b2f829cc5
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
e9982d613f
5 changed files with 67 additions and 38 deletions
  1. +1
    -0
      imperative/python/megengine/functional/__init__.py
  2. +8
    -0
      imperative/python/megengine/functional/elemwise.py
  3. +17
    -38
      imperative/python/megengine/functional/nn.py
  4. +9
    -0
      imperative/python/megengine/functional/utils.py
  5. +32
    -0
      imperative/python/megengine/utils/deprecation.py

+ 1
- 0
imperative/python/megengine/functional/__init__.py View File

@@ -12,6 +12,7 @@ from .elemwise import *
from .math import *
from .nn import *
from .tensor import *
from .utils import *

from . import distributed # isort:skip



+ 8
- 0
imperative/python/megengine/functional/elemwise.py View File

@@ -19,6 +19,7 @@ from ..core.tensor.utils import astype
from ..device import get_default_device
from ..jit.tracing import is_tracing
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func

__all__ = [
"abs",
@@ -567,3 +568,10 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor:
return maximum(x, lower)
else:
return minimum(x, upper)


sigmoid = deprecated_func("1.3", "megengine.functional.nn", "sigmoid", True)
hsigmoid = deprecated_func("1.3", "megengine.functional.nn", "hsigmoid", True)
relu = deprecated_func("1.3", "megengine.functional.nn", "relu", True)
relu6 = deprecated_func("1.3", "megengine.functional.nn", "relu6", True)
hswish = deprecated_func("1.3", "megengine.functional.nn", "hswish", True)

+ 17
- 38
imperative/python/megengine/functional/nn.py View File

@@ -22,10 +22,11 @@ from ..device import get_default_device
from ..distributed import WORLD, is_distributed
from ..random import uniform
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
from ..utils.tuple_function import _pair, _pair_nonzero, _triple, _triple_nonzero
from .debug_param import get_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum
from .elemwise import _elwise, exp, floor, log, log1p, maximum, minimum
from .math import argsort, matmul, max, prod, sum
from .tensor import (
broadcast_to,
@@ -70,6 +71,10 @@ __all__ = [
"relu",
"relu6",
"hswish",
"resize",
"remap",
"warp_affine",
"warp_perspective",
]


@@ -1434,43 +1439,6 @@ def nvof(src: Tensor, precision: int = 1) -> Tensor:
return apply(op, src)[0]


def _elwise(*args, mode):
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args))
if len(tensor_args) == 0:
dtype = utils.dtype_promotion(args)
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device())
args = utils.convert_inputs(first_arg, *args[1:])
else:
args = utils.convert_inputs(*args)
if mode in (
Elemwise.Mode.TRUE_DIV,
Elemwise.Mode.EXP,
Elemwise.Mode.POW,
Elemwise.Mode.LOG,
Elemwise.Mode.EXPM1,
Elemwise.Mode.LOG1P,
Elemwise.Mode.TANH,
Elemwise.Mode.ACOS,
Elemwise.Mode.ASIN,
Elemwise.Mode.ATAN2,
Elemwise.Mode.CEIL,
Elemwise.Mode.COS,
Elemwise.Mode.FLOOR,
Elemwise.Mode.H_SWISH,
Elemwise.Mode.ROUND,
Elemwise.Mode.SIGMOID,
Elemwise.Mode.SIN,
):
if mode in (
Elemwise.Mode.CEIL,
Elemwise.Mode.FLOOR,
Elemwise.Mode.ROUND,
) and np.issubdtype(args[0].dtype, np.integer):
return args[0]
args = tuple(map(lambda x: astype(x, "float32"), args))
return _elwise_apply(args, mode)


def hswish(x):
"""
Element-wise `x * relu6(x + 3) / 6`.
@@ -1518,5 +1486,16 @@ def relu6(x):
return minimum(maximum(x, 0), 6)


interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True)
roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True)
roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True)
nms = deprecated_func("1.3", "megengine.functional.vision", "nms", True)
resize = deprecated_func("1.3", "megengine.functional.vision", "resize", True)
remap = deprecated_func("1.3", "megengine.functional.vision", "remap", True)
warp_affine = deprecated_func("1.3", "megengine.functional.vision", "warp_affine", True)
warp_perspective = deprecated_func(
"1.3", "megengine.functional.vision", "warp_perspective", True
)

from .loss import * # isort:skip
from .quantized import conv_bias_activation # isort:skip

+ 9
- 0
imperative/python/megengine/functional/utils.py View File

@@ -10,8 +10,11 @@ from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.core2 import sync as _sync
from ..core.ops.builtin import AssertEqual
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
from .elemwise import abs, maximum, minimum

__all__ = ["topk_accuracy"]


def _assert_equal(
expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False
@@ -55,3 +58,9 @@ def _assert_equal(
result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0]
_sync() # sync interpreter to get exception
return result


topk_accuracy = deprecated_func(
"1.3", "megengine.functional.metric", "topk_accuracy", True
)
copy = deprecated_func("1.3", "megengine.functional.tensor", "copy", True)

+ 32
- 0
imperative/python/megengine/utils/deprecation.py View File

@@ -5,4 +5,36 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import importlib
import warnings

from deprecated.sphinx import deprecated


def deprecated_func(version, origin, name, tbd):
"""
:param version: version to deprecate this function
:param origin: origin module path
:param name: function name
:param tbd: to be discussed, if true, ignore warnings
"""
should_warning = not tbd

def wrapper(*args, **kwargs):
nonlocal should_warning
module = importlib.import_module(origin)
func = module.__getattribute__(name)
if should_warning:
with warnings.catch_warnings():
warnings.simplefilter(action="always")
warnings.warn(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(
name, origin, name, version
),
category=DeprecationWarning,
stacklevel=2,
)
should_warning = False
return func(*args, **kwargs)

return wrapper

Loading…
Cancel
Save