GitOrigin-RevId: 6a9d5beba2
tags/v1.3.0
@@ -11,6 +11,7 @@ from typing import Iterable, Union | |||
import numpy as np | |||
from .._imperative_rt import VarNode | |||
from .._imperative_rt.core2 import Tensor, apply, dtype_promotion, get_device | |||
from ..ops import builtin | |||
from ..ops.special import Const | |||
@@ -59,7 +60,7 @@ def astype(x, dtype): | |||
def convert_single_value(v, *, dtype=None, device=None): | |||
if isinstance(v, Tensor): | |||
if isinstance(v, (Tensor, VarNode)): | |||
if not is_quantize(v.dtype): | |||
v = astype(v, dtype) | |||
else: | |||
@@ -12,11 +12,12 @@ import functools | |||
import numpy as np | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core._imperative_rt.graph import VarNode | |||
from ..core.ops import builtin | |||
from ..core.ops.builtin import Elemwise | |||
from ..core.tensor import utils | |||
from ..core.tensor.array_method import _elwise_apply | |||
from ..core.tensor.utils import isscalar, setscalar | |||
from ..core.tensor.utils import astype, isscalar, setscalar | |||
from ..device import get_default_device | |||
from ..jit.tracing import is_tracing | |||
from ..tensor import Tensor | |||
@@ -77,7 +78,7 @@ __all__ = [ | |||
def _elwise(*args, mode): | |||
tensor_args = list(filter(lambda x: isinstance(x, Tensor), args)) | |||
tensor_args = list(filter(lambda x: isinstance(x, (Tensor, VarNode)), args)) | |||
if len(tensor_args) == 0: | |||
dtype = utils.dtype_promotion(args) | |||
first_arg = Tensor(args[0], dtype=dtype, device=get_default_device()) | |||
@@ -109,7 +110,7 @@ def _elwise(*args, mode): | |||
Elemwise.Mode.ROUND, | |||
) and np.issubdtype(args[0].dtype, np.integer): | |||
return args[0] | |||
args = tuple(map(lambda x: x.astype("float32"), args)) | |||
args = tuple(map(lambda x: astype(x, "float32"), args)) | |||
return _elwise_apply(args, mode) | |||
@@ -65,7 +65,6 @@ def get_owner_opr_inputs(var: VarNode) -> List[VarNode]: | |||
""" | |||
Gets the inputs of owner opr of a variable. | |||
""" | |||
assert isinstance(var, VarNode) | |||
return var.owner.inputs | |||
@@ -74,7 +73,6 @@ def get_owner_opr_type(var: VarNode) -> str: | |||
Gets the type of owner opr of a variable. | |||
""" | |||
assert isinstance(var, VarNode) | |||
return var.owner.type | |||
@@ -109,7 +107,7 @@ def graph_traversal(outputs: VarNode): | |||
var2oprs = collections.defaultdict(list) | |||
opr2receivers = collections.defaultdict(list) | |||
queue = list(map(lambda x: x.owner, outputs)) | |||
queue = list(set(map(lambda x: x.owner, outputs))) | |||
visited = set(map(lambda x: x.id, queue)) | |||
# iterate through whole comp_graph, fill in meta information | |||
@@ -143,12 +141,15 @@ def graph_traversal(outputs: VarNode): | |||
return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree | |||
def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNode]: | |||
def get_oprs_seq( | |||
outputs: List[VarNode], prune_reshape=False, prune_immtensor=True | |||
) -> List[OperatorNode]: | |||
""" | |||
Gets oprs in some topological order for a dumped model. | |||
:param outputs: model outputs. | |||
:param prune_reshape: whether to prune the useless operators during inference. | |||
:param prune_reshape: whether to prune the useless operators used by Reshape opr during inference. | |||
:param prune_immtensor: whether to prune the ImmutableTensor opr. | |||
:return: opr list with some correct execution order. | |||
""" | |||
@@ -160,9 +161,7 @@ def get_oprs_seq(outputs: List[VarNode], prune_reshape=False) -> List[OperatorNo | |||
opr_id = indegree2opr[0].pop() | |||
opr = map_oprs[opr_id] | |||
nr_remain -= 1 | |||
# skip const value generation operator | |||
if get_opr_type(opr) != "ImmutableTensor": | |||
if opr.type != "ImmutableTensor" or not prune_immtensor: | |||
oprs_seq.append(opr) | |||
for post_id in opr2receivers[opr_id]: | |||
@@ -0,0 +1,682 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import fnmatch | |||
import itertools | |||
import re | |||
from collections import OrderedDict | |||
from typing import Dict, List | |||
import numpy as np | |||
from ..core._imperative_rt import ComputingGraph | |||
from ..core.tensor import megbrain_graph as G | |||
from .comp_graph_tools import get_dep_vars, get_opr_type, get_oprs_seq | |||
from .network_node import ( | |||
NetworkNode, | |||
Host2DeviceCopy, | |||
ImmutableTensor, | |||
OpNode, | |||
VarNode, | |||
str_to_mge_class, | |||
) | |||
class Network: | |||
def __init__(self): | |||
self.input_vars = [] # input var of graph | |||
self._orig_inputs = [] | |||
self.output_vars = [] # output var of graph | |||
self._orig_outputs = [] | |||
self.all_oprs_map = OrderedDict() | |||
self.all_vars_map = OrderedDict() | |||
self.graph = ComputingGraph() | |||
@classmethod | |||
def load(cls, model_path: str, outspec: List[str] = None): | |||
""" | |||
Loads a computing graph as a Network object. | |||
:param model_path: file path of mge model. | |||
:param outspec: only load the subgraph with outspec as its endpoints. | |||
""" | |||
self = cls() | |||
_, _, outputs = G.load_graph(model_path) | |||
if outspec is not None: | |||
output_spec = outspec.copy() | |||
all_vars = get_dep_vars(outputs) + outputs | |||
new_outputs = {} | |||
for i in all_vars: | |||
if i.name in output_spec: | |||
new_outputs[i.name] = i | |||
output_spec.remove(i.name) | |||
assert len(output_spec) == 0, "Can not find {} in this model".format( | |||
output_spec | |||
) | |||
outputs = [new_outputs[i] for i in outspec] | |||
self._orig_outputs = outputs | |||
self.add_dep_oprs(*outputs) | |||
for x in self._orig_inputs: | |||
self.input_vars.append(self._get_var(x)) | |||
for x in self._orig_outputs: | |||
self.output_vars.append(self._get_var(x)) | |||
self.graph = self._orig_outputs[0].graph | |||
return self | |||
def _compile(self): | |||
self.all_oprs_map = {} | |||
self.all_vars_map = {} | |||
for opr in self.all_oprs: | |||
if isinstance(opr, (ImmutableTensor, Host2DeviceCopy)): | |||
opr.compile(self.graph) | |||
else: | |||
opr.compile() | |||
if opr.name is not None: | |||
opr._opr.name = opr.name | |||
self.all_oprs_map[opr._opr.id] = opr | |||
for o in opr.outputs: | |||
self.all_vars_map[o.var.id] = o | |||
def dump( | |||
self, | |||
file, | |||
*, | |||
keep_var_name: int = 1, | |||
keep_opr_name: bool = False, | |||
keep_param_name: bool = False, | |||
keep_opr_priority: bool = False, | |||
strip_info_file=None, | |||
append_json=False, | |||
optimize_for_inference=True, | |||
append=False, | |||
**kwargs | |||
): | |||
""" | |||
Serializes graph to file. | |||
:param file: output file, could be file object or filename. | |||
:param append: whether output is appended to ``file``. | |||
Only works when ``file`` is str. | |||
:param keep_var_name: level for keeping variable names: | |||
* 0: none of the names are kept | |||
* 1: (default)keep names of output vars | |||
* 2: keep names of all (output and internal) vars | |||
:param keep_opr_name: whether to keep operator names. | |||
:param keep_param_name: whether to keep param names, so param values can be | |||
easily manipulated after loading model | |||
:param keep_opr_priority: whether to keep priority setting for operators | |||
:param strip_info_file: a string for path or a file handler. if is not None, | |||
then the dump information for code strip would be written to ``strip_info_file`` | |||
:param append_json: will be check when `strip_info_file` is not None. if set | |||
true, the information for code strip will be append to strip_info_file. | |||
if set false, will rewrite strip_info_file | |||
:param optimize_for_inference: enbale optmizations, | |||
will skip all optimize options if this is False. Default: True | |||
:Keyword Arguments: | |||
* enable_io16xc32 -- | |||
whether to use float16 for I/O between oprs and use | |||
float32 as internal computation precision. Note the output var would be | |||
changed to float16. | |||
* enable_ioc16 -- | |||
whether to use float16 for both I/O and computation | |||
precision. | |||
* enable_hwcd4 -- | |||
whether to use NHWCD4 data layout. This is faster on some | |||
OpenCL backend. | |||
* enable_nchw88 -- | |||
whether to use NCHW88 data layout, currently | |||
used in X86 AVX backend. | |||
* enable_nchw44 -- | |||
whether to use NCHW44 data layout, currently | |||
used in arm backend. | |||
* enable_nchw44_dot -- | |||
whether to use NCHW44_dot data layout, currently | |||
used in armv8.2+dotprod backend. | |||
* enable_nchw4 -- | |||
whether to use NCHW4 data layout, currently | |||
used in nvidia backend(based on cudnn). | |||
* enable_nchw32 -- | |||
whether to use NCHW32 data layout, currently | |||
used in nvidia backend with tensorcore(based on cudnn). | |||
* enable_chwn4 -- | |||
whether to use CHWN4 data layout, currently | |||
used in nvidia backend with tensorcore. | |||
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty | |||
into one opr. | |||
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||
input for inference on nvidia backend(this optimization pass will | |||
result in mismatch of the precision of output of training and | |||
inference) | |||
""" | |||
self._compile() | |||
out = [G.VarNode(var.var) for var in self.output_vars] | |||
if optimize_for_inference: | |||
out = G.optimize_for_inference(out, **kwargs) | |||
dump_content, _ = G.dump_graph( | |||
out, | |||
keep_var_name=keep_var_name, | |||
keep_opr_name=keep_opr_name, | |||
keep_param_name=keep_param_name, | |||
keep_opr_priority=keep_opr_priority, | |||
strip_info_file=strip_info_file, | |||
append_json=append_json, | |||
) | |||
if isinstance(file, str): | |||
permission = "wb" if append == False else "ab" | |||
file = open(file, permission) | |||
file.write(dump_content) | |||
def make_const(self, data, name=None, device=None): | |||
"""Makes an ImmutableTensor OpNode to provide a parameter for the network. | |||
""" | |||
node = ImmutableTensor(data, name, device, self.graph) | |||
node.compile(self.graph) | |||
return node.outputs[0] | |||
def make_input_node(self, shape, dtype, name=None, device=None): | |||
"""Makes a Host2DeviceCopy OpNode to provide an input varnode for the network. | |||
""" | |||
node = Host2DeviceCopy(shape, dtype, name, device) | |||
node.compile(self.graph) | |||
return node.outputs[0] | |||
def add_output(self, *vars: VarNode): | |||
"""Adds vars into the network output node list | |||
""" | |||
for var in vars: | |||
if var not in self.output_vars: | |||
self.output_vars.append(var) | |||
def remove_output(self, *vars: VarNode): | |||
"""Removes vars from the network output node list. | |||
""" | |||
for var in vars: | |||
if var in self.output_vars: | |||
self.output_vars.remove(var) | |||
def add_dep_oprs(self, *vars): | |||
"""Adds dependent opnodes and varnodes of vars into network | |||
""" | |||
oprs = get_oprs_seq(vars, False, False) | |||
for mge_opr in oprs: | |||
if get_opr_type(mge_opr) == "Host2DeviceCopy": | |||
self._orig_inputs.extend(mge_opr.outputs) | |||
opr = self._add_opr(mge_opr) | |||
if opr is not None: | |||
for x in mge_opr.inputs: | |||
opr.add_inp_var(self._get_var(x)) | |||
# set out var | |||
for x in mge_opr.outputs: | |||
opr.add_out_var(self._get_var(x)) | |||
return [self.all_vars_map[var.id] for var in vars] | |||
def modify_opr_names(self, modifier): | |||
"""Modifies names of operators **inplace**; useful for merging loaded | |||
network into another network | |||
:param modifier: a string to be prepended to the name, or a function | |||
that maps from name to name | |||
:type modifier: str or callable | |||
""" | |||
if isinstance(modifier, str): | |||
om = modifier | |||
modifier = lambda v: "{}.{}".format(om, v) | |||
assert isinstance(modifier, collections.Callable) | |||
for i in self.all_oprs: | |||
v0 = i.name | |||
v1 = modifier(v0) | |||
assert isinstance(v1, str) | |||
i.name = v1 | |||
def reset_batch_size(self, batchsize, *, blacklist=()): | |||
"""Helper for reset batch size; first dimension of all data providers | |||
not in blacklist are assumed to be the batch size | |||
:param blacklist: data provider names whose first dimension is not | |||
batchbatch size | |||
""" | |||
blacklist = set(blacklist) | |||
prev_batchsize = None | |||
for i in self.data_providers_filter: | |||
if i.name in blacklist: | |||
blacklist.remove(i.name) | |||
else: | |||
shp = list(i.shape) | |||
if prev_batchsize is None: | |||
prev_batchsize = shp[0] | |||
else: | |||
assert prev_batchsize == shp[0], ( | |||
"batchsize mismatch: batchsize={} " | |||
"shape={} dp={}".format(prev_batchsize, shp, i.name) | |||
) | |||
shp[0] = batchsize | |||
i.shape = tuple(shp) | |||
assert prev_batchsize is not None, "no data provider found" | |||
assert not blacklist, "unused items in blacklist: {}".format(blacklist) | |||
def replace_vars(self, repl_dict: Dict[VarNode, VarNode]): | |||
""" | |||
Replaces vars in the graph. | |||
:param repl_dict: the map {old_var: new_var} that specifies how to replace the vars. | |||
""" | |||
for var in self.all_vars: | |||
if var in repl_dict: | |||
repl_var = repl_dict[var] | |||
owner = repl_var.owner | |||
idx = owner.outputs.index(repl_var) | |||
owner.outputs[idx] = var | |||
var.__dict__.update(repl_var.__dict__) | |||
def replace_oprs(self, repl_dict: Dict[OpNode, OpNode]): | |||
""" | |||
Replaces operators in the graph. | |||
:param oprmap: the map {old_opr: new_opr} that specifies how to replace the operators. | |||
""" | |||
for opr in self.all_oprs: | |||
if opr in repl_dict: | |||
assert len(opr.outputs) == len( | |||
repl_dict[opr].outputs | |||
), "can not replace {} with {}".format(type(opr), type(repl_dict[opr])) | |||
repl_dict[opr].outputs = opr.outputs | |||
for ind, var in enumerate(opr.outputs): | |||
var.owner = repl_dict[opr] | |||
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) | |||
def get_opr_by_type(self, oprcls, unique=True): | |||
assert issubclass(oprcls, OpNode) | |||
rst = self.opr_filter.type(oprcls).as_list() | |||
if unique: | |||
assert len(rst) == 1, "{} operators of type {} found".format( | |||
len(rst), oprcls | |||
) | |||
(rst,) = rst | |||
return rst | |||
def get_opr_by_name(self, name, unique=True): | |||
rst = self.opr_filter.name(name).as_list() | |||
if unique: | |||
assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name) | |||
(rst,) = rst | |||
return rst | |||
def get_var_by_name(self, name, unique=True): | |||
rst = self.var_filter.name(name).as_list() | |||
if unique: | |||
assert len(rst) == 1, "{} operators of type {} found".format(len(rst), name) | |||
(rst,) = rst | |||
return rst | |||
def get_var_receive_oprs(self, var): | |||
""" Gets all oprs which use var as input | |||
""" | |||
return self.opr_filter.has_input(var).as_list() | |||
def get_dep_oprs(self, var): | |||
"""Gets dependent oprs of var | |||
""" | |||
return get_oprs_seq(var, False, False) | |||
@property | |||
def opr_filter(self): | |||
"""Filter on all opnodes of the Network. | |||
""" | |||
oprs = self.all_oprs | |||
return NodeFilter(itertools.islice(oprs, len(oprs))) | |||
@property | |||
def var_filter(self): | |||
"""Filter on all varnode of the Network. | |||
""" | |||
vars = self.all_vars | |||
return NodeFilter(itertools.islice(vars, len(vars))) | |||
@property | |||
def params_filter(self): # all immutable tensor | |||
"""Filter on all parameters (ImmutableTensor Opr) of the Network | |||
""" | |||
return self.opr_filter.param_provider() | |||
@property | |||
def data_providers_filter(self): # all host2devicecopy | |||
"""Filter on all input nodes (Host2DeviceCopy Opr) of the Network | |||
""" | |||
return self.opr_filter.data_provider() | |||
@property | |||
def dest_vars(self): | |||
"""Output varnodes of the Network. | |||
""" | |||
return self.output_vars | |||
@property | |||
def all_oprs(self): | |||
return get_oprs_seq(self.output_vars, False, False) | |||
@property | |||
def all_vars(self): | |||
return get_dep_vars(self.output_vars) | |||
@property | |||
def all_vars_dict(self): | |||
return self.var_filter.as_dict() | |||
@property | |||
def all_oprs_dict(self): | |||
return self.opr_filter.as_dict() | |||
# used for loading and building graph | |||
def _add_opr(self, x): | |||
# TODO: use megbrain C++ RTTI to replace type string | |||
if x.id not in self.all_oprs_map: | |||
self.all_oprs_map[x.id] = str_to_mge_class(get_opr_type(x)).load(x) | |||
return self.all_oprs_map[x.id] | |||
else: | |||
return None | |||
def _get_opr(self, x): | |||
if x.id in self.all_oprs_map: | |||
return self.all_oprs_map[x.id] | |||
else: | |||
return None | |||
def _get_var(self, x): | |||
# auto convert to VarNode of Network | |||
if x.id not in self.all_vars_map: | |||
self.all_vars_map[x.id] = VarNode.load(x, self._get_opr(x.owner)) | |||
return self.all_vars_map[x.id] | |||
def as_varnode(obj): | |||
"""convert a :class:`.VarNode` compatible object to :class:`.VarNode`. | |||
:param obj: it must be one of the following: | |||
1. a :class:`.VarNode` object | |||
2. a :class:`.OpNode` object that has unique output | |||
3. an iterable that produces either type 1 or 2, with length 1 | |||
:rtype: :class:`.VarNode` | |||
""" | |||
if type(obj) is VarNode: | |||
return obj | |||
if isinstance(obj, OpNode): | |||
assert len(obj.outputs) == 1, ( | |||
"operator {} must have one output to be converted to VarNode; " | |||
"got {} actually".format(obj, len(obj.outputs)) | |||
) | |||
ret = obj.outputs[0] | |||
assert type(ret) is VarNode | |||
return ret | |||
assert isinstance( | |||
obj, collections.Iterable | |||
), "{} is not compatible with VarNode".format(obj) | |||
val = list(obj) | |||
assert ( | |||
len(val) == 1 | |||
), "can not convert sequence of length {} to VarNode ({})".format( | |||
len(val), (lambda s: s if len(s) < 50 else s[:50] + " ...")(str(val)) | |||
) | |||
return as_varnode(val[0]) | |||
def as_oprnode(obj): | |||
"""convert a :class:`.OpNode` compatible object to | |||
:class:`.OpNode`; it works like :func:`as_varnode`.""" | |||
if type(obj) is VarNode: | |||
return obj.owner | |||
if isinstance(obj, OpNode): | |||
return obj | |||
assert isinstance( | |||
obj, collections.Iterable | |||
), "{} is not compatible with OpNode".format(obj) | |||
val = list(obj) | |||
assert ( | |||
len(val) == 1 | |||
), "can not convert sequence of length {} to " "OpNode({})".format(len(val), val) | |||
return as_oprnode(val[0]) | |||
class NodeFilter: | |||
"""Filter on node iterator. This class is an iterator of | |||
:class:`.NetworkNode` objects and multiple filtering conditions and | |||
mappers can be chained. | |||
Example:: | |||
# find all :class:`.ImmutableTensor` nodes | |||
for i in NodeFilter(node_iter).param_provider(): | |||
print(i) | |||
# find all :class:`.ImmutableTensor` nodes that end with ':W' | |||
for i in NodeFilter(node_iter).param_provider().name('*:W'): | |||
print(i) | |||
# number of inputs | |||
nr_input = NodeFilter(node_iter).data_provider().as_count() | |||
""" | |||
_iter = None | |||
def __init__(self, node_iter): | |||
""" | |||
:param node_iter: iterator to :class:`.NetworkNode`, or a | |||
:class:`.VarNode`-compatible object; in the later case, its | |||
dependent oprs would be used | |||
""" | |||
if isinstance(node_iter, VarNode): | |||
oprs = get_oprs_seq(node_iter, False, False) | |||
node_iter = itertools.islice(oprs, len(oprs) - 1) | |||
if isinstance(node_iter, OpNode): | |||
oprs = get_oprs_seq(node_iter.inputs, False, False) | |||
node_iter = itertools.islice(oprs, len(oprs) - 1) | |||
assert isinstance(node_iter, collections.Iterable) | |||
if (not isinstance(node_iter, NodeFilter)) and type( | |||
self | |||
) is not NodeFilterCheckType: | |||
node_iter = NodeFilterCheckType(node_iter, NetworkNode) | |||
self._iter = node_iter | |||
@classmethod | |||
def make_all_deps(cls, *dest_vars): | |||
"""make a :class:`NodeFilter` that contains all deps of given vars""" | |||
return cls(list(get_oprs_seq(dest_vars, False, False))) | |||
def __iter__(self): | |||
"""to be overwritten by subclass to implement filters""" | |||
return iter(self._iter) | |||
def type(self, node_type): | |||
"""filter by specific node type | |||
:param node_type: node type class | |||
:return: a new :class:`NodeFilter` object | |||
""" | |||
return NodeFilterType(self, node_type) | |||
def check_type(self, node_type): | |||
"""assert that all oprs produced by this iterator are instances of | |||
certain type | |||
:param node_type: node type class | |||
:return: a new :class:`NodeFilter` object | |||
:raises TypeError: if type check failed | |||
""" | |||
return NodeFilterCheckType(self, node_type) | |||
def not_type(self, node_type): | |||
"""remove oprs of specific type | |||
:param node_type: node type class | |||
:return: a new :class:`NodeFilter` object | |||
""" | |||
return NodeFilterNotType(self, node_type) | |||
def param_provider(self): | |||
"""get :class:`.ParamProvider` oprs; shorthand for | |||
``.type(ParamProvider)``""" | |||
return self.type(ImmutableTensor) | |||
def data_provider(self): | |||
"""get :class:`.DataProvider` oprs; shorthand for | |||
``.type(DataProvider)``""" | |||
return self.type(Host2DeviceCopy) | |||
def name(self, pattern, ignorecase=True): | |||
"""filter by node name | |||
:param pattern: a string in glob syntax that can contain ``?`` and | |||
``*`` to match a single or arbitrary characters. | |||
:type pattern: :class:`str` | |||
:param ignorecase: whether to ignroe case | |||
:type ignorecase: bool | |||
:return: a new :class:`NodeFilter` object | |||
""" | |||
return NodeFilterName(self, pattern, ignorecase) | |||
def has_input(self, var): | |||
"""an opr is kept if it has given var as one of its inputs | |||
:param var: var node to checked | |||
:return: a new :class:`NodeFilter` object | |||
""" | |||
return NodeFilterHasInput(self, var) | |||
def as_list(self): | |||
"""consume this iterator and return its content as a list | |||
:rtype: [:class:`.GraphNodeBase`] | |||
""" | |||
return list(self) | |||
def as_unique(self): | |||
"""assert that this iterator yields only one node and return it | |||
:return: the unique node | |||
:rtype: :class:`.GraphNodeBase` | |||
:raises ValueError: if this iterator does not yield a unique node | |||
""" | |||
(opr,) = self | |||
return opr | |||
def as_dict(self): | |||
"""construct an ordered dict to map from node names to objects in | |||
this iterator | |||
:rtype: :class:`OrderedDict` | |||
""" | |||
return collections.OrderedDict((i.name, i) for i in self) | |||
def as_count(self): | |||
"""consume this iterator and get the number of elements | |||
:rtype: int | |||
""" | |||
return sum(1 for _ in self) | |||
class NodeFilterType(NodeFilter): | |||
"""see :meth:`NodeFilter.type`""" | |||
_node_type = None | |||
def __init__(self, node_iter, node_type): | |||
assert issubclass(node_type, NetworkNode), "bad opr type: {}".format( | |||
node_type | |||
) | |||
super().__init__(node_iter) | |||
self._node_type = node_type | |||
def __iter__(self): | |||
for i in self._iter: | |||
if isinstance(i, self._node_type): | |||
yield i | |||
class NodeFilterNotType(NodeFilterType): | |||
"""see :meth:`NodeFilter.not_type`""" | |||
def __iter__(self): | |||
for i in self._iter: | |||
if not isinstance(i, self._node_type): | |||
yield i | |||
class NodeFilterCheckType(NodeFilterType): | |||
"""see :meth:`NodeFilter.check_type`""" | |||
def __iter__(self): | |||
for i in self._iter: | |||
if not isinstance(i, self._node_type): | |||
raise TypeError( | |||
"all nodes should be {}; got {!r}".format(self._node_type, i) | |||
) | |||
yield i | |||
class NodeFilterHasInput(NodeFilter): | |||
"""see :meth:`NodeFilter.has_input`""" | |||
_var = None | |||
def __init__(self, node_iter, var): | |||
var = as_varnode(var) | |||
super().__init__(node_iter) | |||
self.var = var | |||
def __iter__(self): | |||
for i in self._iter: | |||
assert isinstance( | |||
i, OpNode | |||
), "has_input() must be used with OpNode; " "got {!r}".format(i) | |||
if self.var in i.inputs: | |||
yield i | |||
class NodeFilterName(NodeFilter): | |||
"""see :meth:`NodeFilter.name`""" | |||
_re = None | |||
def __init__(self, node_iter, pattern, ignorecase): | |||
super().__init__(node_iter) | |||
self._re = self.make_re(pattern, ignorecase) | |||
@classmethod | |||
def make_re(cls, pattern, ignorecase=True): | |||
assert isinstance(pattern, str), "bad pattern: {!r}".format(pattern) | |||
assert isinstance(ignorecase, bool) | |||
flags = 0 | |||
if ignorecase: | |||
flags |= re.IGNORECASE | |||
return re.compile(fnmatch.translate(pattern), flags=flags) | |||
def __iter__(self): | |||
for i in self._iter: | |||
if self._re.match(i.name): | |||
yield i |
@@ -0,0 +1,628 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import json | |||
import sys | |||
from typing import Callable | |||
from ..core import _imperative_rt as rt | |||
from ..core._wrap import Device | |||
from ..core.ops import builtin | |||
from ..core.tensor.megbrain_graph import InputNode | |||
from ..tensor import Tensor | |||
from .comp_graph_tools import replace_vars | |||
class NetworkNode: | |||
pass | |||
class VarNode(NetworkNode): | |||
def __init__(self, owner_opr=None, name=None): | |||
self.var = None | |||
self.owner = owner_opr | |||
self.name = name | |||
self.id = id(self) | |||
@classmethod | |||
def load(cls, sym_var, owner_opr): | |||
obj = cls() | |||
obj.var = sym_var # mgb varnode | |||
obj.name = sym_var.name | |||
obj.owner = owner_opr | |||
return obj | |||
@property | |||
def shape(self): | |||
rst = None | |||
if self.var: | |||
try: | |||
rst = self.var.shape | |||
except: | |||
rst = None | |||
return rst | |||
@property | |||
def dtype(self): | |||
return self.var.dtype if self.var else None | |||
def set_owner_opr(self, owner_opr): | |||
self.owner_opr = owner_opr | |||
class OpNode(NetworkNode): | |||
opdef = None | |||
type = None | |||
def __init__(self): | |||
self.inputs = [] | |||
self.outputs = [] | |||
self.params = {} | |||
self._opr = None # mgb opnode | |||
self.id = id(self) | |||
@classmethod | |||
def load(cls, opr): | |||
obj = cls() | |||
obj.params = json.loads(opr.params) | |||
obj.name = opr.name | |||
obj._opr = opr | |||
return obj | |||
def compile(self, graph=None): | |||
op = self.opdef(**self.params) | |||
args = [i.var for i in self.inputs] | |||
outputs = rt.invoke_op(op, args) | |||
assert len(outputs) == len(self.outputs) | |||
self._opr = outputs[0].owner | |||
for i in range(len(self.outputs)): | |||
self.outputs[i].var = outputs[i] | |||
self.outputs[i].var.name = self.outputs[i].name | |||
assert self.outputs[i].owner is self | |||
def add_inp_var(self, x): | |||
self.inputs.append(x) | |||
def add_out_var(self, x): | |||
self.outputs.append(x) | |||
def str_to_mge_class(classname): | |||
# TODO: use megbrain C++ RTTI to replace type string | |||
if classname == "RNGOpr<MegDNNOpr>": | |||
classname = "RNGOpr" | |||
oprcls = getattr(sys.modules[__name__], classname, None) | |||
return oprcls if oprcls else ReadOnlyOpNode | |||
class Host2DeviceCopy(OpNode): | |||
type = "Host2DeviceCopy" | |||
def __init__(self, shape=None, dtype=None, name=None, device=None): | |||
super().__init__() | |||
self.shape = shape | |||
self.dtype = dtype | |||
self.name = name | |||
self.device = Device(device).to_c() if device else Device("xpux").to_c() | |||
self.outputs = [] | |||
@classmethod | |||
def load(cls, opr): | |||
self = cls() | |||
self.outputs = [] | |||
assert len(opr.outputs) == 1, "wrong number of outputs" | |||
self.shape = opr.outputs[0].shape | |||
self.dtype = opr.outputs[0].dtype | |||
self.name = opr.outputs[0].name | |||
self.device = opr.outputs[0].comp_node | |||
self._opr = opr | |||
return self | |||
def compile(self, graph): | |||
outputs = rt.make_h2d(graph, self.device, self.dtype, self.shape, self.name) | |||
self._opr = outputs.owner | |||
if len(self.outputs) == 0: | |||
self.outputs.append(VarNode(self, self.name)) | |||
self.outputs[0].var = outputs | |||
assert self.outputs[0].owner is self | |||
class ImmutableTensor(OpNode): | |||
type = "ImmutableTensor" | |||
def __init__(self, data=None, name=None, device=None, graph=None): | |||
super().__init__() | |||
self.name = name | |||
self.outputs = [] | |||
self.graph = graph | |||
if data is not None: | |||
self.set_value(data, device) | |||
@property | |||
def device(self): | |||
return self._opr.outputs[0].comp_node if self._opr else None | |||
@device.setter | |||
def device(self, device): | |||
self.set_value(self.numpy(), device) | |||
@property | |||
def shape(self): | |||
return self.outputs[0].shape | |||
@property | |||
def dtype(self): | |||
return self._opr.outputs[0].dtype if self._opr else None | |||
def numpy(self): | |||
return self._opr.outputs[0].value if self._opr else None | |||
def set_value(self, data, device=None): | |||
assert self.graph is not None | |||
cn = device if device else self.device | |||
assert isinstance(data, (int, float, np.ndarray)) | |||
if isinstance(data, (int, float)): | |||
data = np.array(data) | |||
if data.dtype == np.float64: | |||
data = data.astype(np.float32) | |||
elif data.dtype == np.int64: | |||
data = data.astype(np.int32) | |||
varnode = rt.make_const(self.graph, data, cn, data.dtype, self.name) | |||
if len(self.outputs) == 0: | |||
self.outputs.append(VarNode(self, self.name)) | |||
self.outputs[0].var = varnode | |||
self._opr = varnode.owner | |||
@classmethod | |||
def load(cls, opr): | |||
self = cls() | |||
self.outputs = [] | |||
self._opr = opr | |||
self.name = opr.outputs[0].name | |||
self.graph = opr.graph | |||
return self | |||
def compile(self, graph): | |||
assert self.outputs[0].var is self._opr.outputs[0] | |||
assert self.outputs[0].owner is self | |||
if self.graph != graph: | |||
self.graph = graph | |||
self.set_value(self.numpy()) | |||
if self.name is not None: | |||
self.outputs[0].var.name = self.name | |||
class ReadOnlyOpNode(OpNode): | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(ReadOnlyOpNode, cls).load(opr) | |||
obj.type = opr.type | |||
return obj | |||
def compile(self): | |||
assert self._opr is not None | |||
assert len(self.inputs) == len(self._opr.inputs) | |||
assert len(self.outputs) == len(self._opr.outputs) | |||
repl_dict = {} | |||
for ind, i in enumerate(self.inputs): | |||
if i.var != self._opr.inputs[ind]: | |||
repl_dict[self._opr.inputs[ind]] = i.var | |||
if bool(repl_dict): | |||
out_vars = replace_vars(self._opr.outputs, repl_dict) | |||
for ind, o in enumerate(self.outputs): | |||
o.var = out_vars[ind] | |||
class Elemwise(OpNode): | |||
type = "Elemwise" | |||
opdef = builtin.Elemwise | |||
class Reduce(OpNode): | |||
type = "Reduce" | |||
opdef = builtin.Reduce | |||
class TypeCvt(OpNode): | |||
type = "TypeCvt" | |||
opdef = builtin.TypeCvt | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(TypeCvt, cls).load(opr) | |||
t_dtype = opr.outputs[0].dtype | |||
obj.params["dtype"] = t_dtype | |||
return obj | |||
class MatrixInverse(OpNode): | |||
type = "MatrixInverse" | |||
opdef = builtin.MatrixInverse | |||
class MatrixMul(OpNode): | |||
type = "MatrixMul" | |||
opdef = builtin.MatrixMul | |||
class BatchedMatrixMul(OpNode): | |||
type = "BatchedMatmul" | |||
opdef = builtin.BatchedMatrixMul | |||
class Dot(OpNode): | |||
type = "Dot" | |||
opdef = builtin.Dot | |||
class SVD(OpNode): | |||
type = "SVD" | |||
opdef = builtin.SVD | |||
class ConvolutionForward(OpNode): | |||
type = "Convolution" | |||
opdef = builtin.Convolution | |||
class ConvolutionBackwardData(OpNode): | |||
type = "ConvTranspose" | |||
opdef = builtin.ConvolutionBackwardData | |||
class DeformableConvForward(OpNode): | |||
type = "DeformableConv" | |||
opdef = builtin.DeformableConv | |||
class GroupLocalForward(OpNode): | |||
type = "GroupLocal" | |||
opdef = builtin.GroupLocal | |||
class PoolingForward(OpNode): | |||
type = "Pooling" | |||
opdef = builtin.Pooling | |||
class AdaptivePoolingForward(OpNode): | |||
type = "AdaptivePooling" | |||
opdef = builtin.AdaptivePooling | |||
class ROIPoolingForward(OpNode): | |||
type = "ROIPooling" | |||
opdef = builtin.ROIPooling | |||
class DeformablePSROIPoolingForward(OpNode): | |||
type = "DeformablePSROIPooling" | |||
opdef = builtin.DeformablePSROIPooling | |||
class ConvBiasForward(OpNode): | |||
type = "ConvBias" | |||
opdef = builtin.ConvBias | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(ConvBiasForward, cls).load(opr) | |||
obj.params["dtype"] = opr.outputs[0].dtype | |||
return obj | |||
class BatchConvBiasForward(OpNode): | |||
type = "BatchConvBias" | |||
opdef = builtin.BatchConvBias | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(BatchConvBiasForward, cls).load(opr) | |||
obj.params["dtype"] = opr.outputs[0].dtype | |||
return obj | |||
class BatchNormForward(OpNode): | |||
type = "BatchNorm" | |||
opdef = builtin.BatchNorm | |||
class ROIAlignForward(OpNode): | |||
type = "ROIAlign" | |||
opdef = builtin.ROIAlign | |||
class WarpPerspectiveForward(OpNode): | |||
type = "WarpPerspective" | |||
opdef = builtin.WarpPerspective | |||
class WarpAffineForward(OpNode): | |||
type = "WarpAffine" | |||
opdef = builtin.WarpAffine | |||
class RemapForward(OpNode): | |||
type = "Remap" | |||
opdef = builtin.Remap | |||
class ResizeForward(OpNode): | |||
type = "Resize" | |||
opdef = builtin.Resize | |||
class IndexingOneHot(OpNode): | |||
type = "IndexingOneHot" | |||
opdef = builtin.IndexingOneHot | |||
class IndexingSetOneHot(OpNode): | |||
type = "IndexingSetOneHot" | |||
opdef = builtin.IndexingSetOneHot | |||
class Copy(OpNode): | |||
type = "Copy" | |||
opdef = builtin.Copy | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(Copy, cls).load(opr) | |||
obj.params["comp_node"] = opr.outputs[0].comp_node | |||
return obj | |||
class ArgsortForward(OpNode): | |||
type = "Argsort" | |||
opdef = builtin.Argsort | |||
class Argmax(OpNode): | |||
type = "Argmax" | |||
opdef = builtin.Argmax | |||
class Argmin(OpNode): | |||
type = "Argmin" | |||
opdef = builtin.Argmin | |||
class CondTake(OpNode): | |||
type = "CondTake" | |||
opdef = builtin.CondTake | |||
class TopK(OpNode): | |||
type = "TopK" | |||
opdef = builtin.TopK | |||
class NvOf(OpNode): | |||
type = "NvOf" | |||
opdef = builtin.NvOf | |||
class RNGOpr(OpNode): | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(RNGOpr, cls).load(opr) | |||
if len(obj.params) == 3: | |||
obj.opdef = builtin.GaussianRNG | |||
obj.type = "GaussianRNG" | |||
else: | |||
obj.opdef = builtin.UniformRNG | |||
obj.type = "UniformRNG" | |||
return obj | |||
class Linspace(OpNode): | |||
type = "Linspace" | |||
opdef = builtin.Linspace | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(Linspace, cls).load(opr) | |||
obj.params["comp_node"] = opr.outputs[0].comp_node | |||
return obj | |||
class Eye(OpNode): | |||
type = "Eye" | |||
opdef = builtin.Eye | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(Eye, cls).load(opr) | |||
obj.params["dtype"] = opr.outputs[0].dtype | |||
obj.params["comp_node"] = opr.outputs[0].comp_node | |||
return obj | |||
class GetVarShape(OpNode): | |||
type = "GetVarShape" | |||
opdef = builtin.GetVarShape | |||
class Concat(OpNode): | |||
type = "Concat" | |||
opdef = builtin.Concat | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(Concat, cls).load(opr) | |||
obj.params["comp_node"] = Device("xpux").to_c() | |||
return obj | |||
class Broadcast(OpNode): | |||
type = "Broadcast" | |||
opdef = builtin.Broadcast | |||
class Identity(OpNode): | |||
type = "Identity" | |||
opdef = builtin.Identity | |||
class NMSKeep(OpNode): | |||
type = "NMSKeep" | |||
opdef = builtin.NMSKeep | |||
# class ParamPackSplit | |||
# class ParamPackConcat | |||
class Dimshuffle(OpNode): | |||
type = "Dimshuffle" | |||
opdef = builtin.Dimshuffle | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(Dimshuffle, cls).load(opr) | |||
del obj.params["ndim"] | |||
return obj | |||
class Reshape(OpNode): | |||
type = "Reshape" | |||
opdef = builtin.Reshape | |||
class AxisAddRemove(OpNode): | |||
type = "AxisAddRemove" | |||
@classmethod | |||
def load(cls, opr): | |||
obj = cls() | |||
obj.name = opr.name | |||
obj._opr = opr | |||
params = json.loads(opr.params) | |||
desc = params["desc"] | |||
method = None | |||
axis = [] | |||
for i in desc: | |||
if method is None: | |||
method = i["method"] | |||
assert method == i["method"] | |||
axis.append(i["axisnum"]) | |||
obj.params = {"axis": axis} | |||
obj.opdef = builtin.AddAxis if desc[0]["method"] == 0 else builtin.RemoveAxis | |||
return obj | |||
class IndexingBase(OpNode): | |||
@classmethod | |||
def load(cls, opr): | |||
obj = cls() | |||
obj.name = opr.name | |||
obj._opr = opr | |||
params = json.loads(opr.params) | |||
items = [ | |||
[ | |||
p["axis"], | |||
bool(p["begin"]), | |||
bool(p["end"]), | |||
bool(p["step"]), | |||
bool(p["idx"]), | |||
] | |||
for p in params | |||
] | |||
obj.params["items"] = items | |||
return obj | |||
class Subtensor(IndexingBase): | |||
type = "Subtensor" | |||
opdef = builtin.Subtensor | |||
class SetSubtensor(IndexingBase): | |||
type = "SetSubtensor" | |||
opdef = builtin.SetSubtensor | |||
class IncrSubtensor(IndexingBase): | |||
type = "IncrSubtensor" | |||
opdef = builtin.IncrSubtensor | |||
class IndexingMultiAxisVec(IndexingBase): | |||
type = "IndexingMultiAxisVec" | |||
opdef = builtin.IndexingMultiAxisVec | |||
class IndexingSetMultiAxisVec(IndexingBase): | |||
type = "IndexingSetMultiAxisVec" | |||
opdef = builtin.IndexingSetMultiAxisVec | |||
class IndexingIncrMultiAxisVec(IndexingBase): | |||
type = "IndexingIncrMultiAxisVec" | |||
opdef = builtin.IndexingIncrMultiAxisVec | |||
class MeshIndexing(IndexingBase): | |||
type = "MeshIndexing" | |||
opdef = builtin.MeshIndexing | |||
class SetMeshIndexing(IndexingBase): | |||
type = "SetMeshIndexing" | |||
opdef = builtin.SetMeshIndexing | |||
class IncrMeshIndexing(IndexingBase): | |||
type = "IncrMeshIndexing" | |||
opdef = builtin.IncrMeshIndexing | |||
class BatchedMeshIndexing(IndexingBase): | |||
type = "BatchedMeshIndexing" | |||
opdef = builtin.BatchedMeshIndexing | |||
class BatchedSetMeshIndexing(IndexingBase): | |||
type = "BatchedSetMeshIndexing" | |||
opdef = builtin.BatchedSetMeshIndexing | |||
class BatchedIncrMeshIndexing(IndexingBase): | |||
type = "BatchedIncrMeshIndexing" | |||
opdef = builtin.BatchedIncrMeshIndexing | |||
# class CollectiveComm | |||
# class RemoteSend | |||
# class RemoteRecv | |||
# class TQT | |||
# class FakeQuant | |||
# class InplaceAdd | |||
class AssertEqual(OpNode): | |||
type = "AssertEqual" | |||
opdef = builtin.AssertEqual | |||
class ElemwiseMultiType(OpNode): | |||
type = "ElemwiseMultiType" | |||
opdef = builtin.ElemwiseMultiType | |||
@classmethod | |||
def load(cls, opr): | |||
obj = super(ElemwiseMultiType, cls).load(opr) | |||
obj.params["dtype"] = opr.outputs[0].dtype | |||
return obj | |||
class CvtColorForward(OpNode): | |||
type = "CvtColor" | |||
opdef = builtin.CvtColor |
@@ -160,6 +160,16 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||
if (ctx.op->same_type<BackwardGraph>()) { | |||
ctx.backward = true; | |||
} | |||
if (py::isinstance<cg::VarNode>(py::handle(args[0]))){ | |||
SmallVector<cg::VarNode*> vinputs(nargs); | |||
for (size_t i = 0; i < nargs; ++i) { | |||
vinputs[i] = py::handle(args[i]).cast<cg::VarNode *>(); | |||
} | |||
auto op = ctx.op.get(); | |||
return to_tuple(OpDef::apply_on_var_node(*op, vinputs)).release().ptr(); | |||
} | |||
for (size_t i = 0; i < nargs; ++i) { | |||
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||
@@ -675,6 +685,16 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||
tensors.emplace_back(descr); | |||
continue; | |||
} | |||
if (py::isinstance<cg::VarNode>(py::handle(handle))){ | |||
auto var = py::handle(handle).cast<cg::VarNode *>(); | |||
mgb::DType type = var->dtype(); | |||
auto && descr = npy::dtype_mgb2np_descr(type); | |||
Py_INCREF(descr.get()); | |||
tensors.emplace_back(descr.get()); | |||
continue; | |||
} | |||
PyArray_Descr* descr = scalar2dtype(handle); | |||
if (descr) { | |||
scalars.emplace_back(descr); | |||
@@ -719,12 +739,14 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { | |||
for (size_t i = 0; i < nargs; ++i) { | |||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
if (tw) { | |||
bool is_var = py::isinstance<cg::VarNode>(py::handle(handle)); | |||
if (tw || is_var) { | |||
if (!valid) { | |||
cn = tw->m_tensor->comp_node(); | |||
cn = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||
valid = true; | |||
} else { | |||
CompNode cn1 = tw->m_tensor->comp_node(); | |||
CompNode cn1 = tw ? tw->m_tensor->comp_node() : py::handle(handle).cast<cg::VarNode *>()->comp_node(); | |||
if (cn1 != cn) { | |||
throw py::value_error(ssprintf("ambiguous device: %s vs %s", | |||
cn.to_string().c_str(), cn1.to_string().c_str())); | |||
@@ -0,0 +1,351 @@ | |||
import io | |||
import numpy as np | |||
import megengine.core.tensor.megbrain_graph as G | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.utils.network_node as N | |||
from megengine.jit.tracing import trace | |||
from megengine.tensor import Tensor | |||
from megengine.utils.comp_graph_tools import GraphInference | |||
from megengine.utils.network import Network as Net | |||
from megengine.utils.network import as_oprnode | |||
from megengine.utils.network_node import Host2DeviceCopy, VarNode | |||
def test_replace_var(): | |||
a = Tensor([1, 2]) | |||
b = Tensor([3, 4]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(a, b): | |||
return (a + b) * 2 | |||
fwd(a, b) | |||
orig_model = io.BytesIO() | |||
fwd.dump( | |||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
) | |||
orig_model.seek(0) | |||
graph = Net.load(orig_model) | |||
vara = graph.var_filter.name("a").as_unique() | |||
varb = graph.var_filter.name("b").as_unique() | |||
out = F.mul(vara.var, varb.var) | |||
out = F.relu(out) | |||
var_list = graph.add_dep_oprs(out) | |||
opnode = list(graph.opr_filter.has_input(vara)) | |||
repl_dict = {opnode[0].outputs[0]: var_list[0]} | |||
graph.replace_vars(repl_dict) | |||
modified_model = io.BytesIO() | |||
graph.dump(modified_model) | |||
modified_model.seek(0) | |||
load_graph = GraphInference(modified_model) | |||
out = load_graph.run(a, b) | |||
np.testing.assert_equal(out["o"], [6, 16]) | |||
def test_replace_opr(): | |||
a = Tensor([1, 2]) | |||
b = Tensor([3, 4]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(a, b): | |||
return (a + b) * 2 | |||
fwd(a, b) | |||
orig_model = io.BytesIO() | |||
fwd.dump( | |||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
) | |||
orig_model.seek(0) | |||
graph = Net.load(orig_model) | |||
vara = graph.var_filter.name("a").as_unique() | |||
varb = graph.var_filter.name("b").as_unique() | |||
out1 = F.sub(vara.var, varb.var) | |||
out1 = F.relu(out1) | |||
var_list = graph.add_dep_oprs(out1) | |||
repl_opr = as_oprnode(var_list) | |||
orig_opr = graph.opr_filter.has_input(vara).as_unique() | |||
repl_dict = {orig_opr: repl_opr} | |||
graph.replace_oprs(repl_dict) | |||
modified_model1 = io.BytesIO() | |||
graph.dump(modified_model1) | |||
modified_model1.seek(0) | |||
load_graph = GraphInference(modified_model1) | |||
out = load_graph.run(a, b) | |||
np.testing.assert_equal(out["o"], [0, 0]) | |||
def test_modify_params(): | |||
a = Tensor([1, 2]) | |||
b = Tensor([3, 4]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(a, b): | |||
return (a + b) * 2 | |||
fwd(a, b) | |||
orig_model = io.BytesIO() | |||
fwd.dump( | |||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
) | |||
orig_model.seek(0) | |||
graph = Net.load(orig_model) | |||
param_const = graph.params_filter.as_unique() | |||
param_const.set_value(3) | |||
modified_model = io.BytesIO() | |||
graph.dump(modified_model) | |||
modified_model.seek(0) | |||
load_graph = GraphInference(modified_model) | |||
out = load_graph.run(a, b) | |||
np.testing.assert_equal(out["o"], [12, 18]) | |||
def test_make_const(): | |||
a = Tensor([1, 2]) | |||
b = Tensor([3, 4]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(a, b): | |||
return (a + b) * 2 | |||
fwd(a, b) | |||
orig_model = io.BytesIO() | |||
fwd.dump( | |||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
) | |||
orig_model.seek(0) | |||
graph = Net.load(orig_model) | |||
const_b = graph.make_const(np.array([0.0, 0.0]), name="b") | |||
varb = graph.var_filter.name("b").as_unique() | |||
repl_dict = {varb: const_b} | |||
graph.replace_vars(repl_dict) | |||
modified_model = io.BytesIO() | |||
graph.dump(modified_model) | |||
modified_model.seek(0) | |||
load_graph = GraphInference(modified_model) | |||
out = load_graph.run(a) | |||
np.testing.assert_equal(out["o"], [2, 4]) | |||
def test_add_input(): | |||
a = Tensor([1, 2]) | |||
b = Tensor([3, 4]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(a, b): | |||
return (a + b) * 2 | |||
fwd(a, b) | |||
orig_model = io.BytesIO() | |||
fwd.dump( | |||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
) | |||
orig_model.seek(0) | |||
graph = Net.load(orig_model) | |||
inp_c = graph.make_input_node((2,), np.int32, name="c") | |||
varo = graph.var_filter.name("o").as_unique() | |||
out = F.add(varo.var, inp_c.var) | |||
out = graph.add_dep_oprs(out)[0] | |||
out.name = "o1" | |||
graph.remove_output(varo) | |||
graph.add_output(out) | |||
modified_model = io.BytesIO() | |||
graph.dump(modified_model) | |||
modified_model.seek(0) | |||
load_graph = GraphInference(modified_model) | |||
out = load_graph.run(a, b, a) | |||
np.testing.assert_equal(out["o1"], ((a + b) * 2 + a).numpy()) | |||
def test_add_output(): | |||
a = Tensor([1.0, 2.0]) | |||
b = Tensor([3.0, 4.0]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(a, b): | |||
return (a + b) * 2 | |||
fwd(a, b) | |||
orig_model = io.BytesIO() | |||
fwd.dump( | |||
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False | |||
) | |||
orig_model.seek(0) | |||
net = Net.load(orig_model) | |||
var_a = net.var_filter.name("a").as_unique() | |||
var_b = net.var_filter.name("b").as_unique() | |||
y = F.add(var_a.var, var_b.var) | |||
y = F.sigmoid(y) | |||
new_vars = net.add_dep_oprs(y)[0] | |||
new_vars.name = "o1" | |||
net.add_output(new_vars) | |||
modified_model = io.BytesIO() | |||
net.dump(modified_model) | |||
modified_model.seek(0) | |||
g = GraphInference(modified_model) | |||
out = g.run(a.numpy(), b.numpy()) | |||
np.testing.assert_equal(out["o"], ((a + b) * 2).numpy()) | |||
np.testing.assert_equal(out["o1"], (F.sigmoid((a + b))).numpy()) | |||
def test_query(): | |||
class Model(M.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv1 = M.Conv2d(3, 32, 3) | |||
self.conv2 = M.Conv2d(32, 32, 3) | |||
self.conv3 = M.Conv2d(32, 32, 3) | |||
def forward(self, data): | |||
x = self.conv1(data) | |||
x = self.conv2(x) | |||
x = self.conv3(x) | |||
return x | |||
n = Model() | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return n(data) | |||
fwd(Tensor(np.random.random((1, 3, 224, 224)))) | |||
orig_model = io.BytesIO() | |||
fwd.dump( | |||
orig_model, | |||
arg_names=["data"], | |||
output_names="o", | |||
keep_opr_name=True, | |||
keep_var_name=True, | |||
optimize_for_inference=False, | |||
) | |||
orig_model.seek(0) | |||
graph = Net.load(orig_model) | |||
r = graph.data_providers_filter.as_count() | |||
assert r == 1 | |||
opr = graph.get_opr_by_type(Host2DeviceCopy) | |||
assert isinstance(opr, Host2DeviceCopy) | |||
r1 = graph.params_filter.as_count() | |||
assert r1 == 6 | |||
r2 = graph.opr_filter.type(N.ConvolutionForward).as_count() | |||
assert r2 == 3 | |||
r3 = graph.opr_filter.not_type(N.ConvolutionForward).as_count() | |||
assert r3 == len(graph.all_oprs) - r2 | |||
var = graph.var_filter.name("data").as_unique() | |||
r4 = graph.opr_filter.has_input(var).as_count() | |||
assert r4 == 1 | |||
r5 = graph.opr_filter.name("data").as_count() | |||
assert r5 == 1 | |||
opr = graph.get_opr_by_name("data") | |||
assert isinstance(opr, Host2DeviceCopy) | |||
var = graph.get_var_by_name("data") | |||
assert isinstance(var, VarNode) | |||
r6 = graph.var_filter.name("*bias").as_count() | |||
assert r6 == 3 | |||
def test_optimize_for_inference(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(x): | |||
return F.exp(x) | |||
orig_model = io.BytesIO() | |||
f(Tensor(5.0)) | |||
f.dump(orig_model, optimize_for_inference=False) | |||
orig_model.seek(0) | |||
optimize_model = io.BytesIO() | |||
net = Net.load(orig_model) | |||
net.dump(optimize_model, enable_io16xc32=True) | |||
optimize_model.seek(0) | |||
res = G.load_graph(optimize_model) | |||
computing_input = res.output_vars_list[0].owner.inputs[0] | |||
assert computing_input.dtype == np.float16 | |||
def test_reset_batchsize(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(x): | |||
return F.exp(x) | |||
orig_model = io.BytesIO() | |||
f(Tensor(np.random.random((3, 3, 224, 224)))) | |||
f.dump(orig_model, optimize_for_inference=False) | |||
orig_model.seek(0) | |||
modified_model = io.BytesIO() | |||
net = Net.load(orig_model) | |||
net.reset_batch_size(1) | |||
net.dump(modified_model, optimize_for_inference=False) | |||
modified_model.seek(0) | |||
net1 = Net.load(modified_model) | |||
assert net1.data_providers_filter.as_unique().shape[0] == 1 | |||
def test_modify_opr_name(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(x): | |||
return F.exp(x) | |||
orig_model = io.BytesIO() | |||
f(Tensor(np.random.random((3, 3, 224, 224)))) | |||
f.dump(orig_model, arg_names=["a"], optimize_for_inference=False) | |||
orig_model.seek(0) | |||
modified_model = io.BytesIO() | |||
net = Net.load(orig_model) | |||
net.modify_opr_names("net") | |||
net.modify_opr_names(lambda x: "net1." + x) | |||
net.dump(modified_model, optimize_for_inference=False) | |||
modified_model.seek(0) | |||
net1 = Net.load(modified_model) | |||
assert net1.data_providers_filter.as_unique().name == "net1.net.a" |
@@ -0,0 +1,712 @@ | |||
import io | |||
import os | |||
import platform | |||
import numpy as np | |||
import pytest | |||
import megengine.core.tensor.dtype as dtype | |||
import megengine.core.tensor.megbrain_graph as G | |||
import megengine.functional as F | |||
import megengine.module as M | |||
import megengine.random as rand | |||
from megengine.core._imperative_rt.core2 import apply | |||
from megengine.core._wrap import Device | |||
from megengine.core.ops import builtin | |||
from megengine.device import is_cuda_available | |||
from megengine.functional.external import tensorrt_runtime_opr | |||
from megengine.jit.tracing import trace | |||
from megengine.tensor import Tensor | |||
from megengine.utils.comp_graph_tools import GraphInference | |||
from megengine.utils.network import Network as Net | |||
def check_pygraph_dump(trace_func, inp_data, expect_results): | |||
orig_model = io.BytesIO() | |||
inp_size = len(inp_data) | |||
out_size = len(expect_results) | |||
arg_names = ["arg_{}".format(i) for i in range(inp_size)] | |||
output_names = ["out_{}".format(i) for i in range(out_size)] | |||
trace_func.dump( | |||
orig_model, | |||
arg_names=arg_names, | |||
output_names=output_names, | |||
optimize_for_inference=False, | |||
) | |||
orig_model.seek(0) | |||
net = Net.load(orig_model) | |||
file = io.BytesIO() | |||
net.dump(file, optimize_for_inference=False) | |||
file.seek(0) | |||
graph = GraphInference(file) | |||
inp_dict = dict([(arg_names[i], inp_data[i].numpy()) for i in range(inp_size)]) | |||
results = graph.run(inp_dict=inp_dict) | |||
for ind, tensor in enumerate(expect_results): | |||
np.testing.assert_equal(tensor.numpy(), results[output_names[ind]]) | |||
assert tensor.dtype == results[output_names[ind]].dtype | |||
def test_elemwise(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(x, y): | |||
z1 = x * y | |||
z2 = x + y | |||
z3 = z1 / z2 | |||
z3 = z3 ** 3 | |||
return z3 | |||
x = Tensor([1.0, 2.0]) | |||
y = Tensor([3.0, 5.0]) | |||
result = fwd(x, y) | |||
check_pygraph_dump(fwd, [x, y], [result]) | |||
def test_reduce(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
x = data.sum(axis=2) | |||
x = x.mean(axis=1) | |||
return x | |||
data = Tensor(np.random.random((1, 32, 32))) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], [result]) | |||
def test_typecvt(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return data.astype(dtype.qint8(0.8)) | |||
x = Tensor(np.random.random((2, 3)) * 255) | |||
result = fwd(x) | |||
check_pygraph_dump(fwd, [x], [result]) | |||
def test_matinv(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return F.matinv(data) | |||
data = Tensor(np.random.random((5, 5))) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], [result]) | |||
def test_matmul(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data1, data2): | |||
return F.matmul(data1, data2) | |||
data1 = Tensor(np.random.random((32, 64))) | |||
data2 = Tensor(np.random.random((64, 16))) | |||
result = fwd(data1, data2) | |||
check_pygraph_dump(fwd, [data1, data2], [result]) | |||
def test_batchmatmul(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(x, y): | |||
return F.matmul(x, y) | |||
x = Tensor(np.random.random((3, 3, 5))) | |||
y = Tensor(np.random.random((3, 5, 3))) | |||
result = fwd(x, y) | |||
check_pygraph_dump(fwd, [x, y], [result]) | |||
def test_dot(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(x, y): | |||
return F.dot(x, y) | |||
x = Tensor([1.0, 2.0, 3.0]) | |||
y = Tensor([3.0, 4.0, 5.0]) | |||
result = fwd(x, y) | |||
check_pygraph_dump(fwd, [x, y], [result]) | |||
def test_svd(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
_, out, _ = F.svd(data) | |||
return out | |||
input = Tensor(np.random.random((1, 1, 3, 3))) | |||
result = fwd(input) | |||
check_pygraph_dump(fwd, [input], [result]) | |||
def test_conv(): | |||
conv = M.Conv2d(3, 32, 3) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return conv(data) | |||
data = Tensor(np.random.random((1, 3, 32, 32))) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], [result]) | |||
def test_deformable_conv(): | |||
if not is_cuda_available(): | |||
return | |||
conv = M.DeformableConv2d(3, 32, 3) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data, offset, mask): | |||
return conv(data, offset, mask) | |||
data = Tensor(np.random.random((1, 3, 32, 32))) | |||
offset = Tensor(np.ones((32, 3 * 3 * 2, 30, 30)).astype("int32") * 5) | |||
mask = Tensor(np.ones((32, 3 * 3, 30, 30)).astype("int32")) | |||
out = fwd(data, offset, mask) | |||
check_pygraph_dump(fwd, [data, offset, mask], [out]) | |||
def test_convtranspose(): | |||
deconv = M.ConvTranspose2d(32, 32, 3) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return deconv(data) | |||
data = Tensor(np.random.random((1, 32, 32, 32))) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], [result]) | |||
@pytest.mark.skip(reason="pytest aborted") | |||
def test_grouplocal(): | |||
n = M.LocalConv2d(3, 32, 32, 32, 3) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return n(data) | |||
input = Tensor(np.random.random((1, 3, 32, 32))) | |||
result = fwd(input) | |||
check_pygraph_dump(fwd, [input], [result]) | |||
def test_pooling(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
out = F.max_pool2d(data, 2, 2) | |||
out = F.avg_pool2d(out, 2, 2) | |||
return out | |||
data = Tensor(np.random.random((1, 3, 64, 64))) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], [result]) | |||
def test_adaptivepooling(): | |||
pool1 = M.AdaptiveMaxPool2d((2, 2)) | |||
pool2 = M.AdaptiveAvgPool2d((2, 2)) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
out = pool1(data) | |||
out = pool2(out) | |||
return out | |||
input = Tensor(np.random.random((1, 3, 32, 32))) | |||
result = fwd(input) | |||
check_pygraph_dump(fwd, [input], [result]) | |||
def test_roipooling(): | |||
inp = Tensor(np.random.random((1, 1, 128, 128))) | |||
rois = Tensor(np.random.random((4, 5))) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(inp, rois): | |||
return F.nn.roi_pooling(inp, rois, (2, 2), scale=2.0) | |||
output = fwd(inp, rois) | |||
check_pygraph_dump(fwd, [inp, rois], [output]) | |||
def test_deformable_ps_roi_pooling(): | |||
inp = Tensor(np.random.random((1, 256, 64, 64)).astype("float32")) | |||
rois = Tensor(np.random.random((1, 5)).astype("float32")) | |||
trans = Tensor(np.random.random((24, 2, 7, 7)).astype("float32")) | |||
pooled_h = 7 | |||
pooled_w = 7 | |||
sample_per_part = 4 | |||
no_trans = False | |||
part_size = 7 | |||
spatial_scale = 1.0 / 64 | |||
trans_std = 0.1 | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(inp, rois, trans): | |||
y = F.deformable_psroi_pooling( | |||
inp, | |||
rois, | |||
trans, | |||
no_trans, | |||
part_size, | |||
pooled_h, | |||
pooled_w, | |||
sample_per_part, | |||
spatial_scale, | |||
trans_std, | |||
) | |||
return y | |||
result = fwd(inp, rois, trans) | |||
check_pygraph_dump(fwd, [inp, rois, trans], [result]) | |||
def test_convbias(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(inp, weight, bias): | |||
return F.quantized.conv_bias_activation( | |||
inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" | |||
) | |||
inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) | |||
weight = Tensor(np.random.random((32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0)) | |||
bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0)) | |||
result = fwd(inp, weight, bias) | |||
check_pygraph_dump(fwd, [inp, weight, bias], [result]) | |||
def test_batch_convbias(): | |||
if is_cuda_available(): | |||
return | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(inp, weight, bias): | |||
return F.quantized.batch_conv_bias_activation( | |||
inp, weight, bias, dtype=dtype.qint8(scale=1.0), nonlinear_mode="RELU" | |||
) | |||
inp = Tensor(np.random.random((1, 3, 64, 64)), dtype=dtype.qint8(scale=1.0)) | |||
weight = Tensor(np.random.random((1, 32, 3, 3, 3)), dtype=dtype.qint8(scale=1.0)) | |||
bias = Tensor(np.random.random((1, 32, 1, 1)), dtype=dtype.qint32(scale=1.0)) | |||
result = fwd(inp, weight, bias) | |||
check_pygraph_dump(fwd, [inp, weight, bias], [result]) | |||
def test_batchnorm(): | |||
bn = M.BatchNorm2d(32) | |||
bn.eval() | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return bn(data) | |||
data = Tensor(np.random.random((1, 32, 32, 32))) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], [result]) | |||
def test_roialign(): | |||
inp = Tensor(np.random.randn(1, 1, 128, 128)) | |||
rois = Tensor(np.random.random((4, 5))) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(inp, rois): | |||
return F.nn.roi_align(inp, rois, (2, 2)) | |||
output = fwd(inp, rois) | |||
check_pygraph_dump(fwd, [inp, rois], [output]) | |||
def test_warpperspective(): | |||
inp_shape = (1, 1, 4, 4) | |||
x = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||
M_shape = (1, 3, 3) | |||
# M defines a translation: dst(1, 1, h, w) = rst(1, 1, h+1, w+1) | |||
M = Tensor( | |||
np.array( | |||
[[1.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 0.0, 1.0]], dtype=np.float32 | |||
).reshape(M_shape) | |||
) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(x, M): | |||
return F.warp_perspective(x, M, (2, 2)) | |||
result = fwd(x, M) | |||
check_pygraph_dump(fwd, [x, M], [result]) | |||
def test_warpaffine(): | |||
inp_shape = (1, 3, 3, 3) | |||
x = Tensor(np.arange(27, dtype=np.float32).reshape(inp_shape)) | |||
weightv = Tensor([[[1.26666667, 0.6, -83.33333333], [-0.33333333, 1, 66.66666667]]]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(x, weightv): | |||
return F.warp_affine(x, weightv, (2, 2), border_mode="WRAP") | |||
outp = fwd(x, weightv) | |||
check_pygraph_dump(fwd, [x, weightv], [outp]) | |||
def test_remap(): | |||
inp_shape = (1, 1, 4, 4) | |||
inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||
map_xy_shape = (1, 2, 2, 2) | |||
map_xy = Tensor( | |||
np.array( | |||
[[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], dtype=np.float32 | |||
).reshape(map_xy_shape) | |||
) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(inp, map_xy): | |||
return F.remap(inp, map_xy) | |||
out = fwd(inp, map_xy) | |||
check_pygraph_dump(fwd, [inp, map_xy], [out]) | |||
def test_resize(): | |||
x = Tensor(np.random.randn(10, 3, 32, 32)) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(x): | |||
return F.nn.interpolate(x, size=(16, 16), mode="BILINEAR") | |||
out = fwd(x) | |||
check_pygraph_dump(fwd, [x], [out]) | |||
def test_index_onehot(): | |||
src = Tensor([[1.0, 2.0]]) | |||
index = Tensor([0]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(src, index): | |||
return F.indexing_one_hot(src, index) | |||
out = fwd(src, index) | |||
check_pygraph_dump(fwd, [src, index], [out]) | |||
def test_set_onehot(): | |||
x = Tensor(np.arange(1, 4, dtype=np.int32)) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(x): | |||
return F.one_hot(x, num_classes=4) | |||
out = fwd(x) | |||
check_pygraph_dump(fwd, [x], [out]) | |||
def test_copy(): | |||
x = Tensor([1, 2, 3]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(x): | |||
return x.to("cpu0:0") | |||
o = fwd(x) | |||
check_pygraph_dump(fwd, [x], [o]) | |||
def test_argsort(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return F.argsort(data, True) | |||
data = Tensor([1.0, 2.0, 3.0, 5.0]) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], [result]) | |||
def test_argmax_min(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return F.argmax(data), F.argmin(data) | |||
data = Tensor(np.random.random((10, 10))) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], result) | |||
def test_condtake(): | |||
mask = Tensor(np.array([[True, False], [False, True]], dtype=np.bool_)) | |||
x = Tensor(np.array([[1, np.inf], [np.nan, 4]], dtype=np.float32)) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(mask, x): | |||
v, index = F.cond_take(mask, x) | |||
return v, index | |||
v, index = fwd(mask, x) | |||
check_pygraph_dump(fwd, [mask, x], [v, index]) | |||
def test_topk(): | |||
x = Tensor(np.array([2, 4, 6, 8, 7, 5, 3, 1], dtype=np.float32)) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(x): | |||
top, indices = F.topk(x, 5) | |||
return top, indices | |||
top, indices = fwd(x) | |||
check_pygraph_dump(fwd, [x], [top, indices]) | |||
def test_random(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(): | |||
x = rand.uniform(size=(2, 2)) | |||
y = rand.normal(size=(1, 3, 3, 3)) | |||
return x, y | |||
x, y = fwd() | |||
check_pygraph_dump(fwd, [], [x, y]) | |||
def test_tensor_gen(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(): | |||
a = F.linspace(3, 10, 3, device=Device("xpux").to_c()) | |||
b = F.eye(3, device=Device("xpux").to_c()) | |||
return a, b | |||
a, b = fwd() | |||
check_pygraph_dump(fwd, [], [a, b]) | |||
def test_getvarshape(): | |||
op = builtin.GetVarShape(axis=1) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return apply(op, data)[0] | |||
data = Tensor(np.random.random((1, 2, 3, 4))) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], [result]) | |||
def test_concat(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data1, data2): | |||
return F.concat([data1, data2], axis=1) | |||
x = Tensor(np.random.random((2, 3))) | |||
y = Tensor(np.random.random((2, 5))) | |||
result = fwd(x, y) | |||
check_pygraph_dump(fwd, [x, y], [result]) | |||
def test_broadcast(): | |||
inp = Tensor([[1], [2], [3], [4]]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(inp): | |||
return F.broadcast_to(inp, (4, 4)) | |||
out = fwd(inp) | |||
check_pygraph_dump(fwd, [inp], [out]) | |||
def test_identity(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return F.copy(data) | |||
data = Tensor([1.0, 2.0]) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], [result]) | |||
@pytest.mark.skip(reason="advance indexing trace error") | |||
def test_nms(): | |||
x = np.zeros((100, 4)) | |||
np.random.seed(42) | |||
x[:, :2] = np.random.rand(100, 2) * 20 | |||
x[:, 2:] = np.random.rand(100, 2) * 20 + 100 | |||
scores = Tensor(np.random.rand(100)) | |||
inp = Tensor(x) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(inp, scores): | |||
return F.nn.nms(inp, scores, iou_thresh=0.7, max_output=3) | |||
result = fwd(inp, scores) | |||
check_pygraph_dump(fwd, [inp, scores], [result]) | |||
def test_dimshuffle(): | |||
inp = Tensor([1, 2, 3, 4]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(inp): | |||
return inp.T | |||
out = fwd(inp) | |||
check_pygraph_dump(fwd, [inp], [out]) | |||
def test_reshape(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
return data.reshape((1, 8)) | |||
data = Tensor(np.random.random((1, 2, 2, 2))) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], [result]) | |||
def test_add_remove_axis(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(data): | |||
x = F.expand_dims(data, [0, 0]) | |||
y = F.squeeze(x, 0) | |||
return y | |||
data = Tensor([1.0, 2.0]) | |||
result = fwd(data) | |||
check_pygraph_dump(fwd, [data], [result]) | |||
@pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||
def test_subtensor(mode): | |||
items = [[0, True, True, True, False], [1, False, False, False, True]] | |||
data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random(2))] | |||
if mode == "get": | |||
op = builtin.Subtensor(items) | |||
data = data[:1] | |||
if mode == "set": | |||
op = builtin.SetSubtensor(items) | |||
if mode == "inc": | |||
op = builtin.IncrSubtensor(items) | |||
tensors = [Tensor(0), Tensor(4), Tensor(2), Tensor(3)] | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(*tensors): | |||
return apply(op, *tensors)[0] | |||
result = fwd(*data, *tensors) | |||
check_pygraph_dump(fwd, data + tensors, [result]) | |||
@pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||
def test_advance_indexing(mode): | |||
items = [[0, False, False, False, True]] | |||
tensors = [Tensor([0, 4, 2])] | |||
data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 5)))] | |||
if mode == "get": | |||
op = builtin.IndexingMultiAxisVec(items) | |||
data = data[:1] | |||
if mode == "set": | |||
op = builtin.IndexingSetMultiAxisVec(items) | |||
if mode == "inc": | |||
op = builtin.IndexingIncrMultiAxisVec(items) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(*tensors): | |||
return apply(op, *tensors)[0] | |||
result = fwd(*data, *tensors) | |||
check_pygraph_dump(fwd, data + tensors, [result]) | |||
@pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||
def test_mesh_indexing(mode): | |||
items = [[0, True, True, True, False], [1, False, False, False, True]] | |||
tensors = [Tensor(0), Tensor(5), Tensor(2), Tensor([1, 3])] | |||
data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 2)))] | |||
if mode == "get": | |||
op = builtin.IndexingMultiAxisVec(items) | |||
data = data[:1] | |||
if mode == "set": | |||
op = builtin.IndexingSetMultiAxisVec(items) | |||
if mode == "inc": | |||
op = builtin.IndexingIncrMultiAxisVec(items) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(*tensors): | |||
return apply(op, *tensors)[0] | |||
result = fwd(*data, *tensors) | |||
check_pygraph_dump(fwd, data + tensors, [result]) | |||
@pytest.mark.parametrize("mode", ["get", "set", "inc"]) | |||
def test_batch_mesh_indexing(mode): | |||
items = [[1, False, False, False, True], [2, False, False, False, True]] | |||
tensors = [Tensor([[0, 2], [0, 2]]), Tensor([[0, 1, 2], [1, 2, 3]])] | |||
data = [Tensor(np.random.random((2, 3, 4))), Tensor(np.random.random((2, 2, 3)))] | |||
if mode == "get": | |||
op = builtin.BatchedMeshIndexing(items) | |||
data = data[:1] | |||
if mode == "set": | |||
op = builtin.BatchedSetMeshIndexing(items) | |||
if mode == "inc": | |||
op = builtin.BatchedIncrMeshIndexing(items) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(*tensors): | |||
return apply(op, *tensors)[0] | |||
result = fwd(*data, *tensors) | |||
check_pygraph_dump(fwd, data + tensors, [result]) | |||
@pytest.mark.skip(reason="tmp skip") | |||
def test_assert_equal(): | |||
g = G.Graph() | |||
inp1 = g.make_h2d(dtype=np.float32, device="xpux") | |||
inp2 = g.make_h2d(dtype=np.float32, device="xpux") | |||
op = builtin.AssertEqual(maxerr=1e-5) | |||
out = G.apply_normal_varnode(op, inp1._node, inp2._node)[0] | |||
print(out) | |||
g.compile(out) | |||
file = io.BytesIO() | |||
out_model = G.dump_graph([out]) | |||
file.write(out_model[0]) | |||
file.seek(0) | |||
net = Net.load(file) | |||
dump_file = io.BytesIO() | |||
net.dump(dump_file) | |||
dump_file.seek(0) | |||
g = GraphInference(dump_file) | |||
g.run(np.array([1.0, 2.0]), np.array([1.0, 2.0])) | |||
def test_elemwise_multitype(): | |||
op = builtin.ElemwiseMultiType(mode="QADD", dtype=dtype.qint32(2.0)) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(x, y): | |||
return apply(op, x, y)[0] | |||
x = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0)) | |||
y = Tensor(np.random.random(10) * 10, dtype=dtype.qint8(2.0)) | |||
result = fwd(x, y) | |||
check_pygraph_dump(fwd, [x, y], [result]) | |||
def test_cvtcolor(): | |||
inp = np.random.randn(3, 3, 3, 3).astype(np.float32) | |||
x = Tensor(inp) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def fwd(inp): | |||
return F.img_proc.cvt_color(inp, mode="RGB2GRAY") | |||
result = fwd(x) | |||
check_pygraph_dump(fwd, [x], [result]) |
@@ -17,9 +17,20 @@ | |||
#include "megbrain/opr/dnn/local.h" | |||
#include "megbrain/opr/dnn/lrn.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||
#include "megbrain/opr/dnn/roi_pooling.h" | |||
#include "megbrain/opr/dnn/roi_align.h" | |||
#include "megbrain/opr/imgproc.h" | |||
#include "megbrain/opr/standalone/nms_opr.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/rand.h" | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "megbrain/opr/misc.h" | |||
#include "megbrain/opr/indexing.h" | |||
#include "megbrain/opr/internal/indexing_helper.h" | |||
#include "megbrain/opr/nn_int.h" | |||
#include "megbrain/opr/tensor_gen.h" | |||
#if MGB_ENABLE_JSON | |||
#include "megdnn/opr_param_json.h" | |||
#endif | |||
@@ -354,7 +365,7 @@ uint64_t opr_footprint_func<opr::DeformableConvForward>( | |||
auto&& out_shape = opr->output()[0]->shape(); | |||
auto&& filter_shape = opr->input()[1]->shape(); | |||
using Param = opr::DeformableConvForward::Param; | |||
auto&& param = opr->cast_final_safe<opr::Convolution>().param(); | |||
auto&& param = opr->cast_final_safe<opr::DeformableConvForward>().param(); | |||
size_t fh, fw, icpg; | |||
mgb_assert(param.format == Param::Format::NCHW); | |||
if (param.sparse == Param::Sparse::GROUP) { | |||
@@ -425,9 +436,11 @@ uint64_t opr_footprint_func<opr::BatchConvBiasForward>( | |||
auto&& filter_shape = opr->input()[1]->shape(); | |||
using Param = opr::BatchConvBiasForward::Param; | |||
auto&& param = opr->cast_final_safe<opr::BatchConvBiasForward>().param(); | |||
mgb_assert(param.format == Param::Format::NCHW4); | |||
size_t packed_channels = 4; | |||
size_t packed_channels = 1; | |||
size_t kern_spatial_pos = 3; | |||
if (param.format == Param::Format::NCHW4) { | |||
packed_channels = 4; | |||
} | |||
size_t fh = filter_shape[kern_spatial_pos], | |||
fw = filter_shape[kern_spatial_pos + 1]; | |||
return out_shape.total_nr_elems() * fh * fw * src_shape[1] * | |||
@@ -508,7 +521,29 @@ REGISTE_PARAM_JSON_FUNC(LocalShareBackwardFilter) | |||
REGISTE_PARAM_JSON_FUNC(DeformableConvForward) | |||
REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardFilter) | |||
REGISTE_PARAM_JSON_FUNC(DeformableConvBackwardData) | |||
REGISTE_PARAM_JSON_FUNC(DeformablePSROIPoolingForward) | |||
REGISTE_PARAM_JSON_FUNC(BatchConvBiasForward) | |||
REGISTE_PARAM_JSON_FUNC(BatchNormForward) | |||
REGISTE_PARAM_JSON_FUNC(ElemwiseMultiType) | |||
REGISTE_PARAM_JSON_FUNC(Argsort) | |||
REGISTE_PARAM_JSON_FUNC(Argmax) | |||
REGISTE_PARAM_JSON_FUNC(Argmin) | |||
REGISTE_PARAM_JSON_FUNC(AdaptivePooling) | |||
REGISTE_PARAM_JSON_FUNC(ROIPooling) | |||
REGISTE_PARAM_JSON_FUNC(ROIAlign) | |||
REGISTE_PARAM_JSON_FUNC(WarpPerspective) | |||
REGISTE_PARAM_JSON_FUNC(WarpAffine) | |||
REGISTE_PARAM_JSON_FUNC(Remap) | |||
REGISTE_PARAM_JSON_FUNC(Resize) | |||
REGISTE_PARAM_JSON_FUNC(IndexingOneHot) | |||
REGISTE_PARAM_JSON_FUNC(IndexingSetOneHot) | |||
REGISTE_PARAM_JSON_FUNC(TopK) | |||
REGISTE_PARAM_JSON_FUNC(UniformRNG) | |||
REGISTE_PARAM_JSON_FUNC(GaussianRNG) | |||
REGISTE_PARAM_JSON_FUNC(Linspace) | |||
REGISTE_PARAM_JSON_FUNC(Eye) | |||
REGISTE_PARAM_JSON_FUNC(CvtColor) | |||
template <> | |||
std::shared_ptr<json::Value> opr_param_json_func<opr::Dimshuffle>( | |||
@@ -547,24 +582,83 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::AxisAddRemove>( | |||
}); | |||
} | |||
std::shared_ptr<json::Value> indexing_param_to_json( | |||
const std::vector<opr::indexing::AxisIndexer>& indices) { | |||
auto desc = json::Array::make(); | |||
for (auto& index : indices) { | |||
desc->add(json::Object::make({ | |||
{"axis", json::NumberInt::make(index.axis.get_raw())}, | |||
{"begin", | |||
json::NumberInt::make(index.begin.node() != nullptr)}, | |||
{"end", json::NumberInt::make(index.end.node() != nullptr)}, | |||
{"step", | |||
json::NumberInt::make(index.step.node() != nullptr)}, | |||
{"idx", json::NumberInt::make(index.idx.node() != nullptr)}, | |||
})); | |||
} | |||
return desc; | |||
} | |||
#define REGISTE_INDEXING_PARAM_JSON_FUNC(cls) \ | |||
template <> \ | |||
std::shared_ptr<json::Value> opr_param_json_func<opr::cls>( \ | |||
cg::OperatorNodeBase * opr) { \ | |||
auto indices = opr->cast_final_safe<opr::cls>().index_desc(); \ | |||
return indexing_param_to_json(indices); \ | |||
} | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(Subtensor); | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(SetSubtensor); | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(IncrSubtensor); | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingMultiAxisVec); | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingSetMultiAxisVec); | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(IndexingIncrMultiAxisVec); | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(MeshIndexing); | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(IncrMeshIndexing); | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(SetMeshIndexing); | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedMeshIndexing); | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedIncrMeshIndexing); | |||
REGISTE_INDEXING_PARAM_JSON_FUNC(BatchedSetMeshIndexing); | |||
template <> | |||
std::shared_ptr<json::Value> opr_param_json_func<opr::Subtensor>( | |||
std::shared_ptr<json::Value> opr_param_json_func<opr::Reshape>( | |||
cg::OperatorNodeBase * opr) { | |||
auto desc = json::Array::make(); | |||
auto indices = opr->cast_final_safe<opr::Subtensor>().index_desc(); | |||
for (auto &index : indices){ | |||
desc->add( | |||
json::Object::make({ | |||
{"axis", json::NumberInt::make(index.axis.get_raw())}, | |||
{"begin", json::NumberInt::make(index.begin.node() != nullptr)}, | |||
{"end", json::NumberInt::make(index.end.node() != nullptr)}, | |||
{"step", json::NumberInt::make(index.step.node() != nullptr)}, | |||
{"idx", json::NumberInt::make(index.idx.node() != nullptr)}, | |||
})); | |||
auto axis_param = opr->cast_final_safe<opr::Reshape>().param(); | |||
if (axis_param.axis != axis_param.MAX_NDIM){ | |||
return json::Object::make({ | |||
{"axis", json::NumberInt::make(axis_param.axis)}, | |||
}); | |||
} else { | |||
return json::Object::make(); | |||
} | |||
} | |||
return desc; | |||
template <> | |||
std::shared_ptr<json::Value> opr_param_json_func<opr::GetVarShape>( | |||
cg::OperatorNodeBase * opr) { | |||
auto desc = json::Array::make(); | |||
auto axis_param = opr->cast_final_safe<opr::GetVarShape>().param(); | |||
if (axis_param.axis != axis_param.MAX_NDIM){ | |||
return json::Object::make({ | |||
{"axis", json::NumberInt::make(axis_param.axis)}, | |||
}); | |||
} else { | |||
return json::Object::make(); | |||
} | |||
} | |||
template <> | |||
std::shared_ptr<json::Value> opr_param_json_func<opr::standalone::NMSKeep>( | |||
cg::OperatorNodeBase * opr) { | |||
auto nms_param = opr->cast_final_safe<opr::standalone::NMSKeep>().param(); | |||
return json::Object::make({ | |||
{"iou_thresh", json::Number::make(nms_param.iou_thresh)}, | |||
{"max_output", json::Number::make(nms_param.max_output)}, | |||
}); | |||
} | |||
#endif // MGB_ENABLE_JSON | |||
} // namespace | |||
@@ -632,6 +726,17 @@ void OprFootprint::init_all_footprints() { | |||
add_single_param_json<opr::Dimshuffle>(); | |||
add_single_param_json<opr::AxisAddRemove>(); | |||
add_single_param_json<opr::Subtensor>(); | |||
add_single_param_json<opr::SetSubtensor>(); | |||
add_single_param_json<opr::IncrSubtensor>(); | |||
add_single_param_json<opr::IndexingMultiAxisVec>(); | |||
add_single_param_json<opr::IndexingSetMultiAxisVec>(); | |||
add_single_param_json<opr::IndexingIncrMultiAxisVec>(); | |||
add_single_param_json<opr::MeshIndexing>(); | |||
add_single_param_json<opr::SetMeshIndexing>(); | |||
add_single_param_json<opr::IncrMeshIndexing>(); | |||
add_single_param_json<opr::BatchedMeshIndexing>(); | |||
add_single_param_json<opr::BatchedSetMeshIndexing>(); | |||
add_single_param_json<opr::BatchedIncrMeshIndexing>(); | |||
add_single_param_json<opr::Reduce>(); | |||
add_single_param_json<opr::LocalShareForward>(); | |||
add_single_param_json<opr::LocalShareBackwardData>(); | |||
@@ -639,7 +744,31 @@ void OprFootprint::init_all_footprints() { | |||
add_single_param_json<opr::DeformableConvForward>(); | |||
add_single_param_json<opr::DeformableConvBackwardFilter>(); | |||
add_single_param_json<opr::DeformableConvBackwardData>(); | |||
add_single_param_json<opr::DeformablePSROIPoolingForward>(); | |||
add_single_param_json<opr::BatchConvBiasForward>(); | |||
add_single_param_json<opr::BatchNormForward>(); | |||
add_single_param_json<opr::Reshape>(); | |||
add_single_param_json<opr::GetVarShape>(); | |||
add_single_param_json<opr::Argsort>(); | |||
add_single_param_json<opr::Argmin>(); | |||
add_single_param_json<opr::Argmax>(); | |||
add_single_param_json<opr::ElemwiseMultiType>(); | |||
add_single_param_json<opr::AdaptivePooling>(); | |||
add_single_param_json<opr::ROIPooling>(); | |||
add_single_param_json<opr::ROIAlign>(); | |||
add_single_param_json<opr::WarpPerspective>(); | |||
add_single_param_json<opr::Remap>(); | |||
add_single_param_json<opr::Resize>(); | |||
add_single_param_json<opr::IndexingOneHot>(); | |||
add_single_param_json<opr::IndexingSetOneHot>(); | |||
add_single_param_json<opr::WarpAffine>(); | |||
add_single_param_json<opr::TopK>(); | |||
add_single_param_json<opr::UniformRNG>(); | |||
add_single_param_json<opr::GaussianRNG>(); | |||
add_single_param_json<opr::Linspace>(); | |||
add_single_param_json<opr::Eye>(); | |||
add_single_param_json<opr::standalone::NMSKeep>(); | |||
add_single_param_json<opr::CvtColor>(); | |||
#endif | |||
} | |||