Browse Source

feat(mge/utils): add GraphInference in cgtools

GitOrigin-RevId: 72f2201169
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
8e50a6daa7
2 changed files with 122 additions and 21 deletions
  1. +77
    -21
      imperative/python/megengine/utils/comp_graph_tools.py
  2. +45
    -0
      imperative/python/test/unit/test_cgtools.py

+ 77
- 21
imperative/python/megengine/utils/comp_graph_tools.py View File

@@ -6,7 +6,8 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Dict, List
from collections import OrderedDict
from typing import Dict, List, Optional

import numpy

@@ -27,6 +28,7 @@ __all__ = [
"replace_oprs",
"set_priority_to_id",
"load_and_inference",
"GraphInference",
]


@@ -46,7 +48,7 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:

q = list(var)
while q:
v = q.pop()
v = q.pop(0)
if v in memo:
continue
memo.add(v)
@@ -281,23 +283,77 @@ def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.n
:return: list of inference results.

"""
*_, out_list = G.load_graph(file)
inputs = get_dep_vars(out_list, "Host2DeviceCopy")
replace_dict = {}
inp_node_list = []
for i in inputs:
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
replace_dict[i] = inp_node.outputs[0]
inp_node_list.append(inp_node)
new_out = replace_vars(out_list, replace_dict)
out_node_list = [G.OutputNode(i) for i in new_out]
new_out_list = [i.outputs[0] for i in out_node_list]
cg = new_out_list[0].graph
func = cg.compile(new_out_list)
for node, value in zip(inp_node_list, inp_data_list):
node.set_value(Tensor(value)._dev_tensor())
func.execute()
out_data_list = [o.get_value().numpy() for o in out_node_list]
graph = GraphInference(file)
result = graph.run(*inp_data_list)
out_data_list = list(result.values())
return out_data_list


class GraphInference:
"""
Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph.
The `GraphInference.run()` accepts a list `inp_args` or a dict `inp_dict` {input_name: input_value} as input and returns a dict {output_name: output_value}.

:param file: could be file object or filename.
:param outputs: only compile the subgraph with outputs as its endpoints.
"""

def __init__(self, file, outputs: Optional[List[str]] = None):
*_, output_nodes = G.load_graph(file)
if outputs is not None:
output_name = outputs.copy()
all_vars = get_dep_vars(output_nodes) + output_nodes
new_outputs = {}
for i in all_vars:
if i.name in output_name:
new_outputs[i.name] = i
output_name.remove(i.name)
assert (
len(output_name) == 0
), "Can not find varnode {} in this model".format(output_name)
output_nodes = [new_outputs[i] for i in outputs]
inputs = get_dep_vars(output_nodes, "Host2DeviceCopy")
self._inp_dict = OrderedDict()
replace_dict = {}
for i in inputs:
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
self._inp_dict[i.name] = inp_node
replace_dict[i] = inp_node.outputs[0]
new_output_nodes = replace_vars(output_nodes, replace_dict)
for old, new in zip(output_nodes, new_output_nodes):
new.name = old.name
self._out_dict = OrderedDict(
[(i.name, G.OutputNode(i)) for i in new_output_nodes]
)
new_out_list = [i.outputs[0] for i in self._out_dict.values()]
cg = new_out_list[0].graph
self._func = cg.compile(new_out_list)

def run(
self,
*inp_args: numpy.ndarray,
inp_dict: Optional[Dict[str, numpy.ndarray]] = None
):
assert len(inp_args) <= len(
self._inp_dict
), "This model expects {} inputs".format(len(self._inp_dict))
inputs = {}
inp_keys = list(self._inp_dict.keys())
for ind, data in enumerate(inp_args):
inputs[inp_keys[ind]] = data
if inp_dict is not None:
inputs.update(inp_dict)
assert (
inputs.keys() == self._inp_dict.keys()
), "This model expects inputs {}, but gets inputs {}".format(
list(self._inp_dict.keys()), list(inputs.keys())
)
for key in self._inp_dict:
self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor())
self._func.execute()
result = OrderedDict()
for key in self._out_dict:
result[key] = self._out_dict[key].get_value().numpy()
return result

+ 45
- 0
imperative/python/test/unit/test_cgtools.py View File

@@ -139,3 +139,48 @@ def test_get_opr_seq():

seq_2 = cgtools.get_oprs_seq(outputs, False)
assert len(seq_2) == 6


def test_graph_function():
class Net(M.Module):
def forward(self, a, b):
return a - b, a * b

net = Net()

@trace(symbolic=True, capture_as_const=True)
def function(a, b, *, net=None):
return net(a, b)

a = np.array([1, 2, 3])
b = np.array([3])
x, y = function(megengine.tensor(a), megengine.tensor(b), net=net)

file = io.BytesIO()
function.dump(
file,
arg_names=["a", "b"],
output_names=["x", "y"],
optimize_for_inference=False,
)
file.seek(0)

graph = cgtools.GraphInference(file)
results = graph.run(inp_dict={"a": a, "b": b})
np.testing.assert_equal(x.numpy(), results["x"])
np.testing.assert_equal(y.numpy(), results["y"])

results = graph.run(a, inp_dict={"b": b})
np.testing.assert_equal(x.numpy(), results["x"])
np.testing.assert_equal(y.numpy(), results["y"])

results = graph.run(a, b)
np.testing.assert_equal(x.numpy(), results["x"])
np.testing.assert_equal(y.numpy(), results["y"])

file.seek(0)

graph1 = cgtools.GraphInference(file, outputs=["x"])
results = graph1.run(inp_dict={"a": a, "b": b})
np.testing.assert_equal(x.numpy(), results["x"])
assert "y" not in results

Loading…
Cancel
Save