Browse Source

fix(mge/utils): filter out parameter "arg_names" and "output_name" in network.dump

GitOrigin-RevId: 408f52ad2d
release-1.3
Megvii Engine Team 4 years ago
parent
commit
c92317edc0
1 changed files with 12 additions and 0 deletions
  1. +12
    -0
      imperative/python/megengine/utils/network.py

+ 12
- 0
imperative/python/megengine/utils/network.py View File

@@ -18,6 +18,7 @@ import numpy as np
from ..core._imperative_rt import ComputingGraph
from ..core._imperative_rt.core2 import SymbolVar
from ..core.tensor import megbrain_graph as G
from ..logger import get_logger
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq
from .network_node import (
Host2DeviceCopy,
@@ -28,6 +29,8 @@ from .network_node import (
str_to_mge_class,
)

logger = get_logger(__name__)


class Network:
def __init__(self):
@@ -164,6 +167,15 @@ class Network:
self._compile()
out = [G.VarNode(var.var) for var in self.output_vars]

if kwargs.pop("arg_names", False):
logger.warning(
'"arg_names" is not supported in Network.dump, rename input vars directly'
)
if kwargs.pop("output_names", False):
logger.warning(
'"output_names" is not supported in Network.dump, rename output vars directly'
)

if optimize_for_inference:
out = G.optimize_for_inference(out, **kwargs)



Loading…
Cancel
Save