Browse Source

fix(imperative/python): add the default warning for args descending

GitOrigin-RevId: cb5f065e6c
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
d23d1352e7
2 changed files with 37 additions and 9 deletions
  1. +3
    -1
      imperative/python/megengine/functional/math.py
  2. +34
    -8
      imperative/python/megengine/utils/deprecation.py

+ 3
- 1
imperative/python/megengine/functional/math.py View File

@@ -22,6 +22,7 @@ from ..core.tensor import amp
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph
from ..jit import exclude_from_trace from ..jit import exclude_from_trace
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.deprecation import deprecated_kwargs_default
from .debug_param import get_execution_strategy from .debug_param import get_execution_strategy
from .elemwise import clip, minimum from .elemwise import clip, minimum
from .tensor import broadcast_to, concat, expand_dims, squeeze from .tensor import broadcast_to, concat, expand_dims, squeeze
@@ -684,6 +685,7 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]:
return tns, ind return tns, ind




@deprecated_kwargs_default("1.12", "descending", 3)
def topk( def topk(
inp: Tensor, inp: Tensor,
k: int, k: int,
@@ -712,7 +714,7 @@ def topk(
import megengine.functional as F import megengine.functional as F


x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32))
top, indices = F.topk(x, 5)
top, indices = F.topk(x, 5, descending=False)
print(top.numpy(), indices.numpy()) print(top.numpy(), indices.numpy())


Outputs: Outputs:


+ 34
- 8
imperative/python/megengine/utils/deprecation.py View File

@@ -7,9 +7,12 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import importlib import importlib
import warnings import warnings
from functools import wraps


from deprecated.sphinx import deprecated from deprecated.sphinx import deprecated


warnings.filterwarnings(action="default", module="megengine")



def deprecated_func(version, origin, name, tbd): def deprecated_func(version, origin, name, tbd):
r""" r"""
@@ -27,16 +30,39 @@ def deprecated_func(version, origin, name, tbd):
module = importlib.import_module(origin) module = importlib.import_module(origin)
func = module.__getattribute__(name) func = module.__getattribute__(name)
if should_warning: if should_warning:
with warnings.catch_warnings():
warnings.simplefilter(action="always")
warnings.warn(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(
name, origin, name, version
),
category=DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)

return wrapper


def deprecated_kwargs_default(version, kwargs_name, kwargs_pos):
r"""
Args:
version: version to deprecate this default
kwargs_name: kwargs name
kwargs_pos: kwargs position
"""

def deprecated(func):
@wraps(func)
def wrapper(*args, **kwargs):
if len(args) < kwargs_pos and kwargs_name not in kwargs:
warnings.warn( warnings.warn(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(
name, origin, name, version
"the default behavior for {} will be changed in version {}, please use it in keyword parameter way".format(
kwargs_name, version
), ),
category=DeprecationWarning,
category=PendingDeprecationWarning,
stacklevel=2, stacklevel=2,
) )
should_warning = False
return func(*args, **kwargs)
return func(*args, **kwargs)


return wrapper
return wrapper

return deprecated

Loading…
Cancel
Save