|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
- import numpy as np
-
- from megengine.functional.tensor import zeros
-
- from ..core.ops.builtin import BatchNorm
- from .expr import CallMethod, Constant
- from .node import TensorNode
- from .serialization import (
- register_functional_loader,
- register_module_loader,
- register_opdef_loader,
- register_tensor_method_loader,
- )
-
-
- """
- # Expr loaders examples
-
- from ..core.ops.builtin import Elemwise
-
- @register_opdef_loader(Elemwise)
- def add_opdef_loader(expr):
- if expr.opdef_state["mode"] == "ADD":
- expr.opdef_state["mode"] == "MUL"
- node = expr.inputs[1]
- astype_expr = CallMethod(node, "astype")
- oup = TensorNode(
- astype_expr,
- shape=node.shape,
- dtype=expr.inputs[0].dtype,
- qparams=node.qparams,
- )
-
- astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
- astype_expr.return_val = (oup,)
- expr.inputs[1] = oup
-
-
- @register_functional_loader(("megengine.functional.nn", "conv2d"))
- def conv2df_loader(expr):
- # expr.func = ("megengine.functional.nn","conv2d")
- kwargs = expr.kwargs
- orig_weight = expr.named_args["weight"]
-
- astype_expr = CallMethod(orig_weight, "astype")
- oup = TensorNode(
- astype_expr,
- shape=orig_weight.shape,
- dtype=orig_weight.dtype,
- qparams=orig_weight.qparams,
- )
-
- astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
- astype_expr.return_val = (oup,)
-
- expr.set_arg("weight", oup)
-
-
- @register_module_loader(("megengine.module.conv", "Conv2d"))
- def conv2dm_loader(expr):
- module = expr.inputs[0].owner
- args = list(expr.args)
- orig_inp = args[1]
- astype_expr = CallMethod(orig_inp, "astype")
- oup = TensorNode(
- astype_expr,
- shape=orig_inp.shape,
- dtype=orig_inp.dtype,
- qparams=orig_inp.qparams,
- )
- astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
- astype_expr.return_val = (oup,)
- args[1] = oup
- expr.set_args_kwargs(*args)
-
-
- @register_tensor_method_loader("__add__")
- def add_loader(expr):
- args = list(expr.args)
- if not isinstance(args[1], TensorNode):
- args[1] = tensor(args[1])
- node = Constant(args[1], "const").outputs[0]
-
- astype_expr = CallMethod(node, "astype")
- oup = TensorNode(
- astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
- )
-
- astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
- astype_expr.return_val = (oup,)
- args[1] = oup
- expr.set_args_kwargs(*args)
- """
-
-
- @register_module_loader(
- ("megengine.module.batchnorm", "BatchNorm1d"),
- ("megengine.module.batchnorm", "BatchNorm2d"),
- ("megengine.module.batchnorm", "SyncBatchNorm"),
- )
- def bn2d_module_loader(expr):
- # mge 1.6
- if not hasattr(expr, "version"):
- module = expr.inputs[0].owner
- if not hasattr(module, "param_dim"):
- module.param_dim = "dim_1c11"
-
-
- @register_module_loader(
- ("megengine.module.conv_bn", "ConvBn2d"),
- ("megengine.module.conv_bn", "ConvBnRelu2d"),
- ("megengine.module.qat.conv_bn", "ConvBn2d"),
- ("megengine.module.qat.conv_bn", "ConvBnRelu2d"),
- )
- def convbn2d_module_loader(expr):
- # mge 1.6
- if not hasattr(expr, "version"):
- module = expr.inputs[0].owner
- if not hasattr(module.bn, "param_dim"):
- module.bn.param_dim = "dim_1c11"
- module = expr.inputs[0].owner
- if not hasattr(module.conv, "padding_mode"):
- module.conv.padding_mode = "zeros"
-
-
- @register_opdef_loader(BatchNorm)
- def bn_opdef_loader(expr):
- # mge 1.6
- if not hasattr(expr, "version") and len(expr.outputs) != 6:
- assert len(expr.outputs) == 5
- output = expr.outputs[-1]
- oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,)
- expr.outputs.insert(4, oup)
-
-
- @register_functional_loader(
- ("megengine.functional.tensor", "ones"), ("megengine.functional.tensor", "zeros")
- )
- def tensor_gen_func_loader(expr):
- if hasattr(expr, "version") and expr.version == "1.7.0":
- expr.set_args_kwargs(expr.args[0], dtype=expr.args[1], device=expr.args[2])
- if not hasattr(expr, "version"):
- # compatiable for version 1.6
- shape = expr.args[0] if len(expr.args) > 0 else expr.kwargs["shape"]
-
- if len(expr.args) > 1:
- dtype = expr.args[1]
- elif "dtype" in expr.kwargs:
- dtype = expr.kwargs["dtype"]
- else:
- dtype = "float32"
-
- if len(expr.args) > 2:
- device = expr.args[2]
- elif "device" in expr.kwargs:
- device = expr.kwargs["device"]
- else:
- device = None
- expr.set_args_kwargs(shape, dtype=dtype, device=device)
-
-
- @register_functional_loader(("megengine.functional.nn", "pad"))
- def pad_func_loader(expr):
- if "pad_witdth" in expr.kwargs:
- kwargs = expr.kwargs
- kwargs["pad_width"] = kwargs.pop("pad_witdth")
- expr.set_args_kwargs(*expr.args, **kwargs)
-
-
- @register_module_loader(
- ("megengine.module.conv", "Conv1d"),
- ("megengine.module.conv", "Conv2d"),
- ("megengine.module.conv", "ConvRelu2d"),
- ("megengine.module.qat.conv", "Conv2d"),
- ("megengine.module.qat.conv", "ConvRelu2d"),
- ("megengine.module.quantized.conv", "Conv2d"),
- ("megengine.module.quantized.conv", "ConvRelu2d"),
- )
- def conv2d_module_loader(expr):
- module = expr.inputs[0].owner
- if not hasattr(module, "padding_mode"):
- module.padding_mode = "zeros"
-
-
- @register_module_loader(
- ("megengine.module.quantized.conv_bn", "ConvBn2d"),
- ("megengine.module.quantized.conv_bn", "ConvBnRelu2d"),
- )
- def quantized_convbn2d_module_loader(expr):
- module = expr.inputs[0].owner
- if not hasattr(module, "padding_mode"):
- module.padding_mode = "zeros"
|