GitOrigin-RevId: cb5f065e6c
tags/v1.8.0
@@ -22,6 +22,7 @@ from ..core.tensor import amp | |||||
from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph | from ..core.tensor.utils import _normalize_axis, cast_tensors, setscalar, subgraph | ||||
from ..jit import exclude_from_trace | from ..jit import exclude_from_trace | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..utils.deprecation import deprecated_kwargs_default | |||||
from .debug_param import get_execution_strategy | from .debug_param import get_execution_strategy | ||||
from .elemwise import clip, minimum | from .elemwise import clip, minimum | ||||
from .tensor import broadcast_to, concat, expand_dims, squeeze | from .tensor import broadcast_to, concat, expand_dims, squeeze | ||||
@@ -684,6 +685,7 @@ def sort(inp: Tensor, descending: bool = False) -> Tuple[Tensor, Tensor]: | |||||
return tns, ind | return tns, ind | ||||
@deprecated_kwargs_default("1.12", "descending", 3) | |||||
def topk( | def topk( | ||||
inp: Tensor, | inp: Tensor, | ||||
k: int, | k: int, | ||||
@@ -712,7 +714,7 @@ def topk( | |||||
import megengine.functional as F | import megengine.functional as F | ||||
x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | x = tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | ||||
top, indices = F.topk(x, 5) | |||||
top, indices = F.topk(x, 5, descending=False) | |||||
print(top.numpy(), indices.numpy()) | print(top.numpy(), indices.numpy()) | ||||
Outputs: | Outputs: | ||||
@@ -7,9 +7,12 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import importlib | import importlib | ||||
import warnings | import warnings | ||||
from functools import wraps | |||||
from deprecated.sphinx import deprecated | from deprecated.sphinx import deprecated | ||||
warnings.filterwarnings(action="default", module="megengine") | |||||
def deprecated_func(version, origin, name, tbd): | def deprecated_func(version, origin, name, tbd): | ||||
r""" | r""" | ||||
@@ -27,16 +30,39 @@ def deprecated_func(version, origin, name, tbd): | |||||
module = importlib.import_module(origin) | module = importlib.import_module(origin) | ||||
func = module.__getattribute__(name) | func = module.__getattribute__(name) | ||||
if should_warning: | if should_warning: | ||||
with warnings.catch_warnings(): | |||||
warnings.simplefilter(action="always") | |||||
warnings.warn( | |||||
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format( | |||||
name, origin, name, version | |||||
), | |||||
category=DeprecationWarning, | |||||
stacklevel=2, | |||||
) | |||||
return func(*args, **kwargs) | |||||
return wrapper | |||||
def deprecated_kwargs_default(version, kwargs_name, kwargs_pos): | |||||
r""" | |||||
Args: | |||||
version: version to deprecate this default | |||||
kwargs_name: kwargs name | |||||
kwargs_pos: kwargs position | |||||
""" | |||||
def deprecated(func): | |||||
@wraps(func) | |||||
def wrapper(*args, **kwargs): | |||||
if len(args) < kwargs_pos and kwargs_name not in kwargs: | |||||
warnings.warn( | warnings.warn( | ||||
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format( | |||||
name, origin, name, version | |||||
"the default behavior for {} will be changed in version {}, please use it in keyword parameter way".format( | |||||
kwargs_name, version | |||||
), | ), | ||||
category=DeprecationWarning, | |||||
category=PendingDeprecationWarning, | |||||
stacklevel=2, | stacklevel=2, | ||||
) | ) | ||||
should_warning = False | |||||
return func(*args, **kwargs) | |||||
return func(*args, **kwargs) | |||||
return wrapper | |||||
return wrapper | |||||
return deprecated |