|
|
@@ -495,7 +495,7 @@ class InternalGraph: |
|
|
|
|
|
|
|
inp = F.zeros(shape = (3, 4)) |
|
|
|
traced_module = tm.trace_module(net, inp) |
|
|
|
|
|
|
|
|
|
|
|
Will produce the following ``InternalGraph``:: |
|
|
|
|
|
|
|
print(traced_module.graph) |
|
|
@@ -728,7 +728,7 @@ class InternalGraph: |
|
|
|
# graph : InternalGraph |
|
|
|
print("{:p}".format(graph)) |
|
|
|
print(graph.__format__("p")) |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
name: a string in glob syntax that can contain ``?`` and |
|
|
|
``*`` to match a single or arbitrary characters. |
|
|
@@ -749,7 +749,7 @@ class InternalGraph: |
|
|
|
|
|
|
|
def get_dep_exprs(self, nodes: Sequence[Node]) -> List[Expr]: |
|
|
|
r"""Get the dependent Exprs of the ``nodes``. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
nodes: a list of :class:`Node`. |
|
|
|
Returns: |
|
|
@@ -817,7 +817,7 @@ class InternalGraph: |
|
|
|
shape: the shape of the new input Node. |
|
|
|
dtype: the dtype of the new input Node. |
|
|
|
Default: float32 |
|
|
|
name: the name of the new input Node. When the name is used in the graph, |
|
|
|
name: the name of the new input Node. When the name is used in the graph, |
|
|
|
a suffix will be added to it. |
|
|
|
""" |
|
|
|
forma_mnode = self.inputs[0] |
|
|
@@ -850,16 +850,16 @@ class InternalGraph: |
|
|
|
|
|
|
|
def reset_outputs(self, outputs): |
|
|
|
r"""Reset the output Nodes of the graph. |
|
|
|
|
|
|
|
|
|
|
|
.. note:: |
|
|
|
|
|
|
|
This method only supports resetting the output of graphs |
|
|
|
that do not have a parent graph. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
outputs: an object which inner element is Node. Support tuple, list |
|
|
|
dict, etc. |
|
|
|
|
|
|
|
|
|
|
|
For example, the following code |
|
|
|
|
|
|
|
.. code-block:: |
|
|
@@ -882,7 +882,7 @@ class InternalGraph: |
|
|
|
out_node = graph.outputs[0] |
|
|
|
graph.reset_outputs((out_node, {"input": inp_node})) |
|
|
|
out = traced_module(inp) |
|
|
|
|
|
|
|
|
|
|
|
Will produce the following ``InternalGraph`` and ``out``:: |
|
|
|
|
|
|
|
print(graph) |
|
|
@@ -910,9 +910,9 @@ class InternalGraph: |
|
|
|
|
|
|
|
def add_output_node(self, node: TensorNode): |
|
|
|
r"""Add an output node to the Graph. |
|
|
|
|
|
|
|
|
|
|
|
The Graph output will become a ``tuple`` after calling ``add_output_node``. |
|
|
|
The first element of the ``tuple`` is the original output, and the second |
|
|
|
The first element of the ``tuple`` is the original output, and the second |
|
|
|
is the ``node``. |
|
|
|
|
|
|
|
For example, the following code |
|
|
@@ -938,7 +938,7 @@ class InternalGraph: |
|
|
|
graph.add_output_node(inp_node) |
|
|
|
graph.add_output_node(out_node) |
|
|
|
out = traced_module(inp) |
|
|
|
|
|
|
|
|
|
|
|
Will produce the following ``InternalGraph`` and ``out``:: |
|
|
|
|
|
|
|
print(graph) |
|
|
@@ -977,11 +977,11 @@ class InternalGraph: |
|
|
|
... # inert exprs into graph and resotre normal mode |
|
|
|
|
|
|
|
Args: |
|
|
|
expr: the ``expr`` after which to insert. If None, the insertion position will be |
|
|
|
expr: the ``expr`` after which to insert. If None, the insertion position will be |
|
|
|
automatically set based on the input node. |
|
|
|
|
|
|
|
Returns: |
|
|
|
A resource manager that will initialize trace mode on ``__enter__`` and |
|
|
|
A resource manager that will initialize trace mode on ``__enter__`` and |
|
|
|
restore normal mode on ``__exit__``. |
|
|
|
""" |
|
|
|
if expr is not None: |
|
|
@@ -990,7 +990,7 @@ class InternalGraph: |
|
|
|
|
|
|
|
def replace_node(self, repl_dict: Dict[Node, Node]): |
|
|
|
r"""Replace the Nodes in the graph. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
repl_dict: the map {old_Node: new_Node} that specifies how to replace the Nodes. |
|
|
|
""" |
|
|
@@ -1752,7 +1752,7 @@ class BaseFilter: |
|
|
|
|
|
|
|
class ExprFilter(BaseFilter): |
|
|
|
"""Filter on Expr iterator. |
|
|
|
This class is an iterator of :class:`.Expr` objects and multiple |
|
|
|
This class is an iterator of :class:`.Expr` objects and multiple |
|
|
|
filtering conditions and mappers can be chained. |
|
|
|
""" |
|
|
|
|
|
|
@@ -1777,7 +1777,7 @@ class ExprFilter(BaseFilter): |
|
|
|
|
|
|
|
class NodeFilter(BaseFilter): |
|
|
|
"""Filter on Node iterator. |
|
|
|
This class is an iterator of :class:`.Node` objects and multiple |
|
|
|
This class is an iterator of :class:`~.traced_module.Node` objects and multiple |
|
|
|
filtering conditions and mappers can be chained. |
|
|
|
""" |
|
|
|
|
|
|
@@ -1905,14 +1905,14 @@ class ExprFilterExprId(ExprFilter): |
|
|
|
|
|
|
|
|
|
|
|
class TracedModule(Module): |
|
|
|
r"""``TracedModule`` is the Module created by tracing normal module. |
|
|
|
|
|
|
|
It owns an argdef to graph(InternalGraph) map. The forward method of ``TracedModule`` |
|
|
|
will get a graph from ``argdef_graph_map`` according to the argdef of input ``args/kwargs`` |
|
|
|
r"""``TracedModule`` is the Module created by tracing normal module. |
|
|
|
|
|
|
|
It owns an argdef to graph(InternalGraph) map. The forward method of ``TracedModule`` |
|
|
|
will get a graph from ``argdef_graph_map`` according to the argdef of input ``args/kwargs`` |
|
|
|
and interpret it. |
|
|
|
|
|
|
|
|
|
|
|
.. note:: |
|
|
|
``TracedModule`` can only be created by :func:`~.trace_module`. See :func:`~.trace_module` |
|
|
|
``TracedModule`` can only be created by :func:`~.trace_module`. See :func:`~.trace_module` |
|
|
|
for more details. |
|
|
|
""" |
|
|
|
# m_node = None # type: ModuleNode |
|
|
@@ -1956,10 +1956,10 @@ class TracedModule(Module): |
|
|
|
r"""Initialize the :attr:`~.TracedModule.watch_points`. |
|
|
|
|
|
|
|
You can call this function to get the ``Tensor/Module`` corresponding to a ``Node`` at runtime. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
nodes: a list of ``Node``. |
|
|
|
|
|
|
|
|
|
|
|
For example, the following code |
|
|
|
|
|
|
|
.. code-block:: |
|
|
@@ -1981,7 +1981,7 @@ class TracedModule(Module): |
|
|
|
traced_module.set_watch_points(add_1_node) |
|
|
|
|
|
|
|
out = traced_module(inp) |
|
|
|
|
|
|
|
|
|
|
|
Will get the following ``watch_node_value``:: |
|
|
|
|
|
|
|
print(traced_module.watch_node_value) |
|
|
@@ -2010,7 +2010,7 @@ class TracedModule(Module): |
|
|
|
r"""Initialize the :attr:`~.TracedModule.end_points`. |
|
|
|
|
|
|
|
When all the ``nodes`` are generated, the Module will stop execution and return directly. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
nodes: a list of ``Node``. |
|
|
|
|
|
|
@@ -2035,7 +2035,7 @@ class TracedModule(Module): |
|
|
|
traced_module.set_end_points(add_1_node) |
|
|
|
|
|
|
|
out = traced_module(inp) |
|
|
|
|
|
|
|
|
|
|
|
Will get the following ``out``:: |
|
|
|
|
|
|
|
print(out) |
|
|
@@ -2354,7 +2354,7 @@ def wrap(func: Callable): |
|
|
|
@tm.wrap |
|
|
|
def my_func(x, y): |
|
|
|
return x + y |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
func: the function of the global function to insert into the graph when it's called. |
|
|
|
""" |
|
|
|