Browse Source

fix(imperative/python): add the default warning for args descending

GitOrigin-RevId: cb5f065e6c
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
d23d1352e7
2 changed files with 37 additions and 9 deletions
  1. +3
    -1
      imperative/python/megengine/functional/math.py
  2. +34
    -8
      imperative/python/megengine/utils/deprecation.py

+ 3
- 1
imperative/python/megengine/functional/math.py View File

@@ -22,6 +22,7 @@ from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph
from ..jit import exclude_from_trace
from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default
from .debug_param import get_execution_strategy
from .elemwise import clip, minimum
from .tensor import broadcast_to, concat, expand_dims, squeeze
@@ -684,6 +685,7 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
return tns, ind


@deprecated_kwargs_default("1.12", "descending", 3)
def topk(
inp: Tensor,
k: int,
@@ -712,7 +714,7 @@ def topk(
import megengine.functional as F

x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
top, indices = F.topk(x, 5)
top, indices = F.topk(x, 5, descending=False)
print(top.numpy(), indices.numpy())

Outputs:


+ 34
- 8
imperative/python/megengine/utils/deprecation.py View File

@@ -7,9 +7,12 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import importlib
import warnings
from functools import wraps

from deprecated.sphinx import deprecated

warnings.filterwarnings(action="default", module="megengine")


def deprecated_func(version, origin, name, tbd):
r"""
@@ -27,16 +30,39 @@ def deprecated_func(version, origin, name, tbd):
module = importlib.import_module(origin)
func = module.__getattribute__(name)
if should_warning:
with warnings.catch_warnings():
warnings.simplefilter(action="always")
warnings.warn(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(
name, origin, name, version
),
category=DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)

return wrapper


def deprecated_kwargs_default(version, kwargs_name, kwargs_pos):
r"""
Args:
version: version to deprecate this default
kwargs_name: kwargs name
kwargs_pos: kwargs position
"""

def deprecated(func):
@wraps(func)
def wrapper(*args, **kwargs):
if len(args) < kwargs_pos and kwargs_name not in kwargs:
warnings.warn(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(
name, origin, name, version
"the default behavior for {} will be changed in version {}, please use it in keyword parameter way".format(
kwargs_name, version
),
category=DeprecationWarning,
category=PendingDeprecationWarning,
stacklevel=2,
)
should_warning = False
return func(*args, **kwargs)
return func(*args, **kwargs)

return wrapper
return wrapper

return deprecated

Loading…
Cancel
Save