Browse Source

fix(traced_module): fix ones/zeros functional compatiable

GitOrigin-RevId: 7ec2c4d3f5
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
7ecdbf251a
1 changed files with 28 additions and 1 deletions
  1. +28
    -1
      imperative/python/megengine/traced_module/compat.py

+ 28
- 1
imperative/python/megengine/traced_module/compat.py View File

@@ -8,7 +8,8 @@

import numpy as np

from .. import tensor
from megengine.functional.tensor import zeros

from ..core.ops.builtin import BatchNorm
from .expr import CallMethod, Constant
from .node import TensorNode
@@ -135,3 +136,29 @@ def bn_opdef_loader(expr):
output = expr.outputs[-1]
oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,)
expr.outputs.insert(4, oup)


@register_functional_loader(
("megengine.functional.tensor", "ones"), ("megengine.functional.tensor", "zeros")
)
def tensor_gen_func_loader(expr):
if hasattr(expr, "version") and expr.version == "1.7.0":
expr.set_args_kwargs(expr.args[0], dtype=expr.args[1], device=expr.args[2])
if not hasattr(expr, "version"):
# compatiable for version 1.6
shape = expr.args[0] if len(expr.args) > 0 else expr.kwargs["shape"]

if len(expr.args) > 1:
dtype = expr.args[1]
elif "dtype" in expr.kwargs:
dtype = expr.kwargs["dtype"]
else:
dtype = "float32"

if len(expr.args) > 2:
device = expr.args[2]
elif "device" in expr.kwargs:
device = expr.kwargs["device"]
else:
device = None
expr.set_args_kwargs(shape, dtype=dtype, device=device)

Loading…
Cancel
Save