Browse Source

fix(traced_module): fix ones/zeros functional compatiable

GitOrigin-RevId: 7ec2c4d3f5
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
7ecdbf251a
1 changed files with 28 additions and 1 deletions
  1. +28
    -1
      imperative/python/megengine/traced_module/compat.py

+ 28
- 1
imperative/python/megengine/traced_module/compat.py View File

@@ -8,7 +8,8 @@


import numpy as np import numpy as np


from .. import tensor
from megengine.functional.tensor import zeros

from ..core.ops.builtin import BatchNorm from ..core.ops.builtin import BatchNorm
from .expr import CallMethod, Constant from .expr import CallMethod, Constant
from .node import TensorNode from .node import TensorNode
@@ -135,3 +136,29 @@ def bn_opdef_loader(expr):
output = expr.outputs[-1] output = expr.outputs[-1]
oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,) oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,)
expr.outputs.insert(4, oup) expr.outputs.insert(4, oup)


@register_functional_loader(
("megengine.functional.tensor", "ones"), ("megengine.functional.tensor", "zeros")
)
def tensor_gen_func_loader(expr):
if hasattr(expr, "version") and expr.version == "1.7.0":
expr.set_args_kwargs(expr.args[0], dtype=expr.args[1], device=expr.args[2])
if not hasattr(expr, "version"):
# compatiable for version 1.6
shape = expr.args[0] if len(expr.args) > 0 else expr.kwargs["shape"]

if len(expr.args) > 1:
dtype = expr.args[1]
elif "dtype" in expr.kwargs:
dtype = expr.kwargs["dtype"]
else:
dtype = "float32"

if len(expr.args) > 2:
device = expr.args[2]
elif "device" in expr.kwargs:
device = expr.kwargs["device"]
else:
device = None
expr.set_args_kwargs(shape, dtype=dtype, device=device)

Loading…
Cancel
Save