Browse Source

feat(mge/utils): add GraphInference in cgtools

GitOrigin-RevId: 72f2201169
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
8e50a6daa7
2 changed files with 122 additions and 21 deletions
  1. +77
    -21
      imperative/python/megengine/utils/comp_graph_tools.py
  2. +45
    -0
      imperative/python/test/unit/test_cgtools.py

+ 77
- 21
imperative/python/megengine/utils/comp_graph_tools.py View File

@@ -6,7 +6,8 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections import collections
from typing import Dict, List
from collections import OrderedDict
from typing import Dict, List, Optional


import numpy import numpy


@@ -27,6 +28,7 @@ __all__ = [
"replace_oprs", "replace_oprs",
"set_priority_to_id", "set_priority_to_id",
"load_and_inference", "load_and_inference",
"GraphInference",
] ]




@@ -46,7 +48,7 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:


q = list(var) q = list(var)
while q: while q:
v = q.pop()
v = q.pop(0)
if v in memo: if v in memo:
continue continue
memo.add(v) memo.add(v)
@@ -281,23 +283,77 @@ def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.n
:return: list of inference results. :return: list of inference results.


""" """
*_, out_list = G.load_graph(file)
inputs = get_dep_vars(out_list, "Host2DeviceCopy")
replace_dict = {}
inp_node_list = []
for i in inputs:
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
replace_dict[i] = inp_node.outputs[0]
inp_node_list.append(inp_node)
new_out = replace_vars(out_list, replace_dict)
out_node_list = [G.OutputNode(i) for i in new_out]
new_out_list = [i.outputs[0] for i in out_node_list]
cg = new_out_list[0].graph
func = cg.compile(new_out_list)
for node, value in zip(inp_node_list, inp_data_list):
node.set_value(Tensor(value)._dev_tensor())
func.execute()
out_data_list = [o.get_value().numpy() for o in out_node_list]
graph = GraphInference(file)
result = graph.run(*inp_data_list)
out_data_list = list(result.values())
return out_data_list return out_data_list


class GraphInference:
"""
Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph.
The `GraphInference.run()` accepts a list `inp_args` or a dict `inp_dict` {input_name: input_value} as input and returns a dict {output_name: output_value}.

:param file: could be file object or filename.
:param outputs: only compile the subgraph with outputs as its endpoints.
"""

def __init__(self, file, outputs: Optional[List[str]] = None):
*_, output_nodes = G.load_graph(file)
if outputs is not None:
output_name = outputs.copy()
all_vars = get_dep_vars(output_nodes) + output_nodes
new_outputs = {}
for i in all_vars:
if i.name in output_name:
new_outputs[i.name] = i
output_name.remove(i.name)
assert (
len(output_name) == 0
), "Can not find varnode {} in this model".format(output_name)
output_nodes = [new_outputs[i] for i in outputs]
inputs = get_dep_vars(output_nodes, "Host2DeviceCopy")
self._inp_dict = OrderedDict()
replace_dict = {}
for i in inputs:
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
self._inp_dict[i.name] = inp_node
replace_dict[i] = inp_node.outputs[0]
new_output_nodes = replace_vars(output_nodes, replace_dict)
for old, new in zip(output_nodes, new_output_nodes):
new.name = old.name
self._out_dict = OrderedDict(
[(i.name, G.OutputNode(i)) for i in new_output_nodes]
)
new_out_list = [i.outputs[0] for i in self._out_dict.values()]
cg = new_out_list[0].graph
self._func = cg.compile(new_out_list)

def run(
self,
*inp_args: numpy.ndarray,
inp_dict: Optional[Dict[str, numpy.ndarray]] = None
):
assert len(inp_args) <= len(
self._inp_dict
), "This model expects {} inputs".format(len(self._inp_dict))
inputs = {}
inp_keys = list(self._inp_dict.keys())
for ind, data in enumerate(inp_args):
inputs[inp_keys[ind]] = data
if inp_dict is not None:
inputs.update(inp_dict)
assert (
inputs.keys() == self._inp_dict.keys()
), "This model expects inputs {}, but gets inputs {}".format(
list(self._inp_dict.keys()), list(inputs.keys())
)
for key in self._inp_dict:
self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor())
self._func.execute()
result = OrderedDict()
for key in self._out_dict:
result[key] = self._out_dict[key].get_value().numpy()
return result

+ 45
- 0
imperative/python/test/unit/test_cgtools.py View File

@@ -139,3 +139,48 @@ def test_get_opr_seq():


seq_2 = cgtools.get_oprs_seq(outputs, False) seq_2 = cgtools.get_oprs_seq(outputs, False)
assert len(seq_2) == 6 assert len(seq_2) == 6


def test_graph_function():
class Net(M.Module):
def forward(self, a, b):
return a - b, a * b

net = Net()

@trace(symbolic=True, capture_as_const=True)
def function(a, b, *, net=None):
return net(a, b)

a = np.array([1, 2, 3])
b = np.array([3])
x, y = function(megengine.tensor(a), megengine.tensor(b), net=net)

file = io.BytesIO()
function.dump(
file,
arg_names=["a", "b"],
output_names=["x", "y"],
optimize_for_inference=False,
)
file.seek(0)

graph = cgtools.GraphInference(file)
results = graph.run(inp_dict={"a": a, "b": b})
np.testing.assert_equal(x.numpy(), results["x"])
np.testing.assert_equal(y.numpy(), results["y"])

results = graph.run(a, inp_dict={"b": b})
np.testing.assert_equal(x.numpy(), results["x"])
np.testing.assert_equal(y.numpy(), results["y"])

results = graph.run(a, b)
np.testing.assert_equal(x.numpy(), results["x"])
np.testing.assert_equal(y.numpy(), results["y"])

file.seek(0)

graph1 = cgtools.GraphInference(file, outputs=["x"])
results = graph1.run(inp_dict={"a": a, "b": b})
np.testing.assert_equal(x.numpy(), results["x"])
assert "y" not in results

Loading…
Cancel
Save