@@ -6,7 +6,8 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
from typing import Dict, List
from collections import OrderedDict
from typing import Dict, List, Optional
import numpy
@@ -27,6 +28,7 @@ __all__ = [
"replace_oprs",
"set_priority_to_id",
"load_and_inference",
"GraphInference",
]
@@ -46,7 +48,7 @@ def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]:
q = list(var)
while q:
v = q.pop()
v = q.pop(0 )
if v in memo:
continue
memo.add(v)
@@ -281,23 +283,77 @@ def load_and_inference(file, inp_data_list: List[numpy.ndarray]) -> List[numpy.n
:return: list of inference results.
"""
*_, out_list = G.load_graph(file)
inputs = get_dep_vars(out_list, "Host2DeviceCopy")
replace_dict = {}
inp_node_list = []
for i in inputs:
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
replace_dict[i] = inp_node.outputs[0]
inp_node_list.append(inp_node)
new_out = replace_vars(out_list, replace_dict)
out_node_list = [G.OutputNode(i) for i in new_out]
new_out_list = [i.outputs[0] for i in out_node_list]
cg = new_out_list[0].graph
func = cg.compile(new_out_list)
for node, value in zip(inp_node_list, inp_data_list):
node.set_value(Tensor(value)._dev_tensor())
func.execute()
out_data_list = [o.get_value().numpy() for o in out_node_list]
graph = GraphInference(file)
result = graph.run(*inp_data_list)
out_data_list = list(result.values())
return out_data_list
class GraphInference:
"""
Loads a serialized computing graph as a GraphInference object which can be used to execute the computing graph.
The `GraphInference.run()` accepts a list `inp_args` or a dict `inp_dict` {input_name: input_value} as input and returns a dict {output_name: output_value}.
:param file: could be file object or filename.
:param outputs: only compile the subgraph with outputs as its endpoints.
"""
def __init__(self, file, outputs: Optional[List[str]] = None):
*_, output_nodes = G.load_graph(file)
if outputs is not None:
output_name = outputs.copy()
all_vars = get_dep_vars(output_nodes) + output_nodes
new_outputs = {}
for i in all_vars:
if i.name in output_name:
new_outputs[i.name] = i
output_name.remove(i.name)
assert (
len(output_name) == 0
), "Can not find varnode {} in this model".format(output_name)
output_nodes = [new_outputs[i] for i in outputs]
inputs = get_dep_vars(output_nodes, "Host2DeviceCopy")
self._inp_dict = OrderedDict()
replace_dict = {}
for i in inputs:
inp_node = G.InputNode(
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph
)
self._inp_dict[i.name] = inp_node
replace_dict[i] = inp_node.outputs[0]
new_output_nodes = replace_vars(output_nodes, replace_dict)
for old, new in zip(output_nodes, new_output_nodes):
new.name = old.name
self._out_dict = OrderedDict(
[(i.name, G.OutputNode(i)) for i in new_output_nodes]
)
new_out_list = [i.outputs[0] for i in self._out_dict.values()]
cg = new_out_list[0].graph
self._func = cg.compile(new_out_list)
def run(
self,
*inp_args: numpy.ndarray,
inp_dict: Optional[Dict[str, numpy.ndarray]] = None
):
assert len(inp_args) <= len(
self._inp_dict
), "This model expects {} inputs".format(len(self._inp_dict))
inputs = {}
inp_keys = list(self._inp_dict.keys())
for ind, data in enumerate(inp_args):
inputs[inp_keys[ind]] = data
if inp_dict is not None:
inputs.update(inp_dict)
assert (
inputs.keys() == self._inp_dict.keys()
), "This model expects inputs {}, but gets inputs {}".format(
list(self._inp_dict.keys()), list(inputs.keys())
)
for key in self._inp_dict:
self._inp_dict[key].set_value(Tensor(inputs[key])._dev_tensor())
self._func.execute()
result = OrderedDict()
for key in self._out_dict:
result[key] = self._out_dict[key].get_value().numpy()
return result