GitOrigin-RevId: ba251f452a
tags/v1.0.0-rc1
@@ -76,6 +76,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level | |||
from .serialization import load, save | |||
from .tensor import Parameter, Tensor, tensor | |||
from .version import __version__ | |||
from .core import cgtools | |||
_set_fork_exec_path_for_timed_func( | |||
sys.executable, | |||
@@ -10,3 +10,5 @@ import os | |||
import sys | |||
from .tensor import Tensor | |||
from .tensor.megbrain_graph import Graph | |||
from .utils import comp_graph_tools as cgtools |
@@ -7,6 +7,7 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import json | |||
import threading | |||
import weakref | |||
from concurrent.futures import Future, ThreadPoolExecutor | |||
@@ -162,14 +163,42 @@ def optimize_for_inference(dest_vars, **kwargs): | |||
return [VarNode(i) for i in res_vars] | |||
def dump(*args): | |||
def dump_graph(*args): | |||
return _imperative_rt.dump_graph([i._node for i in args]) | |||
CompGraphLoadResult = collections.namedtuple( | |||
"CompGraphLoadResult", ["graph", "output_vars_dict", "output_vars_list"] | |||
) | |||
def load_graph(fpath): | |||
"""Load a serialized computing graph from file. | |||
:parma fpath: Path or Handle for the output file | |||
:return: An instance of namedtuple :class:`CompGraphLoadResult`, | |||
whose fields are: | |||
* ``graph`` loaded CompGraph | |||
* ``output_vars_dict`` A Python dict, mapping name to output SymbolVar | |||
* ``output_vars_list`` A Python list, containing output vars in the | |||
order passed to serialize_comp_graph_to_file | |||
""" | |||
output_vars_map = [] | |||
output_vars_list = [] | |||
if isinstance(fpath, str): | |||
buf = open(fpath, "rb").read() | |||
else: | |||
buf = fpath.read() | |||
cg = _imperative_rt.load_graph(buf, output_vars_map, output_vars_list) | |||
return CompGraphLoadResult(cg, dict(output_vars_map), output_vars_list) | |||
class VarNode(TensorBase): | |||
def __init__(self, node: _imperative_rt.VarNode): | |||
self._node = node | |||
self.graph._var_cache[node] = self | |||
if hasattr(self.graph, "_var_cache"): | |||
self.graph._var_cache[node] = self | |||
@property | |||
def graph(self) -> Graph: | |||
@@ -177,12 +206,19 @@ class VarNode(TensorBase): | |||
@property | |||
def op(self): | |||
return self.graph._wrap(self._node.owner) | |||
if hasattr(self.graph, "_wrap"): | |||
return self.graph._wrap(self._node.owner) | |||
else: | |||
return self._node.owner | |||
@property | |||
def name(self): | |||
return self._node.name | |||
@property | |||
def id(self): | |||
return self._node.id | |||
@name.setter | |||
def name(self, name): | |||
self._node.name = name | |||
@@ -207,7 +243,8 @@ class VarNode(TensorBase): | |||
class OpNode: | |||
def __init__(self, node: _imperative_rt.OperatorNode): | |||
self._node = node | |||
self.graph._op_cache[node] = self | |||
if hasattr(self.graph, "_op_cache"): | |||
self.graph._op_cache[node] = self | |||
@property | |||
def graph(self) -> Graph: | |||
@@ -217,29 +254,53 @@ class OpNode: | |||
def name(self): | |||
return self._node.name | |||
@property | |||
def id(self): | |||
return self._node.id | |||
@name.setter | |||
def name(self, name): | |||
self._node.name = name | |||
@property | |||
def inputs(self): | |||
return tuple(map(self.graph._wrap, self._node.inputs)) | |||
if hasattr(self.graph, "_wrap"): | |||
return tuple(map(self.graph._wrap, self._node.inputs)) | |||
else: | |||
return self._node.inputs | |||
@property | |||
def outputs(self): | |||
return tuple(map(self.graph._wrap, self._node.outputs)) | |||
if hasattr(self.graph, "_wrap"): | |||
return tuple(map(self.graph._wrap, self._node.outputs)) | |||
else: | |||
return self._node.outputs | |||
@property | |||
def params(self): | |||
return json.loads(self._node.params) | |||
@property | |||
def type(self): | |||
return self._node.type | |||
def _wrap(x): | |||
if isinstance(x, collections.abc.Sequence): | |||
return type(x)(map(_wrap, x)) | |||
return x.graph._wrap(x) | |||
if hasattr(x.graph, "_wrap"): | |||
return x.graph._wrap(x) | |||
else: | |||
return x | |||
def _unwrap(x): | |||
if isinstance(x, collections.abc.Sequence): | |||
return type(x)(map(_unwrap, x)) | |||
return x._node | |||
if isinstance(x, VarNode): | |||
return x._node | |||
else: | |||
return x | |||
@apply.register() | |||
@@ -0,0 +1,253 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from typing import Dict, List | |||
from .. import _imperative_rt | |||
from .._imperative_rt import OperatorNode, VarNode | |||
def get_dep_vars(var: VarNode, var_type: str = None) -> List[VarNode]: | |||
"""return :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var`` | |||
depands on. If ``var_type`` is None, return all types. | |||
""" | |||
outputs = [] | |||
memo = set() | |||
if isinstance(var, VarNode): | |||
var = [var] | |||
if isinstance(var_type, str): | |||
var_type = [var_type] | |||
q = list(var) | |||
while q: | |||
v = q.pop() | |||
if v in memo: | |||
continue | |||
memo.add(v) | |||
q.extend(get_owner_opr_inputs(v)) | |||
if var_type is not None: | |||
if get_owner_opr_type(v) in var_type: | |||
outputs.append(v) | |||
else: | |||
outputs.append(v) | |||
return outputs | |||
def get_owner_opr_inputs(var: VarNode) -> List[VarNode]: | |||
"""get the inputs of owner opr of a variable | |||
""" | |||
assert isinstance(var, VarNode) | |||
return var.owner.inputs | |||
def get_owner_opr_type(var: VarNode) -> str: | |||
"""get the type of owner opr of a variable | |||
""" | |||
assert isinstance(var, VarNode) | |||
return var.owner.type | |||
def get_opr_type(opr: OperatorNode) -> str: | |||
"""get the type of a opr | |||
""" | |||
assert isinstance(opr, OperatorNode) | |||
return opr.type | |||
def graph_traversal(outputs: VarNode): | |||
"""helper function to traverse the computing graph and return enough useful information | |||
:param outputs: model outputs | |||
:return: tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree) | |||
WHERE | |||
map_oprs is dict from opr_id to actual opr | |||
map_vars is dict from var_id to actual var | |||
var2oprs is dict from var to dest oprs along with index | |||
opr2receivers is dict from current opr to next opr | |||
indegree2opr is dict from in_degree to opr in computing graph | |||
opr2indegree is dict from opr in computing graph to in_degree | |||
(indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function | |||
""" | |||
# meta information for comp graph | |||
map_oprs = collections.defaultdict(set) | |||
map_vars = collections.defaultdict(set) | |||
var2oprs = collections.defaultdict(list) | |||
opr2receivers = collections.defaultdict(list) | |||
queue = list(map(lambda x: x.owner, outputs)) | |||
visited = set(map(lambda x: x.id, queue)) | |||
# iterate through whole comp_graph, fill in meta information | |||
indegree2opr = collections.defaultdict(set) | |||
opr2indegree = {} | |||
idx = 0 | |||
while idx < len(queue): | |||
cur_opr = queue[idx] | |||
map_oprs[cur_opr.id] = cur_opr | |||
idx += 1 | |||
indegree = 0 | |||
for var_idx, var in enumerate(cur_opr.inputs): | |||
map_vars[var.id] = var | |||
var2oprs[var.id].append((cur_opr.id, var_idx)) | |||
pre_opr = var.owner | |||
if pre_opr.id not in visited: | |||
visited.add(pre_opr.id) | |||
queue.append(pre_opr) | |||
indegree += 1 | |||
opr2receivers[pre_opr.id].append(cur_opr.id) | |||
indegree2opr[indegree].add(cur_opr.id) | |||
opr2indegree[cur_opr.id] = indegree | |||
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | |||
def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]: | |||
"""get oprs in some topological order for a dumped model | |||
:param outputs: model outputs | |||
:param prune_reshape: whether to prune the operators useless during inference | |||
:return: opr list with some correct execution order | |||
""" | |||
def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree): | |||
# generate an execution order with topological sort algorithm | |||
oprs_seq = [] | |||
nr_remain = len(map_oprs) | |||
while indegree2opr[0]: | |||
opr_id = indegree2opr[0].pop() | |||
opr = map_oprs[opr_id] | |||
nr_remain -= 1 | |||
# skip const value generation operator | |||
if get_opr_type(opr) != "ImmutableTensor": | |||
oprs_seq.append(opr) | |||
for post_id in opr2receivers[opr_id]: | |||
indegree = opr2indegree[post_id] | |||
indegree2opr[indegree].remove(post_id) | |||
indegree -= 1 | |||
indegree2opr[indegree].add(post_id) | |||
opr2indegree[post_id] = indegree | |||
assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format( | |||
nr_remain | |||
) | |||
return oprs_seq | |||
# reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor | |||
# when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph | |||
def prune_reshape_oprs(outputs, oprs_seq, var2oprs): | |||
def iterative_pruning(cur_opr, post_opr, marked_opr_ids): | |||
useless = True | |||
for oup in cur_opr.outputs: | |||
if "workspace" not in oup.name: | |||
var_idx = post_opr.inputs.index(oup) | |||
var2oprs[oup.id].remove((post_opr.id, var_idx)) | |||
useless = useless and (len(var2oprs[oup.id]) == 0) | |||
if useless: | |||
marked_opr_ids.append(cur_opr.id) | |||
for inp in cur_opr.inputs: | |||
iterative_pruning(inp.owner, cur_opr, marked_opr_ids) | |||
reshape_vars = get_dep_vars(outputs, "Reshape") | |||
reshape_oprs = [var.owner for var in reshape_vars] | |||
marked_opr_ids = [] | |||
for reshape_opr in reshape_oprs: | |||
iterative_pruning(reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids) | |||
# filter out all marked oprs | |||
return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq)) | |||
map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal( | |||
outputs | |||
) | |||
oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree) | |||
if prune_reshape is True: | |||
oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy()) | |||
return oprs_seq | |||
def replace_vars(dst: VarNode, varmap: Dict[VarNode, VarNode]) -> List[VarNode]: | |||
"""replace vars in the graph | |||
:param dst: target vars representing the graph | |||
:param varmap: the map that specifies how to replace the vars | |||
:return: new vars that correspond to ``dst`` with all the dependencies | |||
replaced | |||
""" | |||
dst_vec = [] | |||
repl_src_vec = [] | |||
repl_dst_vec = [] | |||
for i in dst: | |||
assert isinstance(i, VarNode) | |||
dst_vec.append(i) | |||
for i, j in getattr(varmap, "items", lambda: varmap)(): | |||
assert isinstance(i, VarNode) | |||
assert isinstance(j, VarNode) | |||
repl_src_vec.append(i) | |||
repl_dst_vec.append(j) | |||
return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec) | |||
def replace_oprs( | |||
dst: List[VarNode], oprmap: Dict[OperatorNode, OperatorNode] | |||
) -> List[VarNode]: | |||
"""Replace operators in the graph. | |||
:param dst: target vars representing the graph | |||
:param oprmap: the map that specifies how to replace the operators | |||
:return: new vars that correspond to ``dst`` with all the dependencies | |||
replaced | |||
""" | |||
dst_vec = [] | |||
repl_src_vec = [] | |||
repl_dst_vec = [] | |||
for i in dst: | |||
assert isinstance(i, VarNode) | |||
dst_vec.append(i) | |||
for i, j in getattr(oprmap, "items", lambda: oprmap)(): | |||
assert isinstance(i, OperatorNode) | |||
assert isinstance(j, OperatorNode) | |||
repl_src_vec.append(i) | |||
repl_dst_vec.append(j) | |||
return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec) | |||
def set_priority_to_id(dest_vars): | |||
"""For all oprs in the subgraph constructed by dest_vars | |||
set its priority to id if its original priority is zero | |||
:param dest_vars: target vars representing the graph | |||
""" | |||
dest_vec = [] | |||
for i in dest_vars: | |||
assert isinstance(i, VarNode) | |||
dest_vec.append(i) | |||
_imperative_rt.graph._set_priority_to_id(dest_vec) |
@@ -569,7 +569,7 @@ class trace: | |||
if isinstance(file, str): | |||
permission = "wb" if append == False else "ab" | |||
file = open(file, permission) | |||
file.write(G.dump(*dest_vars)) | |||
file.write(G.dump_graph(*dest_vars)) | |||
def _process_inputs(self, *args, **kwargs): | |||
if self._untraced: | |||
@@ -64,7 +64,60 @@ auto def_rendezvous(py::object m, const char* name) { | |||
using TensorAttr = LogicalTensorDesc; | |||
using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>; | |||
std::vector<mgb::cg::VarNode*> _replace_vars(const std::vector<mgb::cg::VarNode*>& repl_src, | |||
const std::vector<mgb::cg::VarNode*>& repl_dst, | |||
const std::vector<mgb::cg::VarNode*>& vars) { | |||
mgb::ThinHashMap<SymbolVar, SymbolVar> varmap; | |||
for (size_t i = 0; i < repl_src.size(); ++i) { | |||
varmap[SymbolVar(repl_src[i])] = SymbolVar(repl_dst[i]); | |||
} | |||
SymbolVarArray symvars(vars.begin(), vars.end()); | |||
auto sym_result = mgb::cg::replace_vars(symvars, varmap); | |||
std::vector<mgb::cg::VarNode*> result; | |||
for (auto symvar : sym_result){ | |||
result.push_back(symvar.node()); | |||
} | |||
return result; | |||
} | |||
typedef std::vector<mgb::cg::OperatorNodeBase*> OperatorArray; | |||
std::vector<mgb::cg::VarNode*> _replace_oprs(const OperatorArray& repl_src, | |||
const OperatorArray& repl_dst, | |||
const std::vector<mgb::cg::VarNode*>& vars) { | |||
mgb::ThinHashMap<mgb::cg::OperatorNodeBase*, mgb::cg::OperatorNodeBase*> | |||
oprmap; | |||
for (size_t i = 0; i < repl_src.size(); ++i) { | |||
oprmap[repl_src[i]] = repl_dst[i]; | |||
} | |||
const SymbolVarArray symvars(vars.begin(), vars.end()); | |||
auto sym_result = mgb::cg::replace_oprs(symvars, oprmap); | |||
std::vector<mgb::cg::VarNode*> result; | |||
for (auto symvar : sym_result){ | |||
result.push_back(symvar.node()); | |||
} | |||
return result; | |||
} | |||
void _set_priority_to_id(const std::vector<mgb::cg::VarNode*>& dest_vars) { | |||
auto on_opr = [](mgb::cg::OperatorNodeBase* opr) { | |||
if (opr->node_prop().attribute().priority == 0) { | |||
opr->node_prop().attribute().priority = opr->id(); | |||
} | |||
}; | |||
mgb::cg::DepOprIter dep_iter{on_opr}; | |||
for (const auto& var : dest_vars) { | |||
dep_iter.add(SymbolVar(var)); | |||
} | |||
} | |||
void init_graph_rt(py::module m) { | |||
static const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr{std::make_unique<mgb::OprFootprint>()}; | |||
def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous"); | |||
def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous"); | |||
@@ -99,7 +152,10 @@ void init_graph_rt(py::module m) { | |||
return py::none(); | |||
} | |||
return py::cast(*val).attr("numpy")(); | |||
}); | |||
}) | |||
.def_property_readonly("id",[](cg::VarNode* v){ | |||
return (v->id()); | |||
}); | |||
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode") | |||
.def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();}) | |||
@@ -110,7 +166,17 @@ void init_graph_rt(py::module m) { | |||
}) | |||
.def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) { | |||
return to_tuple(opr->usable_output()); | |||
}); | |||
}) | |||
.def_property_readonly("id",[](cg::OperatorNodeBase* opr){ | |||
return opr->id(); | |||
}) | |||
.def_property_readonly("params",[](cg::OperatorNodeBase* opr){ | |||
return _imperative_sm_opr_footprint_ptr->calc_footprint(opr).param->to_string(); | |||
}) | |||
.def_property_readonly("type",[](cg::OperatorNodeBase* opr){ | |||
return opr->dyn_typeinfo()->name; | |||
}); | |||
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") | |||
.def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>()) | |||
@@ -174,6 +240,44 @@ void init_graph_rt(py::module m) { | |||
}); | |||
m.def("load_graph", [](std::string& buf, py::list& _output_var_map, py::list& _output_var_list) { | |||
using namespace mgb::serialization; | |||
auto file = InputFile::make_mem_proxy(buf.c_str(), buf.length()); | |||
auto format = GraphLoader::identify_graph_dump_format(*file); | |||
auto loader = GraphLoader::make(std::move(file), format.val()); | |||
GraphLoader::LoadConfig config; | |||
auto rst = loader->load(config); | |||
std::vector<std::pair<std::string, SymbolVar>> output_var_map; | |||
SymbolVarArray output_var_list; | |||
output_var_map = {rst.output_var_map.begin(), rst.output_var_map.end()}; | |||
output_var_list = std::move(rst.output_var_list); | |||
for (auto i : output_var_list){ | |||
_output_var_list.append(i.node()); | |||
} | |||
for (auto i : output_var_map){ | |||
_output_var_map.append(py::make_tuple(i.first,i.second.node())); | |||
} | |||
std::unordered_map<HostTensorND*, const std::string*> tensor2name; | |||
for (const auto& pair : rst.tensor_map) { | |||
tensor2name[pair.second.get()] = &pair.first; | |||
} | |||
auto cb = [&tensor2name, graph=rst.graph](cg::OperatorNodeBase* opr) { | |||
if (!opr->same_type<opr::Host2DeviceCopy>()) | |||
return; | |||
auto& h2d = opr->cast_final_safe<opr::Host2DeviceCopy>(); | |||
auto it = tensor2name.find(h2d.host_data().get()); | |||
mgb_throw_if(it == tensor2name.end(), GraphError, | |||
"unbound Host2DeviceCopy in loaded graph"); | |||
h2d.output(0)->name(*it->second); | |||
}; | |||
cg::DepOprIter iter{cb}; | |||
for (const auto& var : output_var_list) { | |||
iter.add(var.node()->owner_opr()); | |||
} | |||
return rst.graph; | |||
}); | |||
#define CURRENT_CLASS cg::ComputingGraph::Options | |||
auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options") | |||
@@ -287,6 +391,10 @@ void init_graph_rt(py::module m) { | |||
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, shape, dtype), config).node(); | |||
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none()); | |||
m.def("_replace_vars", &_replace_vars,py::arg(),py::arg(),py::arg()); | |||
m.def("_replace_oprs", &_replace_oprs,py::arg(),py::arg(),py::arg()); | |||
m.def("_set_priority_to_id",&_set_priority_to_id,py::arg()); | |||
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, | |||
const CompNode& comp_node, | |||
const DType& dtype, | |||
@@ -16,7 +16,7 @@ | |||
#include <memory> | |||
#include <mutex> | |||
#include <future> | |||
#include "megbrain/plugin/opr_footprint.h" | |||
#include "megbrain/graph.h" | |||
template<typename T> | |||
@@ -0,0 +1,90 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import io | |||
import numpy as np | |||
import megengine | |||
import megengine.functional as F | |||
import megengine.module as M | |||
from megengine import cgtools | |||
from megengine.core.tensor import megbrain_graph as mgb_graph | |||
from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
from megengine.jit import trace | |||
def make_dev_tensor(value, dtype=None, device=None): | |||
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor() | |||
def test_replace_vars(): | |||
g = mgb_graph.Graph() | |||
g.options.async_exec_level = 0b100 | |||
device = "xpux" | |||
dtype = np.float32 | |||
a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g) | |||
const = g.make_const(1.234) | |||
a_plus_a = F.add(a.outputs[0], a.outputs[0]) | |||
a_plus_a_mul_const = F.mul(a_plus_a, const) | |||
rst = F.add(a_plus_a_mul_const, a.outputs[0]) | |||
(new,) = cgtools.replace_vars([rst._node], {const._node: a_plus_a._node}) | |||
out = mgb_graph.OutputNode(mgb_graph.VarNode(new)) | |||
func = g.compile(out.outputs[0]) | |||
func.execute() | |||
x = make_dev_tensor(5.0, device=device) | |||
a.set_value(x) | |||
res = out.get_value().numpy() | |||
np.testing.assert_equal(res, np.array([105.0])) | |||
def test_replace_oprs(): | |||
g = mgb_graph.Graph() | |||
g.options.async_exec_level = 0b100 | |||
device = "xpux" | |||
dtype = np.float32 | |||
a = mgb_graph.InputNode(device=device, dtype=dtype, graph=g) | |||
const = g.make_const(1.25) | |||
a_plus_a = F.add(a.outputs[0], a.outputs[0]) | |||
old_opr = a_plus_a.op | |||
a_plus_a_mul_const = F.mul(a_plus_a, const) | |||
a_mul_a = F.mul(a.outputs[0], a.outputs[0]) | |||
new_opr = a_mul_a.op | |||
(new,) = cgtools.replace_oprs( | |||
[a_plus_a_mul_const._node], {old_opr._node: new_opr._node} | |||
) | |||
out = mgb_graph.OutputNode(mgb_graph.VarNode(new)) | |||
func = g.compile(out.outputs[0]) | |||
func.execute() | |||
x = make_dev_tensor(5.0, device=device) | |||
a.set_value(x) | |||
res = out.get_value().numpy() | |||
np.testing.assert_equal(res, np.array([5.0 * 5.0 * 1.25])) | |||
def test_graph_traversal(): | |||
net = M.Conv2d(3, 32, 3) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fun(data): | |||
x = net(data) | |||
return x | |||
data = np.random.random([1, 3, 224, 224]).astype(np.float32) | |||
for i in range(3): | |||
fun(megengine.tensor(data)) | |||
file = io.BytesIO() | |||
fun.dump(file) | |||
file.seek(0) | |||
cg, _, outputs = mgb_graph.load_graph(file) | |||
_, map_vars, var2oprs, *_ = cgtools.graph_traversal(outputs) | |||
input_var = map_vars[1] | |||
_, var_idx = var2oprs[input_var.id][0] | |||
assert var_idx == 0 |
@@ -13,6 +13,10 @@ import numpy as np | |||
import pytest | |||
from megengine import tensor | |||
import megengine | |||
import megengine.core.tensor.megbrain_graph as mgb_graph | |||
import megengine.module as M | |||
from megengine import cgtools | |||
from megengine.core.ops import builtin as ops | |||
from megengine.core.tensor import megbrain_graph as G | |||
from megengine.core.tensor.core import apply | |||
@@ -21,6 +25,29 @@ from megengine.functional import exp, log | |||
from megengine.jit import exclude_from_trace, trace | |||
def load_and_inference(file, inp_data): | |||
cg, _, out_list = mgb_graph.load_graph(file) | |||
inputs = cgtools.get_dep_vars(out_list, "Host2DeviceCopy") | |||
replace_dict = {} | |||
inp_node_list = [] | |||
for i in inputs: | |||
inp_node = mgb_graph.InputNode( | |||
device="xpux", dtype=inputs[0].dtype, graph=inputs[0].graph | |||
) | |||
replace_dict[i] = inp_node.outputs[0] | |||
inp_node_list.append(inp_node) | |||
new_out = cgtools.replace_vars(out_list, replace_dict) | |||
out_node_list = [mgb_graph.OutputNode(i) for i in new_out] | |||
new_out_list = [i.outputs[0] for i in out_node_list] | |||
new_cg = new_out_list[0].graph | |||
func = new_cg.compile(new_out_list) | |||
for node, value in zip(inp_node_list, inp_data): | |||
node.set_value(as_raw_tensor(value)._dev_tensor()) | |||
func.execute() | |||
out_data_list = [o.get_value().numpy() for o in out_node_list] | |||
return out_data_list | |||
def test_trace(): | |||
for symbolic in [False, True]: | |||
@@ -82,12 +109,57 @@ def test_print_in_trace(): | |||
def test_dump(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(a, b): | |||
op = ops.Elemwise(mode="add") | |||
(y,) = apply(op, a, b) | |||
return y | |||
a = as_raw_tensor([2]).numpy() | |||
b = as_raw_tensor([4]).numpy() | |||
y = f.__wrapped__(as_raw_tensor(a), as_raw_tensor(b)).numpy() | |||
for i in range(3): | |||
np.testing.assert_equal(f(as_raw_tensor(a), as_raw_tensor(b)).numpy(), y) | |||
file = io.BytesIO() | |||
f.dump(file) | |||
file.seek(0) | |||
result = load_and_inference(file, [a, b]) | |||
np.testing.assert_equal(result[0], y) | |||
def test_capture_dump(): | |||
a = as_raw_tensor([2]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(x): | |||
op = ops.Elemwise(mode="mul") | |||
(y,) = apply(op, x, a) | |||
return y | |||
x = as_raw_tensor([3]).numpy() | |||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
for i in range(3): | |||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
file = io.BytesIO() | |||
f.dump(file) | |||
file.seek(0) | |||
result = load_and_inference(file, [x]) | |||
np.testing.assert_equal(result[0], y) | |||
def test_dump_volatile(): | |||
p = as_raw_tensor([2]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(x): | |||
op = ops.Elemwise(mode="negate") | |||
(y,) = apply(op, x) | |||
op = ops.Elemwise(mode="mul") | |||
(y,) = apply(op, x, p) | |||
return y | |||
x = as_raw_tensor([1]).numpy() | |||
x = as_raw_tensor([3]).numpy() | |||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
for i in range(3): | |||
@@ -95,6 +167,13 @@ def test_dump(): | |||
file = io.BytesIO() | |||
f.dump(file) | |||
file.seek(0) | |||
cg, _, outputs = mgb_graph.load_graph(file) | |||
(out,) = outputs | |||
assert ( | |||
cgtools.get_owner_opr_type(cgtools.get_owner_opr_inputs(out)[1]) | |||
== "SharedDeviceTensor" | |||
) | |||
def test_trace_profiler(): | |||
@@ -471,11 +471,9 @@ def main(): | |||
assert not testcase, 'extra inputs provided in testcase: {}'.format( | |||
testcase.keys() | |||
) | |||
mgb.serialize_comp_graph_to_file( | |||
args.output, | |||
output_mgbvars, | |||
append=True, | |||
output_strip_info=args.output_strip_info) | |||
with open(args.output, "ab") as fout: | |||
fout.write(G.dump_graph(*output_mgbvars)) | |||
if __name__ == '__main__': | |||
main() |