|
|
@@ -99,11 +99,10 @@ def add_loader(expr): |
|
|
|
("megengine.module.batchnorm", "SyncBatchNorm"), |
|
|
|
) |
|
|
|
def bn2d_module_loader(expr): |
|
|
|
# mge 1.6 |
|
|
|
if not hasattr(expr, "version"): |
|
|
|
module = expr.inputs[0].owner |
|
|
|
if not hasattr(module, "param_dim"): |
|
|
|
module.param_dim = "dim_1c11" |
|
|
|
module = expr.inputs[0].owner |
|
|
|
if hasattr(module, "param_dim"): |
|
|
|
assert module.param_dim == "dim_1c11" |
|
|
|
delattr(module, "param_dim") |
|
|
|
|
|
|
|
|
|
|
|
@register_module_loader( |
|
|
@@ -113,12 +112,10 @@ def bn2d_module_loader(expr): |
|
|
|
("megengine.module.qat.conv_bn", "ConvBnRelu2d"), |
|
|
|
) |
|
|
|
def convbn2d_module_loader(expr): |
|
|
|
# mge 1.6 |
|
|
|
if not hasattr(expr, "version"): |
|
|
|
module = expr.inputs[0].owner |
|
|
|
if not hasattr(module.bn, "param_dim"): |
|
|
|
module.bn.param_dim = "dim_1c11" |
|
|
|
module = expr.inputs[0].owner |
|
|
|
if hasattr(module.bn, "param_dim"): |
|
|
|
assert module.bn.param_dim == "dim_1c11" |
|
|
|
delattr(module.bn, "param_dim") |
|
|
|
if not hasattr(module.conv, "padding_mode"): |
|
|
|
module.conv.padding_mode = "zeros" |
|
|
|
|
|
|
@@ -167,6 +164,26 @@ def pad_func_loader(expr): |
|
|
|
expr.set_args_kwargs(*expr.args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
@register_functional_loader(("megengine.functional.nn", "batch_norm")) |
|
|
|
def bn_func_loader(expr): |
|
|
|
kwargs = expr.kwargs |
|
|
|
if "compute_mode" in kwargs: |
|
|
|
assert kwargs["compute_mode"] == "default" |
|
|
|
kwargs.pop("compute_mode") |
|
|
|
if "param_dim" in kwargs: |
|
|
|
assert kwargs["param_dim"] == "dim_1c11" |
|
|
|
kwargs.pop("param_dim") |
|
|
|
expr.set_args_kwargs(*expr.args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
@register_functional_loader(("megengine.functional.math", "matmul")) |
|
|
|
def matmul_func_loader(expr): |
|
|
|
args = expr.args |
|
|
|
if len(args) == 6: |
|
|
|
assert args[5] == "default" |
|
|
|
expr.set_args_kwargs(*args[0:5]) |
|
|
|
|
|
|
|
|
|
|
|
@register_module_loader( |
|
|
|
("megengine.module.conv", "Conv1d"), |
|
|
|
("megengine.module.conv", "Conv2d"), |
|
|
|