|
|
@@ -8,7 +8,8 @@ |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from .. import tensor |
|
|
|
from megengine.functional.tensor import zeros |
|
|
|
|
|
|
|
from ..core.ops.builtin import BatchNorm |
|
|
|
from .expr import CallMethod, Constant |
|
|
|
from .node import TensorNode |
|
|
@@ -135,3 +136,29 @@ def bn_opdef_loader(expr): |
|
|
|
output = expr.outputs[-1] |
|
|
|
oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,) |
|
|
|
expr.outputs.insert(4, oup) |
|
|
|
|
|
|
|
|
|
|
|
@register_functional_loader( |
|
|
|
("megengine.functional.tensor", "ones"), ("megengine.functional.tensor", "zeros") |
|
|
|
) |
|
|
|
def tensor_gen_func_loader(expr): |
|
|
|
if hasattr(expr, "version") and expr.version == "1.7.0": |
|
|
|
expr.set_args_kwargs(expr.args[0], dtype=expr.args[1], device=expr.args[2]) |
|
|
|
if not hasattr(expr, "version"): |
|
|
|
# compatiable for version 1.6 |
|
|
|
shape = expr.args[0] if len(expr.args) > 0 else expr.kwargs["shape"] |
|
|
|
|
|
|
|
if len(expr.args) > 1: |
|
|
|
dtype = expr.args[1] |
|
|
|
elif "dtype" in expr.kwargs: |
|
|
|
dtype = expr.kwargs["dtype"] |
|
|
|
else: |
|
|
|
dtype = "float32" |
|
|
|
|
|
|
|
if len(expr.args) > 2: |
|
|
|
device = expr.args[2] |
|
|
|
elif "device" in expr.kwargs: |
|
|
|
device = expr.kwargs["device"] |
|
|
|
else: |
|
|
|
device = None |
|
|
|
expr.set_args_kwargs(shape, dtype=dtype, device=device) |