Browse Source

fix(traced_module): fix traced module compatible issues

GitOrigin-RevId: 67e68ef5ea
release-1.10
Megvii Engine Team 3 years ago
parent
commit
e6dcfbe806
2 changed files with 30 additions and 12 deletions
  1. +27
    -10
      imperative/python/megengine/traced_module/compat.py
  2. +3
    -2
      imperative/python/megengine/utils/deprecation.py

+ 27
- 10
imperative/python/megengine/traced_module/compat.py View File

@@ -99,11 +99,10 @@ def add_loader(expr):
("megengine.module.batchnorm", "SyncBatchNorm"),
)
def bn2d_module_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module, "param_dim"):
module.param_dim = "dim_1c11"
module = expr.inputs[0].owner
if hasattr(module, "param_dim"):
assert module.param_dim == "dim_1c11"
delattr(module, "param_dim")


@register_module_loader(
@@ -113,12 +112,10 @@ def bn2d_module_loader(expr):
("megengine.module.qat.conv_bn", "ConvBnRelu2d"),
)
def convbn2d_module_loader(expr):
# mge 1.6
if not hasattr(expr, "version"):
module = expr.inputs[0].owner
if not hasattr(module.bn, "param_dim"):
module.bn.param_dim = "dim_1c11"
module = expr.inputs[0].owner
if hasattr(module.bn, "param_dim"):
assert module.bn.param_dim == "dim_1c11"
delattr(module.bn, "param_dim")
if not hasattr(module.conv, "padding_mode"):
module.conv.padding_mode = "zeros"

@@ -167,6 +164,26 @@ def pad_func_loader(expr):
expr.set_args_kwargs(*expr.args, **kwargs)


@register_functional_loader(("megengine.functional.nn", "batch_norm"))
def bn_func_loader(expr):
kwargs = expr.kwargs
if "compute_mode" in kwargs:
assert kwargs["compute_mode"] == "default"
kwargs.pop("compute_mode")
if "param_dim" in kwargs:
assert kwargs["param_dim"] == "dim_1c11"
kwargs.pop("param_dim")
expr.set_args_kwargs(*expr.args, **kwargs)


@register_functional_loader(("megengine.functional.math", "matmul"))
def matmul_func_loader(expr):
args = expr.args
if len(args) == 6:
assert args[5] == "default"
expr.set_args_kwargs(*args[0:5])


@register_module_loader(
("megengine.module.conv", "Conv1d"),
("megengine.module.conv", "Conv2d"),


+ 3
- 2
imperative/python/megengine/utils/deprecation.py View File

@@ -17,11 +17,12 @@ def deprecated_func(version, origin, name, tbd):
tbd: to be discussed, if true, ignore warnings
"""
should_warning = not tbd
module = importlib.import_module(origin)
func = module.__getattribute__(name)

@wraps(func)
def wrapper(*args, **kwargs):
nonlocal should_warning
module = importlib.import_module(origin)
func = module.__getattribute__(name)
if should_warning:
warnings.warn(
"Call to deprecated function {}. (use {}.{} instead) -- Deprecated since version {}.".format(


Loading…
Cancel
Save