|
|
@@ -18,6 +18,7 @@ import numpy as np |
|
|
|
from ..core._imperative_rt import ComputingGraph |
|
|
|
from ..core._imperative_rt.core2 import SymbolVar |
|
|
|
from ..core.tensor import megbrain_graph as G |
|
|
|
from ..logger import get_logger |
|
|
|
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq |
|
|
|
from .network_node import ( |
|
|
|
Host2DeviceCopy, |
|
|
@@ -28,6 +29,8 @@ from .network_node import ( |
|
|
|
str_to_mge_class, |
|
|
|
) |
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class Network: |
|
|
|
def __init__(self): |
|
|
@@ -164,6 +167,15 @@ class Network: |
|
|
|
self._compile() |
|
|
|
out = [G.VarNode(var.var) for var in self.output_vars] |
|
|
|
|
|
|
|
if kwargs.pop("arg_names", False): |
|
|
|
logger.warning( |
|
|
|
'"arg_names" is not supported in Network.dump, rename input vars directly' |
|
|
|
) |
|
|
|
if kwargs.pop("output_names", False): |
|
|
|
logger.warning( |
|
|
|
'"output_names" is not supported in Network.dump, rename output vars directly' |
|
|
|
) |
|
|
|
|
|
|
|
if optimize_for_inference: |
|
|
|
out = G.optimize_for_inference(out, **kwargs) |
|
|
|
|
|
|
|