From e6dcfbe8060ddbbba727bd3261bb329cadd34942 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 10 Jun 2022 16:57:39 +0800 Subject: [PATCH] fix(traced_module): fix traced module compatible issues GitOrigin-RevId: 67e68ef5eae78d93a167d8d32ac78837932f3b45 --- .../python/megengine/traced_module/compat.py | 37 ++++++++++++++++------ imperative/python/megengine/utils/deprecation.py | 5 +-- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/imperative/python/megengine/traced_module/compat.py b/imperative/python/megengine/traced_module/compat.py index 1cbf87d8..3941b470 100644 --- a/imperative/python/megengine/traced_module/compat.py +++ b/imperative/python/megengine/traced_module/compat.py @@ -99,11 +99,10 @@ def add_loader(expr): ("megengine.module.batchnorm", "SyncBatchNorm"), ) def bn2d_module_loader(expr): - # mge 1.6 - if not hasattr(expr, "version"): - module = expr.inputs[0].owner - if not hasattr(module, "param_dim"): - module.param_dim = "dim_1c11" + module = expr.inputs[0].owner + if hasattr(module, "param_dim"): + assert module.param_dim == "dim_1c11" + delattr(module, "param_dim") @register_module_loader( @@ -113,12 +112,10 @@ def bn2d_module_loader(expr): ("megengine.module.qat.conv_bn", "ConvBnRelu2d"), ) def convbn2d_module_loader(expr): - # mge 1.6 - if not hasattr(expr, "version"): - module = expr.inputs[0].owner - if not hasattr(module.bn, "param_dim"): - module.bn.param_dim = "dim_1c11" module = expr.inputs[0].owner + if hasattr(module.bn, "param_dim"): + assert module.bn.param_dim == "dim_1c11" + delattr(module.bn, "param_dim") if not hasattr(module.conv, "padding_mode"): module.conv.padding_mode = "zeros" @@ -167,6 +164,26 @@ def pad_func_loader(expr): expr.set_args_kwargs(*expr.args, **kwargs) +@register_functional_loader(("megengine.functional.nn", "batch_norm")) +def bn_func_loader(expr): + kwargs = expr.kwargs + if "compute_mode" in kwargs: + assert kwargs["compute_mode"] == "default" + kwargs.pop("compute_mode") + if "param_dim" in kwargs: + assert kwargs["param_dim"] == "dim_1c11" + kwargs.pop("param_dim") + expr.set_args_kwargs(*expr.args, **kwargs) + + +@register_functional_loader(("megengine.functional.math", "matmul")) +def matmul_func_loader(expr): + args = expr.args + if len(args) == 6: + assert args[5] == "default" + expr.set_args_kwargs(*args[0:5]) + + @register_module_loader( ("megengine.module.conv", "Conv1d"), ("megengine.module.conv", "Conv2d"), diff --git a/imperative/python/megengine/utils/deprecation.py b/imperative/python/megengine/utils/deprecation.py index ea58d71a..f9ed8bdf 100644 --- a/imperative/python/megengine/utils/deprecation.py +++ b/imperative/python/megengine/utils/deprecation.py @@ -17,11 +17,12 @@ def deprecated_func(version, origin, name, tbd): tbd: to be discussed, if true, ignore warnings """ should_warning = not tbd + module = importlib.import_module(origin) + func = module.__getattribute__(name) + @wraps(func) def wrapper(*args, **kwargs): nonlocal should_warning - module = importlib.import_module(origin) - func = module.__getattribute__(name) if should_warning: warnings.warn( "Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(